/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Diff of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3516 by caltinay, Wed May 18 06:22:46 2011 UTC revision 3517 by caltinay, Fri May 20 01:16:41 2011 UTC
# Line 142  class Symbol(object): Line 142  class Symbol(object):
142    
143      def lambdarepr(self):      def lambdarepr(self):
144          from sympy.printing.lambdarepr import lambdarepr          from sympy.printing.lambdarepr import lambdarepr
         if self.getRank()==0:  
             return lambdarepr(self._arr.item())  
145          temp_arr=numpy.empty(self.getShape(), dtype=object)          temp_arr=numpy.empty(self.getShape(), dtype=object)
146          for idx,el in numpy.ndenumerate(self._arr):          for idx,el in numpy.ndenumerate(self._arr):
147              atoms=el.atoms(sympy.Symbol)              atoms=el.atoms(sympy.Symbol)
# Line 161  class Symbol(object): Line 159  class Symbol(object):
159              for key in symdict:              for key in symdict:
160                  s=s.replace(key, symdict[key])                  s=s.replace(key, symdict[key])
161              temp_arr[idx]=s              temp_arr[idx]=s
162          return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))          if self.getRank()==0:
163                return temp_arr.item()
164            else:
165                return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
166    
167      def diff(self, *symbols, **assumptions):      def diff(self, *symbols, **assumptions):
168          symbols=Symbol._symbolgen(*symbols)          symbols=Symbol._symbolgen(*symbols)
# Line 202  class Symbol(object): Line 203  class Symbol(object):
203                      out[index]=grad_n(out[index],d,where)                      out[index]=grad_n(out[index],d,where)
204          return Symbol(out)          return Symbol(out)
205    
206        def inverse(self):
207            if not self.getRank()==2:
208                raise ValueError("inverse: Only rank 2 supported")
209            s=self.getShape()
210            if not s[0] == s[1]:
211                raise ValueError("inverse: Only square shapes supported")
212            out=numpy.zeros(s, numpy.object)
213            arr=self._arr
214            if s[0]==1:
215                if arr[0,0].is_zero:
216                    raise ZeroDivisionError("inverse: Symbol not invertible")
217                out[0,0]=1./arr[0,0]
218            elif s[0]==2:
219                A11=arr[0,0]
220                A12=arr[0,1]
221                A21=arr[1,0]
222                A22=arr[1,1]
223                D = A11*A22-A12*A21
224                if D.is_zero:
225                    raise ZeroDivisionError("inverse: Symbol not invertible")
226                D=1./D
227                out[0,0]= A22*D
228                out[1,0]=-A21*D
229                out[0,1]=-A12*D
230                out[1,1]= A11*D
231            elif s[0]==3:
232                A11=arr[0,0]
233                A21=arr[1,0]
234                A31=arr[2,0]
235                A12=arr[0,1]
236                A22=arr[1,1]
237                A32=arr[2,1]
238                A13=arr[0,2]
239                A23=arr[1,2]
240                A33=arr[2,2]
241                D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
242                if D.is_zero:
243                    raise ZeroDivisionError("inverse: Symbol not invertible")
244                D=1./D
245                out[0,0]=(A22*A33-A23*A32)*D
246                out[1,0]=(A31*A23-A21*A33)*D
247                out[2,0]=(A21*A32-A31*A22)*D
248                out[0,1]=(A13*A32-A12*A33)*D
249                out[1,1]=(A11*A33-A31*A13)*D
250                out[2,1]=(A12*A31-A11*A32)*D
251                out[0,2]=(A12*A23-A13*A22)*D
252                out[1,2]=(A13*A21-A11*A23)*D
253                out[2,2]=(A11*A22-A12*A21)*D
254            else:
255               raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
256            return Symbol(out)
257    
258      def swap_axes(self, axis0, axis1):      def swap_axes(self, axis0, axis1):
259          return Symbol(numpy.swapaxes(self._arr, axis0, axis1))          return Symbol(numpy.swapaxes(self._arr, axis0, axis1))
260    
# Line 308  class Symbol(object): Line 361  class Symbol(object):
361      def _sympy_(self):      def _sympy_(self):
362          return self.applyfunc(sympy.sympify)          return self.applyfunc(sympy.sympify)
363    
364        def _ensureShapeCompatible(self, other):
365            """
366            Checks for compatible shapes for binary operations.
367            Raises TypeError if not compatible.
368            """
369            sh0=self.getShape()
370            if isinstance(other, Symbol):
371                sh1=other.getShape()
372            elif isinstance(other, numpy.ndarray):
373                sh1=other.shape
374            elif isinstance(other,int) or isinstance(other,float):
375                sh1=()
376            else:
377                raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)
378            if not sh0==sh1 and not sh0==() and not sh1==():
379                raise TypeError("Incompatible shapes for binary operation")
380    
381      @staticmethod      @staticmethod
382      def _symComp(sym):      def _symComp(sym):
383          n=sym.name          n=sym.name
# Line 372  class Symbol(object): Line 442  class Symbol(object):
442          return Symbol(abs(self._arr))          return Symbol(abs(self._arr))
443    
444      def __add__(self, other):      def __add__(self, other):
445            self._ensureShapeCompatible(other)
446          if isinstance(other, Symbol):          if isinstance(other, Symbol):
447              return Symbol(self._arr+other._arr)              return Symbol(self._arr+other._arr)
448          return Symbol(self._arr+other)          return Symbol(self._arr+other)
449    
450      def __radd__(self, other):      def __radd__(self, other):
451            self._ensureShapeCompatible(other)
452          if isinstance(other, Symbol):          if isinstance(other, Symbol):
453              return Symbol(other._arr+self._arr)              return Symbol(other._arr+self._arr)
454          return Symbol(other+self._arr)          return Symbol(other+self._arr)
455    
456      def __sub__(self, other):      def __sub__(self, other):
457            self._ensureShapeCompatible(other)
458          if isinstance(other, Symbol):          if isinstance(other, Symbol):
459              return Symbol(self._arr-other._arr)              return Symbol(self._arr-other._arr)
460          return Symbol(self._arr-other)          return Symbol(self._arr-other)
461    
462      def __rsub__(self, other):      def __rsub__(self, other):
463            self._ensureShapeCompatible(other)
464          if isinstance(other, Symbol):          if isinstance(other, Symbol):
465              return Symbol(other._arr-self._arr)              return Symbol(other._arr-self._arr)
466          return Symbol(other-self._arr)          return Symbol(other-self._arr)
467    
468      def __mul__(self, other):      def __mul__(self, other):
469            self._ensureShapeCompatible(other)
470          if isinstance(other, Symbol):          if isinstance(other, Symbol):
471              return Symbol(self._arr*other._arr)              return Symbol(self._arr*other._arr)
472          return Symbol(self._arr*other)          return Symbol(self._arr*other)
473    
474      def __rmul__(self, other):      def __rmul__(self, other):
475            self._ensureShapeCompatible(other)
476          if isinstance(other, Symbol):          if isinstance(other, Symbol):
477              return Symbol(other._arr*self._arr)              return Symbol(other._arr*self._arr)
478          return Symbol(other*self._arr)          return Symbol(other*self._arr)
479    
480      def __div__(self, other):      def __div__(self, other):
481            self._ensureShapeCompatible(other)
482          if isinstance(other, Symbol):          if isinstance(other, Symbol):
483              return Symbol(self._arr/other._arr)              return Symbol(self._arr/other._arr)
484          return Symbol(self._arr/other)          return Symbol(self._arr/other)
485    
486      def __rdiv__(self, other):      def __rdiv__(self, other):
487            self._ensureShapeCompatible(other)
488          if isinstance(other, Symbol):          if isinstance(other, Symbol):
489              return Symbol(other._arr/self._arr)              return Symbol(other._arr/self._arr)
490          return Symbol(other/self._arr)          return Symbol(other/self._arr)
491    
492      def __pow__(self, other):      def __pow__(self, other):
493            self._ensureShapeCompatible(other)
494          if isinstance(other, Symbol):          if isinstance(other, Symbol):
495              return Symbol(self._arr**other._arr)              return Symbol(self._arr**other._arr)
496          return Symbol(self._arr**other)          return Symbol(self._arr**other)
497    
498      def __rpow__(self, other):      def __rpow__(self, other):
499            self._ensureShapeCompatible(other)
500          if isinstance(other, Symbol):          if isinstance(other, Symbol):
501              return Symbol(other._arr**self._arr)              return Symbol(other._arr**self._arr)
502          return Symbol(other**self._arr)          return Symbol(other**self._arr)
# Line 484  def combineData(array, shape): Line 564  def combineData(array, shape):
564              d[idx]=n[idx]              d[idx]=n[idx]
565      return d      return d
566    
567    
568    class SymFunction(Symbol):
569        """
570        """
571        def __init__(self, *args, **kwargs):
572            """
573            Initializes a new symbolic function object.
574            """
575            super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
576            self.args=args
577    
578        def __repr__(self):
579            return self.name+"("+", ".join([str(a) for a in self.args])+")"
580    
581        def __str__(self):
582            return self.name+"("+", ".join([str(a) for a in self.args])+")"
583    
584        def lambdarepr(self):
585            return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
586    
587        def atoms(self, *types):
588            s=set()
589            for el in self.args:
590                atoms=el.atoms(*types)
591                for a in atoms:
592                    if a.is_Symbol:
593                        n,c=Symbol._symComp(a)
594                        s.add(sympy.Symbol(n))
595                    else:
596                        s.add(a)
597            return s
598    
599        def __neg__(self):
600            res=self.__class__(*self.args)
601            res._arr=-res._arr
602            return res
603    

Legend:
Removed from v.3516  
changed lines
  Added in v.3517

  ViewVC Help
Powered by ViewVC 1.1.26