# Diff of /trunk/escript/py_src/util.py

revision 108 by jgs, Thu Jan 27 06:21:59 2005 UTC revision 123 by jgs, Fri Jul 8 04:08:13 2005 UTC
# Line 4  Line 4
4
5  """  """
6  @brief Utility functions for escript  @brief Utility functions for escript
7
8    TODO for Data:
9
10      * binary operations @:               (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11                    @=+,-,*,/,**           (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12                                           (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13
14      * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15      * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16
17  """  """
18
19  import numarray  import numarray
# Line 49  TIFF=7 Line 59  TIFF=7
59  OPENINVENTOR=8  OPENINVENTOR=8
60  RENDERMAN=9  RENDERMAN=9
61  PNM=10  PNM=10
62  #  #===========================================================
63  # wrapper for various functions: if the argument has attribute the function name  #  a simple tool box to deal with _differentials of functions
64  # as an argument it calls the correspong methods. Otherwise the coresponsing numarray  #===========================================================
65  # function is called.
66  #  class Symbol:
67  # functions involving the underlying Domain:     """symbol class"""
68  #     def __init__(self,name="symbol",shape=(),dim=3,args=[]):
69  def grad(arg,where=None):         """creates an instance of a symbol of shape shape and spatial dimension dim
70      """            The symbol may depending on a list of arguments args which
71      @brief returns the spatial gradient of the Data object arg            may be symbols or other objects. name gives the name of the symbol."""
72           self.__args=args
73           self.__name=name
74           self.__shape=shape
75           if hasattr(dim,"getDim"):
76               self.__dim=dim.getDim()
77           else:
78               self.__dim=dim
79           #
80           self.__cache_val=None
81           self.__cache_argval=None
82
83       def getArgument(self,i):
84           """returns the i-th argument"""
85           return self.__args[i]
86
87       def getDim(self):
88           """returns the spatial dimension of the symbol"""
89           return self.__dim
90
91       def getRank(self):
92           """returns the rank of the symbol"""
93           return len(self.getShape())
94
95       def getShape(self):
96           """returns the shape of the symbol"""
97           return self.__shape
98
99       def getEvaluatedArguments(self,argval):
100           """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
101           if argval==self.__cache_argval:
102               print "%s: cached value used"%self
103               return self.__cache_val
104           else:
105               out=[]
106               for a  in self.__args:
107                  if isinstance(a,Symbol):
108                    out.append(a.eval(argval))
109                  else:
110                    out.append(a)
111               self.__cache_argval=argval
112               self.__cache_val=out
113               return out
114
115       def getDifferentiatedArguments(self,arg):
116           """returns the list of the arguments _differentiated by arg"""
117           out=[]
118           for a in self.__args:
119              if isinstance(a,Symbol):
120                out.append(a.diff(arg))
121              else:
122                out.append(0)
123           return out
124
125      @param arg: Data object representing the function which gradient to be calculated.     def diff(self,arg):
126      @param where: FunctionSpace in which the gradient will be. If None Function(dom) where dom is the         """returns the _differention of self by arg."""
127                    domain of the Data object arg.         if self==arg:
128      """            out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
129      if isinstance(arg,escript.Data):            if self.getRank()==0:
130         if where==None:               out=1.
132                  for i0 in range(self.getShape()[0]):
133                     out[i0,i0]=1.
134              elif self.getRank()==2:
135                  for i0 in range(self.getShape()[0]):
136                    for i1 in range(self.getShape()[1]):
137                         out[i0,i1,i0,i1]=1.
138              elif self.getRank()==3:
139                  for i0 in range(self.getShape()[0]):
140                    for i1 in range(self.getShape()[1]):
141                      for i2 in range(self.getShape()[2]):
142                         out[i0,i1,i2,i0,i1,i2]=1.
143              elif self.getRank()==4:
144                  for i0 in range(self.getShape()[0]):
145                    for i1 in range(self.getShape()[1]):
146                      for i2 in range(self.getShape()[2]):
147                        for i3 in range(self.getShape()[3]):
148                           out[i0,i1,i2,i3,i0,i1,i2,i3]=1.
149              else:
150                 raise ValueError,"differential support rank<5 only."
151              return out
152         else:         else:
else:
return arg*0.
154
155  def integrate(arg,what=None):     def _diff(self,arg):
156      """         """return derivate of self with respect to arg (!=self).
157      @brief return the integral if the function represented by Data object arg over its domain.            This method is overwritten by a particular symbol"""
158           return 0
159
160       def eval(self,argval):
161           """subsitutes symbol u in self by argval[u] and returns the result. If
162              self is not a key of argval then self is returned."""
163           if argval.has_key(self):
164             return argval[self]
165           else:
166             return self
167
@param arg
"""
if not what==None:
arg2=escript.Data(arg,what)
else:
arg2=arg
if arg2.getRank()==0:
return arg2.integrate()[0]
else:
return arg2.integrate()
168
169  def interpolate(arg,where):     def __str__(self):
170      """         """returns a string representation of the symbol"""
171      @brief interpolates the function represented by Data object arg into the FunctionSpace where.         return self.__name
172
174      @param where         """adds other to symbol self. if _testForZero(other) self is returned."""
175      """         if _testForZero(other):
176      if isinstance(arg,escript.Data):            return self
177         return arg.interpolate(where)         else:
178      else:            a=_matchShape([self,other])
180
# functions returning Data objects:
181
183      """         """adds other to symbol self. if _testForZero(other) self is returned."""
184      @brief returns the transpose of the Data object arg.         return self+other
185
186       def __neg__(self):
187           """returns -self."""
188           return self*(-1.)
189
190       def __pos__(self):
191           """returns +self."""
192           return self
193
194       def __abs__(self):
195           """returns absolute value"""
196           return Abs_Symbol(self)
197
198       def __sub__(self,other):
199           """subtracts other from symbol self. if _testForZero(other) self is returned."""
200           if _testForZero(other):
201              return self
202           else:
203              return self+(-other)
204
205      @param arg     def __rsub__(self,other):
206      """         """subtracts symbol self from other."""
207      if isinstance(arg,escript.Data):         return -self+other
208         # hack for transpose
209         r=arg.getRank()     def __div__(self,other):
210         if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"         """divides symbol self by other."""
211         s=arg.getShape()         if isinstance(other,Symbol):
212         out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())            a=_matchShape([self,other])
213         for i in range(s[0]):            return Div_Symbol(a[0],a[1])
214            for j in range(s[1]):         else:
215               out[j,i]=arg[i,j]            return self*(1./other)
return out
# end hack for transpose
if axis==None: axis=arg.getRank()/2
return arg.transpose(axis)
else:
if axis==None: axis=arg.rank/2
return numarray.transpose(arg,axis=axis)
216
217  def trace(arg):     def __rdiv__(self,other):
218      """         """dived other by symbol self. if _testForZero(other) 0 is returned."""
219      @brief         if _testForZero(other):
220              return 0
221           else:
222              a=_matchShape([self,other])
223              return Div_Symbol(a[0],a[1])
224
225      @param arg     def __pow__(self,other):
226      """         """raises symbol self to the power of other"""
227      if isinstance(arg,escript.Data):         a=_matchShape([self,other])
228         # hack for trace         return Power_Symbol(a[0],a[1])
229         r=arg.getRank()
230         if r!=2: raise ValueError,"trace only avalaible for rank 2 objects"     def __rpow__(self,other):
231         s=arg.getShape()         """raises other to the symbol self"""
232         out=escript.Scalar(0,arg.getFunctionSpace())         a=_matchShape([self,other])
233         for i in range(min(s)):         return Power_Symbol(a[1],a[0])
234               out+=arg[i,i]
235         return out     def __mul__(self,other):
236         # end hack for trace         """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
237         return arg.trace()         if _testForZero(other):
238      else:            return 0
239         return numarray.trace(arg)         else:
240              a=_matchShape([self,other])
241              return Mult_Symbol(a[0],a[1])
242
243       def __rmul__(self,other):
244           """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
245           return self*other
246
247       def __getitem__(self,sl):
248              print sl
249
250    def Float_Symbol(Symbol):
251        def __init__(self,name="symbol",shape=(),args=[]):
252            Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
253
254    class ScalarSymbol(Symbol):
255       """a scalar symbol"""
256       def __init__(self,dim=3,name="scalar"):
257          """creates a scalar symbol of spatial dimension dim"""
258          if hasattr(dim,"getDim"):
259               d=dim.getDim()
260          else:
261               d=dim
262          Symbol.__init__(self,shape=(),dim=d,name=name)
263
264    class VectorSymbol(Symbol):
265       """a vector symbol"""
266       def __init__(self,dim=3,name="vector"):
267          """creates a vector symbol of spatial dimension dim"""
268          if hasattr(dim,"getDim"):
269               d=dim.getDim()
270          else:
271               d=dim
272          Symbol.__init__(self,shape=(d,),dim=d,name=name)
273
274    class TensorSymbol(Symbol):
275       """a tensor symbol"""
276       def __init__(self,dim=3,name="tensor"):
277          """creates a tensor symbol of spatial dimension dim"""
278          if hasattr(dim,"getDim"):
279               d=dim.getDim()
280          else:
281               d=dim
282          Symbol.__init__(self,shape=(d,d),dim=d,name=name)
283
284    class Tensor3Symbol(Symbol):
285       """a tensor order 3 symbol"""
286       def __init__(self,dim=3,name="tensor3"):
287          """creates a tensor order 3 symbol of spatial dimension dim"""
288          if hasattr(dim,"getDim"):
289               d=dim.getDim()
290          else:
291               d=dim
292          Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
293
294    class Tensor4Symbol(Symbol):
295       """a tensor order 4 symbol"""
296       def __init__(self,dim=3,name="tensor4"):
297          """creates a tensor order 4 symbol of spatial dimension dim"""
298          if hasattr(dim,"getDim"):
299               d=dim.getDim()
300          else:
301               d=dim
302          Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
303
304
306       """symbol representing the sum of two arguments"""
307       def __init__(self,arg0,arg1):
308           a=[arg0,arg1]
309           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
310       def __str__(self):
311          return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
312       def eval(self,argval):
313           a=self.getEvaluatedArguments(argval)
314           return a[0]+a[1]
315       def _diff(self,arg):
316           a=self.getDifferentiatedArguments(arg)
317           return a[0]+a[1]
318
319    class Mult_Symbol(Symbol):
320       """symbol representing the product of two arguments"""
321       def __init__(self,arg0,arg1):
322           a=[arg0,arg1]
323           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
324       def __str__(self):
325          return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
326       def eval(self,argval):
327           a=self.getEvaluatedArguments(argval)
328           return a[0]*a[1]
329       def _diff(self,arg):
330           a=self.getDifferentiatedArguments(arg)
331           return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
332
333    class Div_Symbol(Symbol):
334       """symbol representing the quotient of two arguments"""
335       def __init__(self,arg0,arg1):
336           a=[arg0,arg1]
337           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
338       def __str__(self):
339          return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
340       def eval(self,argval):
341           a=self.getEvaluatedArguments(argval)
342           return a[0]/a[1]
343       def _diff(self,arg):
344           a=self.getDifferentiatedArguments(arg)
345           return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
346                              (self.getArgument(1)*self.getArgument(1))
347
348    class Power_Symbol(Symbol):
349       """symbol representing the power of the first argument to the power of the second argument"""
350       def __init__(self,arg0,arg1):
351           a=[arg0,arg1]
352           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
353       def __str__(self):
354          return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
355       def eval(self,argval):
356           a=self.getEvaluatedArguments(argval)
357           return a[0]**a[1]
358       def _diff(self,arg):
359           a=self.getDifferentiatedArguments(arg)
360           return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
361
362    class Abs_Symbol(Symbol):
363       """symbol representing absolute value of its argument"""
364       def __init__(self,arg):
365           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
366       def __str__(self):
367          return "abs(%s)"%str(self.getArgument(0))
368       def eval(self,argval):
369           return abs(self.getEvaluatedArguments(argval)[0])
370       def _diff(self,arg):
371           return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
372
373    #=========================================================
374    #   some little helpers
375    #=========================================================
376    def _testForZero(arg):
377       """returns True is arg is considered of being zero"""
378       if isinstance(arg,int):
379          return not arg>0
380       elif isinstance(arg,float):
381          return not arg>0.
382       elif isinstance(arg,numarray.NumArray):
383          a=abs(arg)
384          while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
385          return not a>0
386       else:
387          return False
388
389    def _extractDim(args):
390        dim=None
391        for a in args:
392           if hasattr(a,"getDim"):
393              d=a.getDim()
394              if dim==None:
395                 dim=d
396              else:
397                 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
398        if dim==None:
399           raise ValueError,"cannot recover spatial dimension"
400        return dim
401
402    def _identifyShape(arg):
403       """identifies the shape of arg."""
404       if hasattr(arg,"getShape"):
405           arg_shape=arg.getShape()
406       elif hasattr(arg,"shape"):
407         s=arg.shape
408         if callable(s):
409           arg_shape=s()
410         else:
411           arg_shape=s
412       else:
413           arg_shape=()
414       return arg_shape
415
416    def _extractShape(args):
417        """extracts the common shape of the list of arguments args"""
418        shape=None
419        for a in args:
420           a_shape=_identifyShape(a)
421           if shape==None: shape=a_shape
422           if shape!=a_shape: raise ValueError,"inconsistent shape"
423        if shape==None:
424           raise ValueError,"cannot recover shape"
425        return shape
426
427    def _matchShape(args,shape=None):
428        """returns the list of arguments args as object which have all the specified shape.
429           if shape is not given the shape "largest" shape of args is used."""
430        # identify the list of shapes:
431        arg_shapes=[]
432        for a in args: arg_shapes.append(_identifyShape(a))
433        # get the largest shape (currently the longest shape):
434        if shape==None: shape=max(arg_shapes)
435
436        out=[]
437        for i in range(len(args)):
438           if shape==arg_shapes[i]:
439              out.append(args[i])
440           else:
441              if len(shape)==0: # then len(arg_shapes[i])>0
442                raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
443              else:
444                if len(arg_shapes[i])==0:
445                    out.append(outer(args[i],numarray.ones(shape)))
446                else:
447                    raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
448        return out
449    #=========================================================
450    #   wrapper for various mathematical functions:
451    #=========================================================
452    def diff(arg,dep):
453        """returns the derivative of arg with respect to dep. If arg is not Symbol object
454           0 is returned"""
455        if isinstance(arg,Symbol):
456           return arg.diff(dep)
457        elif hasattr(arg,"shape"):
458              if callable(arg.shape):
459                  return numarray.zeros(arg.shape(),numarray.Float)
460              else:
461                  return numarray.zeros(arg.shape,numarray.Float)
462        else:
463           return 0
464
465  def exp(arg):  def exp(arg):
466      """      """
467      @brief      @brief applies the exponential function to arg
468        @param arg (input): argument
@param arg
469      """      """
470      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
471           return Exp_Symbol(arg)
472        elif hasattr(arg,"exp"):
473         return arg.exp()         return arg.exp()
474      else:      else:
475         return numarray.exp(arg)         return numarray.exp(arg)
476
477    class Exp_Symbol(Symbol):
478       """symbol representing the power of the first argument to the power of the second argument"""
479       def __init__(self,arg):
480           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
481       def __str__(self):
482          return "exp(%s)"%str(self.getArgument(0))
483       def eval(self,argval):
484           return exp(self.getEvaluatedArguments(argval)[0])
485       def _diff(self,arg):
486           return self*self.getDifferentiatedArguments(arg)[0]
487
488  def sqrt(arg):  def sqrt(arg):
489      """      """
490      @brief      @brief applies the squre root function to arg
491        @param arg (input): argument
@param arg
492      """      """
493      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
494           return Sqrt_Symbol(arg)
495        elif hasattr(arg,"sqrt"):
496         return arg.sqrt()         return arg.sqrt()
497      else:      else:
498         return numarray.sqrt(arg)         return numarray.sqrt(arg)
499
500    class Sqrt_Symbol(Symbol):
501       """symbol representing square root of argument"""
502       def __init__(self,arg):
503           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
504       def __str__(self):
505          return "sqrt(%s)"%str(self.getArgument(0))
506       def eval(self,argval):
507           return sqrt(self.getEvaluatedArguments(argval)[0])
508       def _diff(self,arg):
509           return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
510
511    def log(arg):
512        """
513        @brief applies the logarithmic function bases exp(1.) to arg
514        @param arg (input): argument
515        """
516        if isinstance(arg,Symbol):
517           return Log_Symbol(arg)
518        elif hasattr(arg,"log"):
519           return arg.log()
520        else:
521           return numarray.log(arg)
522
523    class Log_Symbol(Symbol):
524       """symbol representing logarithm of the argument"""
525       def __init__(self,arg):
526           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
527       def __str__(self):
528          return "log(%s)"%str(self.getArgument(0))
529       def eval(self,argval):
530           return log(self.getEvaluatedArguments(argval)[0])
531       def _diff(self,arg):
532           return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
533
534  def sin(arg):  def sin(arg):
535      """      """
536      @brief      @brief applies the sinus function to arg
537        @param arg (input): argument
@param arg
538      """      """
539      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
540           return Sin_Symbol(arg)
541        elif hasattr(arg,"sin"):
542         return arg.sin()         return arg.sin()
543      else:      else:
544         return numarray.sin(arg)         return numarray.sin(arg)
545
546  def tan(arg):  class Sin_Symbol(Symbol):
547       """symbol representing logarithm of the argument"""
548       def __init__(self,arg):
549           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
550       def __str__(self):
551          return "sin(%s)"%str(self.getArgument(0))
552       def eval(self,argval):
553           return sin(self.getEvaluatedArguments(argval)[0])
554       def _diff(self,arg):
555           return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
556
557    def cos(arg):
558      """      """
559      @brief      @brief applies the sinus function to arg
560        @param arg (input): argument
561        """
562        if isinstance(arg,Symbol):
563           return Cos_Symbol(arg)
564        elif hasattr(arg,"cos"):
565           return arg.cos()
566        else:
567           return numarray.cos(arg)
568
569      @param arg  class Cos_Symbol(Symbol):
570       """symbol representing logarithm of the argument"""
571       def __init__(self,arg):
572           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
573       def __str__(self):
574          return "cos(%s)"%str(self.getArgument(0))
575       def eval(self,argval):
576           return cos(self.getEvaluatedArguments(argval)[0])
577       def _diff(self,arg):
578           return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
579
580    def tan(arg):
581      """      """
582      if isinstance(arg,escript.Data):      @brief applies the sinus function to arg
583        @param arg (input): argument
584        """
585        if isinstance(arg,Symbol):
586           return Tan_Symbol(arg)
587        elif hasattr(arg,"tan"):
588         return arg.tan()         return arg.tan()
589      else:      else:
590         return numarray.tan(arg)         return numarray.tan(arg)
591
592  def cos(arg):  class Tan_Symbol(Symbol):
593      """     """symbol representing logarithm of the argument"""
594      @brief     def __init__(self,arg):
595           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
596       def __str__(self):
597          return "tan(%s)"%str(self.getArgument(0))
598       def eval(self,argval):
599           return tan(self.getEvaluatedArguments(argval)[0])
600       def _diff(self,arg):
601           s=cos(self.getArgument(0))
602           return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
603
604      @param arg  def sign(arg):
605      """      """
606      if isinstance(arg,escript.Data):      @brief applies the sign function to arg
607         return arg.cos()      @param arg (input): argument
608        """
609        if isinstance(arg,Symbol):
610           return Sign_Symbol(arg)
611        elif hasattr(arg,"sign"):
612           return arg.sign()
613      else:      else:
614         return numarray.cos(arg)         return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
615                  numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
616
617    class Sign_Symbol(Symbol):
618       """symbol representing the sign of the argument"""
619       def __init__(self,arg):
620           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
621       def __str__(self):
622          return "sign(%s)"%str(self.getArgument(0))
623       def eval(self,argval):
624           return sign(self.getEvaluatedArguments(argval)[0])
625
626  def maxval(arg):  def maxval(arg):
627      """      """
628      @brief      @brief returns the maximum value of argument arg""
629        @param arg (input): argument
@param arg
630      """      """
631      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
632           return Max_Symbol(arg)
633        elif hasattr(arg,"maxval"):
634         return arg.maxval()         return arg.maxval()
635      elif isinstance(arg,float) or isinstance(arg,int):      elif hasattr(arg,"max"):
return arg
else:
636         return arg.max()         return arg.max()
637        else:
638           return arg
639
640    class Max_Symbol(Symbol):
641       """symbol representing the sign of the argument"""
642       def __init__(self,arg):
643           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
644       def __str__(self):
645          return "maxval(%s)"%str(self.getArgument(0))
646       def eval(self,argval):
647           return maxval(self.getEvaluatedArguments(argval)[0])
648
649  def minval(arg):  def minval(arg):
650      """      """
651      @brief      @brief returns the maximum value of argument arg""
652        @param arg (input): argument
653        """
654        if isinstance(arg,Symbol):
655           return Min_Symbol(arg)
656        elif hasattr(arg,"maxval"):
657           return arg.minval()
658        elif hasattr(arg,"min"):
659           return arg.min()
660        else:
661           return arg
662
663    class Min_Symbol(Symbol):
664       """symbol representing the sign of the argument"""
665       def __init__(self,arg):
666           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
667       def __str__(self):
668          return "minval(%s)"%str(self.getArgument(0))
669       def eval(self,argval):
670           return minval(self.getEvaluatedArguments(argval)[0])
671
672    def wherePositive(arg):
673        """
674        @brief returns the maximum value of argument arg""
675        @param arg (input): argument
676        """
677        if _testForZero(arg):
678          return 0
679        elif isinstance(arg,Symbol):
680           return WherePositive_Symbol(arg)
681        elif hasattr(arg,"wherePositive"):
682           return arg.minval()
683        elif hasattr(arg,"wherePositive"):
684           numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
685        else:
686           if arg>0:
687              return 1.
688           else:
689              return 0.
690
691    class WherePositive_Symbol(Symbol):
692       """symbol representing the wherePositive function"""
693       def __init__(self,arg):
694           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
695       def __str__(self):
696          return "wherePositive(%s)"%str(self.getArgument(0))
697       def eval(self,argval):
698           return wherePositive(self.getEvaluatedArguments(argval)[0])
699
700    def whereNegative(arg):
701        """
702        @brief returns the maximum value of argument arg""
703        @param arg (input): argument
704        """
705        if _testForZero(arg):
706          return 0
707        elif isinstance(arg,Symbol):
708           return WhereNegative_Symbol(arg)
709        elif hasattr(arg,"whereNegative"):
710           return arg.whereNegative()
711        elif hasattr(arg,"shape"):
712           numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
713        else:
714           if arg<0:
715              return 1.
716           else:
717              return 0.
718
719    class WhereNegative_Symbol(Symbol):
720       """symbol representing the whereNegative function"""
721       def __init__(self,arg):
722           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
723       def __str__(self):
724          return "whereNegative(%s)"%str(self.getArgument(0))
725       def eval(self,argval):
726           return whereNegative(self.getEvaluatedArguments(argval)[0])
727
728    def outer(arg0,arg1):
729       if _testForZero(arg0) or _testForZero(arg1):
730          return 0
731       else:
732          if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
733            return Outer_Symbol(arg0,arg1)
734          elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
735            return arg0*arg1
736          elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
737            return numarray.outer(arg0,arg1)
738          else:
739            if arg0.getRank()==1 and arg1.getRank()==1:
740              out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
741              for i in range(arg0.getShape()[0]):
742                for j in range(arg1.getShape()[0]):
743                    out[i,j]=arg0[i]*arg1[j]
744              return out
745            else:
746              raise ValueError,"outer is not fully implemented yet."
747
748    class Outer_Symbol(Symbol):
749       """symbol representing the outer product of its two argument"""
750       def __init__(self,arg0,arg1):
751           a=[arg0,arg1]
752           s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
753           Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
754       def __str__(self):
755          return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
756       def eval(self,argval):
757           a=self.getEvaluatedArguments(argval)
758           return outer(a[0],a[1])
759       def _diff(self,arg):
760           a=self.getDifferentiatedArguments(arg)
761           return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
762
763    def interpolate(arg,where):
764        """
765        @brief interpolates the function into the FunctionSpace where.
766
767        @param arg    interpolant
768        @param where  FunctionSpace to interpolate to
769        """
770        if _testForZero(arg):
771          return 0
772        elif isinstance(arg,Symbol):
773           return Interpolated_Symbol(arg,where)
774        else:
775           return escript.Data(arg,where)
776
777    def Interpolated_Symbol(Symbol):
778       """symbol representing the integral of the argument"""
779       def __init__(self,arg,where):
780            Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
781       def __str__(self):
782          return "interpolated(%s)"%(str(self.getArgument(0)))
783       def eval(self,argval):
784           a=self.getEvaluatedArguments(argval)
785           return integrate(a[0],where=self.getArgument(1))
786       def _diff(self,arg):
787           a=self.getDifferentiatedArguments(arg)
788           return integrate(a[0],where=self.getArgument(1))
789
791        """
792        @brief returns the spatial gradient of arg at where.
793
794        @param arg:   Data object representing the function which gradient to be calculated.
795        @param where: FunctionSpace in which the gradient will be calculated. If not present or
796                      None an appropriate default is used.
797        """
798        if _testForZero(arg):
799          return 0
800        elif isinstance(arg,Symbol):
803           if where==None:
805           else:
807        else:
808           return arg*0.
809
811       """symbol representing the gradient of the argument"""
812       def __init__(self,arg,where=None):
813           d=_extractDim([arg])
814           s=tuple(list(_identifyShape([arg])).append(d))
815           Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
816       def __str__(self):
818       def eval(self,argval):
819           a=self.getEvaluatedArguments(argval)
821       def _diff(self,arg):
822           a=self.getDifferentiatedArguments(arg)
824
825    def integrate(arg,where=None):
826        """
827        @brief return the integral if the function represented by Data object arg over its domain.
828
829        @param arg:   Data object representing the function which is integrated.
830        @param where: FunctionSpace in which the integral is calculated. If not present or
831                      None an appropriate default is used.
832        """
833        if _testForZero(arg):
834          return 0
835        elif isinstance(arg,Symbol):
836           return Integral_Symbol(arg,where)
837        else:
838           if not where==None: arg=escript.Data(arg,where)
839           if arg.getRank()==0:
840             return arg.integrate()[0]
841           else:
842             return arg.integrate()
843
844    def Integral_Symbol(Float_Symbol):
845       """symbol representing the integral of the argument"""
846       def __init__(self,arg,where=None):
847           Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
848       def __str__(self):
849          return "integral(%s)"%(str(self.getArgument(0)))
850       def eval(self,argval):
851           a=self.getEvaluatedArguments(argval)
852           return integrate(a[0],where=self.getArgument(1))
853       def _diff(self,arg):
854           a=self.getDifferentiatedArguments(arg)
855           return integrate(a[0],where=self.getArgument(1))
856
857    #=============================
858    #
859    # wrapper for various functions: if the argument has attribute the function name
860    # as an argument it calls the correspong methods. Otherwise the coresponsing numarray
861    # function is called.
862    #
863    # functions involving the underlying Domain:
864    #
865
866
867    # functions returning Data objects:
868
869    def transpose(arg,axis=None):
870        """
871        @brief returns the transpose of the Data object arg.
872
873      @param arg      @param arg
874      """      """
875        if axis==None:
876           r=0
877           if hasattr(arg,"getRank"): r=arg.getRank()
878           if hasattr(arg,"rank"): r=arg.rank
879           axis=r/2
880        if isinstance(arg,Symbol):
881           return Transpose_Symbol(arg,axis=r)
882      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
883         return arg.minval()         # hack for transpose
884      elif isinstance(arg,float) or isinstance(arg,int):         r=arg.getRank()
885         return arg         if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
886           s=arg.getShape()
887           out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
888           for i in range(s[0]):
889              for j in range(s[1]):
890                 out[j,i]=arg[i,j]
891           return out
892           # end hack for transpose
893           return arg.transpose(axis)
894      else:      else:
895         return arg.min()         return numarray.transpose(arg,axis=axis)
896
897    def trace(arg,axis0=0,axis1=1):
898        """
899        @brief return
900
901        @param arg
902        """
903        if isinstance(arg,Symbol):
904           s=list(arg.getShape())
905           s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
906           return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
907        elif isinstance(arg,escript.Data):
908           # hack for trace
909           s=arg.getShape()
910           if s[axis0]!=s[axis1]:
911               raise ValueError,"illegal axis in trace"
912           out=escript.Scalar(0.,arg.getFunctionSpace())
913           for i in range(s[0]):
914              for j in range(s[1]):
915                 out+=arg[i,j]
916           return out
917           # end hack for transpose
918           return arg.transpose(axis0=axis0,axis1=axis1)
919        else:
920           return numarray.trace(arg,axis0=axis0,axis1=axis1)
921
922
923
924    def Trace_Symbol(Symbol):
925        pass
926
927  def length(arg):  def length(arg):
928      """      """
# Line 267  def length(arg): Line 966  def length(arg):
966      else:      else:
967         return sqrt((arg**2).sum())         return sqrt((arg**2).sum())
968
969  def sign(arg):  def deviator(arg):
970      """      """
971      @brief      @brief
972
973      @param arg      @param arg0
974      """      """
975      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
976         return arg.sign()          shape=arg.getShape()
977      else:      else:
978         return numarray.greater(arg,numarray.zeros(arg.shape))-numarray.less(arg,numarray.zeros(arg.shape))          shape=arg.shape
979        if len(shape)!=2:
980              raise ValueError,"Deviator requires rank 2 object"
981        if shape[0]!=shape[1]:
982              raise ValueError,"Deviator requires a square matrix"
983        return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
984
985  # reduction operations:  def inner(arg0,arg1):
986        """
987        @brief
988
989        @param arg0, arg1
990        """
991        sum=escript.Scalar(0,arg0.getFunctionSpace())
992        if arg.getRank()==0:
993              return arg0*arg1
994        elif arg.getRank()==1:
995             sum=escript.Scalar(0,arg.getFunctionSpace())
996             for i in range(arg.getShape()[0]):
997                sum+=arg0[i]*arg1[i]
998        elif arg.getRank()==2:
999            sum=escript.Scalar(0,arg.getFunctionSpace())
1000            for i in range(arg.getShape()[0]):
1001               for j in range(arg.getShape()[1]):
1002                  sum+=arg0[i,j]*arg1[i,j]
1003        elif arg.getRank()==3:
1004            sum=escript.Scalar(0,arg.getFunctionSpace())
1005            for i in range(arg.getShape()[0]):
1006                for j in range(arg.getShape()[1]):
1007                   for k in range(arg.getShape()[2]):
1008                      sum+=arg0[i,j,k]*arg1[i,j,k]
1009        elif arg.getRank()==4:
1010            sum=escript.Scalar(0,arg.getFunctionSpace())
1011            for i in range(arg.getShape()[0]):
1012               for j in range(arg.getShape()[1]):
1013                  for k in range(arg.getShape()[2]):
1014                     for l in range(arg.getShape()[3]):
1015                        sum+=arg0[i,j,k,l]*arg1[i,j,k,l]
1016        else:
1017              raise SystemError,"inner is not been implemented yet"
1018        return sum
1019
1020    def matrixmult(arg0,arg1):
1021
1022        if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1023            numarray.matrixmult(arg0,arg1)
1024        else:
1025          # escript.matmult(arg0,arg1)
1026          if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1027            arg0=escript.Data(arg0,arg1.getFunctionSpace())
1028          elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1029            arg1=escript.Data(arg1,arg0.getFunctionSpace())
1030          if arg0.getRank()==2 and arg1.getRank()==1:
1031              out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1032              for i in range(arg0.getShape()[0]):
1033                 for j in range(arg0.getShape()[1]):
1034                   out[i]+=arg0[i,j]*arg1[j]
1035              return out
1036          else:
1037              raise SystemError,"matrixmult is not fully implemented yet!"
1038    #=========================================================
1039    # reduction operations:
1040    #=========================================================
1041  def sum(arg):  def sum(arg):
1042      """      """
1043      @brief      @brief
# Line 340  def Lsup(arg): Line 1098  def Lsup(arg):
1098      else:      else:
1099         return max(numarray.abs(arg))         return max(numarray.abs(arg))
1100
1101  def dot(arg1,arg2):  def dot(arg0,arg1):
1102      """      """
1103      @brief      @brief
1104
1105      @param arg      @param arg
1106      """      """
1107      if isinstance(arg1,escript.Data):      if isinstance(arg0,escript.Data):
1108         return arg1.dot(arg2)         return arg0.dot(arg1)
1109      elif isinstance(arg1,escript.Data):      elif isinstance(arg1,escript.Data):
1110         return arg2.dot(arg1)         return arg1.dot(arg0)
1111      else:      else:
1112         return numarray.dot(arg1,arg2)         return numarray.dot(arg0,arg1)
1113
1114    def kronecker(d):
1115       if hasattr(d,"getDim"):
1116          return numarray.identity(d.getDim())
1117       else:
1118          return numarray.identity(d)
1119
1120    def unit(i,d):
1121       """
1122       @brief return a unit vector of dimension d with nonzero index i
1123       @param d dimension
1124       @param i index
1125       """
1126       e = numarray.zeros((d,),numarray.Float)
1127       e[i] = 1.0
1128       return e
1129
1130    #
1131    # ============================================
1132    #   testing
1133    # ============================================
1134
1135    if __name__=="__main__":
1136      u=ScalarSymbol(dim=2,name="u")
1137      v=ScalarSymbol(dim=2,name="v")
1138      v=VectorSymbol(2,"v")
1139      u=VectorSymbol(2,"u")
1140
1141
1142      print u+5,(u+5).diff(u)
1143      print 5+u,(5+u).diff(u)
1144      print u+v,(u+v).diff(u)
1145      print v+u,(v+u).diff(u)
1146
1147      print u*5,(u*5).diff(u)
1148      print 5*u,(5*u).diff(u)
1149      print u*v,(u*v).diff(u)
1150      print v*u,(v*u).diff(u)
1151
1152      print u-5,(u-5).diff(u)
1153      print 5-u,(5-u).diff(u)
1154      print u-v,(u-v).diff(u)
1155      print v-u,(v-u).diff(u)
1156
1157      print u/5,(u/5).diff(u)
1158      print 5/u,(5/u).diff(u)
1159      print u/v,(u/v).diff(u)
1160      print v/u,(v/u).diff(u)
1161
1162      print u**5,(u**5).diff(u)
1163      print 5**u,(5**u).diff(u)
1164      print u**v,(u**v).diff(u)
1165      print v**u,(v**u).diff(u)
1166
1167      print exp(u),exp(u).diff(u)
1168      print sqrt(u),sqrt(u).diff(u)
1169      print log(u),log(u).diff(u)
1170      print sin(u),sin(u).diff(u)
1171      print cos(u),cos(u).diff(u)
1172      print tan(u),tan(u).diff(u)
1173      print sign(u),sign(u).diff(u)
1174      print abs(u),abs(u).diff(u)
1175      print wherePositive(u),wherePositive(u).diff(u)
1176      print whereNegative(u),whereNegative(u).diff(u)
1177      print maxval(u),maxval(u).diff(u)
1178      print minval(u),minval(u).diff(u)
1179
1181      print diff(5*g,g)
1182      4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1183    #
1184    # \$Log\$
1185    # Revision 1.11  2005/07/08 04:07:35  jgs
1186    # Merge of development branch back to main trunk on 2005-07-08
1187    #
1188    # Revision 1.10  2005/06/09 05:37:59  jgs
1189    # Merge of development branch back to main trunk on 2005-06-09
1190    #
1191    # Revision 1.2.2.17  2005/07/07 07:28:58  gross
1192    # some stuff added to util.py to improve functionality
1193    #
1194    # Revision 1.2.2.16  2005/06/30 01:53:55  gross
1195    # a bug in coloring fixed
1196    #
1197    # Revision 1.2.2.15  2005/06/29 02:36:43  gross
1198    # Symbols have been introduced and some function clarified. needs much more work
1199    #
1200    # Revision 1.2.2.14  2005/05/20 04:05:23  gross
1201    # some work on a darcy flow started
1202    #
1203    # Revision 1.2.2.13  2005/03/16 05:17:58  matt
1204    # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1205    # complement kronecker(dim) function.
1206    #
1207    # Revision 1.2.2.12  2005/03/10 08:14:37  matt
1208    # Added non-member Linf utility function to complement Data::Linf().
1209    #
1210    # Revision 1.2.2.11  2005/02/17 05:53:25  gross
1211    # some bug in saveDX fixed: in fact the bug was in
1212    # DataC/getDataPointShape
1213    #
1214    # Revision 1.2.2.10  2005/01/11 04:59:36  gross
1215    # automatic interpolation in integrate switched off
1216    #
1217    # Revision 1.2.2.9  2005/01/11 03:38:13  gross
1218    # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1219    #
1220    # Revision 1.2.2.8  2005/01/05 04:21:41  gross
1221    # FunctionSpace checking/matchig in slicing added
1222    #
1223    # Revision 1.2.2.7  2004/12/29 05:29:59  gross
1224    # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1225    #
1226    # Revision 1.2.2.6  2004/12/24 06:05:41  gross
1228    #
1229    # Revision 1.2.2.5  2004/12/17 00:06:53  gross
1230    # mk sets ESYS_ROOT is undefined
1231    #
1232    # Revision 1.2.2.4  2004/12/07 03:19:51  gross
1233    # options for GMRES and PRES20 added
1234    #
1235    # Revision 1.2.2.3  2004/12/06 04:55:18  gross
1236    # function wraper extended
1237    #
1238    # Revision 1.2.2.2  2004/11/22 05:44:07  gross
1239    # a few more unitary functions have been added but not implemented in Data yet
1240    #
1241    # Revision 1.2.2.1  2004/11/12 06:58:15  gross
1242    # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1243    #
1244    # Revision 1.2  2004/10/27 00:23:36  jgs
1245    # fixed minor syntax error
1246    #
1247    # Revision 1.1.1.1  2004/10/26 06:53:56  jgs
1248    # initial import of project esys2
1249    #
1250    # Revision 1.1.2.3  2004/10/26 06:43:48  jgs
1251    # committing Lutz's and Paul's changes to brach jgs
1252    #
1253    # Revision 1.1.4.1  2004/10/20 05:32:51  cochrane
1254    # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.
1255    #
1256    # Revision 1.1  2004/08/05 03:58:27  gross
1257    # Bug in Assemble_NodeCoordinates fixed
1258    #
1259    #

Legend:
 Removed from v.108 changed lines Added in v.123