# Diff of /trunk/escript/py_src/util.py

revision 429 by gross, Wed Jan 11 05:53:40 2006 UTC revision 433 by gross, Tue Jan 17 23:54:38 2006 UTC
# Line 30  __date__="\$Date\$" Line 30  __date__="\$Date\$"
30
31  import math  import math
32  import numarray  import numarray
33    import numarray.linear_algebra
34  import escript  import escript
35  import os  import os
36
# Line 3013  class Trace_Symbol(DependendSymbol): Line 3014  class Trace_Symbol(DependendSymbol):
3014        else:        else:
3015           return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])           return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3016
3017    def inverse(arg):
3018        """
3019        returns the inverse of the square matrix arg.
3020
3021        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal
3022        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3023        @return: inverse arg_inv of the argument. It will be matrixmul(inverse(arg),arg) almost equal to kronecker(arg.getShape()[0])
3024        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3025        """
3026        if isinstance(arg,numarray.NumArray):
3027          return numarray.linear_algebra.inverse(arg)
3028        elif isinstance(arg,escript.Data):
3029          return escript_inverse(arg)
3030        elif isinstance(arg,float):
3031          return 1./arg
3032        elif isinstance(arg,int):
3033          return 1./float(arg)
3034        elif isinstance(arg,Symbol):
3035          return Inverse_Symbol(arg)
3036        else:
3037          raise TypeError,"inverse: Unknown argument type."
3038
3039    def escript_inverse(arg): # this should be escript._inverse and use LAPACK
3040          "arg is a Data objects!!!"
3041          if not arg.getRank()==2:
3042            raise ValueError,"escript_inverse: argument must have rank 2"
3043          s=arg.getShape()
3044          if not s[0] == s[1]:
3045            raise ValueError,"escript_inverse: argument must be a square matrix."
3046          out=escript.Data(0.,s,arg.getFunctionSpace())
3047          if s[0]==1:
3048              if inf(abs(arg[0,0]))==0: # in c this should be done point wise as abs(arg[0,0](i))<=0.
3049                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3050              out[0,0]=1./arg[0,0]
3051          elif s[0]==2:
3052              A11=arg[0,0]
3053              A12=arg[0,1]
3054              A21=arg[1,0]
3055              A22=arg[1,1]
3056              D = A11*A22-A12*A21
3057              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3058                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3059              D=1./D
3060              out[0,0]= A22*D
3061              out[1,0]=-A21*D
3062              out[0,1]=-A12*D
3063              out[1,1]= A11*D
3064          elif s[0]==3:
3065              A11=arg[0,0]
3066              A21=arg[1,0]
3067              A31=arg[2,0]
3068              A12=arg[0,1]
3069              A22=arg[1,1]
3070              A32=arg[2,1]
3071              A13=arg[0,2]
3072              A23=arg[1,2]
3073              A33=arg[2,2]
3074              D  =  A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
3075              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3076                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3077              D=1./D
3078              out[0,0]=(A22*A33-A23*A32)*D
3079              out[1,0]=(A31*A23-A21*A33)*D
3080              out[2,0]=(A21*A32-A31*A22)*D
3081              out[0,1]=(A13*A32-A12*A33)*D
3082              out[1,1]=(A11*A33-A31*A13)*D
3083              out[2,1]=(A12*A31-A11*A32)*D
3084              out[0,2]=(A12*A23-A13*A22)*D
3085              out[1,2]=(A13*A21-A11*A23)*D
3086              out[2,2]=(A11*A22-A12*A21)*D
3087          else:
3088             raise TypeError,"escript_inverse: only matrix dimensions 1,2,3 are supported right now."
3089          return out
3090
3091    class Inverse_Symbol(DependendSymbol):
3092       """
3093       L{Symbol} representing the result of the inverse function
3094       """
3095       def __init__(self,arg):
3096          """
3097          initialization of inverse L{Symbol} with argument arg
3098          @param arg: argument of function
3099          @type arg: L{Symbol}.
3100          """
3101          if not arg.getRank()==2:
3102            raise ValueError,"Inverse_Symbol:: argument must have rank 2"
3103          s=arg.getShape()
3104          if not s[0] == s[1]:
3105            raise ValueError,"Inverse_Symbol:: argument must be a square matrix."
3106          super(Inverse_Symbol,self).__init__(args=[arg],shape=s,dim=arg.getDim())
3107
3108       def getMyCode(self,argstrs,format="escript"):
3109          """
3110          returns a program code that can be used to evaluate the symbol.
3111
3112          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3113          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3114          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3115          @type format: C{str}
3116          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3117          @rtype: C{str}
3118          @raise: NotImplementedError: if the requested format is not available
3119          """
3120          if format=="escript" or format=="str"  or format=="text":
3121             return "inverse(%s)"%argstrs[0]
3122          else:
3123             raise NotImplementedError,"Inverse_Symbol does not provide program code for format %s."%format
3124
3125       def substitute(self,argvals):
3126          """
3127          assigns new values to symbols in the definition of the symbol.
3128          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3129
3130          @param argvals: new values assigned to symbols
3131          @type argvals: C{dict} with keywords of type L{Symbol}.
3132          @return: result of the substitution process. Operations are executed as much as possible.
3133          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3134          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3135          """
3136          if argvals.has_key(self):
3137             arg=argvals[self]
3138             if self.isAppropriateValue(arg):
3139                return arg
3140             else:
3141                raise TypeError,"%s: new value is not appropriate."%str(self)
3142          else:
3143             arg=self.getSubstitutedArguments(argvals)
3144             return inverse(arg[0])
3145
3146       def diff(self,arg):
3147          """
3148          differential of this object
3149
3150          @param arg: the derivative is calculated with respect to arg
3151          @type arg: L{escript.Symbol}
3152          @return: derivative with respect to C{arg}
3153          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3154          """
3155          if arg==self:
3156             return identity(self.getShape())
3157          else:
3158             return -matrixmult(matrixmult(self,self.getDifferentiatedArguments(arg)[0]),self)
3159  #=======================================================  #=======================================================
3160  #  Binary operations:  #  Binary operations:
3161  #=======================================================  #=======================================================
# Line 4045  def div(arg,where=None): Line 4188  def div(arg,where=None):
4188      @return: divergence of arg.      @return: divergence of arg.
4189      @rtype:  L{escript.Data} or L{Symbol}      @rtype:  L{escript.Data} or L{Symbol}
4190      """      """
4191      if not arg.getShape()==(arg.getDim(),):      if not arg.getShape()==(arg.getDomain().getDim(),):
4192        raise ValueError,"div: expected shape is (%s,)"%arg.getDim()        raise ValueError,"div: expected shape is (%s,)"%arg.getDomain().getDim()