/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3990 - (hide annotations)
Tue Sep 25 05:03:20 2012 UTC (7 years ago) by caltinay
File MIME type: text/x-python
File size: 28839 byte(s)
First set of assorted epydoc fixes/additions.

1 caltinay 3507
2 jfenwick 3981 ##############################################################################
3 caltinay 3507 #
4 caltinay 3815 # Copyright (c) 2003-2012 by University of Queensland
5 jfenwick 3981 # http://www.uq.edu.au
6 caltinay 3507 #
7     # Primary Business: Queensland, Australia
8     # Licensed under the Open Software License version 3.0
9     # http://www.opensource.org/licenses/osl-3.0.php
10     #
11 jfenwick 3981 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12     # Development since 2012 by School of Earth Sciences
13     #
14     ##############################################################################
15 caltinay 3507
16 caltinay 3815 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
17 jfenwick 3981 http://www.uq.edu.au
18 caltinay 3507 Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22 caltinay 3864 __author__="Cihan Altinay"
23 caltinay 3507
24     """
25     :var __author__: name of author
26     :var __copyright__: copyrights
27     :var __license__: licence agreement
28     :var __url__: url entry point on documentation
29     :var __version__: version
30     :var __date__: date of the version
31     """
32    
33     import numpy
34 caltinay 3978 from esys.escript import Data, FunctionSpace, HAVE_SYMBOLS
35     if HAVE_SYMBOLS:
36     import sympy
37 caltinay 3507
38 gross 3856
39 caltinay 3507 class Symbol(object):
40     """
41 caltinay 3854 `Symbol` objects are placeholders for a single mathematical symbol, such as
42     'x', or for arbitrarily complex mathematical expressions such as
43 caltinay 3533 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
44 caltinay 3990 are also Symbols (the symbolic 'atoms' of the expression).
45 caltinay 3533
46     With the help of the 'Evaluator' class these symbols and expressions can
47     be resolved by substituting numeric values and/or escript `Data` objects
48     for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
49     shape (and thus a rank) as well as a dimension (see constructor).
50 caltinay 3990 Symbols are useful to perform mathematical simplifications, compute
51 caltinay 3533 derivatives and as coefficients for nonlinear PDEs which can be solved by
52     the `NonlinearPDE` class.
53 caltinay 3507 """
54    
55 caltinay 3818 # these are for compatibility with sympy.Symbol. lambdify checks these.
56     is_Add=False
57     is_Float=False
58    
59 caltinay 3507 def __init__(self, *args, **kwargs):
60     """
61 caltinay 3533 Initialises a new `Symbol` object in one of three ways::
62    
63     u=Symbol('u')
64    
65     returns a scalar symbol by the name 'u'.
66    
67 caltinay 3535 a=Symbol('alpha', (4,3))
68 caltinay 3533
69     returns a rank 2 symbol with the shape (4,3), whose elements are
70     named '[alpha]_i_j' (with i=0..3, j=0..2).
71    
72     a,b,c=symbols('a,b,c')
73     x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
74    
75     returns a rank 2 symbol with the shape (3,3) whose elements are
76 caltinay 3536 explicitly specified by numeric values and other symbols/expressions
77 caltinay 3533 within a list or numpy array.
78    
79 caltinay 3990 The dimensionality of the symbol can be specified through the ``dim``
80 caltinay 3533 keyword. All other keywords are passed to the underlying symbolic
81     library (currently sympy).
82    
83     :param args: initialisation arguments as described above
84     :keyword dim: dimensionality of the new Symbol (default: 2)
85     :type dim: ``int``
86 caltinay 3507 """
87 caltinay 3978 if not HAVE_SYMBOLS:
88     raise RuntimeError("Trying to instantiate a Symbol but sympy not available")
89    
90 caltinay 3530 if 'dim' in kwargs:
91 caltinay 3864 self._dim=kwargs.pop('dim')
92 caltinay 3530 else:
93 caltinay 3872 self._dim=-1 # undefined
94 caltinay 3530
95 caltinay 3864 if 'subs' in kwargs:
96     self._subs=kwargs.pop('subs')
97     else:
98     self._subs={}
99    
100 caltinay 3507 if len(args)==1:
101     arg=args[0]
102     if isinstance(arg, str):
103 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
104 caltinay 3533 raise ValueError("Name must not contain '[' or ']'")
105 caltinay 3512 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
106 caltinay 3533 elif hasattr(arg, "__array__") or isinstance(arg, list):
107     if isinstance(arg, list): arg=numpy.array(arg)
108 caltinay 3507 arr=arg.__array__()
109     if len(arr.shape)>4:
110     raise ValueError("Symbol only supports tensors up to order 4")
111 caltinay 3533 res=numpy.empty(arr.shape, dtype=object)
112     for idx in numpy.ndindex(arr.shape):
113     if hasattr(arr[idx], "item"):
114     res[idx]=arr[idx].item()
115     else:
116     res[idx]=arr[idx]
117     self._arr=res
118 caltinay 3864 if isinstance(arg, Symbol):
119     self._subs.update(arg._subs)
120 caltinay 3872 if self._dim==-1:
121     self._dim=arg._dim
122 caltinay 3533 elif isinstance(arg, sympy.Basic):
123 caltinay 3512 self._arr=numpy.array(arg)
124 caltinay 3507 else:
125     raise TypeError("Unsupported argument type %s"%str(type(arg)))
126     elif len(args)==2:
127     if not isinstance(args[0], str):
128     raise TypeError("First argument must be a string")
129     if not isinstance(args[1], tuple):
130     raise TypeError("Second argument must be a tuple")
131     name=args[0]
132     shape=args[1]
133 caltinay 3533 if name.find('[')>=0 or name.find(']')>=0:
134     raise ValueError("Name must not contain '[' or ']'")
135 caltinay 3507 if len(shape)>4:
136     raise ValueError("Symbol only supports tensors up to order 4")
137 caltinay 3509 if len(shape)==0:
138 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
139 caltinay 3509 else:
140 caltinay 3631 try:
141     self._arr=sympy.symarray(shape, '['+name+']')
142     except TypeError:
143     self._arr=sympy.symarray('['+name+']', shape)
144 caltinay 3507 else:
145     raise TypeError("Unsupported number of arguments")
146 caltinay 3512 if self._arr.ndim==0:
147     self.name=str(self._arr.item())
148 caltinay 3507 else:
149 caltinay 3512 self.name=str(self._arr.tolist())
150 caltinay 3507
151     def __repr__(self):
152 caltinay 3512 return str(self._arr)
153 caltinay 3507
154     def __str__(self):
155 caltinay 3512 return str(self._arr)
156 caltinay 3507
157     def __eq__(self, other):
158     if type(self) is not type(other):
159     return False
160     if self.getRank()!=other.getRank():
161     return False
162     if self.getShape()!=other.getShape():
163     return False
164 caltinay 3512 return (self._arr==other._arr).all()
165 caltinay 3507
166     def __getitem__(self, key):
167 caltinay 3862 """
168     Returns an element of this symbol which must have rank >0.
169     Unlike item() this method converts sympy objects and numpy arrays into
170     escript Symbols in order to facilitate expressions that require
171     element access, such as: grad(u)[1]+x
172 caltinay 3507
173 caltinay 3862 :param key: (nd-)index of item to be returned
174     :return: the requested element
175     :rtype: ``Symbol``, ``int``, or ``float``
176     """
177     res=self._arr[key]
178     # replace sympy Symbols/expressions by escript Symbols
179     if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
180     res=Symbol(res)
181     return res
182 caltinay 3818
183 caltinay 3507 def __setitem__(self, key, value):
184 caltinay 3533 if isinstance(value, Symbol):
185 caltinay 3507 if value.getRank()==0:
186 caltinay 3533 self._arr[key]=value.item()
187 caltinay 3512 elif hasattr(self._arr[key], "shape"):
188     if self._arr[key].shape==value.getShape():
189 caltinay 3533 for idx in numpy.ndindex(self._arr[key].shape):
190 caltinay 3862 self._arr[key][idx]=value[idx].item()
191 caltinay 3507 else:
192     raise ValueError("Wrong shape of value")
193     else:
194     raise ValueError("Wrong shape of value")
195 caltinay 3533 elif isinstance(value, sympy.Basic):
196 caltinay 3512 self._arr[key]=value
197 caltinay 3507 elif hasattr(value, "__array__"):
198 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
199 caltinay 3507 else:
200 caltinay 3512 self._arr[key]=sympy.sympify(value)
201 caltinay 3507
202 caltinay 3862 def __iter__(self):
203     return self._arr.__iter__
204    
205 caltinay 3902 def __array__(self, t=None):
206     if t:
207     return self._arr.astype(t)
208     else:
209     return self._arr
210 caltinay 3864
211     def _sympy_(self):
212     """
213     """
214     return self.applyfunc(sympy.sympify)
215    
216 caltinay 3533 def getDim(self):
217     """
218     Returns the spatial dimensionality of this symbol.
219    
220 caltinay 3872 :return: the symbol's spatial dimensionality, or -1 if undefined
221 caltinay 3533 :rtype: ``int``
222     """
223 caltinay 3864 return self._dim
224 caltinay 3533
225 caltinay 3507 def getRank(self):
226 caltinay 3533 """
227     Returns the rank of this symbol.
228    
229     :return: the symbol's rank which is equal to the length of the shape.
230     :rtype: ``int``
231     """
232 caltinay 3512 return self._arr.ndim
233 caltinay 3507
234     def getShape(self):
235 caltinay 3533 """
236     Returns the shape of this symbol.
237    
238     :return: the symbol's shape
239     :rtype: ``tuple`` of ``int``
240     """
241 caltinay 3512 return self._arr.shape
242 caltinay 3507
243 caltinay 3864 def getDataSubstitutions(self):
244     """
245     Returns a dictionary of symbol names and the escript ``Data`` objects
246     they represent within this Symbol.
247    
248     :return: the dictionary of substituted ``Data`` objects
249     :rtype: ``dict``
250     """
251     return self._subs
252    
253 caltinay 3533 def item(self, *args):
254     """
255     Returns an element of this symbol.
256     This method behaves like the item() method of numpy.ndarray.
257     If this is a scalar Symbol, no arguments are allowed and the only
258     element in this Symbol is returned.
259     Otherwise, 'args' specifies a flat or nd-index and the element at
260     that index is returned.
261    
262     :param args: index of item to be returned
263     :return: the requested element
264     :rtype: ``sympy.Symbol``, ``int``, or ``float``
265     """
266     return self._arr.item(args)
267    
268 caltinay 3507 def atoms(self, *types):
269 caltinay 3533 """
270     Returns the atoms that form the current Symbol.
271    
272     By default, only objects that are truly atomic and cannot be divided
273     into smaller pieces are returned: symbols, numbers, and number
274     symbols like I and pi. It is possible to request atoms of any type,
275     however.
276    
277     Note that if this symbol contains components such as [x]_i_j then
278     only their main symbol 'x' is returned.
279    
280     :param types: types to restrict result to
281     :return: list of atoms of specified type
282     :rtype: ``set``
283     """
284 caltinay 3507 s=set()
285 caltinay 3512 for el in self._arr.flat:
286 caltinay 3530 if isinstance(el,sympy.Basic):
287     atoms=el.atoms(*types)
288     for a in atoms:
289     if a.is_Symbol:
290     n,c=Symbol._symComp(a)
291     s.add(sympy.Symbol(n))
292     else:
293     s.add(a)
294 caltinay 3533 elif len(types)==0 or type(el) in types:
295     s.add(el)
296 caltinay 3507 return s
297    
298     def _sympystr_(self, printer):
299 caltinay 3818 # compatibility with sympy 1.6
300     return self._sympystr(printer)
301    
302     def _sympystr(self, printer):
303 caltinay 3507 return self.lambdarepr()
304    
305     def lambdarepr(self):
306 caltinay 3872 """
307     """
308 caltinay 3507 from sympy.printing.lambdarepr import lambdarepr
309 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
310     for idx,el in numpy.ndenumerate(self._arr):
311 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
312 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
313 caltinay 3507 symdict={}
314     for a in atoms:
315     n,c=Symbol._symComp(a)
316     if len(c)>0:
317     c=[str(i) for i in c]
318     symstr=n+'['+','.join(c)+']'
319     else:
320     symstr=n
321     symdict[a.name]=symstr
322     s=lambdarepr(el)
323     for key in symdict:
324     s=s.replace(key, symdict[key])
325     temp_arr[idx]=s
326 caltinay 3517 if self.getRank()==0:
327     return temp_arr.item()
328     else:
329     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
330 caltinay 3507
331 caltinay 3530 def coeff(self, x, expand=True):
332 caltinay 3536 """
333     Returns the coefficient of the term "x" or 0 if there is no "x".
334    
335     If "x" is a scalar symbol then "x" is searched in all components of
336     this symbol. Otherwise the shapes must match and the coefficients are
337     checked component by component.
338    
339     Example::
340    
341     x=Symbol('x', (2,2))
342     y=3*x
343     print y.coeff(x)
344     print y.coeff(x[1,1])
345    
346     will print::
347    
348     [[3 3]
349     [3 3]]
350    
351     [[0 0]
352     [0 3]]
353    
354     :param x: the term whose coefficients are to be found
355     :type x: ``Symbol``, ``numpy.ndarray``, `list`
356     :return: the coefficient(s) of the term
357     :rtype: ``Symbol``
358     """
359 caltinay 3530 self._ensureShapeCompatible(x)
360 caltinay 3536 if hasattr(x, '__array__'):
361     y=x.__array__()
362 caltinay 3530 else:
363 caltinay 3536 y=numpy.array(x)
364 caltinay 3530
365 caltinay 3536 if y.ndim>0:
366     result=numpy.zeros(self.getShape(), dtype=object)
367     for idx in numpy.ndindex(y.shape):
368     if y[idx]!=0:
369 caltinay 3862 res=self._arr[idx].coeff(y[idx], expand)
370 caltinay 3536 if res is not None:
371     result[idx]=res
372     elif y.item()==0:
373     result=numpy.zeros(self.getShape(), dtype=object)
374     else:
375     coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
376     none_to_zero=lambda item: 0 if item is None else item
377     result=self.applyfunc(coeff_item)
378 caltinay 3614 result=result.applyfunc(none_to_zero)
379 caltinay 3864 res=Symbol(result, dim=self._dim)
380     for i in self._subs: res.subs(i, self._subs[i])
381     return res
382 caltinay 3530
383 caltinay 3614 def subs(self, old, new):
384     """
385     Substitutes an expression.
386     """
387 caltinay 3864 old._ensureShapeCompatible(new)
388     if isinstance(new, Data):
389     subs=self._subs.copy()
390     if isinstance(old, Symbol) and old.getRank()>0:
391     old=Symbol(old.atoms(sympy.Symbol)[0])
392     subs[old]=new
393     result=Symbol(self._arr, dim=self._dim, subs=subs)
394     elif isinstance(old, Symbol) and old.getRank()>0:
395 caltinay 3614 if hasattr(new, '__array__'):
396     new=new.__array__()
397     else:
398     new=numpy.array(new)
399    
400 caltinay 3864 result=numpy.empty(self.getShape(), dtype=object)
401 caltinay 3614 if new.ndim>0:
402     for idx in numpy.ndindex(self.getShape()):
403     for symidx in numpy.ndindex(new.shape):
404 caltinay 3864 result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
405 caltinay 3614 else: # substitute scalar for non-scalar
406     for idx in numpy.ndindex(self.getShape()):
407     for symidx in numpy.ndindex(old.getShape()):
408 caltinay 3864 result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
409     result=Symbol(result, dim=self._dim, subs=self._subs)
410 caltinay 3614 else: # scalar
411     if isinstance(new, Symbol):
412     new=new.item()
413     if isinstance(old, Symbol):
414     old=old.item()
415     subs_item=lambda item: getattr(item, 'subs')(old, new)
416     result=self.applyfunc(subs_item)
417     return result
418    
419 caltinay 3507 def diff(self, *symbols, **assumptions):
420 caltinay 3533 """
421     """
422 caltinay 3507 symbols=Symbol._symbolgen(*symbols)
423 caltinay 3864 result=Symbol(self._arr, dim=self._dim, subs=self._subs)
424 caltinay 3507 for s in symbols:
425     if isinstance(s, Symbol):
426 caltinay 3530 if s.getRank()==0:
427 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
428 caltinay 3507 result=result.applyfunc(diff_item)
429 caltinay 3530 elif s.getRank()==1:
430     dim=s.getShape()[0]
431     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
432     for d in range(dim):
433     for idx in numpy.ndindex(self.getShape()):
434     index=idx+(d,)
435 caltinay 3862 out[index]=out[index].diff(s[d].item(), **assumptions)
436 caltinay 3864 result=Symbol(out, dim=self._dim, subs=self._subs)
437 caltinay 3530 else:
438 caltinay 3864 raise ValueError("diff: argument must have rank 0 or 1")
439 caltinay 3507 else:
440     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
441     result=result.applyfunc(diff_item)
442     return result
443    
444 caltinay 3512 def grad(self, where=None):
445 caltinay 3533 """
446 caltinay 3872 Returns a symbol which represents the gradient of this symbol.
447 caltinay 3864 :type where: ``Symbol``, ``FunctionSpace``
448 caltinay 3533 """
449 caltinay 3872 if self._dim < 0:
450     raise ValueError("grad: cannot compute gradient as symbol has undefined dimensionality")
451 caltinay 3864 subs=self._subs
452 caltinay 3512 if isinstance(where, Symbol):
453     if where.getRank()>0:
454 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
455 caltinay 3512 where=where._arr.item()
456 caltinay 3864 elif isinstance(where, FunctionSpace):
457     name='fs'+str(id(where))
458     fssym=Symbol(name)
459     subs=self._subs.copy()
460     subs.update({fssym:where})
461     where=name
462 caltinay 3512
463     from functions import grad_n
464 caltinay 3864 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
465     for d in range(self._dim):
466 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
467     index=idx+(d,)
468     if where is None:
469     out[index]=grad_n(out[index],d)
470     else:
471     out[index]=grad_n(out[index],d,where)
472 caltinay 3864 return Symbol(out, dim=self._dim, subs=subs)
473 caltinay 3512
474 caltinay 3517 def inverse(self):
475 caltinay 3533 """
476     """
477 caltinay 3517 if not self.getRank()==2:
478 caltinay 3533 raise TypeError("inverse: Only rank 2 supported")
479 caltinay 3517 s=self.getShape()
480     if not s[0] == s[1]:
481     raise ValueError("inverse: Only square shapes supported")
482     out=numpy.zeros(s, numpy.object)
483     arr=self._arr
484     if s[0]==1:
485     if arr[0,0].is_zero:
486     raise ZeroDivisionError("inverse: Symbol not invertible")
487     out[0,0]=1./arr[0,0]
488     elif s[0]==2:
489     A11=arr[0,0]
490     A12=arr[0,1]
491     A21=arr[1,0]
492     A22=arr[1,1]
493     D = A11*A22-A12*A21
494     if D.is_zero:
495     raise ZeroDivisionError("inverse: Symbol not invertible")
496     D=1./D
497     out[0,0]= A22*D
498     out[1,0]=-A21*D
499     out[0,1]=-A12*D
500     out[1,1]= A11*D
501     elif s[0]==3:
502     A11=arr[0,0]
503     A21=arr[1,0]
504     A31=arr[2,0]
505     A12=arr[0,1]
506     A22=arr[1,1]
507     A32=arr[2,1]
508     A13=arr[0,2]
509     A23=arr[1,2]
510     A33=arr[2,2]
511     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
512     if D.is_zero:
513     raise ZeroDivisionError("inverse: Symbol not invertible")
514     D=1./D
515     out[0,0]=(A22*A33-A23*A32)*D
516     out[1,0]=(A31*A23-A21*A33)*D
517     out[2,0]=(A21*A32-A31*A22)*D
518     out[0,1]=(A13*A32-A12*A33)*D
519     out[1,1]=(A11*A33-A31*A13)*D
520     out[2,1]=(A12*A31-A11*A32)*D
521     out[0,2]=(A12*A23-A13*A22)*D
522     out[1,2]=(A13*A21-A11*A23)*D
523     out[2,2]=(A11*A22-A12*A21)*D
524     else:
525     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
526 caltinay 3864 return Symbol(out, dim=self._dim, subs=self._subs)
527 caltinay 3517
528 caltinay 3507 def swap_axes(self, axis0, axis1):
529 caltinay 3533 """
530     """
531 caltinay 3864 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
532 caltinay 3507
533 caltinay 3509 def tensorProduct(self, other, axis_offset):
534 caltinay 3533 """
535     """
536 caltinay 3512 arg0_c=self._arr.copy()
537     sh0=self.getShape()
538 caltinay 3507 if isinstance(other, Symbol):
539 caltinay 3512 arg1_c=other._arr.copy()
540 caltinay 3507 sh1=other.getShape()
541 caltinay 3872 dim=other._dim if self._dim < 0 else self._dim
542 caltinay 3507 else:
543     arg1_c=other.copy()
544     sh1=other.shape
545 caltinay 3872 dim=self._dim
546 caltinay 3507 d0,d1,d01=1,1,1
547 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
548 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
549     for i in sh1[:axis_offset]: d01*=i
550     arg0_c.resize((d0,d01))
551     arg1_c.resize((d01,d1))
552     out=numpy.zeros((d0,d1),numpy.object)
553     for i0 in range(d0):
554     for i1 in range(d1):
555     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
556 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
557 caltinay 3864 subs=self._subs.copy()
558     subs.update(other._subs)
559 caltinay 3872 return Symbol(out, dim=dim, subs=subs)
560 caltinay 3507
561 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
562 caltinay 3533 """
563     """
564 caltinay 3512 arg0_c=self._arr.copy()
565     sh0=self.getShape()
566 caltinay 3509 if isinstance(other, Symbol):
567 caltinay 3512 arg1_c=other._arr.copy()
568 caltinay 3509 sh1=other.getShape()
569 caltinay 3872 dim=other._dim if self._dim < 0 else self._dim
570 caltinay 3509 else:
571     arg1_c=other.copy()
572     sh1=other.shape
573 caltinay 3872 dim=self._dim
574 caltinay 3509 d0,d1,d01=1,1,1
575     for i in sh0[axis_offset:]: d0*=i
576     for i in sh1[axis_offset:]: d1*=i
577     for i in sh1[:axis_offset]: d01*=i
578     arg0_c.resize((d01,d0))
579     arg1_c.resize((d01,d1))
580     out=numpy.zeros((d0,d1),numpy.object)
581     for i0 in range(d0):
582     for i1 in range(d1):
583     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
584     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
585 caltinay 3864 subs=self._subs.copy()
586     subs.update(other._subs)
587 caltinay 3872 return Symbol(out, dim=dim, subs=subs)
588 caltinay 3509
589     def tensorTransposedProduct(self, other, axis_offset):
590 caltinay 3533 """
591     """
592 caltinay 3512 arg0_c=self._arr.copy()
593     sh0=self.getShape()
594 caltinay 3509 if isinstance(other, Symbol):
595 caltinay 3512 arg1_c=other._arr.copy()
596 caltinay 3509 sh1=other.getShape()
597     r1=other.getRank()
598 caltinay 3872 dim=other._dim if self._dim < 0 else self._dim
599 caltinay 3509 else:
600     arg1_c=other.copy()
601     sh1=other.shape
602     r1=other.ndim
603 caltinay 3872 dim=self._dim
604 caltinay 3509 d0,d1,d01=1,1,1
605 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
606 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
607     for i in sh1[r1-axis_offset:]: d01*=i
608     arg0_c.resize((d0,d01))
609     arg1_c.resize((d1,d01))
610     out=numpy.zeros((d0,d1),numpy.object)
611     for i0 in range(d0):
612     for i1 in range(d1):
613     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
614 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
615 caltinay 3864 subs=self._subs.copy()
616     subs.update(other._subs)
617 caltinay 3872 return Symbol(out, dim=dim, subs=subs)
618 caltinay 3509
619 caltinay 3507 def trace(self, axis_offset):
620 caltinay 3533 """
621 caltinay 3872 Returns the trace of this Symbol.
622 caltinay 3533 """
623 caltinay 3512 sh=self.getShape()
624 caltinay 3507 s1=1
625     for i in range(axis_offset): s1*=sh[i]
626     s2=1
627     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
628 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
629 caltinay 3507 out=numpy.zeros([s1,s2],object)
630     for i1 in range(s1):
631     for i2 in range(s2):
632     for j in range(sh[axis_offset]):
633     out[i1,i2]+=arr_r[i1,j,j,i2]
634     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
635 caltinay 3864 return Symbol(out, dim=self._dim, subs=self._subs)
636 caltinay 3507
637     def transpose(self, axis_offset):
638 caltinay 3533 """
639 caltinay 3872 Returns the transpose of this Symbol.
640 caltinay 3533 """
641 caltinay 3507 if axis_offset is None:
642 caltinay 3512 axis_offset=int(self._arr.ndim/2)
643     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
644 caltinay 3864 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
645 caltinay 3507
646 caltinay 3532 def applyfunc(self, f, on_type=None):
647 caltinay 3533 """
648 caltinay 3872 Applies the function `f` to all elements (if on_type is None) or to
649     all elements of type `on_type`.
650 caltinay 3533 """
651 caltinay 3507 assert callable(f)
652 caltinay 3512 if self._arr.ndim==0:
653 caltinay 3864 if on_type is None or isinstance(self._arr.item(), on_type):
654 caltinay 3532 el=f(self._arr.item())
655     else:
656     el=self._arr.item()
657 caltinay 3530 if el is not None:
658 caltinay 3864 out=Symbol(el, dim=self._dim, subs=self._subs)
659 caltinay 3530 else:
660     return el
661 caltinay 3507 else:
662 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
663     for idx in numpy.ndindex(self.getShape()):
664 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
665     out[idx]=f(self._arr[idx])
666     else:
667     out[idx]=self._arr[idx]
668 caltinay 3864 out=Symbol(out, dim=self._dim, subs=self._subs)
669 caltinay 3507 return out
670    
671 caltinay 3614 def expand(self):
672     """
673 caltinay 3872 Applies the sympy.expand operation on all elements in this symbol
674 caltinay 3614 """
675     return self.applyfunc(sympy.expand, sympy.Basic)
676    
677 caltinay 3532 def simplify(self):
678 caltinay 3533 """
679 caltinay 3872 Applies the sympy.simplify operation on all elements in this symbol
680 caltinay 3533 """
681 caltinay 3532 return self.applyfunc(sympy.simplify, sympy.Basic)
682    
683 caltinay 3872 # unary/binary operators follow
684 caltinay 3507
685 caltinay 3864 def __pos__(self):
686     return self
687    
688     def __neg__(self):
689     return Symbol(-self._arr, dim=self._dim, subs=self._subs)
690    
691     def __abs__(self):
692     return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
693    
694 caltinay 3517 def _ensureShapeCompatible(self, other):
695     """
696     Checks for compatible shapes for binary operations.
697     Raises TypeError if not compatible.
698     """
699     sh0=self.getShape()
700 caltinay 3864 if isinstance(other, Symbol) or isinstance(other, Data):
701 caltinay 3517 sh1=other.getShape()
702     elif isinstance(other, numpy.ndarray):
703     sh1=other.shape
704 caltinay 3536 elif isinstance(other, list):
705     sh1=numpy.array(other).shape
706 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
707 caltinay 3517 sh1=()
708     else:
709 caltinay 3614 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
710 caltinay 3517 if not sh0==sh1 and not sh0==() and not sh1==():
711 caltinay 3614 raise TypeError("Incompatible shapes for operation")
712 caltinay 3517
713 caltinay 3864 def __binaryop(self, op, other):
714 caltinay 3872 """
715     Helper for binary operations that checks types, shapes etc.
716     """
717 caltinay 3864 self._ensureShapeCompatible(other)
718     if isinstance(other, Symbol):
719     subs=self._subs.copy()
720     subs.update(other._subs)
721 caltinay 3872 dim=other._dim if self._dim < 0 else self._dim
722     return Symbol(getattr(self._arr, op)(other._arr), dim=dim, subs=subs)
723 caltinay 3864 if isinstance(other, Data):
724     name='data'+str(id(other))
725     othersym=Symbol(name, other.getShape(), dim=self._dim)
726     subs=self._subs.copy()
727     subs.update({Symbol(name):other})
728     return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
729     return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
730    
731     def __add__(self, other):
732     return self.__binaryop('__add__', other)
733    
734     def __radd__(self, other):
735     return self.__binaryop('__radd__', other)
736    
737     def __sub__(self, other):
738     return self.__binaryop('__sub__', other)
739    
740     def __rsub__(self, other):
741     return self.__binaryop('__rsub__', other)
742    
743     def __mul__(self, other):
744     return self.__binaryop('__mul__', other)
745    
746     def __rmul__(self, other):
747     return self.__binaryop('__rmul__', other)
748    
749     def __div__(self, other):
750     return self.__binaryop('__div__', other)
751    
752     def __rdiv__(self, other):
753     return self.__binaryop('__rdiv__', other)
754    
755     def __pow__(self, other):
756     return self.__binaryop('__pow__', other)
757    
758     def __rpow__(self, other):
759     return self.__binaryop('__rpow__', other)
760    
761 caltinay 3872 # static methods
762    
763 caltinay 3507 @staticmethod
764     def _symComp(sym):
765 caltinay 3857 """
766     """
767 caltinay 3507 n=sym.name
768     a=n.split('[')
769     if len(a)!=2:
770     return n,()
771     a=a[1].split(']')
772     if len(a)!=2:
773     return n,()
774     name=a[0]
775     comps=[int(i) for i in a[1].split('_')[1:]]
776     return name,tuple(comps)
777    
778     @staticmethod
779     def _symbolgen(*symbols):
780     """
781     Generator of all symbols in the argument of diff().
782     (cf. sympy.Derivative._symbolgen)
783    
784     Example:
785     >> ._symbolgen(x, 3, y)
786     (x, x, x, y)
787     >> ._symbolgen(x, 10**6)
788     (x, x, x, x, x, x, x, ...)
789     """
790     from itertools import repeat
791     last_s = symbols[len(symbols)-1]
792     if not isinstance(last_s, Symbol):
793     last_s=sympy.sympify(last_s)
794     for i in xrange(len(symbols)):
795     s = symbols[i]
796     if not isinstance(s, Symbol):
797     s=sympy.sympify(s)
798     next_s = None
799     if s != last_s:
800     next_s = symbols[i+1]
801     if not isinstance(next_s, Symbol):
802     next_s=sympy.sympify(next_s)
803    
804     if isinstance(s, sympy.Integer):
805     continue
806     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
807     # handle cases like (x, 3)
808     if isinstance(next_s, sympy.Integer):
809     # yield (x, x, x)
810     for copy_s in repeat(s,int(next_s)):
811     yield copy_s
812     else:
813     yield s
814     else:
815     yield s
816    

  ViewVC Help
Powered by ViewVC 1.1.26