/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Contents of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3512 - (show annotations)
Wed May 18 06:22:46 2011 UTC (8 years, 5 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 15772 byte(s)
Implementation of symbolic grad() and a few fixes.

1 # -*- coding: utf-8 -*-
2
3 ########################################################
4 #
5 # Copyright (c) 2003-2010 by University of Queensland
6 # Earth Systems Science Computational Center (ESSCC)
7 # http://www.uq.edu.au/esscc
8 #
9 # Primary Business: Queensland, Australia
10 # Licensed under the Open Software License version 3.0
11 # http://www.opensource.org/licenses/osl-3.0.php
12 #
13 ########################################################
14
15 __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16 Earth Systems Science Computational Center (ESSCC)
17 http://www.uq.edu.au/esscc
18 Primary Business: Queensland, Australia"""
19 __license__="""Licensed under the Open Software License version 3.0
20 http://www.opensource.org/licenses/osl-3.0.php"""
21 __url__="https://launchpad.net/escript-finley"
22
23 """
24 :var __author__: name of author
25 :var __copyright__: copyrights
26 :var __license__: licence agreement
27 :var __url__: url entry point on documentation
28 :var __version__: version
29 :var __date__: date of the version
30 """
31
32 import numpy
33 import sympy
34
35 __author__="Cihan Altinay"
36
37
38 class Symbol(object):
39 """
40 """
41
42 def __init__(self, *args, **kwargs):
43 """
44 Initializes a new Symbol object.
45 """
46 #from esys.escript import Data
47 if len(args)==1:
48 arg=args[0]
49 if isinstance(arg, str):
50 if arg.find('[')>=0 or arg.find(']')>=0:
51 raise TypeError("Name must not contain '[' or ']'")
52 self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
53 elif hasattr(arg, "__array__"):
54 arr=arg.__array__()
55 if len(arr.shape)>4:
56 raise ValueError("Symbol only supports tensors up to order 4")
57 self._arr=arr.copy()
58 elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
59 self._arr=numpy.array(arg)
60 #elif isinstance(arg, Data):
61 # self._arr=arg
62 else:
63 raise TypeError("Unsupported argument type %s"%str(type(arg)))
64 elif len(args)==2:
65 if not isinstance(args[0], str):
66 raise TypeError("First argument must be a string")
67 if args[0].find('[')>=0 or args[0].find(']')>=0:
68 raise TypeError("Name must not contain '[' or ']'")
69 if not isinstance(args[1], tuple):
70 raise TypeError("Second argument must be a tuple")
71 name=args[0]
72 shape=args[1]
73 if len(shape)>4:
74 raise ValueError("Symbol only supports tensors up to order 4")
75 if len(shape)==0:
76 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
77 else:
78 self._arr=sympy.symarray(shape, '['+name+']')
79 else:
80 raise TypeError("Unsupported number of arguments")
81 if self._arr.ndim==0:
82 self.name=str(self._arr.item())
83 else:
84 self.name=str(self._arr.tolist())
85
86 def __repr__(self):
87 return str(self._arr)
88
89 def __str__(self):
90 return str(self._arr)
91
92 def __eq__(self, other):
93 if type(self) is not type(other):
94 return False
95 if self.getRank()!=other.getRank():
96 return False
97 if self.getShape()!=other.getShape():
98 return False
99 return (self._arr==other._arr).all()
100
101 def __getitem__(self, key):
102 return self._arr[key]
103
104 def __setitem__(self, key, value):
105 if isinstance(value,Symbol):
106 if value.getRank()==0:
107 self._arr[key]=value
108 elif hasattr(self._arr[key], "shape"):
109 if self._arr[key].shape==value.getShape():
110 self._arr[key]=value
111 else:
112 raise ValueError("Wrong shape of value")
113 else:
114 raise ValueError("Wrong shape of value")
115 elif isinstance(value,sympy.Basic):
116 self._arr[key]=value
117 elif hasattr(value, "__array__"):
118 self._arr[key]=map(sympy.sympify,value.flat)
119 else:
120 self._arr[key]=sympy.sympify(value)
121
122 def getRank(self):
123 return self._arr.ndim
124
125 def getShape(self):
126 return self._arr.shape
127
128 def atoms(self, *types):
129 s=set()
130 for el in self._arr.flat:
131 atoms=el.atoms(*types)
132 for a in atoms:
133 if a.is_Symbol:
134 n,c=Symbol._symComp(a)
135 s.add(sympy.Symbol(n))
136 else:
137 s.add(a)
138 return s
139
140 def _sympystr_(self, printer):
141 return self.lambdarepr()
142
143 def lambdarepr(self):
144 from sympy.printing.lambdarepr import lambdarepr
145 if self.getRank()==0:
146 return lambdarepr(self._arr.item())
147 temp_arr=numpy.empty(self.getShape(), dtype=object)
148 for idx,el in numpy.ndenumerate(self._arr):
149 atoms=el.atoms(sympy.Symbol)
150 # create a dictionary to convert names like x_0_0 to x[0,0]
151 symdict={}
152 for a in atoms:
153 n,c=Symbol._symComp(a)
154 if len(c)>0:
155 c=[str(i) for i in c]
156 symstr=n+'['+','.join(c)+']'
157 else:
158 symstr=n
159 symdict[a.name]=symstr
160 s=lambdarepr(el)
161 for key in symdict:
162 s=s.replace(key, symdict[key])
163 temp_arr[idx]=s
164 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
165
166 def diff(self, *symbols, **assumptions):
167 symbols=Symbol._symbolgen(*symbols)
168 result=Symbol(self._arr)
169 for s in symbols:
170 if isinstance(s, Symbol):
171 if s.getRank()>0:
172 if s.getShape()!=self.getShape():
173 raise ValueError("Incompatible shapes")
174 a=result._arr.flat
175 b=s._arr.flat
176 for idx in range(len(a)):
177 a[idx]=a[idx].diff(b.next())
178 else:
179 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
180 result=result.applyfunc(diff_item)
181
182 else:
183 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
184 result=result.applyfunc(diff_item)
185 return result
186
187 def grad(self, where=None):
188 if isinstance(where, Symbol):
189 if where.getRank()>0:
190 raise ValueError("'where' must be a scalar symbol")
191 where=where._arr.item()
192
193 from functions import grad_n
194 dim=2
195 out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
196 for d in range(dim):
197 for idx in numpy.ndindex(self.getShape()):
198 index=idx+(d,)
199 if where is None:
200 out[index]=grad_n(out[index],d)
201 else:
202 out[index]=grad_n(out[index],d,where)
203 return Symbol(out)
204
205 def swap_axes(self, axis0, axis1):
206 return Symbol(numpy.swapaxes(self._arr, axis0, axis1))
207
208 def tensorProduct(self, other, axis_offset):
209 arg0_c=self._arr.copy()
210 sh0=self.getShape()
211 if isinstance(other, Symbol):
212 arg1_c=other._arr.copy()
213 sh1=other.getShape()
214 else:
215 arg1_c=other.copy()
216 sh1=other.shape
217 d0,d1,d01=1,1,1
218 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
219 for i in sh1[axis_offset:]: d1*=i
220 for i in sh1[:axis_offset]: d01*=i
221 arg0_c.resize((d0,d01))
222 arg1_c.resize((d01,d1))
223 out=numpy.zeros((d0,d1),numpy.object)
224 for i0 in range(d0):
225 for i1 in range(d1):
226 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
227 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
228 return Symbol(out)
229
230 def transposedTensorProduct(self, other, axis_offset):
231 arg0_c=self._arr.copy()
232 sh0=self.getShape()
233 if isinstance(other, Symbol):
234 arg1_c=other._arr.copy()
235 sh1=other.getShape()
236 else:
237 arg1_c=other.copy()
238 sh1=other.shape
239 d0,d1,d01=1,1,1
240 for i in sh0[axis_offset:]: d0*=i
241 for i in sh1[axis_offset:]: d1*=i
242 for i in sh1[:axis_offset]: d01*=i
243 arg0_c.resize((d01,d0))
244 arg1_c.resize((d01,d1))
245 out=numpy.zeros((d0,d1),numpy.object)
246 for i0 in range(d0):
247 for i1 in range(d1):
248 out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
249 out.resize(sh0[axis_offset:]+sh1[axis_offset:])
250 return Symbol(out)
251
252 def tensorTransposedProduct(self, other, axis_offset):
253 arg0_c=self._arr.copy()
254 sh0=self.getShape()
255 if isinstance(other, Symbol):
256 arg1_c=other._arr.copy()
257 sh1=other.getShape()
258 r1=other.getRank()
259 else:
260 arg1_c=other.copy()
261 sh1=other.shape
262 r1=other.ndim
263 d0,d1,d01=1,1,1
264 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
265 for i in sh1[:r1-axis_offset]: d1*=i
266 for i in sh1[r1-axis_offset:]: d01*=i
267 arg0_c.resize((d0,d01))
268 arg1_c.resize((d1,d01))
269 out=numpy.zeros((d0,d1),numpy.object)
270 for i0 in range(d0):
271 for i1 in range(d1):
272 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
273 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
274 return Symbol(out)
275
276 def trace(self, axis_offset):
277 sh=self.getShape()
278 s1=1
279 for i in range(axis_offset): s1*=sh[i]
280 s2=1
281 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
282 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
283 out=numpy.zeros([s1,s2],object)
284 for i1 in range(s1):
285 for i2 in range(s2):
286 for j in range(sh[axis_offset]):
287 out[i1,i2]+=arr_r[i1,j,j,i2]
288 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
289 return Symbol(out)
290
291 def transpose(self, axis_offset):
292 if axis_offset is None:
293 axis_offset=int(self._arr.ndim/2)
294 axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
295 return Symbol(numpy.transpose(self._arr, axes=axes))
296
297 def applyfunc(self, f):
298 assert callable(f)
299 if self._arr.ndim==0:
300 out=Symbol(f(self._arr.item()))
301 else:
302 out=numpy.empty(self.getShape(), dtype=object)
303 for idx in numpy.ndindex(self.getShape()):
304 out[idx]=f(self._arr[idx])
305 out=Symbol(out)
306 return out
307
308 def _sympy_(self):
309 return self.applyfunc(sympy.sympify)
310
311 @staticmethod
312 def _symComp(sym):
313 n=sym.name
314 a=n.split('[')
315 if len(a)!=2:
316 return n,()
317 a=a[1].split(']')
318 if len(a)!=2:
319 return n,()
320 name=a[0]
321 comps=[int(i) for i in a[1].split('_')[1:]]
322 return name,tuple(comps)
323
324 @staticmethod
325 def _symbolgen(*symbols):
326 """
327 Generator of all symbols in the argument of diff().
328 (cf. sympy.Derivative._symbolgen)
329
330 Example:
331 >> ._symbolgen(x, 3, y)
332 (x, x, x, y)
333 >> ._symbolgen(x, 10**6)
334 (x, x, x, x, x, x, x, ...)
335 """
336 from itertools import repeat
337 last_s = symbols[len(symbols)-1]
338 if not isinstance(last_s, Symbol):
339 last_s=sympy.sympify(last_s)
340 for i in xrange(len(symbols)):
341 s = symbols[i]
342 if not isinstance(s, Symbol):
343 s=sympy.sympify(s)
344 next_s = None
345 if s != last_s:
346 next_s = symbols[i+1]
347 if not isinstance(next_s, Symbol):
348 next_s=sympy.sympify(next_s)
349
350 if isinstance(s, sympy.Integer):
351 continue
352 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
353 # handle cases like (x, 3)
354 if isinstance(next_s, sympy.Integer):
355 # yield (x, x, x)
356 for copy_s in repeat(s,int(next_s)):
357 yield copy_s
358 else:
359 yield s
360 else:
361 yield s
362
363 # unary/binary operations follow
364
365 def __pos__(self):
366 return self
367
368 def __neg__(self):
369 return Symbol(-self._arr)
370
371 def __abs__(self):
372 return Symbol(abs(self._arr))
373
374 def __add__(self, other):
375 if isinstance(other, Symbol):
376 return Symbol(self._arr+other._arr)
377 return Symbol(self._arr+other)
378
379 def __radd__(self, other):
380 if isinstance(other, Symbol):
381 return Symbol(other._arr+self._arr)
382 return Symbol(other+self._arr)
383
384 def __sub__(self, other):
385 if isinstance(other, Symbol):
386 return Symbol(self._arr-other._arr)
387 return Symbol(self._arr-other)
388
389 def __rsub__(self, other):
390 if isinstance(other, Symbol):
391 return Symbol(other._arr-self._arr)
392 return Symbol(other-self._arr)
393
394 def __mul__(self, other):
395 if isinstance(other, Symbol):
396 return Symbol(self._arr*other._arr)
397 return Symbol(self._arr*other)
398
399 def __rmul__(self, other):
400 if isinstance(other, Symbol):
401 return Symbol(other._arr*self._arr)
402 return Symbol(other*self._arr)
403
404 def __div__(self, other):
405 if isinstance(other, Symbol):
406 return Symbol(self._arr/other._arr)
407 return Symbol(self._arr/other)
408
409 def __rdiv__(self, other):
410 if isinstance(other, Symbol):
411 return Symbol(other._arr/self._arr)
412 return Symbol(other/self._arr)
413
414 def __pow__(self, other):
415 if isinstance(other, Symbol):
416 return Symbol(self._arr**other._arr)
417 return Symbol(self._arr**other)
418
419 def __rpow__(self, other):
420 if isinstance(other, Symbol):
421 return Symbol(other._arr**self._arr)
422 return Symbol(other**self._arr)
423
424
425 def symbols(*names, **kwargs):
426 """
427 Emulates the behaviour of sympy.symbols.
428 """
429
430 shape=kwargs.pop('shape', ())
431
432 s = names[0]
433 if not isinstance(s, list):
434 import re
435 s = re.split('\s|,', s)
436 res = []
437 for t in s:
438 # skip empty strings
439 if not t:
440 continue
441 sym = Symbol(t, shape, **kwargs)
442 res.append(sym)
443 res = tuple(res)
444 if len(res) == 0: # var('')
445 res = None
446 elif len(res) == 1: # var('x')
447 res = res[0]
448 # otherwise var('a b ...')
449 return res
450
451 def combineData(array, shape):
452 # array could just be a single value
453 if not hasattr(array,'__len__') and shape==():
454 return array
455
456 from esys.escript import Data
457 n=numpy.array(array) # for indexing
458
459 # find function space if any
460 dom=set()
461 fs=set()
462 for idx in numpy.ndindex(shape):
463 if isinstance(n[idx], Data):
464 fs.add(n[idx].getFunctionSpace())
465 dom.add(n[idx].getDomain())
466
467 if len(dom)>1:
468 domain=dom.pop()
469 while len(dom)>0:
470 if domain!=dom.pop():
471 raise ValueError("Mixing of domains not supported")
472
473 if len(fs)>0:
474 d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
475 else:
476 d=numpy.zeros(shape)
477 for idx in numpy.ndindex(shape):
478 #z=numpy.zeros(shape)
479 #z[idx]=1.
480 #d+=n[idx]*z # much slower!
481 if hasattr(n[idx], "ndim") and n[idx].ndim==0:
482 d[idx]=float(n[idx])
483 else:
484 d[idx]=n[idx]
485 return d
486

  ViewVC Help
Powered by ViewVC 1.1.26