/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Diff of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3532 by caltinay, Mon Jun 20 04:14:42 2011 UTC revision 3533 by caltinay, Wed Jun 22 04:40:08 2011 UTC
# Line 37  __author__="Cihan Altinay" Line 37  __author__="Cihan Altinay"
37    
38  class Symbol(object):  class Symbol(object):
39      """      """
40        `Symbol` objects are placeholders for a single mathematic symbol, such as
41        'x', or for arbitrarily complex mathematic expressions such as
42        'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
43        are also `Symbol`s (the symbolic 'atoms' of the expression).
44    
45        With the help of the 'Evaluator' class these symbols and expressions can
46        be resolved by substituting numeric values and/or escript `Data` objects
47        for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
48        shape (and thus a rank) as well as a dimension (see constructor).
49        `Symbol`s are useful to perform mathematic simplifications, compute
50        derivatives and as coefficients for nonlinear PDEs which can be solved by
51        the `NonlinearPDE` class.
52      """      """
53    
54      def __init__(self, *args, **kwargs):      def __init__(self, *args, **kwargs):
55          """          """
56          Initializes a new Symbol object.          Initialises a new `Symbol` object in one of three ways::
57    
58                u=Symbol('u')
59    
60            returns a scalar symbol by the name 'u'.
61    
62                a=Symbol('alpha', (3,2))
63    
64            returns a rank 2 symbol with the shape (4,3), whose elements are
65            named '[alpha]_i_j' (with i=0..3, j=0..2).
66    
67                a,b,c=symbols('a,b,c')
68                x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
69    
70            returns a rank 2 symbol with the shape (3,3) whose elements are
71            explicitely specified by numeric values and other symbols/expressions
72            within a list or numpy array.
73    
74            The dimensionality of the symbol can be specified through the `dim`
75            keyword. All other keywords are passed to the underlying symbolic
76            library (currently sympy).
77    
78            :param args: initialisation arguments as described above
79            :keyword dim: dimensionality of the new Symbol (default: 2)
80            :type dim: ``int``
81          """          """
82          if 'dim' in kwargs:          if 'dim' in kwargs:
83              self.dim=kwargs.pop('dim')              self.dim=kwargs.pop('dim')
# Line 52  class Symbol(object): Line 88  class Symbol(object):
88              arg=args[0]              arg=args[0]
89              if isinstance(arg, str):              if isinstance(arg, str):
90                  if arg.find('[')>=0 or arg.find(']')>=0:                  if arg.find('[')>=0 or arg.find(']')>=0:
91                      raise TypeError("Name must not contain '[' or ']'")                      raise ValueError("Name must not contain '[' or ']'")
92                  self._arr=numpy.array(sympy.Symbol(arg, **kwargs))                  self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
93              elif hasattr(arg, "__array__"):              elif hasattr(arg, "__array__") or isinstance(arg, list):
94                    if isinstance(arg, list): arg=numpy.array(arg)
95                  arr=arg.__array__()                  arr=arg.__array__()
96                  if len(arr.shape)>4:                  if len(arr.shape)>4:
97                      raise ValueError("Symbol only supports tensors up to order 4")                      raise ValueError("Symbol only supports tensors up to order 4")
98                  self._arr=arr.copy()                  res=numpy.empty(arr.shape, dtype=object)
99              elif isinstance(arg, list) or isinstance(arg, sympy.Basic):                  for idx in numpy.ndindex(arr.shape):
100                        if hasattr(arr[idx], "item"):
101                            res[idx]=arr[idx].item()
102                        else:
103                            res[idx]=arr[idx]
104                    self._arr=res
105                elif isinstance(arg, sympy.Basic):
106                  self._arr=numpy.array(arg)                  self._arr=numpy.array(arg)
107              else:              else:
108                  raise TypeError("Unsupported argument type %s"%str(type(arg)))                  raise TypeError("Unsupported argument type %s"%str(type(arg)))
109          elif len(args)==2:          elif len(args)==2:
110              if not isinstance(args[0], str):              if not isinstance(args[0], str):
111                  raise TypeError("First argument must be a string")                  raise TypeError("First argument must be a string")
             if args[0].find('[')>=0 or args[0].find(']')>=0:  
                 raise TypeError("Name must not contain '[' or ']'")  
112              if not isinstance(args[1], tuple):              if not isinstance(args[1], tuple):
113                  raise TypeError("Second argument must be a tuple")                  raise TypeError("Second argument must be a tuple")
114              name=args[0]              name=args[0]
115              shape=args[1]              shape=args[1]
116                if name.find('[')>=0 or name.find(']')>=0:
117                    raise ValueError("Name must not contain '[' or ']'")
118              if len(shape)>4:              if len(shape)>4:
119                  raise ValueError("Symbol only supports tensors up to order 4")                  raise ValueError("Symbol only supports tensors up to order 4")
120              if len(shape)==0:              if len(shape)==0:
# Line 104  class Symbol(object): Line 147  class Symbol(object):
147          return self._arr[key]          return self._arr[key]
148    
149      def __setitem__(self, key, value):      def __setitem__(self, key, value):
150          if isinstance(value,Symbol):          if isinstance(value, Symbol):
151              if value.getRank()==0:              if value.getRank()==0:
152                  self._arr[key]=value                  self._arr[key]=value.item()
153              elif hasattr(self._arr[key], "shape"):              elif hasattr(self._arr[key], "shape"):
154                  if self._arr[key].shape==value.getShape():                  if self._arr[key].shape==value.getShape():
155                      self._arr[key]=value                      for idx in numpy.ndindex(self._arr[key].shape):
156                            self._arr[key][idx]=value[idx]
157                  else:                  else:
158                      raise ValueError("Wrong shape of value")                      raise ValueError("Wrong shape of value")
159              else:              else:
160                  raise ValueError("Wrong shape of value")                  raise ValueError("Wrong shape of value")
161          elif isinstance(value,sympy.Basic):          elif isinstance(value, sympy.Basic):
162              self._arr[key]=value              self._arr[key]=value
163          elif hasattr(value, "__array__"):          elif hasattr(value, "__array__"):
164              self._arr[key]=map(sympy.sympify,value.flat)              self._arr[key]=map(sympy.sympify,value.flat)
165          else:          else:
166              self._arr[key]=sympy.sympify(value)              self._arr[key]=sympy.sympify(value)
167    
168        def getDim(self):
169            """
170            Returns the spatial dimensionality of this symbol.
171    
172            :return: the symbol's spatial dimensionality
173            :rtype: ``int``
174            """
175            return self.dim
176    
177      def getRank(self):      def getRank(self):
178            """
179            Returns the rank of this symbol.
180    
181            :return: the symbol's rank which is equal to the length of the shape.
182            :rtype: ``int``
183            """
184          return self._arr.ndim          return self._arr.ndim
185    
186      def getShape(self):      def getShape(self):
187            """
188            Returns the shape of this symbol.
189    
190            :return: the symbol's shape
191            :rtype: ``tuple`` of ``int``
192            """
193          return self._arr.shape          return self._arr.shape
194    
195        def item(self, *args):
196            """
197            Returns an element of this symbol.
198            This method behaves like the item() method of numpy.ndarray.
199            If this is a scalar Symbol, no arguments are allowed and the only
200            element in this Symbol is returned.
201            Otherwise, 'args' specifies a flat or nd-index and the element at
202            that index is returned.
203    
204            :param args: index of item to be returned
205            :return: the requested element
206            :rtype: ``sympy.Symbol``, ``int``, or ``float``
207            """
208            return self._arr.item(args)
209    
210      def atoms(self, *types):      def atoms(self, *types):
211            """
212            Returns the atoms that form the current Symbol.
213    
214            By default, only objects that are truly atomic and cannot be divided
215            into smaller pieces are returned: symbols, numbers, and number
216            symbols like I and pi. It is possible to request atoms of any type,
217            however.
218    
219            Note that if this symbol contains components such as [x]_i_j then
220            only their main symbol 'x' is returned.
221    
222            :param types: types to restrict result to
223            :return: list of atoms of specified type
224            :rtype: ``set``
225            """
226          s=set()          s=set()
227          for el in self._arr.flat:          for el in self._arr.flat:
228              if isinstance(el,sympy.Basic):              if isinstance(el,sympy.Basic):
# Line 138  class Symbol(object): Line 233  class Symbol(object):
233                          s.add(sympy.Symbol(n))                          s.add(sympy.Symbol(n))
234                      else:                      else:
235                          s.add(a)                          s.add(a)
236              else:              elif len(types)==0 or type(el) in types:
237                  # TODO: Numbers?                  s.add(el)
                 pass  
238          return s          return s
239    
240      def _sympystr_(self, printer):      def _sympystr_(self, printer):
# Line 203  class Symbol(object): Line 297  class Symbol(object):
297          return result          return result
298    
299      def diff(self, *symbols, **assumptions):      def diff(self, *symbols, **assumptions):
300            """
301            """
302          symbols=Symbol._symbolgen(*symbols)          symbols=Symbol._symbolgen(*symbols)
303          result=Symbol(self._arr, dim=self.dim)          result=Symbol(self._arr, dim=self.dim)
304          for s in symbols:          for s in symbols:
# Line 226  class Symbol(object): Line 322  class Symbol(object):
322          return result          return result
323    
324      def grad(self, where=None):      def grad(self, where=None):
325            """
326            """
327          if isinstance(where, Symbol):          if isinstance(where, Symbol):
328              if where.getRank()>0:              if where.getRank()>0:
329                  raise ValueError("grad: 'where' must be a scalar symbol")                  raise ValueError("grad: 'where' must be a scalar symbol")
# Line 243  class Symbol(object): Line 341  class Symbol(object):
341          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
342    
343      def inverse(self):      def inverse(self):
344            """
345            """
346          if not self.getRank()==2:          if not self.getRank()==2:
347              raise ValueError("inverse: Only rank 2 supported")              raise TypeError("inverse: Only rank 2 supported")
348          s=self.getShape()          s=self.getShape()
349          if not s[0] == s[1]:          if not s[0] == s[1]:
350              raise ValueError("inverse: Only square shapes supported")              raise ValueError("inverse: Only square shapes supported")
# Line 295  class Symbol(object): Line 395  class Symbol(object):
395          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
396    
397      def swap_axes(self, axis0, axis1):      def swap_axes(self, axis0, axis1):
398            """
399            """
400          return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)          return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
401    
402      def tensorProduct(self, other, axis_offset):      def tensorProduct(self, other, axis_offset):
403            """
404            """
405          arg0_c=self._arr.copy()          arg0_c=self._arr.copy()
406          sh0=self.getShape()          sh0=self.getShape()
407          if isinstance(other, Symbol):          if isinstance(other, Symbol):
# Line 320  class Symbol(object): Line 424  class Symbol(object):
424          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
425    
426      def transposedTensorProduct(self, other, axis_offset):      def transposedTensorProduct(self, other, axis_offset):
427            """
428            """
429          arg0_c=self._arr.copy()          arg0_c=self._arr.copy()
430          sh0=self.getShape()          sh0=self.getShape()
431          if isinstance(other, Symbol):          if isinstance(other, Symbol):
# Line 342  class Symbol(object): Line 448  class Symbol(object):
448          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
449    
450      def tensorTransposedProduct(self, other, axis_offset):      def tensorTransposedProduct(self, other, axis_offset):
451            """
452            """
453          arg0_c=self._arr.copy()          arg0_c=self._arr.copy()
454          sh0=self.getShape()          sh0=self.getShape()
455          if isinstance(other, Symbol):          if isinstance(other, Symbol):
# Line 366  class Symbol(object): Line 474  class Symbol(object):
474          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
475    
476      def trace(self, axis_offset):      def trace(self, axis_offset):
477            """
478            """
479          sh=self.getShape()          sh=self.getShape()
480          s1=1          s1=1
481          for i in range(axis_offset): s1*=sh[i]          for i in range(axis_offset): s1*=sh[i]
# Line 381  class Symbol(object): Line 491  class Symbol(object):
491          return Symbol(out, dim=self.dim)          return Symbol(out, dim=self.dim)
492    
493      def transpose(self, axis_offset):      def transpose(self, axis_offset):
494            """
495            """
496          if axis_offset is None:          if axis_offset is None:
497              axis_offset=int(self._arr.ndim/2)              axis_offset=int(self._arr.ndim/2)
498          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
499          return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)          return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
500    
501      def applyfunc(self, f, on_type=None):      def applyfunc(self, f, on_type=None):
502            """
503            """
504          assert callable(f)          assert callable(f)
505          if self._arr.ndim==0:          if self._arr.ndim==0:
506              if on_type is None or isinstance(self._arr.item(),on_type):              if on_type is None or isinstance(self._arr.item(),on_type):
# Line 408  class Symbol(object): Line 522  class Symbol(object):
522          return out          return out
523    
524      def simplify(self):      def simplify(self):
525            """
526            """
527          return self.applyfunc(sympy.simplify, sympy.Basic)          return self.applyfunc(sympy.simplify, sympy.Basic)
528    
529      def _sympy_(self):      def _sympy_(self):
530            """
531            """
532          return self.applyfunc(sympy.sympify)          return self.applyfunc(sympy.sympify)
533    
534      def _ensureShapeCompatible(self, other):      def _ensureShapeCompatible(self, other):
# Line 581  def symbols(*names, **kwargs): Line 699  def symbols(*names, **kwargs):
699      return res      return res
700    
701  def combineData(array, shape):  def combineData(array, shape):
702        """
703        """
704    
705      # array could just be a single value      # array could just be a single value
706      if not hasattr(array,'__len__') and shape==():      if not hasattr(array,'__len__') and shape==():
707          return array          return array
# Line 622  class SymFunction(Symbol): Line 743  class SymFunction(Symbol):
743      """      """
744      def __init__(self, *args, **kwargs):      def __init__(self, *args, **kwargs):
745          """          """
746          Initializes a new symbolic function object.          Initialises a new symbolic function object.
747          """          """
748          super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)          super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
749          self.args=args          self.args=args

Legend:
Removed from v.3532  
changed lines
  Added in v.3533

  ViewVC Help
Powered by ViewVC 1.1.26