/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3507 - (hide annotations)
Wed May 11 06:04:52 2011 UTC (8 years, 6 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 13218 byte(s)
New approach with own Symbol class, symbolic components and looser dependency
on sympy.

1 caltinay 3507 # -*- coding: utf-8 -*-
2    
3     ########################################################
4     #
5     # Copyright (c) 2003-2010 by University of Queensland
6     # Earth Systems Science Computational Center (ESSCC)
7     # http://www.uq.edu.au/esscc
8     #
9     # Primary Business: Queensland, Australia
10     # Licensed under the Open Software License version 3.0
11     # http://www.opensource.org/licenses/osl-3.0.php
12     #
13     ########################################################
14    
15     __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16     Earth Systems Science Computational Center (ESSCC)
17     http://www.uq.edu.au/esscc
18     Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22    
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34    
35     __author__="Cihan Altinay"
36    
37    
38     class Symbol(object):
39     """
40     """
41    
42     def __init__(self, *args, **kwargs):
43     """
44     Initializes a new Symbol object.
45     """
46     #from esys.escript import Data
47     if len(args)==1:
48     arg=args[0]
49     if isinstance(arg, str):
50     self.__arr=numpy.array(sympy.Symbol(arg, **kwargs))
51     elif hasattr(arg, "__array__"):
52     arr=arg.__array__()
53     if len(arr.shape)>4:
54     raise ValueError("Symbol only supports tensors up to order 4")
55     self.__arr=arr.copy()
56     elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
57     self.__arr=numpy.array(arg)
58     #elif isinstance(arg, Data):
59     # self.__arr=arg
60     else:
61     raise TypeError("Unsupported argument type %s"%str(type(arg)))
62     elif len(args)==2:
63     if not isinstance(args[0], str):
64     raise TypeError("First argument must be a string")
65     if args[0].find('[')>=0 or args[0].find(']')>=0:
66     raise TypeError("Name must not contain '[' or ']'")
67     if not isinstance(args[1], tuple):
68     raise TypeError("Second argument must be a tuple")
69     name=args[0]
70     shape=args[1]
71     if len(shape)>4:
72     raise ValueError("Symbol only supports tensors up to order 4")
73     self.__arr=sympy.symarray(shape, '['+name+']')
74     else:
75     raise TypeError("Unsupported number of arguments")
76     if self.__arr.ndim==0:
77     self.name=self.__arr.item()
78     else:
79     self.name=str(self.__arr.tolist())
80    
81     def __repr__(self):
82     return str(self.__arr)
83    
84     def __str__(self):
85     return str(self.__arr)
86    
87     def __eq__(self, other):
88     if type(self) is not type(other):
89     return False
90     if self.getRank()!=other.getRank():
91     return False
92     if self.getShape()!=other.getShape():
93     return False
94     return (self.__arr==other.__arr).all()
95    
96     def __getitem__(self, key):
97     return self.__arr[key]
98    
99     def __setitem__(self, key, value):
100     if isinstance(value,Symbol):
101     if value.getRank()==0:
102     self.__arr[key]=value
103     elif hasattr(self.__arr[key], "shape"):
104     if self.__arr[key].shape==value.getShape():
105     self.__arr[key]=value
106     else:
107     raise ValueError("Wrong shape of value")
108     else:
109     raise ValueError("Wrong shape of value")
110     elif isinstance(value,sympy.Basic):
111     self.__arr[key]=value
112     elif hasattr(value, "__array__"):
113     self.__arr[key]=map(sympy.sympify,value.flat)
114     else:
115     self.__arr[key]=sympy.sympify(value)
116    
117     def getRank(self):
118     return self.__arr.ndim
119    
120     def getShape(self):
121     return self.__arr.shape
122    
123     def atoms(self, *types):
124     s=set()
125     for el in self.__arr.flat:
126     atoms=el.atoms(*types)
127     for a in atoms:
128     if a.is_Symbol:
129     n,c=Symbol._symComp(a)
130     s.add(sympy.Symbol(n))
131     else:
132     s.add(a)
133     return s
134    
135     def _sympystr_(self, printer):
136     return self.lambdarepr()
137    
138     def lambdarepr(self):
139     from sympy.printing.lambdarepr import lambdarepr
140     temp_arr=numpy.empty(self.__arr.shape, dtype=object)
141     for idx,el in numpy.ndenumerate(self.__arr):
142     atoms=el.atoms(sympy.Symbol)
143     # create a dictionary to convert names like x_0_0 to x[0,0]
144     symdict={}
145     for a in atoms:
146     n,c=Symbol._symComp(a)
147     if len(c)>0:
148     c=[str(i) for i in c]
149     symstr=n+'['+','.join(c)+']'
150     else:
151     symstr=n
152     symdict[a.name]=symstr
153     s=lambdarepr(el)
154     for key in symdict:
155     s=s.replace(key, symdict[key])
156     temp_arr[idx]=s
157     res='combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.__arr.shape))
158     return res
159    
160     def diff(self, *symbols, **assumptions):
161     symbols=Symbol._symbolgen(*symbols)
162     result=Symbol(self.__arr)
163     for s in symbols:
164     if isinstance(s, Symbol):
165     if s.getRank()>0:
166     if s.getShape()!=self.getShape():
167     raise ValueError("Incompatible shapes")
168     a=result.__arr.flat
169     b=s.__arr.flat
170     for idx in range(len(a)):
171     a[idx]=a[idx].diff(b.next())
172     else:
173     diff_item=lambda item: getattr(item, 'diff')(s.__arr.item(), **assumptions)
174     result=result.applyfunc(diff_item)
175    
176     else:
177     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
178     result=result.applyfunc(diff_item)
179     return result
180    
181     def swap_axes(self, axis0, axis1):
182     return Symbol(numpy.swapaxes(self.__arr, axis0, axis1))
183    
184     def tensorproduct(self, other, axis_offset):
185     arg0_c=self.__arr.copy()
186     sh0=self.__arr.shape
187     if isinstance(other, Symbol):
188     arg1_c=other.__arr.copy()
189     sh1=other.getShape()
190     else:
191     arg1_c=other.copy()
192     sh1=other.shape
193     d0,d1,d01=1,1,1
194     for i in sh0[:self.__arr.ndim-axis_offset]: d0*=i
195     for i in sh1[axis_offset:]: d1*=i
196     for i in sh1[:axis_offset]: d01*=i
197     arg0_c.resize((d0,d01))
198     arg1_c.resize((d01,d1))
199     out=numpy.zeros((d0,d1),numpy.object)
200     for i0 in range(d0):
201     for i1 in range(d1):
202     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
203     out.resize(sh0[:self.__arr.ndim-axis_offset]+sh1[axis_offset:])
204     return Symbol(out)
205    
206     def trace(self, axis_offset):
207     sh=self.__arr.shape
208     s1=1
209     for i in range(axis_offset): s1*=sh[i]
210     s2=1
211     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
212     arr_r=numpy.reshape(self.__arr,(s1,sh[axis_offset],sh[axis_offset],s2))
213     out=numpy.zeros([s1,s2],object)
214     for i1 in range(s1):
215     for i2 in range(s2):
216     for j in range(sh[axis_offset]):
217     out[i1,i2]+=arr_r[i1,j,j,i2]
218     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
219     return Symbol(out)
220    
221     def transpose(self, axis_offset):
222     if axis_offset is None:
223     axis_offset=int(self.__arr.ndim/2)
224     axes=range(axis_offset, self.__arr.ndim)+range(0,axis_offset)
225     return Symbol(numpy.transpose(self.__arr, axes=axes))
226    
227     def applyfunc(self, f):
228     assert callable(f)
229     if self.__arr.ndim==0:
230     out=Symbol(f(self.__arr.item()))
231     else:
232     out=numpy.empty(self.__arr.shape, dtype=object)
233     for idx in numpy.ndindex(self.__arr.shape):
234     out[idx]=f(self.__arr[idx])
235     out=Symbol(out)
236     return out
237    
238     def _sympy_(self):
239     return self.applyfunc(sympy.sympify)
240    
241     @staticmethod
242     def _symComp(sym):
243     n=sym.name
244     a=n.split('[')
245     if len(a)!=2:
246     return n,()
247     a=a[1].split(']')
248     if len(a)!=2:
249     return n,()
250     name=a[0]
251     comps=[int(i) for i in a[1].split('_')[1:]]
252     return name,tuple(comps)
253    
254     @staticmethod
255     def _symbolgen(*symbols):
256     """
257     Generator of all symbols in the argument of diff().
258     (cf. sympy.Derivative._symbolgen)
259    
260     Example:
261     >> ._symbolgen(x, 3, y)
262     (x, x, x, y)
263     >> ._symbolgen(x, 10**6)
264     (x, x, x, x, x, x, x, ...)
265     """
266     from itertools import repeat
267     last_s = symbols[len(symbols)-1]
268     if not isinstance(last_s, Symbol):
269     last_s=sympy.sympify(last_s)
270     for i in xrange(len(symbols)):
271     s = symbols[i]
272     if not isinstance(s, Symbol):
273     s=sympy.sympify(s)
274     next_s = None
275     if s != last_s:
276     next_s = symbols[i+1]
277     if not isinstance(next_s, Symbol):
278     next_s=sympy.sympify(next_s)
279    
280     if isinstance(s, sympy.Integer):
281     continue
282     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
283     # handle cases like (x, 3)
284     if isinstance(next_s, sympy.Integer):
285     # yield (x, x, x)
286     for copy_s in repeat(s,int(next_s)):
287     yield copy_s
288     else:
289     yield s
290     else:
291     yield s
292    
293     # unary/binary operations follow
294    
295     def __pos__(self):
296     return self
297    
298     def __neg__(self):
299     return Symbol(-self.__arr)
300    
301     def __abs__(self):
302     return Symbol(abs(self.__arr))
303    
304     def __add__(self, other):
305     if isinstance(other, Symbol):
306     return Symbol(self.__arr+other.__arr)
307     return Symbol(self.__arr+other)
308    
309     def __radd__(self, other):
310     if isinstance(other, Symbol):
311     return Symbol(other.__arr+self.__arr)
312     return Symbol(other+self.__arr)
313    
314     def __sub__(self, other):
315     if isinstance(other, Symbol):
316     return Symbol(self.__arr-other.__arr)
317     return Symbol(self.__arr-other)
318    
319     def __rsub__(self, other):
320     if isinstance(other, Symbol):
321     return Symbol(other.__arr-self.__arr)
322     return Symbol(other-self.__arr)
323    
324     def __mul__(self, other):
325     if isinstance(other, Symbol):
326     return Symbol(self.__arr*other.__arr)
327     return Symbol(self.__arr*other)
328    
329     def __rmul__(self, other):
330     if isinstance(other, Symbol):
331     return Symbol(other.__arr*self.__arr)
332     return Symbol(other*self.__arr)
333    
334     def __div__(self, other):
335     if isinstance(other, Symbol):
336     return Symbol(self.__arr/other.__arr)
337     return Symbol(self.__arr/other)
338    
339     def __rdiv__(self, other):
340     if isinstance(other, Symbol):
341     return Symbol(other.__arr/self.__arr)
342     return Symbol(other/self.__arr)
343    
344     def __pow__(self, other):
345     if isinstance(other, Symbol):
346     return Symbol(self.__arr**other.__arr)
347     return Symbol(self.__arr**other)
348    
349     def __rpow__(self, other):
350     if isinstance(other, Symbol):
351     return Symbol(other.__arr**self.__arr)
352     return Symbol(other**self.__arr)
353    
354    
355     def symbols(*names, **kwargs):
356     """
357     Emulates the behaviour of sympy.symbols.
358     """
359    
360     shape=kwargs.pop('shape', ())
361    
362     s = names[0]
363     if not isinstance(s, list):
364     import re
365     s = re.split('\s|,', s)
366     res = []
367     for t in s:
368     # skip empty strings
369     if not t:
370     continue
371     sym = Symbol(t, shape, **kwargs)
372     res.append(sym)
373     res = tuple(res)
374     if len(res) == 0: # var('')
375     res = None
376     elif len(res) == 1: # var('x')
377     res = res[0]
378     # otherwise var('a b ...')
379     return res
380    
381     def combineData(array, shape):
382     # array could just be a single value
383     if not hasattr(array,'__len__') and shape==():
384     return array
385    
386     from esys.escript import Data
387     n=numpy.array(array) # for indexing
388    
389     # find function space if any
390     dom=set()
391     fs=set()
392     for idx in numpy.ndindex(shape):
393     if isinstance(n[idx], Data):
394     fs.add(n[idx].getFunctionSpace())
395     dom.add(n[idx].getDomain())
396    
397     if len(dom)>1:
398     domain=dom.pop()
399     while len(dom)>0:
400     if domain!=dom.pop():
401     raise ValueError("Mixing of domains not supported")
402    
403     if len(fs)>0:
404     d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
405     else:
406     d=0.
407     for idx in numpy.ndindex(shape):
408     #z=numpy.zeros(shape)
409     #z[idx]=1.
410     #d+=n[idx]*z # much slower!
411     if hasattr(n[idx], "ndim") and n[idx].ndim==0:
412     d[idx]=float(n[idx])
413     else:
414     d[idx]=n[idx]
415     return d
416    

  ViewVC Help
Powered by ViewVC 1.1.26