/[escript]/trunk-mpi-branch/escript/py_src/util.py
ViewVC logotype

Diff of /trunk-mpi-branch/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 149 by jgs, Thu Sep 1 03:31:39 2005 UTC revision 150 by jgs, Thu Sep 15 03:44:45 2005 UTC
# Line 19  Utility functions for escript Line 19  Utility functions for escript
19    
20  import numarray  import numarray
21  import escript  import escript
22    import symbols
23  #===========================================================  import os
 # a simple tool box to deal with _differentials of functions  
 #===========================================================  
   
 class Symbol:  
    """  
    Symbol class.  
    """  
    def __init__(self,name="symbol",shape=(),dim=3,args=[]):  
        """  
        Creates an instance of a symbol of shape shape and spatial dimension  
        dim.  
         
        The symbol may depending on a list of arguments args which may be  
        symbols or other objects. name gives the name of the symbol.  
        """  
   
        self.__args=args  
        self.__name=name  
        self.__shape=shape  
        if hasattr(dim,"getDim"):  
            self.__dim=dim.getDim()  
        else:      
            self.__dim=dim  
        #  
        self.__cache_val=None  
        self.__cache_argval=None  
   
    def getArgument(self,i):  
        """  
        Returns the i-th argument.  
        """  
        return self.__args[i]  
   
    def getDim(self):  
        """  
        Returns the spatial dimension of the symbol.  
        """  
        return self.__dim  
   
    def getRank(self):  
        """  
        Returns the rank of the symbol.  
        """  
        return len(self.getShape())  
   
    def getShape(self):  
        """  
        Returns the shape of the symbol.  
        """  
        return self.__shape  
   
    def getEvaluatedArguments(self,argval):  
        """  
        Returns the list of evaluated arguments by subsituting symbol u by  
        argval[u].  
        """  
        if argval==self.__cache_argval:  
            print "%s: cached value used"%self  
            return self.__cache_val  
        else:  
            out=[]  
            for a  in self.__args:  
               if isinstance(a,Symbol):  
                 out.append(a.eval(argval))  
               else:  
                 out.append(a)  
            self.__cache_argval=argval  
            self.__cache_val=out  
            return out  
   
    def getDifferentiatedArguments(self,arg):  
        """  
        Returns the list of the arguments _differentiated by arg.  
        """  
        out=[]  
        for a in self.__args:  
           if isinstance(a,Symbol):  
             out.append(a.diff(arg))  
           else:  
             out.append(0)  
        return out  
   
    def diff(self,arg):  
        """  
        Returns the _differention of self by arg.  
        """  
        if self==arg:  
           out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)  
           if self.getRank()==0:  
              out=1.  
           elif self.getRank()==1:  
               for i0 in range(self.getShape()[0]):  
                  out[i0,i0]=1.    
           elif self.getRank()==2:  
               for i0 in range(self.getShape()[0]):  
                 for i1 in range(self.getShape()[1]):  
                      out[i0,i1,i0,i1]=1.    
           elif self.getRank()==3:  
               for i0 in range(self.getShape()[0]):  
                 for i1 in range(self.getShape()[1]):  
                   for i2 in range(self.getShape()[2]):  
                      out[i0,i1,i2,i0,i1,i2]=1.    
           elif self.getRank()==4:  
               for i0 in range(self.getShape()[0]):  
                 for i1 in range(self.getShape()[1]):  
                   for i2 in range(self.getShape()[2]):  
                     for i3 in range(self.getShape()[3]):  
                        out[i0,i1,i2,i3,i0,i1,i2,i3]=1.    
           else:  
              raise ValueError,"differential support rank<5 only."  
           return out  
        else:  
           return self._diff(arg)  
   
    def _diff(self,arg):  
        """  
        Return derivate of self with respect to arg (!=self).  
   
        This method is overwritten by a particular symbol.  
        """  
        return 0  
   
    def eval(self,argval):  
        """  
        Subsitutes symbol u in self by argval[u] and returns the result. If  
        self is not a key of argval then self is returned.  
        """  
        if argval.has_key(self):  
          return argval[self]  
        else:  
          return self  
   
    def __str__(self):  
        """  
        Returns a string representation of the symbol.  
        """  
        return self.__name  
   
    def __add__(self,other):  
        """  
        Adds other to symbol self. if _testForZero(other) self is returned.  
        """  
        if _testForZero(other):  
           return self  
        else:  
           a=_matchShape([self,other])  
           return Add_Symbol(a[0],a[1])  
   
    def __radd__(self,other):  
        """  
        Adds other to symbol self. if _testForZero(other) self is returned.  
        """  
        return self+other  
   
    def __neg__(self):  
        """  
        Returns -self.  
        """  
        return self*(-1.)  
   
    def __pos__(self):  
        """  
        Returns +self.  
        """  
        return self  
   
    def __abs__(self):  
        """  
        Returns absolute value.  
        """  
        return Abs_Symbol(self)  
   
    def __sub__(self,other):  
        """  
        Subtracts other from symbol self.  
         
        If _testForZero(other) self is returned.  
        """  
        if _testForZero(other):  
           return self  
        else:  
           return self+(-other)  
   
    def __rsub__(self,other):  
        """  
        Subtracts symbol self from other.  
        """  
        return -self+other  
   
    def __div__(self,other):  
        """  
        Divides symbol self by other.  
        """  
        if isinstance(other,Symbol):  
           a=_matchShape([self,other])  
           return Div_Symbol(a[0],a[1])  
        else:  
           return self*(1./other)  
   
    def __rdiv__(self,other):  
        """  
        Dived other by symbol self. if _testForZero(other) 0 is returned.  
        """  
        if _testForZero(other):  
           return 0  
        else:  
           a=_matchShape([self,other])  
           return Div_Symbol(a[0],a[1])  
   
    def __pow__(self,other):  
        """  
        Raises symbol self to the power of other.  
        """  
        a=_matchShape([self,other])  
        return Power_Symbol(a[0],a[1])  
   
    def __rpow__(self,other):  
        """  
        Raises other to the symbol self.  
        """  
        a=_matchShape([self,other])  
        return Power_Symbol(a[1],a[0])  
   
    def __mul__(self,other):  
        """  
        Multiplies other by symbol self. if _testForZero(other) 0 is returned.  
        """  
        if _testForZero(other):  
           return 0  
        else:  
           a=_matchShape([self,other])  
           return Mult_Symbol(a[0],a[1])  
   
    def __rmul__(self,other):  
        """  
        Multiplies other by symbol self. if _testSForZero(other) 0 is returned.  
        """  
        return self*other  
   
    def __getitem__(self,sl):  
           print sl  
   
 class Float_Symbol(Symbol):  
     def __init__(self,name="symbol",shape=(),args=[]):  
         Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])  
   
 class ScalarSymbol(Symbol):  
    """  
    A scalar symbol.  
    """  
    def __init__(self,dim=3,name="scalar"):  
       """  
       Creates a scalar symbol of spatial dimension dim.  
       """  
       if hasattr(dim,"getDim"):  
            d=dim.getDim()  
       else:      
            d=dim  
       Symbol.__init__(self,shape=(),dim=d,name=name)  
   
 class VectorSymbol(Symbol):  
    """  
    A vector symbol.  
    """  
    def __init__(self,dim=3,name="vector"):  
       """  
       Creates a vector symbol of spatial dimension dim.  
       """  
       if hasattr(dim,"getDim"):  
            d=dim.getDim()  
       else:      
            d=dim  
       Symbol.__init__(self,shape=(d,),dim=d,name=name)  
   
 class TensorSymbol(Symbol):  
    """  
    A tensor symbol.  
    """  
    def __init__(self,dim=3,name="tensor"):  
       """  
       Creates a tensor symbol of spatial dimension dim.  
       """  
       if hasattr(dim,"getDim"):  
            d=dim.getDim()  
       else:      
            d=dim  
       Symbol.__init__(self,shape=(d,d),dim=d,name=name)  
   
 class Tensor3Symbol(Symbol):  
    """  
    A tensor order 3 symbol.  
    """  
    def __init__(self,dim=3,name="tensor3"):  
       """  
       Creates a tensor order 3 symbol of spatial dimension dim.  
       """  
       if hasattr(dim,"getDim"):  
            d=dim.getDim()  
       else:      
            d=dim  
       Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)  
   
 class Tensor4Symbol(Symbol):  
    """  
    A tensor order 4 symbol.  
    """  
    def __init__(self,dim=3,name="tensor4"):  
       """  
       Creates a tensor order 4 symbol of spatial dimension dim.  
       """  
       if hasattr(dim,"getDim"):  
            d=dim.getDim()  
       else:      
            d=dim  
       Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)  
   
 class Add_Symbol(Symbol):  
    """  
    Symbol representing the sum of two arguments.  
    """  
    def __init__(self,arg0,arg1):  
        a=[arg0,arg1]  
        Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)  
    def __str__(self):  
       return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))  
    def eval(self,argval):  
        a=self.getEvaluatedArguments(argval)  
        return a[0]+a[1]  
    def _diff(self,arg):  
        a=self.getDifferentiatedArguments(arg)  
        return a[0]+a[1]  
   
 class Mult_Symbol(Symbol):  
    """  
    Symbol representing the product of two arguments.  
    """  
    def __init__(self,arg0,arg1):  
        a=[arg0,arg1]  
        Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)  
    def __str__(self):  
       return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))  
    def eval(self,argval):  
        a=self.getEvaluatedArguments(argval)  
        return a[0]*a[1]  
    def _diff(self,arg):  
        a=self.getDifferentiatedArguments(arg)  
        return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]  
   
 class Div_Symbol(Symbol):  
    """  
    Symbol representing the quotient of two arguments.  
    """  
    def __init__(self,arg0,arg1):  
        a=[arg0,arg1]  
        Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)  
    def __str__(self):  
       return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))  
    def eval(self,argval):  
        a=self.getEvaluatedArguments(argval)  
        return a[0]/a[1]  
    def _diff(self,arg):  
        a=self.getDifferentiatedArguments(arg)  
        return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \  
                           (self.getArgument(1)*self.getArgument(1))  
   
 class Power_Symbol(Symbol):  
    """  
    Symbol representing the power of the first argument to the power of the  
    second argument.  
    """  
    def __init__(self,arg0,arg1):  
        a=[arg0,arg1]  
        Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)  
    def __str__(self):  
       return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))  
    def eval(self,argval):  
        a=self.getEvaluatedArguments(argval)  
        return a[0]**a[1]  
    def _diff(self,arg):  
        a=self.getDifferentiatedArguments(arg)  
        return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])  
   
 class Abs_Symbol(Symbol):  
    """  
    Symbol representing absolute value of its argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "abs(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return abs(self.getEvaluatedArguments(argval)[0])  
    def _diff(self,arg):  
        return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]  
24    
25  #=========================================================  #=========================================================
26  #   some little helpers  #   some little helpers
# Line 427  def _testForZero(arg): Line 33  def _testForZero(arg):
33        return not arg>0        return not arg>0
34     elif isinstance(arg,float):     elif isinstance(arg,float):
35        return not arg>0.        return not arg>0.
36     elif isinstance(arg,numarray.NumArray):     elif isinstance(arg,numarray.NumArray):
37        a=abs(arg)        a=abs(arg)
38        while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)        while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
39        return not a>0        return not a>0
40     else:     else:
41        return False        return False
42    
43  def _extractDim(args):  #=========================================================
44      dim=None  def saveVTK(filename,**data):
45      for a in args:      """
46         if hasattr(a,"getDim"):      writes arg into files in the vtk file format
           d=a.getDim()  
           if dim==None:  
              dim=d  
           else:  
              if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"  
     if dim==None:  
        raise ValueError,"cannot recover spatial dimension"  
     return dim  
47    
48  def _identifyShape(arg):             saveVTK(<filename>,<data name 1>=<data object 1>,...,<data name n>=<data object n>)  
49     """  
50     Identifies the shape of arg.        This will create VTK files of the name <dir name>+<data name i>+"."+<extension> where <filename>=<dir name>+<extension>
    """  
    if hasattr(arg,"getShape"):  
        arg_shape=arg.getShape()  
    elif hasattr(arg,"shape"):  
      s=arg.shape  
      if callable(s):  
        arg_shape=s()  
      else:  
        arg_shape=s  
    else:  
        arg_shape=()  
    return arg_shape  
51    
 def _extractShape(args):  
     """  
     Extracts the common shape of the list of arguments args.  
52      """      """
53      shape=None      ex=os.path.split(filename)
54      for a in args:      for i in data.keys():
55         a_shape=_identifyShape(a)         data[i].saveVTK(os.path.join(ex[0],i+"."+ex[1]))
        if shape==None: shape=a_shape  
        if shape!=a_shape: raise ValueError,"inconsistent shape"  
     if shape==None:  
        raise ValueError,"cannot recover shape"  
     return shape  
   
 def _matchShape(args,shape=None):  
     """  
     Returns the list of arguments args as object which have all the  
     specified shape.  
   
     If shape is not given the shape "largest" shape of args is used.  
     """  
     # identify the list of shapes:  
     arg_shapes=[]  
     for a in args: arg_shapes.append(_identifyShape(a))  
     # get the largest shape (currently the longest shape):  
     if shape==None: shape=max(arg_shapes)  
       
     out=[]  
     for i in range(len(args)):  
        if shape==arg_shapes[i]:  
           out.append(args[i])  
        else:  
           if len(shape)==0: # then len(arg_shapes[i])>0  
             raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))  
           else:  
             if len(arg_shapes[i])==0:  
                 out.append(outer(args[i],numarray.ones(shape)))          
             else:    
                 raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))  
     return out    
56    
57  #=========================================================  #=========================================================
58  #   wrappers for various mathematical functions:  def saveDX(filename,**data):
 #=========================================================  
 def diff(arg,dep):  
59      """      """
60      Returns the derivative of arg with respect to dep.      writes arg into file in the openDX file format
61        
62      If arg is not Symbol object 0 is returned.             saveDX(<filename>,<data name 1>=<data object 1>,...,<data name n>=<data object n>)  
63    
64          This will create DX files of the name <dir name>+<data name i>+"."+<extension> where <filename>=<dir name>+<extension>
65    
66      """      """
67      if isinstance(arg,Symbol):      ex=os.path.split(filename)
68         return arg.diff(dep)      for i in data.keys():
69      elif hasattr(arg,"shape"):         data[i].saveDX(os.path.join(ex[0],i+"."+ex[1]))
70            if callable(arg.shape):  
71                return numarray.zeros(arg.shape(),numarray.Float)  #=========================================================
           else:  
               return numarray.zeros(arg.shape,numarray.Float)  
     else:  
        return 0  
72    
73  def exp(arg):  def exp(arg):
74      """      """
# Line 528  def exp(arg): Line 76  def exp(arg):
76    
77      @param arg: argument      @param arg: argument
78      """      """
79      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
80         return Exp_Symbol(arg)         return symbols.Exp_Symbol(arg)
81      elif hasattr(arg,"exp"):      elif hasattr(arg,"exp"):
82         return arg.exp()         return arg.exp()
83      else:      else:
84         return numarray.exp(arg)         return numarray.exp(arg)
85    
 class Exp_Symbol(Symbol):  
    """  
    Symbol representing the power of the first argument to the power of the  
    second argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "exp(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return exp(self.getEvaluatedArguments(argval)[0])  
    def _diff(self,arg):  
        return self*self.getDifferentiatedArguments(arg)[0]  
   
86  def sqrt(arg):  def sqrt(arg):
87      """      """
88      Applies the squre root function to arg.      Applies the squre root function to arg.
89    
90      @param arg: argument      @param arg: argument
91      """      """
92      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
93         return Sqrt_Symbol(arg)         return symbols.Sqrt_Symbol(arg)
94      elif hasattr(arg,"sqrt"):      elif hasattr(arg,"sqrt"):
95         return arg.sqrt()         return arg.sqrt()
96      else:      else:
97         return numarray.sqrt(arg)               return numarray.sqrt(arg)      
98    
 class Sqrt_Symbol(Symbol):  
    """  
    Symbol representing square root of argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "sqrt(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return sqrt(self.getEvaluatedArguments(argval)[0])  
    def _diff(self,arg):  
        return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]  
   
99  def log(arg):  def log(arg):
100      """      """
101      Applies the logarithmic function bases exp(1.) to arg      Applies the logarithmic function bases exp(1.) to arg
102    
103      @param arg: argument      @param arg: argument
104      """      """
105      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
106         return Log_Symbol(arg)         return symbols.Log_Symbol(arg)
107      elif hasattr(arg,"log"):      elif hasattr(arg,"log"):
108         return arg.log()         return arg.log()
109      else:      else:
110         return numarray.log(arg)         return numarray.log(arg)
111    
 class Log_Symbol(Symbol):  
    """  
    Symbol representing logarithm of the argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "log(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return log(self.getEvaluatedArguments(argval)[0])  
    def _diff(self,arg):  
        return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)  
   
112  def ln(arg):  def ln(arg):
113      """      """
114      Applies the natural logarithmic function to arg.      Applies the natural logarithmic function to arg.
115    
116      @param arg: argument      @param arg: argument
117      """      """
118      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
119         return Ln_Symbol(arg)         return symbols.Ln_Symbol(arg)
120      elif hasattr(arg,"ln"):      elif hasattr(arg,"ln"):
121         return arg.log()         return arg.log()
122      else:      else:
123         return numarray.log(arg)         return numarray.log(arg)
124    
 class Ln_Symbol(Symbol):  
    """  
    Symbol representing natural logarithm of the argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "ln(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return ln(self.getEvaluatedArguments(argval)[0])  
    def _diff(self,arg):  
        return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)  
   
125  def sin(arg):  def sin(arg):
126      """      """
127      Applies the sin function to arg.      Applies the sin function to arg.
128    
129      @param arg: argument      @param arg: argument
130      """      """
131      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
132         return Sin_Symbol(arg)         return symbols.Sin_Symbol(arg)
133      elif hasattr(arg,"sin"):      elif hasattr(arg,"sin"):
134         return arg.sin()         return arg.sin()
135      else:      else:
136         return numarray.sin(arg)         return numarray.sin(arg)
137    
 class Sin_Symbol(Symbol):  
    """  
    Symbol representing sin of the argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "sin(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return sin(self.getEvaluatedArguments(argval)[0])  
    def _diff(self,arg):  
        return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]  
   
138  def cos(arg):  def cos(arg):
139      """      """
140      Applies the cos function to arg.      Applies the cos function to arg.
141    
142      @param arg: argument      @param arg: argument
143      """      """
144      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
145         return Cos_Symbol(arg)         return symbols.Cos_Symbol(arg)
146      elif hasattr(arg,"cos"):      elif hasattr(arg,"cos"):
147         return arg.cos()         return arg.cos()
148      else:      else:
149         return numarray.cos(arg)         return numarray.cos(arg)
150    
 class Cos_Symbol(Symbol):  
    """  
    Symbol representing cos of the argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "cos(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return cos(self.getEvaluatedArguments(argval)[0])  
    def _diff(self,arg):  
        return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]  
   
151  def tan(arg):  def tan(arg):
152      """      """
153      Applies the tan function to arg.      Applies the tan function to arg.
154    
155      @param arg: argument      @param arg: argument
156      """      """
157      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
158         return Tan_Symbol(arg)         return symbols.Tan_Symbol(arg)
159      elif hasattr(arg,"tan"):      elif hasattr(arg,"tan"):
160         return arg.tan()         return arg.tan()
161      else:      else:
162         return numarray.tan(arg)         return numarray.tan(arg)
163    
164  class Tan_Symbol(Symbol):  def asin(arg):
165     """      """
166     Symbol representing tan of the argument.      Applies the asin function to arg.
167     """  
168     def __init__(self,arg):      @param arg: argument
169         Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])      """
170     def __str__(self):      if isinstance(arg,symbols.Symbol):
171        return "tan(%s)"%str(self.getArgument(0))         return symbols.Asin_Symbol(arg)
172     def eval(self,argval):      elif hasattr(arg,"asin"):
173         return tan(self.getEvaluatedArguments(argval)[0])         return arg.asin()
174     def _diff(self,arg):      else:
175         s=cos(self.getArgument(0))         return numarray.asin(arg)
176         return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]  
177    def acos(arg):
178        """
179        Applies the acos function to arg.
180    
181        @param arg: argument
182        """
183        if isinstance(arg,symbols.Symbol):
184           return symbols.Acos_Symbol(arg)
185        elif hasattr(arg,"acos"):
186           return arg.acos()
187        else:
188           return numarray.acos(arg)
189    
190    def atan(arg):
191        """
192        Applies the atan function to arg.
193    
194        @param arg: argument
195        """
196        if isinstance(arg,symbols.Symbol):
197           return symbols.Atan_Symbol(arg)
198        elif hasattr(arg,"atan"):
199           return arg.atan()
200        else:
201           return numarray.atan(arg)
202    
203    def sinh(arg):
204        """
205        Applies the sinh function to arg.
206    
207        @param arg: argument
208        """
209        if isinstance(arg,symbols.Symbol):
210           return symbols.Sinh_Symbol(arg)
211        elif hasattr(arg,"sinh"):
212           return arg.sinh()
213        else:
214           return numarray.sinh(arg)
215    
216    def cosh(arg):
217        """
218        Applies the cosh function to arg.
219    
220        @param arg: argument
221        """
222        if isinstance(arg,symbols.Symbol):
223           return symbols.Cosh_Symbol(arg)
224        elif hasattr(arg,"cosh"):
225           return arg.cosh()
226        else:
227           return numarray.cosh(arg)
228    
229    def tanh(arg):
230        """
231        Applies the tanh function to arg.
232    
233        @param arg: argument
234        """
235        if isinstance(arg,symbols.Symbol):
236           return symbols.Tanh_Symbol(arg)
237        elif hasattr(arg,"tanh"):
238           return arg.tanh()
239        else:
240           return numarray.tanh(arg)
241    
242    def asinh(arg):
243        """
244        Applies the asinh function to arg.
245    
246        @param arg: argument
247        """
248        if isinstance(arg,symbols.Symbol):
249           return symbols.Asinh_Symbol(arg)
250        elif hasattr(arg,"asinh"):
251           return arg.asinh()
252        else:
253           return numarray.asinh(arg)
254    
255    def acosh(arg):
256        """
257        Applies the acosh function to arg.
258    
259        @param arg: argument
260        """
261        if isinstance(arg,symbols.Symbol):
262           return symbols.Acosh_Symbol(arg)
263        elif hasattr(arg,"acosh"):
264           return arg.acosh()
265        else:
266           return numarray.acosh(arg)
267    
268    def atanh(arg):
269        """
270        Applies the atanh function to arg.
271    
272        @param arg: argument
273        """
274        if isinstance(arg,symbols.Symbol):
275           return symbols.Atanh_Symbol(arg)
276        elif hasattr(arg,"atanh"):
277           return arg.atanh()
278        else:
279           return numarray.atanh(arg)
280    
281  def sign(arg):  def sign(arg):
282      """      """
# Line 712  def sign(arg): Line 284  def sign(arg):
284    
285      @param arg: argument      @param arg: argument
286      """      """
287      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
288         return Sign_Symbol(arg)         return symbols.Sign_Symbol(arg)
289      elif hasattr(arg,"sign"):      elif hasattr(arg,"sign"):
290         return arg.sign()         return arg.sign()
291      else:      else:
292         return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \         return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
293                numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))                numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
294    
 class Sign_Symbol(Symbol):  
    """  
    Symbol representing the sign of the argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "sign(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return sign(self.getEvaluatedArguments(argval)[0])  
   
295  def maxval(arg):  def maxval(arg):
296      """      """
297      Returns the maximum value of argument arg.      Returns the maximum value of argument arg.
298    
299      @param arg: argument      @param arg: argument
300      """      """
301      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
302         return Max_Symbol(arg)         return symbols.Max_Symbol(arg)
303      elif hasattr(arg,"maxval"):      elif hasattr(arg,"maxval"):
304         return arg.maxval()         return arg.maxval()
305      elif hasattr(arg,"max"):      elif hasattr(arg,"max"):
# Line 746  def maxval(arg): Line 307  def maxval(arg):
307      else:      else:
308         return arg         return arg
309    
 class Max_Symbol(Symbol):  
    """  
    Symbol representing the maximum value of the argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "maxval(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return maxval(self.getEvaluatedArguments(argval)[0])  
   
310  def minval(arg):  def minval(arg):
311      """      """
312      Returns the minimum value of argument arg.      Returns the minimum value of argument arg.
313    
314      @param arg: argument      @param arg: argument
315      """      """
316      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
317         return Min_Symbol(arg)         return symbols.Min_Symbol(arg)
318      elif hasattr(arg,"maxval"):      elif hasattr(arg,"maxval"):
319         return arg.minval()         return arg.minval()
320      elif hasattr(arg,"min"):      elif hasattr(arg,"min"):
# Line 772  def minval(arg): Line 322  def minval(arg):
322      else:      else:
323         return arg         return arg
324    
 class Min_Symbol(Symbol):  
    """  
    Symbol representing the minimum value of the argument.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "minval(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return minval(self.getEvaluatedArguments(argval)[0])  
   
325  def wherePositive(arg):  def wherePositive(arg):
326      """      """
327      Returns the positive values of argument arg.      Returns the positive values of argument arg.
# Line 791  def wherePositive(arg): Line 330  def wherePositive(arg):
330      """      """
331      if _testForZero(arg):      if _testForZero(arg):
332        return 0        return 0
333      elif isinstance(arg,Symbol):      elif isinstance(arg,symbols.Symbol):
334         return WherePositive_Symbol(arg)         return symbols.WherePositive_Symbol(arg)
335      elif hasattr(arg,"wherePositive"):      elif hasattr(arg,"wherePositive"):
336         return arg.minval()         return arg.minval()
337      elif hasattr(arg,"wherePositive"):      elif hasattr(arg,"wherePositive"):
# Line 803  def wherePositive(arg): Line 342  def wherePositive(arg):
342         else:         else:
343            return 0.            return 0.
344    
 class WherePositive_Symbol(Symbol):  
    """  
    Symbol representing the wherePositive function.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "wherePositive(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return wherePositive(self.getEvaluatedArguments(argval)[0])  
   
345  def whereNegative(arg):  def whereNegative(arg):
346      """      """
347      Returns the negative values of argument arg.      Returns the negative values of argument arg.
# Line 822  def whereNegative(arg): Line 350  def whereNegative(arg):
350      """      """
351      if _testForZero(arg):      if _testForZero(arg):
352        return 0        return 0
353      elif isinstance(arg,Symbol):      elif isinstance(arg,symbols.Symbol):
354         return WhereNegative_Symbol(arg)         return symbols.WhereNegative_Symbol(arg)
355      elif hasattr(arg,"whereNegative"):      elif hasattr(arg,"whereNegative"):
356         return arg.whereNegative()         return arg.whereNegative()
357      elif hasattr(arg,"shape"):      elif hasattr(arg,"shape"):
# Line 834  def whereNegative(arg): Line 362  def whereNegative(arg):
362         else:         else:
363            return 0.            return 0.
364    
 class WhereNegative_Symbol(Symbol):  
    """  
    Symbol representing the whereNegative function.  
    """  
    def __init__(self,arg):  
        Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])  
    def __str__(self):  
       return "whereNegative(%s)"%str(self.getArgument(0))  
    def eval(self,argval):  
        return whereNegative(self.getEvaluatedArguments(argval)[0])  
   
365  def maximum(arg0,arg1):  def maximum(arg0,arg1):
366      """      """
367      Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned.      Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned.
# Line 863  def outer(arg0,arg1): Line 380  def outer(arg0,arg1):
380     if _testForZero(arg0) or _testForZero(arg1):     if _testForZero(arg0) or _testForZero(arg1):
381        return 0        return 0
382     else:     else:
383        if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):        if isinstance(arg0,symbols.Symbol) or isinstance(arg1,symbols.Symbol):
384          return Outer_Symbol(arg0,arg1)          return symbols.Outer_Symbol(arg0,arg1)
385        elif _identifyShape(arg0)==() or _identifyShape(arg1)==():        elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
386          return arg0*arg1          return arg0*arg1
387        elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):        elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
# Line 879  def outer(arg0,arg1): Line 396  def outer(arg0,arg1):
396          else:          else:
397            raise ValueError,"outer is not fully implemented yet."            raise ValueError,"outer is not fully implemented yet."
398    
 class Outer_Symbol(Symbol):  
    """  
    Symbol representing the outer product of its two arguments.  
    """  
    def __init__(self,arg0,arg1):  
        a=[arg0,arg1]  
        s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))  
        Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)  
    def __str__(self):  
       return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))  
    def eval(self,argval):  
        a=self.getEvaluatedArguments(argval)  
        return outer(a[0],a[1])  
    def _diff(self,arg):  
        a=self.getDifferentiatedArguments(arg)  
        return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])  
   
399  def interpolate(arg,where):  def interpolate(arg,where):
400      """      """
401      Interpolates the function into the FunctionSpace where.      Interpolates the function into the FunctionSpace where.
# Line 905  def interpolate(arg,where): Line 405  def interpolate(arg,where):
405      """      """
406      if _testForZero(arg):      if _testForZero(arg):
407        return 0        return 0
408      elif isinstance(arg,Symbol):      elif isinstance(arg,symbols.Symbol):
409         return Interpolated_Symbol(arg,where)         return symbols.Interpolated_Symbol(arg,where)
410      else:      else:
411         return escript.Data(arg,where)         return escript.Data(arg,where)
412    
 class Interpolated_Symbol(Symbol):  
    """  
    Symbol representing the integral of the argument.  
    """  
    def __init__(self,arg,where):  
         Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])  
    def __str__(self):  
       return "interpolated(%s)"%(str(self.getArgument(0)))  
    def eval(self,argval):  
        a=self.getEvaluatedArguments(argval)  
        return integrate(a[0],where=self.getArgument(1))  
    def _diff(self,arg):  
        a=self.getDifferentiatedArguments(arg)  
        return integrate(a[0],where=self.getArgument(1))  
   
413  def div(arg,where=None):  def div(arg,where=None):
414      """      """
415      Returns the divergence of arg at where.      Returns the divergence of arg at where.
# Line 958  def grad(arg,where=None): Line 443  def grad(arg,where=None):
443      """      """
444      if _testForZero(arg):      if _testForZero(arg):
445        return 0        return 0
446      elif isinstance(arg,Symbol):      elif isinstance(arg,symbols.Symbol):
447         return Grad_Symbol(arg,where)         return symbols.Grad_Symbol(arg,where)
448      elif hasattr(arg,"grad"):      elif hasattr(arg,"grad"):
449         if where==None:         if where==None:
450            return arg.grad()            return arg.grad()
# Line 968  def grad(arg,where=None): Line 453  def grad(arg,where=None):
453      else:      else:
454         return arg*0.         return arg*0.
455    
 class Grad_Symbol(Symbol):  
    """  
    Symbol representing the gradient of the argument.  
    """  
    def __init__(self,arg,where=None):  
        d=_extractDim([arg])  
        s=tuple(list(_identifyShape([arg])).append(d))  
        Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])  
    def __str__(self):  
       return "grad(%s)"%(str(self.getArgument(0)))  
    def eval(self,argval):  
        a=self.getEvaluatedArguments(argval)  
        return grad(a[0],where=self.getArgument(1))  
    def _diff(self,arg):  
        a=self.getDifferentiatedArguments(arg)  
        return grad(a[0],where=self.getArgument(1))  
   
456  def integrate(arg,where=None):  def integrate(arg,where=None):
457      """      """
458      Return the integral if the function represented by Data object arg over      Return the integral if the function represented by Data object arg over
# Line 996  def integrate(arg,where=None): Line 464  def integrate(arg,where=None):
464      """      """
465      if _testForZero(arg):      if _testForZero(arg):
466        return 0        return 0
467      elif isinstance(arg,Symbol):      elif isinstance(arg,symbols.Symbol):
468         return Integral_Symbol(arg,where)         return symbols.Integral_Symbol(arg,where)
469      else:          else:    
470         if not where==None: arg=escript.Data(arg,where)         if not where==None: arg=escript.Data(arg,where)
471         if arg.getRank()==0:         if arg.getRank()==0:
# Line 1005  def integrate(arg,where=None): Line 473  def integrate(arg,where=None):
473         else:         else:
474           return arg.integrate()           return arg.integrate()
475    
 class Integral_Symbol(Float_Symbol):  
    """  
    Symbol representing the integral of the argument.  
    """  
    def __init__(self,arg,where=None):  
        Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])  
    def __str__(self):  
       return "integral(%s)"%(str(self.getArgument(0)))  
    def eval(self,argval):  
        a=self.getEvaluatedArguments(argval)  
        return integrate(a[0],where=self.getArgument(1))  
    def _diff(self,arg):  
        a=self.getDifferentiatedArguments(arg)  
        return integrate(a[0],where=self.getArgument(1))  
   
476  #=============================  #=============================
477  #  #
478  # wrapper for various functions: if the argument has attribute the function name  # wrapper for various functions: if the argument has attribute the function name
# Line 1042  def transpose(arg,axis=None): Line 495  def transpose(arg,axis=None):
495         if hasattr(arg,"getRank"): r=arg.getRank()         if hasattr(arg,"getRank"): r=arg.getRank()
496         if hasattr(arg,"rank"): r=arg.rank         if hasattr(arg,"rank"): r=arg.rank
497         axis=r/2         axis=r/2
498      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
499         return Transpose_Symbol(arg,axis=r)         return symbols.Transpose_Symbol(arg,axis=r)
500      if isinstance(arg,escript.Data):      if isinstance(arg,escript.Data):
501         # hack for transpose         # hack for transpose
502         r=arg.getRank()         r=arg.getRank()
# Line 1065  def trace(arg,axis0=0,axis1=1): Line 518  def trace(arg,axis0=0,axis1=1):
518    
519      @param arg:      @param arg:
520      """      """
521      if isinstance(arg,Symbol):      if isinstance(arg,symbols.Symbol):
522         s=list(arg.getShape())                 s=list(arg.getShape())        
523         s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])         s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
524         return Trace_Symbol(arg,axis0=axis0,axis1=axis1)         return symbols.Trace_Symbol(arg,axis0=axis0,axis1=axis1)
525      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
526         # hack for trace         # hack for trace
527         s=arg.getShape()         s=arg.getShape()
# Line 1082  def trace(arg,axis0=0,axis1=1): Line 535  def trace(arg,axis0=0,axis1=1):
535      else:      else:
536         return numarray.trace(arg,axis0=axis0,axis1=axis1)         return numarray.trace(arg,axis0=axis0,axis1=axis1)
537    
 def Trace_Symbol(Symbol):  
     pass  
   
538  def length(arg):  def length(arg):
539      """      """
540    
# Line 1290  def unit(i,d): Line 740  def unit(i,d):
740     e = numarray.zeros((d,),numarray.Float)     e = numarray.zeros((d,),numarray.Float)
741     e[i] = 1.0     e[i] = 1.0
742     return e     return e
   
 # ============================================  
 #   testing  
 # ============================================  
   
 if __name__=="__main__":  
   u=ScalarSymbol(dim=2,name="u")  
   v=ScalarSymbol(dim=2,name="v")  
   v=VectorSymbol(2,"v")  
   u=VectorSymbol(2,"u")  
   
   print u+5,(u+5).diff(u)  
   print 5+u,(5+u).diff(u)  
   print u+v,(u+v).diff(u)  
   print v+u,(v+u).diff(u)  
   
   print u*5,(u*5).diff(u)  
   print 5*u,(5*u).diff(u)  
   print u*v,(u*v).diff(u)  
   print v*u,(v*u).diff(u)  
   
   print u-5,(u-5).diff(u)  
   print 5-u,(5-u).diff(u)  
   print u-v,(u-v).diff(u)  
   print v-u,(v-u).diff(u)  
   
   print u/5,(u/5).diff(u)  
   print 5/u,(5/u).diff(u)  
   print u/v,(u/v).diff(u)  
   print v/u,(v/u).diff(u)  
   
   print u**5,(u**5).diff(u)  
   print 5**u,(5**u).diff(u)  
   print u**v,(u**v).diff(u)  
   print v**u,(v**u).diff(u)  
   
   print exp(u),exp(u).diff(u)  
   print sqrt(u),sqrt(u).diff(u)  
   print log(u),log(u).diff(u)  
   print sin(u),sin(u).diff(u)  
   print cos(u),cos(u).diff(u)  
   print tan(u),tan(u).diff(u)  
   print sign(u),sign(u).diff(u)  
   print abs(u),abs(u).diff(u)  
   print wherePositive(u),wherePositive(u).diff(u)  
   print whereNegative(u),whereNegative(u).diff(u)  
   print maxval(u),maxval(u).diff(u)  
   print minval(u),minval(u).diff(u)  
   
   g=grad(u)  
   print diff(5*g,g)  
   4*(g+transpose(g))/2+6*trace(g)*kronecker(3)  
   
743  #  #
744  # $Log$  # $Log$
745    # Revision 1.18  2005/09/15 03:44:19  jgs
746    # Merge of development branch dev-02 back to main trunk on 2005-09-15
747    #
748  # Revision 1.17  2005/09/01 03:31:28  jgs  # Revision 1.17  2005/09/01 03:31:28  jgs
749  # Merge of development branch dev-02 back to main trunk on 2005-09-01  # Merge of development branch dev-02 back to main trunk on 2005-09-01
750  #  #
# Line 1354  if __name__=="__main__": Line 754  if __name__=="__main__":
754  # Revision 1.15  2005/08/12 01:45:36  jgs  # Revision 1.15  2005/08/12 01:45:36  jgs
755  # erge of development branch dev-02 back to main trunk on 2005-08-12  # erge of development branch dev-02 back to main trunk on 2005-08-12
756  #  #
757    # Revision 1.14.2.13  2005/09/12 03:32:14  gross
758    # test_visualiztion has been aded to mk
759    #
760    # Revision 1.14.2.12  2005/09/09 01:56:24  jgs
761    # added implementations of acos asin atan sinh cosh tanh asinh acosh atanh
762    # and some associated testing
763    #
764    # Revision 1.14.2.11  2005/09/08 08:28:39  gross
765    # some cleanup in savevtk
766    #
767    # Revision 1.14.2.10  2005/09/08 00:25:32  gross
768    # test for finley mesh generators added
769    #
770    # Revision 1.14.2.9  2005/09/07 10:32:05  gross
771    # Symbols removed from util and put into symmbols.py.
772    #
773  # Revision 1.14.2.8  2005/08/26 05:06:37  cochrane  # Revision 1.14.2.8  2005/08/26 05:06:37  cochrane
774  # Corrected errors in docstrings.  Improved output formatting of docstrings.  # Corrected errors in docstrings.  Improved output formatting of docstrings.
775  # Other minor improvements to code and docs (eg spelling etc).  # Other minor improvements to code and docs (eg spelling etc).

Legend:
Removed from v.149  
changed lines
  Added in v.150

  ViewVC Help
Powered by ViewVC 1.1.26