/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3815 - (hide annotations)
Thu Feb 9 00:27:46 2012 UTC (7 years, 8 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 28317 byte(s)
Merging trunk 3814 into symbolic to get ripley.

1 caltinay 3507 # -*- coding: utf-8 -*-
2    
3     ########################################################
4     #
5 caltinay 3815 # Copyright (c) 2003-2012 by University of Queensland
6 caltinay 3507 # Earth Systems Science Computational Center (ESSCC)
7     # http://www.uq.edu.au/esscc
8     #
9     # Primary Business: Queensland, Australia
10     # Licensed under the Open Software License version 3.0
11     # http://www.opensource.org/licenses/osl-3.0.php
12     #
13     ########################################################
14    
15 caltinay 3815 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
16 caltinay 3507 Earth Systems Science Computational Center (ESSCC)
17     http://www.uq.edu.au/esscc
18     Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22    
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34    
35     __author__="Cihan Altinay"
36    
37     class Symbol(object):
38     """
39 caltinay 3533 `Symbol` objects are placeholders for a single mathematic symbol, such as
40     'x', or for arbitrarily complex mathematic expressions such as
41     'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
42     are also `Symbol`s (the symbolic 'atoms' of the expression).
43    
44     With the help of the 'Evaluator' class these symbols and expressions can
45     be resolved by substituting numeric values and/or escript `Data` objects
46     for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
47     shape (and thus a rank) as well as a dimension (see constructor).
48     `Symbol`s are useful to perform mathematic simplifications, compute
49     derivatives and as coefficients for nonlinear PDEs which can be solved by
50     the `NonlinearPDE` class.
51 caltinay 3507 """
52    
53     def __init__(self, *args, **kwargs):
54     """
55 caltinay 3533 Initialises a new `Symbol` object in one of three ways::
56    
57     u=Symbol('u')
58    
59     returns a scalar symbol by the name 'u'.
60    
61 caltinay 3535 a=Symbol('alpha', (4,3))
62 caltinay 3533
63     returns a rank 2 symbol with the shape (4,3), whose elements are
64     named '[alpha]_i_j' (with i=0..3, j=0..2).
65    
66     a,b,c=symbols('a,b,c')
67     x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
68    
69     returns a rank 2 symbol with the shape (3,3) whose elements are
70 caltinay 3536 explicitly specified by numeric values and other symbols/expressions
71 caltinay 3533 within a list or numpy array.
72    
73     The dimensionality of the symbol can be specified through the `dim`
74     keyword. All other keywords are passed to the underlying symbolic
75     library (currently sympy).
76    
77     :param args: initialisation arguments as described above
78     :keyword dim: dimensionality of the new Symbol (default: 2)
79     :type dim: ``int``
80 caltinay 3507 """
81 caltinay 3530 if 'dim' in kwargs:
82     self.dim=kwargs.pop('dim')
83     else:
84     self.dim=2
85    
86 caltinay 3507 if len(args)==1:
87     arg=args[0]
88     if isinstance(arg, str):
89 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
90 caltinay 3533 raise ValueError("Name must not contain '[' or ']'")
91 caltinay 3512 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
92 caltinay 3533 elif hasattr(arg, "__array__") or isinstance(arg, list):
93     if isinstance(arg, list): arg=numpy.array(arg)
94 caltinay 3507 arr=arg.__array__()
95     if len(arr.shape)>4:
96     raise ValueError("Symbol only supports tensors up to order 4")
97 caltinay 3533 res=numpy.empty(arr.shape, dtype=object)
98     for idx in numpy.ndindex(arr.shape):
99     if hasattr(arr[idx], "item"):
100     res[idx]=arr[idx].item()
101     else:
102     res[idx]=arr[idx]
103     self._arr=res
104     elif isinstance(arg, sympy.Basic):
105 caltinay 3512 self._arr=numpy.array(arg)
106 caltinay 3507 else:
107     raise TypeError("Unsupported argument type %s"%str(type(arg)))
108     elif len(args)==2:
109     if not isinstance(args[0], str):
110     raise TypeError("First argument must be a string")
111     if not isinstance(args[1], tuple):
112     raise TypeError("Second argument must be a tuple")
113     name=args[0]
114     shape=args[1]
115 caltinay 3533 if name.find('[')>=0 or name.find(']')>=0:
116     raise ValueError("Name must not contain '[' or ']'")
117 caltinay 3507 if len(shape)>4:
118     raise ValueError("Symbol only supports tensors up to order 4")
119 caltinay 3509 if len(shape)==0:
120 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
121 caltinay 3509 else:
122 caltinay 3631 try:
123     self._arr=sympy.symarray(shape, '['+name+']')
124     except TypeError:
125     self._arr=sympy.symarray('['+name+']', shape)
126 caltinay 3507 else:
127     raise TypeError("Unsupported number of arguments")
128 caltinay 3512 if self._arr.ndim==0:
129     self.name=str(self._arr.item())
130 caltinay 3507 else:
131 caltinay 3512 self.name=str(self._arr.tolist())
132 caltinay 3507
133     def __repr__(self):
134 caltinay 3512 return str(self._arr)
135 caltinay 3507
136     def __str__(self):
137 caltinay 3512 return str(self._arr)
138 caltinay 3507
139     def __eq__(self, other):
140     if type(self) is not type(other):
141     return False
142     if self.getRank()!=other.getRank():
143     return False
144     if self.getShape()!=other.getShape():
145     return False
146 caltinay 3512 return (self._arr==other._arr).all()
147 caltinay 3507
148     def __getitem__(self, key):
149 caltinay 3512 return self._arr[key]
150 caltinay 3507
151     def __setitem__(self, key, value):
152 caltinay 3533 if isinstance(value, Symbol):
153 caltinay 3507 if value.getRank()==0:
154 caltinay 3533 self._arr[key]=value.item()
155 caltinay 3512 elif hasattr(self._arr[key], "shape"):
156     if self._arr[key].shape==value.getShape():
157 caltinay 3533 for idx in numpy.ndindex(self._arr[key].shape):
158     self._arr[key][idx]=value[idx]
159 caltinay 3507 else:
160     raise ValueError("Wrong shape of value")
161     else:
162     raise ValueError("Wrong shape of value")
163 caltinay 3533 elif isinstance(value, sympy.Basic):
164 caltinay 3512 self._arr[key]=value
165 caltinay 3507 elif hasattr(value, "__array__"):
166 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
167 caltinay 3507 else:
168 caltinay 3512 self._arr[key]=sympy.sympify(value)
169 caltinay 3507
170 caltinay 3533 def getDim(self):
171     """
172     Returns the spatial dimensionality of this symbol.
173    
174     :return: the symbol's spatial dimensionality
175     :rtype: ``int``
176     """
177     return self.dim
178    
179 caltinay 3507 def getRank(self):
180 caltinay 3533 """
181     Returns the rank of this symbol.
182    
183     :return: the symbol's rank which is equal to the length of the shape.
184     :rtype: ``int``
185     """
186 caltinay 3512 return self._arr.ndim
187 caltinay 3507
188     def getShape(self):
189 caltinay 3533 """
190     Returns the shape of this symbol.
191    
192     :return: the symbol's shape
193     :rtype: ``tuple`` of ``int``
194     """
195 caltinay 3512 return self._arr.shape
196 caltinay 3507
197 caltinay 3533 def item(self, *args):
198     """
199     Returns an element of this symbol.
200     This method behaves like the item() method of numpy.ndarray.
201     If this is a scalar Symbol, no arguments are allowed and the only
202     element in this Symbol is returned.
203     Otherwise, 'args' specifies a flat or nd-index and the element at
204     that index is returned.
205    
206     :param args: index of item to be returned
207     :return: the requested element
208     :rtype: ``sympy.Symbol``, ``int``, or ``float``
209     """
210     return self._arr.item(args)
211    
212 caltinay 3507 def atoms(self, *types):
213 caltinay 3533 """
214     Returns the atoms that form the current Symbol.
215    
216     By default, only objects that are truly atomic and cannot be divided
217     into smaller pieces are returned: symbols, numbers, and number
218     symbols like I and pi. It is possible to request atoms of any type,
219     however.
220    
221     Note that if this symbol contains components such as [x]_i_j then
222     only their main symbol 'x' is returned.
223    
224     :param types: types to restrict result to
225     :return: list of atoms of specified type
226     :rtype: ``set``
227     """
228 caltinay 3507 s=set()
229 caltinay 3512 for el in self._arr.flat:
230 caltinay 3530 if isinstance(el,sympy.Basic):
231     atoms=el.atoms(*types)
232     for a in atoms:
233     if a.is_Symbol:
234     n,c=Symbol._symComp(a)
235     s.add(sympy.Symbol(n))
236     else:
237     s.add(a)
238 caltinay 3533 elif len(types)==0 or type(el) in types:
239     s.add(el)
240 caltinay 3507 return s
241    
242     def _sympystr_(self, printer):
243     return self.lambdarepr()
244    
245     def lambdarepr(self):
246     from sympy.printing.lambdarepr import lambdarepr
247 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
248     for idx,el in numpy.ndenumerate(self._arr):
249 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
250 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
251 caltinay 3507 symdict={}
252     for a in atoms:
253     n,c=Symbol._symComp(a)
254     if len(c)>0:
255     c=[str(i) for i in c]
256     symstr=n+'['+','.join(c)+']'
257     else:
258     symstr=n
259     symdict[a.name]=symstr
260     s=lambdarepr(el)
261     for key in symdict:
262     s=s.replace(key, symdict[key])
263     temp_arr[idx]=s
264 caltinay 3517 if self.getRank()==0:
265     return temp_arr.item()
266     else:
267     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
268 caltinay 3507
269 caltinay 3530 def coeff(self, x, expand=True):
270 caltinay 3536 """
271     Returns the coefficient of the term "x" or 0 if there is no "x".
272    
273     If "x" is a scalar symbol then "x" is searched in all components of
274     this symbol. Otherwise the shapes must match and the coefficients are
275     checked component by component.
276    
277     Example::
278    
279     x=Symbol('x', (2,2))
280     y=3*x
281     print y.coeff(x)
282     print y.coeff(x[1,1])
283    
284     will print::
285    
286     [[3 3]
287     [3 3]]
288    
289     [[0 0]
290     [0 3]]
291    
292     :param x: the term whose coefficients are to be found
293     :type x: ``Symbol``, ``numpy.ndarray``, `list`
294     :return: the coefficient(s) of the term
295     :rtype: ``Symbol``
296     """
297 caltinay 3530 self._ensureShapeCompatible(x)
298 caltinay 3536 if hasattr(x, '__array__'):
299     y=x.__array__()
300 caltinay 3530 else:
301 caltinay 3536 y=numpy.array(x)
302 caltinay 3530
303 caltinay 3536 if y.ndim>0:
304     result=numpy.zeros(self.getShape(), dtype=object)
305     for idx in numpy.ndindex(y.shape):
306     if y[idx]!=0:
307     res=self[idx].coeff(y[idx], expand)
308     if res is not None:
309     result[idx]=res
310     elif y.item()==0:
311     result=numpy.zeros(self.getShape(), dtype=object)
312     else:
313     coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
314     none_to_zero=lambda item: 0 if item is None else item
315     result=self.applyfunc(coeff_item)
316 caltinay 3614 result=result.applyfunc(none_to_zero)
317 caltinay 3536 return Symbol(result, dim=self.dim)
318 caltinay 3530
319 caltinay 3614 def subs(self, old, new):
320     """
321     Substitutes an expression.
322     """
323     if isinstance(old, Symbol) and old.getRank()>0:
324     old._ensureShapeCompatible(new)
325     if hasattr(new, '__array__'):
326     new=new.__array__()
327     else:
328     new=numpy.array(new)
329    
330     result=Symbol(self._arr, dim=self.dim)
331     if new.ndim>0:
332     for idx in numpy.ndindex(self.getShape()):
333     for symidx in numpy.ndindex(new.shape):
334     result[idx]=result[idx].subs(old[symidx], new[symidx])
335     else: # substitute scalar for non-scalar
336     for idx in numpy.ndindex(self.getShape()):
337     for symidx in numpy.ndindex(old.getShape()):
338     result[idx]=result[idx].subs(old[symidx], new.item())
339     else: # scalar
340     if isinstance(new, Symbol):
341     if new.getRank()>0:
342     raise TypeError("Cannot substitute, incompatible ranks.")
343     new=new.item()
344     if isinstance(old, Symbol):
345     old=old.item()
346     subs_item=lambda item: getattr(item, 'subs')(old, new)
347     result=self.applyfunc(subs_item)
348     return result
349    
350 caltinay 3507 def diff(self, *symbols, **assumptions):
351 caltinay 3533 """
352     """
353 caltinay 3507 symbols=Symbol._symbolgen(*symbols)
354 caltinay 3530 result=Symbol(self._arr, dim=self.dim)
355 caltinay 3507 for s in symbols:
356     if isinstance(s, Symbol):
357 caltinay 3530 if s.getRank()==0:
358 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
359 caltinay 3507 result=result.applyfunc(diff_item)
360 caltinay 3530 elif s.getRank()==1:
361     dim=s.getShape()[0]
362     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
363     for d in range(dim):
364     for idx in numpy.ndindex(self.getShape()):
365     index=idx+(d,)
366     out[index]=out[index].diff(s[d], **assumptions)
367     result=Symbol(out, dim=self.dim)
368     else:
369     raise ValueError("diff: Only rank 0 and 1 supported")
370 caltinay 3507 else:
371     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
372     result=result.applyfunc(diff_item)
373     return result
374    
375 caltinay 3512 def grad(self, where=None):
376 caltinay 3533 """
377     """
378 caltinay 3512 if isinstance(where, Symbol):
379     if where.getRank()>0:
380 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
381 caltinay 3512 where=where._arr.item()
382    
383     from functions import grad_n
384 caltinay 3530 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
385     for d in range(self.dim):
386 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
387     index=idx+(d,)
388     if where is None:
389     out[index]=grad_n(out[index],d)
390     else:
391     out[index]=grad_n(out[index],d,where)
392 caltinay 3530 return Symbol(out, dim=self.dim)
393 caltinay 3512
394 caltinay 3517 def inverse(self):
395 caltinay 3533 """
396     """
397 caltinay 3517 if not self.getRank()==2:
398 caltinay 3533 raise TypeError("inverse: Only rank 2 supported")
399 caltinay 3517 s=self.getShape()
400     if not s[0] == s[1]:
401     raise ValueError("inverse: Only square shapes supported")
402     out=numpy.zeros(s, numpy.object)
403     arr=self._arr
404     if s[0]==1:
405     if arr[0,0].is_zero:
406     raise ZeroDivisionError("inverse: Symbol not invertible")
407     out[0,0]=1./arr[0,0]
408     elif s[0]==2:
409     A11=arr[0,0]
410     A12=arr[0,1]
411     A21=arr[1,0]
412     A22=arr[1,1]
413     D = A11*A22-A12*A21
414     if D.is_zero:
415     raise ZeroDivisionError("inverse: Symbol not invertible")
416     D=1./D
417     out[0,0]= A22*D
418     out[1,0]=-A21*D
419     out[0,1]=-A12*D
420     out[1,1]= A11*D
421     elif s[0]==3:
422     A11=arr[0,0]
423     A21=arr[1,0]
424     A31=arr[2,0]
425     A12=arr[0,1]
426     A22=arr[1,1]
427     A32=arr[2,1]
428     A13=arr[0,2]
429     A23=arr[1,2]
430     A33=arr[2,2]
431     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
432     if D.is_zero:
433     raise ZeroDivisionError("inverse: Symbol not invertible")
434     D=1./D
435     out[0,0]=(A22*A33-A23*A32)*D
436     out[1,0]=(A31*A23-A21*A33)*D
437     out[2,0]=(A21*A32-A31*A22)*D
438     out[0,1]=(A13*A32-A12*A33)*D
439     out[1,1]=(A11*A33-A31*A13)*D
440     out[2,1]=(A12*A31-A11*A32)*D
441     out[0,2]=(A12*A23-A13*A22)*D
442     out[1,2]=(A13*A21-A11*A23)*D
443     out[2,2]=(A11*A22-A12*A21)*D
444     else:
445     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
446 caltinay 3530 return Symbol(out, dim=self.dim)
447 caltinay 3517
448 caltinay 3507 def swap_axes(self, axis0, axis1):
449 caltinay 3533 """
450     """
451 caltinay 3530 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
452 caltinay 3507
453 caltinay 3509 def tensorProduct(self, other, axis_offset):
454 caltinay 3533 """
455     """
456 caltinay 3512 arg0_c=self._arr.copy()
457     sh0=self.getShape()
458 caltinay 3507 if isinstance(other, Symbol):
459 caltinay 3512 arg1_c=other._arr.copy()
460 caltinay 3507 sh1=other.getShape()
461     else:
462     arg1_c=other.copy()
463     sh1=other.shape
464     d0,d1,d01=1,1,1
465 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
466 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
467     for i in sh1[:axis_offset]: d01*=i
468     arg0_c.resize((d0,d01))
469     arg1_c.resize((d01,d1))
470     out=numpy.zeros((d0,d1),numpy.object)
471     for i0 in range(d0):
472     for i1 in range(d1):
473     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
474 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
475 caltinay 3530 return Symbol(out, dim=self.dim)
476 caltinay 3507
477 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
478 caltinay 3533 """
479     """
480 caltinay 3512 arg0_c=self._arr.copy()
481     sh0=self.getShape()
482 caltinay 3509 if isinstance(other, Symbol):
483 caltinay 3512 arg1_c=other._arr.copy()
484 caltinay 3509 sh1=other.getShape()
485     else:
486     arg1_c=other.copy()
487     sh1=other.shape
488     d0,d1,d01=1,1,1
489     for i in sh0[axis_offset:]: d0*=i
490     for i in sh1[axis_offset:]: d1*=i
491     for i in sh1[:axis_offset]: d01*=i
492     arg0_c.resize((d01,d0))
493     arg1_c.resize((d01,d1))
494     out=numpy.zeros((d0,d1),numpy.object)
495     for i0 in range(d0):
496     for i1 in range(d1):
497     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
498     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
499 caltinay 3530 return Symbol(out, dim=self.dim)
500 caltinay 3509
501     def tensorTransposedProduct(self, other, axis_offset):
502 caltinay 3533 """
503     """
504 caltinay 3512 arg0_c=self._arr.copy()
505     sh0=self.getShape()
506 caltinay 3509 if isinstance(other, Symbol):
507 caltinay 3512 arg1_c=other._arr.copy()
508 caltinay 3509 sh1=other.getShape()
509     r1=other.getRank()
510     else:
511     arg1_c=other.copy()
512     sh1=other.shape
513     r1=other.ndim
514     d0,d1,d01=1,1,1
515 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
516 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
517     for i in sh1[r1-axis_offset:]: d01*=i
518     arg0_c.resize((d0,d01))
519     arg1_c.resize((d1,d01))
520     out=numpy.zeros((d0,d1),numpy.object)
521     for i0 in range(d0):
522     for i1 in range(d1):
523     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
524 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
525 caltinay 3530 return Symbol(out, dim=self.dim)
526 caltinay 3509
527 caltinay 3507 def trace(self, axis_offset):
528 caltinay 3533 """
529     """
530 caltinay 3512 sh=self.getShape()
531 caltinay 3507 s1=1
532     for i in range(axis_offset): s1*=sh[i]
533     s2=1
534     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
535 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
536 caltinay 3507 out=numpy.zeros([s1,s2],object)
537     for i1 in range(s1):
538     for i2 in range(s2):
539     for j in range(sh[axis_offset]):
540     out[i1,i2]+=arr_r[i1,j,j,i2]
541     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
542 caltinay 3530 return Symbol(out, dim=self.dim)
543 caltinay 3507
544     def transpose(self, axis_offset):
545 caltinay 3533 """
546     """
547 caltinay 3507 if axis_offset is None:
548 caltinay 3512 axis_offset=int(self._arr.ndim/2)
549     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
550 caltinay 3530 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
551 caltinay 3507
552 caltinay 3532 def applyfunc(self, f, on_type=None):
553 caltinay 3533 """
554     """
555 caltinay 3507 assert callable(f)
556 caltinay 3512 if self._arr.ndim==0:
557 caltinay 3532 if on_type is None or isinstance(self._arr.item(),on_type):
558     el=f(self._arr.item())
559     else:
560     el=self._arr.item()
561 caltinay 3530 if el is not None:
562     out=Symbol(el, dim=self.dim)
563     else:
564     return el
565 caltinay 3507 else:
566 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
567     for idx in numpy.ndindex(self.getShape()):
568 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
569     out[idx]=f(self._arr[idx])
570     else:
571     out[idx]=self._arr[idx]
572 caltinay 3530 out=Symbol(out, dim=self.dim)
573 caltinay 3507 return out
574    
575 caltinay 3614 def expand(self):
576     """
577     """
578     return self.applyfunc(sympy.expand, sympy.Basic)
579    
580 caltinay 3532 def simplify(self):
581 caltinay 3533 """
582     """
583 caltinay 3532 return self.applyfunc(sympy.simplify, sympy.Basic)
584    
585 caltinay 3507 def _sympy_(self):
586 caltinay 3533 """
587     """
588 caltinay 3507 return self.applyfunc(sympy.sympify)
589    
590 caltinay 3517 def _ensureShapeCompatible(self, other):
591     """
592     Checks for compatible shapes for binary operations.
593     Raises TypeError if not compatible.
594     """
595     sh0=self.getShape()
596     if isinstance(other, Symbol):
597     sh1=other.getShape()
598     elif isinstance(other, numpy.ndarray):
599     sh1=other.shape
600 caltinay 3536 elif isinstance(other, list):
601     sh1=numpy.array(other).shape
602 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
603 caltinay 3517 sh1=()
604     else:
605 caltinay 3614 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
606 caltinay 3517 if not sh0==sh1 and not sh0==() and not sh1==():
607 caltinay 3614 raise TypeError("Incompatible shapes for operation")
608 caltinay 3517
609 caltinay 3507 @staticmethod
610     def _symComp(sym):
611     n=sym.name
612     a=n.split('[')
613     if len(a)!=2:
614     return n,()
615     a=a[1].split(']')
616     if len(a)!=2:
617     return n,()
618     name=a[0]
619     comps=[int(i) for i in a[1].split('_')[1:]]
620     return name,tuple(comps)
621    
622     @staticmethod
623     def _symbolgen(*symbols):
624     """
625     Generator of all symbols in the argument of diff().
626     (cf. sympy.Derivative._symbolgen)
627    
628     Example:
629     >> ._symbolgen(x, 3, y)
630     (x, x, x, y)
631     >> ._symbolgen(x, 10**6)
632     (x, x, x, x, x, x, x, ...)
633     """
634     from itertools import repeat
635     last_s = symbols[len(symbols)-1]
636     if not isinstance(last_s, Symbol):
637     last_s=sympy.sympify(last_s)
638     for i in xrange(len(symbols)):
639     s = symbols[i]
640     if not isinstance(s, Symbol):
641     s=sympy.sympify(s)
642     next_s = None
643     if s != last_s:
644     next_s = symbols[i+1]
645     if not isinstance(next_s, Symbol):
646     next_s=sympy.sympify(next_s)
647    
648     if isinstance(s, sympy.Integer):
649     continue
650     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
651     # handle cases like (x, 3)
652     if isinstance(next_s, sympy.Integer):
653     # yield (x, x, x)
654     for copy_s in repeat(s,int(next_s)):
655     yield copy_s
656     else:
657     yield s
658     else:
659     yield s
660    
661 caltinay 3536 def __array__(self):
662     return self._arr
663    
664 caltinay 3507 # unary/binary operations follow
665    
666     def __pos__(self):
667     return self
668    
669     def __neg__(self):
670 caltinay 3530 return Symbol(-self._arr, dim=self.dim)
671 caltinay 3507
672     def __abs__(self):
673 caltinay 3530 return Symbol(abs(self._arr), dim=self.dim)
674 caltinay 3507
675     def __add__(self, other):
676 caltinay 3517 self._ensureShapeCompatible(other)
677 caltinay 3507 if isinstance(other, Symbol):
678 caltinay 3530 return Symbol(self._arr+other._arr, dim=self.dim)
679     return Symbol(self._arr+other, dim=self.dim)
680 caltinay 3507
681     def __radd__(self, other):
682 caltinay 3517 self._ensureShapeCompatible(other)
683 caltinay 3507 if isinstance(other, Symbol):
684 caltinay 3530 return Symbol(other._arr+self._arr, dim=self.dim)
685     return Symbol(other+self._arr, dim=self.dim)
686 caltinay 3507
687     def __sub__(self, other):
688 caltinay 3517 self._ensureShapeCompatible(other)
689 caltinay 3507 if isinstance(other, Symbol):
690 caltinay 3530 return Symbol(self._arr-other._arr, dim=self.dim)
691     return Symbol(self._arr-other, dim=self.dim)
692 caltinay 3507
693     def __rsub__(self, other):
694 caltinay 3517 self._ensureShapeCompatible(other)
695 caltinay 3507 if isinstance(other, Symbol):
696 caltinay 3530 return Symbol(other._arr-self._arr, dim=self.dim)
697     return Symbol(other-self._arr, dim=self.dim)
698 caltinay 3507
699     def __mul__(self, other):
700 caltinay 3517 self._ensureShapeCompatible(other)
701 caltinay 3507 if isinstance(other, Symbol):
702 caltinay 3530 return Symbol(self._arr*other._arr, dim=self.dim)
703     return Symbol(self._arr*other, dim=self.dim)
704 caltinay 3507
705     def __rmul__(self, other):
706 caltinay 3517 self._ensureShapeCompatible(other)
707 caltinay 3507 if isinstance(other, Symbol):
708 caltinay 3530 return Symbol(other._arr*self._arr, dim=self.dim)
709     return Symbol(other*self._arr, dim=self.dim)
710 caltinay 3507
711     def __div__(self, other):
712 caltinay 3517 self._ensureShapeCompatible(other)
713 caltinay 3507 if isinstance(other, Symbol):
714 caltinay 3530 return Symbol(self._arr/other._arr, dim=self.dim)
715     return Symbol(self._arr/other, dim=self.dim)
716 caltinay 3507
717     def __rdiv__(self, other):
718 caltinay 3517 self._ensureShapeCompatible(other)
719 caltinay 3507 if isinstance(other, Symbol):
720 caltinay 3530 return Symbol(other._arr/self._arr, dim=self.dim)
721     return Symbol(other/self._arr, dim=self.dim)
722 caltinay 3507
723     def __pow__(self, other):
724 caltinay 3517 self._ensureShapeCompatible(other)
725 caltinay 3507 if isinstance(other, Symbol):
726 caltinay 3530 return Symbol(self._arr**other._arr, dim=self.dim)
727     return Symbol(self._arr**other, dim=self.dim)
728 caltinay 3507
729     def __rpow__(self, other):
730 caltinay 3517 self._ensureShapeCompatible(other)
731 caltinay 3507 if isinstance(other, Symbol):
732 caltinay 3530 return Symbol(other._arr**self._arr, dim=self.dim)
733     return Symbol(other**self._arr, dim=self.dim)
734 caltinay 3507
735    
736     def symbols(*names, **kwargs):
737     """
738     Emulates the behaviour of sympy.symbols.
739     """
740    
741     shape=kwargs.pop('shape', ())
742    
743     s = names[0]
744     if not isinstance(s, list):
745     import re
746     s = re.split('\s|,', s)
747     res = []
748     for t in s:
749     # skip empty strings
750     if not t:
751     continue
752     sym = Symbol(t, shape, **kwargs)
753     res.append(sym)
754     res = tuple(res)
755     if len(res) == 0: # var('')
756     res = None
757     elif len(res) == 1: # var('x')
758     res = res[0]
759     # otherwise var('a b ...')
760     return res
761    
762     def combineData(array, shape):
763 caltinay 3533 """
764     """
765    
766 caltinay 3507 # array could just be a single value
767     if not hasattr(array,'__len__') and shape==():
768     return array
769    
770     from esys.escript import Data
771     n=numpy.array(array) # for indexing
772    
773     # find function space if any
774     dom=set()
775     fs=set()
776     for idx in numpy.ndindex(shape):
777     if isinstance(n[idx], Data):
778     fs.add(n[idx].getFunctionSpace())
779     dom.add(n[idx].getDomain())
780    
781     if len(dom)>1:
782     domain=dom.pop()
783     while len(dom)>0:
784     if domain!=dom.pop():
785     raise ValueError("Mixing of domains not supported")
786    
787     if len(fs)>0:
788     d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
789     else:
790 caltinay 3509 d=numpy.zeros(shape)
791 caltinay 3507 for idx in numpy.ndindex(shape):
792     #z=numpy.zeros(shape)
793     #z[idx]=1.
794     #d+=n[idx]*z # much slower!
795     if hasattr(n[idx], "ndim") and n[idx].ndim==0:
796     d[idx]=float(n[idx])
797     else:
798     d[idx]=n[idx]
799     return d
800    
801 caltinay 3517
802     class SymFunction(Symbol):
803     """
804     """
805     def __init__(self, *args, **kwargs):
806     """
807 caltinay 3533 Initialises a new symbolic function object.
808 caltinay 3517 """
809     super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
810     self.args=args
811    
812     def __repr__(self):
813     return self.name+"("+", ".join([str(a) for a in self.args])+")"
814    
815     def __str__(self):
816     return self.name+"("+", ".join([str(a) for a in self.args])+")"
817    
818     def lambdarepr(self):
819     return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
820    
821     def atoms(self, *types):
822     s=set()
823     for el in self.args:
824     atoms=el.atoms(*types)
825     for a in atoms:
826     if a.is_Symbol:
827     n,c=Symbol._symComp(a)
828     s.add(sympy.Symbol(n))
829     else:
830     s.add(a)
831     return s
832    
833     def __neg__(self):
834     res=self.__class__(*self.args)
835     res._arr=-res._arr
836     return res
837    

  ViewVC Help
Powered by ViewVC 1.1.26