/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3857 - (hide annotations)
Tue Mar 6 07:28:22 2012 UTC (7 years, 7 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 25999 byte(s)
Moved symbolic utility functions into separate file.

1 caltinay 3507
2     ########################################################
3     #
4 caltinay 3815 # Copyright (c) 2003-2012 by University of Queensland
5 caltinay 3507 # Earth Systems Science Computational Center (ESSCC)
6     # http://www.uq.edu.au/esscc
7     #
8     # Primary Business: Queensland, Australia
9     # Licensed under the Open Software License version 3.0
10     # http://www.opensource.org/licenses/osl-3.0.php
11     #
12     ########################################################
13    
14 caltinay 3815 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15 caltinay 3507 Earth Systems Science Computational Center (ESSCC)
16     http://www.uq.edu.au/esscc
17     Primary Business: Queensland, Australia"""
18     __license__="""Licensed under the Open Software License version 3.0
19     http://www.opensource.org/licenses/osl-3.0.php"""
20     __url__="https://launchpad.net/escript-finley"
21    
22     """
23     :var __author__: name of author
24     :var __copyright__: copyrights
25     :var __license__: licence agreement
26     :var __url__: url entry point on documentation
27     :var __version__: version
28     :var __date__: date of the version
29     """
30    
31     import numpy
32     import sympy
33    
34     __author__="Cihan Altinay"
35    
36 gross 3856
37    
38 caltinay 3507 class Symbol(object):
39     """
40 caltinay 3854 `Symbol` objects are placeholders for a single mathematical symbol, such as
41     'x', or for arbitrarily complex mathematical expressions such as
42 caltinay 3533 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
43     are also `Symbol`s (the symbolic 'atoms' of the expression).
44    
45     With the help of the 'Evaluator' class these symbols and expressions can
46     be resolved by substituting numeric values and/or escript `Data` objects
47     for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
48     shape (and thus a rank) as well as a dimension (see constructor).
49 caltinay 3854 `Symbol`s are useful to perform mathematical simplifications, compute
50 caltinay 3533 derivatives and as coefficients for nonlinear PDEs which can be solved by
51     the `NonlinearPDE` class.
52 caltinay 3507 """
53    
54 caltinay 3818 # these are for compatibility with sympy.Symbol. lambdify checks these.
55     is_Add=False
56     is_Float=False
57    
58 caltinay 3507 def __init__(self, *args, **kwargs):
59     """
60 caltinay 3533 Initialises a new `Symbol` object in one of three ways::
61    
62     u=Symbol('u')
63    
64     returns a scalar symbol by the name 'u'.
65    
66 caltinay 3535 a=Symbol('alpha', (4,3))
67 caltinay 3533
68     returns a rank 2 symbol with the shape (4,3), whose elements are
69     named '[alpha]_i_j' (with i=0..3, j=0..2).
70    
71     a,b,c=symbols('a,b,c')
72     x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
73    
74     returns a rank 2 symbol with the shape (3,3) whose elements are
75 caltinay 3536 explicitly specified by numeric values and other symbols/expressions
76 caltinay 3533 within a list or numpy array.
77    
78     The dimensionality of the symbol can be specified through the `dim`
79     keyword. All other keywords are passed to the underlying symbolic
80     library (currently sympy).
81    
82     :param args: initialisation arguments as described above
83     :keyword dim: dimensionality of the new Symbol (default: 2)
84     :type dim: ``int``
85 caltinay 3507 """
86 caltinay 3530 if 'dim' in kwargs:
87     self.dim=kwargs.pop('dim')
88     else:
89     self.dim=2
90    
91 caltinay 3507 if len(args)==1:
92     arg=args[0]
93     if isinstance(arg, str):
94 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
95 caltinay 3533 raise ValueError("Name must not contain '[' or ']'")
96 caltinay 3512 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
97 caltinay 3533 elif hasattr(arg, "__array__") or isinstance(arg, list):
98     if isinstance(arg, list): arg=numpy.array(arg)
99 caltinay 3507 arr=arg.__array__()
100     if len(arr.shape)>4:
101     raise ValueError("Symbol only supports tensors up to order 4")
102 caltinay 3533 res=numpy.empty(arr.shape, dtype=object)
103     for idx in numpy.ndindex(arr.shape):
104     if hasattr(arr[idx], "item"):
105     res[idx]=arr[idx].item()
106     else:
107     res[idx]=arr[idx]
108     self._arr=res
109     elif isinstance(arg, sympy.Basic):
110 caltinay 3512 self._arr=numpy.array(arg)
111 caltinay 3507 else:
112     raise TypeError("Unsupported argument type %s"%str(type(arg)))
113     elif len(args)==2:
114     if not isinstance(args[0], str):
115     raise TypeError("First argument must be a string")
116     if not isinstance(args[1], tuple):
117     raise TypeError("Second argument must be a tuple")
118     name=args[0]
119     shape=args[1]
120 caltinay 3533 if name.find('[')>=0 or name.find(']')>=0:
121     raise ValueError("Name must not contain '[' or ']'")
122 caltinay 3507 if len(shape)>4:
123     raise ValueError("Symbol only supports tensors up to order 4")
124 caltinay 3509 if len(shape)==0:
125 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
126 caltinay 3509 else:
127 caltinay 3631 try:
128     self._arr=sympy.symarray(shape, '['+name+']')
129     except TypeError:
130     self._arr=sympy.symarray('['+name+']', shape)
131 caltinay 3507 else:
132     raise TypeError("Unsupported number of arguments")
133 caltinay 3512 if self._arr.ndim==0:
134     self.name=str(self._arr.item())
135 caltinay 3507 else:
136 caltinay 3512 self.name=str(self._arr.tolist())
137 caltinay 3507
138     def __repr__(self):
139 caltinay 3512 return str(self._arr)
140 caltinay 3507
141     def __str__(self):
142 caltinay 3512 return str(self._arr)
143 caltinay 3507
144     def __eq__(self, other):
145     if type(self) is not type(other):
146     return False
147     if self.getRank()!=other.getRank():
148     return False
149     if self.getShape()!=other.getShape():
150     return False
151 caltinay 3512 return (self._arr==other._arr).all()
152 caltinay 3507
153     def __getitem__(self, key):
154 caltinay 3512 return self._arr[key]
155 caltinay 3507
156 caltinay 3818 def __iter__(self):
157     return self._arr.__iter__
158    
159 caltinay 3507 def __setitem__(self, key, value):
160 caltinay 3533 if isinstance(value, Symbol):
161 caltinay 3507 if value.getRank()==0:
162 caltinay 3533 self._arr[key]=value.item()
163 caltinay 3512 elif hasattr(self._arr[key], "shape"):
164     if self._arr[key].shape==value.getShape():
165 caltinay 3533 for idx in numpy.ndindex(self._arr[key].shape):
166     self._arr[key][idx]=value[idx]
167 caltinay 3507 else:
168     raise ValueError("Wrong shape of value")
169     else:
170     raise ValueError("Wrong shape of value")
171 caltinay 3533 elif isinstance(value, sympy.Basic):
172 caltinay 3512 self._arr[key]=value
173 caltinay 3507 elif hasattr(value, "__array__"):
174 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
175 caltinay 3507 else:
176 caltinay 3512 self._arr[key]=sympy.sympify(value)
177 caltinay 3507
178 caltinay 3533 def getDim(self):
179     """
180     Returns the spatial dimensionality of this symbol.
181    
182     :return: the symbol's spatial dimensionality
183     :rtype: ``int``
184     """
185     return self.dim
186    
187 caltinay 3507 def getRank(self):
188 caltinay 3533 """
189     Returns the rank of this symbol.
190    
191     :return: the symbol's rank which is equal to the length of the shape.
192     :rtype: ``int``
193     """
194 caltinay 3512 return self._arr.ndim
195 caltinay 3507
196     def getShape(self):
197 caltinay 3533 """
198     Returns the shape of this symbol.
199    
200     :return: the symbol's shape
201     :rtype: ``tuple`` of ``int``
202     """
203 caltinay 3512 return self._arr.shape
204 caltinay 3507
205 caltinay 3533 def item(self, *args):
206     """
207     Returns an element of this symbol.
208     This method behaves like the item() method of numpy.ndarray.
209     If this is a scalar Symbol, no arguments are allowed and the only
210     element in this Symbol is returned.
211     Otherwise, 'args' specifies a flat or nd-index and the element at
212     that index is returned.
213    
214     :param args: index of item to be returned
215     :return: the requested element
216     :rtype: ``sympy.Symbol``, ``int``, or ``float``
217     """
218     return self._arr.item(args)
219    
220 caltinay 3507 def atoms(self, *types):
221 caltinay 3533 """
222     Returns the atoms that form the current Symbol.
223    
224     By default, only objects that are truly atomic and cannot be divided
225     into smaller pieces are returned: symbols, numbers, and number
226     symbols like I and pi. It is possible to request atoms of any type,
227     however.
228    
229     Note that if this symbol contains components such as [x]_i_j then
230     only their main symbol 'x' is returned.
231    
232     :param types: types to restrict result to
233     :return: list of atoms of specified type
234     :rtype: ``set``
235     """
236 caltinay 3507 s=set()
237 caltinay 3512 for el in self._arr.flat:
238 caltinay 3530 if isinstance(el,sympy.Basic):
239     atoms=el.atoms(*types)
240     for a in atoms:
241     if a.is_Symbol:
242     n,c=Symbol._symComp(a)
243     s.add(sympy.Symbol(n))
244     else:
245     s.add(a)
246 caltinay 3533 elif len(types)==0 or type(el) in types:
247     s.add(el)
248 caltinay 3507 return s
249    
250     def _sympystr_(self, printer):
251 caltinay 3818 # compatibility with sympy 1.6
252     return self._sympystr(printer)
253    
254     def _sympystr(self, printer):
255 caltinay 3507 return self.lambdarepr()
256    
257     def lambdarepr(self):
258     from sympy.printing.lambdarepr import lambdarepr
259 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
260     for idx,el in numpy.ndenumerate(self._arr):
261 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
262 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
263 caltinay 3507 symdict={}
264     for a in atoms:
265     n,c=Symbol._symComp(a)
266     if len(c)>0:
267     c=[str(i) for i in c]
268     symstr=n+'['+','.join(c)+']'
269     else:
270     symstr=n
271     symdict[a.name]=symstr
272     s=lambdarepr(el)
273     for key in symdict:
274     s=s.replace(key, symdict[key])
275     temp_arr[idx]=s
276 caltinay 3517 if self.getRank()==0:
277     return temp_arr.item()
278     else:
279     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
280 caltinay 3507
281 caltinay 3530 def coeff(self, x, expand=True):
282 caltinay 3536 """
283     Returns the coefficient of the term "x" or 0 if there is no "x".
284    
285     If "x" is a scalar symbol then "x" is searched in all components of
286     this symbol. Otherwise the shapes must match and the coefficients are
287     checked component by component.
288    
289     Example::
290    
291     x=Symbol('x', (2,2))
292     y=3*x
293     print y.coeff(x)
294     print y.coeff(x[1,1])
295    
296     will print::
297    
298     [[3 3]
299     [3 3]]
300    
301     [[0 0]
302     [0 3]]
303    
304     :param x: the term whose coefficients are to be found
305     :type x: ``Symbol``, ``numpy.ndarray``, `list`
306     :return: the coefficient(s) of the term
307     :rtype: ``Symbol``
308     """
309 caltinay 3530 self._ensureShapeCompatible(x)
310 caltinay 3536 if hasattr(x, '__array__'):
311     y=x.__array__()
312 caltinay 3530 else:
313 caltinay 3536 y=numpy.array(x)
314 caltinay 3530
315 caltinay 3536 if y.ndim>0:
316     result=numpy.zeros(self.getShape(), dtype=object)
317     for idx in numpy.ndindex(y.shape):
318     if y[idx]!=0:
319     res=self[idx].coeff(y[idx], expand)
320     if res is not None:
321     result[idx]=res
322     elif y.item()==0:
323     result=numpy.zeros(self.getShape(), dtype=object)
324     else:
325     coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
326     none_to_zero=lambda item: 0 if item is None else item
327     result=self.applyfunc(coeff_item)
328 caltinay 3614 result=result.applyfunc(none_to_zero)
329 caltinay 3536 return Symbol(result, dim=self.dim)
330 caltinay 3530
331 caltinay 3614 def subs(self, old, new):
332     """
333     Substitutes an expression.
334     """
335     if isinstance(old, Symbol) and old.getRank()>0:
336     old._ensureShapeCompatible(new)
337     if hasattr(new, '__array__'):
338     new=new.__array__()
339     else:
340     new=numpy.array(new)
341    
342     result=Symbol(self._arr, dim=self.dim)
343     if new.ndim>0:
344     for idx in numpy.ndindex(self.getShape()):
345     for symidx in numpy.ndindex(new.shape):
346     result[idx]=result[idx].subs(old[symidx], new[symidx])
347     else: # substitute scalar for non-scalar
348     for idx in numpy.ndindex(self.getShape()):
349     for symidx in numpy.ndindex(old.getShape()):
350     result[idx]=result[idx].subs(old[symidx], new.item())
351     else: # scalar
352     if isinstance(new, Symbol):
353     if new.getRank()>0:
354     raise TypeError("Cannot substitute, incompatible ranks.")
355     new=new.item()
356     if isinstance(old, Symbol):
357     old=old.item()
358     subs_item=lambda item: getattr(item, 'subs')(old, new)
359     result=self.applyfunc(subs_item)
360     return result
361    
362 caltinay 3507 def diff(self, *symbols, **assumptions):
363 caltinay 3533 """
364     """
365 caltinay 3507 symbols=Symbol._symbolgen(*symbols)
366 caltinay 3530 result=Symbol(self._arr, dim=self.dim)
367 caltinay 3507 for s in symbols:
368     if isinstance(s, Symbol):
369 caltinay 3530 if s.getRank()==0:
370 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
371 caltinay 3507 result=result.applyfunc(diff_item)
372 caltinay 3530 elif s.getRank()==1:
373     dim=s.getShape()[0]
374     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
375     for d in range(dim):
376     for idx in numpy.ndindex(self.getShape()):
377     index=idx+(d,)
378     out[index]=out[index].diff(s[d], **assumptions)
379     result=Symbol(out, dim=self.dim)
380     else:
381     raise ValueError("diff: Only rank 0 and 1 supported")
382 caltinay 3507 else:
383     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
384     result=result.applyfunc(diff_item)
385     return result
386    
387 caltinay 3512 def grad(self, where=None):
388 caltinay 3533 """
389     """
390 caltinay 3512 if isinstance(where, Symbol):
391     if where.getRank()>0:
392 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
393 caltinay 3512 where=where._arr.item()
394    
395     from functions import grad_n
396 caltinay 3530 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
397     for d in range(self.dim):
398 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
399     index=idx+(d,)
400     if where is None:
401     out[index]=grad_n(out[index],d)
402     else:
403     out[index]=grad_n(out[index],d,where)
404 caltinay 3530 return Symbol(out, dim=self.dim)
405 caltinay 3512
406 caltinay 3517 def inverse(self):
407 caltinay 3533 """
408     """
409 caltinay 3517 if not self.getRank()==2:
410 caltinay 3533 raise TypeError("inverse: Only rank 2 supported")
411 caltinay 3517 s=self.getShape()
412     if not s[0] == s[1]:
413     raise ValueError("inverse: Only square shapes supported")
414     out=numpy.zeros(s, numpy.object)
415     arr=self._arr
416     if s[0]==1:
417     if arr[0,0].is_zero:
418     raise ZeroDivisionError("inverse: Symbol not invertible")
419     out[0,0]=1./arr[0,0]
420     elif s[0]==2:
421     A11=arr[0,0]
422     A12=arr[0,1]
423     A21=arr[1,0]
424     A22=arr[1,1]
425     D = A11*A22-A12*A21
426     if D.is_zero:
427     raise ZeroDivisionError("inverse: Symbol not invertible")
428     D=1./D
429     out[0,0]= A22*D
430     out[1,0]=-A21*D
431     out[0,1]=-A12*D
432     out[1,1]= A11*D
433     elif s[0]==3:
434     A11=arr[0,0]
435     A21=arr[1,0]
436     A31=arr[2,0]
437     A12=arr[0,1]
438     A22=arr[1,1]
439     A32=arr[2,1]
440     A13=arr[0,2]
441     A23=arr[1,2]
442     A33=arr[2,2]
443     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
444     if D.is_zero:
445     raise ZeroDivisionError("inverse: Symbol not invertible")
446     D=1./D
447     out[0,0]=(A22*A33-A23*A32)*D
448     out[1,0]=(A31*A23-A21*A33)*D
449     out[2,0]=(A21*A32-A31*A22)*D
450     out[0,1]=(A13*A32-A12*A33)*D
451     out[1,1]=(A11*A33-A31*A13)*D
452     out[2,1]=(A12*A31-A11*A32)*D
453     out[0,2]=(A12*A23-A13*A22)*D
454     out[1,2]=(A13*A21-A11*A23)*D
455     out[2,2]=(A11*A22-A12*A21)*D
456     else:
457     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
458 caltinay 3530 return Symbol(out, dim=self.dim)
459 caltinay 3517
460 caltinay 3507 def swap_axes(self, axis0, axis1):
461 caltinay 3533 """
462     """
463 caltinay 3530 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
464 caltinay 3507
465 caltinay 3509 def tensorProduct(self, other, axis_offset):
466 caltinay 3533 """
467     """
468 caltinay 3512 arg0_c=self._arr.copy()
469     sh0=self.getShape()
470 caltinay 3507 if isinstance(other, Symbol):
471 caltinay 3512 arg1_c=other._arr.copy()
472 caltinay 3507 sh1=other.getShape()
473     else:
474     arg1_c=other.copy()
475     sh1=other.shape
476     d0,d1,d01=1,1,1
477 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
478 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
479     for i in sh1[:axis_offset]: d01*=i
480     arg0_c.resize((d0,d01))
481     arg1_c.resize((d01,d1))
482     out=numpy.zeros((d0,d1),numpy.object)
483     for i0 in range(d0):
484     for i1 in range(d1):
485     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
486 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
487 caltinay 3530 return Symbol(out, dim=self.dim)
488 caltinay 3507
489 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
490 caltinay 3533 """
491     """
492 caltinay 3512 arg0_c=self._arr.copy()
493     sh0=self.getShape()
494 caltinay 3509 if isinstance(other, Symbol):
495 caltinay 3512 arg1_c=other._arr.copy()
496 caltinay 3509 sh1=other.getShape()
497     else:
498     arg1_c=other.copy()
499     sh1=other.shape
500     d0,d1,d01=1,1,1
501     for i in sh0[axis_offset:]: d0*=i
502     for i in sh1[axis_offset:]: d1*=i
503     for i in sh1[:axis_offset]: d01*=i
504     arg0_c.resize((d01,d0))
505     arg1_c.resize((d01,d1))
506     out=numpy.zeros((d0,d1),numpy.object)
507     for i0 in range(d0):
508     for i1 in range(d1):
509     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
510     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
511 caltinay 3530 return Symbol(out, dim=self.dim)
512 caltinay 3509
513     def tensorTransposedProduct(self, other, axis_offset):
514 caltinay 3533 """
515     """
516 caltinay 3512 arg0_c=self._arr.copy()
517     sh0=self.getShape()
518 caltinay 3509 if isinstance(other, Symbol):
519 caltinay 3512 arg1_c=other._arr.copy()
520 caltinay 3509 sh1=other.getShape()
521     r1=other.getRank()
522     else:
523     arg1_c=other.copy()
524     sh1=other.shape
525     r1=other.ndim
526     d0,d1,d01=1,1,1
527 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
528 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
529     for i in sh1[r1-axis_offset:]: d01*=i
530     arg0_c.resize((d0,d01))
531     arg1_c.resize((d1,d01))
532     out=numpy.zeros((d0,d1),numpy.object)
533     for i0 in range(d0):
534     for i1 in range(d1):
535     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
536 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
537 caltinay 3530 return Symbol(out, dim=self.dim)
538 caltinay 3509
539 caltinay 3507 def trace(self, axis_offset):
540 caltinay 3533 """
541     """
542 caltinay 3512 sh=self.getShape()
543 caltinay 3507 s1=1
544     for i in range(axis_offset): s1*=sh[i]
545     s2=1
546     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
547 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
548 caltinay 3507 out=numpy.zeros([s1,s2],object)
549     for i1 in range(s1):
550     for i2 in range(s2):
551     for j in range(sh[axis_offset]):
552     out[i1,i2]+=arr_r[i1,j,j,i2]
553     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
554 caltinay 3530 return Symbol(out, dim=self.dim)
555 caltinay 3507
556     def transpose(self, axis_offset):
557 caltinay 3533 """
558     """
559 caltinay 3507 if axis_offset is None:
560 caltinay 3512 axis_offset=int(self._arr.ndim/2)
561     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
562 caltinay 3530 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
563 caltinay 3507
564 caltinay 3532 def applyfunc(self, f, on_type=None):
565 caltinay 3533 """
566     """
567 caltinay 3507 assert callable(f)
568 caltinay 3512 if self._arr.ndim==0:
569 caltinay 3532 if on_type is None or isinstance(self._arr.item(),on_type):
570     el=f(self._arr.item())
571     else:
572     el=self._arr.item()
573 caltinay 3530 if el is not None:
574     out=Symbol(el, dim=self.dim)
575     else:
576     return el
577 caltinay 3507 else:
578 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
579     for idx in numpy.ndindex(self.getShape()):
580 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
581     out[idx]=f(self._arr[idx])
582     else:
583     out[idx]=self._arr[idx]
584 caltinay 3530 out=Symbol(out, dim=self.dim)
585 caltinay 3507 return out
586    
587 caltinay 3614 def expand(self):
588     """
589     """
590     return self.applyfunc(sympy.expand, sympy.Basic)
591    
592 caltinay 3532 def simplify(self):
593 caltinay 3533 """
594     """
595 caltinay 3532 return self.applyfunc(sympy.simplify, sympy.Basic)
596    
597 caltinay 3507 def _sympy_(self):
598 caltinay 3533 """
599     """
600 caltinay 3507 return self.applyfunc(sympy.sympify)
601    
602 caltinay 3517 def _ensureShapeCompatible(self, other):
603     """
604     Checks for compatible shapes for binary operations.
605     Raises TypeError if not compatible.
606     """
607     sh0=self.getShape()
608     if isinstance(other, Symbol):
609     sh1=other.getShape()
610     elif isinstance(other, numpy.ndarray):
611     sh1=other.shape
612 caltinay 3536 elif isinstance(other, list):
613     sh1=numpy.array(other).shape
614 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
615 caltinay 3517 sh1=()
616     else:
617 caltinay 3614 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
618 caltinay 3517 if not sh0==sh1 and not sh0==() and not sh1==():
619 caltinay 3614 raise TypeError("Incompatible shapes for operation")
620 caltinay 3517
621 caltinay 3507 @staticmethod
622     def _symComp(sym):
623 caltinay 3857 """
624     """
625 caltinay 3507 n=sym.name
626     a=n.split('[')
627     if len(a)!=2:
628     return n,()
629     a=a[1].split(']')
630     if len(a)!=2:
631     return n,()
632     name=a[0]
633     comps=[int(i) for i in a[1].split('_')[1:]]
634     return name,tuple(comps)
635    
636     @staticmethod
637     def _symbolgen(*symbols):
638     """
639     Generator of all symbols in the argument of diff().
640     (cf. sympy.Derivative._symbolgen)
641    
642     Example:
643     >> ._symbolgen(x, 3, y)
644     (x, x, x, y)
645     >> ._symbolgen(x, 10**6)
646     (x, x, x, x, x, x, x, ...)
647     """
648     from itertools import repeat
649     last_s = symbols[len(symbols)-1]
650     if not isinstance(last_s, Symbol):
651     last_s=sympy.sympify(last_s)
652     for i in xrange(len(symbols)):
653     s = symbols[i]
654     if not isinstance(s, Symbol):
655     s=sympy.sympify(s)
656     next_s = None
657     if s != last_s:
658     next_s = symbols[i+1]
659     if not isinstance(next_s, Symbol):
660     next_s=sympy.sympify(next_s)
661    
662     if isinstance(s, sympy.Integer):
663     continue
664     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
665     # handle cases like (x, 3)
666     if isinstance(next_s, sympy.Integer):
667     # yield (x, x, x)
668     for copy_s in repeat(s,int(next_s)):
669     yield copy_s
670     else:
671     yield s
672     else:
673     yield s
674    
675 caltinay 3536 def __array__(self):
676     return self._arr
677    
678 caltinay 3507 # unary/binary operations follow
679    
680     def __pos__(self):
681     return self
682    
683     def __neg__(self):
684 caltinay 3530 return Symbol(-self._arr, dim=self.dim)
685 caltinay 3507
686     def __abs__(self):
687 caltinay 3530 return Symbol(abs(self._arr), dim=self.dim)
688 caltinay 3507
689     def __add__(self, other):
690 caltinay 3517 self._ensureShapeCompatible(other)
691 caltinay 3507 if isinstance(other, Symbol):
692 caltinay 3530 return Symbol(self._arr+other._arr, dim=self.dim)
693     return Symbol(self._arr+other, dim=self.dim)
694 caltinay 3507
695     def __radd__(self, other):
696 caltinay 3517 self._ensureShapeCompatible(other)
697 caltinay 3507 if isinstance(other, Symbol):
698 caltinay 3530 return Symbol(other._arr+self._arr, dim=self.dim)
699     return Symbol(other+self._arr, dim=self.dim)
700 caltinay 3507
701     def __sub__(self, other):
702 caltinay 3517 self._ensureShapeCompatible(other)
703 caltinay 3507 if isinstance(other, Symbol):
704 caltinay 3530 return Symbol(self._arr-other._arr, dim=self.dim)
705     return Symbol(self._arr-other, dim=self.dim)
706 caltinay 3507
707     def __rsub__(self, other):
708 caltinay 3517 self._ensureShapeCompatible(other)
709 caltinay 3507 if isinstance(other, Symbol):
710 caltinay 3530 return Symbol(other._arr-self._arr, dim=self.dim)
711     return Symbol(other-self._arr, dim=self.dim)
712 caltinay 3507
713     def __mul__(self, other):
714 caltinay 3517 self._ensureShapeCompatible(other)
715 caltinay 3507 if isinstance(other, Symbol):
716 caltinay 3530 return Symbol(self._arr*other._arr, dim=self.dim)
717     return Symbol(self._arr*other, dim=self.dim)
718 caltinay 3507
719     def __rmul__(self, other):
720 caltinay 3517 self._ensureShapeCompatible(other)
721 caltinay 3507 if isinstance(other, Symbol):
722 caltinay 3530 return Symbol(other._arr*self._arr, dim=self.dim)
723     return Symbol(other*self._arr, dim=self.dim)
724 caltinay 3507
725     def __div__(self, other):
726 caltinay 3517 self._ensureShapeCompatible(other)
727 caltinay 3507 if isinstance(other, Symbol):
728 caltinay 3530 return Symbol(self._arr/other._arr, dim=self.dim)
729     return Symbol(self._arr/other, dim=self.dim)
730 caltinay 3507
731     def __rdiv__(self, other):
732 caltinay 3517 self._ensureShapeCompatible(other)
733 caltinay 3507 if isinstance(other, Symbol):
734 caltinay 3530 return Symbol(other._arr/self._arr, dim=self.dim)
735     return Symbol(other/self._arr, dim=self.dim)
736 caltinay 3507
737     def __pow__(self, other):
738 caltinay 3517 self._ensureShapeCompatible(other)
739 caltinay 3507 if isinstance(other, Symbol):
740 caltinay 3530 return Symbol(self._arr**other._arr, dim=self.dim)
741     return Symbol(self._arr**other, dim=self.dim)
742 caltinay 3507
743     def __rpow__(self, other):
744 caltinay 3517 self._ensureShapeCompatible(other)
745 caltinay 3507 if isinstance(other, Symbol):
746 caltinay 3530 return Symbol(other._arr**self._arr, dim=self.dim)
747     return Symbol(other**self._arr, dim=self.dim)
748 caltinay 3507

  ViewVC Help
Powered by ViewVC 1.1.26