/[escript]/branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
ViewVC logotype

Annotation of /branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3512 - (hide annotations)
Wed May 18 06:22:46 2011 UTC (8 years, 6 months ago) by caltinay
File MIME type: text/x-python
File size: 15772 byte(s)
Implementation of symbolic grad() and a few fixes.

1 caltinay 3507 # -*- coding: utf-8 -*-
2    
3     ########################################################
4     #
5     # Copyright (c) 2003-2010 by University of Queensland
6     # Earth Systems Science Computational Center (ESSCC)
7     # http://www.uq.edu.au/esscc
8     #
9     # Primary Business: Queensland, Australia
10     # Licensed under the Open Software License version 3.0
11     # http://www.opensource.org/licenses/osl-3.0.php
12     #
13     ########################################################
14    
15     __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16     Earth Systems Science Computational Center (ESSCC)
17     http://www.uq.edu.au/esscc
18     Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22    
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34    
35     __author__="Cihan Altinay"
36    
37    
38     class Symbol(object):
39     """
40     """
41    
42     def __init__(self, *args, **kwargs):
43     """
44     Initializes a new Symbol object.
45     """
46     #from esys.escript import Data
47     if len(args)==1:
48     arg=args[0]
49     if isinstance(arg, str):
50 caltinay 3512 if arg.find('[')>=0 or arg.find(']')>=0:
51     raise TypeError("Name must not contain '[' or ']'")
52     self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
53 caltinay 3507 elif hasattr(arg, "__array__"):
54     arr=arg.__array__()
55     if len(arr.shape)>4:
56     raise ValueError("Symbol only supports tensors up to order 4")
57 caltinay 3512 self._arr=arr.copy()
58 caltinay 3507 elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
59 caltinay 3512 self._arr=numpy.array(arg)
60 caltinay 3507 #elif isinstance(arg, Data):
61 caltinay 3512 # self._arr=arg
62 caltinay 3507 else:
63     raise TypeError("Unsupported argument type %s"%str(type(arg)))
64     elif len(args)==2:
65     if not isinstance(args[0], str):
66     raise TypeError("First argument must be a string")
67     if args[0].find('[')>=0 or args[0].find(']')>=0:
68     raise TypeError("Name must not contain '[' or ']'")
69     if not isinstance(args[1], tuple):
70     raise TypeError("Second argument must be a tuple")
71     name=args[0]
72     shape=args[1]
73     if len(shape)>4:
74     raise ValueError("Symbol only supports tensors up to order 4")
75 caltinay 3509 if len(shape)==0:
76 caltinay 3512 self._arr=numpy.array(sympy.Symbol(name, **kwargs))
77 caltinay 3509 else:
78 caltinay 3512 self._arr=sympy.symarray(shape, '['+name+']')
79 caltinay 3507 else:
80     raise TypeError("Unsupported number of arguments")
81 caltinay 3512 if self._arr.ndim==0:
82     self.name=str(self._arr.item())
83 caltinay 3507 else:
84 caltinay 3512 self.name=str(self._arr.tolist())
85 caltinay 3507
86     def __repr__(self):
87 caltinay 3512 return str(self._arr)
88 caltinay 3507
89     def __str__(self):
90 caltinay 3512 return str(self._arr)
91 caltinay 3507
92     def __eq__(self, other):
93     if type(self) is not type(other):
94     return False
95     if self.getRank()!=other.getRank():
96     return False
97     if self.getShape()!=other.getShape():
98     return False
99 caltinay 3512 return (self._arr==other._arr).all()
100 caltinay 3507
101     def __getitem__(self, key):
102 caltinay 3512 return self._arr[key]
103 caltinay 3507
104     def __setitem__(self, key, value):
105     if isinstance(value,Symbol):
106     if value.getRank()==0:
107 caltinay 3512 self._arr[key]=value
108     elif hasattr(self._arr[key], "shape"):
109     if self._arr[key].shape==value.getShape():
110     self._arr[key]=value
111 caltinay 3507 else:
112     raise ValueError("Wrong shape of value")
113     else:
114     raise ValueError("Wrong shape of value")
115     elif isinstance(value,sympy.Basic):
116 caltinay 3512 self._arr[key]=value
117 caltinay 3507 elif hasattr(value, "__array__"):
118 caltinay 3512 self._arr[key]=map(sympy.sympify,value.flat)
119 caltinay 3507 else:
120 caltinay 3512 self._arr[key]=sympy.sympify(value)
121 caltinay 3507
122     def getRank(self):
123 caltinay 3512 return self._arr.ndim
124 caltinay 3507
125     def getShape(self):
126 caltinay 3512 return self._arr.shape
127 caltinay 3507
128     def atoms(self, *types):
129     s=set()
130 caltinay 3512 for el in self._arr.flat:
131 caltinay 3507 atoms=el.atoms(*types)
132     for a in atoms:
133     if a.is_Symbol:
134     n,c=Symbol._symComp(a)
135     s.add(sympy.Symbol(n))
136     else:
137     s.add(a)
138     return s
139    
140     def _sympystr_(self, printer):
141     return self.lambdarepr()
142    
143     def lambdarepr(self):
144     from sympy.printing.lambdarepr import lambdarepr
145 caltinay 3512 if self.getRank()==0:
146     return lambdarepr(self._arr.item())
147     temp_arr=numpy.empty(self.getShape(), dtype=object)
148     for idx,el in numpy.ndenumerate(self._arr):
149 caltinay 3507 atoms=el.atoms(sympy.Symbol)
150     # create a dictionary to convert names like x_0_0 to x[0,0]
151     symdict={}
152     for a in atoms:
153     n,c=Symbol._symComp(a)
154     if len(c)>0:
155     c=[str(i) for i in c]
156     symstr=n+'['+','.join(c)+']'
157     else:
158     symstr=n
159     symdict[a.name]=symstr
160     s=lambdarepr(el)
161     for key in symdict:
162     s=s.replace(key, symdict[key])
163     temp_arr[idx]=s
164 caltinay 3512 return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
165 caltinay 3507
166     def diff(self, *symbols, **assumptions):
167     symbols=Symbol._symbolgen(*symbols)
168 caltinay 3512 result=Symbol(self._arr)
169 caltinay 3507 for s in symbols:
170     if isinstance(s, Symbol):
171     if s.getRank()>0:
172     if s.getShape()!=self.getShape():
173     raise ValueError("Incompatible shapes")
174 caltinay 3512 a=result._arr.flat
175     b=s._arr.flat
176 caltinay 3507 for idx in range(len(a)):
177     a[idx]=a[idx].diff(b.next())
178     else:
179 caltinay 3512 diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
180 caltinay 3507 result=result.applyfunc(diff_item)
181    
182     else:
183     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
184     result=result.applyfunc(diff_item)
185     return result
186    
187 caltinay 3512 def grad(self, where=None):
188     if isinstance(where, Symbol):
189     if where.getRank()>0:
190     raise ValueError("'where' must be a scalar symbol")
191     where=where._arr.item()
192    
193     from functions import grad_n
194     dim=2
195     out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
196     for d in range(dim):
197     for idx in numpy.ndindex(self.getShape()):
198     index=idx+(d,)
199     if where is None:
200     out[index]=grad_n(out[index],d)
201     else:
202     out[index]=grad_n(out[index],d,where)
203     return Symbol(out)
204    
205 caltinay 3507 def swap_axes(self, axis0, axis1):
206 caltinay 3512 return Symbol(numpy.swapaxes(self._arr, axis0, axis1))
207 caltinay 3507
208 caltinay 3509 def tensorProduct(self, other, axis_offset):
209 caltinay 3512 arg0_c=self._arr.copy()
210     sh0=self.getShape()
211 caltinay 3507 if isinstance(other, Symbol):
212 caltinay 3512 arg1_c=other._arr.copy()
213 caltinay 3507 sh1=other.getShape()
214     else:
215     arg1_c=other.copy()
216     sh1=other.shape
217     d0,d1,d01=1,1,1
218 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
219 caltinay 3507 for i in sh1[axis_offset:]: d1*=i
220     for i in sh1[:axis_offset]: d01*=i
221     arg0_c.resize((d0,d01))
222     arg1_c.resize((d01,d1))
223     out=numpy.zeros((d0,d1),numpy.object)
224     for i0 in range(d0):
225     for i1 in range(d1):
226     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
227 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
228 caltinay 3507 return Symbol(out)
229    
230 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
231 caltinay 3512 arg0_c=self._arr.copy()
232     sh0=self.getShape()
233 caltinay 3509 if isinstance(other, Symbol):
234 caltinay 3512 arg1_c=other._arr.copy()
235 caltinay 3509 sh1=other.getShape()
236     else:
237     arg1_c=other.copy()
238     sh1=other.shape
239     d0,d1,d01=1,1,1
240     for i in sh0[axis_offset:]: d0*=i
241     for i in sh1[axis_offset:]: d1*=i
242     for i in sh1[:axis_offset]: d01*=i
243     arg0_c.resize((d01,d0))
244     arg1_c.resize((d01,d1))
245     out=numpy.zeros((d0,d1),numpy.object)
246     for i0 in range(d0):
247     for i1 in range(d1):
248     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
249     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
250     return Symbol(out)
251    
252     def tensorTransposedProduct(self, other, axis_offset):
253 caltinay 3512 arg0_c=self._arr.copy()
254     sh0=self.getShape()
255 caltinay 3509 if isinstance(other, Symbol):
256 caltinay 3512 arg1_c=other._arr.copy()
257 caltinay 3509 sh1=other.getShape()
258     r1=other.getRank()
259     else:
260     arg1_c=other.copy()
261     sh1=other.shape
262     r1=other.ndim
263     d0,d1,d01=1,1,1
264 caltinay 3512 for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
265 caltinay 3509 for i in sh1[:r1-axis_offset]: d1*=i
266     for i in sh1[r1-axis_offset:]: d01*=i
267     arg0_c.resize((d0,d01))
268     arg1_c.resize((d1,d01))
269     out=numpy.zeros((d0,d1),numpy.object)
270     for i0 in range(d0):
271     for i1 in range(d1):
272     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
273 caltinay 3512 out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
274 caltinay 3509 return Symbol(out)
275    
276 caltinay 3507 def trace(self, axis_offset):
277 caltinay 3512 sh=self.getShape()
278 caltinay 3507 s1=1
279     for i in range(axis_offset): s1*=sh[i]
280     s2=1
281     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
282 caltinay 3512 arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
283 caltinay 3507 out=numpy.zeros([s1,s2],object)
284     for i1 in range(s1):
285     for i2 in range(s2):
286     for j in range(sh[axis_offset]):
287     out[i1,i2]+=arr_r[i1,j,j,i2]
288     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
289     return Symbol(out)
290    
291     def transpose(self, axis_offset):
292     if axis_offset is None:
293 caltinay 3512 axis_offset=int(self._arr.ndim/2)
294     axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
295     return Symbol(numpy.transpose(self._arr, axes=axes))
296 caltinay 3507
297     def applyfunc(self, f):
298     assert callable(f)
299 caltinay 3512 if self._arr.ndim==0:
300     out=Symbol(f(self._arr.item()))
301 caltinay 3507 else:
302 caltinay 3512 out=numpy.empty(self.getShape(), dtype=object)
303     for idx in numpy.ndindex(self.getShape()):
304     out[idx]=f(self._arr[idx])
305 caltinay 3507 out=Symbol(out)
306     return out
307    
308     def _sympy_(self):
309     return self.applyfunc(sympy.sympify)
310    
311     @staticmethod
312     def _symComp(sym):
313     n=sym.name
314     a=n.split('[')
315     if len(a)!=2:
316     return n,()
317     a=a[1].split(']')
318     if len(a)!=2:
319     return n,()
320     name=a[0]
321     comps=[int(i) for i in a[1].split('_')[1:]]
322     return name,tuple(comps)
323    
324     @staticmethod
325     def _symbolgen(*symbols):
326     """
327     Generator of all symbols in the argument of diff().
328     (cf. sympy.Derivative._symbolgen)
329    
330     Example:
331     >> ._symbolgen(x, 3, y)
332     (x, x, x, y)
333     >> ._symbolgen(x, 10**6)
334     (x, x, x, x, x, x, x, ...)
335     """
336     from itertools import repeat
337     last_s = symbols[len(symbols)-1]
338     if not isinstance(last_s, Symbol):
339     last_s=sympy.sympify(last_s)
340     for i in xrange(len(symbols)):
341     s = symbols[i]
342     if not isinstance(s, Symbol):
343     s=sympy.sympify(s)
344     next_s = None
345     if s != last_s:
346     next_s = symbols[i+1]
347     if not isinstance(next_s, Symbol):
348     next_s=sympy.sympify(next_s)
349    
350     if isinstance(s, sympy.Integer):
351     continue
352     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
353     # handle cases like (x, 3)
354     if isinstance(next_s, sympy.Integer):
355     # yield (x, x, x)
356     for copy_s in repeat(s,int(next_s)):
357     yield copy_s
358     else:
359     yield s
360     else:
361     yield s
362    
363     # unary/binary operations follow
364    
365     def __pos__(self):
366     return self
367    
368     def __neg__(self):
369 caltinay 3512 return Symbol(-self._arr)
370 caltinay 3507
371     def __abs__(self):
372 caltinay 3512 return Symbol(abs(self._arr))
373 caltinay 3507
374     def __add__(self, other):
375     if isinstance(other, Symbol):
376 caltinay 3512 return Symbol(self._arr+other._arr)
377     return Symbol(self._arr+other)
378 caltinay 3507
379     def __radd__(self, other):
380     if isinstance(other, Symbol):
381 caltinay 3512 return Symbol(other._arr+self._arr)
382     return Symbol(other+self._arr)
383 caltinay 3507
384     def __sub__(self, other):
385     if isinstance(other, Symbol):
386 caltinay 3512 return Symbol(self._arr-other._arr)
387     return Symbol(self._arr-other)
388 caltinay 3507
389     def __rsub__(self, other):
390     if isinstance(other, Symbol):
391 caltinay 3512 return Symbol(other._arr-self._arr)
392     return Symbol(other-self._arr)
393 caltinay 3507
394     def __mul__(self, other):
395     if isinstance(other, Symbol):
396 caltinay 3512 return Symbol(self._arr*other._arr)
397     return Symbol(self._arr*other)
398 caltinay 3507
399     def __rmul__(self, other):
400     if isinstance(other, Symbol):
401 caltinay 3512 return Symbol(other._arr*self._arr)
402     return Symbol(other*self._arr)
403 caltinay 3507
404     def __div__(self, other):
405     if isinstance(other, Symbol):
406 caltinay 3512 return Symbol(self._arr/other._arr)
407     return Symbol(self._arr/other)
408 caltinay 3507
409     def __rdiv__(self, other):
410     if isinstance(other, Symbol):
411 caltinay 3512 return Symbol(other._arr/self._arr)
412     return Symbol(other/self._arr)
413 caltinay 3507
414     def __pow__(self, other):
415     if isinstance(other, Symbol):
416 caltinay 3512 return Symbol(self._arr**other._arr)
417     return Symbol(self._arr**other)
418 caltinay 3507
419     def __rpow__(self, other):
420     if isinstance(other, Symbol):
421 caltinay 3512 return Symbol(other._arr**self._arr)
422     return Symbol(other**self._arr)
423 caltinay 3507
424    
425     def symbols(*names, **kwargs):
426     """
427     Emulates the behaviour of sympy.symbols.
428     """
429    
430     shape=kwargs.pop('shape', ())
431    
432     s = names[0]
433     if not isinstance(s, list):
434     import re
435     s = re.split('\s|,', s)
436     res = []
437     for t in s:
438     # skip empty strings
439     if not t:
440     continue
441     sym = Symbol(t, shape, **kwargs)
442     res.append(sym)
443     res = tuple(res)
444     if len(res) == 0: # var('')
445     res = None
446     elif len(res) == 1: # var('x')
447     res = res[0]
448     # otherwise var('a b ...')
449     return res
450    
451     def combineData(array, shape):
452     # array could just be a single value
453     if not hasattr(array,'__len__') and shape==():
454     return array
455    
456     from esys.escript import Data
457     n=numpy.array(array) # for indexing
458    
459     # find function space if any
460     dom=set()
461     fs=set()
462     for idx in numpy.ndindex(shape):
463     if isinstance(n[idx], Data):
464     fs.add(n[idx].getFunctionSpace())
465     dom.add(n[idx].getDomain())
466    
467     if len(dom)>1:
468     domain=dom.pop()
469     while len(dom)>0:
470     if domain!=dom.pop():
471     raise ValueError("Mixing of domains not supported")
472    
473     if len(fs)>0:
474     d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
475     else:
476 caltinay 3509 d=numpy.zeros(shape)
477 caltinay 3507 for idx in numpy.ndindex(shape):
478     #z=numpy.zeros(shape)
479     #z[idx]=1.
480     #d+=n[idx]*z # much slower!
481     if hasattr(n[idx], "ndim") and n[idx].ndim==0:
482     d[idx]=float(n[idx])
483     else:
484     d[idx]=n[idx]
485     return d
486    

  ViewVC Help
Powered by ViewVC 1.1.26