/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3536 - (hide annotations)
Thu Jun 23 04:42:38 2011 UTC (8 years, 4 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 26840 byte(s)
Rewrote Symbol.coeff() and added __array__ attribute so binary operations with
a numpy array work both ways.

1 caltinay 3507 # -*- coding: utf-8 -*-
2    
3     ########################################################
4     #
5     # Copyright (c) 2003-2010 by University of Queensland
6     # Earth Systems Science Computational Center (ESSCC)
7     # http://www.uq.edu.au/esscc
8     #
9     # Primary Business: Queensland, Australia
10     # Licensed under the Open Software License version 3.0
11     # http://www.opensource.org/licenses/osl-3.0.php
12     #
13     ########################################################
14    
15     __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16     Earth Systems Science Computational Center (ESSCC)
17     http://www.uq.edu.au/esscc
18     Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22    
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34    
35     __author__="Cihan Altinay"
36    
37    
38     class Symbol(object):
39     """
40 caltinay 3533 `Symbol` objects are placeholders for a single mathematic symbol, such as
41     'x', or for arbitrarily complex mathematic expressions such as
42     'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
43     are also `Symbol`s (the symbolic 'atoms' of the expression).
44    
45     With the help of the 'Evaluator' class these symbols and expressions can
46     be resolved by substituting numeric values and/or escript `Data` objects
47     for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
48     shape (and thus a rank) as well as a dimension (see constructor).
49     `Symbol`s are useful to perform mathematic simplifications, compute
50     derivatives and as coefficients for nonlinear PDEs which can be solved by
51     the `NonlinearPDE` class.
52 caltinay 3507 """
53    
54     def __init__(self, *args, **kwargs):
55     """
56 caltinay 3533 Initialises a new `Symbol` object in one of three ways::
57    
58     u=Symbol('u')
59    
60     returns a scalar symbol by the name 'u'.
61    
62 caltinay 3535 a=Symbol('alpha', (4,3))
63 caltinay 3533
64     returns a rank 2 symbol with the shape (4,3), whose elements are
65     named '[alpha]_i_j' (with i=0..3, j=0..2).
66    
67     a,b,c=symbols('a,b,c')
68     x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
69    
70     returns a rank 2 symbol with the shape (3,3) whose elements are
71 caltinay 3536 explicitly specified by numeric values and other symbols/expressions
72 caltinay 3533 within a list or numpy array.
73    
74     The dimensionality of the symbol can be specified through the `dim`
75     keyword. All other keywords are passed to the underlying symbolic
76     library (currently sympy).
77    
78     :param args: initialisation arguments as described above
79     :keyword dim: dimensionality of the new Symbol (default: 2)
80     :type dim: ``int``
81 caltinay 3507 """
82 caltinay 3530 if 'dim' in kwargs:
83     self.dim=kwargs.pop('dim')
84     else:
85     self.dim=2
86    
87 caltinay 3507 if len(args)==1:
88     arg=args[0]
89     if isinstance(arg, str):
90 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
91 caltinay 3533 raise ValueError("Name must not contain '[' or ']'")
92 caltinay 3512 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
93 caltinay 3533 elif hasattr(arg, "__array__") or isinstance(arg, list):
94     if isinstance(arg, list): arg=numpy.array(arg)
95 caltinay 3507 arr=arg.__array__()
96     if len(arr.shape)>4:
97     raise ValueError("Symbol only supports tensors up to order 4")
98 caltinay 3533 res=numpy.empty(arr.shape, dtype=object)
99     for idx in numpy.ndindex(arr.shape):
100     if hasattr(arr[idx], "item"):
101     res[idx]=arr[idx].item()
102     else:
103     res[idx]=arr[idx]
104     self._arr=res
105     elif isinstance(arg, sympy.Basic):
106 caltinay 3512 self._arr=numpy.array(arg)
107 caltinay 3507 else:
108     raise TypeError("Unsupported argument type %s"%str(type(arg)))
109     elif len(args)==2:
110     if not isinstance(args[0], str):
111     raise TypeError("First argument must be a string")
112     if not isinstance(args[1], tuple):
113     raise TypeError("Second argument must be a tuple")
114     name=args[0]
115     shape=args[1]
116 caltinay 3533 if name.find('[')>=0 or name.find(']')>=0:
117     raise ValueError("Name must not contain '[' or ']'")
118 caltinay 3507 if len(shape)>4:
119     raise ValueError("Symbol only supports tensors up to order 4")
120 caltinay 3509 if len(shape)==0:
121 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
122 caltinay 3509 else:
123 caltinay 3512 self._arr=sympy.symarray(shape, '['+name+']')
124 caltinay 3507 else:
125     raise TypeError("Unsupported number of arguments")
126 caltinay 3512 if self._arr.ndim==0:
127     self.name=str(self._arr.item())
128 caltinay 3507 else:
129 caltinay 3512 self.name=str(self._arr.tolist())
130 caltinay 3507
131     def __repr__(self):
132 caltinay 3512 return str(self._arr)
133 caltinay 3507
134     def __str__(self):
135 caltinay 3512 return str(self._arr)
136 caltinay 3507
137     def __eq__(self, other):
138     if type(self) is not type(other):
139     return False
140     if self.getRank()!=other.getRank():
141     return False
142     if self.getShape()!=other.getShape():
143     return False
144 caltinay 3512 return (self._arr==other._arr).all()
145 caltinay 3507
146     def __getitem__(self, key):
147 caltinay 3512 return self._arr[key]
148 caltinay 3507
149     def __setitem__(self, key, value):
150 caltinay 3533 if isinstance(value, Symbol):
151 caltinay 3507 if value.getRank()==0:
152 caltinay 3533 self._arr[key]=value.item()
153 caltinay 3512 elif hasattr(self._arr[key], "shape"):
154     if self._arr[key].shape==value.getShape():
155 caltinay 3533 for idx in numpy.ndindex(self._arr[key].shape):
156     self._arr[key][idx]=value[idx]
157 caltinay 3507 else:
158     raise ValueError("Wrong shape of value")
159     else:
160     raise ValueError("Wrong shape of value")
161 caltinay 3533 elif isinstance(value, sympy.Basic):
162 caltinay 3512 self._arr[key]=value
163 caltinay 3507 elif hasattr(value, "__array__"):
164 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
165 caltinay 3507 else:
166 caltinay 3512 self._arr[key]=sympy.sympify(value)
167 caltinay 3507
168 caltinay 3533 def getDim(self):
169     """
170     Returns the spatial dimensionality of this symbol.
171    
172     :return: the symbol's spatial dimensionality
173     :rtype: ``int``
174     """
175     return self.dim
176    
177 caltinay 3507 def getRank(self):
178 caltinay 3533 """
179     Returns the rank of this symbol.
180    
181     :return: the symbol's rank which is equal to the length of the shape.
182     :rtype: ``int``
183     """
184 caltinay 3512 return self._arr.ndim
185 caltinay 3507
186     def getShape(self):
187 caltinay 3533 """
188     Returns the shape of this symbol.
189    
190     :return: the symbol's shape
191     :rtype: ``tuple`` of ``int``
192     """
193 caltinay 3512 return self._arr.shape
194 caltinay 3507
195 caltinay 3533 def item(self, *args):
196     """
197     Returns an element of this symbol.
198     This method behaves like the item() method of numpy.ndarray.
199     If this is a scalar Symbol, no arguments are allowed and the only
200     element in this Symbol is returned.
201     Otherwise, 'args' specifies a flat or nd-index and the element at
202     that index is returned.
203    
204     :param args: index of item to be returned
205     :return: the requested element
206     :rtype: ``sympy.Symbol``, ``int``, or ``float``
207     """
208     return self._arr.item(args)
209    
210 caltinay 3507 def atoms(self, *types):
211 caltinay 3533 """
212     Returns the atoms that form the current Symbol.
213    
214     By default, only objects that are truly atomic and cannot be divided
215     into smaller pieces are returned: symbols, numbers, and number
216     symbols like I and pi. It is possible to request atoms of any type,
217     however.
218    
219     Note that if this symbol contains components such as [x]_i_j then
220     only their main symbol 'x' is returned.
221    
222     :param types: types to restrict result to
223     :return: list of atoms of specified type
224     :rtype: ``set``
225     """
226 caltinay 3507 s=set()
227 caltinay 3512 for el in self._arr.flat:
228 caltinay 3530 if isinstance(el,sympy.Basic):
229     atoms=el.atoms(*types)
230     for a in atoms:
231     if a.is_Symbol:
232     n,c=Symbol._symComp(a)
233     s.add(sympy.Symbol(n))
234     else:
235     s.add(a)
236 caltinay 3533 elif len(types)==0 or type(el) in types:
237     s.add(el)
238 caltinay 3507 return s
239    
240     def _sympystr_(self, printer):
241     return self.lambdarepr()
242    
243     def lambdarepr(self):
244     from sympy.printing.lambdarepr import lambdarepr
245 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
246     for idx,el in numpy.ndenumerate(self._arr):
247 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
248 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
249 caltinay 3507 symdict={}
250     for a in atoms:
251     n,c=Symbol._symComp(a)
252     if len(c)>0:
253     c=[str(i) for i in c]
254     symstr=n+'['+','.join(c)+']'
255     else:
256     symstr=n
257     symdict[a.name]=symstr
258     s=lambdarepr(el)
259     for key in symdict:
260     s=s.replace(key, symdict[key])
261     temp_arr[idx]=s
262 caltinay 3517 if self.getRank()==0:
263     return temp_arr.item()
264     else:
265     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
266 caltinay 3507
267 caltinay 3530 def coeff(self, x, expand=True):
268 caltinay 3536 """
269     Returns the coefficient of the term "x" or 0 if there is no "x".
270    
271     If "x" is a scalar symbol then "x" is searched in all components of
272     this symbol. Otherwise the shapes must match and the coefficients are
273     checked component by component.
274    
275     Example::
276    
277     x=Symbol('x', (2,2))
278     y=3*x
279     print y.coeff(x)
280     print y.coeff(x[1,1])
281    
282     will print::
283    
284     [[3 3]
285     [3 3]]
286    
287     [[0 0]
288     [0 3]]
289    
290     :param x: the term whose coefficients are to be found
291     :type x: ``Symbol``, ``numpy.ndarray``, `list`
292     :return: the coefficient(s) of the term
293     :rtype: ``Symbol``
294     """
295 caltinay 3530 self._ensureShapeCompatible(x)
296 caltinay 3536 if hasattr(x, '__array__'):
297     y=x.__array__()
298 caltinay 3530 else:
299 caltinay 3536 y=numpy.array(x)
300 caltinay 3530
301 caltinay 3536 if y.ndim>0:
302     result=numpy.zeros(self.getShape(), dtype=object)
303     for idx in numpy.ndindex(y.shape):
304     if y[idx]!=0:
305     res=self[idx].coeff(y[idx], expand)
306     if res is not None:
307     result[idx]=res
308     elif y.item()==0:
309     result=numpy.zeros(self.getShape(), dtype=object)
310     else:
311     coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
312     none_to_zero=lambda item: 0 if item is None else item
313     result=self.applyfunc(coeff_item)
314     result=result.applyfunc(none_to_zero)._arr
315     return Symbol(result, dim=self.dim)
316 caltinay 3530
317 caltinay 3507 def diff(self, *symbols, **assumptions):
318 caltinay 3533 """
319     """
320 caltinay 3507 symbols=Symbol._symbolgen(*symbols)
321 caltinay 3530 result=Symbol(self._arr, dim=self.dim)
322 caltinay 3507 for s in symbols:
323     if isinstance(s, Symbol):
324 caltinay 3530 if s.getRank()==0:
325 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
326 caltinay 3507 result=result.applyfunc(diff_item)
327 caltinay 3530 elif s.getRank()==1:
328     dim=s.getShape()[0]
329     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
330     for d in range(dim):
331     for idx in numpy.ndindex(self.getShape()):
332     index=idx+(d,)
333     out[index]=out[index].diff(s[d], **assumptions)
334     result=Symbol(out, dim=self.dim)
335     else:
336     raise ValueError("diff: Only rank 0 and 1 supported")
337 caltinay 3507 else:
338     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
339     result=result.applyfunc(diff_item)
340     return result
341    
342 caltinay 3512 def grad(self, where=None):
343 caltinay 3533 """
344     """
345 caltinay 3512 if isinstance(where, Symbol):
346     if where.getRank()>0:
347 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
348 caltinay 3512 where=where._arr.item()
349    
350     from functions import grad_n
351 caltinay 3530 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
352     for d in range(self.dim):
353 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
354     index=idx+(d,)
355     if where is None:
356     out[index]=grad_n(out[index],d)
357     else:
358     out[index]=grad_n(out[index],d,where)
359 caltinay 3530 return Symbol(out, dim=self.dim)
360 caltinay 3512
361 caltinay 3517 def inverse(self):
362 caltinay 3533 """
363     """
364 caltinay 3517 if not self.getRank()==2:
365 caltinay 3533 raise TypeError("inverse: Only rank 2 supported")
366 caltinay 3517 s=self.getShape()
367     if not s[0] == s[1]:
368     raise ValueError("inverse: Only square shapes supported")
369     out=numpy.zeros(s, numpy.object)
370     arr=self._arr
371     if s[0]==1:
372     if arr[0,0].is_zero:
373     raise ZeroDivisionError("inverse: Symbol not invertible")
374     out[0,0]=1./arr[0,0]
375     elif s[0]==2:
376     A11=arr[0,0]
377     A12=arr[0,1]
378     A21=arr[1,0]
379     A22=arr[1,1]
380     D = A11*A22-A12*A21
381     if D.is_zero:
382     raise ZeroDivisionError("inverse: Symbol not invertible")
383     D=1./D
384     out[0,0]= A22*D
385     out[1,0]=-A21*D
386     out[0,1]=-A12*D
387     out[1,1]= A11*D
388     elif s[0]==3:
389     A11=arr[0,0]
390     A21=arr[1,0]
391     A31=arr[2,0]
392     A12=arr[0,1]
393     A22=arr[1,1]
394     A32=arr[2,1]
395     A13=arr[0,2]
396     A23=arr[1,2]
397     A33=arr[2,2]
398     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
399     if D.is_zero:
400     raise ZeroDivisionError("inverse: Symbol not invertible")
401     D=1./D
402     out[0,0]=(A22*A33-A23*A32)*D
403     out[1,0]=(A31*A23-A21*A33)*D
404     out[2,0]=(A21*A32-A31*A22)*D
405     out[0,1]=(A13*A32-A12*A33)*D
406     out[1,1]=(A11*A33-A31*A13)*D
407     out[2,1]=(A12*A31-A11*A32)*D
408     out[0,2]=(A12*A23-A13*A22)*D
409     out[1,2]=(A13*A21-A11*A23)*D
410     out[2,2]=(A11*A22-A12*A21)*D
411     else:
412     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
413 caltinay 3530 return Symbol(out, dim=self.dim)
414 caltinay 3517
415 caltinay 3507 def swap_axes(self, axis0, axis1):
416 caltinay 3533 """
417     """
418 caltinay 3530 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
419 caltinay 3507
420 caltinay 3509 def tensorProduct(self, other, axis_offset):
421 caltinay 3533 """
422     """
423 caltinay 3512 arg0_c=self._arr.copy()
424     sh0=self.getShape()
425 caltinay 3507 if isinstance(other, Symbol):
426 caltinay 3512 arg1_c=other._arr.copy()
427 caltinay 3507 sh1=other.getShape()
428     else:
429     arg1_c=other.copy()
430     sh1=other.shape
431     d0,d1,d01=1,1,1
432 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
433 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
434     for i in sh1[:axis_offset]: d01*=i
435     arg0_c.resize((d0,d01))
436     arg1_c.resize((d01,d1))
437     out=numpy.zeros((d0,d1),numpy.object)
438     for i0 in range(d0):
439     for i1 in range(d1):
440     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
441 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
442 caltinay 3530 return Symbol(out, dim=self.dim)
443 caltinay 3507
444 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
445 caltinay 3533 """
446     """
447 caltinay 3512 arg0_c=self._arr.copy()
448     sh0=self.getShape()
449 caltinay 3509 if isinstance(other, Symbol):
450 caltinay 3512 arg1_c=other._arr.copy()
451 caltinay 3509 sh1=other.getShape()
452     else:
453     arg1_c=other.copy()
454     sh1=other.shape
455     d0,d1,d01=1,1,1
456     for i in sh0[axis_offset:]: d0*=i
457     for i in sh1[axis_offset:]: d1*=i
458     for i in sh1[:axis_offset]: d01*=i
459     arg0_c.resize((d01,d0))
460     arg1_c.resize((d01,d1))
461     out=numpy.zeros((d0,d1),numpy.object)
462     for i0 in range(d0):
463     for i1 in range(d1):
464     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
465     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
466 caltinay 3530 return Symbol(out, dim=self.dim)
467 caltinay 3509
468     def tensorTransposedProduct(self, other, axis_offset):
469 caltinay 3533 """
470     """
471 caltinay 3512 arg0_c=self._arr.copy()
472     sh0=self.getShape()
473 caltinay 3509 if isinstance(other, Symbol):
474 caltinay 3512 arg1_c=other._arr.copy()
475 caltinay 3509 sh1=other.getShape()
476     r1=other.getRank()
477     else:
478     arg1_c=other.copy()
479     sh1=other.shape
480     r1=other.ndim
481     d0,d1,d01=1,1,1
482 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
483 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
484     for i in sh1[r1-axis_offset:]: d01*=i
485     arg0_c.resize((d0,d01))
486     arg1_c.resize((d1,d01))
487     out=numpy.zeros((d0,d1),numpy.object)
488     for i0 in range(d0):
489     for i1 in range(d1):
490     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
491 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
492 caltinay 3530 return Symbol(out, dim=self.dim)
493 caltinay 3509
494 caltinay 3507 def trace(self, axis_offset):
495 caltinay 3533 """
496     """
497 caltinay 3512 sh=self.getShape()
498 caltinay 3507 s1=1
499     for i in range(axis_offset): s1*=sh[i]
500     s2=1
501     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
502 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
503 caltinay 3507 out=numpy.zeros([s1,s2],object)
504     for i1 in range(s1):
505     for i2 in range(s2):
506     for j in range(sh[axis_offset]):
507     out[i1,i2]+=arr_r[i1,j,j,i2]
508     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
509 caltinay 3530 return Symbol(out, dim=self.dim)
510 caltinay 3507
511     def transpose(self, axis_offset):
512 caltinay 3533 """
513     """
514 caltinay 3507 if axis_offset is None:
515 caltinay 3512 axis_offset=int(self._arr.ndim/2)
516     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
517 caltinay 3530 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
518 caltinay 3507
519 caltinay 3532 def applyfunc(self, f, on_type=None):
520 caltinay 3533 """
521     """
522 caltinay 3507 assert callable(f)
523 caltinay 3512 if self._arr.ndim==0:
524 caltinay 3532 if on_type is None or isinstance(self._arr.item(),on_type):
525     el=f(self._arr.item())
526     else:
527     el=self._arr.item()
528 caltinay 3530 if el is not None:
529     out=Symbol(el, dim=self.dim)
530     else:
531     return el
532 caltinay 3507 else:
533 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
534     for idx in numpy.ndindex(self.getShape()):
535 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
536     out[idx]=f(self._arr[idx])
537     else:
538     out[idx]=self._arr[idx]
539 caltinay 3530 out=Symbol(out, dim=self.dim)
540 caltinay 3507 return out
541    
542 caltinay 3532 def simplify(self):
543 caltinay 3533 """
544     """
545 caltinay 3532 return self.applyfunc(sympy.simplify, sympy.Basic)
546    
547 caltinay 3507 def _sympy_(self):
548 caltinay 3533 """
549     """
550 caltinay 3507 return self.applyfunc(sympy.sympify)
551    
552 caltinay 3517 def _ensureShapeCompatible(self, other):
553     """
554     Checks for compatible shapes for binary operations.
555     Raises TypeError if not compatible.
556     """
557     sh0=self.getShape()
558     if isinstance(other, Symbol):
559     sh1=other.getShape()
560     elif isinstance(other, numpy.ndarray):
561     sh1=other.shape
562 caltinay 3536 elif isinstance(other, list):
563     sh1=numpy.array(other).shape
564 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
565 caltinay 3517 sh1=()
566     else:
567     raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)
568     if not sh0==sh1 and not sh0==() and not sh1==():
569     raise TypeError("Incompatible shapes for binary operation")
570    
571 caltinay 3507 @staticmethod
572     def _symComp(sym):
573     n=sym.name
574     a=n.split('[')
575     if len(a)!=2:
576     return n,()
577     a=a[1].split(']')
578     if len(a)!=2:
579     return n,()
580     name=a[0]
581     comps=[int(i) for i in a[1].split('_')[1:]]
582     return name,tuple(comps)
583    
584     @staticmethod
585     def _symbolgen(*symbols):
586     """
587     Generator of all symbols in the argument of diff().
588     (cf. sympy.Derivative._symbolgen)
589    
590     Example:
591     >> ._symbolgen(x, 3, y)
592     (x, x, x, y)
593     >> ._symbolgen(x, 10**6)
594     (x, x, x, x, x, x, x, ...)
595     """
596     from itertools import repeat
597     last_s = symbols[len(symbols)-1]
598     if not isinstance(last_s, Symbol):
599     last_s=sympy.sympify(last_s)
600     for i in xrange(len(symbols)):
601     s = symbols[i]
602     if not isinstance(s, Symbol):
603     s=sympy.sympify(s)
604     next_s = None
605     if s != last_s:
606     next_s = symbols[i+1]
607     if not isinstance(next_s, Symbol):
608     next_s=sympy.sympify(next_s)
609    
610     if isinstance(s, sympy.Integer):
611     continue
612     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
613     # handle cases like (x, 3)
614     if isinstance(next_s, sympy.Integer):
615     # yield (x, x, x)
616     for copy_s in repeat(s,int(next_s)):
617     yield copy_s
618     else:
619     yield s
620     else:
621     yield s
622    
623 caltinay 3536 def __array__(self):
624     return self._arr
625    
626 caltinay 3507 # unary/binary operations follow
627    
628     def __pos__(self):
629     return self
630    
631     def __neg__(self):
632 caltinay 3530 return Symbol(-self._arr, dim=self.dim)
633 caltinay 3507
634     def __abs__(self):
635 caltinay 3530 return Symbol(abs(self._arr), dim=self.dim)
636 caltinay 3507
637     def __add__(self, other):
638 caltinay 3517 self._ensureShapeCompatible(other)
639 caltinay 3507 if isinstance(other, Symbol):
640 caltinay 3530 return Symbol(self._arr+other._arr, dim=self.dim)
641     return Symbol(self._arr+other, dim=self.dim)
642 caltinay 3507
643     def __radd__(self, other):
644 caltinay 3517 self._ensureShapeCompatible(other)
645 caltinay 3507 if isinstance(other, Symbol):
646 caltinay 3530 return Symbol(other._arr+self._arr, dim=self.dim)
647     return Symbol(other+self._arr, dim=self.dim)
648 caltinay 3507
649     def __sub__(self, other):
650 caltinay 3517 self._ensureShapeCompatible(other)
651 caltinay 3507 if isinstance(other, Symbol):
652 caltinay 3530 return Symbol(self._arr-other._arr, dim=self.dim)
653     return Symbol(self._arr-other, dim=self.dim)
654 caltinay 3507
655     def __rsub__(self, other):
656 caltinay 3517 self._ensureShapeCompatible(other)
657 caltinay 3507 if isinstance(other, Symbol):
658 caltinay 3530 return Symbol(other._arr-self._arr, dim=self.dim)
659     return Symbol(other-self._arr, dim=self.dim)
660 caltinay 3507
661     def __mul__(self, other):
662 caltinay 3517 self._ensureShapeCompatible(other)
663 caltinay 3507 if isinstance(other, Symbol):
664 caltinay 3530 return Symbol(self._arr*other._arr, dim=self.dim)
665     return Symbol(self._arr*other, dim=self.dim)
666 caltinay 3507
667     def __rmul__(self, other):
668 caltinay 3517 self._ensureShapeCompatible(other)
669 caltinay 3507 if isinstance(other, Symbol):
670 caltinay 3530 return Symbol(other._arr*self._arr, dim=self.dim)
671     return Symbol(other*self._arr, dim=self.dim)
672 caltinay 3507
673     def __div__(self, other):
674 caltinay 3517 self._ensureShapeCompatible(other)
675 caltinay 3507 if isinstance(other, Symbol):
676 caltinay 3530 return Symbol(self._arr/other._arr, dim=self.dim)
677     return Symbol(self._arr/other, dim=self.dim)
678 caltinay 3507
679     def __rdiv__(self, other):
680 caltinay 3517 self._ensureShapeCompatible(other)
681 caltinay 3507 if isinstance(other, Symbol):
682 caltinay 3530 return Symbol(other._arr/self._arr, dim=self.dim)
683     return Symbol(other/self._arr, dim=self.dim)
684 caltinay 3507
685     def __pow__(self, other):
686 caltinay 3517 self._ensureShapeCompatible(other)
687 caltinay 3507 if isinstance(other, Symbol):
688 caltinay 3530 return Symbol(self._arr**other._arr, dim=self.dim)
689     return Symbol(self._arr**other, dim=self.dim)
690 caltinay 3507
691     def __rpow__(self, other):
692 caltinay 3517 self._ensureShapeCompatible(other)
693 caltinay 3507 if isinstance(other, Symbol):
694 caltinay 3530 return Symbol(other._arr**self._arr, dim=self.dim)
695     return Symbol(other**self._arr, dim=self.dim)
696 caltinay 3507
697    
698     def symbols(*names, **kwargs):
699     """
700     Emulates the behaviour of sympy.symbols.
701     """
702    
703     shape=kwargs.pop('shape', ())
704    
705     s = names[0]
706     if not isinstance(s, list):
707     import re
708     s = re.split('\s|,', s)
709     res = []
710     for t in s:
711     # skip empty strings
712     if not t:
713     continue
714     sym = Symbol(t, shape, **kwargs)
715     res.append(sym)
716     res = tuple(res)
717     if len(res) == 0: # var('')
718     res = None
719     elif len(res) == 1: # var('x')
720     res = res[0]
721     # otherwise var('a b ...')
722     return res
723    
724     def combineData(array, shape):
725 caltinay 3533 """
726     """
727    
728 caltinay 3507 # array could just be a single value
729     if not hasattr(array,'__len__') and shape==():
730     return array
731    
732     from esys.escript import Data
733     n=numpy.array(array) # for indexing
734    
735     # find function space if any
736     dom=set()
737     fs=set()
738     for idx in numpy.ndindex(shape):
739     if isinstance(n[idx], Data):
740     fs.add(n[idx].getFunctionSpace())
741     dom.add(n[idx].getDomain())
742    
743     if len(dom)>1:
744     domain=dom.pop()
745     while len(dom)>0:
746     if domain!=dom.pop():
747     raise ValueError("Mixing of domains not supported")
748    
749     if len(fs)>0:
750     d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
751     else:
752 caltinay 3509 d=numpy.zeros(shape)
753 caltinay 3507 for idx in numpy.ndindex(shape):
754     #z=numpy.zeros(shape)
755     #z[idx]=1.
756     #d+=n[idx]*z # much slower!
757     if hasattr(n[idx], "ndim") and n[idx].ndim==0:
758     d[idx]=float(n[idx])
759     else:
760     d[idx]=n[idx]
761     return d
762    
763 caltinay 3517
764     class SymFunction(Symbol):
765     """
766     """
767     def __init__(self, *args, **kwargs):
768     """
769 caltinay 3533 Initialises a new symbolic function object.
770 caltinay 3517 """
771     super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
772     self.args=args
773    
774     def __repr__(self):
775     return self.name+"("+", ".join([str(a) for a in self.args])+")"
776    
777     def __str__(self):
778     return self.name+"("+", ".join([str(a) for a in self.args])+")"
779    
780     def lambdarepr(self):
781     return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
782    
783     def atoms(self, *types):
784     s=set()
785     for el in self.args:
786     atoms=el.atoms(*types)
787     for a in atoms:
788     if a.is_Symbol:
789     n,c=Symbol._symComp(a)
790     s.add(sympy.Symbol(n))
791     else:
792     s.add(a)
793     return s
794    
795     def __neg__(self):
796     res=self.__class__(*self.args)
797     res._arr=-res._arr
798     return res
799    

  ViewVC Help
Powered by ViewVC 1.1.26