/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3864 - (hide annotations)
Mon Mar 12 05:18:16 2012 UTC (7 years, 8 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 27492 byte(s)
Symbols now allow direct operations with Data objects and grad() et al allow
specifying FunctionSpace objects directly, without having to use temporary
symbols :-)


1 caltinay 3507
2     ########################################################
3     #
4 caltinay 3815 # Copyright (c) 2003-2012 by University of Queensland
5 caltinay 3507 # Earth Systems Science Computational Center (ESSCC)
6     # http://www.uq.edu.au/esscc
7     #
8     # Primary Business: Queensland, Australia
9     # Licensed under the Open Software License version 3.0
10     # http://www.opensource.org/licenses/osl-3.0.php
11     #
12     ########################################################
13    
14 caltinay 3815 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15 caltinay 3507 Earth Systems Science Computational Center (ESSCC)
16     http://www.uq.edu.au/esscc
17     Primary Business: Queensland, Australia"""
18     __license__="""Licensed under the Open Software License version 3.0
19     http://www.opensource.org/licenses/osl-3.0.php"""
20     __url__="https://launchpad.net/escript-finley"
21 caltinay 3864 __author__="Cihan Altinay"
22 caltinay 3507
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34 caltinay 3864 from esys.escript import Data, FunctionSpace
35 caltinay 3507
36 gross 3856
37 caltinay 3507 class Symbol(object):
38     """
39 caltinay 3854 `Symbol` objects are placeholders for a single mathematical symbol, such as
40     'x', or for arbitrarily complex mathematical expressions such as
41 caltinay 3533 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
42     are also `Symbol`s (the symbolic 'atoms' of the expression).
43    
44     With the help of the 'Evaluator' class these symbols and expressions can
45     be resolved by substituting numeric values and/or escript `Data` objects
46     for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
47     shape (and thus a rank) as well as a dimension (see constructor).
48 caltinay 3854 `Symbol`s are useful to perform mathematical simplifications, compute
49 caltinay 3533 derivatives and as coefficients for nonlinear PDEs which can be solved by
50     the `NonlinearPDE` class.
51 caltinay 3507 """
52    
53 caltinay 3818 # these are for compatibility with sympy.Symbol. lambdify checks these.
54     is_Add=False
55     is_Float=False
56    
57 caltinay 3507 def __init__(self, *args, **kwargs):
58     """
59 caltinay 3533 Initialises a new `Symbol` object in one of three ways::
60    
61     u=Symbol('u')
62    
63     returns a scalar symbol by the name 'u'.
64    
65 caltinay 3535 a=Symbol('alpha', (4,3))
66 caltinay 3533
67     returns a rank 2 symbol with the shape (4,3), whose elements are
68     named '[alpha]_i_j' (with i=0..3, j=0..2).
69    
70     a,b,c=symbols('a,b,c')
71     x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
72    
73     returns a rank 2 symbol with the shape (3,3) whose elements are
74 caltinay 3536 explicitly specified by numeric values and other symbols/expressions
75 caltinay 3533 within a list or numpy array.
76    
77     The dimensionality of the symbol can be specified through the `dim`
78     keyword. All other keywords are passed to the underlying symbolic
79     library (currently sympy).
80    
81     :param args: initialisation arguments as described above
82     :keyword dim: dimensionality of the new Symbol (default: 2)
83     :type dim: ``int``
84 caltinay 3507 """
85 caltinay 3530 if 'dim' in kwargs:
86 caltinay 3864 self._dim=kwargs.pop('dim')
87 caltinay 3530 else:
88 caltinay 3864 self._dim=2
89 caltinay 3530
90 caltinay 3864 if 'subs' in kwargs:
91     self._subs=kwargs.pop('subs')
92     else:
93     self._subs={}
94    
95 caltinay 3507 if len(args)==1:
96     arg=args[0]
97     if isinstance(arg, str):
98 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
99 caltinay 3533 raise ValueError("Name must not contain '[' or ']'")
100 caltinay 3512 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
101 caltinay 3533 elif hasattr(arg, "__array__") or isinstance(arg, list):
102     if isinstance(arg, list): arg=numpy.array(arg)
103 caltinay 3507 arr=arg.__array__()
104     if len(arr.shape)>4:
105     raise ValueError("Symbol only supports tensors up to order 4")
106 caltinay 3533 res=numpy.empty(arr.shape, dtype=object)
107     for idx in numpy.ndindex(arr.shape):
108     if hasattr(arr[idx], "item"):
109     res[idx]=arr[idx].item()
110     else:
111     res[idx]=arr[idx]
112     self._arr=res
113 caltinay 3864 if isinstance(arg, Symbol):
114     self._subs.update(arg._subs)
115     self._dim=arg._dim
116 caltinay 3533 elif isinstance(arg, sympy.Basic):
117 caltinay 3512 self._arr=numpy.array(arg)
118 caltinay 3507 else:
119     raise TypeError("Unsupported argument type %s"%str(type(arg)))
120     elif len(args)==2:
121     if not isinstance(args[0], str):
122     raise TypeError("First argument must be a string")
123     if not isinstance(args[1], tuple):
124     raise TypeError("Second argument must be a tuple")
125     name=args[0]
126     shape=args[1]
127 caltinay 3533 if name.find('[')>=0 or name.find(']')>=0:
128     raise ValueError("Name must not contain '[' or ']'")
129 caltinay 3507 if len(shape)>4:
130     raise ValueError("Symbol only supports tensors up to order 4")
131 caltinay 3509 if len(shape)==0:
132 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
133 caltinay 3509 else:
134 caltinay 3631 try:
135     self._arr=sympy.symarray(shape, '['+name+']')
136     except TypeError:
137     self._arr=sympy.symarray('['+name+']', shape)
138 caltinay 3507 else:
139     raise TypeError("Unsupported number of arguments")
140 caltinay 3512 if self._arr.ndim==0:
141     self.name=str(self._arr.item())
142 caltinay 3507 else:
143 caltinay 3512 self.name=str(self._arr.tolist())
144 caltinay 3507
145     def __repr__(self):
146 caltinay 3512 return str(self._arr)
147 caltinay 3507
148     def __str__(self):
149 caltinay 3512 return str(self._arr)
150 caltinay 3507
151     def __eq__(self, other):
152     if type(self) is not type(other):
153     return False
154     if self.getRank()!=other.getRank():
155     return False
156     if self.getShape()!=other.getShape():
157     return False
158 caltinay 3512 return (self._arr==other._arr).all()
159 caltinay 3507
160     def __getitem__(self, key):
161 caltinay 3862 """
162     Returns an element of this symbol which must have rank >0.
163     Unlike item() this method converts sympy objects and numpy arrays into
164     escript Symbols in order to facilitate expressions that require
165     element access, such as: grad(u)[1]+x
166 caltinay 3507
167 caltinay 3862 :param key: (nd-)index of item to be returned
168     :return: the requested element
169     :rtype: ``Symbol``, ``int``, or ``float``
170     """
171     res=self._arr[key]
172     # replace sympy Symbols/expressions by escript Symbols
173     if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
174     res=Symbol(res)
175     return res
176 caltinay 3818
177 caltinay 3507 def __setitem__(self, key, value):
178 caltinay 3533 if isinstance(value, Symbol):
179 caltinay 3507 if value.getRank()==0:
180 caltinay 3533 self._arr[key]=value.item()
181 caltinay 3512 elif hasattr(self._arr[key], "shape"):
182     if self._arr[key].shape==value.getShape():
183 caltinay 3533 for idx in numpy.ndindex(self._arr[key].shape):
184 caltinay 3862 self._arr[key][idx]=value[idx].item()
185 caltinay 3507 else:
186     raise ValueError("Wrong shape of value")
187     else:
188     raise ValueError("Wrong shape of value")
189 caltinay 3533 elif isinstance(value, sympy.Basic):
190 caltinay 3512 self._arr[key]=value
191 caltinay 3507 elif hasattr(value, "__array__"):
192 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
193 caltinay 3507 else:
194 caltinay 3512 self._arr[key]=sympy.sympify(value)
195 caltinay 3507
196 caltinay 3862 def __iter__(self):
197     return self._arr.__iter__
198    
199 caltinay 3864 def __array__(self):
200     return self._arr
201    
202     def _sympy_(self):
203     """
204     """
205     return self.applyfunc(sympy.sympify)
206    
207 caltinay 3533 def getDim(self):
208     """
209     Returns the spatial dimensionality of this symbol.
210    
211     :return: the symbol's spatial dimensionality
212     :rtype: ``int``
213     """
214 caltinay 3864 return self._dim
215 caltinay 3533
216 caltinay 3507 def getRank(self):
217 caltinay 3533 """
218     Returns the rank of this symbol.
219    
220     :return: the symbol's rank which is equal to the length of the shape.
221     :rtype: ``int``
222     """
223 caltinay 3512 return self._arr.ndim
224 caltinay 3507
225     def getShape(self):
226 caltinay 3533 """
227     Returns the shape of this symbol.
228    
229     :return: the symbol's shape
230     :rtype: ``tuple`` of ``int``
231     """
232 caltinay 3512 return self._arr.shape
233 caltinay 3507
234 caltinay 3864 def getDataSubstitutions(self):
235     """
236     Returns a dictionary of symbol names and the escript ``Data`` objects
237     they represent within this Symbol.
238    
239     :return: the dictionary of substituted ``Data`` objects
240     :rtype: ``dict``
241     """
242     return self._subs
243    
244 caltinay 3533 def item(self, *args):
245     """
246     Returns an element of this symbol.
247     This method behaves like the item() method of numpy.ndarray.
248     If this is a scalar Symbol, no arguments are allowed and the only
249     element in this Symbol is returned.
250     Otherwise, 'args' specifies a flat or nd-index and the element at
251     that index is returned.
252    
253     :param args: index of item to be returned
254     :return: the requested element
255     :rtype: ``sympy.Symbol``, ``int``, or ``float``
256     """
257     return self._arr.item(args)
258    
259 caltinay 3507 def atoms(self, *types):
260 caltinay 3533 """
261     Returns the atoms that form the current Symbol.
262    
263     By default, only objects that are truly atomic and cannot be divided
264     into smaller pieces are returned: symbols, numbers, and number
265     symbols like I and pi. It is possible to request atoms of any type,
266     however.
267    
268     Note that if this symbol contains components such as [x]_i_j then
269     only their main symbol 'x' is returned.
270    
271     :param types: types to restrict result to
272     :return: list of atoms of specified type
273     :rtype: ``set``
274     """
275 caltinay 3507 s=set()
276 caltinay 3512 for el in self._arr.flat:
277 caltinay 3530 if isinstance(el,sympy.Basic):
278     atoms=el.atoms(*types)
279     for a in atoms:
280     if a.is_Symbol:
281     n,c=Symbol._symComp(a)
282     s.add(sympy.Symbol(n))
283     else:
284     s.add(a)
285 caltinay 3533 elif len(types)==0 or type(el) in types:
286     s.add(el)
287 caltinay 3507 return s
288    
289     def _sympystr_(self, printer):
290 caltinay 3818 # compatibility with sympy 1.6
291     return self._sympystr(printer)
292    
293     def _sympystr(self, printer):
294 caltinay 3507 return self.lambdarepr()
295    
296     def lambdarepr(self):
297     from sympy.printing.lambdarepr import lambdarepr
298 caltinay 3512 temp_arr=numpy.empty(self.getShape(), dtype=object)
299     for idx,el in numpy.ndenumerate(self._arr):
300 caltinay 3530 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
301 caltinay 3518 # create a dictionary to convert names like [x]_0_0 to x[0,0]
302 caltinay 3507 symdict={}
303     for a in atoms:
304     n,c=Symbol._symComp(a)
305     if len(c)>0:
306     c=[str(i) for i in c]
307     symstr=n+'['+','.join(c)+']'
308     else:
309     symstr=n
310     symdict[a.name]=symstr
311     s=lambdarepr(el)
312     for key in symdict:
313     s=s.replace(key, symdict[key])
314     temp_arr[idx]=s
315 caltinay 3517 if self.getRank()==0:
316     return temp_arr.item()
317     else:
318     return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
319 caltinay 3507
320 caltinay 3530 def coeff(self, x, expand=True):
321 caltinay 3536 """
322     Returns the coefficient of the term "x" or 0 if there is no "x".
323    
324     If "x" is a scalar symbol then "x" is searched in all components of
325     this symbol. Otherwise the shapes must match and the coefficients are
326     checked component by component.
327    
328     Example::
329    
330     x=Symbol('x', (2,2))
331     y=3*x
332     print y.coeff(x)
333     print y.coeff(x[1,1])
334    
335     will print::
336    
337     [[3 3]
338     [3 3]]
339    
340     [[0 0]
341     [0 3]]
342    
343     :param x: the term whose coefficients are to be found
344     :type x: ``Symbol``, ``numpy.ndarray``, `list`
345     :return: the coefficient(s) of the term
346     :rtype: ``Symbol``
347     """
348 caltinay 3530 self._ensureShapeCompatible(x)
349 caltinay 3536 if hasattr(x, '__array__'):
350     y=x.__array__()
351 caltinay 3530 else:
352 caltinay 3536 y=numpy.array(x)
353 caltinay 3530
354 caltinay 3536 if y.ndim>0:
355     result=numpy.zeros(self.getShape(), dtype=object)
356     for idx in numpy.ndindex(y.shape):
357     if y[idx]!=0:
358 caltinay 3862 res=self._arr[idx].coeff(y[idx], expand)
359 caltinay 3536 if res is not None:
360     result[idx]=res
361     elif y.item()==0:
362     result=numpy.zeros(self.getShape(), dtype=object)
363     else:
364     coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
365     none_to_zero=lambda item: 0 if item is None else item
366     result=self.applyfunc(coeff_item)
367 caltinay 3614 result=result.applyfunc(none_to_zero)
368 caltinay 3864 res=Symbol(result, dim=self._dim)
369     for i in self._subs: res.subs(i, self._subs[i])
370     return res
371 caltinay 3530
372 caltinay 3614 def subs(self, old, new):
373     """
374     Substitutes an expression.
375     """
376 caltinay 3864 old._ensureShapeCompatible(new)
377     if isinstance(new, Data):
378     subs=self._subs.copy()
379     if isinstance(old, Symbol) and old.getRank()>0:
380     old=Symbol(old.atoms(sympy.Symbol)[0])
381     subs[old]=new
382     result=Symbol(self._arr, dim=self._dim, subs=subs)
383     elif isinstance(old, Symbol) and old.getRank()>0:
384 caltinay 3614 if hasattr(new, '__array__'):
385     new=new.__array__()
386     else:
387     new=numpy.array(new)
388    
389 caltinay 3864 result=numpy.empty(self.getShape(), dtype=object)
390 caltinay 3614 if new.ndim>0:
391     for idx in numpy.ndindex(self.getShape()):
392     for symidx in numpy.ndindex(new.shape):
393 caltinay 3864 result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
394 caltinay 3614 else: # substitute scalar for non-scalar
395     for idx in numpy.ndindex(self.getShape()):
396     for symidx in numpy.ndindex(old.getShape()):
397 caltinay 3864 result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
398     result=Symbol(result, dim=self._dim, subs=self._subs)
399 caltinay 3614 else: # scalar
400     if isinstance(new, Symbol):
401     new=new.item()
402     if isinstance(old, Symbol):
403     old=old.item()
404     subs_item=lambda item: getattr(item, 'subs')(old, new)
405     result=self.applyfunc(subs_item)
406     return result
407    
408 caltinay 3507 def diff(self, *symbols, **assumptions):
409 caltinay 3533 """
410     """
411 caltinay 3507 symbols=Symbol._symbolgen(*symbols)
412 caltinay 3864 result=Symbol(self._arr, dim=self._dim, subs=self._subs)
413 caltinay 3507 for s in symbols:
414     if isinstance(s, Symbol):
415 caltinay 3530 if s.getRank()==0:
416 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
417 caltinay 3507 result=result.applyfunc(diff_item)
418 caltinay 3530 elif s.getRank()==1:
419     dim=s.getShape()[0]
420     out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
421     for d in range(dim):
422     for idx in numpy.ndindex(self.getShape()):
423     index=idx+(d,)
424 caltinay 3862 out[index]=out[index].diff(s[d].item(), **assumptions)
425 caltinay 3864 result=Symbol(out, dim=self._dim, subs=self._subs)
426 caltinay 3530 else:
427 caltinay 3864 raise ValueError("diff: argument must have rank 0 or 1")
428 caltinay 3507 else:
429     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
430     result=result.applyfunc(diff_item)
431     return result
432    
433 caltinay 3512 def grad(self, where=None):
434 caltinay 3533 """
435 caltinay 3864 :type where: ``Symbol``, ``FunctionSpace``
436 caltinay 3533 """
437 caltinay 3864 subs=self._subs
438 caltinay 3512 if isinstance(where, Symbol):
439     if where.getRank()>0:
440 caltinay 3518 raise ValueError("grad: 'where' must be a scalar symbol")
441 caltinay 3512 where=where._arr.item()
442 caltinay 3864 elif isinstance(where, FunctionSpace):
443     name='fs'+str(id(where))
444     fssym=Symbol(name)
445     subs=self._subs.copy()
446     subs.update({fssym:where})
447     where=name
448 caltinay 3512
449     from functions import grad_n
450 caltinay 3864 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
451     for d in range(self._dim):
452 caltinay 3512 for idx in numpy.ndindex(self.getShape()):
453     index=idx+(d,)
454     if where is None:
455     out[index]=grad_n(out[index],d)
456     else:
457     out[index]=grad_n(out[index],d,where)
458 caltinay 3864 return Symbol(out, dim=self._dim, subs=subs)
459 caltinay 3512
460 caltinay 3517 def inverse(self):
461 caltinay 3533 """
462     """
463 caltinay 3517 if not self.getRank()==2:
464 caltinay 3533 raise TypeError("inverse: Only rank 2 supported")
465 caltinay 3517 s=self.getShape()
466     if not s[0] == s[1]:
467     raise ValueError("inverse: Only square shapes supported")
468     out=numpy.zeros(s, numpy.object)
469     arr=self._arr
470     if s[0]==1:
471     if arr[0,0].is_zero:
472     raise ZeroDivisionError("inverse: Symbol not invertible")
473     out[0,0]=1./arr[0,0]
474     elif s[0]==2:
475     A11=arr[0,0]
476     A12=arr[0,1]
477     A21=arr[1,0]
478     A22=arr[1,1]
479     D = A11*A22-A12*A21
480     if D.is_zero:
481     raise ZeroDivisionError("inverse: Symbol not invertible")
482     D=1./D
483     out[0,0]= A22*D
484     out[1,0]=-A21*D
485     out[0,1]=-A12*D
486     out[1,1]= A11*D
487     elif s[0]==3:
488     A11=arr[0,0]
489     A21=arr[1,0]
490     A31=arr[2,0]
491     A12=arr[0,1]
492     A22=arr[1,1]
493     A32=arr[2,1]
494     A13=arr[0,2]
495     A23=arr[1,2]
496     A33=arr[2,2]
497     D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
498     if D.is_zero:
499     raise ZeroDivisionError("inverse: Symbol not invertible")
500     D=1./D
501     out[0,0]=(A22*A33-A23*A32)*D
502     out[1,0]=(A31*A23-A21*A33)*D
503     out[2,0]=(A21*A32-A31*A22)*D
504     out[0,1]=(A13*A32-A12*A33)*D
505     out[1,1]=(A11*A33-A31*A13)*D
506     out[2,1]=(A12*A31-A11*A32)*D
507     out[0,2]=(A12*A23-A13*A22)*D
508     out[1,2]=(A13*A21-A11*A23)*D
509     out[2,2]=(A11*A22-A12*A21)*D
510     else:
511     raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
512 caltinay 3864 return Symbol(out, dim=self._dim, subs=self._subs)
513 caltinay 3517
514 caltinay 3507 def swap_axes(self, axis0, axis1):
515 caltinay 3533 """
516     """
517 caltinay 3864 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
518 caltinay 3507
519 caltinay 3509 def tensorProduct(self, other, axis_offset):
520 caltinay 3533 """
521     """
522 caltinay 3512 arg0_c=self._arr.copy()
523     sh0=self.getShape()
524 caltinay 3507 if isinstance(other, Symbol):
525 caltinay 3512 arg1_c=other._arr.copy()
526 caltinay 3507 sh1=other.getShape()
527     else:
528     arg1_c=other.copy()
529     sh1=other.shape
530     d0,d1,d01=1,1,1
531 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
532 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
533     for i in sh1[:axis_offset]: d01*=i
534     arg0_c.resize((d0,d01))
535     arg1_c.resize((d01,d1))
536     out=numpy.zeros((d0,d1),numpy.object)
537     for i0 in range(d0):
538     for i1 in range(d1):
539     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
540 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
541 caltinay 3864 subs=self._subs.copy()
542     subs.update(other._subs)
543     return Symbol(out, dim=self._dim, subs=subs)
544 caltinay 3507
545 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
546 caltinay 3533 """
547     """
548 caltinay 3512 arg0_c=self._arr.copy()
549     sh0=self.getShape()
550 caltinay 3509 if isinstance(other, Symbol):
551 caltinay 3512 arg1_c=other._arr.copy()
552 caltinay 3509 sh1=other.getShape()
553     else:
554     arg1_c=other.copy()
555     sh1=other.shape
556     d0,d1,d01=1,1,1
557     for i in sh0[axis_offset:]: d0*=i
558     for i in sh1[axis_offset:]: d1*=i
559     for i in sh1[:axis_offset]: d01*=i
560     arg0_c.resize((d01,d0))
561     arg1_c.resize((d01,d1))
562     out=numpy.zeros((d0,d1),numpy.object)
563     for i0 in range(d0):
564     for i1 in range(d1):
565     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
566     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
567 caltinay 3864 subs=self._subs.copy()
568     subs.update(other._subs)
569     return Symbol(out, dim=self._dim, subs=subs)
570 caltinay 3509
571     def tensorTransposedProduct(self, other, axis_offset):
572 caltinay 3533 """
573     """
574 caltinay 3512 arg0_c=self._arr.copy()
575     sh0=self.getShape()
576 caltinay 3509 if isinstance(other, Symbol):
577 caltinay 3512 arg1_c=other._arr.copy()
578 caltinay 3509 sh1=other.getShape()
579     r1=other.getRank()
580     else:
581     arg1_c=other.copy()
582     sh1=other.shape
583     r1=other.ndim
584     d0,d1,d01=1,1,1
585 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
586 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
587     for i in sh1[r1-axis_offset:]: d01*=i
588     arg0_c.resize((d0,d01))
589     arg1_c.resize((d1,d01))
590     out=numpy.zeros((d0,d1),numpy.object)
591     for i0 in range(d0):
592     for i1 in range(d1):
593     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
594 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
595 caltinay 3864 subs=self._subs.copy()
596     subs.update(other._subs)
597     return Symbol(out, dim=self._dim, subs=subs)
598 caltinay 3509
599 caltinay 3507 def trace(self, axis_offset):
600 caltinay 3533 """
601     """
602 caltinay 3512 sh=self.getShape()
603 caltinay 3507 s1=1
604     for i in range(axis_offset): s1*=sh[i]
605     s2=1
606     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
607 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
608 caltinay 3507 out=numpy.zeros([s1,s2],object)
609     for i1 in range(s1):
610     for i2 in range(s2):
611     for j in range(sh[axis_offset]):
612     out[i1,i2]+=arr_r[i1,j,j,i2]
613     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
614 caltinay 3864 return Symbol(out, dim=self._dim, subs=self._subs)
615 caltinay 3507
616     def transpose(self, axis_offset):
617 caltinay 3533 """
618     """
619 caltinay 3507 if axis_offset is None:
620 caltinay 3512 axis_offset=int(self._arr.ndim/2)
621     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
622 caltinay 3864 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
623 caltinay 3507
624 caltinay 3532 def applyfunc(self, f, on_type=None):
625 caltinay 3533 """
626     """
627 caltinay 3507 assert callable(f)
628 caltinay 3512 if self._arr.ndim==0:
629 caltinay 3864 if on_type is None or isinstance(self._arr.item(), on_type):
630 caltinay 3532 el=f(self._arr.item())
631     else:
632     el=self._arr.item()
633 caltinay 3530 if el is not None:
634 caltinay 3864 out=Symbol(el, dim=self._dim, subs=self._subs)
635 caltinay 3530 else:
636     return el
637 caltinay 3507 else:
638 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
639     for idx in numpy.ndindex(self.getShape()):
640 caltinay 3532 if on_type is None or isinstance(self._arr[idx],on_type):
641     out[idx]=f(self._arr[idx])
642     else:
643     out[idx]=self._arr[idx]
644 caltinay 3864 out=Symbol(out, dim=self._dim, subs=self._subs)
645 caltinay 3507 return out
646    
647 caltinay 3614 def expand(self):
648     """
649     """
650     return self.applyfunc(sympy.expand, sympy.Basic)
651    
652 caltinay 3532 def simplify(self):
653 caltinay 3533 """
654     """
655 caltinay 3532 return self.applyfunc(sympy.simplify, sympy.Basic)
656    
657 caltinay 3864 # unary/binary operations follow
658 caltinay 3507
659 caltinay 3864 def __pos__(self):
660     return self
661    
662     def __neg__(self):
663     return Symbol(-self._arr, dim=self._dim, subs=self._subs)
664    
665     def __abs__(self):
666     return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
667    
668 caltinay 3517 def _ensureShapeCompatible(self, other):
669     """
670     Checks for compatible shapes for binary operations.
671     Raises TypeError if not compatible.
672     """
673     sh0=self.getShape()
674 caltinay 3864 if isinstance(other, Symbol) or isinstance(other, Data):
675 caltinay 3517 sh1=other.getShape()
676     elif isinstance(other, numpy.ndarray):
677     sh1=other.shape
678 caltinay 3536 elif isinstance(other, list):
679     sh1=numpy.array(other).shape
680 caltinay 3530 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
681 caltinay 3517 sh1=()
682     else:
683 caltinay 3614 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
684 caltinay 3517 if not sh0==sh1 and not sh0==() and not sh1==():
685 caltinay 3614 raise TypeError("Incompatible shapes for operation")
686 caltinay 3517
687 caltinay 3864 def __binaryop(self, op, other):
688     self._ensureShapeCompatible(other)
689     if isinstance(other, Symbol):
690     subs=self._subs.copy()
691     subs.update(other._subs)
692     return Symbol(getattr(self._arr, op)(other._arr), dim=self._dim, subs=subs)
693     if isinstance(other, Data):
694     name='data'+str(id(other))
695     othersym=Symbol(name, other.getShape(), dim=self._dim)
696     subs=self._subs.copy()
697     subs.update({Symbol(name):other})
698     return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
699     return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
700    
701     def __add__(self, other):
702     return self.__binaryop('__add__', other)
703    
704     def __radd__(self, other):
705     return self.__binaryop('__radd__', other)
706    
707     def __sub__(self, other):
708     return self.__binaryop('__sub__', other)
709    
710     def __rsub__(self, other):
711     return self.__binaryop('__rsub__', other)
712    
713     def __mul__(self, other):
714     return self.__binaryop('__mul__', other)
715    
716     def __rmul__(self, other):
717     return self.__binaryop('__rmul__', other)
718    
719     def __div__(self, other):
720     return self.__binaryop('__div__', other)
721    
722     def __rdiv__(self, other):
723     return self.__binaryop('__rdiv__', other)
724    
725     def __pow__(self, other):
726     return self.__binaryop('__pow__', other)
727    
728     def __rpow__(self, other):
729     return self.__binaryop('__rpow__', other)
730    
731 caltinay 3507 @staticmethod
732     def _symComp(sym):
733 caltinay 3857 """
734     """
735 caltinay 3507 n=sym.name
736     a=n.split('[')
737     if len(a)!=2:
738     return n,()
739     a=a[1].split(']')
740     if len(a)!=2:
741     return n,()
742     name=a[0]
743     comps=[int(i) for i in a[1].split('_')[1:]]
744     return name,tuple(comps)
745    
746     @staticmethod
747     def _symbolgen(*symbols):
748     """
749     Generator of all symbols in the argument of diff().
750     (cf. sympy.Derivative._symbolgen)
751    
752     Example:
753     >> ._symbolgen(x, 3, y)
754     (x, x, x, y)
755     >> ._symbolgen(x, 10**6)
756     (x, x, x, x, x, x, x, ...)
757     """
758     from itertools import repeat
759     last_s = symbols[len(symbols)-1]
760     if not isinstance(last_s, Symbol):
761     last_s=sympy.sympify(last_s)
762     for i in xrange(len(symbols)):
763     s = symbols[i]
764     if not isinstance(s, Symbol):
765     s=sympy.sympify(s)
766     next_s = None
767     if s != last_s:
768     next_s = symbols[i+1]
769     if not isinstance(next_s, Symbol):
770     next_s=sympy.sympify(next_s)
771    
772     if isinstance(s, sympy.Integer):
773     continue
774     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
775     # handle cases like (x, 3)
776     if isinstance(next_s, sympy.Integer):
777     # yield (x, x, x)
778     for copy_s in repeat(s,int(next_s)):
779     yield copy_s
780     else:
781     yield s
782     else:
783     yield s
784    

  ViewVC Help
Powered by ViewVC 1.1.26