/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3532 - (hide annotations)
Mon Jun 20 04:14:42 2011 UTC (8 years, 4 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 22246 byte(s)
Added simplify method to Symbol class.

1 caltinay 3507 # -*- coding: utf-8 -*-
2    
3     ########################################################
4     #
5     # Copyright (c) 2003-2010 by University of Queensland
6     # Earth Systems Science Computational Center (ESSCC)
7     # http://www.uq.edu.au/esscc
8     #
9     # Primary Business: Queensland, Australia
10     # Licensed under the Open Software License version 3.0
11     # http://www.opensource.org/licenses/osl-3.0.php
12     #
13     ########################################################
14    
15     __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16     Earth Systems Science Computational Center (ESSCC)
17     http://www.uq.edu.au/esscc
18     Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22    
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34    
35     __author__="Cihan Altinay"
36    
37    
38     class Symbol(object):
39     """
40     """
41    
42     def __init__(self, *args, **kwargs):
43     """
44     Initializes a new Symbol object.
45     """
46 caltinay 3530 if 'dim' in kwargs:
47     self.dim=kwargs.pop('dim')
48     else:
49     self.dim=2
50    
51 caltinay 3507 if len(args)==1:
52     arg=args[0]
53     if isinstance(arg, str):
54 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
55     raise TypeError("Name must not contain '[' or ']'")
56     self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
57 caltinay 3507 elif hasattr(arg, "__array__"):
58     arr=arg.__array__()
59     if len(arr.shape)>4:
60     raise ValueError("Symbol only supports tensors up to order 4")
61 caltinay 3512 self._arr=arr.copy()
62 caltinay 3507 elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
63 caltinay 3512 self._arr=numpy.array(arg)
64 caltinay 3507 else:
65     raise TypeError("Unsupported argument type %s"%str(type(arg)))
66     elif len(args)==2:
67     if not isinstance(args[0], str):
68     raise TypeError("First argument must be a string")
69     if args[0].find('[')>=0 or args[0].find(']')>=0:
70     raise TypeError("Name must not contain '[' or ']'")
71     if not isinstance(args[1], tuple):
72     raise TypeError("Second argument must be a tuple")
73     name=args[0]
74     shape=args[1]
75     if len(shape)>4:
76     raise ValueError("Symbol only supports tensors up to order 4")
77 caltinay 3509 if len(shape)==0:
78 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
79 caltinay 3509 else:
80 caltinay 3512 self._arr=sympy.symarray(shape, '['+name+']')
81 caltinay 3507 else:
82     raise TypeError("Unsupported number of arguments")
83 caltinay 3512 if self._arr.ndim==0:
84     self.name=str(self._arr.item())
85 caltinay 3507 else:
86 caltinay 3512 self.name=str(self._arr.tolist())
87 caltinay 3507
88     def __repr__(self):
89 caltinay 3512 return str(self._arr)
90 caltinay 3507
91     def __str__(self):
92 caltinay 3512 return str(self._arr)
93 caltinay 3507
94     def __eq__(self, other):
95     if type(self) is not type(other):
96     return False
97     if self.getRank()!=other.getRank():
98     return False
99     if self.getShape()!=other.getShape():
100     return False
101 caltinay 3512 return (self._arr==other._arr).all()
102 caltinay 3507
103     def __getitem__(self, key):
104 caltinay 3512 return self._arr[key]
105 caltinay 3507
106     def __setitem__(self, key, value):
107     if isinstance(value,Symbol):
108     if value.getRank()==0:
109 caltinay 3512 self._arr[key]=value
110     elif hasattr(self._arr[key], "shape"):
111     if self._arr[key].shape==value.getShape():
112     self._arr[key]=value
113 caltinay 3507 else:
114     raise ValueError("Wrong shape of value")
115     else:
116     raise ValueError("Wrong shape of value")
117     elif isinstance(value,sympy.Basic):
118 caltinay 3512 self._arr[key]=value
119 caltinay 3507 elif hasattr(value, "__array__"):
120 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
121 caltinay 3507 else:
122 caltinay 3512 self._arr[key]=sympy.sympify(value)
123 caltinay 3507
124     def getRank(self):
125 caltinay 3512 return self._arr.ndim
126 caltinay 3507
127     def getShape(self):
128 caltinay 3512 return self._arr.shape
129 caltinay 3507
130     def atoms(self, *types):
131     s=set()
132 caltinay 3512 for el in self._arr.flat:
133 caltinay 3530 if isinstance(el,sympy.Basic):
134     atoms=el.atoms(*types)
135     for a in atoms:
136     if a.is_Symbol:
137     n,c=Symbol._symComp(a)
138     s.add(sympy.Symbol(n))
139     else:
140     s.add(a)
141     else:
142     # TODO: Numbers?
143     pass
144 caltinay 3507 return s
145    
146     def _sympystr_(self, printer):
147     return self.lambdarepr()
148    
149     def lambdarepr(self):
150     from sympy.printing.lambdarepr import lambdarepr
151 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
152     for idx,el in numpy.ndenumerate(self._arr):
153 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
154 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
155 caltinay 3507 symdict={}
156     for a in atoms:
157     n,c=Symbol._symComp(a)
158     if len(c)>0:
159     c=[str(i) for i in c]
160     symstr=n+'['+','.join(c)+']'
161     else:
162     symstr=n
163     symdict[a.name]=symstr
164     s=lambdarepr(el)
165     for key in symdict:
166     s=s.replace(key, symdict[key])
167     temp_arr[idx]=s
168 caltinay 3517 if self.getRank()==0:
169     return temp_arr.item()
170     else:
171     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
172 caltinay 3507
173 caltinay 3530 def coeff(self, x, expand=True):
174     self._ensureShapeCompatible(x)
175     result=Symbol(self._arr, dim=self.dim)
176     if isinstance(x, Symbol):
177     if x.getRank()>0:
178     a=result._arr.flat
179     b=x._arr.flat
180     for idx in range(len(a)):
181     s=b.next()
182     if s==0:
183     a[idx]=0
184     else:
185     a[idx]=a[idx].coeff(s, expand)
186     else:
187     if x._arr.item()==0:
188     result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
189     else:
190     coeff_item=lambda item: getattr(item, 'coeff')(x._arr.item(), expand)
191     result=result.applyfunc(coeff_item)
192     elif x==0:
193     result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
194     else:
195     coeff_item=lambda item: getattr(item, 'coeff')(x, expand)
196     result=result.applyfunc(coeff_item)
197    
198     # replace None by 0
199     if result is None: return 0
200     a=result._arr.flat
201     for idx in range(len(a)):
202     if a[idx] is None: a[idx]=0
203     return result
204    
205 caltinay 3507 def diff(self, *symbols, **assumptions):
206     symbols=Symbol._symbolgen(*symbols)
207 caltinay 3530 result=Symbol(self._arr, dim=self.dim)
208 caltinay 3507 for s in symbols:
209     if isinstance(s, Symbol):
210 caltinay 3530 if s.getRank()==0:
211 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
212 caltinay 3507 result=result.applyfunc(diff_item)
213 caltinay 3530 elif s.getRank()==1:
214     dim=s.getShape()[0]
215     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
216     for d in range(dim):
217     for idx in numpy.ndindex(self.getShape()):
218     index=idx+(d,)
219     out[index]=out[index].diff(s[d], **assumptions)
220     result=Symbol(out, dim=self.dim)
221     else:
222     raise ValueError("diff: Only rank 0 and 1 supported")
223 caltinay 3507 else:
224     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
225     result=result.applyfunc(diff_item)
226     return result
227    
228 caltinay 3512 def grad(self, where=None):
229     if isinstance(where, Symbol):
230     if where.getRank()>0:
231 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
232 caltinay 3512 where=where._arr.item()
233    
234     from functions import grad_n
235 caltinay 3530 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
236     for d in range(self.dim):
237 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
238     index=idx+(d,)
239     if where is None:
240     out[index]=grad_n(out[index],d)
241     else:
242     out[index]=grad_n(out[index],d,where)
243 caltinay 3530 return Symbol(out, dim=self.dim)
244 caltinay 3512
245 caltinay 3517 def inverse(self):
246     if not self.getRank()==2:
247     raise ValueError("inverse: Only rank 2 supported")
248     s=self.getShape()
249     if not s[0] == s[1]:
250     raise ValueError("inverse: Only square shapes supported")
251     out=numpy.zeros(s, numpy.object)
252     arr=self._arr
253     if s[0]==1:
254     if arr[0,0].is_zero:
255     raise ZeroDivisionError("inverse: Symbol not invertible")
256     out[0,0]=1./arr[0,0]
257     elif s[0]==2:
258     A11=arr[0,0]
259     A12=arr[0,1]
260     A21=arr[1,0]
261     A22=arr[1,1]
262     D = A11*A22-A12*A21
263     if D.is_zero:
264     raise ZeroDivisionError("inverse: Symbol not invertible")
265     D=1./D
266     out[0,0]= A22*D
267     out[1,0]=-A21*D
268     out[0,1]=-A12*D
269     out[1,1]= A11*D
270     elif s[0]==3:
271     A11=arr[0,0]
272     A21=arr[1,0]
273     A31=arr[2,0]
274     A12=arr[0,1]
275     A22=arr[1,1]
276     A32=arr[2,1]
277     A13=arr[0,2]
278     A23=arr[1,2]
279     A33=arr[2,2]
280     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
281     if D.is_zero:
282     raise ZeroDivisionError("inverse: Symbol not invertible")
283     D=1./D
284     out[0,0]=(A22*A33-A23*A32)*D
285     out[1,0]=(A31*A23-A21*A33)*D
286     out[2,0]=(A21*A32-A31*A22)*D
287     out[0,1]=(A13*A32-A12*A33)*D
288     out[1,1]=(A11*A33-A31*A13)*D
289     out[2,1]=(A12*A31-A11*A32)*D
290     out[0,2]=(A12*A23-A13*A22)*D
291     out[1,2]=(A13*A21-A11*A23)*D
292     out[2,2]=(A11*A22-A12*A21)*D
293     else:
294     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
295 caltinay 3530 return Symbol(out, dim=self.dim)
296 caltinay 3517
297 caltinay 3507 def swap_axes(self, axis0, axis1):
298 caltinay 3530 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
299 caltinay 3507
300 caltinay 3509 def tensorProduct(self, other, axis_offset):
301 caltinay 3512 arg0_c=self._arr.copy()
302     sh0=self.getShape()
303 caltinay 3507 if isinstance(other, Symbol):
304 caltinay 3512 arg1_c=other._arr.copy()
305 caltinay 3507 sh1=other.getShape()
306     else:
307     arg1_c=other.copy()
308     sh1=other.shape
309     d0,d1,d01=1,1,1
310 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
311 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
312     for i in sh1[:axis_offset]: d01*=i
313     arg0_c.resize((d0,d01))
314     arg1_c.resize((d01,d1))
315     out=numpy.zeros((d0,d1),numpy.object)
316     for i0 in range(d0):
317     for i1 in range(d1):
318     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
319 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
320 caltinay 3530 return Symbol(out, dim=self.dim)
321 caltinay 3507
322 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
323 caltinay 3512 arg0_c=self._arr.copy()
324     sh0=self.getShape()
325 caltinay 3509 if isinstance(other, Symbol):
326 caltinay 3512 arg1_c=other._arr.copy()
327 caltinay 3509 sh1=other.getShape()
328     else:
329     arg1_c=other.copy()
330     sh1=other.shape
331     d0,d1,d01=1,1,1
332     for i in sh0[axis_offset:]: d0*=i
333     for i in sh1[axis_offset:]: d1*=i
334     for i in sh1[:axis_offset]: d01*=i
335     arg0_c.resize((d01,d0))
336     arg1_c.resize((d01,d1))
337     out=numpy.zeros((d0,d1),numpy.object)
338     for i0 in range(d0):
339     for i1 in range(d1):
340     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
341     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
342 caltinay 3530 return Symbol(out, dim=self.dim)
343 caltinay 3509
344     def tensorTransposedProduct(self, other, axis_offset):
345 caltinay 3512 arg0_c=self._arr.copy()
346     sh0=self.getShape()
347 caltinay 3509 if isinstance(other, Symbol):
348 caltinay 3512 arg1_c=other._arr.copy()
349 caltinay 3509 sh1=other.getShape()
350     r1=other.getRank()
351     else:
352     arg1_c=other.copy()
353     sh1=other.shape
354     r1=other.ndim
355     d0,d1,d01=1,1,1
356 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
357 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
358     for i in sh1[r1-axis_offset:]: d01*=i
359     arg0_c.resize((d0,d01))
360     arg1_c.resize((d1,d01))
361     out=numpy.zeros((d0,d1),numpy.object)
362     for i0 in range(d0):
363     for i1 in range(d1):
364     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
365 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
366 caltinay 3530 return Symbol(out, dim=self.dim)
367 caltinay 3509
368 caltinay 3507 def trace(self, axis_offset):
369 caltinay 3512 sh=self.getShape()
370 caltinay 3507 s1=1
371     for i in range(axis_offset): s1*=sh[i]
372     s2=1
373     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
374 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
375 caltinay 3507 out=numpy.zeros([s1,s2],object)
376     for i1 in range(s1):
377     for i2 in range(s2):
378     for j in range(sh[axis_offset]):
379     out[i1,i2]+=arr_r[i1,j,j,i2]
380     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
381 caltinay 3530 return Symbol(out, dim=self.dim)
382 caltinay 3507
383     def transpose(self, axis_offset):
384     if axis_offset is None:
385 caltinay 3512 axis_offset=int(self._arr.ndim/2)
386     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
387 caltinay 3530 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
388 caltinay 3507
389 caltinay 3532 def applyfunc(self, f, on_type=None):
390 caltinay 3507 assert callable(f)
391 caltinay 3512 if self._arr.ndim==0:
392 caltinay 3532 if on_type is None or isinstance(self._arr.item(),on_type):
393     el=f(self._arr.item())
394     else:
395     el=self._arr.item()
396 caltinay 3530 if el is not None:
397     out=Symbol(el, dim=self.dim)
398     else:
399     return el
400 caltinay 3507 else:
401 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
402     for idx in numpy.ndindex(self.getShape()):
403 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
404     out[idx]=f(self._arr[idx])
405     else:
406     out[idx]=self._arr[idx]
407 caltinay 3530 out=Symbol(out, dim=self.dim)
408 caltinay 3507 return out
409    
410 caltinay 3532 def simplify(self):
411     return self.applyfunc(sympy.simplify, sympy.Basic)
412    
413 caltinay 3507 def _sympy_(self):
414     return self.applyfunc(sympy.sympify)
415    
416 caltinay 3517 def _ensureShapeCompatible(self, other):
417     """
418     Checks for compatible shapes for binary operations.
419     Raises TypeError if not compatible.
420     """
421     sh0=self.getShape()
422     if isinstance(other, Symbol):
423     sh1=other.getShape()
424     elif isinstance(other, numpy.ndarray):
425     sh1=other.shape
426 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
427 caltinay 3517 sh1=()
428     else:
429     raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)
430     if not sh0==sh1 and not sh0==() and not sh1==():
431     raise TypeError("Incompatible shapes for binary operation")
432    
433 caltinay 3507 @staticmethod
434     def _symComp(sym):
435     n=sym.name
436     a=n.split('[')
437     if len(a)!=2:
438     return n,()
439     a=a[1].split(']')
440     if len(a)!=2:
441     return n,()
442     name=a[0]
443     comps=[int(i) for i in a[1].split('_')[1:]]
444     return name,tuple(comps)
445    
446     @staticmethod
447     def _symbolgen(*symbols):
448     """
449     Generator of all symbols in the argument of diff().
450     (cf. sympy.Derivative._symbolgen)
451    
452     Example:
453     >> ._symbolgen(x, 3, y)
454     (x, x, x, y)
455     >> ._symbolgen(x, 10**6)
456     (x, x, x, x, x, x, x, ...)
457     """
458     from itertools import repeat
459     last_s = symbols[len(symbols)-1]
460     if not isinstance(last_s, Symbol):
461     last_s=sympy.sympify(last_s)
462     for i in xrange(len(symbols)):
463     s = symbols[i]
464     if not isinstance(s, Symbol):
465     s=sympy.sympify(s)
466     next_s = None
467     if s != last_s:
468     next_s = symbols[i+1]
469     if not isinstance(next_s, Symbol):
470     next_s=sympy.sympify(next_s)
471    
472     if isinstance(s, sympy.Integer):
473     continue
474     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
475     # handle cases like (x, 3)
476     if isinstance(next_s, sympy.Integer):
477     # yield (x, x, x)
478     for copy_s in repeat(s,int(next_s)):
479     yield copy_s
480     else:
481     yield s
482     else:
483     yield s
484    
485     # unary/binary operations follow
486    
487     def __pos__(self):
488     return self
489    
490     def __neg__(self):
491 caltinay 3530 return Symbol(-self._arr, dim=self.dim)
492 caltinay 3507
493     def __abs__(self):
494 caltinay 3530 return Symbol(abs(self._arr), dim=self.dim)
495 caltinay 3507
496     def __add__(self, other):
497 caltinay 3517 self._ensureShapeCompatible(other)
498 caltinay 3507 if isinstance(other, Symbol):
499 caltinay 3530 return Symbol(self._arr+other._arr, dim=self.dim)
500     return Symbol(self._arr+other, dim=self.dim)
501 caltinay 3507
502     def __radd__(self, other):
503 caltinay 3517 self._ensureShapeCompatible(other)
504 caltinay 3507 if isinstance(other, Symbol):
505 caltinay 3530 return Symbol(other._arr+self._arr, dim=self.dim)
506     return Symbol(other+self._arr, dim=self.dim)
507 caltinay 3507
508     def __sub__(self, other):
509 caltinay 3517 self._ensureShapeCompatible(other)
510 caltinay 3507 if isinstance(other, Symbol):
511 caltinay 3530 return Symbol(self._arr-other._arr, dim=self.dim)
512     return Symbol(self._arr-other, dim=self.dim)
513 caltinay 3507
514     def __rsub__(self, other):
515 caltinay 3517 self._ensureShapeCompatible(other)
516 caltinay 3507 if isinstance(other, Symbol):
517 caltinay 3530 return Symbol(other._arr-self._arr, dim=self.dim)
518     return Symbol(other-self._arr, dim=self.dim)
519 caltinay 3507
520     def __mul__(self, other):
521 caltinay 3517 self._ensureShapeCompatible(other)
522 caltinay 3507 if isinstance(other, Symbol):
523 caltinay 3530 return Symbol(self._arr*other._arr, dim=self.dim)
524     return Symbol(self._arr*other, dim=self.dim)
525 caltinay 3507
526     def __rmul__(self, other):
527 caltinay 3517 self._ensureShapeCompatible(other)
528 caltinay 3507 if isinstance(other, Symbol):
529 caltinay 3530 return Symbol(other._arr*self._arr, dim=self.dim)
530     return Symbol(other*self._arr, dim=self.dim)
531 caltinay 3507
532     def __div__(self, other):
533 caltinay 3517 self._ensureShapeCompatible(other)
534 caltinay 3507 if isinstance(other, Symbol):
535 caltinay 3530 return Symbol(self._arr/other._arr, dim=self.dim)
536     return Symbol(self._arr/other, dim=self.dim)
537 caltinay 3507
538     def __rdiv__(self, other):
539 caltinay 3517 self._ensureShapeCompatible(other)
540 caltinay 3507 if isinstance(other, Symbol):
541 caltinay 3530 return Symbol(other._arr/self._arr, dim=self.dim)
542     return Symbol(other/self._arr, dim=self.dim)
543 caltinay 3507
544     def __pow__(self, other):
545 caltinay 3517 self._ensureShapeCompatible(other)
546 caltinay 3507 if isinstance(other, Symbol):
547 caltinay 3530 return Symbol(self._arr**other._arr, dim=self.dim)
548     return Symbol(self._arr**other, dim=self.dim)
549 caltinay 3507
550     def __rpow__(self, other):
551 caltinay 3517 self._ensureShapeCompatible(other)
552 caltinay 3507 if isinstance(other, Symbol):
553 caltinay 3530 return Symbol(other._arr**self._arr, dim=self.dim)
554     return Symbol(other**self._arr, dim=self.dim)
555 caltinay 3507
556    
557     def symbols(*names, **kwargs):
558     """
559     Emulates the behaviour of sympy.symbols.
560     """
561    
562     shape=kwargs.pop('shape', ())
563    
564     s = names[0]
565     if not isinstance(s, list):
566     import re
567     s = re.split('\s|,', s)
568     res = []
569     for t in s:
570     # skip empty strings
571     if not t:
572     continue
573     sym = Symbol(t, shape, **kwargs)
574     res.append(sym)
575     res = tuple(res)
576     if len(res) == 0: # var('')
577     res = None
578     elif len(res) == 1: # var('x')
579     res = res[0]
580     # otherwise var('a b ...')
581     return res
582    
583     def combineData(array, shape):
584     # array could just be a single value
585     if not hasattr(array,'__len__') and shape==():
586     return array
587    
588     from esys.escript import Data
589     n=numpy.array(array) # for indexing
590    
591     # find function space if any
592     dom=set()
593     fs=set()
594     for idx in numpy.ndindex(shape):
595     if isinstance(n[idx], Data):
596     fs.add(n[idx].getFunctionSpace())
597     dom.add(n[idx].getDomain())
598    
599     if len(dom)>1:
600     domain=dom.pop()
601     while len(dom)>0:
602     if domain!=dom.pop():
603     raise ValueError("Mixing of domains not supported")
604    
605     if len(fs)>0:
606     d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
607     else:
608 caltinay 3509 d=numpy.zeros(shape)
609 caltinay 3507 for idx in numpy.ndindex(shape):
610     #z=numpy.zeros(shape)
611     #z[idx]=1.
612     #d+=n[idx]*z # much slower!
613     if hasattr(n[idx], "ndim") and n[idx].ndim==0:
614     d[idx]=float(n[idx])
615     else:
616     d[idx]=n[idx]
617     return d
618    
619 caltinay 3517
620     class SymFunction(Symbol):
621     """
622     """
623     def __init__(self, *args, **kwargs):
624     """
625     Initializes a new symbolic function object.
626     """
627     super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
628     self.args=args
629    
630     def __repr__(self):
631     return self.name+"("+", ".join([str(a) for a in self.args])+")"
632    
633     def __str__(self):
634     return self.name+"("+", ".join([str(a) for a in self.args])+")"
635    
636     def lambdarepr(self):
637     return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
638    
639     def atoms(self, *types):
640     s=set()
641     for el in self.args:
642     atoms=el.atoms(*types)
643     for a in atoms:
644     if a.is_Symbol:
645     n,c=Symbol._symComp(a)
646     s.add(sympy.Symbol(n))
647     else:
648     s.add(a)
649     return s
650    
651     def __neg__(self):
652     res=self.__class__(*self.args)
653     res._arr=-res._arr
654     return res
655    

  ViewVC Help
Powered by ViewVC 1.1.26