/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3631 - (hide annotations)
Thu Oct 20 03:12:46 2011 UTC (8 years ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 28318 byte(s)
Fix compatibility with latest sympy version.

1 caltinay 3507 # -*- coding: utf-8 -*-
2    
3     ########################################################
4     #
5     # Copyright (c) 2003-2010 by University of Queensland
6     # Earth Systems Science Computational Center (ESSCC)
7     # http://www.uq.edu.au/esscc
8     #
9     # Primary Business: Queensland, Australia
10     # Licensed under the Open Software License version 3.0
11     # http://www.opensource.org/licenses/osl-3.0.php
12     #
13     ########################################################
14    
15     __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16     Earth Systems Science Computational Center (ESSCC)
17     http://www.uq.edu.au/esscc
18     Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22    
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34    
35     __author__="Cihan Altinay"
36    
37    
38     class Symbol(object):
39     """
40 caltinay 3533 `Symbol` objects are placeholders for a single mathematic symbol, such as
41     'x', or for arbitrarily complex mathematic expressions such as
42     'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
43     are also `Symbol`s (the symbolic 'atoms' of the expression).
44    
45     With the help of the 'Evaluator' class these symbols and expressions can
46     be resolved by substituting numeric values and/or escript `Data` objects
47     for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
48     shape (and thus a rank) as well as a dimension (see constructor).
49     `Symbol`s are useful to perform mathematic simplifications, compute
50     derivatives and as coefficients for nonlinear PDEs which can be solved by
51     the `NonlinearPDE` class.
52 caltinay 3507 """
53    
54     def __init__(self, *args, **kwargs):
55     """
56 caltinay 3533 Initialises a new `Symbol` object in one of three ways::
57    
58     u=Symbol('u')
59    
60     returns a scalar symbol by the name 'u'.
61    
62 caltinay 3535 a=Symbol('alpha', (4,3))
63 caltinay 3533
64     returns a rank 2 symbol with the shape (4,3), whose elements are
65     named '[alpha]_i_j' (with i=0..3, j=0..2).
66    
67     a,b,c=symbols('a,b,c')
68     x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
69    
70     returns a rank 2 symbol with the shape (3,3) whose elements are
71 caltinay 3536 explicitly specified by numeric values and other symbols/expressions
72 caltinay 3533 within a list or numpy array.
73    
74     The dimensionality of the symbol can be specified through the `dim`
75     keyword. All other keywords are passed to the underlying symbolic
76     library (currently sympy).
77    
78     :param args: initialisation arguments as described above
79     :keyword dim: dimensionality of the new Symbol (default: 2)
80     :type dim: ``int``
81 caltinay 3507 """
82 caltinay 3530 if 'dim' in kwargs:
83     self.dim=kwargs.pop('dim')
84     else:
85     self.dim=2
86    
87 caltinay 3507 if len(args)==1:
88     arg=args[0]
89     if isinstance(arg, str):
90 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
91 caltinay 3533 raise ValueError("Name must not contain '[' or ']'")
92 caltinay 3512 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
93 caltinay 3533 elif hasattr(arg, "__array__") or isinstance(arg, list):
94     if isinstance(arg, list): arg=numpy.array(arg)
95 caltinay 3507 arr=arg.__array__()
96     if len(arr.shape)>4:
97     raise ValueError("Symbol only supports tensors up to order 4")
98 caltinay 3533 res=numpy.empty(arr.shape, dtype=object)
99     for idx in numpy.ndindex(arr.shape):
100     if hasattr(arr[idx], "item"):
101     res[idx]=arr[idx].item()
102     else:
103     res[idx]=arr[idx]
104     self._arr=res
105     elif isinstance(arg, sympy.Basic):
106 caltinay 3512 self._arr=numpy.array(arg)
107 caltinay 3507 else:
108     raise TypeError("Unsupported argument type %s"%str(type(arg)))
109     elif len(args)==2:
110     if not isinstance(args[0], str):
111     raise TypeError("First argument must be a string")
112     if not isinstance(args[1], tuple):
113     raise TypeError("Second argument must be a tuple")
114     name=args[0]
115     shape=args[1]
116 caltinay 3533 if name.find('[')>=0 or name.find(']')>=0:
117     raise ValueError("Name must not contain '[' or ']'")
118 caltinay 3507 if len(shape)>4:
119     raise ValueError("Symbol only supports tensors up to order 4")
120 caltinay 3509 if len(shape)==0:
121 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
122 caltinay 3509 else:
123 caltinay 3631 try:
124     self._arr=sympy.symarray(shape, '['+name+']')
125     except TypeError:
126     self._arr=sympy.symarray('['+name+']', shape)
127 caltinay 3507 else:
128     raise TypeError("Unsupported number of arguments")
129 caltinay 3512 if self._arr.ndim==0:
130     self.name=str(self._arr.item())
131 caltinay 3507 else:
132 caltinay 3512 self.name=str(self._arr.tolist())
133 caltinay 3507
134     def __repr__(self):
135 caltinay 3512 return str(self._arr)
136 caltinay 3507
137     def __str__(self):
138 caltinay 3512 return str(self._arr)
139 caltinay 3507
140     def __eq__(self, other):
141     if type(self) is not type(other):
142     return False
143     if self.getRank()!=other.getRank():
144     return False
145     if self.getShape()!=other.getShape():
146     return False
147 caltinay 3512 return (self._arr==other._arr).all()
148 caltinay 3507
149     def __getitem__(self, key):
150 caltinay 3512 return self._arr[key]
151 caltinay 3507
152     def __setitem__(self, key, value):
153 caltinay 3533 if isinstance(value, Symbol):
154 caltinay 3507 if value.getRank()==0:
155 caltinay 3533 self._arr[key]=value.item()
156 caltinay 3512 elif hasattr(self._arr[key], "shape"):
157     if self._arr[key].shape==value.getShape():
158 caltinay 3533 for idx in numpy.ndindex(self._arr[key].shape):
159     self._arr[key][idx]=value[idx]
160 caltinay 3507 else:
161     raise ValueError("Wrong shape of value")
162     else:
163     raise ValueError("Wrong shape of value")
164 caltinay 3533 elif isinstance(value, sympy.Basic):
165 caltinay 3512 self._arr[key]=value
166 caltinay 3507 elif hasattr(value, "__array__"):
167 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
168 caltinay 3507 else:
169 caltinay 3512 self._arr[key]=sympy.sympify(value)
170 caltinay 3507
171 caltinay 3533 def getDim(self):
172     """
173     Returns the spatial dimensionality of this symbol.
174    
175     :return: the symbol's spatial dimensionality
176     :rtype: ``int``
177     """
178     return self.dim
179    
180 caltinay 3507 def getRank(self):
181 caltinay 3533 """
182     Returns the rank of this symbol.
183    
184     :return: the symbol's rank which is equal to the length of the shape.
185     :rtype: ``int``
186     """
187 caltinay 3512 return self._arr.ndim
188 caltinay 3507
189     def getShape(self):
190 caltinay 3533 """
191     Returns the shape of this symbol.
192    
193     :return: the symbol's shape
194     :rtype: ``tuple`` of ``int``
195     """
196 caltinay 3512 return self._arr.shape
197 caltinay 3507
198 caltinay 3533 def item(self, *args):
199     """
200     Returns an element of this symbol.
201     This method behaves like the item() method of numpy.ndarray.
202     If this is a scalar Symbol, no arguments are allowed and the only
203     element in this Symbol is returned.
204     Otherwise, 'args' specifies a flat or nd-index and the element at
205     that index is returned.
206    
207     :param args: index of item to be returned
208     :return: the requested element
209     :rtype: ``sympy.Symbol``, ``int``, or ``float``
210     """
211     return self._arr.item(args)
212    
213 caltinay 3507 def atoms(self, *types):
214 caltinay 3533 """
215     Returns the atoms that form the current Symbol.
216    
217     By default, only objects that are truly atomic and cannot be divided
218     into smaller pieces are returned: symbols, numbers, and number
219     symbols like I and pi. It is possible to request atoms of any type,
220     however.
221    
222     Note that if this symbol contains components such as [x]_i_j then
223     only their main symbol 'x' is returned.
224    
225     :param types: types to restrict result to
226     :return: list of atoms of specified type
227     :rtype: ``set``
228     """
229 caltinay 3507 s=set()
230 caltinay 3512 for el in self._arr.flat:
231 caltinay 3530 if isinstance(el,sympy.Basic):
232     atoms=el.atoms(*types)
233     for a in atoms:
234     if a.is_Symbol:
235     n,c=Symbol._symComp(a)
236     s.add(sympy.Symbol(n))
237     else:
238     s.add(a)
239 caltinay 3533 elif len(types)==0 or type(el) in types:
240     s.add(el)
241 caltinay 3507 return s
242    
243     def _sympystr_(self, printer):
244     return self.lambdarepr()
245    
246     def lambdarepr(self):
247     from sympy.printing.lambdarepr import lambdarepr
248 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
249     for idx,el in numpy.ndenumerate(self._arr):
250 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
251 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
252 caltinay 3507 symdict={}
253     for a in atoms:
254     n,c=Symbol._symComp(a)
255     if len(c)>0:
256     c=[str(i) for i in c]
257     symstr=n+'['+','.join(c)+']'
258     else:
259     symstr=n
260     symdict[a.name]=symstr
261     s=lambdarepr(el)
262     for key in symdict:
263     s=s.replace(key, symdict[key])
264     temp_arr[idx]=s
265 caltinay 3517 if self.getRank()==0:
266     return temp_arr.item()
267     else:
268     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
269 caltinay 3507
270 caltinay 3530 def coeff(self, x, expand=True):
271 caltinay 3536 """
272     Returns the coefficient of the term "x" or 0 if there is no "x".
273    
274     If "x" is a scalar symbol then "x" is searched in all components of
275     this symbol. Otherwise the shapes must match and the coefficients are
276     checked component by component.
277    
278     Example::
279    
280     x=Symbol('x', (2,2))
281     y=3*x
282     print y.coeff(x)
283     print y.coeff(x[1,1])
284    
285     will print::
286    
287     [[3 3]
288     [3 3]]
289    
290     [[0 0]
291     [0 3]]
292    
293     :param x: the term whose coefficients are to be found
294     :type x: ``Symbol``, ``numpy.ndarray``, `list`
295     :return: the coefficient(s) of the term
296     :rtype: ``Symbol``
297     """
298 caltinay 3530 self._ensureShapeCompatible(x)
299 caltinay 3536 if hasattr(x, '__array__'):
300     y=x.__array__()
301 caltinay 3530 else:
302 caltinay 3536 y=numpy.array(x)
303 caltinay 3530
304 caltinay 3536 if y.ndim>0:
305     result=numpy.zeros(self.getShape(), dtype=object)
306     for idx in numpy.ndindex(y.shape):
307     if y[idx]!=0:
308     res=self[idx].coeff(y[idx], expand)
309     if res is not None:
310     result[idx]=res
311     elif y.item()==0:
312     result=numpy.zeros(self.getShape(), dtype=object)
313     else:
314     coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
315     none_to_zero=lambda item: 0 if item is None else item
316     result=self.applyfunc(coeff_item)
317 caltinay 3614 result=result.applyfunc(none_to_zero)
318 caltinay 3536 return Symbol(result, dim=self.dim)
319 caltinay 3530
320 caltinay 3614 def subs(self, old, new):
321     """
322     Substitutes an expression.
323     """
324     if isinstance(old, Symbol) and old.getRank()>0:
325     old._ensureShapeCompatible(new)
326     if hasattr(new, '__array__'):
327     new=new.__array__()
328     else:
329     new=numpy.array(new)
330    
331     result=Symbol(self._arr, dim=self.dim)
332     if new.ndim>0:
333     for idx in numpy.ndindex(self.getShape()):
334     for symidx in numpy.ndindex(new.shape):
335     result[idx]=result[idx].subs(old[symidx], new[symidx])
336     else: # substitute scalar for non-scalar
337     for idx in numpy.ndindex(self.getShape()):
338     for symidx in numpy.ndindex(old.getShape()):
339     result[idx]=result[idx].subs(old[symidx], new.item())
340     else: # scalar
341     if isinstance(new, Symbol):
342     if new.getRank()>0:
343     raise TypeError("Cannot substitute, incompatible ranks.")
344     new=new.item()
345     if isinstance(old, Symbol):
346     old=old.item()
347     subs_item=lambda item: getattr(item, 'subs')(old, new)
348     result=self.applyfunc(subs_item)
349     return result
350    
351 caltinay 3507 def diff(self, *symbols, **assumptions):
352 caltinay 3533 """
353     """
354 caltinay 3507 symbols=Symbol._symbolgen(*symbols)
355 caltinay 3530 result=Symbol(self._arr, dim=self.dim)
356 caltinay 3507 for s in symbols:
357     if isinstance(s, Symbol):
358 caltinay 3530 if s.getRank()==0:
359 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
360 caltinay 3507 result=result.applyfunc(diff_item)
361 caltinay 3530 elif s.getRank()==1:
362     dim=s.getShape()[0]
363     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
364     for d in range(dim):
365     for idx in numpy.ndindex(self.getShape()):
366     index=idx+(d,)
367     out[index]=out[index].diff(s[d], **assumptions)
368     result=Symbol(out, dim=self.dim)
369     else:
370     raise ValueError("diff: Only rank 0 and 1 supported")
371 caltinay 3507 else:
372     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
373     result=result.applyfunc(diff_item)
374     return result
375    
376 caltinay 3512 def grad(self, where=None):
377 caltinay 3533 """
378     """
379 caltinay 3512 if isinstance(where, Symbol):
380     if where.getRank()>0:
381 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
382 caltinay 3512 where=where._arr.item()
383    
384     from functions import grad_n
385 caltinay 3530 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
386     for d in range(self.dim):
387 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
388     index=idx+(d,)
389     if where is None:
390     out[index]=grad_n(out[index],d)
391     else:
392     out[index]=grad_n(out[index],d,where)
393 caltinay 3530 return Symbol(out, dim=self.dim)
394 caltinay 3512
395 caltinay 3517 def inverse(self):
396 caltinay 3533 """
397     """
398 caltinay 3517 if not self.getRank()==2:
399 caltinay 3533 raise TypeError("inverse: Only rank 2 supported")
400 caltinay 3517 s=self.getShape()
401     if not s[0] == s[1]:
402     raise ValueError("inverse: Only square shapes supported")
403     out=numpy.zeros(s, numpy.object)
404     arr=self._arr
405     if s[0]==1:
406     if arr[0,0].is_zero:
407     raise ZeroDivisionError("inverse: Symbol not invertible")
408     out[0,0]=1./arr[0,0]
409     elif s[0]==2:
410     A11=arr[0,0]
411     A12=arr[0,1]
412     A21=arr[1,0]
413     A22=arr[1,1]
414     D = A11*A22-A12*A21
415     if D.is_zero:
416     raise ZeroDivisionError("inverse: Symbol not invertible")
417     D=1./D
418     out[0,0]= A22*D
419     out[1,0]=-A21*D
420     out[0,1]=-A12*D
421     out[1,1]= A11*D
422     elif s[0]==3:
423     A11=arr[0,0]
424     A21=arr[1,0]
425     A31=arr[2,0]
426     A12=arr[0,1]
427     A22=arr[1,1]
428     A32=arr[2,1]
429     A13=arr[0,2]
430     A23=arr[1,2]
431     A33=arr[2,2]
432     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
433     if D.is_zero:
434     raise ZeroDivisionError("inverse: Symbol not invertible")
435     D=1./D
436     out[0,0]=(A22*A33-A23*A32)*D
437     out[1,0]=(A31*A23-A21*A33)*D
438     out[2,0]=(A21*A32-A31*A22)*D
439     out[0,1]=(A13*A32-A12*A33)*D
440     out[1,1]=(A11*A33-A31*A13)*D
441     out[2,1]=(A12*A31-A11*A32)*D
442     out[0,2]=(A12*A23-A13*A22)*D
443     out[1,2]=(A13*A21-A11*A23)*D
444     out[2,2]=(A11*A22-A12*A21)*D
445     else:
446     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
447 caltinay 3530 return Symbol(out, dim=self.dim)
448 caltinay 3517
449 caltinay 3507 def swap_axes(self, axis0, axis1):
450 caltinay 3533 """
451     """
452 caltinay 3530 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
453 caltinay 3507
454 caltinay 3509 def tensorProduct(self, other, axis_offset):
455 caltinay 3533 """
456     """
457 caltinay 3512 arg0_c=self._arr.copy()
458     sh0=self.getShape()
459 caltinay 3507 if isinstance(other, Symbol):
460 caltinay 3512 arg1_c=other._arr.copy()
461 caltinay 3507 sh1=other.getShape()
462     else:
463     arg1_c=other.copy()
464     sh1=other.shape
465     d0,d1,d01=1,1,1
466 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
467 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
468     for i in sh1[:axis_offset]: d01*=i
469     arg0_c.resize((d0,d01))
470     arg1_c.resize((d01,d1))
471     out=numpy.zeros((d0,d1),numpy.object)
472     for i0 in range(d0):
473     for i1 in range(d1):
474     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
475 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
476 caltinay 3530 return Symbol(out, dim=self.dim)
477 caltinay 3507
478 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
479 caltinay 3533 """
480     """
481 caltinay 3512 arg0_c=self._arr.copy()
482     sh0=self.getShape()
483 caltinay 3509 if isinstance(other, Symbol):
484 caltinay 3512 arg1_c=other._arr.copy()
485 caltinay 3509 sh1=other.getShape()
486     else:
487     arg1_c=other.copy()
488     sh1=other.shape
489     d0,d1,d01=1,1,1
490     for i in sh0[axis_offset:]: d0*=i
491     for i in sh1[axis_offset:]: d1*=i
492     for i in sh1[:axis_offset]: d01*=i
493     arg0_c.resize((d01,d0))
494     arg1_c.resize((d01,d1))
495     out=numpy.zeros((d0,d1),numpy.object)
496     for i0 in range(d0):
497     for i1 in range(d1):
498     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
499     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
500 caltinay 3530 return Symbol(out, dim=self.dim)
501 caltinay 3509
502     def tensorTransposedProduct(self, other, axis_offset):
503 caltinay 3533 """
504     """
505 caltinay 3512 arg0_c=self._arr.copy()
506     sh0=self.getShape()
507 caltinay 3509 if isinstance(other, Symbol):
508 caltinay 3512 arg1_c=other._arr.copy()
509 caltinay 3509 sh1=other.getShape()
510     r1=other.getRank()
511     else:
512     arg1_c=other.copy()
513     sh1=other.shape
514     r1=other.ndim
515     d0,d1,d01=1,1,1
516 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
517 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
518     for i in sh1[r1-axis_offset:]: d01*=i
519     arg0_c.resize((d0,d01))
520     arg1_c.resize((d1,d01))
521     out=numpy.zeros((d0,d1),numpy.object)
522     for i0 in range(d0):
523     for i1 in range(d1):
524     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
525 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
526 caltinay 3530 return Symbol(out, dim=self.dim)
527 caltinay 3509
528 caltinay 3507 def trace(self, axis_offset):
529 caltinay 3533 """
530     """
531 caltinay 3512 sh=self.getShape()
532 caltinay 3507 s1=1
533     for i in range(axis_offset): s1*=sh[i]
534     s2=1
535     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
536 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
537 caltinay 3507 out=numpy.zeros([s1,s2],object)
538     for i1 in range(s1):
539     for i2 in range(s2):
540     for j in range(sh[axis_offset]):
541     out[i1,i2]+=arr_r[i1,j,j,i2]
542     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
543 caltinay 3530 return Symbol(out, dim=self.dim)
544 caltinay 3507
545     def transpose(self, axis_offset):
546 caltinay 3533 """
547     """
548 caltinay 3507 if axis_offset is None:
549 caltinay 3512 axis_offset=int(self._arr.ndim/2)
550     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
551 caltinay 3530 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
552 caltinay 3507
553 caltinay 3532 def applyfunc(self, f, on_type=None):
554 caltinay 3533 """
555     """
556 caltinay 3507 assert callable(f)
557 caltinay 3512 if self._arr.ndim==0:
558 caltinay 3532 if on_type is None or isinstance(self._arr.item(),on_type):
559     el=f(self._arr.item())
560     else:
561     el=self._arr.item()
562 caltinay 3530 if el is not None:
563     out=Symbol(el, dim=self.dim)
564     else:
565     return el
566 caltinay 3507 else:
567 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
568     for idx in numpy.ndindex(self.getShape()):
569 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
570     out[idx]=f(self._arr[idx])
571     else:
572     out[idx]=self._arr[idx]
573 caltinay 3530 out=Symbol(out, dim=self.dim)
574 caltinay 3507 return out
575    
576 caltinay 3614 def expand(self):
577     """
578     """
579     return self.applyfunc(sympy.expand, sympy.Basic)
580    
581 caltinay 3532 def simplify(self):
582 caltinay 3533 """
583     """
584 caltinay 3532 return self.applyfunc(sympy.simplify, sympy.Basic)
585    
586 caltinay 3507 def _sympy_(self):
587 caltinay 3533 """
588     """
589 caltinay 3507 return self.applyfunc(sympy.sympify)
590    
591 caltinay 3517 def _ensureShapeCompatible(self, other):
592     """
593     Checks for compatible shapes for binary operations.
594     Raises TypeError if not compatible.
595     """
596     sh0=self.getShape()
597     if isinstance(other, Symbol):
598     sh1=other.getShape()
599     elif isinstance(other, numpy.ndarray):
600     sh1=other.shape
601 caltinay 3536 elif isinstance(other, list):
602     sh1=numpy.array(other).shape
603 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
604 caltinay 3517 sh1=()
605     else:
606 caltinay 3614 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
607 caltinay 3517 if not sh0==sh1 and not sh0==() and not sh1==():
608 caltinay 3614 raise TypeError("Incompatible shapes for operation")
609 caltinay 3517
610 caltinay 3507 @staticmethod
611     def _symComp(sym):
612     n=sym.name
613     a=n.split('[')
614     if len(a)!=2:
615     return n,()
616     a=a[1].split(']')
617     if len(a)!=2:
618     return n,()
619     name=a[0]
620     comps=[int(i) for i in a[1].split('_')[1:]]
621     return name,tuple(comps)
622    
623     @staticmethod
624     def _symbolgen(*symbols):
625     """
626     Generator of all symbols in the argument of diff().
627     (cf. sympy.Derivative._symbolgen)
628    
629     Example:
630     >> ._symbolgen(x, 3, y)
631     (x, x, x, y)
632     >> ._symbolgen(x, 10**6)
633     (x, x, x, x, x, x, x, ...)
634     """
635     from itertools import repeat
636     last_s = symbols[len(symbols)-1]
637     if not isinstance(last_s, Symbol):
638     last_s=sympy.sympify(last_s)
639     for i in xrange(len(symbols)):
640     s = symbols[i]
641     if not isinstance(s, Symbol):
642     s=sympy.sympify(s)
643     next_s = None
644     if s != last_s:
645     next_s = symbols[i+1]
646     if not isinstance(next_s, Symbol):
647     next_s=sympy.sympify(next_s)
648    
649     if isinstance(s, sympy.Integer):
650     continue
651     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
652     # handle cases like (x, 3)
653     if isinstance(next_s, sympy.Integer):
654     # yield (x, x, x)
655     for copy_s in repeat(s,int(next_s)):
656     yield copy_s
657     else:
658     yield s
659     else:
660     yield s
661    
662 caltinay 3536 def __array__(self):
663     return self._arr
664    
665 caltinay 3507 # unary/binary operations follow
666    
667     def __pos__(self):
668     return self
669    
670     def __neg__(self):
671 caltinay 3530 return Symbol(-self._arr, dim=self.dim)
672 caltinay 3507
673     def __abs__(self):
674 caltinay 3530 return Symbol(abs(self._arr), dim=self.dim)
675 caltinay 3507
676     def __add__(self, other):
677 caltinay 3517 self._ensureShapeCompatible(other)
678 caltinay 3507 if isinstance(other, Symbol):
679 caltinay 3530 return Symbol(self._arr+other._arr, dim=self.dim)
680     return Symbol(self._arr+other, dim=self.dim)
681 caltinay 3507
682     def __radd__(self, other):
683 caltinay 3517 self._ensureShapeCompatible(other)
684 caltinay 3507 if isinstance(other, Symbol):
685 caltinay 3530 return Symbol(other._arr+self._arr, dim=self.dim)
686     return Symbol(other+self._arr, dim=self.dim)
687 caltinay 3507
688     def __sub__(self, other):
689 caltinay 3517 self._ensureShapeCompatible(other)
690 caltinay 3507 if isinstance(other, Symbol):
691 caltinay 3530 return Symbol(self._arr-other._arr, dim=self.dim)
692     return Symbol(self._arr-other, dim=self.dim)
693 caltinay 3507
694     def __rsub__(self, other):
695 caltinay 3517 self._ensureShapeCompatible(other)
696 caltinay 3507 if isinstance(other, Symbol):
697 caltinay 3530 return Symbol(other._arr-self._arr, dim=self.dim)
698     return Symbol(other-self._arr, dim=self.dim)
699 caltinay 3507
700     def __mul__(self, other):
701 caltinay 3517 self._ensureShapeCompatible(other)
702 caltinay 3507 if isinstance(other, Symbol):
703 caltinay 3530 return Symbol(self._arr*other._arr, dim=self.dim)
704     return Symbol(self._arr*other, dim=self.dim)
705 caltinay 3507
706     def __rmul__(self, other):
707 caltinay 3517 self._ensureShapeCompatible(other)
708 caltinay 3507 if isinstance(other, Symbol):
709 caltinay 3530 return Symbol(other._arr*self._arr, dim=self.dim)
710     return Symbol(other*self._arr, dim=self.dim)
711 caltinay 3507
712     def __div__(self, other):
713 caltinay 3517 self._ensureShapeCompatible(other)
714 caltinay 3507 if isinstance(other, Symbol):
715 caltinay 3530 return Symbol(self._arr/other._arr, dim=self.dim)
716     return Symbol(self._arr/other, dim=self.dim)
717 caltinay 3507
718     def __rdiv__(self, other):
719 caltinay 3517 self._ensureShapeCompatible(other)
720 caltinay 3507 if isinstance(other, Symbol):
721 caltinay 3530 return Symbol(other._arr/self._arr, dim=self.dim)
722     return Symbol(other/self._arr, dim=self.dim)
723 caltinay 3507
724     def __pow__(self, other):
725 caltinay 3517 self._ensureShapeCompatible(other)
726 caltinay 3507 if isinstance(other, Symbol):
727 caltinay 3530 return Symbol(self._arr**other._arr, dim=self.dim)
728     return Symbol(self._arr**other, dim=self.dim)
729 caltinay 3507
730     def __rpow__(self, other):
731 caltinay 3517 self._ensureShapeCompatible(other)
732 caltinay 3507 if isinstance(other, Symbol):
733 caltinay 3530 return Symbol(other._arr**self._arr, dim=self.dim)
734     return Symbol(other**self._arr, dim=self.dim)
735 caltinay 3507
736    
737     def symbols(*names, **kwargs):
738     """
739     Emulates the behaviour of sympy.symbols.
740     """
741    
742     shape=kwargs.pop('shape', ())
743    
744     s = names[0]
745     if not isinstance(s, list):
746     import re
747     s = re.split('\s|,', s)
748     res = []
749     for t in s:
750     # skip empty strings
751     if not t:
752     continue
753     sym = Symbol(t, shape, **kwargs)
754     res.append(sym)
755     res = tuple(res)
756     if len(res) == 0: # var('')
757     res = None
758     elif len(res) == 1: # var('x')
759     res = res[0]
760     # otherwise var('a b ...')
761     return res
762    
763     def combineData(array, shape):
764 caltinay 3533 """
765     """
766    
767 caltinay 3507 # array could just be a single value
768     if not hasattr(array,'__len__') and shape==():
769     return array
770    
771     from esys.escript import Data
772     n=numpy.array(array) # for indexing
773    
774     # find function space if any
775     dom=set()
776     fs=set()
777     for idx in numpy.ndindex(shape):
778     if isinstance(n[idx], Data):
779     fs.add(n[idx].getFunctionSpace())
780     dom.add(n[idx].getDomain())
781    
782     if len(dom)>1:
783     domain=dom.pop()
784     while len(dom)>0:
785     if domain!=dom.pop():
786     raise ValueError("Mixing of domains not supported")
787    
788     if len(fs)>0:
789     d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
790     else:
791 caltinay 3509 d=numpy.zeros(shape)
792 caltinay 3507 for idx in numpy.ndindex(shape):
793     #z=numpy.zeros(shape)
794     #z[idx]=1.
795     #d+=n[idx]*z # much slower!
796     if hasattr(n[idx], "ndim") and n[idx].ndim==0:
797     d[idx]=float(n[idx])
798     else:
799     d[idx]=n[idx]
800     return d
801    
802 caltinay 3517
803     class SymFunction(Symbol):
804     """
805     """
806     def __init__(self, *args, **kwargs):
807     """
808 caltinay 3533 Initialises a new symbolic function object.
809 caltinay 3517 """
810     super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
811     self.args=args
812    
813     def __repr__(self):
814     return self.name+"("+", ".join([str(a) for a in self.args])+")"
815    
816     def __str__(self):
817     return self.name+"("+", ".join([str(a) for a in self.args])+")"
818    
819     def lambdarepr(self):
820     return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
821    
822     def atoms(self, *types):
823     s=set()
824     for el in self.args:
825     atoms=el.atoms(*types)
826     for a in atoms:
827     if a.is_Symbol:
828     n,c=Symbol._symComp(a)
829     s.add(sympy.Symbol(n))
830     else:
831     s.add(a)
832     return s
833    
834     def __neg__(self):
835     res=self.__class__(*self.args)
836     res._arr=-res._arr
837     return res
838    

  ViewVC Help
Powered by ViewVC 1.1.26