/[escript]/branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
ViewVC logotype

Contents of /branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3864 - (show annotations)
Mon Mar 12 05:18:16 2012 UTC (7 years, 5 months ago) by caltinay
File MIME type: text/x-python
File size: 27492 byte(s)
Symbols now allow direct operations with Data objects and grad() et al allow
specifying FunctionSpace objects directly, without having to use temporary
symbols :-)


1
2 ########################################################
3 #
4 # Copyright (c) 2003-2012 by University of Queensland
5 # Earth Systems Science Computational Center (ESSCC)
6 # http://www.uq.edu.au/esscc
7 #
8 # Primary Business: Queensland, Australia
9 # Licensed under the Open Software License version 3.0
10 # http://www.opensource.org/licenses/osl-3.0.php
11 #
12 ########################################################
13
14 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15 Earth Systems Science Computational Center (ESSCC)
16 http://www.uq.edu.au/esscc
17 Primary Business: Queensland, Australia"""
18 __license__="""Licensed under the Open Software License version 3.0
19 http://www.opensource.org/licenses/osl-3.0.php"""
20 __url__="https://launchpad.net/escript-finley"
21 __author__="Cihan Altinay"
22
23 """
24 :var __author__: name of author
25 :var __copyright__: copyrights
26 :var __license__: licence agreement
27 :var __url__: url entry point on documentation
28 :var __version__: version
29 :var __date__: date of the version
30 """
31
32 import numpy
33 import sympy
34 from esys.escript import Data, FunctionSpace
35
36
37 class Symbol(object):
38 """
39 `Symbol` objects are placeholders for a single mathematical symbol, such as
40 'x', or for arbitrarily complex mathematical expressions such as
41 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
42 are also `Symbol`s (the symbolic 'atoms' of the expression).
43
44 With the help of the 'Evaluator' class these symbols and expressions can
45 be resolved by substituting numeric values and/or escript `Data` objects
46 for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
47 shape (and thus a rank) as well as a dimension (see constructor).
48 `Symbol`s are useful to perform mathematical simplifications, compute
49 derivatives and as coefficients for nonlinear PDEs which can be solved by
50 the `NonlinearPDE` class.
51 """
52
53 # these are for compatibility with sympy.Symbol. lambdify checks these.
54 is_Add=False
55 is_Float=False
56
57 def __init__(self, *args, **kwargs):
58 """
59 Initialises a new `Symbol` object in one of three ways::
60
61 u=Symbol('u')
62
63 returns a scalar symbol by the name 'u'.
64
65 a=Symbol('alpha', (4,3))
66
67 returns a rank 2 symbol with the shape (4,3), whose elements are
68 named '[alpha]_i_j' (with i=0..3, j=0..2).
69
70 a,b,c=symbols('a,b,c')
71 x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
72
73 returns a rank 2 symbol with the shape (3,3) whose elements are
74 explicitly specified by numeric values and other symbols/expressions
75 within a list or numpy array.
76
77 The dimensionality of the symbol can be specified through the `dim`
78 keyword. All other keywords are passed to the underlying symbolic
79 library (currently sympy).
80
81 :param args: initialisation arguments as described above
82 :keyword dim: dimensionality of the new Symbol (default: 2)
83 :type dim: ``int``
84 """
85 if 'dim' in kwargs:
86 self._dim=kwargs.pop('dim')
87 else:
88 self._dim=2
89
90 if 'subs' in kwargs:
91 self._subs=kwargs.pop('subs')
92 else:
93 self._subs={}
94
95 if len(args)==1:
96 arg=args[0]
97 if isinstance(arg, str):
98 if arg.find('[')>=0 or arg.find(']')>=0:
99 raise ValueError("Name must not contain '[' or ']'")
100 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
101 elif hasattr(arg, "__array__") or isinstance(arg, list):
102 if isinstance(arg, list): arg=numpy.array(arg)
103 arr=arg.__array__()
104 if len(arr.shape)>4:
105 raise ValueError("Symbol only supports tensors up to order 4")
106 res=numpy.empty(arr.shape, dtype=object)
107 for idx in numpy.ndindex(arr.shape):
108 if hasattr(arr[idx], "item"):
109 res[idx]=arr[idx].item()
110 else:
111 res[idx]=arr[idx]
112 self._arr=res
113 if isinstance(arg, Symbol):
114 self._subs.update(arg._subs)
115 self._dim=arg._dim
116 elif isinstance(arg, sympy.Basic):
117 self._arr=numpy.array(arg)
118 else:
119 raise TypeError("Unsupported argument type %s"%str(type(arg)))
120 elif len(args)==2:
121 if not isinstance(args[0], str):
122 raise TypeError("First argument must be a string")
123 if not isinstance(args[1], tuple):
124 raise TypeError("Second argument must be a tuple")
125 name=args[0]
126 shape=args[1]
127 if name.find('[')>=0 or name.find(']')>=0:
128 raise ValueError("Name must not contain '[' or ']'")
129 if len(shape)>4:
130 raise ValueError("Symbol only supports tensors up to order 4")
131 if len(shape)==0:
132 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
133 else:
134 try:
135 self._arr=sympy.symarray(shape, '['+name+']')
136 except TypeError:
137 self._arr=sympy.symarray('['+name+']', shape)
138 else:
139 raise TypeError("Unsupported number of arguments")
140 if self._arr.ndim==0:
141 self.name=str(self._arr.item())
142 else:
143 self.name=str(self._arr.tolist())
144
145 def __repr__(self):
146 return str(self._arr)
147
148 def __str__(self):
149 return str(self._arr)
150
151 def __eq__(self, other):
152 if type(self) is not type(other):
153 return False
154 if self.getRank()!=other.getRank():
155 return False
156 if self.getShape()!=other.getShape():
157 return False
158 return (self._arr==other._arr).all()
159
160 def __getitem__(self, key):
161 """
162 Returns an element of this symbol which must have rank >0.
163 Unlike item() this method converts sympy objects and numpy arrays into
164 escript Symbols in order to facilitate expressions that require
165 element access, such as: grad(u)[1]+x
166
167 :param key: (nd-)index of item to be returned
168 :return: the requested element
169 :rtype: ``Symbol``, ``int``, or ``float``
170 """
171 res=self._arr[key]
172 # replace sympy Symbols/expressions by escript Symbols
173 if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
174 res=Symbol(res)
175 return res
176
177 def __setitem__(self, key, value):
178 if isinstance(value, Symbol):
179 if value.getRank()==0:
180 self._arr[key]=value.item()
181 elif hasattr(self._arr[key], "shape"):
182 if self._arr[key].shape==value.getShape():
183 for idx in numpy.ndindex(self._arr[key].shape):
184 self._arr[key][idx]=value[idx].item()
185 else:
186 raise ValueError("Wrong shape of value")
187 else:
188 raise ValueError("Wrong shape of value")
189 elif isinstance(value, sympy.Basic):
190 self._arr[key]=value
191 elif hasattr(value, "__array__"):
192 self._arr[key]=map(sympy.sympify,value.flat)
193 else:
194 self._arr[key]=sympy.sympify(value)
195
196 def __iter__(self):
197 return self._arr.__iter__
198
199 def __array__(self):
200 return self._arr
201
202 def _sympy_(self):
203 """
204 """
205 return self.applyfunc(sympy.sympify)
206
207 def getDim(self):
208 """
209 Returns the spatial dimensionality of this symbol.
210
211 :return: the symbol's spatial dimensionality
212 :rtype: ``int``
213 """
214 return self._dim
215
216 def getRank(self):
217 """
218 Returns the rank of this symbol.
219
220 :return: the symbol's rank which is equal to the length of the shape.
221 :rtype: ``int``
222 """
223 return self._arr.ndim
224
225 def getShape(self):
226 """
227 Returns the shape of this symbol.
228
229 :return: the symbol's shape
230 :rtype: ``tuple`` of ``int``
231 """
232 return self._arr.shape
233
234 def getDataSubstitutions(self):
235 """
236 Returns a dictionary of symbol names and the escript ``Data`` objects
237 they represent within this Symbol.
238
239 :return: the dictionary of substituted ``Data`` objects
240 :rtype: ``dict``
241 """
242 return self._subs
243
244 def item(self, *args):
245 """
246 Returns an element of this symbol.
247 This method behaves like the item() method of numpy.ndarray.
248 If this is a scalar Symbol, no arguments are allowed and the only
249 element in this Symbol is returned.
250 Otherwise, 'args' specifies a flat or nd-index and the element at
251 that index is returned.
252
253 :param args: index of item to be returned
254 :return: the requested element
255 :rtype: ``sympy.Symbol``, ``int``, or ``float``
256 """
257 return self._arr.item(args)
258
259 def atoms(self, *types):
260 """
261 Returns the atoms that form the current Symbol.
262
263 By default, only objects that are truly atomic and cannot be divided
264 into smaller pieces are returned: symbols, numbers, and number
265 symbols like I and pi. It is possible to request atoms of any type,
266 however.
267
268 Note that if this symbol contains components such as [x]_i_j then
269 only their main symbol 'x' is returned.
270
271 :param types: types to restrict result to
272 :return: list of atoms of specified type
273 :rtype: ``set``
274 """
275 s=set()
276 for el in self._arr.flat:
277 if isinstance(el,sympy.Basic):
278 atoms=el.atoms(*types)
279 for a in atoms:
280 if a.is_Symbol:
281 n,c=Symbol._symComp(a)
282 s.add(sympy.Symbol(n))
283 else:
284 s.add(a)
285 elif len(types)==0 or type(el) in types:
286 s.add(el)
287 return s
288
289 def _sympystr_(self, printer):
290 # compatibility with sympy 1.6
291 return self._sympystr(printer)
292
293 def _sympystr(self, printer):
294 return self.lambdarepr()
295
296 def lambdarepr(self):
297 from sympy.printing.lambdarepr import lambdarepr
298 temp_arr=numpy.empty(self.getShape(), dtype=object)
299 for idx,el in numpy.ndenumerate(self._arr):
300 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
301 # create a dictionary to convert names like [x]_0_0 to x[0,0]
302 symdict={}
303 for a in atoms:
304 n,c=Symbol._symComp(a)
305 if len(c)>0:
306 c=[str(i) for i in c]
307 symstr=n+'['+','.join(c)+']'
308 else:
309 symstr=n
310 symdict[a.name]=symstr
311 s=lambdarepr(el)
312 for key in symdict:
313 s=s.replace(key, symdict[key])
314 temp_arr[idx]=s
315 if self.getRank()==0:
316 return temp_arr.item()
317 else:
318 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
319
320 def coeff(self, x, expand=True):
321 """
322 Returns the coefficient of the term "x" or 0 if there is no "x".
323
324 If "x" is a scalar symbol then "x" is searched in all components of
325 this symbol. Otherwise the shapes must match and the coefficients are
326 checked component by component.
327
328 Example::
329
330 x=Symbol('x', (2,2))
331 y=3*x
332 print y.coeff(x)
333 print y.coeff(x[1,1])
334
335 will print::
336
337 [[3 3]
338 [3 3]]
339
340 [[0 0]
341 [0 3]]
342
343 :param x: the term whose coefficients are to be found
344 :type x: ``Symbol``, ``numpy.ndarray``, `list`
345 :return: the coefficient(s) of the term
346 :rtype: ``Symbol``
347 """
348 self._ensureShapeCompatible(x)
349 if hasattr(x, '__array__'):
350 y=x.__array__()
351 else:
352 y=numpy.array(x)
353
354 if y.ndim>0:
355 result=numpy.zeros(self.getShape(), dtype=object)
356 for idx in numpy.ndindex(y.shape):
357 if y[idx]!=0:
358 res=self._arr[idx].coeff(y[idx], expand)
359 if res is not None:
360 result[idx]=res
361 elif y.item()==0:
362 result=numpy.zeros(self.getShape(), dtype=object)
363 else:
364 coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
365 none_to_zero=lambda item: 0 if item is None else item
366 result=self.applyfunc(coeff_item)
367 result=result.applyfunc(none_to_zero)
368 res=Symbol(result, dim=self._dim)
369 for i in self._subs: res.subs(i, self._subs[i])
370 return res
371
372 def subs(self, old, new):
373 """
374 Substitutes an expression.
375 """
376 old._ensureShapeCompatible(new)
377 if isinstance(new, Data):
378 subs=self._subs.copy()
379 if isinstance(old, Symbol) and old.getRank()>0:
380 old=Symbol(old.atoms(sympy.Symbol)[0])
381 subs[old]=new
382 result=Symbol(self._arr, dim=self._dim, subs=subs)
383 elif isinstance(old, Symbol) and old.getRank()>0:
384 if hasattr(new, '__array__'):
385 new=new.__array__()
386 else:
387 new=numpy.array(new)
388
389 result=numpy.empty(self.getShape(), dtype=object)
390 if new.ndim>0:
391 for idx in numpy.ndindex(self.getShape()):
392 for symidx in numpy.ndindex(new.shape):
393 result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
394 else: # substitute scalar for non-scalar
395 for idx in numpy.ndindex(self.getShape()):
396 for symidx in numpy.ndindex(old.getShape()):
397 result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
398 result=Symbol(result, dim=self._dim, subs=self._subs)
399 else: # scalar
400 if isinstance(new, Symbol):
401 new=new.item()
402 if isinstance(old, Symbol):
403 old=old.item()
404 subs_item=lambda item: getattr(item, 'subs')(old, new)
405 result=self.applyfunc(subs_item)
406 return result
407
408 def diff(self, *symbols, **assumptions):
409 """
410 """
411 symbols=Symbol._symbolgen(*symbols)
412 result=Symbol(self._arr, dim=self._dim, subs=self._subs)
413 for s in symbols:
414 if isinstance(s, Symbol):
415 if s.getRank()==0:
416 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
417 result=result.applyfunc(diff_item)
418 elif s.getRank()==1:
419 dim=s.getShape()[0]
420 out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
421 for d in range(dim):
422 for idx in numpy.ndindex(self.getShape()):
423 index=idx+(d,)
424 out[index]=out[index].diff(s[d].item(), **assumptions)
425 result=Symbol(out, dim=self._dim, subs=self._subs)
426 else:
427 raise ValueError("diff: argument must have rank 0 or 1")
428 else:
429 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
430 result=result.applyfunc(diff_item)
431 return result
432
433 def grad(self, where=None):
434 """
435 :type where: ``Symbol``, ``FunctionSpace``
436 """
437 subs=self._subs
438 if isinstance(where, Symbol):
439 if where.getRank()>0:
440 raise ValueError("grad: 'where' must be a scalar symbol")
441 where=where._arr.item()
442 elif isinstance(where, FunctionSpace):
443 name='fs'+str(id(where))
444 fssym=Symbol(name)
445 subs=self._subs.copy()
446 subs.update({fssym:where})
447 where=name
448
449 from functions import grad_n
450 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
451 for d in range(self._dim):
452 for idx in numpy.ndindex(self.getShape()):
453 index=idx+(d,)
454 if where is None:
455 out[index]=grad_n(out[index],d)
456 else:
457 out[index]=grad_n(out[index],d,where)
458 return Symbol(out, dim=self._dim, subs=subs)
459
460 def inverse(self):
461 """
462 """
463 if not self.getRank()==2:
464 raise TypeError("inverse: Only rank 2 supported")
465 s=self.getShape()
466 if not s[0] == s[1]:
467 raise ValueError("inverse: Only square shapes supported")
468 out=numpy.zeros(s, numpy.object)
469 arr=self._arr
470 if s[0]==1:
471 if arr[0,0].is_zero:
472 raise ZeroDivisionError("inverse: Symbol not invertible")
473 out[0,0]=1./arr[0,0]
474 elif s[0]==2:
475 A11=arr[0,0]
476 A12=arr[0,1]
477 A21=arr[1,0]
478 A22=arr[1,1]
479 D = A11*A22-A12*A21
480 if D.is_zero:
481 raise ZeroDivisionError("inverse: Symbol not invertible")
482 D=1./D
483 out[0,0]= A22*D
484 out[1,0]=-A21*D
485 out[0,1]=-A12*D
486 out[1,1]= A11*D
487 elif s[0]==3:
488 A11=arr[0,0]
489 A21=arr[1,0]
490 A31=arr[2,0]
491 A12=arr[0,1]
492 A22=arr[1,1]
493 A32=arr[2,1]
494 A13=arr[0,2]
495 A23=arr[1,2]
496 A33=arr[2,2]
497 D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
498 if D.is_zero:
499 raise ZeroDivisionError("inverse: Symbol not invertible")
500 D=1./D
501 out[0,0]=(A22*A33-A23*A32)*D
502 out[1,0]=(A31*A23-A21*A33)*D
503 out[2,0]=(A21*A32-A31*A22)*D
504 out[0,1]=(A13*A32-A12*A33)*D
505 out[1,1]=(A11*A33-A31*A13)*D
506 out[2,1]=(A12*A31-A11*A32)*D
507 out[0,2]=(A12*A23-A13*A22)*D
508 out[1,2]=(A13*A21-A11*A23)*D
509 out[2,2]=(A11*A22-A12*A21)*D
510 else:
511 raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
512 return Symbol(out, dim=self._dim, subs=self._subs)
513
514 def swap_axes(self, axis0, axis1):
515 """
516 """
517 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
518
519 def tensorProduct(self, other, axis_offset):
520 """
521 """
522 arg0_c=self._arr.copy()
523 sh0=self.getShape()
524 if isinstance(other, Symbol):
525 arg1_c=other._arr.copy()
526 sh1=other.getShape()
527 else:
528 arg1_c=other.copy()
529 sh1=other.shape
530 d0,d1,d01=1,1,1
531 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
532 for i in sh1[axis_offset:]: d1*=i
533 for i in sh1[:axis_offset]: d01*=i
534 arg0_c.resize((d0,d01))
535 arg1_c.resize((d01,d1))
536 out=numpy.zeros((d0,d1),numpy.object)
537 for i0 in range(d0):
538 for i1 in range(d1):
539 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
540 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
541 subs=self._subs.copy()
542 subs.update(other._subs)
543 return Symbol(out, dim=self._dim, subs=subs)
544
545 def transposedTensorProduct(self, other, axis_offset):
546 """
547 """
548 arg0_c=self._arr.copy()
549 sh0=self.getShape()
550 if isinstance(other, Symbol):
551 arg1_c=other._arr.copy()
552 sh1=other.getShape()
553 else:
554 arg1_c=other.copy()
555 sh1=other.shape
556 d0,d1,d01=1,1,1
557 for i in sh0[axis_offset:]: d0*=i
558 for i in sh1[axis_offset:]: d1*=i
559 for i in sh1[:axis_offset]: d01*=i
560 arg0_c.resize((d01,d0))
561 arg1_c.resize((d01,d1))
562 out=numpy.zeros((d0,d1),numpy.object)
563 for i0 in range(d0):
564 for i1 in range(d1):
565 out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
566 out.resize(sh0[axis_offset:]+sh1[axis_offset:])
567 subs=self._subs.copy()
568 subs.update(other._subs)
569 return Symbol(out, dim=self._dim, subs=subs)
570
571 def tensorTransposedProduct(self, other, axis_offset):
572 """
573 """
574 arg0_c=self._arr.copy()
575 sh0=self.getShape()
576 if isinstance(other, Symbol):
577 arg1_c=other._arr.copy()
578 sh1=other.getShape()
579 r1=other.getRank()
580 else:
581 arg1_c=other.copy()
582 sh1=other.shape
583 r1=other.ndim
584 d0,d1,d01=1,1,1
585 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
586 for i in sh1[:r1-axis_offset]: d1*=i
587 for i in sh1[r1-axis_offset:]: d01*=i
588 arg0_c.resize((d0,d01))
589 arg1_c.resize((d1,d01))
590 out=numpy.zeros((d0,d1),numpy.object)
591 for i0 in range(d0):
592 for i1 in range(d1):
593 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
594 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
595 subs=self._subs.copy()
596 subs.update(other._subs)
597 return Symbol(out, dim=self._dim, subs=subs)
598
599 def trace(self, axis_offset):
600 """
601 """
602 sh=self.getShape()
603 s1=1
604 for i in range(axis_offset): s1*=sh[i]
605 s2=1
606 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
607 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
608 out=numpy.zeros([s1,s2],object)
609 for i1 in range(s1):
610 for i2 in range(s2):
611 for j in range(sh[axis_offset]):
612 out[i1,i2]+=arr_r[i1,j,j,i2]
613 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
614 return Symbol(out, dim=self._dim, subs=self._subs)
615
616 def transpose(self, axis_offset):
617 """
618 """
619 if axis_offset is None:
620 axis_offset=int(self._arr.ndim/2)
621 axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
622 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
623
624 def applyfunc(self, f, on_type=None):
625 """
626 """
627 assert callable(f)
628 if self._arr.ndim==0:
629 if on_type is None or isinstance(self._arr.item(), on_type):
630 el=f(self._arr.item())
631 else:
632 el=self._arr.item()
633 if el is not None:
634 out=Symbol(el, dim=self._dim, subs=self._subs)
635 else:
636 return el
637 else:
638 out=numpy.empty(self.getShape(), dtype=object)
639 for idx in numpy.ndindex(self.getShape()):
640 if on_type is None or isinstance(self._arr[idx],on_type):
641 out[idx]=f(self._arr[idx])
642 else:
643 out[idx]=self._arr[idx]
644 out=Symbol(out, dim=self._dim, subs=self._subs)
645 return out
646
647 def expand(self):
648 """
649 """
650 return self.applyfunc(sympy.expand, sympy.Basic)
651
652 def simplify(self):
653 """
654 """
655 return self.applyfunc(sympy.simplify, sympy.Basic)
656
657 # unary/binary operations follow
658
659 def __pos__(self):
660 return self
661
662 def __neg__(self):
663 return Symbol(-self._arr, dim=self._dim, subs=self._subs)
664
665 def __abs__(self):
666 return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
667
668 def _ensureShapeCompatible(self, other):
669 """
670 Checks for compatible shapes for binary operations.
671 Raises TypeError if not compatible.
672 """
673 sh0=self.getShape()
674 if isinstance(other, Symbol) or isinstance(other, Data):
675 sh1=other.getShape()
676 elif isinstance(other, numpy.ndarray):
677 sh1=other.shape
678 elif isinstance(other, list):
679 sh1=numpy.array(other).shape
680 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
681 sh1=()
682 else:
683 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
684 if not sh0==sh1 and not sh0==() and not sh1==():
685 raise TypeError("Incompatible shapes for operation")
686
687 def __binaryop(self, op, other):
688 self._ensureShapeCompatible(other)
689 if isinstance(other, Symbol):
690 subs=self._subs.copy()
691 subs.update(other._subs)
692 return Symbol(getattr(self._arr, op)(other._arr), dim=self._dim, subs=subs)
693 if isinstance(other, Data):
694 name='data'+str(id(other))
695 othersym=Symbol(name, other.getShape(), dim=self._dim)
696 subs=self._subs.copy()
697 subs.update({Symbol(name):other})
698 return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
699 return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
700
701 def __add__(self, other):
702 return self.__binaryop('__add__', other)
703
704 def __radd__(self, other):
705 return self.__binaryop('__radd__', other)
706
707 def __sub__(self, other):
708 return self.__binaryop('__sub__', other)
709
710 def __rsub__(self, other):
711 return self.__binaryop('__rsub__', other)
712
713 def __mul__(self, other):
714 return self.__binaryop('__mul__', other)
715
716 def __rmul__(self, other):
717 return self.__binaryop('__rmul__', other)
718
719 def __div__(self, other):
720 return self.__binaryop('__div__', other)
721
722 def __rdiv__(self, other):
723 return self.__binaryop('__rdiv__', other)
724
725 def __pow__(self, other):
726 return self.__binaryop('__pow__', other)
727
728 def __rpow__(self, other):
729 return self.__binaryop('__rpow__', other)
730
731 @staticmethod
732 def _symComp(sym):
733 """
734 """
735 n=sym.name
736 a=n.split('[')
737 if len(a)!=2:
738 return n,()
739 a=a[1].split(']')
740 if len(a)!=2:
741 return n,()
742 name=a[0]
743 comps=[int(i) for i in a[1].split('_')[1:]]
744 return name,tuple(comps)
745
746 @staticmethod
747 def _symbolgen(*symbols):
748 """
749 Generator of all symbols in the argument of diff().
750 (cf. sympy.Derivative._symbolgen)
751
752 Example:
753 >> ._symbolgen(x, 3, y)
754 (x, x, x, y)
755 >> ._symbolgen(x, 10**6)
756 (x, x, x, x, x, x, x, ...)
757 """
758 from itertools import repeat
759 last_s = symbols[len(symbols)-1]
760 if not isinstance(last_s, Symbol):
761 last_s=sympy.sympify(last_s)
762 for i in xrange(len(symbols)):
763 s = symbols[i]
764 if not isinstance(s, Symbol):
765 s=sympy.sympify(s)
766 next_s = None
767 if s != last_s:
768 next_s = symbols[i+1]
769 if not isinstance(next_s, Symbol):
770 next_s=sympy.sympify(next_s)
771
772 if isinstance(s, sympy.Integer):
773 continue
774 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
775 # handle cases like (x, 3)
776 if isinstance(next_s, sympy.Integer):
777 # yield (x, x, x)
778 for copy_s in repeat(s,int(next_s)):
779 yield copy_s
780 else:
781 yield s
782 else:
783 yield s
784

  ViewVC Help
Powered by ViewVC 1.1.26