/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3530 - (hide annotations)
Wed Jun 15 04:48:53 2011 UTC (8 years, 4 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 21874 byte(s)
Added dimensionality to symbols (default: 2).
Fixed differentiation.
Added coeff() method.
Fixed a few special cases where elements are numbers/zero etc.

1 caltinay 3507 # -*- coding: utf-8 -*-
2    
3     ########################################################
4     #
5     # Copyright (c) 2003-2010 by University of Queensland
6     # Earth Systems Science Computational Center (ESSCC)
7     # http://www.uq.edu.au/esscc
8     #
9     # Primary Business: Queensland, Australia
10     # Licensed under the Open Software License version 3.0
11     # http://www.opensource.org/licenses/osl-3.0.php
12     #
13     ########################################################
14    
15     __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16     Earth Systems Science Computational Center (ESSCC)
17     http://www.uq.edu.au/esscc
18     Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22    
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34    
35     __author__="Cihan Altinay"
36    
37    
38     class Symbol(object):
39     """
40     """
41    
42     def __init__(self, *args, **kwargs):
43     """
44     Initializes a new Symbol object.
45     """
46 caltinay 3530 if 'dim' in kwargs:
47     self.dim=kwargs.pop('dim')
48     else:
49     self.dim=2
50    
51 caltinay 3507 if len(args)==1:
52     arg=args[0]
53     if isinstance(arg, str):
54 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
55     raise TypeError("Name must not contain '[' or ']'")
56     self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
57 caltinay 3507 elif hasattr(arg, "__array__"):
58     arr=arg.__array__()
59     if len(arr.shape)>4:
60     raise ValueError("Symbol only supports tensors up to order 4")
61 caltinay 3512 self._arr=arr.copy()
62 caltinay 3507 elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
63 caltinay 3512 self._arr=numpy.array(arg)
64 caltinay 3507 else:
65     raise TypeError("Unsupported argument type %s"%str(type(arg)))
66     elif len(args)==2:
67     if not isinstance(args[0], str):
68     raise TypeError("First argument must be a string")
69     if args[0].find('[')>=0 or args[0].find(']')>=0:
70     raise TypeError("Name must not contain '[' or ']'")
71     if not isinstance(args[1], tuple):
72     raise TypeError("Second argument must be a tuple")
73     name=args[0]
74     shape=args[1]
75     if len(shape)>4:
76     raise ValueError("Symbol only supports tensors up to order 4")
77 caltinay 3509 if len(shape)==0:
78 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
79 caltinay 3509 else:
80 caltinay 3512 self._arr=sympy.symarray(shape, '['+name+']')
81 caltinay 3507 else:
82     raise TypeError("Unsupported number of arguments")
83 caltinay 3512 if self._arr.ndim==0:
84     self.name=str(self._arr.item())
85 caltinay 3507 else:
86 caltinay 3512 self.name=str(self._arr.tolist())
87 caltinay 3507
88     def __repr__(self):
89 caltinay 3512 return str(self._arr)
90 caltinay 3507
91     def __str__(self):
92 caltinay 3512 return str(self._arr)
93 caltinay 3507
94     def __eq__(self, other):
95     if type(self) is not type(other):
96     return False
97     if self.getRank()!=other.getRank():
98     return False
99     if self.getShape()!=other.getShape():
100     return False
101 caltinay 3512 return (self._arr==other._arr).all()
102 caltinay 3507
103     def __getitem__(self, key):
104 caltinay 3512 return self._arr[key]
105 caltinay 3507
106     def __setitem__(self, key, value):
107     if isinstance(value,Symbol):
108     if value.getRank()==0:
109 caltinay 3512 self._arr[key]=value
110     elif hasattr(self._arr[key], "shape"):
111     if self._arr[key].shape==value.getShape():
112     self._arr[key]=value
113 caltinay 3507 else:
114     raise ValueError("Wrong shape of value")
115     else:
116     raise ValueError("Wrong shape of value")
117     elif isinstance(value,sympy.Basic):
118 caltinay 3512 self._arr[key]=value
119 caltinay 3507 elif hasattr(value, "__array__"):
120 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
121 caltinay 3507 else:
122 caltinay 3512 self._arr[key]=sympy.sympify(value)
123 caltinay 3507
124     def getRank(self):
125 caltinay 3512 return self._arr.ndim
126 caltinay 3507
127     def getShape(self):
128 caltinay 3512 return self._arr.shape
129 caltinay 3507
130     def atoms(self, *types):
131     s=set()
132 caltinay 3512 for el in self._arr.flat:
133 caltinay 3530 if isinstance(el,sympy.Basic):
134     atoms=el.atoms(*types)
135     for a in atoms:
136     if a.is_Symbol:
137     n,c=Symbol._symComp(a)
138     s.add(sympy.Symbol(n))
139     else:
140     s.add(a)
141     else:
142     # TODO: Numbers?
143     pass
144 caltinay 3507 return s
145    
146     def _sympystr_(self, printer):
147     return self.lambdarepr()
148    
149     def lambdarepr(self):
150     from sympy.printing.lambdarepr import lambdarepr
151 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
152     for idx,el in numpy.ndenumerate(self._arr):
153 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
154 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
155 caltinay 3507 symdict={}
156     for a in atoms:
157     n,c=Symbol._symComp(a)
158     if len(c)>0:
159     c=[str(i) for i in c]
160     symstr=n+'['+','.join(c)+']'
161     else:
162     symstr=n
163     symdict[a.name]=symstr
164     s=lambdarepr(el)
165     for key in symdict:
166     s=s.replace(key, symdict[key])
167     temp_arr[idx]=s
168 caltinay 3517 if self.getRank()==0:
169     return temp_arr.item()
170     else:
171     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
172 caltinay 3507
173 caltinay 3530 def coeff(self, x, expand=True):
174     self._ensureShapeCompatible(x)
175     result=Symbol(self._arr, dim=self.dim)
176     if isinstance(x, Symbol):
177     if x.getRank()>0:
178     a=result._arr.flat
179     b=x._arr.flat
180     for idx in range(len(a)):
181     s=b.next()
182     if s==0:
183     a[idx]=0
184     else:
185     a[idx]=a[idx].coeff(s, expand)
186     else:
187     if x._arr.item()==0:
188     result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
189     else:
190     coeff_item=lambda item: getattr(item, 'coeff')(x._arr.item(), expand)
191     result=result.applyfunc(coeff_item)
192     elif x==0:
193     result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
194     else:
195     coeff_item=lambda item: getattr(item, 'coeff')(x, expand)
196     result=result.applyfunc(coeff_item)
197    
198     # replace None by 0
199     if result is None: return 0
200     a=result._arr.flat
201     for idx in range(len(a)):
202     if a[idx] is None: a[idx]=0
203     return result
204    
205 caltinay 3507 def diff(self, *symbols, **assumptions):
206     symbols=Symbol._symbolgen(*symbols)
207 caltinay 3530 result=Symbol(self._arr, dim=self.dim)
208 caltinay 3507 for s in symbols:
209     if isinstance(s, Symbol):
210 caltinay 3530 if s.getRank()==0:
211 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
212 caltinay 3507 result=result.applyfunc(diff_item)
213 caltinay 3530 elif s.getRank()==1:
214     dim=s.getShape()[0]
215     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
216     for d in range(dim):
217     for idx in numpy.ndindex(self.getShape()):
218     index=idx+(d,)
219     out[index]=out[index].diff(s[d], **assumptions)
220     result=Symbol(out, dim=self.dim)
221     else:
222     raise ValueError("diff: Only rank 0 and 1 supported")
223 caltinay 3507 else:
224     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
225     result=result.applyfunc(diff_item)
226     return result
227    
228 caltinay 3512 def grad(self, where=None):
229     if isinstance(where, Symbol):
230     if where.getRank()>0:
231 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
232 caltinay 3512 where=where._arr.item()
233    
234     from functions import grad_n
235 caltinay 3530 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
236     for d in range(self.dim):
237 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
238     index=idx+(d,)
239     if where is None:
240     out[index]=grad_n(out[index],d)
241     else:
242     out[index]=grad_n(out[index],d,where)
243 caltinay 3530 return Symbol(out, dim=self.dim)
244 caltinay 3512
245 caltinay 3517 def inverse(self):
246     if not self.getRank()==2:
247     raise ValueError("inverse: Only rank 2 supported")
248     s=self.getShape()
249     if not s[0] == s[1]:
250     raise ValueError("inverse: Only square shapes supported")
251     out=numpy.zeros(s, numpy.object)
252     arr=self._arr
253     if s[0]==1:
254     if arr[0,0].is_zero:
255     raise ZeroDivisionError("inverse: Symbol not invertible")
256     out[0,0]=1./arr[0,0]
257     elif s[0]==2:
258     A11=arr[0,0]
259     A12=arr[0,1]
260     A21=arr[1,0]
261     A22=arr[1,1]
262     D = A11*A22-A12*A21
263     if D.is_zero:
264     raise ZeroDivisionError("inverse: Symbol not invertible")
265     D=1./D
266     out[0,0]= A22*D
267     out[1,0]=-A21*D
268     out[0,1]=-A12*D
269     out[1,1]= A11*D
270     elif s[0]==3:
271     A11=arr[0,0]
272     A21=arr[1,0]
273     A31=arr[2,0]
274     A12=arr[0,1]
275     A22=arr[1,1]
276     A32=arr[2,1]
277     A13=arr[0,2]
278     A23=arr[1,2]
279     A33=arr[2,2]
280     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
281     if D.is_zero:
282     raise ZeroDivisionError("inverse: Symbol not invertible")
283     D=1./D
284     out[0,0]=(A22*A33-A23*A32)*D
285     out[1,0]=(A31*A23-A21*A33)*D
286     out[2,0]=(A21*A32-A31*A22)*D
287     out[0,1]=(A13*A32-A12*A33)*D
288     out[1,1]=(A11*A33-A31*A13)*D
289     out[2,1]=(A12*A31-A11*A32)*D
290     out[0,2]=(A12*A23-A13*A22)*D
291     out[1,2]=(A13*A21-A11*A23)*D
292     out[2,2]=(A11*A22-A12*A21)*D
293     else:
294     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
295 caltinay 3530 return Symbol(out, dim=self.dim)
296 caltinay 3517
297 caltinay 3507 def swap_axes(self, axis0, axis1):
298 caltinay 3530 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
299 caltinay 3507
300 caltinay 3509 def tensorProduct(self, other, axis_offset):
301 caltinay 3512 arg0_c=self._arr.copy()
302     sh0=self.getShape()
303 caltinay 3507 if isinstance(other, Symbol):
304 caltinay 3512 arg1_c=other._arr.copy()
305 caltinay 3507 sh1=other.getShape()
306     else:
307     arg1_c=other.copy()
308     sh1=other.shape
309     d0,d1,d01=1,1,1
310 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
311 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
312     for i in sh1[:axis_offset]: d01*=i
313     arg0_c.resize((d0,d01))
314     arg1_c.resize((d01,d1))
315     out=numpy.zeros((d0,d1),numpy.object)
316     for i0 in range(d0):
317     for i1 in range(d1):
318     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
319 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
320 caltinay 3530 return Symbol(out, dim=self.dim)
321 caltinay 3507
322 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
323 caltinay 3512 arg0_c=self._arr.copy()
324     sh0=self.getShape()
325 caltinay 3509 if isinstance(other, Symbol):
326 caltinay 3512 arg1_c=other._arr.copy()
327 caltinay 3509 sh1=other.getShape()
328     else:
329     arg1_c=other.copy()
330     sh1=other.shape
331     d0,d1,d01=1,1,1
332     for i in sh0[axis_offset:]: d0*=i
333     for i in sh1[axis_offset:]: d1*=i
334     for i in sh1[:axis_offset]: d01*=i
335     arg0_c.resize((d01,d0))
336     arg1_c.resize((d01,d1))
337     out=numpy.zeros((d0,d1),numpy.object)
338     for i0 in range(d0):
339     for i1 in range(d1):
340     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
341     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
342 caltinay 3530 return Symbol(out, dim=self.dim)
343 caltinay 3509
344     def tensorTransposedProduct(self, other, axis_offset):
345 caltinay 3512 arg0_c=self._arr.copy()
346     sh0=self.getShape()
347 caltinay 3509 if isinstance(other, Symbol):
348 caltinay 3512 arg1_c=other._arr.copy()
349 caltinay 3509 sh1=other.getShape()
350     r1=other.getRank()
351     else:
352     arg1_c=other.copy()
353     sh1=other.shape
354     r1=other.ndim
355     d0,d1,d01=1,1,1
356 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
357 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
358     for i in sh1[r1-axis_offset:]: d01*=i
359     arg0_c.resize((d0,d01))
360     arg1_c.resize((d1,d01))
361     out=numpy.zeros((d0,d1),numpy.object)
362     for i0 in range(d0):
363     for i1 in range(d1):
364     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
365 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
366 caltinay 3530 return Symbol(out, dim=self.dim)
367 caltinay 3509
368 caltinay 3507 def trace(self, axis_offset):
369 caltinay 3512 sh=self.getShape()
370 caltinay 3507 s1=1
371     for i in range(axis_offset): s1*=sh[i]
372     s2=1
373     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
374 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
375 caltinay 3507 out=numpy.zeros([s1,s2],object)
376     for i1 in range(s1):
377     for i2 in range(s2):
378     for j in range(sh[axis_offset]):
379     out[i1,i2]+=arr_r[i1,j,j,i2]
380     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
381 caltinay 3530 return Symbol(out, dim=self.dim)
382 caltinay 3507
383     def transpose(self, axis_offset):
384     if axis_offset is None:
385 caltinay 3512 axis_offset=int(self._arr.ndim/2)
386     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
387 caltinay 3530 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
388 caltinay 3507
389     def applyfunc(self, f):
390     assert callable(f)
391 caltinay 3512 if self._arr.ndim==0:
392 caltinay 3530 el=f(self._arr.item())
393     if el is not None:
394     out=Symbol(el, dim=self.dim)
395     else:
396     return el
397 caltinay 3507 else:
398 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
399     for idx in numpy.ndindex(self.getShape()):
400     out[idx]=f(self._arr[idx])
401 caltinay 3530 out=Symbol(out, dim=self.dim)
402 caltinay 3507 return out
403    
404     def _sympy_(self):
405     return self.applyfunc(sympy.sympify)
406    
407 caltinay 3517 def _ensureShapeCompatible(self, other):
408     """
409     Checks for compatible shapes for binary operations.
410     Raises TypeError if not compatible.
411     """
412     sh0=self.getShape()
413     if isinstance(other, Symbol):
414     sh1=other.getShape()
415     elif isinstance(other, numpy.ndarray):
416     sh1=other.shape
417 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
418 caltinay 3517 sh1=()
419     else:
420     raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)
421     if not sh0==sh1 and not sh0==() and not sh1==():
422     raise TypeError("Incompatible shapes for binary operation")
423    
424 caltinay 3507 @staticmethod
425     def _symComp(sym):
426     n=sym.name
427     a=n.split('[')
428     if len(a)!=2:
429     return n,()
430     a=a[1].split(']')
431     if len(a)!=2:
432     return n,()
433     name=a[0]
434     comps=[int(i) for i in a[1].split('_')[1:]]
435     return name,tuple(comps)
436    
437     @staticmethod
438     def _symbolgen(*symbols):
439     """
440     Generator of all symbols in the argument of diff().
441     (cf. sympy.Derivative._symbolgen)
442    
443     Example:
444     >> ._symbolgen(x, 3, y)
445     (x, x, x, y)
446     >> ._symbolgen(x, 10**6)
447     (x, x, x, x, x, x, x, ...)
448     """
449     from itertools import repeat
450     last_s = symbols[len(symbols)-1]
451     if not isinstance(last_s, Symbol):
452     last_s=sympy.sympify(last_s)
453     for i in xrange(len(symbols)):
454     s = symbols[i]
455     if not isinstance(s, Symbol):
456     s=sympy.sympify(s)
457     next_s = None
458     if s != last_s:
459     next_s = symbols[i+1]
460     if not isinstance(next_s, Symbol):
461     next_s=sympy.sympify(next_s)
462    
463     if isinstance(s, sympy.Integer):
464     continue
465     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
466     # handle cases like (x, 3)
467     if isinstance(next_s, sympy.Integer):
468     # yield (x, x, x)
469     for copy_s in repeat(s,int(next_s)):
470     yield copy_s
471     else:
472     yield s
473     else:
474     yield s
475    
476     # unary/binary operations follow
477    
478     def __pos__(self):
479     return self
480    
481     def __neg__(self):
482 caltinay 3530 return Symbol(-self._arr, dim=self.dim)
483 caltinay 3507
484     def __abs__(self):
485 caltinay 3530 return Symbol(abs(self._arr), dim=self.dim)
486 caltinay 3507
487     def __add__(self, other):
488 caltinay 3517 self._ensureShapeCompatible(other)
489 caltinay 3507 if isinstance(other, Symbol):
490 caltinay 3530 return Symbol(self._arr+other._arr, dim=self.dim)
491     return Symbol(self._arr+other, dim=self.dim)
492 caltinay 3507
493     def __radd__(self, other):
494 caltinay 3517 self._ensureShapeCompatible(other)
495 caltinay 3507 if isinstance(other, Symbol):
496 caltinay 3530 return Symbol(other._arr+self._arr, dim=self.dim)
497     return Symbol(other+self._arr, dim=self.dim)
498 caltinay 3507
499     def __sub__(self, other):
500 caltinay 3517 self._ensureShapeCompatible(other)
501 caltinay 3507 if isinstance(other, Symbol):
502 caltinay 3530 return Symbol(self._arr-other._arr, dim=self.dim)
503     return Symbol(self._arr-other, dim=self.dim)
504 caltinay 3507
505     def __rsub__(self, other):
506 caltinay 3517 self._ensureShapeCompatible(other)
507 caltinay 3507 if isinstance(other, Symbol):
508 caltinay 3530 return Symbol(other._arr-self._arr, dim=self.dim)
509     return Symbol(other-self._arr, dim=self.dim)
510 caltinay 3507
511     def __mul__(self, other):
512 caltinay 3517 self._ensureShapeCompatible(other)
513 caltinay 3507 if isinstance(other, Symbol):
514 caltinay 3530 return Symbol(self._arr*other._arr, dim=self.dim)
515     return Symbol(self._arr*other, dim=self.dim)
516 caltinay 3507
517     def __rmul__(self, other):
518 caltinay 3517 self._ensureShapeCompatible(other)
519 caltinay 3507 if isinstance(other, Symbol):
520 caltinay 3530 return Symbol(other._arr*self._arr, dim=self.dim)
521     return Symbol(other*self._arr, dim=self.dim)
522 caltinay 3507
523     def __div__(self, other):
524 caltinay 3517 self._ensureShapeCompatible(other)
525 caltinay 3507 if isinstance(other, Symbol):
526 caltinay 3530 return Symbol(self._arr/other._arr, dim=self.dim)
527     return Symbol(self._arr/other, dim=self.dim)
528 caltinay 3507
529     def __rdiv__(self, other):
530 caltinay 3517 self._ensureShapeCompatible(other)
531 caltinay 3507 if isinstance(other, Symbol):
532 caltinay 3530 return Symbol(other._arr/self._arr, dim=self.dim)
533     return Symbol(other/self._arr, dim=self.dim)
534 caltinay 3507
535     def __pow__(self, other):
536 caltinay 3517 self._ensureShapeCompatible(other)
537 caltinay 3507 if isinstance(other, Symbol):
538 caltinay 3530 return Symbol(self._arr**other._arr, dim=self.dim)
539     return Symbol(self._arr**other, dim=self.dim)
540 caltinay 3507
541     def __rpow__(self, other):
542 caltinay 3517 self._ensureShapeCompatible(other)
543 caltinay 3507 if isinstance(other, Symbol):
544 caltinay 3530 return Symbol(other._arr**self._arr, dim=self.dim)
545     return Symbol(other**self._arr, dim=self.dim)
546 caltinay 3507
547    
548     def symbols(*names, **kwargs):
549     """
550     Emulates the behaviour of sympy.symbols.
551     """
552    
553     shape=kwargs.pop('shape', ())
554    
555     s = names[0]
556     if not isinstance(s, list):
557     import re
558     s = re.split('\s|,', s)
559     res = []
560     for t in s:
561     # skip empty strings
562     if not t:
563     continue
564     sym = Symbol(t, shape, **kwargs)
565     res.append(sym)
566     res = tuple(res)
567     if len(res) == 0: # var('')
568     res = None
569     elif len(res) == 1: # var('x')
570     res = res[0]
571     # otherwise var('a b ...')
572     return res
573    
574     def combineData(array, shape):
575     # array could just be a single value
576     if not hasattr(array,'__len__') and shape==():
577     return array
578    
579     from esys.escript import Data
580     n=numpy.array(array) # for indexing
581    
582     # find function space if any
583     dom=set()
584     fs=set()
585     for idx in numpy.ndindex(shape):
586     if isinstance(n[idx], Data):
587     fs.add(n[idx].getFunctionSpace())
588     dom.add(n[idx].getDomain())
589    
590     if len(dom)>1:
591     domain=dom.pop()
592     while len(dom)>0:
593     if domain!=dom.pop():
594     raise ValueError("Mixing of domains not supported")
595    
596     if len(fs)>0:
597     d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
598     else:
599 caltinay 3509 d=numpy.zeros(shape)
600 caltinay 3507 for idx in numpy.ndindex(shape):
601     #z=numpy.zeros(shape)
602     #z[idx]=1.
603     #d+=n[idx]*z # much slower!
604     if hasattr(n[idx], "ndim") and n[idx].ndim==0:
605     d[idx]=float(n[idx])
606     else:
607     d[idx]=n[idx]
608     return d
609    
610 caltinay 3517
611     class SymFunction(Symbol):
612     """
613     """
614     def __init__(self, *args, **kwargs):
615     """
616     Initializes a new symbolic function object.
617     """
618     super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
619     self.args=args
620    
621     def __repr__(self):
622     return self.name+"("+", ".join([str(a) for a in self.args])+")"
623    
624     def __str__(self):
625     return self.name+"("+", ".join([str(a) for a in self.args])+")"
626    
627     def lambdarepr(self):
628     return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
629    
630     def atoms(self, *types):
631     s=set()
632     for el in self.args:
633     atoms=el.atoms(*types)
634     for a in atoms:
635     if a.is_Symbol:
636     n,c=Symbol._symComp(a)
637     s.add(sympy.Symbol(n))
638     else:
639     s.add(a)
640     return s
641    
642     def __neg__(self):
643     res=self.__class__(*self.args)
644     res._arr=-res._arr
645     return res
646    

  ViewVC Help
Powered by ViewVC 1.1.26