/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Contents of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3530 - (show annotations)
Wed Jun 15 04:48:53 2011 UTC (8 years, 4 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 21874 byte(s)
Added dimensionality to symbols (default: 2).
Fixed differentiation.
Added coeff() method.
Fixed a few special cases where elements are numbers/zero etc.

1 # -*- coding: utf-8 -*-
2
3 ########################################################
4 #
5 # Copyright (c) 2003-2010 by University of Queensland
6 # Earth Systems Science Computational Center (ESSCC)
7 # http://www.uq.edu.au/esscc
8 #
9 # Primary Business: Queensland, Australia
10 # Licensed under the Open Software License version 3.0
11 # http://www.opensource.org/licenses/osl-3.0.php
12 #
13 ########################################################
14
15 __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16 Earth Systems Science Computational Center (ESSCC)
17 http://www.uq.edu.au/esscc
18 Primary Business: Queensland, Australia"""
19 __license__="""Licensed under the Open Software License version 3.0
20 http://www.opensource.org/licenses/osl-3.0.php"""
21 __url__="https://launchpad.net/escript-finley"
22
23 """
24 :var __author__: name of author
25 :var __copyright__: copyrights
26 :var __license__: licence agreement
27 :var __url__: url entry point on documentation
28 :var __version__: version
29 :var __date__: date of the version
30 """
31
32 import numpy
33 import sympy
34
35 __author__="Cihan Altinay"
36
37
38 class Symbol(object):
39 """
40 """
41
42 def __init__(self, *args, **kwargs):
43 """
44 Initializes a new Symbol object.
45 """
46 if 'dim' in kwargs:
47 self.dim=kwargs.pop('dim')
48 else:
49 self.dim=2
50
51 if len(args)==1:
52 arg=args[0]
53 if isinstance(arg, str):
54 if arg.find('[')>=0 or arg.find(']')>=0:
55 raise TypeError("Name must not contain '[' or ']'")
56 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
57 elif hasattr(arg, "__array__"):
58 arr=arg.__array__()
59 if len(arr.shape)>4:
60 raise ValueError("Symbol only supports tensors up to order 4")
61 self._arr=arr.copy()
62 elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
63 self._arr=numpy.array(arg)
64 else:
65 raise TypeError("Unsupported argument type %s"%str(type(arg)))
66 elif len(args)==2:
67 if not isinstance(args[0], str):
68 raise TypeError("First argument must be a string")
69 if args[0].find('[')>=0 or args[0].find(']')>=0:
70 raise TypeError("Name must not contain '[' or ']'")
71 if not isinstance(args[1], tuple):
72 raise TypeError("Second argument must be a tuple")
73 name=args[0]
74 shape=args[1]
75 if len(shape)>4:
76 raise ValueError("Symbol only supports tensors up to order 4")
77 if len(shape)==0:
78 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
79 else:
80 self._arr=sympy.symarray(shape, '['+name+']')
81 else:
82 raise TypeError("Unsupported number of arguments")
83 if self._arr.ndim==0:
84 self.name=str(self._arr.item())
85 else:
86 self.name=str(self._arr.tolist())
87
88 def __repr__(self):
89 return str(self._arr)
90
91 def __str__(self):
92 return str(self._arr)
93
94 def __eq__(self, other):
95 if type(self) is not type(other):
96 return False
97 if self.getRank()!=other.getRank():
98 return False
99 if self.getShape()!=other.getShape():
100 return False
101 return (self._arr==other._arr).all()
102
103 def __getitem__(self, key):
104 return self._arr[key]
105
106 def __setitem__(self, key, value):
107 if isinstance(value,Symbol):
108 if value.getRank()==0:
109 self._arr[key]=value
110 elif hasattr(self._arr[key], "shape"):
111 if self._arr[key].shape==value.getShape():
112 self._arr[key]=value
113 else:
114 raise ValueError("Wrong shape of value")
115 else:
116 raise ValueError("Wrong shape of value")
117 elif isinstance(value,sympy.Basic):
118 self._arr[key]=value
119 elif hasattr(value, "__array__"):
120 self._arr[key]=map(sympy.sympify,value.flat)
121 else:
122 self._arr[key]=sympy.sympify(value)
123
124 def getRank(self):
125 return self._arr.ndim
126
127 def getShape(self):
128 return self._arr.shape
129
130 def atoms(self, *types):
131 s=set()
132 for el in self._arr.flat:
133 if isinstance(el,sympy.Basic):
134 atoms=el.atoms(*types)
135 for a in atoms:
136 if a.is_Symbol:
137 n,c=Symbol._symComp(a)
138 s.add(sympy.Symbol(n))
139 else:
140 s.add(a)
141 else:
142 # TODO: Numbers?
143 pass
144 return s
145
146 def _sympystr_(self, printer):
147 return self.lambdarepr()
148
149 def lambdarepr(self):
150 from sympy.printing.lambdarepr import lambdarepr
151 temp_arr=numpy.empty(self.getShape(), dtype=object)
152 for idx,el in numpy.ndenumerate(self._arr):
153 atoms=el.atoms(sympy.Symbol) if isinstance(el,sympy.Basic) else []
154 # create a dictionary to convert names like [x]_0_0 to x[0,0]
155 symdict={}
156 for a in atoms:
157 n,c=Symbol._symComp(a)
158 if len(c)>0:
159 c=[str(i) for i in c]
160 symstr=n+'['+','.join(c)+']'
161 else:
162 symstr=n
163 symdict[a.name]=symstr
164 s=lambdarepr(el)
165 for key in symdict:
166 s=s.replace(key, symdict[key])
167 temp_arr[idx]=s
168 if self.getRank()==0:
169 return temp_arr.item()
170 else:
171 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
172
173 def coeff(self, x, expand=True):
174 self._ensureShapeCompatible(x)
175 result=Symbol(self._arr, dim=self.dim)
176 if isinstance(x, Symbol):
177 if x.getRank()>0:
178 a=result._arr.flat
179 b=x._arr.flat
180 for idx in range(len(a)):
181 s=b.next()
182 if s==0:
183 a[idx]=0
184 else:
185 a[idx]=a[idx].coeff(s, expand)
186 else:
187 if x._arr.item()==0:
188 result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
189 else:
190 coeff_item=lambda item: getattr(item, 'coeff')(x._arr.item(), expand)
191 result=result.applyfunc(coeff_item)
192 elif x==0:
193 result=Symbol(numpy.zeros(self.getShape()), dim=self.dim)
194 else:
195 coeff_item=lambda item: getattr(item, 'coeff')(x, expand)
196 result=result.applyfunc(coeff_item)
197
198 # replace None by 0
199 if result is None: return 0
200 a=result._arr.flat
201 for idx in range(len(a)):
202 if a[idx] is None: a[idx]=0
203 return result
204
205 def diff(self, *symbols, **assumptions):
206 symbols=Symbol._symbolgen(*symbols)
207 result=Symbol(self._arr, dim=self.dim)
208 for s in symbols:
209 if isinstance(s, Symbol):
210 if s.getRank()==0:
211 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
212 result=result.applyfunc(diff_item)
213 elif s.getRank()==1:
214 dim=s.getShape()[0]
215 out=result._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
216 for d in range(dim):
217 for idx in numpy.ndindex(self.getShape()):
218 index=idx+(d,)
219 out[index]=out[index].diff(s[d], **assumptions)
220 result=Symbol(out, dim=self.dim)
221 else:
222 raise ValueError("diff: Only rank 0 and 1 supported")
223 else:
224 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
225 result=result.applyfunc(diff_item)
226 return result
227
228 def grad(self, where=None):
229 if isinstance(where, Symbol):
230 if where.getRank()>0:
231 raise ValueError("grad: 'where' must be a scalar symbol")
232 where=where._arr.item()
233
234 from functions import grad_n
235 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(self.dim,axis=self.getRank())
236 for d in range(self.dim):
237 for idx in numpy.ndindex(self.getShape()):
238 index=idx+(d,)
239 if where is None:
240 out[index]=grad_n(out[index],d)
241 else:
242 out[index]=grad_n(out[index],d,where)
243 return Symbol(out, dim=self.dim)
244
245 def inverse(self):
246 if not self.getRank()==2:
247 raise ValueError("inverse: Only rank 2 supported")
248 s=self.getShape()
249 if not s[0] == s[1]:
250 raise ValueError("inverse: Only square shapes supported")
251 out=numpy.zeros(s, numpy.object)
252 arr=self._arr
253 if s[0]==1:
254 if arr[0,0].is_zero:
255 raise ZeroDivisionError("inverse: Symbol not invertible")
256 out[0,0]=1./arr[0,0]
257 elif s[0]==2:
258 A11=arr[0,0]
259 A12=arr[0,1]
260 A21=arr[1,0]
261 A22=arr[1,1]
262 D = A11*A22-A12*A21
263 if D.is_zero:
264 raise ZeroDivisionError("inverse: Symbol not invertible")
265 D=1./D
266 out[0,0]= A22*D
267 out[1,0]=-A21*D
268 out[0,1]=-A12*D
269 out[1,1]= A11*D
270 elif s[0]==3:
271 A11=arr[0,0]
272 A21=arr[1,0]
273 A31=arr[2,0]
274 A12=arr[0,1]
275 A22=arr[1,1]
276 A32=arr[2,1]
277 A13=arr[0,2]
278 A23=arr[1,2]
279 A33=arr[2,2]
280 D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
281 if D.is_zero:
282 raise ZeroDivisionError("inverse: Symbol not invertible")
283 D=1./D
284 out[0,0]=(A22*A33-A23*A32)*D
285 out[1,0]=(A31*A23-A21*A33)*D
286 out[2,0]=(A21*A32-A31*A22)*D
287 out[0,1]=(A13*A32-A12*A33)*D
288 out[1,1]=(A11*A33-A31*A13)*D
289 out[2,1]=(A12*A31-A11*A32)*D
290 out[0,2]=(A12*A23-A13*A22)*D
291 out[1,2]=(A13*A21-A11*A23)*D
292 out[2,2]=(A11*A22-A12*A21)*D
293 else:
294 raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
295 return Symbol(out, dim=self.dim)
296
297 def swap_axes(self, axis0, axis1):
298 return Symbol(numpy.swapaxes(self._arr, axis0, axis1), dim=self.dim)
299
300 def tensorProduct(self, other, axis_offset):
301 arg0_c=self._arr.copy()
302 sh0=self.getShape()
303 if isinstance(other, Symbol):
304 arg1_c=other._arr.copy()
305 sh1=other.getShape()
306 else:
307 arg1_c=other.copy()
308 sh1=other.shape
309 d0,d1,d01=1,1,1
310 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
311 for i in sh1[axis_offset:]: d1*=i
312 for i in sh1[:axis_offset]: d01*=i
313 arg0_c.resize((d0,d01))
314 arg1_c.resize((d01,d1))
315 out=numpy.zeros((d0,d1),numpy.object)
316 for i0 in range(d0):
317 for i1 in range(d1):
318 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
319 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
320 return Symbol(out, dim=self.dim)
321
322 def transposedTensorProduct(self, other, axis_offset):
323 arg0_c=self._arr.copy()
324 sh0=self.getShape()
325 if isinstance(other, Symbol):
326 arg1_c=other._arr.copy()
327 sh1=other.getShape()
328 else:
329 arg1_c=other.copy()
330 sh1=other.shape
331 d0,d1,d01=1,1,1
332 for i in sh0[axis_offset:]: d0*=i
333 for i in sh1[axis_offset:]: d1*=i
334 for i in sh1[:axis_offset]: d01*=i
335 arg0_c.resize((d01,d0))
336 arg1_c.resize((d01,d1))
337 out=numpy.zeros((d0,d1),numpy.object)
338 for i0 in range(d0):
339 for i1 in range(d1):
340 out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
341 out.resize(sh0[axis_offset:]+sh1[axis_offset:])
342 return Symbol(out, dim=self.dim)
343
344 def tensorTransposedProduct(self, other, axis_offset):
345 arg0_c=self._arr.copy()
346 sh0=self.getShape()
347 if isinstance(other, Symbol):
348 arg1_c=other._arr.copy()
349 sh1=other.getShape()
350 r1=other.getRank()
351 else:
352 arg1_c=other.copy()
353 sh1=other.shape
354 r1=other.ndim
355 d0,d1,d01=1,1,1
356 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
357 for i in sh1[:r1-axis_offset]: d1*=i
358 for i in sh1[r1-axis_offset:]: d01*=i
359 arg0_c.resize((d0,d01))
360 arg1_c.resize((d1,d01))
361 out=numpy.zeros((d0,d1),numpy.object)
362 for i0 in range(d0):
363 for i1 in range(d1):
364 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
365 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
366 return Symbol(out, dim=self.dim)
367
368 def trace(self, axis_offset):
369 sh=self.getShape()
370 s1=1
371 for i in range(axis_offset): s1*=sh[i]
372 s2=1
373 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
374 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
375 out=numpy.zeros([s1,s2],object)
376 for i1 in range(s1):
377 for i2 in range(s2):
378 for j in range(sh[axis_offset]):
379 out[i1,i2]+=arr_r[i1,j,j,i2]
380 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
381 return Symbol(out, dim=self.dim)
382
383 def transpose(self, axis_offset):
384 if axis_offset is None:
385 axis_offset=int(self._arr.ndim/2)
386 axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
387 return Symbol(numpy.transpose(self._arr, axes=axes), dim=self.dim)
388
389 def applyfunc(self, f):
390 assert callable(f)
391 if self._arr.ndim==0:
392 el=f(self._arr.item())
393 if el is not None:
394 out=Symbol(el, dim=self.dim)
395 else:
396 return el
397 else:
398 out=numpy.empty(self.getShape(), dtype=object)
399 for idx in numpy.ndindex(self.getShape()):
400 out[idx]=f(self._arr[idx])
401 out=Symbol(out, dim=self.dim)
402 return out
403
404 def _sympy_(self):
405 return self.applyfunc(sympy.sympify)
406
407 def _ensureShapeCompatible(self, other):
408 """
409 Checks for compatible shapes for binary operations.
410 Raises TypeError if not compatible.
411 """
412 sh0=self.getShape()
413 if isinstance(other, Symbol):
414 sh1=other.getShape()
415 elif isinstance(other, numpy.ndarray):
416 sh1=other.shape
417 elif isinstance(other,int) or isinstance(other,float) or isinstance(other,sympy.Basic):
418 sh1=()
419 else:
420 raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)
421 if not sh0==sh1 and not sh0==() and not sh1==():
422 raise TypeError("Incompatible shapes for binary operation")
423
424 @staticmethod
425 def _symComp(sym):
426 n=sym.name
427 a=n.split('[')
428 if len(a)!=2:
429 return n,()
430 a=a[1].split(']')
431 if len(a)!=2:
432 return n,()
433 name=a[0]
434 comps=[int(i) for i in a[1].split('_')[1:]]
435 return name,tuple(comps)
436
437 @staticmethod
438 def _symbolgen(*symbols):
439 """
440 Generator of all symbols in the argument of diff().
441 (cf. sympy.Derivative._symbolgen)
442
443 Example:
444 >> ._symbolgen(x, 3, y)
445 (x, x, x, y)
446 >> ._symbolgen(x, 10**6)
447 (x, x, x, x, x, x, x, ...)
448 """
449 from itertools import repeat
450 last_s = symbols[len(symbols)-1]
451 if not isinstance(last_s, Symbol):
452 last_s=sympy.sympify(last_s)
453 for i in xrange(len(symbols)):
454 s = symbols[i]
455 if not isinstance(s, Symbol):
456 s=sympy.sympify(s)
457 next_s = None
458 if s != last_s:
459 next_s = symbols[i+1]
460 if not isinstance(next_s, Symbol):
461 next_s=sympy.sympify(next_s)
462
463 if isinstance(s, sympy.Integer):
464 continue
465 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
466 # handle cases like (x, 3)
467 if isinstance(next_s, sympy.Integer):
468 # yield (x, x, x)
469 for copy_s in repeat(s,int(next_s)):
470 yield copy_s
471 else:
472 yield s
473 else:
474 yield s
475
476 # unary/binary operations follow
477
478 def __pos__(self):
479 return self
480
481 def __neg__(self):
482 return Symbol(-self._arr, dim=self.dim)
483
484 def __abs__(self):
485 return Symbol(abs(self._arr), dim=self.dim)
486
487 def __add__(self, other):
488 self._ensureShapeCompatible(other)
489 if isinstance(other, Symbol):
490 return Symbol(self._arr+other._arr, dim=self.dim)
491 return Symbol(self._arr+other, dim=self.dim)
492
493 def __radd__(self, other):
494 self._ensureShapeCompatible(other)
495 if isinstance(other, Symbol):
496 return Symbol(other._arr+self._arr, dim=self.dim)
497 return Symbol(other+self._arr, dim=self.dim)
498
499 def __sub__(self, other):
500 self._ensureShapeCompatible(other)
501 if isinstance(other, Symbol):
502 return Symbol(self._arr-other._arr, dim=self.dim)
503 return Symbol(self._arr-other, dim=self.dim)
504
505 def __rsub__(self, other):
506 self._ensureShapeCompatible(other)
507 if isinstance(other, Symbol):
508 return Symbol(other._arr-self._arr, dim=self.dim)
509 return Symbol(other-self._arr, dim=self.dim)
510
511 def __mul__(self, other):
512 self._ensureShapeCompatible(other)
513 if isinstance(other, Symbol):
514 return Symbol(self._arr*other._arr, dim=self.dim)
515 return Symbol(self._arr*other, dim=self.dim)
516
517 def __rmul__(self, other):
518 self._ensureShapeCompatible(other)
519 if isinstance(other, Symbol):
520 return Symbol(other._arr*self._arr, dim=self.dim)
521 return Symbol(other*self._arr, dim=self.dim)
522
523 def __div__(self, other):
524 self._ensureShapeCompatible(other)
525 if isinstance(other, Symbol):
526 return Symbol(self._arr/other._arr, dim=self.dim)
527 return Symbol(self._arr/other, dim=self.dim)
528
529 def __rdiv__(self, other):
530 self._ensureShapeCompatible(other)
531 if isinstance(other, Symbol):
532 return Symbol(other._arr/self._arr, dim=self.dim)
533 return Symbol(other/self._arr, dim=self.dim)
534
535 def __pow__(self, other):
536 self._ensureShapeCompatible(other)
537 if isinstance(other, Symbol):
538 return Symbol(self._arr**other._arr, dim=self.dim)
539 return Symbol(self._arr**other, dim=self.dim)
540
541 def __rpow__(self, other):
542 self._ensureShapeCompatible(other)
543 if isinstance(other, Symbol):
544 return Symbol(other._arr**self._arr, dim=self.dim)
545 return Symbol(other**self._arr, dim=self.dim)
546
547
548 def symbols(*names, **kwargs):
549 """
550 Emulates the behaviour of sympy.symbols.
551 """
552
553 shape=kwargs.pop('shape', ())
554
555 s = names[0]
556 if not isinstance(s, list):
557 import re
558 s = re.split('\s|,', s)
559 res = []
560 for t in s:
561 # skip empty strings
562 if not t:
563 continue
564 sym = Symbol(t, shape, **kwargs)
565 res.append(sym)
566 res = tuple(res)
567 if len(res) == 0: # var('')
568 res = None
569 elif len(res) == 1: # var('x')
570 res = res[0]
571 # otherwise var('a b ...')
572 return res
573
574 def combineData(array, shape):
575 # array could just be a single value
576 if not hasattr(array,'__len__') and shape==():
577 return array
578
579 from esys.escript import Data
580 n=numpy.array(array) # for indexing
581
582 # find function space if any
583 dom=set()
584 fs=set()
585 for idx in numpy.ndindex(shape):
586 if isinstance(n[idx], Data):
587 fs.add(n[idx].getFunctionSpace())
588 dom.add(n[idx].getDomain())
589
590 if len(dom)>1:
591 domain=dom.pop()
592 while len(dom)>0:
593 if domain!=dom.pop():
594 raise ValueError("Mixing of domains not supported")
595
596 if len(fs)>0:
597 d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
598 else:
599 d=numpy.zeros(shape)
600 for idx in numpy.ndindex(shape):
601 #z=numpy.zeros(shape)
602 #z[idx]=1.
603 #d+=n[idx]*z # much slower!
604 if hasattr(n[idx], "ndim") and n[idx].ndim==0:
605 d[idx]=float(n[idx])
606 else:
607 d[idx]=n[idx]
608 return d
609
610
611 class SymFunction(Symbol):
612 """
613 """
614 def __init__(self, *args, **kwargs):
615 """
616 Initializes a new symbolic function object.
617 """
618 super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
619 self.args=args
620
621 def __repr__(self):
622 return self.name+"("+", ".join([str(a) for a in self.args])+")"
623
624 def __str__(self):
625 return self.name+"("+", ".join([str(a) for a in self.args])+")"
626
627 def lambdarepr(self):
628 return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
629
630 def atoms(self, *types):
631 s=set()
632 for el in self.args:
633 atoms=el.atoms(*types)
634 for a in atoms:
635 if a.is_Symbol:
636 n,c=Symbol._symComp(a)
637 s.add(sympy.Symbol(n))
638 else:
639 s.add(a)
640 return s
641
642 def __neg__(self):
643 res=self.__class__(*self.args)
644 res._arr=-res._arr
645 return res
646

  ViewVC Help
Powered by ViewVC 1.1.26