/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 341 by gross, Mon Dec 12 05:26:10 2005 UTC revision 443 by gross, Fri Jan 20 06:22:38 2006 UTC
# Line 24  Utility functions for escript Line 24  Utility functions for escript
24  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
25  __licence__="contact: esys@access.uq.edu.au"  __licence__="contact: esys@access.uq.edu.au"
26  __url__="http://www.iservo.edu.au/esys/escript"  __url__="http://www.iservo.edu.au/esys/escript"
27  __version__="$Revision: 329 $"  __version__="$Revision$"
28  __date__="$Date$"  __date__="$Date$"
29    
30    
31  import math  import math
32  import numarray  import numarray
33    import numarray.linear_algebra
34  import escript  import escript
35  import os  import os
36    
# Line 43  import os Line 44  import os
44  # def matchType(arg0=0.,arg1=0.):  # def matchType(arg0=0.,arg1=0.):
45  # def matchShape(arg0,arg1):  # def matchShape(arg0,arg1):
46    
 # def maximum(arg0,arg1):  
 # def minimum(arg0,arg1):  
   
47  # def transpose(arg,axis=None):  # def transpose(arg,axis=None):
 # def trace(arg,axis0=0,axis1=1):  
48  # def reorderComponents(arg,index):  # def reorderComponents(arg,index):
49    
 # def integrate(arg,where=None):  
 # def interpolate(arg,where):  
 # def div(arg,where=None):  
 # def grad(arg,where=None):  
   
50  #  #
51  # slicing: get  # slicing: get
52  #          set  #          set
# Line 125  def kronecker(d=3): Line 117  def kronecker(d=3):
117     return the kronecker S{delta}-symbol     return the kronecker S{delta}-symbol
118    
119     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
120     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
121     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
122     @rtype d: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2.
    @remark: the function is identical L{identity}  
123     """     """
124     return identityTensor(d)     return identityTensor(d)
125    
# Line 147  def identity(shape=()): Line 138  def identity(shape=()):
138        if len(shape)==1:        if len(shape)==1:
139            for i0 in range(shape[0]):            for i0 in range(shape[0]):
140               out[i0,i0]=1.               out[i0,i0]=1.
   
141        elif len(shape)==2:        elif len(shape)==2:
142            for i0 in range(shape[0]):            for i0 in range(shape[0]):
143               for i1 in range(shape[1]):               for i1 in range(shape[1]):
# Line 163  def identityTensor(d=3): Line 153  def identityTensor(d=3):
153     return the dxd identity matrix     return the dxd identity matrix
154    
155     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
156     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
157     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
158     @rtype: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2
159     """     """
160     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
161        d=d.getDim()         return escript.Data(identity((d.getDim(),)),d)
162     return identity(shape=(d,))     elif isinstance(d,escript.Domain):
163           return identity((d.getDim(),))
164       else:
165           return identity((d,))
166    
167  def identityTensor4(d=3):  def identityTensor4(d=3):
168     """     """
# Line 178  def identityTensor4(d=3): Line 171  def identityTensor4(d=3):
171     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
172     @type d: C{int} or any object with a C{getDim} method     @type d: C{int} or any object with a C{getDim} method
173     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise
174     @rtype: L{numarray.NumArray} of rank 4.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 4.
175     """     """
176     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
177        d=d.getDim()         return escript.Data(identity((d.getDim(),d.getDim())),d)
178     return identity((d,d))     elif isinstance(d,escript.Domain):
179           return identity((d.getDim(),d.getDim()))
180       else:
181           return identity((d,d))
182    
183  def unitVector(i=0,d=3):  def unitVector(i=0,d=3):
184     """     """
# Line 191  def unitVector(i=0,d=3): Line 187  def unitVector(i=0,d=3):
187     @param i: index     @param i: index
188     @type i: C{int}     @type i: C{int}
189     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
190     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
191     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise
192     @rtype: L{numarray.NumArray} of rank 1.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 1
193     """     """
194     return kronecker(d)[i]     return kronecker(d)[i]
195    
# Line 363  def testForZero(arg): Line 359  def testForZero(arg):
359      @return : True if the argument is identical to zero.      @return : True if the argument is identical to zero.
360      @rtype : C{bool}      @rtype : C{bool}
361      """      """
362      try:      if isinstance(arg,numarray.NumArray):
363         return not Lsup(arg)>0.         return not Lsup(arg)>0.
364      except TypeError:      elif isinstance(arg,escript.Data):
365           return False
366        elif isinstance(arg,float):
367           return not Lsup(arg)>0.
368        elif isinstance(arg,int):
369           return not Lsup(arg)>0.
370        elif isinstance(arg,Symbol):
371           return False
372        else:
373         return False         return False
374    
375  def matchType(arg0=0.,arg1=0.):  def matchType(arg0=0.,arg1=0.):
# Line 907  def wherePositive(arg): Line 911  def wherePositive(arg):
911     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
912     """     """
913     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
914        if arg.rank==0:        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))*1.
915           if arg>0:        if isinstance(out,float): out=numarray.array(out)
916             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))  
917     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
918        return arg._wherePositive()        return arg._wherePositive()
919     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 993  def whereNegative(arg): Line 993  def whereNegative(arg):
993     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
994     """     """
995     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
996        if arg.rank==0:        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))*1.
997           if arg<0:        if isinstance(out,float): out=numarray.array(out)
998             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))  
999     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1000        return arg._whereNegative()        return arg._whereNegative()
1001     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1079  def whereNonNegative(arg): Line 1075  def whereNonNegative(arg):
1075     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1076     """     """
1077     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1078        if arg.rank==0:        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1079           if arg<0:        if isinstance(out,float): out=numarray.array(out)
1080             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))  
1081     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1082        return arg._whereNonNegative()        return arg._whereNonNegative()
1083     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1113  def whereNonPositive(arg): Line 1105  def whereNonPositive(arg):
1105     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1106     """     """
1107     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1108        if arg.rank==0:        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1109           if arg>0:        if isinstance(out,float): out=numarray.array(out)
1110             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.  
1111     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1112        return arg._whereNonPositive()        return arg._whereNonPositive()
1113     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1149  def whereZero(arg,tol=0.): Line 1137  def whereZero(arg,tol=0.):
1137     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1138     """     """
1139     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1140        if arg.rank==0:        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1141           if abs(arg)<=tol:        if isinstance(out,float): out=numarray.array(out)
1142             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1143     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1144        if tol>0.:        if tol>0.:
1145           return whereNegative(abs(arg)-tol)           return whereNegative(abs(arg)-tol)
# Line 1236  def whereNonZero(arg,tol=0.): Line 1220  def whereNonZero(arg,tol=0.):
1220     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1221     """     """
1222     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1223        if arg.rank==0:        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1224          if abs(arg)>tol:        if isinstance(out,float): out=numarray.array(out)
1225             return numarray.array(1.)        return out
         else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1226     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1227        if tol>0.:        if tol>0.:
1228           return 1.-whereZero(arg,tol)           return 1.-whereZero(arg,tol)
# Line 2877  def length(arg): Line 2857  def length(arg):
2857     """     """
2858     return sqrt(inner(arg,arg))     return sqrt(inner(arg,arg))
2859    
2860    def trace(arg,axis_offset=0):
2861       """
2862       returns the trace of arg which the sum of arg[k,k] over k.
2863    
2864       @param arg: argument
2865       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}.
2866       @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2867                      axis_offset and axis_offset+1 must be equal.
2868       @type axis_offset: C{int}
2869       @return: trace of arg. The rank of the returned object is minus 2 of the rank of arg.
2870       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.
2871       """
2872       if isinstance(arg,numarray.NumArray):
2873          sh=arg.shape
2874          if len(sh)<2:
2875            raise ValueError,"trace: rank of argument must be greater than 1"
2876          if axis_offset<0 or axis_offset>len(sh)-2:
2877            raise ValueError,"trace: axis_offset must be between 0 and %s"%len(sh)-2
2878          s1=1
2879          for i in range(axis_offset): s1*=sh[i]
2880          s2=1
2881          for i in range(axis_offset+2,len(sh)): s2*=sh[i]
2882          if not sh[axis_offset] == sh[axis_offset+1]:
2883            raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2884          arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2885          out=numarray.zeros([s1,s2],numarray.Float)
2886          for i1 in range(s1):
2887            for i2 in range(s2):
2888                for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
2889          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
2890          return out
2891       elif isinstance(arg,escript.Data):
2892          return escript_trace(arg,axis_offset)
2893       elif isinstance(arg,float):
2894          raise TypeError,"trace: illegal argument type float."
2895       elif isinstance(arg,int):
2896          raise TypeError,"trace: illegal argument type int."
2897       elif isinstance(arg,Symbol):
2898          return Trace_Symbol(arg,axis_offset)
2899       else:
2900          raise TypeError,"trace: Unknown argument type."
2901    
2902    def escript_trace(arg,axis_offset): # this should be escript._trace
2903          "arg si a Data objects!!!"
2904          if arg.getRank()<2:
2905            raise ValueError,"escript_trace: rank of argument must be greater than 1"
2906          if axis_offset<0 or axis_offset>arg.getRank()-2:
2907            raise ValueError,"escript_trace: axis_offset must be between 0 and %s"%arg.getRank()-2
2908          s=list(arg.getShape())        
2909          if not s[axis_offset] == s[axis_offset+1]:
2910            raise ValueError,"escript_trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2911          out=escript.Data(0.,tuple(s[0:axis_offset]+s[axis_offset+2:]),arg.getFunctionSpace())
2912          if arg.getRank()==2:
2913             for i0 in range(s[0]):
2914                out+=arg[i0,i0]
2915          elif arg.getRank()==3:
2916             if axis_offset==0:
2917                for i0 in range(s[0]):
2918                      for i2 in range(s[2]):
2919                             out[i2]+=arg[i0,i0,i2]
2920             elif axis_offset==1:
2921                for i0 in range(s[0]):
2922                   for i1 in range(s[1]):
2923                             out[i0]+=arg[i0,i1,i1]
2924          elif arg.getRank()==4:
2925             if axis_offset==0:
2926                for i0 in range(s[0]):
2927                      for i2 in range(s[2]):
2928                         for i3 in range(s[3]):
2929                             out[i2,i3]+=arg[i0,i0,i2,i3]
2930             elif axis_offset==1:
2931                for i0 in range(s[0]):
2932                   for i1 in range(s[1]):
2933                         for i3 in range(s[3]):
2934                             out[i0,i3]+=arg[i0,i1,i1,i3]
2935             elif axis_offset==2:
2936                for i0 in range(s[0]):
2937                   for i1 in range(s[1]):
2938                      for i2 in range(s[2]):
2939                             out[i0,i1]+=arg[i0,i1,i2,i2]
2940          return out
2941    class Trace_Symbol(DependendSymbol):
2942       """
2943       L{Symbol} representing the result of the trace function
2944       """
2945       def __init__(self,arg,axis_offset=0):
2946          """
2947          initialization of trace L{Symbol} with argument arg
2948          @param arg: argument of function
2949          @type arg: L{Symbol}.
2950          @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2951                      axis_offset and axis_offset+1 must be equal.
2952          @type axis_offset: C{int}
2953          """
2954          if arg.getRank()<2:
2955            raise ValueError,"Trace_Symbol: rank of argument must be greater than 1"
2956          if axis_offset<0 or axis_offset>arg.getRank()-2:
2957            raise ValueError,"Trace_Symbol: axis_offset must be between 0 and %s"%arg.getRank()-2
2958          s=list(arg.getShape())        
2959          if not s[axis_offset] == s[axis_offset+1]:
2960            raise ValueError,"Trace_Symbol: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2961          super(Trace_Symbol,self).__init__(args=[arg,axis_offset],shape=tuple(s[0:axis_offset]+s[axis_offset+2:]),dim=arg.getDim())
2962    
2963       def getMyCode(self,argstrs,format="escript"):
2964          """
2965          returns a program code that can be used to evaluate the symbol.
2966    
2967          @param argstrs: gives for each argument a string representing the argument for the evaluation.
2968          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
2969          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
2970          @type format: C{str}
2971          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
2972          @rtype: C{str}
2973          @raise: NotImplementedError: if the requested format is not available
2974          """
2975          if format=="escript" or format=="str"  or format=="text":
2976             return "trace(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
2977          else:
2978             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
2979    
2980       def substitute(self,argvals):
2981          """
2982          assigns new values to symbols in the definition of the symbol.
2983          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
2984    
2985          @param argvals: new values assigned to symbols
2986          @type argvals: C{dict} with keywords of type L{Symbol}.
2987          @return: result of the substitution process. Operations are executed as much as possible.
2988          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
2989          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
2990          """
2991          if argvals.has_key(self):
2992             arg=argvals[self]
2993             if self.isAppropriateValue(arg):
2994                return arg
2995             else:
2996                raise TypeError,"%s: new value is not appropriate."%str(self)
2997          else:
2998             arg=self.getSubstitutedArguments(argvals)
2999             return trace(arg[0],axis_offset=arg[1])
3000    
3001       def diff(self,arg):
3002          """
3003          differential of this object
3004    
3005          @param arg: the derivative is calculated with respect to arg
3006          @type arg: L{escript.Symbol}
3007          @return: derivative with respect to C{arg}
3008          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3009          """
3010          if arg==self:
3011             return identity(self.getShape())
3012          else:
3013             return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3014    
3015    def inverse(arg):
3016        """
3017        returns the inverse of the square matrix arg.
3018    
3019        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal
3020        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3021        @return: inverse arg_inv of the argument. It will be matrixmul(inverse(arg),arg) almost equal to kronecker(arg.getShape()[0])
3022        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3023        """
3024        if isinstance(arg,numarray.NumArray):
3025          return numarray.linear_algebra.inverse(arg)
3026        elif isinstance(arg,escript.Data):
3027          return escript_inverse(arg)
3028        elif isinstance(arg,float):
3029          return 1./arg
3030        elif isinstance(arg,int):
3031          return 1./float(arg)
3032        elif isinstance(arg,Symbol):
3033          return Inverse_Symbol(arg)
3034        else:
3035          raise TypeError,"inverse: Unknown argument type."
3036    
3037    def escript_inverse(arg): # this should be escript._inverse and use LAPACK
3038          "arg is a Data objects!!!"
3039          if not arg.getRank()==2:
3040            raise ValueError,"escript_inverse: argument must have rank 2"
3041          s=arg.getShape()      
3042          if not s[0] == s[1]:
3043            raise ValueError,"escript_inverse: argument must be a square matrix."
3044          out=escript.Data(0.,s,arg.getFunctionSpace())
3045          if s[0]==1:
3046              if inf(abs(arg[0,0]))==0: # in c this should be done point wise as abs(arg[0,0](i))<=0.
3047                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3048              out[0,0]=1./arg[0,0]
3049          elif s[0]==2:
3050              A11=arg[0,0]
3051              A12=arg[0,1]
3052              A21=arg[1,0]
3053              A22=arg[1,1]
3054              D = A11*A22-A12*A21
3055              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3056                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3057              D=1./D
3058              out[0,0]= A22*D
3059              out[1,0]=-A21*D
3060              out[0,1]=-A12*D
3061              out[1,1]= A11*D
3062          elif s[0]==3:
3063              A11=arg[0,0]
3064              A21=arg[1,0]
3065              A31=arg[2,0]
3066              A12=arg[0,1]
3067              A22=arg[1,1]
3068              A32=arg[2,1]
3069              A13=arg[0,2]
3070              A23=arg[1,2]
3071              A33=arg[2,2]
3072              D  =  A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
3073              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3074                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3075              D=1./D
3076              out[0,0]=(A22*A33-A23*A32)*D
3077              out[1,0]=(A31*A23-A21*A33)*D
3078              out[2,0]=(A21*A32-A31*A22)*D
3079              out[0,1]=(A13*A32-A12*A33)*D
3080              out[1,1]=(A11*A33-A31*A13)*D
3081              out[2,1]=(A12*A31-A11*A32)*D
3082              out[0,2]=(A12*A23-A13*A22)*D
3083              out[1,2]=(A13*A21-A11*A23)*D
3084              out[2,2]=(A11*A22-A12*A21)*D
3085          else:
3086             raise TypeError,"escript_inverse: only matrix dimensions 1,2,3 are supported right now."
3087          return out
3088    
3089    class Inverse_Symbol(DependendSymbol):
3090       """
3091       L{Symbol} representing the result of the inverse function
3092       """
3093       def __init__(self,arg):
3094          """
3095          initialization of inverse L{Symbol} with argument arg
3096          @param arg: argument of function
3097          @type arg: L{Symbol}.
3098          """
3099          if not arg.getRank()==2:
3100            raise ValueError,"Inverse_Symbol:: argument must have rank 2"
3101          s=arg.getShape()
3102          if not s[0] == s[1]:
3103            raise ValueError,"Inverse_Symbol:: argument must be a square matrix."
3104          super(Inverse_Symbol,self).__init__(args=[arg],shape=s,dim=arg.getDim())
3105    
3106       def getMyCode(self,argstrs,format="escript"):
3107          """
3108          returns a program code that can be used to evaluate the symbol.
3109    
3110          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3111          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3112          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3113          @type format: C{str}
3114          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3115          @rtype: C{str}
3116          @raise: NotImplementedError: if the requested format is not available
3117          """
3118          if format=="escript" or format=="str"  or format=="text":
3119             return "inverse(%s)"%argstrs[0]
3120          else:
3121             raise NotImplementedError,"Inverse_Symbol does not provide program code for format %s."%format
3122    
3123       def substitute(self,argvals):
3124          """
3125          assigns new values to symbols in the definition of the symbol.
3126          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3127    
3128          @param argvals: new values assigned to symbols
3129          @type argvals: C{dict} with keywords of type L{Symbol}.
3130          @return: result of the substitution process. Operations are executed as much as possible.
3131          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3132          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3133          """
3134          if argvals.has_key(self):
3135             arg=argvals[self]
3136             if self.isAppropriateValue(arg):
3137                return arg
3138             else:
3139                raise TypeError,"%s: new value is not appropriate."%str(self)
3140          else:
3141             arg=self.getSubstitutedArguments(argvals)
3142             return inverse(arg[0])
3143    
3144       def diff(self,arg):
3145          """
3146          differential of this object
3147    
3148          @param arg: the derivative is calculated with respect to arg
3149          @type arg: L{escript.Symbol}
3150          @return: derivative with respect to C{arg}
3151          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3152          """
3153          if arg==self:
3154             return identity(self.getShape())
3155          else:
3156             return -matrixmult(matrixmult(self,self.getDifferentiatedArguments(arg)[0]),self)
3157  #=======================================================  #=======================================================
3158  #  Binary operations:  #  Binary operations:
3159  #=======================================================  #=======================================================
# Line 3304  def maximum(*args): Line 3581  def maximum(*args):
3581         if out==None:         if out==None:
3582            out=a            out=a
3583         else:         else:
3584            m=whereNegative(out-a)            diff=add(a,-out)
3585            out=m*a+(1.-m)*out            out=add(out,mult(wherePositive(diff),diff))
3586      return out      return out
3587        
3588  def minimum(*arg):  def minimum(*args):
3589      """      """
3590      the minimum over arguments args      the minimum over arguments args
3591    
# Line 3322  def minimum(*arg): Line 3599  def minimum(*arg):
3599         if out==None:         if out==None:
3600            out=a            out=a
3601         else:         else:
3602            m=whereNegative(out-a)            diff=add(a,-out)
3603            out=m*out+(1.-m)*a            out=add(out,mult(whereNegative(diff),diff))
3604      return out      return out
3605    
3606    def clip(arg,minval=0.,maxval=1.):
3607        """
3608        cuts the values of arg between minval and maxval
3609    
3610        @param arg: argument
3611        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float}
3612        @param minval: lower range
3613        @type arg: C{float}
3614        @param maxval: upper range
3615        @type arg: C{float}
3616        @return: is on object with all its value between minval and maxval. value of the argument that greater then minval and
3617                 less then maxval are unchanged.
3618        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float} depending on the input
3619        @raise ValueError: if minval>maxval
3620        """
3621        if minval>maxval:
3622           raise ValueError,"minval = %s must be less then maxval %s"%(minval,maxval)
3623        return minimum(maximum(minval,arg),maxval)
3624    
3625        
3626  def inner(arg0,arg1):  def inner(arg0,arg1):
3627      """      """
# Line 3348  def inner(arg0,arg1): Line 3645  def inner(arg0,arg1):
3645      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3646      if not sh0==sh1:      if not sh0==sh1:
3647          raise ValueError,"inner: shape of arguments does not match"          raise ValueError,"inner: shape of arguments does not match"
3648      return generalTensorProduct(arg0,arg1,offset=len(sh0))      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))
3649    
3650  def matrixmult(arg0,arg1):  def matrixmult(arg0,arg1):
3651      """      """
# Line 3376  def matrixmult(arg0,arg1): Line 3673  def matrixmult(arg0,arg1):
3673          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
3674      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
3675          raise ValueError,"second argument must have rank 1 or 2"          raise ValueError,"second argument must have rank 1 or 2"
3676      return generalTensorProduct(arg0,arg1,offset=1)      return generalTensorProduct(arg0,arg1,axis_offset=1)
3677    
3678  def outer(arg0,arg1):  def outer(arg0,arg1):
3679      """      """
# Line 3394  def outer(arg0,arg1): Line 3691  def outer(arg0,arg1):
3691      @return: the outer product of arg0 and arg1 at each data point      @return: the outer product of arg0 and arg1 at each data point
3692      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3693      """      """
3694      return generalTensorProduct(arg0,arg1,offset=0)      return generalTensorProduct(arg0,arg1,axis_offset=0)
3695    
3696    
3697  def tensormult(arg0,arg1):  def tensormult(arg0,arg1):
# Line 3436  def tensormult(arg0,arg1): Line 3733  def tensormult(arg0,arg1):
3733      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
3734      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3735      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
3736         return generalTensorProduct(arg0,arg1,offset=1)         return generalTensorProduct(arg0,arg1,axis_offset=1)
3737      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
3738         return generalTensorProduct(arg0,arg1,offset=2)         return generalTensorProduct(arg0,arg1,axis_offset=2)
3739      else:      else:
3740          raise ValueError,"tensormult: first argument must have rank 2 or 4"          raise ValueError,"tensormult: first argument must have rank 2 or 4"
3741    
3742  def generalTensorProduct(arg0,arg1,offset=0):  def generalTensorProduct(arg0,arg1,axis_offset=0):
3743      """      """
3744      generalized tensor product      generalized tensor product
3745    
3746      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]
3747    
3748      where s runs through arg0.Shape[:arg0.Rank-offset]      where s runs through arg0.Shape[:arg0.Rank-axis_offset]
3749            r runs trough arg0.Shape[:offset]            r runs trough arg0.Shape[:axis_offset]
3750            t runs through arg1.Shape[offset:]            t runs through arg1.Shape[axis_offset:]
3751    
3752      In the first case the the second dimension of arg0 and the length of arg1 must match and        In the first case the the second dimension of arg0 and the length of arg1 must match and  
3753      in the second case the two last dimensions of arg0 must match the shape of arg1.      in the second case the two last dimensions of arg0 must match the shape of arg1.
# Line 3467  def generalTensorProduct(arg0,arg1,offse Line 3764  def generalTensorProduct(arg0,arg1,offse
3764      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols
3765      if isinstance(arg0,numarray.NumArray):      if isinstance(arg0,numarray.NumArray):
3766         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3767             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3768         else:         else:
3769             if not arg0.shape[arg0.rank-offset:]==arg1.shape[:offset]:             if not arg0.shape[arg0.rank-axis_offset:]==arg1.shape[:axis_offset]:
3770                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3771             arg0_c=arg0.copy()             arg0_c=arg0.copy()
3772             arg1_c=arg1.copy()             arg1_c=arg1.copy()
3773             sh0,sh1=arg0.shape,arg1.shape             sh0,sh1=arg0.shape,arg1.shape
3774             d0,d1,d01=1,1,1             d0,d1,d01=1,1,1
3775             for i in sh0[:arg0.rank-offset]: d0*=i             for i in sh0[:arg0.rank-axis_offset]: d0*=i
3776             for i in sh1[offset:]: d1*=i             for i in sh1[axis_offset:]: d1*=i
3777             for i in sh1[:offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
3778             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
3779             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
3780             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float)
3781             for i0 in range(d0):             for i0 in range(d0):
3782                      for i1 in range(d1):                      for i1 in range(d1):
3783                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])
3784             out.resize(sh0[:arg0.rank-offset]+sh1[offset:])             out.resize(sh0[:arg0.rank-axis_offset]+sh1[axis_offset:])
3785             return out             return out
3786      elif isinstance(arg0,escript.Data):      elif isinstance(arg0,escript.Data):
3787         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3788             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3789         else:         else:
3790             return escript_generalTensorProduct(arg0,arg1,offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,offset)             return escript_generalTensorProduct(arg0,arg1,axis_offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,axis_offset)
3791      else:            else:      
3792         return GeneralTensorProduct_Symbol(arg0,arg1,offset)         return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3793                                    
3794  class GeneralTensorProduct_Symbol(DependendSymbol):  class GeneralTensorProduct_Symbol(DependendSymbol):
3795     """     """
3796     Symbol representing the quotient of two arguments.     Symbol representing the quotient of two arguments.
3797     """     """
3798     def __init__(self,arg0,arg1,offset=0):     def __init__(self,arg0,arg1,axis_offset=0):
3799         """         """
3800         initialization of L{Symbol} representing the quotient of two arguments         initialization of L{Symbol} representing the quotient of two arguments
3801    
# Line 3511  class GeneralTensorProduct_Symbol(Depend Line 3808  class GeneralTensorProduct_Symbol(Depend
3808         """         """
3809         sh_arg0=pokeShape(arg0)         sh_arg0=pokeShape(arg0)
3810         sh_arg1=pokeShape(arg1)         sh_arg1=pokeShape(arg1)
3811         sh0=sh_arg0[:len(sh_arg0)-offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
3812         sh01=sh_arg0[len(sh_arg0)-offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
3813         sh10=sh_arg1[:offset]         sh10=sh_arg1[:axis_offset]
3814         sh1=sh_arg1[offset:]         sh1=sh_arg1[axis_offset:]
3815         if not sh01==sh10:         if not sh01==sh10:
3816             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3817         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,offset])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,axis_offset])
3818    
3819     def getMyCode(self,argstrs,format="escript"):     def getMyCode(self,argstrs,format="escript"):
3820        """        """
# Line 3532  class GeneralTensorProduct_Symbol(Depend Line 3829  class GeneralTensorProduct_Symbol(Depend
3829        @raise: NotImplementedError: if the requested format is not available        @raise: NotImplementedError: if the requested format is not available
3830        """        """
3831        if format=="escript" or format=="str" or format=="text":        if format=="escript" or format=="str" or format=="text":
3832           return "generalTensorProduct(%s,%s,offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])           return "generalTensorProduct(%s,%s,axis_offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])
3833        else:        else:
3834           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)
3835    
# Line 3557  class GeneralTensorProduct_Symbol(Depend Line 3854  class GeneralTensorProduct_Symbol(Depend
3854           args=self.getSubstitutedArguments(argvals)           args=self.getSubstitutedArguments(argvals)
3855           return generalTensorProduct(args[0],args[1],args[2])           return generalTensorProduct(args[0],args[1],args[2])
3856    
3857  def escript_generalTensorProduct(arg0,arg1,offset): # this should be escript._generalTensorProduct  def escript_generalTensorProduct(arg0,arg1,axis_offset): # this should be escript._generalTensorProduct
3858      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"
3859      # calculate the return shape:      # calculate the return shape:
3860      shape0=arg0.getShape()[:arg0.getRank()-offset]      shape0=arg0.getShape()[:arg0.getRank()-axis_offset]
3861      shape01=arg0.getShape()[arg0.getRank()-offset:]      shape01=arg0.getShape()[arg0.getRank()-axis_offset:]
3862      shape10=arg1.getShape()[:offset]      shape10=arg1.getShape()[:axis_offset]
3863      shape1=arg1.getShape()[offset:]      shape1=arg1.getShape()[axis_offset:]
3864      if not shape01==shape10:      if not shape01==shape10:
3865          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3866    
3867        # whatr function space should be used? (this here is not good!)
3868        fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace()
3869      # create return value:      # create return value:
3870      out=escript.Data(0.,tuple(shape0+shape1),arg0.getFunctionSpace())      out=escript.Data(0.,tuple(shape0+shape1),fs)
3871      #      #
3872      s0=[[]]      s0=[[]]
3873      for k in shape0:      for k in shape0:
# Line 3591  def escript_generalTensorProduct(arg0,ar Line 3890  def escript_generalTensorProduct(arg0,ar
3890    
3891      for i0 in s0:      for i0 in s0:
3892         for i1 in s1:         for i1 in s1:
3893           s=escript.Scalar(0.,arg0.getFunctionSpace())           s=escript.Scalar(0.,fs)
3894           for i01 in s01:           for i01 in s01:
3895              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))
3896           out.__setitem__(tuple(i0+i1),s)           out.__setitem__(tuple(i0+i1),s)
3897      return out      return out
3898    
3899    
3900  #=========================================================  #=========================================================
3901  #   some little helpers  #  functions dealing with spatial dependency
3902  #=========================================================  #=========================================================
3903  def grad(arg,where=None):  def grad(arg,where=None):
3904      """      """
3905      Returns the spatial gradient of arg at where.      Returns the spatial gradient of arg at where.
3906    
3907        If C{g} is the returned object, then
3908    
3909          - if C{arg} is rank 0 C{g[s]} is the derivative of C{arg} with respect to the C{s}-th spatial dimension.
3910          - if C{arg} is rank 1 C{g[i,s]} is the derivative of C{arg[i]} with respect to the C{s}-th spatial dimension.
3911          - if C{arg} is rank 2 C{g[i,j,s]} is the derivative of C{arg[i,j]} with respect to the C{s}-th spatial dimension.
3912          - if C{arg} is rank 3 C{g[i,j,k,s]} is the derivative of C{arg[i,j,k]} with respect to the C{s}-th spatial dimension.
3913    
3914      @param arg:   Data object representing the function which gradient      @param arg: function which gradient to be calculated. Its rank has to be less than 3.
3915                    to be calculated.      @type arg: L{escript.Data} or L{Symbol}
3916      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the gradient will be calculated.
3917                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
3918        @type where: C{None} or L{escript.FunctionSpace}
3919        @return: gradient of arg.
3920        @rtype:  L{escript.Data} or L{Symbol}
3921      """      """
3922      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
3923         return Grad_Symbol(arg,where)         return Grad_Symbol(arg,where)
# Line 3617  def grad(arg,where=None): Line 3927  def grad(arg,where=None):
3927         else:         else:
3928            return arg._grad(where)            return arg._grad(where)
3929      else:      else:
3930        raise TypeError,"grad: Unknown argument type."         raise TypeError,"grad: Unknown argument type."
3931    
3932    class Grad_Symbol(DependendSymbol):
3933       """
3934       L{Symbol} representing the result of the gradient operator
3935       """
3936       def __init__(self,arg,where=None):
3937          """
3938          initialization of gradient L{Symbol} with argument arg
3939          @param arg: argument of function
3940          @type arg: L{Symbol}.
3941          @param where: FunctionSpace in which the gradient will be calculated.
3942                      If not present or C{None} an appropriate default is used.
3943          @type where: C{None} or L{escript.FunctionSpace}
3944          """
3945          d=arg.getDim()
3946          if d==None:
3947             raise ValueError,"argument must have a spatial dimension"
3948          super(Grad_Symbol,self).__init__(args=[arg,where],shape=arg.getShape()+(d,),dim=d)
3949    
3950       def getMyCode(self,argstrs,format="escript"):
3951          """
3952          returns a program code that can be used to evaluate the symbol.
3953    
3954          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3955          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3956          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3957          @type format: C{str}
3958          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3959          @rtype: C{str}
3960          @raise: NotImplementedError: if the requested format is not available
3961          """
3962          if format=="escript" or format=="str"  or format=="text":
3963             return "grad(%s,where=%s)"%(argstrs[0],argstrs[1])
3964          else:
3965             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3966    
3967       def substitute(self,argvals):
3968          """
3969          assigns new values to symbols in the definition of the symbol.
3970          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3971    
3972          @param argvals: new values assigned to symbols
3973          @type argvals: C{dict} with keywords of type L{Symbol}.
3974          @return: result of the substitution process. Operations are executed as much as possible.
3975          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3976          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3977          """
3978          if argvals.has_key(self):
3979             arg=argvals[self]
3980             if self.isAppropriateValue(arg):
3981                return arg
3982             else:
3983                raise TypeError,"%s: new value is not appropriate."%str(self)
3984          else:
3985             arg=self.getSubstitutedArguments(argvals)
3986             return grad(arg[0],where=arg[1])
3987    
3988       def diff(self,arg):
3989          """
3990          differential of this object
3991    
3992          @param arg: the derivative is calculated with respect to arg
3993          @type arg: L{escript.Symbol}
3994          @return: derivative with respect to C{arg}
3995          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3996          """
3997          if arg==self:
3998             return identity(self.getShape())
3999          else:
4000             return grad(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4001    
4002  def integrate(arg,where=None):  def integrate(arg,where=None):
4003      """      """
4004      Return the integral if the function represented by Data object arg over      Return the integral of the function C{arg} over its domain. If C{where} is present C{arg} is interpolated to C{where}
4005      its domain.      before integration.
4006    
4007      @param arg:   Data object representing the function which is integrated.      @param arg:   the function which is integrated.
4008        @type arg: L{escript.Data} or L{Symbol}
4009      @param where: FunctionSpace in which the integral is calculated.      @param where: FunctionSpace in which the integral is calculated.
4010                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4011        @type where: C{None} or L{escript.FunctionSpace}
4012        @return: integral of arg.
4013        @rtype:  C{float}, C{numarray.NumArray} or L{Symbol}
4014      """      """
4015      if isinstance(arg,numarray.NumArray):      if isinstance(arg,Symbol):
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,float):  
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,int):  
         if checkForZero(arg):  
            return float(arg)  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,Symbol):  
4016         return Integrate_Symbol(arg,where)         return Integrate_Symbol(arg,where)
4017      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4018         if not where==None: arg=escript.Data(arg,where)         if not where==None: arg=escript.Data(arg,where)
# Line 3654  def integrate(arg,where=None): Line 4023  def integrate(arg,where=None):
4023      else:      else:
4024        raise TypeError,"integrate: Unknown argument type."        raise TypeError,"integrate: Unknown argument type."
4025    
4026    class Integrate_Symbol(DependendSymbol):
4027       """
4028       L{Symbol} representing the result of the spatial integration operator
4029       """
4030       def __init__(self,arg,where=None):
4031          """
4032          initialization of integration L{Symbol} with argument arg
4033          @param arg: argument of the integration
4034          @type arg: L{Symbol}.
4035          @param where: FunctionSpace in which the integration will be calculated.
4036                      If not present or C{None} an appropriate default is used.
4037          @type where: C{None} or L{escript.FunctionSpace}
4038          """
4039          super(Integrate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4040    
4041       def getMyCode(self,argstrs,format="escript"):
4042          """
4043          returns a program code that can be used to evaluate the symbol.
4044    
4045          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4046          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4047          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4048          @type format: C{str}
4049          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4050          @rtype: C{str}
4051          @raise: NotImplementedError: if the requested format is not available
4052          """
4053          if format=="escript" or format=="str"  or format=="text":
4054             return "integrate(%s,where=%s)"%(argstrs[0],argstrs[1])
4055          else:
4056             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4057    
4058       def substitute(self,argvals):
4059          """
4060          assigns new values to symbols in the definition of the symbol.
4061          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4062    
4063          @param argvals: new values assigned to symbols
4064          @type argvals: C{dict} with keywords of type L{Symbol}.
4065          @return: result of the substitution process. Operations are executed as much as possible.
4066          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4067          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4068          """
4069          if argvals.has_key(self):
4070             arg=argvals[self]
4071             if self.isAppropriateValue(arg):
4072                return arg
4073             else:
4074                raise TypeError,"%s: new value is not appropriate."%str(self)
4075          else:
4076             arg=self.getSubstitutedArguments(argvals)
4077             return integrate(arg[0],where=arg[1])
4078    
4079       def diff(self,arg):
4080          """
4081          differential of this object
4082    
4083          @param arg: the derivative is calculated with respect to arg
4084          @type arg: L{escript.Symbol}
4085          @return: derivative with respect to C{arg}
4086          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4087          """
4088          if arg==self:
4089             return identity(self.getShape())
4090          else:
4091             return integrate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4092    
4093    
4094  def interpolate(arg,where):  def interpolate(arg,where):
4095      """      """
4096      Interpolates the function into the FunctionSpace where.      interpolates the function into the FunctionSpace where.
4097    
4098      @param arg:    interpolant      @param arg: interpolant
4099      @param where:  FunctionSpace to interpolate to      @type arg: L{escript.Data} or L{Symbol}
4100        @param where: FunctionSpace to be interpolated to
4101        @type where: L{escript.FunctionSpace}
4102        @return: interpolated argument
4103        @rtype:  C{escript.Data} or L{Symbol}
4104      """      """
4105      if testForZero(arg):      if isinstance(arg,Symbol):
4106        return 0         return Interpolate_Symbol(arg,where)
     elif isinstance(arg,Symbol):  
        return Interpolated_Symbol(arg,where)  
4107      else:      else:
4108         return escript.Data(arg,where)         return escript.Data(arg,where)
4109    
4110    class Interpolate_Symbol(DependendSymbol):
4111       """
4112       L{Symbol} representing the result of the interpolation operator
4113       """
4114       def __init__(self,arg,where):
4115          """
4116          initialization of interpolation L{Symbol} with argument arg
4117          @param arg: argument of the interpolation
4118          @type arg: L{Symbol}.
4119          @param where: FunctionSpace into which the argument is interpolated.
4120          @type where: L{escript.FunctionSpace}
4121          """
4122          super(Interpolate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4123    
4124       def getMyCode(self,argstrs,format="escript"):
4125          """
4126          returns a program code that can be used to evaluate the symbol.
4127    
4128          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4129          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4130          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4131          @type format: C{str}
4132          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4133          @rtype: C{str}
4134          @raise: NotImplementedError: if the requested format is not available
4135          """
4136          if format=="escript" or format=="str"  or format=="text":
4137             return "interpolate(%s,where=%s)"%(argstrs[0],argstrs[1])
4138          else:
4139             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4140    
4141       def substitute(self,argvals):
4142          """
4143          assigns new values to symbols in the definition of the symbol.
4144          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4145    
4146          @param argvals: new values assigned to symbols
4147          @type argvals: C{dict} with keywords of type L{Symbol}.
4148          @return: result of the substitution process. Operations are executed as much as possible.
4149          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4150          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4151          """
4152          if argvals.has_key(self):
4153             arg=argvals[self]
4154             if self.isAppropriateValue(arg):
4155                return arg
4156             else:
4157                raise TypeError,"%s: new value is not appropriate."%str(self)
4158          else:
4159             arg=self.getSubstitutedArguments(argvals)
4160             return interpolate(arg[0],where=arg[1])
4161    
4162       def diff(self,arg):
4163          """
4164          differential of this object
4165    
4166          @param arg: the derivative is calculated with respect to arg
4167          @type arg: L{escript.Symbol}
4168          @return: derivative with respect to C{arg}
4169          @rtype: L{Symbol} but other types such as L{escript.Data}, L{numarray.NumArray}  are possible.
4170          """
4171          if arg==self:
4172             return identity(self.getShape())
4173          else:
4174             return interpolate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4175    
4176    
4177  def div(arg,where=None):  def div(arg,where=None):
4178      """      """
4179      Returns the divergence of arg at where.      returns the divergence of arg at where.
4180    
4181      @param arg:   Data object representing the function which gradient to      @param arg: function which divergence to be calculated. Its shape has to be (d,) where d is the spatial dimension.
4182                    be calculated.      @type arg: L{escript.Data} or L{Symbol}
4183      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the divergence will be calculated.
4184                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4185        @type where: C{None} or L{escript.FunctionSpace}
4186        @return: divergence of arg.
4187        @rtype:  L{escript.Data} or L{Symbol}
4188      """      """
4189      g=grad(arg,where)      if isinstance(arg,Symbol):
4190      return trace(g,axis0=g.getRank()-2,axis1=g.getRank()-1)          dim=arg.getDim()
4191        elif isinstance(arg,escript.Data):
4192            dim=arg.getDomain().getDim()
4193        else:
4194            raise TypeError,"div: argument type not supported"
4195        if not arg.getShape()==(dim,):
4196          raise ValueError,"div: expected shape is (%s,)"%dim
4197        return trace(grad(arg,where))
4198    
4199  def jump(arg):  def jump(arg,domain=None):
4200      """      """
4201      Returns the jump of arg across a continuity.      returns the jump of arg across the continuity of the domain
4202    
4203      @param arg:   Data object representing the function which gradient      @param arg: argument
4204                    to be calculated.      @type arg: L{escript.Data} or L{Symbol}
4205        @param domain: the domain where the discontinuity is located. If domain is not present or equal to C{None}
4206                       the domain of arg is used. If arg is a L{Symbol} the domain must be present.
4207        @type domain: C{None} or L{escript.Domain}
4208        @return: jump of arg
4209        @rtype:  L{escript.Data} or L{Symbol}
4210      """      """
4211      d=arg.getDomain()      if domain==None: domain=arg.getDomain()
4212      return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())      return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))
4213    
4214    def L2(arg):
4215        """
4216        returns the L2 norm of arg at where
4217        
4218        @param arg: function which L2 to be calculated.
4219        @type arg: L{escript.Data} or L{Symbol}
4220        @return: L2 norm of arg.
4221        @rtype:  L{float} or L{Symbol}
4222        @note: L2(arg) is equivalent to sqrt(integrate(inner(arg,arg)))
4223        """
4224        return sqrt(integrate(inner(arg,arg)))
4225  #=============================  #=============================
4226  #  #
4227  # wrapper for various functions: if the argument has attribute the function name  # wrapper for various functions: if the argument has attribute the function name
# Line 3726  def transpose(arg,axis=None): Line 4258  def transpose(arg,axis=None):
4258      else:      else:
4259         return numarray.transpose(arg,axis=axis)         return numarray.transpose(arg,axis=axis)
4260    
 def trace(arg,axis0=0,axis1=1):  
     """  
     Return  
   
     @param arg:  
     """  
     if isinstance(arg,Symbol):  
        s=list(arg.getShape())          
        s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])  
        return Trace_Symbol(arg,axis0=axis0,axis1=axis1)  
     elif isinstance(arg,escript.Data):  
        # hack for trace  
        s=arg.getShape()  
        if s[axis0]!=s[axis1]:  
            raise ValueError,"illegal axis in trace"  
        out=escript.Scalar(0.,arg.getFunctionSpace())  
        for i in range(s[axis0]):  
           out+=arg[i,i]  
        return out  
        # end hack for trace  
     else:  
        return numarray.trace(arg,axis0=axis0,axis1=axis1)  
4261    
4262    
4263  def reorderComponents(arg,index):  def reorderComponents(arg,index):

Legend:
Removed from v.341  
changed lines
  Added in v.443

  ViewVC Help
Powered by ViewVC 1.1.26