/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Contents of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3981 - (show annotations)
Fri Sep 21 02:47:54 2012 UTC (7 years ago) by jfenwick
Original Path: trunk/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 28841 byte(s)
First pass of updating copyright notices
1
2 ##############################################################################
3 #
4 # Copyright (c) 2003-2012 by University of Queensland
5 # http://www.uq.edu.au
6 #
7 # Primary Business: Queensland, Australia
8 # Licensed under the Open Software License version 3.0
9 # http://www.opensource.org/licenses/osl-3.0.php
10 #
11 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 # Development since 2012 by School of Earth Sciences
13 #
14 ##############################################################################
15
16 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
17 http://www.uq.edu.au
18 Primary Business: Queensland, Australia"""
19 __license__="""Licensed under the Open Software License version 3.0
20 http://www.opensource.org/licenses/osl-3.0.php"""
21 __url__="https://launchpad.net/escript-finley"
22 __author__="Cihan Altinay"
23
24 """
25 :var __author__: name of author
26 :var __copyright__: copyrights
27 :var __license__: licence agreement
28 :var __url__: url entry point on documentation
29 :var __version__: version
30 :var __date__: date of the version
31 """
32
33 import numpy
34 from esys.escript import Data, FunctionSpace, HAVE_SYMBOLS
35 if HAVE_SYMBOLS:
36 import sympy
37
38
39 class Symbol(object):
40 """
41 `Symbol` objects are placeholders for a single mathematical symbol, such as
42 'x', or for arbitrarily complex mathematical expressions such as
43 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
44 are also `Symbol`s (the symbolic 'atoms' of the expression).
45
46 With the help of the 'Evaluator' class these symbols and expressions can
47 be resolved by substituting numeric values and/or escript `Data` objects
48 for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
49 shape (and thus a rank) as well as a dimension (see constructor).
50 `Symbol`s are useful to perform mathematical simplifications, compute
51 derivatives and as coefficients for nonlinear PDEs which can be solved by
52 the `NonlinearPDE` class.
53 """
54
55 # these are for compatibility with sympy.Symbol. lambdify checks these.
56 is_Add=False
57 is_Float=False
58
59 def __init__(self, *args, **kwargs):
60 """
61 Initialises a new `Symbol` object in one of three ways::
62
63 u=Symbol('u')
64
65 returns a scalar symbol by the name 'u'.
66
67 a=Symbol('alpha', (4,3))
68
69 returns a rank 2 symbol with the shape (4,3), whose elements are
70 named '[alpha]_i_j' (with i=0..3, j=0..2).
71
72 a,b,c=symbols('a,b,c')
73 x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
74
75 returns a rank 2 symbol with the shape (3,3) whose elements are
76 explicitly specified by numeric values and other symbols/expressions
77 within a list or numpy array.
78
79 The dimensionality of the symbol can be specified through the `dim`
80 keyword. All other keywords are passed to the underlying symbolic
81 library (currently sympy).
82
83 :param args: initialisation arguments as described above
84 :keyword dim: dimensionality of the new Symbol (default: 2)
85 :type dim: ``int``
86 """
87 if not HAVE_SYMBOLS:
88 raise RuntimeError("Trying to instantiate a Symbol but sympy not available")
89
90 if 'dim' in kwargs:
91 self._dim=kwargs.pop('dim')
92 else:
93 self._dim=-1 # undefined
94
95 if 'subs' in kwargs:
96 self._subs=kwargs.pop('subs')
97 else:
98 self._subs={}
99
100 if len(args)==1:
101 arg=args[0]
102 if isinstance(arg, str):
103 if arg.find('[')>=0 or arg.find(']')>=0:
104 raise ValueError("Name must not contain '[' or ']'")
105 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
106 elif hasattr(arg, "__array__") or isinstance(arg, list):
107 if isinstance(arg, list): arg=numpy.array(arg)
108 arr=arg.__array__()
109 if len(arr.shape)>4:
110 raise ValueError("Symbol only supports tensors up to order 4")
111 res=numpy.empty(arr.shape, dtype=object)
112 for idx in numpy.ndindex(arr.shape):
113 if hasattr(arr[idx], "item"):
114 res[idx]=arr[idx].item()
115 else:
116 res[idx]=arr[idx]
117 self._arr=res
118 if isinstance(arg, Symbol):
119 self._subs.update(arg._subs)
120 if self._dim==-1:
121 self._dim=arg._dim
122 elif isinstance(arg, sympy.Basic):
123 self._arr=numpy.array(arg)
124 else:
125 raise TypeError("Unsupported argument type %s"%str(type(arg)))
126 elif len(args)==2:
127 if not isinstance(args[0], str):
128 raise TypeError("First argument must be a string")
129 if not isinstance(args[1], tuple):
130 raise TypeError("Second argument must be a tuple")
131 name=args[0]
132 shape=args[1]
133 if name.find('[')>=0 or name.find(']')>=0:
134 raise ValueError("Name must not contain '[' or ']'")
135 if len(shape)>4:
136 raise ValueError("Symbol only supports tensors up to order 4")
137 if len(shape)==0:
138 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
139 else:
140 try:
141 self._arr=sympy.symarray(shape, '['+name+']')
142 except TypeError:
143 self._arr=sympy.symarray('['+name+']', shape)
144 else:
145 raise TypeError("Unsupported number of arguments")
146 if self._arr.ndim==0:
147 self.name=str(self._arr.item())
148 else:
149 self.name=str(self._arr.tolist())
150
151 def __repr__(self):
152 return str(self._arr)
153
154 def __str__(self):
155 return str(self._arr)
156
157 def __eq__(self, other):
158 if type(self) is not type(other):
159 return False
160 if self.getRank()!=other.getRank():
161 return False
162 if self.getShape()!=other.getShape():
163 return False
164 return (self._arr==other._arr).all()
165
166 def __getitem__(self, key):
167 """
168 Returns an element of this symbol which must have rank >0.
169 Unlike item() this method converts sympy objects and numpy arrays into
170 escript Symbols in order to facilitate expressions that require
171 element access, such as: grad(u)[1]+x
172
173 :param key: (nd-)index of item to be returned
174 :return: the requested element
175 :rtype: ``Symbol``, ``int``, or ``float``
176 """
177 res=self._arr[key]
178 # replace sympy Symbols/expressions by escript Symbols
179 if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
180 res=Symbol(res)
181 return res
182
183 def __setitem__(self, key, value):
184 if isinstance(value, Symbol):
185 if value.getRank()==0:
186 self._arr[key]=value.item()
187 elif hasattr(self._arr[key], "shape"):
188 if self._arr[key].shape==value.getShape():
189 for idx in numpy.ndindex(self._arr[key].shape):
190 self._arr[key][idx]=value[idx].item()
191 else:
192 raise ValueError("Wrong shape of value")
193 else:
194 raise ValueError("Wrong shape of value")
195 elif isinstance(value, sympy.Basic):
196 self._arr[key]=value
197 elif hasattr(value, "__array__"):
198 self._arr[key]=map(sympy.sympify,value.flat)
199 else:
200 self._arr[key]=sympy.sympify(value)
201
202 def __iter__(self):
203 return self._arr.__iter__
204
205 def __array__(self, t=None):
206 if t:
207 return self._arr.astype(t)
208 else:
209 return self._arr
210
211 def _sympy_(self):
212 """
213 """
214 return self.applyfunc(sympy.sympify)
215
216 def getDim(self):
217 """
218 Returns the spatial dimensionality of this symbol.
219
220 :return: the symbol's spatial dimensionality, or -1 if undefined
221 :rtype: ``int``
222 """
223 return self._dim
224
225 def getRank(self):
226 """
227 Returns the rank of this symbol.
228
229 :return: the symbol's rank which is equal to the length of the shape.
230 :rtype: ``int``
231 """
232 return self._arr.ndim
233
234 def getShape(self):
235 """
236 Returns the shape of this symbol.
237
238 :return: the symbol's shape
239 :rtype: ``tuple`` of ``int``
240 """
241 return self._arr.shape
242
243 def getDataSubstitutions(self):
244 """
245 Returns a dictionary of symbol names and the escript ``Data`` objects
246 they represent within this Symbol.
247
248 :return: the dictionary of substituted ``Data`` objects
249 :rtype: ``dict``
250 """
251 return self._subs
252
253 def item(self, *args):
254 """
255 Returns an element of this symbol.
256 This method behaves like the item() method of numpy.ndarray.
257 If this is a scalar Symbol, no arguments are allowed and the only
258 element in this Symbol is returned.
259 Otherwise, 'args' specifies a flat or nd-index and the element at
260 that index is returned.
261
262 :param args: index of item to be returned
263 :return: the requested element
264 :rtype: ``sympy.Symbol``, ``int``, or ``float``
265 """
266 return self._arr.item(args)
267
268 def atoms(self, *types):
269 """
270 Returns the atoms that form the current Symbol.
271
272 By default, only objects that are truly atomic and cannot be divided
273 into smaller pieces are returned: symbols, numbers, and number
274 symbols like I and pi. It is possible to request atoms of any type,
275 however.
276
277 Note that if this symbol contains components such as [x]_i_j then
278 only their main symbol 'x' is returned.
279
280 :param types: types to restrict result to
281 :return: list of atoms of specified type
282 :rtype: ``set``
283 """
284 s=set()
285 for el in self._arr.flat:
286 if isinstance(el,sympy.Basic):
287 atoms=el.atoms(*types)
288 for a in atoms:
289 if a.is_Symbol:
290 n,c=Symbol._symComp(a)
291 s.add(sympy.Symbol(n))
292 else:
293 s.add(a)
294 elif len(types)==0 or type(el) in types:
295 s.add(el)
296 return s
297
298 def _sympystr_(self, printer):
299 # compatibility with sympy 1.6
300 return self._sympystr(printer)
301
302 def _sympystr(self, printer):
303 return self.lambdarepr()
304
305 def lambdarepr(self):
306 """
307 """
308 from sympy.printing.lambdarepr import lambdarepr
309 temp_arr=numpy.empty(self.getShape(), dtype=object)
310 for idx,el in numpy.ndenumerate(self._arr):
311 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
312 # create a dictionary to convert names like [x]_0_0 to x[0,0]
313 symdict={}
314 for a in atoms:
315 n,c=Symbol._symComp(a)
316 if len(c)>0:
317 c=[str(i) for i in c]
318 symstr=n+'['+','.join(c)+']'
319 else:
320 symstr=n
321 symdict[a.name]=symstr
322 s=lambdarepr(el)
323 for key in symdict:
324 s=s.replace(key, symdict[key])
325 temp_arr[idx]=s
326 if self.getRank()==0:
327 return temp_arr.item()
328 else:
329 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
330
331 def coeff(self, x, expand=True):
332 """
333 Returns the coefficient of the term "x" or 0 if there is no "x".
334
335 If "x" is a scalar symbol then "x" is searched in all components of
336 this symbol. Otherwise the shapes must match and the coefficients are
337 checked component by component.
338
339 Example::
340
341 x=Symbol('x', (2,2))
342 y=3*x
343 print y.coeff(x)
344 print y.coeff(x[1,1])
345
346 will print::
347
348 [[3 3]
349 [3 3]]
350
351 [[0 0]
352 [0 3]]
353
354 :param x: the term whose coefficients are to be found
355 :type x: ``Symbol``, ``numpy.ndarray``, `list`
356 :return: the coefficient(s) of the term
357 :rtype: ``Symbol``
358 """
359 self._ensureShapeCompatible(x)
360 if hasattr(x, '__array__'):
361 y=x.__array__()
362 else:
363 y=numpy.array(x)
364
365 if y.ndim>0:
366 result=numpy.zeros(self.getShape(), dtype=object)
367 for idx in numpy.ndindex(y.shape):
368 if y[idx]!=0:
369 res=self._arr[idx].coeff(y[idx], expand)
370 if res is not None:
371 result[idx]=res
372 elif y.item()==0:
373 result=numpy.zeros(self.getShape(), dtype=object)
374 else:
375 coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
376 none_to_zero=lambda item: 0 if item is None else item
377 result=self.applyfunc(coeff_item)
378 result=result.applyfunc(none_to_zero)
379 res=Symbol(result, dim=self._dim)
380 for i in self._subs: res.subs(i, self._subs[i])
381 return res
382
383 def subs(self, old, new):
384 """
385 Substitutes an expression.
386 """
387 old._ensureShapeCompatible(new)
388 if isinstance(new, Data):
389 subs=self._subs.copy()
390 if isinstance(old, Symbol) and old.getRank()>0:
391 old=Symbol(old.atoms(sympy.Symbol)[0])
392 subs[old]=new
393 result=Symbol(self._arr, dim=self._dim, subs=subs)
394 elif isinstance(old, Symbol) and old.getRank()>0:
395 if hasattr(new, '__array__'):
396 new=new.__array__()
397 else:
398 new=numpy.array(new)
399
400 result=numpy.empty(self.getShape(), dtype=object)
401 if new.ndim>0:
402 for idx in numpy.ndindex(self.getShape()):
403 for symidx in numpy.ndindex(new.shape):
404 result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
405 else: # substitute scalar for non-scalar
406 for idx in numpy.ndindex(self.getShape()):
407 for symidx in numpy.ndindex(old.getShape()):
408 result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
409 result=Symbol(result, dim=self._dim, subs=self._subs)
410 else: # scalar
411 if isinstance(new, Symbol):
412 new=new.item()
413 if isinstance(old, Symbol):
414 old=old.item()
415 subs_item=lambda item: getattr(item, 'subs')(old, new)
416 result=self.applyfunc(subs_item)
417 return result
418
419 def diff(self, *symbols, **assumptions):
420 """
421 """
422 symbols=Symbol._symbolgen(*symbols)
423 result=Symbol(self._arr, dim=self._dim, subs=self._subs)
424 for s in symbols:
425 if isinstance(s, Symbol):
426 if s.getRank()==0:
427 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
428 result=result.applyfunc(diff_item)
429 elif s.getRank()==1:
430 dim=s.getShape()[0]
431 out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
432 for d in range(dim):
433 for idx in numpy.ndindex(self.getShape()):
434 index=idx+(d,)
435 out[index]=out[index].diff(s[d].item(), **assumptions)
436 result=Symbol(out, dim=self._dim, subs=self._subs)
437 else:
438 raise ValueError("diff: argument must have rank 0 or 1")
439 else:
440 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
441 result=result.applyfunc(diff_item)
442 return result
443
444 def grad(self, where=None):
445 """
446 Returns a symbol which represents the gradient of this symbol.
447 :type where: ``Symbol``, ``FunctionSpace``
448 """
449 if self._dim < 0:
450 raise ValueError("grad: cannot compute gradient as symbol has undefined dimensionality")
451 subs=self._subs
452 if isinstance(where, Symbol):
453 if where.getRank()>0:
454 raise ValueError("grad: 'where' must be a scalar symbol")
455 where=where._arr.item()
456 elif isinstance(where, FunctionSpace):
457 name='fs'+str(id(where))
458 fssym=Symbol(name)
459 subs=self._subs.copy()
460 subs.update({fssym:where})
461 where=name
462
463 from functions import grad_n
464 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
465 for d in range(self._dim):
466 for idx in numpy.ndindex(self.getShape()):
467 index=idx+(d,)
468 if where is None:
469 out[index]=grad_n(out[index],d)
470 else:
471 out[index]=grad_n(out[index],d,where)
472 return Symbol(out, dim=self._dim, subs=subs)
473
474 def inverse(self):
475 """
476 """
477 if not self.getRank()==2:
478 raise TypeError("inverse: Only rank 2 supported")
479 s=self.getShape()
480 if not s[0] == s[1]:
481 raise ValueError("inverse: Only square shapes supported")
482 out=numpy.zeros(s, numpy.object)
483 arr=self._arr
484 if s[0]==1:
485 if arr[0,0].is_zero:
486 raise ZeroDivisionError("inverse: Symbol not invertible")
487 out[0,0]=1./arr[0,0]
488 elif s[0]==2:
489 A11=arr[0,0]
490 A12=arr[0,1]
491 A21=arr[1,0]
492 A22=arr[1,1]
493 D = A11*A22-A12*A21
494 if D.is_zero:
495 raise ZeroDivisionError("inverse: Symbol not invertible")
496 D=1./D
497 out[0,0]= A22*D
498 out[1,0]=-A21*D
499 out[0,1]=-A12*D
500 out[1,1]= A11*D
501 elif s[0]==3:
502 A11=arr[0,0]
503 A21=arr[1,0]
504 A31=arr[2,0]
505 A12=arr[0,1]
506 A22=arr[1,1]
507 A32=arr[2,1]
508 A13=arr[0,2]
509 A23=arr[1,2]
510 A33=arr[2,2]
511 D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
512 if D.is_zero:
513 raise ZeroDivisionError("inverse: Symbol not invertible")
514 D=1./D
515 out[0,0]=(A22*A33-A23*A32)*D
516 out[1,0]=(A31*A23-A21*A33)*D
517 out[2,0]=(A21*A32-A31*A22)*D
518 out[0,1]=(A13*A32-A12*A33)*D
519 out[1,1]=(A11*A33-A31*A13)*D
520 out[2,1]=(A12*A31-A11*A32)*D
521 out[0,2]=(A12*A23-A13*A22)*D
522 out[1,2]=(A13*A21-A11*A23)*D
523 out[2,2]=(A11*A22-A12*A21)*D
524 else:
525 raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
526 return Symbol(out, dim=self._dim, subs=self._subs)
527
528 def swap_axes(self, axis0, axis1):
529 """
530 """
531 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
532
533 def tensorProduct(self, other, axis_offset):
534 """
535 """
536 arg0_c=self._arr.copy()
537 sh0=self.getShape()
538 if isinstance(other, Symbol):
539 arg1_c=other._arr.copy()
540 sh1=other.getShape()
541 dim=other._dim if self._dim < 0 else self._dim
542 else:
543 arg1_c=other.copy()
544 sh1=other.shape
545 dim=self._dim
546 d0,d1,d01=1,1,1
547 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
548 for i in sh1[axis_offset:]: d1*=i
549 for i in sh1[:axis_offset]: d01*=i
550 arg0_c.resize((d0,d01))
551 arg1_c.resize((d01,d1))
552 out=numpy.zeros((d0,d1),numpy.object)
553 for i0 in range(d0):
554 for i1 in range(d1):
555 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
556 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
557 subs=self._subs.copy()
558 subs.update(other._subs)
559 return Symbol(out, dim=dim, subs=subs)
560
561 def transposedTensorProduct(self, other, axis_offset):
562 """
563 """
564 arg0_c=self._arr.copy()
565 sh0=self.getShape()
566 if isinstance(other, Symbol):
567 arg1_c=other._arr.copy()
568 sh1=other.getShape()
569 dim=other._dim if self._dim < 0 else self._dim
570 else:
571 arg1_c=other.copy()
572 sh1=other.shape
573 dim=self._dim
574 d0,d1,d01=1,1,1
575 for i in sh0[axis_offset:]: d0*=i
576 for i in sh1[axis_offset:]: d1*=i
577 for i in sh1[:axis_offset]: d01*=i
578 arg0_c.resize((d01,d0))
579 arg1_c.resize((d01,d1))
580 out=numpy.zeros((d0,d1),numpy.object)
581 for i0 in range(d0):
582 for i1 in range(d1):
583 out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
584 out.resize(sh0[axis_offset:]+sh1[axis_offset:])
585 subs=self._subs.copy()
586 subs.update(other._subs)
587 return Symbol(out, dim=dim, subs=subs)
588
589 def tensorTransposedProduct(self, other, axis_offset):
590 """
591 """
592 arg0_c=self._arr.copy()
593 sh0=self.getShape()
594 if isinstance(other, Symbol):
595 arg1_c=other._arr.copy()
596 sh1=other.getShape()
597 r1=other.getRank()
598 dim=other._dim if self._dim < 0 else self._dim
599 else:
600 arg1_c=other.copy()
601 sh1=other.shape
602 r1=other.ndim
603 dim=self._dim
604 d0,d1,d01=1,1,1
605 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
606 for i in sh1[:r1-axis_offset]: d1*=i
607 for i in sh1[r1-axis_offset:]: d01*=i
608 arg0_c.resize((d0,d01))
609 arg1_c.resize((d1,d01))
610 out=numpy.zeros((d0,d1),numpy.object)
611 for i0 in range(d0):
612 for i1 in range(d1):
613 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
614 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
615 subs=self._subs.copy()
616 subs.update(other._subs)
617 return Symbol(out, dim=dim, subs=subs)
618
619 def trace(self, axis_offset):
620 """
621 Returns the trace of this Symbol.
622 """
623 sh=self.getShape()
624 s1=1
625 for i in range(axis_offset): s1*=sh[i]
626 s2=1
627 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
628 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
629 out=numpy.zeros([s1,s2],object)
630 for i1 in range(s1):
631 for i2 in range(s2):
632 for j in range(sh[axis_offset]):
633 out[i1,i2]+=arr_r[i1,j,j,i2]
634 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
635 return Symbol(out, dim=self._dim, subs=self._subs)
636
637 def transpose(self, axis_offset):
638 """
639 Returns the transpose of this Symbol.
640 """
641 if axis_offset is None:
642 axis_offset=int(self._arr.ndim/2)
643 axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
644 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
645
646 def applyfunc(self, f, on_type=None):
647 """
648 Applies the function `f` to all elements (if on_type is None) or to
649 all elements of type `on_type`.
650 """
651 assert callable(f)
652 if self._arr.ndim==0:
653 if on_type is None or isinstance(self._arr.item(), on_type):
654 el=f(self._arr.item())
655 else:
656 el=self._arr.item()
657 if el is not None:
658 out=Symbol(el, dim=self._dim, subs=self._subs)
659 else:
660 return el
661 else:
662 out=numpy.empty(self.getShape(), dtype=object)
663 for idx in numpy.ndindex(self.getShape()):
664 if on_type is None or isinstance(self._arr[idx],on_type):
665 out[idx]=f(self._arr[idx])
666 else:
667 out[idx]=self._arr[idx]
668 out=Symbol(out, dim=self._dim, subs=self._subs)
669 return out
670
671 def expand(self):
672 """
673 Applies the sympy.expand operation on all elements in this symbol
674 """
675 return self.applyfunc(sympy.expand, sympy.Basic)
676
677 def simplify(self):
678 """
679 Applies the sympy.simplify operation on all elements in this symbol
680 """
681 return self.applyfunc(sympy.simplify, sympy.Basic)
682
683 # unary/binary operators follow
684
685 def __pos__(self):
686 return self
687
688 def __neg__(self):
689 return Symbol(-self._arr, dim=self._dim, subs=self._subs)
690
691 def __abs__(self):
692 return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
693
694 def _ensureShapeCompatible(self, other):
695 """
696 Checks for compatible shapes for binary operations.
697 Raises TypeError if not compatible.
698 """
699 sh0=self.getShape()
700 if isinstance(other, Symbol) or isinstance(other, Data):
701 sh1=other.getShape()
702 elif isinstance(other, numpy.ndarray):
703 sh1=other.shape
704 elif isinstance(other, list):
705 sh1=numpy.array(other).shape
706 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
707 sh1=()
708 else:
709 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
710 if not sh0==sh1 and not sh0==() and not sh1==():
711 raise TypeError("Incompatible shapes for operation")
712
713 def __binaryop(self, op, other):
714 """
715 Helper for binary operations that checks types, shapes etc.
716 """
717 self._ensureShapeCompatible(other)
718 if isinstance(other, Symbol):
719 subs=self._subs.copy()
720 subs.update(other._subs)
721 dim=other._dim if self._dim < 0 else self._dim
722 return Symbol(getattr(self._arr, op)(other._arr), dim=dim, subs=subs)
723 if isinstance(other, Data):
724 name='data'+str(id(other))
725 othersym=Symbol(name, other.getShape(), dim=self._dim)
726 subs=self._subs.copy()
727 subs.update({Symbol(name):other})
728 return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
729 return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
730
731 def __add__(self, other):
732 return self.__binaryop('__add__', other)
733
734 def __radd__(self, other):
735 return self.__binaryop('__radd__', other)
736
737 def __sub__(self, other):
738 return self.__binaryop('__sub__', other)
739
740 def __rsub__(self, other):
741 return self.__binaryop('__rsub__', other)
742
743 def __mul__(self, other):
744 return self.__binaryop('__mul__', other)
745
746 def __rmul__(self, other):
747 return self.__binaryop('__rmul__', other)
748
749 def __div__(self, other):
750 return self.__binaryop('__div__', other)
751
752 def __rdiv__(self, other):
753 return self.__binaryop('__rdiv__', other)
754
755 def __pow__(self, other):
756 return self.__binaryop('__pow__', other)
757
758 def __rpow__(self, other):
759 return self.__binaryop('__rpow__', other)
760
761 # static methods
762
763 @staticmethod
764 def _symComp(sym):
765 """
766 """
767 n=sym.name
768 a=n.split('[')
769 if len(a)!=2:
770 return n,()
771 a=a[1].split(']')
772 if len(a)!=2:
773 return n,()
774 name=a[0]
775 comps=[int(i) for i in a[1].split('_')[1:]]
776 return name,tuple(comps)
777
778 @staticmethod
779 def _symbolgen(*symbols):
780 """
781 Generator of all symbols in the argument of diff().
782 (cf. sympy.Derivative._symbolgen)
783
784 Example:
785 >> ._symbolgen(x, 3, y)
786 (x, x, x, y)
787 >> ._symbolgen(x, 10**6)
788 (x, x, x, x, x, x, x, ...)
789 """
790 from itertools import repeat
791 last_s = symbols[len(symbols)-1]
792 if not isinstance(last_s, Symbol):
793 last_s=sympy.sympify(last_s)
794 for i in xrange(len(symbols)):
795 s = symbols[i]
796 if not isinstance(s, Symbol):
797 s=sympy.sympify(s)
798 next_s = None
799 if s != last_s:
800 next_s = symbols[i+1]
801 if not isinstance(next_s, Symbol):
802 next_s=sympy.sympify(next_s)
803
804 if isinstance(s, sympy.Integer):
805 continue
806 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
807 # handle cases like (x, 3)
808 if isinstance(next_s, sympy.Integer):
809 # yield (x, x, x)
810 for copy_s in repeat(s,int(next_s)):
811 yield copy_s
812 else:
813 yield s
814 else:
815 yield s
816

  ViewVC Help
Powered by ViewVC 1.1.26