/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Contents of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3507 - (show annotations)
Wed May 11 06:04:52 2011 UTC (8 years, 5 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 13218 byte(s)
New approach with own Symbol class, symbolic components and looser dependency
on sympy.

1 # -*- coding: utf-8 -*-
2
3 ########################################################
4 #
5 # Copyright (c) 2003-2010 by University of Queensland
6 # Earth Systems Science Computational Center (ESSCC)
7 # http://www.uq.edu.au/esscc
8 #
9 # Primary Business: Queensland, Australia
10 # Licensed under the Open Software License version 3.0
11 # http://www.opensource.org/licenses/osl-3.0.php
12 #
13 ########################################################
14
15 __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16 Earth Systems Science Computational Center (ESSCC)
17 http://www.uq.edu.au/esscc
18 Primary Business: Queensland, Australia"""
19 __license__="""Licensed under the Open Software License version 3.0
20 http://www.opensource.org/licenses/osl-3.0.php"""
21 __url__="https://launchpad.net/escript-finley"
22
23 """
24 :var __author__: name of author
25 :var __copyright__: copyrights
26 :var __license__: licence agreement
27 :var __url__: url entry point on documentation
28 :var __version__: version
29 :var __date__: date of the version
30 """
31
32 import numpy
33 import sympy
34
35 __author__="Cihan Altinay"
36
37
38 class Symbol(object):
39 """
40 """
41
42 def __init__(self, *args, **kwargs):
43 """
44 Initializes a new Symbol object.
45 """
46 #from esys.escript import Data
47 if len(args)==1:
48 arg=args[0]
49 if isinstance(arg, str):
50 self.__arr=numpy.array(sympy.Symbol(arg, **kwargs))
51 elif hasattr(arg, "__array__"):
52 arr=arg.__array__()
53 if len(arr.shape)>4:
54 raise ValueError("Symbol only supports tensors up to order 4")
55 self.__arr=arr.copy()
56 elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
57 self.__arr=numpy.array(arg)
58 #elif isinstance(arg, Data):
59 # self.__arr=arg
60 else:
61 raise TypeError("Unsupported argument type %s"%str(type(arg)))
62 elif len(args)==2:
63 if not isinstance(args[0], str):
64 raise TypeError("First argument must be a string")
65 if args[0].find('[')>=0 or args[0].find(']')>=0:
66 raise TypeError("Name must not contain '[' or ']'")
67 if not isinstance(args[1], tuple):
68 raise TypeError("Second argument must be a tuple")
69 name=args[0]
70 shape=args[1]
71 if len(shape)>4:
72 raise ValueError("Symbol only supports tensors up to order 4")
73 self.__arr=sympy.symarray(shape, '['+name+']')
74 else:
75 raise TypeError("Unsupported number of arguments")
76 if self.__arr.ndim==0:
77 self.name=self.__arr.item()
78 else:
79 self.name=str(self.__arr.tolist())
80
81 def __repr__(self):
82 return str(self.__arr)
83
84 def __str__(self):
85 return str(self.__arr)
86
87 def __eq__(self, other):
88 if type(self) is not type(other):
89 return False
90 if self.getRank()!=other.getRank():
91 return False
92 if self.getShape()!=other.getShape():
93 return False
94 return (self.__arr==other.__arr).all()
95
96 def __getitem__(self, key):
97 return self.__arr[key]
98
99 def __setitem__(self, key, value):
100 if isinstance(value,Symbol):
101 if value.getRank()==0:
102 self.__arr[key]=value
103 elif hasattr(self.__arr[key], "shape"):
104 if self.__arr[key].shape==value.getShape():
105 self.__arr[key]=value
106 else:
107 raise ValueError("Wrong shape of value")
108 else:
109 raise ValueError("Wrong shape of value")
110 elif isinstance(value,sympy.Basic):
111 self.__arr[key]=value
112 elif hasattr(value, "__array__"):
113 self.__arr[key]=map(sympy.sympify,value.flat)
114 else:
115 self.__arr[key]=sympy.sympify(value)
116
117 def getRank(self):
118 return self.__arr.ndim
119
120 def getShape(self):
121 return self.__arr.shape
122
123 def atoms(self, *types):
124 s=set()
125 for el in self.__arr.flat:
126 atoms=el.atoms(*types)
127 for a in atoms:
128 if a.is_Symbol:
129 n,c=Symbol._symComp(a)
130 s.add(sympy.Symbol(n))
131 else:
132 s.add(a)
133 return s
134
135 def _sympystr_(self, printer):
136 return self.lambdarepr()
137
138 def lambdarepr(self):
139 from sympy.printing.lambdarepr import lambdarepr
140 temp_arr=numpy.empty(self.__arr.shape, dtype=object)
141 for idx,el in numpy.ndenumerate(self.__arr):
142 atoms=el.atoms(sympy.Symbol)
143 # create a dictionary to convert names like x_0_0 to x[0,0]
144 symdict={}
145 for a in atoms:
146 n,c=Symbol._symComp(a)
147 if len(c)>0:
148 c=[str(i) for i in c]
149 symstr=n+'['+','.join(c)+']'
150 else:
151 symstr=n
152 symdict[a.name]=symstr
153 s=lambdarepr(el)
154 for key in symdict:
155 s=s.replace(key, symdict[key])
156 temp_arr[idx]=s
157 res='combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.__arr.shape))
158 return res
159
160 def diff(self, *symbols, **assumptions):
161 symbols=Symbol._symbolgen(*symbols)
162 result=Symbol(self.__arr)
163 for s in symbols:
164 if isinstance(s, Symbol):
165 if s.getRank()>0:
166 if s.getShape()!=self.getShape():
167 raise ValueError("Incompatible shapes")
168 a=result.__arr.flat
169 b=s.__arr.flat
170 for idx in range(len(a)):
171 a[idx]=a[idx].diff(b.next())
172 else:
173 diff_item=lambda item: getattr(item, 'diff')(s.__arr.item(), **assumptions)
174 result=result.applyfunc(diff_item)
175
176 else:
177 diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
178 result=result.applyfunc(diff_item)
179 return result
180
181 def swap_axes(self, axis0, axis1):
182 return Symbol(numpy.swapaxes(self.__arr, axis0, axis1))
183
184 def tensorproduct(self, other, axis_offset):
185 arg0_c=self.__arr.copy()
186 sh0=self.__arr.shape
187 if isinstance(other, Symbol):
188 arg1_c=other.__arr.copy()
189 sh1=other.getShape()
190 else:
191 arg1_c=other.copy()
192 sh1=other.shape
193 d0,d1,d01=1,1,1
194 for i in sh0[:self.__arr.ndim-axis_offset]: d0*=i
195 for i in sh1[axis_offset:]: d1*=i
196 for i in sh1[:axis_offset]: d01*=i
197 arg0_c.resize((d0,d01))
198 arg1_c.resize((d01,d1))
199 out=numpy.zeros((d0,d1),numpy.object)
200 for i0 in range(d0):
201 for i1 in range(d1):
202 out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
203 out.resize(sh0[:self.__arr.ndim-axis_offset]+sh1[axis_offset:])
204 return Symbol(out)
205
206 def trace(self, axis_offset):
207 sh=self.__arr.shape
208 s1=1
209 for i in range(axis_offset): s1*=sh[i]
210 s2=1
211 for i in range(axis_offset+2,len(sh)): s2*=sh[i]
212 arr_r=numpy.reshape(self.__arr,(s1,sh[axis_offset],sh[axis_offset],s2))
213 out=numpy.zeros([s1,s2],object)
214 for i1 in range(s1):
215 for i2 in range(s2):
216 for j in range(sh[axis_offset]):
217 out[i1,i2]+=arr_r[i1,j,j,i2]
218 out.resize(sh[:axis_offset]+sh[axis_offset+2:])
219 return Symbol(out)
220
221 def transpose(self, axis_offset):
222 if axis_offset is None:
223 axis_offset=int(self.__arr.ndim/2)
224 axes=range(axis_offset, self.__arr.ndim)+range(0,axis_offset)
225 return Symbol(numpy.transpose(self.__arr, axes=axes))
226
227 def applyfunc(self, f):
228 assert callable(f)
229 if self.__arr.ndim==0:
230 out=Symbol(f(self.__arr.item()))
231 else:
232 out=numpy.empty(self.__arr.shape, dtype=object)
233 for idx in numpy.ndindex(self.__arr.shape):
234 out[idx]=f(self.__arr[idx])
235 out=Symbol(out)
236 return out
237
238 def _sympy_(self):
239 return self.applyfunc(sympy.sympify)
240
241 @staticmethod
242 def _symComp(sym):
243 n=sym.name
244 a=n.split('[')
245 if len(a)!=2:
246 return n,()
247 a=a[1].split(']')
248 if len(a)!=2:
249 return n,()
250 name=a[0]
251 comps=[int(i) for i in a[1].split('_')[1:]]
252 return name,tuple(comps)
253
254 @staticmethod
255 def _symbolgen(*symbols):
256 """
257 Generator of all symbols in the argument of diff().
258 (cf. sympy.Derivative._symbolgen)
259
260 Example:
261 >> ._symbolgen(x, 3, y)
262 (x, x, x, y)
263 >> ._symbolgen(x, 10**6)
264 (x, x, x, x, x, x, x, ...)
265 """
266 from itertools import repeat
267 last_s = symbols[len(symbols)-1]
268 if not isinstance(last_s, Symbol):
269 last_s=sympy.sympify(last_s)
270 for i in xrange(len(symbols)):
271 s = symbols[i]
272 if not isinstance(s, Symbol):
273 s=sympy.sympify(s)
274 next_s = None
275 if s != last_s:
276 next_s = symbols[i+1]
277 if not isinstance(next_s, Symbol):
278 next_s=sympy.sympify(next_s)
279
280 if isinstance(s, sympy.Integer):
281 continue
282 elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
283 # handle cases like (x, 3)
284 if isinstance(next_s, sympy.Integer):
285 # yield (x, x, x)
286 for copy_s in repeat(s,int(next_s)):
287 yield copy_s
288 else:
289 yield s
290 else:
291 yield s
292
293 # unary/binary operations follow
294
295 def __pos__(self):
296 return self
297
298 def __neg__(self):
299 return Symbol(-self.__arr)
300
301 def __abs__(self):
302 return Symbol(abs(self.__arr))
303
304 def __add__(self, other):
305 if isinstance(other, Symbol):
306 return Symbol(self.__arr+other.__arr)
307 return Symbol(self.__arr+other)
308
309 def __radd__(self, other):
310 if isinstance(other, Symbol):
311 return Symbol(other.__arr+self.__arr)
312 return Symbol(other+self.__arr)
313
314 def __sub__(self, other):
315 if isinstance(other, Symbol):
316 return Symbol(self.__arr-other.__arr)
317 return Symbol(self.__arr-other)
318
319 def __rsub__(self, other):
320 if isinstance(other, Symbol):
321 return Symbol(other.__arr-self.__arr)
322 return Symbol(other-self.__arr)
323
324 def __mul__(self, other):
325 if isinstance(other, Symbol):
326 return Symbol(self.__arr*other.__arr)
327 return Symbol(self.__arr*other)
328
329 def __rmul__(self, other):
330 if isinstance(other, Symbol):
331 return Symbol(other.__arr*self.__arr)
332 return Symbol(other*self.__arr)
333
334 def __div__(self, other):
335 if isinstance(other, Symbol):
336 return Symbol(self.__arr/other.__arr)
337 return Symbol(self.__arr/other)
338
339 def __rdiv__(self, other):
340 if isinstance(other, Symbol):
341 return Symbol(other.__arr/self.__arr)
342 return Symbol(other/self.__arr)
343
344 def __pow__(self, other):
345 if isinstance(other, Symbol):
346 return Symbol(self.__arr**other.__arr)
347 return Symbol(self.__arr**other)
348
349 def __rpow__(self, other):
350 if isinstance(other, Symbol):
351 return Symbol(other.__arr**self.__arr)
352 return Symbol(other**self.__arr)
353
354
355 def symbols(*names, **kwargs):
356 """
357 Emulates the behaviour of sympy.symbols.
358 """
359
360 shape=kwargs.pop('shape', ())
361
362 s = names[0]
363 if not isinstance(s, list):
364 import re
365 s = re.split('\s|,', s)
366 res = []
367 for t in s:
368 # skip empty strings
369 if not t:
370 continue
371 sym = Symbol(t, shape, **kwargs)
372 res.append(sym)
373 res = tuple(res)
374 if len(res) == 0: # var('')
375 res = None
376 elif len(res) == 1: # var('x')
377 res = res[0]
378 # otherwise var('a b ...')
379 return res
380
381 def combineData(array, shape):
382 # array could just be a single value
383 if not hasattr(array,'__len__') and shape==():
384 return array
385
386 from esys.escript import Data
387 n=numpy.array(array) # for indexing
388
389 # find function space if any
390 dom=set()
391 fs=set()
392 for idx in numpy.ndindex(shape):
393 if isinstance(n[idx], Data):
394 fs.add(n[idx].getFunctionSpace())
395 dom.add(n[idx].getDomain())
396
397 if len(dom)>1:
398 domain=dom.pop()
399 while len(dom)>0:
400 if domain!=dom.pop():
401 raise ValueError("Mixing of domains not supported")
402
403 if len(fs)>0:
404 d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
405 else:
406 d=0.
407 for idx in numpy.ndindex(shape):
408 #z=numpy.zeros(shape)
409 #z[idx]=1.
410 #d+=n[idx]*z # much slower!
411 if hasattr(n[idx], "ndim") and n[idx].ndim==0:
412 d[idx]=float(n[idx])
413 else:
414 d[idx]=n[idx]
415 return d
416

  ViewVC Help
Powered by ViewVC 1.1.26