/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 341 by gross, Mon Dec 12 05:26:10 2005 UTC revision 429 by gross, Wed Jan 11 05:53:40 2006 UTC
# Line 24  Utility functions for escript Line 24  Utility functions for escript
24  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
25  __licence__="contact: esys@access.uq.edu.au"  __licence__="contact: esys@access.uq.edu.au"
26  __url__="http://www.iservo.edu.au/esys/escript"  __url__="http://www.iservo.edu.au/esys/escript"
27  __version__="$Revision: 329 $"  __version__="$Revision$"
28  __date__="$Date$"  __date__="$Date$"
29    
30    
# Line 43  import os Line 43  import os
43  # def matchType(arg0=0.,arg1=0.):  # def matchType(arg0=0.,arg1=0.):
44  # def matchShape(arg0,arg1):  # def matchShape(arg0,arg1):
45    
 # def maximum(arg0,arg1):  
 # def minimum(arg0,arg1):  
   
46  # def transpose(arg,axis=None):  # def transpose(arg,axis=None):
47  # def trace(arg,axis0=0,axis1=1):  # def trace(arg,axis0=0,axis1=1):
48  # def reorderComponents(arg,index):  # def reorderComponents(arg,index):
# Line 363  def testForZero(arg): Line 360  def testForZero(arg):
360      @return : True if the argument is identical to zero.      @return : True if the argument is identical to zero.
361      @rtype : C{bool}      @rtype : C{bool}
362      """      """
363      try:      if isinstance(arg,numarray.NumArray):
364         return not Lsup(arg)>0.         return not Lsup(arg)>0.
365      except TypeError:      elif isinstance(arg,escript.Data):
366           return False
367        elif isinstance(arg,float):
368           return not Lsup(arg)>0.
369        elif isinstance(arg,int):
370           return not Lsup(arg)>0.
371        elif isinstance(arg,Symbol):
372           return False
373        else:
374         return False         return False
375    
376  def matchType(arg0=0.,arg1=0.):  def matchType(arg0=0.,arg1=0.):
# Line 907  def wherePositive(arg): Line 912  def wherePositive(arg):
912     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
913     """     """
914     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
915        if arg.rank==0:        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))*1.
916           if arg>0:        if isinstance(out,float): out=numarray.array(out)
917             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))  
918     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
919        return arg._wherePositive()        return arg._wherePositive()
920     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 993  def whereNegative(arg): Line 994  def whereNegative(arg):
994     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
995     """     """
996     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
997        if arg.rank==0:        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))*1.
998           if arg<0:        if isinstance(out,float): out=numarray.array(out)
999             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))  
1000     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1001        return arg._whereNegative()        return arg._whereNegative()
1002     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1079  def whereNonNegative(arg): Line 1076  def whereNonNegative(arg):
1076     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1077     """     """
1078     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1079        if arg.rank==0:        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1080           if arg<0:        if isinstance(out,float): out=numarray.array(out)
1081             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))  
1082     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1083        return arg._whereNonNegative()        return arg._whereNonNegative()
1084     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1113  def whereNonPositive(arg): Line 1106  def whereNonPositive(arg):
1106     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1107     """     """
1108     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1109        if arg.rank==0:        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.
1110           if arg>0:        if isinstance(out,float): out=numarray.array(out)
1111             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.  
1112     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1113        return arg._whereNonPositive()        return arg._whereNonPositive()
1114     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1149  def whereZero(arg,tol=0.): Line 1138  def whereZero(arg,tol=0.):
1138     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1139     """     """
1140     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1141        if arg.rank==0:        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1142           if abs(arg)<=tol:        if isinstance(out,float): out=numarray.array(out)
1143             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1144     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1145        if tol>0.:        if tol>0.:
1146           return whereNegative(abs(arg)-tol)           return whereNegative(abs(arg)-tol)
# Line 1236  def whereNonZero(arg,tol=0.): Line 1221  def whereNonZero(arg,tol=0.):
1221     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1222     """     """
1223     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1224        if arg.rank==0:        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.
1225          if abs(arg)>tol:        if isinstance(out,float): out=numarray.array(out)
1226             return numarray.array(1.)        return out
         else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1227     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1228        if tol>0.:        if tol>0.:
1229           return 1.-whereZero(arg,tol)           return 1.-whereZero(arg,tol)
# Line 2877  def length(arg): Line 2858  def length(arg):
2858     """     """
2859     return sqrt(inner(arg,arg))     return sqrt(inner(arg,arg))
2860    
2861    def trace(arg,axis_offset=0):
2862       """
2863       returns the trace of arg which the sum of arg[k,k] over k.
2864    
2865       @param arg: argument
2866       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}.
2867       @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2868                      axis_offset and axis_offset+1 must be equal.
2869       @type axis_offset: C{int}
2870       @return: trace of arg. The rank of the returned object is minus 2 of the rank of arg.
2871       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.
2872       """
2873       if isinstance(arg,numarray.NumArray):
2874          sh=arg.shape
2875          if len(sh)<2:
2876            raise ValueError,"trace: rank of argument must be greater than 1"
2877          if axis_offset<0 or axis_offset>len(sh)-2:
2878            raise ValueError,"trace: axis_offset must be between 0 and %s"%len(sh)-2
2879          s1=1
2880          for i in range(axis_offset): s1*=sh[i]
2881          s2=1
2882          for i in range(axis_offset+2,len(sh)): s2*=sh[i]
2883          if not sh[axis_offset] == sh[axis_offset+1]:
2884            raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2885          arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2886          out=numarray.zeros([s1,s2],numarray.Float)
2887          for i1 in range(s1):
2888            for i2 in range(s2):
2889                for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
2890          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
2891          return out
2892       elif isinstance(arg,escript.Data):
2893          return escript_trace(arg,axis_offset)
2894       elif isinstance(arg,float):
2895          raise TypeError,"trace: illegal argument type float."
2896       elif isinstance(arg,int):
2897          raise TypeError,"trace: illegal argument type int."
2898       elif isinstance(arg,Symbol):
2899          return Trace_Symbol(arg,axis_offset)
2900       else:
2901          raise TypeError,"trace: Unknown argument type."
2902    
2903    def escript_trace(arg,axis_offset): # this should be escript._trace
2904          "arg si a Data objects!!!"
2905          if arg.getRank()<2:
2906            raise ValueError,"escript_trace: rank of argument must be greater than 1"
2907          if axis_offset<0 or axis_offset>arg.getRank()-2:
2908            raise ValueError,"escript_trace: axis_offset must be between 0 and %s"%arg.getRank()-2
2909          s=list(arg.getShape())        
2910          if not s[axis_offset] == s[axis_offset+1]:
2911            raise ValueError,"escript_trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2912          out=escript.Data(0.,tuple(s[0:axis_offset]+s[axis_offset+2:]),arg.getFunctionSpace())
2913          if arg.getRank()==2:
2914             for i0 in range(s[0]):
2915                out+=arg[i0,i0]
2916          elif arg.getRank()==3:
2917             if axis_offset==0:
2918                for i0 in range(s[0]):
2919                      for i2 in range(s[2]):
2920                             out[i2]+=arg[i0,i0,i2]
2921             elif axis_offset==1:
2922                for i0 in range(s[0]):
2923                   for i1 in range(s[1]):
2924                             out[i0]+=arg[i0,i1,i1]
2925          elif arg.getRank()==4:
2926             if axis_offset==0:
2927                for i0 in range(s[0]):
2928                      for i2 in range(s[2]):
2929                         for i3 in range(s[3]):
2930                             out[i2,i3]+=arg[i0,i0,i2,i3]
2931             elif axis_offset==1:
2932                for i0 in range(s[0]):
2933                   for i1 in range(s[1]):
2934                         for i3 in range(s[3]):
2935                             out[i0,i3]+=arg[i0,i1,i1,i3]
2936             elif axis_offset==2:
2937                for i0 in range(s[0]):
2938                   for i1 in range(s[1]):
2939                      for i2 in range(s[2]):
2940                             out[i0,i1]+=arg[i0,i1,i2,i2]
2941          return out
2942    class Trace_Symbol(DependendSymbol):
2943       """
2944       L{Symbol} representing the result of the trace function
2945       """
2946       def __init__(self,arg,axis_offset=0):
2947          """
2948          initialization of trace L{Symbol} with argument arg
2949          @param arg: argument of function
2950          @type arg: L{Symbol}.
2951          @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2952                      axis_offset and axis_offset+1 must be equal.
2953          @type axis_offset: C{int}
2954          """
2955          if arg.getRank()<2:
2956            raise ValueError,"Trace_Symbol: rank of argument must be greater than 1"
2957          if axis_offset<0 or axis_offset>arg.getRank()-2:
2958            raise ValueError,"Trace_Symbol: axis_offset must be between 0 and %s"%arg.getRank()-2
2959          s=list(arg.getShape())        
2960          if not s[axis_offset] == s[axis_offset+1]:
2961            raise ValueError,"Trace_Symbol: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2962          super(Trace_Symbol,self).__init__(args=[arg,axis_offset],shape=tuple(s[0:axis_offset]+s[axis_offset+2:]),dim=arg.getDim())
2963    
2964       def getMyCode(self,argstrs,format="escript"):
2965          """
2966          returns a program code that can be used to evaluate the symbol.
2967    
2968          @param argstrs: gives for each argument a string representing the argument for the evaluation.
2969          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
2970          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
2971          @type format: C{str}
2972          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
2973          @rtype: C{str}
2974          @raise: NotImplementedError: if the requested format is not available
2975          """
2976          if format=="escript" or format=="str"  or format=="text":
2977             return "trace(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
2978          else:
2979             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
2980    
2981       def substitute(self,argvals):
2982          """
2983          assigns new values to symbols in the definition of the symbol.
2984          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
2985    
2986          @param argvals: new values assigned to symbols
2987          @type argvals: C{dict} with keywords of type L{Symbol}.
2988          @return: result of the substitution process. Operations are executed as much as possible.
2989          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
2990          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
2991          """
2992          if argvals.has_key(self):
2993             arg=argvals[self]
2994             if self.isAppropriateValue(arg):
2995                return arg
2996             else:
2997                raise TypeError,"%s: new value is not appropriate."%str(self)
2998          else:
2999             arg=self.getSubstitutedArguments(argvals)
3000             return trace(arg[0],axis_offset=arg[1])
3001    
3002       def diff(self,arg):
3003          """
3004          differential of this object
3005    
3006          @param arg: the derivative is calculated with respect to arg
3007          @type arg: L{escript.Symbol}
3008          @return: derivative with respect to C{arg}
3009          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3010          """
3011          if arg==self:
3012             return identity(self.getShape())
3013          else:
3014             return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3015    
3016  #=======================================================  #=======================================================
3017  #  Binary operations:  #  Binary operations:
3018  #=======================================================  #=======================================================
# Line 3304  def maximum(*args): Line 3440  def maximum(*args):
3440         if out==None:         if out==None:
3441            out=a            out=a
3442         else:         else:
3443            m=whereNegative(out-a)            diff=add(a,-out)
3444            out=m*a+(1.-m)*out            out=add(out,mult(wherePositive(diff),diff))
3445      return out      return out
3446        
3447  def minimum(*arg):  def minimum(*args):
3448      """      """
3449      the minimum over arguments args      the minimum over arguments args
3450    
# Line 3322  def minimum(*arg): Line 3458  def minimum(*arg):
3458         if out==None:         if out==None:
3459            out=a            out=a
3460         else:         else:
3461            m=whereNegative(out-a)            diff=add(a,-out)
3462            out=m*out+(1.-m)*a            out=add(out,mult(whereNegative(diff),diff))
3463      return out      return out
3464    
3465    def clip(arg,minval=0.,maxval=1.):
3466        """
3467        cuts the values of arg between minval and maxval
3468    
3469        @param arg: argument
3470        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float}
3471        @param minval: lower range
3472        @type arg: C{float}
3473        @param maxval: upper range
3474        @type arg: C{float}
3475        @return: is on object with all its value between minval and maxval. value of the argument that greater then minval and
3476                 less then maxval are unchanged.
3477        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float} depending on the input
3478        @raise ValueError: if minval>maxval
3479        """
3480        if minval>maxval:
3481           raise ValueError,"minval = %s must be less then maxval %s"%(minval,maxval)
3482        return minimum(maximum(minval,arg),maxval)
3483    
3484        
3485  def inner(arg0,arg1):  def inner(arg0,arg1):
3486      """      """
# Line 3348  def inner(arg0,arg1): Line 3504  def inner(arg0,arg1):
3504      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3505      if not sh0==sh1:      if not sh0==sh1:
3506          raise ValueError,"inner: shape of arguments does not match"          raise ValueError,"inner: shape of arguments does not match"
3507      return generalTensorProduct(arg0,arg1,offset=len(sh0))      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))
3508    
3509  def matrixmult(arg0,arg1):  def matrixmult(arg0,arg1):
3510      """      """
# Line 3376  def matrixmult(arg0,arg1): Line 3532  def matrixmult(arg0,arg1):
3532          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
3533      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
3534          raise ValueError,"second argument must have rank 1 or 2"          raise ValueError,"second argument must have rank 1 or 2"
3535      return generalTensorProduct(arg0,arg1,offset=1)      return generalTensorProduct(arg0,arg1,axis_offset=1)
3536    
3537  def outer(arg0,arg1):  def outer(arg0,arg1):
3538      """      """
# Line 3394  def outer(arg0,arg1): Line 3550  def outer(arg0,arg1):
3550      @return: the outer product of arg0 and arg1 at each data point      @return: the outer product of arg0 and arg1 at each data point
3551      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3552      """      """
3553      return generalTensorProduct(arg0,arg1,offset=0)      return generalTensorProduct(arg0,arg1,axis_offset=0)
3554    
3555    
3556  def tensormult(arg0,arg1):  def tensormult(arg0,arg1):
# Line 3436  def tensormult(arg0,arg1): Line 3592  def tensormult(arg0,arg1):
3592      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
3593      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3594      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
3595         return generalTensorProduct(arg0,arg1,offset=1)         return generalTensorProduct(arg0,arg1,axis_offset=1)
3596      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
3597         return generalTensorProduct(arg0,arg1,offset=2)         return generalTensorProduct(arg0,arg1,axis_offset=2)
3598      else:      else:
3599          raise ValueError,"tensormult: first argument must have rank 2 or 4"          raise ValueError,"tensormult: first argument must have rank 2 or 4"
3600    
3601  def generalTensorProduct(arg0,arg1,offset=0):  def generalTensorProduct(arg0,arg1,axis_offset=0):
3602      """      """
3603      generalized tensor product      generalized tensor product
3604    
3605      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]
3606    
3607      where s runs through arg0.Shape[:arg0.Rank-offset]      where s runs through arg0.Shape[:arg0.Rank-axis_offset]
3608            r runs trough arg0.Shape[:offset]            r runs trough arg0.Shape[:axis_offset]
3609            t runs through arg1.Shape[offset:]            t runs through arg1.Shape[axis_offset:]
3610    
3611      In the first case the the second dimension of arg0 and the length of arg1 must match and        In the first case the the second dimension of arg0 and the length of arg1 must match and  
3612      in the second case the two last dimensions of arg0 must match the shape of arg1.      in the second case the two last dimensions of arg0 must match the shape of arg1.
# Line 3467  def generalTensorProduct(arg0,arg1,offse Line 3623  def generalTensorProduct(arg0,arg1,offse
3623      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols
3624      if isinstance(arg0,numarray.NumArray):      if isinstance(arg0,numarray.NumArray):
3625         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3626             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3627         else:         else:
3628             if not arg0.shape[arg0.rank-offset:]==arg1.shape[:offset]:             if not arg0.shape[arg0.rank-axis_offset:]==arg1.shape[:axis_offset]:
3629                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3630             arg0_c=arg0.copy()             arg0_c=arg0.copy()
3631             arg1_c=arg1.copy()             arg1_c=arg1.copy()
3632             sh0,sh1=arg0.shape,arg1.shape             sh0,sh1=arg0.shape,arg1.shape
3633             d0,d1,d01=1,1,1             d0,d1,d01=1,1,1
3634             for i in sh0[:arg0.rank-offset]: d0*=i             for i in sh0[:arg0.rank-axis_offset]: d0*=i
3635             for i in sh1[offset:]: d1*=i             for i in sh1[axis_offset:]: d1*=i
3636             for i in sh1[:offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
3637             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
3638             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
3639             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float)
3640             for i0 in range(d0):             for i0 in range(d0):
3641                      for i1 in range(d1):                      for i1 in range(d1):
3642                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])
3643             out.resize(sh0[:arg0.rank-offset]+sh1[offset:])             out.resize(sh0[:arg0.rank-axis_offset]+sh1[axis_offset:])
3644             return out             return out
3645      elif isinstance(arg0,escript.Data):      elif isinstance(arg0,escript.Data):
3646         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3647             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3648         else:         else:
3649             return escript_generalTensorProduct(arg0,arg1,offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,offset)             return escript_generalTensorProduct(arg0,arg1,axis_offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,axis_offset)
3650      else:            else:      
3651         return GeneralTensorProduct_Symbol(arg0,arg1,offset)         return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3652                                    
3653  class GeneralTensorProduct_Symbol(DependendSymbol):  class GeneralTensorProduct_Symbol(DependendSymbol):
3654     """     """
3655     Symbol representing the quotient of two arguments.     Symbol representing the quotient of two arguments.
3656     """     """
3657     def __init__(self,arg0,arg1,offset=0):     def __init__(self,arg0,arg1,axis_offset=0):
3658         """         """
3659         initialization of L{Symbol} representing the quotient of two arguments         initialization of L{Symbol} representing the quotient of two arguments
3660    
# Line 3511  class GeneralTensorProduct_Symbol(Depend Line 3667  class GeneralTensorProduct_Symbol(Depend
3667         """         """
3668         sh_arg0=pokeShape(arg0)         sh_arg0=pokeShape(arg0)
3669         sh_arg1=pokeShape(arg1)         sh_arg1=pokeShape(arg1)
3670         sh0=sh_arg0[:len(sh_arg0)-offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
3671         sh01=sh_arg0[len(sh_arg0)-offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
3672         sh10=sh_arg1[:offset]         sh10=sh_arg1[:axis_offset]
3673         sh1=sh_arg1[offset:]         sh1=sh_arg1[axis_offset:]
3674         if not sh01==sh10:         if not sh01==sh10:
3675             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3676         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,offset])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,axis_offset])
3677    
3678     def getMyCode(self,argstrs,format="escript"):     def getMyCode(self,argstrs,format="escript"):
3679        """        """
# Line 3532  class GeneralTensorProduct_Symbol(Depend Line 3688  class GeneralTensorProduct_Symbol(Depend
3688        @raise: NotImplementedError: if the requested format is not available        @raise: NotImplementedError: if the requested format is not available
3689        """        """
3690        if format=="escript" or format=="str" or format=="text":        if format=="escript" or format=="str" or format=="text":
3691           return "generalTensorProduct(%s,%s,offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])           return "generalTensorProduct(%s,%s,axis_offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])
3692        else:        else:
3693           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)
3694    
# Line 3557  class GeneralTensorProduct_Symbol(Depend Line 3713  class GeneralTensorProduct_Symbol(Depend
3713           args=self.getSubstitutedArguments(argvals)           args=self.getSubstitutedArguments(argvals)
3714           return generalTensorProduct(args[0],args[1],args[2])           return generalTensorProduct(args[0],args[1],args[2])
3715    
3716  def escript_generalTensorProduct(arg0,arg1,offset): # this should be escript._generalTensorProduct  def escript_generalTensorProduct(arg0,arg1,axis_offset): # this should be escript._generalTensorProduct
3717      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"
3718      # calculate the return shape:      # calculate the return shape:
3719      shape0=arg0.getShape()[:arg0.getRank()-offset]      shape0=arg0.getShape()[:arg0.getRank()-axis_offset]
3720      shape01=arg0.getShape()[arg0.getRank()-offset:]      shape01=arg0.getShape()[arg0.getRank()-axis_offset:]
3721      shape10=arg1.getShape()[:offset]      shape10=arg1.getShape()[:axis_offset]
3722      shape1=arg1.getShape()[offset:]      shape1=arg1.getShape()[axis_offset:]
3723      if not shape01==shape10:      if not shape01==shape10:
3724          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3725    
3726        # whatr function space should be used? (this here is not good!)
3727        fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace()
3728      # create return value:      # create return value:
3729      out=escript.Data(0.,tuple(shape0+shape1),arg0.getFunctionSpace())      out=escript.Data(0.,tuple(shape0+shape1),fs)
3730      #      #
3731      s0=[[]]      s0=[[]]
3732      for k in shape0:      for k in shape0:
# Line 3591  def escript_generalTensorProduct(arg0,ar Line 3749  def escript_generalTensorProduct(arg0,ar
3749    
3750      for i0 in s0:      for i0 in s0:
3751         for i1 in s1:         for i1 in s1:
3752           s=escript.Scalar(0.,arg0.getFunctionSpace())           s=escript.Scalar(0.,fs)
3753           for i01 in s01:           for i01 in s01:
3754              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))
3755           out.__setitem__(tuple(i0+i1),s)           out.__setitem__(tuple(i0+i1),s)
3756      return out      return out
3757    
3758    
3759  #=========================================================  #=========================================================
3760  #   some little helpers  #  functions dealing with spatial dependency
3761  #=========================================================  #=========================================================
3762  def grad(arg,where=None):  def grad(arg,where=None):
3763      """      """
3764      Returns the spatial gradient of arg at where.      Returns the spatial gradient of arg at where.
3765    
3766        If C{g} is the returned object, then
3767    
3768          - if C{arg} is rank 0 C{g[s]} is the derivative of C{arg} with respect to the C{s}-th spatial dimension.
3769          - if C{arg} is rank 1 C{g[i,s]} is the derivative of C{arg[i]} with respect to the C{s}-th spatial dimension.
3770          - if C{arg} is rank 2 C{g[i,j,s]} is the derivative of C{arg[i,j]} with respect to the C{s}-th spatial dimension.
3771          - if C{arg} is rank 3 C{g[i,j,k,s]} is the derivative of C{arg[i,j,k]} with respect to the C{s}-th spatial dimension.
3772    
3773      @param arg:   Data object representing the function which gradient      @param arg: function which gradient to be calculated. Its rank has to be less than 3.
3774                    to be calculated.      @type arg: L{escript.Data} or L{Symbol}
3775      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the gradient will be calculated.
3776                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
3777        @type where: C{None} or L{escript.FunctionSpace}
3778        @return: gradient of arg.
3779        @rtype:  L{escript.Data} or L{Symbol}
3780      """      """
3781      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
3782         return Grad_Symbol(arg,where)         return Grad_Symbol(arg,where)
# Line 3617  def grad(arg,where=None): Line 3786  def grad(arg,where=None):
3786         else:         else:
3787            return arg._grad(where)            return arg._grad(where)
3788      else:      else:
3789        raise TypeError,"grad: Unknown argument type."         raise TypeError,"grad: Unknown argument type."
3790    
3791    class Grad_Symbol(DependendSymbol):
3792       """
3793       L{Symbol} representing the result of the gradient operator
3794       """
3795       def __init__(self,arg,where=None):
3796          """
3797          initialization of gradient L{Symbol} with argument arg
3798          @param arg: argument of function
3799          @type arg: L{Symbol}.
3800          @param where: FunctionSpace in which the gradient will be calculated.
3801                      If not present or C{None} an appropriate default is used.
3802          @type where: C{None} or L{escript.FunctionSpace}
3803          """
3804          d=arg.getDim()
3805          if d==None:
3806             raise ValueError,"argument must have a spatial dimension"
3807          super(Grad_Symbol,self).__init__(args=[arg,where],shape=tuple(list(arg.getShape()).extend(d)),dim=d)
3808    
3809       def getMyCode(self,argstrs,format="escript"):
3810          """
3811          returns a program code that can be used to evaluate the symbol.
3812    
3813          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3814          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3815          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3816          @type format: C{str}
3817          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3818          @rtype: C{str}
3819          @raise: NotImplementedError: if the requested format is not available
3820          """
3821          if format=="escript" or format=="str"  or format=="text":
3822             return "grad(%s,where=%s)"%(argstrs[0],argstrs[1])
3823          else:
3824             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3825    
3826       def substitute(self,argvals):
3827          """
3828          assigns new values to symbols in the definition of the symbol.
3829          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3830    
3831          @param argvals: new values assigned to symbols
3832          @type argvals: C{dict} with keywords of type L{Symbol}.
3833          @return: result of the substitution process. Operations are executed as much as possible.
3834          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3835          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3836          """
3837          if argvals.has_key(self):
3838             arg=argvals[self]
3839             if self.isAppropriateValue(arg):
3840                return arg
3841             else:
3842                raise TypeError,"%s: new value is not appropriate."%str(self)
3843          else:
3844             arg=self.getSubstitutedArguments(argvals)
3845             return grad(arg[0],where=arg[1])
3846    
3847       def diff(self,arg):
3848          """
3849          differential of this object
3850    
3851          @param arg: the derivative is calculated with respect to arg
3852          @type arg: L{escript.Symbol}
3853          @return: derivative with respect to C{arg}
3854          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3855          """
3856          if arg==self:
3857             return identity(self.getShape())
3858          else:
3859             return grad(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
3860    
3861  def integrate(arg,where=None):  def integrate(arg,where=None):
3862      """      """
3863      Return the integral if the function represented by Data object arg over      Return the integral of the function C{arg} over its domain. If C{where} is present C{arg} is interpolated to C{where}
3864      its domain.      before integration.
3865    
3866      @param arg:   Data object representing the function which is integrated.      @param arg:   the function which is integrated.
3867        @type arg: L{escript.Data} or L{Symbol}
3868      @param where: FunctionSpace in which the integral is calculated.      @param where: FunctionSpace in which the integral is calculated.
3869                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
3870        @type where: C{None} or L{escript.FunctionSpace}
3871        @return: integral of arg.
3872        @rtype:  C{float}, C{numarray.NumArray} or L{Symbol}
3873      """      """
3874      if isinstance(arg,numarray.NumArray):      if isinstance(arg,Symbol):
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,float):  
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,int):  
         if checkForZero(arg):  
            return float(arg)  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,Symbol):  
3875         return Integrate_Symbol(arg,where)         return Integrate_Symbol(arg,where)
3876      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
3877         if not where==None: arg=escript.Data(arg,where)         if not where==None: arg=escript.Data(arg,where)
# Line 3654  def integrate(arg,where=None): Line 3882  def integrate(arg,where=None):
3882      else:      else:
3883        raise TypeError,"integrate: Unknown argument type."        raise TypeError,"integrate: Unknown argument type."
3884    
3885    class Integrate_Symbol(DependendSymbol):
3886       """
3887       L{Symbol} representing the result of the spatial integration operator
3888       """
3889       def __init__(self,arg,where=None):
3890          """
3891          initialization of integration L{Symbol} with argument arg
3892          @param arg: argument of the integration
3893          @type arg: L{Symbol}.
3894          @param where: FunctionSpace in which the integration will be calculated.
3895                      If not present or C{None} an appropriate default is used.
3896          @type where: C{None} or L{escript.FunctionSpace}
3897          """
3898          super(Integrate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
3899    
3900       def getMyCode(self,argstrs,format="escript"):
3901          """
3902          returns a program code that can be used to evaluate the symbol.
3903    
3904          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3905          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3906          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3907          @type format: C{str}
3908          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3909          @rtype: C{str}
3910          @raise: NotImplementedError: if the requested format is not available
3911          """
3912          if format=="escript" or format=="str"  or format=="text":
3913             return "integrate(%s,where=%s)"%(argstrs[0],argstrs[1])
3914          else:
3915             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3916    
3917       def substitute(self,argvals):
3918          """
3919          assigns new values to symbols in the definition of the symbol.
3920          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3921    
3922          @param argvals: new values assigned to symbols
3923          @type argvals: C{dict} with keywords of type L{Symbol}.
3924          @return: result of the substitution process. Operations are executed as much as possible.
3925          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3926          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3927          """
3928          if argvals.has_key(self):
3929             arg=argvals[self]
3930             if self.isAppropriateValue(arg):
3931                return arg
3932             else:
3933                raise TypeError,"%s: new value is not appropriate."%str(self)
3934          else:
3935             arg=self.getSubstitutedArguments(argvals)
3936             return integrate(arg[0],where=arg[1])
3937    
3938       def diff(self,arg):
3939          """
3940          differential of this object
3941    
3942          @param arg: the derivative is calculated with respect to arg
3943          @type arg: L{escript.Symbol}
3944          @return: derivative with respect to C{arg}
3945          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3946          """
3947          if arg==self:
3948             return identity(self.getShape())
3949          else:
3950             return integrate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
3951    
3952    
3953  def interpolate(arg,where):  def interpolate(arg,where):
3954      """      """
3955      Interpolates the function into the FunctionSpace where.      interpolates the function into the FunctionSpace where.
3956    
3957      @param arg:    interpolant      @param arg: interpolant
3958      @param where:  FunctionSpace to interpolate to      @type arg: L{escript.Data} or L{Symbol}
3959        @param where: FunctionSpace to be interpolated to
3960        @type where: L{escript.FunctionSpace}
3961        @return: interpolated argument
3962        @rtype:  C{escript.Data} or L{Symbol}
3963      """      """
3964      if testForZero(arg):      if isinstance(arg,Symbol):
3965        return 0         return Interpolate_Symbol(arg,where)
     elif isinstance(arg,Symbol):  
        return Interpolated_Symbol(arg,where)  
3966      else:      else:
3967         return escript.Data(arg,where)         return escript.Data(arg,where)
3968    
3969    class Interpolate_Symbol(DependendSymbol):
3970       """
3971       L{Symbol} representing the result of the interpolation operator
3972       """
3973       def __init__(self,arg,where):
3974          """
3975          initialization of interpolation L{Symbol} with argument arg
3976          @param arg: argument of the interpolation
3977          @type arg: L{Symbol}.
3978          @param where: FunctionSpace into which the argument is interpolated.
3979          @type where: L{escript.FunctionSpace}
3980          """
3981          super(Interpolate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
3982    
3983       def getMyCode(self,argstrs,format="escript"):
3984          """
3985          returns a program code that can be used to evaluate the symbol.
3986    
3987          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3988          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3989          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3990          @type format: C{str}
3991          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3992          @rtype: C{str}
3993          @raise: NotImplementedError: if the requested format is not available
3994          """
3995          if format=="escript" or format=="str"  or format=="text":
3996             return "interpolate(%s,where=%s)"%(argstrs[0],argstrs[1])
3997          else:
3998             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3999    
4000       def substitute(self,argvals):
4001          """
4002          assigns new values to symbols in the definition of the symbol.
4003          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4004    
4005          @param argvals: new values assigned to symbols
4006          @type argvals: C{dict} with keywords of type L{Symbol}.
4007          @return: result of the substitution process. Operations are executed as much as possible.
4008          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4009          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4010          """
4011          if argvals.has_key(self):
4012             arg=argvals[self]
4013             if self.isAppropriateValue(arg):
4014                return arg
4015             else:
4016                raise TypeError,"%s: new value is not appropriate."%str(self)
4017          else:
4018             arg=self.getSubstitutedArguments(argvals)
4019             return interpolate(arg[0],where=arg[1])
4020    
4021       def diff(self,arg):
4022          """
4023          differential of this object
4024    
4025          @param arg: the derivative is calculated with respect to arg
4026          @type arg: L{escript.Symbol}
4027          @return: derivative with respect to C{arg}
4028          @rtype: L{Symbol} but other types such as L{escript.Data}, L{numarray.NumArray}  are possible.
4029          """
4030          if arg==self:
4031             return identity(self.getShape())
4032          else:
4033             return interpolate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4034    
4035    
4036  def div(arg,where=None):  def div(arg,where=None):
4037      """      """
4038      Returns the divergence of arg at where.      returns the divergence of arg at where.
4039    
4040      @param arg:   Data object representing the function which gradient to      @param arg: function which divergence to be calculated. Its shape has to be (d,) where d is the spatial dimension.
4041                    be calculated.      @type arg: L{escript.Data} or L{Symbol}
4042      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the divergence will be calculated.
4043                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4044        @type where: C{None} or L{escript.FunctionSpace}
4045        @return: divergence of arg.
4046        @rtype:  L{escript.Data} or L{Symbol}
4047      """      """
4048      g=grad(arg,where)      if not arg.getShape()==(arg.getDim(),):
4049      return trace(g,axis0=g.getRank()-2,axis1=g.getRank()-1)        raise ValueError,"div: expected shape is (%s,)"%arg.getDim()
4050        return trace(grad(arg,where))
4051    
4052  def jump(arg):  def jump(arg,domain=None):
4053      """      """
4054      Returns the jump of arg across a continuity.      returns the jump of arg across the continuity of the domain
4055    
4056      @param arg:   Data object representing the function which gradient      @param arg: argument
4057                    to be calculated.      @type arg: L{escript.Data} or L{Symbol}
4058        @param domain: the domain where the discontinuity is located. If domain is not present or equal to C{None}
4059                       the domain of arg is used. If arg is a L{Symbol} the domain must be present.
4060        @type domain: C{None} or L{escript.Domain}
4061        @return: jump of arg
4062        @rtype:  L{escript.Data} or L{Symbol}
4063      """      """
4064      d=arg.getDomain()      if domain==None: domain=arg.getDomain()
4065      return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())      return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))
   
4066  #=============================  #=============================
4067  #  #
4068  # wrapper for various functions: if the argument has attribute the function name  # wrapper for various functions: if the argument has attribute the function name
# Line 3726  def transpose(arg,axis=None): Line 4099  def transpose(arg,axis=None):
4099      else:      else:
4100         return numarray.transpose(arg,axis=axis)         return numarray.transpose(arg,axis=axis)
4101    
 def trace(arg,axis0=0,axis1=1):  
     """  
     Return  
   
     @param arg:  
     """  
     if isinstance(arg,Symbol):  
        s=list(arg.getShape())          
        s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])  
        return Trace_Symbol(arg,axis0=axis0,axis1=axis1)  
     elif isinstance(arg,escript.Data):  
        # hack for trace  
        s=arg.getShape()  
        if s[axis0]!=s[axis1]:  
            raise ValueError,"illegal axis in trace"  
        out=escript.Scalar(0.,arg.getFunctionSpace())  
        for i in range(s[axis0]):  
           out+=arg[i,i]  
        return out  
        # end hack for trace  
     else:  
        return numarray.trace(arg,axis0=axis0,axis1=axis1)  
4102    
4103    
4104  def reorderComponents(arg,index):  def reorderComponents(arg,index):

Legend:
Removed from v.341  
changed lines
  Added in v.429

  ViewVC Help
Powered by ViewVC 1.1.26