/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Contents of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4019 - (show annotations)
Thu Oct 11 08:12:55 2012 UTC (6 years, 8 months ago) by jfenwick
File MIME type: text/x-python
File size: 29146 byte(s)
More tabbing errors,
range/xrange
...
1
2 ##############################################################################
3 #
4 # Copyright (c) 2003-2012 by University of Queensland
5 # http://www.uq.edu.au
6 #
7 # Primary Business: Queensland, Australia
8 # Licensed under the Open Software License version 3.0
9 # http://www.opensource.org/licenses/osl-3.0.php
10 #
11 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 # Development since 2012 by School of Earth Sciences
13 #
14 ##############################################################################
15
16 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
17 http://www.uq.edu.au
18 Primary Business: Queensland, Australia"""
19 __license__="""Licensed under the Open Software License version 3.0
20 http://www.opensource.org/licenses/osl-3.0.php"""
21 __url__="https://launchpad.net/escript-finley"
22 __author__="Cihan Altinay"
23
24 """
25 :var __author__: name of author
26 :var __copyright__: copyrights
27 :var __license__: licence agreement
28 :var __url__: url entry point on documentation
29 :var __version__: version
30 :var __date__: date of the version
31 """
32
33 import numpy
34 from esys.escript import Data, FunctionSpace, HAVE_SYMBOLS
35 if HAVE_SYMBOLS:
36 import sympy
37
38
39 class Symbol(object):
40 """
41 `Symbol` objects are placeholders for a single mathematical symbol, such as
42 'x', or for arbitrarily complex mathematical expressions such as
43 'c*x**4+alpha*exp(x)-2*sin(beta*x)', where 'alpha', 'beta', 'c', and 'x'
44 are also Symbols (the symbolic 'atoms' of the expression).
45
46 With the help of the 'Evaluator' class these symbols and expressions can
47 be resolved by substituting numeric values and/or escript `Data` objects
48 for the atoms. To facilitate the use of `Data` objects a `Symbol` has a
49 shape (and thus a rank) as well as a dimension (see constructor).
50 Symbols are useful to perform mathematical simplifications, compute
51 derivatives and as coefficients for nonlinear PDEs which can be solved by
52 the `NonlinearPDE` class.
53 """
54
55 # these are for compatibility with sympy.Symbol. lambdify checks these.
56 is_Add=False
57 is_Float=False
58
59 def __init__(self, *args, **kwargs):
60 """
61 Initialises a new `Symbol` object in one of three ways::
62
63 u=Symbol('u')
64
65 returns a scalar symbol by the name 'u'.
66
67 a=Symbol('alpha', (4,3))
68
69 returns a rank 2 symbol with the shape (4,3), whose elements are
70 named '[alpha]_i_j' (with i=0..3, j=0..2).
71
72 a,b,c=symbols('a,b,c')
73 x=Symbol([[a+b,0,0],[0,b-c,0],[0,0,c-a]])
74
75 returns a rank 2 symbol with the shape (3,3) whose elements are
76 explicitly specified by numeric values and other symbols/expressions
77 within a list or numpy array.
78
79 The dimensionality of the symbol can be specified through the ``dim``
80 keyword. All other keywords are passed to the underlying symbolic
81 library (currently sympy).
82
83 :param args: initialisation arguments as described above
84 :keyword dim: dimensionality of the new Symbol (default: 2)
85 :type dim: ``int``
86 """
87 if not HAVE_SYMBOLS:
88 raise RuntimeError("Trying to instantiate a Symbol but sympy not available")
89
90 if 'dim' in kwargs:
91 self._dim=kwargs.pop('dim')
92 else:
93 self._dim=-1 # undefined
94
95 if 'subs' in kwargs:
96 self._subs=kwargs.pop('subs')
97 else:
98 self._subs={}
99
100 if len(args)==1:
101 arg=args[0]
102 if isinstance(arg, str):
103 if arg.find('[')>=0 or arg.find(']')>=0:
104 raise ValueError("Name must not contain '[' or ']'")
105 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
106 elif hasattr(arg, "__array__") or isinstance(arg, list):
107 if isinstance(arg, list): arg=numpy.array(arg)
108 arr=arg.__array__()
109 if len(arr.shape)>4:
110 raise ValueError("Symbol only supports tensors up to order 4")
111 res=numpy.empty(arr.shape, dtype=object)
112 for idx in numpy.ndindex(arr.shape):
113 if hasattr(arr[idx], "item"):
114 res[idx]=arr[idx].item()
115 else:
116 res[idx]=arr[idx]
117 self._arr=res
118 if isinstance(arg, Symbol):
119 self._subs.update(arg._subs)
120 if self._dim==-1:
121 self._dim=arg._dim
122 elif isinstance(arg, sympy.Basic):
123 self._arr=numpy.array(arg)
124 else:
125 raise TypeError("Unsupported argument type %s"%str(type(arg)))
126 elif len(args)==2:
127 if not isinstance(args[0], str):
128 raise TypeError("First argument must be a string")
129 if not isinstance(args[1], tuple):
130 raise TypeError("Second argument must be a tuple")
131 name=args[0]
132 shape=args[1]
133 if name.find('[')>=0 or name.find(']')>=0:
134 raise ValueError("Name must not contain '[' or ']'")
135 if len(shape)>4:
136 raise ValueError("Symbol only supports tensors up to order 4")
137 if len(shape)==0:
138 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
139 else:
140 try:
141 self._arr=sympy.symarray(shape, '['+name+']')
142 except TypeError:
143 self._arr=sympy.symarray('['+name+']', shape)
144 else:
145 raise TypeError("Unsupported number of arguments")
146 if self._arr.ndim==0:
147 self.name=str(self._arr.item())
148 else:
149 self.name=str(self._arr.tolist())
150
151 def __repr__(self):
152 return str(self._arr)
153
154 def __str__(self):
155 return str(self._arr)
156
157 def __eq__(self, other):
158 if type(self) is not type(other):
159 return False
160 if self.getRank()!=other.getRank():
161 return False
162 if self.getShape()!=other.getShape():
163 return False
164 return (self._arr==other._arr).all()
165
166 def __hash__(self):
167 return id(self)
168
169 def __getitem__(self, key):
170 """
171 Returns an element of this symbol which must have rank >0.
172 Unlike item() this method converts sympy objects and numpy arrays into
173 escript Symbols in order to facilitate expressions that require
174 element access, such as: grad(u)[1]+x
175
176 :param key: (nd-)index of item to be returned
177 :return: the requested element
178 :rtype: ``Symbol``, ``int``, or ``float``
179 """
180 res=self._arr[key]
181 # replace sympy Symbols/expressions by escript Symbols
182 if isinstance(res, sympy.Basic) or isinstance(res, numpy.ndarray):
183 res=Symbol(res)
184 return res
185
186 def __setitem__(self, key, value):
187 if isinstance(value, Symbol):
188 if value.getRank()==0:
189 self._arr[key]=value.item()
190 elif hasattr(self._arr[key], "shape"):
191 if self._arr[key].shape==value.getShape():
192 for idx in numpy.ndindex(self._arr[key].shape):
193 self._arr[key][idx]=value[idx].item()
194 else:
195 raise ValueError("Wrong shape of value")
196 else:
197 raise ValueError("Wrong shape of value")
198 elif isinstance(value, sympy.Basic):
199 self._arr[key]=value
200 elif hasattr(value, "__array__"):
201 self._arr[key]=map(sympy.sympify,value.flat)
202 else:
203 self._arr[key]=sympy.sympify(value)
204
205 def __iter__(self):
206 return self._arr.__iter__
207
208 def __array__(self, t=None):
209 if t:
210 return self._arr.astype(t)
211 else:
212 return self._arr
213
214 def _sympy_(self):
215 """
216 """
217 return self.applyfunc(sympy.sympify)
218
219 def getDim(self):
220 """
221 Returns the spatial dimensionality of this symbol.
222
223 :return: the symbol's spatial dimensionality, or -1 if undefined
224 :rtype: ``int``
225 """
226 return self._dim
227
228 def getRank(self):
229 """
230 Returns the rank of this symbol.
231
232 :return: the symbol's rank which is equal to the length of the shape.
233 :rtype: ``int``
234 """
235 return self._arr.ndim
236
237 def getShape(self):
238 """
239 Returns the shape of this symbol.
240
241 :return: the symbol's shape
242 :rtype: ``tuple`` of ``int``
243 """
244 return self._arr.shape
245
246 def getDataSubstitutions(self):
247 """
248 Returns a dictionary of symbol names and the escript ``Data`` objects
249 they represent within this Symbol.
250
251 :return: the dictionary of substituted ``Data`` objects
252 :rtype: ``dict``
253 """
254 return self._subs
255
256 def item(self, *args):
257 """
258 Returns an element of this symbol.
259 This method behaves like the item() method of numpy.ndarray.
260 If this is a scalar Symbol, no arguments are allowed and the only
261 element in this Symbol is returned.
262 Otherwise, 'args' specifies a flat or nd-index and the element at
263 that index is returned.
264
265 :param args: index of item to be returned
266 :return: the requested element
267 :rtype: ``sympy.Symbol``, ``int``, or ``float``
268 """
269 return self._arr.item(args)
270
271 def atoms(self, *types):
272 """
273 Returns the atoms that form the current Symbol.
274
275 By default, only objects that are truly atomic and cannot be divided
276 into smaller pieces are returned: symbols, numbers, and number
277 symbols like I and pi. It is possible to request atoms of any type,
278 however.
279
280 Note that if this symbol contains components such as [x]_i_j then
281 only their main symbol 'x' is returned.
282
283 :param types: types to restrict result to
284 :return: list of atoms of specified type
285 :rtype: ``set``
286 """
287 s=set()
288 for el in self._arr.flat:
289 if isinstance(el,sympy.Basic):
290 atoms=el.atoms(*types)
291 for a in atoms:
292 if a.is_Symbol:
293 n,c=Symbol._symComp(a)
294 s.add(sympy.Symbol(n))
295 else:
296 s.add(a)
297 elif len(types)==0 or type(el) in types:
298 s.add(el)
299 return s
300
301 def _sympystr_(self, printer):
302 # compatibility with sympy 1.6
303 return self._sympystr(printer)
304
305 def _sympystr(self, printer):
306 return self.lambdarepr()
307
308 def lambdarepr(self):
309 """
310 """
311 from sympy.printing.lambdarepr import lambdarepr
312 temp_arr=numpy.empty(self.getShape(), dtype=object)
313 for idx,el in numpy.ndenumerate(self._arr):
314 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
315 # create a dictionary to convert names like [x]_0_0 to x[0,0]
316 symdict={}
317 for a in atoms:
318 n,c=Symbol._symComp(a)
319 if len(c)>0:
320 c=[str(i) for i in c]
321 symstr=n+'['+','.join(c)+']'
322 else:
323 symstr=n
324 symdict[a.name]=symstr
325 s=lambdarepr(el)
326 for key in symdict:
327 s=s.replace(key, symdict[key])
328 temp_arr[idx]=s
329 if self.getRank()==0:
330 return temp_arr.item()
331 else:
332 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
333
334 def coeff(self, x, expand=True):
335 """
336 Returns the coefficient of the term "x" or 0 if there is no "x".
337
338 If "x" is a scalar symbol then "x" is searched in all components of
339 this symbol. Otherwise the shapes must match and the coefficients are
340 checked component by component.
341
342 Example::
343
344 x=Symbol('x', (2,2))
345 y=3*x
346 print y.coeff(x)
347 print y.coeff(x[1,1])
348
349 will print::
350
351 [[3 3]
352 [3 3]]
353
354 [[0 0]
355 [0 3]]
356
357 :param x: the term whose coefficients are to be found
358 :type x: ``Symbol``, ``numpy.ndarray``, `list`
359 :return: the coefficient(s) of the term
360 :rtype: ``Symbol``
361 """
362 self._ensureShapeCompatible(x)
363 if hasattr(x, '__array__'):
364 y=x.__array__()
365 else:
366 y=numpy.array(x)
367
368 if y.ndim>0:
369 result=numpy.zeros(self.getShape(), dtype=object)
370 for idx in numpy.ndindex(y.shape):
371 if y[idx]!=0:
372 res=self._arr[idx].coeff(y[idx], expand)
373 if res is not None:
374 result[idx]=res
375 elif y.item()==0:
376 result=numpy.zeros(self.getShape(), dtype=object)
377 else:
378 coeff_item=lambda item: getattr(item, 'coeff')(y.item(), expand)
379 none_to_zero=lambda item: 0 if item is None else item
380 result=self.applyfunc(coeff_item)
381 result=result.applyfunc(none_to_zero)
382 res=Symbol(result, dim=self._dim)
383 for i in self._subs: res.subs(i, self._subs[i])
384 return res
385
386 def subs(self, old, new):
387 """
388 Substitutes an expression.
389 """
390 old._ensureShapeCompatible(new)
391 if isinstance(new, Data):
392 subs=self._subs.copy()
393 if isinstance(old, Symbol) and old.getRank()>0:
394 old=Symbol(old.atoms(sympy.Symbol)[0])
395 subs[old]=new
396 result=Symbol(self._arr, dim=self._dim, subs=subs)
397 elif isinstance(old, Symbol) and old.getRank()>0:
398 if hasattr(new, '__array__'):
399 new=new.__array__()
400 else:
401 new=numpy.array(new)
402
403 result=numpy.empty(self.getShape(), dtype=object)
404 if new.ndim>0:
405 for idx in numpy.ndindex(self.getShape()):
406 for symidx in numpy.ndindex(new.shape):
407 result[idx]=self._arr[idx].subs(old._arr[symidx], new[symidx])
408 else: # substitute scalar for non-scalar
409 for idx in numpy.ndindex(self.getShape()):
410 for symidx in numpy.ndindex(old.getShape()):
411 result[idx]=self._arr[idx].subs(old._arr[symidx], new.item())
412 result=Symbol(result, dim=self._dim, subs=self._subs)
413 else: # scalar
414 if isinstance(new, Symbol):
415 new=new.item()
416 if isinstance(old, Symbol):
417 old=old.item()
418 subs_item=lambda item: getattr(item, 'subs')(old, new)
419 result=self.applyfunc(subs_item)
420 return result
421
422 def diff(self, *symbols, **assumptions):
423 """
424 """
425 symbols=Symbol._symbolgen(*symbols)
426 result=Symbol(self._arr, dim=self._dim, subs=self._subs)
427 for s in symbols:
428 if isinstance(s, Symbol):
429 if s.getRank()==0:
430 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
431 result=result.applyfunc(diff_item)
432 elif s.getRank()==1:
433 dim=s.getShape()[0]
434 out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
435 for d in range(dim):
436 for idx in numpy.ndindex(self.getShape()):
437 index=idx+(d,)
438 out[index]=out[index].diff(s[d].item(), **assumptions)
439 result=Symbol(out, dim=self._dim, subs=self._subs)
440 else:
441 raise ValueError("diff: argument must have rank 0 or 1")
442 else:
443 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
444 result=result.applyfunc(diff_item)
445 return result
446
447 def grad(self, where=None):
448 """
449 Returns a symbol which represents the gradient of this symbol.
450 :type where: ``Symbol``, ``FunctionSpace``
451 """
452 if self._dim < 0:
453 raise ValueError("grad: cannot compute gradient as symbol has undefined dimensionality")
454 subs=self._subs
455 if isinstance(where, Symbol):
456 if where.getRank()>0:
457 raise ValueError("grad: 'where' must be a scalar symbol")
458 where=where._arr.item()
459 elif isinstance(where, FunctionSpace):
460 name='fs'+str(id(where))
461 fssym=Symbol(name)
462 subs=self._subs.copy()
463 subs.update({fssym:where})
464 where=name
465
466 from .functions import grad_n
467 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self._dim,axis=self.getRank())
468 for d in range(self._dim):
469 for idx in numpy.ndindex(self.getShape()):
470 index=idx+(d,)
471 if where is None:
472 out[index]=grad_n(out[index],d)
473 else:
474 out[index]=grad_n(out[index],d,where)
475 return Symbol(out, dim=self._dim, subs=subs)
476
477 def inverse(self):
478 """
479 """
480 if not self.getRank()==2:
481 raise TypeError("inverse: Only rank 2 supported")
482 s=self.getShape()
483 if not s[0] == s[1]:
484 raise ValueError("inverse: Only square shapes supported")
485 out=numpy.zeros(s, numpy.object)
486 arr=self._arr
487 if s[0]==1:
488 if arr[0,0].is_zero:
489 raise ZeroDivisionError("inverse: Symbol not invertible")
490 out[0,0]=1./arr[0,0]
491 elif s[0]==2:
492 A11=arr[0,0]
493 A12=arr[0,1]
494 A21=arr[1,0]
495 A22=arr[1,1]
496 D = A11*A22-A12*A21
497 if D.is_zero:
498 raise ZeroDivisionError("inverse: Symbol not invertible")
499 D=1./D
500 out[0,0]= A22*D
501 out[1,0]=-A21*D
502 out[0,1]=-A12*D
503 out[1,1]= A11*D
504 elif s[0]==3:
505 A11=arr[0,0]
506 A21=arr[1,0]
507 A31=arr[2,0]
508 A12=arr[0,1]
509 A22=arr[1,1]
510 A32=arr[2,1]
511 A13=arr[0,2]
512 A23=arr[1,2]
513 A33=arr[2,2]
514 D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
515 if D.is_zero:
516 raise ZeroDivisionError("inverse: Symbol not invertible")
517 D=1./D
518 out[0,0]=(A22*A33-A23*A32)*D
519 out[1,0]=(A31*A23-A21*A33)*D
520 out[2,0]=(A21*A32-A31*A22)*D
521 out[0,1]=(A13*A32-A12*A33)*D
522 out[1,1]=(A11*A33-A31*A13)*D
523 out[2,1]=(A12*A31-A11*A32)*D
524 out[0,2]=(A12*A23-A13*A22)*D
525 out[1,2]=(A13*A21-A11*A23)*D
526 out[2,2]=(A11*A22-A12*A21)*D
527 else:
528 raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
529 return Symbol(out, dim=self._dim, subs=self._subs)
530
531 def swap_axes(self, axis0, axis1):
532 """
533 """
534 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self._dim, subs=self._subs)
535
536 def tensorProduct(self, other, axis_offset):
537 """
538 """
539 arg0_c=self._arr.copy()
540 sh0=self.getShape()
541 if isinstance(other, Symbol):
542 arg1_c=other._arr.copy()
543 sh1=other.getShape()
544 dim=other._dim if self._dim < 0 else self._dim
545 else:
546 arg1_c=other.copy()
547 sh1=other.shape
548 dim=self._dim
549 d0,d1,d01=1,1,1
550 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
551 for i in sh1[axis_offset:]: d1*=i
552 for i in sh1[:axis_offset]: d01*=i
553 arg0_c.resize((d0,d01))
554 arg1_c.resize((d01,d1))
555 out=numpy.zeros((d0,d1),numpy.object)
556 for i0 in range(d0):
557 for i1 in range(d1):
558 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
559 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
560 subs=self._subs.copy()
561 subs.update(other._subs)
562 return Symbol(out, dim=dim, subs=subs)
563
564 def transposedTensorProduct(self, other, axis_offset):
565 """
566 """
567 arg0_c=self._arr.copy()
568 sh0=self.getShape()
569 if isinstance(other, Symbol):
570 arg1_c=other._arr.copy()
571 sh1=other.getShape()
572 dim=other._dim if self._dim < 0 else self._dim
573 else:
574 arg1_c=other.copy()
575 sh1=other.shape
576 dim=self._dim
577 d0,d1,d01=1,1,1
578 for i in sh0[axis_offset:]: d0*=i
579 for i in sh1[axis_offset:]: d1*=i
580 for i in sh1[:axis_offset]: d01*=i
581 arg0_c.resize((d01,d0))
582 arg1_c.resize((d01,d1))
583 out=numpy.zeros((d0,d1),numpy.object)
584 for i0 in range(d0):
585 for i1 in range(d1):
586 out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
587 out.resize(sh0[axis_offset:]+sh1[axis_offset:])
588 subs=self._subs.copy()
589 subs.update(other._subs)
590 return Symbol(out, dim=dim, subs=subs)
591
592 def tensorTransposedProduct(self, other, axis_offset):
593 """
594 """
595 arg0_c=self._arr.copy()
596 sh0=self.getShape()
597 if isinstance(other, Symbol):
598 arg1_c=other._arr.copy()
599 sh1=other.getShape()
600 r1=other.getRank()
601 dim=other._dim if self._dim < 0 else self._dim
602 else:
603 arg1_c=other.copy()
604 sh1=other.shape
605 r1=other.ndim
606 dim=self._dim
607 d0,d1,d01=1,1,1
608 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
609 for i in sh1[:r1-axis_offset]: d1*=i
610 for i in sh1[r1-axis_offset:]: d01*=i
611 arg0_c.resize((d0,d01))
612 arg1_c.resize((d1,d01))
613 out=numpy.zeros((d0,d1),numpy.object)
614 for i0 in range(d0):
615 for i1 in range(d1):
616 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
617 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
618 subs=self._subs.copy()
619 subs.update(other._subs)
620 return Symbol(out, dim=dim, subs=subs)
621
622 def trace(self, axis_offset):
623 """
624 Returns the trace of this Symbol.
625 """
626 sh=self.getShape()
627 s1=1
628 for i in range(axis_offset): s1*=sh[i]
629 s2=1
630 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
631 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
632 out=numpy.zeros([s1,s2],object)
633 for i1 in range(s1):
634 for i2 in range(s2):
635 for j in range(sh[axis_offset]):
636 out[i1,i2]+=arr_r[i1,j,j,i2]
637 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
638 return Symbol(out, dim=self._dim, subs=self._subs)
639
640 def transpose(self, axis_offset):
641 """
642 Returns the transpose of this Symbol.
643 """
644 if axis_offset is None:
645 axis_offset=int(self._arr.ndim/2)
646 axes=list(range(axis_offset, self._arr.ndim))+list(range(0,axis_offset))
647 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self._dim, subs=self._subs)
648
649 def applyfunc(self, f, on_type=None):
650 """
651 Applies the function `f` to all elements (if on_type is None) or to
652 all elements of type `on_type`.
653 """
654 assert callable(f)
655 if self._arr.ndim==0:
656 if on_type is None or isinstance(self._arr.item(), on_type):
657 el=f(self._arr.item())
658 else:
659 el=self._arr.item()
660 if el is not None:
661 out=Symbol(el, dim=self._dim, subs=self._subs)
662 else:
663 return el
664 else:
665 out=numpy.empty(self.getShape(), dtype=object)
666 for idx in numpy.ndindex(self.getShape()):
667 if on_type is None or isinstance(self._arr[idx],on_type):
668 out[idx]=f(self._arr[idx])
669 else:
670 out[idx]=self._arr[idx]
671 out=Symbol(out, dim=self._dim, subs=self._subs)
672 return out
673
674 def expand(self):
675 """
676 Applies the sympy.expand operation on all elements in this symbol
677 """
678 return self.applyfunc(sympy.expand, sympy.Basic)
679
680 def simplify(self):
681 """
682 Applies the sympy.simplify operation on all elements in this symbol
683 """
684 return self.applyfunc(sympy.simplify, sympy.Basic)
685
686 # unary/binary operators follow
687
688 def __pos__(self):
689 return self
690
691 def __neg__(self):
692 return Symbol(-self._arr, dim=self._dim, subs=self._subs)
693
694 def __abs__(self):
695 return Symbol(abs(self._arr), dim=self._dim, subs=self._subs)
696
697 def _ensureShapeCompatible(self, other):
698 """
699 Checks for compatible shapes for binary operations.
700 Raises TypeError if not compatible.
701 """
702 sh0=self.getShape()
703 if isinstance(other, Symbol) or isinstance(other, Data):
704 sh1=other.getShape()
705 elif isinstance(other, numpy.ndarray):
706 sh1=other.shape
707 elif isinstance(other, list):
708 sh1=numpy.array(other).shape
709 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
710 sh1=()
711 else:
712 raise TypeError("Unsupported argument type '%s' for operation"%other.__class__.__name__)
713 if not sh0==sh1 and not sh0==() and not sh1==():
714 raise TypeError("Incompatible shapes for operation")
715
716 def __binaryop(self, op, other):
717 """
718 Helper for binary operations that checks types, shapes etc.
719 """
720 self._ensureShapeCompatible(other)
721 if isinstance(other, Symbol):
722 subs=self._subs.copy()
723 subs.update(other._subs)
724 dim=other._dim if self._dim < 0 else self._dim
725 return Symbol(getattr(self._arr, op)(other._arr), dim=dim, subs=subs)
726 if isinstance(other, Data):
727 name='data'+str(id(other))
728 othersym=Symbol(name, other.getShape(), dim=self._dim)
729 subs=self._subs.copy()
730 subs.update({Symbol(name):other})
731 return Symbol(getattr(self._arr, op)(othersym._arr), dim=self._dim, subs=subs)
732 return Symbol(getattr(self._arr, op)(other), dim=self._dim, subs=self._subs)
733
734 def __add__(self, other):
735 return self.__binaryop('__add__', other)
736
737 def __radd__(self, other):
738 return self.__binaryop('__radd__', other)
739
740 def __sub__(self, other):
741 return self.__binaryop('__sub__', other)
742
743 def __rsub__(self, other):
744 return self.__binaryop('__rsub__', other)
745
746 def __mul__(self, other):
747 return self.__binaryop('__mul__', other)
748
749 def __rmul__(self, other):
750 return self.__binaryop('__rmul__', other)
751
752 def __div__(self, other):
753 print(type(self), type(other))
754 return self.__binaryop('__div__', other)
755
756 def __truediv__(self, other):
757 return self.__binaryop('__truediv__', other)
758
759 def __rdiv__(self, other):
760 return self.__binaryop('__rdiv__', other)
761
762 def __rtruediv__(self, other):
763 return self.__binaryop('__rtruediv__', other)
764 def __pow__(self, other):
765 return self.__binaryop('__pow__', other)
766
767 def __rpow__(self, other):
768 return self.__binaryop('__rpow__', other)
769
770 # static methods
771
772 @staticmethod
773 def _symComp(sym):
774 """
775 """
776 n=sym.name
777 a=n.split('[')
778 if len(a)!=2:
779 return n,()
780 a=a[1].split(']')
781 if len(a)!=2:
782 return n,()
783 name=a[0]
784 comps=[int(i) for i in a[1].split('_')[1:]]
785 return name,tuple(comps)
786
787 @staticmethod
788 def _symbolgen(*symbols):
789 """
790 Generator of all symbols in the argument of diff().
791 (cf. sympy.Derivative._symbolgen)
792
793 Example:
794 >> ._symbolgen(x, 3, y)
795 (x, x, x, y)
796 >> ._symbolgen(x, 10**6)
797 (x, x, x, x, x, x, x, ...)
798 """
799 from itertools import repeat
800 last_s = symbols[len(symbols)-1]
801 if not isinstance(last_s, Symbol):
802 last_s=sympy.sympify(last_s)
803 for i in xrange(len(symbols)):
804 s = symbols[i]
805 if not isinstance(s, Symbol):
806 s=sympy.sympify(s)
807 next_s = None
808 if s != last_s:
809 next_s = symbols[i+1]
810 if not isinstance(next_s, Symbol):
811 next_s=sympy.sympify(next_s)
812
813 if isinstance(s, sympy.Integer):
814 continue
815 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
816 # handle cases like (x, 3)
817 if isinstance(next_s, sympy.Integer):
818 # yield (x, x, x)
819 for copy_s in repeat(s,int(next_s)):
820 yield copy_s
821 else:
822 yield s
823 else:
824 yield s
825

  ViewVC Help
Powered by ViewVC 1.1.26