/[escript]/trunk/escript/py_src/symbolic/symbols.py
ViewVC logotype

Contents of /trunk/escript/py_src/symbolic/symbols.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3975 - (show annotations)
Thu Sep 20 01:54:06 2012 UTC (7 years, 1 month ago) by caltinay
File MIME type: text/x-python
File size: 28625 byte(s)
Merged symbolic branch into trunk. Curious what daniel and spartacus have to
say...

1
2 ########################################################
3 #
4 # Copyright (c) 2003-2012 by University of Queensland
5 # Earth Systems Science Computational Center (ESSCC)
6 # http://www.uq.edu.au/esscc
7 #
8 # Primary Business: Queensland, Australia
9 # Licensed under the Open Software License version 3.0
10 # http://www.opensource.org/licenses/osl-3.0.php
11 #
12 ########################################################
13
14 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15 Earth Systems Science Computational Center (ESSCC)
16 http://www.uq.edu.au/esscc
17 Primary Business: Queensland, Australia"""
18 __license__="""Licensed under the Open Software License version 3.0
19 http://www.opensource.org/licenses/osl-3.0.php"""
20 __url__="https://launchpad.net/escript-finley"
21 __author__="Cihan Altinay"
22
23 """
24 :var __author__: name of author
25 :var __copyright__: copyrights
26 :var __license__: licence agreement
27 :var __url__: url entry point on documentation
28 :var __version__: version
29 :var __date__: date of the version
30 """
31
32 import numpy
33 import sympy
34 from esys.escript import Data, FunctionSpace
35
36
37 class Symbol(object):
38 """
39 `Symbol` objects are placeholders for a single mathematical symbol, such as
40 'x', or for arbitrarily complex mathematical expressions such as
41 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
42 are also `Symbol`s (the symbolic 'atoms' of the expression).
43
44 With the help of the 'Evaluator' class these symbols and expressions can
45 be resolved by substituting numeric values and/or escript `Data` objects
46 for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
47 shape (and thus a rank) as well as a dimension (see constructor).
48 `Symbol`s are useful to perform mathematical simplifications, compute
49 derivatives and as coefficients for nonlinear PDEs which can be solved by
50 the `NonlinearPDE` class.
51 """
52
53 # these are for compatibility with sympy.Symbol. lambdify checks these.
54 is_Add=False
55 is_Float=False
56
57 def __init__(self, *args, **kwargs):
58 """
59 Initialises a new `Symbol` object in one of three ways::
60
61 u=Symbol('u')
62
63 returns a scalar symbol by the name 'u'.
64
65 a=Symbol('alpha', (4,3))
66
67 returns a rank 2 symbol with the shape (4,3), whose elements are
68 named '[alpha]_i_j' (with i=0..3, j=0..2).
69
70 a,b,c=symbols('a,b,c')
71 x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
72
73 returns a rank 2 symbol with the shape (3,3) whose elements are
74 explicitly specified by numeric values and other symbols/expressions
75 within a list or numpy array.
76
77 The dimensionality of the symbol can be specified through the `dim`
78 keyword. All other keywords are passed to the underlying symbolic
79 library (currently sympy).
80
81 :param args: initialisation arguments as described above
82 :keyword dim: dimensionality of the new Symbol (default: 2)
83 :type dim: ``int``
84 """
85 if 'dim' in kwargs:
86 self._dim=kwargs.pop('dim')
87 else:
88 self._dim=-1 # undefined
89
90 if 'subs' in kwargs:
91 self._subs=kwargs.pop('subs')
92 else:
93 self._subs={}
94
95 if len(args)==1:
96 arg=args[0]
97 if isinstance(arg, str):
98 if arg.find('[')>=0 or arg.find(']')>=0:
99 raise ValueError("Name must not contain '[' or ']'")
100 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
101 elif hasattr(arg, "__array__") or isinstance(arg, list):
102 if isinstance(arg, list): arg=numpy.array(arg)
103 arr=arg.__array__()
104 if len(arr.shape)>4:
105 raise ValueError("Symbol only supports tensors up to order 4")
106 res=numpy.empty(arr.shape, dtype=object)
107 for idx in numpy.ndindex(arr.shape):
108 if hasattr(arr[idx], "item"):
109 res[idx]=arr[idx].item()
110 else:
111 res[idx]=arr[idx]
112 self._arr=res
113 if isinstance(arg, Symbol):
114 self._subs.update(arg._subs)
115 if self._dim==-1:
116 self._dim=arg._dim
117 elif isinstance(arg, sympy.Basic):
118 self._arr=numpy.array(arg)
119 else:
120 raise TypeError("Unsupported argument type %s"%str(type(arg)))
121 elif len(args)==2:
122 if not isinstance(args[0], str):
123 raise TypeError("First argument must be a string")
124 if not isinstance(args[1], tuple):
125 raise TypeError("Second argument must be a tuple")
126 name=args[0]
127 shape=args[1]
128 if name.find('[')>=0 or name.find(']')>=0:
129 raise ValueError("Name must not contain '[' or ']'")
130 if len(shape)>4:
131 raise ValueError("Symbol only supports tensors up to order 4")
132 if len(shape)==0:
133 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
134 else:
135 try:
136 self._arr=sympy.symarray(shape, '['+name+']')
137 except TypeError:
138 self._arr=sympy.symarray('['+name+']', shape)
139 else:
140 raise TypeError("Unsupported number of arguments")
141 if self._arr.ndim==0:
142 self.name=str(self._arr.item())
143 else:
144 self.name=str(self._arr.tolist())
145
146 def __repr__(self):
147 return str(self._arr)
148
149 def __str__(self):
150 return str(self._arr)
151
152 def __eq__(self, other):
153 if type(self) is not type(other):
154 return False
155 if self.getRank()!=other.getRank():
156 return False
157 if self.getShape()!=other.getShape():
158 return False
159 return (self._arr==other._arr).all()
160
161 def __getitem__(self, key):
162 """
163 Returns an element of this symbol which must have rank >0.
164 Unlike item() this method converts sympy objects and numpy arrays into
165 escript Symbols in order to facilitate expressions that require
166 element access, such as: grad(u)[1]+x
167
168 :param key: (nd-)index of item to be returned
169 :return: the requested element
170 :rtype: ``Symbol``, ``int``, or ``float``
171 """
172 res=self._arr[key]
173 # replace sympy Symbols/expressions by escript Symbols
174 if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
175 res=Symbol(res)
176 return res
177
178 def __setitem__(self, key, value):
179 if isinstance(value, Symbol):
180 if value.getRank()==0:
181 self._arr[key]=value.item()
182 elif hasattr(self._arr[key], "shape"):
183 if self._arr[key].shape==value.getShape():
184 for idx in numpy.ndindex(self._arr[key].shape):
185 self._arr[key][idx]=value[idx].item()
186 else:
187 raise ValueError("Wrong shape of value")
188 else:
189 raise ValueError("Wrong shape of value")
190 elif isinstance(value, sympy.Basic):
191 self._arr[key]=value
192 elif hasattr(value, "__array__"):
193 self._arr[key]=map(sympy.sympify,value.flat)
194 else:
195 self._arr[key]=sympy.sympify(value)
196
197 def __iter__(self):
198 return self._arr.__iter__
199
200 def __array__(self, t=None):
201 if t:
202 return self._arr.astype(t)
203 else:
204 return self._arr
205
206 def _sympy_(self):
207 """
208 """
209 return self.applyfunc(sympy.sympify)
210
211 def getDim(self):
212 """
213 Returns the spatial dimensionality of this symbol.
214
215 :return: the symbol's spatial dimensionality, or -1 if undefined
216 :rtype: ``int``
217 """
218 return self._dim
219
220 def getRank(self):
221 """
222 Returns the rank of this symbol.
223
224 :return: the symbol's rank which is equal to the length of the shape.
225 :rtype: ``int``
226 """
227 return self._arr.ndim
228
229 def getShape(self):
230 """
231 Returns the shape of this symbol.
232
233 :return: the symbol's shape
234 :rtype: ``tuple`` of ``int``
235 """
236 return self._arr.shape
237
238 def getDataSubstitutions(self):
239 """
240 Returns a dictionary of symbol names and the escript ``Data`` objects
241 they represent within this Symbol.
242
243 :return: the dictionary of substituted ``Data`` objects
244 :rtype: ``dict``
245 """
246 return self._subs
247
248 def item(self, *args):
249 """
250 Returns an element of this symbol.
251 This method behaves like the item() method of numpy.ndarray.
252 If this is a scalar Symbol, no arguments are allowed and the only
253 element in this Symbol is returned.
254 Otherwise, 'args' specifies a flat or nd-index and the element at
255 that index is returned.
256
257 :param args: index of item to be returned
258 :return: the requested element
259 :rtype: ``sympy.Symbol``, ``int``, or ``float``
260 """
261 return self._arr.item(args)
262
263 def atoms(self, *types):
264 """
265 Returns the atoms that form the current Symbol.
266
267 By default, only objects that are truly atomic and cannot be divided
268 into smaller pieces are returned: symbols, numbers, and number
269 symbols like I and pi. It is possible to request atoms of any type,
270 however.
271
272 Note that if this symbol contains components such as [x]_i_j then
273 only their main symbol 'x' is returned.
274
275 :param types: types to restrict result to
276 :return: list of atoms of specified type
277 :rtype: ``set``
278 """
279 s=set()
280 for el in self._arr.flat:
281 if isinstance(el,sympy.Basic):
282 atoms=el.atoms(*types)
283 for a in atoms:
284 if a.is_Symbol:
285 n,c=Symbol._symComp(a)
286 s.add(sympy.Symbol(n))
287 else:
288 s.add(a)
289 elif len(types)==0 or type(el) in types:
290 s.add(el)
291 return s
292
293 def _sympystr_(self, printer):
294 # compatibility with sympy 1.6
295 return self._sympystr(printer)
296
297 def _sympystr(self, printer):
298 return self.lambdarepr()
299
300 def lambdarepr(self):
301 """
302 """
303 from sympy.printing.lambdarepr import lambdarepr
304 temp_arr=numpy.empty(self.getShape(), dtype=object)
305 for idx,el in numpy.ndenumerate(self._arr):
306 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
307 # create a dictionary to convert names like [x]_0_0 to x[0,0]
308 symdict={}
309 for a in atoms:
310 n,c=Symbol._symComp(a)
311 if len(c)>0:
312 c=[str(i) for i in c]
313 symstr=n+'['+','.join(c)+']'
314 else:
315 symstr=n
316 symdict[a.name]=symstr
317 s=lambdarepr(el)
318 for key in symdict:
319 s=s.replace(key, symdict[key])
320 temp_arr[idx]=s
321 if self.getRank()==0:
322 return temp_arr.item()
323 else:
324 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
325
326 def coeff(self, x, expand=True):
327 """
328 Returns the coefficient of the term "x" or 0 if there is no "x".
329
330 If "x" is a scalar symbol then "x" is searched in all components of
331 this symbol. Otherwise the shapes must match and the coefficients are
332 checked component by component.
333
334 Example::
335
336 x=Symbol('x', (2,2))
337 y=3*x
338 print y.coeff(x)
339 print y.coeff(x[1,1])
340
341 will print::
342
343 [[3 3]
344 [3 3]]
345
346 [[0 0]
347 [0 3]]
348
349 :param x: the term whose coefficients are to be found
350 :type x: ``Symbol``, ``numpy.ndarray``, `list`
351 :return: the coefficient(s) of the term
352 :rtype: ``Symbol``
353 """
354 self._ensureShapeCompatible(x)
355 if hasattr(x, '__array__'):
356 y=x.__array__()
357 else:
358 y=numpy.array(x)
359
360 if y.ndim>0:
361 result=numpy.zeros(self.getShape(), dtype=object)
362 for idx in numpy.ndindex(y.shape):
363 if y[idx]!=0:
364 res=self._arr[idx].coeff(y[idx], expand)
365 if res is not None:
366 result[idx]=res
367 elif y.item()==0:
368 result=numpy.zeros(self.getShape(), dtype=object)
369 else:
370 coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
371 none_to_zero=lambda item: 0 if item is None else item
372 result=self.applyfunc(coeff_item)
373 result=result.applyfunc(none_to_zero)
374 res=Symbol(result, dim=self._dim)
375 for i in self._subs: res.subs(i, self._subs[i])
376 return res
377
378 def subs(self, old, new):
379 """
380 Substitutes an expression.
381 """
382 old._ensureShapeCompatible(new)
383 if isinstance(new, Data):
384 subs=self._subs.copy()
385 if isinstance(old, Symbol) and old.getRank()>0:
386 old=Symbol(old.atoms(sympy.Symbol)[0])
387 subs[old]=new
388 result=Symbol(self._arr, dim=self._dim, subs=subs)
389 elif isinstance(old, Symbol) and old.getRank()>0:
390 if hasattr(new, '__array__'):
391 new=new.__array__()
392 else:
393 new=numpy.array(new)
394
395 result=numpy.empty(self.getShape(), dtype=object)
396 if new.ndim>0:
397 for idx in numpy.ndindex(self.getShape()):
398 for symidx in numpy.ndindex(new.shape):
399 result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
400 else: # substitute scalar for non-scalar
401 for idx in numpy.ndindex(self.getShape()):
402 for symidx in numpy.ndindex(old.getShape()):
403 result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
404 result=Symbol(result, dim=self._dim, subs=self._subs)
405 else: # scalar
406 if isinstance(new, Symbol):
407 new=new.item()
408 if isinstance(old, Symbol):
409 old=old.item()
410 subs_item=lambda item: getattr(item, 'subs')(old, new)
411 result=self.applyfunc(subs_item)
412 return result
413
414 def diff(self, *symbols, **assumptions):
415 """
416 """
417 symbols=Symbol._symbolgen(*symbols)
418 result=Symbol(self._arr, dim=self._dim, subs=self._subs)
419 for s in symbols:
420 if isinstance(s, Symbol):
421 if s.getRank()==0:
422 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
423 result=result.applyfunc(diff_item)
424 elif s.getRank()==1:
425 dim=s.getShape()[0]
426 out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
427 for d in range(dim):
428 for idx in numpy.ndindex(self.getShape()):
429 index=idx+(d,)
430 out[index]=out[index].diff(s[d].item(), **assumptions)
431 result=Symbol(out, dim=self._dim, subs=self._subs)
432 else:
433 raise ValueError("diff: argument must have rank 0 or 1")
434 else:
435 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
436 result=result.applyfunc(diff_item)
437 return result
438
439 def grad(self, where=None):
440 """
441 Returns a symbol which represents the gradient of this symbol.
442 :type where: ``Symbol``, ``FunctionSpace``
443 """
444 if self._dim < 0:
445 raise ValueError("grad: cannot compute gradient as symbol has undefined dimensionality")
446 subs=self._subs
447 if isinstance(where, Symbol):
448 if where.getRank()>0:
449 raise ValueError("grad: 'where' must be a scalar symbol")
450 where=where._arr.item()
451 elif isinstance(where, FunctionSpace):
452 name='fs'+str(id(where))
453 fssym=Symbol(name)
454 subs=self._subs.copy()
455 subs.update({fssym:where})
456 where=name
457
458 from functions import grad_n
459 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
460 for d in range(self._dim):
461 for idx in numpy.ndindex(self.getShape()):
462 index=idx+(d,)
463 if where is None:
464 out[index]=grad_n(out[index],d)
465 else:
466 out[index]=grad_n(out[index],d,where)
467 return Symbol(out, dim=self._dim, subs=subs)
468
469 def inverse(self):
470 """
471 """
472 if not self.getRank()==2:
473 raise TypeError("inverse: Only rank 2 supported")
474 s=self.getShape()
475 if not s[0] == s[1]:
476 raise ValueError("inverse: Only square shapes supported")
477 out=numpy.zeros(s, numpy.object)
478 arr=self._arr
479 if s[0]==1:
480 if arr[0,0].is_zero:
481 raise ZeroDivisionError("inverse: Symbol not invertible")
482 out[0,0]=1./arr[0,0]
483 elif s[0]==2:
484 A11=arr[0,0]
485 A12=arr[0,1]
486 A21=arr[1,0]
487 A22=arr[1,1]
488 D = A11*A22-A12*A21
489 if D.is_zero:
490 raise ZeroDivisionError("inverse: Symbol not invertible")
491 D=1./D
492 out[0,0]= A22*D
493 out[1,0]=-A21*D
494 out[0,1]=-A12*D
495 out[1,1]= A11*D
496 elif s[0]==3:
497 A11=arr[0,0]
498 A21=arr[1,0]
499 A31=arr[2,0]
500 A12=arr[0,1]
501 A22=arr[1,1]
502 A32=arr[2,1]
503 A13=arr[0,2]
504 A23=arr[1,2]
505 A33=arr[2,2]
506 D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
507 if D.is_zero:
508 raise ZeroDivisionError("inverse: Symbol not invertible")
509 D=1./D
510 out[0,0]=(A22*A33-A23*A32)*D
511 out[1,0]=(A31*A23-A21*A33)*D
512 out[2,0]=(A21*A32-A31*A22)*D
513 out[0,1]=(A13*A32-A12*A33)*D
514 out[1,1]=(A11*A33-A31*A13)*D
515 out[2,1]=(A12*A31-A11*A32)*D
516 out[0,2]=(A12*A23-A13*A22)*D
517 out[1,2]=(A13*A21-A11*A23)*D
518 out[2,2]=(A11*A22-A12*A21)*D
519 else:
520 raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
521 return Symbol(out, dim=self._dim, subs=self._subs)
522
523 def swap_axes(self, axis0, axis1):
524 """
525 """
526 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
527
528 def tensorProduct(self, other, axis_offset):
529 """
530 """
531 arg0_c=self._arr.copy()
532 sh0=self.getShape()
533 if isinstance(other, Symbol):
534 arg1_c=other._arr.copy()
535 sh1=other.getShape()
536 dim=other._dim if self._dim < 0 else self._dim
537 else:
538 arg1_c=other.copy()
539 sh1=other.shape
540 dim=self._dim
541 d0,d1,d01=1,1,1
542 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
543 for i in sh1[axis_offset:]: d1*=i
544 for i in sh1[:axis_offset]: d01*=i
545 arg0_c.resize((d0,d01))
546 arg1_c.resize((d01,d1))
547 out=numpy.zeros((d0,d1),numpy.object)
548 for i0 in range(d0):
549 for i1 in range(d1):
550 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
551 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
552 subs=self._subs.copy()
553 subs.update(other._subs)
554 return Symbol(out, dim=dim, subs=subs)
555
556 def transposedTensorProduct(self, other, axis_offset):
557 """
558 """
559 arg0_c=self._arr.copy()
560 sh0=self.getShape()
561 if isinstance(other, Symbol):
562 arg1_c=other._arr.copy()
563 sh1=other.getShape()
564 dim=other._dim if self._dim < 0 else self._dim
565 else:
566 arg1_c=other.copy()
567 sh1=other.shape
568 dim=self._dim
569 d0,d1,d01=1,1,1
570 for i in sh0[axis_offset:]: d0*=i
571 for i in sh1[axis_offset:]: d1*=i
572 for i in sh1[:axis_offset]: d01*=i
573 arg0_c.resize((d01,d0))
574 arg1_c.resize((d01,d1))
575 out=numpy.zeros((d0,d1),numpy.object)
576 for i0 in range(d0):
577 for i1 in range(d1):
578 out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
579 out.resize(sh0[axis_offset:]+sh1[axis_offset:])
580 subs=self._subs.copy()
581 subs.update(other._subs)
582 return Symbol(out, dim=dim, subs=subs)
583
584 def tensorTransposedProduct(self, other, axis_offset):
585 """
586 """
587 arg0_c=self._arr.copy()
588 sh0=self.getShape()
589 if isinstance(other, Symbol):
590 arg1_c=other._arr.copy()
591 sh1=other.getShape()
592 r1=other.getRank()
593 dim=other._dim if self._dim < 0 else self._dim
594 else:
595 arg1_c=other.copy()
596 sh1=other.shape
597 r1=other.ndim
598 dim=self._dim
599 d0,d1,d01=1,1,1
600 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
601 for i in sh1[:r1-axis_offset]: d1*=i
602 for i in sh1[r1-axis_offset:]: d01*=i
603 arg0_c.resize((d0,d01))
604 arg1_c.resize((d1,d01))
605 out=numpy.zeros((d0,d1),numpy.object)
606 for i0 in range(d0):
607 for i1 in range(d1):
608 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
609 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
610 subs=self._subs.copy()
611 subs.update(other._subs)
612 return Symbol(out, dim=dim, subs=subs)
613
614 def trace(self, axis_offset):
615 """
616 Returns the trace of this Symbol.
617 """
618 sh=self.getShape()
619 s1=1
620 for i in range(axis_offset): s1*=sh[i]
621 s2=1
622 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
623 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
624 out=numpy.zeros([s1,s2],object)
625 for i1 in range(s1):
626 for i2 in range(s2):
627 for j in range(sh[axis_offset]):
628 out[i1,i2]+=arr_r[i1,j,j,i2]
629 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
630 return Symbol(out, dim=self._dim, subs=self._subs)
631
632 def transpose(self, axis_offset):
633 """
634 Returns the transpose of this Symbol.
635 """
636 if axis_offset is None:
637 axis_offset=int(self._arr.ndim/2)
638 axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
639 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
640
641 def applyfunc(self, f, on_type=None):
642 """
643 Applies the function `f` to all elements (if on_type is None) or to
644 all elements of type `on_type`.
645 """
646 assert callable(f)
647 if self._arr.ndim==0:
648 if on_type is None or isinstance(self._arr.item(), on_type):
649 el=f(self._arr.item())
650 else:
651 el=self._arr.item()
652 if el is not None:
653 out=Symbol(el, dim=self._dim, subs=self._subs)
654 else:
655 return el
656 else:
657 out=numpy.empty(self.getShape(), dtype=object)
658 for idx in numpy.ndindex(self.getShape()):
659 if on_type is None or isinstance(self._arr[idx],on_type):
660 out[idx]=f(self._arr[idx])
661 else:
662 out[idx]=self._arr[idx]
663 out=Symbol(out, dim=self._dim, subs=self._subs)
664 return out
665
666 def expand(self):
667 """
668 Applies the sympy.expand operation on all elements in this symbol
669 """
670 return self.applyfunc(sympy.expand, sympy.Basic)
671
672 def simplify(self):
673 """
674 Applies the sympy.simplify operation on all elements in this symbol
675 """
676 return self.applyfunc(sympy.simplify, sympy.Basic)
677
678 # unary/binary operators follow
679
680 def __pos__(self):
681 return self
682
683 def __neg__(self):
684 return Symbol(-self._arr, dim=self._dim, subs=self._subs)
685
686 def __abs__(self):
687 return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
688
689 def _ensureShapeCompatible(self, other):
690 """
691 Checks for compatible shapes for binary operations.
692 Raises TypeError if not compatible.
693 """
694 sh0=self.getShape()
695 if isinstance(other, Symbol) or isinstance(other, Data):
696 sh1=other.getShape()
697 elif isinstance(other, numpy.ndarray):
698 sh1=other.shape
699 elif isinstance(other, list):
700 sh1=numpy.array(other).shape
701 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
702 sh1=()
703 else:
704 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
705 if not sh0==sh1 and not sh0==() and not sh1==():
706 raise TypeError("Incompatible shapes for operation")
707
708 def __binaryop(self, op, other):
709 """
710 Helper for binary operations that checks types, shapes etc.
711 """
712 self._ensureShapeCompatible(other)
713 if isinstance(other, Symbol):
714 subs=self._subs.copy()
715 subs.update(other._subs)
716 dim=other._dim if self._dim < 0 else self._dim
717 return Symbol(getattr(self._arr, op)(other._arr), dim=dim, subs=subs)
718 if isinstance(other, Data):
719 name='data'+str(id(other))
720 othersym=Symbol(name, other.getShape(), dim=self._dim)
721 subs=self._subs.copy()
722 subs.update({Symbol(name):other})
723 return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
724 return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
725
726 def __add__(self, other):
727 return self.__binaryop('__add__', other)
728
729 def __radd__(self, other):
730 return self.__binaryop('__radd__', other)
731
732 def __sub__(self, other):
733 return self.__binaryop('__sub__', other)
734
735 def __rsub__(self, other):
736 return self.__binaryop('__rsub__', other)
737
738 def __mul__(self, other):
739 return self.__binaryop('__mul__', other)
740
741 def __rmul__(self, other):
742 return self.__binaryop('__rmul__', other)
743
744 def __div__(self, other):
745 return self.__binaryop('__div__', other)
746
747 def __rdiv__(self, other):
748 return self.__binaryop('__rdiv__', other)
749
750 def __pow__(self, other):
751 return self.__binaryop('__pow__', other)
752
753 def __rpow__(self, other):
754 return self.__binaryop('__rpow__', other)
755
756 # static methods
757
758 @staticmethod
759 def _symComp(sym):
760 """
761 """
762 n=sym.name
763 a=n.split('[')
764 if len(a)!=2:
765 return n,()
766 a=a[1].split(']')
767 if len(a)!=2:
768 return n,()
769 name=a[0]
770 comps=[int(i) for i in a[1].split('_')[1:]]
771 return name,tuple(comps)
772
773 @staticmethod
774 def _symbolgen(*symbols):
775 """
776 Generator of all symbols in the argument of diff().
777 (cf. sympy.Derivative._symbolgen)
778
779 Example:
780 >> ._symbolgen(x, 3, y)
781 (x, x, x, y)
782 >> ._symbolgen(x, 10**6)
783 (x, x, x, x, x, x, x, ...)
784 """
785 from itertools import repeat
786 last_s = symbols[len(symbols)-1]
787 if not isinstance(last_s, Symbol):
788 last_s=sympy.sympify(last_s)
789 for i in xrange(len(symbols)):
790 s = symbols[i]
791 if not isinstance(s, Symbol):
792 s=sympy.sympify(s)
793 next_s = None
794 if s != last_s:
795 next_s = symbols[i+1]
796 if not isinstance(next_s, Symbol):
797 next_s=sympy.sympify(next_s)
798
799 if isinstance(s, sympy.Integer):
800 continue
801 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
802 # handle cases like (x, 3)
803 if isinstance(next_s, sympy.Integer):
804 # yield (x, x, x)
805 for copy_s in repeat(s,int(next_s)):
806 yield copy_s
807 else:
808 yield s
809 else:
810 yield s
811

  ViewVC Help
Powered by ViewVC 1.1.26