/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Contents of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3532 - (show annotations)
Mon Jun 20 04:14:42 2011 UTC (8 years, 4 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 22246 byte(s)
Added simplify method to Symbol class.

1 # -*- coding: utf-8 -*-
2
3 ########################################################
4 #
5 # Copyright (c) 2003-2010 by University of Queensland
6 # Earth Systems Science Computational Center (ESSCC)
7 # http://www.uq.edu.au/esscc
8 #
9 # Primary Business: Queensland, Australia
10 # Licensed under the Open Software License version 3.0
11 # http://www.opensource.org/licenses/osl-3.0.php
12 #
13 ########################################################
14
15 __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16 Earth Systems Science Computational Center (ESSCC)
17 http://www.uq.edu.au/esscc
18 Primary Business: Queensland, Australia"""
19 __license__="""Licensed under the Open Software License version 3.0
20 http://www.opensource.org/licenses/osl-3.0.php"""
21 __url__="https://launchpad.net/escript-finley"
22
23 """
24 :var __author__: name of author
25 :var __copyright__: copyrights
26 :var __license__: licence agreement
27 :var __url__: url entry point on documentation
28 :var __version__: version
29 :var __date__: date of the version
30 """
31
32 import numpy
33 import sympy
34
35 __author__="Cihan Altinay"
36
37
38 class Symbol(object):
39 """
40 """
41
42 def __init__(self, *args, **kwargs):
43 """
44 Initializes a new Symbol object.
45 """
46 if 'dim' in kwargs:
47 self.dim=kwargs.pop('dim')
48 else:
49 self.dim=2
50
51 if len(args)==1:
52 arg=args[0]
53 if isinstance(arg, str):
54 if arg.find('[')>=0 or arg.find(']')>=0:
55 raise TypeError("Name must not contain '[' or ']'")
56 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
57 elif hasattr(arg, "__array__"):
58 arr=arg.__array__()
59 if len(arr.shape)>4:
60 raise ValueError("Symbol only supports tensors up to order 4")
61 self._arr=arr.copy()
62 elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
63 self._arr=numpy.array(arg)
64 else:
65 raise TypeError("Unsupported argument type %s"%str(type(arg)))
66 elif len(args)==2:
67 if not isinstance(args[0], str):
68 raise TypeError("First argument must be a string")
69 if args[0].find('[')>=0 or args[0].find(']')>=0:
70 raise TypeError("Name must not contain '[' or ']'")
71 if not isinstance(args[1], tuple):
72 raise TypeError("Second argument must be a tuple")
73 name=args[0]
74 shape=args[1]
75 if len(shape)>4:
76 raise ValueError("Symbol only supports tensors up to order 4")
77 if len(shape)==0:
78 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
79 else:
80 self._arr=sympy.symarray(shape, '['+name+']')
81 else:
82 raise TypeError("Unsupported number of arguments")
83 if self._arr.ndim==0:
84 self.name=str(self._arr.item())
85 else:
86 self.name=str(self._arr.tolist())
87
88 def __repr__(self):
89 return str(self._arr)
90
91 def __str__(self):
92 return str(self._arr)
93
94 def __eq__(self, other):
95 if type(self) is not type(other):
96 return False
97 if self.getRank()!=other.getRank():
98 return False
99 if self.getShape()!=other.getShape():
100 return False
101 return (self._arr==other._arr).all()
102
103 def __getitem__(self, key):
104 return self._arr[key]
105
106 def __setitem__(self, key, value):
107 if isinstance(value,Symbol):
108 if value.getRank()==0:
109 self._arr[key]=value
110 elif hasattr(self._arr[key], "shape"):
111 if self._arr[key].shape==value.getShape():
112 self._arr[key]=value
113 else:
114 raise ValueError("Wrong shape of value")
115 else:
116 raise ValueError("Wrong shape of value")
117 elif isinstance(value,sympy.Basic):
118 self._arr[key]=value
119 elif hasattr(value, "__array__"):
120 self._arr[key]=map(sympy.sympify,value.flat)
121 else:
122 self._arr[key]=sympy.sympify(value)
123
124 def getRank(self):
125 return self._arr.ndim
126
127 def getShape(self):
128 return self._arr.shape
129
130 def atoms(self, *types):
131 s=set()
132 for el in self._arr.flat:
133 if isinstance(el,sympy.Basic):
134 atoms=el.atoms(*types)
135 for a in atoms:
136 if a.is_Symbol:
137 n,c=Symbol._symComp(a)
138 s.add(sympy.Symbol(n))
139 else:
140 s.add(a)
141 else:
142 # TODO: Numbers?
143 pass
144 return s
145
146 def _sympystr_(self, printer):
147 return self.lambdarepr()
148
149 def lambdarepr(self):
150 from sympy.printing.lambdarepr import lambdarepr
151 temp_arr=numpy.empty(self.getShape(), dtype=object)
152 for idx,el in numpy.ndenumerate(self._arr):
153 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
154 # create a dictionary to convert names like [x]_0_0 to x[0,0]
155 symdict={}
156 for a in atoms:
157 n,c=Symbol._symComp(a)
158 if len(c)>0:
159 c=[str(i) for i in c]
160 symstr=n+'['+','.join(c)+']'
161 else:
162 symstr=n
163 symdict[a.name]=symstr
164 s=lambdarepr(el)
165 for key in symdict:
166 s=s.replace(key, symdict[key])
167 temp_arr[idx]=s
168 if self.getRank()==0:
169 return temp_arr.item()
170 else:
171 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
172
173 def coeff(self, x, expand=True):
174 self._ensureShapeCompatible(x)
175 result=Symbol(self._arr, dim=self.dim)
176 if isinstance(x, Symbol):
177 if x.getRank()>0:
178 a=result._arr.flat
179 b=x._arr.flat
180 for idx in range(len(a)):
181 s=b.next()
182 if s==0:
183 a[idx]=0
184 else:
185 a[idx]=a[idx].coeff(s, expand)
186 else:
187 if x._arr.item()==0:
188 result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
189 else:
190 coeff_item=lambda item: getattr(item, 'coeff')(x._arr.item(), expand)
191 result=result.applyfunc(coeff_item)
192 elif x==0:
193 result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
194 else:
195 coeff_item=lambda item: getattr(item, 'coeff')(x, expand)
196 result=result.applyfunc(coeff_item)
197
198 # replace None by 0
199 if result is None: return 0
200 a=result._arr.flat
201 for idx in range(len(a)):
202 if a[idx] is None: a[idx]=0
203 return result
204
205 def diff(self, *symbols, **assumptions):
206 symbols=Symbol._symbolgen(*symbols)
207 result=Symbol(self._arr, dim=self.dim)
208 for s in symbols:
209 if isinstance(s, Symbol):
210 if s.getRank()==0:
211 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
212 result=result.applyfunc(diff_item)
213 elif s.getRank()==1:
214 dim=s.getShape()[0]
215 out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
216 for d in range(dim):
217 for idx in numpy.ndindex(self.getShape()):
218 index=idx+(d,)
219 out[index]=out[index].diff(s[d], **assumptions)
220 result=Symbol(out, dim=self.dim)
221 else:
222 raise ValueError("diff: Only rank 0 and 1 supported")
223 else:
224 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
225 result=result.applyfunc(diff_item)
226 return result
227
228 def grad(self, where=None):
229 if isinstance(where, Symbol):
230 if where.getRank()>0:
231 raise ValueError("grad: 'where' must be a scalar symbol")
232 where=where._arr.item()
233
234 from functions import grad_n
235 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
236 for d in range(self.dim):
237 for idx in numpy.ndindex(self.getShape()):
238 index=idx+(d,)
239 if where is None:
240 out[index]=grad_n(out[index],d)
241 else:
242 out[index]=grad_n(out[index],d,where)
243 return Symbol(out, dim=self.dim)
244
245 def inverse(self):
246 if not self.getRank()==2:
247 raise ValueError("inverse: Only rank 2 supported")
248 s=self.getShape()
249 if not s[0] == s[1]:
250 raise ValueError("inverse: Only square shapes supported")
251 out=numpy.zeros(s, numpy.object)
252 arr=self._arr
253 if s[0]==1:
254 if arr[0,0].is_zero:
255 raise ZeroDivisionError("inverse: Symbol not invertible")
256 out[0,0]=1./arr[0,0]
257 elif s[0]==2:
258 A11=arr[0,0]
259 A12=arr[0,1]
260 A21=arr[1,0]
261 A22=arr[1,1]
262 D = A11*A22-A12*A21
263 if D.is_zero:
264 raise ZeroDivisionError("inverse: Symbol not invertible")
265 D=1./D
266 out[0,0]= A22*D
267 out[1,0]=-A21*D
268 out[0,1]=-A12*D
269 out[1,1]= A11*D
270 elif s[0]==3:
271 A11=arr[0,0]
272 A21=arr[1,0]
273 A31=arr[2,0]
274 A12=arr[0,1]
275 A22=arr[1,1]
276 A32=arr[2,1]
277 A13=arr[0,2]
278 A23=arr[1,2]
279 A33=arr[2,2]
280 D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
281 if D.is_zero:
282 raise ZeroDivisionError("inverse: Symbol not invertible")
283 D=1./D
284 out[0,0]=(A22*A33-A23*A32)*D
285 out[1,0]=(A31*A23-A21*A33)*D
286 out[2,0]=(A21*A32-A31*A22)*D
287 out[0,1]=(A13*A32-A12*A33)*D
288 out[1,1]=(A11*A33-A31*A13)*D
289 out[2,1]=(A12*A31-A11*A32)*D
290 out[0,2]=(A12*A23-A13*A22)*D
291 out[1,2]=(A13*A21-A11*A23)*D
292 out[2,2]=(A11*A22-A12*A21)*D
293 else:
294 raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
295 return Symbol(out, dim=self.dim)
296
297 def swap_axes(self, axis0, axis1):
298 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
299
300 def tensorProduct(self, other, axis_offset):
301 arg0_c=self._arr.copy()
302 sh0=self.getShape()
303 if isinstance(other, Symbol):
304 arg1_c=other._arr.copy()
305 sh1=other.getShape()
306 else:
307 arg1_c=other.copy()
308 sh1=other.shape
309 d0,d1,d01=1,1,1
310 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
311 for i in sh1[axis_offset:]: d1*=i
312 for i in sh1[:axis_offset]: d01*=i
313 arg0_c.resize((d0,d01))
314 arg1_c.resize((d01,d1))
315 out=numpy.zeros((d0,d1),numpy.object)
316 for i0 in range(d0):
317 for i1 in range(d1):
318 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
319 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
320 return Symbol(out, dim=self.dim)
321
322 def transposedTensorProduct(self, other, axis_offset):
323 arg0_c=self._arr.copy()
324 sh0=self.getShape()
325 if isinstance(other, Symbol):
326 arg1_c=other._arr.copy()
327 sh1=other.getShape()
328 else:
329 arg1_c=other.copy()
330 sh1=other.shape
331 d0,d1,d01=1,1,1
332 for i in sh0[axis_offset:]: d0*=i
333 for i in sh1[axis_offset:]: d1*=i
334 for i in sh1[:axis_offset]: d01*=i
335 arg0_c.resize((d01,d0))
336 arg1_c.resize((d01,d1))
337 out=numpy.zeros((d0,d1),numpy.object)
338 for i0 in range(d0):
339 for i1 in range(d1):
340 out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
341 out.resize(sh0[axis_offset:]+sh1[axis_offset:])
342 return Symbol(out, dim=self.dim)
343
344 def tensorTransposedProduct(self, other, axis_offset):
345 arg0_c=self._arr.copy()
346 sh0=self.getShape()
347 if isinstance(other, Symbol):
348 arg1_c=other._arr.copy()
349 sh1=other.getShape()
350 r1=other.getRank()
351 else:
352 arg1_c=other.copy()
353 sh1=other.shape
354 r1=other.ndim
355 d0,d1,d01=1,1,1
356 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
357 for i in sh1[:r1-axis_offset]: d1*=i
358 for i in sh1[r1-axis_offset:]: d01*=i
359 arg0_c.resize((d0,d01))
360 arg1_c.resize((d1,d01))
361 out=numpy.zeros((d0,d1),numpy.object)
362 for i0 in range(d0):
363 for i1 in range(d1):
364 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
365 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
366 return Symbol(out, dim=self.dim)
367
368 def trace(self, axis_offset):
369 sh=self.getShape()
370 s1=1
371 for i in range(axis_offset): s1*=sh[i]
372 s2=1
373 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
374 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
375 out=numpy.zeros([s1,s2],object)
376 for i1 in range(s1):
377 for i2 in range(s2):
378 for j in range(sh[axis_offset]):
379 out[i1,i2]+=arr_r[i1,j,j,i2]
380 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
381 return Symbol(out, dim=self.dim)
382
383 def transpose(self, axis_offset):
384 if axis_offset is None:
385 axis_offset=int(self._arr.ndim/2)
386 axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
387 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
388
389 def applyfunc(self, f, on_type=None):
390 assert callable(f)
391 if self._arr.ndim==0:
392 if on_type is None or isinstance(self._arr.item(),on_type):
393 el=f(self._arr.item())
394 else:
395 el=self._arr.item()
396 if el is not None:
397 out=Symbol(el, dim=self.dim)
398 else:
399 return el
400 else:
401 out=numpy.empty(self.getShape(), dtype=object)
402 for idx in numpy.ndindex(self.getShape()):
403 if on_type is None or isinstance(self._arr[idx],on_type):
404 out[idx]=f(self._arr[idx])
405 else:
406 out[idx]=self._arr[idx]
407 out=Symbol(out, dim=self.dim)
408 return out
409
410 def simplify(self):
411 return self.applyfunc(sympy.simplify, sympy.Basic)
412
413 def _sympy_(self):
414 return self.applyfunc(sympy.sympify)
415
416 def _ensureShapeCompatible(self, other):
417 """
418 Checks for compatible shapes for binary operations.
419 Raises TypeError if not compatible.
420 """
421 sh0=self.getShape()
422 if isinstance(other, Symbol):
423 sh1=other.getShape()
424 elif isinstance(other, numpy.ndarray):
425 sh1=other.shape
426 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
427 sh1=()
428 else:
429 raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)
430 if not sh0==sh1 and not sh0==() and not sh1==():
431 raise TypeError("Incompatible shapes for binary operation")
432
433 @staticmethod
434 def _symComp(sym):
435 n=sym.name
436 a=n.split('[')
437 if len(a)!=2:
438 return n,()
439 a=a[1].split(']')
440 if len(a)!=2:
441 return n,()
442 name=a[0]
443 comps=[int(i) for i in a[1].split('_')[1:]]
444 return name,tuple(comps)
445
446 @staticmethod
447 def _symbolgen(*symbols):
448 """
449 Generator of all symbols in the argument of diff().
450 (cf. sympy.Derivative._symbolgen)
451
452 Example:
453 >> ._symbolgen(x, 3, y)
454 (x, x, x, y)
455 >> ._symbolgen(x, 10**6)
456 (x, x, x, x, x, x, x, ...)
457 """
458 from itertools import repeat
459 last_s = symbols[len(symbols)-1]
460 if not isinstance(last_s, Symbol):
461 last_s=sympy.sympify(last_s)
462 for i in xrange(len(symbols)):
463 s = symbols[i]
464 if not isinstance(s, Symbol):
465 s=sympy.sympify(s)
466 next_s = None
467 if s != last_s:
468 next_s = symbols[i+1]
469 if not isinstance(next_s, Symbol):
470 next_s=sympy.sympify(next_s)
471
472 if isinstance(s, sympy.Integer):
473 continue
474 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
475 # handle cases like (x, 3)
476 if isinstance(next_s, sympy.Integer):
477 # yield (x, x, x)
478 for copy_s in repeat(s,int(next_s)):
479 yield copy_s
480 else:
481 yield s
482 else:
483 yield s
484
485 # unary/binary operations follow
486
487 def __pos__(self):
488 return self
489
490 def __neg__(self):
491 return Symbol(-self._arr, dim=self.dim)
492
493 def __abs__(self):
494 return Symbol(abs(self._arr), dim=self.dim)
495
496 def __add__(self, other):
497 self._ensureShapeCompatible(other)
498 if isinstance(other, Symbol):
499 return Symbol(self._arr+other._arr, dim=self.dim)
500 return Symbol(self._arr+other, dim=self.dim)
501
502 def __radd__(self, other):
503 self._ensureShapeCompatible(other)
504 if isinstance(other, Symbol):
505 return Symbol(other._arr+self._arr, dim=self.dim)
506 return Symbol(other+self._arr, dim=self.dim)
507
508 def __sub__(self, other):
509 self._ensureShapeCompatible(other)
510 if isinstance(other, Symbol):
511 return Symbol(self._arr-other._arr, dim=self.dim)
512 return Symbol(self._arr-other, dim=self.dim)
513
514 def __rsub__(self, other):
515 self._ensureShapeCompatible(other)
516 if isinstance(other, Symbol):
517 return Symbol(other._arr-self._arr, dim=self.dim)
518 return Symbol(other-self._arr, dim=self.dim)
519
520 def __mul__(self, other):
521 self._ensureShapeCompatible(other)
522 if isinstance(other, Symbol):
523 return Symbol(self._arr*other._arr, dim=self.dim)
524 return Symbol(self._arr*other, dim=self.dim)
525
526 def __rmul__(self, other):
527 self._ensureShapeCompatible(other)
528 if isinstance(other, Symbol):
529 return Symbol(other._arr*self._arr, dim=self.dim)
530 return Symbol(other*self._arr, dim=self.dim)
531
532 def __div__(self, other):
533 self._ensureShapeCompatible(other)
534 if isinstance(other, Symbol):
535 return Symbol(self._arr/other._arr, dim=self.dim)
536 return Symbol(self._arr/other, dim=self.dim)
537
538 def __rdiv__(self, other):
539 self._ensureShapeCompatible(other)
540 if isinstance(other, Symbol):
541 return Symbol(other._arr/self._arr, dim=self.dim)
542 return Symbol(other/self._arr, dim=self.dim)
543
544 def __pow__(self, other):
545 self._ensureShapeCompatible(other)
546 if isinstance(other, Symbol):
547 return Symbol(self._arr**other._arr, dim=self.dim)
548 return Symbol(self._arr**other, dim=self.dim)
549
550 def __rpow__(self, other):
551 self._ensureShapeCompatible(other)
552 if isinstance(other, Symbol):
553 return Symbol(other._arr**self._arr, dim=self.dim)
554 return Symbol(other**self._arr, dim=self.dim)
555
556
557 def symbols(*names, **kwargs):
558 """
559 Emulates the behaviour of sympy.symbols.
560 """
561
562 shape=kwargs.pop('shape', ())
563
564 s = names[0]
565 if not isinstance(s, list):
566 import re
567 s = re.split('\s|,', s)
568 res = []
569 for t in s:
570 # skip empty strings
571 if not t:
572 continue
573 sym = Symbol(t, shape, **kwargs)
574 res.append(sym)
575 res = tuple(res)
576 if len(res) == 0: # var('')
577 res = None
578 elif len(res) == 1: # var('x')
579 res = res[0]
580 # otherwise var('a b ...')
581 return res
582
583 def combineData(array, shape):
584 # array could just be a single value
585 if not hasattr(array,'__len__') and shape==():
586 return array
587
588 from esys.escript import Data
589 n=numpy.array(array) # for indexing
590
591 # find function space if any
592 dom=set()
593 fs=set()
594 for idx in numpy.ndindex(shape):
595 if isinstance(n[idx], Data):
596 fs.add(n[idx].getFunctionSpace())
597 dom.add(n[idx].getDomain())
598
599 if len(dom)>1:
600 domain=dom.pop()
601 while len(dom)>0:
602 if domain!=dom.pop():
603 raise ValueError("Mixing of domains not supported")
604
605 if len(fs)>0:
606 d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
607 else:
608 d=numpy.zeros(shape)
609 for idx in numpy.ndindex(shape):
610 #z=numpy.zeros(shape)
611 #z[idx]=1.
612 #d+=n[idx]*z # much slower!
613 if hasattr(n[idx], "ndim") and n[idx].ndim==0:
614 d[idx]=float(n[idx])
615 else:
616 d[idx]=n[idx]
617 return d
618
619
620 class SymFunction(Symbol):
621 """
622 """
623 def __init__(self, *args, **kwargs):
624 """
625 Initializes a new symbolic function object.
626 """
627 super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
628 self.args=args
629
630 def __repr__(self):
631 return self.name+"("+", ".join([str(a) for a in self.args])+")"
632
633 def __str__(self):
634 return self.name+"("+", ".join([str(a) for a in self.args])+")"
635
636 def lambdarepr(self):
637 return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
638
639 def atoms(self, *types):
640 s=set()
641 for el in self.args:
642 atoms=el.atoms(*types)
643 for a in atoms:
644 if a.is_Symbol:
645 n,c=Symbol._symComp(a)
646 s.add(sympy.Symbol(n))
647 else:
648 s.add(a)
649 return s
650
651 def __neg__(self):
652 res=self.__class__(*self.args)
653 res._arr=-res._arr
654 return res
655

  ViewVC Help
Powered by ViewVC 1.1.26