/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 429 by gross, Wed Jan 11 05:53:40 2006 UTC revision 614 by elspeth, Wed Mar 22 01:37:07 2006 UTC
# Line 1  Line 1 
1  # $Id$  # $Id$
 #  
 #      COPYRIGHT ACcESS 2004 -  All Rights Reserved  
 #  
 #   This software is the property of ACcESS.  No part of this code  
 #   may be copied in any form or by any means without the expressed written  
 #   consent of ACcESS.  Copying, use or modification of this software  
 #   by any unauthorised person is illegal unless that  
 #   person has a software license agreement with ACcESS.  
 #  
2    
3  """  """
4  Utility functions for escript  Utility functions for escript
# Line 15  Utility functions for escript Line 6  Utility functions for escript
6  @remark:  This module is under construction and is still tested!!!  @remark:  This module is under construction and is still tested!!!
7    
8  @var __author__: name of author  @var __author__: name of author
9  @var __licence__: licence agreement  @var __license__: licence agreement
10  @var __url__: url entry point on documentation  @var __url__: url entry point on documentation
11  @var __version__: version  @var __version__: version
12  @var __date__: date of the version  @var __date__: date of the version
13  """  """
14                                                                                                                                                                                                                                                                                                                                                                                                            
15  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
16  __licence__="contact: esys@access.uq.edu.au"  __copyright__="""  Copyright (c) 2006 by ACcESS MNRF
17                        http://www.access.edu.au
18                    Primary Business: Queensland, Australia"""
19    __license__="""Licensed under the Open Software License version 3.0
20                 http://www.opensource.org/licenses/osl-3.0.php"""
21  __url__="http://www.iservo.edu.au/esys/escript"  __url__="http://www.iservo.edu.au/esys/escript"
22  __version__="$Revision$"  __version__="$Revision$"
23  __date__="$Date$"  __date__="$Date$"
# Line 33  import numarray Line 28  import numarray
28  import escript  import escript
29  import os  import os
30    
 # missing tests:  
   
 # def pokeShape(arg):  
 # def pokeDim(arg):  
 # def commonShape(arg0,arg1):  
 # def commonDim(*args):  
 # def testForZero(arg):  
 # def matchType(arg0=0.,arg1=0.):  
 # def matchShape(arg0,arg1):  
   
 # def transpose(arg,axis=None):  
 # def trace(arg,axis0=0,axis1=1):  
 # def reorderComponents(arg,index):  
   
 # def integrate(arg,where=None):  
 # def interpolate(arg,where):  
 # def div(arg,where=None):  
 # def grad(arg,where=None):  
   
 #  
 # slicing: get  
 #          set  
 #  
 # and derivatives  
   
31  #=========================================================  #=========================================================
32  #   some helpers:  #   some helpers:
33  #=========================================================  #=========================================================
# Line 122  def kronecker(d=3): Line 92  def kronecker(d=3):
92     return the kronecker S{delta}-symbol     return the kronecker S{delta}-symbol
93    
94     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
95     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
96     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
97     @rtype d: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2.
    @remark: the function is identical L{identity}  
98     """     """
99     return identityTensor(d)     return identityTensor(d)
100    
# Line 140  def identity(shape=()): Line 109  def identity(shape=()):
109     @raise ValueError: if len(shape)>2.     @raise ValueError: if len(shape)>2.
110     """     """
111     if len(shape)>0:     if len(shape)>0:
112        out=numarray.zeros(shape+shape,numarray.Float)        out=numarray.zeros(shape+shape,numarray.Float64)
113        if len(shape)==1:        if len(shape)==1:
114            for i0 in range(shape[0]):            for i0 in range(shape[0]):
115               out[i0,i0]=1.               out[i0,i0]=1.
   
116        elif len(shape)==2:        elif len(shape)==2:
117            for i0 in range(shape[0]):            for i0 in range(shape[0]):
118               for i1 in range(shape[1]):               for i1 in range(shape[1]):
# Line 160  def identityTensor(d=3): Line 128  def identityTensor(d=3):
128     return the dxd identity matrix     return the dxd identity matrix
129    
130     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
131     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
132     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
133     @rtype: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2
134     """     """
135     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
136        d=d.getDim()         return escript.Data(identity((d.getDim(),)),d)
137     return identity(shape=(d,))     elif isinstance(d,escript.Domain):
138           return identity((d.getDim(),))
139       else:
140           return identity((d,))
141    
142  def identityTensor4(d=3):  def identityTensor4(d=3):
143     """     """
# Line 175  def identityTensor4(d=3): Line 146  def identityTensor4(d=3):
146     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
147     @type d: C{int} or any object with a C{getDim} method     @type d: C{int} or any object with a C{getDim} method
148     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise
149     @rtype: L{numarray.NumArray} of rank 4.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 4.
150     """     """
151     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
152        d=d.getDim()         return escript.Data(identity((d.getDim(),d.getDim())),d)
153     return identity((d,d))     elif isinstance(d,escript.Domain):
154           return identity((d.getDim(),d.getDim()))
155       else:
156           return identity((d,d))
157    
158  def unitVector(i=0,d=3):  def unitVector(i=0,d=3):
159     """     """
# Line 188  def unitVector(i=0,d=3): Line 162  def unitVector(i=0,d=3):
162     @param i: index     @param i: index
163     @type i: C{int}     @type i: C{int}
164     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
165     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
166     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise
167     @rtype: L{numarray.NumArray} of rank 1.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 1
168     """     """
169     return kronecker(d)[i]     return kronecker(d)[i]
170    
# Line 391  def matchType(arg0=0.,arg1=0.): Line 365  def matchType(arg0=0.,arg1=0.):
365         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
366            arg0=escript.Data(arg0,arg1.getFunctionSpace())            arg0=escript.Data(arg0,arg1.getFunctionSpace())
367         elif isinstance(arg1,float):         elif isinstance(arg1,float):
368            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
369         elif isinstance(arg1,int):         elif isinstance(arg1,int):
370            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
371         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
372            pass            pass
373         else:         else:
# Line 417  def matchType(arg0=0.,arg1=0.): Line 391  def matchType(arg0=0.,arg1=0.):
391         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
392            pass            pass
393         elif isinstance(arg1,float):         elif isinstance(arg1,float):
394            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
395         elif isinstance(arg1,int):         elif isinstance(arg1,int):
396            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
397         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
398            pass            pass
399         else:         else:
400            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
401      elif isinstance(arg0,float):      elif isinstance(arg0,float):
402         if isinstance(arg1,numarray.NumArray):         if isinstance(arg1,numarray.NumArray):
403            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
404         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
405            arg0=escript.Data(arg0,arg1.getFunctionSpace())            arg0=escript.Data(arg0,arg1.getFunctionSpace())
406         elif isinstance(arg1,float):         elif isinstance(arg1,float):
407            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
408            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
409         elif isinstance(arg1,int):         elif isinstance(arg1,int):
410            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
411            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
412         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
413            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
414         else:         else:
415            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
416      elif isinstance(arg0,int):      elif isinstance(arg0,int):
417         if isinstance(arg1,numarray.NumArray):         if isinstance(arg1,numarray.NumArray):
418            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
419         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
420            arg0=escript.Data(float(arg0),arg1.getFunctionSpace())            arg0=escript.Data(float(arg0),arg1.getFunctionSpace())
421         elif isinstance(arg1,float):         elif isinstance(arg1,float):
422            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
423            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
424         elif isinstance(arg1,int):         elif isinstance(arg1,int):
425            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
426            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
427         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
428            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
429         else:         else:
430            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
431      else:      else:
# Line 474  def matchShape(arg0,arg1): Line 448  def matchShape(arg0,arg1):
448      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
449      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
450      if len(sh0)<len(sh):      if len(sh0)<len(sh):
451         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float)),arg1         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float64)),arg1
452      elif len(sh1)<len(sh):      elif len(sh1)<len(sh):
453         return arg0,outer(arg1,numarray.ones(sh[len(sh1):],numarray.Float))         return arg0,outer(arg1,numarray.ones(sh[len(sh1):],numarray.Float64))
454      else:      else:
455         return arg0,arg1         return arg0,arg1
456  #=========================================================  #=========================================================
# Line 602  class Symbol(object): Line 576  class Symbol(object):
576            else:            else:
577                s=pokeShape(s)+arg.getShape()                s=pokeShape(s)+arg.getShape()
578                if len(s)>0:                if len(s)>0:
579                   out.append(numarray.zeros(s),numarray.Float)                   out.append(numarray.zeros(s),numarray.Float64)
580                else:                else:
581                   out.append(a)                   out.append(a)
582         return out         return out
# Line 692  class Symbol(object): Line 666  class Symbol(object):
666         else:         else:
667            s=self.getShape()+arg.getShape()            s=self.getShape()+arg.getShape()
668            if len(s)>0:            if len(s)>0:
669               return numarray.zeros(s,numarray.Float)               return numarray.zeros(s,numarray.Float64)
670            else:            else:
671               return 0.               return 0.
672    
# Line 830  class Symbol(object): Line 804  class Symbol(object):
804         """         """
805         return power(other,self)         return power(other,self)
806    
807       def __getitem__(self,index):
808           """
809           returns the slice defined by index
810    
811           @param index: defines a
812           @type index: C{slice} or C{int} or a C{tuple} of them
813           @return: a S{Symbol} representing the slice defined by index
814           @rtype: L{DependendSymbol}
815           """
816           return GetSlice_Symbol(self,index)
817    
818  class DependendSymbol(Symbol):  class DependendSymbol(Symbol):
819     """     """
820     DependendSymbol extents L{Symbol} by modifying the == operator to allow two instances to be equal.     DependendSymbol extents L{Symbol} by modifying the == operator to allow two instances to be equal.
# Line 880  class DependendSymbol(Symbol): Line 865  class DependendSymbol(Symbol):
865  #=========================================================  #=========================================================
866  #  Unary operations prserving the shape  #  Unary operations prserving the shape
867  #========================================================  #========================================================
868    class GetSlice_Symbol(DependendSymbol):
869       """
870       L{Symbol} representing getting a slice for a L{Symbol}
871       """
872       def __init__(self,arg,index):
873          """
874          initialization of wherePositive L{Symbol} with argument arg
875          @param arg: argument
876          @type arg: L{Symbol}.
877          @param index: defines index
878          @type index: C{slice} or C{int} or a C{tuple} of them
879          @raises IndexError: if length of index is larger than rank of arg or a index start or stop is out of range
880          @raises ValueError: if a step is given
881          """
882          if not isinstance(index,tuple): index=(index,)
883          if len(index)>arg.getRank():
884               raise IndexError,"GetSlice_Symbol: index out of range."
885          sh=()
886          index2=()
887          for i in range(len(index)):
888             ix=index[i]
889             if isinstance(ix,int):
890                if ix<0 or ix>=arg.getShape()[i]:
891                   raise ValueError,"GetSlice_Symbol: index out of range."
892                index2=index2+(ix,)
893             else:
894               if not ix.step==None:
895                 raise ValueError,"GetSlice_Symbol: steping is not supported."
896               if ix.start==None:
897                  s=0
898               else:
899                  s=ix.start
900               if ix.stop==None:
901                  e=arg.getShape()[i]
902               else:
903                  e=ix.stop
904                  if e>arg.getShape()[i]:
905                     raise IndexError,"GetSlice_Symbol: index out of range."
906               index2=index2+(slice(s,e),)
907               if e>s:
908                   sh=sh+(e-s,)
909               elif s>e:
910                   raise IndexError,"GetSlice_Symbol: slice start must be less or equal slice end"
911          for i in range(len(index),arg.getRank()):
912              index2=index2+(slice(0,arg.getShape()[i]),)
913              sh=sh+(arg.getShape()[i],)
914          super(GetSlice_Symbol, self).__init__(args=[arg,index2],shape=sh,dim=arg.getDim())
915    
916       def getMyCode(self,argstrs,format="escript"):
917          """
918          returns a program code that can be used to evaluate the symbol.
919    
920          @param argstrs: gives for each argument a string representing the argument for the evaluation.
921          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
922          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
923          @type format: C{str}
924          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
925          @rtype: C{str}
926          @raise: NotImplementedError: if the requested format is not available
927          """
928          if format=="escript" or format=="str"  or format=="text":
929             return "%s.__getitem__(%s)"%(argstrs[0],argstrs[1])
930          else:
931             raise NotImplementedError,"GetItem_Symbol does not provide program code for format %s."%format
932    
933       def substitute(self,argvals):
934          """
935          assigns new values to symbols in the definition of the symbol.
936          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
937    
938          @param argvals: new values assigned to symbols
939          @type argvals: C{dict} with keywords of type L{Symbol}.
940          @return: result of the substitution process. Operations are executed as much as possible.
941          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
942          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
943          """
944          if argvals.has_key(self):
945             arg=argvals[self]
946             if self.isAppropriateValue(arg):
947                return arg
948             else:
949                raise TypeError,"%s: new value is not appropriate."%str(self)
950          else:
951             args=self.getSubstitutedArguments(argvals)
952             arg=args[0]
953             index=args[1]
954             return arg.__getitem__(index)
955    
956  def log10(arg):  def log10(arg):
957     """     """
958     returns base-10 logarithm of argument arg     returns base-10 logarithm of argument arg
# Line 912  def wherePositive(arg): Line 985  def wherePositive(arg):
985     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
986     """     """
987     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
988        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
989        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
990        return out        return out
991     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
992        return arg._wherePositive()        return arg._wherePositive()
# Line 994  def whereNegative(arg): Line 1067  def whereNegative(arg):
1067     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1068     """     """
1069     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1070        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1071        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1072        return out        return out
1073     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1074        return arg._whereNegative()        return arg._whereNegative()
# Line 1076  def whereNonNegative(arg): Line 1149  def whereNonNegative(arg):
1149     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1150     """     """
1151     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1152        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1153        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1154        return out        return out
1155     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1156        return arg._whereNonNegative()        return arg._whereNonNegative()
# Line 1106  def whereNonPositive(arg): Line 1179  def whereNonPositive(arg):
1179     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1180     """     """
1181     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1182        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1183        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1184        return out        return out
1185     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1186        return arg._whereNonPositive()        return arg._whereNonPositive()
# Line 1138  def whereZero(arg,tol=0.): Line 1211  def whereZero(arg,tol=0.):
1211     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1212     """     """
1213     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1214        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float64))*1.
1215        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1216        return out        return out
1217     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1218        if tol>0.:        if tol>0.:
# Line 1221  def whereNonZero(arg,tol=0.): Line 1294  def whereNonZero(arg,tol=0.):
1294     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1295     """     """
1296     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1297        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float64))*1.
1298        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1299        return out        return out
1300     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1301        if tol>0.:        if tol>0.:
# Line 2883  def trace(arg,axis_offset=0): Line 2956  def trace(arg,axis_offset=0):
2956        if not sh[axis_offset] == sh[axis_offset+1]:        if not sh[axis_offset] == sh[axis_offset+1]:
2957          raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)          raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2958        arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))        arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2959        out=numarray.zeros([s1,s2],numarray.Float)        out=numarray.zeros([s1,s2],numarray.Float64)
2960        for i1 in range(s1):        for i1 in range(s1):
2961          for i2 in range(s2):          for i2 in range(s2):
2962              for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]              for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
# Line 3013  class Trace_Symbol(DependendSymbol): Line 3086  class Trace_Symbol(DependendSymbol):
3086        else:        else:
3087           return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])           return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3088    
3089    def transpose(arg,axis_offset=None):
3090       """
3091       returns the transpose of arg by swaping the first axis_offset and the last rank-axis_offset components.
3092    
3093       @param arg: argument
3094       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}, C{float}, C{int}
3095       @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3096                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3097       @type axis_offset: C{int}
3098       @return: transpose of arg
3099       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray},C{float}, C{int} depending on the type of arg.
3100       """
3101       if isinstance(arg,numarray.NumArray):
3102          if axis_offset==None: axis_offset=int(arg.rank/2)
3103          return numarray.transpose(arg,axes=range(axis_offset,arg.rank)+range(0,axis_offset))
3104       elif isinstance(arg,escript.Data):
3105          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3106          return escript_transpose(arg,axis_offset)
3107       elif isinstance(arg,float):
3108          if not ( axis_offset==0 or axis_offset==None):
3109            raise ValueError,"transpose: axis_offset must be 0 for float argument"
3110          return arg
3111       elif isinstance(arg,int):
3112          if not ( axis_offset==0 or axis_offset==None):
3113            raise ValueError,"transpose: axis_offset must be 0 for int argument"
3114          return float(arg)
3115       elif isinstance(arg,Symbol):
3116          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3117          return Transpose_Symbol(arg,axis_offset)
3118       else:
3119          raise TypeError,"transpose: Unknown argument type."
3120    
3121    def escript_transpose(arg,axis_offset): # this should be escript._transpose
3122          "arg si a Data objects!!!"
3123          r=arg.getRank()
3124          if axis_offset<0 or axis_offset>r:
3125            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3126          s=arg.getShape()
3127          s_out=s[axis_offset:]+s[:axis_offset]
3128          out=escript.Data(0.,s_out,arg.getFunctionSpace())
3129          if r==4:
3130             if axis_offset==1:
3131                for i0 in range(s_out[0]):
3132                   for i1 in range(s_out[1]):
3133                      for i2 in range(s_out[2]):
3134                         for i3 in range(s_out[3]):
3135                             out[i0,i1,i2,i3]=arg[i3,i0,i1,i2]
3136             elif axis_offset==2:
3137                for i0 in range(s_out[0]):
3138                   for i1 in range(s_out[1]):
3139                      for i2 in range(s_out[2]):
3140                         for i3 in range(s_out[3]):
3141                             out[i0,i1,i2,i3]=arg[i2,i3,i0,i1]
3142             elif axis_offset==3:
3143                for i0 in range(s_out[0]):
3144                   for i1 in range(s_out[1]):
3145                      for i2 in range(s_out[2]):
3146                         for i3 in range(s_out[3]):
3147                             out[i0,i1,i2,i3]=arg[i1,i2,i3,i0]
3148             else:
3149                for i0 in range(s_out[0]):
3150                   for i1 in range(s_out[1]):
3151                      for i2 in range(s_out[2]):
3152                         for i3 in range(s_out[3]):
3153                             out[i0,i1,i2,i3]=arg[i0,i1,i2,i3]
3154          elif r==3:
3155             if axis_offset==1:
3156                for i0 in range(s_out[0]):
3157                   for i1 in range(s_out[1]):
3158                      for i2 in range(s_out[2]):
3159                             out[i0,i1,i2]=arg[i2,i0,i1]
3160             elif axis_offset==2:
3161                for i0 in range(s_out[0]):
3162                   for i1 in range(s_out[1]):
3163                      for i2 in range(s_out[2]):
3164                             out[i0,i1,i2]=arg[i1,i2,i0]
3165             else:
3166                for i0 in range(s_out[0]):
3167                   for i1 in range(s_out[1]):
3168                      for i2 in range(s_out[2]):
3169                             out[i0,i1,i2]=arg[i0,i1,i2]
3170          elif r==2:
3171             if axis_offset==1:
3172                for i0 in range(s_out[0]):
3173                   for i1 in range(s_out[1]):
3174                             out[i0,i1]=arg[i1,i0]
3175             else:
3176                for i0 in range(s_out[0]):
3177                   for i1 in range(s_out[1]):
3178                             out[i0,i1]=arg[i0,i1]
3179          elif r==1:
3180              for i0 in range(s_out[0]):
3181                   out[i0]=arg[i0]
3182          elif r==0:
3183                 out=arg+0.
3184          return out
3185    class Transpose_Symbol(DependendSymbol):
3186       """
3187       L{Symbol} representing the result of the transpose function
3188       """
3189       def __init__(self,arg,axis_offset=None):
3190          """
3191          initialization of transpose L{Symbol} with argument arg
3192    
3193          @param arg: argument of function
3194          @type arg: L{Symbol}.
3195           @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3196                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3197          @type axis_offset: C{int}
3198          """
3199          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3200          if axis_offset<0 or axis_offset>arg.getRank():
3201            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3202          s=arg.getShape()
3203          super(Transpose_Symbol,self).__init__(args=[arg,axis_offset],shape=s[axis_offset:]+s[:axis_offset],dim=arg.getDim())
3204    
3205       def getMyCode(self,argstrs,format="escript"):
3206          """
3207          returns a program code that can be used to evaluate the symbol.
3208    
3209          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3210          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3211          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3212          @type format: C{str}
3213          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3214          @rtype: C{str}
3215          @raise: NotImplementedError: if the requested format is not available
3216          """
3217          if format=="escript" or format=="str"  or format=="text":
3218             return "transpose(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
3219          else:
3220             raise NotImplementedError,"Transpose_Symbol does not provide program code for format %s."%format
3221    
3222       def substitute(self,argvals):
3223          """
3224          assigns new values to symbols in the definition of the symbol.
3225          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3226    
3227          @param argvals: new values assigned to symbols
3228          @type argvals: C{dict} with keywords of type L{Symbol}.
3229          @return: result of the substitution process. Operations are executed as much as possible.
3230          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3231          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3232          """
3233          if argvals.has_key(self):
3234             arg=argvals[self]
3235             if self.isAppropriateValue(arg):
3236                return arg
3237             else:
3238                raise TypeError,"%s: new value is not appropriate."%str(self)
3239          else:
3240             arg=self.getSubstitutedArguments(argvals)
3241             return transpose(arg[0],axis_offset=arg[1])
3242    
3243       def diff(self,arg):
3244          """
3245          differential of this object
3246    
3247          @param arg: the derivative is calculated with respect to arg
3248          @type arg: L{escript.Symbol}
3249          @return: derivative with respect to C{arg}
3250          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3251          """
3252          if arg==self:
3253             return identity(self.getShape())
3254          else:
3255             return transpose(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3256    def symmetric(arg):
3257        """
3258        returns the symmetric part of the square matrix arg. This is (arg+transpose(arg))/2
3259    
3260        @param arg: square matrix. Must have rank 2 or 4 and be square.
3261        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3262        @return: symmetric part of arg
3263        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3264        """
3265        if isinstance(arg,numarray.NumArray):
3266          if arg.rank==2:
3267            if not (arg.shape[0]==arg.shape[1]):
3268               raise ValueError,"symmetric: argument must be square."
3269          elif arg.rank==4:
3270            if not (arg.shape[0]==arg.shape[2] and arg.shape[1]==arg.shape[3]):
3271               raise ValueError,"symmetric: argument must be square."
3272          else:
3273            raise ValueError,"symmetric: rank 2 or 4 is required."
3274          return (arg+transpose(arg))/2
3275        elif isinstance(arg,escript.Data):
3276          return escript_symmetric(arg)
3277        elif isinstance(arg,float):
3278          return arg
3279        elif isinstance(arg,int):
3280          return float(arg)
3281        elif isinstance(arg,Symbol):
3282          if arg.getRank()==2:
3283            if not (arg.getShape()[0]==arg.getShape()[1]):
3284               raise ValueError,"symmetric: argument must be square."
3285          elif arg.getRank()==4:
3286            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3287               raise ValueError,"symmetric: argument must be square."
3288          else:
3289            raise ValueError,"symmetric: rank 2 or 4 is required."
3290          return (arg+transpose(arg))/2
3291        else:
3292          raise TypeError,"symmetric: Unknown argument type."
3293    
3294    def escript_symmetric(arg): # this should be implemented in c++
3295          if arg.getRank()==2:
3296            if not (arg.getShape()[0]==arg.getShape()[1]):
3297               raise ValueError,"escript_symmetric: argument must be square."
3298            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3299            for i0 in range(arg.getShape()[0]):
3300               for i1 in range(arg.getShape()[1]):
3301                  out[i0,i1]=(arg[i0,i1]+arg[i1,i0])/2.
3302          elif arg.getRank()==4:
3303            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3304               raise ValueError,"escript_symmetric: argument must be square."
3305            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3306            for i0 in range(arg.getShape()[0]):
3307               for i1 in range(arg.getShape()[1]):
3308                  for i2 in range(arg.getShape()[2]):
3309                     for i3 in range(arg.getShape()[3]):
3310                         out[i0,i1,i2,i3]=(arg[i0,i1,i2,i3]+arg[i2,i3,i0,i1])/2.
3311          else:
3312            raise ValueError,"escript_symmetric: rank 2 or 4 is required."
3313          return out
3314    
3315    def nonsymmetric(arg):
3316        """
3317        returns the nonsymmetric part of the square matrix arg. This is (arg-transpose(arg))/2
3318    
3319        @param arg: square matrix. Must have rank 2 or 4 and be square.
3320        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3321        @return: nonsymmetric part of arg
3322        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3323        """
3324        if isinstance(arg,numarray.NumArray):
3325          if arg.rank==2:
3326            if not (arg.shape[0]==arg.shape[1]):
3327               raise ValueError,"nonsymmetric: argument must be square."
3328          elif arg.rank==4:
3329            if not (arg.shape[0]==arg.shape[2] and arg.shape[1]==arg.shape[3]):
3330               raise ValueError,"nonsymmetric: argument must be square."
3331          else:
3332            raise ValueError,"nonsymmetric: rank 2 or 4 is required."
3333          return (arg-transpose(arg))/2
3334        elif isinstance(arg,escript.Data):
3335          return escript_nonsymmetric(arg)
3336        elif isinstance(arg,float):
3337          return arg
3338        elif isinstance(arg,int):
3339          return float(arg)
3340        elif isinstance(arg,Symbol):
3341          if arg.getRank()==2:
3342            if not (arg.getShape()[0]==arg.getShape()[1]):
3343               raise ValueError,"nonsymmetric: argument must be square."
3344          elif arg.getRank()==4:
3345            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3346               raise ValueError,"nonsymmetric: argument must be square."
3347          else:
3348            raise ValueError,"nonsymmetric: rank 2 or 4 is required."
3349          return (arg-transpose(arg))/2
3350        else:
3351          raise TypeError,"nonsymmetric: Unknown argument type."
3352    
3353    def escript_nonsymmetric(arg): # this should be implemented in c++
3354          if arg.getRank()==2:
3355            if not (arg.getShape()[0]==arg.getShape()[1]):
3356               raise ValueError,"escript_nonsymmetric: argument must be square."
3357            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3358            for i0 in range(arg.getShape()[0]):
3359               for i1 in range(arg.getShape()[1]):
3360                  out[i0,i1]=(arg[i0,i1]-arg[i1,i0])/2.
3361          elif arg.getRank()==4:
3362            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3363               raise ValueError,"escript_nonsymmetric: argument must be square."
3364            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3365            for i0 in range(arg.getShape()[0]):
3366               for i1 in range(arg.getShape()[1]):
3367                  for i2 in range(arg.getShape()[2]):
3368                     for i3 in range(arg.getShape()[3]):
3369                         out[i0,i1,i2,i3]=(arg[i0,i1,i2,i3]-arg[i2,i3,i0,i1])/2.
3370          else:
3371            raise ValueError,"escript_nonsymmetric: rank 2 or 4 is required."
3372          return out
3373    
3374    
3375    def inverse(arg):
3376        """
3377        returns the inverse of the square matrix arg.
3378    
3379        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3380        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3381        @return: inverse arg_inv of the argument. It will be matrixmul(inverse(arg),arg) almost equal to kronecker(arg.getShape()[0])
3382        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3383        @remark: for L{escript.Data} objects the dimension is restricted to 3.
3384        """
3385        import numarray.linear_algebra # This statement should be after the next statement but then somehow numarray is gone.
3386        if isinstance(arg,numarray.NumArray):
3387          return numarray.linear_algebra.inverse(arg)
3388        elif isinstance(arg,escript.Data):
3389          return escript_inverse(arg)
3390        elif isinstance(arg,float):
3391          return 1./arg
3392        elif isinstance(arg,int):
3393          return 1./float(arg)
3394        elif isinstance(arg,Symbol):
3395          return Inverse_Symbol(arg)
3396        else:
3397          raise TypeError,"inverse: Unknown argument type."
3398    
3399    def escript_inverse(arg): # this should be escript._inverse and use LAPACK
3400          "arg is a Data objects!!!"
3401          if not arg.getRank()==2:
3402            raise ValueError,"escript_inverse: argument must have rank 2"
3403          s=arg.getShape()      
3404          if not s[0] == s[1]:
3405            raise ValueError,"escript_inverse: argument must be a square matrix."
3406          out=escript.Data(0.,s,arg.getFunctionSpace())
3407          if s[0]==1:
3408              if inf(abs(arg[0,0]))==0: # in c this should be done point wise as abs(arg[0,0](i))<=0.
3409                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3410              out[0,0]=1./arg[0,0]
3411          elif s[0]==2:
3412              A11=arg[0,0]
3413              A12=arg[0,1]
3414              A21=arg[1,0]
3415              A22=arg[1,1]
3416              D = A11*A22-A12*A21
3417              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3418                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3419              D=1./D
3420              out[0,0]= A22*D
3421              out[1,0]=-A21*D
3422              out[0,1]=-A12*D
3423              out[1,1]= A11*D
3424          elif s[0]==3:
3425              A11=arg[0,0]
3426              A21=arg[1,0]
3427              A31=arg[2,0]
3428              A12=arg[0,1]
3429              A22=arg[1,1]
3430              A32=arg[2,1]
3431              A13=arg[0,2]
3432              A23=arg[1,2]
3433              A33=arg[2,2]
3434              D  =  A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
3435              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3436                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3437              D=1./D
3438              out[0,0]=(A22*A33-A23*A32)*D
3439              out[1,0]=(A31*A23-A21*A33)*D
3440              out[2,0]=(A21*A32-A31*A22)*D
3441              out[0,1]=(A13*A32-A12*A33)*D
3442              out[1,1]=(A11*A33-A31*A13)*D
3443              out[2,1]=(A12*A31-A11*A32)*D
3444              out[0,2]=(A12*A23-A13*A22)*D
3445              out[1,2]=(A13*A21-A11*A23)*D
3446              out[2,2]=(A11*A22-A12*A21)*D
3447          else:
3448             raise TypeError,"escript_inverse: only matrix dimensions 1,2,3 are supported right now."
3449          return out
3450    
3451    class Inverse_Symbol(DependendSymbol):
3452       """
3453       L{Symbol} representing the result of the inverse function
3454       """
3455       def __init__(self,arg):
3456          """
3457          initialization of inverse L{Symbol} with argument arg
3458          @param arg: argument of function
3459          @type arg: L{Symbol}.
3460          """
3461          if not arg.getRank()==2:
3462            raise ValueError,"Inverse_Symbol:: argument must have rank 2"
3463          s=arg.getShape()
3464          if not s[0] == s[1]:
3465            raise ValueError,"Inverse_Symbol:: argument must be a square matrix."
3466          super(Inverse_Symbol,self).__init__(args=[arg],shape=s,dim=arg.getDim())
3467    
3468       def getMyCode(self,argstrs,format="escript"):
3469          """
3470          returns a program code that can be used to evaluate the symbol.
3471    
3472          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3473          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3474          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3475          @type format: C{str}
3476          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3477          @rtype: C{str}
3478          @raise: NotImplementedError: if the requested format is not available
3479          """
3480          if format=="escript" or format=="str"  or format=="text":
3481             return "inverse(%s)"%argstrs[0]
3482          else:
3483             raise NotImplementedError,"Inverse_Symbol does not provide program code for format %s."%format
3484    
3485       def substitute(self,argvals):
3486          """
3487          assigns new values to symbols in the definition of the symbol.
3488          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3489    
3490          @param argvals: new values assigned to symbols
3491          @type argvals: C{dict} with keywords of type L{Symbol}.
3492          @return: result of the substitution process. Operations are executed as much as possible.
3493          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3494          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3495          """
3496          if argvals.has_key(self):
3497             arg=argvals[self]
3498             if self.isAppropriateValue(arg):
3499                return arg
3500             else:
3501                raise TypeError,"%s: new value is not appropriate."%str(self)
3502          else:
3503             arg=self.getSubstitutedArguments(argvals)
3504             return inverse(arg[0])
3505    
3506       def diff(self,arg):
3507          """
3508          differential of this object
3509    
3510          @param arg: the derivative is calculated with respect to arg
3511          @type arg: L{escript.Symbol}
3512          @return: derivative with respect to C{arg}
3513          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3514          """
3515          if arg==self:
3516             return identity(self.getShape())
3517          else:
3518             return -matrixmult(matrixmult(self,self.getDifferentiatedArguments(arg)[0]),self)
3519    
3520    def eigenvalues(arg):
3521        """
3522        returns the eigenvalues of the square matrix arg.
3523    
3524        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3525                    arg must be symmetric, ie. transpose(arg)==arg (this is not checked).
3526        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3527        @return: the eigenvalues in increasing order.
3528        @rtype: L{numarray.NumArray},L{escript.Data}, L{Symbol} depending on the input.
3529        @remark: for L{escript.Data} and L{Symbol} objects the dimension is restricted to 3.
3530        """
3531        if isinstance(arg,numarray.NumArray):
3532          out=numarray.linear_algebra.eigenvalues((arg+numarray.transpose(arg))/2.)
3533          out.sort()
3534          return out
3535        elif isinstance(arg,escript.Data):
3536          return arg._eigenvalues()
3537        elif isinstance(arg,Symbol):
3538          if not arg.getRank()==2:
3539            raise ValueError,"eigenvalues: argument must have rank 2"
3540          s=arg.getShape()      
3541          if not s[0] == s[1]:
3542            raise ValueError,"eigenvalues: argument must be a square matrix."
3543          if s[0]==1:
3544              return arg[0]
3545          elif s[0]==2:
3546              arg1=symmetric(arg)
3547              A11=arg1[0,0]
3548              A12=arg1[0,1]
3549              A22=arg1[1,1]
3550              trA=(A11+A22)/2.
3551              A11-=trA
3552              A22-=trA
3553              s=sqrt(A12**2-A11*A22)
3554              return trA+s*numarray.array([-1.,1.],type=numarray.Float64)
3555          elif s[0]==3:
3556              arg1=symmetric(arg)
3557              A11=arg1[0,0]
3558              A12=arg1[0,1]
3559              A22=arg1[1,1]
3560              A13=arg1[0,2]
3561              A23=arg1[1,2]
3562              A33=arg1[2,2]
3563              trA=(A11+A22+A33)/3.
3564              A11-=trA
3565              A22-=trA
3566              A33-=trA
3567              A13_2=A13**2
3568              A23_2=A23**2
3569              A12_2=A12**2
3570              p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.
3571              q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13
3572              sq_p=sqrt(p/3.)
3573              alpha_3=acos(clip(-q*(sq_p+whereZero(p,0.)*1.e-15)**(-3.)/2.,-1.,1.))/3.  # whereZero is protection against divison by zero
3574              sq_p*=2.
3575              f=cos(alpha_3)               *numarray.array([0.,0.,1.],type=numarray.Float64) \
3576               -cos(alpha_3+numarray.pi/3.)*numarray.array([0.,1.,0.],type=numarray.Float64) \
3577               -cos(alpha_3-numarray.pi/3.)*numarray.array([1.,0.,0.],type=numarray.Float64)
3578              return trA+sq_p*f
3579          else:
3580             raise TypeError,"eigenvalues: only matrix dimensions 1,2,3 are supported right now."
3581        elif isinstance(arg,float):
3582          return arg
3583        elif isinstance(arg,int):
3584          return float(arg)
3585        else:
3586          raise TypeError,"eigenvalues: Unknown argument type."
3587    
3588    def eigenvalues_and_eigenvectors(arg):
3589        """
3590        returns the eigenvalues and eigenvectors of the square matrix arg.
3591    
3592        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3593                    arg must be symmetric, ie. transpose(arg)==arg (this is not checked).
3594        @type arg: L{escript.Data}
3595        @return: the eigenvalues and eigenvectors. The eigenvalues are ordered by increasing value. The
3596                 eigenvectors are orthogonal and normalized. If V are the eigenvectors than V[:,i] is
3597                 the eigenvector coresponding to the i-th eigenvalue.
3598        @rtype: L{tuple} of L{escript.Data}.
3599        @remark: The dimension is restricted to 3.
3600        """
3601        if isinstance(arg,numarray.NumArray):
3602          raise TypeError,"eigenvalues_and_eigenvectors is not supporting numarray arguments"
3603        elif isinstance(arg,escript.Data):
3604          return arg._eigenvalues_and_eigenvectors()
3605        elif isinstance(arg,Symbol):
3606          raise TypeError,"eigenvalues_and_eigenvectors is not supporting Symbol arguments"
3607        elif isinstance(arg,float):
3608          return (numarray.array([[arg]],numarray.Float),numarray.ones((1,1),numarray.Float))
3609        elif isinstance(arg,int):
3610          return (numarray.array([[arg]],numarray.Float),numarray.ones((1,1),numarray.Float))
3611        else:
3612          raise TypeError,"eigenvalues: Unknown argument type."
3613  #=======================================================  #=======================================================
3614  #  Binary operations:  #  Binary operations:
3615  #=======================================================  #=======================================================
# Line 3131  def mult(arg0,arg1): Line 3728  def mult(arg0,arg1):
3728         """         """
3729         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3730         if testForZero(args[0]) or testForZero(args[1]):         if testForZero(args[0]) or testForZero(args[1]):
3731            return numarray.zeros(pokeShape(args[0]),numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3732         else:         else:
3733            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :
3734                return Mult_Symbol(args[0],args[1])                return Mult_Symbol(args[0],args[1])
# Line 3231  def quotient(arg0,arg1): Line 3828  def quotient(arg0,arg1):
3828         """         """
3829         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3830         if testForZero(args[0]):         if testForZero(args[0]):
3831            return numarray.zeros(pokeShape(args[0]),numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3832         elif isinstance(args[0],Symbol):         elif isinstance(args[0],Symbol):
3833            if isinstance(args[1],Symbol):            if isinstance(args[1],Symbol):
3834               return Quotient_Symbol(args[0],args[1])               return Quotient_Symbol(args[0],args[1])
# Line 3337  def power(arg0,arg1): Line 3934  def power(arg0,arg1):
3934         """         """
3935         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3936         if testForZero(args[0]):         if testForZero(args[0]):
3937            return numarray.zeros(args[0],numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3938         elif testForZero(args[1]):         elif testForZero(args[1]):
3939            return numarray.ones(args[0],numarray.Float)            return numarray.ones(pokeShape(args[1]),numarray.Float64)
3940         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):
3941            return Power_Symbol(args[0],args[1])            return Power_Symbol(args[0],args[1])
3942         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):
# Line 3636  def generalTensorProduct(arg0,arg1,axis_ Line 4233  def generalTensorProduct(arg0,arg1,axis_
4233             for i in sh1[:axis_offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
4234             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
4235             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
4236             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float64)
4237             for i0 in range(d0):             for i0 in range(d0):
4238                      for i1 in range(d1):                      for i1 in range(d1):
4239                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])
# Line 3804  class Grad_Symbol(DependendSymbol): Line 4401  class Grad_Symbol(DependendSymbol):
4401        d=arg.getDim()        d=arg.getDim()
4402        if d==None:        if d==None:
4403           raise ValueError,"argument must have a spatial dimension"           raise ValueError,"argument must have a spatial dimension"
4404        super(Grad_Symbol,self).__init__(args=[arg,where],shape=tuple(list(arg.getShape()).extend(d)),dim=d)        super(Grad_Symbol,self).__init__(args=[arg,where],shape=arg.getShape()+(d,),dim=d)
4405    
4406     def getMyCode(self,argstrs,format="escript"):     def getMyCode(self,argstrs,format="escript"):
4407        """        """
# Line 4045  def div(arg,where=None): Line 4642  def div(arg,where=None):
4642      @return: divergence of arg.      @return: divergence of arg.
4643      @rtype:  L{escript.Data} or L{Symbol}      @rtype:  L{escript.Data} or L{Symbol}
4644      """      """
4645      if not arg.getShape()==(arg.getDim(),):      if isinstance(arg,Symbol):
4646        raise ValueError,"div: expected shape is (%s,)"%arg.getDim()          dim=arg.getDim()
4647        elif isinstance(arg,escript.Data):
4648            dim=arg.getDomain().getDim()
4649        else:
4650            raise TypeError,"div: argument type not supported"
4651        if not arg.getShape()==(dim,):
4652          raise ValueError,"div: expected shape is (%s,)"%dim
4653      return trace(grad(arg,where))      return trace(grad(arg,where))
4654    
4655  def jump(arg,domain=None):  def jump(arg,domain=None):
# Line 4063  def jump(arg,domain=None): Line 4666  def jump(arg,domain=None):
4666      """      """
4667      if domain==None: domain=arg.getDomain()      if domain==None: domain=arg.getDomain()
4668      return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))      return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))
 #=============================  
 #  
 # wrapper for various functions: if the argument has attribute the function name  
 # as an argument it calls the corresponding methods. Otherwise the corresponding  
 # numarray function is called.  
   
 # functions involving the underlying Domain:  
4669    
4670  def transpose(arg,axis=None):  def L2(arg):
4671      """      """
4672      Returns the transpose of the Data object arg.      returns the L2 norm of arg at where
4673        
4674      @param arg:      @param arg: function which L2 to be calculated.
4675        @type arg: L{escript.Data} or L{Symbol}
4676        @return: L2 norm of arg.
4677        @rtype:  L{float} or L{Symbol}
4678        @note: L2(arg) is equivalent to sqrt(integrate(inner(arg,arg)))
4679      """      """
4680      if axis==None:      return sqrt(integrate(inner(arg,arg)))
4681         r=0  #=============================
4682         if hasattr(arg,"getRank"): r=arg.getRank()  #
        if hasattr(arg,"rank"): r=arg.rank  
        axis=r/2  
     if isinstance(arg,Symbol):  
        return Transpose_Symbol(arg,axis=r)  
     if isinstance(arg,escript.Data):  
        # hack for transpose  
        r=arg.getRank()  
        if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"  
        s=arg.getShape()  
        out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())  
        for i in range(s[0]):  
           for j in range(s[1]):  
              out[j,i]=arg[i,j]  
        return out  
        # end hack for transpose  
        return arg.transpose(axis)  
     else:  
        return numarray.transpose(arg,axis=axis)  
   
   
4683    
4684  def reorderComponents(arg,index):  def reorderComponents(arg,index):
4685      """      """
4686      resorts the component of arg according to index      resorts the component of arg according to index
4687    
4688      """      """
4689      pass      raise NotImplementedError
4690  #  #
4691  # $Log: util.py,v $  # $Log: util.py,v $
4692  # Revision 1.14.2.16  2005/10/19 06:09:57  gross  # Revision 1.14.2.16  2005/10/19 06:09:57  gross

Legend:
Removed from v.429  
changed lines
  Added in v.614

  ViewVC Help
Powered by ViewVC 1.1.26