/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 881 by gross, Thu Oct 26 02:54:47 2006 UTC revision 912 by gross, Wed Dec 6 03:29:49 2006 UTC
# Line 239  def inf(arg): Line 239  def inf(arg):
239  #=========================================================================  #=========================================================================
240  #   some little helpers  #   some little helpers
241  #=========================================================================  #=========================================================================
242  def pokeShape(arg):  def getRank(arg):
243        """
244        identifies the rank of its argument
245    
246        @param arg: a given object
247        @type arg: L{numarray.NumArray},L{escript.Data},C{float}, C{int}, C{Symbol}
248        @return: the rank of the argument
249        @rtype: C{int}
250        @raise TypeError: if type of arg cannot be processed
251        """
252    
253        if isinstance(arg,numarray.NumArray):
254            return arg.rank
255        elif isinstance(arg,escript.Data):
256            return arg.getRank()
257        elif isinstance(arg,float):
258            return 0
259        elif isinstance(arg,int):
260            return 0
261        elif isinstance(arg,Symbol):
262            return arg.getRank()
263        else:
264          raise TypeError,"getShape: cannot identify shape"
265    def getShape(arg):
266      """      """
267      identifies the shape of its argument      identifies the shape of its argument
268    
# Line 261  def pokeShape(arg): Line 284  def pokeShape(arg):
284      elif isinstance(arg,Symbol):      elif isinstance(arg,Symbol):
285          return arg.getShape()          return arg.getShape()
286      else:      else:
287        raise TypeError,"pokeShape: cannot identify shape"        raise TypeError,"getShape: cannot identify shape"
288    
289  def pokeDim(arg):  def pokeDim(arg):
290      """      """
# Line 284  def commonShape(arg0,arg1): Line 307  def commonShape(arg0,arg1):
307      """      """
308      returns a shape to which arg0 can be extendent from the right and arg1 can be extended from the left.      returns a shape to which arg0 can be extendent from the right and arg1 can be extended from the left.
309    
310      @param arg0: an object with a shape (see L{pokeShape})      @param arg0: an object with a shape (see L{getShape})
311      @param arg1: an object with a shape (see L{pokeShape})      @param arg1: an object with a shape (see L{getShape})
312      @return: the shape of arg0 or arg1 such that the left port equals the shape of arg0 and the right end equals the shape of arg1.      @return: the shape of arg0 or arg1 such that the left port equals the shape of arg0 and the right end equals the shape of arg1.
313      @rtype: C{tuple} of C{int}      @rtype: C{tuple} of C{int}
314      @raise ValueError: if no shape can be found.      @raise ValueError: if no shape can be found.
315      """      """
316      sh0=pokeShape(arg0)      sh0=getShape(arg0)
317      sh1=pokeShape(arg1)      sh1=getShape(arg1)
318      if len(sh0)<len(sh1):      if len(sh0)<len(sh1):
319         if not sh0==sh1[:len(sh0)]:         if not sh0==sh1[:len(sh0)]:
320               raise ValueError,"argument 0 cannot be extended to the shape of argument 1"               raise ValueError,"argument 0 cannot be extended to the shape of argument 1"
# Line 445  def matchShape(arg0,arg1): Line 468  def matchShape(arg0,arg1):
468      @rtype: C{tuple}      @rtype: C{tuple}
469      """      """
470      sh=commonShape(arg0,arg1)      sh=commonShape(arg0,arg1)
471      sh0=pokeShape(arg0)      sh0=getShape(arg0)
472      sh1=pokeShape(arg1)      sh1=getShape(arg1)
473      if len(sh0)<len(sh):      if len(sh0)<len(sh):
474         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float64)),arg1         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float64)),arg1
475      elif len(sh1)<len(sh):      elif len(sh1)<len(sh):
# Line 574  class Symbol(object): Line 597  class Symbol(object):
597            if isinstance(a,Symbol):            if isinstance(a,Symbol):
598               out.append(a.substitute(argvals))               out.append(a.substitute(argvals))
599            else:            else:
600                s=pokeShape(s)+arg.getShape()                s=getShape(s)+arg.getShape()
601                if len(s)>0:                if len(s)>0:
602                   out.append(numarray.zeros(s),numarray.Float64)                   out.append(numarray.zeros(s),numarray.Float64)
603                else:                else:
# Line 3660  class Add_Symbol(DependendSymbol): Line 3683  class Add_Symbol(DependendSymbol):
3683         @raise ValueError: if both arguments do not have the same shape.         @raise ValueError: if both arguments do not have the same shape.
3684         @note: if both arguments have a spatial dimension, they must equal.         @note: if both arguments have a spatial dimension, they must equal.
3685         """         """
3686         sh0=pokeShape(arg0)         sh0=getShape(arg0)
3687         sh1=pokeShape(arg1)         sh1=getShape(arg1)
3688         if not sh0==sh1:         if not sh0==sh1:
3689            raise ValueError,"Add_Symbol: shape of arguments must match"            raise ValueError,"Add_Symbol: shape of arguments must match"
3690         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0,args=[arg0,arg1])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0,args=[arg0,arg1])
# Line 3735  def mult(arg0,arg1): Line 3758  def mult(arg0,arg1):
3758         """         """
3759         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3760         if testForZero(args[0]) or testForZero(args[1]):         if testForZero(args[0]) or testForZero(args[1]):
3761            return numarray.zeros(pokeShape(args[0]),numarray.Float64)            return numarray.zeros(getShape(args[0]),numarray.Float64)
3762         else:         else:
3763            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :
3764                return Mult_Symbol(args[0],args[1])                return Mult_Symbol(args[0],args[1])
# Line 3759  class Mult_Symbol(DependendSymbol): Line 3782  class Mult_Symbol(DependendSymbol):
3782         @raise ValueError: if both arguments do not have the same shape.         @raise ValueError: if both arguments do not have the same shape.
3783         @note: if both arguments have a spatial dimension, they must equal.         @note: if both arguments have a spatial dimension, they must equal.
3784         """         """
3785         sh0=pokeShape(arg0)         sh0=getShape(arg0)
3786         sh1=pokeShape(arg1)         sh1=getShape(arg1)
3787         if not sh0==sh1:         if not sh0==sh1:
3788            raise ValueError,"Mult_Symbol: shape of arguments must match"            raise ValueError,"Mult_Symbol: shape of arguments must match"
3789         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0,args=[arg0,arg1])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0,args=[arg0,arg1])
# Line 3835  def quotient(arg0,arg1): Line 3858  def quotient(arg0,arg1):
3858         """         """
3859         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3860         if testForZero(args[0]):         if testForZero(args[0]):
3861            return numarray.zeros(pokeShape(args[0]),numarray.Float64)            return numarray.zeros(getShape(args[0]),numarray.Float64)
3862         elif isinstance(args[0],Symbol):         elif isinstance(args[0],Symbol):
3863            if isinstance(args[1],Symbol):            if isinstance(args[1],Symbol):
3864               return Quotient_Symbol(args[0],args[1])               return Quotient_Symbol(args[0],args[1])
# Line 3864  class Quotient_Symbol(DependendSymbol): Line 3887  class Quotient_Symbol(DependendSymbol):
3887         @raise ValueError: if both arguments do not have the same shape.         @raise ValueError: if both arguments do not have the same shape.
3888         @note: if both arguments have a spatial dimension, they must equal.         @note: if both arguments have a spatial dimension, they must equal.
3889         """         """
3890         sh0=pokeShape(arg0)         sh0=getShape(arg0)
3891         sh1=pokeShape(arg1)         sh1=getShape(arg1)
3892         if not sh0==sh1:         if not sh0==sh1:
3893            raise ValueError,"Quotient_Symbol: shape of arguments must match"            raise ValueError,"Quotient_Symbol: shape of arguments must match"
3894         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0,args=[arg0,arg1])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0,args=[arg0,arg1])
# Line 3941  def power(arg0,arg1): Line 3964  def power(arg0,arg1):
3964         """         """
3965         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3966         if testForZero(args[0]):         if testForZero(args[0]):
3967            return numarray.zeros(pokeShape(args[0]),numarray.Float64)            return numarray.zeros(getShape(args[0]),numarray.Float64)
3968         elif testForZero(args[1]):         elif testForZero(args[1]):
3969            return numarray.ones(pokeShape(args[1]),numarray.Float64)            return numarray.ones(getShape(args[1]),numarray.Float64)
3970         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):
3971            return Power_Symbol(args[0],args[1])            return Power_Symbol(args[0],args[1])
3972         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):
# Line 3966  class Power_Symbol(DependendSymbol): Line 3989  class Power_Symbol(DependendSymbol):
3989         @raise ValueError: if both arguments do not have the same shape.         @raise ValueError: if both arguments do not have the same shape.
3990         @note: if both arguments have a spatial dimension, they must equal.         @note: if both arguments have a spatial dimension, they must equal.
3991         """         """
3992         sh0=pokeShape(arg0)         sh0=getShape(arg0)
3993         sh1=pokeShape(arg1)         sh1=getShape(arg1)
3994         if not sh0==sh1:         if not sh0==sh1:
3995            raise ValueError,"Power_Symbol: shape of arguments must match"            raise ValueError,"Power_Symbol: shape of arguments must match"
3996         d0=pokeDim(arg0)         d0=pokeDim(arg0)
# Line 4112  def inner(arg0,arg1): Line 4135  def inner(arg0,arg1):
4135      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{float} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{float} depending on the input
4136      @raise ValueError: if the shapes of the arguments are not identical      @raise ValueError: if the shapes of the arguments are not identical
4137      """      """
4138      sh0=pokeShape(arg0)      sh0=getShape(arg0)
4139      sh1=pokeShape(arg1)      sh1=getShape(arg1)
4140      if not sh0==sh1:      if not sh0==sh1:
4141          raise ValueError,"inner: shape of arguments does not match"          raise ValueError,"inner: shape of arguments does not match"
4142      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))
# Line 4164  def matrix_mult(arg0,arg1): Line 4187  def matrix_mult(arg0,arg1):
4187      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
4188      @raise ValueError: if the shapes of the arguments are not appropriate      @raise ValueError: if the shapes of the arguments are not appropriate
4189      """      """
4190      sh0=pokeShape(arg0)      sh0=getShape(arg0)
4191      sh1=pokeShape(arg1)      sh1=getShape(arg1)
4192      if not len(sh0)==2 :      if not len(sh0)==2 :
4193          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
4194      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
# Line 4213  def tensor_mult(arg0,arg1): Line 4236  def tensor_mult(arg0,arg1):
4236      @return: the tensor product of arg0 and arg1 at each data point      @return: the tensor product of arg0 and arg1 at each data point
4237      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
4238      """      """
4239      sh0=pokeShape(arg0)      sh0=getShape(arg0)
4240      sh1=pokeShape(arg1)      sh1=getShape(arg1)
4241      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
4242         return generalTensorProduct(arg0,arg1,axis_offset=1)         return generalTensorProduct(arg0,arg1,axis_offset=1)
4243      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
# Line 4288  class GeneralTensorProduct_Symbol(Depend Line 4311  class GeneralTensorProduct_Symbol(Depend
4311         @raise ValueError: illegal dimension         @raise ValueError: illegal dimension
4312         @note: if both arguments have a spatial dimension, they must equal.         @note: if both arguments have a spatial dimension, they must equal.
4313         """         """
4314         sh_arg0=pokeShape(arg0)         sh_arg0=getShape(arg0)
4315         sh_arg1=pokeShape(arg1)         sh_arg1=getShape(arg1)
4316         sh0=sh_arg0[:len(sh_arg0)-axis_offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
4317         sh01=sh_arg0[len(sh_arg0)-axis_offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
4318         sh10=sh_arg1[:axis_offset]         sh10=sh_arg1[:axis_offset]
# Line 4362  def transposed_matrix_mult(arg0,arg1): Line 4385  def transposed_matrix_mult(arg0,arg1):
4385      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
4386      @raise ValueError: if the shapes of the arguments are not appropriate      @raise ValueError: if the shapes of the arguments are not appropriate
4387      """      """
4388      sh0=pokeShape(arg0)      sh0=getShape(arg0)
4389      sh1=pokeShape(arg1)      sh1=getShape(arg1)
4390      if not len(sh0)==2 :      if not len(sh0)==2 :
4391          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
4392      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
# Line 4407  def transposed_tensor_mult(arg0,arg1): Line 4430  def transposed_tensor_mult(arg0,arg1):
4430      @return: the tensor product of tarnsposed of arg0 and arg1 at each data point      @return: the tensor product of tarnsposed of arg0 and arg1 at each data point
4431      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
4432      """      """
4433      sh0=pokeShape(arg0)      sh0=getShape(arg0)
4434      sh1=pokeShape(arg1)      sh1=getShape(arg1)
4435      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
4436         return generalTransposedTensorProduct(arg0,arg1,axis_offset=1)         return generalTransposedTensorProduct(arg0,arg1,axis_offset=1)
4437      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
# Line 4485  class GeneralTransposedTensorProduct_Sym Line 4508  class GeneralTransposedTensorProduct_Sym
4508         @raise ValueError: inconsistent dimensions of arguments.         @raise ValueError: inconsistent dimensions of arguments.
4509         @note: if both arguments have a spatial dimension, they must equal.         @note: if both arguments have a spatial dimension, they must equal.
4510         """         """
4511         sh_arg0=pokeShape(arg0)         sh_arg0=getShape(arg0)
4512         sh_arg1=pokeShape(arg1)         sh_arg1=getShape(arg1)
4513         sh01=sh_arg0[:axis_offset]         sh01=sh_arg0[:axis_offset]
4514         sh10=sh_arg1[:axis_offset]         sh10=sh_arg1[:axis_offset]
4515         sh0=sh_arg0[axis_offset:]         sh0=sh_arg0[axis_offset:]
# Line 4555  def matrix_transposed_mult(arg0,arg1): Line 4578  def matrix_transposed_mult(arg0,arg1):
4578      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
4579      @raise ValueError: if the shapes of the arguments are not appropriate      @raise ValueError: if the shapes of the arguments are not appropriate
4580      """      """
4581      sh0=pokeShape(arg0)      sh0=getShape(arg0)
4582      sh1=pokeShape(arg1)      sh1=getShape(arg1)
4583      if not len(sh0)==2 :      if not len(sh0)==2 :
4584          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
4585      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
# Line 4591  def tensor_transposed_mult(arg0,arg1): Line 4614  def tensor_transposed_mult(arg0,arg1):
4614      @return: the tensor product of tarnsposed of arg0 and arg1 at each data point      @return: the tensor product of tarnsposed of arg0 and arg1 at each data point
4615      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
4616      """      """
4617      sh0=pokeShape(arg0)      sh0=getShape(arg0)
4618      sh1=pokeShape(arg1)      sh1=getShape(arg1)
4619      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
4620         return generalTensorTransposedProduct(arg0,arg1,axis_offset=1)         return generalTensorTransposedProduct(arg0,arg1,axis_offset=1)
4621      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
# Line 4669  class GeneralTensorTransposedProduct_Sym Line 4692  class GeneralTensorTransposedProduct_Sym
4692         @raise ValueError: inconsistent dimensions of arguments.         @raise ValueError: inconsistent dimensions of arguments.
4693         @note: if both arguments have a spatial dimension, they must equal.         @note: if both arguments have a spatial dimension, they must equal.
4694         """         """
4695         sh_arg0=pokeShape(arg0)         sh_arg0=getShape(arg0)
4696         sh_arg1=pokeShape(arg1)         sh_arg1=getShape(arg1)
4697         sh0=sh_arg0[:len(sh_arg0)-axis_offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
4698         sh01=sh_arg0[len(sh_arg0)-axis_offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
4699         sh10=sh_arg1[len(sh_arg1)-axis_offset:]         sh10=sh_arg1[len(sh_arg1)-axis_offset:]

Legend:
Removed from v.881  
changed lines
  Added in v.912

  ViewVC Help
Powered by ViewVC 1.1.26