/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Diff of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3862 by caltinay, Fri Mar 9 06:32:35 2012 UTC revision 3864 by caltinay, Mon Mar 12 05:18:16 2012 UTC
# Line 18  Primary Business: Queensland, Australia" Line 18  Primary Business: Queensland, Australia"
18  __license__="""Licensed under the Open Software License version 3.0  __license__="""Licensed under the Open Software License version 3.0
19  http://www.opensource.org/licenses/osl-3.0.php"""  http://www.opensource.org/licenses/osl-3.0.php"""
20  __url__="https://launchpad.net/escript-finley"  __url__="https://launchpad.net/escript-finley"
21    __author__="Cihan Altinay"
22    
23  """  """
24  :var __author__: name of author  :var __author__: name of author
# Line 30  __url__="https://launchpad.net/escript-f Line 31  __url__="https://launchpad.net/escript-f
31    
32  import numpy  import numpy
33  import sympy  import sympy
34    from esys.escript import Data, FunctionSpace
 __author__="Cihan Altinay"  
   
35    
36        
37  class Symbol(object):  class Symbol(object):
# Line 84  class Symbol(object): Line 83  class Symbol(object):
83          :type dim: ``int``          :type dim: ``int``
84          """          """
85          if 'dim' in kwargs:          if 'dim' in kwargs:
86              self.dim=kwargs.pop('dim')              self._dim=kwargs.pop('dim')
87          else:          else:
88              self.dim=2              self._dim=2
89    
90            if 'subs' in kwargs:
91                self._subs=kwargs.pop('subs')
92            else:
93                self._subs={}
94    
95          if len(args)==1:          if len(args)==1:
96              arg=args[0]              arg=args[0]
# Line 106  class Symbol(object): Line 110  class Symbol(object):
110                      else:                      else:
111                          res[idx]=arr[idx]                          res[idx]=arr[idx]
112                  self._arr=res                  self._arr=res
113                    if isinstance(arg, Symbol):
114                        self._subs.update(arg._subs)
115                        self._dim=arg._dim
116              elif isinstance(arg, sympy.Basic):              elif isinstance(arg, sympy.Basic):
117                  self._arr=numpy.array(arg)                  self._arr=numpy.array(arg)
118              else:              else:
# Line 189  class Symbol(object): Line 196  class Symbol(object):
196      def __iter__(self):      def __iter__(self):
197          return self._arr.__iter__          return self._arr.__iter__
198    
199        def __array__(self):
200            return self._arr
201    
202        def _sympy_(self):
203            """
204            """
205            return self.applyfunc(sympy.sympify)
206    
207      def getDim(self):      def getDim(self):
208          """          """
209          Returns the spatial dimensionality of this symbol.          Returns the spatial dimensionality of this symbol.
# Line 196  class Symbol(object): Line 211  class Symbol(object):
211          :return: the symbol's spatial dimensionality          :return: the symbol's spatial dimensionality
212          :rtype: ``int``          :rtype: ``int``
213          """          """
214          return self.dim          return self._dim
215    
216      def getRank(self):      def getRank(self):
217          """          """
# Line 216  class Symbol(object): Line 231  class Symbol(object):
231          """          """
232          return self._arr.shape          return self._arr.shape
233    
234        def getDataSubstitutions(self):
235            """
236            Returns a dictionary of symbol names and the escript ``Data`` objects
237            they represent within this Symbol.
238    
239            :return: the dictionary of substituted ``Data`` objects
240            :rtype: ``dict``
241            """
242            return self._subs
243    
244      def item(self, *args):      def item(self, *args):
245          """          """
246          Returns an element of this symbol.          Returns an element of this symbol.
# Line 340  class Symbol(object): Line 365  class Symbol(object):
365              none_to_zero=lambda item: 0 if item is None else item              none_to_zero=lambda item: 0 if item is None else item
366              result=self.applyfunc(coeff_item)              result=self.applyfunc(coeff_item)
367              result=result.applyfunc(none_to_zero)              result=result.applyfunc(none_to_zero)
368          return Symbol(result, dim=self.dim)          res=Symbol(result, dim=self._dim)
369            for i in self._subs: res.subs(i, self._subs[i])
370            return res
371    
372      def subs(self, old, new):      def subs(self, old, new):
373          """          """
374          Substitutes an expression.          Substitutes an expression.
375          """          """
376          if isinstance(old, Symbol) and old.getRank()>0:          old._ensureShapeCompatible(new)
377              old._ensureShapeCompatible(new)          if isinstance(new, Data):
378                subs=self._subs.copy()
379                if isinstance(old, Symbol) and old.getRank()>0:
380                    old=Symbol(old.atoms(sympy.Symbol)[0])
381                subs[old]=new
382                result=Symbol(self._arr, dim=self._dim, subs=subs)
383            elif isinstance(old, Symbol) and old.getRank()>0:
384              if hasattr(new, '__array__'):              if hasattr(new, '__array__'):
385                  new=new.__array__()                  new=new.__array__()
386              else:              else:
387                  new=numpy.array(new)                  new=numpy.array(new)
388    
389              result=Symbol(self._arr, dim=self.dim)              result=numpy.empty(self.getShape(), dtype=object)
390              if new.ndim>0:              if new.ndim>0:
391                  for idx in numpy.ndindex(self.getShape()):                  for idx in numpy.ndindex(self.getShape()):
392                      for symidx in numpy.ndindex(new.shape):                      for symidx in numpy.ndindex(new.shape):
393                          result[idx]=result[idx].subs(old[symidx], new[symidx])                          result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
394              else: # substitute scalar for non-scalar              else: # substitute scalar for non-scalar
395                  for idx in numpy.ndindex(self.getShape()):                  for idx in numpy.ndindex(self.getShape()):
396                      for symidx in numpy.ndindex(old.getShape()):                      for symidx in numpy.ndindex(old.getShape()):
397                          result[idx]=result[idx].subs(old[symidx], new.item())                          result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
398                result=Symbol(result, dim=self._dim, subs=self._subs)
399          else: # scalar          else: # scalar
400              if isinstance(new, Symbol):              if isinstance(new, Symbol):
                 if new.getRank()>0:  
                     raise TypeError("Cannot substitute, incompatible ranks.")  
401                  new=new.item()                  new=new.item()
402              if isinstance(old, Symbol):              if isinstance(old, Symbol):
403                  old=old.item()                  old=old.item()
# Line 377  class Symbol(object): Line 409  class Symbol(object):
409          """          """
410          """          """
411          symbols=Symbol._symbolgen(*symbols)          symbols=Symbol._symbolgen(*symbols)
412          result=Symbol(self._arr, dim=self.dim)          result=Symbol(self._arr, dim=self._dim, subs=self._subs)
413          for s in symbols:          for s in symbols:
414              if isinstance(s, Symbol):              if isinstance(s, Symbol):
415                  if s.getRank()==0:                  if s.getRank()==0:
# Line 390  class Symbol(object): Line 422  class Symbol(object):
422                          for idx in numpy.ndindex(self.getShape()):                          for idx in numpy.ndindex(self.getShape()):
423                              index=idx+(d,)                              index=idx+(d,)
424                              out[index]=out[index].diff(s[d].item(), **assumptions)                              out[index]=out[index].diff(s[d].item(), **assumptions)
425                      result=Symbol(out, dim=self.dim)                      result=Symbol(out, dim=self._dim, subs=self._subs)
426                  else:                  else:
427                      raise ValueError("diff: Only rank 0 and 1 supported")                      raise ValueError("diff: argument must have rank 0 or 1")
428              else:              else:
429                  diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)                  diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
430                  result=result.applyfunc(diff_item)                  result=result.applyfunc(diff_item)
# Line 400  class Symbol(object): Line 432  class Symbol(object):
432    
433      def grad(self, where=None):      def grad(self, where=None):
434          """          """
435            :type where: ``Symbol``, ``FunctionSpace``
436          """          """
437            subs=self._subs
438          if isinstance(where, Symbol):          if isinstance(where, Symbol):
439              if where.getRank()>0:              if where.getRank()>0:
440                  raise ValueError("grad: 'where' must be a scalar symbol")                  raise ValueError("grad: 'where' must be a scalar symbol")
441              where=where._arr.item()              where=where._arr.item()
442            elif isinstance(where, FunctionSpace):
443                name='fs'+str(id(where))
444                fssym=Symbol(name)
445                subs=self._subs.copy()
446                subs.update({fssym:where})
447                where=name
448    
449          from functions import grad_n          from functions import grad_n
450          out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())          out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
451          for d in range(self.dim):          for d in range(self._dim):
452              for idx in numpy.ndindex(self.getShape()):              for idx in numpy.ndindex(self.getShape()):
453                  index=idx+(d,)                  index=idx+(d,)
454                  if where is None:                  if where is None:
455                      out[index]=grad_n(out[index],d)                      out[index]=grad_n(out[index],d)
456                  else:                  else:
457                      out[index]=grad_n(out[index],d,where)                      out[index]=grad_n(out[index],d,where)
458          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self._dim, subs=subs)
459    
460      def inverse(self):      def inverse(self):
461          """          """
# Line 469  class Symbol(object): Line 509  class Symbol(object):
509              out[2,2]=(A11*A22-A12*A21)*D              out[2,2]=(A11*A22-A12*A21)*D
510          else:          else:
511             raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")             raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
512          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self._dim, subs=self._subs)
513    
514      def swap_axes(self, axis0, axis1):      def swap_axes(self, axis0, axis1):
515          """          """
516          """          """
517          return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)          return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
518    
519      def tensorProduct(self, other, axis_offset):      def tensorProduct(self, other, axis_offset):
520          """          """
# Line 498  class Symbol(object): Line 538  class Symbol(object):
538              for i1 in range(d1):              for i1 in range(d1):
539                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
540          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
541          return Symbol(out, dim=self.dim)          subs=self._subs.copy()
542            subs.update(other._subs)
543            return Symbol(out, dim=self._dim, subs=subs)
544    
545      def transposedTensorProduct(self, other, axis_offset):      def transposedTensorProduct(self, other, axis_offset):
546          """          """
# Line 522  class Symbol(object): Line 564  class Symbol(object):
564              for i1 in range(d1):              for i1 in range(d1):
565                  out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])                  out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
566          out.resize(sh0[axis_offset:]+sh1[axis_offset:])          out.resize(sh0[axis_offset:]+sh1[axis_offset:])
567          return Symbol(out, dim=self.dim)          subs=self._subs.copy()
568            subs.update(other._subs)
569            return Symbol(out, dim=self._dim, subs=subs)
570    
571      def tensorTransposedProduct(self, other, axis_offset):      def tensorTransposedProduct(self, other, axis_offset):
572          """          """
# Line 548  class Symbol(object): Line 592  class Symbol(object):
592              for i1 in range(d1):              for i1 in range(d1):
593                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
594          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
595          return Symbol(out, dim=self.dim)          subs=self._subs.copy()
596            subs.update(other._subs)
597            return Symbol(out, dim=self._dim, subs=subs)
598    
599      def trace(self, axis_offset):      def trace(self, axis_offset):
600          """          """
# Line 565  class Symbol(object): Line 611  class Symbol(object):
611                  for j in range(sh[axis_offset]):                  for j in range(sh[axis_offset]):
612                      out[i1,i2]+=arr_r[i1,j,j,i2]                      out[i1,i2]+=arr_r[i1,j,j,i2]
613          out.resize(sh[:axis_offset]+sh[axis_offset+2:])          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
614          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self._dim, subs=self._subs)
615    
616      def transpose(self, axis_offset):      def transpose(self, axis_offset):
617          """          """
# Line 573  class Symbol(object): Line 619  class Symbol(object):
619          if axis_offset is None:          if axis_offset is None:
620              axis_offset=int(self._arr.ndim/2)              axis_offset=int(self._arr.ndim/2)
621          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
622          return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)          return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
623    
624      def applyfunc(self, f, on_type=None):      def applyfunc(self, f, on_type=None):
625          """          """
626          """          """
627          assert callable(f)          assert callable(f)
628          if self._arr.ndim==0:          if self._arr.ndim==0:
629              if on_type is None or isinstance(self._arr.item(),on_type):              if on_type is None or isinstance(self._arr.item(), on_type):
630                  el=f(self._arr.item())                  el=f(self._arr.item())
631              else:              else:
632                  el=self._arr.item()                  el=self._arr.item()
633              if el is not None:              if el is not None:
634                  out=Symbol(el, dim=self.dim)                  out=Symbol(el, dim=self._dim, subs=self._subs)
635              else:              else:
636                  return el                  return el
637          else:          else:
# Line 595  class Symbol(object): Line 641  class Symbol(object):
641                      out[idx]=f(self._arr[idx])                      out[idx]=f(self._arr[idx])
642                  else:                  else:
643                      out[idx]=self._arr[idx]                      out[idx]=self._arr[idx]
644              out=Symbol(out, dim=self.dim)              out=Symbol(out, dim=self._dim, subs=self._subs)
645          return out          return out
646    
647      def expand(self):      def expand(self):
# Line 608  class Symbol(object): Line 654  class Symbol(object):
654          """          """
655          return self.applyfunc(sympy.simplify, sympy.Basic)          return self.applyfunc(sympy.simplify, sympy.Basic)
656    
657      def _sympy_(self):      # unary/binary operations follow
658          """  
659          """      def __pos__(self):
660          return self.applyfunc(sympy.sympify)          return self
661    
662        def __neg__(self):
663            return Symbol(-self._arr, dim=self._dim, subs=self._subs)
664    
665        def __abs__(self):
666            return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
667    
668      def _ensureShapeCompatible(self, other):      def _ensureShapeCompatible(self, other):
669          """          """
# Line 619  class Symbol(object): Line 671  class Symbol(object):
671          Raises TypeError if not compatible.          Raises TypeError if not compatible.
672          """          """
673          sh0=self.getShape()          sh0=self.getShape()
674          if isinstance(other, Symbol): # or isinstance(other, Data):          if isinstance(other, Symbol) or isinstance(other, Data):
675              sh1=other.getShape()              sh1=other.getShape()
676          elif isinstance(other, numpy.ndarray):          elif isinstance(other, numpy.ndarray):
677              sh1=other.shape              sh1=other.shape
# Line 632  class Symbol(object): Line 684  class Symbol(object):
684          if not sh0==sh1 and not sh0==() and not sh1==():          if not sh0==sh1 and not sh0==() and not sh1==():
685              raise TypeError("Incompatible shapes for operation")              raise TypeError("Incompatible shapes for operation")
686    
687        def __binaryop(self, op, other):
688            self._ensureShapeCompatible(other)
689            if isinstance(other, Symbol):
690                subs=self._subs.copy()
691                subs.update(other._subs)
692                return Symbol(getattr(self._arr, op)(other._arr), dim=self._dim, subs=subs)
693            if isinstance(other, Data):
694                name='data'+str(id(other))
695                othersym=Symbol(name, other.getShape(), dim=self._dim)
696                subs=self._subs.copy()
697                subs.update({Symbol(name):other})
698                return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
699            return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
700    
701        def __add__(self, other):
702            return self.__binaryop('__add__', other)
703    
704        def __radd__(self, other):
705            return self.__binaryop('__radd__', other)
706    
707        def __sub__(self, other):
708            return self.__binaryop('__sub__', other)
709    
710        def __rsub__(self, other):
711            return self.__binaryop('__rsub__', other)
712    
713        def __mul__(self, other):
714            return self.__binaryop('__mul__', other)
715    
716        def __rmul__(self, other):
717            return self.__binaryop('__rmul__', other)
718    
719        def __div__(self, other):
720            return self.__binaryop('__div__', other)
721    
722        def __rdiv__(self, other):
723            return self.__binaryop('__rdiv__', other)
724    
725        def __pow__(self, other):
726            return self.__binaryop('__pow__', other)
727    
728        def __rpow__(self, other):
729            return self.__binaryop('__rpow__', other)
730    
731      @staticmethod      @staticmethod
732      def _symComp(sym):      def _symComp(sym):
733          """          """
# Line 686  class Symbol(object): Line 782  class Symbol(object):
782              else:              else:
783                  yield s                  yield s
784    
     def __array__(self):  
         return self._arr  
   
     # unary/binary operations follow  
   
     def __pos__(self):  
         return self  
   
     def __neg__(self):  
         return Symbol(-self._arr, dim=self.dim)  
   
     def __abs__(self):  
         return Symbol(abs(self._arr), dim=self.dim)  
   
     def __add__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr+other._arr, dim=self.dim)  
         return Symbol(self._arr+other, dim=self.dim)  
   
     def __radd__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr+self._arr, dim=self.dim)  
         return Symbol(other+self._arr, dim=self.dim)  
   
     def __sub__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr-other._arr, dim=self.dim)  
         return Symbol(self._arr-other, dim=self.dim)  
   
     def __rsub__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr-self._arr, dim=self.dim)  
         return Symbol(other-self._arr, dim=self.dim)  
   
     def __mul__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr*other._arr, dim=self.dim)  
         return Symbol(self._arr*other, dim=self.dim)  
   
     def __rmul__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr*self._arr, dim=self.dim)  
         return Symbol(other*self._arr, dim=self.dim)  
   
     def __div__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr/other._arr, dim=self.dim)  
         return Symbol(self._arr/other, dim=self.dim)  
   
     def __rdiv__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr/self._arr, dim=self.dim)  
         return Symbol(other/self._arr, dim=self.dim)  
   
     def __pow__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr**other._arr, dim=self.dim)  
         return Symbol(self._arr**other, dim=self.dim)  
   
     def __rpow__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr**self._arr, dim=self.dim)  
         return Symbol(other**self._arr, dim=self.dim)  
   

Legend:
Removed from v.3862  
changed lines
  Added in v.3864

  ViewVC Help
Powered by ViewVC 1.1.26