/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 108 by jgs, Thu Jan 27 06:21:59 2005 UTC revision 124 by jgs, Wed Jul 20 06:14:58 2005 UTC
# Line 4  Line 4 
4    
5  """  """
6  @brief Utility functions for escript  @brief Utility functions for escript
7    
8    TODO for Data:
9    
10      * binary operations @:               (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11                    @=+,-,*,/,**           (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12                                           (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13      
14      * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15      * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16  """  """
17    
18  import numarray  import numarray
19  import escript  import escript
20    
21  #  #
22  #   escript constants (have to be consistent witj utilC.h  #   escript constants (have to be consistent with utilC.h )
23  #  #
24  UNKNOWN=-1  UNKNOWN=-1
25  EPSILON=1.e-15  EPSILON=1.e-15
# Line 49  TIFF=7 Line 59  TIFF=7
59  OPENINVENTOR=8  OPENINVENTOR=8
60  RENDERMAN=9  RENDERMAN=9
61  PNM=10  PNM=10
 #  
 # wrapper for various functions: if the argument has attribute the function name  
 # as an argument it calls the correspong methods. Otherwise the coresponsing numarray  
 # function is called.  
 #  
 # functions involving the underlying Domain:  
 #  
 def grad(arg,where=None):  
     """  
     @brief returns the spatial gradient of the Data object arg  
62    
63      @param arg: Data object representing the function which gradient to be calculated.  #===========================================================
64      @param where: FunctionSpace in which the gradient will be. If None Function(dom) where dom is the  # a simple tool box to deal with _differentials of functions
65                    domain of the Data object arg.  #===========================================================
66      """  
67      if isinstance(arg,escript.Data):  class Symbol:
68         if where==None:     """symbol class"""
69            return arg.grad()     def __init__(self,name="symbol",shape=(),dim=3,args=[]):
70           """creates an instance of a symbol of shape shape and spatial dimension dim
71              The symbol may depending on a list of arguments args which
72              may be symbols or other objects. name gives the name of the symbol."""
73           self.__args=args
74           self.__name=name
75           self.__shape=shape
76           if hasattr(dim,"getDim"):
77               self.__dim=dim.getDim()
78           else:    
79               self.__dim=dim
80           #
81           self.__cache_val=None
82           self.__cache_argval=None
83    
84       def getArgument(self,i):
85           """returns the i-th argument"""
86           return self.__args[i]
87    
88       def getDim(self):
89           """returns the spatial dimension of the symbol"""
90           return self.__dim
91    
92       def getRank(self):
93           """returns the rank of the symbol"""
94           return len(self.getShape())
95    
96       def getShape(self):
97           """returns the shape of the symbol"""
98           return self.__shape
99    
100       def getEvaluatedArguments(self,argval):
101           """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
102           if argval==self.__cache_argval:
103               print "%s: cached value used"%self
104               return self.__cache_val
105           else:
106               out=[]
107               for a  in self.__args:
108                  if isinstance(a,Symbol):
109                    out.append(a.eval(argval))
110                  else:
111                    out.append(a)
112               self.__cache_argval=argval
113               self.__cache_val=out
114               return out
115    
116       def getDifferentiatedArguments(self,arg):
117           """returns the list of the arguments _differentiated by arg"""
118           out=[]
119           for a in self.__args:
120              if isinstance(a,Symbol):
121                out.append(a.diff(arg))
122              else:
123                out.append(0)
124           return out
125    
126       def diff(self,arg):
127           """returns the _differention of self by arg."""
128           if self==arg:
129              out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
130              if self.getRank()==0:
131                 out=1.
132              elif self.getRank()==1:
133                  for i0 in range(self.getShape()[0]):
134                     out[i0,i0]=1.  
135              elif self.getRank()==2:
136                  for i0 in range(self.getShape()[0]):
137                    for i1 in range(self.getShape()[1]):
138                         out[i0,i1,i0,i1]=1.  
139              elif self.getRank()==3:
140                  for i0 in range(self.getShape()[0]):
141                    for i1 in range(self.getShape()[1]):
142                      for i2 in range(self.getShape()[2]):
143                         out[i0,i1,i2,i0,i1,i2]=1.  
144              elif self.getRank()==4:
145                  for i0 in range(self.getShape()[0]):
146                    for i1 in range(self.getShape()[1]):
147                      for i2 in range(self.getShape()[2]):
148                        for i3 in range(self.getShape()[3]):
149                           out[i0,i1,i2,i3,i0,i1,i2,i3]=1.  
150              else:
151                 raise ValueError,"differential support rank<5 only."
152              return out
153         else:         else:
154            return arg.grad(where)            return self._diff(arg)
155      else:  
156         return arg*0.     def _diff(self,arg):
157           """return derivate of self with respect to arg (!=self).
158              This method is overwritten by a particular symbol"""
159           return 0
160    
161       def eval(self,argval):
162           """subsitutes symbol u in self by argval[u] and returns the result. If
163              self is not a key of argval then self is returned."""
164           if argval.has_key(self):
165             return argval[self]
166           else:
167             return self
168    
169       def __str__(self):
170           """returns a string representation of the symbol"""
171           return self.__name
172    
173       def __add__(self,other):
174           """adds other to symbol self. if _testForZero(other) self is returned."""
175           if _testForZero(other):
176              return self
177           else:
178              a=_matchShape([self,other])
179              return Add_Symbol(a[0],a[1])
180    
181       def __radd__(self,other):
182           """adds other to symbol self. if _testForZero(other) self is returned."""
183           return self+other
184    
185       def __neg__(self):
186           """returns -self."""
187           return self*(-1.)
188    
189       def __pos__(self):
190           """returns +self."""
191           return self
192    
193       def __abs__(self):
194           """returns absolute value"""
195           return Abs_Symbol(self)
196    
197       def __sub__(self,other):
198           """subtracts other from symbol self. if _testForZero(other) self is returned."""
199           if _testForZero(other):
200              return self
201           else:
202              return self+(-other)
203    
204       def __rsub__(self,other):
205           """subtracts symbol self from other."""
206           return -self+other
207    
208       def __div__(self,other):
209           """divides symbol self by other."""
210           if isinstance(other,Symbol):
211              a=_matchShape([self,other])
212              return Div_Symbol(a[0],a[1])
213           else:
214              return self*(1./other)
215    
216       def __rdiv__(self,other):
217           """dived other by symbol self. if _testForZero(other) 0 is returned."""
218           if _testForZero(other):
219              return 0
220           else:
221              a=_matchShape([self,other])
222              return Div_Symbol(a[0],a[1])
223    
224       def __pow__(self,other):
225           """raises symbol self to the power of other"""
226           a=_matchShape([self,other])
227           return Power_Symbol(a[0],a[1])
228    
229       def __rpow__(self,other):
230           """raises other to the symbol self"""
231           a=_matchShape([self,other])
232           return Power_Symbol(a[1],a[0])
233    
234       def __mul__(self,other):
235           """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
236           if _testForZero(other):
237              return 0
238           else:
239              a=_matchShape([self,other])
240              return Mult_Symbol(a[0],a[1])
241    
242  def integrate(arg,what=None):     def __rmul__(self,other):
243           """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
244           return self*other
245    
246       def __getitem__(self,sl):
247              print sl
248    
249    def Float_Symbol(Symbol):
250        def __init__(self,name="symbol",shape=(),args=[]):
251            Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
252    
253    class ScalarSymbol(Symbol):
254       """a scalar symbol"""
255       def __init__(self,dim=3,name="scalar"):
256          """creates a scalar symbol of spatial dimension dim"""
257          if hasattr(dim,"getDim"):
258               d=dim.getDim()
259          else:    
260               d=dim
261          Symbol.__init__(self,shape=(),dim=d,name=name)
262    
263    class VectorSymbol(Symbol):
264       """a vector symbol"""
265       def __init__(self,dim=3,name="vector"):
266          """creates a vector symbol of spatial dimension dim"""
267          if hasattr(dim,"getDim"):
268               d=dim.getDim()
269          else:    
270               d=dim
271          Symbol.__init__(self,shape=(d,),dim=d,name=name)
272    
273    class TensorSymbol(Symbol):
274       """a tensor symbol"""
275       def __init__(self,dim=3,name="tensor"):
276          """creates a tensor symbol of spatial dimension dim"""
277          if hasattr(dim,"getDim"):
278               d=dim.getDim()
279          else:    
280               d=dim
281          Symbol.__init__(self,shape=(d,d),dim=d,name=name)
282    
283    class Tensor3Symbol(Symbol):
284       """a tensor order 3 symbol"""
285       def __init__(self,dim=3,name="tensor3"):
286          """creates a tensor order 3 symbol of spatial dimension dim"""
287          if hasattr(dim,"getDim"):
288               d=dim.getDim()
289          else:    
290               d=dim
291          Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
292    
293    class Tensor4Symbol(Symbol):
294       """a tensor order 4 symbol"""
295       def __init__(self,dim=3,name="tensor4"):
296          """creates a tensor order 4 symbol of spatial dimension dim"""    
297          if hasattr(dim,"getDim"):
298               d=dim.getDim()
299          else:    
300               d=dim
301          Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
302    
303    class Add_Symbol(Symbol):
304       """symbol representing the sum of two arguments"""
305       def __init__(self,arg0,arg1):
306           a=[arg0,arg1]
307           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
308       def __str__(self):
309          return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
310       def eval(self,argval):
311           a=self.getEvaluatedArguments(argval)
312           return a[0]+a[1]
313       def _diff(self,arg):
314           a=self.getDifferentiatedArguments(arg)
315           return a[0]+a[1]
316    
317    class Mult_Symbol(Symbol):
318       """symbol representing the product of two arguments"""
319       def __init__(self,arg0,arg1):
320           a=[arg0,arg1]
321           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
322       def __str__(self):
323          return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
324       def eval(self,argval):
325           a=self.getEvaluatedArguments(argval)
326           return a[0]*a[1]
327       def _diff(self,arg):
328           a=self.getDifferentiatedArguments(arg)
329           return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
330    
331    class Div_Symbol(Symbol):
332       """symbol representing the quotient of two arguments"""
333       def __init__(self,arg0,arg1):
334           a=[arg0,arg1]
335           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
336       def __str__(self):
337          return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
338       def eval(self,argval):
339           a=self.getEvaluatedArguments(argval)
340           return a[0]/a[1]
341       def _diff(self,arg):
342           a=self.getDifferentiatedArguments(arg)
343           return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
344                              (self.getArgument(1)*self.getArgument(1))
345    
346    class Power_Symbol(Symbol):
347       """symbol representing the power of the first argument to the power of the second argument"""
348       def __init__(self,arg0,arg1):
349           a=[arg0,arg1]
350           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
351       def __str__(self):
352          return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
353       def eval(self,argval):
354           a=self.getEvaluatedArguments(argval)
355           return a[0]**a[1]
356       def _diff(self,arg):
357           a=self.getDifferentiatedArguments(arg)
358           return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
359    
360    class Abs_Symbol(Symbol):
361       """symbol representing absolute value of its argument"""
362       def __init__(self,arg):
363           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
364       def __str__(self):
365          return "abs(%s)"%str(self.getArgument(0))
366       def eval(self,argval):
367           return abs(self.getEvaluatedArguments(argval)[0])
368       def _diff(self,arg):
369           return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
370    
371    #=========================================================
372    #   some little helpers
373    #=========================================================
374    def _testForZero(arg):
375       """returns True is arg is considered of being zero"""
376       if isinstance(arg,int):
377          return not arg>0
378       elif isinstance(arg,float):
379          return not arg>0.
380       elif isinstance(arg,numarray.NumArray):
381          a=abs(arg)
382          while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
383          return not a>0
384       else:
385          return False
386    
387    def _extractDim(args):
388        dim=None
389        for a in args:
390           if hasattr(a,"getDim"):
391              d=a.getDim()
392              if dim==None:
393                 dim=d
394              else:
395                 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
396        if dim==None:
397           raise ValueError,"cannot recover spatial dimension"
398        return dim
399    
400    def _identifyShape(arg):
401       """identifies the shape of arg."""
402       if hasattr(arg,"getShape"):
403           arg_shape=arg.getShape()
404       elif hasattr(arg,"shape"):
405         s=arg.shape
406         if callable(s):
407           arg_shape=s()
408         else:
409           arg_shape=s
410       else:
411           arg_shape=()
412       return arg_shape
413    
414    def _extractShape(args):
415        """extracts the common shape of the list of arguments args"""
416        shape=None
417        for a in args:
418           a_shape=_identifyShape(a)
419           if shape==None: shape=a_shape
420           if shape!=a_shape: raise ValueError,"inconsistent shape"
421        if shape==None:
422           raise ValueError,"cannot recover shape"
423        return shape
424    
425    def _matchShape(args,shape=None):
426        """returns the list of arguments args as object which have all the specified shape.
427           if shape is not given the shape "largest" shape of args is used."""
428        # identify the list of shapes:
429        arg_shapes=[]
430        for a in args: arg_shapes.append(_identifyShape(a))
431        # get the largest shape (currently the longest shape):
432        if shape==None: shape=max(arg_shapes)
433        
434        out=[]
435        for i in range(len(args)):
436           if shape==arg_shapes[i]:
437              out.append(args[i])
438           else:
439              if len(shape)==0: # then len(arg_shapes[i])>0
440                raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
441              else:
442                if len(arg_shapes[i])==0:
443                    out.append(outer(args[i],numarray.ones(shape)))        
444                else:  
445                    raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
446        return out  
447    
448    #=========================================================
449    #   wrappers for various mathematical functions:
450    #=========================================================
451    def diff(arg,dep):
452        """returns the derivative of arg with respect to dep. If arg is not Symbol object
453           0 is returned"""
454        if isinstance(arg,Symbol):
455           return arg.diff(dep)
456        elif hasattr(arg,"shape"):
457              if callable(arg.shape):
458                  return numarray.zeros(arg.shape(),numarray.Float)
459              else:
460                  return numarray.zeros(arg.shape,numarray.Float)
461        else:
462           return 0
463    
464    def exp(arg):
465      """      """
466      @brief return the integral if the function represented by Data object arg over its domain.      @brief applies the exponential function to arg
467        @param arg (input): argument
468        """
469        if isinstance(arg,Symbol):
470           return Exp_Symbol(arg)
471        elif hasattr(arg,"exp"):
472           return arg.exp()
473        else:
474           return numarray.exp(arg)
475    
476      @param arg  class Exp_Symbol(Symbol):
477       """symbol representing the power of the first argument to the power of the second argument"""
478       def __init__(self,arg):
479           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
480       def __str__(self):
481          return "exp(%s)"%str(self.getArgument(0))
482       def eval(self,argval):
483           return exp(self.getEvaluatedArguments(argval)[0])
484       def _diff(self,arg):
485           return self*self.getDifferentiatedArguments(arg)[0]
486    
487    def sqrt(arg):
488        """
489        @brief applies the squre root function to arg
490        @param arg (input): argument
491      """      """
492      if not what==None:      if isinstance(arg,Symbol):
493         arg2=escript.Data(arg,what)         return Sqrt_Symbol(arg)
494        elif hasattr(arg,"sqrt"):
495           return arg.sqrt()
496      else:      else:
497         arg2=arg         return numarray.sqrt(arg)      
498      if arg2.getRank()==0:  
499          return arg2.integrate()[0]  class Sqrt_Symbol(Symbol):
500       """symbol representing square root of argument"""
501       def __init__(self,arg):
502           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
503       def __str__(self):
504          return "sqrt(%s)"%str(self.getArgument(0))
505       def eval(self,argval):
506           return sqrt(self.getEvaluatedArguments(argval)[0])
507       def _diff(self,arg):
508           return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
509    
510    def log(arg):
511        """
512        @brief applies the logarithmic function bases exp(1.) to arg
513        @param arg (input): argument
514        """
515        if isinstance(arg,Symbol):
516           return Log_Symbol(arg)
517        elif hasattr(arg,"log"):
518           return arg.log()
519      else:      else:
520          return arg2.integrate()         return numarray.log(arg)
521    
522  def interpolate(arg,where):  class Log_Symbol(Symbol):
523      """     """symbol representing logarithm of the argument"""
524      @brief interpolates the function represented by Data object arg into the FunctionSpace where.     def __init__(self,arg):
525           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
526       def __str__(self):
527          return "log(%s)"%str(self.getArgument(0))
528       def eval(self,argval):
529           return log(self.getEvaluatedArguments(argval)[0])
530       def _diff(self,arg):
531           return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
532    
533    def ln(arg):
534        """
535        @brief applies the natural logarithmic function to arg
536        @param arg (input): argument
537        """
538        if isinstance(arg,Symbol):
539           return Ln_Symbol(arg)
540        elif hasattr(arg,"ln"):
541           return arg.log()
542        else:
543           return numarray.log(arg)
544    
545      @param arg  class Ln_Symbol(Symbol):
546      @param where     """symbol representing natural logarithm of the argument"""
547       def __init__(self,arg):
548           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
549       def __str__(self):
550          return "ln(%s)"%str(self.getArgument(0))
551       def eval(self,argval):
552           return ln(self.getEvaluatedArguments(argval)[0])
553       def _diff(self,arg):
554           return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
555    
556    def sin(arg):
557      """      """
558      if isinstance(arg,escript.Data):      @brief applies the sinus function to arg
559         return arg.interpolate(where)      @param arg (input): argument
560        """
561        if isinstance(arg,Symbol):
562           return Sin_Symbol(arg)
563        elif hasattr(arg,"sin"):
564           return arg.sin()
565      else:      else:
566         return arg         return numarray.sin(arg)
567    
568  # functions returning Data objects:  class Sin_Symbol(Symbol):
569       """symbol representing logarithm of the argument"""
570       def __init__(self,arg):
571           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
572       def __str__(self):
573          return "sin(%s)"%str(self.getArgument(0))
574       def eval(self,argval):
575           return sin(self.getEvaluatedArguments(argval)[0])
576       def _diff(self,arg):
577           return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
578    
579  def transpose(arg,axis=None):  def cos(arg):
580      """      """
581      @brief returns the transpose of the Data object arg.      @brief applies the sinus function to arg
582        @param arg (input): argument
     @param arg  
583      """      """
584      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
585         # hack for transpose         return Cos_Symbol(arg)
586         r=arg.getRank()      elif hasattr(arg,"cos"):
587         if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"         return arg.cos()
        s=arg.getShape()  
        out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())  
        for i in range(s[0]):  
           for j in range(s[1]):  
              out[j,i]=arg[i,j]  
        return out  
        # end hack for transpose  
        if axis==None: axis=arg.getRank()/2  
        return arg.transpose(axis)  
588      else:      else:
589         if axis==None: axis=arg.rank/2         return numarray.cos(arg)
        return numarray.transpose(arg,axis=axis)  
590    
591  def trace(arg):  class Cos_Symbol(Symbol):
592      """     """symbol representing logarithm of the argument"""
593      @brief     def __init__(self,arg):
594           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
595       def __str__(self):
596          return "cos(%s)"%str(self.getArgument(0))
597       def eval(self,argval):
598           return cos(self.getEvaluatedArguments(argval)[0])
599       def _diff(self,arg):
600           return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
601    
602      @param arg  def tan(arg):
603      """      """
604      if isinstance(arg,escript.Data):      @brief applies the sinus function to arg
605         # hack for trace      @param arg (input): argument
606         r=arg.getRank()      """
607         if r!=2: raise ValueError,"trace only avalaible for rank 2 objects"      if isinstance(arg,Symbol):
608         s=arg.getShape()         return Tan_Symbol(arg)
609         out=escript.Scalar(0,arg.getFunctionSpace())      elif hasattr(arg,"tan"):
610         for i in range(min(s)):         return arg.tan()
              out+=arg[i,i]  
        return out  
        # end hack for trace  
        return arg.trace()  
611      else:      else:
612         return numarray.trace(arg)         return numarray.tan(arg)
613    
614  def exp(arg):  class Tan_Symbol(Symbol):
615      """     """symbol representing logarithm of the argument"""
616      @brief     def __init__(self,arg):
617           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
618       def __str__(self):
619          return "tan(%s)"%str(self.getArgument(0))
620       def eval(self,argval):
621           return tan(self.getEvaluatedArguments(argval)[0])
622       def _diff(self,arg):
623           s=cos(self.getArgument(0))
624           return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
625    
626      @param arg  def sign(arg):
627      """      """
628      if isinstance(arg,escript.Data):      @brief applies the sign function to arg
629         return arg.exp()      @param arg (input): argument
630        """
631        if isinstance(arg,Symbol):
632           return Sign_Symbol(arg)
633        elif hasattr(arg,"sign"):
634           return arg.sign()
635      else:      else:
636         return numarray.exp(arg)         return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
637                  numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
638    
639  def sqrt(arg):  class Sign_Symbol(Symbol):
640      """     """symbol representing the sign of the argument"""
641      @brief     def __init__(self,arg):
642           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
643       def __str__(self):
644          return "sign(%s)"%str(self.getArgument(0))
645       def eval(self,argval):
646           return sign(self.getEvaluatedArguments(argval)[0])
647    
648      @param arg  def maxval(arg):
649      """      """
650      if isinstance(arg,escript.Data):      @brief returns the maximum value of argument arg""
651         return arg.sqrt()      @param arg (input): argument
652        """
653        if isinstance(arg,Symbol):
654           return Max_Symbol(arg)
655        elif hasattr(arg,"maxval"):
656           return arg.maxval()
657        elif hasattr(arg,"max"):
658           return arg.max()
659      else:      else:
660         return numarray.sqrt(arg)         return arg
661    
662  def sin(arg):  class Max_Symbol(Symbol):
663      """     """symbol representing the sign of the argument"""
664      @brief     def __init__(self,arg):
665           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
666       def __str__(self):
667          return "maxval(%s)"%str(self.getArgument(0))
668       def eval(self,argval):
669           return maxval(self.getEvaluatedArguments(argval)[0])
670    
671      @param arg  def minval(arg):
672      """      """
673      if isinstance(arg,escript.Data):      @brief returns the maximum value of argument arg""
674         return arg.sin()      @param arg (input): argument
675        """
676        if isinstance(arg,Symbol):
677           return Min_Symbol(arg)
678        elif hasattr(arg,"maxval"):
679           return arg.minval()
680        elif hasattr(arg,"min"):
681           return arg.min()
682      else:      else:
683         return numarray.sin(arg)         return arg
684    
685  def tan(arg):  class Min_Symbol(Symbol):
686       """symbol representing the sign of the argument"""
687       def __init__(self,arg):
688           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
689       def __str__(self):
690          return "minval(%s)"%str(self.getArgument(0))
691       def eval(self,argval):
692           return minval(self.getEvaluatedArguments(argval)[0])
693    
694    def wherePositive(arg):
695        """
696        @brief returns the maximum value of argument arg""
697        @param arg (input): argument
698        """
699        if _testForZero(arg):
700          return 0
701        elif isinstance(arg,Symbol):
702           return WherePositive_Symbol(arg)
703        elif hasattr(arg,"wherePositive"):
704           return arg.minval()
705        elif hasattr(arg,"wherePositive"):
706           numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
707        else:
708           if arg>0:
709              return 1.
710           else:
711              return 0.
712    
713    class WherePositive_Symbol(Symbol):
714       """symbol representing the wherePositive function"""
715       def __init__(self,arg):
716           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
717       def __str__(self):
718          return "wherePositive(%s)"%str(self.getArgument(0))
719       def eval(self,argval):
720           return wherePositive(self.getEvaluatedArguments(argval)[0])
721    
722    def whereNegative(arg):
723        """
724        @brief returns the maximum value of argument arg""
725        @param arg (input): argument
726        """
727        if _testForZero(arg):
728          return 0
729        elif isinstance(arg,Symbol):
730           return WhereNegative_Symbol(arg)
731        elif hasattr(arg,"whereNegative"):
732           return arg.whereNegative()
733        elif hasattr(arg,"shape"):
734           numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
735        else:
736           if arg<0:
737              return 1.
738           else:
739              return 0.
740    
741    class WhereNegative_Symbol(Symbol):
742       """symbol representing the whereNegative function"""
743       def __init__(self,arg):
744           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
745       def __str__(self):
746          return "whereNegative(%s)"%str(self.getArgument(0))
747       def eval(self,argval):
748           return whereNegative(self.getEvaluatedArguments(argval)[0])
749    
750    def outer(arg0,arg1):
751       if _testForZero(arg0) or _testForZero(arg1):
752          return 0
753       else:
754          if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
755            return Outer_Symbol(arg0,arg1)
756          elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
757            return arg0*arg1
758          elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
759            return numarray.outer(arg0,arg1)
760          else:
761            if arg0.getRank()==1 and arg1.getRank()==1:
762              out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
763              for i in range(arg0.getShape()[0]):
764                for j in range(arg1.getShape()[0]):
765                    out[i,j]=arg0[i]*arg1[j]
766              return out
767            else:
768              raise ValueError,"outer is not fully implemented yet."
769    
770    class Outer_Symbol(Symbol):
771       """symbol representing the outer product of its two argument"""
772       def __init__(self,arg0,arg1):
773           a=[arg0,arg1]
774           s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
775           Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
776       def __str__(self):
777          return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
778       def eval(self,argval):
779           a=self.getEvaluatedArguments(argval)
780           return outer(a[0],a[1])
781       def _diff(self,arg):
782           a=self.getDifferentiatedArguments(arg)
783           return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
784    
785    def interpolate(arg,where):
786      """      """
787      @brief      @brief interpolates the function into the FunctionSpace where.
788    
789      @param arg      @param arg    interpolant
790        @param where  FunctionSpace to interpolate to
791      """      """
792      if isinstance(arg,escript.Data):      if _testForZero(arg):
793         return arg.tan()        return 0
794        elif isinstance(arg,Symbol):
795           return Interpolated_Symbol(arg,where)
796      else:      else:
797         return numarray.tan(arg)         return escript.Data(arg,where)
798    
799  def cos(arg):  def Interpolated_Symbol(Symbol):
800      """     """symbol representing the integral of the argument"""
801      @brief     def __init__(self,arg,where):
802            Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
803       def __str__(self):
804          return "interpolated(%s)"%(str(self.getArgument(0)))
805       def eval(self,argval):
806           a=self.getEvaluatedArguments(argval)
807           return integrate(a[0],where=self.getArgument(1))
808       def _diff(self,arg):
809           a=self.getDifferentiatedArguments(arg)
810           return integrate(a[0],where=self.getArgument(1))
811    
812      @param arg  def grad(arg,where=None):
813      """      """
814      if isinstance(arg,escript.Data):      @brief returns the spatial gradient of arg at where.
815         return arg.cos()  
816        @param arg:   Data object representing the function which gradient to be calculated.
817        @param where: FunctionSpace in which the gradient will be calculated. If not present or
818                      None an appropriate default is used.
819        """
820        if _testForZero(arg):
821          return 0
822        elif isinstance(arg,Symbol):
823           return Grad_Symbol(arg,where)
824        elif hasattr(arg,"grad"):
825           if where==None:
826              return arg.grad()
827           else:
828              return arg.grad(where)
829      else:      else:
830         return numarray.cos(arg)         return arg*0.
831    
832  def maxval(arg):  def Grad_Symbol(Symbol):
833       """symbol representing the gradient of the argument"""
834       def __init__(self,arg,where=None):
835           d=_extractDim([arg])
836           s=tuple(list(_identifyShape([arg])).append(d))
837           Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
838       def __str__(self):
839          return "grad(%s)"%(str(self.getArgument(0)))
840       def eval(self,argval):
841           a=self.getEvaluatedArguments(argval)
842           return grad(a[0],where=self.getArgument(1))
843       def _diff(self,arg):
844           a=self.getDifferentiatedArguments(arg)
845           return grad(a[0],where=self.getArgument(1))
846    
847    def integrate(arg,where=None):
848      """      """
849      @brief      @brief return the integral if the function represented by Data object arg over its domain.
850    
851        @param arg:   Data object representing the function which is integrated.
852        @param where: FunctionSpace in which the integral is calculated. If not present or
853                      None an appropriate default is used.
854        """
855        if _testForZero(arg):
856          return 0
857        elif isinstance(arg,Symbol):
858           return Integral_Symbol(arg,where)
859        else:    
860           if not where==None: arg=escript.Data(arg,where)
861           if arg.getRank()==0:
862             return arg.integrate()[0]
863           else:
864             return arg.integrate()
865    
866    def Integral_Symbol(Float_Symbol):
867       """symbol representing the integral of the argument"""
868       def __init__(self,arg,where=None):
869           Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
870       def __str__(self):
871          return "integral(%s)"%(str(self.getArgument(0)))
872       def eval(self,argval):
873           a=self.getEvaluatedArguments(argval)
874           return integrate(a[0],where=self.getArgument(1))
875       def _diff(self,arg):
876           a=self.getDifferentiatedArguments(arg)
877           return integrate(a[0],where=self.getArgument(1))
878    
879    #=============================
880    #
881    # wrapper for various functions: if the argument has attribute the function name
882    # as an argument it calls the corresponding methods. Otherwise the corresponding
883    # numarray function is called.
884    
885    # functions involving the underlying Domain:
886    
887    
888    # functions returning Data objects:
889    
890    def transpose(arg,axis=None):
891        """
892        @brief returns the transpose of the Data object arg.
893    
894      @param arg      @param arg
895      """      """
896        if axis==None:
897           r=0
898           if hasattr(arg,"getRank"): r=arg.getRank()
899           if hasattr(arg,"rank"): r=arg.rank
900           axis=r/2
901        if isinstance(arg,Symbol):
902           return Transpose_Symbol(arg,axis=r)
903      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
904         return arg.maxval()         # hack for transpose
905      elif isinstance(arg,float) or isinstance(arg,int):         r=arg.getRank()
906         return arg         if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
907           s=arg.getShape()
908           out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
909           for i in range(s[0]):
910              for j in range(s[1]):
911                 out[j,i]=arg[i,j]
912           return out
913           # end hack for transpose
914           return arg.transpose(axis)
915      else:      else:
916         return arg.max()         return numarray.transpose(arg,axis=axis)
917    
918  def minval(arg):  def trace(arg,axis0=0,axis1=1):
919      """      """
920      @brief      @brief return
921    
922      @param arg      @param arg
923      """      """
924      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
925         return arg.minval()         s=list(arg.getShape())        
926      elif isinstance(arg,float) or isinstance(arg,int):         s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
927         return arg         return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
928        elif isinstance(arg,escript.Data):
929           # hack for trace
930           s=arg.getShape()
931           if s[axis0]!=s[axis1]:
932               raise ValueError,"illegal axis in trace"
933           out=escript.Scalar(0.,arg.getFunctionSpace())
934           for i in range(s[0]):
935              for j in range(s[1]):
936                 out+=arg[i,j]
937           return out
938           # end hack for transpose
939           return arg.transpose(axis0=axis0,axis1=axis1)
940      else:      else:
941         return arg.min()         return numarray.trace(arg,axis0=axis0,axis1=axis1)
942    
943    def Trace_Symbol(Symbol):
944        pass
945    
946  def length(arg):  def length(arg):
947      """      """
# Line 267  def length(arg): Line 985  def length(arg):
985      else:      else:
986         return sqrt((arg**2).sum())         return sqrt((arg**2).sum())
987    
988  def sign(arg):  def deviator(arg):
989      """      """
990      @brief      @brief
991    
992      @param arg      @param arg0
993      """      """
994      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
995         return arg.sign()          shape=arg.getShape()
996      else:      else:
997         return numarray.greater(arg,numarray.zeros(arg.shape))-numarray.less(arg,numarray.zeros(arg.shape))          shape=arg.shape
998        if len(shape)!=2:
999              raise ValueError,"Deviator requires rank 2 object"
1000        if shape[0]!=shape[1]:
1001              raise ValueError,"Deviator requires a square matrix"
1002        return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
1003    
1004  # reduction operations:  def inner(arg0,arg1):
1005        """
1006        @brief
1007    
1008        @param arg0, arg1
1009        """
1010        sum=escript.Scalar(0,arg0.getFunctionSpace())
1011        if arg.getRank()==0:
1012              return arg0*arg1
1013        elif arg.getRank()==1:
1014             sum=escript.Scalar(0,arg.getFunctionSpace())
1015             for i in range(arg.getShape()[0]):
1016                sum+=arg0[i]*arg1[i]
1017        elif arg.getRank()==2:
1018            sum=escript.Scalar(0,arg.getFunctionSpace())
1019            for i in range(arg.getShape()[0]):
1020               for j in range(arg.getShape()[1]):
1021                  sum+=arg0[i,j]*arg1[i,j]
1022        elif arg.getRank()==3:
1023            sum=escript.Scalar(0,arg.getFunctionSpace())
1024            for i in range(arg.getShape()[0]):
1025                for j in range(arg.getShape()[1]):
1026                   for k in range(arg.getShape()[2]):
1027                      sum+=arg0[i,j,k]*arg1[i,j,k]
1028        elif arg.getRank()==4:
1029            sum=escript.Scalar(0,arg.getFunctionSpace())
1030            for i in range(arg.getShape()[0]):
1031               for j in range(arg.getShape()[1]):
1032                  for k in range(arg.getShape()[2]):
1033                     for l in range(arg.getShape()[3]):
1034                        sum+=arg0[i,j,k,l]*arg1[i,j,k,l]
1035        else:
1036              raise SystemError,"inner is not been implemented yet"
1037        return sum
1038    
1039    def matrixmult(arg0,arg1):
1040    
1041        if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1042            numarray.matrixmult(arg0,arg1)
1043        else:
1044          # escript.matmult(arg0,arg1)
1045          if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1046            arg0=escript.Data(arg0,arg1.getFunctionSpace())
1047          elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1048            arg1=escript.Data(arg1,arg0.getFunctionSpace())
1049          if arg0.getRank()==2 and arg1.getRank()==1:
1050              out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1051              for i in range(arg0.getShape()[0]):
1052                 for j in range(arg0.getShape()[1]):
1053                   out[i]+=arg0[i,j]*arg1[j]
1054              return out
1055          else:
1056              raise SystemError,"matrixmult is not fully implemented yet!"
1057    
1058    #=========================================================
1059    # reduction operations:
1060    #=========================================================
1061  def sum(arg):  def sum(arg):
1062      """      """
1063      @brief      @brief
# Line 340  def Lsup(arg): Line 1118  def Lsup(arg):
1118      else:      else:
1119         return max(numarray.abs(arg))         return max(numarray.abs(arg))
1120    
1121  def dot(arg1,arg2):  def dot(arg0,arg1):
1122      """      """
1123      @brief      @brief
1124    
1125      @param arg      @param arg
1126      """      """
1127      if isinstance(arg1,escript.Data):      if isinstance(arg0,escript.Data):
1128         return arg1.dot(arg2)         return arg0.dot(arg1)
1129      elif isinstance(arg1,escript.Data):      elif isinstance(arg1,escript.Data):
1130         return arg2.dot(arg1)         return arg1.dot(arg0)
1131      else:      else:
1132         return numarray.dot(arg1,arg2)         return numarray.dot(arg0,arg1)
1133    
1134    def kronecker(d):
1135       if hasattr(d,"getDim"):
1136          return numarray.identity(d.getDim())
1137       else:
1138          return numarray.identity(d)
1139    
1140    def unit(i,d):
1141       """
1142       @brief return a unit vector of dimension d with nonzero index i
1143       @param d dimension
1144       @param i index
1145       """
1146       e = numarray.zeros((d,),numarray.Float)
1147       e[i] = 1.0
1148       return e
1149    
1150    # ============================================
1151    #   testing
1152    # ============================================
1153    
1154    if __name__=="__main__":
1155      u=ScalarSymbol(dim=2,name="u")
1156      v=ScalarSymbol(dim=2,name="v")
1157      v=VectorSymbol(2,"v")
1158      u=VectorSymbol(2,"u")
1159    
1160      print u+5,(u+5).diff(u)
1161      print 5+u,(5+u).diff(u)
1162      print u+v,(u+v).diff(u)
1163      print v+u,(v+u).diff(u)
1164    
1165      print u*5,(u*5).diff(u)
1166      print 5*u,(5*u).diff(u)
1167      print u*v,(u*v).diff(u)
1168      print v*u,(v*u).diff(u)
1169    
1170      print u-5,(u-5).diff(u)
1171      print 5-u,(5-u).diff(u)
1172      print u-v,(u-v).diff(u)
1173      print v-u,(v-u).diff(u)
1174    
1175      print u/5,(u/5).diff(u)
1176      print 5/u,(5/u).diff(u)
1177      print u/v,(u/v).diff(u)
1178      print v/u,(v/u).diff(u)
1179    
1180      print u**5,(u**5).diff(u)
1181      print 5**u,(5**u).diff(u)
1182      print u**v,(u**v).diff(u)
1183      print v**u,(v**u).diff(u)
1184    
1185      print exp(u),exp(u).diff(u)
1186      print sqrt(u),sqrt(u).diff(u)
1187      print log(u),log(u).diff(u)
1188      print sin(u),sin(u).diff(u)
1189      print cos(u),cos(u).diff(u)
1190      print tan(u),tan(u).diff(u)
1191      print sign(u),sign(u).diff(u)
1192      print abs(u),abs(u).diff(u)
1193      print wherePositive(u),wherePositive(u).diff(u)
1194      print whereNegative(u),whereNegative(u).diff(u)
1195      print maxval(u),maxval(u).diff(u)
1196      print minval(u),minval(u).diff(u)
1197    
1198      g=grad(u)
1199      print diff(5*g,g)
1200      4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1201    
1202    #
1203    # $Log$
1204    # Revision 1.12  2005/07/20 06:14:58  jgs
1205    # added ln(data) style wrapper for data.ln() - also added corresponding
1206    # implementation of Ln_Symbol class (not sure if this is right though)
1207    #
1208    # Revision 1.11  2005/07/08 04:07:35  jgs
1209    # Merge of development branch back to main trunk on 2005-07-08
1210    #
1211    # Revision 1.10  2005/06/09 05:37:59  jgs
1212    # Merge of development branch back to main trunk on 2005-06-09
1213    #
1214    # Revision 1.2.2.17  2005/07/07 07:28:58  gross
1215    # some stuff added to util.py to improve functionality
1216    #
1217    # Revision 1.2.2.16  2005/06/30 01:53:55  gross
1218    # a bug in coloring fixed
1219    #
1220    # Revision 1.2.2.15  2005/06/29 02:36:43  gross
1221    # Symbols have been introduced and some function clarified. needs much more work
1222    #
1223    # Revision 1.2.2.14  2005/05/20 04:05:23  gross
1224    # some work on a darcy flow started
1225    #
1226    # Revision 1.2.2.13  2005/03/16 05:17:58  matt
1227    # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1228    # complement kronecker(dim) function.
1229    #
1230    # Revision 1.2.2.12  2005/03/10 08:14:37  matt
1231    # Added non-member Linf utility function to complement Data::Linf().
1232    #
1233    # Revision 1.2.2.11  2005/02/17 05:53:25  gross
1234    # some bug in saveDX fixed: in fact the bug was in
1235    # DataC/getDataPointShape
1236    #
1237    # Revision 1.2.2.10  2005/01/11 04:59:36  gross
1238    # automatic interpolation in integrate switched off
1239    #
1240    # Revision 1.2.2.9  2005/01/11 03:38:13  gross
1241    # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1242    #
1243    # Revision 1.2.2.8  2005/01/05 04:21:41  gross
1244    # FunctionSpace checking/matchig in slicing added
1245    #
1246    # Revision 1.2.2.7  2004/12/29 05:29:59  gross
1247    # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1248    #
1249    # Revision 1.2.2.6  2004/12/24 06:05:41  gross
1250    # some changes in linearPDEs to add AdevectivePDE
1251    #
1252    # Revision 1.2.2.5  2004/12/17 00:06:53  gross
1253    # mk sets ESYS_ROOT is undefined
1254    #
1255    # Revision 1.2.2.4  2004/12/07 03:19:51  gross
1256    # options for GMRES and PRES20 added
1257    #
1258    # Revision 1.2.2.3  2004/12/06 04:55:18  gross
1259    # function wraper extended
1260    #
1261    # Revision 1.2.2.2  2004/11/22 05:44:07  gross
1262    # a few more unitary functions have been added but not implemented in Data yet
1263    #
1264    # Revision 1.2.2.1  2004/11/12 06:58:15  gross
1265    # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1266    #
1267    # Revision 1.2  2004/10/27 00:23:36  jgs
1268    # fixed minor syntax error
1269    #
1270    # Revision 1.1.1.1  2004/10/26 06:53:56  jgs
1271    # initial import of project esys2
1272    #
1273    # Revision 1.1.2.3  2004/10/26 06:43:48  jgs
1274    # committing Lutz's and Paul's changes to brach jgs
1275    #
1276    # Revision 1.1.4.1  2004/10/20 05:32:51  cochrane
1277    # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.
1278    #
1279    # Revision 1.1  2004/08/05 03:58:27  gross
1280    # Bug in Assemble_NodeCoordinates fixed
1281    #
1282    #

Legend:
Removed from v.108  
changed lines
  Added in v.124

  ViewVC Help
Powered by ViewVC 1.1.26