/[escript]/trunk/escript/py_src/symbolic/symbol.py
ViewVC logotype

Diff of /trunk/escript/py_src/symbolic/symbol.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3509 by caltinay, Fri May 13 06:01:52 2011 UTC revision 3512 by caltinay, Wed May 18 06:22:46 2011 UTC
# Line 47  class Symbol(object): Line 47  class Symbol(object):
47          if len(args)==1:          if len(args)==1:
48              arg=args[0]              arg=args[0]
49              if isinstance(arg, str):              if isinstance(arg, str):
50                  self.__arr=numpy.array(sympy.Symbol(arg, **kwargs))                  if arg.find('[')>=0 or arg.find(']')>=0:
51                        raise TypeError("Name must not contain '[' or ']'")
52                    self._arr=numpy.array(sympy.Symbol(arg, **kwargs))
53              elif hasattr(arg, "__array__"):              elif hasattr(arg, "__array__"):
54                  arr=arg.__array__()                  arr=arg.__array__()
55                  if len(arr.shape)>4:                  if len(arr.shape)>4:
56                      raise ValueError("Symbol only supports tensors up to order 4")                      raise ValueError("Symbol only supports tensors up to order 4")
57                  self.__arr=arr.copy()                  self._arr=arr.copy()
58              elif isinstance(arg, list) or isinstance(arg, sympy.Basic):              elif isinstance(arg, list) or isinstance(arg, sympy.Basic):
59                  self.__arr=numpy.array(arg)                  self._arr=numpy.array(arg)
60              #elif isinstance(arg, Data):              #elif isinstance(arg, Data):
61              #    self.__arr=arg              #    self._arr=arg
62              else:              else:
63                  raise TypeError("Unsupported argument type %s"%str(type(arg)))                  raise TypeError("Unsupported argument type %s"%str(type(arg)))
64          elif len(args)==2:          elif len(args)==2:
# Line 71  class Symbol(object): Line 73  class Symbol(object):
73              if len(shape)>4:              if len(shape)>4:
74                  raise ValueError("Symbol only supports tensors up to order 4")                  raise ValueError("Symbol only supports tensors up to order 4")
75              if len(shape)==0:              if len(shape)==0:
76                  self.__arr=numpy.array(sympy.Symbol(name, **kwargs))                  self._arr=numpy.array(sympy.Symbol(name, **kwargs))
77              else:              else:
78                  self.__arr=sympy.symarray(shape, '['+name+']')                  self._arr=sympy.symarray(shape, '['+name+']')
79          else:          else:
80              raise TypeError("Unsupported number of arguments")              raise TypeError("Unsupported number of arguments")
81          if self.__arr.ndim==0:          if self._arr.ndim==0:
82              self.name=self.__arr.item()              self.name=str(self._arr.item())
83          else:          else:
84              self.name=str(self.__arr.tolist())              self.name=str(self._arr.tolist())
85    
86      def __repr__(self):      def __repr__(self):
87          return str(self.__arr)          return str(self._arr)
88    
89      def __str__(self):      def __str__(self):
90          return str(self.__arr)          return str(self._arr)
91    
92      def __eq__(self, other):      def __eq__(self, other):
93          if type(self) is not type(other):          if type(self) is not type(other):
# Line 94  class Symbol(object): Line 96  class Symbol(object):
96              return False              return False
97          if self.getShape()!=other.getShape():          if self.getShape()!=other.getShape():
98              return False              return False
99          return (self.__arr==other.__arr).all()          return (self._arr==other._arr).all()
100    
101      def __getitem__(self, key):      def __getitem__(self, key):
102          return self.__arr[key]          return self._arr[key]
103    
104      def __setitem__(self, key, value):      def __setitem__(self, key, value):
105          if isinstance(value,Symbol):          if isinstance(value,Symbol):
106              if value.getRank()==0:              if value.getRank()==0:
107                  self.__arr[key]=value                  self._arr[key]=value
108              elif hasattr(self.__arr[key], "shape"):              elif hasattr(self._arr[key], "shape"):
109                  if self.__arr[key].shape==value.getShape():                  if self._arr[key].shape==value.getShape():
110                      self.__arr[key]=value                      self._arr[key]=value
111                  else:                  else:
112                      raise ValueError("Wrong shape of value")                      raise ValueError("Wrong shape of value")
113              else:              else:
114                  raise ValueError("Wrong shape of value")                  raise ValueError("Wrong shape of value")
115          elif isinstance(value,sympy.Basic):          elif isinstance(value,sympy.Basic):
116              self.__arr[key]=value              self._arr[key]=value
117          elif hasattr(value, "__array__"):          elif hasattr(value, "__array__"):
118              self.__arr[key]=map(sympy.sympify,value.flat)              self._arr[key]=map(sympy.sympify,value.flat)
119          else:          else:
120              self.__arr[key]=sympy.sympify(value)              self._arr[key]=sympy.sympify(value)
121    
122      def getRank(self):      def getRank(self):
123          return self.__arr.ndim          return self._arr.ndim
124    
125      def getShape(self):      def getShape(self):
126          return self.__arr.shape          return self._arr.shape
127    
128      def atoms(self, *types):      def atoms(self, *types):
129          s=set()          s=set()
130          for el in self.__arr.flat:          for el in self._arr.flat:
131              atoms=el.atoms(*types)              atoms=el.atoms(*types)
132              for a in atoms:              for a in atoms:
133                  if a.is_Symbol:                  if a.is_Symbol:
# Line 140  class Symbol(object): Line 142  class Symbol(object):
142    
143      def lambdarepr(self):      def lambdarepr(self):
144          from sympy.printing.lambdarepr import lambdarepr          from sympy.printing.lambdarepr import lambdarepr
145          temp_arr=numpy.empty(self.__arr.shape, dtype=object)          if self.getRank()==0:
146          for idx,el in numpy.ndenumerate(self.__arr):              return lambdarepr(self._arr.item())
147            temp_arr=numpy.empty(self.getShape(), dtype=object)
148            for idx,el in numpy.ndenumerate(self._arr):
149              atoms=el.atoms(sympy.Symbol)              atoms=el.atoms(sympy.Symbol)
150              # create a dictionary to convert names like x_0_0 to x[0,0]              # create a dictionary to convert names like x_0_0 to x[0,0]
151              symdict={}              symdict={}
# Line 157  class Symbol(object): Line 161  class Symbol(object):
161              for key in symdict:              for key in symdict:
162                  s=s.replace(key, symdict[key])                  s=s.replace(key, symdict[key])
163              temp_arr[idx]=s              temp_arr[idx]=s
164          res='combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.__arr.shape))          return 'combineData(%s,%s)'%(str(temp_arr.tolist()).replace("'",""),str(self.getShape()))
         return res  
165    
166      def diff(self, *symbols, **assumptions):      def diff(self, *symbols, **assumptions):
167          symbols=Symbol._symbolgen(*symbols)          symbols=Symbol._symbolgen(*symbols)
168          result=Symbol(self.__arr)          result=Symbol(self._arr)
169          for s in symbols:          for s in symbols:
170              if isinstance(s, Symbol):              if isinstance(s, Symbol):
171                  if s.getRank()>0:                  if s.getRank()>0:
172                      if s.getShape()!=self.getShape():                      if s.getShape()!=self.getShape():
173                          raise ValueError("Incompatible shapes")                          raise ValueError("Incompatible shapes")
174                      a=result.__arr.flat                      a=result._arr.flat
175                      b=s.__arr.flat                      b=s._arr.flat
176                      for idx in range(len(a)):                      for idx in range(len(a)):
177                          a[idx]=a[idx].diff(b.next())                          a[idx]=a[idx].diff(b.next())
178                  else:                  else:
179                      diff_item=lambda item: getattr(item, 'diff')(s.__arr.item(), **assumptions)                      diff_item=lambda item: getattr(item, 'diff')(s._arr.item(), **assumptions)
180                      result=result.applyfunc(diff_item)                      result=result.applyfunc(diff_item)
181    
182              else:              else:
# Line 181  class Symbol(object): Line 184  class Symbol(object):
184                  result=result.applyfunc(diff_item)                  result=result.applyfunc(diff_item)
185          return result          return result
186    
187        def grad(self, where=None):
188            if isinstance(where, Symbol):
189                if where.getRank()>0:
190                    raise ValueError("'where' must be a scalar symbol")
191                where=where._arr.item()
192    
193            from functions import grad_n
194            dim=2
195            out=self._arr.copy().reshape(self.getShape()+(1,)).repeat(dim,axis=self.getRank())
196            for d in range(dim):
197                for idx in numpy.ndindex(self.getShape()):
198                    index=idx+(d,)
199                    if where is None:
200                        out[index]=grad_n(out[index],d)
201                    else:
202                        out[index]=grad_n(out[index],d,where)
203            return Symbol(out)
204    
205      def swap_axes(self, axis0, axis1):      def swap_axes(self, axis0, axis1):
206          return Symbol(numpy.swapaxes(self.__arr, axis0, axis1))          return Symbol(numpy.swapaxes(self._arr, axis0, axis1))
207    
208      def tensorProduct(self, other, axis_offset):      def tensorProduct(self, other, axis_offset):
209          arg0_c=self.__arr.copy()          arg0_c=self._arr.copy()
210          sh0=self.__arr.shape          sh0=self.getShape()
211          if isinstance(other, Symbol):          if isinstance(other, Symbol):
212              arg1_c=other.__arr.copy()              arg1_c=other._arr.copy()
213              sh1=other.getShape()              sh1=other.getShape()
214          else:          else:
215              arg1_c=other.copy()              arg1_c=other.copy()
216              sh1=other.shape              sh1=other.shape
217          d0,d1,d01=1,1,1          d0,d1,d01=1,1,1
218          for i in sh0[:self.__arr.ndim-axis_offset]: d0*=i          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
219          for i in sh1[axis_offset:]: d1*=i          for i in sh1[axis_offset:]: d1*=i
220          for i in sh1[:axis_offset]: d01*=i          for i in sh1[:axis_offset]: d01*=i
221          arg0_c.resize((d0,d01))          arg0_c.resize((d0,d01))
# Line 203  class Symbol(object): Line 224  class Symbol(object):
224          for i0 in range(d0):          for i0 in range(d0):
225              for i1 in range(d1):              for i1 in range(d1):
226                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[:,i1])
227          out.resize(sh0[:self.__arr.ndim-axis_offset]+sh1[axis_offset:])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[axis_offset:])
228          return Symbol(out)          return Symbol(out)
229    
230      def transposedTensorProduct(self, other, axis_offset):      def transposedTensorProduct(self, other, axis_offset):
231          arg0_c=self.__arr.copy()          arg0_c=self._arr.copy()
232          sh0=self.__arr.shape          sh0=self.getShape()
233          if isinstance(other, Symbol):          if isinstance(other, Symbol):
234              arg1_c=other.__arr.copy()              arg1_c=other._arr.copy()
235              sh1=other.getShape()              sh1=other.getShape()
236          else:          else:
237              arg1_c=other.copy()              arg1_c=other.copy()
# Line 229  class Symbol(object): Line 250  class Symbol(object):
250          return Symbol(out)          return Symbol(out)
251    
252      def tensorTransposedProduct(self, other, axis_offset):      def tensorTransposedProduct(self, other, axis_offset):
253          arg0_c=self.__arr.copy()          arg0_c=self._arr.copy()
254          sh0=self.__arr.shape          sh0=self.getShape()
255          if isinstance(other, Symbol):          if isinstance(other, Symbol):
256              arg1_c=other.__arr.copy()              arg1_c=other._arr.copy()
257              sh1=other.getShape()              sh1=other.getShape()
258              r1=other.getRank()              r1=other.getRank()
259          else:          else:
# Line 240  class Symbol(object): Line 261  class Symbol(object):
261              sh1=other.shape              sh1=other.shape
262              r1=other.ndim              r1=other.ndim
263          d0,d1,d01=1,1,1          d0,d1,d01=1,1,1
264          for i in sh0[:self.__arr.ndim-axis_offset]: d0*=i          for i in sh0[:self._arr.ndim-axis_offset]: d0*=i
265          for i in sh1[:r1-axis_offset]: d1*=i          for i in sh1[:r1-axis_offset]: d1*=i
266          for i in sh1[r1-axis_offset:]: d01*=i          for i in sh1[r1-axis_offset:]: d01*=i
267          arg0_c.resize((d0,d01))          arg0_c.resize((d0,d01))
# Line 249  class Symbol(object): Line 270  class Symbol(object):
270          for i0 in range(d0):          for i0 in range(d0):
271              for i1 in range(d1):              for i1 in range(d1):
272                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])                  out[i0,i1]=numpy.sum(arg0_c[i0,:]*arg1_c[i1,:])
273          out.resize(sh0[:self.__arr.ndim-axis_offset]+sh1[:r1-axis_offset])          out.resize(sh0[:self._arr.ndim-axis_offset]+sh1[:r1-axis_offset])
274          return Symbol(out)          return Symbol(out)
275    
276      def trace(self, axis_offset):      def trace(self, axis_offset):
277          sh=self.__arr.shape          sh=self.getShape()
278          s1=1          s1=1
279          for i in range(axis_offset): s1*=sh[i]          for i in range(axis_offset): s1*=sh[i]
280          s2=1          s2=1
281          for i in range(axis_offset+2,len(sh)): s2*=sh[i]          for i in range(axis_offset+2,len(sh)): s2*=sh[i]
282          arr_r=numpy.reshape(self.__arr,(s1,sh[axis_offset],sh[axis_offset],s2))          arr_r=numpy.reshape(self._arr,(s1,sh[axis_offset],sh[axis_offset],s2))
283          out=numpy.zeros([s1,s2],object)          out=numpy.zeros([s1,s2],object)
284          for i1 in range(s1):          for i1 in range(s1):
285              for i2 in range(s2):              for i2 in range(s2):
# Line 269  class Symbol(object): Line 290  class Symbol(object):
290    
291      def transpose(self, axis_offset):      def transpose(self, axis_offset):
292          if axis_offset is None:          if axis_offset is None:
293              axis_offset=int(self.__arr.ndim/2)              axis_offset=int(self._arr.ndim/2)
294          axes=range(axis_offset, self.__arr.ndim)+range(0,axis_offset)          axes=range(axis_offset, self._arr.ndim)+range(0,axis_offset)
295          return Symbol(numpy.transpose(self.__arr, axes=axes))          return Symbol(numpy.transpose(self._arr, axes=axes))
296    
297      def applyfunc(self, f):      def applyfunc(self, f):
298          assert callable(f)          assert callable(f)
299          if self.__arr.ndim==0:          if self._arr.ndim==0:
300              out=Symbol(f(self.__arr.item()))              out=Symbol(f(self._arr.item()))
301          else:          else:
302              out=numpy.empty(self.__arr.shape, dtype=object)              out=numpy.empty(self.getShape(), dtype=object)
303              for idx in numpy.ndindex(self.__arr.shape):              for idx in numpy.ndindex(self.getShape()):
304                  out[idx]=f(self.__arr[idx])                  out[idx]=f(self._arr[idx])
305              out=Symbol(out)              out=Symbol(out)
306          return out          return out
307    
# Line 345  class Symbol(object): Line 366  class Symbol(object):
366          return self          return self
367    
368      def __neg__(self):      def __neg__(self):
369          return Symbol(-self.__arr)          return Symbol(-self._arr)
370    
371      def __abs__(self):      def __abs__(self):
372          return Symbol(abs(self.__arr))          return Symbol(abs(self._arr))
373    
374      def __add__(self, other):      def __add__(self, other):
375          if isinstance(other, Symbol):          if isinstance(other, Symbol):
376              return Symbol(self.__arr+other.__arr)              return Symbol(self._arr+other._arr)
377          return Symbol(self.__arr+other)          return Symbol(self._arr+other)
378    
379      def __radd__(self, other):      def __radd__(self, other):
380          if isinstance(other, Symbol):          if isinstance(other, Symbol):
381              return Symbol(other.__arr+self.__arr)              return Symbol(other._arr+self._arr)
382          return Symbol(other+self.__arr)          return Symbol(other+self._arr)
383    
384      def __sub__(self, other):      def __sub__(self, other):
385          if isinstance(other, Symbol):          if isinstance(other, Symbol):
386              return Symbol(self.__arr-other.__arr)              return Symbol(self._arr-other._arr)
387          return Symbol(self.__arr-other)          return Symbol(self._arr-other)
388    
389      def __rsub__(self, other):      def __rsub__(self, other):
390          if isinstance(other, Symbol):          if isinstance(other, Symbol):
391              return Symbol(other.__arr-self.__arr)              return Symbol(other._arr-self._arr)
392          return Symbol(other-self.__arr)          return Symbol(other-self._arr)
393    
394      def __mul__(self, other):      def __mul__(self, other):
395          if isinstance(other, Symbol):          if isinstance(other, Symbol):
396              return Symbol(self.__arr*other.__arr)              return Symbol(self._arr*other._arr)
397          return Symbol(self.__arr*other)          return Symbol(self._arr*other)
398    
399      def __rmul__(self, other):      def __rmul__(self, other):
400          if isinstance(other, Symbol):          if isinstance(other, Symbol):
401              return Symbol(other.__arr*self.__arr)              return Symbol(other._arr*self._arr)
402          return Symbol(other*self.__arr)          return Symbol(other*self._arr)
403    
404      def __div__(self, other):      def __div__(self, other):
405          if isinstance(other, Symbol):          if isinstance(other, Symbol):
406              return Symbol(self.__arr/other.__arr)              return Symbol(self._arr/other._arr)
407          return Symbol(self.__arr/other)          return Symbol(self._arr/other)
408    
409      def __rdiv__(self, other):      def __rdiv__(self, other):
410          if isinstance(other, Symbol):          if isinstance(other, Symbol):
411              return Symbol(other.__arr/self.__arr)              return Symbol(other._arr/self._arr)
412          return Symbol(other/self.__arr)          return Symbol(other/self._arr)
413    
414      def __pow__(self, other):      def __pow__(self, other):
415          if isinstance(other, Symbol):          if isinstance(other, Symbol):
416              return Symbol(self.__arr**other.__arr)              return Symbol(self._arr**other._arr)
417          return Symbol(self.__arr**other)          return Symbol(self._arr**other)
418    
419      def __rpow__(self, other):      def __rpow__(self, other):
420          if isinstance(other, Symbol):          if isinstance(other, Symbol):
421              return Symbol(other.__arr**self.__arr)              return Symbol(other._arr**self._arr)
422          return Symbol(other**self.__arr)          return Symbol(other**self._arr)
423    
424    
425  def symbols(*names, **kwargs):  def symbols(*names, **kwargs):

Legend:
Removed from v.3509  
changed lines
  Added in v.3512

  ViewVC Help
Powered by ViewVC 1.1.26