/[escript]/branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
ViewVC logotype

Contents of /branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3872 - (show annotations)
Fri Mar 16 00:48:46 2012 UTC (5 years, 9 months ago) by caltinay
File MIME type: text/x-python
File size: 28546 byte(s)
Symbols now have undefined dimensionality unless specified.

1
2 ########################################################
3 #
4 # Copyright (c) 2003-2012 by University of Queensland
5 # Earth Systems Science Computational Center (ESSCC)
6 # http://www.uq.edu.au/esscc
7 #
8 # Primary Business: Queensland, Australia
9 # Licensed under the Open Software License version 3.0
10 # http://www.opensource.org/licenses/osl-3.0.php
11 #
12 ########################################################
13
14 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15 Earth Systems Science Computational Center (ESSCC)
16 http://www.uq.edu.au/esscc
17 Primary Business: Queensland, Australia"""
18 __license__="""Licensed under the Open Software License version 3.0
19 http://www.opensource.org/licenses/osl-3.0.php"""
20 __url__="https://launchpad.net/escript-finley"
21 __author__="Cihan Altinay"
22
23 """
24 :var __author__: name of author
25 :var __copyright__: copyrights
26 :var __license__: licence agreement
27 :var __url__: url entry point on documentation
28 :var __version__: version
29 :var __date__: date of the version
30 """
31
32 import numpy
33 import sympy
34 from esys.escript import Data, FunctionSpace
35
36
37 class Symbol(object):
38 """
39 `Symbol` objects are placeholders for a single mathematical symbol, such as
40 'x', or for arbitrarily complex mathematical expressions such as
41 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
42 are also `Symbol`s (the symbolic 'atoms' of the expression).
43
44 With the help of the 'Evaluator' class these symbols and expressions can
45 be resolved by substituting numeric values and/or escript `Data` objects
46 for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
47 shape (and thus a rank) as well as a dimension (see constructor).
48 `Symbol`s are useful to perform mathematical simplifications, compute
49 derivatives and as coefficients for nonlinear PDEs which can be solved by
50 the `NonlinearPDE` class.
51 """
52
53 # these are for compatibility with sympy.Symbol. lambdify checks these.
54 is_Add=False
55 is_Float=False
56
57 def __init__(self, *args, **kwargs):
58 """
59 Initialises a new `Symbol` object in one of three ways::
60
61 u=Symbol('u')
62
63 returns a scalar symbol by the name 'u'.
64
65 a=Symbol('alpha', (4,3))
66
67 returns a rank 2 symbol with the shape (4,3), whose elements are
68 named '[alpha]_i_j' (with i=0..3, j=0..2).
69
70 a,b,c=symbols('a,b,c')
71 x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
72
73 returns a rank 2 symbol with the shape (3,3) whose elements are
74 explicitly specified by numeric values and other symbols/expressions
75 within a list or numpy array.
76
77 The dimensionality of the symbol can be specified through the `dim`
78 keyword. All other keywords are passed to the underlying symbolic
79 library (currently sympy).
80
81 :param args: initialisation arguments as described above
82 :keyword dim: dimensionality of the new Symbol (default: 2)
83 :type dim: ``int``
84 """
85 if 'dim' in kwargs:
86 self._dim=kwargs.pop('dim')
87 else:
88 self._dim=-1 # undefined
89
90 if 'subs' in kwargs:
91 self._subs=kwargs.pop('subs')
92 else:
93 self._subs={}
94
95 if len(args)==1:
96 arg=args[0]
97 if isinstance(arg, str):
98 if arg.find('[')>=0 or arg.find(']')>=0:
99 raise ValueError("Name must not contain '[' or ']'")
100 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
101 elif hasattr(arg, "__array__") or isinstance(arg, list):
102 if isinstance(arg, list): arg=numpy.array(arg)
103 arr=arg.__array__()
104 if len(arr.shape)>4:
105 raise ValueError("Symbol only supports tensors up to order 4")
106 res=numpy.empty(arr.shape, dtype=object)
107 for idx in numpy.ndindex(arr.shape):
108 if hasattr(arr[idx], "item"):
109 res[idx]=arr[idx].item()
110 else:
111 res[idx]=arr[idx]
112 self._arr=res
113 if isinstance(arg, Symbol):
114 self._subs.update(arg._subs)
115 if self._dim==-1:
116 self._dim=arg._dim
117 elif isinstance(arg, sympy.Basic):
118 self._arr=numpy.array(arg)
119 else:
120 raise TypeError("Unsupported argument type %s"%str(type(arg)))
121 elif len(args)==2:
122 if not isinstance(args[0], str):
123 raise TypeError("First argument must be a string")
124 if not isinstance(args[1], tuple):
125 raise TypeError("Second argument must be a tuple")
126 name=args[0]
127 shape=args[1]
128 if name.find('[')>=0 or name.find(']')>=0:
129 raise ValueError("Name must not contain '[' or ']'")
130 if len(shape)>4:
131 raise ValueError("Symbol only supports tensors up to order 4")
132 if len(shape)==0:
133 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
134 else:
135 try:
136 self._arr=sympy.symarray(shape, '['+name+']')
137 except TypeError:
138 self._arr=sympy.symarray('['+name+']', shape)
139 else:
140 raise TypeError("Unsupported number of arguments")
141 if self._arr.ndim==0:
142 self.name=str(self._arr.item())
143 else:
144 self.name=str(self._arr.tolist())
145
146 def __repr__(self):
147 return str(self._arr)
148
149 def __str__(self):
150 return str(self._arr)
151
152 def __eq__(self, other):
153 if type(self) is not type(other):
154 return False
155 if self.getRank()!=other.getRank():
156 return False
157 if self.getShape()!=other.getShape():
158 return False
159 return (self._arr==other._arr).all()
160
161 def __getitem__(self, key):
162 """
163 Returns an element of this symbol which must have rank >0.
164 Unlike item() this method converts sympy objects and numpy arrays into
165 escript Symbols in order to facilitate expressions that require
166 element access, such as: grad(u)[1]+x
167
168 :param key: (nd-)index of item to be returned
169 :return: the requested element
170 :rtype: ``Symbol``, ``int``, or ``float``
171 """
172 res=self._arr[key]
173 # replace sympy Symbols/expressions by escript Symbols
174 if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
175 res=Symbol(res)
176 return res
177
178 def __setitem__(self, key, value):
179 if isinstance(value, Symbol):
180 if value.getRank()==0:
181 self._arr[key]=value.item()
182 elif hasattr(self._arr[key], "shape"):
183 if self._arr[key].shape==value.getShape():
184 for idx in numpy.ndindex(self._arr[key].shape):
185 self._arr[key][idx]=value[idx].item()
186 else:
187 raise ValueError("Wrong shape of value")
188 else:
189 raise ValueError("Wrong shape of value")
190 elif isinstance(value, sympy.Basic):
191 self._arr[key]=value
192 elif hasattr(value, "__array__"):
193 self._arr[key]=map(sympy.sympify,value.flat)
194 else:
195 self._arr[key]=sympy.sympify(value)
196
197 def __iter__(self):
198 return self._arr.__iter__
199
200 def __array__(self):
201 return self._arr
202
203 def _sympy_(self):
204 """
205 """
206 return self.applyfunc(sympy.sympify)
207
208 def getDim(self):
209 """
210 Returns the spatial dimensionality of this symbol.
211
212 :return: the symbol's spatial dimensionality, or -1 if undefined
213 :rtype: ``int``
214 """
215 return self._dim
216
217 def getRank(self):
218 """
219 Returns the rank of this symbol.
220
221 :return: the symbol's rank which is equal to the length of the shape.
222 :rtype: ``int``
223 """
224 return self._arr.ndim
225
226 def getShape(self):
227 """
228 Returns the shape of this symbol.
229
230 :return: the symbol's shape
231 :rtype: ``tuple`` of ``int``
232 """
233 return self._arr.shape
234
235 def getDataSubstitutions(self):
236 """
237 Returns a dictionary of symbol names and the escript ``Data`` objects
238 they represent within this Symbol.
239
240 :return: the dictionary of substituted ``Data`` objects
241 :rtype: ``dict``
242 """
243 return self._subs
244
245 def item(self, *args):
246 """
247 Returns an element of this symbol.
248 This method behaves like the item() method of numpy.ndarray.
249 If this is a scalar Symbol, no arguments are allowed and the only
250 element in this Symbol is returned.
251 Otherwise, 'args' specifies a flat or nd-index and the element at
252 that index is returned.
253
254 :param args: index of item to be returned
255 :return: the requested element
256 :rtype: ``sympy.Symbol``, ``int``, or ``float``
257 """
258 return self._arr.item(args)
259
260 def atoms(self, *types):
261 """
262 Returns the atoms that form the current Symbol.
263
264 By default, only objects that are truly atomic and cannot be divided
265 into smaller pieces are returned: symbols, numbers, and number
266 symbols like I and pi. It is possible to request atoms of any type,
267 however.
268
269 Note that if this symbol contains components such as [x]_i_j then
270 only their main symbol 'x' is returned.
271
272 :param types: types to restrict result to
273 :return: list of atoms of specified type
274 :rtype: ``set``
275 """
276 s=set()
277 for el in self._arr.flat:
278 if isinstance(el,sympy.Basic):
279 atoms=el.atoms(*types)
280 for a in atoms:
281 if a.is_Symbol:
282 n,c=Symbol._symComp(a)
283 s.add(sympy.Symbol(n))
284 else:
285 s.add(a)
286 elif len(types)==0 or type(el) in types:
287 s.add(el)
288 return s
289
290 def _sympystr_(self, printer):
291 # compatibility with sympy 1.6
292 return self._sympystr(printer)
293
294 def _sympystr(self, printer):
295 return self.lambdarepr()
296
297 def lambdarepr(self):
298 """
299 """
300 from sympy.printing.lambdarepr import lambdarepr
301 temp_arr=numpy.empty(self.getShape(), dtype=object)
302 for idx,el in numpy.ndenumerate(self._arr):
303 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
304 # create a dictionary to convert names like [x]_0_0 to x[0,0]
305 symdict={}
306 for a in atoms:
307 n,c=Symbol._symComp(a)
308 if len(c)>0:
309 c=[str(i) for i in c]
310 symstr=n+'['+','.join(c)+']'
311 else:
312 symstr=n
313 symdict[a.name]=symstr
314 s=lambdarepr(el)
315 for key in symdict:
316 s=s.replace(key, symdict[key])
317 temp_arr[idx]=s
318 if self.getRank()==0:
319 return temp_arr.item()
320 else:
321 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
322
323 def coeff(self, x, expand=True):
324 """
325 Returns the coefficient of the term "x" or 0 if there is no "x".
326
327 If "x" is a scalar symbol then "x" is searched in all components of
328 this symbol. Otherwise the shapes must match and the coefficients are
329 checked component by component.
330
331 Example::
332
333 x=Symbol('x', (2,2))
334 y=3*x
335 print y.coeff(x)
336 print y.coeff(x[1,1])
337
338 will print::
339
340 [[3 3]
341 [3 3]]
342
343 [[0 0]
344 [0 3]]
345
346 :param x: the term whose coefficients are to be found
347 :type x: ``Symbol``, ``numpy.ndarray``, `list`
348 :return: the coefficient(s) of the term
349 :rtype: ``Symbol``
350 """
351 self._ensureShapeCompatible(x)
352 if hasattr(x, '__array__'):
353 y=x.__array__()
354 else:
355 y=numpy.array(x)
356
357 if y.ndim>0:
358 result=numpy.zeros(self.getShape(), dtype=object)
359 for idx in numpy.ndindex(y.shape):
360 if y[idx]!=0:
361 res=self._arr[idx].coeff(y[idx], expand)
362 if res is not None:
363 result[idx]=res
364 elif y.item()==0:
365 result=numpy.zeros(self.getShape(), dtype=object)
366 else:
367 coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
368 none_to_zero=lambda item: 0 if item is None else item
369 result=self.applyfunc(coeff_item)
370 result=result.applyfunc(none_to_zero)
371 res=Symbol(result, dim=self._dim)
372 for i in self._subs: res.subs(i, self._subs[i])
373 return res
374
375 def subs(self, old, new):
376 """
377 Substitutes an expression.
378 """
379 old._ensureShapeCompatible(new)
380 if isinstance(new, Data):
381 subs=self._subs.copy()
382 if isinstance(old, Symbol) and old.getRank()>0:
383 old=Symbol(old.atoms(sympy.Symbol)[0])
384 subs[old]=new
385 result=Symbol(self._arr, dim=self._dim, subs=subs)
386 elif isinstance(old, Symbol) and old.getRank()>0:
387 if hasattr(new, '__array__'):
388 new=new.__array__()
389 else:
390 new=numpy.array(new)
391
392 result=numpy.empty(self.getShape(), dtype=object)
393 if new.ndim>0:
394 for idx in numpy.ndindex(self.getShape()):
395 for symidx in numpy.ndindex(new.shape):
396 result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
397 else: # substitute scalar for non-scalar
398 for idx in numpy.ndindex(self.getShape()):
399 for symidx in numpy.ndindex(old.getShape()):
400 result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
401 result=Symbol(result, dim=self._dim, subs=self._subs)
402 else: # scalar
403 if isinstance(new, Symbol):
404 new=new.item()
405 if isinstance(old, Symbol):
406 old=old.item()
407 subs_item=lambda item: getattr(item, 'subs')(old, new)
408 result=self.applyfunc(subs_item)
409 return result
410
411 def diff(self, *symbols, **assumptions):
412 """
413 """
414 symbols=Symbol._symbolgen(*symbols)
415 result=Symbol(self._arr, dim=self._dim, subs=self._subs)
416 for s in symbols:
417 if isinstance(s, Symbol):
418 if s.getRank()==0:
419 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
420 result=result.applyfunc(diff_item)
421 elif s.getRank()==1:
422 dim=s.getShape()[0]
423 out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
424 for d in range(dim):
425 for idx in numpy.ndindex(self.getShape()):
426 index=idx+(d,)
427 out[index]=out[index].diff(s[d].item(), **assumptions)
428 result=Symbol(out, dim=self._dim, subs=self._subs)
429 else:
430 raise ValueError("diff: argument must have rank 0 or 1")
431 else:
432 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
433 result=result.applyfunc(diff_item)
434 return result
435
436 def grad(self, where=None):
437 """
438 Returns a symbol which represents the gradient of this symbol.
439 :type where: ``Symbol``, ``FunctionSpace``
440 """
441 if self._dim < 0:
442 raise ValueError("grad: cannot compute gradient as symbol has undefined dimensionality")
443 subs=self._subs
444 if isinstance(where, Symbol):
445 if where.getRank()>0:
446 raise ValueError("grad: 'where' must be a scalar symbol")
447 where=where._arr.item()
448 elif isinstance(where, FunctionSpace):
449 name='fs'+str(id(where))
450 fssym=Symbol(name)
451 subs=self._subs.copy()
452 subs.update({fssym:where})
453 where=name
454
455 from functions import grad_n
456 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
457 for d in range(self._dim):
458 for idx in numpy.ndindex(self.getShape()):
459 index=idx+(d,)
460 if where is None:
461 out[index]=grad_n(out[index],d)
462 else:
463 out[index]=grad_n(out[index],d,where)
464 return Symbol(out, dim=self._dim, subs=subs)
465
466 def inverse(self):
467 """
468 """
469 if not self.getRank()==2:
470 raise TypeError("inverse: Only rank 2 supported")
471 s=self.getShape()
472 if not s[0] == s[1]:
473 raise ValueError("inverse: Only square shapes supported")
474 out=numpy.zeros(s, numpy.object)
475 arr=self._arr
476 if s[0]==1:
477 if arr[0,0].is_zero:
478 raise ZeroDivisionError("inverse: Symbol not invertible")
479 out[0,0]=1./arr[0,0]
480 elif s[0]==2:
481 A11=arr[0,0]
482 A12=arr[0,1]
483 A21=arr[1,0]
484 A22=arr[1,1]
485 D = A11*A22-A12*A21
486 if D.is_zero:
487 raise ZeroDivisionError("inverse: Symbol not invertible")
488 D=1./D
489 out[0,0]= A22*D
490 out[1,0]=-A21*D
491 out[0,1]=-A12*D
492 out[1,1]= A11*D
493 elif s[0]==3:
494 A11=arr[0,0]
495 A21=arr[1,0]
496 A31=arr[2,0]
497 A12=arr[0,1]
498 A22=arr[1,1]
499 A32=arr[2,1]
500 A13=arr[0,2]
501 A23=arr[1,2]
502 A33=arr[2,2]
503 D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
504 if D.is_zero:
505 raise ZeroDivisionError("inverse: Symbol not invertible")
506 D=1./D
507 out[0,0]=(A22*A33-A23*A32)*D
508 out[1,0]=(A31*A23-A21*A33)*D
509 out[2,0]=(A21*A32-A31*A22)*D
510 out[0,1]=(A13*A32-A12*A33)*D
511 out[1,1]=(A11*A33-A31*A13)*D
512 out[2,1]=(A12*A31-A11*A32)*D
513 out[0,2]=(A12*A23-A13*A22)*D
514 out[1,2]=(A13*A21-A11*A23)*D
515 out[2,2]=(A11*A22-A12*A21)*D
516 else:
517 raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
518 return Symbol(out, dim=self._dim, subs=self._subs)
519
520 def swap_axes(self, axis0, axis1):
521 """
522 """
523 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
524
525 def tensorProduct(self, other, axis_offset):
526 """
527 """
528 arg0_c=self._arr.copy()
529 sh0=self.getShape()
530 if isinstance(other, Symbol):
531 arg1_c=other._arr.copy()
532 sh1=other.getShape()
533 dim=other._dim if self._dim < 0 else self._dim
534 else:
535 arg1_c=other.copy()
536 sh1=other.shape
537 dim=self._dim
538 d0,d1,d01=1,1,1
539 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
540 for i in sh1[axis_offset:]: d1*=i
541 for i in sh1[:axis_offset]: d01*=i
542 arg0_c.resize((d0,d01))
543 arg1_c.resize((d01,d1))
544 out=numpy.zeros((d0,d1),numpy.object)
545 for i0 in range(d0):
546 for i1 in range(d1):
547 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
548 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
549 subs=self._subs.copy()
550 subs.update(other._subs)
551 return Symbol(out, dim=dim, subs=subs)
552
553 def transposedTensorProduct(self, other, axis_offset):
554 """
555 """
556 arg0_c=self._arr.copy()
557 sh0=self.getShape()
558 if isinstance(other, Symbol):
559 arg1_c=other._arr.copy()
560 sh1=other.getShape()
561 dim=other._dim if self._dim < 0 else self._dim
562 else:
563 arg1_c=other.copy()
564 sh1=other.shape
565 dim=self._dim
566 d0,d1,d01=1,1,1
567 for i in sh0[axis_offset:]: d0*=i
568 for i in sh1[axis_offset:]: d1*=i
569 for i in sh1[:axis_offset]: d01*=i
570 arg0_c.resize((d01,d0))
571 arg1_c.resize((d01,d1))
572 out=numpy.zeros((d0,d1),numpy.object)
573 for i0 in range(d0):
574 for i1 in range(d1):
575 out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
576 out.resize(sh0[axis_offset:]+sh1[axis_offset:])
577 subs=self._subs.copy()
578 subs.update(other._subs)
579 return Symbol(out, dim=dim, subs=subs)
580
581 def tensorTransposedProduct(self, other, axis_offset):
582 """
583 """
584 arg0_c=self._arr.copy()
585 sh0=self.getShape()
586 if isinstance(other, Symbol):
587 arg1_c=other._arr.copy()
588 sh1=other.getShape()
589 r1=other.getRank()
590 dim=other._dim if self._dim < 0 else self._dim
591 else:
592 arg1_c=other.copy()
593 sh1=other.shape
594 r1=other.ndim
595 dim=self._dim
596 d0,d1,d01=1,1,1
597 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
598 for i in sh1[:r1-axis_offset]: d1*=i
599 for i in sh1[r1-axis_offset:]: d01*=i
600 arg0_c.resize((d0,d01))
601 arg1_c.resize((d1,d01))
602 out=numpy.zeros((d0,d1),numpy.object)
603 for i0 in range(d0):
604 for i1 in range(d1):
605 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
606 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
607 subs=self._subs.copy()
608 subs.update(other._subs)
609 return Symbol(out, dim=dim, subs=subs)
610
611 def trace(self, axis_offset):
612 """
613 Returns the trace of this Symbol.
614 """
615 sh=self.getShape()
616 s1=1
617 for i in range(axis_offset): s1*=sh[i]
618 s2=1
619 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
620 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
621 out=numpy.zeros([s1,s2],object)
622 for i1 in range(s1):
623 for i2 in range(s2):
624 for j in range(sh[axis_offset]):
625 out[i1,i2]+=arr_r[i1,j,j,i2]
626 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
627 return Symbol(out, dim=self._dim, subs=self._subs)
628
629 def transpose(self, axis_offset):
630 """
631 Returns the transpose of this Symbol.
632 """
633 if axis_offset is None:
634 axis_offset=int(self._arr.ndim/2)
635 axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
636 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
637
638 def applyfunc(self, f, on_type=None):
639 """
640 Applies the function `f` to all elements (if on_type is None) or to
641 all elements of type `on_type`.
642 """
643 assert callable(f)
644 if self._arr.ndim==0:
645 if on_type is None or isinstance(self._arr.item(), on_type):
646 el=f(self._arr.item())
647 else:
648 el=self._arr.item()
649 if el is not None:
650 out=Symbol(el, dim=self._dim, subs=self._subs)
651 else:
652 return el
653 else:
654 out=numpy.empty(self.getShape(), dtype=object)
655 for idx in numpy.ndindex(self.getShape()):
656 if on_type is None or isinstance(self._arr[idx],on_type):
657 out[idx]=f(self._arr[idx])
658 else:
659 out[idx]=self._arr[idx]
660 out=Symbol(out, dim=self._dim, subs=self._subs)
661 return out
662
663 def expand(self):
664 """
665 Applies the sympy.expand operation on all elements in this symbol
666 """
667 return self.applyfunc(sympy.expand, sympy.Basic)
668
669 def simplify(self):
670 """
671 Applies the sympy.simplify operation on all elements in this symbol
672 """
673 return self.applyfunc(sympy.simplify, sympy.Basic)
674
675 # unary/binary operators follow
676
677 def __pos__(self):
678 return self
679
680 def __neg__(self):
681 return Symbol(-self._arr, dim=self._dim, subs=self._subs)
682
683 def __abs__(self):
684 return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
685
686 def _ensureShapeCompatible(self, other):
687 """
688 Checks for compatible shapes for binary operations.
689 Raises TypeError if not compatible.
690 """
691 sh0=self.getShape()
692 if isinstance(other, Symbol) or isinstance(other, Data):
693 sh1=other.getShape()
694 elif isinstance(other, numpy.ndarray):
695 sh1=other.shape
696 elif isinstance(other, list):
697 sh1=numpy.array(other).shape
698 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
699 sh1=()
700 else:
701 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
702 if not sh0==sh1 and not sh0==() and not sh1==():
703 raise TypeError("Incompatible shapes for operation")
704
705 def __binaryop(self, op, other):
706 """
707 Helper for binary operations that checks types, shapes etc.
708 """
709 self._ensureShapeCompatible(other)
710 if isinstance(other, Symbol):
711 subs=self._subs.copy()
712 subs.update(other._subs)
713 dim=other._dim if self._dim < 0 else self._dim
714 return Symbol(getattr(self._arr, op)(other._arr), dim=dim, subs=subs)
715 if isinstance(other, Data):
716 name='data'+str(id(other))
717 othersym=Symbol(name, other.getShape(), dim=self._dim)
718 subs=self._subs.copy()
719 subs.update({Symbol(name):other})
720 return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
721 return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
722
723 def __add__(self, other):
724 return self.__binaryop('__add__', other)
725
726 def __radd__(self, other):
727 return self.__binaryop('__radd__', other)
728
729 def __sub__(self, other):
730 return self.__binaryop('__sub__', other)
731
732 def __rsub__(self, other):
733 return self.__binaryop('__rsub__', other)
734
735 def __mul__(self, other):
736 return self.__binaryop('__mul__', other)
737
738 def __rmul__(self, other):
739 return self.__binaryop('__rmul__', other)
740
741 def __div__(self, other):
742 return self.__binaryop('__div__', other)
743
744 def __rdiv__(self, other):
745 return self.__binaryop('__rdiv__', other)
746
747 def __pow__(self, other):
748 return self.__binaryop('__pow__', other)
749
750 def __rpow__(self, other):
751 return self.__binaryop('__rpow__', other)
752
753 # static methods
754
755 @staticmethod
756 def _symComp(sym):
757 """
758 """
759 n=sym.name
760 a=n.split('[')
761 if len(a)!=2:
762 return n,()
763 a=a[1].split(']')
764 if len(a)!=2:
765 return n,()
766 name=a[0]
767 comps=[int(i) for i in a[1].split('_')[1:]]
768 return name,tuple(comps)
769
770 @staticmethod
771 def _symbolgen(*symbols):
772 """
773 Generator of all symbols in the argument of diff().
774 (cf. sympy.Derivative._symbolgen)
775
776 Example:
777 >> ._symbolgen(x, 3, y)
778 (x, x, x, y)
779 >> ._symbolgen(x, 10**6)
780 (x, x, x, x, x, x, x, ...)
781 """
782 from itertools import repeat
783 last_s = symbols[len(symbols)-1]
784 if not isinstance(last_s, Symbol):
785 last_s=sympy.sympify(last_s)
786 for i in xrange(len(symbols)):
787 s = symbols[i]
788 if not isinstance(s, Symbol):
789 s=sympy.sympify(s)
790 next_s = None
791 if s != last_s:
792 next_s = symbols[i+1]
793 if not isinstance(next_s, Symbol):
794 next_s=sympy.sympify(next_s)
795
796 if isinstance(s, sympy.Integer):
797 continue
798 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
799 # handle cases like (x, 3)
800 if isinstance(next_s, sympy.Integer):
801 # yield (x, x, x)
802 for copy_s in repeat(s,int(next_s)):
803 yield copy_s
804 else:
805 yield s
806 else:
807 yield s
808

  ViewVC Help
Powered by ViewVC 1.1.26