/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4019 - (hide annotations)
Thu Oct 11 08:12:55 2012 UTC (7 years, 1 month ago) by jfenwick
File MIME type: text/x-python
File size: 29146 byte(s)
More tabbing errors,
range/xrange
...
1 caltinay 3507
2 jfenwick 3981 ##############################################################################
3 caltinay 3507 #
4 caltinay 3815 # Copyright (c) 2003-2012 by University of Queensland
5 jfenwick 3981 # http://www.uq.edu.au
6 caltinay 3507 #
7     # Primary Business: Queensland, Australia
8     # Licensed under the Open Software License version 3.0
9     # http://www.opensource.org/licenses/osl-3.0.php
10     #
11 jfenwick 3981 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12     # Development since 2012 by School of Earth Sciences
13     #
14     ##############################################################################
15 caltinay 3507
16 caltinay 3815 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
17 jfenwick 3981 http://www.uq.edu.au
18 caltinay 3507 Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22 caltinay 3864 __author__="Cihan Altinay"
23 caltinay 3507
24     """
25     :var __author__: name of author
26     :var __copyright__: copyrights
27     :var __license__: licence agreement
28     :var __url__: url entry point on documentation
29     :var __version__: version
30     :var __date__: date of the version
31     """
32    
33     import numpy
34 caltinay 3978 from esys.escript import Data, FunctionSpace, HAVE_SYMBOLS
35     if HAVE_SYMBOLS:
36     import sympy
37 caltinay 3507
38 gross 3856
39 caltinay 3507 class Symbol(object):
40     """
41 caltinay 3854 `Symbol` objects are placeholders for a single mathematical symbol, such as
42     'x', or for arbitrarily complex mathematical expressions such as
43 caltinay 3533 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
44 caltinay 3990 are also Symbols (the symbolic 'atoms' of the expression).
45 caltinay 3533
46     With the help of the 'Evaluator' class these symbols and expressions can
47     be resolved by substituting numeric values and/or escript `Data` objects
48     for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
49     shape (and thus a rank) as well as a dimension (see constructor).
50 caltinay 3990 Symbols are useful to perform mathematical simplifications, compute
51 caltinay 3533 derivatives and as coefficients for nonlinear PDEs which can be solved by
52     the `NonlinearPDE` class.
53 caltinay 3507 """
54    
55 caltinay 3818 # these are for compatibility with sympy.Symbol. lambdify checks these.
56     is_Add=False
57     is_Float=False
58    
59 caltinay 3507 def __init__(self, *args, **kwargs):
60     """
61 caltinay 3533 Initialises a new `Symbol` object in one of three ways::
62    
63     u=Symbol('u')
64    
65     returns a scalar symbol by the name 'u'.
66    
67 caltinay 3535 a=Symbol('alpha', (4,3))
68 caltinay 3533
69     returns a rank 2 symbol with the shape (4,3), whose elements are
70     named '[alpha]_i_j' (with i=0..3, j=0..2).
71    
72     a,b,c=symbols('a,b,c')
73     x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
74    
75     returns a rank 2 symbol with the shape (3,3) whose elements are
76 caltinay 3536 explicitly specified by numeric values and other symbols/expressions
77 caltinay 3533 within a list or numpy array.
78    
79 caltinay 3990 The dimensionality of the symbol can be specified through the ``dim``
80 caltinay 3533 keyword. All other keywords are passed to the underlying symbolic
81     library (currently sympy).
82    
83     :param args: initialisation arguments as described above
84     :keyword dim: dimensionality of the new Symbol (default: 2)
85     :type dim: ``int``
86 caltinay 3507 """
87 caltinay 3978 if not HAVE_SYMBOLS:
88     raise RuntimeError("Trying to instantiate a Symbol but sympy not available")
89    
90 caltinay 3530 if 'dim' in kwargs:
91 caltinay 3864 self._dim=kwargs.pop('dim')
92 caltinay 3530 else:
93 caltinay 3872 self._dim=-1 # undefined
94 caltinay 3530
95 caltinay 3864 if 'subs' in kwargs:
96     self._subs=kwargs.pop('subs')
97     else:
98     self._subs={}
99    
100 caltinay 3507 if len(args)==1:
101     arg=args[0]
102     if isinstance(arg, str):
103 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
104 caltinay 3533 raise ValueError("Name must not contain '[' or ']'")
105 caltinay 3512 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
106 caltinay 3533 elif hasattr(arg, "__array__") or isinstance(arg, list):
107     if isinstance(arg, list): arg=numpy.array(arg)
108 caltinay 3507 arr=arg.__array__()
109     if len(arr.shape)>4:
110     raise ValueError("Symbol only supports tensors up to order 4")
111 caltinay 3533 res=numpy.empty(arr.shape, dtype=object)
112     for idx in numpy.ndindex(arr.shape):
113     if hasattr(arr[idx], "item"):
114     res[idx]=arr[idx].item()
115     else:
116     res[idx]=arr[idx]
117     self._arr=res
118 caltinay 3864 if isinstance(arg, Symbol):
119     self._subs.update(arg._subs)
120 caltinay 3872 if self._dim==-1:
121     self._dim=arg._dim
122 caltinay 3533 elif isinstance(arg, sympy.Basic):
123 caltinay 3512 self._arr=numpy.array(arg)
124 caltinay 3507 else:
125     raise TypeError("Unsupported argument type %s"%str(type(arg)))
126     elif len(args)==2:
127     if not isinstance(args[0], str):
128     raise TypeError("First argument must be a string")
129     if not isinstance(args[1], tuple):
130     raise TypeError("Second argument must be a tuple")
131     name=args[0]
132     shape=args[1]
133 caltinay 3533 if name.find('[')>=0 or name.find(']')>=0:
134     raise ValueError("Name must not contain '[' or ']'")
135 caltinay 3507 if len(shape)>4:
136     raise ValueError("Symbol only supports tensors up to order 4")
137 caltinay 3509 if len(shape)==0:
138 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
139 caltinay 3509 else:
140 caltinay 3631 try:
141     self._arr=sympy.symarray(shape, '['+name+']')
142     except TypeError:
143     self._arr=sympy.symarray('['+name+']', shape)
144 caltinay 3507 else:
145     raise TypeError("Unsupported number of arguments")
146 caltinay 3512 if self._arr.ndim==0:
147     self.name=str(self._arr.item())
148 caltinay 3507 else:
149 caltinay 3512 self.name=str(self._arr.tolist())
150 caltinay 3507
151     def __repr__(self):
152 caltinay 3512 return str(self._arr)
153 caltinay 3507
154     def __str__(self):
155 caltinay 3512 return str(self._arr)
156 caltinay 3507
157     def __eq__(self, other):
158     if type(self) is not type(other):
159     return False
160     if self.getRank()!=other.getRank():
161     return False
162     if self.getShape()!=other.getShape():
163     return False
164 caltinay 3512 return (self._arr==other._arr).all()
165 jfenwick 4018
166     def __hash__(self):
167     return id(self)
168 caltinay 3507
169     def __getitem__(self, key):
170 caltinay 3862 """
171     Returns an element of this symbol which must have rank >0.
172     Unlike item() this method converts sympy objects and numpy arrays into
173     escript Symbols in order to facilitate expressions that require
174     element access, such as: grad(u)[1]+x
175 caltinay 3507
176 caltinay 3862 :param key: (nd-)index of item to be returned
177     :return: the requested element
178     :rtype: ``Symbol``, ``int``, or ``float``
179     """
180     res=self._arr[key]
181     # replace sympy Symbols/expressions by escript Symbols
182     if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
183     res=Symbol(res)
184     return res
185 caltinay 3818
186 caltinay 3507 def __setitem__(self, key, value):
187 caltinay 3533 if isinstance(value, Symbol):
188 caltinay 3507 if value.getRank()==0:
189 caltinay 3533 self._arr[key]=value.item()
190 caltinay 3512 elif hasattr(self._arr[key], "shape"):
191     if self._arr[key].shape==value.getShape():
192 caltinay 3533 for idx in numpy.ndindex(self._arr[key].shape):
193 caltinay 3862 self._arr[key][idx]=value[idx].item()
194 caltinay 3507 else:
195     raise ValueError("Wrong shape of value")
196     else:
197     raise ValueError("Wrong shape of value")
198 caltinay 3533 elif isinstance(value, sympy.Basic):
199 caltinay 3512 self._arr[key]=value
200 caltinay 3507 elif hasattr(value, "__array__"):
201 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
202 caltinay 3507 else:
203 caltinay 3512 self._arr[key]=sympy.sympify(value)
204 caltinay 3507
205 caltinay 3862 def __iter__(self):
206     return self._arr.__iter__
207    
208 caltinay 3902 def __array__(self, t=None):
209     if t:
210     return self._arr.astype(t)
211     else:
212     return self._arr
213 caltinay 3864
214     def _sympy_(self):
215     """
216     """
217     return self.applyfunc(sympy.sympify)
218    
219 caltinay 3533 def getDim(self):
220     """
221     Returns the spatial dimensionality of this symbol.
222    
223 caltinay 3872 :return: the symbol's spatial dimensionality, or -1 if undefined
224 caltinay 3533 :rtype: ``int``
225     """
226 caltinay 3864 return self._dim
227 caltinay 3533
228 caltinay 3507 def getRank(self):
229 caltinay 3533 """
230     Returns the rank of this symbol.
231    
232     :return: the symbol's rank which is equal to the length of the shape.
233     :rtype: ``int``
234     """
235 caltinay 3512 return self._arr.ndim
236 caltinay 3507
237     def getShape(self):
238 caltinay 3533 """
239     Returns the shape of this symbol.
240    
241     :return: the symbol's shape
242     :rtype: ``tuple`` of ``int``
243     """
244 caltinay 3512 return self._arr.shape
245 caltinay 3507
246 caltinay 3864 def getDataSubstitutions(self):
247     """
248     Returns a dictionary of symbol names and the escript ``Data`` objects
249     they represent within this Symbol.
250    
251     :return: the dictionary of substituted ``Data`` objects
252     :rtype: ``dict``
253     """
254     return self._subs
255    
256 caltinay 3533 def item(self, *args):
257     """
258     Returns an element of this symbol.
259     This method behaves like the item() method of numpy.ndarray.
260     If this is a scalar Symbol, no arguments are allowed and the only
261     element in this Symbol is returned.
262     Otherwise, 'args' specifies a flat or nd-index and the element at
263     that index is returned.
264    
265     :param args: index of item to be returned
266     :return: the requested element
267     :rtype: ``sympy.Symbol``, ``int``, or ``float``
268     """
269     return self._arr.item(args)
270    
271 caltinay 3507 def atoms(self, *types):
272 caltinay 3533 """
273     Returns the atoms that form the current Symbol.
274    
275     By default, only objects that are truly atomic and cannot be divided
276     into smaller pieces are returned: symbols, numbers, and number
277     symbols like I and pi. It is possible to request atoms of any type,
278     however.
279    
280     Note that if this symbol contains components such as [x]_i_j then
281     only their main symbol 'x' is returned.
282    
283     :param types: types to restrict result to
284     :return: list of atoms of specified type
285     :rtype: ``set``
286     """
287 caltinay 3507 s=set()
288 caltinay 3512 for el in self._arr.flat:
289 caltinay 3530 if isinstance(el,sympy.Basic):
290     atoms=el.atoms(*types)
291     for a in atoms:
292     if a.is_Symbol:
293     n,c=Symbol._symComp(a)
294     s.add(sympy.Symbol(n))
295     else:
296     s.add(a)
297 caltinay 3533 elif len(types)==0 or type(el) in types:
298     s.add(el)
299 caltinay 3507 return s
300    
301     def _sympystr_(self, printer):
302 caltinay 3818 # compatibility with sympy 1.6
303     return self._sympystr(printer)
304    
305     def _sympystr(self, printer):
306 caltinay 3507 return self.lambdarepr()
307    
308     def lambdarepr(self):
309 caltinay 3872 """
310     """
311 caltinay 3507 from sympy.printing.lambdarepr import lambdarepr
312 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
313     for idx,el in numpy.ndenumerate(self._arr):
314 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
315 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
316 caltinay 3507 symdict={}
317     for a in atoms:
318     n,c=Symbol._symComp(a)
319     if len(c)>0:
320     c=[str(i) for i in c]
321     symstr=n+'['+','.join(c)+']'
322     else:
323     symstr=n
324     symdict[a.name]=symstr
325     s=lambdarepr(el)
326     for key in symdict:
327     s=s.replace(key, symdict[key])
328     temp_arr[idx]=s
329 caltinay 3517 if self.getRank()==0:
330     return temp_arr.item()
331     else:
332     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
333 caltinay 3507
334 caltinay 3530 def coeff(self, x, expand=True):
335 caltinay 3536 """
336     Returns the coefficient of the term "x" or 0 if there is no "x".
337    
338     If "x" is a scalar symbol then "x" is searched in all components of
339     this symbol. Otherwise the shapes must match and the coefficients are
340     checked component by component.
341    
342     Example::
343    
344     x=Symbol('x', (2,2))
345     y=3*x
346     print y.coeff(x)
347     print y.coeff(x[1,1])
348    
349     will print::
350    
351     [[3 3]
352     [3 3]]
353    
354     [[0 0]
355     [0 3]]
356    
357     :param x: the term whose coefficients are to be found
358     :type x: ``Symbol``, ``numpy.ndarray``, `list`
359     :return: the coefficient(s) of the term
360     :rtype: ``Symbol``
361     """
362 caltinay 3530 self._ensureShapeCompatible(x)
363 caltinay 3536 if hasattr(x, '__array__'):
364     y=x.__array__()
365 caltinay 3530 else:
366 caltinay 3536 y=numpy.array(x)
367 caltinay 3530
368 caltinay 3536 if y.ndim>0:
369     result=numpy.zeros(self.getShape(), dtype=object)
370     for idx in numpy.ndindex(y.shape):
371     if y[idx]!=0:
372 caltinay 3862 res=self._arr[idx].coeff(y[idx], expand)
373 caltinay 3536 if res is not None:
374     result[idx]=res
375     elif y.item()==0:
376     result=numpy.zeros(self.getShape(), dtype=object)
377     else:
378     coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
379     none_to_zero=lambda item: 0 if item is None else item
380     result=self.applyfunc(coeff_item)
381 caltinay 3614 result=result.applyfunc(none_to_zero)
382 caltinay 3864 res=Symbol(result, dim=self._dim)
383     for i in self._subs: res.subs(i, self._subs[i])
384     return res
385 caltinay 3530
386 caltinay 3614 def subs(self, old, new):
387     """
388     Substitutes an expression.
389     """
390 caltinay 3864 old._ensureShapeCompatible(new)
391     if isinstance(new, Data):
392     subs=self._subs.copy()
393     if isinstance(old, Symbol) and old.getRank()>0:
394     old=Symbol(old.atoms(sympy.Symbol)[0])
395     subs[old]=new
396     result=Symbol(self._arr, dim=self._dim, subs=subs)
397     elif isinstance(old, Symbol) and old.getRank()>0:
398 caltinay 3614 if hasattr(new, '__array__'):
399     new=new.__array__()
400     else:
401     new=numpy.array(new)
402    
403 caltinay 3864 result=numpy.empty(self.getShape(), dtype=object)
404 caltinay 3614 if new.ndim>0:
405     for idx in numpy.ndindex(self.getShape()):
406     for symidx in numpy.ndindex(new.shape):
407 caltinay 3864 result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
408 caltinay 3614 else: # substitute scalar for non-scalar
409     for idx in numpy.ndindex(self.getShape()):
410     for symidx in numpy.ndindex(old.getShape()):
411 caltinay 3864 result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
412     result=Symbol(result, dim=self._dim, subs=self._subs)
413 caltinay 3614 else: # scalar
414     if isinstance(new, Symbol):
415     new=new.item()
416     if isinstance(old, Symbol):
417     old=old.item()
418     subs_item=lambda item: getattr(item, 'subs')(old, new)
419     result=self.applyfunc(subs_item)
420     return result
421    
422 caltinay 3507 def diff(self, *symbols, **assumptions):
423 caltinay 3533 """
424     """
425 caltinay 3507 symbols=Symbol._symbolgen(*symbols)
426 caltinay 3864 result=Symbol(self._arr, dim=self._dim, subs=self._subs)
427 caltinay 3507 for s in symbols:
428     if isinstance(s, Symbol):
429 caltinay 3530 if s.getRank()==0:
430 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
431 caltinay 3507 result=result.applyfunc(diff_item)
432 caltinay 3530 elif s.getRank()==1:
433     dim=s.getShape()[0]
434     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
435     for d in range(dim):
436     for idx in numpy.ndindex(self.getShape()):
437     index=idx+(d,)
438 caltinay 3862 out[index]=out[index].diff(s[d].item(), **assumptions)
439 caltinay 3864 result=Symbol(out, dim=self._dim, subs=self._subs)
440 caltinay 3530 else:
441 caltinay 3864 raise ValueError("diff: argument must have rank 0 or 1")
442 caltinay 3507 else:
443     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
444     result=result.applyfunc(diff_item)
445     return result
446    
447 caltinay 3512 def grad(self, where=None):
448 caltinay 3533 """
449 caltinay 3872 Returns a symbol which represents the gradient of this symbol.
450 caltinay 3864 :type where: ``Symbol``, ``FunctionSpace``
451 caltinay 3533 """
452 caltinay 3872 if self._dim < 0:
453     raise ValueError("grad: cannot compute gradient as symbol has undefined dimensionality")
454 caltinay 3864 subs=self._subs
455 caltinay 3512 if isinstance(where, Symbol):
456     if where.getRank()>0:
457 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
458 caltinay 3512 where=where._arr.item()
459 caltinay 3864 elif isinstance(where, FunctionSpace):
460     name='fs'+str(id(where))
461     fssym=Symbol(name)
462     subs=self._subs.copy()
463     subs.update({fssym:where})
464     where=name
465 caltinay 3512
466 jfenwick 4019 from .functions import grad_n
467 caltinay 3864 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
468     for d in range(self._dim):
469 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
470     index=idx+(d,)
471     if where is None:
472     out[index]=grad_n(out[index],d)
473     else:
474     out[index]=grad_n(out[index],d,where)
475 caltinay 3864 return Symbol(out, dim=self._dim, subs=subs)
476 caltinay 3512
477 caltinay 3517 def inverse(self):
478 caltinay 3533 """
479     """
480 caltinay 3517 if not self.getRank()==2:
481 caltinay 3533 raise TypeError("inverse: Only rank 2 supported")
482 caltinay 3517 s=self.getShape()
483     if not s[0] == s[1]:
484     raise ValueError("inverse: Only square shapes supported")
485     out=numpy.zeros(s, numpy.object)
486     arr=self._arr
487     if s[0]==1:
488     if arr[0,0].is_zero:
489     raise ZeroDivisionError("inverse: Symbol not invertible")
490     out[0,0]=1./arr[0,0]
491     elif s[0]==2:
492     A11=arr[0,0]
493     A12=arr[0,1]
494     A21=arr[1,0]
495     A22=arr[1,1]
496     D = A11*A22-A12*A21
497     if D.is_zero:
498     raise ZeroDivisionError("inverse: Symbol not invertible")
499     D=1./D
500     out[0,0]= A22*D
501     out[1,0]=-A21*D
502     out[0,1]=-A12*D
503     out[1,1]= A11*D
504     elif s[0]==3:
505     A11=arr[0,0]
506     A21=arr[1,0]
507     A31=arr[2,0]
508     A12=arr[0,1]
509     A22=arr[1,1]
510     A32=arr[2,1]
511     A13=arr[0,2]
512     A23=arr[1,2]
513     A33=arr[2,2]
514     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
515     if D.is_zero:
516     raise ZeroDivisionError("inverse: Symbol not invertible")
517     D=1./D
518     out[0,0]=(A22*A33-A23*A32)*D
519     out[1,0]=(A31*A23-A21*A33)*D
520     out[2,0]=(A21*A32-A31*A22)*D
521     out[0,1]=(A13*A32-A12*A33)*D
522     out[1,1]=(A11*A33-A31*A13)*D
523     out[2,1]=(A12*A31-A11*A32)*D
524     out[0,2]=(A12*A23-A13*A22)*D
525     out[1,2]=(A13*A21-A11*A23)*D
526     out[2,2]=(A11*A22-A12*A21)*D
527     else:
528     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
529 caltinay 3864 return Symbol(out, dim=self._dim, subs=self._subs)
530 caltinay 3517
531 caltinay 3507 def swap_axes(self, axis0, axis1):
532 caltinay 3533 """
533     """
534 caltinay 3864 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
535 caltinay 3507
536 caltinay 3509 def tensorProduct(self, other, axis_offset):
537 caltinay 3533 """
538     """
539 caltinay 3512 arg0_c=self._arr.copy()
540     sh0=self.getShape()
541 caltinay 3507 if isinstance(other, Symbol):
542 caltinay 3512 arg1_c=other._arr.copy()
543 caltinay 3507 sh1=other.getShape()
544 caltinay 3872 dim=other._dim if self._dim < 0 else self._dim
545 caltinay 3507 else:
546     arg1_c=other.copy()
547     sh1=other.shape
548 caltinay 3872 dim=self._dim
549 caltinay 3507 d0,d1,d01=1,1,1
550 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
551 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
552     for i in sh1[:axis_offset]: d01*=i
553     arg0_c.resize((d0,d01))
554     arg1_c.resize((d01,d1))
555     out=numpy.zeros((d0,d1),numpy.object)
556     for i0 in range(d0):
557     for i1 in range(d1):
558     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
559 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
560 caltinay 3864 subs=self._subs.copy()
561     subs.update(other._subs)
562 caltinay 3872 return Symbol(out, dim=dim, subs=subs)
563 caltinay 3507
564 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
565 caltinay 3533 """
566     """
567 caltinay 3512 arg0_c=self._arr.copy()
568     sh0=self.getShape()
569 caltinay 3509 if isinstance(other, Symbol):
570 caltinay 3512 arg1_c=other._arr.copy()
571 caltinay 3509 sh1=other.getShape()
572 caltinay 3872 dim=other._dim if self._dim < 0 else self._dim
573 caltinay 3509 else:
574     arg1_c=other.copy()
575     sh1=other.shape
576 caltinay 3872 dim=self._dim
577 caltinay 3509 d0,d1,d01=1,1,1
578     for i in sh0[axis_offset:]: d0*=i
579     for i in sh1[axis_offset:]: d1*=i
580     for i in sh1[:axis_offset]: d01*=i
581     arg0_c.resize((d01,d0))
582     arg1_c.resize((d01,d1))
583     out=numpy.zeros((d0,d1),numpy.object)
584     for i0 in range(d0):
585     for i1 in range(d1):
586     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
587     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
588 caltinay 3864 subs=self._subs.copy()
589     subs.update(other._subs)
590 caltinay 3872 return Symbol(out, dim=dim, subs=subs)
591 caltinay 3509
592     def tensorTransposedProduct(self, other, axis_offset):
593 caltinay 3533 """
594     """
595 caltinay 3512 arg0_c=self._arr.copy()
596     sh0=self.getShape()
597 caltinay 3509 if isinstance(other, Symbol):
598 caltinay 3512 arg1_c=other._arr.copy()
599 caltinay 3509 sh1=other.getShape()
600     r1=other.getRank()
601 caltinay 3872 dim=other._dim if self._dim < 0 else self._dim
602 caltinay 3509 else:
603     arg1_c=other.copy()
604     sh1=other.shape
605     r1=other.ndim
606 caltinay 3872 dim=self._dim
607 caltinay 3509 d0,d1,d01=1,1,1
608 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
609 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
610     for i in sh1[r1-axis_offset:]: d01*=i
611     arg0_c.resize((d0,d01))
612     arg1_c.resize((d1,d01))
613     out=numpy.zeros((d0,d1),numpy.object)
614     for i0 in range(d0):
615     for i1 in range(d1):
616     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
617 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
618 caltinay 3864 subs=self._subs.copy()
619     subs.update(other._subs)
620 caltinay 3872 return Symbol(out, dim=dim, subs=subs)
621 caltinay 3509
622 caltinay 3507 def trace(self, axis_offset):
623 caltinay 3533 """
624 caltinay 3872 Returns the trace of this Symbol.
625 caltinay 3533 """
626 caltinay 3512 sh=self.getShape()
627 caltinay 3507 s1=1
628     for i in range(axis_offset): s1*=sh[i]
629     s2=1
630     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
631 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
632 caltinay 3507 out=numpy.zeros([s1,s2],object)
633     for i1 in range(s1):
634     for i2 in range(s2):
635     for j in range(sh[axis_offset]):
636     out[i1,i2]+=arr_r[i1,j,j,i2]
637     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
638 caltinay 3864 return Symbol(out, dim=self._dim, subs=self._subs)
639 caltinay 3507
640     def transpose(self, axis_offset):
641 caltinay 3533 """
642 caltinay 3872 Returns the transpose of this Symbol.
643 caltinay 3533 """
644 caltinay 3507 if axis_offset is None:
645 caltinay 3512 axis_offset=int(self._arr.ndim/2)
646 jfenwick 4018 axes=list(range(axis_offset, self._arr.ndim))+list(range(0,axis_offset))
647 caltinay 3864 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
648 caltinay 3507
649 caltinay 3532 def applyfunc(self, f, on_type=None):
650 caltinay 3533 """
651 caltinay 3872 Applies the function `f` to all elements (if on_type is None) or to
652     all elements of type `on_type`.
653 caltinay 3533 """
654 caltinay 3507 assert callable(f)
655 caltinay 3512 if self._arr.ndim==0:
656 caltinay 3864 if on_type is None or isinstance(self._arr.item(), on_type):
657 caltinay 3532 el=f(self._arr.item())
658     else:
659     el=self._arr.item()
660 caltinay 3530 if el is not None:
661 caltinay 3864 out=Symbol(el, dim=self._dim, subs=self._subs)
662 caltinay 3530 else:
663     return el
664 caltinay 3507 else:
665 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
666     for idx in numpy.ndindex(self.getShape()):
667 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
668     out[idx]=f(self._arr[idx])
669     else:
670     out[idx]=self._arr[idx]
671 caltinay 3864 out=Symbol(out, dim=self._dim, subs=self._subs)
672 caltinay 3507 return out
673    
674 caltinay 3614 def expand(self):
675     """
676 caltinay 3872 Applies the sympy.expand operation on all elements in this symbol
677 caltinay 3614 """
678     return self.applyfunc(sympy.expand, sympy.Basic)
679    
680 caltinay 3532 def simplify(self):
681 caltinay 3533 """
682 caltinay 3872 Applies the sympy.simplify operation on all elements in this symbol
683 caltinay 3533 """
684 caltinay 3532 return self.applyfunc(sympy.simplify, sympy.Basic)
685    
686 caltinay 3872 # unary/binary operators follow
687 caltinay 3507
688 caltinay 3864 def __pos__(self):
689     return self
690    
691     def __neg__(self):
692     return Symbol(-self._arr, dim=self._dim, subs=self._subs)
693    
694     def __abs__(self):
695     return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
696    
697 caltinay 3517 def _ensureShapeCompatible(self, other):
698     """
699     Checks for compatible shapes for binary operations.
700     Raises TypeError if not compatible.
701     """
702     sh0=self.getShape()
703 caltinay 3864 if isinstance(other, Symbol) or isinstance(other, Data):
704 caltinay 3517 sh1=other.getShape()
705     elif isinstance(other, numpy.ndarray):
706     sh1=other.shape
707 caltinay 3536 elif isinstance(other, list):
708     sh1=numpy.array(other).shape
709 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
710 caltinay 3517 sh1=()
711     else:
712 caltinay 3614 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
713 caltinay 3517 if not sh0==sh1 and not sh0==() and not sh1==():
714 caltinay 3614 raise TypeError("Incompatible shapes for operation")
715 caltinay 3517
716 caltinay 3864 def __binaryop(self, op, other):
717 caltinay 3872 """
718     Helper for binary operations that checks types, shapes etc.
719     """
720 caltinay 3864 self._ensureShapeCompatible(other)
721     if isinstance(other, Symbol):
722     subs=self._subs.copy()
723     subs.update(other._subs)
724 caltinay 3872 dim=other._dim if self._dim < 0 else self._dim
725     return Symbol(getattr(self._arr, op)(other._arr), dim=dim, subs=subs)
726 caltinay 3864 if isinstance(other, Data):
727     name='data'+str(id(other))
728     othersym=Symbol(name, other.getShape(), dim=self._dim)
729     subs=self._subs.copy()
730     subs.update({Symbol(name):other})
731     return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
732     return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
733    
734     def __add__(self, other):
735     return self.__binaryop('__add__', other)
736    
737     def __radd__(self, other):
738     return self.__binaryop('__radd__', other)
739    
740     def __sub__(self, other):
741     return self.__binaryop('__sub__', other)
742    
743     def __rsub__(self, other):
744     return self.__binaryop('__rsub__', other)
745    
746     def __mul__(self, other):
747     return self.__binaryop('__mul__', other)
748    
749     def __rmul__(self, other):
750     return self.__binaryop('__rmul__', other)
751    
752     def __div__(self, other):
753 jfenwick 4019 print(type(self), type(other))
754 caltinay 3864 return self.__binaryop('__div__', other)
755 jfenwick 4019
756     def __truediv__(self, other):
757     return self.__binaryop('__truediv__', other)
758    
759 caltinay 3864 def __rdiv__(self, other):
760     return self.__binaryop('__rdiv__', other)
761 jfenwick 4019
762     def __rtruediv__(self, other):
763     return self.__binaryop('__rtruediv__', other)
764 caltinay 3864 def __pow__(self, other):
765     return self.__binaryop('__pow__', other)
766    
767     def __rpow__(self, other):
768     return self.__binaryop('__rpow__', other)
769    
770 caltinay 3872 # static methods
771    
772 caltinay 3507 @staticmethod
773     def _symComp(sym):
774 caltinay 3857 """
775     """
776 caltinay 3507 n=sym.name
777     a=n.split('[')
778     if len(a)!=2:
779     return n,()
780     a=a[1].split(']')
781     if len(a)!=2:
782     return n,()
783     name=a[0]
784     comps=[int(i) for i in a[1].split('_')[1:]]
785     return name,tuple(comps)
786    
787     @staticmethod
788     def _symbolgen(*symbols):
789     """
790     Generator of all symbols in the argument of diff().
791     (cf. sympy.Derivative._symbolgen)
792    
793     Example:
794     >> ._symbolgen(x, 3, y)
795     (x, x, x, y)
796     >> ._symbolgen(x, 10**6)
797     (x, x, x, x, x, x, x, ...)
798     """
799     from itertools import repeat
800     last_s = symbols[len(symbols)-1]
801     if not isinstance(last_s, Symbol):
802     last_s=sympy.sympify(last_s)
803     for i in xrange(len(symbols)):
804     s = symbols[i]
805     if not isinstance(s, Symbol):
806     s=sympy.sympify(s)
807     next_s = None
808     if s != last_s:
809     next_s = symbols[i+1]
810     if not isinstance(next_s, Symbol):
811     next_s=sympy.sympify(next_s)
812    
813     if isinstance(s, sympy.Integer):
814     continue
815     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
816     # handle cases like (x, 3)
817     if isinstance(next_s, sympy.Integer):
818     # yield (x, x, x)
819     for copy_s in repeat(s,int(next_s)):
820     yield copy_s
821     else:
822     yield s
823     else:
824     yield s
825    

  ViewVC Help
Powered by ViewVC 1.1.26