/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Contents of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3518 - (show annotations)
Fri May 20 06:29:31 2011 UTC (8 years, 5 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 19692 byte(s)
Some cleanup and added tests for symbolic div and grad.

1 # -*- coding: utf-8 -*-
2
3 ########################################################
4 #
5 # Copyright (c) 2003-2010 by University of Queensland
6 # Earth Systems Science Computational Center (ESSCC)
7 # http://www.uq.edu.au/esscc
8 #
9 # Primary Business: Queensland, Australia
10 # Licensed under the Open Software License version 3.0
11 # http://www.opensource.org/licenses/osl-3.0.php
12 #
13 ########################################################
14
15 __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16 Earth Systems Science Computational Center (ESSCC)
17 http://www.uq.edu.au/esscc
18 Primary Business: Queensland, Australia"""
19 __license__="""Licensed under the Open Software License version 3.0
20 http://www.opensource.org/licenses/osl-3.0.php"""
21 __url__="https://launchpad.net/escript-finley"
22
23 """
24 :var __author__: name of author
25 :var __copyright__: copyrights
26 :var __license__: licence agreement
27 :var __url__: url entry point on documentation
28 :var __version__: version
29 :var __date__: date of the version
30 """
31
32 import numpy
33 import sympy
34
35 __author__="Cihan Altinay"
36
37
38 class Symbol(object):
39 """
40 """
41
42 def __init__(self, *args, **kwargs):
43 """
44 Initializes a new Symbol object.
45 """
46 #from esys.escript import Data
47 if len(args)==1:
48 arg=args[0]
49 if isinstance(arg, str):
50 if arg.find('[')>=0 or arg.find(']')>=0:
51 raise TypeError("Name must not contain '[' or ']'")
52 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
53 elif hasattr(arg, "__array__"):
54 arr=arg.__array__()
55 if len(arr.shape)>4:
56 raise ValueError("Symbol only supports tensors up to order 4")
57 self._arr=arr.copy()
58 elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
59 self._arr=numpy.array(arg)
60 #elif isinstance(arg, Data):
61 # self._arr=arg
62 else:
63 raise TypeError("Unsupported argument type %s"%str(type(arg)))
64 elif len(args)==2:
65 if not isinstance(args[0], str):
66 raise TypeError("First argument must be a string")
67 if args[0].find('[')>=0 or args[0].find(']')>=0:
68 raise TypeError("Name must not contain '[' or ']'")
69 if not isinstance(args[1], tuple):
70 raise TypeError("Second argument must be a tuple")
71 name=args[0]
72 shape=args[1]
73 if len(shape)>4:
74 raise ValueError("Symbol only supports tensors up to order 4")
75 if len(shape)==0:
76 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
77 else:
78 self._arr=sympy.symarray(shape, '['+name+']')
79 else:
80 raise TypeError("Unsupported number of arguments")
81 if self._arr.ndim==0:
82 self.name=str(self._arr.item())
83 else:
84 self.name=str(self._arr.tolist())
85
86 def __repr__(self):
87 return str(self._arr)
88
89 def __str__(self):
90 return str(self._arr)
91
92 def __eq__(self, other):
93 if type(self) is not type(other):
94 return False
95 if self.getRank()!=other.getRank():
96 return False
97 if self.getShape()!=other.getShape():
98 return False
99 return (self._arr==other._arr).all()
100
101 def __getitem__(self, key):
102 return self._arr[key]
103
104 def __setitem__(self, key, value):
105 if isinstance(value,Symbol):
106 if value.getRank()==0:
107 self._arr[key]=value
108 elif hasattr(self._arr[key], "shape"):
109 if self._arr[key].shape==value.getShape():
110 self._arr[key]=value
111 else:
112 raise ValueError("Wrong shape of value")
113 else:
114 raise ValueError("Wrong shape of value")
115 elif isinstance(value,sympy.Basic):
116 self._arr[key]=value
117 elif hasattr(value, "__array__"):
118 self._arr[key]=map(sympy.sympify,value.flat)
119 else:
120 self._arr[key]=sympy.sympify(value)
121
122 def getRank(self):
123 return self._arr.ndim
124
125 def getShape(self):
126 return self._arr.shape
127
128 def atoms(self, *types):
129 s=set()
130 for el in self._arr.flat:
131 atoms=el.atoms(*types)
132 for a in atoms:
133 if a.is_Symbol:
134 n,c=Symbol._symComp(a)
135 s.add(sympy.Symbol(n))
136 else:
137 s.add(a)
138 return s
139
140 def _sympystr_(self, printer):
141 return self.lambdarepr()
142
143 def lambdarepr(self):
144 from sympy.printing.lambdarepr import lambdarepr
145 temp_arr=numpy.empty(self.getShape(), dtype=object)
146 for idx,el in numpy.ndenumerate(self._arr):
147 atoms=el.atoms(sympy.Symbol)
148 # create a dictionary to convert names like [x]_0_0 to x[0,0]
149 symdict={}
150 for a in atoms:
151 n,c=Symbol._symComp(a)
152 if len(c)>0:
153 c=[str(i) for i in c]
154 symstr=n+'['+','.join(c)+']'
155 else:
156 symstr=n
157 symdict[a.name]=symstr
158 s=lambdarepr(el)
159 for key in symdict:
160 s=s.replace(key, symdict[key])
161 temp_arr[idx]=s
162 if self.getRank()==0:
163 return temp_arr.item()
164 else:
165 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
166
167 def diff(self, *symbols, **assumptions):
168 symbols=Symbol._symbolgen(*symbols)
169 result=Symbol(self._arr)
170 for s in symbols:
171 if isinstance(s, Symbol):
172 if s.getRank()>0:
173 if s.getShape()!=self.getShape():
174 raise ValueError("diff: Incompatible shapes")
175 a=result._arr.flat
176 b=s._arr.flat
177 for idx in range(len(a)):
178 a[idx]=a[idx].diff(b.next())
179 else:
180 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
181 result=result.applyfunc(diff_item)
182
183 else:
184 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
185 result=result.applyfunc(diff_item)
186 return result
187
188 def grad(self, where=None):
189 if isinstance(where, Symbol):
190 if where.getRank()>0:
191 raise ValueError("grad: 'where' must be a scalar symbol")
192 where=where._arr.item()
193
194 from functions import grad_n
195 dim=2
196 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
197 for d in range(dim):
198 for idx in numpy.ndindex(self.getShape()):
199 index=idx+(d,)
200 if where is None:
201 out[index]=grad_n(out[index],d)
202 else:
203 out[index]=grad_n(out[index],d,where)
204 return Symbol(out)
205
206 def inverse(self):
207 if not self.getRank()==2:
208 raise ValueError("inverse: Only rank 2 supported")
209 s=self.getShape()
210 if not s[0] == s[1]:
211 raise ValueError("inverse: Only square shapes supported")
212 out=numpy.zeros(s, numpy.object)
213 arr=self._arr
214 if s[0]==1:
215 if arr[0,0].is_zero:
216 raise ZeroDivisionError("inverse: Symbol not invertible")
217 out[0,0]=1./arr[0,0]
218 elif s[0]==2:
219 A11=arr[0,0]
220 A12=arr[0,1]
221 A21=arr[1,0]
222 A22=arr[1,1]
223 D = A11*A22-A12*A21
224 if D.is_zero:
225 raise ZeroDivisionError("inverse: Symbol not invertible")
226 D=1./D
227 out[0,0]= A22*D
228 out[1,0]=-A21*D
229 out[0,1]=-A12*D
230 out[1,1]= A11*D
231 elif s[0]==3:
232 A11=arr[0,0]
233 A21=arr[1,0]
234 A31=arr[2,0]
235 A12=arr[0,1]
236 A22=arr[1,1]
237 A32=arr[2,1]
238 A13=arr[0,2]
239 A23=arr[1,2]
240 A33=arr[2,2]
241 D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
242 if D.is_zero:
243 raise ZeroDivisionError("inverse: Symbol not invertible")
244 D=1./D
245 out[0,0]=(A22*A33-A23*A32)*D
246 out[1,0]=(A31*A23-A21*A33)*D
247 out[2,0]=(A21*A32-A31*A22)*D
248 out[0,1]=(A13*A32-A12*A33)*D
249 out[1,1]=(A11*A33-A31*A13)*D
250 out[2,1]=(A12*A31-A11*A32)*D
251 out[0,2]=(A12*A23-A13*A22)*D
252 out[1,2]=(A13*A21-A11*A23)*D
253 out[2,2]=(A11*A22-A12*A21)*D
254 else:
255 raise TypeError("inverse: Only matrix dimensions 1,2,3 are supported")
256 return Symbol(out)
257
258 def swap_axes(self, axis0, axis1):
259 return Symbol(numpy.swapaxes(self._arr, axis0, axis1))
260
261 def tensorProduct(self, other, axis_offset):
262 arg0_c=self._arr.copy()
263 sh0=self.getShape()
264 if isinstance(other, Symbol):
265 arg1_c=other._arr.copy()
266 sh1=other.getShape()
267 else:
268 arg1_c=other.copy()
269 sh1=other.shape
270 d0,d1,d01=1,1,1
271 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
272 for i in sh1[axis_offset:]: d1*=i
273 for i in sh1[:axis_offset]: d01*=i
274 arg0_c.resize((d0,d01))
275 arg1_c.resize((d01,d1))
276 out=numpy.zeros((d0,d1),numpy.object)
277 for i0 in range(d0):
278 for i1 in range(d1):
279 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
280 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
281 return Symbol(out)
282
283 def transposedTensorProduct(self, other, axis_offset):
284 arg0_c=self._arr.copy()
285 sh0=self.getShape()
286 if isinstance(other, Symbol):
287 arg1_c=other._arr.copy()
288 sh1=other.getShape()
289 else:
290 arg1_c=other.copy()
291 sh1=other.shape
292 d0,d1,d01=1,1,1
293 for i in sh0[axis_offset:]: d0*=i
294 for i in sh1[axis_offset:]: d1*=i
295 for i in sh1[:axis_offset]: d01*=i
296 arg0_c.resize((d01,d0))
297 arg1_c.resize((d01,d1))
298 out=numpy.zeros((d0,d1),numpy.object)
299 for i0 in range(d0):
300 for i1 in range(d1):
301 out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
302 out.resize(sh0[axis_offset:]+sh1[axis_offset:])
303 return Symbol(out)
304
305 def tensorTransposedProduct(self, other, axis_offset):
306 arg0_c=self._arr.copy()
307 sh0=self.getShape()
308 if isinstance(other, Symbol):
309 arg1_c=other._arr.copy()
310 sh1=other.getShape()
311 r1=other.getRank()
312 else:
313 arg1_c=other.copy()
314 sh1=other.shape
315 r1=other.ndim
316 d0,d1,d01=1,1,1
317 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
318 for i in sh1[:r1-axis_offset]: d1*=i
319 for i in sh1[r1-axis_offset:]: d01*=i
320 arg0_c.resize((d0,d01))
321 arg1_c.resize((d1,d01))
322 out=numpy.zeros((d0,d1),numpy.object)
323 for i0 in range(d0):
324 for i1 in range(d1):
325 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
326 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
327 return Symbol(out)
328
329 def trace(self, axis_offset):
330 sh=self.getShape()
331 s1=1
332 for i in range(axis_offset): s1*=sh[i]
333 s2=1
334 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
335 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
336 out=numpy.zeros([s1,s2],object)
337 for i1 in range(s1):
338 for i2 in range(s2):
339 for j in range(sh[axis_offset]):
340 out[i1,i2]+=arr_r[i1,j,j,i2]
341 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
342 return Symbol(out)
343
344 def transpose(self, axis_offset):
345 if axis_offset is None:
346 axis_offset=int(self._arr.ndim/2)
347 axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
348 return Symbol(numpy.transpose(self._arr, axes=axes))
349
350 def applyfunc(self, f):
351 assert callable(f)
352 if self._arr.ndim==0:
353 out=Symbol(f(self._arr.item()))
354 else:
355 out=numpy.empty(self.getShape(), dtype=object)
356 for idx in numpy.ndindex(self.getShape()):
357 out[idx]=f(self._arr[idx])
358 out=Symbol(out)
359 return out
360
361 def _sympy_(self):
362 return self.applyfunc(sympy.sympify)
363
364 def _ensureShapeCompatible(self, other):
365 """
366 Checks for compatible shapes for binary operations.
367 Raises TypeError if not compatible.
368 """
369 sh0=self.getShape()
370 if isinstance(other, Symbol):
371 sh1=other.getShape()
372 elif isinstance(other, numpy.ndarray):
373 sh1=other.shape
374 elif isinstance(other,int) or isinstance(other,float):
375 sh1=()
376 else:
377 raise TypeError("Unsupported argument type '%s' for binary operation"%other.__class__.__name__)
378 if not sh0==sh1 and not sh0==() and not sh1==():
379 raise TypeError("Incompatible shapes for binary operation")
380
381 @staticmethod
382 def _symComp(sym):
383 n=sym.name
384 a=n.split('[')
385 if len(a)!=2:
386 return n,()
387 a=a[1].split(']')
388 if len(a)!=2:
389 return n,()
390 name=a[0]
391 comps=[int(i) for i in a[1].split('_')[1:]]
392 return name,tuple(comps)
393
394 @staticmethod
395 def _symbolgen(*symbols):
396 """
397 Generator of all symbols in the argument of diff().
398 (cf. sympy.Derivative._symbolgen)
399
400 Example:
401 >> ._symbolgen(x, 3, y)
402 (x, x, x, y)
403 >> ._symbolgen(x, 10**6)
404 (x, x, x, x, x, x, x, ...)
405 """
406 from itertools import repeat
407 last_s = symbols[len(symbols)-1]
408 if not isinstance(last_s, Symbol):
409 last_s=sympy.sympify(last_s)
410 for i in xrange(len(symbols)):
411 s = symbols[i]
412 if not isinstance(s, Symbol):
413 s=sympy.sympify(s)
414 next_s = None
415 if s != last_s:
416 next_s = symbols[i+1]
417 if not isinstance(next_s, Symbol):
418 next_s=sympy.sympify(next_s)
419
420 if isinstance(s, sympy.Integer):
421 continue
422 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
423 # handle cases like (x, 3)
424 if isinstance(next_s, sympy.Integer):
425 # yield (x, x, x)
426 for copy_s in repeat(s,int(next_s)):
427 yield copy_s
428 else:
429 yield s
430 else:
431 yield s
432
433 # unary/binary operations follow
434
435 def __pos__(self):
436 return self
437
438 def __neg__(self):
439 return Symbol(-self._arr)
440
441 def __abs__(self):
442 return Symbol(abs(self._arr))
443
444 def __add__(self, other):
445 self._ensureShapeCompatible(other)
446 if isinstance(other, Symbol):
447 return Symbol(self._arr+other._arr)
448 return Symbol(self._arr+other)
449
450 def __radd__(self, other):
451 self._ensureShapeCompatible(other)
452 if isinstance(other, Symbol):
453 return Symbol(other._arr+self._arr)
454 return Symbol(other+self._arr)
455
456 def __sub__(self, other):
457 self._ensureShapeCompatible(other)
458 if isinstance(other, Symbol):
459 return Symbol(self._arr-other._arr)
460 return Symbol(self._arr-other)
461
462 def __rsub__(self, other):
463 self._ensureShapeCompatible(other)
464 if isinstance(other, Symbol):
465 return Symbol(other._arr-self._arr)
466 return Symbol(other-self._arr)
467
468 def __mul__(self, other):
469 self._ensureShapeCompatible(other)
470 if isinstance(other, Symbol):
471 return Symbol(self._arr*other._arr)
472 return Symbol(self._arr*other)
473
474 def __rmul__(self, other):
475 self._ensureShapeCompatible(other)
476 if isinstance(other, Symbol):
477 return Symbol(other._arr*self._arr)
478 return Symbol(other*self._arr)
479
480 def __div__(self, other):
481 self._ensureShapeCompatible(other)
482 if isinstance(other, Symbol):
483 return Symbol(self._arr/other._arr)
484 return Symbol(self._arr/other)
485
486 def __rdiv__(self, other):
487 self._ensureShapeCompatible(other)
488 if isinstance(other, Symbol):
489 return Symbol(other._arr/self._arr)
490 return Symbol(other/self._arr)
491
492 def __pow__(self, other):
493 self._ensureShapeCompatible(other)
494 if isinstance(other, Symbol):
495 return Symbol(self._arr**other._arr)
496 return Symbol(self._arr**other)
497
498 def __rpow__(self, other):
499 self._ensureShapeCompatible(other)
500 if isinstance(other, Symbol):
501 return Symbol(other._arr**self._arr)
502 return Symbol(other**self._arr)
503
504
505 def symbols(*names, **kwargs):
506 """
507 Emulates the behaviour of sympy.symbols.
508 """
509
510 shape=kwargs.pop('shape', ())
511
512 s = names[0]
513 if not isinstance(s, list):
514 import re
515 s = re.split('\s|,', s)
516 res = []
517 for t in s:
518 # skip empty strings
519 if not t:
520 continue
521 sym = Symbol(t, shape, **kwargs)
522 res.append(sym)
523 res = tuple(res)
524 if len(res) == 0: # var('')
525 res = None
526 elif len(res) == 1: # var('x')
527 res = res[0]
528 # otherwise var('a b ...')
529 return res
530
531 def combineData(array, shape):
532 # array could just be a single value
533 if not hasattr(array,'__len__') and shape==():
534 return array
535
536 from esys.escript import Data
537 n=numpy.array(array) # for indexing
538
539 # find function space if any
540 dom=set()
541 fs=set()
542 for idx in numpy.ndindex(shape):
543 if isinstance(n[idx], Data):
544 fs.add(n[idx].getFunctionSpace())
545 dom.add(n[idx].getDomain())
546
547 if len(dom)>1:
548 domain=dom.pop()
549 while len(dom)>0:
550 if domain!=dom.pop():
551 raise ValueError("Mixing of domains not supported")
552
553 if len(fs)>0:
554 d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
555 else:
556 d=numpy.zeros(shape)
557 for idx in numpy.ndindex(shape):
558 #z=numpy.zeros(shape)
559 #z[idx]=1.
560 #d+=n[idx]*z # much slower!
561 if hasattr(n[idx], "ndim") and n[idx].ndim==0:
562 d[idx]=float(n[idx])
563 else:
564 d[idx]=n[idx]
565 return d
566
567
568 class SymFunction(Symbol):
569 """
570 """
571 def __init__(self, *args, **kwargs):
572 """
573 Initializes a new symbolic function object.
574 """
575 super(SymFunction, self).__init__(self.__class__.__name__, **kwargs)
576 self.args=args
577
578 def __repr__(self):
579 return self.name+"("+", ".join([str(a) for a in self.args])+")"
580
581 def __str__(self):
582 return self.name+"("+", ".join([str(a) for a in self.args])+")"
583
584 def lambdarepr(self):
585 return self.name+"("+", ".join([a.lambdarepr() for a in self.args])+")"
586
587 def atoms(self, *types):
588 s=set()
589 for el in self.args:
590 atoms=el.atoms(*types)
591 for a in atoms:
592 if a.is_Symbol:
593 n,c=Symbol._symComp(a)
594 s.add(sympy.Symbol(n))
595 else:
596 s.add(a)
597 return s
598
599 def __neg__(self):
600 res=self.__class__(*self.args)
601 res._arr=-res._arr
602 return res
603

  ViewVC Help
Powered by ViewVC 1.1.26