/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 341 by gross, Mon Dec 12 05:26:10 2005 UTC revision 637 by gross, Thu Mar 23 10:55:31 2006 UTC
# Line 1  Line 1 
1  # $Id$  # $Id$
 #  
 #      COPYRIGHT ACcESS 2004 -  All Rights Reserved  
 #  
 #   This software is the property of ACcESS.  No part of this code  
 #   may be copied in any form or by any means without the expressed written  
 #   consent of ACcESS.  Copying, use or modification of this software  
 #   by any unauthorised person is illegal unless that  
 #   person has a software license agreement with ACcESS.  
 #  
2    
3  """  """
4  Utility functions for escript  Utility functions for escript
5    
 @remark:  This module is under construction and is still tested!!!  
   
6  @var __author__: name of author  @var __author__: name of author
7  @var __licence__: licence agreement  @var __copyright__: copyrights
8    @var __license__: licence agreement
9  @var __url__: url entry point on documentation  @var __url__: url entry point on documentation
10  @var __version__: version  @var __version__: version
11  @var __date__: date of the version  @var __date__: date of the version
12  """  """
13                                                                                                                                                                                                                                                                                                                                                                                                            
14  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
15  __licence__="contact: esys@access.uq.edu.au"  __copyright__="""  Copyright (c) 2006 by ACcESS MNRF
16                        http://www.access.edu.au
17                    Primary Business: Queensland, Australia"""
18    __license__="""Licensed under the Open Software License version 3.0
19                 http://www.opensource.org/licenses/osl-3.0.php"""
20  __url__="http://www.iservo.edu.au/esys/escript"  __url__="http://www.iservo.edu.au/esys/escript"
21  __version__="$Revision: 329 $"  __version__="$Revision$"
22  __date__="$Date$"  __date__="$Date$"
23    
24    
# Line 33  import numarray Line 27  import numarray
27  import escript  import escript
28  import os  import os
29    
 # missing tests:  
   
 # def pokeShape(arg):  
 # def pokeDim(arg):  
 # def commonShape(arg0,arg1):  
 # def commonDim(*args):  
 # def testForZero(arg):  
 # def matchType(arg0=0.,arg1=0.):  
 # def matchShape(arg0,arg1):  
   
 # def maximum(arg0,arg1):  
 # def minimum(arg0,arg1):  
   
 # def transpose(arg,axis=None):  
 # def trace(arg,axis0=0,axis1=1):  
 # def reorderComponents(arg,index):  
   
 # def integrate(arg,where=None):  
 # def interpolate(arg,where):  
 # def div(arg,where=None):  
 # def grad(arg,where=None):  
   
 #  
 # slicing: get  
 #          set  
 #  
 # and derivatives  
   
30  #=========================================================  #=========================================================
31  #   some helpers:  #   some helpers:
32  #=========================================================  #=========================================================
# Line 125  def kronecker(d=3): Line 91  def kronecker(d=3):
91     return the kronecker S{delta}-symbol     return the kronecker S{delta}-symbol
92    
93     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
94     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
95     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
96     @rtype d: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2.
    @remark: the function is identical L{identity}  
97     """     """
98     return identityTensor(d)     return identityTensor(d)
99    
# Line 143  def identity(shape=()): Line 108  def identity(shape=()):
108     @raise ValueError: if len(shape)>2.     @raise ValueError: if len(shape)>2.
109     """     """
110     if len(shape)>0:     if len(shape)>0:
111        out=numarray.zeros(shape+shape,numarray.Float)        out=numarray.zeros(shape+shape,numarray.Float64)
112        if len(shape)==1:        if len(shape)==1:
113            for i0 in range(shape[0]):            for i0 in range(shape[0]):
114               out[i0,i0]=1.               out[i0,i0]=1.
   
115        elif len(shape)==2:        elif len(shape)==2:
116            for i0 in range(shape[0]):            for i0 in range(shape[0]):
117               for i1 in range(shape[1]):               for i1 in range(shape[1]):
# Line 163  def identityTensor(d=3): Line 127  def identityTensor(d=3):
127     return the dxd identity matrix     return the dxd identity matrix
128    
129     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
130     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
131     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
132     @rtype: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2
133     """     """
134     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
135        d=d.getDim()         return escript.Data(identity((d.getDim(),)),d)
136     return identity(shape=(d,))     elif isinstance(d,escript.Domain):
137           return identity((d.getDim(),))
138       else:
139           return identity((d,))
140    
141  def identityTensor4(d=3):  def identityTensor4(d=3):
142     """     """
# Line 178  def identityTensor4(d=3): Line 145  def identityTensor4(d=3):
145     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
146     @type d: C{int} or any object with a C{getDim} method     @type d: C{int} or any object with a C{getDim} method
147     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise
148     @rtype: L{numarray.NumArray} of rank 4.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 4.
149     """     """
150     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
151        d=d.getDim()         return escript.Data(identity((d.getDim(),d.getDim())),d)
152     return identity((d,d))     elif isinstance(d,escript.Domain):
153           return identity((d.getDim(),d.getDim()))
154       else:
155           return identity((d,d))
156    
157  def unitVector(i=0,d=3):  def unitVector(i=0,d=3):
158     """     """
# Line 191  def unitVector(i=0,d=3): Line 161  def unitVector(i=0,d=3):
161     @param i: index     @param i: index
162     @type i: C{int}     @type i: C{int}
163     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
164     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
165     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise
166     @rtype: L{numarray.NumArray} of rank 1.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 1
167     """     """
168     return kronecker(d)[i]     return kronecker(d)[i]
169    
# Line 363  def testForZero(arg): Line 333  def testForZero(arg):
333      @return : True if the argument is identical to zero.      @return : True if the argument is identical to zero.
334      @rtype : C{bool}      @rtype : C{bool}
335      """      """
336      try:      if isinstance(arg,numarray.NumArray):
337           return not Lsup(arg)>0.
338        elif isinstance(arg,escript.Data):
339           return False
340        elif isinstance(arg,float):
341           return not Lsup(arg)>0.
342        elif isinstance(arg,int):
343         return not Lsup(arg)>0.         return not Lsup(arg)>0.
344      except TypeError:      elif isinstance(arg,Symbol):
345           return False
346        else:
347         return False         return False
348    
349  def matchType(arg0=0.,arg1=0.):  def matchType(arg0=0.,arg1=0.):
# Line 386  def matchType(arg0=0.,arg1=0.): Line 364  def matchType(arg0=0.,arg1=0.):
364         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
365            arg0=escript.Data(arg0,arg1.getFunctionSpace())            arg0=escript.Data(arg0,arg1.getFunctionSpace())
366         elif isinstance(arg1,float):         elif isinstance(arg1,float):
367            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
368         elif isinstance(arg1,int):         elif isinstance(arg1,int):
369            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
370         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
371            pass            pass
372         else:         else:
# Line 412  def matchType(arg0=0.,arg1=0.): Line 390  def matchType(arg0=0.,arg1=0.):
390         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
391            pass            pass
392         elif isinstance(arg1,float):         elif isinstance(arg1,float):
393            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
394         elif isinstance(arg1,int):         elif isinstance(arg1,int):
395            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
396         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
397            pass            pass
398         else:         else:
399            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
400      elif isinstance(arg0,float):      elif isinstance(arg0,float):
401         if isinstance(arg1,numarray.NumArray):         if isinstance(arg1,numarray.NumArray):
402            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
403         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
404            arg0=escript.Data(arg0,arg1.getFunctionSpace())            arg0=escript.Data(arg0,arg1.getFunctionSpace())
405         elif isinstance(arg1,float):         elif isinstance(arg1,float):
406            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
407            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
408         elif isinstance(arg1,int):         elif isinstance(arg1,int):
409            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
410            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
411         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
412            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
413         else:         else:
414            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
415      elif isinstance(arg0,int):      elif isinstance(arg0,int):
416         if isinstance(arg1,numarray.NumArray):         if isinstance(arg1,numarray.NumArray):
417            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
418         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
419            arg0=escript.Data(float(arg0),arg1.getFunctionSpace())            arg0=escript.Data(float(arg0),arg1.getFunctionSpace())
420         elif isinstance(arg1,float):         elif isinstance(arg1,float):
421            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
422            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
423         elif isinstance(arg1,int):         elif isinstance(arg1,int):
424            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
425            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
426         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
427            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
428         else:         else:
429            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
430      else:      else:
# Line 469  def matchShape(arg0,arg1): Line 447  def matchShape(arg0,arg1):
447      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
448      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
449      if len(sh0)<len(sh):      if len(sh0)<len(sh):
450         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float)),arg1         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float64)),arg1
451      elif len(sh1)<len(sh):      elif len(sh1)<len(sh):
452         return arg0,outer(arg1,numarray.ones(sh[len(sh1):],numarray.Float))         return arg0,outer(arg1,numarray.ones(sh[len(sh1):],numarray.Float64))
453      else:      else:
454         return arg0,arg1         return arg0,arg1
455  #=========================================================  #=========================================================
# Line 597  class Symbol(object): Line 575  class Symbol(object):
575            else:            else:
576                s=pokeShape(s)+arg.getShape()                s=pokeShape(s)+arg.getShape()
577                if len(s)>0:                if len(s)>0:
578                   out.append(numarray.zeros(s),numarray.Float)                   out.append(numarray.zeros(s),numarray.Float64)
579                else:                else:
580                   out.append(a)                   out.append(a)
581         return out         return out
# Line 687  class Symbol(object): Line 665  class Symbol(object):
665         else:         else:
666            s=self.getShape()+arg.getShape()            s=self.getShape()+arg.getShape()
667            if len(s)>0:            if len(s)>0:
668               return numarray.zeros(s,numarray.Float)               return numarray.zeros(s,numarray.Float64)
669            else:            else:
670               return 0.               return 0.
671    
# Line 825  class Symbol(object): Line 803  class Symbol(object):
803         """         """
804         return power(other,self)         return power(other,self)
805    
806       def __getitem__(self,index):
807           """
808           returns the slice defined by index
809    
810           @param index: defines a
811           @type index: C{slice} or C{int} or a C{tuple} of them
812           @return: a S{Symbol} representing the slice defined by index
813           @rtype: L{DependendSymbol}
814           """
815           return GetSlice_Symbol(self,index)
816    
817  class DependendSymbol(Symbol):  class DependendSymbol(Symbol):
818     """     """
819     DependendSymbol extents L{Symbol} by modifying the == operator to allow two instances to be equal.     DependendSymbol extents L{Symbol} by modifying the == operator to allow two instances to be equal.
# Line 875  class DependendSymbol(Symbol): Line 864  class DependendSymbol(Symbol):
864  #=========================================================  #=========================================================
865  #  Unary operations prserving the shape  #  Unary operations prserving the shape
866  #========================================================  #========================================================
867    class GetSlice_Symbol(DependendSymbol):
868       """
869       L{Symbol} representing getting a slice for a L{Symbol}
870       """
871       def __init__(self,arg,index):
872          """
873          initialization of wherePositive L{Symbol} with argument arg
874          @param arg: argument
875          @type arg: L{Symbol}.
876          @param index: defines index
877          @type index: C{slice} or C{int} or a C{tuple} of them
878          @raises IndexError: if length of index is larger than rank of arg or a index start or stop is out of range
879          @raises ValueError: if a step is given
880          """
881          if not isinstance(index,tuple): index=(index,)
882          if len(index)>arg.getRank():
883               raise IndexError,"GetSlice_Symbol: index out of range."
884          sh=()
885          index2=()
886          for i in range(len(index)):
887             ix=index[i]
888             if isinstance(ix,int):
889                if ix<0 or ix>=arg.getShape()[i]:
890                   raise ValueError,"GetSlice_Symbol: index out of range."
891                index2=index2+(ix,)
892             else:
893               if not ix.step==None:
894                 raise ValueError,"GetSlice_Symbol: steping is not supported."
895               if ix.start==None:
896                  s=0
897               else:
898                  s=ix.start
899               if ix.stop==None:
900                  e=arg.getShape()[i]
901               else:
902                  e=ix.stop
903                  if e>arg.getShape()[i]:
904                     raise IndexError,"GetSlice_Symbol: index out of range."
905               index2=index2+(slice(s,e),)
906               if e>s:
907                   sh=sh+(e-s,)
908               elif s>e:
909                   raise IndexError,"GetSlice_Symbol: slice start must be less or equal slice end"
910          for i in range(len(index),arg.getRank()):
911              index2=index2+(slice(0,arg.getShape()[i]),)
912              sh=sh+(arg.getShape()[i],)
913          super(GetSlice_Symbol, self).__init__(args=[arg,index2],shape=sh,dim=arg.getDim())
914    
915       def getMyCode(self,argstrs,format="escript"):
916          """
917          returns a program code that can be used to evaluate the symbol.
918    
919          @param argstrs: gives for each argument a string representing the argument for the evaluation.
920          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
921          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
922          @type format: C{str}
923          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
924          @rtype: C{str}
925          @raise: NotImplementedError: if the requested format is not available
926          """
927          if format=="escript" or format=="str"  or format=="text":
928             return "%s.__getitem__(%s)"%(argstrs[0],argstrs[1])
929          else:
930             raise NotImplementedError,"GetItem_Symbol does not provide program code for format %s."%format
931    
932       def substitute(self,argvals):
933          """
934          assigns new values to symbols in the definition of the symbol.
935          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
936    
937          @param argvals: new values assigned to symbols
938          @type argvals: C{dict} with keywords of type L{Symbol}.
939          @return: result of the substitution process. Operations are executed as much as possible.
940          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
941          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
942          """
943          if argvals.has_key(self):
944             arg=argvals[self]
945             if self.isAppropriateValue(arg):
946                return arg
947             else:
948                raise TypeError,"%s: new value is not appropriate."%str(self)
949          else:
950             args=self.getSubstitutedArguments(argvals)
951             arg=args[0]
952             index=args[1]
953             return arg.__getitem__(index)
954    
955  def log10(arg):  def log10(arg):
956     """     """
957     returns base-10 logarithm of argument arg     returns base-10 logarithm of argument arg
# Line 907  def wherePositive(arg): Line 984  def wherePositive(arg):
984     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
985     """     """
986     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
987        if arg.rank==0:        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
988           if arg>0:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
989             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))  
990     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
991        return arg._wherePositive()        return arg._wherePositive()
992     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 993  def whereNegative(arg): Line 1066  def whereNegative(arg):
1066     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1067     """     """
1068     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1069        if arg.rank==0:        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1070           if arg<0:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1071             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))  
1072     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1073        return arg._whereNegative()        return arg._whereNegative()
1074     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1079  def whereNonNegative(arg): Line 1148  def whereNonNegative(arg):
1148     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1149     """     """
1150     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1151        if arg.rank==0:        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1152           if arg<0:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1153             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))  
1154     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1155        return arg._whereNonNegative()        return arg._whereNonNegative()
1156     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1113  def whereNonPositive(arg): Line 1178  def whereNonPositive(arg):
1178     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1179     """     """
1180     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1181        if arg.rank==0:        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1182           if arg>0:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1183             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.  
1184     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1185        return arg._whereNonPositive()        return arg._whereNonPositive()
1186     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1149  def whereZero(arg,tol=0.): Line 1210  def whereZero(arg,tol=0.):
1210     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1211     """     """
1212     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1213        if arg.rank==0:        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float64))*1.
1214           if abs(arg)<=tol:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1215             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1216     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1217        if tol>0.:        if tol>0.:
1218           return whereNegative(abs(arg)-tol)           return whereNegative(abs(arg)-tol)
# Line 1236  def whereNonZero(arg,tol=0.): Line 1293  def whereNonZero(arg,tol=0.):
1293     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1294     """     """
1295     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1296        if arg.rank==0:        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float64))*1.
1297          if abs(arg)>tol:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1298             return numarray.array(1.)        return out
         else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1299     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1300        if tol>0.:        if tol>0.:
1301           return 1.-whereZero(arg,tol)           return 1.-whereZero(arg,tol)
# Line 2877  def length(arg): Line 2930  def length(arg):
2930     """     """
2931     return sqrt(inner(arg,arg))     return sqrt(inner(arg,arg))
2932    
2933    def trace(arg,axis_offset=0):
2934       """
2935       returns the trace of arg which the sum of arg[k,k] over k.
2936    
2937       @param arg: argument
2938       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}.
2939       @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2940                      axis_offset and axis_offset+1 must be equal.
2941       @type axis_offset: C{int}
2942       @return: trace of arg. The rank of the returned object is minus 2 of the rank of arg.
2943       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.
2944       """
2945       if isinstance(arg,numarray.NumArray):
2946          sh=arg.shape
2947          if len(sh)<2:
2948            raise ValueError,"trace: rank of argument must be greater than 1"
2949          if axis_offset<0 or axis_offset>len(sh)-2:
2950            raise ValueError,"trace: axis_offset must be between 0 and %s"%len(sh)-2
2951          s1=1
2952          for i in range(axis_offset): s1*=sh[i]
2953          s2=1
2954          for i in range(axis_offset+2,len(sh)): s2*=sh[i]
2955          if not sh[axis_offset] == sh[axis_offset+1]:
2956            raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2957          arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2958          out=numarray.zeros([s1,s2],numarray.Float64)
2959          for i1 in range(s1):
2960            for i2 in range(s2):
2961                for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
2962          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
2963          return out
2964       elif isinstance(arg,escript.Data):
2965          return escript_trace(arg,axis_offset)
2966       elif isinstance(arg,float):
2967          raise TypeError,"trace: illegal argument type float."
2968       elif isinstance(arg,int):
2969          raise TypeError,"trace: illegal argument type int."
2970       elif isinstance(arg,Symbol):
2971          return Trace_Symbol(arg,axis_offset)
2972       else:
2973          raise TypeError,"trace: Unknown argument type."
2974    
2975    def escript_trace(arg,axis_offset): # this should be escript._trace
2976          "arg si a Data objects!!!"
2977          if arg.getRank()<2:
2978            raise ValueError,"escript_trace: rank of argument must be greater than 1"
2979          if axis_offset<0 or axis_offset>arg.getRank()-2:
2980            raise ValueError,"escript_trace: axis_offset must be between 0 and %s"%arg.getRank()-2
2981          s=list(arg.getShape())        
2982          if not s[axis_offset] == s[axis_offset+1]:
2983            raise ValueError,"escript_trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2984          out=escript.Data(0.,tuple(s[0:axis_offset]+s[axis_offset+2:]),arg.getFunctionSpace())
2985          if arg.getRank()==2:
2986             for i0 in range(s[0]):
2987                out+=arg[i0,i0]
2988          elif arg.getRank()==3:
2989             if axis_offset==0:
2990                for i0 in range(s[0]):
2991                      for i2 in range(s[2]):
2992                             out[i2]+=arg[i0,i0,i2]
2993             elif axis_offset==1:
2994                for i0 in range(s[0]):
2995                   for i1 in range(s[1]):
2996                             out[i0]+=arg[i0,i1,i1]
2997          elif arg.getRank()==4:
2998             if axis_offset==0:
2999                for i0 in range(s[0]):
3000                      for i2 in range(s[2]):
3001                         for i3 in range(s[3]):
3002                             out[i2,i3]+=arg[i0,i0,i2,i3]
3003             elif axis_offset==1:
3004                for i0 in range(s[0]):
3005                   for i1 in range(s[1]):
3006                         for i3 in range(s[3]):
3007                             out[i0,i3]+=arg[i0,i1,i1,i3]
3008             elif axis_offset==2:
3009                for i0 in range(s[0]):
3010                   for i1 in range(s[1]):
3011                      for i2 in range(s[2]):
3012                             out[i0,i1]+=arg[i0,i1,i2,i2]
3013          return out
3014    class Trace_Symbol(DependendSymbol):
3015       """
3016       L{Symbol} representing the result of the trace function
3017       """
3018       def __init__(self,arg,axis_offset=0):
3019          """
3020          initialization of trace L{Symbol} with argument arg
3021          @param arg: argument of function
3022          @type arg: L{Symbol}.
3023          @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
3024                      axis_offset and axis_offset+1 must be equal.
3025          @type axis_offset: C{int}
3026          """
3027          if arg.getRank()<2:
3028            raise ValueError,"Trace_Symbol: rank of argument must be greater than 1"
3029          if axis_offset<0 or axis_offset>arg.getRank()-2:
3030            raise ValueError,"Trace_Symbol: axis_offset must be between 0 and %s"%arg.getRank()-2
3031          s=list(arg.getShape())        
3032          if not s[axis_offset] == s[axis_offset+1]:
3033            raise ValueError,"Trace_Symbol: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
3034          super(Trace_Symbol,self).__init__(args=[arg,axis_offset],shape=tuple(s[0:axis_offset]+s[axis_offset+2:]),dim=arg.getDim())
3035    
3036       def getMyCode(self,argstrs,format="escript"):
3037          """
3038          returns a program code that can be used to evaluate the symbol.
3039    
3040          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3041          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3042          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3043          @type format: C{str}
3044          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3045          @rtype: C{str}
3046          @raise: NotImplementedError: if the requested format is not available
3047          """
3048          if format=="escript" or format=="str"  or format=="text":
3049             return "trace(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
3050          else:
3051             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3052    
3053       def substitute(self,argvals):
3054          """
3055          assigns new values to symbols in the definition of the symbol.
3056          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3057    
3058          @param argvals: new values assigned to symbols
3059          @type argvals: C{dict} with keywords of type L{Symbol}.
3060          @return: result of the substitution process. Operations are executed as much as possible.
3061          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3062          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3063          """
3064          if argvals.has_key(self):
3065             arg=argvals[self]
3066             if self.isAppropriateValue(arg):
3067                return arg
3068             else:
3069                raise TypeError,"%s: new value is not appropriate."%str(self)
3070          else:
3071             arg=self.getSubstitutedArguments(argvals)
3072             return trace(arg[0],axis_offset=arg[1])
3073    
3074       def diff(self,arg):
3075          """
3076          differential of this object
3077    
3078          @param arg: the derivative is calculated with respect to arg
3079          @type arg: L{escript.Symbol}
3080          @return: derivative with respect to C{arg}
3081          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3082          """
3083          if arg==self:
3084             return identity(self.getShape())
3085          else:
3086             return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3087    
3088    def transpose(arg,axis_offset=None):
3089       """
3090       returns the transpose of arg by swaping the first axis_offset and the last rank-axis_offset components.
3091    
3092       @param arg: argument
3093       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}, C{float}, C{int}
3094       @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3095                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3096       @type axis_offset: C{int}
3097       @return: transpose of arg
3098       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray},C{float}, C{int} depending on the type of arg.
3099       """
3100       if isinstance(arg,numarray.NumArray):
3101          if axis_offset==None: axis_offset=int(arg.rank/2)
3102          return numarray.transpose(arg,axes=range(axis_offset,arg.rank)+range(0,axis_offset))
3103       elif isinstance(arg,escript.Data):
3104          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3105          return escript_transpose(arg,axis_offset)
3106       elif isinstance(arg,float):
3107          if not ( axis_offset==0 or axis_offset==None):
3108            raise ValueError,"transpose: axis_offset must be 0 for float argument"
3109          return arg
3110       elif isinstance(arg,int):
3111          if not ( axis_offset==0 or axis_offset==None):
3112            raise ValueError,"transpose: axis_offset must be 0 for int argument"
3113          return float(arg)
3114       elif isinstance(arg,Symbol):
3115          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3116          return Transpose_Symbol(arg,axis_offset)
3117       else:
3118          raise TypeError,"transpose: Unknown argument type."
3119    
3120    def escript_transpose(arg,axis_offset): # this should be escript._transpose
3121          "arg si a Data objects!!!"
3122          r=arg.getRank()
3123          if axis_offset<0 or axis_offset>r:
3124            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3125          s=arg.getShape()
3126          s_out=s[axis_offset:]+s[:axis_offset]
3127          out=escript.Data(0.,s_out,arg.getFunctionSpace())
3128          if r==4:
3129             if axis_offset==1:
3130                for i0 in range(s_out[0]):
3131                   for i1 in range(s_out[1]):
3132                      for i2 in range(s_out[2]):
3133                         for i3 in range(s_out[3]):
3134                             out[i0,i1,i2,i3]=arg[i3,i0,i1,i2]
3135             elif axis_offset==2:
3136                for i0 in range(s_out[0]):
3137                   for i1 in range(s_out[1]):
3138                      for i2 in range(s_out[2]):
3139                         for i3 in range(s_out[3]):
3140                             out[i0,i1,i2,i3]=arg[i2,i3,i0,i1]
3141             elif axis_offset==3:
3142                for i0 in range(s_out[0]):
3143                   for i1 in range(s_out[1]):
3144                      for i2 in range(s_out[2]):
3145                         for i3 in range(s_out[3]):
3146                             out[i0,i1,i2,i3]=arg[i1,i2,i3,i0]
3147             else:
3148                for i0 in range(s_out[0]):
3149                   for i1 in range(s_out[1]):
3150                      for i2 in range(s_out[2]):
3151                         for i3 in range(s_out[3]):
3152                             out[i0,i1,i2,i3]=arg[i0,i1,i2,i3]
3153          elif r==3:
3154             if axis_offset==1:
3155                for i0 in range(s_out[0]):
3156                   for i1 in range(s_out[1]):
3157                      for i2 in range(s_out[2]):
3158                             out[i0,i1,i2]=arg[i2,i0,i1]
3159             elif axis_offset==2:
3160                for i0 in range(s_out[0]):
3161                   for i1 in range(s_out[1]):
3162                      for i2 in range(s_out[2]):
3163                             out[i0,i1,i2]=arg[i1,i2,i0]
3164             else:
3165                for i0 in range(s_out[0]):
3166                   for i1 in range(s_out[1]):
3167                      for i2 in range(s_out[2]):
3168                             out[i0,i1,i2]=arg[i0,i1,i2]
3169          elif r==2:
3170             if axis_offset==1:
3171                for i0 in range(s_out[0]):
3172                   for i1 in range(s_out[1]):
3173                             out[i0,i1]=arg[i1,i0]
3174             else:
3175                for i0 in range(s_out[0]):
3176                   for i1 in range(s_out[1]):
3177                             out[i0,i1]=arg[i0,i1]
3178          elif r==1:
3179              for i0 in range(s_out[0]):
3180                   out[i0]=arg[i0]
3181          elif r==0:
3182                 out=arg+0.
3183          return out
3184    class Transpose_Symbol(DependendSymbol):
3185       """
3186       L{Symbol} representing the result of the transpose function
3187       """
3188       def __init__(self,arg,axis_offset=None):
3189          """
3190          initialization of transpose L{Symbol} with argument arg
3191    
3192          @param arg: argument of function
3193          @type arg: L{Symbol}.
3194           @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3195                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3196          @type axis_offset: C{int}
3197          """
3198          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3199          if axis_offset<0 or axis_offset>arg.getRank():
3200            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3201          s=arg.getShape()
3202          super(Transpose_Symbol,self).__init__(args=[arg,axis_offset],shape=s[axis_offset:]+s[:axis_offset],dim=arg.getDim())
3203    
3204       def getMyCode(self,argstrs,format="escript"):
3205          """
3206          returns a program code that can be used to evaluate the symbol.
3207    
3208          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3209          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3210          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3211          @type format: C{str}
3212          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3213          @rtype: C{str}
3214          @raise: NotImplementedError: if the requested format is not available
3215          """
3216          if format=="escript" or format=="str"  or format=="text":
3217             return "transpose(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
3218          else:
3219             raise NotImplementedError,"Transpose_Symbol does not provide program code for format %s."%format
3220    
3221       def substitute(self,argvals):
3222          """
3223          assigns new values to symbols in the definition of the symbol.
3224          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3225    
3226          @param argvals: new values assigned to symbols
3227          @type argvals: C{dict} with keywords of type L{Symbol}.
3228          @return: result of the substitution process. Operations are executed as much as possible.
3229          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3230          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3231          """
3232          if argvals.has_key(self):
3233             arg=argvals[self]
3234             if self.isAppropriateValue(arg):
3235                return arg
3236             else:
3237                raise TypeError,"%s: new value is not appropriate."%str(self)
3238          else:
3239             arg=self.getSubstitutedArguments(argvals)
3240             return transpose(arg[0],axis_offset=arg[1])
3241    
3242       def diff(self,arg):
3243          """
3244          differential of this object
3245    
3246          @param arg: the derivative is calculated with respect to arg
3247          @type arg: L{escript.Symbol}
3248          @return: derivative with respect to C{arg}
3249          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3250          """
3251          if arg==self:
3252             return identity(self.getShape())
3253          else:
3254             return transpose(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3255    def symmetric(arg):
3256        """
3257        returns the symmetric part of the square matrix arg. This is (arg+transpose(arg))/2
3258    
3259        @param arg: square matrix. Must have rank 2 or 4 and be square.
3260        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3261        @return: symmetric part of arg
3262        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3263        """
3264        if isinstance(arg,numarray.NumArray):
3265          if arg.rank==2:
3266            if not (arg.shape[0]==arg.shape[1]):
3267               raise ValueError,"symmetric: argument must be square."
3268          elif arg.rank==4:
3269            if not (arg.shape[0]==arg.shape[2] and arg.shape[1]==arg.shape[3]):
3270               raise ValueError,"symmetric: argument must be square."
3271          else:
3272            raise ValueError,"symmetric: rank 2 or 4 is required."
3273          return (arg+transpose(arg))/2
3274        elif isinstance(arg,escript.Data):
3275          return escript_symmetric(arg)
3276        elif isinstance(arg,float):
3277          return arg
3278        elif isinstance(arg,int):
3279          return float(arg)
3280        elif isinstance(arg,Symbol):
3281          if arg.getRank()==2:
3282            if not (arg.getShape()[0]==arg.getShape()[1]):
3283               raise ValueError,"symmetric: argument must be square."
3284          elif arg.getRank()==4:
3285            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3286               raise ValueError,"symmetric: argument must be square."
3287          else:
3288            raise ValueError,"symmetric: rank 2 or 4 is required."
3289          return (arg+transpose(arg))/2
3290        else:
3291          raise TypeError,"symmetric: Unknown argument type."
3292    
3293    def escript_symmetric(arg): # this should be implemented in c++
3294          if arg.getRank()==2:
3295            if not (arg.getShape()[0]==arg.getShape()[1]):
3296               raise ValueError,"escript_symmetric: argument must be square."
3297            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3298            for i0 in range(arg.getShape()[0]):
3299               for i1 in range(arg.getShape()[1]):
3300                  out[i0,i1]=(arg[i0,i1]+arg[i1,i0])/2.
3301          elif arg.getRank()==4:
3302            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3303               raise ValueError,"escript_symmetric: argument must be square."
3304            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3305            for i0 in range(arg.getShape()[0]):
3306               for i1 in range(arg.getShape()[1]):
3307                  for i2 in range(arg.getShape()[2]):
3308                     for i3 in range(arg.getShape()[3]):
3309                         out[i0,i1,i2,i3]=(arg[i0,i1,i2,i3]+arg[i2,i3,i0,i1])/2.
3310          else:
3311            raise ValueError,"escript_symmetric: rank 2 or 4 is required."
3312          return out
3313    
3314    def nonsymmetric(arg):
3315        """
3316        returns the nonsymmetric part of the square matrix arg. This is (arg-transpose(arg))/2
3317    
3318        @param arg: square matrix. Must have rank 2 or 4 and be square.
3319        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3320        @return: nonsymmetric part of arg
3321        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3322        """
3323        if isinstance(arg,numarray.NumArray):
3324          if arg.rank==2:
3325            if not (arg.shape[0]==arg.shape[1]):
3326               raise ValueError,"nonsymmetric: argument must be square."
3327          elif arg.rank==4:
3328            if not (arg.shape[0]==arg.shape[2] and arg.shape[1]==arg.shape[3]):
3329               raise ValueError,"nonsymmetric: argument must be square."
3330          else:
3331            raise ValueError,"nonsymmetric: rank 2 or 4 is required."
3332          return (arg-transpose(arg))/2
3333        elif isinstance(arg,escript.Data):
3334          return escript_nonsymmetric(arg)
3335        elif isinstance(arg,float):
3336          return arg
3337        elif isinstance(arg,int):
3338          return float(arg)
3339        elif isinstance(arg,Symbol):
3340          if arg.getRank()==2:
3341            if not (arg.getShape()[0]==arg.getShape()[1]):
3342               raise ValueError,"nonsymmetric: argument must be square."
3343          elif arg.getRank()==4:
3344            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3345               raise ValueError,"nonsymmetric: argument must be square."
3346          else:
3347            raise ValueError,"nonsymmetric: rank 2 or 4 is required."
3348          return (arg-transpose(arg))/2
3349        else:
3350          raise TypeError,"nonsymmetric: Unknown argument type."
3351    
3352    def escript_nonsymmetric(arg): # this should be implemented in c++
3353          if arg.getRank()==2:
3354            if not (arg.getShape()[0]==arg.getShape()[1]):
3355               raise ValueError,"escript_nonsymmetric: argument must be square."
3356            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3357            for i0 in range(arg.getShape()[0]):
3358               for i1 in range(arg.getShape()[1]):
3359                  out[i0,i1]=(arg[i0,i1]-arg[i1,i0])/2.
3360          elif arg.getRank()==4:
3361            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3362               raise ValueError,"escript_nonsymmetric: argument must be square."
3363            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3364            for i0 in range(arg.getShape()[0]):
3365               for i1 in range(arg.getShape()[1]):
3366                  for i2 in range(arg.getShape()[2]):
3367                     for i3 in range(arg.getShape()[3]):
3368                         out[i0,i1,i2,i3]=(arg[i0,i1,i2,i3]-arg[i2,i3,i0,i1])/2.
3369          else:
3370            raise ValueError,"escript_nonsymmetric: rank 2 or 4 is required."
3371          return out
3372    
3373    
3374    def inverse(arg):
3375        """
3376        returns the inverse of the square matrix arg.
3377    
3378        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3379        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3380        @return: inverse arg_inv of the argument. It will be matrixmul(inverse(arg),arg) almost equal to kronecker(arg.getShape()[0])
3381        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3382        @remark: for L{escript.Data} objects the dimension is restricted to 3.
3383        """
3384        import numarray.linear_algebra # This statement should be after the next statement but then somehow numarray is gone.
3385        if isinstance(arg,numarray.NumArray):
3386          return numarray.linear_algebra.inverse(arg)
3387        elif isinstance(arg,escript.Data):
3388          return escript_inverse(arg)
3389        elif isinstance(arg,float):
3390          return 1./arg
3391        elif isinstance(arg,int):
3392          return 1./float(arg)
3393        elif isinstance(arg,Symbol):
3394          return Inverse_Symbol(arg)
3395        else:
3396          raise TypeError,"inverse: Unknown argument type."
3397    
3398    def escript_inverse(arg): # this should be escript._inverse and use LAPACK
3399          "arg is a Data objects!!!"
3400          if not arg.getRank()==2:
3401            raise ValueError,"escript_inverse: argument must have rank 2"
3402          s=arg.getShape()      
3403          if not s[0] == s[1]:
3404            raise ValueError,"escript_inverse: argument must be a square matrix."
3405          out=escript.Data(0.,s,arg.getFunctionSpace())
3406          if s[0]==1:
3407              if inf(abs(arg[0,0]))==0: # in c this should be done point wise as abs(arg[0,0](i))<=0.
3408                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3409              out[0,0]=1./arg[0,0]
3410          elif s[0]==2:
3411              A11=arg[0,0]
3412              A12=arg[0,1]
3413              A21=arg[1,0]
3414              A22=arg[1,1]
3415              D = A11*A22-A12*A21
3416              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3417                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3418              D=1./D
3419              out[0,0]= A22*D
3420              out[1,0]=-A21*D
3421              out[0,1]=-A12*D
3422              out[1,1]= A11*D
3423          elif s[0]==3:
3424              A11=arg[0,0]
3425              A21=arg[1,0]
3426              A31=arg[2,0]
3427              A12=arg[0,1]
3428              A22=arg[1,1]
3429              A32=arg[2,1]
3430              A13=arg[0,2]
3431              A23=arg[1,2]
3432              A33=arg[2,2]
3433              D  =  A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
3434              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3435                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3436              D=1./D
3437              out[0,0]=(A22*A33-A23*A32)*D
3438              out[1,0]=(A31*A23-A21*A33)*D
3439              out[2,0]=(A21*A32-A31*A22)*D
3440              out[0,1]=(A13*A32-A12*A33)*D
3441              out[1,1]=(A11*A33-A31*A13)*D
3442              out[2,1]=(A12*A31-A11*A32)*D
3443              out[0,2]=(A12*A23-A13*A22)*D
3444              out[1,2]=(A13*A21-A11*A23)*D
3445              out[2,2]=(A11*A22-A12*A21)*D
3446          else:
3447             raise TypeError,"escript_inverse: only matrix dimensions 1,2,3 are supported right now."
3448          return out
3449    
3450    class Inverse_Symbol(DependendSymbol):
3451       """
3452       L{Symbol} representing the result of the inverse function
3453       """
3454       def __init__(self,arg):
3455          """
3456          initialization of inverse L{Symbol} with argument arg
3457          @param arg: argument of function
3458          @type arg: L{Symbol}.
3459          """
3460          if not arg.getRank()==2:
3461            raise ValueError,"Inverse_Symbol:: argument must have rank 2"
3462          s=arg.getShape()
3463          if not s[0] == s[1]:
3464            raise ValueError,"Inverse_Symbol:: argument must be a square matrix."
3465          super(Inverse_Symbol,self).__init__(args=[arg],shape=s,dim=arg.getDim())
3466    
3467       def getMyCode(self,argstrs,format="escript"):
3468          """
3469          returns a program code that can be used to evaluate the symbol.
3470    
3471          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3472          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3473          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3474          @type format: C{str}
3475          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3476          @rtype: C{str}
3477          @raise: NotImplementedError: if the requested format is not available
3478          """
3479          if format=="escript" or format=="str"  or format=="text":
3480             return "inverse(%s)"%argstrs[0]
3481          else:
3482             raise NotImplementedError,"Inverse_Symbol does not provide program code for format %s."%format
3483    
3484       def substitute(self,argvals):
3485          """
3486          assigns new values to symbols in the definition of the symbol.
3487          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3488    
3489          @param argvals: new values assigned to symbols
3490          @type argvals: C{dict} with keywords of type L{Symbol}.
3491          @return: result of the substitution process. Operations are executed as much as possible.
3492          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3493          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3494          """
3495          if argvals.has_key(self):
3496             arg=argvals[self]
3497             if self.isAppropriateValue(arg):
3498                return arg
3499             else:
3500                raise TypeError,"%s: new value is not appropriate."%str(self)
3501          else:
3502             arg=self.getSubstitutedArguments(argvals)
3503             return inverse(arg[0])
3504    
3505       def diff(self,arg):
3506          """
3507          differential of this object
3508    
3509          @param arg: the derivative is calculated with respect to arg
3510          @type arg: L{escript.Symbol}
3511          @return: derivative with respect to C{arg}
3512          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3513          """
3514          if arg==self:
3515             return identity(self.getShape())
3516          else:
3517             return -matrixmult(matrixmult(self,self.getDifferentiatedArguments(arg)[0]),self)
3518    
3519    def eigenvalues(arg):
3520        """
3521        returns the eigenvalues of the square matrix arg.
3522    
3523        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3524                    arg must be symmetric, ie. transpose(arg)==arg (this is not checked).
3525        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3526        @return: the eigenvalues in increasing order.
3527        @rtype: L{numarray.NumArray},L{escript.Data}, L{Symbol} depending on the input.
3528        @remark: for L{escript.Data} and L{Symbol} objects the dimension is restricted to 3.
3529        """
3530        if isinstance(arg,numarray.NumArray):
3531          out=numarray.linear_algebra.eigenvalues((arg+numarray.transpose(arg))/2.)
3532          out.sort()
3533          return out
3534        elif isinstance(arg,escript.Data):
3535          return arg._eigenvalues()
3536        elif isinstance(arg,Symbol):
3537          if not arg.getRank()==2:
3538            raise ValueError,"eigenvalues: argument must have rank 2"
3539          s=arg.getShape()      
3540          if not s[0] == s[1]:
3541            raise ValueError,"eigenvalues: argument must be a square matrix."
3542          if s[0]==1:
3543              return arg[0]
3544          elif s[0]==2:
3545              arg1=symmetric(arg)
3546              A11=arg1[0,0]
3547              A12=arg1[0,1]
3548              A22=arg1[1,1]
3549              trA=(A11+A22)/2.
3550              A11-=trA
3551              A22-=trA
3552              s=sqrt(A12**2-A11*A22)
3553              return trA+s*numarray.array([-1.,1.],type=numarray.Float64)
3554          elif s[0]==3:
3555              arg1=symmetric(arg)
3556              A11=arg1[0,0]
3557              A12=arg1[0,1]
3558              A22=arg1[1,1]
3559              A13=arg1[0,2]
3560              A23=arg1[1,2]
3561              A33=arg1[2,2]
3562              trA=(A11+A22+A33)/3.
3563              A11-=trA
3564              A22-=trA
3565              A33-=trA
3566              A13_2=A13**2
3567              A23_2=A23**2
3568              A12_2=A12**2
3569              p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.
3570              q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13
3571              sq_p=sqrt(p/3.)
3572              alpha_3=acos(clip(-q*(sq_p+whereZero(p,0.)*1.e-15)**(-3.)/2.,-1.,1.))/3.  # whereZero is protection against divison by zero
3573              sq_p*=2.
3574              f=cos(alpha_3)               *numarray.array([0.,0.,1.],type=numarray.Float64) \
3575               -cos(alpha_3+numarray.pi/3.)*numarray.array([0.,1.,0.],type=numarray.Float64) \
3576               -cos(alpha_3-numarray.pi/3.)*numarray.array([1.,0.,0.],type=numarray.Float64)
3577              return trA+sq_p*f
3578          else:
3579             raise TypeError,"eigenvalues: only matrix dimensions 1,2,3 are supported right now."
3580        elif isinstance(arg,float):
3581          return arg
3582        elif isinstance(arg,int):
3583          return float(arg)
3584        else:
3585          raise TypeError,"eigenvalues: Unknown argument type."
3586    
3587    def eigenvalues_and_eigenvectors(arg):
3588        """
3589        returns the eigenvalues and eigenvectors of the square matrix arg.
3590    
3591        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3592                    arg must be symmetric, ie. transpose(arg)==arg (this is not checked).
3593        @type arg: L{escript.Data}
3594        @return: the eigenvalues and eigenvectors. The eigenvalues are ordered by increasing value. The
3595                 eigenvectors are orthogonal and normalized. If V are the eigenvectors than V[:,i] is
3596                 the eigenvector coresponding to the i-th eigenvalue.
3597        @rtype: L{tuple} of L{escript.Data}.
3598        @remark: The dimension is restricted to 3.
3599        """
3600        if isinstance(arg,numarray.NumArray):
3601          raise TypeError,"eigenvalues_and_eigenvectors is not supporting numarray arguments"
3602        elif isinstance(arg,escript.Data):
3603          return arg._eigenvalues_and_eigenvectors()
3604        elif isinstance(arg,Symbol):
3605          raise TypeError,"eigenvalues_and_eigenvectors is not supporting Symbol arguments"
3606        elif isinstance(arg,float):
3607          return (numarray.array([[arg]],numarray.Float),numarray.ones((1,1),numarray.Float))
3608        elif isinstance(arg,int):
3609          return (numarray.array([[arg]],numarray.Float),numarray.ones((1,1),numarray.Float))
3610        else:
3611          raise TypeError,"eigenvalues: Unknown argument type."
3612  #=======================================================  #=======================================================
3613  #  Binary operations:  #  Binary operations:
3614  #=======================================================  #=======================================================
# Line 2995  def mult(arg0,arg1): Line 3727  def mult(arg0,arg1):
3727         """         """
3728         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3729         if testForZero(args[0]) or testForZero(args[1]):         if testForZero(args[0]) or testForZero(args[1]):
3730            return numarray.zeros(pokeShape(args[0]),numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3731         else:         else:
3732            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :
3733                return Mult_Symbol(args[0],args[1])                return Mult_Symbol(args[0],args[1])
# Line 3095  def quotient(arg0,arg1): Line 3827  def quotient(arg0,arg1):
3827         """         """
3828         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3829         if testForZero(args[0]):         if testForZero(args[0]):
3830            return numarray.zeros(pokeShape(args[0]),numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3831         elif isinstance(args[0],Symbol):         elif isinstance(args[0],Symbol):
3832            if isinstance(args[1],Symbol):            if isinstance(args[1],Symbol):
3833               return Quotient_Symbol(args[0],args[1])               return Quotient_Symbol(args[0],args[1])
# Line 3201  def power(arg0,arg1): Line 3933  def power(arg0,arg1):
3933         """         """
3934         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3935         if testForZero(args[0]):         if testForZero(args[0]):
3936            return numarray.zeros(args[0],numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3937         elif testForZero(args[1]):         elif testForZero(args[1]):
3938            return numarray.ones(args[0],numarray.Float)            return numarray.ones(pokeShape(args[1]),numarray.Float64)
3939         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):
3940            return Power_Symbol(args[0],args[1])            return Power_Symbol(args[0],args[1])
3941         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):
# Line 3304  def maximum(*args): Line 4036  def maximum(*args):
4036         if out==None:         if out==None:
4037            out=a            out=a
4038         else:         else:
4039            m=whereNegative(out-a)            diff=add(a,-out)
4040            out=m*a+(1.-m)*out            out=add(out,mult(wherePositive(diff),diff))
4041      return out      return out
4042        
4043  def minimum(*arg):  def minimum(*args):
4044      """      """
4045      the minimum over arguments args      the minimum over arguments args
4046    
# Line 3322  def minimum(*arg): Line 4054  def minimum(*arg):
4054         if out==None:         if out==None:
4055            out=a            out=a
4056         else:         else:
4057            m=whereNegative(out-a)            diff=add(a,-out)
4058            out=m*out+(1.-m)*a            out=add(out,mult(whereNegative(diff),diff))
4059      return out      return out
4060    
4061    def clip(arg,minval=0.,maxval=1.):
4062        """
4063        cuts the values of arg between minval and maxval
4064    
4065        @param arg: argument
4066        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float}
4067        @param minval: lower range
4068        @type arg: C{float}
4069        @param maxval: upper range
4070        @type arg: C{float}
4071        @return: is on object with all its value between minval and maxval. value of the argument that greater then minval and
4072                 less then maxval are unchanged.
4073        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float} depending on the input
4074        @raise ValueError: if minval>maxval
4075        """
4076        if minval>maxval:
4077           raise ValueError,"minval = %s must be less then maxval %s"%(minval,maxval)
4078        return minimum(maximum(minval,arg),maxval)
4079    
4080        
4081  def inner(arg0,arg1):  def inner(arg0,arg1):
4082      """      """
# Line 3348  def inner(arg0,arg1): Line 4100  def inner(arg0,arg1):
4100      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
4101      if not sh0==sh1:      if not sh0==sh1:
4102          raise ValueError,"inner: shape of arguments does not match"          raise ValueError,"inner: shape of arguments does not match"
4103      return generalTensorProduct(arg0,arg1,offset=len(sh0))      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))
4104    
4105  def matrixmult(arg0,arg1):  def matrixmult(arg0,arg1):
4106      """      """
# Line 3376  def matrixmult(arg0,arg1): Line 4128  def matrixmult(arg0,arg1):
4128          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
4129      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
4130          raise ValueError,"second argument must have rank 1 or 2"          raise ValueError,"second argument must have rank 1 or 2"
4131      return generalTensorProduct(arg0,arg1,offset=1)      return generalTensorProduct(arg0,arg1,axis_offset=1)
4132    
4133  def outer(arg0,arg1):  def outer(arg0,arg1):
4134      """      """
# Line 3394  def outer(arg0,arg1): Line 4146  def outer(arg0,arg1):
4146      @return: the outer product of arg0 and arg1 at each data point      @return: the outer product of arg0 and arg1 at each data point
4147      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
4148      """      """
4149      return generalTensorProduct(arg0,arg1,offset=0)      return generalTensorProduct(arg0,arg1,axis_offset=0)
4150    
4151    
4152  def tensormult(arg0,arg1):  def tensormult(arg0,arg1):
# Line 3436  def tensormult(arg0,arg1): Line 4188  def tensormult(arg0,arg1):
4188      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
4189      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
4190      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
4191         return generalTensorProduct(arg0,arg1,offset=1)         return generalTensorProduct(arg0,arg1,axis_offset=1)
4192      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
4193         return generalTensorProduct(arg0,arg1,offset=2)         return generalTensorProduct(arg0,arg1,axis_offset=2)
4194      else:      else:
4195          raise ValueError,"tensormult: first argument must have rank 2 or 4"          raise ValueError,"tensormult: first argument must have rank 2 or 4"
4196    
4197  def generalTensorProduct(arg0,arg1,offset=0):  def generalTensorProduct(arg0,arg1,axis_offset=0):
4198      """      """
4199      generalized tensor product      generalized tensor product
4200    
4201      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]
4202    
4203      where s runs through arg0.Shape[:arg0.Rank-offset]      where s runs through arg0.Shape[:arg0.Rank-axis_offset]
4204            r runs trough arg0.Shape[:offset]            r runs trough arg0.Shape[:axis_offset]
4205            t runs through arg1.Shape[offset:]            t runs through arg1.Shape[axis_offset:]
4206    
4207      In the first case the the second dimension of arg0 and the length of arg1 must match and        In the first case the the second dimension of arg0 and the length of arg1 must match and  
4208      in the second case the two last dimensions of arg0 must match the shape of arg1.      in the second case the two last dimensions of arg0 must match the shape of arg1.
# Line 3467  def generalTensorProduct(arg0,arg1,offse Line 4219  def generalTensorProduct(arg0,arg1,offse
4219      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols
4220      if isinstance(arg0,numarray.NumArray):      if isinstance(arg0,numarray.NumArray):
4221         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
4222             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
4223         else:         else:
4224             if not arg0.shape[arg0.rank-offset:]==arg1.shape[:offset]:             if not arg0.shape[arg0.rank-axis_offset:]==arg1.shape[:axis_offset]:
4225                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4226             arg0_c=arg0.copy()             arg0_c=arg0.copy()
4227             arg1_c=arg1.copy()             arg1_c=arg1.copy()
4228             sh0,sh1=arg0.shape,arg1.shape             sh0,sh1=arg0.shape,arg1.shape
4229             d0,d1,d01=1,1,1             d0,d1,d01=1,1,1
4230             for i in sh0[:arg0.rank-offset]: d0*=i             for i in sh0[:arg0.rank-axis_offset]: d0*=i
4231             for i in sh1[offset:]: d1*=i             for i in sh1[axis_offset:]: d1*=i
4232             for i in sh1[:offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
4233             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
4234             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
4235             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float64)
4236             for i0 in range(d0):             for i0 in range(d0):
4237                      for i1 in range(d1):                      for i1 in range(d1):
4238                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])
4239             out.resize(sh0[:arg0.rank-offset]+sh1[offset:])             out.resize(sh0[:arg0.rank-axis_offset]+sh1[axis_offset:])
4240             return out             return out
4241      elif isinstance(arg0,escript.Data):      elif isinstance(arg0,escript.Data):
4242         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
4243             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
4244         else:         else:
4245             return escript_generalTensorProduct(arg0,arg1,offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,offset)             return escript_generalTensorProduct(arg0,arg1,axis_offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,axis_offset)
4246      else:            else:      
4247         return GeneralTensorProduct_Symbol(arg0,arg1,offset)         return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
4248                                    
4249  class GeneralTensorProduct_Symbol(DependendSymbol):  class GeneralTensorProduct_Symbol(DependendSymbol):
4250     """     """
4251     Symbol representing the quotient of two arguments.     Symbol representing the quotient of two arguments.
4252     """     """
4253     def __init__(self,arg0,arg1,offset=0):     def __init__(self,arg0,arg1,axis_offset=0):
4254         """         """
4255         initialization of L{Symbol} representing the quotient of two arguments         initialization of L{Symbol} representing the quotient of two arguments
4256    
# Line 3511  class GeneralTensorProduct_Symbol(Depend Line 4263  class GeneralTensorProduct_Symbol(Depend
4263         """         """
4264         sh_arg0=pokeShape(arg0)         sh_arg0=pokeShape(arg0)
4265         sh_arg1=pokeShape(arg1)         sh_arg1=pokeShape(arg1)
4266         sh0=sh_arg0[:len(sh_arg0)-offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
4267         sh01=sh_arg0[len(sh_arg0)-offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
4268         sh10=sh_arg1[:offset]         sh10=sh_arg1[:axis_offset]
4269         sh1=sh_arg1[offset:]         sh1=sh_arg1[axis_offset:]
4270         if not sh01==sh10:         if not sh01==sh10:
4271             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4272         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,offset])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,axis_offset])
4273    
4274     def getMyCode(self,argstrs,format="escript"):     def getMyCode(self,argstrs,format="escript"):
4275        """        """
# Line 3532  class GeneralTensorProduct_Symbol(Depend Line 4284  class GeneralTensorProduct_Symbol(Depend
4284        @raise: NotImplementedError: if the requested format is not available        @raise: NotImplementedError: if the requested format is not available
4285        """        """
4286        if format=="escript" or format=="str" or format=="text":        if format=="escript" or format=="str" or format=="text":
4287           return "generalTensorProduct(%s,%s,offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])           return "generalTensorProduct(%s,%s,axis_offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])
4288        else:        else:
4289           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)
4290    
# Line 3557  class GeneralTensorProduct_Symbol(Depend Line 4309  class GeneralTensorProduct_Symbol(Depend
4309           args=self.getSubstitutedArguments(argvals)           args=self.getSubstitutedArguments(argvals)
4310           return generalTensorProduct(args[0],args[1],args[2])           return generalTensorProduct(args[0],args[1],args[2])
4311    
4312  def escript_generalTensorProduct(arg0,arg1,offset): # this should be escript._generalTensorProduct  def escript_generalTensorProduct(arg0,arg1,axis_offset): # this should be escript._generalTensorProduct
4313      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"
4314      # calculate the return shape:      # calculate the return shape:
4315      shape0=arg0.getShape()[:arg0.getRank()-offset]      shape0=arg0.getShape()[:arg0.getRank()-axis_offset]
4316      shape01=arg0.getShape()[arg0.getRank()-offset:]      shape01=arg0.getShape()[arg0.getRank()-axis_offset:]
4317      shape10=arg1.getShape()[:offset]      shape10=arg1.getShape()[:axis_offset]
4318      shape1=arg1.getShape()[offset:]      shape1=arg1.getShape()[axis_offset:]
4319      if not shape01==shape10:      if not shape01==shape10:
4320          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4321    
4322        # whatr function space should be used? (this here is not good!)
4323        fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace()
4324      # create return value:      # create return value:
4325      out=escript.Data(0.,tuple(shape0+shape1),arg0.getFunctionSpace())      out=escript.Data(0.,tuple(shape0+shape1),fs)
4326      #      #
4327      s0=[[]]      s0=[[]]
4328      for k in shape0:      for k in shape0:
# Line 3591  def escript_generalTensorProduct(arg0,ar Line 4345  def escript_generalTensorProduct(arg0,ar
4345    
4346      for i0 in s0:      for i0 in s0:
4347         for i1 in s1:         for i1 in s1:
4348           s=escript.Scalar(0.,arg0.getFunctionSpace())           s=escript.Scalar(0.,fs)
4349           for i01 in s01:           for i01 in s01:
4350              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))
4351           out.__setitem__(tuple(i0+i1),s)           out.__setitem__(tuple(i0+i1),s)
4352      return out      return out
4353    
4354    
4355  #=========================================================  #=========================================================
4356  #   some little helpers  #  functions dealing with spatial dependency
4357  #=========================================================  #=========================================================
4358  def grad(arg,where=None):  def grad(arg,where=None):
4359      """      """
4360      Returns the spatial gradient of arg at where.      Returns the spatial gradient of arg at where.
4361    
4362        If C{g} is the returned object, then
4363    
4364          - if C{arg} is rank 0 C{g[s]} is the derivative of C{arg} with respect to the C{s}-th spatial dimension.
4365          - if C{arg} is rank 1 C{g[i,s]} is the derivative of C{arg[i]} with respect to the C{s}-th spatial dimension.
4366          - if C{arg} is rank 2 C{g[i,j,s]} is the derivative of C{arg[i,j]} with respect to the C{s}-th spatial dimension.
4367          - if C{arg} is rank 3 C{g[i,j,k,s]} is the derivative of C{arg[i,j,k]} with respect to the C{s}-th spatial dimension.
4368    
4369      @param arg:   Data object representing the function which gradient      @param arg: function which gradient to be calculated. Its rank has to be less than 3.
4370                    to be calculated.      @type arg: L{escript.Data} or L{Symbol}
4371      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the gradient will be calculated.
4372                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4373        @type where: C{None} or L{escript.FunctionSpace}
4374        @return: gradient of arg.
4375        @rtype:  L{escript.Data} or L{Symbol}
4376      """      """
4377      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
4378         return Grad_Symbol(arg,where)         return Grad_Symbol(arg,where)
# Line 3617  def grad(arg,where=None): Line 4382  def grad(arg,where=None):
4382         else:         else:
4383            return arg._grad(where)            return arg._grad(where)
4384      else:      else:
4385        raise TypeError,"grad: Unknown argument type."         raise TypeError,"grad: Unknown argument type."
4386    
4387    class Grad_Symbol(DependendSymbol):
4388       """
4389       L{Symbol} representing the result of the gradient operator
4390       """
4391       def __init__(self,arg,where=None):
4392          """
4393          initialization of gradient L{Symbol} with argument arg
4394          @param arg: argument of function
4395          @type arg: L{Symbol}.
4396          @param where: FunctionSpace in which the gradient will be calculated.
4397                      If not present or C{None} an appropriate default is used.
4398          @type where: C{None} or L{escript.FunctionSpace}
4399          """
4400          d=arg.getDim()
4401          if d==None:
4402             raise ValueError,"argument must have a spatial dimension"
4403          super(Grad_Symbol,self).__init__(args=[arg,where],shape=arg.getShape()+(d,),dim=d)
4404    
4405       def getMyCode(self,argstrs,format="escript"):
4406          """
4407          returns a program code that can be used to evaluate the symbol.
4408    
4409          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4410          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4411          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4412          @type format: C{str}
4413          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4414          @rtype: C{str}
4415          @raise: NotImplementedError: if the requested format is not available
4416          """
4417          if format=="escript" or format=="str"  or format=="text":
4418             return "grad(%s,where=%s)"%(argstrs[0],argstrs[1])
4419          else:
4420             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4421    
4422       def substitute(self,argvals):
4423          """
4424          assigns new values to symbols in the definition of the symbol.
4425          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4426    
4427          @param argvals: new values assigned to symbols
4428          @type argvals: C{dict} with keywords of type L{Symbol}.
4429          @return: result of the substitution process. Operations are executed as much as possible.
4430          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4431          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4432          """
4433          if argvals.has_key(self):
4434             arg=argvals[self]
4435             if self.isAppropriateValue(arg):
4436                return arg
4437             else:
4438                raise TypeError,"%s: new value is not appropriate."%str(self)
4439          else:
4440             arg=self.getSubstitutedArguments(argvals)
4441             return grad(arg[0],where=arg[1])
4442    
4443       def diff(self,arg):
4444          """
4445          differential of this object
4446    
4447          @param arg: the derivative is calculated with respect to arg
4448          @type arg: L{escript.Symbol}
4449          @return: derivative with respect to C{arg}
4450          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4451          """
4452          if arg==self:
4453             return identity(self.getShape())
4454          else:
4455             return grad(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4456    
4457  def integrate(arg,where=None):  def integrate(arg,where=None):
4458      """      """
4459      Return the integral if the function represented by Data object arg over      Return the integral of the function C{arg} over its domain. If C{where} is present C{arg} is interpolated to C{where}
4460      its domain.      before integration.
4461    
4462      @param arg:   Data object representing the function which is integrated.      @param arg:   the function which is integrated.
4463        @type arg: L{escript.Data} or L{Symbol}
4464      @param where: FunctionSpace in which the integral is calculated.      @param where: FunctionSpace in which the integral is calculated.
4465                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4466        @type where: C{None} or L{escript.FunctionSpace}
4467        @return: integral of arg.
4468        @rtype:  C{float}, C{numarray.NumArray} or L{Symbol}
4469      """      """
4470      if isinstance(arg,numarray.NumArray):      if isinstance(arg,Symbol):
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,float):  
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,int):  
         if checkForZero(arg):  
            return float(arg)  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,Symbol):  
4471         return Integrate_Symbol(arg,where)         return Integrate_Symbol(arg,where)
4472      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4473         if not where==None: arg=escript.Data(arg,where)         if not where==None: arg=escript.Data(arg,where)
# Line 3654  def integrate(arg,where=None): Line 4478  def integrate(arg,where=None):
4478      else:      else:
4479        raise TypeError,"integrate: Unknown argument type."        raise TypeError,"integrate: Unknown argument type."
4480    
4481    class Integrate_Symbol(DependendSymbol):
4482       """
4483       L{Symbol} representing the result of the spatial integration operator
4484       """
4485       def __init__(self,arg,where=None):
4486          """
4487          initialization of integration L{Symbol} with argument arg
4488          @param arg: argument of the integration
4489          @type arg: L{Symbol}.
4490          @param where: FunctionSpace in which the integration will be calculated.
4491                      If not present or C{None} an appropriate default is used.
4492          @type where: C{None} or L{escript.FunctionSpace}
4493          """
4494          super(Integrate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4495    
4496       def getMyCode(self,argstrs,format="escript"):
4497          """
4498          returns a program code that can be used to evaluate the symbol.
4499    
4500          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4501          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4502          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4503          @type format: C{str}
4504          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4505          @rtype: C{str}
4506          @raise: NotImplementedError: if the requested format is not available
4507          """
4508          if format=="escript" or format=="str"  or format=="text":
4509             return "integrate(%s,where=%s)"%(argstrs[0],argstrs[1])
4510          else:
4511             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4512    
4513       def substitute(self,argvals):
4514          """
4515          assigns new values to symbols in the definition of the symbol.
4516          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4517    
4518          @param argvals: new values assigned to symbols
4519          @type argvals: C{dict} with keywords of type L{Symbol}.
4520          @return: result of the substitution process. Operations are executed as much as possible.
4521          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4522          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4523          """
4524          if argvals.has_key(self):
4525             arg=argvals[self]
4526             if self.isAppropriateValue(arg):
4527                return arg
4528             else:
4529                raise TypeError,"%s: new value is not appropriate."%str(self)
4530          else:
4531             arg=self.getSubstitutedArguments(argvals)
4532             return integrate(arg[0],where=arg[1])
4533    
4534       def diff(self,arg):
4535          """
4536          differential of this object
4537    
4538          @param arg: the derivative is calculated with respect to arg
4539          @type arg: L{escript.Symbol}
4540          @return: derivative with respect to C{arg}
4541          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4542          """
4543          if arg==self:
4544             return identity(self.getShape())
4545          else:
4546             return integrate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4547    
4548    
4549  def interpolate(arg,where):  def interpolate(arg,where):
4550      """      """
4551      Interpolates the function into the FunctionSpace where.      interpolates the function into the FunctionSpace where.
4552    
4553      @param arg:    interpolant      @param arg: interpolant
4554      @param where:  FunctionSpace to interpolate to      @type arg: L{escript.Data} or L{Symbol}
4555        @param where: FunctionSpace to be interpolated to
4556        @type where: L{escript.FunctionSpace}
4557        @return: interpolated argument
4558        @rtype:  C{escript.Data} or L{Symbol}
4559      """      """
4560      if testForZero(arg):      if isinstance(arg,Symbol):
4561        return 0         return Interpolate_Symbol(arg,where)
     elif isinstance(arg,Symbol):  
        return Interpolated_Symbol(arg,where)  
4562      else:      else:
4563         return escript.Data(arg,where)         return escript.Data(arg,where)
4564    
4565  def div(arg,where=None):  class Interpolate_Symbol(DependendSymbol):
4566      """     """
4567      Returns the divergence of arg at where.     L{Symbol} representing the result of the interpolation operator
4568       """
4569       def __init__(self,arg,where):
4570          """
4571          initialization of interpolation L{Symbol} with argument arg
4572          @param arg: argument of the interpolation
4573          @type arg: L{Symbol}.
4574          @param where: FunctionSpace into which the argument is interpolated.
4575          @type where: L{escript.FunctionSpace}
4576          """
4577          super(Interpolate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4578    
4579      @param arg:   Data object representing the function which gradient to     def getMyCode(self,argstrs,format="escript"):
4580                    be calculated.        """
4581      @param where: FunctionSpace in which the gradient will be calculated.        returns a program code that can be used to evaluate the symbol.
                   If not present or C{None} an appropriate default is used.  
     """  
     g=grad(arg,where)  
     return trace(g,axis0=g.getRank()-2,axis1=g.getRank()-1)  
4582    
4583  def jump(arg):        @param argstrs: gives for each argument a string representing the argument for the evaluation.
4584      """        @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4585      Returns the jump of arg across a continuity.        @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4586          @type format: C{str}
4587          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4588          @rtype: C{str}
4589          @raise: NotImplementedError: if the requested format is not available
4590          """
4591          if format=="escript" or format=="str"  or format=="text":
4592             return "interpolate(%s,where=%s)"%(argstrs[0],argstrs[1])
4593          else:
4594             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4595    
4596      @param arg:   Data object representing the function which gradient     def substitute(self,argvals):
4597                    to be calculated.        """
4598      """        assigns new values to symbols in the definition of the symbol.
4599      d=arg.getDomain()        The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
     return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())  
4600    
4601  #=============================        @param argvals: new values assigned to symbols
4602  #        @type argvals: C{dict} with keywords of type L{Symbol}.
4603  # wrapper for various functions: if the argument has attribute the function name        @return: result of the substitution process. Operations are executed as much as possible.
4604  # as an argument it calls the corresponding methods. Otherwise the corresponding        @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4605  # numarray function is called.        @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4606          """
4607          if argvals.has_key(self):
4608             arg=argvals[self]
4609             if self.isAppropriateValue(arg):
4610                return arg
4611             else:
4612                raise TypeError,"%s: new value is not appropriate."%str(self)
4613          else:
4614             arg=self.getSubstitutedArguments(argvals)
4615             return interpolate(arg[0],where=arg[1])
4616    
4617  # functions involving the underlying Domain:     def diff(self,arg):
4618          """
4619          differential of this object
4620    
4621          @param arg: the derivative is calculated with respect to arg
4622          @type arg: L{escript.Symbol}
4623          @return: derivative with respect to C{arg}
4624          @rtype: L{Symbol} but other types such as L{escript.Data}, L{numarray.NumArray}  are possible.
4625          """
4626          if arg==self:
4627             return identity(self.getShape())
4628          else:
4629             return interpolate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4630    
4631  def transpose(arg,axis=None):  
4632    def div(arg,where=None):
4633      """      """
4634      Returns the transpose of the Data object arg.      returns the divergence of arg at where.
4635    
4636      @param arg:      @param arg: function which divergence to be calculated. Its shape has to be (d,) where d is the spatial dimension.
4637        @type arg: L{escript.Data} or L{Symbol}
4638        @param where: FunctionSpace in which the divergence will be calculated.
4639                      If not present or C{None} an appropriate default is used.
4640        @type where: C{None} or L{escript.FunctionSpace}
4641        @return: divergence of arg.
4642        @rtype:  L{escript.Data} or L{Symbol}
4643      """      """
     if axis==None:  
        r=0  
        if hasattr(arg,"getRank"): r=arg.getRank()  
        if hasattr(arg,"rank"): r=arg.rank  
        axis=r/2  
4644      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
4645         return Transpose_Symbol(arg,axis=r)          dim=arg.getDim()
4646      if isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4647         # hack for transpose          dim=arg.getDomain().getDim()
        r=arg.getRank()  
        if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"  
        s=arg.getShape()  
        out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())  
        for i in range(s[0]):  
           for j in range(s[1]):  
              out[j,i]=arg[i,j]  
        return out  
        # end hack for transpose  
        return arg.transpose(axis)  
4648      else:      else:
4649         return numarray.transpose(arg,axis=axis)          raise TypeError,"div: argument type not supported"
4650        if not arg.getShape()==(dim,):
4651          raise ValueError,"div: expected shape is (%s,)"%dim
4652        return trace(grad(arg,where))
4653    
4654  def trace(arg,axis0=0,axis1=1):  def jump(arg,domain=None):
4655      """      """
4656      Return      returns the jump of arg across the continuity of the domain
4657    
4658      @param arg:      @param arg: argument
4659        @type arg: L{escript.Data} or L{Symbol}
4660        @param domain: the domain where the discontinuity is located. If domain is not present or equal to C{None}
4661                       the domain of arg is used. If arg is a L{Symbol} the domain must be present.
4662        @type domain: C{None} or L{escript.Domain}
4663        @return: jump of arg
4664        @rtype:  L{escript.Data} or L{Symbol}
4665      """      """
4666      if isinstance(arg,Symbol):      if domain==None: domain=arg.getDomain()
4667         s=list(arg.getShape())              return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))
        s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])  
        return Trace_Symbol(arg,axis0=axis0,axis1=axis1)  
     elif isinstance(arg,escript.Data):  
        # hack for trace  
        s=arg.getShape()  
        if s[axis0]!=s[axis1]:  
            raise ValueError,"illegal axis in trace"  
        out=escript.Scalar(0.,arg.getFunctionSpace())  
        for i in range(s[axis0]):  
           out+=arg[i,i]  
        return out  
        # end hack for trace  
     else:  
        return numarray.trace(arg,axis0=axis0,axis1=axis1)  
4668    
4669    def L2(arg):
4670        """
4671        returns the L2 norm of arg at where
4672        
4673        @param arg: function which L2 to be calculated.
4674        @type arg: L{escript.Data} or L{Symbol}
4675        @return: L2 norm of arg.
4676        @rtype:  L{float} or L{Symbol}
4677        @note: L2(arg) is equivalent to sqrt(integrate(inner(arg,arg)))
4678        """
4679        return sqrt(integrate(inner(arg,arg)))
4680    #=============================
4681    #
4682    
4683  def reorderComponents(arg,index):  def reorderComponents(arg,index):
4684      """      """
4685      resorts the component of arg according to index      resorts the component of arg according to index
4686    
4687      """      """
4688      pass      raise NotImplementedError
4689  #  #
4690  # $Log: util.py,v $  # $Log: util.py,v $
4691  # Revision 1.14.2.16  2005/10/19 06:09:57  gross  # Revision 1.14.2.16  2005/10/19 06:09:57  gross

Legend:
Removed from v.341  
changed lines
  Added in v.637

  ViewVC Help
Powered by ViewVC 1.1.26