# Diff of /trunk/escript/py_src/util.py

revision 102 by jgs, Wed Dec 15 07:08:39 2004 UTC revision 126 by jgs, Fri Jul 22 03:53:08 2005 UTC
# Line 4  Line 4
4
5  """  """
6  @brief Utility functions for escript  @brief Utility functions for escript
7
8    TODO for Data:
9
10      * binary operations @:               (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11                    @=+,-,*,/,**           (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12                                           (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13
14      * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15      * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16  """  """
17
18  import numarray  import numarray
19  import escript  import escript
20
21  #  #
22  #   escript constants (have to be consistent witj utilC.h  #   escript constants (have to be consistent with utilC.h )
23  #  #
24  UNKNOWN=-1  UNKNOWN=-1
25  EPSILON=1.e-15  EPSILON=1.e-15
# Line 49  TIFF=7 Line 59  TIFF=7
59  OPENINVENTOR=8  OPENINVENTOR=8
60  RENDERMAN=9  RENDERMAN=9
61  PNM=10  PNM=10
#
# wrapper for various functions: if the argument has attribute the function name
# as an argument it calls the correspong methods. Otherwise the coresponsing numarray
# function is called.
#
# functions involving the underlying Domain:
#
"""
@brief returns the spatial gradient of the Data object arg
62
63      @param arg: Data object representing the function which gradient to be calculated.  #===========================================================
64      @param where: FunctionSpace in which the gradient will be. If None Function(dom) where dom is the  # a simple tool box to deal with _differentials of functions
65                    domain of the Data object arg.  #===========================================================
66
67    class Symbol:
68       """symbol class"""
69       def __init__(self,name="symbol",shape=(),dim=3,args=[]):
70           """creates an instance of a symbol of shape shape and spatial dimension dim
71              The symbol may depending on a list of arguments args which
72              may be symbols or other objects. name gives the name of the symbol."""
73           self.__args=args
74           self.__name=name
75           self.__shape=shape
76           if hasattr(dim,"getDim"):
77               self.__dim=dim.getDim()
78           else:
79               self.__dim=dim
80           #
81           self.__cache_val=None
82           self.__cache_argval=None
83
84       def getArgument(self,i):
85           """returns the i-th argument"""
86           return self.__args[i]
87
88       def getDim(self):
89           """returns the spatial dimension of the symbol"""
90           return self.__dim
91
92       def getRank(self):
93           """returns the rank of the symbol"""
94           return len(self.getShape())
95
96       def getShape(self):
97           """returns the shape of the symbol"""
98           return self.__shape
99
100       def getEvaluatedArguments(self,argval):
101           """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
102           if argval==self.__cache_argval:
103               print "%s: cached value used"%self
104               return self.__cache_val
105           else:
106               out=[]
107               for a  in self.__args:
108                  if isinstance(a,Symbol):
109                    out.append(a.eval(argval))
110                  else:
111                    out.append(a)
112               self.__cache_argval=argval
113               self.__cache_val=out
114               return out
115
116       def getDifferentiatedArguments(self,arg):
117           """returns the list of the arguments _differentiated by arg"""
118           out=[]
119           for a in self.__args:
120              if isinstance(a,Symbol):
121                out.append(a.diff(arg))
122              else:
123                out.append(0)
124           return out
125
126       def diff(self,arg):
127           """returns the _differention of self by arg."""
128           if self==arg:
129              out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
130              if self.getRank()==0:
131                 out=1.
132              elif self.getRank()==1:
133                  for i0 in range(self.getShape()[0]):
134                     out[i0,i0]=1.
135              elif self.getRank()==2:
136                  for i0 in range(self.getShape()[0]):
137                    for i1 in range(self.getShape()[1]):
138                         out[i0,i1,i0,i1]=1.
139              elif self.getRank()==3:
140                  for i0 in range(self.getShape()[0]):
141                    for i1 in range(self.getShape()[1]):
142                      for i2 in range(self.getShape()[2]):
143                         out[i0,i1,i2,i0,i1,i2]=1.
144              elif self.getRank()==4:
145                  for i0 in range(self.getShape()[0]):
146                    for i1 in range(self.getShape()[1]):
147                      for i2 in range(self.getShape()[2]):
148                        for i3 in range(self.getShape()[3]):
149                           out[i0,i1,i2,i3,i0,i1,i2,i3]=1.
150              else:
151                 raise ValueError,"differential support rank<5 only."
152              return out
153           else:
154              return self._diff(arg)
155
156       def _diff(self,arg):
157           """return derivate of self with respect to arg (!=self).
158              This method is overwritten by a particular symbol"""
159           return 0
160
161       def eval(self,argval):
162           """subsitutes symbol u in self by argval[u] and returns the result. If
163              self is not a key of argval then self is returned."""
164           if argval.has_key(self):
165             return argval[self]
166           else:
167             return self
168
169       def __str__(self):
170           """returns a string representation of the symbol"""
171           return self.__name
172
174           """adds other to symbol self. if _testForZero(other) self is returned."""
175           if _testForZero(other):
176              return self
177           else:
178              a=_matchShape([self,other])
180
182           """adds other to symbol self. if _testForZero(other) self is returned."""
183           return self+other
184
185       def __neg__(self):
186           """returns -self."""
187           return self*(-1.)
188
189       def __pos__(self):
190           """returns +self."""
191           return self
192
193       def __abs__(self):
194           """returns absolute value"""
195           return Abs_Symbol(self)
196
197       def __sub__(self,other):
198           """subtracts other from symbol self. if _testForZero(other) self is returned."""
199           if _testForZero(other):
200              return self
201           else:
202              return self+(-other)
203
204       def __rsub__(self,other):
205           """subtracts symbol self from other."""
206           return -self+other
207
208       def __div__(self,other):
209           """divides symbol self by other."""
210           if isinstance(other,Symbol):
211              a=_matchShape([self,other])
212              return Div_Symbol(a[0],a[1])
213           else:
214              return self*(1./other)
215
216       def __rdiv__(self,other):
217           """dived other by symbol self. if _testForZero(other) 0 is returned."""
218           if _testForZero(other):
219              return 0
220           else:
221              a=_matchShape([self,other])
222              return Div_Symbol(a[0],a[1])
223
224       def __pow__(self,other):
225           """raises symbol self to the power of other"""
226           a=_matchShape([self,other])
227           return Power_Symbol(a[0],a[1])
228
229       def __rpow__(self,other):
230           """raises other to the symbol self"""
231           a=_matchShape([self,other])
232           return Power_Symbol(a[1],a[0])
233
234       def __mul__(self,other):
235           """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
236           if _testForZero(other):
237              return 0
238           else:
239              a=_matchShape([self,other])
240              return Mult_Symbol(a[0],a[1])
241
242       def __rmul__(self,other):
243           """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
244           return self*other
245
246       def __getitem__(self,sl):
247              print sl
248
249    def Float_Symbol(Symbol):
250        def __init__(self,name="symbol",shape=(),args=[]):
251            Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
252
253    class ScalarSymbol(Symbol):
254       """a scalar symbol"""
255       def __init__(self,dim=3,name="scalar"):
256          """creates a scalar symbol of spatial dimension dim"""
257          if hasattr(dim,"getDim"):
258               d=dim.getDim()
259          else:
260               d=dim
261          Symbol.__init__(self,shape=(),dim=d,name=name)
262
263    class VectorSymbol(Symbol):
264       """a vector symbol"""
265       def __init__(self,dim=3,name="vector"):
266          """creates a vector symbol of spatial dimension dim"""
267          if hasattr(dim,"getDim"):
268               d=dim.getDim()
269          else:
270               d=dim
271          Symbol.__init__(self,shape=(d,),dim=d,name=name)
272
273    class TensorSymbol(Symbol):
274       """a tensor symbol"""
275       def __init__(self,dim=3,name="tensor"):
276          """creates a tensor symbol of spatial dimension dim"""
277          if hasattr(dim,"getDim"):
278               d=dim.getDim()
279          else:
280               d=dim
281          Symbol.__init__(self,shape=(d,d),dim=d,name=name)
282
283    class Tensor3Symbol(Symbol):
284       """a tensor order 3 symbol"""
285       def __init__(self,dim=3,name="tensor3"):
286          """creates a tensor order 3 symbol of spatial dimension dim"""
287          if hasattr(dim,"getDim"):
288               d=dim.getDim()
289          else:
290               d=dim
291          Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
292
293    class Tensor4Symbol(Symbol):
294       """a tensor order 4 symbol"""
295       def __init__(self,dim=3,name="tensor4"):
296          """creates a tensor order 4 symbol of spatial dimension dim"""
297          if hasattr(dim,"getDim"):
298               d=dim.getDim()
299          else:
300               d=dim
301          Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
302
304       """symbol representing the sum of two arguments"""
305       def __init__(self,arg0,arg1):
306           a=[arg0,arg1]
307           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
308       def __str__(self):
309          return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
310       def eval(self,argval):
311           a=self.getEvaluatedArguments(argval)
312           return a[0]+a[1]
313       def _diff(self,arg):
314           a=self.getDifferentiatedArguments(arg)
315           return a[0]+a[1]
316
317    class Mult_Symbol(Symbol):
318       """symbol representing the product of two arguments"""
319       def __init__(self,arg0,arg1):
320           a=[arg0,arg1]
321           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
322       def __str__(self):
323          return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
324       def eval(self,argval):
325           a=self.getEvaluatedArguments(argval)
326           return a[0]*a[1]
327       def _diff(self,arg):
328           a=self.getDifferentiatedArguments(arg)
329           return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
330
331    class Div_Symbol(Symbol):
332       """symbol representing the quotient of two arguments"""
333       def __init__(self,arg0,arg1):
334           a=[arg0,arg1]
335           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
336       def __str__(self):
337          return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
338       def eval(self,argval):
339           a=self.getEvaluatedArguments(argval)
340           return a[0]/a[1]
341       def _diff(self,arg):
342           a=self.getDifferentiatedArguments(arg)
343           return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
344                              (self.getArgument(1)*self.getArgument(1))
345
346    class Power_Symbol(Symbol):
347       """symbol representing the power of the first argument to the power of the second argument"""
348       def __init__(self,arg0,arg1):
349           a=[arg0,arg1]
350           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
351       def __str__(self):
352          return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
353       def eval(self,argval):
354           a=self.getEvaluatedArguments(argval)
355           return a[0]**a[1]
356       def _diff(self,arg):
357           a=self.getDifferentiatedArguments(arg)
358           return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
359
360    class Abs_Symbol(Symbol):
361       """symbol representing absolute value of its argument"""
362       def __init__(self,arg):
363           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
364       def __str__(self):
365          return "abs(%s)"%str(self.getArgument(0))
366       def eval(self,argval):
367           return abs(self.getEvaluatedArguments(argval)[0])
368       def _diff(self,arg):
369           return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
370
371    #=========================================================
372    #   some little helpers
373    #=========================================================
374    def _testForZero(arg):
375       """returns True is arg is considered of being zero"""
376       if isinstance(arg,int):
377          return not arg>0
378       elif isinstance(arg,float):
379          return not arg>0.
380       elif isinstance(arg,numarray.NumArray):
381          a=abs(arg)
382          while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
383          return not a>0
384       else:
385          return False
386
387    def _extractDim(args):
388        dim=None
389        for a in args:
390           if hasattr(a,"getDim"):
391              d=a.getDim()
392              if dim==None:
393                 dim=d
394              else:
395                 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
396        if dim==None:
397           raise ValueError,"cannot recover spatial dimension"
398        return dim
399
400    def _identifyShape(arg):
401       """identifies the shape of arg."""
402       if hasattr(arg,"getShape"):
403           arg_shape=arg.getShape()
404       elif hasattr(arg,"shape"):
405         s=arg.shape
406         if callable(s):
407           arg_shape=s()
408         else:
409           arg_shape=s
410       else:
411           arg_shape=()
412       return arg_shape
413
414    def _extractShape(args):
415        """extracts the common shape of the list of arguments args"""
416        shape=None
417        for a in args:
418           a_shape=_identifyShape(a)
419           if shape==None: shape=a_shape
420           if shape!=a_shape: raise ValueError,"inconsistent shape"
421        if shape==None:
422           raise ValueError,"cannot recover shape"
423        return shape
424
425    def _matchShape(args,shape=None):
426        """returns the list of arguments args as object which have all the specified shape.
427           if shape is not given the shape "largest" shape of args is used."""
428        # identify the list of shapes:
429        arg_shapes=[]
430        for a in args: arg_shapes.append(_identifyShape(a))
431        # get the largest shape (currently the longest shape):
432        if shape==None: shape=max(arg_shapes)
433
434        out=[]
435        for i in range(len(args)):
436           if shape==arg_shapes[i]:
437              out.append(args[i])
438           else:
439              if len(shape)==0: # then len(arg_shapes[i])>0
440                raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
441              else:
442                if len(arg_shapes[i])==0:
443                    out.append(outer(args[i],numarray.ones(shape)))
444                else:
445                    raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
446        return out
447
448    #=========================================================
449    #   wrappers for various mathematical functions:
450    #=========================================================
451    def diff(arg,dep):
452        """returns the derivative of arg with respect to dep. If arg is not Symbol object
453           0 is returned"""
454        if isinstance(arg,Symbol):
455           return arg.diff(dep)
456        elif hasattr(arg,"shape"):
457              if callable(arg.shape):
458                  return numarray.zeros(arg.shape(),numarray.Float)
459              else:
460                  return numarray.zeros(arg.shape,numarray.Float)
461        else:
462           return 0
463
464    def exp(arg):
465      """      """
466      if where==None:      @brief applies the exponential function to arg
467         return arg.grad()      @param arg (input): argument
468        """
469        if isinstance(arg,Symbol):
470           return Exp_Symbol(arg)
471        elif hasattr(arg,"exp"):
472           return arg.exp()
473      else:      else:
475
476  def integrate(arg):  class Exp_Symbol(Symbol):
477      """     """symbol representing the power of the first argument to the power of the second argument"""
478      @brief return the integral if the function represented by Data object arg over its domain.     def __init__(self,arg):
479           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
480       def __str__(self):
481          return "exp(%s)"%str(self.getArgument(0))
482       def eval(self,argval):
483           return exp(self.getEvaluatedArguments(argval)[0])
484       def _diff(self,arg):
485           return self*self.getDifferentiatedArguments(arg)[0]
486
487      @param arg  def sqrt(arg):
488      """      """
489      return arg.integrate()      @brief applies the squre root function to arg
490        @param arg (input): argument
def interpolate(arg,where):
491      """      """
492      @brief interpolates the function represented by Data object arg into the FunctionSpace where.      if isinstance(arg,Symbol):
493           return Sqrt_Symbol(arg)
494        elif hasattr(arg,"sqrt"):
495           return arg.sqrt()
496        else:
497           return numarray.sqrt(arg)
498
499      @param arg  class Sqrt_Symbol(Symbol):
500      @param where     """symbol representing square root of argument"""
501      """     def __init__(self,arg):
502      return arg.interpolate(where)         Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
503       def __str__(self):
504          return "sqrt(%s)"%str(self.getArgument(0))
505       def eval(self,argval):
506           return sqrt(self.getEvaluatedArguments(argval)[0])
507       def _diff(self,arg):
508           return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
509
510    def log(arg):
511        """
512        @brief applies the logarithmic function bases exp(1.) to arg
513        @param arg (input): argument
514        """
515        if isinstance(arg,Symbol):
516           return Log_Symbol(arg)
517        elif hasattr(arg,"log"):
518           return arg.log()
519        else:
520           return numarray.log(arg)
521
522  # functions returning Data objects:  class Log_Symbol(Symbol):
523       """symbol representing logarithm of the argument"""
524       def __init__(self,arg):
525           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
526       def __str__(self):
527          return "log(%s)"%str(self.getArgument(0))
528       def eval(self,argval):
529           return log(self.getEvaluatedArguments(argval)[0])
530       def _diff(self,arg):
531           return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
532
533    def ln(arg):
534        """
535        @brief applies the natural logarithmic function to arg
536        @param arg (input): argument
537        """
538        if isinstance(arg,Symbol):
539           return Ln_Symbol(arg)
540        elif hasattr(arg,"ln"):
541           return arg.log()
542        else:
543           return numarray.log(arg)
544
545  def transpose(arg,axis=None):  class Ln_Symbol(Symbol):
546      """     """symbol representing natural logarithm of the argument"""
547      @brief returns the transpose of the Data object arg.     def __init__(self,arg):
548           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
549       def __str__(self):
550          return "ln(%s)"%str(self.getArgument(0))
551       def eval(self,argval):
552           return ln(self.getEvaluatedArguments(argval)[0])
553       def _diff(self,arg):
554           return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
555
556      @param arg  def sin(arg):
557      """      """
558      if isinstance(arg,escript.Data):      @brief applies the sin function to arg
559         if axis==None: axis=arg.getRank()/2      @param arg (input): argument
560         return arg.transpose(axis)      """
561        if isinstance(arg,Symbol):
562           return Sin_Symbol(arg)
563        elif hasattr(arg,"sin"):
564           return arg.sin()
565      else:      else:
566         if axis==None: axis=arg.rank/2         return numarray.sin(arg)
return numarray.transpose(arg,axis=axis)
567
568  def trace(arg):  class Sin_Symbol(Symbol):
569      """     """symbol representing sin of the argument"""
570      @brief     def __init__(self,arg):
571           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
572       def __str__(self):
573          return "sin(%s)"%str(self.getArgument(0))
574       def eval(self,argval):
575           return sin(self.getEvaluatedArguments(argval)[0])
576       def _diff(self,arg):
577           return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
578
579      @param arg  def cos(arg):
580      """      """
581      if isinstance(arg,escript.Data):      @brief applies the cos function to arg
582         return arg.trace()      @param arg (input): argument
583        """
584        if isinstance(arg,Symbol):
585           return Cos_Symbol(arg)
586        elif hasattr(arg,"cos"):
587           return arg.cos()
588      else:      else:
589         return numarray.trace(arg)         return numarray.cos(arg)
590
591  def exp(arg):  class Cos_Symbol(Symbol):
592       """symbol representing cos of the argument"""
593       def __init__(self,arg):
594           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
595       def __str__(self):
596          return "cos(%s)"%str(self.getArgument(0))
597       def eval(self,argval):
598           return cos(self.getEvaluatedArguments(argval)[0])
599       def _diff(self,arg):
600           return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
601
602    def tan(arg):
603      """      """
604      @brief      @brief applies the tan function to arg
605        @param arg (input): argument
606        """
607        if isinstance(arg,Symbol):
608           return Tan_Symbol(arg)
609        elif hasattr(arg,"tan"):
610           return arg.tan()
611        else:
612           return numarray.tan(arg)
613
614      @param arg  class Tan_Symbol(Symbol):
615       """symbol representing tan of the argument"""
616       def __init__(self,arg):
617           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
618       def __str__(self):
619          return "tan(%s)"%str(self.getArgument(0))
620       def eval(self,argval):
621           return tan(self.getEvaluatedArguments(argval)[0])
622       def _diff(self,arg):
623           s=cos(self.getArgument(0))
624           return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
625
626    def sign(arg):
627      """      """
628      if isinstance(arg,escript.Data):      @brief applies the sign function to arg
629         return arg.exp()      @param arg (input): argument
630        """
631        if isinstance(arg,Symbol):
632           return Sign_Symbol(arg)
633        elif hasattr(arg,"sign"):
634           return arg.sign()
635      else:      else:
636         return numarray.exp(arg)         return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
637                  numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
638
639  def sqrt(arg):  class Sign_Symbol(Symbol):
640       """symbol representing the sign of the argument"""
641       def __init__(self,arg):
642           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
643       def __str__(self):
644          return "sign(%s)"%str(self.getArgument(0))
645       def eval(self,argval):
646           return sign(self.getEvaluatedArguments(argval)[0])
647
648    def maxval(arg):
649      """      """
650      @brief      @brief returns the maximum value of argument arg""
651        @param arg (input): argument
652        """
653        if isinstance(arg,Symbol):
654           return Max_Symbol(arg)
655        elif hasattr(arg,"maxval"):
656           return arg.maxval()
657        elif hasattr(arg,"max"):
658           return arg.max()
659        else:
660           return arg
661
662      @param arg  class Max_Symbol(Symbol):
663       """symbol representing the maximum value of the argument"""
664       def __init__(self,arg):
665           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
666       def __str__(self):
667          return "maxval(%s)"%str(self.getArgument(0))
668       def eval(self,argval):
669           return maxval(self.getEvaluatedArguments(argval)[0])
670
671    def minval(arg):
672      """      """
673      if isinstance(arg,escript.Data):      @brief returns the minimum value of argument arg""
674         return arg.sqrt()      @param arg (input): argument
675        """
676        if isinstance(arg,Symbol):
677           return Min_Symbol(arg)
678        elif hasattr(arg,"maxval"):
679           return arg.minval()
680        elif hasattr(arg,"min"):
681           return arg.min()
682      else:      else:
683         return numarray.sqrt(arg)         return arg
684
685  def sin(arg):  class Min_Symbol(Symbol):
686       """symbol representing the minimum value of the argument"""
687       def __init__(self,arg):
688           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
689       def __str__(self):
690          return "minval(%s)"%str(self.getArgument(0))
691       def eval(self,argval):
692           return minval(self.getEvaluatedArguments(argval)[0])
693
694    def wherePositive(arg):
695        """
696        @brief returns the positive values of argument arg""
697        @param arg (input): argument
698        """
699        if _testForZero(arg):
700          return 0
701        elif isinstance(arg,Symbol):
702           return WherePositive_Symbol(arg)
703        elif hasattr(arg,"wherePositive"):
704           return arg.minval()
705        elif hasattr(arg,"wherePositive"):
706           numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
707        else:
708           if arg>0:
709              return 1.
710           else:
711              return 0.
712
713    class WherePositive_Symbol(Symbol):
714       """symbol representing the wherePositive function"""
715       def __init__(self,arg):
716           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
717       def __str__(self):
718          return "wherePositive(%s)"%str(self.getArgument(0))
719       def eval(self,argval):
720           return wherePositive(self.getEvaluatedArguments(argval)[0])
721
722    def whereNegative(arg):
723        """
724        @brief returns the negative values of argument arg""
725        @param arg (input): argument
726        """
727        if _testForZero(arg):
728          return 0
729        elif isinstance(arg,Symbol):
730           return WhereNegative_Symbol(arg)
731        elif hasattr(arg,"whereNegative"):
732           return arg.whereNegative()
733        elif hasattr(arg,"shape"):
734           numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
735        else:
736           if arg<0:
737              return 1.
738           else:
739              return 0.
740
741    class WhereNegative_Symbol(Symbol):
742       """symbol representing the whereNegative function"""
743       def __init__(self,arg):
744           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
745       def __str__(self):
746          return "whereNegative(%s)"%str(self.getArgument(0))
747       def eval(self,argval):
748           return whereNegative(self.getEvaluatedArguments(argval)[0])
749
750    def outer(arg0,arg1):
751       if _testForZero(arg0) or _testForZero(arg1):
752          return 0
753       else:
754          if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
755            return Outer_Symbol(arg0,arg1)
756          elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
757            return arg0*arg1
758          elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
759            return numarray.outer(arg0,arg1)
760          else:
761            if arg0.getRank()==1 and arg1.getRank()==1:
762              out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
763              for i in range(arg0.getShape()[0]):
764                for j in range(arg1.getShape()[0]):
765                    out[i,j]=arg0[i]*arg1[j]
766              return out
767            else:
768              raise ValueError,"outer is not fully implemented yet."
769
770    class Outer_Symbol(Symbol):
771       """symbol representing the outer product of its two arguments"""
772       def __init__(self,arg0,arg1):
773           a=[arg0,arg1]
774           s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
775           Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
776       def __str__(self):
777          return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
778       def eval(self,argval):
779           a=self.getEvaluatedArguments(argval)
780           return outer(a[0],a[1])
781       def _diff(self,arg):
782           a=self.getDifferentiatedArguments(arg)
783           return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
784
785    def interpolate(arg,where):
786      """      """
787      @brief      @brief interpolates the function into the FunctionSpace where.
788
789      @param arg      @param arg    interpolant
790        @param where  FunctionSpace to interpolate to
791      """      """
792      if isinstance(arg,escript.Data):      if _testForZero(arg):
793         return arg.sin()        return 0
794        elif isinstance(arg,Symbol):
795           return Interpolated_Symbol(arg,where)
796      else:      else:
797         return numarray.sin(arg)         return escript.Data(arg,where)
798
799  def tan(arg):  def Interpolated_Symbol(Symbol):
800      """     """symbol representing the integral of the argument"""
801      @brief     def __init__(self,arg,where):
802            Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
803       def __str__(self):
804          return "interpolated(%s)"%(str(self.getArgument(0)))
805       def eval(self,argval):
806           a=self.getEvaluatedArguments(argval)
807           return integrate(a[0],where=self.getArgument(1))
808       def _diff(self,arg):
809           a=self.getDifferentiatedArguments(arg)
810           return integrate(a[0],where=self.getArgument(1))
811
813      """      """
814      if isinstance(arg,escript.Data):      @brief returns the spatial gradient of arg at where.
815         return arg.tan()
816        @param arg:   Data object representing the function which gradient to be calculated.
817        @param where: FunctionSpace in which the gradient will be calculated. If not present or
818                      None an appropriate default is used.
819        """
820        if _testForZero(arg):
821          return 0
822        elif isinstance(arg,Symbol):
825           if where==None:
827           else:
829      else:      else:
830         return numarray.tan(arg)         return arg*0.
831
833       """symbol representing the gradient of the argument"""
834       def __init__(self,arg,where=None):
835           d=_extractDim([arg])
836           s=tuple(list(_identifyShape([arg])).append(d))
837           Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
838       def __str__(self):
840       def eval(self,argval):
841           a=self.getEvaluatedArguments(argval)
843       def _diff(self,arg):
844           a=self.getDifferentiatedArguments(arg)
846
847    def integrate(arg,where=None):
848      """      """
849      @brief      @brief return the integral if the function represented by Data object arg over its domain.
850
851        @param arg:   Data object representing the function which is integrated.
852        @param where: FunctionSpace in which the integral is calculated. If not present or
853                      None an appropriate default is used.
854        """
855        if _testForZero(arg):
856          return 0
857        elif isinstance(arg,Symbol):
858           return Integral_Symbol(arg,where)
859        else:
860           if not where==None: arg=escript.Data(arg,where)
861           if arg.getRank()==0:
862             return arg.integrate()[0]
863           else:
864             return arg.integrate()
865
866    def Integral_Symbol(Float_Symbol):
867       """symbol representing the integral of the argument"""
868       def __init__(self,arg,where=None):
869           Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
870       def __str__(self):
871          return "integral(%s)"%(str(self.getArgument(0)))
872       def eval(self,argval):
873           a=self.getEvaluatedArguments(argval)
874           return integrate(a[0],where=self.getArgument(1))
875       def _diff(self,arg):
876           a=self.getDifferentiatedArguments(arg)
877           return integrate(a[0],where=self.getArgument(1))
878
879    #=============================
880    #
881    # wrapper for various functions: if the argument has attribute the function name
882    # as an argument it calls the corresponding methods. Otherwise the corresponding
883    # numarray function is called.
884
885    # functions involving the underlying Domain:
886
887
888    # functions returning Data objects:
889
890    def transpose(arg,axis=None):
891        """
892        @brief returns the transpose of the Data object arg.
893
894      @param arg      @param arg
895      """      """
896        if axis==None:
897           r=0
898           if hasattr(arg,"getRank"): r=arg.getRank()
899           if hasattr(arg,"rank"): r=arg.rank
900           axis=r/2
901        if isinstance(arg,Symbol):
902           return Transpose_Symbol(arg,axis=r)
903      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
904         return arg.cos()         # hack for transpose
905           r=arg.getRank()
906           if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
907           s=arg.getShape()
908           out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
909           for i in range(s[0]):
910              for j in range(s[1]):
911                 out[j,i]=arg[i,j]
912           return out
913           # end hack for transpose
914           return arg.transpose(axis)
915      else:      else:
916         return numarray.cos(arg)         return numarray.transpose(arg,axis=axis)
917
918  def maxval(arg):  def trace(arg,axis0=0,axis1=1):
919      """      """
920      @brief      @brief return
921
922      @param arg      @param arg
923      """      """
924      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
925         return arg.maxval()         s=list(arg.getShape())
926           s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
927           return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
928        elif isinstance(arg,escript.Data):
929           # hack for trace
930           s=arg.getShape()
931           if s[axis0]!=s[axis1]:
932               raise ValueError,"illegal axis in trace"
933           out=escript.Scalar(0.,arg.getFunctionSpace())
934           for i in range(s[0]):
935              for j in range(s[1]):
936                 out+=arg[i,j]
937           return out
938           # end hack for trace
939           return arg.transpose(axis0=axis0,axis1=axis1)
940      else:      else:
941         return arg.max()         return numarray.trace(arg,axis0=axis0,axis1=axis1)
942
943  def minval(arg):  def Trace_Symbol(Symbol):
944        pass
945
946    def length(arg):
947      """      """
948      @brief      @brief
949
950      @param arg      @param arg
951      """      """
952      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
953         return arg.minval()         if arg.isEmpty(): return escript.Data()
954           if arg.getRank()==0:
955              return abs(arg)
956           elif arg.getRank()==1:
957              sum=escript.Scalar(0,arg.getFunctionSpace())
958              for i in range(arg.getShape()[0]):
959                 sum+=arg[i]**2
960              return sqrt(sum)
961           elif arg.getRank()==2:
962              sum=escript.Scalar(0,arg.getFunctionSpace())
963              for i in range(arg.getShape()[0]):
964                 for j in range(arg.getShape()[1]):
965                    sum+=arg[i,j]**2
966              return sqrt(sum)
967           elif arg.getRank()==3:
968              sum=escript.Scalar(0,arg.getFunctionSpace())
969              for i in range(arg.getShape()[0]):
970                 for j in range(arg.getShape()[1]):
971                    for k in range(arg.getShape()[2]):
972                       sum+=arg[i,j,k]**2
973              return sqrt(sum)
974           elif arg.getRank()==4:
975              sum=escript.Scalar(0,arg.getFunctionSpace())
976              for i in range(arg.getShape()[0]):
977                 for j in range(arg.getShape()[1]):
978                    for k in range(arg.getShape()[2]):
979                       for l in range(arg.getShape()[3]):
980                          sum+=arg[i,j,k,l]**2
981              return sqrt(sum)
982           else:
983              raise SystemError,"length is not been fully implemented yet"
984              # return arg.length()
985      else:      else:
986         return arg.max()         return sqrt((arg**2).sum())
987
988  def length(arg):  def deviator(arg):
989      """      """
990      @brief      @brief
991
992      @param arg      @param arg0
993      """      """
994      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
995         return arg.length()          shape=arg.getShape()
996      else:      else:
997         return sqrt((arg**2).sum())          shape=arg.shape
998        if len(shape)!=2:
999              raise ValueError,"Deviator requires rank 2 object"
1000        if shape[0]!=shape[1]:
1001              raise ValueError,"Deviator requires a square matrix"
1002        return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
1003
1004  def sign(arg):  def inner(arg0,arg1):
1005      """      """
1006      @brief      @brief
1007
1008      @param arg      @param arg0, arg1
1009      """      """
1010      if isinstance(arg,escript.Data):      sum=escript.Scalar(0,arg0.getFunctionSpace())
1011         return arg.sign()      if arg.getRank()==0:
1012              return arg0*arg1
1013        elif arg.getRank()==1:
1014             sum=escript.Scalar(0,arg.getFunctionSpace())
1015             for i in range(arg.getShape()[0]):
1016                sum+=arg0[i]*arg1[i]
1017        elif arg.getRank()==2:
1018            sum=escript.Scalar(0,arg.getFunctionSpace())
1019            for i in range(arg.getShape()[0]):
1020               for j in range(arg.getShape()[1]):
1021                  sum+=arg0[i,j]*arg1[i,j]
1022        elif arg.getRank()==3:
1023            sum=escript.Scalar(0,arg.getFunctionSpace())
1024            for i in range(arg.getShape()[0]):
1025                for j in range(arg.getShape()[1]):
1026                   for k in range(arg.getShape()[2]):
1027                      sum+=arg0[i,j,k]*arg1[i,j,k]
1028        elif arg.getRank()==4:
1029            sum=escript.Scalar(0,arg.getFunctionSpace())
1030            for i in range(arg.getShape()[0]):
1031               for j in range(arg.getShape()[1]):
1032                  for k in range(arg.getShape()[2]):
1033                     for l in range(arg.getShape()[3]):
1034                        sum+=arg0[i,j,k,l]*arg1[i,j,k,l]
1035      else:      else:
1036         return numarray.greater(arg,numarray.zeros(arg.shape))-numarray.less(arg,numarray.zeros(arg.shape))            raise SystemError,"inner is not been implemented yet"
1037        return sum
1038
1039  # reduction operations:  def matrixmult(arg0,arg1):
1040
1041        if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1042            numarray.matrixmult(arg0,arg1)
1043        else:
1044          # escript.matmult(arg0,arg1)
1045          if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1046            arg0=escript.Data(arg0,arg1.getFunctionSpace())
1047          elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1048            arg1=escript.Data(arg1,arg0.getFunctionSpace())
1049          if arg0.getRank()==2 and arg1.getRank()==1:
1050              out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1051              for i in range(arg0.getShape()[0]):
1052                 for j in range(arg0.getShape()[1]):
1053                   # uses Data object slicing, plus Data * and += operators
1054                   out[i]+=arg0[i,j]*arg1[j]
1055              return out
1056          else:
1057              raise SystemError,"matrixmult is not fully implemented yet!"
1058
1059    #=========================================================
1060    # reduction operations:
1061    #=========================================================
1062  def sum(arg):  def sum(arg):
1063      """      """
1064      @brief      @brief
# Line 229  def sup(arg): Line 1075  def sup(arg):
1075      """      """
1076      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
1077         return arg.sup()         return arg.sup()
1078        elif isinstance(arg,float) or isinstance(arg,int):
1079           return arg
1080      else:      else:
1081         return arg.max()         return arg.max()
1082
# Line 240  def inf(arg): Line 1088  def inf(arg):
1088      """      """
1089      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
1090         return arg.inf()         return arg.inf()
1091        elif isinstance(arg,float) or isinstance(arg,int):
1092           return arg
1093      else:      else:
1094         return arg.min()         return arg.min()
1095
# Line 249  def L2(arg): Line 1099  def L2(arg):
1099
1100      @param arg      @param arg
1101      """      """
1102      return arg.L2()      if isinstance(arg,escript.Data):
1103           return arg.L2()
1104        elif isinstance(arg,float) or isinstance(arg,int):
1105           return abs(arg)
1106        else:
1107           return numarry.sqrt(dot(arg,arg))
1108
1109  def Lsup(arg):  def Lsup(arg):
1110      """      """
# Line 259  def Lsup(arg): Line 1114  def Lsup(arg):
1114      """      """
1115      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
1116         return arg.Lsup()         return arg.Lsup()
1117        elif isinstance(arg,float) or isinstance(arg,int):
1118           return abs(arg)
1119      else:      else:
1120         return arg.max(numarray.abs(arg))         return max(numarray.abs(arg))
1121
1122  def dot(arg1,arg2):  def dot(arg0,arg1):
1123      """      """
1124      @brief      @brief
1125
1126      @param arg      @param arg
1127      """      """
1128      if isinstance(arg1,escript.Data):      if isinstance(arg0,escript.Data):
1129         return arg1.dot(arg2)         return arg0.dot(arg1)
1130      elif isinstance(arg1,escript.Data):      elif isinstance(arg1,escript.Data):
1131         return arg2.dot(arg1)         return arg1.dot(arg0)
1132      else:      else:
1133         return numarray.dot(arg1,arg2)         return numarray.dot(arg0,arg1)
1134
1135    def kronecker(d):
1136       if hasattr(d,"getDim"):
1137          return numarray.identity(d.getDim())
1138       else:
1139          return numarray.identity(d)
1140
1141    def unit(i,d):
1142       """
1143       @brief return a unit vector of dimension d with nonzero index i
1144       @param d dimension
1145       @param i index
1146       """
1147       e = numarray.zeros((d,),numarray.Float)
1148       e[i] = 1.0
1149       return e
1150
1151    # ============================================
1152    #   testing
1153    # ============================================
1154
1155    if __name__=="__main__":
1156      u=ScalarSymbol(dim=2,name="u")
1157      v=ScalarSymbol(dim=2,name="v")
1158      v=VectorSymbol(2,"v")
1159      u=VectorSymbol(2,"u")
1160
1161      print u+5,(u+5).diff(u)
1162      print 5+u,(5+u).diff(u)
1163      print u+v,(u+v).diff(u)
1164      print v+u,(v+u).diff(u)
1165
1166      print u*5,(u*5).diff(u)
1167      print 5*u,(5*u).diff(u)
1168      print u*v,(u*v).diff(u)
1169      print v*u,(v*u).diff(u)
1170
1171      print u-5,(u-5).diff(u)
1172      print 5-u,(5-u).diff(u)
1173      print u-v,(u-v).diff(u)
1174      print v-u,(v-u).diff(u)
1175
1176      print u/5,(u/5).diff(u)
1177      print 5/u,(5/u).diff(u)
1178      print u/v,(u/v).diff(u)
1179      print v/u,(v/u).diff(u)
1180
1181      print u**5,(u**5).diff(u)
1182      print 5**u,(5**u).diff(u)
1183      print u**v,(u**v).diff(u)
1184      print v**u,(v**u).diff(u)
1185
1186      print exp(u),exp(u).diff(u)
1187      print sqrt(u),sqrt(u).diff(u)
1188      print log(u),log(u).diff(u)
1189      print sin(u),sin(u).diff(u)
1190      print cos(u),cos(u).diff(u)
1191      print tan(u),tan(u).diff(u)
1192      print sign(u),sign(u).diff(u)
1193      print abs(u),abs(u).diff(u)
1194      print wherePositive(u),wherePositive(u).diff(u)
1195      print whereNegative(u),whereNegative(u).diff(u)
1196      print maxval(u),maxval(u).diff(u)
1197      print minval(u),minval(u).diff(u)
1198
1200      print diff(5*g,g)
1201      4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1202
1203    #
1204    # \$Log\$
1205    # Revision 1.13  2005/07/22 03:53:01  jgs
1206    # Merge of development branch back to main trunk on 2005-07-22
1207    #
1208    # Revision 1.12  2005/07/20 06:14:58  jgs
1209    # added ln(data) style wrapper for data.ln() - also added corresponding
1210    # implementation of Ln_Symbol class (not sure if this is right though)
1211    #
1212    # Revision 1.11  2005/07/08 04:07:35  jgs
1213    # Merge of development branch back to main trunk on 2005-07-08
1214    #
1215    # Revision 1.10  2005/06/09 05:37:59  jgs
1216    # Merge of development branch back to main trunk on 2005-06-09
1217    #
1218    # Revision 1.2.2.19  2005/07/21 04:01:28  jgs
1219    # minor comment fixes
1220    #
1221    # Revision 1.2.2.18  2005/07/21 01:02:43  jgs
1222    # commit ln() updates to development branch version
1223    #
1224    # Revision 1.12  2005/07/20 06:14:58  jgs
1225    # added ln(data) style wrapper for data.ln() - also added corresponding
1226    # implementation of Ln_Symbol class (not sure if this is right though)
1227    #
1228    # Revision 1.11  2005/07/08 04:07:35  jgs
1229    # Merge of development branch back to main trunk on 2005-07-08
1230    #
1231    # Revision 1.10  2005/06/09 05:37:59  jgs
1232    # Merge of development branch back to main trunk on 2005-06-09
1233    #
1234    # Revision 1.2.2.17  2005/07/07 07:28:58  gross
1235    # some stuff added to util.py to improve functionality
1236    #
1237    # Revision 1.2.2.16  2005/06/30 01:53:55  gross
1238    # a bug in coloring fixed
1239    #
1240    # Revision 1.2.2.15  2005/06/29 02:36:43  gross
1241    # Symbols have been introduced and some function clarified. needs much more work
1242    #
1243    # Revision 1.2.2.14  2005/05/20 04:05:23  gross
1244    # some work on a darcy flow started
1245    #
1246    # Revision 1.2.2.13  2005/03/16 05:17:58  matt
1247    # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1248    # complement kronecker(dim) function.
1249    #
1250    # Revision 1.2.2.12  2005/03/10 08:14:37  matt
1251    # Added non-member Linf utility function to complement Data::Linf().
1252    #
1253    # Revision 1.2.2.11  2005/02/17 05:53:25  gross
1254    # some bug in saveDX fixed: in fact the bug was in
1255    # DataC/getDataPointShape
1256    #
1257    # Revision 1.2.2.10  2005/01/11 04:59:36  gross
1258    # automatic interpolation in integrate switched off
1259    #
1260    # Revision 1.2.2.9  2005/01/11 03:38:13  gross
1261    # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1262    #
1263    # Revision 1.2.2.8  2005/01/05 04:21:41  gross
1264    # FunctionSpace checking/matchig in slicing added
1265    #
1266    # Revision 1.2.2.7  2004/12/29 05:29:59  gross
1267    # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1268    #
1269    # Revision 1.2.2.6  2004/12/24 06:05:41  gross
1271    #
1272    # Revision 1.2.2.5  2004/12/17 00:06:53  gross
1273    # mk sets ESYS_ROOT is undefined
1274    #
1275    # Revision 1.2.2.4  2004/12/07 03:19:51  gross
1276    # options for GMRES and PRES20 added
1277    #
1278    # Revision 1.2.2.3  2004/12/06 04:55:18  gross
1279    # function wraper extended
1280    #
1281    # Revision 1.2.2.2  2004/11/22 05:44:07  gross
1282    # a few more unitary functions have been added but not implemented in Data yet
1283    #
1284    # Revision 1.2.2.1  2004/11/12 06:58:15  gross
1285    # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1286    #
1287    # Revision 1.2  2004/10/27 00:23:36  jgs
1288    # fixed minor syntax error
1289    #
1290    # Revision 1.1.1.1  2004/10/26 06:53:56  jgs
1291    # initial import of project esys2
1292    #
1293    # Revision 1.1.2.3  2004/10/26 06:43:48  jgs
1294    # committing Lutz's and Paul's changes to brach jgs
1295    #
1296    # Revision 1.1.4.1  2004/10/20 05:32:51  cochrane
1297    # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.
1298    #
1299    # Revision 1.1  2004/08/05 03:58:27  gross
1300    # Bug in Assemble_NodeCoordinates fixed
1301    #
1302    #

Legend:
 Removed from v.102 changed lines Added in v.126