/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 341 by gross, Mon Dec 12 05:26:10 2005 UTC revision 698 by gross, Fri Mar 31 04:52:55 2006 UTC
# Line 1  Line 1 
1  # $Id$  # $Id$
 #  
 #      COPYRIGHT ACcESS 2004 -  All Rights Reserved  
 #  
 #   This software is the property of ACcESS.  No part of this code  
 #   may be copied in any form or by any means without the expressed written  
 #   consent of ACcESS.  Copying, use or modification of this software  
 #   by any unauthorised person is illegal unless that  
 #   person has a software license agreement with ACcESS.  
 #  
2    
3  """  """
4  Utility functions for escript  Utility functions for escript
5    
 @remark:  This module is under construction and is still tested!!!  
   
6  @var __author__: name of author  @var __author__: name of author
7  @var __licence__: licence agreement  @var __copyright__: copyrights
8    @var __license__: licence agreement
9  @var __url__: url entry point on documentation  @var __url__: url entry point on documentation
10  @var __version__: version  @var __version__: version
11  @var __date__: date of the version  @var __date__: date of the version
12  """  """
13                                                                                                                                                                                                                                                                                                                                                                                                            
14  __author__="Lutz Gross, l.gross@uq.edu.au"  __author__="Lutz Gross, l.gross@uq.edu.au"
15  __licence__="contact: esys@access.uq.edu.au"  __copyright__="""  Copyright (c) 2006 by ACcESS MNRF
16                        http://www.access.edu.au
17                    Primary Business: Queensland, Australia"""
18    __license__="""Licensed under the Open Software License version 3.0
19                 http://www.opensource.org/licenses/osl-3.0.php"""
20  __url__="http://www.iservo.edu.au/esys/escript"  __url__="http://www.iservo.edu.au/esys/escript"
21  __version__="$Revision: 329 $"  __version__="$Revision$"
22  __date__="$Date$"  __date__="$Date$"
23    
24    
# Line 33  import numarray Line 27  import numarray
27  import escript  import escript
28  import os  import os
29    
 # missing tests:  
   
 # def pokeShape(arg):  
 # def pokeDim(arg):  
 # def commonShape(arg0,arg1):  
 # def commonDim(*args):  
 # def testForZero(arg):  
 # def matchType(arg0=0.,arg1=0.):  
 # def matchShape(arg0,arg1):  
   
 # def maximum(arg0,arg1):  
 # def minimum(arg0,arg1):  
   
 # def transpose(arg,axis=None):  
 # def trace(arg,axis0=0,axis1=1):  
 # def reorderComponents(arg,index):  
   
 # def integrate(arg,where=None):  
 # def interpolate(arg,where):  
 # def div(arg,where=None):  
 # def grad(arg,where=None):  
   
 #  
 # slicing: get  
 #          set  
 #  
 # and derivatives  
   
30  #=========================================================  #=========================================================
31  #   some helpers:  #   some helpers:
32  #=========================================================  #=========================================================
# Line 125  def kronecker(d=3): Line 91  def kronecker(d=3):
91     return the kronecker S{delta}-symbol     return the kronecker S{delta}-symbol
92    
93     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
94     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
95     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
96     @rtype d: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2.
    @remark: the function is identical L{identity}  
97     """     """
98     return identityTensor(d)     return identityTensor(d)
99    
# Line 143  def identity(shape=()): Line 108  def identity(shape=()):
108     @raise ValueError: if len(shape)>2.     @raise ValueError: if len(shape)>2.
109     """     """
110     if len(shape)>0:     if len(shape)>0:
111        out=numarray.zeros(shape+shape,numarray.Float)        out=numarray.zeros(shape+shape,numarray.Float64)
112        if len(shape)==1:        if len(shape)==1:
113            for i0 in range(shape[0]):            for i0 in range(shape[0]):
114               out[i0,i0]=1.               out[i0,i0]=1.
   
115        elif len(shape)==2:        elif len(shape)==2:
116            for i0 in range(shape[0]):            for i0 in range(shape[0]):
117               for i1 in range(shape[1]):               for i1 in range(shape[1]):
# Line 163  def identityTensor(d=3): Line 127  def identityTensor(d=3):
127     return the dxd identity matrix     return the dxd identity matrix
128    
129     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
130     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
131     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise     @return: the object u of rank 2 with M{u[i,j]=1} for M{i=j} and M{u[i,j]=0} otherwise
132     @rtype: L{numarray.NumArray} of rank 2.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 2
133     """     """
134     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
135        d=d.getDim()         return escript.Data(identity((d.getDim(),)),d)
136     return identity(shape=(d,))     elif isinstance(d,escript.Domain):
137           return identity((d.getDim(),))
138       else:
139           return identity((d,))
140    
141  def identityTensor4(d=3):  def identityTensor4(d=3):
142     """     """
# Line 178  def identityTensor4(d=3): Line 145  def identityTensor4(d=3):
145     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
146     @type d: C{int} or any object with a C{getDim} method     @type d: C{int} or any object with a C{getDim} method
147     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise     @return: the object u of rank 4 with M{u[i,j,k,l]=1} for M{i=k and j=l} and M{u[i,j,k,l]=0} otherwise
148     @rtype: L{numarray.NumArray} of rank 4.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 4.
149     """     """
150     if hasattr(d,"getDim"):     if isinstance(d,escript.FunctionSpace):
151        d=d.getDim()         return escript.Data(identity((d.getDim(),d.getDim())),d)
152     return identity((d,d))     elif isinstance(d,escript.Domain):
153           return identity((d.getDim(),d.getDim()))
154       else:
155           return identity((d,d))
156    
157  def unitVector(i=0,d=3):  def unitVector(i=0,d=3):
158     """     """
# Line 191  def unitVector(i=0,d=3): Line 161  def unitVector(i=0,d=3):
161     @param i: index     @param i: index
162     @type i: C{int}     @type i: C{int}
163     @param d: dimension or an object that has the C{getDim} method defining the dimension     @param d: dimension or an object that has the C{getDim} method defining the dimension
164     @type d: C{int} or any object with a C{getDim} method     @type d: C{int}, L{escript.Domain} or L{escript.FunctionSpace}
165     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise     @return: the object u of rank 1 with M{u[j]=1} for M{j=i} and M{u[i]=0} otherwise
166     @rtype: L{numarray.NumArray} of rank 1.     @rtype d: L{numarray.NumArray} or L{escript.Data} of rank 1
167     """     """
168     return kronecker(d)[i]     return kronecker(d)[i]
169    
# Line 363  def testForZero(arg): Line 333  def testForZero(arg):
333      @return : True if the argument is identical to zero.      @return : True if the argument is identical to zero.
334      @rtype : C{bool}      @rtype : C{bool}
335      """      """
336      try:      if isinstance(arg,numarray.NumArray):
337           return not Lsup(arg)>0.
338        elif isinstance(arg,escript.Data):
339           return False
340        elif isinstance(arg,float):
341           return not Lsup(arg)>0.
342        elif isinstance(arg,int):
343         return not Lsup(arg)>0.         return not Lsup(arg)>0.
344      except TypeError:      elif isinstance(arg,Symbol):
345           return False
346        else:
347         return False         return False
348    
349  def matchType(arg0=0.,arg1=0.):  def matchType(arg0=0.,arg1=0.):
# Line 386  def matchType(arg0=0.,arg1=0.): Line 364  def matchType(arg0=0.,arg1=0.):
364         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
365            arg0=escript.Data(arg0,arg1.getFunctionSpace())            arg0=escript.Data(arg0,arg1.getFunctionSpace())
366         elif isinstance(arg1,float):         elif isinstance(arg1,float):
367            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
368         elif isinstance(arg1,int):         elif isinstance(arg1,int):
369            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
370         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
371            pass            pass
372         else:         else:
# Line 412  def matchType(arg0=0.,arg1=0.): Line 390  def matchType(arg0=0.,arg1=0.):
390         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
391            pass            pass
392         elif isinstance(arg1,float):         elif isinstance(arg1,float):
393            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
394         elif isinstance(arg1,int):         elif isinstance(arg1,int):
395            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
396         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
397            pass            pass
398         else:         else:
399            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
400      elif isinstance(arg0,float):      elif isinstance(arg0,float):
401         if isinstance(arg1,numarray.NumArray):         if isinstance(arg1,numarray.NumArray):
402            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
403         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
404            arg0=escript.Data(arg0,arg1.getFunctionSpace())            arg0=escript.Data(arg0,arg1.getFunctionSpace())
405         elif isinstance(arg1,float):         elif isinstance(arg1,float):
406            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
407            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
408         elif isinstance(arg1,int):         elif isinstance(arg1,int):
409            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
410            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
411         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
412            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
413         else:         else:
414            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
415      elif isinstance(arg0,int):      elif isinstance(arg0,int):
416         if isinstance(arg1,numarray.NumArray):         if isinstance(arg1,numarray.NumArray):
417            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
418         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
419            arg0=escript.Data(float(arg0),arg1.getFunctionSpace())            arg0=escript.Data(float(arg0),arg1.getFunctionSpace())
420         elif isinstance(arg1,float):         elif isinstance(arg1,float):
421            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
422            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
423         elif isinstance(arg1,int):         elif isinstance(arg1,int):
424            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
425            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
426         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
427            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
428         else:         else:
429            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
430      else:      else:
# Line 469  def matchShape(arg0,arg1): Line 447  def matchShape(arg0,arg1):
447      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
448      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
449      if len(sh0)<len(sh):      if len(sh0)<len(sh):
450         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float)),arg1         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float64)),arg1
451      elif len(sh1)<len(sh):      elif len(sh1)<len(sh):
452         return arg0,outer(arg1,numarray.ones(sh[len(sh1):],numarray.Float))         return arg0,outer(arg1,numarray.ones(sh[len(sh1):],numarray.Float64))
453      else:      else:
454         return arg0,arg1         return arg0,arg1
455  #=========================================================  #=========================================================
# Line 597  class Symbol(object): Line 575  class Symbol(object):
575            else:            else:
576                s=pokeShape(s)+arg.getShape()                s=pokeShape(s)+arg.getShape()
577                if len(s)>0:                if len(s)>0:
578                   out.append(numarray.zeros(s),numarray.Float)                   out.append(numarray.zeros(s),numarray.Float64)
579                else:                else:
580                   out.append(a)                   out.append(a)
581         return out         return out
# Line 687  class Symbol(object): Line 665  class Symbol(object):
665         else:         else:
666            s=self.getShape()+arg.getShape()            s=self.getShape()+arg.getShape()
667            if len(s)>0:            if len(s)>0:
668               return numarray.zeros(s,numarray.Float)               return numarray.zeros(s,numarray.Float64)
669            else:            else:
670               return 0.               return 0.
671    
# Line 825  class Symbol(object): Line 803  class Symbol(object):
803         """         """
804         return power(other,self)         return power(other,self)
805    
806       def __getitem__(self,index):
807           """
808           returns the slice defined by index
809    
810           @param index: defines a
811           @type index: C{slice} or C{int} or a C{tuple} of them
812           @return: a S{Symbol} representing the slice defined by index
813           @rtype: L{DependendSymbol}
814           """
815           return GetSlice_Symbol(self,index)
816    
817  class DependendSymbol(Symbol):  class DependendSymbol(Symbol):
818     """     """
819     DependendSymbol extents L{Symbol} by modifying the == operator to allow two instances to be equal.     DependendSymbol extents L{Symbol} by modifying the == operator to allow two instances to be equal.
# Line 875  class DependendSymbol(Symbol): Line 864  class DependendSymbol(Symbol):
864  #=========================================================  #=========================================================
865  #  Unary operations prserving the shape  #  Unary operations prserving the shape
866  #========================================================  #========================================================
867    class GetSlice_Symbol(DependendSymbol):
868       """
869       L{Symbol} representing getting a slice for a L{Symbol}
870       """
871       def __init__(self,arg,index):
872          """
873          initialization of wherePositive L{Symbol} with argument arg
874          @param arg: argument
875          @type arg: L{Symbol}.
876          @param index: defines index
877          @type index: C{slice} or C{int} or a C{tuple} of them
878          @raises IndexError: if length of index is larger than rank of arg or a index start or stop is out of range
879          @raises ValueError: if a step is given
880          """
881          if not isinstance(index,tuple): index=(index,)
882          if len(index)>arg.getRank():
883               raise IndexError,"GetSlice_Symbol: index out of range."
884          sh=()
885          index2=()
886          for i in range(len(index)):
887             ix=index[i]
888             if isinstance(ix,int):
889                if ix<0 or ix>=arg.getShape()[i]:
890                   raise ValueError,"GetSlice_Symbol: index out of range."
891                index2=index2+(ix,)
892             else:
893               if not ix.step==None:
894                 raise ValueError,"GetSlice_Symbol: steping is not supported."
895               if ix.start==None:
896                  s=0
897               else:
898                  s=ix.start
899               if ix.stop==None:
900                  e=arg.getShape()[i]
901               else:
902                  e=ix.stop
903                  if e>arg.getShape()[i]:
904                     raise IndexError,"GetSlice_Symbol: index out of range."
905               index2=index2+(slice(s,e),)
906               if e>s:
907                   sh=sh+(e-s,)
908               elif s>e:
909                   raise IndexError,"GetSlice_Symbol: slice start must be less or equal slice end"
910          for i in range(len(index),arg.getRank()):
911              index2=index2+(slice(0,arg.getShape()[i]),)
912              sh=sh+(arg.getShape()[i],)
913          super(GetSlice_Symbol, self).__init__(args=[arg,index2],shape=sh,dim=arg.getDim())
914    
915       def getMyCode(self,argstrs,format="escript"):
916          """
917          returns a program code that can be used to evaluate the symbol.
918    
919          @param argstrs: gives for each argument a string representing the argument for the evaluation.
920          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
921          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
922          @type format: C{str}
923          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
924          @rtype: C{str}
925          @raise: NotImplementedError: if the requested format is not available
926          """
927          if format=="escript" or format=="str"  or format=="text":
928             return "%s.__getitem__(%s)"%(argstrs[0],argstrs[1])
929          else:
930             raise NotImplementedError,"GetItem_Symbol does not provide program code for format %s."%format
931    
932       def substitute(self,argvals):
933          """
934          assigns new values to symbols in the definition of the symbol.
935          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
936    
937          @param argvals: new values assigned to symbols
938          @type argvals: C{dict} with keywords of type L{Symbol}.
939          @return: result of the substitution process. Operations are executed as much as possible.
940          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
941          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
942          """
943          if argvals.has_key(self):
944             arg=argvals[self]
945             if self.isAppropriateValue(arg):
946                return arg
947             else:
948                raise TypeError,"%s: new value is not appropriate."%str(self)
949          else:
950             args=self.getSubstitutedArguments(argvals)
951             arg=args[0]
952             index=args[1]
953             return arg.__getitem__(index)
954    
955  def log10(arg):  def log10(arg):
956     """     """
957     returns base-10 logarithm of argument arg     returns base-10 logarithm of argument arg
# Line 907  def wherePositive(arg): Line 984  def wherePositive(arg):
984     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
985     """     """
986     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
987        if arg.rank==0:        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
988           if arg>0:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
989             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))  
990     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
991        return arg._wherePositive()        return arg._wherePositive()
992     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 993  def whereNegative(arg): Line 1066  def whereNegative(arg):
1066     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1067     """     """
1068     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1069        if arg.rank==0:        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1070           if arg<0:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1071             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))  
1072     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1073        return arg._whereNegative()        return arg._whereNegative()
1074     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1079  def whereNonNegative(arg): Line 1148  def whereNonNegative(arg):
1148     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1149     """     """
1150     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1151        if arg.rank==0:        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1152           if arg<0:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1153             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))  
1154     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1155        return arg._whereNonNegative()        return arg._whereNonNegative()
1156     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1113  def whereNonPositive(arg): Line 1178  def whereNonPositive(arg):
1178     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1179     """     """
1180     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1181        if arg.rank==0:        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1182           if arg>0:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1183             return numarray.array(0.)        return out
          else:  
            return numarray.array(1.)  
       else:  
          return numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.  
1184     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1185        return arg._whereNonPositive()        return arg._whereNonPositive()
1186     elif isinstance(arg,float):     elif isinstance(arg,float):
# Line 1149  def whereZero(arg,tol=0.): Line 1210  def whereZero(arg,tol=0.):
1210     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1211     """     """
1212     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1213        if arg.rank==0:        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float64))*1.
1214           if abs(arg)<=tol:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1215             return numarray.array(1.)        return out
          else:  
            return numarray.array(0.)  
       else:  
          return numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1216     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1217        if tol>0.:        return arg._whereZero(tol)
          return whereNegative(abs(arg)-tol)  
       else:  
          return arg._whereZero()  
1218     elif isinstance(arg,float):     elif isinstance(arg,float):
1219        if abs(arg)<=tol:        if abs(arg)<=tol:
1220          return 1.          return 1.
# Line 1236  def whereNonZero(arg,tol=0.): Line 1290  def whereNonZero(arg,tol=0.):
1290     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1291     """     """
1292     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1293        if arg.rank==0:        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float64))*1.
1294          if abs(arg)>tol:        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1295             return numarray.array(1.)        return out
         else:  
            return numarray.array(0.)  
       else:  
          return numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.  
1296     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1297        if tol>0.:        return arg._whereNonZero(tol)
          return 1.-whereZero(arg,tol)  
       else:  
          return arg._whereNonZero()  
1298     elif isinstance(arg,float):     elif isinstance(arg,float):
1299        if abs(arg)>tol:        if abs(arg)>tol:
1300          return 1.          return 1.
# Line 2877  def length(arg): Line 2924  def length(arg):
2924     """     """
2925     return sqrt(inner(arg,arg))     return sqrt(inner(arg,arg))
2926    
2927    def trace(arg,axis_offset=0):
2928       """
2929       returns the trace of arg which the sum of arg[k,k] over k.
2930    
2931       @param arg: argument
2932       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}.
2933       @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
2934                      axis_offset and axis_offset+1 must be equal.
2935       @type axis_offset: C{int}
2936       @return: trace of arg. The rank of the returned object is minus 2 of the rank of arg.
2937       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.
2938       """
2939       if isinstance(arg,numarray.NumArray):
2940          sh=arg.shape
2941          if len(sh)<2:
2942            raise ValueError,"trace: rank of argument must be greater than 1"
2943          if axis_offset<0 or axis_offset>len(sh)-2:
2944            raise ValueError,"trace: axis_offset must be between 0 and %s"%len(sh)-2
2945          s1=1
2946          for i in range(axis_offset): s1*=sh[i]
2947          s2=1
2948          for i in range(axis_offset+2,len(sh)): s2*=sh[i]
2949          if not sh[axis_offset] == sh[axis_offset+1]:
2950            raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2951          arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2952          out=numarray.zeros([s1,s2],numarray.Float64)
2953          for i1 in range(s1):
2954            for i2 in range(s2):
2955                for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
2956          out.resize(sh[:axis_offset]+sh[axis_offset+2:])
2957          return out
2958       elif isinstance(arg,escript.Data):
2959          return escript_trace(arg,axis_offset)
2960       elif isinstance(arg,float):
2961          raise TypeError,"trace: illegal argument type float."
2962       elif isinstance(arg,int):
2963          raise TypeError,"trace: illegal argument type int."
2964       elif isinstance(arg,Symbol):
2965          return Trace_Symbol(arg,axis_offset)
2966       else:
2967          raise TypeError,"trace: Unknown argument type."
2968    
2969    def escript_trace(arg,axis_offset): # this should be escript._trace
2970          "arg si a Data objects!!!"
2971          if arg.getRank()<2:
2972            raise ValueError,"escript_trace: rank of argument must be greater than 1"
2973          if axis_offset<0 or axis_offset>arg.getRank()-2:
2974            raise ValueError,"escript_trace: axis_offset must be between 0 and %s"%arg.getRank()-2
2975          s=list(arg.getShape())        
2976          if not s[axis_offset] == s[axis_offset+1]:
2977            raise ValueError,"escript_trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2978          out=escript.Data(0.,tuple(s[0:axis_offset]+s[axis_offset+2:]),arg.getFunctionSpace())
2979          if arg.getRank()==2:
2980             for i0 in range(s[0]):
2981                out+=arg[i0,i0]
2982          elif arg.getRank()==3:
2983             if axis_offset==0:
2984                for i0 in range(s[0]):
2985                      for i2 in range(s[2]):
2986                             out[i2]+=arg[i0,i0,i2]
2987             elif axis_offset==1:
2988                for i0 in range(s[0]):
2989                   for i1 in range(s[1]):
2990                             out[i0]+=arg[i0,i1,i1]
2991          elif arg.getRank()==4:
2992             if axis_offset==0:
2993                for i0 in range(s[0]):
2994                      for i2 in range(s[2]):
2995                         for i3 in range(s[3]):
2996                             out[i2,i3]+=arg[i0,i0,i2,i3]
2997             elif axis_offset==1:
2998                for i0 in range(s[0]):
2999                   for i1 in range(s[1]):
3000                         for i3 in range(s[3]):
3001                             out[i0,i3]+=arg[i0,i1,i1,i3]
3002             elif axis_offset==2:
3003                for i0 in range(s[0]):
3004                   for i1 in range(s[1]):
3005                      for i2 in range(s[2]):
3006                             out[i0,i1]+=arg[i0,i1,i2,i2]
3007          return out
3008    class Trace_Symbol(DependendSymbol):
3009       """
3010       L{Symbol} representing the result of the trace function
3011       """
3012       def __init__(self,arg,axis_offset=0):
3013          """
3014          initialization of trace L{Symbol} with argument arg
3015          @param arg: argument of function
3016          @type arg: L{Symbol}.
3017          @param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component
3018                      axis_offset and axis_offset+1 must be equal.
3019          @type axis_offset: C{int}
3020          """
3021          if arg.getRank()<2:
3022            raise ValueError,"Trace_Symbol: rank of argument must be greater than 1"
3023          if axis_offset<0 or axis_offset>arg.getRank()-2:
3024            raise ValueError,"Trace_Symbol: axis_offset must be between 0 and %s"%arg.getRank()-2
3025          s=list(arg.getShape())        
3026          if not s[axis_offset] == s[axis_offset+1]:
3027            raise ValueError,"Trace_Symbol: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
3028          super(Trace_Symbol,self).__init__(args=[arg,axis_offset],shape=tuple(s[0:axis_offset]+s[axis_offset+2:]),dim=arg.getDim())
3029    
3030       def getMyCode(self,argstrs,format="escript"):
3031          """
3032          returns a program code that can be used to evaluate the symbol.
3033    
3034          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3035          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3036          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3037          @type format: C{str}
3038          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3039          @rtype: C{str}
3040          @raise: NotImplementedError: if the requested format is not available
3041          """
3042          if format=="escript" or format=="str"  or format=="text":
3043             return "trace(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
3044          else:
3045             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
3046    
3047       def substitute(self,argvals):
3048          """
3049          assigns new values to symbols in the definition of the symbol.
3050          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3051    
3052          @param argvals: new values assigned to symbols
3053          @type argvals: C{dict} with keywords of type L{Symbol}.
3054          @return: result of the substitution process. Operations are executed as much as possible.
3055          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3056          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3057          """
3058          if argvals.has_key(self):
3059             arg=argvals[self]
3060             if self.isAppropriateValue(arg):
3061                return arg
3062             else:
3063                raise TypeError,"%s: new value is not appropriate."%str(self)
3064          else:
3065             arg=self.getSubstitutedArguments(argvals)
3066             return trace(arg[0],axis_offset=arg[1])
3067    
3068       def diff(self,arg):
3069          """
3070          differential of this object
3071    
3072          @param arg: the derivative is calculated with respect to arg
3073          @type arg: L{escript.Symbol}
3074          @return: derivative with respect to C{arg}
3075          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3076          """
3077          if arg==self:
3078             return identity(self.getShape())
3079          else:
3080             return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3081    
3082    def transpose(arg,axis_offset=None):
3083       """
3084       returns the transpose of arg by swaping the first axis_offset and the last rank-axis_offset components.
3085    
3086       @param arg: argument
3087       @type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}, C{float}, C{int}
3088       @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3089                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3090       @type axis_offset: C{int}
3091       @return: transpose of arg
3092       @rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray},C{float}, C{int} depending on the type of arg.
3093       """
3094       if isinstance(arg,numarray.NumArray):
3095          if axis_offset==None: axis_offset=int(arg.rank/2)
3096          return numarray.transpose(arg,axes=range(axis_offset,arg.rank)+range(0,axis_offset))
3097       elif isinstance(arg,escript.Data):
3098          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3099          return escript_transpose(arg,axis_offset)
3100       elif isinstance(arg,float):
3101          if not ( axis_offset==0 or axis_offset==None):
3102            raise ValueError,"transpose: axis_offset must be 0 for float argument"
3103          return arg
3104       elif isinstance(arg,int):
3105          if not ( axis_offset==0 or axis_offset==None):
3106            raise ValueError,"transpose: axis_offset must be 0 for int argument"
3107          return float(arg)
3108       elif isinstance(arg,Symbol):
3109          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3110          return Transpose_Symbol(arg,axis_offset)
3111       else:
3112          raise TypeError,"transpose: Unknown argument type."
3113    
3114    def escript_transpose(arg,axis_offset): # this should be escript._transpose
3115          "arg si a Data objects!!!"
3116          r=arg.getRank()
3117          if axis_offset<0 or axis_offset>r:
3118            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3119          s=arg.getShape()
3120          s_out=s[axis_offset:]+s[:axis_offset]
3121          out=escript.Data(0.,s_out,arg.getFunctionSpace())
3122          if r==4:
3123             if axis_offset==1:
3124                for i0 in range(s_out[0]):
3125                   for i1 in range(s_out[1]):
3126                      for i2 in range(s_out[2]):
3127                         for i3 in range(s_out[3]):
3128                             out[i0,i1,i2,i3]=arg[i3,i0,i1,i2]
3129             elif axis_offset==2:
3130                for i0 in range(s_out[0]):
3131                   for i1 in range(s_out[1]):
3132                      for i2 in range(s_out[2]):
3133                         for i3 in range(s_out[3]):
3134                             out[i0,i1,i2,i3]=arg[i2,i3,i0,i1]
3135             elif axis_offset==3:
3136                for i0 in range(s_out[0]):
3137                   for i1 in range(s_out[1]):
3138                      for i2 in range(s_out[2]):
3139                         for i3 in range(s_out[3]):
3140                             out[i0,i1,i2,i3]=arg[i1,i2,i3,i0]
3141             else:
3142                for i0 in range(s_out[0]):
3143                   for i1 in range(s_out[1]):
3144                      for i2 in range(s_out[2]):
3145                         for i3 in range(s_out[3]):
3146                             out[i0,i1,i2,i3]=arg[i0,i1,i2,i3]
3147          elif r==3:
3148             if axis_offset==1:
3149                for i0 in range(s_out[0]):
3150                   for i1 in range(s_out[1]):
3151                      for i2 in range(s_out[2]):
3152                             out[i0,i1,i2]=arg[i2,i0,i1]
3153             elif axis_offset==2:
3154                for i0 in range(s_out[0]):
3155                   for i1 in range(s_out[1]):
3156                      for i2 in range(s_out[2]):
3157                             out[i0,i1,i2]=arg[i1,i2,i0]
3158             else:
3159                for i0 in range(s_out[0]):
3160                   for i1 in range(s_out[1]):
3161                      for i2 in range(s_out[2]):
3162                             out[i0,i1,i2]=arg[i0,i1,i2]
3163          elif r==2:
3164             if axis_offset==1:
3165                for i0 in range(s_out[0]):
3166                   for i1 in range(s_out[1]):
3167                             out[i0,i1]=arg[i1,i0]
3168             else:
3169                for i0 in range(s_out[0]):
3170                   for i1 in range(s_out[1]):
3171                             out[i0,i1]=arg[i0,i1]
3172          elif r==1:
3173              for i0 in range(s_out[0]):
3174                   out[i0]=arg[i0]
3175          elif r==0:
3176                 out=arg+0.
3177          return out
3178    class Transpose_Symbol(DependendSymbol):
3179       """
3180       L{Symbol} representing the result of the transpose function
3181       """
3182       def __init__(self,arg,axis_offset=None):
3183          """
3184          initialization of transpose L{Symbol} with argument arg
3185    
3186          @param arg: argument of function
3187          @type arg: L{Symbol}.
3188           @param axis_offset: the first axis_offset components are swapped with rest. If C{axis_offset} must be non-negative and less or equal the rank of arg.
3189                           if axis_offset is not present C{int(r/2)} where r is the rank of arg is used.
3190          @type axis_offset: C{int}
3191          """
3192          if axis_offset==None: axis_offset=int(arg.getRank()/2)
3193          if axis_offset<0 or axis_offset>arg.getRank():
3194            raise ValueError,"escript_transpose: axis_offset must be between 0 and %s"%r
3195          s=arg.getShape()
3196          super(Transpose_Symbol,self).__init__(args=[arg,axis_offset],shape=s[axis_offset:]+s[:axis_offset],dim=arg.getDim())
3197    
3198       def getMyCode(self,argstrs,format="escript"):
3199          """
3200          returns a program code that can be used to evaluate the symbol.
3201    
3202          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3203          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3204          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3205          @type format: C{str}
3206          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3207          @rtype: C{str}
3208          @raise: NotImplementedError: if the requested format is not available
3209          """
3210          if format=="escript" or format=="str"  or format=="text":
3211             return "transpose(%s,axis_offset=%s)"%(argstrs[0],argstrs[1])
3212          else:
3213             raise NotImplementedError,"Transpose_Symbol does not provide program code for format %s."%format
3214    
3215       def substitute(self,argvals):
3216          """
3217          assigns new values to symbols in the definition of the symbol.
3218          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3219    
3220          @param argvals: new values assigned to symbols
3221          @type argvals: C{dict} with keywords of type L{Symbol}.
3222          @return: result of the substitution process. Operations are executed as much as possible.
3223          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3224          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3225          """
3226          if argvals.has_key(self):
3227             arg=argvals[self]
3228             if self.isAppropriateValue(arg):
3229                return arg
3230             else:
3231                raise TypeError,"%s: new value is not appropriate."%str(self)
3232          else:
3233             arg=self.getSubstitutedArguments(argvals)
3234             return transpose(arg[0],axis_offset=arg[1])
3235    
3236       def diff(self,arg):
3237          """
3238          differential of this object
3239    
3240          @param arg: the derivative is calculated with respect to arg
3241          @type arg: L{escript.Symbol}
3242          @return: derivative with respect to C{arg}
3243          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3244          """
3245          if arg==self:
3246             return identity(self.getShape())
3247          else:
3248             return transpose(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1])
3249    def symmetric(arg):
3250        """
3251        returns the symmetric part of the square matrix arg. This is (arg+transpose(arg))/2
3252    
3253        @param arg: square matrix. Must have rank 2 or 4 and be square.
3254        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3255        @return: symmetric part of arg
3256        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3257        """
3258        if isinstance(arg,numarray.NumArray):
3259          if arg.rank==2:
3260            if not (arg.shape[0]==arg.shape[1]):
3261               raise ValueError,"symmetric: argument must be square."
3262          elif arg.rank==4:
3263            if not (arg.shape[0]==arg.shape[2] and arg.shape[1]==arg.shape[3]):
3264               raise ValueError,"symmetric: argument must be square."
3265          else:
3266            raise ValueError,"symmetric: rank 2 or 4 is required."
3267          return (arg+transpose(arg))/2
3268        elif isinstance(arg,escript.Data):
3269          return escript_symmetric(arg)
3270        elif isinstance(arg,float):
3271          return arg
3272        elif isinstance(arg,int):
3273          return float(arg)
3274        elif isinstance(arg,Symbol):
3275          if arg.getRank()==2:
3276            if not (arg.getShape()[0]==arg.getShape()[1]):
3277               raise ValueError,"symmetric: argument must be square."
3278          elif arg.getRank()==4:
3279            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3280               raise ValueError,"symmetric: argument must be square."
3281          else:
3282            raise ValueError,"symmetric: rank 2 or 4 is required."
3283          return (arg+transpose(arg))/2
3284        else:
3285          raise TypeError,"symmetric: Unknown argument type."
3286    
3287    def escript_symmetric(arg): # this should be implemented in c++
3288          if arg.getRank()==2:
3289            if not (arg.getShape()[0]==arg.getShape()[1]):
3290               raise ValueError,"escript_symmetric: argument must be square."
3291            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3292            for i0 in range(arg.getShape()[0]):
3293               for i1 in range(arg.getShape()[1]):
3294                  out[i0,i1]=(arg[i0,i1]+arg[i1,i0])/2.
3295          elif arg.getRank()==4:
3296            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3297               raise ValueError,"escript_symmetric: argument must be square."
3298            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3299            for i0 in range(arg.getShape()[0]):
3300               for i1 in range(arg.getShape()[1]):
3301                  for i2 in range(arg.getShape()[2]):
3302                     for i3 in range(arg.getShape()[3]):
3303                         out[i0,i1,i2,i3]=(arg[i0,i1,i2,i3]+arg[i2,i3,i0,i1])/2.
3304          else:
3305            raise ValueError,"escript_symmetric: rank 2 or 4 is required."
3306          return out
3307    
3308    def nonsymmetric(arg):
3309        """
3310        returns the nonsymmetric part of the square matrix arg. This is (arg-transpose(arg))/2
3311    
3312        @param arg: square matrix. Must have rank 2 or 4 and be square.
3313        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3314        @return: nonsymmetric part of arg
3315        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3316        """
3317        if isinstance(arg,numarray.NumArray):
3318          if arg.rank==2:
3319            if not (arg.shape[0]==arg.shape[1]):
3320               raise ValueError,"nonsymmetric: argument must be square."
3321          elif arg.rank==4:
3322            if not (arg.shape[0]==arg.shape[2] and arg.shape[1]==arg.shape[3]):
3323               raise ValueError,"nonsymmetric: argument must be square."
3324          else:
3325            raise ValueError,"nonsymmetric: rank 2 or 4 is required."
3326          return (arg-transpose(arg))/2
3327        elif isinstance(arg,escript.Data):
3328          return escript_nonsymmetric(arg)
3329        elif isinstance(arg,float):
3330          return arg
3331        elif isinstance(arg,int):
3332          return float(arg)
3333        elif isinstance(arg,Symbol):
3334          if arg.getRank()==2:
3335            if not (arg.getShape()[0]==arg.getShape()[1]):
3336               raise ValueError,"nonsymmetric: argument must be square."
3337          elif arg.getRank()==4:
3338            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3339               raise ValueError,"nonsymmetric: argument must be square."
3340          else:
3341            raise ValueError,"nonsymmetric: rank 2 or 4 is required."
3342          return (arg-transpose(arg))/2
3343        else:
3344          raise TypeError,"nonsymmetric: Unknown argument type."
3345    
3346    def escript_nonsymmetric(arg): # this should be implemented in c++
3347          if arg.getRank()==2:
3348            if not (arg.getShape()[0]==arg.getShape()[1]):
3349               raise ValueError,"escript_nonsymmetric: argument must be square."
3350            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3351            for i0 in range(arg.getShape()[0]):
3352               for i1 in range(arg.getShape()[1]):
3353                  out[i0,i1]=(arg[i0,i1]-arg[i1,i0])/2.
3354          elif arg.getRank()==4:
3355            if not (arg.getShape()[0]==arg.getShape()[2] and arg.getShape()[1]==arg.getShape()[3]):
3356               raise ValueError,"escript_nonsymmetric: argument must be square."
3357            out=escript.Data(0.,arg.getShape(),arg.getFunctionSpace())
3358            for i0 in range(arg.getShape()[0]):
3359               for i1 in range(arg.getShape()[1]):
3360                  for i2 in range(arg.getShape()[2]):
3361                     for i3 in range(arg.getShape()[3]):
3362                         out[i0,i1,i2,i3]=(arg[i0,i1,i2,i3]-arg[i2,i3,i0,i1])/2.
3363          else:
3364            raise ValueError,"escript_nonsymmetric: rank 2 or 4 is required."
3365          return out
3366    
3367    
3368    def inverse(arg):
3369        """
3370        returns the inverse of the square matrix arg.
3371    
3372        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3373        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3374        @return: inverse arg_inv of the argument. It will be matrixmul(inverse(arg),arg) almost equal to kronecker(arg.getShape()[0])
3375        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
3376        @remark: for L{escript.Data} objects the dimension is restricted to 3.
3377        """
3378        import numarray.linear_algebra # This statement should be after the next statement but then somehow numarray is gone.
3379        if isinstance(arg,numarray.NumArray):
3380          return numarray.linear_algebra.inverse(arg)
3381        elif isinstance(arg,escript.Data):
3382          return escript_inverse(arg)
3383        elif isinstance(arg,float):
3384          return 1./arg
3385        elif isinstance(arg,int):
3386          return 1./float(arg)
3387        elif isinstance(arg,Symbol):
3388          return Inverse_Symbol(arg)
3389        else:
3390          raise TypeError,"inverse: Unknown argument type."
3391    
3392    def escript_inverse(arg): # this should be escript._inverse and use LAPACK
3393          "arg is a Data objects!!!"
3394          if not arg.getRank()==2:
3395            raise ValueError,"escript_inverse: argument must have rank 2"
3396          s=arg.getShape()      
3397          if not s[0] == s[1]:
3398            raise ValueError,"escript_inverse: argument must be a square matrix."
3399          out=escript.Data(0.,s,arg.getFunctionSpace())
3400          if s[0]==1:
3401              if inf(abs(arg[0,0]))==0: # in c this should be done point wise as abs(arg[0,0](i))<=0.
3402                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3403              out[0,0]=1./arg[0,0]
3404          elif s[0]==2:
3405              A11=arg[0,0]
3406              A12=arg[0,1]
3407              A21=arg[1,0]
3408              A22=arg[1,1]
3409              D = A11*A22-A12*A21
3410              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3411                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3412              D=1./D
3413              out[0,0]= A22*D
3414              out[1,0]=-A21*D
3415              out[0,1]=-A12*D
3416              out[1,1]= A11*D
3417          elif s[0]==3:
3418              A11=arg[0,0]
3419              A21=arg[1,0]
3420              A31=arg[2,0]
3421              A12=arg[0,1]
3422              A22=arg[1,1]
3423              A32=arg[2,1]
3424              A13=arg[0,2]
3425              A23=arg[1,2]
3426              A33=arg[2,2]
3427              D  =  A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22)
3428              if inf(abs(D))==0: # in c this should be done point wise as abs(D(i))<=0.
3429                  raise ZeroDivisionError,"escript_inverse: argument not invertible"
3430              D=1./D
3431              out[0,0]=(A22*A33-A23*A32)*D
3432              out[1,0]=(A31*A23-A21*A33)*D
3433              out[2,0]=(A21*A32-A31*A22)*D
3434              out[0,1]=(A13*A32-A12*A33)*D
3435              out[1,1]=(A11*A33-A31*A13)*D
3436              out[2,1]=(A12*A31-A11*A32)*D
3437              out[0,2]=(A12*A23-A13*A22)*D
3438              out[1,2]=(A13*A21-A11*A23)*D
3439              out[2,2]=(A11*A22-A12*A21)*D
3440          else:
3441             raise TypeError,"escript_inverse: only matrix dimensions 1,2,3 are supported right now."
3442          return out
3443    
3444    class Inverse_Symbol(DependendSymbol):
3445       """
3446       L{Symbol} representing the result of the inverse function
3447       """
3448       def __init__(self,arg):
3449          """
3450          initialization of inverse L{Symbol} with argument arg
3451          @param arg: argument of function
3452          @type arg: L{Symbol}.
3453          """
3454          if not arg.getRank()==2:
3455            raise ValueError,"Inverse_Symbol:: argument must have rank 2"
3456          s=arg.getShape()
3457          if not s[0] == s[1]:
3458            raise ValueError,"Inverse_Symbol:: argument must be a square matrix."
3459          super(Inverse_Symbol,self).__init__(args=[arg],shape=s,dim=arg.getDim())
3460    
3461       def getMyCode(self,argstrs,format="escript"):
3462          """
3463          returns a program code that can be used to evaluate the symbol.
3464    
3465          @param argstrs: gives for each argument a string representing the argument for the evaluation.
3466          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
3467          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
3468          @type format: C{str}
3469          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
3470          @rtype: C{str}
3471          @raise: NotImplementedError: if the requested format is not available
3472          """
3473          if format=="escript" or format=="str"  or format=="text":
3474             return "inverse(%s)"%argstrs[0]
3475          else:
3476             raise NotImplementedError,"Inverse_Symbol does not provide program code for format %s."%format
3477    
3478       def substitute(self,argvals):
3479          """
3480          assigns new values to symbols in the definition of the symbol.
3481          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
3482    
3483          @param argvals: new values assigned to symbols
3484          @type argvals: C{dict} with keywords of type L{Symbol}.
3485          @return: result of the substitution process. Operations are executed as much as possible.
3486          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
3487          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
3488          """
3489          if argvals.has_key(self):
3490             arg=argvals[self]
3491             if self.isAppropriateValue(arg):
3492                return arg
3493             else:
3494                raise TypeError,"%s: new value is not appropriate."%str(self)
3495          else:
3496             arg=self.getSubstitutedArguments(argvals)
3497             return inverse(arg[0])
3498    
3499       def diff(self,arg):
3500          """
3501          differential of this object
3502    
3503          @param arg: the derivative is calculated with respect to arg
3504          @type arg: L{escript.Symbol}
3505          @return: derivative with respect to C{arg}
3506          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
3507          """
3508          if arg==self:
3509             return identity(self.getShape())
3510          else:
3511             return -matrixmult(matrixmult(self,self.getDifferentiatedArguments(arg)[0]),self)
3512    
3513    def eigenvalues(arg):
3514        """
3515        returns the eigenvalues of the square matrix arg.
3516    
3517        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3518                    arg must be symmetric, ie. transpose(arg)==arg (this is not checked).
3519        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}
3520        @return: the eigenvalues in increasing order.
3521        @rtype: L{numarray.NumArray},L{escript.Data}, L{Symbol} depending on the input.
3522        @remark: for L{escript.Data} and L{Symbol} objects the dimension is restricted to 3.
3523        """
3524        if isinstance(arg,numarray.NumArray):
3525          out=numarray.linear_algebra.eigenvalues((arg+numarray.transpose(arg))/2.)
3526          out.sort()
3527          return out
3528        elif isinstance(arg,escript.Data):
3529          return arg._eigenvalues()
3530        elif isinstance(arg,Symbol):
3531          if not arg.getRank()==2:
3532            raise ValueError,"eigenvalues: argument must have rank 2"
3533          s=arg.getShape()      
3534          if not s[0] == s[1]:
3535            raise ValueError,"eigenvalues: argument must be a square matrix."
3536          if s[0]==1:
3537              return arg[0]
3538          elif s[0]==2:
3539              arg1=symmetric(arg)
3540              A11=arg1[0,0]
3541              A12=arg1[0,1]
3542              A22=arg1[1,1]
3543              trA=(A11+A22)/2.
3544              A11-=trA
3545              A22-=trA
3546              s=sqrt(A12**2-A11*A22)
3547              return trA+s*numarray.array([-1.,1.],type=numarray.Float64)
3548          elif s[0]==3:
3549              arg1=symmetric(arg)
3550              A11=arg1[0,0]
3551              A12=arg1[0,1]
3552              A22=arg1[1,1]
3553              A13=arg1[0,2]
3554              A23=arg1[1,2]
3555              A33=arg1[2,2]
3556              trA=(A11+A22+A33)/3.
3557              A11-=trA
3558              A22-=trA
3559              A33-=trA
3560              A13_2=A13**2
3561              A23_2=A23**2
3562              A12_2=A12**2
3563              p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.
3564              q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13
3565              sq_p=sqrt(p/3.)
3566              alpha_3=acos(clip(-q*(sq_p+whereZero(p,0.)*1.e-15)**(-3.)/2.,-1.,1.))/3.  # whereZero is protection against divison by zero
3567              sq_p*=2.
3568              f=cos(alpha_3)               *numarray.array([0.,0.,1.],type=numarray.Float64) \
3569               -cos(alpha_3+numarray.pi/3.)*numarray.array([0.,1.,0.],type=numarray.Float64) \
3570               -cos(alpha_3-numarray.pi/3.)*numarray.array([1.,0.,0.],type=numarray.Float64)
3571              return trA+sq_p*f
3572          else:
3573             raise TypeError,"eigenvalues: only matrix dimensions 1,2,3 are supported right now."
3574        elif isinstance(arg,float):
3575          return arg
3576        elif isinstance(arg,int):
3577          return float(arg)
3578        else:
3579          raise TypeError,"eigenvalues: Unknown argument type."
3580    
3581    def eigenvalues_and_eigenvectors(arg):
3582        """
3583        returns the eigenvalues and eigenvectors of the square matrix arg.
3584    
3585        @param arg: square matrix. Must have rank 2 and the first and second dimension must be equal.
3586                    arg must be symmetric, ie. transpose(arg)==arg (this is not checked).
3587        @type arg: L{escript.Data}
3588        @return: the eigenvalues and eigenvectors. The eigenvalues are ordered by increasing value. The
3589                 eigenvectors are orthogonal and normalized. If V are the eigenvectors than V[:,i] is
3590                 the eigenvector coresponding to the i-th eigenvalue.
3591        @rtype: L{tuple} of L{escript.Data}.
3592        @remark: The dimension is restricted to 3.
3593        """
3594        if isinstance(arg,numarray.NumArray):
3595          raise TypeError,"eigenvalues_and_eigenvectors is not supporting numarray arguments"
3596        elif isinstance(arg,escript.Data):
3597          return arg._eigenvalues_and_eigenvectors()
3598        elif isinstance(arg,Symbol):
3599          raise TypeError,"eigenvalues_and_eigenvectors is not supporting Symbol arguments"
3600        elif isinstance(arg,float):
3601          return (numarray.array([[arg]],numarray.Float),numarray.ones((1,1),numarray.Float))
3602        elif isinstance(arg,int):
3603          return (numarray.array([[arg]],numarray.Float),numarray.ones((1,1),numarray.Float))
3604        else:
3605          raise TypeError,"eigenvalues: Unknown argument type."
3606  #=======================================================  #=======================================================
3607  #  Binary operations:  #  Binary operations:
3608  #=======================================================  #=======================================================
# Line 2995  def mult(arg0,arg1): Line 3721  def mult(arg0,arg1):
3721         """         """
3722         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3723         if testForZero(args[0]) or testForZero(args[1]):         if testForZero(args[0]) or testForZero(args[1]):
3724            return numarray.zeros(pokeShape(args[0]),numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3725         else:         else:
3726            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :
3727                return Mult_Symbol(args[0],args[1])                return Mult_Symbol(args[0],args[1])
# Line 3095  def quotient(arg0,arg1): Line 3821  def quotient(arg0,arg1):
3821         """         """
3822         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3823         if testForZero(args[0]):         if testForZero(args[0]):
3824            return numarray.zeros(pokeShape(args[0]),numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3825         elif isinstance(args[0],Symbol):         elif isinstance(args[0],Symbol):
3826            if isinstance(args[1],Symbol):            if isinstance(args[1],Symbol):
3827               return Quotient_Symbol(args[0],args[1])               return Quotient_Symbol(args[0],args[1])
# Line 3201  def power(arg0,arg1): Line 3927  def power(arg0,arg1):
3927         """         """
3928         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3929         if testForZero(args[0]):         if testForZero(args[0]):
3930            return numarray.zeros(args[0],numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3931         elif testForZero(args[1]):         elif testForZero(args[1]):
3932            return numarray.ones(args[0],numarray.Float)            return numarray.ones(pokeShape(args[1]),numarray.Float64)
3933         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):
3934            return Power_Symbol(args[0],args[1])            return Power_Symbol(args[0],args[1])
3935         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):
# Line 3304  def maximum(*args): Line 4030  def maximum(*args):
4030         if out==None:         if out==None:
4031            out=a            out=a
4032         else:         else:
4033            m=whereNegative(out-a)            diff=add(a,-out)
4034            out=m*a+(1.-m)*out            out=add(out,mult(wherePositive(diff),diff))
4035      return out      return out
4036        
4037  def minimum(*arg):  def minimum(*args):
4038      """      """
4039      the minimum over arguments args      the minimum over arguments args
4040    
# Line 3322  def minimum(*arg): Line 4048  def minimum(*arg):
4048         if out==None:         if out==None:
4049            out=a            out=a
4050         else:         else:
4051            m=whereNegative(out-a)            diff=add(a,-out)
4052            out=m*out+(1.-m)*a            out=add(out,mult(whereNegative(diff),diff))
4053      return out      return out
4054    
4055    def clip(arg,minval=0.,maxval=1.):
4056        """
4057        cuts the values of arg between minval and maxval
4058    
4059        @param arg: argument
4060        @type arg: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float}
4061        @param minval: lower range
4062        @type arg: C{float}
4063        @param maxval: upper range
4064        @type arg: C{float}
4065        @return: is on object with all its value between minval and maxval. value of the argument that greater then minval and
4066                 less then maxval are unchanged.
4067        @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol}, C{int} or C{float} depending on the input
4068        @raise ValueError: if minval>maxval
4069        """
4070        if minval>maxval:
4071           raise ValueError,"minval = %s must be less then maxval %s"%(minval,maxval)
4072        return minimum(maximum(minval,arg),maxval)
4073    
4074        
4075  def inner(arg0,arg1):  def inner(arg0,arg1):
4076      """      """
# Line 3348  def inner(arg0,arg1): Line 4094  def inner(arg0,arg1):
4094      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
4095      if not sh0==sh1:      if not sh0==sh1:
4096          raise ValueError,"inner: shape of arguments does not match"          raise ValueError,"inner: shape of arguments does not match"
4097      return generalTensorProduct(arg0,arg1,offset=len(sh0))      return generalTensorProduct(arg0,arg1,axis_offset=len(sh0))
4098    
4099  def matrixmult(arg0,arg1):  def matrixmult(arg0,arg1):
4100      """      """
# Line 3376  def matrixmult(arg0,arg1): Line 4122  def matrixmult(arg0,arg1):
4122          raise ValueError,"first argument must have rank 2"          raise ValueError,"first argument must have rank 2"
4123      if not len(sh1)==2 and not len(sh1)==1:      if not len(sh1)==2 and not len(sh1)==1:
4124          raise ValueError,"second argument must have rank 1 or 2"          raise ValueError,"second argument must have rank 1 or 2"
4125      return generalTensorProduct(arg0,arg1,offset=1)      return generalTensorProduct(arg0,arg1,axis_offset=1)
4126    
4127  def outer(arg0,arg1):  def outer(arg0,arg1):
4128      """      """
# Line 3394  def outer(arg0,arg1): Line 4140  def outer(arg0,arg1):
4140      @return: the outer product of arg0 and arg1 at each data point      @return: the outer product of arg0 and arg1 at each data point
4141      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input      @rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input
4142      """      """
4143      return generalTensorProduct(arg0,arg1,offset=0)      return generalTensorProduct(arg0,arg1,axis_offset=0)
4144    
4145    
4146  def tensormult(arg0,arg1):  def tensormult(arg0,arg1):
# Line 3436  def tensormult(arg0,arg1): Line 4182  def tensormult(arg0,arg1):
4182      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
4183      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
4184      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):      if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ):
4185         return generalTensorProduct(arg0,arg1,offset=1)         return generalTensorProduct(arg0,arg1,axis_offset=1)
4186      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):      elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4):
4187         return generalTensorProduct(arg0,arg1,offset=2)         return generalTensorProduct(arg0,arg1,axis_offset=2)
4188      else:      else:
4189          raise ValueError,"tensormult: first argument must have rank 2 or 4"          raise ValueError,"tensormult: first argument must have rank 2 or 4"
4190    
4191  def generalTensorProduct(arg0,arg1,offset=0):  def generalTensorProduct(arg0,arg1,axis_offset=0):
4192      """      """
4193      generalized tensor product      generalized tensor product
4194    
4195      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]      out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t]
4196    
4197      where s runs through arg0.Shape[:arg0.Rank-offset]      where s runs through arg0.Shape[:arg0.Rank-axis_offset]
4198            r runs trough arg0.Shape[:offset]            r runs trough arg0.Shape[:axis_offset]
4199            t runs through arg1.Shape[offset:]            t runs through arg1.Shape[axis_offset:]
4200    
4201      In the first case the the second dimension of arg0 and the length of arg1 must match and        In the first case the the second dimension of arg0 and the length of arg1 must match and  
4202      in the second case the two last dimensions of arg0 must match the shape of arg1.      in the second case the two last dimensions of arg0 must match the shape of arg1.
# Line 3467  def generalTensorProduct(arg0,arg1,offse Line 4213  def generalTensorProduct(arg0,arg1,offse
4213      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols      # at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols
4214      if isinstance(arg0,numarray.NumArray):      if isinstance(arg0,numarray.NumArray):
4215         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
4216             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
4217         else:         else:
4218             if not arg0.shape[arg0.rank-offset:]==arg1.shape[:offset]:             if not arg0.shape[arg0.rank-axis_offset:]==arg1.shape[:axis_offset]:
4219                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)                 raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4220             arg0_c=arg0.copy()             arg0_c=arg0.copy()
4221             arg1_c=arg1.copy()             arg1_c=arg1.copy()
4222             sh0,sh1=arg0.shape,arg1.shape             sh0,sh1=arg0.shape,arg1.shape
4223             d0,d1,d01=1,1,1             d0,d1,d01=1,1,1
4224             for i in sh0[:arg0.rank-offset]: d0*=i             for i in sh0[:arg0.rank-axis_offset]: d0*=i
4225             for i in sh1[offset:]: d1*=i             for i in sh1[axis_offset:]: d1*=i
4226             for i in sh1[:offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
4227             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
4228             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
4229             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float64)
4230             for i0 in range(d0):             for i0 in range(d0):
4231                      for i1 in range(d1):                      for i1 in range(d1):
4232                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])
4233             out.resize(sh0[:arg0.rank-offset]+sh1[offset:])             out.resize(sh0[:arg0.rank-axis_offset]+sh1[axis_offset:])
4234             return out             return out
4235      elif isinstance(arg0,escript.Data):      elif isinstance(arg0,escript.Data):
4236         if isinstance(arg1,Symbol):         if isinstance(arg1,Symbol):
4237             return GeneralTensorProduct_Symbol(arg0,arg1,offset)             return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
4238         else:         else:
4239             return escript_generalTensorProduct(arg0,arg1,offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,offset)             return escript_generalTensorProduct(arg0,arg1,axis_offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,axis_offset)
4240      else:            else:      
4241         return GeneralTensorProduct_Symbol(arg0,arg1,offset)         return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset)
4242                                    
4243  class GeneralTensorProduct_Symbol(DependendSymbol):  class GeneralTensorProduct_Symbol(DependendSymbol):
4244     """     """
4245     Symbol representing the quotient of two arguments.     Symbol representing the quotient of two arguments.
4246     """     """
4247     def __init__(self,arg0,arg1,offset=0):     def __init__(self,arg0,arg1,axis_offset=0):
4248         """         """
4249         initialization of L{Symbol} representing the quotient of two arguments         initialization of L{Symbol} representing the quotient of two arguments
4250    
# Line 3511  class GeneralTensorProduct_Symbol(Depend Line 4257  class GeneralTensorProduct_Symbol(Depend
4257         """         """
4258         sh_arg0=pokeShape(arg0)         sh_arg0=pokeShape(arg0)
4259         sh_arg1=pokeShape(arg1)         sh_arg1=pokeShape(arg1)
4260         sh0=sh_arg0[:len(sh_arg0)-offset]         sh0=sh_arg0[:len(sh_arg0)-axis_offset]
4261         sh01=sh_arg0[len(sh_arg0)-offset:]         sh01=sh_arg0[len(sh_arg0)-axis_offset:]
4262         sh10=sh_arg1[:offset]         sh10=sh_arg1[:axis_offset]
4263         sh1=sh_arg1[offset:]         sh1=sh_arg1[axis_offset:]
4264         if not sh01==sh10:         if not sh01==sh10:
4265             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)             raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4266         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,offset])         DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,axis_offset])
4267    
4268     def getMyCode(self,argstrs,format="escript"):     def getMyCode(self,argstrs,format="escript"):
4269        """        """
# Line 3532  class GeneralTensorProduct_Symbol(Depend Line 4278  class GeneralTensorProduct_Symbol(Depend
4278        @raise: NotImplementedError: if the requested format is not available        @raise: NotImplementedError: if the requested format is not available
4279        """        """
4280        if format=="escript" or format=="str" or format=="text":        if format=="escript" or format=="str" or format=="text":
4281           return "generalTensorProduct(%s,%s,offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])           return "generalTensorProduct(%s,%s,axis_offset=%s)"%(argstrs[0],argstrs[1],argstrs[2])
4282        else:        else:
4283           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)           raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format)
4284    
# Line 3557  class GeneralTensorProduct_Symbol(Depend Line 4303  class GeneralTensorProduct_Symbol(Depend
4303           args=self.getSubstitutedArguments(argvals)           args=self.getSubstitutedArguments(argvals)
4304           return generalTensorProduct(args[0],args[1],args[2])           return generalTensorProduct(args[0],args[1],args[2])
4305    
4306  def escript_generalTensorProduct(arg0,arg1,offset): # this should be escript._generalTensorProduct  def escript_generalTensorProduct(arg0,arg1,axis_offset): # this should be escript._generalTensorProduct
4307      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"      "arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!"
4308      # calculate the return shape:      # calculate the return shape:
4309      shape0=arg0.getShape()[:arg0.getRank()-offset]      shape0=arg0.getShape()[:arg0.getRank()-axis_offset]
4310      shape01=arg0.getShape()[arg0.getRank()-offset:]      shape01=arg0.getShape()[arg0.getRank()-axis_offset:]
4311      shape10=arg1.getShape()[:offset]      shape10=arg1.getShape()[:axis_offset]
4312      shape1=arg1.getShape()[offset:]      shape1=arg1.getShape()[axis_offset:]
4313      if not shape01==shape10:      if not shape01==shape10:
4314          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset)          raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset)
4315    
4316        # whatr function space should be used? (this here is not good!)
4317        fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace()
4318      # create return value:      # create return value:
4319      out=escript.Data(0.,tuple(shape0+shape1),arg0.getFunctionSpace())      out=escript.Data(0.,tuple(shape0+shape1),fs)
4320      #      #
4321      s0=[[]]      s0=[[]]
4322      for k in shape0:      for k in shape0:
# Line 3591  def escript_generalTensorProduct(arg0,ar Line 4339  def escript_generalTensorProduct(arg0,ar
4339    
4340      for i0 in s0:      for i0 in s0:
4341         for i1 in s1:         for i1 in s1:
4342           s=escript.Scalar(0.,arg0.getFunctionSpace())           s=escript.Scalar(0.,fs)
4343           for i01 in s01:           for i01 in s01:
4344              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))              s+=arg0.__getitem__(tuple(i0+i01))*arg1.__getitem__(tuple(i01+i1))
4345           out.__setitem__(tuple(i0+i1),s)           out.__setitem__(tuple(i0+i1),s)
4346      return out      return out
4347    
4348    
4349  #=========================================================  #=========================================================
4350  #   some little helpers  #  functions dealing with spatial dependency
4351  #=========================================================  #=========================================================
4352  def grad(arg,where=None):  def grad(arg,where=None):
4353      """      """
4354      Returns the spatial gradient of arg at where.      Returns the spatial gradient of arg at where.
4355    
4356        If C{g} is the returned object, then
4357    
4358          - if C{arg} is rank 0 C{g[s]} is the derivative of C{arg} with respect to the C{s}-th spatial dimension.
4359          - if C{arg} is rank 1 C{g[i,s]} is the derivative of C{arg[i]} with respect to the C{s}-th spatial dimension.
4360          - if C{arg} is rank 2 C{g[i,j,s]} is the derivative of C{arg[i,j]} with respect to the C{s}-th spatial dimension.
4361          - if C{arg} is rank 3 C{g[i,j,k,s]} is the derivative of C{arg[i,j,k]} with respect to the C{s}-th spatial dimension.
4362    
4363      @param arg:   Data object representing the function which gradient      @param arg: function which gradient to be calculated. Its rank has to be less than 3.
4364                    to be calculated.      @type arg: L{escript.Data} or L{Symbol}
4365      @param where: FunctionSpace in which the gradient will be calculated.      @param where: FunctionSpace in which the gradient will be calculated.
4366                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4367        @type where: C{None} or L{escript.FunctionSpace}
4368        @return: gradient of arg.
4369        @rtype:  L{escript.Data} or L{Symbol}
4370      """      """
4371      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
4372         return Grad_Symbol(arg,where)         return Grad_Symbol(arg,where)
# Line 3617  def grad(arg,where=None): Line 4376  def grad(arg,where=None):
4376         else:         else:
4377            return arg._grad(where)            return arg._grad(where)
4378      else:      else:
4379        raise TypeError,"grad: Unknown argument type."         raise TypeError,"grad: Unknown argument type."
4380    
4381    class Grad_Symbol(DependendSymbol):
4382       """
4383       L{Symbol} representing the result of the gradient operator
4384       """
4385       def __init__(self,arg,where=None):
4386          """
4387          initialization of gradient L{Symbol} with argument arg
4388          @param arg: argument of function
4389          @type arg: L{Symbol}.
4390          @param where: FunctionSpace in which the gradient will be calculated.
4391                      If not present or C{None} an appropriate default is used.
4392          @type where: C{None} or L{escript.FunctionSpace}
4393          """
4394          d=arg.getDim()
4395          if d==None:
4396             raise ValueError,"argument must have a spatial dimension"
4397          super(Grad_Symbol,self).__init__(args=[arg,where],shape=arg.getShape()+(d,),dim=d)
4398    
4399       def getMyCode(self,argstrs,format="escript"):
4400          """
4401          returns a program code that can be used to evaluate the symbol.
4402    
4403          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4404          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4405          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4406          @type format: C{str}
4407          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4408          @rtype: C{str}
4409          @raise: NotImplementedError: if the requested format is not available
4410          """
4411          if format=="escript" or format=="str"  or format=="text":
4412             return "grad(%s,where=%s)"%(argstrs[0],argstrs[1])
4413          else:
4414             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4415    
4416       def substitute(self,argvals):
4417          """
4418          assigns new values to symbols in the definition of the symbol.
4419          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4420    
4421          @param argvals: new values assigned to symbols
4422          @type argvals: C{dict} with keywords of type L{Symbol}.
4423          @return: result of the substitution process. Operations are executed as much as possible.
4424          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4425          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4426          """
4427          if argvals.has_key(self):
4428             arg=argvals[self]
4429             if self.isAppropriateValue(arg):
4430                return arg
4431             else:
4432                raise TypeError,"%s: new value is not appropriate."%str(self)
4433          else:
4434             arg=self.getSubstitutedArguments(argvals)
4435             return grad(arg[0],where=arg[1])
4436    
4437       def diff(self,arg):
4438          """
4439          differential of this object
4440    
4441          @param arg: the derivative is calculated with respect to arg
4442          @type arg: L{escript.Symbol}
4443          @return: derivative with respect to C{arg}
4444          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4445          """
4446          if arg==self:
4447             return identity(self.getShape())
4448          else:
4449             return grad(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4450    
4451  def integrate(arg,where=None):  def integrate(arg,where=None):
4452      """      """
4453      Return the integral if the function represented by Data object arg over      Return the integral of the function C{arg} over its domain. If C{where} is present C{arg} is interpolated to C{where}
4454      its domain.      before integration.
4455    
4456      @param arg:   Data object representing the function which is integrated.      @param arg:   the function which is integrated.
4457        @type arg: L{escript.Data} or L{Symbol}
4458      @param where: FunctionSpace in which the integral is calculated.      @param where: FunctionSpace in which the integral is calculated.
4459                    If not present or C{None} an appropriate default is used.                    If not present or C{None} an appropriate default is used.
4460        @type where: C{None} or L{escript.FunctionSpace}
4461        @return: integral of arg.
4462        @rtype:  C{float}, C{numarray.NumArray} or L{Symbol}
4463      """      """
4464      if isinstance(arg,numarray.NumArray):      if isinstance(arg,Symbol):
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,float):  
         if checkForZero(arg):  
            return arg  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,int):  
         if checkForZero(arg):  
            return float(arg)  
         else:  
            raise TypeError,"integrate: cannot intergrate argument"  
     elif isinstance(arg,Symbol):  
4465         return Integrate_Symbol(arg,where)         return Integrate_Symbol(arg,where)
4466      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4467         if not where==None: arg=escript.Data(arg,where)         if not where==None: arg=escript.Data(arg,where)
# Line 3654  def integrate(arg,where=None): Line 4472  def integrate(arg,where=None):
4472      else:      else:
4473        raise TypeError,"integrate: Unknown argument type."        raise TypeError,"integrate: Unknown argument type."
4474    
4475    class Integrate_Symbol(DependendSymbol):
4476       """
4477       L{Symbol} representing the result of the spatial integration operator
4478       """
4479       def __init__(self,arg,where=None):
4480          """
4481          initialization of integration L{Symbol} with argument arg
4482          @param arg: argument of the integration
4483          @type arg: L{Symbol}.
4484          @param where: FunctionSpace in which the integration will be calculated.
4485                      If not present or C{None} an appropriate default is used.
4486          @type where: C{None} or L{escript.FunctionSpace}
4487          """
4488          super(Integrate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4489    
4490       def getMyCode(self,argstrs,format="escript"):
4491          """
4492          returns a program code that can be used to evaluate the symbol.
4493    
4494          @param argstrs: gives for each argument a string representing the argument for the evaluation.
4495          @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4496          @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4497          @type format: C{str}
4498          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4499          @rtype: C{str}
4500          @raise: NotImplementedError: if the requested format is not available
4501          """
4502          if format=="escript" or format=="str"  or format=="text":
4503             return "integrate(%s,where=%s)"%(argstrs[0],argstrs[1])
4504          else:
4505             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4506    
4507       def substitute(self,argvals):
4508          """
4509          assigns new values to symbols in the definition of the symbol.
4510          The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
4511    
4512          @param argvals: new values assigned to symbols
4513          @type argvals: C{dict} with keywords of type L{Symbol}.
4514          @return: result of the substitution process. Operations are executed as much as possible.
4515          @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4516          @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4517          """
4518          if argvals.has_key(self):
4519             arg=argvals[self]
4520             if self.isAppropriateValue(arg):
4521                return arg
4522             else:
4523                raise TypeError,"%s: new value is not appropriate."%str(self)
4524          else:
4525             arg=self.getSubstitutedArguments(argvals)
4526             return integrate(arg[0],where=arg[1])
4527    
4528       def diff(self,arg):
4529          """
4530          differential of this object
4531    
4532          @param arg: the derivative is calculated with respect to arg
4533          @type arg: L{escript.Symbol}
4534          @return: derivative with respect to C{arg}
4535          @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray}  are possible.
4536          """
4537          if arg==self:
4538             return identity(self.getShape())
4539          else:
4540             return integrate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4541    
4542    
4543  def interpolate(arg,where):  def interpolate(arg,where):
4544      """      """
4545      Interpolates the function into the FunctionSpace where.      interpolates the function into the FunctionSpace where.
4546    
4547      @param arg:    interpolant      @param arg: interpolant
4548      @param where:  FunctionSpace to interpolate to      @type arg: L{escript.Data} or L{Symbol}
4549        @param where: FunctionSpace to be interpolated to
4550        @type where: L{escript.FunctionSpace}
4551        @return: interpolated argument
4552        @rtype:  C{escript.Data} or L{Symbol}
4553      """      """
4554      if testForZero(arg):      if isinstance(arg,Symbol):
4555        return 0         return Interpolate_Symbol(arg,where)
     elif isinstance(arg,Symbol):  
        return Interpolated_Symbol(arg,where)  
4556      else:      else:
4557         return escript.Data(arg,where)         return escript.Data(arg,where)
4558    
4559  def div(arg,where=None):  class Interpolate_Symbol(DependendSymbol):
4560      """     """
4561      Returns the divergence of arg at where.     L{Symbol} representing the result of the interpolation operator
4562       """
4563       def __init__(self,arg,where):
4564          """
4565          initialization of interpolation L{Symbol} with argument arg
4566          @param arg: argument of the interpolation
4567          @type arg: L{Symbol}.
4568          @param where: FunctionSpace into which the argument is interpolated.
4569          @type where: L{escript.FunctionSpace}
4570          """
4571          super(Interpolate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim())
4572    
4573      @param arg:   Data object representing the function which gradient to     def getMyCode(self,argstrs,format="escript"):
4574                    be calculated.        """
4575      @param where: FunctionSpace in which the gradient will be calculated.        returns a program code that can be used to evaluate the symbol.
                   If not present or C{None} an appropriate default is used.  
     """  
     g=grad(arg,where)  
     return trace(g,axis0=g.getRank()-2,axis1=g.getRank()-1)  
4576    
4577  def jump(arg):        @param argstrs: gives for each argument a string representing the argument for the evaluation.
4578      """        @type argstrs: C{str} or a C{list} of length 1 of C{str}.
4579      Returns the jump of arg across a continuity.        @param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported.
4580          @type format: C{str}
4581          @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.
4582          @rtype: C{str}
4583          @raise: NotImplementedError: if the requested format is not available
4584          """
4585          if format=="escript" or format=="str"  or format=="text":
4586             return "interpolate(%s,where=%s)"%(argstrs[0],argstrs[1])
4587          else:
4588             raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format
4589    
4590      @param arg:   Data object representing the function which gradient     def substitute(self,argvals):
4591                    to be calculated.        """
4592      """        assigns new values to symbols in the definition of the symbol.
4593      d=arg.getDomain()        The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.
     return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())  
4594    
4595  #=============================        @param argvals: new values assigned to symbols
4596  #        @type argvals: C{dict} with keywords of type L{Symbol}.
4597  # wrapper for various functions: if the argument has attribute the function name        @return: result of the substitution process. Operations are executed as much as possible.
4598  # as an argument it calls the corresponding methods. Otherwise the corresponding        @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution
4599  # numarray function is called.        @raise TypeError: if a value for a L{Symbol} cannot be substituted.
4600          """
4601          if argvals.has_key(self):
4602             arg=argvals[self]
4603             if self.isAppropriateValue(arg):
4604                return arg
4605             else:
4606                raise TypeError,"%s: new value is not appropriate."%str(self)
4607          else:
4608             arg=self.getSubstitutedArguments(argvals)
4609             return interpolate(arg[0],where=arg[1])
4610    
4611  # functions involving the underlying Domain:     def diff(self,arg):
4612          """
4613          differential of this object
4614    
4615          @param arg: the derivative is calculated with respect to arg
4616          @type arg: L{escript.Symbol}
4617          @return: derivative with respect to C{arg}
4618          @rtype: L{Symbol} but other types such as L{escript.Data}, L{numarray.NumArray}  are possible.
4619          """
4620          if arg==self:
4621             return identity(self.getShape())
4622          else:
4623             return interpolate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1])
4624    
4625  def transpose(arg,axis=None):  
4626    def div(arg,where=None):
4627      """      """
4628      Returns the transpose of the Data object arg.      returns the divergence of arg at where.
4629    
4630      @param arg:      @param arg: function which divergence to be calculated. Its shape has to be (d,) where d is the spatial dimension.
4631        @type arg: L{escript.Data} or L{Symbol}
4632        @param where: FunctionSpace in which the divergence will be calculated.
4633                      If not present or C{None} an appropriate default is used.
4634        @type where: C{None} or L{escript.FunctionSpace}
4635        @return: divergence of arg.
4636        @rtype:  L{escript.Data} or L{Symbol}
4637      """      """
     if axis==None:  
        r=0  
        if hasattr(arg,"getRank"): r=arg.getRank()  
        if hasattr(arg,"rank"): r=arg.rank  
        axis=r/2  
4638      if isinstance(arg,Symbol):      if isinstance(arg,Symbol):
4639         return Transpose_Symbol(arg,axis=r)          dim=arg.getDim()
4640      if isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
4641         # hack for transpose          dim=arg.getDomain().getDim()
        r=arg.getRank()  
        if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"  
        s=arg.getShape()  
        out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())  
        for i in range(s[0]):  
           for j in range(s[1]):  
              out[j,i]=arg[i,j]  
        return out  
        # end hack for transpose  
        return arg.transpose(axis)  
4642      else:      else:
4643         return numarray.transpose(arg,axis=axis)          raise TypeError,"div: argument type not supported"
4644        if not arg.getShape()==(dim,):
4645          raise ValueError,"div: expected shape is (%s,)"%dim
4646        return trace(grad(arg,where))
4647    
4648  def trace(arg,axis0=0,axis1=1):  def jump(arg,domain=None):
4649      """      """
4650      Return      returns the jump of arg across the continuity of the domain
4651    
4652      @param arg:      @param arg: argument
4653        @type arg: L{escript.Data} or L{Symbol}
4654        @param domain: the domain where the discontinuity is located. If domain is not present or equal to C{None}
4655                       the domain of arg is used. If arg is a L{Symbol} the domain must be present.
4656        @type domain: C{None} or L{escript.Domain}
4657        @return: jump of arg
4658        @rtype:  L{escript.Data} or L{Symbol}
4659      """      """
4660      if isinstance(arg,Symbol):      if domain==None: domain=arg.getDomain()
4661         s=list(arg.getShape())              return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain))
        s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])  
        return Trace_Symbol(arg,axis0=axis0,axis1=axis1)  
     elif isinstance(arg,escript.Data):  
        # hack for trace  
        s=arg.getShape()  
        if s[axis0]!=s[axis1]:  
            raise ValueError,"illegal axis in trace"  
        out=escript.Scalar(0.,arg.getFunctionSpace())  
        for i in range(s[axis0]):  
           out+=arg[i,i]  
        return out  
        # end hack for trace  
     else:  
        return numarray.trace(arg,axis0=axis0,axis1=axis1)  
4662    
4663    def L2(arg):
4664        """
4665        returns the L2 norm of arg at where
4666        
4667        @param arg: function which L2 to be calculated.
4668        @type arg: L{escript.Data} or L{Symbol}
4669        @return: L2 norm of arg.
4670        @rtype:  L{float} or L{Symbol}
4671        @note: L2(arg) is equivalent to sqrt(integrate(inner(arg,arg)))
4672        """
4673        return sqrt(integrate(inner(arg,arg)))
4674    #=============================
4675    #
4676    
4677  def reorderComponents(arg,index):  def reorderComponents(arg,index):
4678      """      """
4679      resorts the component of arg according to index      resorts the component of arg according to index
4680    
4681      """      """
4682      pass      raise NotImplementedError
4683  #  #
4684  # $Log: util.py,v $  # $Log: util.py,v $
4685  # Revision 1.14.2.16  2005/10/19 06:09:57  gross  # Revision 1.14.2.16  2005/10/19 06:09:57  gross

Legend:
Removed from v.341  
changed lines
  Added in v.698

  ViewVC Help
Powered by ViewVC 1.1.26