Diff of /trunk/escript/py_src/util.py

revision 122 by jgs, Thu Jun 9 05:38:05 2005 UTC revision 123 by jgs, Fri Jul 8 04:08:13 2005 UTC
# Line 4  Line 4
4
5  """  """
6  @brief Utility functions for escript  @brief Utility functions for escript
7
8    TODO for Data:
9
10      * binary operations @:               (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11                    @=+,-,*,/,**           (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12                                           (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13
14      * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15      * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16
17  """  """
18
19  import numarray  import numarray
# Line 49  TIFF=7 Line 59  TIFF=7
59  OPENINVENTOR=8  OPENINVENTOR=8
60  RENDERMAN=9  RENDERMAN=9
61  PNM=10  PNM=10
62    #===========================================================
63    #  a simple tool box to deal with _differentials of functions
64    #===========================================================
65
66    class Symbol:
67       """symbol class"""
68       def __init__(self,name="symbol",shape=(),dim=3,args=[]):
69           """creates an instance of a symbol of shape shape and spatial dimension dim
70              The symbol may depending on a list of arguments args which
71              may be symbols or other objects. name gives the name of the symbol."""
72           self.__args=args
73           self.__name=name
74           self.__shape=shape
75           if hasattr(dim,"getDim"):
76               self.__dim=dim.getDim()
77           else:
78               self.__dim=dim
79           #
80           self.__cache_val=None
81           self.__cache_argval=None
82
83       def getArgument(self,i):
84           """returns the i-th argument"""
85           return self.__args[i]
86
87       def getDim(self):
88           """returns the spatial dimension of the symbol"""
89           return self.__dim
90
91       def getRank(self):
92           """returns the rank of the symbol"""
93           return len(self.getShape())
94
95       def getShape(self):
96           """returns the shape of the symbol"""
97           return self.__shape
98
99       def getEvaluatedArguments(self,argval):
100           """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
101           if argval==self.__cache_argval:
102               print "%s: cached value used"%self
103               return self.__cache_val
104           else:
105               out=[]
106               for a  in self.__args:
107                  if isinstance(a,Symbol):
108                    out.append(a.eval(argval))
109                  else:
110                    out.append(a)
111               self.__cache_argval=argval
112               self.__cache_val=out
113               return out
114
115       def getDifferentiatedArguments(self,arg):
116           """returns the list of the arguments _differentiated by arg"""
117           out=[]
118           for a in self.__args:
119              if isinstance(a,Symbol):
120                out.append(a.diff(arg))
121              else:
122                out.append(0)
123           return out
124
125  #     def diff(self,arg):
126  # wrapper for various functions: if the argument has attribute the function name         """returns the _differention of self by arg."""
127  # as an argument it calls the correspong methods. Otherwise the coresponsing numarray         if self==arg:
128  # function is called.            out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
129  #            if self.getRank()==0:
130  # functions involving the underlying Domain:               out=1.
131  #            elif self.getRank()==1:
132  def grad(arg,where=None):                for i0 in range(self.getShape()[0]):
133      """                   out[i0,i0]=1.
134      @brief returns the spatial gradient of the Data object arg            elif self.getRank()==2:
135                  for i0 in range(self.getShape()[0]):
136                    for i1 in range(self.getShape()[1]):
137                         out[i0,i1,i0,i1]=1.
138              elif self.getRank()==3:
139                  for i0 in range(self.getShape()[0]):
140                    for i1 in range(self.getShape()[1]):
141                      for i2 in range(self.getShape()[2]):
142                         out[i0,i1,i2,i0,i1,i2]=1.
143              elif self.getRank()==4:
144                  for i0 in range(self.getShape()[0]):
145                    for i1 in range(self.getShape()[1]):
146                      for i2 in range(self.getShape()[2]):
147                        for i3 in range(self.getShape()[3]):
148                           out[i0,i1,i2,i3,i0,i1,i2,i3]=1.
149              else:
150                 raise ValueError,"differential support rank<5 only."
151              return out
152           else:
153              return self._diff(arg)
154
155      @param arg: Data object representing the function which gradient to be calculated.     def _diff(self,arg):
156      @param where: FunctionSpace in which the gradient will be. If None Function(dom) where dom is the         """return derivate of self with respect to arg (!=self).
157                    domain of the Data object arg.            This method is overwritten by a particular symbol"""
158      """         return 0
159      if isinstance(arg,escript.Data):
160         if where==None:     def eval(self,argval):
161            return arg.grad()         """subsitutes symbol u in self by argval[u] and returns the result. If
162              self is not a key of argval then self is returned."""
163           if argval.has_key(self):
164             return argval[self]
165         else:         else:
else:
return arg*0.
167
def integrate(arg,what=None):
"""
@brief return the integral if the function represented by Data object arg over its domain.
168
169      @param arg     def __str__(self):
170      """         """returns a string representation of the symbol"""
171      if not what==None:         return self.__name
arg2=escript.Data(arg,what)
else:
arg2=arg
if arg2.getRank()==0:
return arg2.integrate()[0]
else:
return arg2.integrate()
172
174      """         """adds other to symbol self. if _testForZero(other) self is returned."""
175      @brief interpolates the function represented by Data object arg into the FunctionSpace where.         if _testForZero(other):
176              return self
177           else:
178              a=_matchShape([self,other])
180
@param arg
@param where
"""
if isinstance(arg,escript.Data):
return arg.interpolate(where)
else:
return arg
181
182  # functions returning Data objects:     def __radd__(self,other):
183           """adds other to symbol self. if _testForZero(other) self is returned."""
184           return self+other
185
186       def __neg__(self):
187           """returns -self."""
188           return self*(-1.)
189
190       def __pos__(self):
191           """returns +self."""
192           return self
193
194       def __abs__(self):
195           """returns absolute value"""
196           return Abs_Symbol(self)
197
198       def __sub__(self,other):
199           """subtracts other from symbol self. if _testForZero(other) self is returned."""
200           if _testForZero(other):
201              return self
202           else:
203              return self+(-other)
204
205  def transpose(arg,axis=None):     def __rsub__(self,other):
206      """         """subtracts symbol self from other."""
207      @brief returns the transpose of the Data object arg.         return -self+other
208
209       def __div__(self,other):
210           """divides symbol self by other."""
211           if isinstance(other,Symbol):
212              a=_matchShape([self,other])
213              return Div_Symbol(a[0],a[1])
214           else:
215              return self*(1./other)
216
217      @param arg     def __rdiv__(self,other):
218      """         """dived other by symbol self. if _testForZero(other) 0 is returned."""
219      if isinstance(arg,escript.Data):         if _testForZero(other):
220         # hack for transpose            return 0
221         r=arg.getRank()         else:
222         if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"            a=_matchShape([self,other])
223         s=arg.getShape()            return Div_Symbol(a[0],a[1])
out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
for i in range(s[0]):
for j in range(s[1]):
out[j,i]=arg[i,j]
return out
# end hack for transpose
if axis==None: axis=arg.getRank()/2
return arg.transpose(axis)
else:
if axis==None: axis=arg.rank/2
return numarray.transpose(arg,axis=axis)
224
225  def trace(arg):     def __pow__(self,other):
226      """         """raises symbol self to the power of other"""
227      @brief         a=_matchShape([self,other])
228           return Power_Symbol(a[0],a[1])
229
230       def __rpow__(self,other):
231           """raises other to the symbol self"""
232           a=_matchShape([self,other])
233           return Power_Symbol(a[1],a[0])
234
235       def __mul__(self,other):
236           """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
237           if _testForZero(other):
238              return 0
239           else:
240              a=_matchShape([self,other])
241              return Mult_Symbol(a[0],a[1])
242
243      @param arg     def __rmul__(self,other):
244      """         """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
245      if isinstance(arg,escript.Data):         return self*other
246         # hack for trace
247         r=arg.getRank()     def __getitem__(self,sl):
248         if r!=2: raise ValueError,"trace only avalaible for rank 2 objects"            print sl
249         s=arg.getShape()
250         out=escript.Scalar(0,arg.getFunctionSpace())  def Float_Symbol(Symbol):
251         for i in range(min(s)):      def __init__(self,name="symbol",shape=(),args=[]):
252               out+=arg[i,i]          Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
253         return out
254         # end hack for trace  class ScalarSymbol(Symbol):
255         return arg.trace()     """a scalar symbol"""
256      else:     def __init__(self,dim=3,name="scalar"):
257         return numarray.trace(arg)        """creates a scalar symbol of spatial dimension dim"""
258          if hasattr(dim,"getDim"):
259               d=dim.getDim()
260          else:
261               d=dim
262          Symbol.__init__(self,shape=(),dim=d,name=name)
263
264    class VectorSymbol(Symbol):
265       """a vector symbol"""
266       def __init__(self,dim=3,name="vector"):
267          """creates a vector symbol of spatial dimension dim"""
268          if hasattr(dim,"getDim"):
269               d=dim.getDim()
270          else:
271               d=dim
272          Symbol.__init__(self,shape=(d,),dim=d,name=name)
273
274    class TensorSymbol(Symbol):
275       """a tensor symbol"""
276       def __init__(self,dim=3,name="tensor"):
277          """creates a tensor symbol of spatial dimension dim"""
278          if hasattr(dim,"getDim"):
279               d=dim.getDim()
280          else:
281               d=dim
282          Symbol.__init__(self,shape=(d,d),dim=d,name=name)
283
284    class Tensor3Symbol(Symbol):
285       """a tensor order 3 symbol"""
286       def __init__(self,dim=3,name="tensor3"):
287          """creates a tensor order 3 symbol of spatial dimension dim"""
288          if hasattr(dim,"getDim"):
289               d=dim.getDim()
290          else:
291               d=dim
292          Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
293
294    class Tensor4Symbol(Symbol):
295       """a tensor order 4 symbol"""
296       def __init__(self,dim=3,name="tensor4"):
297          """creates a tensor order 4 symbol of spatial dimension dim"""
298          if hasattr(dim,"getDim"):
299               d=dim.getDim()
300          else:
301               d=dim
302          Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
303
304
306       """symbol representing the sum of two arguments"""
307       def __init__(self,arg0,arg1):
308           a=[arg0,arg1]
309           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
310       def __str__(self):
311          return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
312       def eval(self,argval):
313           a=self.getEvaluatedArguments(argval)
314           return a[0]+a[1]
315       def _diff(self,arg):
316           a=self.getDifferentiatedArguments(arg)
317           return a[0]+a[1]
318
319    class Mult_Symbol(Symbol):
320       """symbol representing the product of two arguments"""
321       def __init__(self,arg0,arg1):
322           a=[arg0,arg1]
323           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
324       def __str__(self):
325          return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
326       def eval(self,argval):
327           a=self.getEvaluatedArguments(argval)
328           return a[0]*a[1]
329       def _diff(self,arg):
330           a=self.getDifferentiatedArguments(arg)
331           return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
332
333    class Div_Symbol(Symbol):
334       """symbol representing the quotient of two arguments"""
335       def __init__(self,arg0,arg1):
336           a=[arg0,arg1]
337           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
338       def __str__(self):
339          return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
340       def eval(self,argval):
341           a=self.getEvaluatedArguments(argval)
342           return a[0]/a[1]
343       def _diff(self,arg):
344           a=self.getDifferentiatedArguments(arg)
345           return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
346                              (self.getArgument(1)*self.getArgument(1))
347
348    class Power_Symbol(Symbol):
349       """symbol representing the power of the first argument to the power of the second argument"""
350       def __init__(self,arg0,arg1):
351           a=[arg0,arg1]
352           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
353       def __str__(self):
354          return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
355       def eval(self,argval):
356           a=self.getEvaluatedArguments(argval)
357           return a[0]**a[1]
358       def _diff(self,arg):
359           a=self.getDifferentiatedArguments(arg)
360           return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
361
362    class Abs_Symbol(Symbol):
363       """symbol representing absolute value of its argument"""
364       def __init__(self,arg):
365           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
366       def __str__(self):
367          return "abs(%s)"%str(self.getArgument(0))
368       def eval(self,argval):
369           return abs(self.getEvaluatedArguments(argval)[0])
370       def _diff(self,arg):
371           return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
372
373    #=========================================================
374    #   some little helpers
375    #=========================================================
376    def _testForZero(arg):
377       """returns True is arg is considered of being zero"""
378       if isinstance(arg,int):
379          return not arg>0
380       elif isinstance(arg,float):
381          return not arg>0.
382       elif isinstance(arg,numarray.NumArray):
383          a=abs(arg)
384          while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
385          return not a>0
386       else:
387          return False
388
389    def _extractDim(args):
390        dim=None
391        for a in args:
392           if hasattr(a,"getDim"):
393              d=a.getDim()
394              if dim==None:
395                 dim=d
396              else:
397                 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
398        if dim==None:
399           raise ValueError,"cannot recover spatial dimension"
400        return dim
401
402    def _identifyShape(arg):
403       """identifies the shape of arg."""
404       if hasattr(arg,"getShape"):
405           arg_shape=arg.getShape()
406       elif hasattr(arg,"shape"):
407         s=arg.shape
408         if callable(s):
409           arg_shape=s()
410         else:
411           arg_shape=s
412       else:
413           arg_shape=()
414       return arg_shape
415
416    def _extractShape(args):
417        """extracts the common shape of the list of arguments args"""
418        shape=None
419        for a in args:
420           a_shape=_identifyShape(a)
421           if shape==None: shape=a_shape
422           if shape!=a_shape: raise ValueError,"inconsistent shape"
423        if shape==None:
424           raise ValueError,"cannot recover shape"
425        return shape
426
427    def _matchShape(args,shape=None):
428        """returns the list of arguments args as object which have all the specified shape.
429           if shape is not given the shape "largest" shape of args is used."""
430        # identify the list of shapes:
431        arg_shapes=[]
432        for a in args: arg_shapes.append(_identifyShape(a))
433        # get the largest shape (currently the longest shape):
434        if shape==None: shape=max(arg_shapes)
435
436        out=[]
437        for i in range(len(args)):
438           if shape==arg_shapes[i]:
439              out.append(args[i])
440           else:
441              if len(shape)==0: # then len(arg_shapes[i])>0
442                raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
443              else:
444                if len(arg_shapes[i])==0:
445                    out.append(outer(args[i],numarray.ones(shape)))
446                else:
447                    raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
448        return out
449    #=========================================================
450    #   wrapper for various mathematical functions:
451    #=========================================================
452    def diff(arg,dep):
453        """returns the derivative of arg with respect to dep. If arg is not Symbol object
454           0 is returned"""
455        if isinstance(arg,Symbol):
456           return arg.diff(dep)
457        elif hasattr(arg,"shape"):
458              if callable(arg.shape):
459                  return numarray.zeros(arg.shape(),numarray.Float)
460              else:
461                  return numarray.zeros(arg.shape,numarray.Float)
462        else:
463           return 0
464
465  def exp(arg):  def exp(arg):
466      """      """
467      @brief      @brief applies the exponential function to arg
468        @param arg (input): argument
@param arg
469      """      """
470      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
471           return Exp_Symbol(arg)
472        elif hasattr(arg,"exp"):
473         return arg.exp()         return arg.exp()
474      else:      else:
475         return numarray.exp(arg)         return numarray.exp(arg)
476
477    class Exp_Symbol(Symbol):
478       """symbol representing the power of the first argument to the power of the second argument"""
479       def __init__(self,arg):
480           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
481       def __str__(self):
482          return "exp(%s)"%str(self.getArgument(0))
483       def eval(self,argval):
484           return exp(self.getEvaluatedArguments(argval)[0])
485       def _diff(self,arg):
486           return self*self.getDifferentiatedArguments(arg)[0]
487
488  def sqrt(arg):  def sqrt(arg):
489      """      """
490      @brief      @brief applies the squre root function to arg
491        @param arg (input): argument
@param arg
492      """      """
493      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
494           return Sqrt_Symbol(arg)
495        elif hasattr(arg,"sqrt"):
496         return arg.sqrt()         return arg.sqrt()
497      else:      else:
498         return numarray.sqrt(arg)         return numarray.sqrt(arg)
499
500    class Sqrt_Symbol(Symbol):
501       """symbol representing square root of argument"""
502       def __init__(self,arg):
503           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
504       def __str__(self):
505          return "sqrt(%s)"%str(self.getArgument(0))
506       def eval(self,argval):
507           return sqrt(self.getEvaluatedArguments(argval)[0])
508       def _diff(self,arg):
509           return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
510
511    def log(arg):
512        """
513        @brief applies the logarithmic function bases exp(1.) to arg
514        @param arg (input): argument
515        """
516        if isinstance(arg,Symbol):
517           return Log_Symbol(arg)
518        elif hasattr(arg,"log"):
519           return arg.log()
520        else:
521           return numarray.log(arg)
522
523    class Log_Symbol(Symbol):
524       """symbol representing logarithm of the argument"""
525       def __init__(self,arg):
526           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
527       def __str__(self):
528          return "log(%s)"%str(self.getArgument(0))
529       def eval(self,argval):
530           return log(self.getEvaluatedArguments(argval)[0])
531       def _diff(self,arg):
532           return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
533
534  def sin(arg):  def sin(arg):
535      """      """
536      @brief      @brief applies the sinus function to arg
537        @param arg (input): argument
@param arg
538      """      """
539      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
540           return Sin_Symbol(arg)
541        elif hasattr(arg,"sin"):
542         return arg.sin()         return arg.sin()
543      else:      else:
544         return numarray.sin(arg)         return numarray.sin(arg)
545
546  def tan(arg):  class Sin_Symbol(Symbol):
547       """symbol representing logarithm of the argument"""
548       def __init__(self,arg):
549           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
550       def __str__(self):
551          return "sin(%s)"%str(self.getArgument(0))
552       def eval(self,argval):
553           return sin(self.getEvaluatedArguments(argval)[0])
554       def _diff(self,arg):
555           return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
556
557    def cos(arg):
558      """      """
559      @brief      @brief applies the sinus function to arg
560        @param arg (input): argument
561        """
562        if isinstance(arg,Symbol):
563           return Cos_Symbol(arg)
564        elif hasattr(arg,"cos"):
565           return arg.cos()
566        else:
567           return numarray.cos(arg)
568
569      @param arg  class Cos_Symbol(Symbol):
570       """symbol representing logarithm of the argument"""
571       def __init__(self,arg):
572           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
573       def __str__(self):
574          return "cos(%s)"%str(self.getArgument(0))
575       def eval(self,argval):
576           return cos(self.getEvaluatedArguments(argval)[0])
577       def _diff(self,arg):
578           return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
579
580    def tan(arg):
581      """      """
582      if isinstance(arg,escript.Data):      @brief applies the sinus function to arg
583        @param arg (input): argument
584        """
585        if isinstance(arg,Symbol):
586           return Tan_Symbol(arg)
587        elif hasattr(arg,"tan"):
588         return arg.tan()         return arg.tan()
589      else:      else:
590         return numarray.tan(arg)         return numarray.tan(arg)
591
592  def cos(arg):  class Tan_Symbol(Symbol):
593      """     """symbol representing logarithm of the argument"""
594      @brief     def __init__(self,arg):
595           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
596       def __str__(self):
597          return "tan(%s)"%str(self.getArgument(0))
598       def eval(self,argval):
599           return tan(self.getEvaluatedArguments(argval)[0])
600       def _diff(self,arg):
601           s=cos(self.getArgument(0))
602           return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
603
604      @param arg  def sign(arg):
605      """      """
606      if isinstance(arg,escript.Data):      @brief applies the sign function to arg
607         return arg.cos()      @param arg (input): argument
608        """
609        if isinstance(arg,Symbol):
610           return Sign_Symbol(arg)
611        elif hasattr(arg,"sign"):
612           return arg.sign()
613      else:      else:
614         return numarray.cos(arg)         return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
615                  numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
616
617    class Sign_Symbol(Symbol):
618       """symbol representing the sign of the argument"""
619       def __init__(self,arg):
620           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
621       def __str__(self):
622          return "sign(%s)"%str(self.getArgument(0))
623       def eval(self,argval):
624           return sign(self.getEvaluatedArguments(argval)[0])
625
626  def maxval(arg):  def maxval(arg):
627      """      """
628      @brief      @brief returns the maximum value of argument arg""
629        @param arg (input): argument
@param arg
630      """      """
631      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
632           return Max_Symbol(arg)
633        elif hasattr(arg,"maxval"):
634         return arg.maxval()         return arg.maxval()
635      elif isinstance(arg,float) or isinstance(arg,int):      elif hasattr(arg,"max"):
return arg
else:
636         return arg.max()         return arg.max()
637        else:
638           return arg
639
640    class Max_Symbol(Symbol):
641       """symbol representing the sign of the argument"""
642       def __init__(self,arg):
643           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
644       def __str__(self):
645          return "maxval(%s)"%str(self.getArgument(0))
646       def eval(self,argval):
647           return maxval(self.getEvaluatedArguments(argval)[0])
648
649  def minval(arg):  def minval(arg):
650      """      """
651      @brief      @brief returns the maximum value of argument arg""
652        @param arg (input): argument
653        """
654        if isinstance(arg,Symbol):
655           return Min_Symbol(arg)
656        elif hasattr(arg,"maxval"):
657           return arg.minval()
658        elif hasattr(arg,"min"):
659           return arg.min()
660        else:
661           return arg
662
663    class Min_Symbol(Symbol):
664       """symbol representing the sign of the argument"""
665       def __init__(self,arg):
666           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
667       def __str__(self):
668          return "minval(%s)"%str(self.getArgument(0))
669       def eval(self,argval):
670           return minval(self.getEvaluatedArguments(argval)[0])
671
672    def wherePositive(arg):
673        """
674        @brief returns the maximum value of argument arg""
675        @param arg (input): argument
676        """
677        if _testForZero(arg):
678          return 0
679        elif isinstance(arg,Symbol):
680           return WherePositive_Symbol(arg)
681        elif hasattr(arg,"wherePositive"):
682           return arg.minval()
683        elif hasattr(arg,"wherePositive"):
684           numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
685        else:
686           if arg>0:
687              return 1.
688           else:
689              return 0.
690
691    class WherePositive_Symbol(Symbol):
692       """symbol representing the wherePositive function"""
693       def __init__(self,arg):
694           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
695       def __str__(self):
696          return "wherePositive(%s)"%str(self.getArgument(0))
697       def eval(self,argval):
698           return wherePositive(self.getEvaluatedArguments(argval)[0])
699
700    def whereNegative(arg):
701        """
702        @brief returns the maximum value of argument arg""
703        @param arg (input): argument
704        """
705        if _testForZero(arg):
706          return 0
707        elif isinstance(arg,Symbol):
708           return WhereNegative_Symbol(arg)
709        elif hasattr(arg,"whereNegative"):
710           return arg.whereNegative()
711        elif hasattr(arg,"shape"):
712           numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
713        else:
714           if arg<0:
715              return 1.
716           else:
717              return 0.
718
719    class WhereNegative_Symbol(Symbol):
720       """symbol representing the whereNegative function"""
721       def __init__(self,arg):
722           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
723       def __str__(self):
724          return "whereNegative(%s)"%str(self.getArgument(0))
725       def eval(self,argval):
726           return whereNegative(self.getEvaluatedArguments(argval)[0])
727
728    def outer(arg0,arg1):
729       if _testForZero(arg0) or _testForZero(arg1):
730          return 0
731       else:
732          if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
733            return Outer_Symbol(arg0,arg1)
734          elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
735            return arg0*arg1
736          elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
737            return numarray.outer(arg0,arg1)
738          else:
739            if arg0.getRank()==1 and arg1.getRank()==1:
740              out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
741              for i in range(arg0.getShape()[0]):
742                for j in range(arg1.getShape()[0]):
743                    out[i,j]=arg0[i]*arg1[j]
744              return out
745            else:
746              raise ValueError,"outer is not fully implemented yet."
747
748    class Outer_Symbol(Symbol):
749       """symbol representing the outer product of its two argument"""
750       def __init__(self,arg0,arg1):
751           a=[arg0,arg1]
752           s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
753           Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
754       def __str__(self):
755          return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
756       def eval(self,argval):
757           a=self.getEvaluatedArguments(argval)
758           return outer(a[0],a[1])
759       def _diff(self,arg):
760           a=self.getDifferentiatedArguments(arg)
761           return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
762
763    def interpolate(arg,where):
764        """
765        @brief interpolates the function into the FunctionSpace where.
766
767        @param arg    interpolant
768        @param where  FunctionSpace to interpolate to
769        """
770        if _testForZero(arg):
771          return 0
772        elif isinstance(arg,Symbol):
773           return Interpolated_Symbol(arg,where)
774        else:
775           return escript.Data(arg,where)
776
777    def Interpolated_Symbol(Symbol):
778       """symbol representing the integral of the argument"""
779       def __init__(self,arg,where):
780            Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
781       def __str__(self):
782          return "interpolated(%s)"%(str(self.getArgument(0)))
783       def eval(self,argval):
784           a=self.getEvaluatedArguments(argval)
785           return integrate(a[0],where=self.getArgument(1))
786       def _diff(self,arg):
787           a=self.getDifferentiatedArguments(arg)
788           return integrate(a[0],where=self.getArgument(1))
789
791        """
792        @brief returns the spatial gradient of arg at where.
793
794        @param arg:   Data object representing the function which gradient to be calculated.
795        @param where: FunctionSpace in which the gradient will be calculated. If not present or
796                      None an appropriate default is used.
797        """
798        if _testForZero(arg):
799          return 0
800        elif isinstance(arg,Symbol):
803           if where==None:
805           else:
807        else:
808           return arg*0.
809
811       """symbol representing the gradient of the argument"""
812       def __init__(self,arg,where=None):
813           d=_extractDim([arg])
814           s=tuple(list(_identifyShape([arg])).append(d))
815           Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
816       def __str__(self):
818       def eval(self,argval):
819           a=self.getEvaluatedArguments(argval)
821       def _diff(self,arg):
822           a=self.getDifferentiatedArguments(arg)
824
825    def integrate(arg,where=None):
826        """
827        @brief return the integral if the function represented by Data object arg over its domain.
828
829        @param arg:   Data object representing the function which is integrated.
830        @param where: FunctionSpace in which the integral is calculated. If not present or
831                      None an appropriate default is used.
832        """
833        if _testForZero(arg):
834          return 0
835        elif isinstance(arg,Symbol):
836           return Integral_Symbol(arg,where)
837        else:
838           if not where==None: arg=escript.Data(arg,where)
839           if arg.getRank()==0:
840             return arg.integrate()[0]
841           else:
842             return arg.integrate()
843
844    def Integral_Symbol(Float_Symbol):
845       """symbol representing the integral of the argument"""
846       def __init__(self,arg,where=None):
847           Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
848       def __str__(self):
849          return "integral(%s)"%(str(self.getArgument(0)))
850       def eval(self,argval):
851           a=self.getEvaluatedArguments(argval)
852           return integrate(a[0],where=self.getArgument(1))
853       def _diff(self,arg):
854           a=self.getDifferentiatedArguments(arg)
855           return integrate(a[0],where=self.getArgument(1))
856
857    #=============================
858    #
859    # wrapper for various functions: if the argument has attribute the function name
860    # as an argument it calls the correspong methods. Otherwise the coresponsing numarray
861    # function is called.
862    #
863    # functions involving the underlying Domain:
864    #
865
866
867    # functions returning Data objects:
868
869    def transpose(arg,axis=None):
870        """
871        @brief returns the transpose of the Data object arg.
872
873      @param arg      @param arg
874      """      """
875        if axis==None:
876           r=0
877           if hasattr(arg,"getRank"): r=arg.getRank()
878           if hasattr(arg,"rank"): r=arg.rank
879           axis=r/2
880        if isinstance(arg,Symbol):
881           return Transpose_Symbol(arg,axis=r)
882      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
883         return arg.minval()         # hack for transpose
884      elif isinstance(arg,float) or isinstance(arg,int):         r=arg.getRank()
885         return arg         if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
886           s=arg.getShape()
887           out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
888           for i in range(s[0]):
889              for j in range(s[1]):
890                 out[j,i]=arg[i,j]
891           return out
892           # end hack for transpose
893           return arg.transpose(axis)
894      else:      else:
895         return arg.min()         return numarray.transpose(arg,axis=axis)
896
897    def trace(arg,axis0=0,axis1=1):
898        """
899        @brief return
900
901        @param arg
902        """
903        if isinstance(arg,Symbol):
904           s=list(arg.getShape())
905           s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
906           return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
907        elif isinstance(arg,escript.Data):
908           # hack for trace
909           s=arg.getShape()
910           if s[axis0]!=s[axis1]:
911               raise ValueError,"illegal axis in trace"
912           out=escript.Scalar(0.,arg.getFunctionSpace())
913           for i in range(s[0]):
914              for j in range(s[1]):
915                 out+=arg[i,j]
916           return out
917           # end hack for transpose
918           return arg.transpose(axis0=axis0,axis1=axis1)
919        else:
920           return numarray.trace(arg,axis0=axis0,axis1=axis1)
921
922
923
924    def Trace_Symbol(Symbol):
925        pass
926
927  def length(arg):  def length(arg):
928      """      """
# Line 272  def deviator(arg): Line 970  def deviator(arg):
970      """      """
971      @brief      @brief
972
973      @param arg1      @param arg0
974      """      """
975      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
976          shape=arg.getShape()          shape=arg.getShape()
# Line 284  def deviator(arg): Line 982  def deviator(arg):
982            raise ValueError,"Deviator requires a square matrix"            raise ValueError,"Deviator requires a square matrix"
983      return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])      return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
984
985  def inner(arg1,arg2):  def inner(arg0,arg1):
986      """      """
987      @brief      @brief
988
989      @param arg1, arg2      @param arg0, arg1
990      """      """
991      sum=escript.Scalar(0,arg1.getFunctionSpace())      sum=escript.Scalar(0,arg0.getFunctionSpace())
992      if arg.getRank()==0:      if arg.getRank()==0:
993            return arg1*arg2            return arg0*arg1
994      elif arg.getRank()==1:      elif arg.getRank()==1:
995           sum=escript.Scalar(0,arg.getFunctionSpace())           sum=escript.Scalar(0,arg.getFunctionSpace())
996           for i in range(arg.getShape()[0]):           for i in range(arg.getShape()[0]):
997              sum+=arg1[i]*arg2[i]              sum+=arg0[i]*arg1[i]
998      elif arg.getRank()==2:      elif arg.getRank()==2:
999          sum=escript.Scalar(0,arg.getFunctionSpace())          sum=escript.Scalar(0,arg.getFunctionSpace())
1000          for i in range(arg.getShape()[0]):          for i in range(arg.getShape()[0]):
1001             for j in range(arg.getShape()[1]):             for j in range(arg.getShape()[1]):
1002                sum+=arg1[i,j]*arg2[i,j]                sum+=arg0[i,j]*arg1[i,j]
1003      elif arg.getRank()==3:      elif arg.getRank()==3:
1004          sum=escript.Scalar(0,arg.getFunctionSpace())          sum=escript.Scalar(0,arg.getFunctionSpace())
1005          for i in range(arg.getShape()[0]):          for i in range(arg.getShape()[0]):
1006              for j in range(arg.getShape()[1]):              for j in range(arg.getShape()[1]):
1007                 for k in range(arg.getShape()[2]):                 for k in range(arg.getShape()[2]):
1008                    sum+=arg1[i,j,k]*arg2[i,j,k]                    sum+=arg0[i,j,k]*arg1[i,j,k]
1009      elif arg.getRank()==4:      elif arg.getRank()==4:
1010          sum=escript.Scalar(0,arg.getFunctionSpace())          sum=escript.Scalar(0,arg.getFunctionSpace())
1011          for i in range(arg.getShape()[0]):          for i in range(arg.getShape()[0]):
1012             for j in range(arg.getShape()[1]):             for j in range(arg.getShape()[1]):
1013                for k in range(arg.getShape()[2]):                for k in range(arg.getShape()[2]):
1014                   for l in range(arg.getShape()[3]):                   for l in range(arg.getShape()[3]):
1015                      sum+=arg1[i,j,k,l]*arg2[i,j,k,l]                      sum+=arg0[i,j,k,l]*arg1[i,j,k,l]
1016      else:      else:
1017            raise SystemError,"inner is not been implemented yet"            raise SystemError,"inner is not been implemented yet"
1018      return sum      return sum
1019
1020  def sign(arg):  def matrixmult(arg0,arg1):
"""
@brief
1021
1022      @param arg      if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1023      """          numarray.matrixmult(arg0,arg1)
if isinstance(arg,escript.Data):
return arg.sign()
1024      else:      else:
1025         return numarray.greater(arg,numarray.zeros(arg.shape))-numarray.less(arg,numarray.zeros(arg.shape))        # escript.matmult(arg0,arg1)
1026          if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1027            arg0=escript.Data(arg0,arg1.getFunctionSpace())
1028          elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1029            arg1=escript.Data(arg1,arg0.getFunctionSpace())
1030          if arg0.getRank()==2 and arg1.getRank()==1:
1031              out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1032              for i in range(arg0.getShape()[0]):
1033                 for j in range(arg0.getShape()[1]):
1034                   out[i]+=arg0[i,j]*arg1[j]
1035              return out
1036          else:
1037              raise SystemError,"matrixmult is not fully implemented yet!"
1038    #=========================================================
1039  # reduction operations:  # reduction operations:
1040    #=========================================================
1041  def sum(arg):  def sum(arg):
1042      """      """
1043      @brief      @brief
# Line 392  def Lsup(arg): Line 1098  def Lsup(arg):
1098      else:      else:
1099         return max(numarray.abs(arg))         return max(numarray.abs(arg))
1100
1101  def Linf(arg):  def dot(arg0,arg1):
"""
@brief

@param arg
"""
if isinstance(arg,escript.Data):
return arg.Linf()
elif isinstance(arg,float) or isinstance(arg,int):
return abs(arg)
else:
return min(numarray.abs(arg))

def dot(arg1,arg2):
1102      """      """
1103      @brief      @brief
1104
1105      @param arg      @param arg
1106      """      """
1107      if isinstance(arg1,escript.Data):      if isinstance(arg0,escript.Data):
1108         return arg1.dot(arg2)         return arg0.dot(arg1)
1109      elif isinstance(arg1,escript.Data):      elif isinstance(arg1,escript.Data):
1110         return arg2.dot(arg1)         return arg1.dot(arg0)
1111      else:      else:
1112         return numarray.dot(arg1,arg2)         return numarray.dot(arg0,arg1)
1113
1114  def kronecker(d):  def kronecker(d):
1115     if hasattr(d,"getDim"):     if hasattr(d,"getDim"):
# Line 430  def unit(i,d): Line 1123  def unit(i,d):
1123     @param d dimension     @param d dimension
1124     @param i index     @param i index
1125     """     """
1126     e = numarray.zeros((d,))     e = numarray.zeros((d,),numarray.Float)
1127     e[i] = 1.0     e[i] = 1.0
1128     return e     return e
1129
1130  #  #
1131    # ============================================
1132    #   testing
1133    # ============================================
1134
1135    if __name__=="__main__":
1136      u=ScalarSymbol(dim=2,name="u")
1137      v=ScalarSymbol(dim=2,name="v")
1138      v=VectorSymbol(2,"v")
1139      u=VectorSymbol(2,"u")
1140
1141
1142      print u+5,(u+5).diff(u)
1143      print 5+u,(5+u).diff(u)
1144      print u+v,(u+v).diff(u)
1145      print v+u,(v+u).diff(u)
1146
1147      print u*5,(u*5).diff(u)
1148      print 5*u,(5*u).diff(u)
1149      print u*v,(u*v).diff(u)
1150      print v*u,(v*u).diff(u)
1151
1152      print u-5,(u-5).diff(u)
1153      print 5-u,(5-u).diff(u)
1154      print u-v,(u-v).diff(u)
1155      print v-u,(v-u).diff(u)
1156
1157      print u/5,(u/5).diff(u)
1158      print 5/u,(5/u).diff(u)
1159      print u/v,(u/v).diff(u)
1160      print v/u,(v/u).diff(u)
1161
1162      print u**5,(u**5).diff(u)
1163      print 5**u,(5**u).diff(u)
1164      print u**v,(u**v).diff(u)
1165      print v**u,(v**u).diff(u)
1166
1167      print exp(u),exp(u).diff(u)
1168      print sqrt(u),sqrt(u).diff(u)
1169      print log(u),log(u).diff(u)
1170      print sin(u),sin(u).diff(u)
1171      print cos(u),cos(u).diff(u)
1172      print tan(u),tan(u).diff(u)
1173      print sign(u),sign(u).diff(u)
1174      print abs(u),abs(u).diff(u)
1175      print wherePositive(u),wherePositive(u).diff(u)
1176      print whereNegative(u),whereNegative(u).diff(u)
1177      print maxval(u),maxval(u).diff(u)
1178      print minval(u),minval(u).diff(u)
1179
1181      print diff(5*g,g)
1182      4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1183    #
1184  # \$Log\$  # \$Log\$
1185    # Revision 1.11  2005/07/08 04:07:35  jgs
1186    # Merge of development branch back to main trunk on 2005-07-08
1187    #
1188  # Revision 1.10  2005/06/09 05:37:59  jgs  # Revision 1.10  2005/06/09 05:37:59  jgs
1189  # Merge of development branch back to main trunk on 2005-06-09  # Merge of development branch back to main trunk on 2005-06-09
1190  #  #
1191    # Revision 1.2.2.17  2005/07/07 07:28:58  gross
1192    # some stuff added to util.py to improve functionality
1193    #
1194    # Revision 1.2.2.16  2005/06/30 01:53:55  gross
1195    # a bug in coloring fixed
1196    #
1197    # Revision 1.2.2.15  2005/06/29 02:36:43  gross
1198    # Symbols have been introduced and some function clarified. needs much more work
1199    #
1200  # Revision 1.2.2.14  2005/05/20 04:05:23  gross  # Revision 1.2.2.14  2005/05/20 04:05:23  gross
1201  # some work on a darcy flow started  # some work on a darcy flow started
1202  #  #

Legend:
 Removed from v.122 changed lines Added in v.123