/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Diff of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3518 by caltinay, Fri May 20 06:29:31 2011 UTC revision 3530 by caltinay, Wed Jun 15 04:48:53 2011 UTC
# Line 43  class Symbol(object): Line 43  class Symbol(object):
43          """          """
44          Initializes a new Symbol object.          Initializes a new Symbol object.
45          """          """
46          #from esys.escript import Data          if 'dim' in kwargs:
47                self.dim=kwargs.pop('dim')
48            else:
49                self.dim=2
50    
51          if len(args)==1:          if len(args)==1:
52              arg=args[0]              arg=args[0]
53              if isinstance(arg, str):              if isinstance(arg, str):
# Line 57  class Symbol(object): Line 61  class Symbol(object):
61                  self._arr=arr.copy()                  self._arr=arr.copy()
62              elif isinstance(arg, list) or isinstance(arg, sympy.Basic):              elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
63                  self._arr=numpy.array(arg)                  self._arr=numpy.array(arg)
             #elif isinstance(arg, Data):  
             #    self._arr=arg  
64              else:              else:
65                  raise TypeError("Unsupported argument type %s"%str(type(arg)))                  raise TypeError("Unsupported argument type %s"%str(type(arg)))
66          elif len(args)==2:          elif len(args)==2:
# Line 128  class Symbol(object): Line 130  class Symbol(object):
130      def atoms(self, *types):      def atoms(self, *types):
131          s=set()          s=set()
132          for el in self._arr.flat:          for el in self._arr.flat:
133              atoms=el.atoms(*types)              if isinstance(el,sympy.Basic):
134              for a in atoms:                  atoms=el.atoms(*types)
135                  if a.is_Symbol:                  for a in atoms:
136                      n,c=Symbol._symComp(a)                      if a.is_Symbol:
137                      s.add(sympy.Symbol(n))                          n,c=Symbol._symComp(a)
138                  else:                          s.add(sympy.Symbol(n))
139                      s.add(a)                      else:
140                            s.add(a)
141                else:
142                    # TODO: Numbers?
143                    pass
144          return s          return s
145    
146      def _sympystr_(self, printer):      def _sympystr_(self, printer):
# Line 144  class Symbol(object): Line 150  class Symbol(object):
150          from sympy.printing.lambdarepr import lambdarepr          from sympy.printing.lambdarepr import lambdarepr
151          temp_arr=numpy.empty(self.getShape(), dtype=object)          temp_arr=numpy.empty(self.getShape(), dtype=object)
152          for idx,el in numpy.ndenumerate(self._arr):          for idx,el in numpy.ndenumerate(self._arr):
153              atoms=el.atoms(sympy.Symbol)              atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
154              # create a dictionary to convert names like [x]_0_0 to x[0,0]              # create a dictionary to convert names like [x]_0_0 to x[0,0]
155              symdict={}              symdict={}
156              for a in atoms:              for a in atoms:
# Line 164  class Symbol(object): Line 170  class Symbol(object):
170          else:          else:
171              return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))              return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
172    
173        def coeff(self, x, expand=True):
174            self._ensureShapeCompatible(x)
175            result=Symbol(self._arr, dim=self.dim)
176            if isinstance(x, Symbol):
177                if x.getRank()>0:
178                    a=result._arr.flat
179                    b=x._arr.flat
180                    for idx in range(len(a)):
181                        s=b.next()
182                        if s==0:
183                            a[idx]=0
184                        else:
185                            a[idx]=a[idx].coeff(s, expand)
186                else:
187                    if x._arr.item()==0:
188                        result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
189                    else:
190                        coeff_item=lambda item: getattr(item, 'coeff')(x._arr.item(), expand)
191                        result=result.applyfunc(coeff_item)
192            elif x==0:
193                result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
194            else:
195                coeff_item=lambda item: getattr(item, 'coeff')(x, expand)
196                result=result.applyfunc(coeff_item)
197    
198            # replace None by 0
199            if result is None: return 0
200            a=result._arr.flat
201            for idx in range(len(a)):
202                if a[idx] is None: a[idx]=0
203            return result
204    
205      def diff(self, *symbols, **assumptions):      def diff(self, *symbols, **assumptions):
206          symbols=Symbol._symbolgen(*symbols)          symbols=Symbol._symbolgen(*symbols)
207          result=Symbol(self._arr)          result=Symbol(self._arr, dim=self.dim)
208          for s in symbols:          for s in symbols:
209              if isinstance(s, Symbol):              if isinstance(s, Symbol):
210                  if s.getRank()>0:                  if s.getRank()==0:
                     if s.getShape()!=self.getShape():  
                         raise ValueError("diff: Incompatible shapes")  
                     a=result._arr.flat  
                     b=s._arr.flat  
                     for idx in range(len(a)):  
                         a[idx]=a[idx].diff(b.next())  
                 else:  
211                      diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)                      diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
212                      result=result.applyfunc(diff_item)                      result=result.applyfunc(diff_item)
213                    elif s.getRank()==1:
214                        dim=s.getShape()[0]
215                        out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
216                        for d in range(dim):
217                            for idx in numpy.ndindex(self.getShape()):
218                                index=idx+(d,)
219                                out[index]=out[index].diff(s[d], **assumptions)
220                        result=Symbol(out, dim=self.dim)
221                    else:
222                        raise ValueError("diff: Only rank 0 and 1 supported")
223              else:              else:
224                  diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)                  diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
225                  result=result.applyfunc(diff_item)                  result=result.applyfunc(diff_item)
# Line 192  class Symbol(object): Line 232  class Symbol(object):
232              where=where._arr.item()              where=where._arr.item()
233    
234          from functions import grad_n          from functions import grad_n
235          dim=2          out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
236          out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())          for d in range(self.dim):
         for d in range(dim):  
237              for idx in numpy.ndindex(self.getShape()):              for idx in numpy.ndindex(self.getShape()):
238                  index=idx+(d,)                  index=idx+(d,)
239                  if where is None:                  if where is None:
240                      out[index]=grad_n(out[index],d)                      out[index]=grad_n(out[index],d)
241                  else:                  else:
242                      out[index]=grad_n(out[index],d,where)                      out[index]=grad_n(out[index],d,where)
243          return Symbol(out)          return Symbol(out, dim=self.dim)
244    
245      def inverse(self):      def inverse(self):
246          if not self.getRank()==2:          if not self.getRank()==2:
# Line 253  class Symbol(object): Line 292  class Symbol(object):
292              out[2,2]=(A11*A22-A12*A21)*D              out[2,2]=(A11*A22-A12*A21)*D
293          else:          else:
294             raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")             raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
295          return Symbol(out)          return Symbol(out, dim=self.dim)
296    
297      def swap_axes(self, axis0, axis1):      def swap_axes(self, axis0, axis1):
298          return Symbol(numpy.swapaxes(self._arr, axis0, axis1))          return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
299    
300      def tensorProduct(self, other, axis_offset):      def tensorProduct(self, other, axis_offset):
301          arg0_c=self._arr.copy()          arg0_c=self._arr.copy()
# Line 278  class Symbol(object): Line 317  class Symbol(object):
317              for i1 in range(d1):              for i1 in range(d1):
318                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
319          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
320          return Symbol(out)          return Symbol(out, dim=self.dim)
321    
322      def transposedTensorProduct(self, other, axis_offset):      def transposedTensorProduct(self, other, axis_offset):
323          arg0_c=self._arr.copy()          arg0_c=self._arr.copy()
# Line 300  class Symbol(object): Line 339  class Symbol(object):
339              for i1 in range(d1):              for i1 in range(d1):
340                  out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])                  out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
341          out.resize(sh0[axis_offset:]+sh1[axis_offset:])          out.resize(sh0[axis_offset:]+sh1[axis_offset:])
342          return Symbol(out)          return Symbol(out, dim=self.dim)
343    
344      def tensorTransposedProduct(self, other, axis_offset):      def tensorTransposedProduct(self, other, axis_offset):
345          arg0_c=self._arr.copy()          arg0_c=self._arr.copy()
# Line 324  class Symbol(object): Line 363  class Symbol(object):
363              for i1 in range(d1):              for i1 in range(d1):
364                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
365          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
366          return Symbol(out)          return Symbol(out, dim=self.dim)
367    
368      def trace(self, axis_offset):      def trace(self, axis_offset):
369          sh=self.getShape()          sh=self.getShape()
# Line 339  class Symbol(object): Line 378  class Symbol(object):
378                  for j in range(sh[axis_offset]):                  for j in range(sh[axis_offset]):
379                      out[i1,i2]+=arr_r[i1,j,j,i2]                      out[i1,i2]+=arr_r[i1,j,j,i2]
380          out.resize(sh[:axis_offset]+sh[axis_offset+2:])          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
381          return Symbol(out)          return Symbol(out, dim=self.dim)
382    
383      def transpose(self, axis_offset):      def transpose(self, axis_offset):
384          if axis_offset is None:          if axis_offset is None:
385              axis_offset=int(self._arr.ndim/2)              axis_offset=int(self._arr.ndim/2)
386          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
387          return Symbol(numpy.transpose(self._arr, axes=axes))          return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
388    
389      def applyfunc(self, f):      def applyfunc(self, f):
390          assert callable(f)          assert callable(f)
391          if self._arr.ndim==0:          if self._arr.ndim==0:
392              out=Symbol(f(self._arr.item()))              el=f(self._arr.item())
393                if el is not None:
394                    out=Symbol(el, dim=self.dim)
395                else:
396                    return el
397          else:          else:
398              out=numpy.empty(self.getShape(), dtype=object)              out=numpy.empty(self.getShape(), dtype=object)
399              for idx in numpy.ndindex(self.getShape()):              for idx in numpy.ndindex(self.getShape()):
400                  out[idx]=f(self._arr[idx])                  out[idx]=f(self._arr[idx])
401              out=Symbol(out)              out=Symbol(out, dim=self.dim)
402          return out          return out
403    
404      def _sympy_(self):      def _sympy_(self):
# Line 371  class Symbol(object): Line 414  class Symbol(object):
414              sh1=other.getShape()              sh1=other.getShape()
415          elif isinstance(other, numpy.ndarray):          elif isinstance(other, numpy.ndarray):
416              sh1=other.shape              sh1=other.shape
417          elif isinstance(other,int) or isinstance(other,float):          elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
418              sh1=()              sh1=()
419          else:          else:
420              raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)              raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)
# Line 436  class Symbol(object): Line 479  class Symbol(object):
479          return self          return self
480    
481      def __neg__(self):      def __neg__(self):
482          return Symbol(-self._arr)          return Symbol(-self._arr, dim=self.dim)
483    
484      def __abs__(self):      def __abs__(self):
485          return Symbol(abs(self._arr))          return Symbol(abs(self._arr), dim=self.dim)
486    
487      def __add__(self, other):      def __add__(self, other):
488          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
489          if isinstance(other, Symbol):          if isinstance(other, Symbol):
490              return Symbol(self._arr+other._arr)              return Symbol(self._arr+other._arr, dim=self.dim)
491          return Symbol(self._arr+other)          return Symbol(self._arr+other, dim=self.dim)
492    
493      def __radd__(self, other):      def __radd__(self, other):
494          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
495          if isinstance(other, Symbol):          if isinstance(other, Symbol):
496              return Symbol(other._arr+self._arr)              return Symbol(other._arr+self._arr, dim=self.dim)
497          return Symbol(other+self._arr)          return Symbol(other+self._arr, dim=self.dim)
498    
499      def __sub__(self, other):      def __sub__(self, other):
500          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
501          if isinstance(other, Symbol):          if isinstance(other, Symbol):
502              return Symbol(self._arr-other._arr)              return Symbol(self._arr-other._arr, dim=self.dim)
503          return Symbol(self._arr-other)          return Symbol(self._arr-other, dim=self.dim)
504    
505      def __rsub__(self, other):      def __rsub__(self, other):
506          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
507          if isinstance(other, Symbol):          if isinstance(other, Symbol):
508              return Symbol(other._arr-self._arr)              return Symbol(other._arr-self._arr, dim=self.dim)
509          return Symbol(other-self._arr)          return Symbol(other-self._arr, dim=self.dim)
510    
511      def __mul__(self, other):      def __mul__(self, other):
512          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
513          if isinstance(other, Symbol):          if isinstance(other, Symbol):
514              return Symbol(self._arr*other._arr)              return Symbol(self._arr*other._arr, dim=self.dim)
515          return Symbol(self._arr*other)          return Symbol(self._arr*other, dim=self.dim)
516    
517      def __rmul__(self, other):      def __rmul__(self, other):
518          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
519          if isinstance(other, Symbol):          if isinstance(other, Symbol):
520              return Symbol(other._arr*self._arr)              return Symbol(other._arr*self._arr, dim=self.dim)
521          return Symbol(other*self._arr)          return Symbol(other*self._arr, dim=self.dim)
522    
523      def __div__(self, other):      def __div__(self, other):
524          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
525          if isinstance(other, Symbol):          if isinstance(other, Symbol):
526              return Symbol(self._arr/other._arr)              return Symbol(self._arr/other._arr, dim=self.dim)
527          return Symbol(self._arr/other)          return Symbol(self._arr/other, dim=self.dim)
528    
529      def __rdiv__(self, other):      def __rdiv__(self, other):
530          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
531          if isinstance(other, Symbol):          if isinstance(other, Symbol):
532              return Symbol(other._arr/self._arr)              return Symbol(other._arr/self._arr, dim=self.dim)
533          return Symbol(other/self._arr)          return Symbol(other/self._arr, dim=self.dim)
534    
535      def __pow__(self, other):      def __pow__(self, other):
536          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
537          if isinstance(other, Symbol):          if isinstance(other, Symbol):
538              return Symbol(self._arr**other._arr)              return Symbol(self._arr**other._arr, dim=self.dim)
539          return Symbol(self._arr**other)          return Symbol(self._arr**other, dim=self.dim)
540    
541      def __rpow__(self, other):      def __rpow__(self, other):
542          self._ensureShapeCompatible(other)          self._ensureShapeCompatible(other)
543          if isinstance(other, Symbol):          if isinstance(other, Symbol):
544              return Symbol(other._arr**self._arr)              return Symbol(other._arr**self._arr, dim=self.dim)
545          return Symbol(other**self._arr)          return Symbol(other**self._arr, dim=self.dim)
546    
547    
548  def symbols(*names, **kwargs):  def symbols(*names, **kwargs):

Legend:
Removed from v.3518  
changed lines
  Added in v.3530

  ViewVC Help
Powered by ViewVC 1.1.26