/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3818 - (hide annotations)
Mon Feb 13 01:12:08 2012 UTC (7 years, 8 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 28602 byte(s)
-Fixed compatibility with sympy 1.7
-Fixed unit tests (1.0*x != x in sympy)

1 caltinay 3507 # -*- coding: utf-8 -*-
2    
3     ########################################################
4     #
5 caltinay 3815 # Copyright (c) 2003-2012 by University of Queensland
6 caltinay 3507 # Earth Systems Science Computational Center (ESSCC)
7     # http://www.uq.edu.au/esscc
8     #
9     # Primary Business: Queensland, Australia
10     # Licensed under the Open Software License version 3.0
11     # http://www.opensource.org/licenses/osl-3.0.php
12     #
13     ########################################################
14    
15 caltinay 3815 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
16 caltinay 3507 Earth Systems Science Computational Center (ESSCC)
17     http://www.uq.edu.au/esscc
18     Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22    
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34    
35     __author__="Cihan Altinay"
36    
37     class Symbol(object):
38     """
39 caltinay 3533 `Symbol` objects are placeholders for a single mathematic symbol, such as
40     'x', or for arbitrarily complex mathematic expressions such as
41     'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
42     are also `Symbol`s (the symbolic 'atoms' of the expression).
43    
44     With the help of the 'Evaluator' class these symbols and expressions can
45     be resolved by substituting numeric values and/or escript `Data` objects
46     for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
47     shape (and thus a rank) as well as a dimension (see constructor).
48     `Symbol`s are useful to perform mathematic simplifications, compute
49     derivatives and as coefficients for nonlinear PDEs which can be solved by
50     the `NonlinearPDE` class.
51 caltinay 3507 """
52    
53 caltinay 3818 # these are for compatibility with sympy.Symbol. lambdify checks these.
54     is_Add=False
55     is_Float=False
56    
57 caltinay 3507 def __init__(self, *args, **kwargs):
58     """
59 caltinay 3533 Initialises a new `Symbol` object in one of three ways::
60    
61     u=Symbol('u')
62    
63     returns a scalar symbol by the name 'u'.
64    
65 caltinay 3535 a=Symbol('alpha', (4,3))
66 caltinay 3533
67     returns a rank 2 symbol with the shape (4,3), whose elements are
68     named '[alpha]_i_j' (with i=0..3, j=0..2).
69    
70     a,b,c=symbols('a,b,c')
71     x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
72    
73     returns a rank 2 symbol with the shape (3,3) whose elements are
74 caltinay 3536 explicitly specified by numeric values and other symbols/expressions
75 caltinay 3533 within a list or numpy array.
76    
77     The dimensionality of the symbol can be specified through the `dim`
78     keyword. All other keywords are passed to the underlying symbolic
79     library (currently sympy).
80    
81     :param args: initialisation arguments as described above
82     :keyword dim: dimensionality of the new Symbol (default: 2)
83     :type dim: ``int``
84 caltinay 3507 """
85 caltinay 3530 if 'dim' in kwargs:
86     self.dim=kwargs.pop('dim')
87     else:
88     self.dim=2
89    
90 caltinay 3507 if len(args)==1:
91     arg=args[0]
92     if isinstance(arg, str):
93 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
94 caltinay 3533 raise ValueError("Name must not contain '[' or ']'")
95 caltinay 3512 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
96 caltinay 3533 elif hasattr(arg, "__array__") or isinstance(arg, list):
97     if isinstance(arg, list): arg=numpy.array(arg)
98 caltinay 3507 arr=arg.__array__()
99     if len(arr.shape)>4:
100     raise ValueError("Symbol only supports tensors up to order 4")
101 caltinay 3533 res=numpy.empty(arr.shape, dtype=object)
102     for idx in numpy.ndindex(arr.shape):
103     if hasattr(arr[idx], "item"):
104     res[idx]=arr[idx].item()
105     else:
106     res[idx]=arr[idx]
107     self._arr=res
108     elif isinstance(arg, sympy.Basic):
109 caltinay 3512 self._arr=numpy.array(arg)
110 caltinay 3507 else:
111     raise TypeError("Unsupported argument type %s"%str(type(arg)))
112     elif len(args)==2:
113     if not isinstance(args[0], str):
114     raise TypeError("First argument must be a string")
115     if not isinstance(args[1], tuple):
116     raise TypeError("Second argument must be a tuple")
117     name=args[0]
118     shape=args[1]
119 caltinay 3533 if name.find('[')>=0 or name.find(']')>=0:
120     raise ValueError("Name must not contain '[' or ']'")
121 caltinay 3507 if len(shape)>4:
122     raise ValueError("Symbol only supports tensors up to order 4")
123 caltinay 3509 if len(shape)==0:
124 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
125 caltinay 3509 else:
126 caltinay 3631 try:
127     self._arr=sympy.symarray(shape, '['+name+']')
128     except TypeError:
129     self._arr=sympy.symarray('['+name+']', shape)
130 caltinay 3507 else:
131     raise TypeError("Unsupported number of arguments")
132 caltinay 3512 if self._arr.ndim==0:
133     self.name=str(self._arr.item())
134 caltinay 3507 else:
135 caltinay 3512 self.name=str(self._arr.tolist())
136 caltinay 3507
137     def __repr__(self):
138 caltinay 3512 return str(self._arr)
139 caltinay 3507
140     def __str__(self):
141 caltinay 3512 return str(self._arr)
142 caltinay 3507
143     def __eq__(self, other):
144     if type(self) is not type(other):
145     return False
146     if self.getRank()!=other.getRank():
147     return False
148     if self.getShape()!=other.getShape():
149     return False
150 caltinay 3512 return (self._arr==other._arr).all()
151 caltinay 3507
152     def __getitem__(self, key):
153 caltinay 3512 return self._arr[key]
154 caltinay 3507
155 caltinay 3818 def __iter__(self):
156     return self._arr.__iter__
157    
158 caltinay 3507 def __setitem__(self, key, value):
159 caltinay 3533 if isinstance(value, Symbol):
160 caltinay 3507 if value.getRank()==0:
161 caltinay 3533 self._arr[key]=value.item()
162 caltinay 3512 elif hasattr(self._arr[key], "shape"):
163     if self._arr[key].shape==value.getShape():
164 caltinay 3533 for idx in numpy.ndindex(self._arr[key].shape):
165     self._arr[key][idx]=value[idx]
166 caltinay 3507 else:
167     raise ValueError("Wrong shape of value")
168     else:
169     raise ValueError("Wrong shape of value")
170 caltinay 3533 elif isinstance(value, sympy.Basic):
171 caltinay 3512 self._arr[key]=value
172 caltinay 3507 elif hasattr(value, "__array__"):
173 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
174 caltinay 3507 else:
175 caltinay 3512 self._arr[key]=sympy.sympify(value)
176 caltinay 3507
177 caltinay 3533 def getDim(self):
178     """
179     Returns the spatial dimensionality of this symbol.
180    
181     :return: the symbol's spatial dimensionality
182     :rtype: ``int``
183     """
184     return self.dim
185    
186 caltinay 3507 def getRank(self):
187 caltinay 3533 """
188     Returns the rank of this symbol.
189    
190     :return: the symbol's rank which is equal to the length of the shape.
191     :rtype: ``int``
192     """
193 caltinay 3512 return self._arr.ndim
194 caltinay 3507
195     def getShape(self):
196 caltinay 3533 """
197     Returns the shape of this symbol.
198    
199     :return: the symbol's shape
200     :rtype: ``tuple`` of ``int``
201     """
202 caltinay 3512 return self._arr.shape
203 caltinay 3507
204 caltinay 3533 def item(self, *args):
205     """
206     Returns an element of this symbol.
207     This method behaves like the item() method of numpy.ndarray.
208     If this is a scalar Symbol, no arguments are allowed and the only
209     element in this Symbol is returned.
210     Otherwise, 'args' specifies a flat or nd-index and the element at
211     that index is returned.
212    
213     :param args: index of item to be returned
214     :return: the requested element
215     :rtype: ``sympy.Symbol``, ``int``, or ``float``
216     """
217     return self._arr.item(args)
218    
219 caltinay 3507 def atoms(self, *types):
220 caltinay 3533 """
221     Returns the atoms that form the current Symbol.
222    
223     By default, only objects that are truly atomic and cannot be divided
224     into smaller pieces are returned: symbols, numbers, and number
225     symbols like I and pi. It is possible to request atoms of any type,
226     however.
227    
228     Note that if this symbol contains components such as [x]_i_j then
229     only their main symbol 'x' is returned.
230    
231     :param types: types to restrict result to
232     :return: list of atoms of specified type
233     :rtype: ``set``
234     """
235 caltinay 3507 s=set()
236 caltinay 3512 for el in self._arr.flat:
237 caltinay 3530 if isinstance(el,sympy.Basic):
238     atoms=el.atoms(*types)
239     for a in atoms:
240     if a.is_Symbol:
241     n,c=Symbol._symComp(a)
242     s.add(sympy.Symbol(n))
243     else:
244     s.add(a)
245 caltinay 3533 elif len(types)==0 or type(el) in types:
246     s.add(el)
247 caltinay 3507 return s
248    
249     def _sympystr_(self, printer):
250 caltinay 3818 # compatibility with sympy 1.6
251     return self._sympystr(printer)
252    
253     def _sympystr(self, printer):
254 caltinay 3507 return self.lambdarepr()
255    
256     def lambdarepr(self):
257     from sympy.printing.lambdarepr import lambdarepr
258 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
259     for idx,el in numpy.ndenumerate(self._arr):
260 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
261 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
262 caltinay 3507 symdict={}
263     for a in atoms:
264     n,c=Symbol._symComp(a)
265     if len(c)>0:
266     c=[str(i) for i in c]
267     symstr=n+'['+','.join(c)+']'
268     else:
269     symstr=n
270     symdict[a.name]=symstr
271     s=lambdarepr(el)
272     for key in symdict:
273     s=s.replace(key, symdict[key])
274     temp_arr[idx]=s
275 caltinay 3517 if self.getRank()==0:
276     return temp_arr.item()
277     else:
278     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
279 caltinay 3507
280 caltinay 3530 def coeff(self, x, expand=True):
281 caltinay 3536 """
282     Returns the coefficient of the term "x" or 0 if there is no "x".
283    
284     If "x" is a scalar symbol then "x" is searched in all components of
285     this symbol. Otherwise the shapes must match and the coefficients are
286     checked component by component.
287    
288     Example::
289    
290     x=Symbol('x', (2,2))
291     y=3*x
292     print y.coeff(x)
293     print y.coeff(x[1,1])
294    
295     will print::
296    
297     [[3 3]
298     [3 3]]
299    
300     [[0 0]
301     [0 3]]
302    
303     :param x: the term whose coefficients are to be found
304     :type x: ``Symbol``, ``numpy.ndarray``, `list`
305     :return: the coefficient(s) of the term
306     :rtype: ``Symbol``
307     """
308 caltinay 3530 self._ensureShapeCompatible(x)
309 caltinay 3536 if hasattr(x, '__array__'):
310     y=x.__array__()
311 caltinay 3530 else:
312 caltinay 3536 y=numpy.array(x)
313 caltinay 3530
314 caltinay 3536 if y.ndim>0:
315     result=numpy.zeros(self.getShape(), dtype=object)
316     for idx in numpy.ndindex(y.shape):
317     if y[idx]!=0:
318     res=self[idx].coeff(y[idx], expand)
319     if res is not None:
320     result[idx]=res
321     elif y.item()==0:
322     result=numpy.zeros(self.getShape(), dtype=object)
323     else:
324     coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
325     none_to_zero=lambda item: 0 if item is None else item
326     result=self.applyfunc(coeff_item)
327 caltinay 3614 result=result.applyfunc(none_to_zero)
328 caltinay 3536 return Symbol(result, dim=self.dim)
329 caltinay 3530
330 caltinay 3614 def subs(self, old, new):
331     """
332     Substitutes an expression.
333     """
334     if isinstance(old, Symbol) and old.getRank()>0:
335     old._ensureShapeCompatible(new)
336     if hasattr(new, '__array__'):
337     new=new.__array__()
338     else:
339     new=numpy.array(new)
340    
341     result=Symbol(self._arr, dim=self.dim)
342     if new.ndim>0:
343     for idx in numpy.ndindex(self.getShape()):
344     for symidx in numpy.ndindex(new.shape):
345     result[idx]=result[idx].subs(old[symidx], new[symidx])
346     else: # substitute scalar for non-scalar
347     for idx in numpy.ndindex(self.getShape()):
348     for symidx in numpy.ndindex(old.getShape()):
349     result[idx]=result[idx].subs(old[symidx], new.item())
350     else: # scalar
351     if isinstance(new, Symbol):
352     if new.getRank()>0:
353     raise TypeError("Cannot substitute, incompatible ranks.")
354     new=new.item()
355     if isinstance(old, Symbol):
356     old=old.item()
357     subs_item=lambda item: getattr(item, 'subs')(old, new)
358     result=self.applyfunc(subs_item)
359     return result
360    
361 caltinay 3507 def diff(self, *symbols, **assumptions):
362 caltinay 3533 """
363     """
364 caltinay 3507 symbols=Symbol._symbolgen(*symbols)
365 caltinay 3530 result=Symbol(self._arr, dim=self.dim)
366 caltinay 3507 for s in symbols:
367     if isinstance(s, Symbol):
368 caltinay 3530 if s.getRank()==0:
369 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
370 caltinay 3507 result=result.applyfunc(diff_item)
371 caltinay 3530 elif s.getRank()==1:
372     dim=s.getShape()[0]
373     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
374     for d in range(dim):
375     for idx in numpy.ndindex(self.getShape()):
376     index=idx+(d,)
377     out[index]=out[index].diff(s[d], **assumptions)
378     result=Symbol(out, dim=self.dim)
379     else:
380     raise ValueError("diff: Only rank 0 and 1 supported")
381 caltinay 3507 else:
382     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
383     result=result.applyfunc(diff_item)
384     return result
385    
386 caltinay 3512 def grad(self, where=None):
387 caltinay 3533 """
388     """
389 caltinay 3512 if isinstance(where, Symbol):
390     if where.getRank()>0:
391 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
392 caltinay 3512 where=where._arr.item()
393    
394     from functions import grad_n
395 caltinay 3530 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
396     for d in range(self.dim):
397 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
398     index=idx+(d,)
399     if where is None:
400     out[index]=grad_n(out[index],d)
401     else:
402     out[index]=grad_n(out[index],d,where)
403 caltinay 3530 return Symbol(out, dim=self.dim)
404 caltinay 3512
405 caltinay 3517 def inverse(self):
406 caltinay 3533 """
407     """
408 caltinay 3517 if not self.getRank()==2:
409 caltinay 3533 raise TypeError("inverse: Only rank 2 supported")
410 caltinay 3517 s=self.getShape()
411     if not s[0] == s[1]:
412     raise ValueError("inverse: Only square shapes supported")
413     out=numpy.zeros(s, numpy.object)
414     arr=self._arr
415     if s[0]==1:
416     if arr[0,0].is_zero:
417     raise ZeroDivisionError("inverse: Symbol not invertible")
418     out[0,0]=1./arr[0,0]
419     elif s[0]==2:
420     A11=arr[0,0]
421     A12=arr[0,1]
422     A21=arr[1,0]
423     A22=arr[1,1]
424     D = A11*A22-A12*A21
425     if D.is_zero:
426     raise ZeroDivisionError("inverse: Symbol not invertible")
427     D=1./D
428     out[0,0]= A22*D
429     out[1,0]=-A21*D
430     out[0,1]=-A12*D
431     out[1,1]= A11*D
432     elif s[0]==3:
433     A11=arr[0,0]
434     A21=arr[1,0]
435     A31=arr[2,0]
436     A12=arr[0,1]
437     A22=arr[1,1]
438     A32=arr[2,1]
439     A13=arr[0,2]
440     A23=arr[1,2]
441     A33=arr[2,2]
442     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
443     if D.is_zero:
444     raise ZeroDivisionError("inverse: Symbol not invertible")
445     D=1./D
446     out[0,0]=(A22*A33-A23*A32)*D
447     out[1,0]=(A31*A23-A21*A33)*D
448     out[2,0]=(A21*A32-A31*A22)*D
449     out[0,1]=(A13*A32-A12*A33)*D
450     out[1,1]=(A11*A33-A31*A13)*D
451     out[2,1]=(A12*A31-A11*A32)*D
452     out[0,2]=(A12*A23-A13*A22)*D
453     out[1,2]=(A13*A21-A11*A23)*D
454     out[2,2]=(A11*A22-A12*A21)*D
455     else:
456     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
457 caltinay 3530 return Symbol(out, dim=self.dim)
458 caltinay 3517
459 caltinay 3507 def swap_axes(self, axis0, axis1):
460 caltinay 3533 """
461     """
462 caltinay 3530 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
463 caltinay 3507
464 caltinay 3509 def tensorProduct(self, other, axis_offset):
465 caltinay 3533 """
466     """
467 caltinay 3512 arg0_c=self._arr.copy()
468     sh0=self.getShape()
469 caltinay 3507 if isinstance(other, Symbol):
470 caltinay 3512 arg1_c=other._arr.copy()
471 caltinay 3507 sh1=other.getShape()
472     else:
473     arg1_c=other.copy()
474     sh1=other.shape
475     d0,d1,d01=1,1,1
476 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
477 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
478     for i in sh1[:axis_offset]: d01*=i
479     arg0_c.resize((d0,d01))
480     arg1_c.resize((d01,d1))
481     out=numpy.zeros((d0,d1),numpy.object)
482     for i0 in range(d0):
483     for i1 in range(d1):
484     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
485 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
486 caltinay 3530 return Symbol(out, dim=self.dim)
487 caltinay 3507
488 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
489 caltinay 3533 """
490     """
491 caltinay 3512 arg0_c=self._arr.copy()
492     sh0=self.getShape()
493 caltinay 3509 if isinstance(other, Symbol):
494 caltinay 3512 arg1_c=other._arr.copy()
495 caltinay 3509 sh1=other.getShape()
496     else:
497     arg1_c=other.copy()
498     sh1=other.shape
499     d0,d1,d01=1,1,1
500     for i in sh0[axis_offset:]: d0*=i
501     for i in sh1[axis_offset:]: d1*=i
502     for i in sh1[:axis_offset]: d01*=i
503     arg0_c.resize((d01,d0))
504     arg1_c.resize((d01,d1))
505     out=numpy.zeros((d0,d1),numpy.object)
506     for i0 in range(d0):
507     for i1 in range(d1):
508     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
509     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
510 caltinay 3530 return Symbol(out, dim=self.dim)
511 caltinay 3509
512     def tensorTransposedProduct(self, other, axis_offset):
513 caltinay 3533 """
514     """
515 caltinay 3512 arg0_c=self._arr.copy()
516     sh0=self.getShape()
517 caltinay 3509 if isinstance(other, Symbol):
518 caltinay 3512 arg1_c=other._arr.copy()
519 caltinay 3509 sh1=other.getShape()
520     r1=other.getRank()
521     else:
522     arg1_c=other.copy()
523     sh1=other.shape
524     r1=other.ndim
525     d0,d1,d01=1,1,1
526 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
527 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
528     for i in sh1[r1-axis_offset:]: d01*=i
529     arg0_c.resize((d0,d01))
530     arg1_c.resize((d1,d01))
531     out=numpy.zeros((d0,d1),numpy.object)
532     for i0 in range(d0):
533     for i1 in range(d1):
534     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
535 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
536 caltinay 3530 return Symbol(out, dim=self.dim)
537 caltinay 3509
538 caltinay 3507 def trace(self, axis_offset):
539 caltinay 3533 """
540     """
541 caltinay 3512 sh=self.getShape()
542 caltinay 3507 s1=1
543     for i in range(axis_offset): s1*=sh[i]
544     s2=1
545     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
546 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
547 caltinay 3507 out=numpy.zeros([s1,s2],object)
548     for i1 in range(s1):
549     for i2 in range(s2):
550     for j in range(sh[axis_offset]):
551     out[i1,i2]+=arr_r[i1,j,j,i2]
552     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
553 caltinay 3530 return Symbol(out, dim=self.dim)
554 caltinay 3507
555     def transpose(self, axis_offset):
556 caltinay 3533 """
557     """
558 caltinay 3507 if axis_offset is None:
559 caltinay 3512 axis_offset=int(self._arr.ndim/2)
560     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
561 caltinay 3530 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
562 caltinay 3507
563 caltinay 3532 def applyfunc(self, f, on_type=None):
564 caltinay 3533 """
565     """
566 caltinay 3507 assert callable(f)
567 caltinay 3512 if self._arr.ndim==0:
568 caltinay 3532 if on_type is None or isinstance(self._arr.item(),on_type):
569     el=f(self._arr.item())
570     else:
571     el=self._arr.item()
572 caltinay 3530 if el is not None:
573     out=Symbol(el, dim=self.dim)
574     else:
575     return el
576 caltinay 3507 else:
577 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
578     for idx in numpy.ndindex(self.getShape()):
579 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
580     out[idx]=f(self._arr[idx])
581     else:
582     out[idx]=self._arr[idx]
583 caltinay 3530 out=Symbol(out, dim=self.dim)
584 caltinay 3507 return out
585    
586 caltinay 3614 def expand(self):
587     """
588     """
589     return self.applyfunc(sympy.expand, sympy.Basic)
590    
591 caltinay 3532 def simplify(self):
592 caltinay 3533 """
593     """
594 caltinay 3532 return self.applyfunc(sympy.simplify, sympy.Basic)
595    
596 caltinay 3507 def _sympy_(self):
597 caltinay 3533 """
598     """
599 caltinay 3507 return self.applyfunc(sympy.sympify)
600    
601 caltinay 3517 def _ensureShapeCompatible(self, other):
602     """
603     Checks for compatible shapes for binary operations.
604     Raises TypeError if not compatible.
605     """
606     sh0=self.getShape()
607     if isinstance(other, Symbol):
608     sh1=other.getShape()
609     elif isinstance(other, numpy.ndarray):
610     sh1=other.shape
611 caltinay 3536 elif isinstance(other, list):
612     sh1=numpy.array(other).shape
613 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
614 caltinay 3517 sh1=()
615     else:
616 caltinay 3614 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
617 caltinay 3517 if not sh0==sh1 and not sh0==() and not sh1==():
618 caltinay 3614 raise TypeError("Incompatible shapes for operation")
619 caltinay 3517
620 caltinay 3507 @staticmethod
621     def _symComp(sym):
622     n=sym.name
623     a=n.split('[')
624     if len(a)!=2:
625     return n,()
626     a=a[1].split(']')
627     if len(a)!=2:
628     return n,()
629     name=a[0]
630     comps=[int(i) for i in a[1].split('_')[1:]]
631     return name,tuple(comps)
632    
633     @staticmethod
634     def _symbolgen(*symbols):
635     """
636     Generator of all symbols in the argument of diff().
637     (cf. sympy.Derivative._symbolgen)
638    
639     Example:
640     >> ._symbolgen(x, 3, y)
641     (x, x, x, y)
642     >> ._symbolgen(x, 10**6)
643     (x, x, x, x, x, x, x, ...)
644     """
645     from itertools import repeat
646     last_s = symbols[len(symbols)-1]
647     if not isinstance(last_s, Symbol):
648     last_s=sympy.sympify(last_s)
649     for i in xrange(len(symbols)):
650     s = symbols[i]
651     if not isinstance(s, Symbol):
652     s=sympy.sympify(s)
653     next_s = None
654     if s != last_s:
655     next_s = symbols[i+1]
656     if not isinstance(next_s, Symbol):
657     next_s=sympy.sympify(next_s)
658    
659     if isinstance(s, sympy.Integer):
660     continue
661     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
662     # handle cases like (x, 3)
663     if isinstance(next_s, sympy.Integer):
664     # yield (x, x, x)
665     for copy_s in repeat(s,int(next_s)):
666     yield copy_s
667     else:
668     yield s
669     else:
670     yield s
671    
672 caltinay 3536 def __array__(self):
673     return self._arr
674    
675 caltinay 3507 # unary/binary operations follow
676    
677     def __pos__(self):
678     return self
679    
680     def __neg__(self):
681 caltinay 3530 return Symbol(-self._arr, dim=self.dim)
682 caltinay 3507
683     def __abs__(self):
684 caltinay 3530 return Symbol(abs(self._arr), dim=self.dim)
685 caltinay 3507
686     def __add__(self, other):
687 caltinay 3517 self._ensureShapeCompatible(other)
688 caltinay 3507 if isinstance(other, Symbol):
689 caltinay 3530 return Symbol(self._arr+other._arr, dim=self.dim)
690     return Symbol(self._arr+other, dim=self.dim)
691 caltinay 3507
692     def __radd__(self, other):
693 caltinay 3517 self._ensureShapeCompatible(other)
694 caltinay 3507 if isinstance(other, Symbol):
695 caltinay 3530 return Symbol(other._arr+self._arr, dim=self.dim)
696     return Symbol(other+self._arr, dim=self.dim)
697 caltinay 3507
698     def __sub__(self, other):
699 caltinay 3517 self._ensureShapeCompatible(other)
700 caltinay 3507 if isinstance(other, Symbol):
701 caltinay 3530 return Symbol(self._arr-other._arr, dim=self.dim)
702     return Symbol(self._arr-other, dim=self.dim)
703 caltinay 3507
704     def __rsub__(self, other):
705 caltinay 3517 self._ensureShapeCompatible(other)
706 caltinay 3507 if isinstance(other, Symbol):
707 caltinay 3530 return Symbol(other._arr-self._arr, dim=self.dim)
708     return Symbol(other-self._arr, dim=self.dim)
709 caltinay 3507
710     def __mul__(self, other):
711 caltinay 3517 self._ensureShapeCompatible(other)
712 caltinay 3507 if isinstance(other, Symbol):
713 caltinay 3530 return Symbol(self._arr*other._arr, dim=self.dim)
714     return Symbol(self._arr*other, dim=self.dim)
715 caltinay 3507
716     def __rmul__(self, other):
717 caltinay 3517 self._ensureShapeCompatible(other)
718 caltinay 3507 if isinstance(other, Symbol):
719 caltinay 3530 return Symbol(other._arr*self._arr, dim=self.dim)
720     return Symbol(other*self._arr, dim=self.dim)
721 caltinay 3507
722     def __div__(self, other):
723 caltinay 3517 self._ensureShapeCompatible(other)
724 caltinay 3507 if isinstance(other, Symbol):
725 caltinay 3530 return Symbol(self._arr/other._arr, dim=self.dim)
726     return Symbol(self._arr/other, dim=self.dim)
727 caltinay 3507
728     def __rdiv__(self, other):
729 caltinay 3517 self._ensureShapeCompatible(other)
730 caltinay 3507 if isinstance(other, Symbol):
731 caltinay 3530 return Symbol(other._arr/self._arr, dim=self.dim)
732     return Symbol(other/self._arr, dim=self.dim)
733 caltinay 3507
734     def __pow__(self, other):
735 caltinay 3517 self._ensureShapeCompatible(other)
736 caltinay 3507 if isinstance(other, Symbol):
737 caltinay 3530 return Symbol(self._arr**other._arr, dim=self.dim)
738     return Symbol(self._arr**other, dim=self.dim)
739 caltinay 3507
740     def __rpow__(self, other):
741 caltinay 3517 self._ensureShapeCompatible(other)
742 caltinay 3507 if isinstance(other, Symbol):
743 caltinay 3530 return Symbol(other._arr**self._arr, dim=self.dim)
744     return Symbol(other**self._arr, dim=self.dim)
745 caltinay 3507
746    
747     def symbols(*names, **kwargs):
748     """
749     Emulates the behaviour of sympy.symbols.
750     """
751    
752     shape=kwargs.pop('shape', ())
753    
754     s = names[0]
755     if not isinstance(s, list):
756     import re
757     s = re.split('\s|,', s)
758     res = []
759     for t in s:
760     # skip empty strings
761     if not t:
762     continue
763     sym = Symbol(t, shape, **kwargs)
764     res.append(sym)
765     res = tuple(res)
766     if len(res) == 0: # var('')
767     res = None
768     elif len(res) == 1: # var('x')
769     res = res[0]
770     # otherwise var('a b ...')
771     return res
772    
773     def combineData(array, shape):
774 caltinay 3533 """
775     """
776    
777 caltinay 3507 # array could just be a single value
778     if not hasattr(array,'__len__') and shape==():
779     return array
780    
781     from esys.escript import Data
782     n=numpy.array(array) # for indexing
783    
784     # find function space if any
785     dom=set()
786     fs=set()
787     for idx in numpy.ndindex(shape):
788     if isinstance(n[idx], Data):
789     fs.add(n[idx].getFunctionSpace())
790     dom.add(n[idx].getDomain())
791    
792     if len(dom)>1:
793     domain=dom.pop()
794     while len(dom)>0:
795     if domain!=dom.pop():
796     raise ValueError("Mixing of domains not supported")
797    
798     if len(fs)>0:
799     d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
800     else:
801 caltinay 3509 d=numpy.zeros(shape)
802 caltinay 3507 for idx in numpy.ndindex(shape):
803     #z=numpy.zeros(shape)
804     #z[idx]=1.
805     #d+=n[idx]*z # much slower!
806     if hasattr(n[idx], "ndim") and n[idx].ndim==0:
807     d[idx]=float(n[idx])
808     else:
809     d[idx]=n[idx]
810     return d
811    
812 caltinay 3517
813     class SymFunction(Symbol):
814     """
815     """
816     def __init__(self, *args, **kwargs):
817     """
818 caltinay 3533 Initialises a new symbolic function object.
819 caltinay 3517 """
820     super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
821     self.args=args
822    
823     def __repr__(self):
824     return self.name+"("+", ".join([str(a) for a in self.args])+")"
825    
826     def __str__(self):
827     return self.name+"("+", ".join([str(a) for a in self.args])+")"
828    
829     def lambdarepr(self):
830     return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
831    
832     def atoms(self, *types):
833     s=set()
834     for el in self.args:
835     atoms=el.atoms(*types)
836     for a in atoms:
837     if a.is_Symbol:
838     n,c=Symbol._symComp(a)
839     s.add(sympy.Symbol(n))
840     else:
841     s.add(a)
842     return s
843    
844     def __neg__(self):
845     res=self.__class__(*self.args)
846     res._arr=-res._arr
847     return res
848    

  ViewVC Help
Powered by ViewVC 1.1.26