/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Diff of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3532 by caltinay, Mon Jun 20 04:14:42 2011 UTC revision 3536 by caltinay, Thu Jun 23 04:42:38 2011 UTC
# Line 37  __author__="Cihan Altinay" Line 37  __author__="Cihan Altinay"
37    
38  class Symbol(object):  class Symbol(object):
39      """      """
40        `Symbol` objects are placeholders for a single mathematic symbol, such as
41        'x', or for arbitrarily complex mathematic expressions such as
42        'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
43        are also `Symbol`s (the symbolic 'atoms' of the expression).
44    
45        With the help of the 'Evaluator' class these symbols and expressions can
46        be resolved by substituting numeric values and/or escript `Data` objects
47        for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
48        shape (and thus a rank) as well as a dimension (see constructor).
49        `Symbol`s are useful to perform mathematic simplifications, compute
50        derivatives and as coefficients for nonlinear PDEs which can be solved by
51        the `NonlinearPDE` class.
52      """      """
53    
54      def __init__(self, *args, **kwargs):      def __init__(self, *args, **kwargs):
55          """          """
56          Initializes a new Symbol object.          Initialises a new `Symbol` object in one of three ways::
57    
58                u=Symbol('u')
59    
60            returns a scalar symbol by the name 'u'.
61    
62                a=Symbol('alpha', (4,3))
63    
64            returns a rank 2 symbol with the shape (4,3), whose elements are
65            named '[alpha]_i_j' (with i=0..3, j=0..2).
66    
67                a,b,c=symbols('a,b,c')
68                x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
69    
70            returns a rank 2 symbol with the shape (3,3) whose elements are
71            explicitly specified by numeric values and other symbols/expressions
72            within a list or numpy array.
73    
74            The dimensionality of the symbol can be specified through the `dim`
75            keyword. All other keywords are passed to the underlying symbolic
76            library (currently sympy).
77    
78            :param args: initialisation arguments as described above
79            :keyword dim: dimensionality of the new Symbol (default: 2)
80            :type dim: ``int``
81          """          """
82          if 'dim' in kwargs:          if 'dim' in kwargs:
83              self.dim=kwargs.pop('dim')              self.dim=kwargs.pop('dim')
# Line 52  class Symbol(object): Line 88  class Symbol(object):
88              arg=args[0]              arg=args[0]
89              if isinstance(arg, str):              if isinstance(arg, str):
90                  if arg.find('[')>=0 or arg.find(']')>=0:                  if arg.find('[')>=0 or arg.find(']')>=0:
91                      raise TypeError("Name must not contain '[' or ']'")                      raise ValueError("Name must not contain '[' or ']'")
92                  self._arr=numpy.array(sympy.Symbol(arg, **kwargs))                  self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
93              elif hasattr(arg, "__array__"):              elif hasattr(arg, "__array__") or isinstance(arg, list):
94                    if isinstance(arg, list): arg=numpy.array(arg)
95                  arr=arg.__array__()                  arr=arg.__array__()
96                  if len(arr.shape)>4:                  if len(arr.shape)>4:
97                      raise ValueError("Symbol only supports tensors up to order 4")                      raise ValueError("Symbol only supports tensors up to order 4")
98                  self._arr=arr.copy()                  res=numpy.empty(arr.shape, dtype=object)
99              elif isinstance(arg, list) or isinstance(arg, sympy.Basic):                  for idx in numpy.ndindex(arr.shape):
100                        if hasattr(arr[idx], "item"):
101                            res[idx]=arr[idx].item()
102                        else:
103                            res[idx]=arr[idx]
104                    self._arr=res
105                elif isinstance(arg, sympy.Basic):
106                  self._arr=numpy.array(arg)                  self._arr=numpy.array(arg)
107              else:              else:
108                  raise TypeError("Unsupported argument type %s"%str(type(arg)))                  raise TypeError("Unsupported argument type %s"%str(type(arg)))
109          elif len(args)==2:          elif len(args)==2:
110              if not isinstance(args[0], str):              if not isinstance(args[0], str):
111                  raise TypeError("First argument must be a string")                  raise TypeError("First argument must be a string")
             if args[0].find('[')>=0 or args[0].find(']')>=0:  
                 raise TypeError("Name must not contain '[' or ']'")  
112              if not isinstance(args[1], tuple):              if not isinstance(args[1], tuple):
113                  raise TypeError("Second argument must be a tuple")                  raise TypeError("Second argument must be a tuple")
114              name=args[0]              name=args[0]
115              shape=args[1]              shape=args[1]
116                if name.find('[')>=0 or name.find(']')>=0:
117                    raise ValueError("Name must not contain '[' or ']'")
118              if len(shape)>4:              if len(shape)>4:
119                  raise ValueError("Symbol only supports tensors up to order 4")                  raise ValueError("Symbol only supports tensors up to order 4")
120              if len(shape)==0:              if len(shape)==0:
# Line 104  class Symbol(object): Line 147  class Symbol(object):
147          return self._arr[key]          return self._arr[key]
148    
149      def __setitem__(self, key, value):      def __setitem__(self, key, value):
150          if isinstance(value,Symbol):          if isinstance(value, Symbol):
151              if value.getRank()==0:              if value.getRank()==0:
152                  self._arr[key]=value                  self._arr[key]=value.item()
153              elif hasattr(self._arr[key], "shape"):              elif hasattr(self._arr[key], "shape"):
154                  if self._arr[key].shape==value.getShape():                  if self._arr[key].shape==value.getShape():
155                      self._arr[key]=value                      for idx in numpy.ndindex(self._arr[key].shape):
156                            self._arr[key][idx]=value[idx]
157                  else:                  else:
158                      raise ValueError("Wrong shape of value")                      raise ValueError("Wrong shape of value")
159              else:              else:
160                  raise ValueError("Wrong shape of value")                  raise ValueError("Wrong shape of value")
161          elif isinstance(value,sympy.Basic):          elif isinstance(value, sympy.Basic):
162              self._arr[key]=value              self._arr[key]=value
163          elif hasattr(value, "__array__"):          elif hasattr(value, "__array__"):
164              self._arr[key]=map(sympy.sympify,value.flat)              self._arr[key]=map(sympy.sympify,value.flat)
165          else:          else:
166              self._arr[key]=sympy.sympify(value)              self._arr[key]=sympy.sympify(value)
167    
168        def getDim(self):
169            """
170            Returns the spatial dimensionality of this symbol.
171    
172            :return: the symbol's spatial dimensionality
173            :rtype: ``int``
174            """
175            return self.dim
176    
177      def getRank(self):      def getRank(self):
178            """
179            Returns the rank of this symbol.
180    
181            :return: the symbol's rank which is equal to the length of the shape.
182            :rtype: ``int``
183            """
184          return self._arr.ndim          return self._arr.ndim
185    
186      def getShape(self):      def getShape(self):
187            """
188            Returns the shape of this symbol.
189    
190            :return: the symbol's shape
191            :rtype: ``tuple`` of ``int``
192            """
193          return self._arr.shape          return self._arr.shape
194    
195        def item(self, *args):
196            """
197            Returns an element of this symbol.
198            This method behaves like the item() method of numpy.ndarray.
199            If this is a scalar Symbol, no arguments are allowed and the only
200            element in this Symbol is returned.
201            Otherwise, 'args' specifies a flat or nd-index and the element at
202            that index is returned.
203    
204            :param args: index of item to be returned
205            :return: the requested element
206            :rtype: ``sympy.Symbol``, ``int``, or ``float``
207            """
208            return self._arr.item(args)
209    
210      def atoms(self, *types):      def atoms(self, *types):
211            """
212            Returns the atoms that form the current Symbol.
213    
214            By default, only objects that are truly atomic and cannot be divided
215            into smaller pieces are returned: symbols, numbers, and number
216            symbols like I and pi. It is possible to request atoms of any type,
217            however.
218    
219            Note that if this symbol contains components such as [x]_i_j then
220            only their main symbol 'x' is returned.
221    
222            :param types: types to restrict result to
223            :return: list of atoms of specified type
224            :rtype: ``set``
225            """
226          s=set()          s=set()
227          for el in self._arr.flat:          for el in self._arr.flat:
228              if isinstance(el,sympy.Basic):              if isinstance(el,sympy.Basic):
# Line 138  class Symbol(object): Line 233  class Symbol(object):
233                          s.add(sympy.Symbol(n))                          s.add(sympy.Symbol(n))
234                      else:                      else:
235                          s.add(a)                          s.add(a)
236              else:              elif len(types)==0 or type(el) in types:
237                  # TODO: Numbers?                  s.add(el)
                 pass  
238          return s          return s
239    
240      def _sympystr_(self, printer):      def _sympystr_(self, printer):
# Line 171  class Symbol(object): Line 265  class Symbol(object):
265              return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))              return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
266    
267      def coeff(self, x, expand=True):      def coeff(self, x, expand=True):
268            """
269            Returns the coefficient of the term "x" or 0 if there is no "x".
270    
271            If "x" is a scalar symbol then "x" is searched in all components of
272            this symbol. Otherwise the shapes must match and the coefficients are
273            checked component by component.
274    
275            Example::
276            
277                x=Symbol('x', (2,2))
278                y=3*x
279                print y.coeff(x)
280                print y.coeff(x[1,1])
281    
282            will print::
283    
284                [[3 3]
285                 [3 3]]
286    
287                [[0 0]
288                 [0 3]]
289    
290            :param x: the term whose coefficients are to be found
291            :type x: ``Symbol``, ``numpy.ndarray``, `list`
292            :return: the coefficient(s) of the term
293            :rtype: ``Symbol``
294            """
295          self._ensureShapeCompatible(x)          self._ensureShapeCompatible(x)
296          result=Symbol(self._arr, dim=self.dim)          if hasattr(x, '__array__'):
297          if isinstance(x, Symbol):              y=x.__array__()
298              if x.getRank()>0:          else:
299                  a=result._arr.flat              y=numpy.array(x)
300                  b=x._arr.flat  
301                  for idx in range(len(a)):          if y.ndim>0:
302                      s=b.next()              result=numpy.zeros(self.getShape(), dtype=object)
303                      if s==0:              for idx in numpy.ndindex(y.shape):
304                          a[idx]=0                  if y[idx]!=0:
305                      else:                      res=self[idx].coeff(y[idx], expand)
306                          a[idx]=a[idx].coeff(s, expand)                      if res is not None:
307              else:                          result[idx]=res
308                  if x._arr.item()==0:          elif y.item()==0:
309                      result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)              result=numpy.zeros(self.getShape(), dtype=object)
310                  else:          else:
311                      coeff_item=lambda item: getattr(item, 'coeff')(x._arr.item(), expand)              coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
312                      result=result.applyfunc(coeff_item)              none_to_zero=lambda item: 0 if item is None else item
313          elif x==0:              result=self.applyfunc(coeff_item)
314              result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)              result=result.applyfunc(none_to_zero)._arr
315          else:          return Symbol(result, dim=self.dim)
             coeff_item=lambda item: getattr(item, 'coeff')(x, expand)  
             result=result.applyfunc(coeff_item)  
   
         # replace None by 0  
         if result is None: return 0  
         a=result._arr.flat  
         for idx in range(len(a)):  
             if a[idx] is None: a[idx]=0  
         return result  
316    
317      def diff(self, *symbols, **assumptions):      def diff(self, *symbols, **assumptions):
318            """
319            """
320          symbols=Symbol._symbolgen(*symbols)          symbols=Symbol._symbolgen(*symbols)
321          result=Symbol(self._arr, dim=self.dim)          result=Symbol(self._arr, dim=self.dim)
322          for s in symbols:          for s in symbols:
# Line 226  class Symbol(object): Line 340  class Symbol(object):
340          return result          return result
341    
342      def grad(self, where=None):      def grad(self, where=None):
343            """
344            """
345          if isinstance(where, Symbol):          if isinstance(where, Symbol):
346              if where.getRank()>0:              if where.getRank()>0:
347                  raise ValueError("grad: 'where' must be a scalar symbol")                  raise ValueError("grad: 'where' must be a scalar symbol")
# Line 243  class Symbol(object): Line 359  class Symbol(object):
359          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
360    
361      def inverse(self):      def inverse(self):
362            """
363            """
364          if not self.getRank()==2:          if not self.getRank()==2:
365              raise ValueError("inverse: Only rank 2 supported")              raise TypeError("inverse: Only rank 2 supported")
366          s=self.getShape()          s=self.getShape()
367          if not s[0] == s[1]:          if not s[0] == s[1]:
368              raise ValueError("inverse: Only square shapes supported")              raise ValueError("inverse: Only square shapes supported")
# Line 295  class Symbol(object): Line 413  class Symbol(object):
413          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
414    
415      def swap_axes(self, axis0, axis1):      def swap_axes(self, axis0, axis1):
416            """
417            """
418          return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)          return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
419    
420      def tensorProduct(self, other, axis_offset):      def tensorProduct(self, other, axis_offset):
421            """
422            """
423          arg0_c=self._arr.copy()          arg0_c=self._arr.copy()
424          sh0=self.getShape()          sh0=self.getShape()
425          if isinstance(other, Symbol):          if isinstance(other, Symbol):
# Line 320  class Symbol(object): Line 442  class Symbol(object):
442          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
443    
444      def transposedTensorProduct(self, other, axis_offset):      def transposedTensorProduct(self, other, axis_offset):
445            """
446            """
447          arg0_c=self._arr.copy()          arg0_c=self._arr.copy()
448          sh0=self.getShape()          sh0=self.getShape()
449          if isinstance(other, Symbol):          if isinstance(other, Symbol):
# Line 342  class Symbol(object): Line 466  class Symbol(object):
466          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
467    
468      def tensorTransposedProduct(self, other, axis_offset):      def tensorTransposedProduct(self, other, axis_offset):
469            """
470            """
471          arg0_c=self._arr.copy()          arg0_c=self._arr.copy()
472          sh0=self.getShape()          sh0=self.getShape()
473          if isinstance(other, Symbol):          if isinstance(other, Symbol):
# Line 366  class Symbol(object): Line 492  class Symbol(object):
492          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
493    
494      def trace(self, axis_offset):      def trace(self, axis_offset):
495            """
496            """
497          sh=self.getShape()          sh=self.getShape()
498          s1=1          s1=1
499          for i in range(axis_offset): s1*=sh[i]          for i in range(axis_offset): s1*=sh[i]
# Line 381  class Symbol(object): Line 509  class Symbol(object):
509          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
510    
511      def transpose(self, axis_offset):      def transpose(self, axis_offset):
512            """
513            """
514          if axis_offset is None:          if axis_offset is None:
515              axis_offset=int(self._arr.ndim/2)              axis_offset=int(self._arr.ndim/2)
516          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
517          return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)          return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
518    
519      def applyfunc(self, f, on_type=None):      def applyfunc(self, f, on_type=None):
520            """
521            """
522          assert callable(f)          assert callable(f)
523          if self._arr.ndim==0:          if self._arr.ndim==0:
524              if on_type is None or isinstance(self._arr.item(),on_type):              if on_type is None or isinstance(self._arr.item(),on_type):
# Line 408  class Symbol(object): Line 540  class Symbol(object):
540          return out          return out
541    
542      def simplify(self):      def simplify(self):
543            """
544            """
545          return self.applyfunc(sympy.simplify, sympy.Basic)          return self.applyfunc(sympy.simplify, sympy.Basic)
546    
547      def _sympy_(self):      def _sympy_(self):
548            """
549            """
550          return self.applyfunc(sympy.sympify)          return self.applyfunc(sympy.sympify)
551    
552      def _ensureShapeCompatible(self, other):      def _ensureShapeCompatible(self, other):
# Line 423  class Symbol(object): Line 559  class Symbol(object):
559              sh1=other.getShape()              sh1=other.getShape()
560          elif isinstance(other, numpy.ndarray):          elif isinstance(other, numpy.ndarray):
561              sh1=other.shape              sh1=other.shape
562            elif isinstance(other, list):
563                sh1=numpy.array(other).shape
564          elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):          elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
565              sh1=()              sh1=()
566          else:          else:
# Line 482  class Symbol(object): Line 620  class Symbol(object):
620              else:              else:
621                  yield s                  yield s
622    
623        def __array__(self):
624            return self._arr
625    
626      # unary/binary operations follow      # unary/binary operations follow
627    
628      def __pos__(self):      def __pos__(self):
# Line 581  def symbols(*names, **kwargs): Line 722  def symbols(*names, **kwargs):
722      return res      return res
723    
724  def combineData(array, shape):  def combineData(array, shape):
725        """
726        """
727    
728      # array could just be a single value      # array could just be a single value
729      if not hasattr(array,'__len__') and shape==():      if not hasattr(array,'__len__') and shape==():
730          return array          return array
# Line 622  class SymFunction(Symbol): Line 766  class SymFunction(Symbol):
766      """      """
767      def __init__(self, *args, **kwargs):      def __init__(self, *args, **kwargs):
768          """          """
769          Initializes a new symbolic function object.          Initialises a new symbolic function object.
770          """          """
771          super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)          super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
772          self.args=args          self.args=args

Legend:
Removed from v.3532  
changed lines
  Added in v.3536

  ViewVC Help
Powered by ViewVC 1.1.26