/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 102 by jgs, Wed Dec 15 07:08:39 2004 UTC revision 123 by jgs, Fri Jul 8 04:08:13 2005 UTC
# Line 4  Line 4 
4    
5  """  """
6  @brief Utility functions for escript  @brief Utility functions for escript
7    
8    TODO for Data:
9    
10      * binary operations @:               (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11                    @=+,-,*,/,**           (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12                                           (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13      
14      * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15      * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16    
17  """  """
18    
19  import numarray  import numarray
# Line 49  TIFF=7 Line 59  TIFF=7
59  OPENINVENTOR=8  OPENINVENTOR=8
60  RENDERMAN=9  RENDERMAN=9
61  PNM=10  PNM=10
62  #  #===========================================================
63  # wrapper for various functions: if the argument has attribute the function name  #  a simple tool box to deal with _differentials of functions
64  # as an argument it calls the correspong methods. Otherwise the coresponsing numarray  #===========================================================
65  # function is called.  
66  #  class Symbol:
67  # functions involving the underlying Domain:     """symbol class"""
68  #     def __init__(self,name="symbol",shape=(),dim=3,args=[]):
69  def grad(arg,where=None):         """creates an instance of a symbol of shape shape and spatial dimension dim
70      """            The symbol may depending on a list of arguments args which
71      @brief returns the spatial gradient of the Data object arg            may be symbols or other objects. name gives the name of the symbol."""
72           self.__args=args
73           self.__name=name
74           self.__shape=shape
75           if hasattr(dim,"getDim"):
76               self.__dim=dim.getDim()
77           else:    
78               self.__dim=dim
79           #
80           self.__cache_val=None
81           self.__cache_argval=None
82    
83       def getArgument(self,i):
84           """returns the i-th argument"""
85           return self.__args[i]
86    
87       def getDim(self):
88           """returns the spatial dimension of the symbol"""
89           return self.__dim
90    
91       def getRank(self):
92           """returns the rank of the symbol"""
93           return len(self.getShape())
94    
95       def getShape(self):
96           """returns the shape of the symbol"""
97           return self.__shape
98    
99       def getEvaluatedArguments(self,argval):
100           """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
101           if argval==self.__cache_argval:
102               print "%s: cached value used"%self
103               return self.__cache_val
104           else:
105               out=[]
106               for a  in self.__args:
107                  if isinstance(a,Symbol):
108                    out.append(a.eval(argval))
109                  else:
110                    out.append(a)
111               self.__cache_argval=argval
112               self.__cache_val=out
113               return out
114    
115       def getDifferentiatedArguments(self,arg):
116           """returns the list of the arguments _differentiated by arg"""
117           out=[]
118           for a in self.__args:
119              if isinstance(a,Symbol):
120                out.append(a.diff(arg))
121              else:
122                out.append(0)
123           return out
124    
125       def diff(self,arg):
126           """returns the _differention of self by arg."""
127           if self==arg:
128              out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
129              if self.getRank()==0:
130                 out=1.
131              elif self.getRank()==1:
132                  for i0 in range(self.getShape()[0]):
133                     out[i0,i0]=1.  
134              elif self.getRank()==2:
135                  for i0 in range(self.getShape()[0]):
136                    for i1 in range(self.getShape()[1]):
137                         out[i0,i1,i0,i1]=1.  
138              elif self.getRank()==3:
139                  for i0 in range(self.getShape()[0]):
140                    for i1 in range(self.getShape()[1]):
141                      for i2 in range(self.getShape()[2]):
142                         out[i0,i1,i2,i0,i1,i2]=1.  
143              elif self.getRank()==4:
144                  for i0 in range(self.getShape()[0]):
145                    for i1 in range(self.getShape()[1]):
146                      for i2 in range(self.getShape()[2]):
147                        for i3 in range(self.getShape()[3]):
148                           out[i0,i1,i2,i3,i0,i1,i2,i3]=1.  
149              else:
150                 raise ValueError,"differential support rank<5 only."
151              return out
152           else:
153              return self._diff(arg)
154    
155       def _diff(self,arg):
156           """return derivate of self with respect to arg (!=self).
157              This method is overwritten by a particular symbol"""
158           return 0
159    
160       def eval(self,argval):
161           """subsitutes symbol u in self by argval[u] and returns the result. If
162              self is not a key of argval then self is returned."""
163           if argval.has_key(self):
164             return argval[self]
165           else:
166             return self
167    
168    
169       def __str__(self):
170           """returns a string representation of the symbol"""
171           return self.__name
172    
173       def __add__(self,other):
174           """adds other to symbol self. if _testForZero(other) self is returned."""
175           if _testForZero(other):
176              return self
177           else:
178              a=_matchShape([self,other])
179              return Add_Symbol(a[0],a[1])
180    
181    
182       def __radd__(self,other):
183           """adds other to symbol self. if _testForZero(other) self is returned."""
184           return self+other
185    
186       def __neg__(self):
187           """returns -self."""
188           return self*(-1.)
189    
190       def __pos__(self):
191           """returns +self."""
192           return self
193    
194       def __abs__(self):
195           """returns absolute value"""
196           return Abs_Symbol(self)
197    
198       def __sub__(self,other):
199           """subtracts other from symbol self. if _testForZero(other) self is returned."""
200           if _testForZero(other):
201              return self
202           else:
203              return self+(-other)
204    
205       def __rsub__(self,other):
206           """subtracts symbol self from other."""
207           return -self+other
208    
209       def __div__(self,other):
210           """divides symbol self by other."""
211           if isinstance(other,Symbol):
212              a=_matchShape([self,other])
213              return Div_Symbol(a[0],a[1])
214           else:
215              return self*(1./other)
216    
217       def __rdiv__(self,other):
218           """dived other by symbol self. if _testForZero(other) 0 is returned."""
219           if _testForZero(other):
220              return 0
221           else:
222              a=_matchShape([self,other])
223              return Div_Symbol(a[0],a[1])
224    
225       def __pow__(self,other):
226           """raises symbol self to the power of other"""
227           a=_matchShape([self,other])
228           return Power_Symbol(a[0],a[1])
229    
230       def __rpow__(self,other):
231           """raises other to the symbol self"""
232           a=_matchShape([self,other])
233           return Power_Symbol(a[1],a[0])
234    
235       def __mul__(self,other):
236           """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
237           if _testForZero(other):
238              return 0
239           else:
240              a=_matchShape([self,other])
241              return Mult_Symbol(a[0],a[1])
242    
243       def __rmul__(self,other):
244           """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
245           return self*other
246    
247       def __getitem__(self,sl):
248              print sl
249    
250    def Float_Symbol(Symbol):
251        def __init__(self,name="symbol",shape=(),args=[]):
252            Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
253    
254    class ScalarSymbol(Symbol):
255       """a scalar symbol"""
256       def __init__(self,dim=3,name="scalar"):
257          """creates a scalar symbol of spatial dimension dim"""
258          if hasattr(dim,"getDim"):
259               d=dim.getDim()
260          else:    
261               d=dim
262          Symbol.__init__(self,shape=(),dim=d,name=name)
263    
264    class VectorSymbol(Symbol):
265       """a vector symbol"""
266       def __init__(self,dim=3,name="vector"):
267          """creates a vector symbol of spatial dimension dim"""
268          if hasattr(dim,"getDim"):
269               d=dim.getDim()
270          else:    
271               d=dim
272          Symbol.__init__(self,shape=(d,),dim=d,name=name)
273    
274    class TensorSymbol(Symbol):
275       """a tensor symbol"""
276       def __init__(self,dim=3,name="tensor"):
277          """creates a tensor symbol of spatial dimension dim"""
278          if hasattr(dim,"getDim"):
279               d=dim.getDim()
280          else:    
281               d=dim
282          Symbol.__init__(self,shape=(d,d),dim=d,name=name)
283    
284    class Tensor3Symbol(Symbol):
285       """a tensor order 3 symbol"""
286       def __init__(self,dim=3,name="tensor3"):
287          """creates a tensor order 3 symbol of spatial dimension dim"""
288          if hasattr(dim,"getDim"):
289               d=dim.getDim()
290          else:    
291               d=dim
292          Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
293    
294    class Tensor4Symbol(Symbol):
295       """a tensor order 4 symbol"""
296       def __init__(self,dim=3,name="tensor4"):
297          """creates a tensor order 4 symbol of spatial dimension dim"""    
298          if hasattr(dim,"getDim"):
299               d=dim.getDim()
300          else:    
301               d=dim
302          Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
303    
304    
305    class Add_Symbol(Symbol):
306       """symbol representing the sum of two arguments"""
307       def __init__(self,arg0,arg1):
308           a=[arg0,arg1]
309           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
310       def __str__(self):
311          return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
312       def eval(self,argval):
313           a=self.getEvaluatedArguments(argval)
314           return a[0]+a[1]
315       def _diff(self,arg):
316           a=self.getDifferentiatedArguments(arg)
317           return a[0]+a[1]
318    
319    class Mult_Symbol(Symbol):
320       """symbol representing the product of two arguments"""
321       def __init__(self,arg0,arg1):
322           a=[arg0,arg1]
323           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
324       def __str__(self):
325          return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
326       def eval(self,argval):
327           a=self.getEvaluatedArguments(argval)
328           return a[0]*a[1]
329       def _diff(self,arg):
330           a=self.getDifferentiatedArguments(arg)
331           return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
332    
333    class Div_Symbol(Symbol):
334       """symbol representing the quotient of two arguments"""
335       def __init__(self,arg0,arg1):
336           a=[arg0,arg1]
337           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
338       def __str__(self):
339          return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
340       def eval(self,argval):
341           a=self.getEvaluatedArguments(argval)
342           return a[0]/a[1]
343       def _diff(self,arg):
344           a=self.getDifferentiatedArguments(arg)
345           return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
346                              (self.getArgument(1)*self.getArgument(1))
347    
348    class Power_Symbol(Symbol):
349       """symbol representing the power of the first argument to the power of the second argument"""
350       def __init__(self,arg0,arg1):
351           a=[arg0,arg1]
352           Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
353       def __str__(self):
354          return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
355       def eval(self,argval):
356           a=self.getEvaluatedArguments(argval)
357           return a[0]**a[1]
358       def _diff(self,arg):
359           a=self.getDifferentiatedArguments(arg)
360           return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
361    
362    class Abs_Symbol(Symbol):
363       """symbol representing absolute value of its argument"""
364       def __init__(self,arg):
365           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
366       def __str__(self):
367          return "abs(%s)"%str(self.getArgument(0))
368       def eval(self,argval):
369           return abs(self.getEvaluatedArguments(argval)[0])
370       def _diff(self,arg):
371           return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
372    
373    #=========================================================
374    #   some little helpers
375    #=========================================================
376    def _testForZero(arg):
377       """returns True is arg is considered of being zero"""
378       if isinstance(arg,int):
379          return not arg>0
380       elif isinstance(arg,float):
381          return not arg>0.
382       elif isinstance(arg,numarray.NumArray):
383          a=abs(arg)
384          while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
385          return not a>0
386       else:
387          return False
388    
389    def _extractDim(args):
390        dim=None
391        for a in args:
392           if hasattr(a,"getDim"):
393              d=a.getDim()
394              if dim==None:
395                 dim=d
396              else:
397                 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
398        if dim==None:
399           raise ValueError,"cannot recover spatial dimension"
400        return dim
401    
402    def _identifyShape(arg):
403       """identifies the shape of arg."""
404       if hasattr(arg,"getShape"):
405           arg_shape=arg.getShape()
406       elif hasattr(arg,"shape"):
407         s=arg.shape
408         if callable(s):
409           arg_shape=s()
410         else:
411           arg_shape=s
412       else:
413           arg_shape=()
414       return arg_shape
415    
416    def _extractShape(args):
417        """extracts the common shape of the list of arguments args"""
418        shape=None
419        for a in args:
420           a_shape=_identifyShape(a)
421           if shape==None: shape=a_shape
422           if shape!=a_shape: raise ValueError,"inconsistent shape"
423        if shape==None:
424           raise ValueError,"cannot recover shape"
425        return shape
426    
427    def _matchShape(args,shape=None):
428        """returns the list of arguments args as object which have all the specified shape.
429           if shape is not given the shape "largest" shape of args is used."""
430        # identify the list of shapes:
431        arg_shapes=[]
432        for a in args: arg_shapes.append(_identifyShape(a))
433        # get the largest shape (currently the longest shape):
434        if shape==None: shape=max(arg_shapes)
435        
436        out=[]
437        for i in range(len(args)):
438           if shape==arg_shapes[i]:
439              out.append(args[i])
440           else:
441              if len(shape)==0: # then len(arg_shapes[i])>0
442                raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
443              else:
444                if len(arg_shapes[i])==0:
445                    out.append(outer(args[i],numarray.ones(shape)))        
446                else:  
447                    raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
448        return out  
449    #=========================================================
450    #   wrapper for various mathematical functions:
451    #=========================================================
452    def diff(arg,dep):
453        """returns the derivative of arg with respect to dep. If arg is not Symbol object
454           0 is returned"""
455        if isinstance(arg,Symbol):
456           return arg.diff(dep)
457        elif hasattr(arg,"shape"):
458              if callable(arg.shape):
459                  return numarray.zeros(arg.shape(),numarray.Float)
460              else:
461                  return numarray.zeros(arg.shape,numarray.Float)
462        else:
463           return 0
464    
465      @param arg: Data object representing the function which gradient to be calculated.  def exp(arg):
466      @param where: FunctionSpace in which the gradient will be. If None Function(dom) where dom is the      """
467                    domain of the Data object arg.      @brief applies the exponential function to arg
468        @param arg (input): argument
469      """      """
470      if where==None:      if isinstance(arg,Symbol):
471         return arg.grad()         return Exp_Symbol(arg)
472        elif hasattr(arg,"exp"):
473           return arg.exp()
474      else:      else:
475         return arg.grad(where)         return numarray.exp(arg)
476    
477  def integrate(arg):  class Exp_Symbol(Symbol):
478      """     """symbol representing the power of the first argument to the power of the second argument"""
479      @brief return the integral if the function represented by Data object arg over its domain.     def __init__(self,arg):
480           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
481       def __str__(self):
482          return "exp(%s)"%str(self.getArgument(0))
483       def eval(self,argval):
484           return exp(self.getEvaluatedArguments(argval)[0])
485       def _diff(self,arg):
486           return self*self.getDifferentiatedArguments(arg)[0]
487    
488      @param arg  def sqrt(arg):
489      """      """
490      return arg.integrate()      @brief applies the squre root function to arg
491        @param arg (input): argument
 def interpolate(arg,where):  
492      """      """
493      @brief interpolates the function represented by Data object arg into the FunctionSpace where.      if isinstance(arg,Symbol):
494           return Sqrt_Symbol(arg)
495        elif hasattr(arg,"sqrt"):
496           return arg.sqrt()
497        else:
498           return numarray.sqrt(arg)      
499    
500      @param arg  class Sqrt_Symbol(Symbol):
501      @param where     """symbol representing square root of argument"""
502      """     def __init__(self,arg):
503      return arg.interpolate(where)         Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
504       def __str__(self):
505          return "sqrt(%s)"%str(self.getArgument(0))
506       def eval(self,argval):
507           return sqrt(self.getEvaluatedArguments(argval)[0])
508       def _diff(self,arg):
509           return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
510    
511    def log(arg):
512        """
513        @brief applies the logarithmic function bases exp(1.) to arg
514        @param arg (input): argument
515        """
516        if isinstance(arg,Symbol):
517           return Log_Symbol(arg)
518        elif hasattr(arg,"log"):
519           return arg.log()
520        else:
521           return numarray.log(arg)
522    
523  # functions returning Data objects:  class Log_Symbol(Symbol):
524       """symbol representing logarithm of the argument"""
525       def __init__(self,arg):
526           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
527       def __str__(self):
528          return "log(%s)"%str(self.getArgument(0))
529       def eval(self,argval):
530           return log(self.getEvaluatedArguments(argval)[0])
531       def _diff(self,arg):
532           return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
533    
534  def transpose(arg,axis=None):  def sin(arg):
535      """      """
536      @brief returns the transpose of the Data object arg.      @brief applies the sinus function to arg
537        @param arg (input): argument
     @param arg  
538      """      """
539      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
540         if axis==None: axis=arg.getRank()/2         return Sin_Symbol(arg)
541         return arg.transpose(axis)      elif hasattr(arg,"sin"):
542           return arg.sin()
543      else:      else:
544         if axis==None: axis=arg.rank/2         return numarray.sin(arg)
        return numarray.transpose(arg,axis=axis)  
545    
546  def trace(arg):  class Sin_Symbol(Symbol):
547      """     """symbol representing logarithm of the argument"""
548      @brief     def __init__(self,arg):
549           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
550       def __str__(self):
551          return "sin(%s)"%str(self.getArgument(0))
552       def eval(self,argval):
553           return sin(self.getEvaluatedArguments(argval)[0])
554       def _diff(self,arg):
555           return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
556    
557      @param arg  def cos(arg):
558      """      """
559      if isinstance(arg,escript.Data):      @brief applies the sinus function to arg
560         return arg.trace()      @param arg (input): argument
561        """
562        if isinstance(arg,Symbol):
563           return Cos_Symbol(arg)
564        elif hasattr(arg,"cos"):
565           return arg.cos()
566      else:      else:
567         return numarray.trace(arg)         return numarray.cos(arg)
568    
569  def exp(arg):  class Cos_Symbol(Symbol):
570       """symbol representing logarithm of the argument"""
571       def __init__(self,arg):
572           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
573       def __str__(self):
574          return "cos(%s)"%str(self.getArgument(0))
575       def eval(self,argval):
576           return cos(self.getEvaluatedArguments(argval)[0])
577       def _diff(self,arg):
578           return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
579    
580    def tan(arg):
581      """      """
582      @brief      @brief applies the sinus function to arg
583        @param arg (input): argument
584        """
585        if isinstance(arg,Symbol):
586           return Tan_Symbol(arg)
587        elif hasattr(arg,"tan"):
588           return arg.tan()
589        else:
590           return numarray.tan(arg)
591    
592      @param arg  class Tan_Symbol(Symbol):
593       """symbol representing logarithm of the argument"""
594       def __init__(self,arg):
595           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
596       def __str__(self):
597          return "tan(%s)"%str(self.getArgument(0))
598       def eval(self,argval):
599           return tan(self.getEvaluatedArguments(argval)[0])
600       def _diff(self,arg):
601           s=cos(self.getArgument(0))
602           return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
603    
604    def sign(arg):
605      """      """
606      if isinstance(arg,escript.Data):      @brief applies the sign function to arg
607         return arg.exp()      @param arg (input): argument
608        """
609        if isinstance(arg,Symbol):
610           return Sign_Symbol(arg)
611        elif hasattr(arg,"sign"):
612           return arg.sign()
613      else:      else:
614         return numarray.exp(arg)         return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
615                  numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
616    
617  def sqrt(arg):  class Sign_Symbol(Symbol):
618       """symbol representing the sign of the argument"""
619       def __init__(self,arg):
620           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
621       def __str__(self):
622          return "sign(%s)"%str(self.getArgument(0))
623       def eval(self,argval):
624           return sign(self.getEvaluatedArguments(argval)[0])
625    
626    def maxval(arg):
627      """      """
628      @brief      @brief returns the maximum value of argument arg""
629        @param arg (input): argument
630        """
631        if isinstance(arg,Symbol):
632           return Max_Symbol(arg)
633        elif hasattr(arg,"maxval"):
634           return arg.maxval()
635        elif hasattr(arg,"max"):
636           return arg.max()
637        else:
638           return arg
639    
640      @param arg  class Max_Symbol(Symbol):
641       """symbol representing the sign of the argument"""
642       def __init__(self,arg):
643           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
644       def __str__(self):
645          return "maxval(%s)"%str(self.getArgument(0))
646       def eval(self,argval):
647           return maxval(self.getEvaluatedArguments(argval)[0])
648    
649    def minval(arg):
650      """      """
651      if isinstance(arg,escript.Data):      @brief returns the maximum value of argument arg""
652         return arg.sqrt()      @param arg (input): argument
653        """
654        if isinstance(arg,Symbol):
655           return Min_Symbol(arg)
656        elif hasattr(arg,"maxval"):
657           return arg.minval()
658        elif hasattr(arg,"min"):
659           return arg.min()
660      else:      else:
661         return numarray.sqrt(arg)         return arg
662    
663  def sin(arg):  class Min_Symbol(Symbol):
664       """symbol representing the sign of the argument"""
665       def __init__(self,arg):
666           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
667       def __str__(self):
668          return "minval(%s)"%str(self.getArgument(0))
669       def eval(self,argval):
670           return minval(self.getEvaluatedArguments(argval)[0])
671    
672    def wherePositive(arg):
673        """
674        @brief returns the maximum value of argument arg""
675        @param arg (input): argument
676        """
677        if _testForZero(arg):
678          return 0
679        elif isinstance(arg,Symbol):
680           return WherePositive_Symbol(arg)
681        elif hasattr(arg,"wherePositive"):
682           return arg.minval()
683        elif hasattr(arg,"wherePositive"):
684           numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
685        else:
686           if arg>0:
687              return 1.
688           else:
689              return 0.
690    
691    class WherePositive_Symbol(Symbol):
692       """symbol representing the wherePositive function"""
693       def __init__(self,arg):
694           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
695       def __str__(self):
696          return "wherePositive(%s)"%str(self.getArgument(0))
697       def eval(self,argval):
698           return wherePositive(self.getEvaluatedArguments(argval)[0])
699    
700    def whereNegative(arg):
701        """
702        @brief returns the maximum value of argument arg""
703        @param arg (input): argument
704        """
705        if _testForZero(arg):
706          return 0
707        elif isinstance(arg,Symbol):
708           return WhereNegative_Symbol(arg)
709        elif hasattr(arg,"whereNegative"):
710           return arg.whereNegative()
711        elif hasattr(arg,"shape"):
712           numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
713        else:
714           if arg<0:
715              return 1.
716           else:
717              return 0.
718    
719    class WhereNegative_Symbol(Symbol):
720       """symbol representing the whereNegative function"""
721       def __init__(self,arg):
722           Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
723       def __str__(self):
724          return "whereNegative(%s)"%str(self.getArgument(0))
725       def eval(self,argval):
726           return whereNegative(self.getEvaluatedArguments(argval)[0])
727    
728    def outer(arg0,arg1):
729       if _testForZero(arg0) or _testForZero(arg1):
730          return 0
731       else:
732          if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
733            return Outer_Symbol(arg0,arg1)
734          elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
735            return arg0*arg1
736          elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
737            return numarray.outer(arg0,arg1)
738          else:
739            if arg0.getRank()==1 and arg1.getRank()==1:
740              out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
741              for i in range(arg0.getShape()[0]):
742                for j in range(arg1.getShape()[0]):
743                    out[i,j]=arg0[i]*arg1[j]
744              return out
745            else:
746              raise ValueError,"outer is not fully implemented yet."
747    
748    class Outer_Symbol(Symbol):
749       """symbol representing the outer product of its two argument"""
750       def __init__(self,arg0,arg1):
751           a=[arg0,arg1]
752           s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
753           Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
754       def __str__(self):
755          return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
756       def eval(self,argval):
757           a=self.getEvaluatedArguments(argval)
758           return outer(a[0],a[1])
759       def _diff(self,arg):
760           a=self.getDifferentiatedArguments(arg)
761           return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
762    
763    def interpolate(arg,where):
764      """      """
765      @brief      @brief interpolates the function into the FunctionSpace where.
766    
767      @param arg      @param arg    interpolant
768        @param where  FunctionSpace to interpolate to
769      """      """
770      if isinstance(arg,escript.Data):      if _testForZero(arg):
771         return arg.sin()        return 0
772        elif isinstance(arg,Symbol):
773           return Interpolated_Symbol(arg,where)
774      else:      else:
775         return numarray.sin(arg)         return escript.Data(arg,where)
776    
777  def tan(arg):  def Interpolated_Symbol(Symbol):
778      """     """symbol representing the integral of the argument"""
779      @brief     def __init__(self,arg,where):
780            Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
781       def __str__(self):
782          return "interpolated(%s)"%(str(self.getArgument(0)))
783       def eval(self,argval):
784           a=self.getEvaluatedArguments(argval)
785           return integrate(a[0],where=self.getArgument(1))
786       def _diff(self,arg):
787           a=self.getDifferentiatedArguments(arg)
788           return integrate(a[0],where=self.getArgument(1))
789    
790      @param arg  def grad(arg,where=None):
791      """      """
792      if isinstance(arg,escript.Data):      @brief returns the spatial gradient of arg at where.
793         return arg.tan()  
794        @param arg:   Data object representing the function which gradient to be calculated.
795        @param where: FunctionSpace in which the gradient will be calculated. If not present or
796                      None an appropriate default is used.
797        """
798        if _testForZero(arg):
799          return 0
800        elif isinstance(arg,Symbol):
801           return Grad_Symbol(arg,where)
802        elif hasattr(arg,"grad"):
803           if where==None:
804              return arg.grad()
805           else:
806              return arg.grad(where)
807      else:      else:
808         return numarray.tan(arg)         return arg*0.
809    
810  def cos(arg):  def Grad_Symbol(Symbol):
811       """symbol representing the gradient of the argument"""
812       def __init__(self,arg,where=None):
813           d=_extractDim([arg])
814           s=tuple(list(_identifyShape([arg])).append(d))
815           Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
816       def __str__(self):
817          return "grad(%s)"%(str(self.getArgument(0)))
818       def eval(self,argval):
819           a=self.getEvaluatedArguments(argval)
820           return grad(a[0],where=self.getArgument(1))
821       def _diff(self,arg):
822           a=self.getDifferentiatedArguments(arg)
823           return grad(a[0],where=self.getArgument(1))
824    
825    def integrate(arg,where=None):
826      """      """
827      @brief      @brief return the integral if the function represented by Data object arg over its domain.
828    
829        @param arg:   Data object representing the function which is integrated.
830        @param where: FunctionSpace in which the integral is calculated. If not present or
831                      None an appropriate default is used.
832        """
833        if _testForZero(arg):
834          return 0
835        elif isinstance(arg,Symbol):
836           return Integral_Symbol(arg,where)
837        else:    
838           if not where==None: arg=escript.Data(arg,where)
839           if arg.getRank()==0:
840             return arg.integrate()[0]
841           else:
842             return arg.integrate()
843    
844    def Integral_Symbol(Float_Symbol):
845       """symbol representing the integral of the argument"""
846       def __init__(self,arg,where=None):
847           Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
848       def __str__(self):
849          return "integral(%s)"%(str(self.getArgument(0)))
850       def eval(self,argval):
851           a=self.getEvaluatedArguments(argval)
852           return integrate(a[0],where=self.getArgument(1))
853       def _diff(self,arg):
854           a=self.getDifferentiatedArguments(arg)
855           return integrate(a[0],where=self.getArgument(1))
856    
857    #=============================
858    #
859    # wrapper for various functions: if the argument has attribute the function name
860    # as an argument it calls the correspong methods. Otherwise the coresponsing numarray
861    # function is called.
862    #
863    # functions involving the underlying Domain:
864    #
865    
866    
867    # functions returning Data objects:
868    
869    def transpose(arg,axis=None):
870        """
871        @brief returns the transpose of the Data object arg.
872    
873      @param arg      @param arg
874      """      """
875        if axis==None:
876           r=0
877           if hasattr(arg,"getRank"): r=arg.getRank()
878           if hasattr(arg,"rank"): r=arg.rank
879           axis=r/2
880        if isinstance(arg,Symbol):
881           return Transpose_Symbol(arg,axis=r)
882      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
883         return arg.cos()         # hack for transpose
884           r=arg.getRank()
885           if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
886           s=arg.getShape()
887           out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
888           for i in range(s[0]):
889              for j in range(s[1]):
890                 out[j,i]=arg[i,j]
891           return out
892           # end hack for transpose
893           return arg.transpose(axis)
894      else:      else:
895         return numarray.cos(arg)         return numarray.transpose(arg,axis=axis)
896    
897  def maxval(arg):  def trace(arg,axis0=0,axis1=1):
898      """      """
899      @brief      @brief return
900    
901      @param arg      @param arg
902      """      """
903      if isinstance(arg,escript.Data):      if isinstance(arg,Symbol):
904         return arg.maxval()         s=list(arg.getShape())        
905           s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
906           return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
907        elif isinstance(arg,escript.Data):
908           # hack for trace
909           s=arg.getShape()
910           if s[axis0]!=s[axis1]:
911               raise ValueError,"illegal axis in trace"
912           out=escript.Scalar(0.,arg.getFunctionSpace())
913           for i in range(s[0]):
914              for j in range(s[1]):
915                 out+=arg[i,j]
916           return out
917           # end hack for transpose
918           return arg.transpose(axis0=axis0,axis1=axis1)
919      else:      else:
920         return arg.max()         return numarray.trace(arg,axis0=axis0,axis1=axis1)
921    
922  def minval(arg):  
923    
924    def Trace_Symbol(Symbol):
925        pass
926    
927    def length(arg):
928      """      """
929      @brief      @brief
930    
931      @param arg      @param arg
932      """      """
933      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
934         return arg.minval()         if arg.isEmpty(): return escript.Data()
935           if arg.getRank()==0:
936              return abs(arg)
937           elif arg.getRank()==1:
938              sum=escript.Scalar(0,arg.getFunctionSpace())
939              for i in range(arg.getShape()[0]):
940                 sum+=arg[i]**2
941              return sqrt(sum)
942           elif arg.getRank()==2:
943              sum=escript.Scalar(0,arg.getFunctionSpace())
944              for i in range(arg.getShape()[0]):
945                 for j in range(arg.getShape()[1]):
946                    sum+=arg[i,j]**2
947              return sqrt(sum)
948           elif arg.getRank()==3:
949              sum=escript.Scalar(0,arg.getFunctionSpace())
950              for i in range(arg.getShape()[0]):
951                 for j in range(arg.getShape()[1]):
952                    for k in range(arg.getShape()[2]):
953                       sum+=arg[i,j,k]**2
954              return sqrt(sum)
955           elif arg.getRank()==4:
956              sum=escript.Scalar(0,arg.getFunctionSpace())
957              for i in range(arg.getShape()[0]):
958                 for j in range(arg.getShape()[1]):
959                    for k in range(arg.getShape()[2]):
960                       for l in range(arg.getShape()[3]):
961                          sum+=arg[i,j,k,l]**2
962              return sqrt(sum)
963           else:
964              raise SystemError,"length is not been implemented yet"
965           # return arg.length()
966      else:      else:
967         return arg.max()         return sqrt((arg**2).sum())
968    
969  def length(arg):  def deviator(arg):
970      """      """
971      @brief      @brief
972    
973      @param arg      @param arg0
974      """      """
975      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
976         return arg.length()          shape=arg.getShape()
977      else:      else:
978         return sqrt((arg**2).sum())          shape=arg.shape
979        if len(shape)!=2:
980              raise ValueError,"Deviator requires rank 2 object"
981        if shape[0]!=shape[1]:
982              raise ValueError,"Deviator requires a square matrix"
983        return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
984    
985  def sign(arg):  def inner(arg0,arg1):
986      """      """
987      @brief      @brief
988    
989      @param arg      @param arg0, arg1
990      """      """
991      if isinstance(arg,escript.Data):      sum=escript.Scalar(0,arg0.getFunctionSpace())
992         return arg.sign()      if arg.getRank()==0:
993              return arg0*arg1
994        elif arg.getRank()==1:
995             sum=escript.Scalar(0,arg.getFunctionSpace())
996             for i in range(arg.getShape()[0]):
997                sum+=arg0[i]*arg1[i]
998        elif arg.getRank()==2:
999            sum=escript.Scalar(0,arg.getFunctionSpace())
1000            for i in range(arg.getShape()[0]):
1001               for j in range(arg.getShape()[1]):
1002                  sum+=arg0[i,j]*arg1[i,j]
1003        elif arg.getRank()==3:
1004            sum=escript.Scalar(0,arg.getFunctionSpace())
1005            for i in range(arg.getShape()[0]):
1006                for j in range(arg.getShape()[1]):
1007                   for k in range(arg.getShape()[2]):
1008                      sum+=arg0[i,j,k]*arg1[i,j,k]
1009        elif arg.getRank()==4:
1010            sum=escript.Scalar(0,arg.getFunctionSpace())
1011            for i in range(arg.getShape()[0]):
1012               for j in range(arg.getShape()[1]):
1013                  for k in range(arg.getShape()[2]):
1014                     for l in range(arg.getShape()[3]):
1015                        sum+=arg0[i,j,k,l]*arg1[i,j,k,l]
1016      else:      else:
1017         return numarray.greater(arg,numarray.zeros(arg.shape))-numarray.less(arg,numarray.zeros(arg.shape))            raise SystemError,"inner is not been implemented yet"
1018        return sum
1019    
1020  # reduction operations:  def matrixmult(arg0,arg1):
1021    
1022        if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1023            numarray.matrixmult(arg0,arg1)
1024        else:
1025          # escript.matmult(arg0,arg1)
1026          if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1027            arg0=escript.Data(arg0,arg1.getFunctionSpace())
1028          elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1029            arg1=escript.Data(arg1,arg0.getFunctionSpace())
1030          if arg0.getRank()==2 and arg1.getRank()==1:
1031              out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1032              for i in range(arg0.getShape()[0]):
1033                 for j in range(arg0.getShape()[1]):
1034                   out[i]+=arg0[i,j]*arg1[j]
1035              return out
1036          else:
1037              raise SystemError,"matrixmult is not fully implemented yet!"
1038    #=========================================================
1039    # reduction operations:
1040    #=========================================================
1041  def sum(arg):  def sum(arg):
1042      """      """
1043      @brief      @brief
# Line 229  def sup(arg): Line 1054  def sup(arg):
1054      """      """
1055      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
1056         return arg.sup()         return arg.sup()
1057        elif isinstance(arg,float) or isinstance(arg,int):
1058           return arg
1059      else:      else:
1060         return arg.max()         return arg.max()
1061    
# Line 240  def inf(arg): Line 1067  def inf(arg):
1067      """      """
1068      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
1069         return arg.inf()         return arg.inf()
1070        elif isinstance(arg,float) or isinstance(arg,int):
1071           return arg
1072      else:      else:
1073         return arg.min()         return arg.min()
1074    
# Line 249  def L2(arg): Line 1078  def L2(arg):
1078    
1079      @param arg      @param arg
1080      """      """
1081      return arg.L2()      if isinstance(arg,escript.Data):
1082           return arg.L2()
1083        elif isinstance(arg,float) or isinstance(arg,int):
1084           return abs(arg)
1085        else:
1086           return numarry.sqrt(dot(arg,arg))
1087    
1088  def Lsup(arg):  def Lsup(arg):
1089      """      """
# Line 259  def Lsup(arg): Line 1093  def Lsup(arg):
1093      """      """
1094      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
1095         return arg.Lsup()         return arg.Lsup()
1096        elif isinstance(arg,float) or isinstance(arg,int):
1097           return abs(arg)
1098      else:      else:
1099         return arg.max(numarray.abs(arg))         return max(numarray.abs(arg))
1100    
1101  def dot(arg1,arg2):  def dot(arg0,arg1):
1102      """      """
1103      @brief      @brief
1104    
1105      @param arg      @param arg
1106      """      """
1107      if isinstance(arg1,escript.Data):      if isinstance(arg0,escript.Data):
1108         return arg1.dot(arg2)         return arg0.dot(arg1)
1109      elif isinstance(arg1,escript.Data):      elif isinstance(arg1,escript.Data):
1110         return arg2.dot(arg1)         return arg1.dot(arg0)
1111      else:      else:
1112         return numarray.dot(arg1,arg2)         return numarray.dot(arg0,arg1)
1113    
1114    def kronecker(d):
1115       if hasattr(d,"getDim"):
1116          return numarray.identity(d.getDim())
1117       else:
1118          return numarray.identity(d)
1119    
1120    def unit(i,d):
1121       """
1122       @brief return a unit vector of dimension d with nonzero index i
1123       @param d dimension
1124       @param i index
1125       """
1126       e = numarray.zeros((d,),numarray.Float)
1127       e[i] = 1.0
1128       return e
1129    
1130    #
1131    # ============================================
1132    #   testing
1133    # ============================================
1134    
1135    if __name__=="__main__":
1136      u=ScalarSymbol(dim=2,name="u")
1137      v=ScalarSymbol(dim=2,name="v")
1138      v=VectorSymbol(2,"v")
1139      u=VectorSymbol(2,"u")
1140    
1141    
1142      print u+5,(u+5).diff(u)
1143      print 5+u,(5+u).diff(u)
1144      print u+v,(u+v).diff(u)
1145      print v+u,(v+u).diff(u)
1146    
1147      print u*5,(u*5).diff(u)
1148      print 5*u,(5*u).diff(u)
1149      print u*v,(u*v).diff(u)
1150      print v*u,(v*u).diff(u)
1151    
1152      print u-5,(u-5).diff(u)
1153      print 5-u,(5-u).diff(u)
1154      print u-v,(u-v).diff(u)
1155      print v-u,(v-u).diff(u)
1156    
1157      print u/5,(u/5).diff(u)
1158      print 5/u,(5/u).diff(u)
1159      print u/v,(u/v).diff(u)
1160      print v/u,(v/u).diff(u)
1161    
1162      print u**5,(u**5).diff(u)
1163      print 5**u,(5**u).diff(u)
1164      print u**v,(u**v).diff(u)
1165      print v**u,(v**u).diff(u)
1166    
1167      print exp(u),exp(u).diff(u)
1168      print sqrt(u),sqrt(u).diff(u)
1169      print log(u),log(u).diff(u)
1170      print sin(u),sin(u).diff(u)
1171      print cos(u),cos(u).diff(u)
1172      print tan(u),tan(u).diff(u)
1173      print sign(u),sign(u).diff(u)
1174      print abs(u),abs(u).diff(u)
1175      print wherePositive(u),wherePositive(u).diff(u)
1176      print whereNegative(u),whereNegative(u).diff(u)
1177      print maxval(u),maxval(u).diff(u)
1178      print minval(u),minval(u).diff(u)
1179    
1180      g=grad(u)
1181      print diff(5*g,g)
1182      4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1183    #
1184    # $Log$
1185    # Revision 1.11  2005/07/08 04:07:35  jgs
1186    # Merge of development branch back to main trunk on 2005-07-08
1187    #
1188    # Revision 1.10  2005/06/09 05:37:59  jgs
1189    # Merge of development branch back to main trunk on 2005-06-09
1190    #
1191    # Revision 1.2.2.17  2005/07/07 07:28:58  gross
1192    # some stuff added to util.py to improve functionality
1193    #
1194    # Revision 1.2.2.16  2005/06/30 01:53:55  gross
1195    # a bug in coloring fixed
1196    #
1197    # Revision 1.2.2.15  2005/06/29 02:36:43  gross
1198    # Symbols have been introduced and some function clarified. needs much more work
1199    #
1200    # Revision 1.2.2.14  2005/05/20 04:05:23  gross
1201    # some work on a darcy flow started
1202    #
1203    # Revision 1.2.2.13  2005/03/16 05:17:58  matt
1204    # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1205    # complement kronecker(dim) function.
1206    #
1207    # Revision 1.2.2.12  2005/03/10 08:14:37  matt
1208    # Added non-member Linf utility function to complement Data::Linf().
1209    #
1210    # Revision 1.2.2.11  2005/02/17 05:53:25  gross
1211    # some bug in saveDX fixed: in fact the bug was in
1212    # DataC/getDataPointShape
1213    #
1214    # Revision 1.2.2.10  2005/01/11 04:59:36  gross
1215    # automatic interpolation in integrate switched off
1216    #
1217    # Revision 1.2.2.9  2005/01/11 03:38:13  gross
1218    # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1219    #
1220    # Revision 1.2.2.8  2005/01/05 04:21:41  gross
1221    # FunctionSpace checking/matchig in slicing added
1222    #
1223    # Revision 1.2.2.7  2004/12/29 05:29:59  gross
1224    # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1225    #
1226    # Revision 1.2.2.6  2004/12/24 06:05:41  gross
1227    # some changes in linearPDEs to add AdevectivePDE
1228    #
1229    # Revision 1.2.2.5  2004/12/17 00:06:53  gross
1230    # mk sets ESYS_ROOT is undefined
1231    #
1232    # Revision 1.2.2.4  2004/12/07 03:19:51  gross
1233    # options for GMRES and PRES20 added
1234    #
1235    # Revision 1.2.2.3  2004/12/06 04:55:18  gross
1236    # function wraper extended
1237    #
1238    # Revision 1.2.2.2  2004/11/22 05:44:07  gross
1239    # a few more unitary functions have been added but not implemented in Data yet
1240    #
1241    # Revision 1.2.2.1  2004/11/12 06:58:15  gross
1242    # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1243    #
1244    # Revision 1.2  2004/10/27 00:23:36  jgs
1245    # fixed minor syntax error
1246    #
1247    # Revision 1.1.1.1  2004/10/26 06:53:56  jgs
1248    # initial import of project esys2
1249    #
1250    # Revision 1.1.2.3  2004/10/26 06:43:48  jgs
1251    # committing Lutz's and Paul's changes to brach jgs
1252    #
1253    # Revision 1.1.4.1  2004/10/20 05:32:51  cochrane
1254    # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.
1255    #
1256    # Revision 1.1  2004/08/05 03:58:27  gross
1257    # Bug in Assemble_NodeCoordinates fixed
1258    #
1259    #

Legend:
Removed from v.102  
changed lines
  Added in v.123

  ViewVC Help
Powered by ViewVC 1.1.26