/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 414 by gross, Wed Jan 4 05:29:05 2006 UTC revision 429 by gross, Wed Jan 11 05:53:40 2006 UTC
# Line 2858  def length(arg): Line 2858  def length(arg):
2858     """     """
2859     return sqrt(inner(arg,arg))     return sqrt(inner(arg,arg))
2860    
2861    def trace(arg,axis_offset=0):
2862       """
2863       returns the trace of arg which the sum of arg[k,k] over k.
2864    
2865       @param arg: argument
2866       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}.
2867       @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2868                      axis_offset and axis_offset+1 must be equal.
2869       @type axis_offset: C{int}
2870       @return: trace of arg. The rank of the returned object is minus 2 of the rank of arg.
2871       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.
2872       """
2873       if isinstance(arg,numarray.NumArray):
2874          sh=arg.shape
2875          if len(sh)<2:
2876            raise ValueError,"trace: rank of argument must be greater than 1"
2877          if axis_offset<0 or axis_offset>len(sh)-2:
2878            raise ValueError,"trace: axis_offset must be between 0 and %s"%len(sh)-2
2879          s1=1
2880          for i in range(axis_offset): s1*=sh[i]
2881          s2=1
2882          for i in range(axis_offset+2,len(sh)): s2*=sh[i]
2883          if not sh[axis_offset] == sh[axis_offset+1]:
2884            raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2885          arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2886          out=numarray.zeros([s1,s2],numarray.Float)
2887          for i1 in range(s1):
2888            for i2 in range(s2):
2889                for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
2890          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
2891          return out
2892       elif isinstance(arg,escript.Data):
2893          return escript_trace(arg,axis_offset)
2894       elif isinstance(arg,float):
2895          raise TypeError,"trace: illegal argument type float."
2896       elif isinstance(arg,int):
2897          raise TypeError,"trace: illegal argument type int."
2898       elif isinstance(arg,Symbol):
2899          return Trace_Symbol(arg,axis_offset)
2900       else:
2901          raise TypeError,"trace: Unknown argument type."
2902    
2903    def escript_trace(arg,axis_offset): # this should be escript._trace
2904          "arg si a Data objects!!!"
2905          if arg.getRank()<2:
2906            raise ValueError,"escript_trace: rank of argument must be greater than 1"
2907          if axis_offset<0 or axis_offset>arg.getRank()-2:
2908            raise ValueError,"escript_trace: axis_offset must be between 0 and %s"%arg.getRank()-2
2909          s=list(arg.getShape())        
2910          if not s[axis_offset] == s[axis_offset+1]:
2911            raise ValueError,"escript_trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2912          out=escript.Data(0.,tuple(s[0:axis_offset]+s[axis_offset+2:]),arg.getFunctionSpace())
2913          if arg.getRank()==2:
2914             for i0 in range(s[0]):
2915                out+=arg[i0,i0]
2916          elif arg.getRank()==3:
2917             if axis_offset==0:
2918                for i0 in range(s[0]):
2919                      for i2 in range(s[2]):
2920                             out[i2]+=arg[i0,i0,i2]
2921             elif axis_offset==1:
2922                for i0 in range(s[0]):
2923                   for i1 in range(s[1]):
2924                             out[i0]+=arg[i0,i1,i1]
2925          elif arg.getRank()==4:
2926             if axis_offset==0:
2927                for i0 in range(s[0]):
2928                      for i2 in range(s[2]):
2929                         for i3 in range(s[3]):
2930                             out[i2,i3]+=arg[i0,i0,i2,i3]
2931             elif axis_offset==1:
2932                for i0 in range(s[0]):
2933                   for i1 in range(s[1]):
2934                         for i3 in range(s[3]):
2935                             out[i0,i3]+=arg[i0,i1,i1,i3]
2936             elif axis_offset==2:
2937                for i0 in range(s[0]):
2938                   for i1 in range(s[1]):
2939                      for i2 in range(s[2]):
2940                             out[i0,i1]+=arg[i0,i1,i2,i2]
2941          return out
2942    class Trace_Symbol(DependendSymbol):
2943       """
2944       L{Symbol} representing the result of the trace function
2945       """
2946       def __init__(self,arg,axis_offset=0):
2947          """
2948          initialization of trace L{Symbol} with argument arg
2949          @param arg: argument of function
2950          @type arg: L{Symbol}.
2951          @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2952                      axis_offset and axis_offset+1 must be equal.
2953          @type axis_offset: C{int}
2954          """
2955          if arg.getRank()<2:
2956            raise ValueError,"Trace_Symbol: rank of argument must be greater than 1"
2957          if axis_offset<0 or axis_offset>arg.getRank()-2:
2958            raise ValueError,"Trace_Symbol: axis_offset must be between 0 and %s"%arg.getRank()-2
2959          s=list(arg.getShape())        
2960          if not s[axis_offset] == s[axis_offset+1]:
2961            raise ValueError,"Trace_Symbol: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2962          super(Trace_Symbol,self).__init__(args=[arg,axis_offset],shape=tuple(s[0:axis_offset]+s[axis_offset+2:]),dim=arg.getDim())
2963    
2964       def getMyCode(self,argstrs,format="escript"):
2965          """
2966          returns a program code that can be used to evaluate the symbol.
2967    
2968          @param argstrs: gives for each argument a string representing the argument for the evaluation.
2969          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
2970          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
2971          @type format: C{str}
2972          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
2973          @rtype: C{str}
2974          @raise: NotImplementedError: if the requested format is not available
2975          """
2976          if format=="escript" or format=="str"  or format=="text":
2977             return "trace(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
2978          else:
2979             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
2980    
2981       def substitute(self,argvals):
2982          """
2983          assigns new values to symbols in the definition of the symbol.
2984          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
2985    
2986          @param argvals: new values assigned to symbols
2987          @type argvals: C{dict} with keywords of type L{Symbol}.
2988          @return: result of the substitution process. Operations are executed as much as possible.
2989          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
2990          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
2991          """
2992          if argvals.has_key(self):
2993             arg=argvals[self]
2994             if self.isAppropriateValue(arg):
2995                return arg
2996             else:
2997                raise TypeError,"%s: new value is not appropriate."%str(self)
2998          else:
2999             arg=self.getSubstitutedArguments(argvals)
3000             return trace(arg[0],axis_offset=arg[1])
3001    
3002       def diff(self,arg):
3003          """
3004          differential of this object
3005    
3006          @param arg: the derivative is calculated with respect to arg
3007          @type arg: L{escript.Symbol}
3008          @return: derivative with respect to C{arg}
3009          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3010          """
3011          if arg==self:
3012             return identity(self.getShape())
3013          else:
3014             return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3015    
3016  #=======================================================  #=======================================================
3017  #  Binary operations:  #  Binary operations:
3018  #=======================================================  #=======================================================
# Line 3349  def inner(arg0,arg1): Line 3504  def inner(arg0,arg1):
3504      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3505      if not sh0==sh1:      if not sh0==sh1:
3506          raise ValueError,"inner: shape of arguments does not match"          raise ValueError,"inner: shape of arguments does not match"
3507      return generalTensorProduct(arg0,arg1,offset=len(sh0))      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))
3508    
3509  def matrixmult(arg0,arg1):  def matrixmult(arg0,arg1):
3510      """      """
# Line 3377  def matrixmult(arg0,arg1): Line 3532  def matrixmult(arg0,arg1):
3532          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
3533      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
3534          raise ValueError,"second argument must have rank 1 or 2"          raise ValueError,"second argument must have rank 1 or 2"
3535      return generalTensorProduct(arg0,arg1,offset=1)      return generalTensorProduct(arg0,arg1,axis_offset=1)
3536    
3537  def outer(arg0,arg1):  def outer(arg0,arg1):
3538      """      """
# Line 3395  def outer(arg0,arg1): Line 3550  def outer(arg0,arg1):
3550      @return: the outer product of arg0 and arg1 at each data point      @return: the outer product of arg0 and arg1 at each data point
3551      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3552      """      """
3553      return generalTensorProduct(arg0,arg1,offset=0)      return generalTensorProduct(arg0,arg1,axis_offset=0)
3554    
3555    
3556  def tensormult(arg0,arg1):  def tensormult(arg0,arg1):
# Line 3437  def tensormult(arg0,arg1): Line 3592  def tensormult(arg0,arg1):
3592      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
3593      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
3594      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
3595         return generalTensorProduct(arg0,arg1,offset=1)         return generalTensorProduct(arg0,arg1,axis_offset=1)
3596      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
3597         return generalTensorProduct(arg0,arg1,offset=2)         return generalTensorProduct(arg0,arg1,axis_offset=2)
3598      else:      else:
3599          raise ValueError,"tensormult: first argument must have rank 2 or 4"          raise ValueError,"tensormult: first argument must have rank 2 or 4"
3600    
3601  def generalTensorProduct(arg0,arg1,offset=0):  def generalTensorProduct(arg0,arg1,axis_offset=0):
3602      """      """
3603      generalized tensor product      generalized tensor product
3604    
3605      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]
3606    
3607      where s runs through arg0.Shape[:arg0.Rank-offset]      where s runs through arg0.Shape[:arg0.Rank-axis_offset]
3608            r runs trough arg0.Shape[:offset]            r runs trough arg0.Shape[:axis_offset]
3609            t runs through arg1.Shape[offset:]            t runs through arg1.Shape[axis_offset:]
3610    
3611      In the first case the the second dimension of arg0 and the length of arg1 must match and        In the first case the the second dimension of arg0 and the length of arg1 must match and  
3612      in the second case the two last dimensions of arg0 must match the shape of arg1.      in the second case the two last dimensions of arg0 must match the shape of arg1.
# Line 3468  def generalTensorProduct(arg0,arg1,offse Line 3623  def generalTensorProduct(arg0,arg1,offse
3623      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols
3624      if isinstance(arg0,numarray.NumArray):      if isinstance(arg0,numarray.NumArray):
3625         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3626             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3627         else:         else:
3628             if not arg0.shape[arg0.rank-offset:]==arg1.shape[:offset]:             if not arg0.shape[arg0.rank-axis_offset:]==arg1.shape[:axis_offset]:
3629                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3630             arg0_c=arg0.copy()             arg0_c=arg0.copy()
3631             arg1_c=arg1.copy()             arg1_c=arg1.copy()
3632             sh0,sh1=arg0.shape,arg1.shape             sh0,sh1=arg0.shape,arg1.shape
3633             d0,d1,d01=1,1,1             d0,d1,d01=1,1,1
3634             for i in sh0[:arg0.rank-offset]: d0*=i             for i in sh0[:arg0.rank-axis_offset]: d0*=i
3635             for i in sh1[offset:]: d1*=i             for i in sh1[axis_offset:]: d1*=i
3636             for i in sh1[:offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
3637             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
3638             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
3639             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float)
3640             for i0 in range(d0):             for i0 in range(d0):
3641                      for i1 in range(d1):                      for i1 in range(d1):
3642                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])
3643             out.resize(sh0[:arg0.rank-offset]+sh1[offset:])             out.resize(sh0[:arg0.rank-axis_offset]+sh1[axis_offset:])
3644             return out             return out
3645      elif isinstance(arg0,escript.Data):      elif isinstance(arg0,escript.Data):
3646         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
3647             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3648         else:         else:
3649             return escript_generalTensorProduct(arg0,arg1,offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,offset)             return escript_generalTensorProduct(arg0,arg1,axis_offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,axis_offset)
3650      else:            else:      
3651         return GeneralTensorProduct_Symbol(arg0,arg1,offset)         return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
3652                                    
3653  class GeneralTensorProduct_Symbol(DependendSymbol):  class GeneralTensorProduct_Symbol(DependendSymbol):
3654     """     """
3655     Symbol representing the quotient of two arguments.     Symbol representing the quotient of two arguments.
3656     """     """
3657     def __init__(self,arg0,arg1,offset=0):     def __init__(self,arg0,arg1,axis_offset=0):
3658         """         """
3659         initialization of L{Symbol} representing the quotient of two arguments         initialization of L{Symbol} representing the quotient of two arguments
3660    
# Line 3512  class GeneralTensorProduct_Symbol(Depend Line 3667  class GeneralTensorProduct_Symbol(Depend
3667         """         """
3668         sh_arg0=pokeShape(arg0)         sh_arg0=pokeShape(arg0)
3669         sh_arg1=pokeShape(arg1)         sh_arg1=pokeShape(arg1)
3670         sh0=sh_arg0[:len(sh_arg0)-offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
3671         sh01=sh_arg0[len(sh_arg0)-offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
3672         sh10=sh_arg1[:offset]         sh10=sh_arg1[:axis_offset]
3673         sh1=sh_arg1[offset:]         sh1=sh_arg1[axis_offset:]
3674         if not sh01==sh10:         if not sh01==sh10:
3675             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3676         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,offset])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,axis_offset])
3677    
3678     def getMyCode(self,argstrs,format="escript"):     def getMyCode(self,argstrs,format="escript"):
3679        """        """
# Line 3533  class GeneralTensorProduct_Symbol(Depend Line 3688  class GeneralTensorProduct_Symbol(Depend
3688        @raise: NotImplementedError: if the requested format is not available        @raise: NotImplementedError: if the requested format is not available
3689        """        """
3690        if format=="escript" or format=="str" or format=="text":        if format=="escript" or format=="str" or format=="text":
3691           return "generalTensorProduct(%s,%s,offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])           return "generalTensorProduct(%s,%s,axis_offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])
3692        else:        else:
3693           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)
3694    
# Line 3558  class GeneralTensorProduct_Symbol(Depend Line 3713  class GeneralTensorProduct_Symbol(Depend
3713           args=self.getSubstitutedArguments(argvals)           args=self.getSubstitutedArguments(argvals)
3714           return generalTensorProduct(args[0],args[1],args[2])           return generalTensorProduct(args[0],args[1],args[2])
3715    
3716  def escript_generalTensorProduct(arg0,arg1,offset): # this should be escript._generalTensorProduct  def escript_generalTensorProduct(arg0,arg1,axis_offset): # this should be escript._generalTensorProduct
3717      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"
3718      # calculate the return shape:      # calculate the return shape:
3719      shape0=arg0.getShape()[:arg0.getRank()-offset]      shape0=arg0.getShape()[:arg0.getRank()-axis_offset]
3720      shape01=arg0.getShape()[arg0.getRank()-offset:]      shape01=arg0.getShape()[arg0.getRank()-axis_offset:]
3721      shape10=arg1.getShape()[:offset]      shape10=arg1.getShape()[:axis_offset]
3722      shape1=arg1.getShape()[offset:]      shape1=arg1.getShape()[axis_offset:]
3723      if not shape01==shape10:      if not shape01==shape10:
3724          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
3725    
3726      # whatr function space should be used? (this here is not good!)      # whatr function space should be used? (this here is not good!)
3727      fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace()      fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace()
# Line 3600  def escript_generalTensorProduct(arg0,ar Line 3755  def escript_generalTensorProduct(arg0,ar
3755           out.__setitem__(tuple(i0+i1),s)           out.__setitem__(tuple(i0+i1),s)
3756      return out      return out
3757    
3758    
3759  #=========================================================  #=========================================================
3760  #   some little helpers  #  functions dealing with spatial dependency
3761  #=========================================================  #=========================================================
3762  def grad(arg,where=None):  def grad(arg,where=None):
3763      """      """
3764      Returns the spatial gradient of arg at where.      Returns the spatial gradient of arg at where.
3765    
3766      @param arg:   Data object representing the function which gradient      If C{g} is the returned object, then
3767                    to be calculated.  
3768          - if C{arg} is rank 0 C{g[s]} is the derivative of C{arg} with respect to the C{s}-th spatial dimension.
3769          - if C{arg} is rank 1 C{g[i,s]} is the derivative of C{arg[i]} with respect to the C{s}-th spatial dimension.
3770          - if C{arg} is rank 2 C{g[i,j,s]} is the derivative of C{arg[i,j]} with respect to the C{s}-th spatial dimension.
3771          - if C{arg} is rank 3 C{g[i,j,k,s]} is the derivative of C{arg[i,j,k]} with respect to the C{s}-th spatial dimension.
3772    
3773        @param arg: function which gradient to be calculated. Its rank has to be less than 3.
3774        @type arg: L{escript.Data} or L{Symbol}
3775      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the gradient will be calculated.
3776                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
3777        @type where: C{None} or L{escript.FunctionSpace}
3778        @return: gradient of arg.
3779        @rtype:  L{escript.Data} or L{Symbol}
3780      """      """
3781      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
3782         return Grad_Symbol(arg,where)         return Grad_Symbol(arg,where)
# Line 3620  def grad(arg,where=None): Line 3786  def grad(arg,where=None):
3786         else:         else:
3787            return arg._grad(where)            return arg._grad(where)
3788      else:      else:
3789        raise TypeError,"grad: Unknown argument type."         raise TypeError,"grad: Unknown argument type."
3790    
3791    class Grad_Symbol(DependendSymbol):
3792       """
3793       L{Symbol} representing the result of the gradient operator
3794       """
3795       def __init__(self,arg,where=None):
3796          """
3797          initialization of gradient L{Symbol} with argument arg
3798          @param arg: argument of function
3799          @type arg: L{Symbol}.
3800          @param where: FunctionSpace in which the gradient will be calculated.
3801                      If not present or C{None} an appropriate default is used.
3802          @type where: C{None} or L{escript.FunctionSpace}
3803          """
3804          d=arg.getDim()
3805          if d==None:
3806             raise ValueError,"argument must have a spatial dimension"
3807          super(Grad_Symbol,self).__init__(args=[arg,where],shape=tuple(list(arg.getShape()).extend(d)),dim=d)
3808    
3809       def getMyCode(self,argstrs,format="escript"):
3810          """
3811          returns a program code that can be used to evaluate the symbol.
3812    
3813          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3814          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3815          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3816          @type format: C{str}
3817          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3818          @rtype: C{str}
3819          @raise: NotImplementedError: if the requested format is not available
3820          """
3821          if format=="escript" or format=="str"  or format=="text":
3822             return "grad(%s,where=%s)"%(argstrs[0],argstrs[1])
3823          else:
3824             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3825    
3826       def substitute(self,argvals):
3827          """
3828          assigns new values to symbols in the definition of the symbol.
3829          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3830    
3831          @param argvals: new values assigned to symbols
3832          @type argvals: C{dict} with keywords of type L{Symbol}.
3833          @return: result of the substitution process. Operations are executed as much as possible.
3834          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3835          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3836          """
3837          if argvals.has_key(self):
3838             arg=argvals[self]
3839             if self.isAppropriateValue(arg):
3840                return arg
3841             else:
3842                raise TypeError,"%s: new value is not appropriate."%str(self)
3843          else:
3844             arg=self.getSubstitutedArguments(argvals)
3845             return grad(arg[0],where=arg[1])
3846    
3847       def diff(self,arg):
3848          """
3849          differential of this object
3850    
3851          @param arg: the derivative is calculated with respect to arg
3852          @type arg: L{escript.Symbol}
3853          @return: derivative with respect to C{arg}
3854          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3855          """
3856          if arg==self:
3857             return identity(self.getShape())
3858          else:
3859             return grad(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
3860    
3861  def integrate(arg,where=None):  def integrate(arg,where=None):
3862      """      """
3863      Return the integral if the function represented by Data object arg over      Return the integral of the function C{arg} over its domain. If C{where} is present C{arg} is interpolated to C{where}
3864      its domain.      before integration.
3865    
3866      @param arg:   Data object representing the function which is integrated.      @param arg:   the function which is integrated.
3867        @type arg: L{escript.Data} or L{Symbol}
3868      @param where: FunctionSpace in which the integral is calculated.      @param where: FunctionSpace in which the integral is calculated.
3869                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
3870        @type where: C{None} or L{escript.FunctionSpace}
3871        @return: integral of arg.
3872        @rtype:  C{float}, C{numarray.NumArray} or L{Symbol}
3873      """      """
3874      if isinstance(arg,numarray.NumArray):      if isinstance(arg,Symbol):
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,float):  
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,int):  
         if checkForZero(arg):  
            return float(arg)  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,Symbol):  
3875         return Integrate_Symbol(arg,where)         return Integrate_Symbol(arg,where)
3876      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
3877         if not where==None: arg=escript.Data(arg,where)         if not where==None: arg=escript.Data(arg,where)
# Line 3657  def integrate(arg,where=None): Line 3882  def integrate(arg,where=None):
3882      else:      else:
3883        raise TypeError,"integrate: Unknown argument type."        raise TypeError,"integrate: Unknown argument type."
3884    
3885    class Integrate_Symbol(DependendSymbol):
3886       """
3887       L{Symbol} representing the result of the spatial integration operator
3888       """
3889       def __init__(self,arg,where=None):
3890          """
3891          initialization of integration L{Symbol} with argument arg
3892          @param arg: argument of the integration
3893          @type arg: L{Symbol}.
3894          @param where: FunctionSpace in which the integration will be calculated.
3895                      If not present or C{None} an appropriate default is used.
3896          @type where: C{None} or L{escript.FunctionSpace}
3897          """
3898          super(Integrate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
3899    
3900       def getMyCode(self,argstrs,format="escript"):
3901          """
3902          returns a program code that can be used to evaluate the symbol.
3903    
3904          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3905          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3906          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3907          @type format: C{str}
3908          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3909          @rtype: C{str}
3910          @raise: NotImplementedError: if the requested format is not available
3911          """
3912          if format=="escript" or format=="str"  or format=="text":
3913             return "integrate(%s,where=%s)"%(argstrs[0],argstrs[1])
3914          else:
3915             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3916    
3917       def substitute(self,argvals):
3918          """
3919          assigns new values to symbols in the definition of the symbol.
3920          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3921    
3922          @param argvals: new values assigned to symbols
3923          @type argvals: C{dict} with keywords of type L{Symbol}.
3924          @return: result of the substitution process. Operations are executed as much as possible.
3925          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3926          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3927          """
3928          if argvals.has_key(self):
3929             arg=argvals[self]
3930             if self.isAppropriateValue(arg):
3931                return arg
3932             else:
3933                raise TypeError,"%s: new value is not appropriate."%str(self)
3934          else:
3935             arg=self.getSubstitutedArguments(argvals)
3936             return integrate(arg[0],where=arg[1])
3937    
3938       def diff(self,arg):
3939          """
3940          differential of this object
3941    
3942          @param arg: the derivative is calculated with respect to arg
3943          @type arg: L{escript.Symbol}
3944          @return: derivative with respect to C{arg}
3945          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3946          """
3947          if arg==self:
3948             return identity(self.getShape())
3949          else:
3950             return integrate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
3951    
3952    
3953  def interpolate(arg,where):  def interpolate(arg,where):
3954      """      """
3955      Interpolates the function into the FunctionSpace where.      interpolates the function into the FunctionSpace where.
3956    
3957      @param arg:    interpolant      @param arg: interpolant
3958      @param where:  FunctionSpace to interpolate to      @type arg: L{escript.Data} or L{Symbol}
3959        @param where: FunctionSpace to be interpolated to
3960        @type where: L{escript.FunctionSpace}
3961        @return: interpolated argument
3962        @rtype:  C{escript.Data} or L{Symbol}
3963      """      """
3964      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
3965         return Interpolated_Symbol(arg,where)         return Interpolate_Symbol(arg,where)
3966      else:      else:
3967         return escript.Data(arg,where)         return escript.Data(arg,where)
3968    
3969    class Interpolate_Symbol(DependendSymbol):
3970       """
3971       L{Symbol} representing the result of the interpolation operator
3972       """
3973       def __init__(self,arg,where):
3974          """
3975          initialization of interpolation L{Symbol} with argument arg
3976          @param arg: argument of the interpolation
3977          @type arg: L{Symbol}.
3978          @param where: FunctionSpace into which the argument is interpolated.
3979          @type where: L{escript.FunctionSpace}
3980          """
3981          super(Interpolate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
3982    
3983       def getMyCode(self,argstrs,format="escript"):
3984          """
3985          returns a program code that can be used to evaluate the symbol.
3986    
3987          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3988          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3989          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3990          @type format: C{str}
3991          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3992          @rtype: C{str}
3993          @raise: NotImplementedError: if the requested format is not available
3994          """
3995          if format=="escript" or format=="str"  or format=="text":
3996             return "interpolate(%s,where=%s)"%(argstrs[0],argstrs[1])
3997          else:
3998             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3999    
4000       def substitute(self,argvals):
4001          """
4002          assigns new values to symbols in the definition of the symbol.
4003          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4004    
4005          @param argvals: new values assigned to symbols
4006          @type argvals: C{dict} with keywords of type L{Symbol}.
4007          @return: result of the substitution process. Operations are executed as much as possible.
4008          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4009          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4010          """
4011          if argvals.has_key(self):
4012             arg=argvals[self]
4013             if self.isAppropriateValue(arg):
4014                return arg
4015             else:
4016                raise TypeError,"%s: new value is not appropriate."%str(self)
4017          else:
4018             arg=self.getSubstitutedArguments(argvals)
4019             return interpolate(arg[0],where=arg[1])
4020    
4021       def diff(self,arg):
4022          """
4023          differential of this object
4024    
4025          @param arg: the derivative is calculated with respect to arg
4026          @type arg: L{escript.Symbol}
4027          @return: derivative with respect to C{arg}
4028          @rtype: L{Symbol} but other types such as L{escript.Data}, L{numarray.NumArray}  are possible.
4029          """
4030          if arg==self:
4031             return identity(self.getShape())
4032          else:
4033             return interpolate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4034    
4035    
4036  def div(arg,where=None):  def div(arg,where=None):
4037      """      """
4038      Returns the divergence of arg at where.      returns the divergence of arg at where.
4039    
4040      @param arg:   Data object representing the function which gradient to      @param arg: function which divergence to be calculated. Its shape has to be (d,) where d is the spatial dimension.
4041                    be calculated.      @type arg: L{escript.Data} or L{Symbol}
4042      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the divergence will be calculated.
4043                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4044        @type where: C{None} or L{escript.FunctionSpace}
4045        @return: divergence of arg.
4046        @rtype:  L{escript.Data} or L{Symbol}
4047      """      """
4048      g=grad(arg,where)      if not arg.getShape()==(arg.getDim(),):
4049      return trace(g,axis0=g.getRank()-2,axis1=g.getRank()-1)        raise ValueError,"div: expected shape is (%s,)"%arg.getDim()
4050        return trace(grad(arg,where))
4051    
4052  def jump(arg):  def jump(arg,domain=None):
4053      """      """
4054      Returns the jump of arg across a continuity.      returns the jump of arg across the continuity of the domain
4055    
4056      @param arg:   Data object representing the function which gradient      @param arg: argument
4057                    to be calculated.      @type arg: L{escript.Data} or L{Symbol}
4058        @param domain: the domain where the discontinuity is located. If domain is not present or equal to C{None}
4059                       the domain of arg is used. If arg is a L{Symbol} the domain must be present.
4060        @type domain: C{None} or L{escript.Domain}
4061        @return: jump of arg
4062        @rtype:  L{escript.Data} or L{Symbol}
4063      """      """
4064      d=arg.getDomain()      if domain==None: domain=arg.getDomain()
4065      return arg.interpolate(escript.FunctionOnContactOne(d))-arg.interpolate(escript.FunctionOnContactZero(d))      return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))
   
4066  #=============================  #=============================
4067  #  #
4068  # wrapper for various functions: if the argument has attribute the function name  # wrapper for various functions: if the argument has attribute the function name
# Line 3727  def transpose(arg,axis=None): Line 4099  def transpose(arg,axis=None):
4099      else:      else:
4100         return numarray.transpose(arg,axis=axis)         return numarray.transpose(arg,axis=axis)
4101    
 def trace(arg,axis0=0,axis1=1):  
     """  
     Return  
   
     @param arg:  
     """  
     if isinstance(arg,Symbol):  
        s=list(arg.getShape())          
        s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])  
        return Trace_Symbol(arg,axis0=axis0,axis1=axis1)  
     elif isinstance(arg,escript.Data):  
        # hack for trace  
        s=arg.getShape()  
        if s[axis0]!=s[axis1]:  
            raise ValueError,"illegal axis in trace"  
        out=escript.Scalar(0.,arg.getFunctionSpace())  
        for i in range(s[axis0]):  
           out+=arg[i,i]  
        return out  
        # end hack for trace  
     else:  
        return numarray.trace(arg,axis0=axis0,axis1=axis1)  
4102    
4103    
4104  def reorderComponents(arg,index):  def reorderComponents(arg,index):

Legend:
Removed from v.414  
changed lines
  Added in v.429

  ViewVC Help
Powered by ViewVC 1.1.26