/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 341 by gross, Mon Dec 12 05:26:10 2005 UTC revision 530 by gross, Wed Feb 15 07:11:00 2006 UTC
# Line 24  Utility functions for escript Line 24  Utility functions for escript
24  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
25  __licence__="contact: esys@access.uq.edu.au"  __licence__="contact: esys@access.uq.edu.au"
26  __url__="http://www.iservo.edu.au/esys/escript"  __url__="http://www.iservo.edu.au/esys/escript"
27  __version__="$Revision: 329 $"  __version__="$Revision$"
28  __date__="$Date$"  __date__="$Date$"
29    
30    
31  import math  import math
32  import numarray  import numarray
33    import numarray.linear_algebra
34  import escript  import escript
35  import os  import os
36    
# Line 43  import os Line 44  import os
44  # def matchType(arg0=0.,arg1=0.):  # def matchType(arg0=0.,arg1=0.):
45  # def matchShape(arg0,arg1):  # def matchShape(arg0,arg1):
46    
 # def maximum(arg0,arg1):  
 # def minimum(arg0,arg1):  
   
 # def transpose(arg,axis=None):  
 # def trace(arg,axis0=0,axis1=1):  
47  # def reorderComponents(arg,index):  # def reorderComponents(arg,index):
48    
 # def integrate(arg,where=None):  
 # def interpolate(arg,where):  
 # def div(arg,where=None):  
 # def grad(arg,where=None):  
   
49  #  #
50  # slicing: get  # slicing: get
51  #          set  #          set
# Line 125  def kronecker(d=3): Line 116  def kronecker(d=3):
116     return the kronecker S{delta}-symbol     return the kronecker S{delta}-symbol
117    
118     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
119     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
120     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
121     @rtype d: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2.
    @remark: the function is identical L{identity}  
122     """     """
123     return identityTensor(d)     return identityTensor(d)
124    
# Line 147  def identity(shape=()): Line 137  def identity(shape=()):
137        if len(shape)==1:        if len(shape)==1:
138            for i0 in range(shape[0]):            for i0 in range(shape[0]):
139               out[i0,i0]=1.               out[i0,i0]=1.
   
140        elif len(shape)==2:        elif len(shape)==2:
141            for i0 in range(shape[0]):            for i0 in range(shape[0]):
142               for i1 in range(shape[1]):               for i1 in range(shape[1]):
# Line 163  def identityTensor(d=3): Line 152  def identityTensor(d=3):
152     return the dxd identity matrix     return the dxd identity matrix
153    
154     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
155     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
156     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
157     @rtype: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2
158     """     """
159     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
160        d=d.getDim()         return escript.Data(identity((d.getDim(),)),d)
161     return identity(shape=(d,))     elif isinstance(d,escript.Domain):
162           return identity((d.getDim(),))
163       else:
164           return identity((d,))
165    
166  def identityTensor4(d=3):  def identityTensor4(d=3):
167     """     """
# Line 178  def identityTensor4(d=3): Line 170  def identityTensor4(d=3):
170     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
171     @type d: C{int} or any object with a C{getDim} method     @type d: C{int} or any object with a C{getDim} method
172     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise
173     @rtype: L{numarray.NumArray} of rank 4.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 4.
174     """     """
175     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
176        d=d.getDim()         return escript.Data(identity((d.getDim(),d.getDim())),d)
177     return identity((d,d))     elif isinstance(d,escript.Domain):
178           return identity((d.getDim(),d.getDim()))
179       else:
180           return identity((d,d))
181    
182  def unitVector(i=0,d=3):  def unitVector(i=0,d=3):
183     """     """
# Line 191  def unitVector(i=0,d=3): Line 186  def unitVector(i=0,d=3):
186     @param i: index     @param i: index
187     @type i: C{int}     @type i: C{int}
188     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
189     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
190     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise
191     @rtype: L{numarray.NumArray} of rank 1.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 1
192     """     """
193     return kronecker(d)[i]     return kronecker(d)[i]
194    
# Line 363  def testForZero(arg): Line 358  def testForZero(arg):
358      @return : True if the argument is identical to zero.      @return : True if the argument is identical to zero.
359      @rtype : C{bool}      @rtype : C{bool}
360      """      """
361      try:      if isinstance(arg,numarray.NumArray):
362           return not Lsup(arg)>0.
363        elif isinstance(arg,escript.Data):
364           return False
365        elif isinstance(arg,float):
366           return not Lsup(arg)>0.
367        elif isinstance(arg,int):
368         return not Lsup(arg)>0.         return not Lsup(arg)>0.
369      except TypeError:      elif isinstance(arg,Symbol):
370           return False
371        else:
372         return False         return False
373    
374  def matchType(arg0=0.,arg1=0.):  def matchType(arg0=0.,arg1=0.):
# Line 825  class Symbol(object): Line 828  class Symbol(object):
828         """         """
829         return power(other,self)         return power(other,self)
830    
831       def __getitem__(self,index):
832           """
833           returns the slice defined by index
834    
835           @param index: defines a
836           @type index: C{slice} or C{int} or a C{tuple} of them
837           @return: a S{Symbol} representing the slice defined by index
838           @rtype: L{DependendSymbol}
839           """
840           return GetSlice_Symbol(self,index)
841    
842  class DependendSymbol(Symbol):  class DependendSymbol(Symbol):
843     """     """
844     DependendSymbol extents L{Symbol} by modifying the == operator to allow two instances to be equal.     DependendSymbol extents L{Symbol} by modifying the == operator to allow two instances to be equal.
# Line 875  class DependendSymbol(Symbol): Line 889  class DependendSymbol(Symbol):
889  #=========================================================  #=========================================================
890  #  Unary operations prserving the shape  #  Unary operations prserving the shape
891  #========================================================  #========================================================
892    class GetSlice_Symbol(DependendSymbol):
893       """
894       L{Symbol} representing getting a slice for a L{Symbol}
895       """
896       def __init__(self,arg,index):
897          """
898          initialization of wherePositive L{Symbol} with argument arg
899          @param arg: argument
900          @type arg: L{Symbol}.
901          @param index: defines index
902          @type index: C{slice} or C{int} or a C{tuple} of them
903          @raises IndexError: if length of index is larger than rank of arg or a index start or stop is out of range
904          @raises ValueError: if a step is given
905          """
906          if not isinstance(index,tuple): index=(index,)
907          if len(index)>arg.getRank():
908               raise IndexError,"GetSlice_Symbol: index out of range."
909          sh=()
910          index2=()
911          for i in range(len(index)):
912             ix=index[i]
913             if isinstance(ix,int):
914                if ix<0 or ix>=arg.getShape()[i]:
915                   raise ValueError,"GetSlice_Symbol: index out of range."
916                index2=index2+(ix,)
917             else:
918               if not ix.step==None:
919                 raise ValueError,"GetSlice_Symbol: steping is not supported."
920               if ix.start==None:
921                  s=0
922               else:
923                  s=ix.start
924               if ix.stop==None:
925                  e=arg.getShape()[i]
926               else:
927                  e=ix.stop
928                  if e>arg.getShape()[i]:
929                     raise IndexError,"GetSlice_Symbol: index out of range."
930               index2=index2+(slice(s,e),)
931               if e>s:
932                   sh=sh+(e-s,)
933               elif s>e:
934                   raise IndexError,"GetSlice_Symbol: slice start must be less or equal slice end"
935          for i in range(len(index),arg.getRank()):
936              index2=index2+(slice(0,arg.getShape()[i]),)
937              sh=sh+(arg.getShape()[i],)
938          super(GetSlice_Symbol, self).__init__(args=[arg,index2],shape=sh,dim=arg.getDim())
939    
940       def getMyCode(self,argstrs,format="escript"):
941          """
942          returns a program code that can be used to evaluate the symbol.
943    
944          @param argstrs: gives for each argument a string representing the argument for the evaluation.
945          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
946          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
947          @type format: C{str}
948          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
949          @rtype: C{str}
950          @raise: NotImplementedError: if the requested format is not available
951          """
952          if format=="escript" or format=="str"  or format=="text":
953             return "%s.__getitem__(%s)"%(argstrs[0],argstrs[1])
954          else:
955             raise NotImplementedError,"GetItem_Symbol does not provide program code for format %s."%format
956    
957       def substitute(self,argvals):
958          """
959          assigns new values to symbols in the definition of the symbol.
960          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
961    
962          @param argvals: new values assigned to symbols
963          @type argvals: C{dict} with keywords of type L{Symbol}.
964          @return: result of the substitution process. Operations are executed as much as possible.
965          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
966          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
967          """
968          if argvals.has_key(self):
969             arg=argvals[self]
970             if self.isAppropriateValue(arg):
971                return arg
972             else:
973                raise TypeError,"%s: new value is not appropriate."%str(self)
974          else:
975             args=self.getSubstitutedArguments(argvals)
976             arg=args[0]
977             index=args[1]
978             return arg.__getitem__(index)
979    
980  def log10(arg):  def log10(arg):
981     """     """
982     returns base-10 logarithm of argument arg     returns base-10 logarithm of argument arg
# Line 907  def wherePositive(arg): Line 1009  def wherePositive(arg):
1009     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1010     """     """
1011     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1012        if arg.rank==0:        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1013           if arg>0:        if isinstance(out,float): out=numarray.array(out)
1014             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))  
1015     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1016        return arg._wherePositive()        return arg._wherePositive()
1017     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 993  def whereNegative(arg): Line 1091  def whereNegative(arg):
1091     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1092     """     """
1093     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1094        if arg.rank==0:        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1095           if arg<0:        if isinstance(out,float): out=numarray.array(out)
1096             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))  
1097     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1098        return arg._whereNegative()        return arg._whereNegative()
1099     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1079  def whereNonNegative(arg): Line 1173  def whereNonNegative(arg):
1173     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1174     """     """
1175     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1176        if arg.rank==0:        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1177           if arg<0:        if isinstance(out,float): out=numarray.array(out)
1178             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))  
1179     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1180        return arg._whereNonNegative()        return arg._whereNonNegative()
1181     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1113  def whereNonPositive(arg): Line 1203  def whereNonPositive(arg):
1203     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1204     """     """
1205     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1206        if arg.rank==0:        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1207           if arg>0:        if isinstance(out,float): out=numarray.array(out)
1208             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.  
1209     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1210        return arg._whereNonPositive()        return arg._whereNonPositive()
1211     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1149  def whereZero(arg,tol=0.): Line 1235  def whereZero(arg,tol=0.):
1235     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1236     """     """
1237     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1238        if arg.rank==0:        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1239           if abs(arg)<=tol:        if isinstance(out,float): out=numarray.array(out)
1240             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1241     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1242        if tol>0.:        if tol>0.:
1243           return whereNegative(abs(arg)-tol)           return whereNegative(abs(arg)-tol)
# Line 1236  def whereNonZero(arg,tol=0.): Line 1318  def whereNonZero(arg,tol=0.):
1318     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1319     """     """
1320     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1321        if arg.rank==0:        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1322          if abs(arg)>tol:        if isinstance(out,float): out=numarray.array(out)
1323             return numarray.array(1.)        return out
         else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1324     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1325        if tol>0.:        if tol>0.:
1326           return 1.-whereZero(arg,tol)           return 1.-whereZero(arg,tol)
# Line 2877  def length(arg): Line 2955  def length(arg):
2955     """     """
2956     return sqrt(inner(arg,arg))     return sqrt(inner(arg,arg))
2957    
2958    def trace(arg,axis_offset=0):
2959       """
2960       returns the trace of arg which the sum of arg[k,k] over k.
2961    
2962       @param arg: argument
2963       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}.
2964       @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2965                      axis_offset and axis_offset+1 must be equal.
2966       @type axis_offset: C{int}
2967       @return: trace of arg. The rank of the returned object is minus 2 of the rank of arg.
2968       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.
2969       """
2970       if isinstance(arg,numarray.NumArray):
2971          sh=arg.shape
2972          if len(sh)<2:
2973            raise ValueError,"trace: rank of argument must be greater than 1"
2974          if axis_offset<0 or axis_offset>len(sh)-2:
2975            raise ValueError,"trace: axis_offset must be between 0 and %s"%len(sh)-2
2976          s1=1
2977          for i in range(axis_offset): s1*=sh[i]
2978          s2=1
2979          for i in range(axis_offset+2,len(sh)): s2*=sh[i]
2980          if not sh[axis_offset] == sh[axis_offset+1]:
2981            raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2982          arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2983          out=numarray.zeros([s1,s2],numarray.Float)
2984          for i1 in range(s1):
2985            for i2 in range(s2):
2986                for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
2987          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
2988          return out
2989       elif isinstance(arg,escript.Data):
2990          return escript_trace(arg,axis_offset)
2991       elif isinstance(arg,float):
2992          raise TypeError,"trace: illegal argument type float."
2993       elif isinstance(arg,int):
2994          raise TypeError,"trace: illegal argument type int."
2995       elif isinstance(arg,Symbol):
2996          return Trace_Symbol(arg,axis_offset)
2997       else:
2998          raise TypeError,"trace: Unknown argument type."
2999    
3000    def escript_trace(arg,axis_offset): # this should be escript._trace
3001          "arg si a Data objects!!!"
3002          if arg.getRank()<2:
3003            raise ValueError,"escript_trace: rank of argument must be greater than 1"
3004          if axis_offset<0 or axis_offset>arg.getRank()-2:
3005            raise ValueError,"escript_trace: axis_offset must be between 0 and %s"%arg.getRank()-2
3006          s=list(arg.getShape())        
3007          if not s[axis_offset] == s[axis_offset+1]:
3008            raise ValueError,"escript_trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
3009          out=escript.Data(0.,tuple(s[0:axis_offset]+s[axis_offset+2:]),arg.getFunctionSpace())
3010          if arg.getRank()==2:
3011             for i0 in range(s[0]):
3012                out+=arg[i0,i0]
3013          elif arg.getRank()==3:
3014             if axis_offset==0:
3015                for i0 in range(s[0]):
3016                      for i2 in range(s[2]):
3017                             out[i2]+=arg[i0,i0,i2]
3018             elif axis_offset==1:
3019                for i0 in range(s[0]):
3020                   for i1 in range(s[1]):
3021                             out[i0]+=arg[i0,i1,i1]
3022          elif arg.getRank()==4:
3023             if axis_offset==0:
3024                for i0 in range(s[0]):
3025                      for i2 in range(s[2]):
3026                         for i3 in range(s[3]):
3027                             out[i2,i3]+=arg[i0,i0,i2,i3]
3028             elif axis_offset==1:
3029                for i0 in range(s[0]):
3030                   for i1 in range(s[1]):
3031                         for i3 in range(s[3]):
3032                             out[i0,i3]+=arg[i0,i1,i1,i3]
3033             elif axis_offset==2:
3034                for i0 in range(s[0]):
3035                   for i1 in range(s[1]):
3036                      for i2 in range(s[2]):
3037                             out[i0,i1]+=arg[i0,i1,i2,i2]
3038          return out
3039    class Trace_Symbol(DependendSymbol):
3040       """
3041       L{Symbol} representing the result of the trace function
3042       """
3043       def __init__(self,arg,axis_offset=0):
3044          """
3045          initialization of trace L{Symbol} with argument arg
3046          @param arg: argument of function
3047          @type arg: L{Symbol}.
3048          @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
3049                      axis_offset and axis_offset+1 must be equal.
3050          @type axis_offset: C{int}
3051          """
3052          if arg.getRank()<2:
3053            raise ValueError,"Trace_Symbol: rank of argument must be greater than 1"
3054          if axis_offset<0 or axis_offset>arg.getRank()-2:
3055            raise ValueError,"Trace_Symbol: axis_offset must be between 0 and %s"%arg.getRank()-2
3056          s=list(arg.getShape())        
3057          if not s[axis_offset] == s[axis_offset+1]:
3058            raise ValueError,"Trace_Symbol: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
3059          super(Trace_Symbol,self).__init__(args=[arg,axis_offset],shape=tuple(s[0:axis_offset]+s[axis_offset+2:]),dim=arg.getDim())
3060    
3061       def getMyCode(self,argstrs,format="escript"):
3062          """
3063          returns a program code that can be used to evaluate the symbol.
3064    
3065          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3066          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3067          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3068          @type format: C{str}
3069          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3070          @rtype: C{str}
3071          @raise: NotImplementedError: if the requested format is not available
3072          """
3073          if format=="escript" or format=="str"  or format=="text":
3074             return "trace(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
3075          else:
3076             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3077    
3078       def substitute(self,argvals):
3079          """
3080          assigns new values to symbols in the definition of the symbol.
3081          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3082    
3083          @param argvals: new values assigned to symbols
3084          @type argvals: C{dict} with keywords of type L{Symbol}.
3085          @return: result of the substitution process. Operations are executed as much as possible.
3086          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3087          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3088          """
3089          if argvals.has_key(self):
3090             arg=argvals[self]
3091             if self.isAppropriateValue(arg):
3092                return arg
3093             else:
3094                raise TypeError,"%s: new value is not appropriate."%str(self)
3095          else:
3096             arg=self.getSubstitutedArguments(argvals)
3097             return trace(arg[0],axis_offset=arg[1])
3098    
3099       def diff(self,arg):
3100          """
3101          differential of this object
3102    
3103          @param arg: the derivative is calculated with respect to arg
3104          @type arg: L{escript.Symbol}
3105          @return: derivative with respect to C{arg}
3106          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3107          """
3108          if arg==self:
3109             return identity(self.getShape())
3110          else:
3111             return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3112    
3113    def transpose(arg,axis_offset=None):
3114       """
3115       returns the transpose of arg by swaping the first axis_offset and the last rank-axis_offset components.
3116    
3117       @param arg: argument
3118       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}, C{float}, C{int}
3119       @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3120                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3121       @type axis_offset: C{int}
3122       @return: transpose of arg
3123       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray},C{float}, C{int} depending on the type of arg.
3124       """
3125       if isinstance(arg,numarray.NumArray):
3126          if axis_offset==None: axis_offset=int(arg.rank/2)
3127          return numarray.transpose(arg,axes=range(axis_offset,arg.rank)+range(0,axis_offset))
3128       elif isinstance(arg,escript.Data):
3129          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3130          return escript_transpose(arg,axis_offset)
3131       elif isinstance(arg,float):
3132          if not ( axis_offset==0 or axis_offset==None):
3133            raise ValueError,"transpose: axis_offset must be 0 for float argument"
3134          return arg
3135       elif isinstance(arg,int):
3136          if not ( axis_offset==0 or axis_offset==None):
3137            raise ValueError,"transpose: axis_offset must be 0 for int argument"
3138          return float(arg)
3139       elif isinstance(arg,Symbol):
3140          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3141          return Transpose_Symbol(arg,axis_offset)
3142       else:
3143          raise TypeError,"transpose: Unknown argument type."
3144    
3145    def escript_transpose(arg,axis_offset): # this should be escript._transpose
3146          "arg si a Data objects!!!"
3147          r=arg.getRank()
3148          if axis_offset<0 or axis_offset>r:
3149            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3150          s=arg.getShape()
3151          s_out=s[axis_offset:]+s[:axis_offset]
3152          out=escript.Data(0.,s_out,arg.getFunctionSpace())
3153          if r==4:
3154             if axis_offset==1:
3155                for i0 in range(s_out[0]):
3156                   for i1 in range(s_out[1]):
3157                      for i2 in range(s_out[2]):
3158                         for i3 in range(s_out[3]):
3159                             out[i0,i1,i2,i3]=arg[i3,i0,i1,i2]
3160             elif axis_offset==2:
3161                for i0 in range(s_out[0]):
3162                   for i1 in range(s_out[1]):
3163                      for i2 in range(s_out[2]):
3164                         for i3 in range(s_out[3]):
3165                             out[i0,i1,i2,i3]=arg[i2,i3,i0,i1]
3166             elif axis_offset==3:
3167                for i0 in range(s_out[0]):
3168                   for i1 in range(s_out[1]):
3169                      for i2 in range(s_out[2]):
3170                         for i3 in range(s_out[3]):
3171                             out[i0,i1,i2,i3]=arg[i1,i2,i3,i0]
3172             else:
3173                for i0 in range(s_out[0]):
3174                   for i1 in range(s_out[1]):
3175                      for i2 in range(s_out[2]):
3176                         for i3 in range(s_out[3]):
3177                             out[i0,i1,i2,i3]=arg[i0,i1,i2,i3]
3178          elif r==3:
3179             if axis_offset==1:
3180                for i0 in range(s_out[0]):
3181                   for i1 in range(s_out[1]):
3182                      for i2 in range(s_out[2]):
3183                             out[i0,i1,i2]=arg[i2,i0,i1]
3184             elif axis_offset==2:
3185                for i0 in range(s_out[0]):
3186                   for i1 in range(s_out[1]):
3187                      for i2 in range(s_out[2]):
3188                             out[i0,i1,i2]=arg[i1,i2,i0]
3189             else:
3190                for i0 in range(s_out[0]):
3191                   for i1 in range(s_out[1]):
3192                      for i2 in range(s_out[2]):
3193                             out[i0,i1,i2]=arg[i0,i1,i2]
3194          elif r==2:
3195             if axis_offset==1:
3196                for i0 in range(s_out[0]):
3197                   for i1 in range(s_out[1]):
3198                             out[i0,i1]=arg[i1,i0]
3199             else:
3200                for i0 in range(s_out[0]):
3201                   for i1 in range(s_out[1]):
3202                             out[i0,i1]=arg[i0,i1]
3203          elif r==1:
3204              for i0 in range(s_out[0]):
3205                   out[i0]=arg[i0]
3206          elif r==0:
3207                 out=arg+0.
3208          return out
3209    class Transpose_Symbol(DependendSymbol):
3210       """
3211       L{Symbol} representing the result of the transpose function
3212       """
3213       def __init__(self,arg,axis_offset=None):
3214          """
3215          initialization of transpose L{Symbol} with argument arg
3216    
3217          @param arg: argument of function
3218          @type arg: L{Symbol}.
3219           @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3220                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3221          @type axis_offset: C{int}
3222          """
3223          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3224          if axis_offset<0 or axis_offset>arg.getRank():
3225            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3226          s=arg.getShape()
3227          super(Transpose_Symbol,self).__init__(args=[arg,axis_offset],shape=s[axis_offset:]+s[:axis_offset],dim=arg.getDim())
3228    
3229       def getMyCode(self,argstrs,format="escript"):
3230          """
3231          returns a program code that can be used to evaluate the symbol.
3232    
3233          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3234          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3235          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3236          @type format: C{str}
3237          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3238          @rtype: C{str}
3239          @raise: NotImplementedError: if the requested format is not available
3240          """
3241          if format=="escript" or format=="str"  or format=="text":
3242             return "transpose(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
3243          else:
3244             raise NotImplementedError,"Transpose_Symbol does not provide program code for format %s."%format
3245    
3246       def substitute(self,argvals):
3247          """
3248          assigns new values to symbols in the definition of the symbol.
3249          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3250    
3251          @param argvals: new values assigned to symbols
3252          @type argvals: C{dict} with keywords of type L{Symbol}.
3253          @return: result of the substitution process. Operations are executed as much as possible.
3254          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3255          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3256          """
3257          if argvals.has_key(self):
3258             arg=argvals[self]
3259             if self.isAppropriateValue(arg):
3260                return arg
3261             else:
3262                raise TypeError,"%s: new value is not appropriate."%str(self)
3263          else:
3264             arg=self.getSubstitutedArguments(argvals)
3265             return transpose(arg[0],axis_offset=arg[1])
3266    
3267       def diff(self,arg):
3268          """
3269          differential of this object
3270    
3271          @param arg: the derivative is calculated with respect to arg
3272          @type arg: L{escript.Symbol}
3273          @return: derivative with respect to C{arg}
3274          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3275          """
3276          if arg==self:
3277             return identity(self.getShape())
3278          else:
3279             return transpose(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3280    
3281    def inverse(arg):
3282        """
3283        returns the inverse of the square matrix arg.
3284    
3285        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3286        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3287        @return: inverse arg_inv of the argument. It will be matrixmul(inverse(arg),arg) almost equal to kronecker(arg.getShape()[0])
3288        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3289        @remark: for L{escript.Data} objects the dimension is restricted to 3.
3290        """
3291        if isinstance(arg,numarray.NumArray):
3292          return numarray.linear_algebra.inverse(arg)
3293        elif isinstance(arg,escript.Data):
3294          return escript_inverse(arg)
3295        elif isinstance(arg,float):
3296          return 1./arg
3297        elif isinstance(arg,int):
3298          return 1./float(arg)
3299        elif isinstance(arg,Symbol):
3300          return Inverse_Symbol(arg)
3301        else:
3302          raise TypeError,"inverse: Unknown argument type."
3303    
3304    def escript_inverse(arg): # this should be escript._inverse and use LAPACK
3305          "arg is a Data objects!!!"
3306          if not arg.getRank()==2:
3307            raise ValueError,"escript_inverse: argument must have rank 2"
3308          s=arg.getShape()      
3309          if not s[0] == s[1]:
3310            raise ValueError,"escript_inverse: argument must be a square matrix."
3311          out=escript.Data(0.,s,arg.getFunctionSpace())
3312          if s[0]==1:
3313              if inf(abs(arg[0,0]))==0: # in c this should be done point wise as abs(arg[0,0](i))<=0.
3314                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3315              out[0,0]=1./arg[0,0]
3316          elif s[0]==2:
3317              A11=arg[0,0]
3318              A12=arg[0,1]
3319              A21=arg[1,0]
3320              A22=arg[1,1]
3321              D = A11*A22-A12*A21
3322              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3323                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3324              D=1./D
3325              out[0,0]= A22*D
3326              out[1,0]=-A21*D
3327              out[0,1]=-A12*D
3328              out[1,1]= A11*D
3329          elif s[0]==3:
3330              A11=arg[0,0]
3331              A21=arg[1,0]
3332              A31=arg[2,0]
3333              A12=arg[0,1]
3334              A22=arg[1,1]
3335              A32=arg[2,1]
3336              A13=arg[0,2]
3337              A23=arg[1,2]
3338              A33=arg[2,2]
3339              D  =  A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
3340              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3341                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3342              D=1./D
3343              out[0,0]=(A22*A33-A23*A32)*D
3344              out[1,0]=(A31*A23-A21*A33)*D
3345              out[2,0]=(A21*A32-A31*A22)*D
3346              out[0,1]=(A13*A32-A12*A33)*D
3347              out[1,1]=(A11*A33-A31*A13)*D
3348              out[2,1]=(A12*A31-A11*A32)*D
3349              out[0,2]=(A12*A23-A13*A22)*D
3350              out[1,2]=(A13*A21-A11*A23)*D
3351              out[2,2]=(A11*A22-A12*A21)*D
3352          else:
3353             raise TypeError,"escript_inverse: only matrix dimensions 1,2,3 are supported right now."
3354          return out
3355    
3356    class Inverse_Symbol(DependendSymbol):
3357       """
3358       L{Symbol} representing the result of the inverse function
3359       """
3360       def __init__(self,arg):
3361          """
3362          initialization of inverse L{Symbol} with argument arg
3363          @param arg: argument of function
3364          @type arg: L{Symbol}.
3365          """
3366          if not arg.getRank()==2:
3367            raise ValueError,"Inverse_Symbol:: argument must have rank 2"
3368          s=arg.getShape()
3369          if not s[0] == s[1]:
3370            raise ValueError,"Inverse_Symbol:: argument must be a square matrix."
3371          super(Inverse_Symbol,self).__init__(args=[arg],shape=s,dim=arg.getDim())
3372    
3373       def getMyCode(self,argstrs,format="escript"):
3374          """
3375          returns a program code that can be used to evaluate the symbol.
3376    
3377          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3378          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3379          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3380          @type format: C{str}
3381          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3382          @rtype: C{str}
3383          @raise: NotImplementedError: if the requested format is not available
3384          """
3385          if format=="escript" or format=="str"  or format=="text":
3386             return "inverse(%s)"%argstrs[0]
3387          else:
3388             raise NotImplementedError,"Inverse_Symbol does not provide program code for format %s."%format
3389    
3390       def substitute(self,argvals):
3391          """
3392          assigns new values to symbols in the definition of the symbol.
3393          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3394    
3395          @param argvals: new values assigned to symbols
3396          @type argvals: C{dict} with keywords of type L{Symbol}.
3397          @return: result of the substitution process. Operations are executed as much as possible.
3398          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3399          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3400          """
3401          if argvals.has_key(self):
3402             arg=argvals[self]
3403             if self.isAppropriateValue(arg):
3404                return arg
3405             else:
3406                raise TypeError,"%s: new value is not appropriate."%str(self)
3407          else:
3408             arg=self.getSubstitutedArguments(argvals)
3409             return inverse(arg[0])
3410    
3411       def diff(self,arg):
3412          """
3413          differential of this object
3414    
3415          @param arg: the derivative is calculated with respect to arg
3416          @type arg: L{escript.Symbol}
3417          @return: derivative with respect to C{arg}
3418          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3419          """
3420          if arg==self:
3421             return identity(self.getShape())
3422          else:
3423             return -matrixmult(matrixmult(self,self.getDifferentiatedArguments(arg)[0]),self)
3424    
3425    def eigenvalues(arg):
3426        """
3427        returns the eigenvalues of the square matrix arg.
3428    
3429        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3430                    arg must be symmetric, ie. transpose(arg)==arg (this is not checked).
3431        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3432        @return: the eigenvalues in increasing order.
3433        @rtype: L{numarray.NumArray},L{escript.Data}, L{Symbol} depending on the input.
3434        @remark: for L{escript.Data} and L{Symbol} objects the dimension is restricted to 3.
3435        """
3436        if isinstance(arg,numarray.NumArray):
3437          if not arg.rank==2:
3438            raise ValueError,"eigenvalues: argument must have rank 2"
3439          s=arg.shape      
3440          if not s[0] == s[1]:
3441            raise ValueError,"eigenvalues: argument must be a square matrix."
3442          out=numarray.linear_algebra.eigenvalues((arg+numarray.transpose(arg))/2.)
3443          out.sort()
3444          return out
3445        elif isinstance(arg,escript.Data):
3446          return escript_eigenvalues(arg)
3447        elif isinstance(arg,Symbol):
3448          if not arg.getRank()==2:
3449            raise ValueError,"eigenvalues: argument must have rank 2"
3450          s=arg.getShape()      
3451          if not s[0] == s[1]:
3452            raise ValueError,"eigenvalues: argument must be a square matrix."
3453          if s[0]==1:
3454              return arg[0]
3455          elif s[0]==2:
3456              A11=arg[0,0]
3457              A12=arg[0,1]
3458              A22=arg[1,1]
3459              trA=(A11+A22)/2.
3460              A11-=trA
3461              A22-=trA
3462              s=sqrt(A12**2-A11*A22)
3463              return trA+s*numarray.array([-1.,1.])
3464          elif s[0]==3:
3465              A11=arg[0,0]
3466              A12=arg[0,1]
3467              A22=arg[1,1]
3468              A13=arg[0,2]
3469              A23=arg[1,2]
3470              A33=arg[2,2]
3471              trA=(A11+A22+A33)/3.
3472              A11-=trA
3473              A22-=trA
3474              A33-=trA
3475              A13_2=A13**2
3476              A23_2=A23**2
3477              A12_2=A12**2
3478              p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.
3479              q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13
3480              sq_p=sqrt(p/3.)
3481              alpha_3=acos(-q*sq_p**(-3.)/2.)/3.
3482              sq_p*=2.
3483              f=cos(alpha_3)               *numarray.array([0.,0.,1.]) \
3484               -cos(alpha_3+numarray.pi/3.)*numarray.array([0.,1.,0.]) \
3485               -cos(alpha_3-numarray.pi/3.)*numarray.array([1.,0.,0.])
3486              return trA+sq_p*f
3487          else:
3488             raise TypeError,"eigenvalues: only matrix dimensions 1,2,3 are supported right now."
3489        elif isinstance(arg,float):
3490          return arg
3491        elif isinstance(arg,int):
3492          return float(arg)
3493        else:
3494          raise TypeError,"eigenvalues: Unknown argument type."
3495    
3496    def escript_eigenvalues(arg): # this should be implemented in C++ arg and LAPACK is data object
3497          if not arg.getRank()==2:
3498            raise ValueError,"eigenvalues: argument must have rank 2"
3499          s=arg.getShape()      
3500          if not s[0] == s[1]:
3501            raise ValueError,"eigenvalues: argument must be a square matrix."
3502          if s[0]==1:
3503              return arg[0]
3504          elif s[0]==2:
3505              A11=arg[0,0]
3506              A12=arg[0,1]
3507              A22=arg[1,1]
3508              trA=(A11+A22)/2.
3509              A11-=trA
3510              A22-=trA
3511              s=sqrt(A12**2-A11*A22)
3512              return trA+s*numarray.array([-1.,1.])
3513          elif s[0]==3:
3514              A11=arg[0,0]
3515              A12=arg[0,1]
3516              A22=arg[1,1]
3517              A13=arg[0,2]
3518              A23=arg[1,2]
3519              A33=arg[2,2]
3520              trA=(A11+A22+A33)/3.
3521              A11-=trA
3522              A22-=trA
3523              A33-=trA
3524              A13_2=A13**2
3525              A23_2=A23**2
3526              A12_2=A12**2
3527              p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.
3528              q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13
3529              sq_p=sqrt(p/3.)
3530              alpha_3=acos(-q*sq_p**(-3.)/2.)/3.
3531              sq_p*=2.
3532              f=escript.Data(0.,(3,),arg.getFunctionSpace())
3533              f[0]=-cos(alpha_3-numarray.pi/3.)
3534              f[1]=-cos(alpha_3+numarray.pi/3.)
3535              f[2]=cos(alpha_3)
3536              return trA+sq_p*f
3537          else:
3538             raise TypeError,"eigenvalues: only matrix dimensions 1,2,3 are supported right now."
3539  #=======================================================  #=======================================================
3540  #  Binary operations:  #  Binary operations:
3541  #=======================================================  #=======================================================
# Line 3304  def maximum(*args): Line 3963  def maximum(*args):
3963         if out==None:         if out==None:
3964            out=a            out=a
3965         else:         else:
3966            m=whereNegative(out-a)            diff=add(a,-out)
3967            out=m*a+(1.-m)*out            out=add(out,mult(wherePositive(diff),diff))
3968      return out      return out
3969        
3970  def minimum(*arg):  def minimum(*args):
3971      """      """
3972      the minimum over arguments args      the minimum over arguments args
3973    
# Line 3322  def minimum(*arg): Line 3981  def minimum(*arg):
3981         if out==None:         if out==None:
3982            out=a            out=a
3983         else:         else:
3984            m=whereNegative(out-a)            diff=add(a,-out)
3985            out=m*out+(1.-m)*a            out=add(out,mult(whereNegative(diff),diff))
3986      return out      return out
3987    
3988    def clip(arg,minval=0.,maxval=1.):
3989        """
3990        cuts the values of arg between minval and maxval
3991    
3992        @param arg: argument
3993        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float}
3994        @param minval: lower range
3995        @type arg: C{float}
3996        @param maxval: upper range
3997        @type arg: C{float}
3998        @return: is on object with all its value between minval and maxval. value of the argument that greater then minval and
3999                 less then maxval are unchanged.
4000        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float} depending on the input
4001        @raise ValueError: if minval>maxval
4002        """
4003        if minval>maxval:
4004           raise ValueError,"minval = %s must be less then maxval %s"%(minval,maxval)
4005        return minimum(maximum(minval,arg),maxval)
4006    
4007        
4008  def inner(arg0,arg1):  def inner(arg0,arg1):
4009      """      """
# Line 3348  def inner(arg0,arg1): Line 4027  def inner(arg0,arg1):
4027      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
4028      if not sh0==sh1:      if not sh0==sh1:
4029          raise ValueError,"inner: shape of arguments does not match"          raise ValueError,"inner: shape of arguments does not match"
4030      return generalTensorProduct(arg0,arg1,offset=len(sh0))      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))
4031    
4032  def matrixmult(arg0,arg1):  def matrixmult(arg0,arg1):
4033      """      """
# Line 3376  def matrixmult(arg0,arg1): Line 4055  def matrixmult(arg0,arg1):
4055          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
4056      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
4057          raise ValueError,"second argument must have rank 1 or 2"          raise ValueError,"second argument must have rank 1 or 2"
4058      return generalTensorProduct(arg0,arg1,offset=1)      return generalTensorProduct(arg0,arg1,axis_offset=1)
4059    
4060  def outer(arg0,arg1):  def outer(arg0,arg1):
4061      """      """
# Line 3394  def outer(arg0,arg1): Line 4073  def outer(arg0,arg1):
4073      @return: the outer product of arg0 and arg1 at each data point      @return: the outer product of arg0 and arg1 at each data point
4074      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
4075      """      """
4076      return generalTensorProduct(arg0,arg1,offset=0)      return generalTensorProduct(arg0,arg1,axis_offset=0)
4077    
4078    
4079  def tensormult(arg0,arg1):  def tensormult(arg0,arg1):
# Line 3436  def tensormult(arg0,arg1): Line 4115  def tensormult(arg0,arg1):
4115      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
4116      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
4117      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
4118         return generalTensorProduct(arg0,arg1,offset=1)         return generalTensorProduct(arg0,arg1,axis_offset=1)
4119      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
4120         return generalTensorProduct(arg0,arg1,offset=2)         return generalTensorProduct(arg0,arg1,axis_offset=2)
4121      else:      else:
4122          raise ValueError,"tensormult: first argument must have rank 2 or 4"          raise ValueError,"tensormult: first argument must have rank 2 or 4"
4123    
4124  def generalTensorProduct(arg0,arg1,offset=0):  def generalTensorProduct(arg0,arg1,axis_offset=0):
4125      """      """
4126      generalized tensor product      generalized tensor product
4127    
4128      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]
4129    
4130      where s runs through arg0.Shape[:arg0.Rank-offset]      where s runs through arg0.Shape[:arg0.Rank-axis_offset]
4131            r runs trough arg0.Shape[:offset]            r runs trough arg0.Shape[:axis_offset]
4132            t runs through arg1.Shape[offset:]            t runs through arg1.Shape[axis_offset:]
4133    
4134      In the first case the the second dimension of arg0 and the length of arg1 must match and        In the first case the the second dimension of arg0 and the length of arg1 must match and  
4135      in the second case the two last dimensions of arg0 must match the shape of arg1.      in the second case the two last dimensions of arg0 must match the shape of arg1.
# Line 3467  def generalTensorProduct(arg0,arg1,offse Line 4146  def generalTensorProduct(arg0,arg1,offse
4146      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols
4147      if isinstance(arg0,numarray.NumArray):      if isinstance(arg0,numarray.NumArray):
4148         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
4149             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
4150         else:         else:
4151             if not arg0.shape[arg0.rank-offset:]==arg1.shape[:offset]:             if not arg0.shape[arg0.rank-axis_offset:]==arg1.shape[:axis_offset]:
4152                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4153             arg0_c=arg0.copy()             arg0_c=arg0.copy()
4154             arg1_c=arg1.copy()             arg1_c=arg1.copy()
4155             sh0,sh1=arg0.shape,arg1.shape             sh0,sh1=arg0.shape,arg1.shape
4156             d0,d1,d01=1,1,1             d0,d1,d01=1,1,1
4157             for i in sh0[:arg0.rank-offset]: d0*=i             for i in sh0[:arg0.rank-axis_offset]: d0*=i
4158             for i in sh1[offset:]: d1*=i             for i in sh1[axis_offset:]: d1*=i
4159             for i in sh1[:offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
4160             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
4161             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
4162             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float)
4163             for i0 in range(d0):             for i0 in range(d0):
4164                      for i1 in range(d1):                      for i1 in range(d1):
4165                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])
4166             out.resize(sh0[:arg0.rank-offset]+sh1[offset:])             out.resize(sh0[:arg0.rank-axis_offset]+sh1[axis_offset:])
4167             return out             return out
4168      elif isinstance(arg0,escript.Data):      elif isinstance(arg0,escript.Data):
4169         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
4170             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
4171         else:         else:
4172             return escript_generalTensorProduct(arg0,arg1,offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,offset)             return escript_generalTensorProduct(arg0,arg1,axis_offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,axis_offset)
4173      else:            else:      
4174         return GeneralTensorProduct_Symbol(arg0,arg1,offset)         return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
4175                                    
4176  class GeneralTensorProduct_Symbol(DependendSymbol):  class GeneralTensorProduct_Symbol(DependendSymbol):
4177     """     """
4178     Symbol representing the quotient of two arguments.     Symbol representing the quotient of two arguments.
4179     """     """
4180     def __init__(self,arg0,arg1,offset=0):     def __init__(self,arg0,arg1,axis_offset=0):
4181         """         """
4182         initialization of L{Symbol} representing the quotient of two arguments         initialization of L{Symbol} representing the quotient of two arguments
4183    
# Line 3511  class GeneralTensorProduct_Symbol(Depend Line 4190  class GeneralTensorProduct_Symbol(Depend
4190         """         """
4191         sh_arg0=pokeShape(arg0)         sh_arg0=pokeShape(arg0)
4192         sh_arg1=pokeShape(arg1)         sh_arg1=pokeShape(arg1)
4193         sh0=sh_arg0[:len(sh_arg0)-offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
4194         sh01=sh_arg0[len(sh_arg0)-offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
4195         sh10=sh_arg1[:offset]         sh10=sh_arg1[:axis_offset]
4196         sh1=sh_arg1[offset:]         sh1=sh_arg1[axis_offset:]
4197         if not sh01==sh10:         if not sh01==sh10:
4198             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4199         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,offset])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,axis_offset])
4200    
4201     def getMyCode(self,argstrs,format="escript"):     def getMyCode(self,argstrs,format="escript"):
4202        """        """
# Line 3532  class GeneralTensorProduct_Symbol(Depend Line 4211  class GeneralTensorProduct_Symbol(Depend
4211        @raise: NotImplementedError: if the requested format is not available        @raise: NotImplementedError: if the requested format is not available
4212        """        """
4213        if format=="escript" or format=="str" or format=="text":        if format=="escript" or format=="str" or format=="text":
4214           return "generalTensorProduct(%s,%s,offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])           return "generalTensorProduct(%s,%s,axis_offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])
4215        else:        else:
4216           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)
4217    
# Line 3557  class GeneralTensorProduct_Symbol(Depend Line 4236  class GeneralTensorProduct_Symbol(Depend
4236           args=self.getSubstitutedArguments(argvals)           args=self.getSubstitutedArguments(argvals)
4237           return generalTensorProduct(args[0],args[1],args[2])           return generalTensorProduct(args[0],args[1],args[2])
4238    
4239  def escript_generalTensorProduct(arg0,arg1,offset): # this should be escript._generalTensorProduct  def escript_generalTensorProduct(arg0,arg1,axis_offset): # this should be escript._generalTensorProduct
4240      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"
4241      # calculate the return shape:      # calculate the return shape:
4242      shape0=arg0.getShape()[:arg0.getRank()-offset]      shape0=arg0.getShape()[:arg0.getRank()-axis_offset]
4243      shape01=arg0.getShape()[arg0.getRank()-offset:]      shape01=arg0.getShape()[arg0.getRank()-axis_offset:]
4244      shape10=arg1.getShape()[:offset]      shape10=arg1.getShape()[:axis_offset]
4245      shape1=arg1.getShape()[offset:]      shape1=arg1.getShape()[axis_offset:]
4246      if not shape01==shape10:      if not shape01==shape10:
4247          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4248    
4249        # whatr function space should be used? (this here is not good!)
4250        fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace()
4251      # create return value:      # create return value:
4252      out=escript.Data(0.,tuple(shape0+shape1),arg0.getFunctionSpace())      out=escript.Data(0.,tuple(shape0+shape1),fs)
4253      #      #
4254      s0=[[]]      s0=[[]]
4255      for k in shape0:      for k in shape0:
# Line 3591  def escript_generalTensorProduct(arg0,ar Line 4272  def escript_generalTensorProduct(arg0,ar
4272    
4273      for i0 in s0:      for i0 in s0:
4274         for i1 in s1:         for i1 in s1:
4275           s=escript.Scalar(0.,arg0.getFunctionSpace())           s=escript.Scalar(0.,fs)
4276           for i01 in s01:           for i01 in s01:
4277              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))
4278           out.__setitem__(tuple(i0+i1),s)           out.__setitem__(tuple(i0+i1),s)
4279      return out      return out
4280    
4281    
4282  #=========================================================  #=========================================================
4283  #   some little helpers  #  functions dealing with spatial dependency
4284  #=========================================================  #=========================================================
4285  def grad(arg,where=None):  def grad(arg,where=None):
4286      """      """
4287      Returns the spatial gradient of arg at where.      Returns the spatial gradient of arg at where.
4288    
4289        If C{g} is the returned object, then
4290    
4291      @param arg:   Data object representing the function which gradient        - if C{arg} is rank 0 C{g[s]} is the derivative of C{arg} with respect to the C{s}-th spatial dimension.
4292                    to be calculated.        - if C{arg} is rank 1 C{g[i,s]} is the derivative of C{arg[i]} with respect to the C{s}-th spatial dimension.
4293          - if C{arg} is rank 2 C{g[i,j,s]} is the derivative of C{arg[i,j]} with respect to the C{s}-th spatial dimension.
4294          - if C{arg} is rank 3 C{g[i,j,k,s]} is the derivative of C{arg[i,j,k]} with respect to the C{s}-th spatial dimension.
4295    
4296        @param arg: function which gradient to be calculated. Its rank has to be less than 3.
4297        @type arg: L{escript.Data} or L{Symbol}
4298      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the gradient will be calculated.
4299                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4300        @type where: C{None} or L{escript.FunctionSpace}
4301        @return: gradient of arg.
4302        @rtype:  L{escript.Data} or L{Symbol}
4303      """      """
4304      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
4305         return Grad_Symbol(arg,where)         return Grad_Symbol(arg,where)
# Line 3617  def grad(arg,where=None): Line 4309  def grad(arg,where=None):
4309         else:         else:
4310            return arg._grad(where)            return arg._grad(where)
4311      else:      else:
4312        raise TypeError,"grad: Unknown argument type."         raise TypeError,"grad: Unknown argument type."
4313    
4314    class Grad_Symbol(DependendSymbol):
4315       """
4316       L{Symbol} representing the result of the gradient operator
4317       """
4318       def __init__(self,arg,where=None):
4319          """
4320          initialization of gradient L{Symbol} with argument arg
4321          @param arg: argument of function
4322          @type arg: L{Symbol}.
4323          @param where: FunctionSpace in which the gradient will be calculated.
4324                      If not present or C{None} an appropriate default is used.
4325          @type where: C{None} or L{escript.FunctionSpace}
4326          """
4327          d=arg.getDim()
4328          if d==None:
4329             raise ValueError,"argument must have a spatial dimension"
4330          super(Grad_Symbol,self).__init__(args=[arg,where],shape=arg.getShape()+(d,),dim=d)
4331    
4332       def getMyCode(self,argstrs,format="escript"):
4333          """
4334          returns a program code that can be used to evaluate the symbol.
4335    
4336          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4337          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4338          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4339          @type format: C{str}
4340          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4341          @rtype: C{str}
4342          @raise: NotImplementedError: if the requested format is not available
4343          """
4344          if format=="escript" or format=="str"  or format=="text":
4345             return "grad(%s,where=%s)"%(argstrs[0],argstrs[1])
4346          else:
4347             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4348    
4349       def substitute(self,argvals):
4350          """
4351          assigns new values to symbols in the definition of the symbol.
4352          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4353    
4354          @param argvals: new values assigned to symbols
4355          @type argvals: C{dict} with keywords of type L{Symbol}.
4356          @return: result of the substitution process. Operations are executed as much as possible.
4357          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4358          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4359          """
4360          if argvals.has_key(self):
4361             arg=argvals[self]
4362             if self.isAppropriateValue(arg):
4363                return arg
4364             else:
4365                raise TypeError,"%s: new value is not appropriate."%str(self)
4366          else:
4367             arg=self.getSubstitutedArguments(argvals)
4368             return grad(arg[0],where=arg[1])
4369    
4370       def diff(self,arg):
4371          """
4372          differential of this object
4373    
4374          @param arg: the derivative is calculated with respect to arg
4375          @type arg: L{escript.Symbol}
4376          @return: derivative with respect to C{arg}
4377          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4378          """
4379          if arg==self:
4380             return identity(self.getShape())
4381          else:
4382             return grad(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4383    
4384  def integrate(arg,where=None):  def integrate(arg,where=None):
4385      """      """
4386      Return the integral if the function represented by Data object arg over      Return the integral of the function C{arg} over its domain. If C{where} is present C{arg} is interpolated to C{where}
4387      its domain.      before integration.
4388    
4389      @param arg:   Data object representing the function which is integrated.      @param arg:   the function which is integrated.
4390        @type arg: L{escript.Data} or L{Symbol}
4391      @param where: FunctionSpace in which the integral is calculated.      @param where: FunctionSpace in which the integral is calculated.
4392                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4393        @type where: C{None} or L{escript.FunctionSpace}
4394        @return: integral of arg.
4395        @rtype:  C{float}, C{numarray.NumArray} or L{Symbol}
4396      """      """
4397      if isinstance(arg,numarray.NumArray):      if isinstance(arg,Symbol):
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,float):  
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,int):  
         if checkForZero(arg):  
            return float(arg)  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,Symbol):  
4398         return Integrate_Symbol(arg,where)         return Integrate_Symbol(arg,where)
4399      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4400         if not where==None: arg=escript.Data(arg,where)         if not where==None: arg=escript.Data(arg,where)
# Line 3654  def integrate(arg,where=None): Line 4405  def integrate(arg,where=None):
4405      else:      else:
4406        raise TypeError,"integrate: Unknown argument type."        raise TypeError,"integrate: Unknown argument type."
4407    
4408    class Integrate_Symbol(DependendSymbol):
4409       """
4410       L{Symbol} representing the result of the spatial integration operator
4411       """
4412       def __init__(self,arg,where=None):
4413          """
4414          initialization of integration L{Symbol} with argument arg
4415          @param arg: argument of the integration
4416          @type arg: L{Symbol}.
4417          @param where: FunctionSpace in which the integration will be calculated.
4418                      If not present or C{None} an appropriate default is used.
4419          @type where: C{None} or L{escript.FunctionSpace}
4420          """
4421          super(Integrate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4422    
4423       def getMyCode(self,argstrs,format="escript"):
4424          """
4425          returns a program code that can be used to evaluate the symbol.
4426    
4427          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4428          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4429          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4430          @type format: C{str}
4431          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4432          @rtype: C{str}
4433          @raise: NotImplementedError: if the requested format is not available
4434          """
4435          if format=="escript" or format=="str"  or format=="text":
4436             return "integrate(%s,where=%s)"%(argstrs[0],argstrs[1])
4437          else:
4438             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4439    
4440       def substitute(self,argvals):
4441          """
4442          assigns new values to symbols in the definition of the symbol.
4443          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4444    
4445          @param argvals: new values assigned to symbols
4446          @type argvals: C{dict} with keywords of type L{Symbol}.
4447          @return: result of the substitution process. Operations are executed as much as possible.
4448          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4449          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4450          """
4451          if argvals.has_key(self):
4452             arg=argvals[self]
4453             if self.isAppropriateValue(arg):
4454                return arg
4455             else:
4456                raise TypeError,"%s: new value is not appropriate."%str(self)
4457          else:
4458             arg=self.getSubstitutedArguments(argvals)
4459             return integrate(arg[0],where=arg[1])
4460    
4461       def diff(self,arg):
4462          """
4463          differential of this object
4464    
4465          @param arg: the derivative is calculated with respect to arg
4466          @type arg: L{escript.Symbol}
4467          @return: derivative with respect to C{arg}
4468          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4469          """
4470          if arg==self:
4471             return identity(self.getShape())
4472          else:
4473             return integrate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4474    
4475    
4476  def interpolate(arg,where):  def interpolate(arg,where):
4477      """      """
4478      Interpolates the function into the FunctionSpace where.      interpolates the function into the FunctionSpace where.
4479    
4480      @param arg:    interpolant      @param arg: interpolant
4481      @param where:  FunctionSpace to interpolate to      @type arg: L{escript.Data} or L{Symbol}
4482        @param where: FunctionSpace to be interpolated to
4483        @type where: L{escript.FunctionSpace}
4484        @return: interpolated argument
4485        @rtype:  C{escript.Data} or L{Symbol}
4486      """      """
4487      if testForZero(arg):      if isinstance(arg,Symbol):
4488        return 0         return Interpolate_Symbol(arg,where)
     elif isinstance(arg,Symbol):  
        return Interpolated_Symbol(arg,where)  
4489      else:      else:
4490         return escript.Data(arg,where)         return escript.Data(arg,where)
4491    
4492  def div(arg,where=None):  class Interpolate_Symbol(DependendSymbol):
4493      """     """
4494      Returns the divergence of arg at where.     L{Symbol} representing the result of the interpolation operator
4495       """
4496       def __init__(self,arg,where):
4497          """
4498          initialization of interpolation L{Symbol} with argument arg
4499          @param arg: argument of the interpolation
4500          @type arg: L{Symbol}.
4501          @param where: FunctionSpace into which the argument is interpolated.
4502          @type where: L{escript.FunctionSpace}
4503          """
4504          super(Interpolate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4505    
4506      @param arg:   Data object representing the function which gradient to     def getMyCode(self,argstrs,format="escript"):
4507                    be calculated.        """
4508      @param where: FunctionSpace in which the gradient will be calculated.        returns a program code that can be used to evaluate the symbol.
                   If not present or C{None} an appropriate default is used.  
     """  
     g=grad(arg,where)  
     return trace(g,axis0=g.getRank()-2,axis1=g.getRank()-1)  
4509    
4510  def jump(arg):        @param argstrs: gives for each argument a string representing the argument for the evaluation.
4511      """        @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4512      Returns the jump of arg across a continuity.        @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4513          @type format: C{str}
4514          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4515          @rtype: C{str}
4516          @raise: NotImplementedError: if the requested format is not available
4517          """
4518          if format=="escript" or format=="str"  or format=="text":
4519             return "interpolate(%s,where=%s)"%(argstrs[0],argstrs[1])
4520          else:
4521             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4522    
4523      @param arg:   Data object representing the function which gradient     def substitute(self,argvals):
4524                    to be calculated.        """
4525      """        assigns new values to symbols in the definition of the symbol.
4526      d=arg.getDomain()        The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
     return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())  
4527    
4528  #=============================        @param argvals: new values assigned to symbols
4529  #        @type argvals: C{dict} with keywords of type L{Symbol}.
4530  # wrapper for various functions: if the argument has attribute the function name        @return: result of the substitution process. Operations are executed as much as possible.
4531  # as an argument it calls the corresponding methods. Otherwise the corresponding        @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4532  # numarray function is called.        @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4533          """
4534          if argvals.has_key(self):
4535             arg=argvals[self]
4536             if self.isAppropriateValue(arg):
4537                return arg
4538             else:
4539                raise TypeError,"%s: new value is not appropriate."%str(self)
4540          else:
4541             arg=self.getSubstitutedArguments(argvals)
4542             return interpolate(arg[0],where=arg[1])
4543    
4544       def diff(self,arg):
4545          """
4546          differential of this object
4547    
4548          @param arg: the derivative is calculated with respect to arg
4549          @type arg: L{escript.Symbol}
4550          @return: derivative with respect to C{arg}
4551          @rtype: L{Symbol} but other types such as L{escript.Data}, L{numarray.NumArray}  are possible.
4552          """
4553          if arg==self:
4554             return identity(self.getShape())
4555          else:
4556             return interpolate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4557    
 # functions involving the underlying Domain:  
4558    
4559  def transpose(arg,axis=None):  def div(arg,where=None):
4560      """      """
4561      Returns the transpose of the Data object arg.      returns the divergence of arg at where.
4562    
4563      @param arg:      @param arg: function which divergence to be calculated. Its shape has to be (d,) where d is the spatial dimension.
4564        @type arg: L{escript.Data} or L{Symbol}
4565        @param where: FunctionSpace in which the divergence will be calculated.
4566                      If not present or C{None} an appropriate default is used.
4567        @type where: C{None} or L{escript.FunctionSpace}
4568        @return: divergence of arg.
4569        @rtype:  L{escript.Data} or L{Symbol}
4570      """      """
     if axis==None:  
        r=0  
        if hasattr(arg,"getRank"): r=arg.getRank()  
        if hasattr(arg,"rank"): r=arg.rank  
        axis=r/2  
4571      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
4572         return Transpose_Symbol(arg,axis=r)          dim=arg.getDim()
4573      if isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4574         # hack for transpose          dim=arg.getDomain().getDim()
        r=arg.getRank()  
        if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"  
        s=arg.getShape()  
        out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())  
        for i in range(s[0]):  
           for j in range(s[1]):  
              out[j,i]=arg[i,j]  
        return out  
        # end hack for transpose  
        return arg.transpose(axis)  
4575      else:      else:
4576         return numarray.transpose(arg,axis=axis)          raise TypeError,"div: argument type not supported"
4577        if not arg.getShape()==(dim,):
4578          raise ValueError,"div: expected shape is (%s,)"%dim
4579        return trace(grad(arg,where))
4580    
4581  def trace(arg,axis0=0,axis1=1):  def jump(arg,domain=None):
4582      """      """
4583      Return      returns the jump of arg across the continuity of the domain
4584    
4585      @param arg:      @param arg: argument
4586        @type arg: L{escript.Data} or L{Symbol}
4587        @param domain: the domain where the discontinuity is located. If domain is not present or equal to C{None}
4588                       the domain of arg is used. If arg is a L{Symbol} the domain must be present.
4589        @type domain: C{None} or L{escript.Domain}
4590        @return: jump of arg
4591        @rtype:  L{escript.Data} or L{Symbol}
4592      """      """
4593      if isinstance(arg,Symbol):      if domain==None: domain=arg.getDomain()
4594         s=list(arg.getShape())              return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))
        s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])  
        return Trace_Symbol(arg,axis0=axis0,axis1=axis1)  
     elif isinstance(arg,escript.Data):  
        # hack for trace  
        s=arg.getShape()  
        if s[axis0]!=s[axis1]:  
            raise ValueError,"illegal axis in trace"  
        out=escript.Scalar(0.,arg.getFunctionSpace())  
        for i in range(s[axis0]):  
           out+=arg[i,i]  
        return out  
        # end hack for trace  
     else:  
        return numarray.trace(arg,axis0=axis0,axis1=axis1)  
4595    
4596    def L2(arg):
4597        """
4598        returns the L2 norm of arg at where
4599        
4600        @param arg: function which L2 to be calculated.
4601        @type arg: L{escript.Data} or L{Symbol}
4602        @return: L2 norm of arg.
4603        @rtype:  L{float} or L{Symbol}
4604        @note: L2(arg) is equivalent to sqrt(integrate(inner(arg,arg)))
4605        """
4606        return sqrt(integrate(inner(arg,arg)))
4607    #=============================
4608    #
4609    
4610  def reorderComponents(arg,index):  def reorderComponents(arg,index):
4611      """      """
4612      resorts the component of arg according to index      resorts the component of arg according to index
4613    
4614      """      """
4615      pass      raise NotImplementedError
4616  #  #
4617  # $Log: util.py,v $  # $Log: util.py,v $
4618  # Revision 1.14.2.16  2005/10/19 06:09:57  gross  # Revision 1.14.2.16  2005/10/19 06:09:57  gross

Legend:
Removed from v.341  
changed lines
  Added in v.530

  ViewVC Help
Powered by ViewVC 1.1.26