/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 341 by gross, Mon Dec 12 05:26:10 2005 UTC revision 433 by gross, Tue Jan 17 23:54:38 2006 UTC
# Line 24  Utility functions for escript Line 24  Utility functions for escript
24  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
25  __licence__="contact: esys@access.uq.edu.au"  __licence__="contact: esys@access.uq.edu.au"
26  __url__="http://www.iservo.edu.au/esys/escript"  __url__="http://www.iservo.edu.au/esys/escript"
27  __version__="$Revision: 329 $"  __version__="$Revision$"
28  __date__="$Date$"  __date__="$Date$"
29    
30    
31  import math  import math
32  import numarray  import numarray
33    import numarray.linear_algebra
34  import escript  import escript
35  import os  import os
36    
# Line 43  import os Line 44  import os
44  # def matchType(arg0=0.,arg1=0.):  # def matchType(arg0=0.,arg1=0.):
45  # def matchShape(arg0,arg1):  # def matchShape(arg0,arg1):
46    
 # def maximum(arg0,arg1):  
 # def minimum(arg0,arg1):  
   
47  # def transpose(arg,axis=None):  # def transpose(arg,axis=None):
48  # def trace(arg,axis0=0,axis1=1):  # def trace(arg,axis0=0,axis1=1):
49  # def reorderComponents(arg,index):  # def reorderComponents(arg,index):
# Line 363  def testForZero(arg): Line 361  def testForZero(arg):
361      @return : True if the argument is identical to zero.      @return : True if the argument is identical to zero.
362      @rtype : C{bool}      @rtype : C{bool}
363      """      """
364      try:      if isinstance(arg,numarray.NumArray):
365           return not Lsup(arg)>0.
366        elif isinstance(arg,escript.Data):
367           return False
368        elif isinstance(arg,float):
369         return not Lsup(arg)>0.         return not Lsup(arg)>0.
370      except TypeError:      elif isinstance(arg,int):
371           return not Lsup(arg)>0.
372        elif isinstance(arg,Symbol):
373           return False
374        else:
375         return False         return False
376    
377  def matchType(arg0=0.,arg1=0.):  def matchType(arg0=0.,arg1=0.):
# Line 907  def wherePositive(arg): Line 913  def wherePositive(arg):
913     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
914     """     """
915     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
916        if arg.rank==0:        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))*1.
917           if arg>0:        if isinstance(out,float): out=numarray.array(out)
918             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))  
919     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
920        return arg._wherePositive()        return arg._wherePositive()
921     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 993  def whereNegative(arg): Line 995  def whereNegative(arg):
995     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
996     """     """
997     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
998        if arg.rank==0:        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))*1.
999           if arg<0:        if isinstance(out,float): out=numarray.array(out)
1000             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))  
1001     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1002        return arg._whereNegative()        return arg._whereNegative()
1003     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1079  def whereNonNegative(arg): Line 1077  def whereNonNegative(arg):
1077     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1078     """     """
1079     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1080        if arg.rank==0:        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1081           if arg<0:        if isinstance(out,float): out=numarray.array(out)
1082             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))  
1083     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1084        return arg._whereNonNegative()        return arg._whereNonNegative()
1085     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1113  def whereNonPositive(arg): Line 1107  def whereNonPositive(arg):
1107     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1108     """     """
1109     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1110        if arg.rank==0:        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1111           if arg>0:        if isinstance(out,float): out=numarray.array(out)
1112             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.  
1113     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1114        return arg._whereNonPositive()        return arg._whereNonPositive()
1115     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1149  def whereZero(arg,tol=0.): Line 1139  def whereZero(arg,tol=0.):
1139     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1140     """     """
1141     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1142        if arg.rank==0:        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1143           if abs(arg)<=tol:        if isinstance(out,float): out=numarray.array(out)
1144             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1145     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1146        if tol>0.:        if tol>0.:
1147           return whereNegative(abs(arg)-tol)           return whereNegative(abs(arg)-tol)
# Line 1236  def whereNonZero(arg,tol=0.): Line 1222  def whereNonZero(arg,tol=0.):
1222     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1223     """     """
1224     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1225        if arg.rank==0:        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1226          if abs(arg)>tol:        if isinstance(out,float): out=numarray.array(out)
1227             return numarray.array(1.)        return out
         else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1228     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1229        if tol>0.:        if tol>0.:
1230           return 1.-whereZero(arg,tol)           return 1.-whereZero(arg,tol)
# Line 2877  def length(arg): Line 2859  def length(arg):
2859     """     """
2860     return sqrt(inner(arg,arg))     return sqrt(inner(arg,arg))
2861    
2862    def trace(arg,axis_offset=0):
2863       """
2864       returns the trace of arg which the sum of arg[k,k] over k.
2865    
2866       @param arg: argument
2867       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}.
2868       @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2869                      axis_offset and axis_offset+1 must be equal.
2870       @type axis_offset: C{int}
2871       @return: trace of arg. The rank of the returned object is minus 2 of the rank of arg.
2872       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.
2873       """
2874       if isinstance(arg,numarray.NumArray):
2875          sh=arg.shape
2876          if len(sh)<2:
2877            raise ValueError,"trace: rank of argument must be greater than 1"
2878          if axis_offset<0 or axis_offset>len(sh)-2:
2879            raise ValueError,"trace: axis_offset must be between 0 and %s"%len(sh)-2
2880          s1=1
2881          for i in range(axis_offset): s1*=sh[i]
2882          s2=1
2883          for i in range(axis_offset+2,len(sh)): s2*=sh[i]
2884          if not sh[axis_offset] == sh[axis_offset+1]:
2885            raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2886          arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2887          out=numarray.zeros([s1,s2],numarray.Float)
2888          for i1 in range(s1):
2889            for i2 in range(s2):
2890                for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
2891          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
2892          return out
2893       elif isinstance(arg,escript.Data):
2894          return escript_trace(arg,axis_offset)
2895       elif isinstance(arg,float):
2896          raise TypeError,"trace: illegal argument type float."
2897       elif isinstance(arg,int):
2898          raise TypeError,"trace: illegal argument type int."
2899       elif isinstance(arg,Symbol):
2900          return Trace_Symbol(arg,axis_offset)
2901       else:
2902          raise TypeError,"trace: Unknown argument type."
2903    
2904    def escript_trace(arg,axis_offset): # this should be escript._trace
2905          "arg si a Data objects!!!"
2906          if arg.getRank()<2:
2907            raise ValueError,"escript_trace: rank of argument must be greater than 1"
2908          if axis_offset<0 or axis_offset>arg.getRank()-2:
2909            raise ValueError,"escript_trace: axis_offset must be between 0 and %s"%arg.getRank()-2
2910          s=list(arg.getShape())        
2911          if not s[axis_offset] == s[axis_offset+1]:
2912            raise ValueError,"escript_trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2913          out=escript.Data(0.,tuple(s[0:axis_offset]+s[axis_offset+2:]),arg.getFunctionSpace())
2914          if arg.getRank()==2:
2915             for i0 in range(s[0]):
2916                out+=arg[i0,i0]
2917          elif arg.getRank()==3:
2918             if axis_offset==0:
2919                for i0 in range(s[0]):
2920                      for i2 in range(s[2]):
2921                             out[i2]+=arg[i0,i0,i2]
2922             elif axis_offset==1:
2923                for i0 in range(s[0]):
2924                   for i1 in range(s[1]):
2925                             out[i0]+=arg[i0,i1,i1]
2926          elif arg.getRank()==4:
2927             if axis_offset==0:
2928                for i0 in range(s[0]):
2929                      for i2 in range(s[2]):
2930                         for i3 in range(s[3]):
2931                             out[i2,i3]+=arg[i0,i0,i2,i3]
2932             elif axis_offset==1:
2933                for i0 in range(s[0]):
2934                   for i1 in range(s[1]):
2935                         for i3 in range(s[3]):
2936                             out[i0,i3]+=arg[i0,i1,i1,i3]
2937             elif axis_offset==2:
2938                for i0 in range(s[0]):
2939                   for i1 in range(s[1]):
2940                      for i2 in range(s[2]):
2941                             out[i0,i1]+=arg[i0,i1,i2,i2]
2942          return out
2943    class Trace_Symbol(DependendSymbol):
2944       """
2945       L{Symbol} representing the result of the trace function
2946       """
2947       def __init__(self,arg,axis_offset=0):
2948          """
2949          initialization of trace L{Symbol} with argument arg
2950          @param arg: argument of function
2951          @type arg: L{Symbol}.
2952          @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2953                      axis_offset and axis_offset+1 must be equal.
2954          @type axis_offset: C{int}
2955          """
2956          if arg.getRank()<2:
2957            raise ValueError,"Trace_Symbol: rank of argument must be greater than 1"
2958          if axis_offset<0 or axis_offset>arg.getRank()-2:
2959            raise ValueError,"Trace_Symbol: axis_offset must be between 0 and %s"%arg.getRank()-2
2960          s=list(arg.getShape())        
2961          if not s[axis_offset] == s[axis_offset+1]:
2962            raise ValueError,"Trace_Symbol: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2963          super(Trace_Symbol,self).__init__(args=[arg,axis_offset],shape=tuple(s[0:axis_offset]+s[axis_offset+2:]),dim=arg.getDim())
2964    
2965       def getMyCode(self,argstrs,format="escript"):
2966          """
2967          returns a program code that can be used to evaluate the symbol.
2968    
2969          @param argstrs: gives for each argument a string representing the argument for the evaluation.
2970          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
2971          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
2972          @type format: C{str}
2973          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
2974          @rtype: C{str}
2975          @raise: NotImplementedError: if the requested format is not available
2976          """
2977          if format=="escript" or format=="str"  or format=="text":
2978             return "trace(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
2979          else:
2980             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
2981    
2982       def substitute(self,argvals):
2983          """
2984          assigns new values to symbols in the definition of the symbol.
2985          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
2986    
2987          @param argvals: new values assigned to symbols
2988          @type argvals: C{dict} with keywords of type L{Symbol}.
2989          @return: result of the substitution process. Operations are executed as much as possible.
2990          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
2991          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
2992          """
2993          if argvals.has_key(self):
2994             arg=argvals[self]
2995             if self.isAppropriateValue(arg):
2996                return arg
2997             else:
2998                raise TypeError,"%s: new value is not appropriate."%str(self)
2999          else:
3000             arg=self.getSubstitutedArguments(argvals)
3001             return trace(arg[0],axis_offset=arg[1])
3002    
3003       def diff(self,arg):
3004          """
3005          differential of this object
3006    
3007          @param arg: the derivative is calculated with respect to arg
3008          @type arg: L{escript.Symbol}
3009          @return: derivative with respect to C{arg}
3010          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3011          """
3012          if arg==self:
3013             return identity(self.getShape())
3014          else:
3015             return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3016    
3017    def inverse(arg):
3018        """
3019        returns the inverse of the square matrix arg.
3020    
3021        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal
3022        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3023        @return: inverse arg_inv of the argument. It will be matrixmul(inverse(arg),arg) almost equal to kronecker(arg.getShape()[0])
3024        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3025        """
3026        if isinstance(arg,numarray.NumArray):
3027          return numarray.linear_algebra.inverse(arg)
3028        elif isinstance(arg,escript.Data):
3029          return escript_inverse(arg)
3030        elif isinstance(arg,float):
3031          return 1./arg
3032        elif isinstance(arg,int):
3033          return 1./float(arg)
3034        elif isinstance(arg,Symbol):
3035          return Inverse_Symbol(arg)
3036        else:
3037          raise TypeError,"inverse: Unknown argument type."
3038    
3039    def escript_inverse(arg): # this should be escript._inverse and use LAPACK
3040          "arg is a Data objects!!!"
3041          if not arg.getRank()==2:
3042            raise ValueError,"escript_inverse: argument must have rank 2"
3043          s=arg.getShape()      
3044          if not s[0] == s[1]:
3045            raise ValueError,"escript_inverse: argument must be a square matrix."
3046          out=escript.Data(0.,s,arg.getFunctionSpace())
3047          if s[0]==1:
3048              if inf(abs(arg[0,0]))==0: # in c this should be done point wise as abs(arg[0,0](i))<=0.
3049                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3050              out[0,0]=1./arg[0,0]
3051          elif s[0]==2:
3052              A11=arg[0,0]
3053              A12=arg[0,1]
3054              A21=arg[1,0]
3055              A22=arg[1,1]
3056              D = A11*A22-A12*A21
3057              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3058                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3059              D=1./D
3060              out[0,0]= A22*D
3061              out[1,0]=-A21*D
3062              out[0,1]=-A12*D
3063              out[1,1]= A11*D
3064          elif s[0]==3:
3065              A11=arg[0,0]
3066              A21=arg[1,0]
3067              A31=arg[2,0]
3068              A12=arg[0,1]
3069              A22=arg[1,1]
3070              A32=arg[2,1]
3071              A13=arg[0,2]
3072              A23=arg[1,2]
3073              A33=arg[2,2]
3074              D  =  A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
3075              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3076                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3077              D=1./D
3078              out[0,0]=(A22*A33-A23*A32)*D
3079              out[1,0]=(A31*A23-A21*A33)*D
3080              out[2,0]=(A21*A32-A31*A22)*D
3081              out[0,1]=(A13*A32-A12*A33)*D
3082              out[1,1]=(A11*A33-A31*A13)*D
3083              out[2,1]=(A12*A31-A11*A32)*D
3084              out[0,2]=(A12*A23-A13*A22)*D
3085              out[1,2]=(A13*A21-A11*A23)*D
3086              out[2,2]=(A11*A22-A12*A21)*D
3087          else:
3088             raise TypeError,"escript_inverse: only matrix dimensions 1,2,3 are supported right now."
3089          return out
3090    
3091    class Inverse_Symbol(DependendSymbol):
3092       """
3093       L{Symbol} representing the result of the inverse function
3094       """
3095       def __init__(self,arg):
3096          """
3097          initialization of inverse L{Symbol} with argument arg
3098          @param arg: argument of function
3099          @type arg: L{Symbol}.
3100          """
3101          if not arg.getRank()==2:
3102            raise ValueError,"Inverse_Symbol:: argument must have rank 2"
3103          s=arg.getShape()
3104          if not s[0] == s[1]:
3105            raise ValueError,"Inverse_Symbol:: argument must be a square matrix."
3106          super(Inverse_Symbol,self).__init__(args=[arg],shape=s,dim=arg.getDim())
3107    
3108       def getMyCode(self,argstrs,format="escript"):
3109          """
3110          returns a program code that can be used to evaluate the symbol.
3111    
3112          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3113          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3114          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3115          @type format: C{str}
3116          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3117          @rtype: C{str}
3118          @raise: NotImplementedError: if the requested format is not available
3119          """
3120          if format=="escript" or format=="str"  or format=="text":
3121             return "inverse(%s)"%argstrs[0]
3122          else:
3123             raise NotImplementedError,"Inverse_Symbol does not provide program code for format %s."%format
3124    
3125       def substitute(self,argvals):
3126          """
3127          assigns new values to symbols in the definition of the symbol.
3128          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3129    
3130          @param argvals: new values assigned to symbols
3131          @type argvals: C{dict} with keywords of type L{Symbol}.
3132          @return: result of the substitution process. Operations are executed as much as possible.
3133          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3134          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3135          """
3136          if argvals.has_key(self):
3137             arg=argvals[self]
3138             if self.isAppropriateValue(arg):
3139                return arg
3140             else:
3141                raise TypeError,"%s: new value is not appropriate."%str(self)
3142          else:
3143             arg=self.getSubstitutedArguments(argvals)
3144             return inverse(arg[0])
3145    
3146       def diff(self,arg):
3147          """
3148          differential of this object
3149    
3150          @param arg: the derivative is calculated with respect to arg
3151          @type arg: L{escript.Symbol}
3152          @return: derivative with respect to C{arg}
3153          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3154          """
3155          if arg==self:
3156             return identity(self.getShape())
3157          else:
3158             return -matrixmult(matrixmult(self,self.getDifferentiatedArguments(arg)[0]),self)
3159  #=======================================================  #=======================================================
3160  #  Binary operations:  #  Binary operations:
3161  #=======================================================  #=======================================================
# Line 3304  def maximum(*args): Line 3583  def maximum(*args):
3583         if out==None:         if out==None:
3584            out=a            out=a
3585         else:         else:
3586            m=whereNegative(out-a)            diff=add(a,-out)
3587            out=m*a+(1.-m)*out            out=add(out,mult(wherePositive(diff),diff))
3588      return out      return out
3589        
3590  def minimum(*arg):  def minimum(*args):
3591      """      """
3592      the minimum over arguments args      the minimum over arguments args
3593    
# Line 3322  def minimum(*arg): Line 3601  def minimum(*arg):
3601         if out==None:         if out==None:
3602            out=a            out=a
3603         else:         else:
3604            m=whereNegative(out-a)            diff=add(a,-out)
3605            out=m*out+(1.-m)*a            out=add(out,mult(whereNegative(diff),diff))
3606      return out      return out
3607    
3608    def clip(arg,minval=0.,maxval=1.):
3609        """
3610        cuts the values of arg between minval and maxval
3611    
3612        @param arg: argument
3613        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float}
3614        @param minval: lower range
3615        @type arg: C{float}
3616        @param maxval: upper range
3617        @type arg: C{float}
3618        @return: is on object with all its value between minval and maxval. value of the argument that greater then minval and
3619                 less then maxval are unchanged.
3620        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float} depending on the input
3621        @raise ValueError: if minval>maxval
3622        """
3623        if minval>maxval:
3624           raise ValueError,"minval = %s must be less then maxval %s"%(minval,maxval)
3625        return minimum(maximum(minval,arg),maxval)
3626    
3627        
3628  def inner(arg0,arg1):  def inner(arg0,arg1):
3629      """      """
# Line 3348  def inner(arg0,arg1): Line 3647  def inner(arg0,arg1):
3647      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3648      if not sh0==sh1:      if not sh0==sh1:
3649          raise ValueError,"inner: shape of arguments does not match"          raise ValueError,"inner: shape of arguments does not match"
3650      return generalTensorProduct(arg0,arg1,offset=len(sh0))      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))
3651    
3652  def matrixmult(arg0,arg1):  def matrixmult(arg0,arg1):
3653      """      """
# Line 3376  def matrixmult(arg0,arg1): Line 3675  def matrixmult(arg0,arg1):
3675          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
3676      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
3677          raise ValueError,"second argument must have rank 1 or 2"          raise ValueError,"second argument must have rank 1 or 2"
3678      return generalTensorProduct(arg0,arg1,offset=1)      return generalTensorProduct(arg0,arg1,axis_offset=1)
3679    
3680  def outer(arg0,arg1):  def outer(arg0,arg1):
3681      """      """
# Line 3394  def outer(arg0,arg1): Line 3693  def outer(arg0,arg1):
3693      @return: the outer product of arg0 and arg1 at each data point      @return: the outer product of arg0 and arg1 at each data point
3694      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3695      """      """
3696      return generalTensorProduct(arg0,arg1,offset=0)      return generalTensorProduct(arg0,arg1,axis_offset=0)
3697    
3698    
3699  def tensormult(arg0,arg1):  def tensormult(arg0,arg1):
# Line 3436  def tensormult(arg0,arg1): Line 3735  def tensormult(arg0,arg1):
3735      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
3736      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3737      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
3738         return generalTensorProduct(arg0,arg1,offset=1)         return generalTensorProduct(arg0,arg1,axis_offset=1)
3739      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
3740         return generalTensorProduct(arg0,arg1,offset=2)         return generalTensorProduct(arg0,arg1,axis_offset=2)
3741      else:      else:
3742          raise ValueError,"tensormult: first argument must have rank 2 or 4"          raise ValueError,"tensormult: first argument must have rank 2 or 4"
3743    
3744  def generalTensorProduct(arg0,arg1,offset=0):  def generalTensorProduct(arg0,arg1,axis_offset=0):
3745      """      """
3746      generalized tensor product      generalized tensor product
3747    
3748      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]
3749    
3750      where s runs through arg0.Shape[:arg0.Rank-offset]      where s runs through arg0.Shape[:arg0.Rank-axis_offset]
3751            r runs trough arg0.Shape[:offset]            r runs trough arg0.Shape[:axis_offset]
3752            t runs through arg1.Shape[offset:]            t runs through arg1.Shape[axis_offset:]
3753    
3754      In the first case the the second dimension of arg0 and the length of arg1 must match and        In the first case the the second dimension of arg0 and the length of arg1 must match and  
3755      in the second case the two last dimensions of arg0 must match the shape of arg1.      in the second case the two last dimensions of arg0 must match the shape of arg1.
# Line 3467  def generalTensorProduct(arg0,arg1,offse Line 3766  def generalTensorProduct(arg0,arg1,offse
3766      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols
3767      if isinstance(arg0,numarray.NumArray):      if isinstance(arg0,numarray.NumArray):
3768         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3769             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3770         else:         else:
3771             if not arg0.shape[arg0.rank-offset:]==arg1.shape[:offset]:             if not arg0.shape[arg0.rank-axis_offset:]==arg1.shape[:axis_offset]:
3772                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3773             arg0_c=arg0.copy()             arg0_c=arg0.copy()
3774             arg1_c=arg1.copy()             arg1_c=arg1.copy()
3775             sh0,sh1=arg0.shape,arg1.shape             sh0,sh1=arg0.shape,arg1.shape
3776             d0,d1,d01=1,1,1             d0,d1,d01=1,1,1
3777             for i in sh0[:arg0.rank-offset]: d0*=i             for i in sh0[:arg0.rank-axis_offset]: d0*=i
3778             for i in sh1[offset:]: d1*=i             for i in sh1[axis_offset:]: d1*=i
3779             for i in sh1[:offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
3780             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
3781             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
3782             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float)
3783             for i0 in range(d0):             for i0 in range(d0):
3784                      for i1 in range(d1):                      for i1 in range(d1):
3785                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])
3786             out.resize(sh0[:arg0.rank-offset]+sh1[offset:])             out.resize(sh0[:arg0.rank-axis_offset]+sh1[axis_offset:])
3787             return out             return out
3788      elif isinstance(arg0,escript.Data):      elif isinstance(arg0,escript.Data):
3789         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3790             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3791         else:         else:
3792             return escript_generalTensorProduct(arg0,arg1,offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,offset)             return escript_generalTensorProduct(arg0,arg1,axis_offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,axis_offset)
3793      else:            else:      
3794         return GeneralTensorProduct_Symbol(arg0,arg1,offset)         return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3795                                    
3796  class GeneralTensorProduct_Symbol(DependendSymbol):  class GeneralTensorProduct_Symbol(DependendSymbol):
3797     """     """
3798     Symbol representing the quotient of two arguments.     Symbol representing the quotient of two arguments.
3799     """     """
3800     def __init__(self,arg0,arg1,offset=0):     def __init__(self,arg0,arg1,axis_offset=0):
3801         """         """
3802         initialization of L{Symbol} representing the quotient of two arguments         initialization of L{Symbol} representing the quotient of two arguments
3803    
# Line 3511  class GeneralTensorProduct_Symbol(Depend Line 3810  class GeneralTensorProduct_Symbol(Depend
3810         """         """
3811         sh_arg0=pokeShape(arg0)         sh_arg0=pokeShape(arg0)
3812         sh_arg1=pokeShape(arg1)         sh_arg1=pokeShape(arg1)
3813         sh0=sh_arg0[:len(sh_arg0)-offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
3814         sh01=sh_arg0[len(sh_arg0)-offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
3815         sh10=sh_arg1[:offset]         sh10=sh_arg1[:axis_offset]
3816         sh1=sh_arg1[offset:]         sh1=sh_arg1[axis_offset:]
3817         if not sh01==sh10:         if not sh01==sh10:
3818             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3819         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,offset])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,axis_offset])
3820    
3821     def getMyCode(self,argstrs,format="escript"):     def getMyCode(self,argstrs,format="escript"):
3822        """        """
# Line 3532  class GeneralTensorProduct_Symbol(Depend Line 3831  class GeneralTensorProduct_Symbol(Depend
3831        @raise: NotImplementedError: if the requested format is not available        @raise: NotImplementedError: if the requested format is not available
3832        """        """
3833        if format=="escript" or format=="str" or format=="text":        if format=="escript" or format=="str" or format=="text":
3834           return "generalTensorProduct(%s,%s,offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])           return "generalTensorProduct(%s,%s,axis_offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])
3835        else:        else:
3836           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)
3837    
# Line 3557  class GeneralTensorProduct_Symbol(Depend Line 3856  class GeneralTensorProduct_Symbol(Depend
3856           args=self.getSubstitutedArguments(argvals)           args=self.getSubstitutedArguments(argvals)
3857           return generalTensorProduct(args[0],args[1],args[2])           return generalTensorProduct(args[0],args[1],args[2])
3858    
3859  def escript_generalTensorProduct(arg0,arg1,offset): # this should be escript._generalTensorProduct  def escript_generalTensorProduct(arg0,arg1,axis_offset): # this should be escript._generalTensorProduct
3860      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"
3861      # calculate the return shape:      # calculate the return shape:
3862      shape0=arg0.getShape()[:arg0.getRank()-offset]      shape0=arg0.getShape()[:arg0.getRank()-axis_offset]
3863      shape01=arg0.getShape()[arg0.getRank()-offset:]      shape01=arg0.getShape()[arg0.getRank()-axis_offset:]
3864      shape10=arg1.getShape()[:offset]      shape10=arg1.getShape()[:axis_offset]
3865      shape1=arg1.getShape()[offset:]      shape1=arg1.getShape()[axis_offset:]
3866      if not shape01==shape10:      if not shape01==shape10:
3867          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3868    
3869        # whatr function space should be used? (this here is not good!)
3870        fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace()
3871      # create return value:      # create return value:
3872      out=escript.Data(0.,tuple(shape0+shape1),arg0.getFunctionSpace())      out=escript.Data(0.,tuple(shape0+shape1),fs)
3873      #      #
3874      s0=[[]]      s0=[[]]
3875      for k in shape0:      for k in shape0:
# Line 3591  def escript_generalTensorProduct(arg0,ar Line 3892  def escript_generalTensorProduct(arg0,ar
3892    
3893      for i0 in s0:      for i0 in s0:
3894         for i1 in s1:         for i1 in s1:
3895           s=escript.Scalar(0.,arg0.getFunctionSpace())           s=escript.Scalar(0.,fs)
3896           for i01 in s01:           for i01 in s01:
3897              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))
3898           out.__setitem__(tuple(i0+i1),s)           out.__setitem__(tuple(i0+i1),s)
3899      return out      return out
3900    
3901    
3902  #=========================================================  #=========================================================
3903  #   some little helpers  #  functions dealing with spatial dependency
3904  #=========================================================  #=========================================================
3905  def grad(arg,where=None):  def grad(arg,where=None):
3906      """      """
3907      Returns the spatial gradient of arg at where.      Returns the spatial gradient of arg at where.
3908    
3909      @param arg:   Data object representing the function which gradient      If C{g} is the returned object, then
3910                    to be calculated.  
3911          - if C{arg} is rank 0 C{g[s]} is the derivative of C{arg} with respect to the C{s}-th spatial dimension.
3912          - if C{arg} is rank 1 C{g[i,s]} is the derivative of C{arg[i]} with respect to the C{s}-th spatial dimension.
3913          - if C{arg} is rank 2 C{g[i,j,s]} is the derivative of C{arg[i,j]} with respect to the C{s}-th spatial dimension.
3914          - if C{arg} is rank 3 C{g[i,j,k,s]} is the derivative of C{arg[i,j,k]} with respect to the C{s}-th spatial dimension.
3915    
3916        @param arg: function which gradient to be calculated. Its rank has to be less than 3.
3917        @type arg: L{escript.Data} or L{Symbol}
3918      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the gradient will be calculated.
3919                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
3920        @type where: C{None} or L{escript.FunctionSpace}
3921        @return: gradient of arg.
3922        @rtype:  L{escript.Data} or L{Symbol}
3923      """      """
3924      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
3925         return Grad_Symbol(arg,where)         return Grad_Symbol(arg,where)
# Line 3617  def grad(arg,where=None): Line 3929  def grad(arg,where=None):
3929         else:         else:
3930            return arg._grad(where)            return arg._grad(where)
3931      else:      else:
3932        raise TypeError,"grad: Unknown argument type."         raise TypeError,"grad: Unknown argument type."
3933    
3934    class Grad_Symbol(DependendSymbol):
3935       """
3936       L{Symbol} representing the result of the gradient operator
3937       """
3938       def __init__(self,arg,where=None):
3939          """
3940          initialization of gradient L{Symbol} with argument arg
3941          @param arg: argument of function
3942          @type arg: L{Symbol}.
3943          @param where: FunctionSpace in which the gradient will be calculated.
3944                      If not present or C{None} an appropriate default is used.
3945          @type where: C{None} or L{escript.FunctionSpace}
3946          """
3947          d=arg.getDim()
3948          if d==None:
3949             raise ValueError,"argument must have a spatial dimension"
3950          super(Grad_Symbol,self).__init__(args=[arg,where],shape=tuple(list(arg.getShape()).extend(d)),dim=d)
3951    
3952       def getMyCode(self,argstrs,format="escript"):
3953          """
3954          returns a program code that can be used to evaluate the symbol.
3955    
3956          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3957          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3958          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3959          @type format: C{str}
3960          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3961          @rtype: C{str}
3962          @raise: NotImplementedError: if the requested format is not available
3963          """
3964          if format=="escript" or format=="str"  or format=="text":
3965             return "grad(%s,where=%s)"%(argstrs[0],argstrs[1])
3966          else:
3967             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3968    
3969       def substitute(self,argvals):
3970          """
3971          assigns new values to symbols in the definition of the symbol.
3972          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3973    
3974          @param argvals: new values assigned to symbols
3975          @type argvals: C{dict} with keywords of type L{Symbol}.
3976          @return: result of the substitution process. Operations are executed as much as possible.
3977          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3978          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3979          """
3980          if argvals.has_key(self):
3981             arg=argvals[self]
3982             if self.isAppropriateValue(arg):
3983                return arg
3984             else:
3985                raise TypeError,"%s: new value is not appropriate."%str(self)
3986          else:
3987             arg=self.getSubstitutedArguments(argvals)
3988             return grad(arg[0],where=arg[1])
3989    
3990       def diff(self,arg):
3991          """
3992          differential of this object
3993    
3994          @param arg: the derivative is calculated with respect to arg
3995          @type arg: L{escript.Symbol}
3996          @return: derivative with respect to C{arg}
3997          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3998          """
3999          if arg==self:
4000             return identity(self.getShape())
4001          else:
4002             return grad(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4003    
4004  def integrate(arg,where=None):  def integrate(arg,where=None):
4005      """      """
4006      Return the integral if the function represented by Data object arg over      Return the integral of the function C{arg} over its domain. If C{where} is present C{arg} is interpolated to C{where}
4007      its domain.      before integration.
4008    
4009      @param arg:   Data object representing the function which is integrated.      @param arg:   the function which is integrated.
4010        @type arg: L{escript.Data} or L{Symbol}
4011      @param where: FunctionSpace in which the integral is calculated.      @param where: FunctionSpace in which the integral is calculated.
4012                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4013        @type where: C{None} or L{escript.FunctionSpace}
4014        @return: integral of arg.
4015        @rtype:  C{float}, C{numarray.NumArray} or L{Symbol}
4016      """      """
4017      if isinstance(arg,numarray.NumArray):      if isinstance(arg,Symbol):
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,float):  
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,int):  
         if checkForZero(arg):  
            return float(arg)  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,Symbol):  
4018         return Integrate_Symbol(arg,where)         return Integrate_Symbol(arg,where)
4019      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4020         if not where==None: arg=escript.Data(arg,where)         if not where==None: arg=escript.Data(arg,where)
# Line 3654  def integrate(arg,where=None): Line 4025  def integrate(arg,where=None):
4025      else:      else:
4026        raise TypeError,"integrate: Unknown argument type."        raise TypeError,"integrate: Unknown argument type."
4027    
4028    class Integrate_Symbol(DependendSymbol):
4029       """
4030       L{Symbol} representing the result of the spatial integration operator
4031       """
4032       def __init__(self,arg,where=None):
4033          """
4034          initialization of integration L{Symbol} with argument arg
4035          @param arg: argument of the integration
4036          @type arg: L{Symbol}.
4037          @param where: FunctionSpace in which the integration will be calculated.
4038                      If not present or C{None} an appropriate default is used.
4039          @type where: C{None} or L{escript.FunctionSpace}
4040          """
4041          super(Integrate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4042    
4043       def getMyCode(self,argstrs,format="escript"):
4044          """
4045          returns a program code that can be used to evaluate the symbol.
4046    
4047          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4048          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4049          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4050          @type format: C{str}
4051          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4052          @rtype: C{str}
4053          @raise: NotImplementedError: if the requested format is not available
4054          """
4055          if format=="escript" or format=="str"  or format=="text":
4056             return "integrate(%s,where=%s)"%(argstrs[0],argstrs[1])
4057          else:
4058             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4059    
4060       def substitute(self,argvals):
4061          """
4062          assigns new values to symbols in the definition of the symbol.
4063          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4064    
4065          @param argvals: new values assigned to symbols
4066          @type argvals: C{dict} with keywords of type L{Symbol}.
4067          @return: result of the substitution process. Operations are executed as much as possible.
4068          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4069          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4070          """
4071          if argvals.has_key(self):
4072             arg=argvals[self]
4073             if self.isAppropriateValue(arg):
4074                return arg
4075             else:
4076                raise TypeError,"%s: new value is not appropriate."%str(self)
4077          else:
4078             arg=self.getSubstitutedArguments(argvals)
4079             return integrate(arg[0],where=arg[1])
4080    
4081       def diff(self,arg):
4082          """
4083          differential of this object
4084    
4085          @param arg: the derivative is calculated with respect to arg
4086          @type arg: L{escript.Symbol}
4087          @return: derivative with respect to C{arg}
4088          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4089          """
4090          if arg==self:
4091             return identity(self.getShape())
4092          else:
4093             return integrate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4094    
4095    
4096  def interpolate(arg,where):  def interpolate(arg,where):
4097      """      """
4098      Interpolates the function into the FunctionSpace where.      interpolates the function into the FunctionSpace where.
4099    
4100      @param arg:    interpolant      @param arg: interpolant
4101      @param where:  FunctionSpace to interpolate to      @type arg: L{escript.Data} or L{Symbol}
4102        @param where: FunctionSpace to be interpolated to
4103        @type where: L{escript.FunctionSpace}
4104        @return: interpolated argument
4105        @rtype:  C{escript.Data} or L{Symbol}
4106      """      """
4107      if testForZero(arg):      if isinstance(arg,Symbol):
4108        return 0         return Interpolate_Symbol(arg,where)
     elif isinstance(arg,Symbol):  
        return Interpolated_Symbol(arg,where)  
4109      else:      else:
4110         return escript.Data(arg,where)         return escript.Data(arg,where)
4111    
4112    class Interpolate_Symbol(DependendSymbol):
4113       """
4114       L{Symbol} representing the result of the interpolation operator
4115       """
4116       def __init__(self,arg,where):
4117          """
4118          initialization of interpolation L{Symbol} with argument arg
4119          @param arg: argument of the interpolation
4120          @type arg: L{Symbol}.
4121          @param where: FunctionSpace into which the argument is interpolated.
4122          @type where: L{escript.FunctionSpace}
4123          """
4124          super(Interpolate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4125    
4126       def getMyCode(self,argstrs,format="escript"):
4127          """
4128          returns a program code that can be used to evaluate the symbol.
4129    
4130          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4131          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4132          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4133          @type format: C{str}
4134          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4135          @rtype: C{str}
4136          @raise: NotImplementedError: if the requested format is not available
4137          """
4138          if format=="escript" or format=="str"  or format=="text":
4139             return "interpolate(%s,where=%s)"%(argstrs[0],argstrs[1])
4140          else:
4141             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4142    
4143       def substitute(self,argvals):
4144          """
4145          assigns new values to symbols in the definition of the symbol.
4146          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4147    
4148          @param argvals: new values assigned to symbols
4149          @type argvals: C{dict} with keywords of type L{Symbol}.
4150          @return: result of the substitution process. Operations are executed as much as possible.
4151          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4152          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4153          """
4154          if argvals.has_key(self):
4155             arg=argvals[self]
4156             if self.isAppropriateValue(arg):
4157                return arg
4158             else:
4159                raise TypeError,"%s: new value is not appropriate."%str(self)
4160          else:
4161             arg=self.getSubstitutedArguments(argvals)
4162             return interpolate(arg[0],where=arg[1])
4163    
4164       def diff(self,arg):
4165          """
4166          differential of this object
4167    
4168          @param arg: the derivative is calculated with respect to arg
4169          @type arg: L{escript.Symbol}
4170          @return: derivative with respect to C{arg}
4171          @rtype: L{Symbol} but other types such as L{escript.Data}, L{numarray.NumArray}  are possible.
4172          """
4173          if arg==self:
4174             return identity(self.getShape())
4175          else:
4176             return interpolate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4177    
4178    
4179  def div(arg,where=None):  def div(arg,where=None):
4180      """      """
4181      Returns the divergence of arg at where.      returns the divergence of arg at where.
4182    
4183      @param arg:   Data object representing the function which gradient to      @param arg: function which divergence to be calculated. Its shape has to be (d,) where d is the spatial dimension.
4184                    be calculated.      @type arg: L{escript.Data} or L{Symbol}
4185      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the divergence will be calculated.
4186                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4187        @type where: C{None} or L{escript.FunctionSpace}
4188        @return: divergence of arg.
4189        @rtype:  L{escript.Data} or L{Symbol}
4190      """      """
4191      g=grad(arg,where)      if not arg.getShape()==(arg.getDomain().getDim(),):
4192      return trace(g,axis0=g.getRank()-2,axis1=g.getRank()-1)        raise ValueError,"div: expected shape is (%s,)"%arg.getDomain().getDim()
4193        return trace(grad(arg,where))
4194    
4195  def jump(arg):  def jump(arg,domain=None):
4196      """      """
4197      Returns the jump of arg across a continuity.      returns the jump of arg across the continuity of the domain
4198    
4199      @param arg:   Data object representing the function which gradient      @param arg: argument
4200                    to be calculated.      @type arg: L{escript.Data} or L{Symbol}
4201        @param domain: the domain where the discontinuity is located. If domain is not present or equal to C{None}
4202                       the domain of arg is used. If arg is a L{Symbol} the domain must be present.
4203        @type domain: C{None} or L{escript.Domain}
4204        @return: jump of arg
4205        @rtype:  L{escript.Data} or L{Symbol}
4206      """      """
4207      d=arg.getDomain()      if domain==None: domain=arg.getDomain()
4208      return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())      return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))
   
4209  #=============================  #=============================
4210  #  #
4211  # wrapper for various functions: if the argument has attribute the function name  # wrapper for various functions: if the argument has attribute the function name
# Line 3726  def transpose(arg,axis=None): Line 4242  def transpose(arg,axis=None):
4242      else:      else:
4243         return numarray.transpose(arg,axis=axis)         return numarray.transpose(arg,axis=axis)
4244    
 def trace(arg,axis0=0,axis1=1):  
     """  
     Return  
   
     @param arg:  
     """  
     if isinstance(arg,Symbol):  
        s=list(arg.getShape())          
        s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])  
        return Trace_Symbol(arg,axis0=axis0,axis1=axis1)  
     elif isinstance(arg,escript.Data):  
        # hack for trace  
        s=arg.getShape()  
        if s[axis0]!=s[axis1]:  
            raise ValueError,"illegal axis in trace"  
        out=escript.Scalar(0.,arg.getFunctionSpace())  
        for i in range(s[axis0]):  
           out+=arg[i,i]  
        return out  
        # end hack for trace  
     else:  
        return numarray.trace(arg,axis0=axis0,axis1=axis1)  
4245    
4246    
4247  def reorderComponents(arg,index):  def reorderComponents(arg,index):

Legend:
Removed from v.341  
changed lines
  Added in v.433

  ViewVC Help
Powered by ViewVC 1.1.26