/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Annotation of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3509 - (hide annotations)
Fri May 13 06:01:52 2011 UTC (8 years, 5 months ago) by caltinay
Original Path: branches/symbolic_from_3470/escript/py_src/symbolic/symbols.py
File MIME type: text/x-python
File size: 15003 byte(s)
Some fixes, additions and changes to unit tests

1 caltinay 3507 # -*- coding: utf-8 -*-
2    
3     ########################################################
4     #
5     # Copyright (c) 2003-2010 by University of Queensland
6     # Earth Systems Science Computational Center (ESSCC)
7     # http://www.uq.edu.au/esscc
8     #
9     # Primary Business: Queensland, Australia
10     # Licensed under the Open Software License version 3.0
11     # http://www.opensource.org/licenses/osl-3.0.php
12     #
13     ########################################################
14    
15     __copyright__="""Copyright (c) 2003-2010 by University of Queensland
16     Earth Systems Science Computational Center (ESSCC)
17     http://www.uq.edu.au/esscc
18     Primary Business: Queensland, Australia"""
19     __license__="""Licensed under the Open Software License version 3.0
20     http://www.opensource.org/licenses/osl-3.0.php"""
21     __url__="https://launchpad.net/escript-finley"
22    
23     """
24     :var __author__: name of author
25     :var __copyright__: copyrights
26     :var __license__: licence agreement
27     :var __url__: url entry point on documentation
28     :var __version__: version
29     :var __date__: date of the version
30     """
31    
32     import numpy
33     import sympy
34    
35     __author__="Cihan Altinay"
36    
37    
38     class Symbol(object):
39     """
40     """
41    
42     def __init__(self, *args, **kwargs):
43     """
44     Initializes a new Symbol object.
45     """
46     #from esys.escript import Data
47     if len(args)==1:
48     arg=args[0]
49     if isinstance(arg, str):
50     self.__arr=numpy.array(sympy.Symbol(arg, **kwargs))
51     elif hasattr(arg, "__array__"):
52     arr=arg.__array__()
53     if len(arr.shape)>4:
54     raise ValueError("Symbol only supports tensors up to order 4")
55     self.__arr=arr.copy()
56     elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
57     self.__arr=numpy.array(arg)
58     #elif isinstance(arg, Data):
59     # self.__arr=arg
60     else:
61     raise TypeError("Unsupported argument type %s"%str(type(arg)))
62     elif len(args)==2:
63     if not isinstance(args[0], str):
64     raise TypeError("First argument must be a string")
65     if args[0].find('[')>=0 or args[0].find(']')>=0:
66     raise TypeError("Name must not contain '[' or ']'")
67     if not isinstance(args[1], tuple):
68     raise TypeError("Second argument must be a tuple")
69     name=args[0]
70     shape=args[1]
71     if len(shape)>4:
72     raise ValueError("Symbol only supports tensors up to order 4")
73 caltinay 3509 if len(shape)==0:
74     self.__arr=numpy.array(sympy.Symbol(name, **kwargs))
75     else:
76     self.__arr=sympy.symarray(shape, '['+name+']')
77 caltinay 3507 else:
78     raise TypeError("Unsupported number of arguments")
79     if self.__arr.ndim==0:
80     self.name=self.__arr.item()
81     else:
82     self.name=str(self.__arr.tolist())
83    
84     def __repr__(self):
85     return str(self.__arr)
86    
87     def __str__(self):
88     return str(self.__arr)
89    
90     def __eq__(self, other):
91     if type(self) is not type(other):
92     return False
93     if self.getRank()!=other.getRank():
94     return False
95     if self.getShape()!=other.getShape():
96     return False
97     return (self.__arr==other.__arr).all()
98    
99     def __getitem__(self, key):
100     return self.__arr[key]
101    
102     def __setitem__(self, key, value):
103     if isinstance(value,Symbol):
104     if value.getRank()==0:
105     self.__arr[key]=value
106     elif hasattr(self.__arr[key], "shape"):
107     if self.__arr[key].shape==value.getShape():
108     self.__arr[key]=value
109     else:
110     raise ValueError("Wrong shape of value")
111     else:
112     raise ValueError("Wrong shape of value")
113     elif isinstance(value,sympy.Basic):
114     self.__arr[key]=value
115     elif hasattr(value, "__array__"):
116     self.__arr[key]=map(sympy.sympify,value.flat)
117     else:
118     self.__arr[key]=sympy.sympify(value)
119    
120     def getRank(self):
121     return self.__arr.ndim
122    
123     def getShape(self):
124     return self.__arr.shape
125    
126     def atoms(self, *types):
127     s=set()
128     for el in self.__arr.flat:
129     atoms=el.atoms(*types)
130     for a in atoms:
131     if a.is_Symbol:
132     n,c=Symbol._symComp(a)
133     s.add(sympy.Symbol(n))
134     else:
135     s.add(a)
136     return s
137    
138     def _sympystr_(self, printer):
139     return self.lambdarepr()
140    
141     def lambdarepr(self):
142     from sympy.printing.lambdarepr import lambdarepr
143     temp_arr=numpy.empty(self.__arr.shape, dtype=object)
144     for idx,el in numpy.ndenumerate(self.__arr):
145     atoms=el.atoms(sympy.Symbol)
146     # create a dictionary to convert names like x_0_0 to x[0,0]
147     symdict={}
148     for a in atoms:
149     n,c=Symbol._symComp(a)
150     if len(c)>0:
151     c=[str(i) for i in c]
152     symstr=n+'['+','.join(c)+']'
153     else:
154     symstr=n
155     symdict[a.name]=symstr
156     s=lambdarepr(el)
157     for key in symdict:
158     s=s.replace(key, symdict[key])
159     temp_arr[idx]=s
160     res='combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.__arr.shape))
161     return res
162    
163     def diff(self, *symbols, **assumptions):
164     symbols=Symbol._symbolgen(*symbols)
165     result=Symbol(self.__arr)
166     for s in symbols:
167     if isinstance(s, Symbol):
168     if s.getRank()>0:
169     if s.getShape()!=self.getShape():
170     raise ValueError("Incompatible shapes")
171     a=result.__arr.flat
172     b=s.__arr.flat
173     for idx in range(len(a)):
174     a[idx]=a[idx].diff(b.next())
175     else:
176     diff_item=lambda item: getattr(item, 'diff')(s.__arr.item(), **assumptions)
177     result=result.applyfunc(diff_item)
178    
179     else:
180     diff_item=lambda item: getattr(item, 'diff')(s, **assumptions)
181     result=result.applyfunc(diff_item)
182     return result
183    
184     def swap_axes(self, axis0, axis1):
185     return Symbol(numpy.swapaxes(self.__arr, axis0, axis1))
186    
187 caltinay 3509 def tensorProduct(self, other, axis_offset):
188 caltinay 3507 arg0_c=self.__arr.copy()
189     sh0=self.__arr.shape
190     if isinstance(other, Symbol):
191     arg1_c=other.__arr.copy()
192     sh1=other.getShape()
193     else:
194     arg1_c=other.copy()
195     sh1=other.shape
196     d0,d1,d01=1,1,1
197     for i in sh0[:self.__arr.ndim-axis_offset]: d0*=i
198     for i in sh1[axis_offset:]: d1*=i
199     for i in sh1[:axis_offset]: d01*=i
200     arg0_c.resize((d0,d01))
201     arg1_c.resize((d01,d1))
202     out=numpy.zeros((d0,d1),numpy.object)
203     for i0 in range(d0):
204     for i1 in range(d1):
205     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
206     out.resize(sh0[:self.__arr.ndim-axis_offset]+sh1[axis_offset:])
207     return Symbol(out)
208    
209 caltinay 3509 def transposedTensorProduct(self, other, axis_offset):
210     arg0_c=self.__arr.copy()
211     sh0=self.__arr.shape
212     if isinstance(other, Symbol):
213     arg1_c=other.__arr.copy()
214     sh1=other.getShape()
215     else:
216     arg1_c=other.copy()
217     sh1=other.shape
218     d0,d1,d01=1,1,1
219     for i in sh0[axis_offset:]: d0*=i
220     for i in sh1[axis_offset:]: d1*=i
221     for i in sh1[:axis_offset]: d01*=i
222     arg0_c.resize((d01,d0))
223     arg1_c.resize((d01,d1))
224     out=numpy.zeros((d0,d1),numpy.object)
225     for i0 in range(d0):
226     for i1 in range(d1):
227     out[i0,i1]=numpy.sum(arg0_c[:,i0]*arg1_c[:,i1])
228     out.resize(sh0[axis_offset:]+sh1[axis_offset:])
229     return Symbol(out)
230    
231     def tensorTransposedProduct(self, other, axis_offset):
232     arg0_c=self.__arr.copy()
233     sh0=self.__arr.shape
234     if isinstance(other, Symbol):
235     arg1_c=other.__arr.copy()
236     sh1=other.getShape()
237     r1=other.getRank()
238     else:
239     arg1_c=other.copy()
240     sh1=other.shape
241     r1=other.ndim
242     d0,d1,d01=1,1,1
243     for i in sh0[:self.__arr.ndim-axis_offset]: d0*=i
244     for i in sh1[:r1-axis_offset]: d1*=i
245     for i in sh1[r1-axis_offset:]: d01*=i
246     arg0_c.resize((d0,d01))
247     arg1_c.resize((d1,d01))
248     out=numpy.zeros((d0,d1),numpy.object)
249     for i0 in range(d0):
250     for i1 in range(d1):
251     out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
252     out.resize(sh0[:self.__arr.ndim-axis_offset]+sh1[:r1-axis_offset])
253     return Symbol(out)
254    
255 caltinay 3507 def trace(self, axis_offset):
256     sh=self.__arr.shape
257     s1=1
258     for i in range(axis_offset): s1*=sh[i]
259     s2=1
260     for i in range(axis_offset+2,len(sh)): s2*=sh[i]
261     arr_r=numpy.reshape(self.__arr,(s1,sh[axis_offset],sh[axis_offset],s2))
262     out=numpy.zeros([s1,s2],object)
263     for i1 in range(s1):
264     for i2 in range(s2):
265     for j in range(sh[axis_offset]):
266     out[i1,i2]+=arr_r[i1,j,j,i2]
267     out.resize(sh[:axis_offset]+sh[axis_offset+2:])
268     return Symbol(out)
269    
270     def transpose(self, axis_offset):
271     if axis_offset is None:
272     axis_offset=int(self.__arr.ndim/2)
273     axes=range(axis_offset, self.__arr.ndim)+range(0,axis_offset)
274     return Symbol(numpy.transpose(self.__arr, axes=axes))
275    
276     def applyfunc(self, f):
277     assert callable(f)
278     if self.__arr.ndim==0:
279     out=Symbol(f(self.__arr.item()))
280     else:
281     out=numpy.empty(self.__arr.shape, dtype=object)
282     for idx in numpy.ndindex(self.__arr.shape):
283     out[idx]=f(self.__arr[idx])
284     out=Symbol(out)
285     return out
286    
287     def _sympy_(self):
288     return self.applyfunc(sympy.sympify)
289    
290     @staticmethod
291     def _symComp(sym):
292     n=sym.name
293     a=n.split('[')
294     if len(a)!=2:
295     return n,()
296     a=a[1].split(']')
297     if len(a)!=2:
298     return n,()
299     name=a[0]
300     comps=[int(i) for i in a[1].split('_')[1:]]
301     return name,tuple(comps)
302    
303     @staticmethod
304     def _symbolgen(*symbols):
305     """
306     Generator of all symbols in the argument of diff().
307     (cf. sympy.Derivative._symbolgen)
308    
309     Example:
310     >> ._symbolgen(x, 3, y)
311     (x, x, x, y)
312     >> ._symbolgen(x, 10**6)
313     (x, x, x, x, x, x, x, ...)
314     """
315     from itertools import repeat
316     last_s = symbols[len(symbols)-1]
317     if not isinstance(last_s, Symbol):
318     last_s=sympy.sympify(last_s)
319     for i in xrange(len(symbols)):
320     s = symbols[i]
321     if not isinstance(s, Symbol):
322     s=sympy.sympify(s)
323     next_s = None
324     if s != last_s:
325     next_s = symbols[i+1]
326     if not isinstance(next_s, Symbol):
327     next_s=sympy.sympify(next_s)
328    
329     if isinstance(s, sympy.Integer):
330     continue
331     elif isinstance(s, Symbol) or isinstance(s, sympy.Symbol):
332     # handle cases like (x, 3)
333     if isinstance(next_s, sympy.Integer):
334     # yield (x, x, x)
335     for copy_s in repeat(s,int(next_s)):
336     yield copy_s
337     else:
338     yield s
339     else:
340     yield s
341    
342     # unary/binary operations follow
343    
344     def __pos__(self):
345     return self
346    
347     def __neg__(self):
348     return Symbol(-self.__arr)
349    
350     def __abs__(self):
351     return Symbol(abs(self.__arr))
352    
353     def __add__(self, other):
354     if isinstance(other, Symbol):
355     return Symbol(self.__arr+other.__arr)
356     return Symbol(self.__arr+other)
357    
358     def __radd__(self, other):
359     if isinstance(other, Symbol):
360     return Symbol(other.__arr+self.__arr)
361     return Symbol(other+self.__arr)
362    
363     def __sub__(self, other):
364     if isinstance(other, Symbol):
365     return Symbol(self.__arr-other.__arr)
366     return Symbol(self.__arr-other)
367    
368     def __rsub__(self, other):
369     if isinstance(other, Symbol):
370     return Symbol(other.__arr-self.__arr)
371     return Symbol(other-self.__arr)
372    
373     def __mul__(self, other):
374     if isinstance(other, Symbol):
375     return Symbol(self.__arr*other.__arr)
376     return Symbol(self.__arr*other)
377    
378     def __rmul__(self, other):
379     if isinstance(other, Symbol):
380     return Symbol(other.__arr*self.__arr)
381     return Symbol(other*self.__arr)
382    
383     def __div__(self, other):
384     if isinstance(other, Symbol):
385     return Symbol(self.__arr/other.__arr)
386     return Symbol(self.__arr/other)
387    
388     def __rdiv__(self, other):
389     if isinstance(other, Symbol):
390     return Symbol(other.__arr/self.__arr)
391     return Symbol(other/self.__arr)
392    
393     def __pow__(self, other):
394     if isinstance(other, Symbol):
395     return Symbol(self.__arr**other.__arr)
396     return Symbol(self.__arr**other)
397    
398     def __rpow__(self, other):
399     if isinstance(other, Symbol):
400     return Symbol(other.__arr**self.__arr)
401     return Symbol(other**self.__arr)
402    
403    
404     def symbols(*names, **kwargs):
405     """
406     Emulates the behaviour of sympy.symbols.
407     """
408    
409     shape=kwargs.pop('shape', ())
410    
411     s = names[0]
412     if not isinstance(s, list):
413     import re
414     s = re.split('\s|,', s)
415     res = []
416     for t in s:
417     # skip empty strings
418     if not t:
419     continue
420     sym = Symbol(t, shape, **kwargs)
421     res.append(sym)
422     res = tuple(res)
423     if len(res) == 0: # var('')
424     res = None
425     elif len(res) == 1: # var('x')
426     res = res[0]
427     # otherwise var('a b ...')
428     return res
429    
430     def combineData(array, shape):
431     # array could just be a single value
432     if not hasattr(array,'__len__') and shape==():
433     return array
434    
435     from esys.escript import Data
436     n=numpy.array(array) # for indexing
437    
438     # find function space if any
439     dom=set()
440     fs=set()
441     for idx in numpy.ndindex(shape):
442     if isinstance(n[idx], Data):
443     fs.add(n[idx].getFunctionSpace())
444     dom.add(n[idx].getDomain())
445    
446     if len(dom)>1:
447     domain=dom.pop()
448     while len(dom)>0:
449     if domain!=dom.pop():
450     raise ValueError("Mixing of domains not supported")
451    
452     if len(fs)>0:
453     d=Data(0., shape, fs.pop()) #FIXME: interpolate instead of using first?
454     else:
455 caltinay 3509 d=numpy.zeros(shape)
456 caltinay 3507 for idx in numpy.ndindex(shape):
457     #z=numpy.zeros(shape)
458     #z[idx]=1.
459     #d+=n[idx]*z # much slower!
460     if hasattr(n[idx], "ndim") and n[idx].ndim==0:
461     d[idx]=float(n[idx])
462     else:
463     d[idx]=n[idx]
464     return d
465    

  ViewVC Help
Powered by ViewVC 1.1.26