/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3862 - (hide annotations)
Fri Mar 9 06:32:35 2012 UTC (7 years, 7 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 26662 byte(s)
Symbol's item accessor now returns escript Symbols instead of sympy or numpy.

1 caltinay 3507
2     ########################################################
3     #
4 caltinay 3815 # Copyright (c) 2003-2012 by University of Queensland
5 caltinay 3507 # Earth Systems Science Computational Center (ESSCC)
6     # http://www.uq.edu.au/esscc
7     #
8     # Primary Business: Queensland, Australia
9     # Licensed under the Open Software License version 3.0
10     # http://www.opensource.org/licenses/osl-3.0.php
11     #
12     ########################################################
13    
14 caltinay 3815 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15 caltinay 3507 Earth Systems Science Computational Center (ESSCC)
16     http://www.uq.edu.au/esscc
17     Primary Business: Queensland, Australia"""
18     __license__="""Licensed under the Open Software License version 3.0
19     http://www.opensource.org/licenses/osl-3.0.php"""
20     __url__="https://launchpad.net/escript-finley"
21    
22     """
23     :var __author__: name of author
24     :var __copyright__: copyrights
25     :var __license__: licence agreement
26     :var __url__: url entry point on documentation
27     :var __version__: version
28     :var __date__: date of the version
29     """
30    
31     import numpy
32     import sympy
33    
34     __author__="Cihan Altinay"
35    
36 gross 3856
37    
38 caltinay 3507 class Symbol(object):
39     """
40 caltinay 3854 `Symbol` objects are placeholders for a single mathematical symbol, such as
41     'x', or for arbitrarily complex mathematical expressions such as
42 caltinay 3533 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
43     are also `Symbol`s (the symbolic 'atoms' of the expression).
44    
45     With the help of the 'Evaluator' class these symbols and expressions can
46     be resolved by substituting numeric values and/or escript `Data` objects
47     for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
48     shape (and thus a rank) as well as a dimension (see constructor).
49 caltinay 3854 `Symbol`s are useful to perform mathematical simplifications, compute
50 caltinay 3533 derivatives and as coefficients for nonlinear PDEs which can be solved by
51     the `NonlinearPDE` class.
52 caltinay 3507 """
53    
54 caltinay 3818 # these are for compatibility with sympy.Symbol. lambdify checks these.
55     is_Add=False
56     is_Float=False
57    
58 caltinay 3507 def __init__(self, *args, **kwargs):
59     """
60 caltinay 3533 Initialises a new `Symbol` object in one of three ways::
61    
62     u=Symbol('u')
63    
64     returns a scalar symbol by the name 'u'.
65    
66 caltinay 3535 a=Symbol('alpha', (4,3))
67 caltinay 3533
68     returns a rank 2 symbol with the shape (4,3), whose elements are
69     named '[alpha]_i_j' (with i=0..3, j=0..2).
70    
71     a,b,c=symbols('a,b,c')
72     x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
73    
74     returns a rank 2 symbol with the shape (3,3) whose elements are
75 caltinay 3536 explicitly specified by numeric values and other symbols/expressions
76 caltinay 3533 within a list or numpy array.
77    
78     The dimensionality of the symbol can be specified through the `dim`
79     keyword. All other keywords are passed to the underlying symbolic
80     library (currently sympy).
81    
82     :param args: initialisation arguments as described above
83     :keyword dim: dimensionality of the new Symbol (default: 2)
84     :type dim: ``int``
85 caltinay 3507 """
86 caltinay 3530 if 'dim' in kwargs:
87     self.dim=kwargs.pop('dim')
88     else:
89     self.dim=2
90    
91 caltinay 3507 if len(args)==1:
92     arg=args[0]
93     if isinstance(arg, str):
94 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
95 caltinay 3533 raise ValueError("Name must not contain '[' or ']'")
96 caltinay 3512 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
97 caltinay 3533 elif hasattr(arg, "__array__") or isinstance(arg, list):
98     if isinstance(arg, list): arg=numpy.array(arg)
99 caltinay 3507 arr=arg.__array__()
100     if len(arr.shape)>4:
101     raise ValueError("Symbol only supports tensors up to order 4")
102 caltinay 3533 res=numpy.empty(arr.shape, dtype=object)
103     for idx in numpy.ndindex(arr.shape):
104     if hasattr(arr[idx], "item"):
105     res[idx]=arr[idx].item()
106     else:
107     res[idx]=arr[idx]
108     self._arr=res
109     elif isinstance(arg, sympy.Basic):
110 caltinay 3512 self._arr=numpy.array(arg)
111 caltinay 3507 else:
112     raise TypeError("Unsupported argument type %s"%str(type(arg)))
113     elif len(args)==2:
114     if not isinstance(args[0], str):
115     raise TypeError("First argument must be a string")
116     if not isinstance(args[1], tuple):
117     raise TypeError("Second argument must be a tuple")
118     name=args[0]
119     shape=args[1]
120 caltinay 3533 if name.find('[')>=0 or name.find(']')>=0:
121     raise ValueError("Name must not contain '[' or ']'")
122 caltinay 3507 if len(shape)>4:
123     raise ValueError("Symbol only supports tensors up to order 4")
124 caltinay 3509 if len(shape)==0:
125 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
126 caltinay 3509 else:
127 caltinay 3631 try:
128     self._arr=sympy.symarray(shape, '['+name+']')
129     except TypeError:
130     self._arr=sympy.symarray('['+name+']', shape)
131 caltinay 3507 else:
132     raise TypeError("Unsupported number of arguments")
133 caltinay 3512 if self._arr.ndim==0:
134     self.name=str(self._arr.item())
135 caltinay 3507 else:
136 caltinay 3512 self.name=str(self._arr.tolist())
137 caltinay 3507
138     def __repr__(self):
139 caltinay 3512 return str(self._arr)
140 caltinay 3507
141     def __str__(self):
142 caltinay 3512 return str(self._arr)
143 caltinay 3507
144     def __eq__(self, other):
145     if type(self) is not type(other):
146     return False
147     if self.getRank()!=other.getRank():
148     return False
149     if self.getShape()!=other.getShape():
150     return False
151 caltinay 3512 return (self._arr==other._arr).all()
152 caltinay 3507
153     def __getitem__(self, key):
154 caltinay 3862 """
155     Returns an element of this symbol which must have rank >0.
156     Unlike item() this method converts sympy objects and numpy arrays into
157     escript Symbols in order to facilitate expressions that require
158     element access, such as: grad(u)[1]+x
159 caltinay 3507
160 caltinay 3862 :param key: (nd-)index of item to be returned
161     :return: the requested element
162     :rtype: ``Symbol``, ``int``, or ``float``
163     """
164     res=self._arr[key]
165     # replace sympy Symbols/expressions by escript Symbols
166     if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
167     res=Symbol(res)
168     return res
169 caltinay 3818
170 caltinay 3507 def __setitem__(self, key, value):
171 caltinay 3533 if isinstance(value, Symbol):
172 caltinay 3507 if value.getRank()==0:
173 caltinay 3533 self._arr[key]=value.item()
174 caltinay 3512 elif hasattr(self._arr[key], "shape"):
175     if self._arr[key].shape==value.getShape():
176 caltinay 3533 for idx in numpy.ndindex(self._arr[key].shape):
177 caltinay 3862 self._arr[key][idx]=value[idx].item()
178 caltinay 3507 else:
179     raise ValueError("Wrong shape of value")
180     else:
181     raise ValueError("Wrong shape of value")
182 caltinay 3533 elif isinstance(value, sympy.Basic):
183 caltinay 3512 self._arr[key]=value
184 caltinay 3507 elif hasattr(value, "__array__"):
185 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
186 caltinay 3507 else:
187 caltinay 3512 self._arr[key]=sympy.sympify(value)
188 caltinay 3507
189 caltinay 3862 def __iter__(self):
190     return self._arr.__iter__
191    
192 caltinay 3533 def getDim(self):
193     """
194     Returns the spatial dimensionality of this symbol.
195    
196     :return: the symbol's spatial dimensionality
197     :rtype: ``int``
198     """
199     return self.dim
200    
201 caltinay 3507 def getRank(self):
202 caltinay 3533 """
203     Returns the rank of this symbol.
204    
205     :return: the symbol's rank which is equal to the length of the shape.
206     :rtype: ``int``
207     """
208 caltinay 3512 return self._arr.ndim
209 caltinay 3507
210     def getShape(self):
211 caltinay 3533 """
212     Returns the shape of this symbol.
213    
214     :return: the symbol's shape
215     :rtype: ``tuple`` of ``int``
216     """
217 caltinay 3512 return self._arr.shape
218 caltinay 3507
219 caltinay 3533 def item(self, *args):
220     """
221     Returns an element of this symbol.
222     This method behaves like the item() method of numpy.ndarray.
223     If this is a scalar Symbol, no arguments are allowed and the only
224     element in this Symbol is returned.
225     Otherwise, 'args' specifies a flat or nd-index and the element at
226     that index is returned.
227    
228     :param args: index of item to be returned
229     :return: the requested element
230     :rtype: ``sympy.Symbol``, ``int``, or ``float``
231     """
232     return self._arr.item(args)
233    
234 caltinay 3507 def atoms(self, *types):
235 caltinay 3533 """
236     Returns the atoms that form the current Symbol.
237    
238     By default, only objects that are truly atomic and cannot be divided
239     into smaller pieces are returned: symbols, numbers, and number
240     symbols like I and pi. It is possible to request atoms of any type,
241     however.
242    
243     Note that if this symbol contains components such as [x]_i_j then
244     only their main symbol 'x' is returned.
245    
246     :param types: types to restrict result to
247     :return: list of atoms of specified type
248     :rtype: ``set``
249     """
250 caltinay 3507 s=set()
251 caltinay 3512 for el in self._arr.flat:
252 caltinay 3530 if isinstance(el,sympy.Basic):
253     atoms=el.atoms(*types)
254     for a in atoms:
255     if a.is_Symbol:
256     n,c=Symbol._symComp(a)
257     s.add(sympy.Symbol(n))
258     else:
259     s.add(a)
260 caltinay 3533 elif len(types)==0 or type(el) in types:
261     s.add(el)
262 caltinay 3507 return s
263    
264     def _sympystr_(self, printer):
265 caltinay 3818 # compatibility with sympy 1.6
266     return self._sympystr(printer)
267    
268     def _sympystr(self, printer):
269 caltinay 3507 return self.lambdarepr()
270    
271     def lambdarepr(self):
272     from sympy.printing.lambdarepr import lambdarepr
273 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
274     for idx,el in numpy.ndenumerate(self._arr):
275 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
276 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
277 caltinay 3507 symdict={}
278     for a in atoms:
279     n,c=Symbol._symComp(a)
280     if len(c)>0:
281     c=[str(i) for i in c]
282     symstr=n+'['+','.join(c)+']'
283     else:
284     symstr=n
285     symdict[a.name]=symstr
286     s=lambdarepr(el)
287     for key in symdict:
288     s=s.replace(key, symdict[key])
289     temp_arr[idx]=s
290 caltinay 3517 if self.getRank()==0:
291     return temp_arr.item()
292     else:
293     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
294 caltinay 3507
295 caltinay 3530 def coeff(self, x, expand=True):
296 caltinay 3536 """
297     Returns the coefficient of the term "x" or 0 if there is no "x".
298    
299     If "x" is a scalar symbol then "x" is searched in all components of
300     this symbol. Otherwise the shapes must match and the coefficients are
301     checked component by component.
302    
303     Example::
304    
305     x=Symbol('x', (2,2))
306     y=3*x
307     print y.coeff(x)
308     print y.coeff(x[1,1])
309    
310     will print::
311    
312     [[3 3]
313     [3 3]]
314    
315     [[0 0]
316     [0 3]]
317    
318     :param x: the term whose coefficients are to be found
319     :type x: ``Symbol``, ``numpy.ndarray``, `list`
320     :return: the coefficient(s) of the term
321     :rtype: ``Symbol``
322     """
323 caltinay 3530 self._ensureShapeCompatible(x)
324 caltinay 3536 if hasattr(x, '__array__'):
325     y=x.__array__()
326 caltinay 3530 else:
327 caltinay 3536 y=numpy.array(x)
328 caltinay 3530
329 caltinay 3536 if y.ndim>0:
330     result=numpy.zeros(self.getShape(), dtype=object)
331     for idx in numpy.ndindex(y.shape):
332     if y[idx]!=0:
333 caltinay 3862 res=self._arr[idx].coeff(y[idx], expand)
334 caltinay 3536 if res is not None:
335     result[idx]=res
336     elif y.item()==0:
337     result=numpy.zeros(self.getShape(), dtype=object)
338     else:
339     coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
340     none_to_zero=lambda item: 0 if item is None else item
341     result=self.applyfunc(coeff_item)
342 caltinay 3614 result=result.applyfunc(none_to_zero)
343 caltinay 3536 return Symbol(result, dim=self.dim)
344 caltinay 3530
345 caltinay 3614 def subs(self, old, new):
346     """
347     Substitutes an expression.
348     """
349     if isinstance(old, Symbol) and old.getRank()>0:
350     old._ensureShapeCompatible(new)
351     if hasattr(new, '__array__'):
352     new=new.__array__()
353     else:
354     new=numpy.array(new)
355    
356     result=Symbol(self._arr, dim=self.dim)
357     if new.ndim>0:
358     for idx in numpy.ndindex(self.getShape()):
359     for symidx in numpy.ndindex(new.shape):
360     result[idx]=result[idx].subs(old[symidx], new[symidx])
361     else: # substitute scalar for non-scalar
362     for idx in numpy.ndindex(self.getShape()):
363     for symidx in numpy.ndindex(old.getShape()):
364     result[idx]=result[idx].subs(old[symidx], new.item())
365     else: # scalar
366     if isinstance(new, Symbol):
367     if new.getRank()>0:
368     raise TypeError("Cannot substitute, incompatible ranks.")
369     new=new.item()
370     if isinstance(old, Symbol):
371     old=old.item()
372     subs_item=lambda item: getattr(item, 'subs')(old, new)
373     result=self.applyfunc(subs_item)
374     return result
375    
376 caltinay 3507 def diff(self, *symbols, **assumptions):
377 caltinay 3533 """
378     """
379 caltinay 3507 symbols=Symbol._symbolgen(*symbols)
380 caltinay 3530 result=Symbol(self._arr, dim=self.dim)
381 caltinay 3507 for s in symbols:
382     if isinstance(s, Symbol):
383 caltinay 3530 if s.getRank()==0:
384 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
385 caltinay 3507 result=result.applyfunc(diff_item)
386 caltinay 3530 elif s.getRank()==1:
387     dim=s.getShape()[0]
388     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
389     for d in range(dim):
390     for idx in numpy.ndindex(self.getShape()):
391     index=idx+(d,)
392 caltinay 3862 out[index]=out[index].diff(s[d].item(), **assumptions)
393 caltinay 3530 result=Symbol(out, dim=self.dim)
394     else:
395     raise ValueError("diff: Only rank 0 and 1 supported")
396 caltinay 3507 else:
397     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
398     result=result.applyfunc(diff_item)
399     return result
400    
401 caltinay 3512 def grad(self, where=None):
402 caltinay 3533 """
403     """
404 caltinay 3512 if isinstance(where, Symbol):
405     if where.getRank()>0:
406 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
407 caltinay 3512 where=where._arr.item()
408    
409     from functions import grad_n
410 caltinay 3530 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
411     for d in range(self.dim):
412 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
413     index=idx+(d,)
414     if where is None:
415     out[index]=grad_n(out[index],d)
416     else:
417     out[index]=grad_n(out[index],d,where)
418 caltinay 3530 return Symbol(out, dim=self.dim)
419 caltinay 3512
420 caltinay 3517 def inverse(self):
421 caltinay 3533 """
422     """
423 caltinay 3517 if not self.getRank()==2:
424 caltinay 3533 raise TypeError("inverse: Only rank 2 supported")
425 caltinay 3517 s=self.getShape()
426     if not s[0] == s[1]:
427     raise ValueError("inverse: Only square shapes supported")
428     out=numpy.zeros(s, numpy.object)
429     arr=self._arr
430     if s[0]==1:
431     if arr[0,0].is_zero:
432     raise ZeroDivisionError("inverse: Symbol not invertible")
433     out[0,0]=1./arr[0,0]
434     elif s[0]==2:
435     A11=arr[0,0]
436     A12=arr[0,1]
437     A21=arr[1,0]
438     A22=arr[1,1]
439     D = A11*A22-A12*A21
440     if D.is_zero:
441     raise ZeroDivisionError("inverse: Symbol not invertible")
442     D=1./D
443     out[0,0]= A22*D
444     out[1,0]=-A21*D
445     out[0,1]=-A12*D
446     out[1,1]= A11*D
447     elif s[0]==3:
448     A11=arr[0,0]
449     A21=arr[1,0]
450     A31=arr[2,0]
451     A12=arr[0,1]
452     A22=arr[1,1]
453     A32=arr[2,1]
454     A13=arr[0,2]
455     A23=arr[1,2]
456     A33=arr[2,2]
457     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
458     if D.is_zero:
459     raise ZeroDivisionError("inverse: Symbol not invertible")
460     D=1./D
461     out[0,0]=(A22*A33-A23*A32)*D
462     out[1,0]=(A31*A23-A21*A33)*D
463     out[2,0]=(A21*A32-A31*A22)*D
464     out[0,1]=(A13*A32-A12*A33)*D
465     out[1,1]=(A11*A33-A31*A13)*D
466     out[2,1]=(A12*A31-A11*A32)*D
467     out[0,2]=(A12*A23-A13*A22)*D
468     out[1,2]=(A13*A21-A11*A23)*D
469     out[2,2]=(A11*A22-A12*A21)*D
470     else:
471     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
472 caltinay 3530 return Symbol(out, dim=self.dim)
473 caltinay 3517
474 caltinay 3507 def swap_axes(self, axis0, axis1):
475 caltinay 3533 """
476     """
477 caltinay 3530 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
478 caltinay 3507
479 caltinay 3509 def tensorProduct(self, other, axis_offset):
480 caltinay 3533 """
481     """
482 caltinay 3512 arg0_c=self._arr.copy()
483     sh0=self.getShape()
484 caltinay 3507 if isinstance(other, Symbol):
485 caltinay 3512 arg1_c=other._arr.copy()
486 caltinay 3507 sh1=other.getShape()
487     else:
488     arg1_c=other.copy()
489     sh1=other.shape
490     d0,d1,d01=1,1,1
491 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
492 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
493     for i in sh1[:axis_offset]: d01*=i
494     arg0_c.resize((d0,d01))
495     arg1_c.resize((d01,d1))
496     out=numpy.zeros((d0,d1),numpy.object)
497     for i0 in range(d0):
498     for i1 in range(d1):
499     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
500 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
501 caltinay 3530 return Symbol(out, dim=self.dim)
502 caltinay 3507
503 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
504 caltinay 3533 """
505     """
506 caltinay 3512 arg0_c=self._arr.copy()
507     sh0=self.getShape()
508 caltinay 3509 if isinstance(other, Symbol):
509 caltinay 3512 arg1_c=other._arr.copy()
510 caltinay 3509 sh1=other.getShape()
511     else:
512     arg1_c=other.copy()
513     sh1=other.shape
514     d0,d1,d01=1,1,1
515     for i in sh0[axis_offset:]: d0*=i
516     for i in sh1[axis_offset:]: d1*=i
517     for i in sh1[:axis_offset]: d01*=i
518     arg0_c.resize((d01,d0))
519     arg1_c.resize((d01,d1))
520     out=numpy.zeros((d0,d1),numpy.object)
521     for i0 in range(d0):
522     for i1 in range(d1):
523     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
524     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
525 caltinay 3530 return Symbol(out, dim=self.dim)
526 caltinay 3509
527     def tensorTransposedProduct(self, other, axis_offset):
528 caltinay 3533 """
529     """
530 caltinay 3512 arg0_c=self._arr.copy()
531     sh0=self.getShape()
532 caltinay 3509 if isinstance(other, Symbol):
533 caltinay 3512 arg1_c=other._arr.copy()
534 caltinay 3509 sh1=other.getShape()
535     r1=other.getRank()
536     else:
537     arg1_c=other.copy()
538     sh1=other.shape
539     r1=other.ndim
540     d0,d1,d01=1,1,1
541 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
542 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
543     for i in sh1[r1-axis_offset:]: d01*=i
544     arg0_c.resize((d0,d01))
545     arg1_c.resize((d1,d01))
546     out=numpy.zeros((d0,d1),numpy.object)
547     for i0 in range(d0):
548     for i1 in range(d1):
549     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
550 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
551 caltinay 3530 return Symbol(out, dim=self.dim)
552 caltinay 3509
553 caltinay 3507 def trace(self, axis_offset):
554 caltinay 3533 """
555     """
556 caltinay 3512 sh=self.getShape()
557 caltinay 3507 s1=1
558     for i in range(axis_offset): s1*=sh[i]
559     s2=1
560     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
561 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
562 caltinay 3507 out=numpy.zeros([s1,s2],object)
563     for i1 in range(s1):
564     for i2 in range(s2):
565     for j in range(sh[axis_offset]):
566     out[i1,i2]+=arr_r[i1,j,j,i2]
567     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
568 caltinay 3530 return Symbol(out, dim=self.dim)
569 caltinay 3507
570     def transpose(self, axis_offset):
571 caltinay 3533 """
572     """
573 caltinay 3507 if axis_offset is None:
574 caltinay 3512 axis_offset=int(self._arr.ndim/2)
575     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
576 caltinay 3530 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
577 caltinay 3507
578 caltinay 3532 def applyfunc(self, f, on_type=None):
579 caltinay 3533 """
580     """
581 caltinay 3507 assert callable(f)
582 caltinay 3512 if self._arr.ndim==0:
583 caltinay 3532 if on_type is None or isinstance(self._arr.item(),on_type):
584     el=f(self._arr.item())
585     else:
586     el=self._arr.item()
587 caltinay 3530 if el is not None:
588     out=Symbol(el, dim=self.dim)
589     else:
590     return el
591 caltinay 3507 else:
592 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
593     for idx in numpy.ndindex(self.getShape()):
594 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
595     out[idx]=f(self._arr[idx])
596     else:
597     out[idx]=self._arr[idx]
598 caltinay 3530 out=Symbol(out, dim=self.dim)
599 caltinay 3507 return out
600    
601 caltinay 3614 def expand(self):
602     """
603     """
604     return self.applyfunc(sympy.expand, sympy.Basic)
605    
606 caltinay 3532 def simplify(self):
607 caltinay 3533 """
608     """
609 caltinay 3532 return self.applyfunc(sympy.simplify, sympy.Basic)
610    
611 caltinay 3507 def _sympy_(self):
612 caltinay 3533 """
613     """
614 caltinay 3507 return self.applyfunc(sympy.sympify)
615    
616 caltinay 3517 def _ensureShapeCompatible(self, other):
617     """
618     Checks for compatible shapes for binary operations.
619     Raises TypeError if not compatible.
620     """
621     sh0=self.getShape()
622 caltinay 3862 if isinstance(other, Symbol): # or isinstance(other, Data):
623 caltinay 3517 sh1=other.getShape()
624     elif isinstance(other, numpy.ndarray):
625     sh1=other.shape
626 caltinay 3536 elif isinstance(other, list):
627     sh1=numpy.array(other).shape
628 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
629 caltinay 3517 sh1=()
630     else:
631 caltinay 3614 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
632 caltinay 3517 if not sh0==sh1 and not sh0==() and not sh1==():
633 caltinay 3614 raise TypeError("Incompatible shapes for operation")
634 caltinay 3517
635 caltinay 3507 @staticmethod
636     def _symComp(sym):
637 caltinay 3857 """
638     """
639 caltinay 3507 n=sym.name
640     a=n.split('[')
641     if len(a)!=2:
642     return n,()
643     a=a[1].split(']')
644     if len(a)!=2:
645     return n,()
646     name=a[0]
647     comps=[int(i) for i in a[1].split('_')[1:]]
648     return name,tuple(comps)
649    
650     @staticmethod
651     def _symbolgen(*symbols):
652     """
653     Generator of all symbols in the argument of diff().
654     (cf. sympy.Derivative._symbolgen)
655    
656     Example:
657     >> ._symbolgen(x, 3, y)
658     (x, x, x, y)
659     >> ._symbolgen(x, 10**6)
660     (x, x, x, x, x, x, x, ...)
661     """
662     from itertools import repeat
663     last_s = symbols[len(symbols)-1]
664     if not isinstance(last_s, Symbol):
665     last_s=sympy.sympify(last_s)
666     for i in xrange(len(symbols)):
667     s = symbols[i]
668     if not isinstance(s, Symbol):
669     s=sympy.sympify(s)
670     next_s = None
671     if s != last_s:
672     next_s = symbols[i+1]
673     if not isinstance(next_s, Symbol):
674     next_s=sympy.sympify(next_s)
675    
676     if isinstance(s, sympy.Integer):
677     continue
678     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
679     # handle cases like (x, 3)
680     if isinstance(next_s, sympy.Integer):
681     # yield (x, x, x)
682     for copy_s in repeat(s,int(next_s)):
683     yield copy_s
684     else:
685     yield s
686     else:
687     yield s
688    
689 caltinay 3536 def __array__(self):
690     return self._arr
691    
692 caltinay 3507 # unary/binary operations follow
693    
694     def __pos__(self):
695     return self
696    
697     def __neg__(self):
698 caltinay 3530 return Symbol(-self._arr, dim=self.dim)
699 caltinay 3507
700     def __abs__(self):
701 caltinay 3530 return Symbol(abs(self._arr), dim=self.dim)
702 caltinay 3507
703     def __add__(self, other):
704 caltinay 3517 self._ensureShapeCompatible(other)
705 caltinay 3507 if isinstance(other, Symbol):
706 caltinay 3530 return Symbol(self._arr+other._arr, dim=self.dim)
707     return Symbol(self._arr+other, dim=self.dim)
708 caltinay 3507
709     def __radd__(self, other):
710 caltinay 3517 self._ensureShapeCompatible(other)
711 caltinay 3507 if isinstance(other, Symbol):
712 caltinay 3530 return Symbol(other._arr+self._arr, dim=self.dim)
713     return Symbol(other+self._arr, dim=self.dim)
714 caltinay 3507
715     def __sub__(self, other):
716 caltinay 3517 self._ensureShapeCompatible(other)
717 caltinay 3507 if isinstance(other, Symbol):
718 caltinay 3530 return Symbol(self._arr-other._arr, dim=self.dim)
719     return Symbol(self._arr-other, dim=self.dim)
720 caltinay 3507
721     def __rsub__(self, other):
722 caltinay 3517 self._ensureShapeCompatible(other)
723 caltinay 3507 if isinstance(other, Symbol):
724 caltinay 3530 return Symbol(other._arr-self._arr, dim=self.dim)
725     return Symbol(other-self._arr, dim=self.dim)
726 caltinay 3507
727     def __mul__(self, other):
728 caltinay 3517 self._ensureShapeCompatible(other)
729 caltinay 3507 if isinstance(other, Symbol):
730 caltinay 3530 return Symbol(self._arr*other._arr, dim=self.dim)
731     return Symbol(self._arr*other, dim=self.dim)
732 caltinay 3507
733     def __rmul__(self, other):
734 caltinay 3517 self._ensureShapeCompatible(other)
735 caltinay 3507 if isinstance(other, Symbol):
736 caltinay 3530 return Symbol(other._arr*self._arr, dim=self.dim)
737     return Symbol(other*self._arr, dim=self.dim)
738 caltinay 3507
739     def __div__(self, other):
740 caltinay 3517 self._ensureShapeCompatible(other)
741 caltinay 3507 if isinstance(other, Symbol):
742 caltinay 3530 return Symbol(self._arr/other._arr, dim=self.dim)
743     return Symbol(self._arr/other, dim=self.dim)
744 caltinay 3507
745     def __rdiv__(self, other):
746 caltinay 3517 self._ensureShapeCompatible(other)
747 caltinay 3507 if isinstance(other, Symbol):
748 caltinay 3530 return Symbol(other._arr/self._arr, dim=self.dim)
749     return Symbol(other/self._arr, dim=self.dim)
750 caltinay 3507
751     def __pow__(self, other):
752 caltinay 3517 self._ensureShapeCompatible(other)
753 caltinay 3507 if isinstance(other, Symbol):
754 caltinay 3530 return Symbol(self._arr**other._arr, dim=self.dim)
755     return Symbol(self._arr**other, dim=self.dim)
756 caltinay 3507
757     def __rpow__(self, other):
758 caltinay 3517 self._ensureShapeCompatible(other)
759 caltinay 3507 if isinstance(other, Symbol):
760 caltinay 3530 return Symbol(other._arr**self._arr, dim=self.dim)
761     return Symbol(other**self._arr, dim=self.dim)
762 caltinay 3507

  ViewVC Help
Powered by ViewVC 1.1.26