/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 341 by gross, Mon Dec 12 05:26:10 2005 UTC revision 492 by gross, Fri Feb 3 02:07:24 2006 UTC
# Line 24  Utility functions for escript Line 24  Utility functions for escript
24  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
25  __licence__="contact: esys@access.uq.edu.au"  __licence__="contact: esys@access.uq.edu.au"
26  __url__="http://www.iservo.edu.au/esys/escript"  __url__="http://www.iservo.edu.au/esys/escript"
27  __version__="$Revision: 329 $"  __version__="$Revision$"
28  __date__="$Date$"  __date__="$Date$"
29    
30    
31  import math  import math
32  import numarray  import numarray
33    import numarray.linear_algebra
34  import escript  import escript
35  import os  import os
36    
# Line 43  import os Line 44  import os
44  # def matchType(arg0=0.,arg1=0.):  # def matchType(arg0=0.,arg1=0.):
45  # def matchShape(arg0,arg1):  # def matchShape(arg0,arg1):
46    
 # def maximum(arg0,arg1):  
 # def minimum(arg0,arg1):  
   
 # def transpose(arg,axis=None):  
 # def trace(arg,axis0=0,axis1=1):  
47  # def reorderComponents(arg,index):  # def reorderComponents(arg,index):
48    
 # def integrate(arg,where=None):  
 # def interpolate(arg,where):  
 # def div(arg,where=None):  
 # def grad(arg,where=None):  
   
49  #  #
50  # slicing: get  # slicing: get
51  #          set  #          set
# Line 125  def kronecker(d=3): Line 116  def kronecker(d=3):
116     return the kronecker S{delta}-symbol     return the kronecker S{delta}-symbol
117    
118     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
119     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
120     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
121     @rtype d: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2.
    @remark: the function is identical L{identity}  
122     """     """
123     return identityTensor(d)     return identityTensor(d)
124    
# Line 147  def identity(shape=()): Line 137  def identity(shape=()):
137        if len(shape)==1:        if len(shape)==1:
138            for i0 in range(shape[0]):            for i0 in range(shape[0]):
139               out[i0,i0]=1.               out[i0,i0]=1.
   
140        elif len(shape)==2:        elif len(shape)==2:
141            for i0 in range(shape[0]):            for i0 in range(shape[0]):
142               for i1 in range(shape[1]):               for i1 in range(shape[1]):
# Line 163  def identityTensor(d=3): Line 152  def identityTensor(d=3):
152     return the dxd identity matrix     return the dxd identity matrix
153    
154     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
155     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
156     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
157     @rtype: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2
158     """     """
159     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
160        d=d.getDim()         return escript.Data(identity((d.getDim(),)),d)
161     return identity(shape=(d,))     elif isinstance(d,escript.Domain):
162           return identity((d.getDim(),))
163       else:
164           return identity((d,))
165    
166  def identityTensor4(d=3):  def identityTensor4(d=3):
167     """     """
# Line 178  def identityTensor4(d=3): Line 170  def identityTensor4(d=3):
170     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
171     @type d: C{int} or any object with a C{getDim} method     @type d: C{int} or any object with a C{getDim} method
172     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise
173     @rtype: L{numarray.NumArray} of rank 4.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 4.
174     """     """
175     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
176        d=d.getDim()         return escript.Data(identity((d.getDim(),d.getDim())),d)
177     return identity((d,d))     elif isinstance(d,escript.Domain):
178           return identity((d.getDim(),d.getDim()))
179       else:
180           return identity((d,d))
181    
182  def unitVector(i=0,d=3):  def unitVector(i=0,d=3):
183     """     """
# Line 191  def unitVector(i=0,d=3): Line 186  def unitVector(i=0,d=3):
186     @param i: index     @param i: index
187     @type i: C{int}     @type i: C{int}
188     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
189     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
190     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise
191     @rtype: L{numarray.NumArray} of rank 1.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 1
192     """     """
193     return kronecker(d)[i]     return kronecker(d)[i]
194    
# Line 363  def testForZero(arg): Line 358  def testForZero(arg):
358      @return : True if the argument is identical to zero.      @return : True if the argument is identical to zero.
359      @rtype : C{bool}      @rtype : C{bool}
360      """      """
361      try:      if isinstance(arg,numarray.NumArray):
362           return not Lsup(arg)>0.
363        elif isinstance(arg,escript.Data):
364           return False
365        elif isinstance(arg,float):
366           return not Lsup(arg)>0.
367        elif isinstance(arg,int):
368         return not Lsup(arg)>0.         return not Lsup(arg)>0.
369      except TypeError:      elif isinstance(arg,Symbol):
370           return False
371        else:
372         return False         return False
373    
374  def matchType(arg0=0.,arg1=0.):  def matchType(arg0=0.,arg1=0.):
# Line 907  def wherePositive(arg): Line 910  def wherePositive(arg):
910     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
911     """     """
912     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
913        if arg.rank==0:        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))*1.
914           if arg>0:        if isinstance(out,float): out=numarray.array(out)
915             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))  
916     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
917        return arg._wherePositive()        return arg._wherePositive()
918     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 993  def whereNegative(arg): Line 992  def whereNegative(arg):
992     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
993     """     """
994     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
995        if arg.rank==0:        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))*1.
996           if arg<0:        if isinstance(out,float): out=numarray.array(out)
997             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))  
998     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
999        return arg._whereNegative()        return arg._whereNegative()
1000     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1079  def whereNonNegative(arg): Line 1074  def whereNonNegative(arg):
1074     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1075     """     """
1076     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1077        if arg.rank==0:        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1078           if arg<0:        if isinstance(out,float): out=numarray.array(out)
1079             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))  
1080     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1081        return arg._whereNonNegative()        return arg._whereNonNegative()
1082     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1113  def whereNonPositive(arg): Line 1104  def whereNonPositive(arg):
1104     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1105     """     """
1106     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1107        if arg.rank==0:        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1108           if arg>0:        if isinstance(out,float): out=numarray.array(out)
1109             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.  
1110     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1111        return arg._whereNonPositive()        return arg._whereNonPositive()
1112     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1149  def whereZero(arg,tol=0.): Line 1136  def whereZero(arg,tol=0.):
1136     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1137     """     """
1138     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1139        if arg.rank==0:        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1140           if abs(arg)<=tol:        if isinstance(out,float): out=numarray.array(out)
1141             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1142     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1143        if tol>0.:        if tol>0.:
1144           return whereNegative(abs(arg)-tol)           return whereNegative(abs(arg)-tol)
# Line 1236  def whereNonZero(arg,tol=0.): Line 1219  def whereNonZero(arg,tol=0.):
1219     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1220     """     """
1221     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1222        if arg.rank==0:        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1223          if abs(arg)>tol:        if isinstance(out,float): out=numarray.array(out)
1224             return numarray.array(1.)        return out
         else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1225     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1226        if tol>0.:        if tol>0.:
1227           return 1.-whereZero(arg,tol)           return 1.-whereZero(arg,tol)
# Line 2877  def length(arg): Line 2856  def length(arg):
2856     """     """
2857     return sqrt(inner(arg,arg))     return sqrt(inner(arg,arg))
2858    
2859    def trace(arg,axis_offset=0):
2860       """
2861       returns the trace of arg which the sum of arg[k,k] over k.
2862    
2863       @param arg: argument
2864       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}.
2865       @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2866                      axis_offset and axis_offset+1 must be equal.
2867       @type axis_offset: C{int}
2868       @return: trace of arg. The rank of the returned object is minus 2 of the rank of arg.
2869       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.
2870       """
2871       if isinstance(arg,numarray.NumArray):
2872          sh=arg.shape
2873          if len(sh)<2:
2874            raise ValueError,"trace: rank of argument must be greater than 1"
2875          if axis_offset<0 or axis_offset>len(sh)-2:
2876            raise ValueError,"trace: axis_offset must be between 0 and %s"%len(sh)-2
2877          s1=1
2878          for i in range(axis_offset): s1*=sh[i]
2879          s2=1
2880          for i in range(axis_offset+2,len(sh)): s2*=sh[i]
2881          if not sh[axis_offset] == sh[axis_offset+1]:
2882            raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2883          arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2884          out=numarray.zeros([s1,s2],numarray.Float)
2885          for i1 in range(s1):
2886            for i2 in range(s2):
2887                for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
2888          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
2889          return out
2890       elif isinstance(arg,escript.Data):
2891          return escript_trace(arg,axis_offset)
2892       elif isinstance(arg,float):
2893          raise TypeError,"trace: illegal argument type float."
2894       elif isinstance(arg,int):
2895          raise TypeError,"trace: illegal argument type int."
2896       elif isinstance(arg,Symbol):
2897          return Trace_Symbol(arg,axis_offset)
2898       else:
2899          raise TypeError,"trace: Unknown argument type."
2900    
2901    def escript_trace(arg,axis_offset): # this should be escript._trace
2902          "arg si a Data objects!!!"
2903          if arg.getRank()<2:
2904            raise ValueError,"escript_trace: rank of argument must be greater than 1"
2905          if axis_offset<0 or axis_offset>arg.getRank()-2:
2906            raise ValueError,"escript_trace: axis_offset must be between 0 and %s"%arg.getRank()-2
2907          s=list(arg.getShape())        
2908          if not s[axis_offset] == s[axis_offset+1]:
2909            raise ValueError,"escript_trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2910          out=escript.Data(0.,tuple(s[0:axis_offset]+s[axis_offset+2:]),arg.getFunctionSpace())
2911          if arg.getRank()==2:
2912             for i0 in range(s[0]):
2913                out+=arg[i0,i0]
2914          elif arg.getRank()==3:
2915             if axis_offset==0:
2916                for i0 in range(s[0]):
2917                      for i2 in range(s[2]):
2918                             out[i2]+=arg[i0,i0,i2]
2919             elif axis_offset==1:
2920                for i0 in range(s[0]):
2921                   for i1 in range(s[1]):
2922                             out[i0]+=arg[i0,i1,i1]
2923          elif arg.getRank()==4:
2924             if axis_offset==0:
2925                for i0 in range(s[0]):
2926                      for i2 in range(s[2]):
2927                         for i3 in range(s[3]):
2928                             out[i2,i3]+=arg[i0,i0,i2,i3]
2929             elif axis_offset==1:
2930                for i0 in range(s[0]):
2931                   for i1 in range(s[1]):
2932                         for i3 in range(s[3]):
2933                             out[i0,i3]+=arg[i0,i1,i1,i3]
2934             elif axis_offset==2:
2935                for i0 in range(s[0]):
2936                   for i1 in range(s[1]):
2937                      for i2 in range(s[2]):
2938                             out[i0,i1]+=arg[i0,i1,i2,i2]
2939          return out
2940    class Trace_Symbol(DependendSymbol):
2941       """
2942       L{Symbol} representing the result of the trace function
2943       """
2944       def __init__(self,arg,axis_offset=0):
2945          """
2946          initialization of trace L{Symbol} with argument arg
2947          @param arg: argument of function
2948          @type arg: L{Symbol}.
2949          @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2950                      axis_offset and axis_offset+1 must be equal.
2951          @type axis_offset: C{int}
2952          """
2953          if arg.getRank()<2:
2954            raise ValueError,"Trace_Symbol: rank of argument must be greater than 1"
2955          if axis_offset<0 or axis_offset>arg.getRank()-2:
2956            raise ValueError,"Trace_Symbol: axis_offset must be between 0 and %s"%arg.getRank()-2
2957          s=list(arg.getShape())        
2958          if not s[axis_offset] == s[axis_offset+1]:
2959            raise ValueError,"Trace_Symbol: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2960          super(Trace_Symbol,self).__init__(args=[arg,axis_offset],shape=tuple(s[0:axis_offset]+s[axis_offset+2:]),dim=arg.getDim())
2961    
2962       def getMyCode(self,argstrs,format="escript"):
2963          """
2964          returns a program code that can be used to evaluate the symbol.
2965    
2966          @param argstrs: gives for each argument a string representing the argument for the evaluation.
2967          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
2968          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
2969          @type format: C{str}
2970          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
2971          @rtype: C{str}
2972          @raise: NotImplementedError: if the requested format is not available
2973          """
2974          if format=="escript" or format=="str"  or format=="text":
2975             return "trace(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
2976          else:
2977             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
2978    
2979       def substitute(self,argvals):
2980          """
2981          assigns new values to symbols in the definition of the symbol.
2982          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
2983    
2984          @param argvals: new values assigned to symbols
2985          @type argvals: C{dict} with keywords of type L{Symbol}.
2986          @return: result of the substitution process. Operations are executed as much as possible.
2987          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
2988          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
2989          """
2990          if argvals.has_key(self):
2991             arg=argvals[self]
2992             if self.isAppropriateValue(arg):
2993                return arg
2994             else:
2995                raise TypeError,"%s: new value is not appropriate."%str(self)
2996          else:
2997             arg=self.getSubstitutedArguments(argvals)
2998             return trace(arg[0],axis_offset=arg[1])
2999    
3000       def diff(self,arg):
3001          """
3002          differential of this object
3003    
3004          @param arg: the derivative is calculated with respect to arg
3005          @type arg: L{escript.Symbol}
3006          @return: derivative with respect to C{arg}
3007          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3008          """
3009          if arg==self:
3010             return identity(self.getShape())
3011          else:
3012             return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3013    
3014    def transpose(arg,axis_offset=None):
3015       """
3016       returns the transpose of arg by swaping the first axis_offset and the last rank-axis_offset components.
3017    
3018       @param arg: argument
3019       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}, C{float}, C{int}
3020       @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3021                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3022       @type axis_offset: C{int}
3023       @return: transpose of arg
3024       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray},C{float}, C{int} depending on the type of arg.
3025       """
3026       if isinstance(arg,numarray.NumArray):
3027          if axis_offset==None: axis_offset=int(arg.rank/2)
3028          return numarray.transpose(arg,axes=range(axis_offset,arg.rank)+range(0,axis_offset))
3029       elif isinstance(arg,escript.Data):
3030          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3031          return escript_transpose(arg,axis_offset)
3032       elif isinstance(arg,float):
3033          if not ( axis_offset==0 or axis_offset==None):
3034            raise ValueError,"transpose: axis_offset must be 0 for float argument"
3035          return arg
3036       elif isinstance(arg,int):
3037          if not ( axis_offset==0 or axis_offset==None):
3038            raise ValueError,"transpose: axis_offset must be 0 for int argument"
3039          return float(arg)
3040       elif isinstance(arg,Symbol):
3041          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3042          return Transpose_Symbol(arg,axis_offset)
3043       else:
3044          raise TypeError,"transpose: Unknown argument type."
3045    
3046    def escript_transpose(arg,axis_offset): # this should be escript._transpose
3047          "arg si a Data objects!!!"
3048          r=arg.getRank()
3049          if axis_offset<0 or axis_offset>r:
3050            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3051          s=arg.getShape()
3052          s_out=s[axis_offset:]+s[:axis_offset]
3053          out=escript.Data(0.,s_out,arg.getFunctionSpace())
3054          if r==4:
3055             if axis_offset==1:
3056                for i0 in range(s_out[0]):
3057                   for i1 in range(s_out[1]):
3058                      for i2 in range(s_out[2]):
3059                         for i3 in range(s_out[3]):
3060                             out[i0,i1,i2,i3]=arg[i3,i0,i1,i2]
3061             elif axis_offset==2:
3062                for i0 in range(s_out[0]):
3063                   for i1 in range(s_out[1]):
3064                      for i2 in range(s_out[2]):
3065                         for i3 in range(s_out[3]):
3066                             out[i0,i1,i2,i3]=arg[i2,i3,i0,i1]
3067             elif axis_offset==3:
3068                for i0 in range(s_out[0]):
3069                   for i1 in range(s_out[1]):
3070                      for i2 in range(s_out[2]):
3071                         for i3 in range(s_out[3]):
3072                             out[i0,i1,i2,i3]=arg[i1,i2,i3,i0]
3073             else:
3074                for i0 in range(s_out[0]):
3075                   for i1 in range(s_out[1]):
3076                      for i2 in range(s_out[2]):
3077                         for i3 in range(s_out[3]):
3078                             out[i0,i1,i2,i3]=arg[i0,i1,i2,i3]
3079          elif r==3:
3080             if axis_offset==1:
3081                for i0 in range(s_out[0]):
3082                   for i1 in range(s_out[1]):
3083                      for i2 in range(s_out[2]):
3084                             out[i0,i1,i2]=arg[i2,i0,i1]
3085             elif axis_offset==2:
3086                for i0 in range(s_out[0]):
3087                   for i1 in range(s_out[1]):
3088                      for i2 in range(s_out[2]):
3089                             out[i0,i1,i2]=arg[i1,i2,i0]
3090             else:
3091                for i0 in range(s_out[0]):
3092                   for i1 in range(s_out[1]):
3093                      for i2 in range(s_out[2]):
3094                             out[i0,i1,i2]=arg[i0,i1,i2]
3095          elif r==2:
3096             if axis_offset==1:
3097                for i0 in range(s_out[0]):
3098                   for i1 in range(s_out[1]):
3099                             out[i0,i1]=arg[i1,i0]
3100             else:
3101                for i0 in range(s_out[0]):
3102                   for i1 in range(s_out[1]):
3103                             out[i0,i1]=arg[i0,i1]
3104          elif r==1:
3105              for i0 in range(s_out[0]):
3106                   out[i0]=arg[i0]
3107          elif r==0:
3108                 out=arg+0.
3109          return out
3110    class Transpose_Symbol(DependendSymbol):
3111       """
3112       L{Symbol} representing the result of the transpose function
3113       """
3114       def __init__(self,arg,axis_offset=None):
3115          """
3116          initialization of transpose L{Symbol} with argument arg
3117    
3118          @param arg: argument of function
3119          @type arg: L{Symbol}.
3120           @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3121                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3122          @type axis_offset: C{int}
3123          """
3124          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3125          if axis_offset<0 or axis_offset>arg.getRank():
3126            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3127          s=arg.getShape()
3128          super(Transpose_Symbol,self).__init__(args=[arg,axis_offset],shape=s[axis_offset:]+s[:axis_offset],dim=arg.getDim())
3129    
3130       def getMyCode(self,argstrs,format="escript"):
3131          """
3132          returns a program code that can be used to evaluate the symbol.
3133    
3134          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3135          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3136          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3137          @type format: C{str}
3138          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3139          @rtype: C{str}
3140          @raise: NotImplementedError: if the requested format is not available
3141          """
3142          if format=="escript" or format=="str"  or format=="text":
3143             return "transpose(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
3144          else:
3145             raise NotImplementedError,"Transpose_Symbol does not provide program code for format %s."%format
3146    
3147       def substitute(self,argvals):
3148          """
3149          assigns new values to symbols in the definition of the symbol.
3150          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3151    
3152          @param argvals: new values assigned to symbols
3153          @type argvals: C{dict} with keywords of type L{Symbol}.
3154          @return: result of the substitution process. Operations are executed as much as possible.
3155          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3156          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3157          """
3158          if argvals.has_key(self):
3159             arg=argvals[self]
3160             if self.isAppropriateValue(arg):
3161                return arg
3162             else:
3163                raise TypeError,"%s: new value is not appropriate."%str(self)
3164          else:
3165             arg=self.getSubstitutedArguments(argvals)
3166             return transpose(arg[0],axis_offset=arg[1])
3167    
3168       def diff(self,arg):
3169          """
3170          differential of this object
3171    
3172          @param arg: the derivative is calculated with respect to arg
3173          @type arg: L{escript.Symbol}
3174          @return: derivative with respect to C{arg}
3175          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3176          """
3177          if arg==self:
3178             return identity(self.getShape())
3179          else:
3180             return transpose(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3181    
3182    def inverse(arg):
3183        """
3184        returns the inverse of the square matrix arg.
3185    
3186        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal
3187        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3188        @return: inverse arg_inv of the argument. It will be matrixmul(inverse(arg),arg) almost equal to kronecker(arg.getShape()[0])
3189        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3190        """
3191        if isinstance(arg,numarray.NumArray):
3192          return numarray.linear_algebra.inverse(arg)
3193        elif isinstance(arg,escript.Data):
3194          return escript_inverse(arg)
3195        elif isinstance(arg,float):
3196          return 1./arg
3197        elif isinstance(arg,int):
3198          return 1./float(arg)
3199        elif isinstance(arg,Symbol):
3200          return Inverse_Symbol(arg)
3201        else:
3202          raise TypeError,"inverse: Unknown argument type."
3203    
3204    def escript_inverse(arg): # this should be escript._inverse and use LAPACK
3205          "arg is a Data objects!!!"
3206          if not arg.getRank()==2:
3207            raise ValueError,"escript_inverse: argument must have rank 2"
3208          s=arg.getShape()      
3209          if not s[0] == s[1]:
3210            raise ValueError,"escript_inverse: argument must be a square matrix."
3211          out=escript.Data(0.,s,arg.getFunctionSpace())
3212          if s[0]==1:
3213              if inf(abs(arg[0,0]))==0: # in c this should be done point wise as abs(arg[0,0](i))<=0.
3214                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3215              out[0,0]=1./arg[0,0]
3216          elif s[0]==2:
3217              A11=arg[0,0]
3218              A12=arg[0,1]
3219              A21=arg[1,0]
3220              A22=arg[1,1]
3221              D = A11*A22-A12*A21
3222              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3223                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3224              D=1./D
3225              out[0,0]= A22*D
3226              out[1,0]=-A21*D
3227              out[0,1]=-A12*D
3228              out[1,1]= A11*D
3229          elif s[0]==3:
3230              A11=arg[0,0]
3231              A21=arg[1,0]
3232              A31=arg[2,0]
3233              A12=arg[0,1]
3234              A22=arg[1,1]
3235              A32=arg[2,1]
3236              A13=arg[0,2]
3237              A23=arg[1,2]
3238              A33=arg[2,2]
3239              D  =  A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
3240              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3241                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3242              D=1./D
3243              out[0,0]=(A22*A33-A23*A32)*D
3244              out[1,0]=(A31*A23-A21*A33)*D
3245              out[2,0]=(A21*A32-A31*A22)*D
3246              out[0,1]=(A13*A32-A12*A33)*D
3247              out[1,1]=(A11*A33-A31*A13)*D
3248              out[2,1]=(A12*A31-A11*A32)*D
3249              out[0,2]=(A12*A23-A13*A22)*D
3250              out[1,2]=(A13*A21-A11*A23)*D
3251              out[2,2]=(A11*A22-A12*A21)*D
3252          else:
3253             raise TypeError,"escript_inverse: only matrix dimensions 1,2,3 are supported right now."
3254          return out
3255    
3256    class Inverse_Symbol(DependendSymbol):
3257       """
3258       L{Symbol} representing the result of the inverse function
3259       """
3260       def __init__(self,arg):
3261          """
3262          initialization of inverse L{Symbol} with argument arg
3263          @param arg: argument of function
3264          @type arg: L{Symbol}.
3265          """
3266          if not arg.getRank()==2:
3267            raise ValueError,"Inverse_Symbol:: argument must have rank 2"
3268          s=arg.getShape()
3269          if not s[0] == s[1]:
3270            raise ValueError,"Inverse_Symbol:: argument must be a square matrix."
3271          super(Inverse_Symbol,self).__init__(args=[arg],shape=s,dim=arg.getDim())
3272    
3273       def getMyCode(self,argstrs,format="escript"):
3274          """
3275          returns a program code that can be used to evaluate the symbol.
3276    
3277          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3278          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3279          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3280          @type format: C{str}
3281          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3282          @rtype: C{str}
3283          @raise: NotImplementedError: if the requested format is not available
3284          """
3285          if format=="escript" or format=="str"  or format=="text":
3286             return "inverse(%s)"%argstrs[0]
3287          else:
3288             raise NotImplementedError,"Inverse_Symbol does not provide program code for format %s."%format
3289    
3290       def substitute(self,argvals):
3291          """
3292          assigns new values to symbols in the definition of the symbol.
3293          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3294    
3295          @param argvals: new values assigned to symbols
3296          @type argvals: C{dict} with keywords of type L{Symbol}.
3297          @return: result of the substitution process. Operations are executed as much as possible.
3298          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3299          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3300          """
3301          if argvals.has_key(self):
3302             arg=argvals[self]
3303             if self.isAppropriateValue(arg):
3304                return arg
3305             else:
3306                raise TypeError,"%s: new value is not appropriate."%str(self)
3307          else:
3308             arg=self.getSubstitutedArguments(argvals)
3309             return inverse(arg[0])
3310    
3311       def diff(self,arg):
3312          """
3313          differential of this object
3314    
3315          @param arg: the derivative is calculated with respect to arg
3316          @type arg: L{escript.Symbol}
3317          @return: derivative with respect to C{arg}
3318          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3319          """
3320          if arg==self:
3321             return identity(self.getShape())
3322          else:
3323             return -matrixmult(matrixmult(self,self.getDifferentiatedArguments(arg)[0]),self)
3324  #=======================================================  #=======================================================
3325  #  Binary operations:  #  Binary operations:
3326  #=======================================================  #=======================================================
# Line 3304  def maximum(*args): Line 3748  def maximum(*args):
3748         if out==None:         if out==None:
3749            out=a            out=a
3750         else:         else:
3751            m=whereNegative(out-a)            diff=add(a,-out)
3752            out=m*a+(1.-m)*out            out=add(out,mult(wherePositive(diff),diff))
3753      return out      return out
3754        
3755  def minimum(*arg):  def minimum(*args):
3756      """      """
3757      the minimum over arguments args      the minimum over arguments args
3758    
# Line 3322  def minimum(*arg): Line 3766  def minimum(*arg):
3766         if out==None:         if out==None:
3767            out=a            out=a
3768         else:         else:
3769            m=whereNegative(out-a)            diff=add(a,-out)
3770            out=m*out+(1.-m)*a            out=add(out,mult(whereNegative(diff),diff))
3771      return out      return out
3772    
3773    def clip(arg,minval=0.,maxval=1.):
3774        """
3775        cuts the values of arg between minval and maxval
3776    
3777        @param arg: argument
3778        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float}
3779        @param minval: lower range
3780        @type arg: C{float}
3781        @param maxval: upper range
3782        @type arg: C{float}
3783        @return: is on object with all its value between minval and maxval. value of the argument that greater then minval and
3784                 less then maxval are unchanged.
3785        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float} depending on the input
3786        @raise ValueError: if minval>maxval
3787        """
3788        if minval>maxval:
3789           raise ValueError,"minval = %s must be less then maxval %s"%(minval,maxval)
3790        return minimum(maximum(minval,arg),maxval)
3791    
3792        
3793  def inner(arg0,arg1):  def inner(arg0,arg1):
3794      """      """
# Line 3348  def inner(arg0,arg1): Line 3812  def inner(arg0,arg1):
3812      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3813      if not sh0==sh1:      if not sh0==sh1:
3814          raise ValueError,"inner: shape of arguments does not match"          raise ValueError,"inner: shape of arguments does not match"
3815      return generalTensorProduct(arg0,arg1,offset=len(sh0))      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))
3816    
3817  def matrixmult(arg0,arg1):  def matrixmult(arg0,arg1):
3818      """      """
# Line 3376  def matrixmult(arg0,arg1): Line 3840  def matrixmult(arg0,arg1):
3840          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
3841      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
3842          raise ValueError,"second argument must have rank 1 or 2"          raise ValueError,"second argument must have rank 1 or 2"
3843      return generalTensorProduct(arg0,arg1,offset=1)      return generalTensorProduct(arg0,arg1,axis_offset=1)
3844    
3845  def outer(arg0,arg1):  def outer(arg0,arg1):
3846      """      """
# Line 3394  def outer(arg0,arg1): Line 3858  def outer(arg0,arg1):
3858      @return: the outer product of arg0 and arg1 at each data point      @return: the outer product of arg0 and arg1 at each data point
3859      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3860      """      """
3861      return generalTensorProduct(arg0,arg1,offset=0)      return generalTensorProduct(arg0,arg1,axis_offset=0)
3862    
3863    
3864  def tensormult(arg0,arg1):  def tensormult(arg0,arg1):
# Line 3436  def tensormult(arg0,arg1): Line 3900  def tensormult(arg0,arg1):
3900      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
3901      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3902      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
3903         return generalTensorProduct(arg0,arg1,offset=1)         return generalTensorProduct(arg0,arg1,axis_offset=1)
3904      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
3905         return generalTensorProduct(arg0,arg1,offset=2)         return generalTensorProduct(arg0,arg1,axis_offset=2)
3906      else:      else:
3907          raise ValueError,"tensormult: first argument must have rank 2 or 4"          raise ValueError,"tensormult: first argument must have rank 2 or 4"
3908    
3909  def generalTensorProduct(arg0,arg1,offset=0):  def generalTensorProduct(arg0,arg1,axis_offset=0):
3910      """      """
3911      generalized tensor product      generalized tensor product
3912    
3913      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]
3914    
3915      where s runs through arg0.Shape[:arg0.Rank-offset]      where s runs through arg0.Shape[:arg0.Rank-axis_offset]
3916            r runs trough arg0.Shape[:offset]            r runs trough arg0.Shape[:axis_offset]
3917            t runs through arg1.Shape[offset:]            t runs through arg1.Shape[axis_offset:]
3918    
3919      In the first case the the second dimension of arg0 and the length of arg1 must match and        In the first case the the second dimension of arg0 and the length of arg1 must match and  
3920      in the second case the two last dimensions of arg0 must match the shape of arg1.      in the second case the two last dimensions of arg0 must match the shape of arg1.
# Line 3467  def generalTensorProduct(arg0,arg1,offse Line 3931  def generalTensorProduct(arg0,arg1,offse
3931      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols
3932      if isinstance(arg0,numarray.NumArray):      if isinstance(arg0,numarray.NumArray):
3933         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3934             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3935         else:         else:
3936             if not arg0.shape[arg0.rank-offset:]==arg1.shape[:offset]:             if not arg0.shape[arg0.rank-axis_offset:]==arg1.shape[:axis_offset]:
3937                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3938             arg0_c=arg0.copy()             arg0_c=arg0.copy()
3939             arg1_c=arg1.copy()             arg1_c=arg1.copy()
3940             sh0,sh1=arg0.shape,arg1.shape             sh0,sh1=arg0.shape,arg1.shape
3941             d0,d1,d01=1,1,1             d0,d1,d01=1,1,1
3942             for i in sh0[:arg0.rank-offset]: d0*=i             for i in sh0[:arg0.rank-axis_offset]: d0*=i
3943             for i in sh1[offset:]: d1*=i             for i in sh1[axis_offset:]: d1*=i
3944             for i in sh1[:offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
3945             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
3946             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
3947             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float)
3948             for i0 in range(d0):             for i0 in range(d0):
3949                      for i1 in range(d1):                      for i1 in range(d1):
3950                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])
3951             out.resize(sh0[:arg0.rank-offset]+sh1[offset:])             out.resize(sh0[:arg0.rank-axis_offset]+sh1[axis_offset:])
3952             return out             return out
3953      elif isinstance(arg0,escript.Data):      elif isinstance(arg0,escript.Data):
3954         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3955             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3956         else:         else:
3957             return escript_generalTensorProduct(arg0,arg1,offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,offset)             return escript_generalTensorProduct(arg0,arg1,axis_offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,axis_offset)
3958      else:            else:      
3959         return GeneralTensorProduct_Symbol(arg0,arg1,offset)         return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3960                                    
3961  class GeneralTensorProduct_Symbol(DependendSymbol):  class GeneralTensorProduct_Symbol(DependendSymbol):
3962     """     """
3963     Symbol representing the quotient of two arguments.     Symbol representing the quotient of two arguments.
3964     """     """
3965     def __init__(self,arg0,arg1,offset=0):     def __init__(self,arg0,arg1,axis_offset=0):
3966         """         """
3967         initialization of L{Symbol} representing the quotient of two arguments         initialization of L{Symbol} representing the quotient of two arguments
3968    
# Line 3511  class GeneralTensorProduct_Symbol(Depend Line 3975  class GeneralTensorProduct_Symbol(Depend
3975         """         """
3976         sh_arg0=pokeShape(arg0)         sh_arg0=pokeShape(arg0)
3977         sh_arg1=pokeShape(arg1)         sh_arg1=pokeShape(arg1)
3978         sh0=sh_arg0[:len(sh_arg0)-offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
3979         sh01=sh_arg0[len(sh_arg0)-offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
3980         sh10=sh_arg1[:offset]         sh10=sh_arg1[:axis_offset]
3981         sh1=sh_arg1[offset:]         sh1=sh_arg1[axis_offset:]
3982         if not sh01==sh10:         if not sh01==sh10:
3983             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3984         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,offset])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,axis_offset])
3985    
3986     def getMyCode(self,argstrs,format="escript"):     def getMyCode(self,argstrs,format="escript"):
3987        """        """
# Line 3532  class GeneralTensorProduct_Symbol(Depend Line 3996  class GeneralTensorProduct_Symbol(Depend
3996        @raise: NotImplementedError: if the requested format is not available        @raise: NotImplementedError: if the requested format is not available
3997        """        """
3998        if format=="escript" or format=="str" or format=="text":        if format=="escript" or format=="str" or format=="text":
3999           return "generalTensorProduct(%s,%s,offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])           return "generalTensorProduct(%s,%s,axis_offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])
4000        else:        else:
4001           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)
4002    
# Line 3557  class GeneralTensorProduct_Symbol(Depend Line 4021  class GeneralTensorProduct_Symbol(Depend
4021           args=self.getSubstitutedArguments(argvals)           args=self.getSubstitutedArguments(argvals)
4022           return generalTensorProduct(args[0],args[1],args[2])           return generalTensorProduct(args[0],args[1],args[2])
4023    
4024  def escript_generalTensorProduct(arg0,arg1,offset): # this should be escript._generalTensorProduct  def escript_generalTensorProduct(arg0,arg1,axis_offset): # this should be escript._generalTensorProduct
4025      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"
4026      # calculate the return shape:      # calculate the return shape:
4027      shape0=arg0.getShape()[:arg0.getRank()-offset]      shape0=arg0.getShape()[:arg0.getRank()-axis_offset]
4028      shape01=arg0.getShape()[arg0.getRank()-offset:]      shape01=arg0.getShape()[arg0.getRank()-axis_offset:]
4029      shape10=arg1.getShape()[:offset]      shape10=arg1.getShape()[:axis_offset]
4030      shape1=arg1.getShape()[offset:]      shape1=arg1.getShape()[axis_offset:]
4031      if not shape01==shape10:      if not shape01==shape10:
4032          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4033    
4034        # whatr function space should be used? (this here is not good!)
4035        fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace()
4036      # create return value:      # create return value:
4037      out=escript.Data(0.,tuple(shape0+shape1),arg0.getFunctionSpace())      out=escript.Data(0.,tuple(shape0+shape1),fs)
4038      #      #
4039      s0=[[]]      s0=[[]]
4040      for k in shape0:      for k in shape0:
# Line 3591  def escript_generalTensorProduct(arg0,ar Line 4057  def escript_generalTensorProduct(arg0,ar
4057    
4058      for i0 in s0:      for i0 in s0:
4059         for i1 in s1:         for i1 in s1:
4060           s=escript.Scalar(0.,arg0.getFunctionSpace())           s=escript.Scalar(0.,fs)
4061           for i01 in s01:           for i01 in s01:
4062              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))
4063           out.__setitem__(tuple(i0+i1),s)           out.__setitem__(tuple(i0+i1),s)
4064      return out      return out
4065    
4066    
4067  #=========================================================  #=========================================================
4068  #   some little helpers  #  functions dealing with spatial dependency
4069  #=========================================================  #=========================================================
4070  def grad(arg,where=None):  def grad(arg,where=None):
4071      """      """
4072      Returns the spatial gradient of arg at where.      Returns the spatial gradient of arg at where.
4073    
4074        If C{g} is the returned object, then
4075    
4076      @param arg:   Data object representing the function which gradient        - if C{arg} is rank 0 C{g[s]} is the derivative of C{arg} with respect to the C{s}-th spatial dimension.
4077                    to be calculated.        - if C{arg} is rank 1 C{g[i,s]} is the derivative of C{arg[i]} with respect to the C{s}-th spatial dimension.
4078          - if C{arg} is rank 2 C{g[i,j,s]} is the derivative of C{arg[i,j]} with respect to the C{s}-th spatial dimension.
4079          - if C{arg} is rank 3 C{g[i,j,k,s]} is the derivative of C{arg[i,j,k]} with respect to the C{s}-th spatial dimension.
4080    
4081        @param arg: function which gradient to be calculated. Its rank has to be less than 3.
4082        @type arg: L{escript.Data} or L{Symbol}
4083      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the gradient will be calculated.
4084                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4085        @type where: C{None} or L{escript.FunctionSpace}
4086        @return: gradient of arg.
4087        @rtype:  L{escript.Data} or L{Symbol}
4088      """      """
4089      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
4090         return Grad_Symbol(arg,where)         return Grad_Symbol(arg,where)
# Line 3617  def grad(arg,where=None): Line 4094  def grad(arg,where=None):
4094         else:         else:
4095            return arg._grad(where)            return arg._grad(where)
4096      else:      else:
4097        raise TypeError,"grad: Unknown argument type."         raise TypeError,"grad: Unknown argument type."
4098    
4099    class Grad_Symbol(DependendSymbol):
4100       """
4101       L{Symbol} representing the result of the gradient operator
4102       """
4103       def __init__(self,arg,where=None):
4104          """
4105          initialization of gradient L{Symbol} with argument arg
4106          @param arg: argument of function
4107          @type arg: L{Symbol}.
4108          @param where: FunctionSpace in which the gradient will be calculated.
4109                      If not present or C{None} an appropriate default is used.
4110          @type where: C{None} or L{escript.FunctionSpace}
4111          """
4112          d=arg.getDim()
4113          if d==None:
4114             raise ValueError,"argument must have a spatial dimension"
4115          super(Grad_Symbol,self).__init__(args=[arg,where],shape=arg.getShape()+(d,),dim=d)
4116    
4117       def getMyCode(self,argstrs,format="escript"):
4118          """
4119          returns a program code that can be used to evaluate the symbol.
4120    
4121          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4122          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4123          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4124          @type format: C{str}
4125          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4126          @rtype: C{str}
4127          @raise: NotImplementedError: if the requested format is not available
4128          """
4129          if format=="escript" or format=="str"  or format=="text":
4130             return "grad(%s,where=%s)"%(argstrs[0],argstrs[1])
4131          else:
4132             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4133    
4134       def substitute(self,argvals):
4135          """
4136          assigns new values to symbols in the definition of the symbol.
4137          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4138    
4139          @param argvals: new values assigned to symbols
4140          @type argvals: C{dict} with keywords of type L{Symbol}.
4141          @return: result of the substitution process. Operations are executed as much as possible.
4142          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4143          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4144          """
4145          if argvals.has_key(self):
4146             arg=argvals[self]
4147             if self.isAppropriateValue(arg):
4148                return arg
4149             else:
4150                raise TypeError,"%s: new value is not appropriate."%str(self)
4151          else:
4152             arg=self.getSubstitutedArguments(argvals)
4153             return grad(arg[0],where=arg[1])
4154    
4155       def diff(self,arg):
4156          """
4157          differential of this object
4158    
4159          @param arg: the derivative is calculated with respect to arg
4160          @type arg: L{escript.Symbol}
4161          @return: derivative with respect to C{arg}
4162          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4163          """
4164          if arg==self:
4165             return identity(self.getShape())
4166          else:
4167             return grad(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4168    
4169  def integrate(arg,where=None):  def integrate(arg,where=None):
4170      """      """
4171      Return the integral if the function represented by Data object arg over      Return the integral of the function C{arg} over its domain. If C{where} is present C{arg} is interpolated to C{where}
4172      its domain.      before integration.
4173    
4174      @param arg:   Data object representing the function which is integrated.      @param arg:   the function which is integrated.
4175        @type arg: L{escript.Data} or L{Symbol}
4176      @param where: FunctionSpace in which the integral is calculated.      @param where: FunctionSpace in which the integral is calculated.
4177                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4178        @type where: C{None} or L{escript.FunctionSpace}
4179        @return: integral of arg.
4180        @rtype:  C{float}, C{numarray.NumArray} or L{Symbol}
4181      """      """
4182      if isinstance(arg,numarray.NumArray):      if isinstance(arg,Symbol):
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,float):  
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,int):  
         if checkForZero(arg):  
            return float(arg)  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,Symbol):  
4183         return Integrate_Symbol(arg,where)         return Integrate_Symbol(arg,where)
4184      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4185         if not where==None: arg=escript.Data(arg,where)         if not where==None: arg=escript.Data(arg,where)
# Line 3654  def integrate(arg,where=None): Line 4190  def integrate(arg,where=None):
4190      else:      else:
4191        raise TypeError,"integrate: Unknown argument type."        raise TypeError,"integrate: Unknown argument type."
4192    
4193    class Integrate_Symbol(DependendSymbol):
4194       """
4195       L{Symbol} representing the result of the spatial integration operator
4196       """
4197       def __init__(self,arg,where=None):
4198          """
4199          initialization of integration L{Symbol} with argument arg
4200          @param arg: argument of the integration
4201          @type arg: L{Symbol}.
4202          @param where: FunctionSpace in which the integration will be calculated.
4203                      If not present or C{None} an appropriate default is used.
4204          @type where: C{None} or L{escript.FunctionSpace}
4205          """
4206          super(Integrate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4207    
4208       def getMyCode(self,argstrs,format="escript"):
4209          """
4210          returns a program code that can be used to evaluate the symbol.
4211    
4212          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4213          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4214          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4215          @type format: C{str}
4216          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4217          @rtype: C{str}
4218          @raise: NotImplementedError: if the requested format is not available
4219          """
4220          if format=="escript" or format=="str"  or format=="text":
4221             return "integrate(%s,where=%s)"%(argstrs[0],argstrs[1])
4222          else:
4223             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4224    
4225       def substitute(self,argvals):
4226          """
4227          assigns new values to symbols in the definition of the symbol.
4228          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4229    
4230          @param argvals: new values assigned to symbols
4231          @type argvals: C{dict} with keywords of type L{Symbol}.
4232          @return: result of the substitution process. Operations are executed as much as possible.
4233          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4234          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4235          """
4236          if argvals.has_key(self):
4237             arg=argvals[self]
4238             if self.isAppropriateValue(arg):
4239                return arg
4240             else:
4241                raise TypeError,"%s: new value is not appropriate."%str(self)
4242          else:
4243             arg=self.getSubstitutedArguments(argvals)
4244             return integrate(arg[0],where=arg[1])
4245    
4246       def diff(self,arg):
4247          """
4248          differential of this object
4249    
4250          @param arg: the derivative is calculated with respect to arg
4251          @type arg: L{escript.Symbol}
4252          @return: derivative with respect to C{arg}
4253          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4254          """
4255          if arg==self:
4256             return identity(self.getShape())
4257          else:
4258             return integrate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4259    
4260    
4261  def interpolate(arg,where):  def interpolate(arg,where):
4262      """      """
4263      Interpolates the function into the FunctionSpace where.      interpolates the function into the FunctionSpace where.
4264    
4265      @param arg:    interpolant      @param arg: interpolant
4266      @param where:  FunctionSpace to interpolate to      @type arg: L{escript.Data} or L{Symbol}
4267        @param where: FunctionSpace to be interpolated to
4268        @type where: L{escript.FunctionSpace}
4269        @return: interpolated argument
4270        @rtype:  C{escript.Data} or L{Symbol}
4271      """      """
4272      if testForZero(arg):      if isinstance(arg,Symbol):
4273        return 0         return Interpolate_Symbol(arg,where)
     elif isinstance(arg,Symbol):  
        return Interpolated_Symbol(arg,where)  
4274      else:      else:
4275         return escript.Data(arg,where)         return escript.Data(arg,where)
4276    
4277  def div(arg,where=None):  class Interpolate_Symbol(DependendSymbol):
4278      """     """
4279      Returns the divergence of arg at where.     L{Symbol} representing the result of the interpolation operator
4280       """
4281       def __init__(self,arg,where):
4282          """
4283          initialization of interpolation L{Symbol} with argument arg
4284          @param arg: argument of the interpolation
4285          @type arg: L{Symbol}.
4286          @param where: FunctionSpace into which the argument is interpolated.
4287          @type where: L{escript.FunctionSpace}
4288          """
4289          super(Interpolate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4290    
4291      @param arg:   Data object representing the function which gradient to     def getMyCode(self,argstrs,format="escript"):
4292                    be calculated.        """
4293      @param where: FunctionSpace in which the gradient will be calculated.        returns a program code that can be used to evaluate the symbol.
                   If not present or C{None} an appropriate default is used.  
     """  
     g=grad(arg,where)  
     return trace(g,axis0=g.getRank()-2,axis1=g.getRank()-1)  
4294    
4295  def jump(arg):        @param argstrs: gives for each argument a string representing the argument for the evaluation.
4296      """        @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4297      Returns the jump of arg across a continuity.        @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4298          @type format: C{str}
4299          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4300          @rtype: C{str}
4301          @raise: NotImplementedError: if the requested format is not available
4302          """
4303          if format=="escript" or format=="str"  or format=="text":
4304             return "interpolate(%s,where=%s)"%(argstrs[0],argstrs[1])
4305          else:
4306             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4307    
4308      @param arg:   Data object representing the function which gradient     def substitute(self,argvals):
4309                    to be calculated.        """
4310      """        assigns new values to symbols in the definition of the symbol.
4311      d=arg.getDomain()        The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
     return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())  
4312    
4313  #=============================        @param argvals: new values assigned to symbols
4314  #        @type argvals: C{dict} with keywords of type L{Symbol}.
4315  # wrapper for various functions: if the argument has attribute the function name        @return: result of the substitution process. Operations are executed as much as possible.
4316  # as an argument it calls the corresponding methods. Otherwise the corresponding        @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4317  # numarray function is called.        @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4318          """
4319          if argvals.has_key(self):
4320             arg=argvals[self]
4321             if self.isAppropriateValue(arg):
4322                return arg
4323             else:
4324                raise TypeError,"%s: new value is not appropriate."%str(self)
4325          else:
4326             arg=self.getSubstitutedArguments(argvals)
4327             return interpolate(arg[0],where=arg[1])
4328    
4329  # functions involving the underlying Domain:     def diff(self,arg):
4330          """
4331          differential of this object
4332    
4333          @param arg: the derivative is calculated with respect to arg
4334          @type arg: L{escript.Symbol}
4335          @return: derivative with respect to C{arg}
4336          @rtype: L{Symbol} but other types such as L{escript.Data}, L{numarray.NumArray}  are possible.
4337          """
4338          if arg==self:
4339             return identity(self.getShape())
4340          else:
4341             return interpolate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4342    
4343  def transpose(arg,axis=None):  
4344    def div(arg,where=None):
4345      """      """
4346      Returns the transpose of the Data object arg.      returns the divergence of arg at where.
4347    
4348      @param arg:      @param arg: function which divergence to be calculated. Its shape has to be (d,) where d is the spatial dimension.
4349        @type arg: L{escript.Data} or L{Symbol}
4350        @param where: FunctionSpace in which the divergence will be calculated.
4351                      If not present or C{None} an appropriate default is used.
4352        @type where: C{None} or L{escript.FunctionSpace}
4353        @return: divergence of arg.
4354        @rtype:  L{escript.Data} or L{Symbol}
4355      """      """
     if axis==None:  
        r=0  
        if hasattr(arg,"getRank"): r=arg.getRank()  
        if hasattr(arg,"rank"): r=arg.rank  
        axis=r/2  
4356      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
4357         return Transpose_Symbol(arg,axis=r)          dim=arg.getDim()
4358      if isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4359         # hack for transpose          dim=arg.getDomain().getDim()
        r=arg.getRank()  
        if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"  
        s=arg.getShape()  
        out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())  
        for i in range(s[0]):  
           for j in range(s[1]):  
              out[j,i]=arg[i,j]  
        return out  
        # end hack for transpose  
        return arg.transpose(axis)  
4360      else:      else:
4361         return numarray.transpose(arg,axis=axis)          raise TypeError,"div: argument type not supported"
4362        if not arg.getShape()==(dim,):
4363          raise ValueError,"div: expected shape is (%s,)"%dim
4364        return trace(grad(arg,where))
4365    
4366  def trace(arg,axis0=0,axis1=1):  def jump(arg,domain=None):
4367      """      """
4368      Return      returns the jump of arg across the continuity of the domain
4369    
4370      @param arg:      @param arg: argument
4371        @type arg: L{escript.Data} or L{Symbol}
4372        @param domain: the domain where the discontinuity is located. If domain is not present or equal to C{None}
4373                       the domain of arg is used. If arg is a L{Symbol} the domain must be present.
4374        @type domain: C{None} or L{escript.Domain}
4375        @return: jump of arg
4376        @rtype:  L{escript.Data} or L{Symbol}
4377      """      """
4378      if isinstance(arg,Symbol):      if domain==None: domain=arg.getDomain()
4379         s=list(arg.getShape())              return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))
        s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])  
        return Trace_Symbol(arg,axis0=axis0,axis1=axis1)  
     elif isinstance(arg,escript.Data):  
        # hack for trace  
        s=arg.getShape()  
        if s[axis0]!=s[axis1]:  
            raise ValueError,"illegal axis in trace"  
        out=escript.Scalar(0.,arg.getFunctionSpace())  
        for i in range(s[axis0]):  
           out+=arg[i,i]  
        return out  
        # end hack for trace  
     else:  
        return numarray.trace(arg,axis0=axis0,axis1=axis1)  
4380    
4381    def L2(arg):
4382        """
4383        returns the L2 norm of arg at where
4384        
4385        @param arg: function which L2 to be calculated.
4386        @type arg: L{escript.Data} or L{Symbol}
4387        @return: L2 norm of arg.
4388        @rtype:  L{float} or L{Symbol}
4389        @note: L2(arg) is equivalent to sqrt(integrate(inner(arg,arg)))
4390        """
4391        return sqrt(integrate(inner(arg,arg)))
4392    #=============================
4393    #
4394    
4395  def reorderComponents(arg,index):  def reorderComponents(arg,index):
4396      """      """
4397      resorts the component of arg according to index      resorts the component of arg according to index
4398    
4399      """      """
4400      pass      raise NotImplementedError
4401  #  #
4402  # $Log: util.py,v $  # $Log: util.py,v $
4403  # Revision 1.14.2.16  2005/10/19 06:09:57  gross  # Revision 1.14.2.16  2005/10/19 06:09:57  gross

Legend:
Removed from v.341  
changed lines
  Added in v.492

  ViewVC Help
Powered by ViewVC 1.1.26