/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Diff of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py revision 3536 by caltinay, Thu Jun 23 04:42:38 2011 UTC trunk/escript/py_src/symbolic/symbols.py revision 3978 by caltinay, Thu Sep 20 04:36:17 2012 UTC
# Line 1  Line 1 
 # -*- coding: utf-8 -*-  
1    
2  ########################################################  ########################################################
3  #  #
4  # Copyright (c) 2003-2010 by University of Queensland  # Copyright (c) 2003-2012 by University of Queensland
5  # Earth Systems Science Computational Center (ESSCC)  # Earth Systems Science Computational Center (ESSCC)
6  # http://www.uq.edu.au/esscc  # http://www.uq.edu.au/esscc
7  #  #
# Line 12  Line 11 
11  #  #
12  ########################################################  ########################################################
13    
14  __copyright__="""Copyright (c) 2003-2010 by University of Queensland  __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15  Earth Systems Science Computational Center (ESSCC)  Earth Systems Science Computational Center (ESSCC)
16  http://www.uq.edu.au/esscc  http://www.uq.edu.au/esscc
17  Primary Business: Queensland, Australia"""  Primary Business: Queensland, Australia"""
18  __license__="""Licensed under the Open Software License version 3.0  __license__="""Licensed under the Open Software License version 3.0
19  http://www.opensource.org/licenses/osl-3.0.php"""  http://www.opensource.org/licenses/osl-3.0.php"""
20  __url__="https://launchpad.net/escript-finley"  __url__="https://launchpad.net/escript-finley"
21    __author__="Cihan Altinay"
22    
23  """  """
24  :var __author__: name of author  :var __author__: name of author
# Line 30  __url__="https://launchpad.net/escript-f Line 30  __url__="https://launchpad.net/escript-f
30  """  """
31    
32  import numpy  import numpy
33  import sympy  from esys.escript import Data, FunctionSpace, HAVE_SYMBOLS
34    if HAVE_SYMBOLS:
35  __author__="Cihan Altinay"      import sympy
   
36    
37      
38  class Symbol(object):  class Symbol(object):
39      """      """
40      `Symbol` objects are placeholders for a single mathematic symbol, such as      `Symbol` objects are placeholders for a single mathematical symbol, such as
41      'x', or for arbitrarily complex mathematic expressions such as      'x', or for arbitrarily complex mathematical expressions such as
42      'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'      'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
43      are also `Symbol`s (the symbolic 'atoms' of the expression).      are also `Symbol`s (the symbolic 'atoms' of the expression).
44    
# Line 46  class Symbol(object): Line 46  class Symbol(object):
46      be resolved by substituting numeric values and/or escript `Data` objects      be resolved by substituting numeric values and/or escript `Data` objects
47      for the atoms. To facilitate the use of `Data` objects a `Symbol` has a      for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
48      shape (and thus a rank) as well as a dimension (see constructor).      shape (and thus a rank) as well as a dimension (see constructor).
49      `Symbol`s are useful to perform mathematic simplifications, compute      `Symbol`s are useful to perform mathematical simplifications, compute
50      derivatives and as coefficients for nonlinear PDEs which can be solved by      derivatives and as coefficients for nonlinear PDEs which can be solved by
51      the `NonlinearPDE` class.      the `NonlinearPDE` class.
52      """      """
53    
54        # these are for compatibility with sympy.Symbol. lambdify checks these.
55        is_Add=False
56        is_Float=False
57    
58      def __init__(self, *args, **kwargs):      def __init__(self, *args, **kwargs):
59          """          """
60          Initialises a new `Symbol` object in one of three ways::          Initialises a new `Symbol` object in one of three ways::
# Line 79  class Symbol(object): Line 83  class Symbol(object):
83          :keyword dim: dimensionality of the new Symbol (default: 2)          :keyword dim: dimensionality of the new Symbol (default: 2)
84          :type dim: ``int``          :type dim: ``int``
85          """          """
86            if not HAVE_SYMBOLS:
87                raise RuntimeError("Trying to instantiate a Symbol but sympy not available")
88    
89          if 'dim' in kwargs:          if 'dim' in kwargs:
90              self.dim=kwargs.pop('dim')              self._dim=kwargs.pop('dim')
91            else:
92                self._dim=-1 # undefined
93    
94            if 'subs' in kwargs:
95                self._subs=kwargs.pop('subs')
96          else:          else:
97              self.dim=2              self._subs={}
98    
99          if len(args)==1:          if len(args)==1:
100              arg=args[0]              arg=args[0]
# Line 102  class Symbol(object): Line 114  class Symbol(object):
114                      else:                      else:
115                          res[idx]=arr[idx]                          res[idx]=arr[idx]
116                  self._arr=res                  self._arr=res
117                    if isinstance(arg, Symbol):
118                        self._subs.update(arg._subs)
119                        if self._dim==-1:
120                            self._dim=arg._dim
121              elif isinstance(arg, sympy.Basic):              elif isinstance(arg, sympy.Basic):
122                  self._arr=numpy.array(arg)                  self._arr=numpy.array(arg)
123              else:              else:
# Line 120  class Symbol(object): Line 136  class Symbol(object):
136              if len(shape)==0:              if len(shape)==0:
137                  self._arr=numpy.array(sympy.Symbol(name, **kwargs))                  self._arr=numpy.array(sympy.Symbol(name, **kwargs))
138              else:              else:
139                  self._arr=sympy.symarray(shape, '['+name+']')                  try:
140                        self._arr=sympy.symarray(shape, '['+name+']')
141                    except TypeError:
142                        self._arr=sympy.symarray('['+name+']', shape)
143          else:          else:
144              raise TypeError("Unsupported number of arguments")              raise TypeError("Unsupported number of arguments")
145          if self._arr.ndim==0:          if self._arr.ndim==0:
# Line 144  class Symbol(object): Line 163  class Symbol(object):
163          return (self._arr==other._arr).all()          return (self._arr==other._arr).all()
164    
165      def __getitem__(self, key):      def __getitem__(self, key):
166          return self._arr[key]          """
167            Returns an element of this symbol which must have rank >0.
168            Unlike item() this method converts sympy objects and numpy arrays into
169            escript Symbols in order to facilitate expressions that require
170            element access, such as: grad(u)[1]+x
171    
172            :param key: (nd-)index of item to be returned
173            :return: the requested element
174            :rtype: ``Symbol``, ``int``, or ``float``
175            """
176            res=self._arr[key]
177            # replace sympy Symbols/expressions by escript Symbols
178            if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
179                res=Symbol(res)
180            return res
181    
182      def __setitem__(self, key, value):      def __setitem__(self, key, value):
183          if isinstance(value, Symbol):          if isinstance(value, Symbol):
# Line 153  class Symbol(object): Line 186  class Symbol(object):
186              elif hasattr(self._arr[key], "shape"):              elif hasattr(self._arr[key], "shape"):
187                  if self._arr[key].shape==value.getShape():                  if self._arr[key].shape==value.getShape():
188                      for idx in numpy.ndindex(self._arr[key].shape):                      for idx in numpy.ndindex(self._arr[key].shape):
189                          self._arr[key][idx]=value[idx]                          self._arr[key][idx]=value[idx].item()
190                  else:                  else:
191                      raise ValueError("Wrong shape of value")                      raise ValueError("Wrong shape of value")
192              else:              else:
# Line 165  class Symbol(object): Line 198  class Symbol(object):
198          else:          else:
199              self._arr[key]=sympy.sympify(value)              self._arr[key]=sympy.sympify(value)
200    
201        def __iter__(self):
202            return self._arr.__iter__
203    
204        def __array__(self, t=None):
205            if t:
206                return self._arr.astype(t)
207            else:
208                return self._arr
209    
210        def _sympy_(self):
211            """
212            """
213            return self.applyfunc(sympy.sympify)
214    
215      def getDim(self):      def getDim(self):
216          """          """
217          Returns the spatial dimensionality of this symbol.          Returns the spatial dimensionality of this symbol.
218    
219          :return: the symbol's spatial dimensionality          :return: the symbol's spatial dimensionality, or -1 if undefined
220          :rtype: ``int``          :rtype: ``int``
221          """          """
222          return self.dim          return self._dim
223    
224      def getRank(self):      def getRank(self):
225          """          """
# Line 192  class Symbol(object): Line 239  class Symbol(object):
239          """          """
240          return self._arr.shape          return self._arr.shape
241    
242        def getDataSubstitutions(self):
243            """
244            Returns a dictionary of symbol names and the escript ``Data`` objects
245            they represent within this Symbol.
246    
247            :return: the dictionary of substituted ``Data`` objects
248            :rtype: ``dict``
249            """
250            return self._subs
251    
252      def item(self, *args):      def item(self, *args):
253          """          """
254          Returns an element of this symbol.          Returns an element of this symbol.
# Line 238  class Symbol(object): Line 295  class Symbol(object):
295          return s          return s
296    
297      def _sympystr_(self, printer):      def _sympystr_(self, printer):
298            # compatibility with sympy 1.6
299            return self._sympystr(printer)
300    
301        def _sympystr(self, printer):
302          return self.lambdarepr()          return self.lambdarepr()
303    
304      def lambdarepr(self):      def lambdarepr(self):
305            """
306            """
307          from sympy.printing.lambdarepr import lambdarepr          from sympy.printing.lambdarepr import lambdarepr
308          temp_arr=numpy.empty(self.getShape(), dtype=object)          temp_arr=numpy.empty(self.getShape(), dtype=object)
309          for idx,el in numpy.ndenumerate(self._arr):          for idx,el in numpy.ndenumerate(self._arr):
# Line 302  class Symbol(object): Line 365  class Symbol(object):
365              result=numpy.zeros(self.getShape(), dtype=object)              result=numpy.zeros(self.getShape(), dtype=object)
366              for idx in numpy.ndindex(y.shape):              for idx in numpy.ndindex(y.shape):
367                  if y[idx]!=0:                  if y[idx]!=0:
368                      res=self[idx].coeff(y[idx], expand)                      res=self._arr[idx].coeff(y[idx], expand)
369                      if res is not None:                      if res is not None:
370                          result[idx]=res                          result[idx]=res
371          elif y.item()==0:          elif y.item()==0:
# Line 311  class Symbol(object): Line 374  class Symbol(object):
374              coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)              coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
375              none_to_zero=lambda item: 0 if item is None else item              none_to_zero=lambda item: 0 if item is None else item
376              result=self.applyfunc(coeff_item)              result=self.applyfunc(coeff_item)
377              result=result.applyfunc(none_to_zero)._arr              result=result.applyfunc(none_to_zero)
378          return Symbol(result, dim=self.dim)          res=Symbol(result, dim=self._dim)
379            for i in self._subs: res.subs(i, self._subs[i])
380            return res
381    
382        def subs(self, old, new):
383            """
384            Substitutes an expression.
385            """
386            old._ensureShapeCompatible(new)
387            if isinstance(new, Data):
388                subs=self._subs.copy()
389                if isinstance(old, Symbol) and old.getRank()>0:
390                    old=Symbol(old.atoms(sympy.Symbol)[0])
391                subs[old]=new
392                result=Symbol(self._arr, dim=self._dim, subs=subs)
393            elif isinstance(old, Symbol) and old.getRank()>0:
394                if hasattr(new, '__array__'):
395                    new=new.__array__()
396                else:
397                    new=numpy.array(new)
398    
399                result=numpy.empty(self.getShape(), dtype=object)
400                if new.ndim>0:
401                    for idx in numpy.ndindex(self.getShape()):
402                        for symidx in numpy.ndindex(new.shape):
403                            result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
404                else: # substitute scalar for non-scalar
405                    for idx in numpy.ndindex(self.getShape()):
406                        for symidx in numpy.ndindex(old.getShape()):
407                            result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
408                result=Symbol(result, dim=self._dim, subs=self._subs)
409            else: # scalar
410                if isinstance(new, Symbol):
411                    new=new.item()
412                if isinstance(old, Symbol):
413                    old=old.item()
414                subs_item=lambda item: getattr(item, 'subs')(old, new)
415                result=self.applyfunc(subs_item)
416            return result
417    
418      def diff(self, *symbols, **assumptions):      def diff(self, *symbols, **assumptions):
419          """          """
420          """          """
421          symbols=Symbol._symbolgen(*symbols)          symbols=Symbol._symbolgen(*symbols)
422          result=Symbol(self._arr, dim=self.dim)          result=Symbol(self._arr, dim=self._dim, subs=self._subs)
423          for s in symbols:          for s in symbols:
424              if isinstance(s, Symbol):              if isinstance(s, Symbol):
425                  if s.getRank()==0:                  if s.getRank()==0:
# Line 330  class Symbol(object): Line 431  class Symbol(object):
431                      for d in range(dim):                      for d in range(dim):
432                          for idx in numpy.ndindex(self.getShape()):                          for idx in numpy.ndindex(self.getShape()):
433                              index=idx+(d,)                              index=idx+(d,)
434                              out[index]=out[index].diff(s[d], **assumptions)                              out[index]=out[index].diff(s[d].item(), **assumptions)
435                      result=Symbol(out, dim=self.dim)                      result=Symbol(out, dim=self._dim, subs=self._subs)
436                  else:                  else:
437                      raise ValueError("diff: Only rank 0 and 1 supported")                      raise ValueError("diff: argument must have rank 0 or 1")
438              else:              else:
439                  diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)                  diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
440                  result=result.applyfunc(diff_item)                  result=result.applyfunc(diff_item)
# Line 341  class Symbol(object): Line 442  class Symbol(object):
442    
443      def grad(self, where=None):      def grad(self, where=None):
444          """          """
445            Returns a symbol which represents the gradient of this symbol.
446            :type where: ``Symbol``, ``FunctionSpace``
447          """          """
448            if self._dim < 0:
449                raise ValueError("grad: cannot compute gradient as symbol has undefined dimensionality")
450            subs=self._subs
451          if isinstance(where, Symbol):          if isinstance(where, Symbol):
452              if where.getRank()>0:              if where.getRank()>0:
453                  raise ValueError("grad: 'where' must be a scalar symbol")                  raise ValueError("grad: 'where' must be a scalar symbol")
454              where=where._arr.item()              where=where._arr.item()
455            elif isinstance(where, FunctionSpace):
456                name='fs'+str(id(where))
457                fssym=Symbol(name)
458                subs=self._subs.copy()
459                subs.update({fssym:where})
460                where=name
461    
462          from functions import grad_n          from functions import grad_n
463          out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())          out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
464          for d in range(self.dim):          for d in range(self._dim):
465              for idx in numpy.ndindex(self.getShape()):              for idx in numpy.ndindex(self.getShape()):
466                  index=idx+(d,)                  index=idx+(d,)
467                  if where is None:                  if where is None:
468                      out[index]=grad_n(out[index],d)                      out[index]=grad_n(out[index],d)
469                  else:                  else:
470                      out[index]=grad_n(out[index],d,where)                      out[index]=grad_n(out[index],d,where)
471          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self._dim, subs=subs)
472    
473      def inverse(self):      def inverse(self):
474          """          """
# Line 410  class Symbol(object): Line 522  class Symbol(object):
522              out[2,2]=(A11*A22-A12*A21)*D              out[2,2]=(A11*A22-A12*A21)*D
523          else:          else:
524             raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")             raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
525          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self._dim, subs=self._subs)
526    
527      def swap_axes(self, axis0, axis1):      def swap_axes(self, axis0, axis1):
528          """          """
529          """          """
530          return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)          return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
531    
532      def tensorProduct(self, other, axis_offset):      def tensorProduct(self, other, axis_offset):
533          """          """
# Line 425  class Symbol(object): Line 537  class Symbol(object):
537          if isinstance(other, Symbol):          if isinstance(other, Symbol):
538              arg1_c=other._arr.copy()              arg1_c=other._arr.copy()
539              sh1=other.getShape()              sh1=other.getShape()
540                dim=other._dim if self._dim < 0 else self._dim
541          else:          else:
542              arg1_c=other.copy()              arg1_c=other.copy()
543              sh1=other.shape              sh1=other.shape
544                dim=self._dim
545          d0,d1,d01=1,1,1          d0,d1,d01=1,1,1
546          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
547          for i in sh1[axis_offset:]: d1*=i          for i in sh1[axis_offset:]: d1*=i
# Line 439  class Symbol(object): Line 553  class Symbol(object):
553              for i1 in range(d1):              for i1 in range(d1):
554                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
555          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
556          return Symbol(out, dim=self.dim)          subs=self._subs.copy()
557            subs.update(other._subs)
558            return Symbol(out, dim=dim, subs=subs)
559    
560      def transposedTensorProduct(self, other, axis_offset):      def transposedTensorProduct(self, other, axis_offset):
561          """          """
# Line 449  class Symbol(object): Line 565  class Symbol(object):
565          if isinstance(other, Symbol):          if isinstance(other, Symbol):
566              arg1_c=other._arr.copy()              arg1_c=other._arr.copy()
567              sh1=other.getShape()              sh1=other.getShape()
568                dim=other._dim if self._dim < 0 else self._dim
569          else:          else:
570              arg1_c=other.copy()              arg1_c=other.copy()
571              sh1=other.shape              sh1=other.shape
572                dim=self._dim
573          d0,d1,d01=1,1,1          d0,d1,d01=1,1,1
574          for i in sh0[axis_offset:]: d0*=i          for i in sh0[axis_offset:]: d0*=i
575          for i in sh1[axis_offset:]: d1*=i          for i in sh1[axis_offset:]: d1*=i
# Line 463  class Symbol(object): Line 581  class Symbol(object):
581              for i1 in range(d1):              for i1 in range(d1):
582                  out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])                  out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
583          out.resize(sh0[axis_offset:]+sh1[axis_offset:])          out.resize(sh0[axis_offset:]+sh1[axis_offset:])
584          return Symbol(out, dim=self.dim)          subs=self._subs.copy()
585            subs.update(other._subs)
586            return Symbol(out, dim=dim, subs=subs)
587    
588      def tensorTransposedProduct(self, other, axis_offset):      def tensorTransposedProduct(self, other, axis_offset):
589          """          """
# Line 474  class Symbol(object): Line 594  class Symbol(object):
594              arg1_c=other._arr.copy()              arg1_c=other._arr.copy()
595              sh1=other.getShape()              sh1=other.getShape()
596              r1=other.getRank()              r1=other.getRank()
597                dim=other._dim if self._dim < 0 else self._dim
598          else:          else:
599              arg1_c=other.copy()              arg1_c=other.copy()
600              sh1=other.shape              sh1=other.shape
601              r1=other.ndim              r1=other.ndim
602                dim=self._dim
603          d0,d1,d01=1,1,1          d0,d1,d01=1,1,1
604          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
605          for i in sh1[:r1-axis_offset]: d1*=i          for i in sh1[:r1-axis_offset]: d1*=i
# Line 489  class Symbol(object): Line 611  class Symbol(object):
611              for i1 in range(d1):              for i1 in range(d1):
612                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
613          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
614          return Symbol(out, dim=self.dim)          subs=self._subs.copy()
615            subs.update(other._subs)
616            return Symbol(out, dim=dim, subs=subs)
617    
618      def trace(self, axis_offset):      def trace(self, axis_offset):
619          """          """
620            Returns the trace of this Symbol.
621          """          """
622          sh=self.getShape()          sh=self.getShape()
623          s1=1          s1=1
# Line 506  class Symbol(object): Line 631  class Symbol(object):
631                  for j in range(sh[axis_offset]):                  for j in range(sh[axis_offset]):
632                      out[i1,i2]+=arr_r[i1,j,j,i2]                      out[i1,i2]+=arr_r[i1,j,j,i2]
633          out.resize(sh[:axis_offset]+sh[axis_offset+2:])          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
634          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self._dim, subs=self._subs)
635    
636      def transpose(self, axis_offset):      def transpose(self, axis_offset):
637          """          """
638            Returns the transpose of this Symbol.
639          """          """
640          if axis_offset is None:          if axis_offset is None:
641              axis_offset=int(self._arr.ndim/2)              axis_offset=int(self._arr.ndim/2)
642          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
643          return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)          return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
644    
645      def applyfunc(self, f, on_type=None):      def applyfunc(self, f, on_type=None):
646          """          """
647            Applies the function `f` to all elements (if on_type is None) or to
648            all elements of type `on_type`.
649          """          """
650          assert callable(f)          assert callable(f)
651          if self._arr.ndim==0:          if self._arr.ndim==0:
652              if on_type is None or isinstance(self._arr.item(),on_type):              if on_type is None or isinstance(self._arr.item(), on_type):
653                  el=f(self._arr.item())                  el=f(self._arr.item())
654              else:              else:
655                  el=self._arr.item()                  el=self._arr.item()
656              if el is not None:              if el is not None:
657                  out=Symbol(el, dim=self.dim)                  out=Symbol(el, dim=self._dim, subs=self._subs)
658              else:              else:
659                  return el                  return el
660          else:          else:
# Line 536  class Symbol(object): Line 664  class Symbol(object):
664                      out[idx]=f(self._arr[idx])                      out[idx]=f(self._arr[idx])
665                  else:                  else:
666                      out[idx]=self._arr[idx]                      out[idx]=self._arr[idx]
667              out=Symbol(out, dim=self.dim)              out=Symbol(out, dim=self._dim, subs=self._subs)
668          return out          return out
669    
670      def simplify(self):      def expand(self):
671          """          """
672            Applies the sympy.expand operation on all elements in this symbol
673          """          """
674          return self.applyfunc(sympy.simplify, sympy.Basic)          return self.applyfunc(sympy.expand, sympy.Basic)
675    
676      def _sympy_(self):      def simplify(self):
677          """          """
678            Applies the sympy.simplify operation on all elements in this symbol
679          """          """
680          return self.applyfunc(sympy.sympify)          return self.applyfunc(sympy.simplify, sympy.Basic)
681    
682        # unary/binary operators follow
683    
684        def __pos__(self):
685            return self
686    
687        def __neg__(self):
688            return Symbol(-self._arr, dim=self._dim, subs=self._subs)
689    
690        def __abs__(self):
691            return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
692    
693      def _ensureShapeCompatible(self, other):      def _ensureShapeCompatible(self, other):
694          """          """
# Line 555  class Symbol(object): Line 696  class Symbol(object):
696          Raises TypeError if not compatible.          Raises TypeError if not compatible.
697          """          """
698          sh0=self.getShape()          sh0=self.getShape()
699          if isinstance(other, Symbol):          if isinstance(other, Symbol) or isinstance(other, Data):
700              sh1=other.getShape()              sh1=other.getShape()
701          elif isinstance(other, numpy.ndarray):          elif isinstance(other, numpy.ndarray):
702              sh1=other.shape              sh1=other.shape
# Line 564  class Symbol(object): Line 705  class Symbol(object):
705          elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):          elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
706              sh1=()              sh1=()
707          else:          else:
708              raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)              raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
709          if not sh0==sh1 and not sh0==() and not sh1==():          if not sh0==sh1 and not sh0==() and not sh1==():
710              raise TypeError("Incompatible shapes for binary operation")              raise TypeError("Incompatible shapes for operation")
711    
712        def __binaryop(self, op, other):
713            """
714            Helper for binary operations that checks types, shapes etc.
715            """
716            self._ensureShapeCompatible(other)
717            if isinstance(other, Symbol):
718                subs=self._subs.copy()
719                subs.update(other._subs)
720                dim=other._dim if self._dim < 0 else self._dim
721                return Symbol(getattr(self._arr, op)(other._arr), dim=dim, subs=subs)
722            if isinstance(other, Data):
723                name='data'+str(id(other))
724                othersym=Symbol(name, other.getShape(), dim=self._dim)
725                subs=self._subs.copy()
726                subs.update({Symbol(name):other})
727                return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
728            return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
729    
730        def __add__(self, other):
731            return self.__binaryop('__add__', other)
732    
733        def __radd__(self, other):
734            return self.__binaryop('__radd__', other)
735    
736        def __sub__(self, other):
737            return self.__binaryop('__sub__', other)
738    
739        def __rsub__(self, other):
740            return self.__binaryop('__rsub__', other)
741    
742        def __mul__(self, other):
743            return self.__binaryop('__mul__', other)
744    
745        def __rmul__(self, other):
746            return self.__binaryop('__rmul__', other)
747    
748        def __div__(self, other):
749            return self.__binaryop('__div__', other)
750    
751        def __rdiv__(self, other):
752            return self.__binaryop('__rdiv__', other)
753    
754        def __pow__(self, other):
755            return self.__binaryop('__pow__', other)
756    
757        def __rpow__(self, other):
758            return self.__binaryop('__rpow__', other)
759    
760        # static methods
761    
762      @staticmethod      @staticmethod
763      def _symComp(sym):      def _symComp(sym):
764            """
765            """
766          n=sym.name          n=sym.name
767          a=n.split('[')          a=n.split('[')
768          if len(a)!=2:          if len(a)!=2:
# Line 620  class Symbol(object): Line 813  class Symbol(object):
813              else:              else:
814                  yield s                  yield s
815    
     def __array__(self):  
         return self._arr  
   
     # unary/binary operations follow  
   
     def __pos__(self):  
         return self  
   
     def __neg__(self):  
         return Symbol(-self._arr, dim=self.dim)  
   
     def __abs__(self):  
         return Symbol(abs(self._arr), dim=self.dim)  
   
     def __add__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr+other._arr, dim=self.dim)  
         return Symbol(self._arr+other, dim=self.dim)  
   
     def __radd__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr+self._arr, dim=self.dim)  
         return Symbol(other+self._arr, dim=self.dim)  
   
     def __sub__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr-other._arr, dim=self.dim)  
         return Symbol(self._arr-other, dim=self.dim)  
   
     def __rsub__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr-self._arr, dim=self.dim)  
         return Symbol(other-self._arr, dim=self.dim)  
   
     def __mul__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr*other._arr, dim=self.dim)  
         return Symbol(self._arr*other, dim=self.dim)  
   
     def __rmul__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr*self._arr, dim=self.dim)  
         return Symbol(other*self._arr, dim=self.dim)  
   
     def __div__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr/other._arr, dim=self.dim)  
         return Symbol(self._arr/other, dim=self.dim)  
   
     def __rdiv__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr/self._arr, dim=self.dim)  
         return Symbol(other/self._arr, dim=self.dim)  
   
     def __pow__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(self._arr**other._arr, dim=self.dim)  
         return Symbol(self._arr**other, dim=self.dim)  
   
     def __rpow__(self, other):  
         self._ensureShapeCompatible(other)  
         if isinstance(other, Symbol):  
             return Symbol(other._arr**self._arr, dim=self.dim)  
         return Symbol(other**self._arr, dim=self.dim)  
   
   
 def symbols(*names, **kwargs):  
     """  
     Emulates the behaviour of sympy.symbols.  
     """  
   
     shape=kwargs.pop('shape', ())  
   
     s = names[0]  
     if not isinstance(s, list):  
         import re  
         s = re.split('\s|,', s)  
     res = []  
     for t in s:  
         # skip empty strings  
         if not t:  
             continue  
         sym = Symbol(t, shape, **kwargs)  
         res.append(sym)  
     res = tuple(res)  
     if len(res) == 0:   # var('')  
         res = None  
     elif len(res) == 1: # var('x')  
         res = res[0]  
                         # otherwise var('a b ...')  
     return res  
   
 def combineData(array, shape):  
     """  
     """  
   
     # array could just be a single value  
     if not hasattr(array,'__len__') and shape==():  
         return array  
   
     from esys.escript import Data  
     n=numpy.array(array) # for indexing  
   
     # find function space if any  
     dom=set()  
     fs=set()  
     for idx in numpy.ndindex(shape):  
         if isinstance(n[idx], Data):  
             fs.add(n[idx].getFunctionSpace())  
             dom.add(n[idx].getDomain())  
   
     if len(dom)>1:  
         domain=dom.pop()  
         while len(dom)>0:  
             if domain!=dom.pop():  
                 raise ValueError("Mixing of domains not supported")  
   
     if len(fs)>0:  
         d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?  
     else:  
         d=numpy.zeros(shape)  
     for idx in numpy.ndindex(shape):  
         #z=numpy.zeros(shape)  
         #z[idx]=1.  
         #d+=n[idx]*z # much slower!  
         if hasattr(n[idx], "ndim") and n[idx].ndim==0:  
             d[idx]=float(n[idx])  
         else:  
             d[idx]=n[idx]  
     return d  
   
   
 class SymFunction(Symbol):  
     """  
     """  
     def __init__(self, *args, **kwargs):  
         """  
         Initialises a new symbolic function object.  
         """  
         super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)  
         self.args=args  
   
     def __repr__(self):  
         return self.name+"("+", ".join([str(a) for a in self.args])+")"  
   
     def __str__(self):  
         return self.name+"("+", ".join([str(a) for a in self.args])+")"  
   
     def lambdarepr(self):  
         return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"  
   
     def atoms(self, *types):  
         s=set()  
         for el in self.args:  
             atoms=el.atoms(*types)  
             for a in atoms:  
                 if a.is_Symbol:  
                     n,c=Symbol._symComp(a)  
                     s.add(sympy.Symbol(n))  
                 else:  
                     s.add(a)  
         return s  
   
     def __neg__(self):  
         res=self.__class__(*self.args)  
         res._arr=-res._arr  
         return res  
   

Legend:
Removed from v.3536  
changed lines
  Added in v.3978

  ViewVC Help
Powered by ViewVC 1.1.26