Diff of /trunk/escript/py_src/util.py

revision 100 by jgs, Wed Dec 15 03:48:48 2004 UTC revision 123 by jgs, Fri Jul 8 04:08:13 2005 UTC
# Line 4  Line 4
4
5  """  """
6  @brief Utility functions for escript  @brief Utility functions for escript
7
8    TODO for Data:
9
10      * binary operations @:               (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11                    @=+,-,*,/,**           (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12                                           (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13
14      * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15      * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16
17  """  """
18
19  import numarray  import numarray
20    import escript
21  #  #
22  #   escript constants:  #   escript constants (have to be consistent witj utilC.h
23  #  #
FALSE=0
TRUE=1
24  UNKNOWN=-1  UNKNOWN=-1
25  EPSILON=1.e-15  EPSILON=1.e-15
26  Pi=3.1415926535897931  Pi=numarray.pi
# matrix types
CSC=0
CSR=1
LUMPED=10
27  # some solver options:  # some solver options:
28  NO_REORDERING=0  NO_REORDERING=0
29  MINIMUM_FILL_IN=1  MINIMUM_FILL_IN=1
30  NESTED_DISSECTION=2  NESTED_DISSECTION=2
31    # solver methods
32  DEFAULT_METHOD=0  DEFAULT_METHOD=0
33  PCG=1  DIRECT=1
34  CR=2  CHOLEVSKY=2
35  CGS=3  PCG=3
36  BICGSTAB=4  CR=4
37  SSOR=5  CGS=5
38  ILU0=6  BICGSTAB=6
39  ILUT=7  SSOR=7
40  JACOBI=8  ILU0=8
41    ILUT=9
42    JACOBI=10
43    GMRES=11
44    PRES20=12
45
46    METHOD_KEY="method"
47    SYMMETRY_KEY="symmetric"
48    TOLERANCE_KEY="tolerance"
49
50  # supported file formats:  # supported file formats:
51  VRML=1  VRML=1
52  PNG=2  PNG=2
# Line 44  TIFF=7 Line 59  TIFF=7
59  OPENINVENTOR=8  OPENINVENTOR=8
60  RENDERMAN=9  RENDERMAN=9
61  PNM=10  PNM=10
62  #  #===========================================================
63  # wrapper for various functions: if the argument has attribute the function name  #  a simple tool box to deal with _differentials of functions
64  # as an argument it calls the correspong methods. Otherwise the coresponsing numarray  #===========================================================
65  # function is called.
66  #  class Symbol:
67  def L2(arg):     """symbol class"""
68      """     def __init__(self,name="symbol",shape=(),dim=3,args=[]):
69      @brief         """creates an instance of a symbol of shape shape and spatial dimension dim
70              The symbol may depending on a list of arguments args which
71              may be symbols or other objects. name gives the name of the symbol."""
72           self.__args=args
73           self.__name=name
74           self.__shape=shape
75           if hasattr(dim,"getDim"):
76               self.__dim=dim.getDim()
77           else:
78               self.__dim=dim
79           #
80           self.__cache_val=None
81           self.__cache_argval=None
82
83       def getArgument(self,i):
84           """returns the i-th argument"""
85           return self.__args[i]
86
87       def getDim(self):
88           """returns the spatial dimension of the symbol"""
89           return self.__dim
90
91       def getRank(self):
92           """returns the rank of the symbol"""
93           return len(self.getShape())
94
95       def getShape(self):
96           """returns the shape of the symbol"""
97           return self.__shape
98
99       def getEvaluatedArguments(self,argval):
100           """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
101           if argval==self.__cache_argval:
102               print "%s: cached value used"%self
103               return self.__cache_val
104           else:
105               out=[]
106               for a  in self.__args:
107                  if isinstance(a,Symbol):
108                    out.append(a.eval(argval))
109                  else:
110                    out.append(a)
111               self.__cache_argval=argval
112               self.__cache_val=out
113               return out
114
115       def getDifferentiatedArguments(self,arg):
116           """returns the list of the arguments _differentiated by arg"""
117           out=[]
118           for a in self.__args:
119              if isinstance(a,Symbol):
120                out.append(a.diff(arg))
121              else:
122                out.append(0)
123           return out
124
125       def diff(self,arg):
126           """returns the _differention of self by arg."""
127           if self==arg:
128              out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
129              if self.getRank()==0:
130                 out=1.
131              elif self.getRank()==1:
132                  for i0 in range(self.getShape()[0]):
133                     out[i0,i0]=1.
134              elif self.getRank()==2:
135                  for i0 in range(self.getShape()[0]):
136                    for i1 in range(self.getShape()[1]):
137                         out[i0,i1,i0,i1]=1.
138              elif self.getRank()==3:
139                  for i0 in range(self.getShape()[0]):
140                    for i1 in range(self.getShape()[1]):
141                      for i2 in range(self.getShape()[2]):
142                         out[i0,i1,i2,i0,i1,i2]=1.
143              elif self.getRank()==4:
144                  for i0 in range(self.getShape()[0]):
145                    for i1 in range(self.getShape()[1]):
146                      for i2 in range(self.getShape()[2]):
147                        for i3 in range(self.getShape()[3]):
148                           out[i0,i1,i2,i3,i0,i1,i2,i3]=1.
149              else:
150                 raise ValueError,"differential support rank<5 only."
151              return out
152           else:
153              return self._diff(arg)
154
155       def _diff(self,arg):
156           """return derivate of self with respect to arg (!=self).
157              This method is overwritten by a particular symbol"""
158           return 0
159
160       def eval(self,argval):
161           """subsitutes symbol u in self by argval[u] and returns the result. If
162              self is not a key of argval then self is returned."""
163           if argval.has_key(self):
164             return argval[self]
165           else:
166             return self
167
168
169       def __str__(self):
170           """returns a string representation of the symbol"""
171           return self.__name
172
174           """adds other to symbol self. if _testForZero(other) self is returned."""
175           if _testForZero(other):
176              return self
177           else:
178              a=_matchShape([self,other])
180
181
183           """adds other to symbol self. if _testForZero(other) self is returned."""
184           return self+other
185
186       def __neg__(self):
187           """returns -self."""
188           return self*(-1.)
189
190       def __pos__(self):
191           """returns +self."""
192           return self
193
194       def __abs__(self):
195           """returns absolute value"""
196           return Abs_Symbol(self)
197
198       def __sub__(self,other):
199           """subtracts other from symbol self. if _testForZero(other) self is returned."""
200           if _testForZero(other):
201              return self
202           else:
203              return self+(-other)
204
205       def __rsub__(self,other):
206           """subtracts symbol self from other."""
207           return -self+other
208
209       def __div__(self,other):
210           """divides symbol self by other."""
211           if isinstance(other,Symbol):
212              a=_matchShape([self,other])
213              return Div_Symbol(a[0],a[1])
214           else:
215              return self*(1./other)
216
217       def __rdiv__(self,other):
218           """dived other by symbol self. if _testForZero(other) 0 is returned."""
219           if _testForZero(other):
220              return 0
221           else:
222              a=_matchShape([self,other])
223              return Div_Symbol(a[0],a[1])
224
225       def __pow__(self,other):
226           """raises symbol self to the power of other"""
227           a=_matchShape([self,other])
228           return Power_Symbol(a[0],a[1])
229
230       def __rpow__(self,other):
231           """raises other to the symbol self"""
232           a=_matchShape([self,other])
233           return Power_Symbol(a[1],a[0])
234
235       def __mul__(self,other):
236           """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
237           if _testForZero(other):
238              return 0
239           else:
240              a=_matchShape([self,other])
241              return Mult_Symbol(a[0],a[1])
242
243       def __rmul__(self,other):
244           """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
245           return self*other
246
247       def __getitem__(self,sl):
248              print sl
249
250    def Float_Symbol(Symbol):
251        def __init__(self,name="symbol",shape=(),args=[]):
252            Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
253
254    class ScalarSymbol(Symbol):
255       """a scalar symbol"""
256       def __init__(self,dim=3,name="scalar"):
257          """creates a scalar symbol of spatial dimension dim"""
258          if hasattr(dim,"getDim"):
259               d=dim.getDim()
260          else:
261               d=dim
262          Symbol.__init__(self,shape=(),dim=d,name=name)
263
264    class VectorSymbol(Symbol):
265       """a vector symbol"""
266       def __init__(self,dim=3,name="vector"):
267          """creates a vector symbol of spatial dimension dim"""
268          if hasattr(dim,"getDim"):
269               d=dim.getDim()
270          else:
271               d=dim
272          Symbol.__init__(self,shape=(d,),dim=d,name=name)
273
274    class TensorSymbol(Symbol):
275       """a tensor symbol"""
276       def __init__(self,dim=3,name="tensor"):
277          """creates a tensor symbol of spatial dimension dim"""
278          if hasattr(dim,"getDim"):
279               d=dim.getDim()
280          else:
281               d=dim
282          Symbol.__init__(self,shape=(d,d),dim=d,name=name)
283
284    class Tensor3Symbol(Symbol):
285       """a tensor order 3 symbol"""
286       def __init__(self,dim=3,name="tensor3"):
287          """creates a tensor order 3 symbol of spatial dimension dim"""
288          if hasattr(dim,"getDim"):
289               d=dim.getDim()
290          else:
291               d=dim
292          Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
293
294    class Tensor4Symbol(Symbol):
295       """a tensor order 4 symbol"""
296       def __init__(self,dim=3,name="tensor4"):
297          """creates a tensor order 4 symbol of spatial dimension dim"""
298          if hasattr(dim,"getDim"):
299               d=dim.getDim()
300          else:
301               d=dim
302          Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
303
304
306       """symbol representing the sum of two arguments"""
307       def __init__(self,arg0,arg1):
308           a=[arg0,arg1]
309           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
310       def __str__(self):
311          return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
312       def eval(self,argval):
313           a=self.getEvaluatedArguments(argval)
314           return a[0]+a[1]
315       def _diff(self,arg):
316           a=self.getDifferentiatedArguments(arg)
317           return a[0]+a[1]
318
319    class Mult_Symbol(Symbol):
320       """symbol representing the product of two arguments"""
321       def __init__(self,arg0,arg1):
322           a=[arg0,arg1]
323           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
324       def __str__(self):
325          return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
326       def eval(self,argval):
327           a=self.getEvaluatedArguments(argval)
328           return a[0]*a[1]
329       def _diff(self,arg):
330           a=self.getDifferentiatedArguments(arg)
331           return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
332
333    class Div_Symbol(Symbol):
334       """symbol representing the quotient of two arguments"""
335       def __init__(self,arg0,arg1):
336           a=[arg0,arg1]
337           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
338       def __str__(self):
339          return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
340       def eval(self,argval):
341           a=self.getEvaluatedArguments(argval)
342           return a[0]/a[1]
343       def _diff(self,arg):
344           a=self.getDifferentiatedArguments(arg)
345           return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
346                              (self.getArgument(1)*self.getArgument(1))
347
348    class Power_Symbol(Symbol):
349       """symbol representing the power of the first argument to the power of the second argument"""
350       def __init__(self,arg0,arg1):
351           a=[arg0,arg1]
352           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
353       def __str__(self):
354          return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
355       def eval(self,argval):
356           a=self.getEvaluatedArguments(argval)
357           return a[0]**a[1]
358       def _diff(self,arg):
359           a=self.getDifferentiatedArguments(arg)
360           return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
361
362    class Abs_Symbol(Symbol):
363       """symbol representing absolute value of its argument"""
364       def __init__(self,arg):
365           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
366       def __str__(self):
367          return "abs(%s)"%str(self.getArgument(0))
368       def eval(self,argval):
369           return abs(self.getEvaluatedArguments(argval)[0])
370       def _diff(self,arg):
371           return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
372
373    #=========================================================
374    #   some little helpers
375    #=========================================================
376    def _testForZero(arg):
377       """returns True is arg is considered of being zero"""
378       if isinstance(arg,int):
379          return not arg>0
380       elif isinstance(arg,float):
381          return not arg>0.
382       elif isinstance(arg,numarray.NumArray):
383          a=abs(arg)
384          while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
385          return not a>0
386       else:
387          return False
388
389    def _extractDim(args):
390        dim=None
391        for a in args:
392           if hasattr(a,"getDim"):
393              d=a.getDim()
394              if dim==None:
395                 dim=d
396              else:
397                 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
398        if dim==None:
399           raise ValueError,"cannot recover spatial dimension"
400        return dim
401
402    def _identifyShape(arg):
403       """identifies the shape of arg."""
404       if hasattr(arg,"getShape"):
405           arg_shape=arg.getShape()
406       elif hasattr(arg,"shape"):
407         s=arg.shape
408         if callable(s):
409           arg_shape=s()
410         else:
411           arg_shape=s
412       else:
413           arg_shape=()
414       return arg_shape
415
416    def _extractShape(args):
417        """extracts the common shape of the list of arguments args"""
418        shape=None
419        for a in args:
420           a_shape=_identifyShape(a)
421           if shape==None: shape=a_shape
422           if shape!=a_shape: raise ValueError,"inconsistent shape"
423        if shape==None:
424           raise ValueError,"cannot recover shape"
425        return shape
426
427    def _matchShape(args,shape=None):
428        """returns the list of arguments args as object which have all the specified shape.
429           if shape is not given the shape "largest" shape of args is used."""
430        # identify the list of shapes:
431        arg_shapes=[]
432        for a in args: arg_shapes.append(_identifyShape(a))
433        # get the largest shape (currently the longest shape):
434        if shape==None: shape=max(arg_shapes)
435
436        out=[]
437        for i in range(len(args)):
438           if shape==arg_shapes[i]:
439              out.append(args[i])
440           else:
441              if len(shape)==0: # then len(arg_shapes[i])>0
442                raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
443              else:
444                if len(arg_shapes[i])==0:
445                    out.append(outer(args[i],numarray.ones(shape)))
446                else:
447                    raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
448        return out
449    #=========================================================
450    #   wrapper for various mathematical functions:
451    #=========================================================
452    def diff(arg,dep):
453        """returns the derivative of arg with respect to dep. If arg is not Symbol object
454           0 is returned"""
455        if isinstance(arg,Symbol):
456           return arg.diff(dep)
457        elif hasattr(arg,"shape"):
458              if callable(arg.shape):
459                  return numarray.zeros(arg.shape(),numarray.Float)
460              else:
461                  return numarray.zeros(arg.shape,numarray.Float)
462        else:
463           return 0
464
465      @param arg  def exp(arg):
466        """
467        @brief applies the exponential function to arg
468        @param arg (input): argument
469      """      """
470      return arg.L2()      if isinstance(arg,Symbol):
471           return Exp_Symbol(arg)
472        elif hasattr(arg,"exp"):
473           return arg.exp()
474        else:
475           return numarray.exp(arg)
476
478       """symbol representing the power of the first argument to the power of the second argument"""
479       def __init__(self,arg):
480           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
481       def __str__(self):
482          return "exp(%s)"%str(self.getArgument(0))
483       def eval(self,argval):
484           return exp(self.getEvaluatedArguments(argval)[0])
485       def _diff(self,arg):
486           return self*self.getDifferentiatedArguments(arg)[0]
487
488    def sqrt(arg):
489      """      """
490      @brief      @brief applies the squre root function to arg
491        @param arg (input): argument
492        """
493        if isinstance(arg,Symbol):
494           return Sqrt_Symbol(arg)
495        elif hasattr(arg,"sqrt"):
496           return arg.sqrt()
497        else:
498           return numarray.sqrt(arg)
499
500      @param arg  class Sqrt_Symbol(Symbol):
501      @param where     """symbol representing square root of argument"""
502       def __init__(self,arg):
503           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
504       def __str__(self):
505          return "sqrt(%s)"%str(self.getArgument(0))
506       def eval(self,argval):
507           return sqrt(self.getEvaluatedArguments(argval)[0])
508       def _diff(self,arg):
509           return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
510
511    def log(arg):
512        """
513        @brief applies the logarithmic function bases exp(1.) to arg
514        @param arg (input): argument
515        """
516        if isinstance(arg,Symbol):
517           return Log_Symbol(arg)
518        elif hasattr(arg,"log"):
519           return arg.log()
520        else:
521           return numarray.log(arg)
522
523    class Log_Symbol(Symbol):
524       """symbol representing logarithm of the argument"""
525       def __init__(self,arg):
526           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
527       def __str__(self):
528          return "log(%s)"%str(self.getArgument(0))
529       def eval(self,argval):
530           return log(self.getEvaluatedArguments(argval)[0])
531       def _diff(self,arg):
532           return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
533
534    def sin(arg):
535      """      """
536      if where==None:      @brief applies the sinus function to arg
537         return arg.grad()      @param arg (input): argument
538        """
539        if isinstance(arg,Symbol):
540           return Sin_Symbol(arg)
541        elif hasattr(arg,"sin"):
542           return arg.sin()
543      else:      else:
545
546    class Sin_Symbol(Symbol):
547       """symbol representing logarithm of the argument"""
548       def __init__(self,arg):
549           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
550       def __str__(self):
551          return "sin(%s)"%str(self.getArgument(0))
552       def eval(self,argval):
553           return sin(self.getEvaluatedArguments(argval)[0])
554       def _diff(self,arg):
555           return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
556
557  def integrate(arg):  def cos(arg):
558      """      """
559      @brief      @brief applies the sinus function to arg
560        @param arg (input): argument
561        """
562        if isinstance(arg,Symbol):
563           return Cos_Symbol(arg)
564        elif hasattr(arg,"cos"):
565           return arg.cos()
566        else:
567           return numarray.cos(arg)
568
569      @param arg  class Cos_Symbol(Symbol):
570       """symbol representing logarithm of the argument"""
571       def __init__(self,arg):
572           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
573       def __str__(self):
574          return "cos(%s)"%str(self.getArgument(0))
575       def eval(self,argval):
576           return cos(self.getEvaluatedArguments(argval)[0])
577       def _diff(self,arg):
578           return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
579
580    def tan(arg):
581        """
582        @brief applies the sinus function to arg
583        @param arg (input): argument
584        """
585        if isinstance(arg,Symbol):
586           return Tan_Symbol(arg)
587        elif hasattr(arg,"tan"):
588           return arg.tan()
589        else:
590           return numarray.tan(arg)
591
592    class Tan_Symbol(Symbol):
593       """symbol representing logarithm of the argument"""
594       def __init__(self,arg):
595           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
596       def __str__(self):
597          return "tan(%s)"%str(self.getArgument(0))
598       def eval(self,argval):
599           return tan(self.getEvaluatedArguments(argval)[0])
600       def _diff(self,arg):
601           s=cos(self.getArgument(0))
602           return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
603
604    def sign(arg):
605      """      """
606      return arg.integrate()      @brief applies the sign function to arg
607        @param arg (input): argument
608        """
609        if isinstance(arg,Symbol):
610           return Sign_Symbol(arg)
611        elif hasattr(arg,"sign"):
612           return arg.sign()
613        else:
614           return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
615                  numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
616
617  def interpolate(arg,where):  class Sign_Symbol(Symbol):
618       """symbol representing the sign of the argument"""
619       def __init__(self,arg):
620           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
621       def __str__(self):
622          return "sign(%s)"%str(self.getArgument(0))
623       def eval(self,argval):
624           return sign(self.getEvaluatedArguments(argval)[0])
625
626    def maxval(arg):
627      """      """
628      @brief      @brief returns the maximum value of argument arg""
629        @param arg (input): argument
630        """
631        if isinstance(arg,Symbol):
632           return Max_Symbol(arg)
633        elif hasattr(arg,"maxval"):
634           return arg.maxval()
635        elif hasattr(arg,"max"):
636           return arg.max()
637        else:
638           return arg
639
640      @param arg  class Max_Symbol(Symbol):
641      @param where     """symbol representing the sign of the argument"""
642       def __init__(self,arg):
643           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
644       def __str__(self):
645          return "maxval(%s)"%str(self.getArgument(0))
646       def eval(self,argval):
647           return maxval(self.getEvaluatedArguments(argval)[0])
648
649    def minval(arg):
650      """      """
651      return arg.interpolate(where)      @brief returns the maximum value of argument arg""
652        @param arg (input): argument
653        """
654        if isinstance(arg,Symbol):
655           return Min_Symbol(arg)
656        elif hasattr(arg,"maxval"):
657           return arg.minval()
658        elif hasattr(arg,"min"):
659           return arg.min()
660        else:
661           return arg
662
663  def transpose(arg):  class Min_Symbol(Symbol):
664       """symbol representing the sign of the argument"""
665       def __init__(self,arg):
666           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
667       def __str__(self):
668          return "minval(%s)"%str(self.getArgument(0))
669       def eval(self,argval):
670           return minval(self.getEvaluatedArguments(argval)[0])
671
672    def wherePositive(arg):
673        """
674        @brief returns the maximum value of argument arg""
675        @param arg (input): argument
676        """
677        if _testForZero(arg):
678          return 0
679        elif isinstance(arg,Symbol):
680           return WherePositive_Symbol(arg)
681        elif hasattr(arg,"wherePositive"):
682           return arg.minval()
683        elif hasattr(arg,"wherePositive"):
684           numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
685        else:
686           if arg>0:
687              return 1.
688           else:
689              return 0.
690
691    class WherePositive_Symbol(Symbol):
692       """symbol representing the wherePositive function"""
693       def __init__(self,arg):
694           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
695       def __str__(self):
696          return "wherePositive(%s)"%str(self.getArgument(0))
697       def eval(self,argval):
698           return wherePositive(self.getEvaluatedArguments(argval)[0])
699
700    def whereNegative(arg):
701        """
702        @brief returns the maximum value of argument arg""
703        @param arg (input): argument
704        """
705        if _testForZero(arg):
706          return 0
707        elif isinstance(arg,Symbol):
708           return WhereNegative_Symbol(arg)
709        elif hasattr(arg,"whereNegative"):
710           return arg.whereNegative()
711        elif hasattr(arg,"shape"):
712           numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
713        else:
714           if arg<0:
715              return 1.
716           else:
717              return 0.
718
719    class WhereNegative_Symbol(Symbol):
720       """symbol representing the whereNegative function"""
721       def __init__(self,arg):
722           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
723       def __str__(self):
724          return "whereNegative(%s)"%str(self.getArgument(0))
725       def eval(self,argval):
726           return whereNegative(self.getEvaluatedArguments(argval)[0])
727
728    def outer(arg0,arg1):
729       if _testForZero(arg0) or _testForZero(arg1):
730          return 0
731       else:
732          if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
733            return Outer_Symbol(arg0,arg1)
734          elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
735            return arg0*arg1
736          elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
737            return numarray.outer(arg0,arg1)
738          else:
739            if arg0.getRank()==1 and arg1.getRank()==1:
740              out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
741              for i in range(arg0.getShape()[0]):
742                for j in range(arg1.getShape()[0]):
743                    out[i,j]=arg0[i]*arg1[j]
744              return out
745            else:
746              raise ValueError,"outer is not fully implemented yet."
747
748    class Outer_Symbol(Symbol):
749       """symbol representing the outer product of its two argument"""
750       def __init__(self,arg0,arg1):
751           a=[arg0,arg1]
752           s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
753           Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
754       def __str__(self):
755          return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
756       def eval(self,argval):
757           a=self.getEvaluatedArguments(argval)
758           return outer(a[0],a[1])
759       def _diff(self,arg):
760           a=self.getDifferentiatedArguments(arg)
761           return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
762
763    def interpolate(arg,where):
764      """      """
765      @brief      @brief interpolates the function into the FunctionSpace where.
766
767      @param arg      @param arg    interpolant
768        @param where  FunctionSpace to interpolate to
769      """      """
770      if hasattr(arg,"transpose"):      if _testForZero(arg):
771         return arg.transpose()        return 0
772        elif isinstance(arg,Symbol):
773           return Interpolated_Symbol(arg,where)
774      else:      else:
775         return numarray.transpose(arg,axis=None)         return escript.Data(arg,where)
776
777  def trace(arg):  def Interpolated_Symbol(Symbol):
778      """     """symbol representing the integral of the argument"""
779      @brief     def __init__(self,arg,where):
780            Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
781       def __str__(self):
782          return "interpolated(%s)"%(str(self.getArgument(0)))
783       def eval(self,argval):
784           a=self.getEvaluatedArguments(argval)
785           return integrate(a[0],where=self.getArgument(1))
786       def _diff(self,arg):
787           a=self.getDifferentiatedArguments(arg)
788           return integrate(a[0],where=self.getArgument(1))
789
791      """      """
792      if hasattr(arg,"trace"):      @brief returns the spatial gradient of arg at where.
793         return arg.trace()
794        @param arg:   Data object representing the function which gradient to be calculated.
795        @param where: FunctionSpace in which the gradient will be calculated. If not present or
796                      None an appropriate default is used.
797        """
798        if _testForZero(arg):
799          return 0
800        elif isinstance(arg,Symbol):
803           if where==None:
805           else:
807      else:      else:
808         return numarray.trace(arg,k=0)         return arg*0.
809
811       """symbol representing the gradient of the argument"""
812       def __init__(self,arg,where=None):
813           d=_extractDim([arg])
814           s=tuple(list(_identifyShape([arg])).append(d))
815           Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
816       def __str__(self):
818       def eval(self,argval):
819           a=self.getEvaluatedArguments(argval)
821       def _diff(self,arg):
822           a=self.getDifferentiatedArguments(arg)
824
825    def integrate(arg,where=None):
826        """
827        @brief return the integral if the function represented by Data object arg over its domain.
828
829        @param arg:   Data object representing the function which is integrated.
830        @param where: FunctionSpace in which the integral is calculated. If not present or
831                      None an appropriate default is used.
832        """
833        if _testForZero(arg):
834          return 0
835        elif isinstance(arg,Symbol):
836           return Integral_Symbol(arg,where)
837        else:
838           if not where==None: arg=escript.Data(arg,where)
839           if arg.getRank()==0:
840             return arg.integrate()[0]
841           else:
842             return arg.integrate()
843
844    def Integral_Symbol(Float_Symbol):
845       """symbol representing the integral of the argument"""
846       def __init__(self,arg,where=None):
847           Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
848       def __str__(self):
849          return "integral(%s)"%(str(self.getArgument(0)))
850       def eval(self,argval):
851           a=self.getEvaluatedArguments(argval)
852           return integrate(a[0],where=self.getArgument(1))
853       def _diff(self,arg):
854           a=self.getDifferentiatedArguments(arg)
855           return integrate(a[0],where=self.getArgument(1))
856
857    #=============================
858    #
859    # wrapper for various functions: if the argument has attribute the function name
860    # as an argument it calls the correspong methods. Otherwise the coresponsing numarray
861    # function is called.
862    #
863    # functions involving the underlying Domain:
864    #
865
866
867    # functions returning Data objects:
868
869    def transpose(arg,axis=None):
870      """      """
871      @brief      @brief returns the transpose of the Data object arg.
872
873      @param arg      @param arg
874      """      """
875      if hasattr(arg,"exp"):      if axis==None:
876         return arg.exp()         r=0
877           if hasattr(arg,"getRank"): r=arg.getRank()
878           if hasattr(arg,"rank"): r=arg.rank
879           axis=r/2
880        if isinstance(arg,Symbol):
881           return Transpose_Symbol(arg,axis=r)
882        if isinstance(arg,escript.Data):
883           # hack for transpose
884           r=arg.getRank()
885           if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
886           s=arg.getShape()
887           out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
888           for i in range(s[0]):
889              for j in range(s[1]):
890                 out[j,i]=arg[i,j]
891           return out
892           # end hack for transpose
893           return arg.transpose(axis)
894      else:      else:
895         return numarray.exp(arg)         return numarray.transpose(arg,axis=axis)
896
897  def sqrt(arg):  def trace(arg,axis0=0,axis1=1):
898      """      """
899      @brief      @brief return
900
901      @param arg      @param arg
902      """      """
903      if hasattr(arg,"sqrt"):      if isinstance(arg,Symbol):
904         return arg.sqrt()         s=list(arg.getShape())
905           s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
906           return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
907        elif isinstance(arg,escript.Data):
908           # hack for trace
909           s=arg.getShape()
910           if s[axis0]!=s[axis1]:
911               raise ValueError,"illegal axis in trace"
912           out=escript.Scalar(0.,arg.getFunctionSpace())
913           for i in range(s[0]):
914              for j in range(s[1]):
915                 out+=arg[i,j]
916           return out
917           # end hack for transpose
918           return arg.transpose(axis0=axis0,axis1=axis1)
919      else:      else:
920         return numarray.sqrt(arg)         return numarray.trace(arg,axis0=axis0,axis1=axis1)
921
922  def sin(arg):
923
924    def Trace_Symbol(Symbol):
925        pass
926
927    def length(arg):
928      """      """
929      @brief      @brief
930
931      @param arg      @param arg
932      """      """
933      if hasattr(arg,"sin"):      if isinstance(arg,escript.Data):
934         return arg.sin()         if arg.isEmpty(): return escript.Data()
935           if arg.getRank()==0:
936              return abs(arg)
937           elif arg.getRank()==1:
938              sum=escript.Scalar(0,arg.getFunctionSpace())
939              for i in range(arg.getShape()[0]):
940                 sum+=arg[i]**2
941              return sqrt(sum)
942           elif arg.getRank()==2:
943              sum=escript.Scalar(0,arg.getFunctionSpace())
944              for i in range(arg.getShape()[0]):
945                 for j in range(arg.getShape()[1]):
946                    sum+=arg[i,j]**2
947              return sqrt(sum)
948           elif arg.getRank()==3:
949              sum=escript.Scalar(0,arg.getFunctionSpace())
950              for i in range(arg.getShape()[0]):
951                 for j in range(arg.getShape()[1]):
952                    for k in range(arg.getShape()[2]):
953                       sum+=arg[i,j,k]**2
954              return sqrt(sum)
955           elif arg.getRank()==4:
956              sum=escript.Scalar(0,arg.getFunctionSpace())
957              for i in range(arg.getShape()[0]):
958                 for j in range(arg.getShape()[1]):
959                    for k in range(arg.getShape()[2]):
960                       for l in range(arg.getShape()[3]):
961                          sum+=arg[i,j,k,l]**2
962              return sqrt(sum)
963           else:
964              raise SystemError,"length is not been implemented yet"
965           # return arg.length()
966      else:      else:
967         return numarray.sin(arg)         return sqrt((arg**2).sum())
968
969  def cos(arg):  def deviator(arg):
970      """      """
971      @brief      @brief
972
973      @param arg      @param arg0
974      """      """
975      if hasattr(arg,"cos"):      if isinstance(arg,escript.Data):
976         return arg.cos()          shape=arg.getShape()
977      else:      else:
978         return numarray.cos(arg)          shape=arg.shape
979        if len(shape)!=2:
980              raise ValueError,"Deviator requires rank 2 object"
981        if shape[0]!=shape[1]:
982              raise ValueError,"Deviator requires a square matrix"
983        return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
984
985  def maxval(arg):  def inner(arg0,arg1):
986      """      """
987      @brief      @brief
988
989      @param arg      @param arg0, arg1
990      """      """
991      return arg.maxval()      sum=escript.Scalar(0,arg0.getFunctionSpace())
992        if arg.getRank()==0:
993              return arg0*arg1
994        elif arg.getRank()==1:
995             sum=escript.Scalar(0,arg.getFunctionSpace())
996             for i in range(arg.getShape()[0]):
997                sum+=arg0[i]*arg1[i]
998        elif arg.getRank()==2:
999            sum=escript.Scalar(0,arg.getFunctionSpace())
1000            for i in range(arg.getShape()[0]):
1001               for j in range(arg.getShape()[1]):
1002                  sum+=arg0[i,j]*arg1[i,j]
1003        elif arg.getRank()==3:
1004            sum=escript.Scalar(0,arg.getFunctionSpace())
1005            for i in range(arg.getShape()[0]):
1006                for j in range(arg.getShape()[1]):
1007                   for k in range(arg.getShape()[2]):
1008                      sum+=arg0[i,j,k]*arg1[i,j,k]
1009        elif arg.getRank()==4:
1010            sum=escript.Scalar(0,arg.getFunctionSpace())
1011            for i in range(arg.getShape()[0]):
1012               for j in range(arg.getShape()[1]):
1013                  for k in range(arg.getShape()[2]):
1014                     for l in range(arg.getShape()[3]):
1015                        sum+=arg0[i,j,k,l]*arg1[i,j,k,l]
1016        else:
1017              raise SystemError,"inner is not been implemented yet"
1018        return sum
1019
1020  def minval(arg):  def matrixmult(arg0,arg1):
1021
1022        if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1023            numarray.matrixmult(arg0,arg1)
1024        else:
1025          # escript.matmult(arg0,arg1)
1026          if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1027            arg0=escript.Data(arg0,arg1.getFunctionSpace())
1028          elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1029            arg1=escript.Data(arg1,arg0.getFunctionSpace())
1030          if arg0.getRank()==2 and arg1.getRank()==1:
1031              out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1032              for i in range(arg0.getShape()[0]):
1033                 for j in range(arg0.getShape()[1]):
1034                   out[i]+=arg0[i,j]*arg1[j]
1035              return out
1036          else:
1037              raise SystemError,"matrixmult is not fully implemented yet!"
1038    #=========================================================
1039    # reduction operations:
1040    #=========================================================
1041    def sum(arg):
1042      """      """
1043      @brief      @brief
1044
1045      @param arg      @param arg
1046      """      """
1047      return arg.minval()      return arg.sum()
1048
1049  def sup(arg):  def sup(arg):
1050      """      """
# Line 174  def sup(arg): Line 1052  def sup(arg):
1052
1053      @param arg      @param arg
1054      """      """
1055      return arg.sup()      if isinstance(arg,escript.Data):
1056           return arg.sup()
1057        elif isinstance(arg,float) or isinstance(arg,int):
1058           return arg
1059        else:
1060           return arg.max()
1061
1062  def inf(arg):  def inf(arg):
1063      """      """
# Line 182  def inf(arg): Line 1065  def inf(arg):
1065
1066      @param arg      @param arg
1067      """      """
1068      return arg.inf()      if isinstance(arg,escript.Data):
1069           return arg.inf()
1070        elif isinstance(arg,float) or isinstance(arg,int):
1071           return arg
1072        else:
1073           return arg.min()
1074
1075  def Lsup(arg):  def L2(arg):
1076      """      """
1077      @brief      @brief returns the L2-norm of the
1078
1079      @param arg      @param arg
1080      """      """
1081      return arg.Lsup()      if isinstance(arg,escript.Data):
1082           return arg.L2()
1083        elif isinstance(arg,float) or isinstance(arg,int):
1084           return abs(arg)
1085        else:
1086           return numarry.sqrt(dot(arg,arg))
1087
1088  def length(arg):  def Lsup(arg):
1089      """      """
1090      @brief      @brief
1091
1092      @param arg      @param arg
1093      """      """
1094      return arg.length()      if isinstance(arg,escript.Data):
1095           return arg.Lsup()
1096        elif isinstance(arg,float) or isinstance(arg,int):
1097           return abs(arg)
1098        else:
1099           return max(numarray.abs(arg))
1100
1101  def sign(arg):  def dot(arg0,arg1):
1102      """      """
1103      @brief      @brief
1104
1105      @param arg      @param arg
1106      """      """
1107      return arg.sign()      if isinstance(arg0,escript.Data):
1108           return arg0.dot(arg1)
1109        elif isinstance(arg1,escript.Data):
1110           return arg1.dot(arg0)
1111        else:
1112           return numarray.dot(arg0,arg1)
1113
1114    def kronecker(d):
1115       if hasattr(d,"getDim"):
1116          return numarray.identity(d.getDim())
1117       else:
1118          return numarray.identity(d)
1119
1120    def unit(i,d):
1121       """
1122       @brief return a unit vector of dimension d with nonzero index i
1123       @param d dimension
1124       @param i index
1125       """
1126       e = numarray.zeros((d,),numarray.Float)
1127       e[i] = 1.0
1128       return e
1129
1130    #
1131    # ============================================
1132    #   testing
1133    # ============================================
1134
1135    if __name__=="__main__":
1136      u=ScalarSymbol(dim=2,name="u")
1137      v=ScalarSymbol(dim=2,name="v")
1138      v=VectorSymbol(2,"v")
1139      u=VectorSymbol(2,"u")
1140
1141
1142      print u+5,(u+5).diff(u)
1143      print 5+u,(5+u).diff(u)
1144      print u+v,(u+v).diff(u)
1145      print v+u,(v+u).diff(u)
1146
1147      print u*5,(u*5).diff(u)
1148      print 5*u,(5*u).diff(u)
1149      print u*v,(u*v).diff(u)
1150      print v*u,(v*u).diff(u)
1151
1152      print u-5,(u-5).diff(u)
1153      print 5-u,(5-u).diff(u)
1154      print u-v,(u-v).diff(u)
1155      print v-u,(v-u).diff(u)
1156
1157      print u/5,(u/5).diff(u)
1158      print 5/u,(5/u).diff(u)
1159      print u/v,(u/v).diff(u)
1160      print v/u,(v/u).diff(u)
1161
1162      print u**5,(u**5).diff(u)
1163      print 5**u,(5**u).diff(u)
1164      print u**v,(u**v).diff(u)
1165      print v**u,(v**u).diff(u)
1166
1167      print exp(u),exp(u).diff(u)
1168      print sqrt(u),sqrt(u).diff(u)
1169      print log(u),log(u).diff(u)
1170      print sin(u),sin(u).diff(u)
1171      print cos(u),cos(u).diff(u)
1172      print tan(u),tan(u).diff(u)
1173      print sign(u),sign(u).diff(u)
1174      print abs(u),abs(u).diff(u)
1175      print wherePositive(u),wherePositive(u).diff(u)
1176      print whereNegative(u),whereNegative(u).diff(u)
1177      print maxval(u),maxval(u).diff(u)
1178      print minval(u),minval(u).diff(u)
1179
1181      print diff(5*g,g)
1182      4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1183  #  #
1184  # \$Log\$  # \$Log\$
1185  # Revision 1.4  2004/12/15 03:48:42  jgs  # Revision 1.11  2005/07/08 04:07:35  jgs
1186  # *** empty log message ***  # Merge of development branch back to main trunk on 2005-07-08
1187    #
1188    # Revision 1.10  2005/06/09 05:37:59  jgs
1189    # Merge of development branch back to main trunk on 2005-06-09
1190    #
1191    # Revision 1.2.2.17  2005/07/07 07:28:58  gross
1192    # some stuff added to util.py to improve functionality
1193    #
1194    # Revision 1.2.2.16  2005/06/30 01:53:55  gross
1195    # a bug in coloring fixed
1196    #
1197    # Revision 1.2.2.15  2005/06/29 02:36:43  gross
1198    # Symbols have been introduced and some function clarified. needs much more work
1199    #
1200    # Revision 1.2.2.14  2005/05/20 04:05:23  gross
1201    # some work on a darcy flow started
1202    #
1203    # Revision 1.2.2.13  2005/03/16 05:17:58  matt
1204    # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1205    # complement kronecker(dim) function.
1206    #
1207    # Revision 1.2.2.12  2005/03/10 08:14:37  matt
1208    # Added non-member Linf utility function to complement Data::Linf().
1209    #
1210    # Revision 1.2.2.11  2005/02/17 05:53:25  gross
1211    # some bug in saveDX fixed: in fact the bug was in
1212    # DataC/getDataPointShape
1213    #
1214    # Revision 1.2.2.10  2005/01/11 04:59:36  gross
1215    # automatic interpolation in integrate switched off
1216    #
1217    # Revision 1.2.2.9  2005/01/11 03:38:13  gross
1218    # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1219    #
1220    # Revision 1.2.2.8  2005/01/05 04:21:41  gross
1221    # FunctionSpace checking/matchig in slicing added
1222    #
1223    # Revision 1.2.2.7  2004/12/29 05:29:59  gross
1224    # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1225    #
1226    # Revision 1.2.2.6  2004/12/24 06:05:41  gross
1228    #
1229    # Revision 1.2.2.5  2004/12/17 00:06:53  gross
1230    # mk sets ESYS_ROOT is undefined
1231    #
1232    # Revision 1.2.2.4  2004/12/07 03:19:51  gross
1233    # options for GMRES and PRES20 added
1234    #
1235    # Revision 1.2.2.3  2004/12/06 04:55:18  gross
1236    # function wraper extended
1237    #
1238    # Revision 1.2.2.2  2004/11/22 05:44:07  gross
1239    # a few more unitary functions have been added but not implemented in Data yet
1240    #
1241    # Revision 1.2.2.1  2004/11/12 06:58:15  gross
1242    # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1243  #  #
1244  # Revision 1.2  2004/10/27 00:23:36  jgs  # Revision 1.2  2004/10/27 00:23:36  jgs
1245  # fixed minor syntax error  # fixed minor syntax error

Legend:
 Removed from v.100 changed lines Added in v.123