/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Diff of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 573 by gross, Wed Feb 22 02:14:38 2006 UTC revision 574 by gross, Thu Mar 2 06:31:24 2006 UTC
# Line 133  def identity(shape=()): Line 133  def identity(shape=()):
133     @raise ValueError: if len(shape)>2.     @raise ValueError: if len(shape)>2.
134     """     """
135     if len(shape)>0:     if len(shape)>0:
136        out=numarray.zeros(shape+shape,numarray.Float)        out=numarray.zeros(shape+shape,numarray.Float64)
137        if len(shape)==1:        if len(shape)==1:
138            for i0 in range(shape[0]):            for i0 in range(shape[0]):
139               out[i0,i0]=1.               out[i0,i0]=1.
# Line 389  def matchType(arg0=0.,arg1=0.): Line 389  def matchType(arg0=0.,arg1=0.):
389         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
390            arg0=escript.Data(arg0,arg1.getFunctionSpace())            arg0=escript.Data(arg0,arg1.getFunctionSpace())
391         elif isinstance(arg1,float):         elif isinstance(arg1,float):
392            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
393         elif isinstance(arg1,int):         elif isinstance(arg1,int):
394            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
395         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
396            pass            pass
397         else:         else:
# Line 415  def matchType(arg0=0.,arg1=0.): Line 415  def matchType(arg0=0.,arg1=0.):
415         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
416            pass            pass
417         elif isinstance(arg1,float):         elif isinstance(arg1,float):
418            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
419         elif isinstance(arg1,int):         elif isinstance(arg1,int):
420            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
421         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
422            pass            pass
423         else:         else:
424            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
425      elif isinstance(arg0,float):      elif isinstance(arg0,float):
426         if isinstance(arg1,numarray.NumArray):         if isinstance(arg1,numarray.NumArray):
427            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
428         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
429            arg0=escript.Data(arg0,arg1.getFunctionSpace())            arg0=escript.Data(arg0,arg1.getFunctionSpace())
430         elif isinstance(arg1,float):         elif isinstance(arg1,float):
431            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
432            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
433         elif isinstance(arg1,int):         elif isinstance(arg1,int):
434            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
435            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
436         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
437            arg0=numarray.array(arg0)            arg0=numarray.array(arg0,type=numarray.Float64)
438         else:         else:
439            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
440      elif isinstance(arg0,int):      elif isinstance(arg0,int):
441         if isinstance(arg1,numarray.NumArray):         if isinstance(arg1,numarray.NumArray):
442            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
443         elif isinstance(arg1,escript.Data):         elif isinstance(arg1,escript.Data):
444            arg0=escript.Data(float(arg0),arg1.getFunctionSpace())            arg0=escript.Data(float(arg0),arg1.getFunctionSpace())
445         elif isinstance(arg1,float):         elif isinstance(arg1,float):
446            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
447            arg1=numarray.array(arg1)            arg1=numarray.array(arg1,type=numarray.Float64)
448         elif isinstance(arg1,int):         elif isinstance(arg1,int):
449            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
450            arg1=numarray.array(float(arg1))            arg1=numarray.array(float(arg1),type=numarray.Float64)
451         elif isinstance(arg1,Symbol):         elif isinstance(arg1,Symbol):
452            arg0=numarray.array(float(arg0))            arg0=numarray.array(float(arg0),type=numarray.Float64)
453         else:         else:
454            raise TypeError,"function: Unknown type of second argument."                raise TypeError,"function: Unknown type of second argument."    
455      else:      else:
# Line 472  def matchShape(arg0,arg1): Line 472  def matchShape(arg0,arg1):
472      sh0=pokeShape(arg0)      sh0=pokeShape(arg0)
473      sh1=pokeShape(arg1)      sh1=pokeShape(arg1)
474      if len(sh0)<len(sh):      if len(sh0)<len(sh):
475         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float)),arg1         return outer(arg0,numarray.ones(sh[len(sh0):],numarray.Float64)),arg1
476      elif len(sh1)<len(sh):      elif len(sh1)<len(sh):
477         return arg0,outer(arg1,numarray.ones(sh[len(sh1):],numarray.Float))         return arg0,outer(arg1,numarray.ones(sh[len(sh1):],numarray.Float64))
478      else:      else:
479         return arg0,arg1         return arg0,arg1
480  #=========================================================  #=========================================================
# Line 600  class Symbol(object): Line 600  class Symbol(object):
600            else:            else:
601                s=pokeShape(s)+arg.getShape()                s=pokeShape(s)+arg.getShape()
602                if len(s)>0:                if len(s)>0:
603                   out.append(numarray.zeros(s),numarray.Float)                   out.append(numarray.zeros(s),numarray.Float64)
604                else:                else:
605                   out.append(a)                   out.append(a)
606         return out         return out
# Line 690  class Symbol(object): Line 690  class Symbol(object):
690         else:         else:
691            s=self.getShape()+arg.getShape()            s=self.getShape()+arg.getShape()
692            if len(s)>0:            if len(s)>0:
693               return numarray.zeros(s,numarray.Float)               return numarray.zeros(s,numarray.Float64)
694            else:            else:
695               return 0.               return 0.
696    
# Line 1009  def wherePositive(arg): Line 1009  def wherePositive(arg):
1009     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1010     """     """
1011     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1012        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1013        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1014        return out        return out
1015     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1016        return arg._wherePositive()        return arg._wherePositive()
# Line 1091  def whereNegative(arg): Line 1091  def whereNegative(arg):
1091     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1092     """     """
1093     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1094        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.less(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1095        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1096        return out        return out
1097     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1098        return arg._whereNegative()        return arg._whereNegative()
# Line 1173  def whereNonNegative(arg): Line 1173  def whereNonNegative(arg):
1173     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1174     """     """
1175     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1176        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1177        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1178        return out        return out
1179     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1180        return arg._whereNonNegative()        return arg._whereNonNegative()
# Line 1203  def whereNonPositive(arg): Line 1203  def whereNonPositive(arg):
1203     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1204     """     """
1205     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1206        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float64))*1.
1207        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1208        return out        return out
1209     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1210        return arg._whereNonPositive()        return arg._whereNonPositive()
# Line 1235  def whereZero(arg,tol=0.): Line 1235  def whereZero(arg,tol=0.):
1235     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1236     """     """
1237     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1238        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.less_equal(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float64))*1.
1239        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1240        return out        return out
1241     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1242        if tol>0.:        if tol>0.:
# Line 1318  def whereNonZero(arg,tol=0.): Line 1318  def whereNonZero(arg,tol=0.):
1318     @raises TypeError: if the type of the argument is not expected.     @raises TypeError: if the type of the argument is not expected.
1319     """     """
1320     if isinstance(arg,numarray.NumArray):     if isinstance(arg,numarray.NumArray):
1321        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float))*1.        out=numarray.greater(abs(arg)-tol,numarray.zeros(arg.shape,numarray.Float64))*1.
1322        if isinstance(out,float): out=numarray.array(out)        if isinstance(out,float): out=numarray.array(out,type=numarray.Float64)
1323        return out        return out
1324     elif isinstance(arg,escript.Data):     elif isinstance(arg,escript.Data):
1325        if tol>0.:        if tol>0.:
# Line 2980  def trace(arg,axis_offset=0): Line 2980  def trace(arg,axis_offset=0):
2980        if not sh[axis_offset] == sh[axis_offset+1]:        if not sh[axis_offset] == sh[axis_offset+1]:
2981          raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)          raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1)
2982        arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))        arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2))
2983        out=numarray.zeros([s1,s2],numarray.Float)        out=numarray.zeros([s1,s2],numarray.Float64)
2984        for i1 in range(s1):        for i1 in range(s1):
2985          for i2 in range(s2):          for i2 in range(s2):
2986              for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]              for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2]
# Line 3557  def eigenvalues(arg): Line 3557  def eigenvalues(arg):
3557        s=arg.shape              s=arg.shape      
3558        if not s[0] == s[1]:        if not s[0] == s[1]:
3559          raise ValueError,"eigenvalues: argument must be a square matrix."          raise ValueError,"eigenvalues: argument must be a square matrix."
3560        out=numarray.linear_algebra.eigenvalues((arg+numarray.transpose(arg))/2.)        if s[0]==1:
3561        out.sort()            out=arg[0]
3562          elif s[0]==2:
3563              A11=arg[0,0]
3564              A12=arg[0,1]
3565              A22=arg[1,1]
3566              trA=(A11+A22)/2.
3567              A11-=trA
3568              A22-=trA
3569              s=sqrt(A12**2-A11*A22)
3570              out=trA+numarray.array([-s,s],type=numarray.Float64)
3571          elif s[0]==3:
3572              A11=arg[0,0]
3573              A12=arg[0,1]
3574              A22=arg[1,1]
3575              A13=arg[0,2]
3576              A23=arg[1,2]
3577              A33=arg[2,2]
3578              trA=(A11+A22+A33)/3.
3579              A11-=trA
3580              A22-=trA
3581              A33-=trA
3582              A13_2=A13**2
3583              A23_2=A23**2
3584              A12_2=A12**2
3585              p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.
3586              q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13
3587              sq_p=sqrt(p/3.)
3588              alpha_3=acos(clip(-q*sq_p**(-3.)/2.,-1.,1.))/3.
3589              sq_p*=2.
3590              out=trA+sq_p*numarray.array([-cos(alpha_3-numarray.pi/3.),-cos(alpha_3+numarray.pi/3.),cos(alpha_3)],type=numarray.Float64)
3591          else:
3592             out=numarray.linear_algebra.eigenvalues((arg+numarray.transpose(arg))/2.)
3593             out.sort()
3594        return out        return out
3595      elif isinstance(arg,escript.Data):      elif isinstance(arg,escript.Data):
3596        return escript_eigenvalues(arg)        return escript_eigenvalues(arg)
# Line 3578  def eigenvalues(arg): Line 3610  def eigenvalues(arg):
3610            A11-=trA            A11-=trA
3611            A22-=trA            A22-=trA
3612            s=sqrt(A12**2-A11*A22)            s=sqrt(A12**2-A11*A22)
3613            return trA+s*numarray.array([-1.,1.])            return trA+s*numarray.array([-1.,1.],type=numarray.Float64)
3614        elif s[0]==3:        elif s[0]==3:
3615            A11=arg[0,0]            A11=arg[0,0]
3616            A12=arg[0,1]            A12=arg[0,1]
# Line 3596  def eigenvalues(arg): Line 3628  def eigenvalues(arg):
3628            p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.            p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.
3629            q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13            q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13
3630            sq_p=sqrt(p/3.)            sq_p=sqrt(p/3.)
3631            alpha_3=acos(-q*sq_p**(-3.)/2.)/3.            alpha_3=acos(clip(-q*sq_p**(-3.)/2.,-1.,1.))/3.
3632            sq_p*=2.            sq_p*=2.
3633            f=cos(alpha_3)               *numarray.array([0.,0.,1.]) \            f=cos(alpha_3)               *numarray.array([0.,0.,1.],type=numarray.Float64) \
3634             -cos(alpha_3+numarray.pi/3.)*numarray.array([0.,1.,0.]) \             -cos(alpha_3+numarray.pi/3.)*numarray.array([0.,1.,0.],type=numarray.Float64) \
3635             -cos(alpha_3-numarray.pi/3.)*numarray.array([1.,0.,0.])             -cos(alpha_3-numarray.pi/3.)*numarray.array([1.,0.,0.],type=numarray.Float64)
3636            return trA+sq_p*f            return trA+sq_p*f
3637        else:        else:
3638           raise TypeError,"eigenvalues: only matrix dimensions 1,2,3 are supported right now."           raise TypeError,"eigenvalues: only matrix dimensions 1,2,3 are supported right now."
# Line 3627  def escript_eigenvalues(arg): # this sho Line 3659  def escript_eigenvalues(arg): # this sho
3659            A11-=trA            A11-=trA
3660            A22-=trA            A22-=trA
3661            s=sqrt(A12**2-A11*A22)            s=sqrt(A12**2-A11*A22)
3662            return trA+s*numarray.array([-1.,1.])            return trA+s*numarray.array([-1.,1.],type=numarray.Float64)
3663        elif s[0]==3:        elif s[0]==3:
3664            A11=arg[0,0]            A11=arg[0,0]
3665            A12=arg[0,1]            A12=arg[0,1]
# Line 3645  def escript_eigenvalues(arg): # this sho Line 3677  def escript_eigenvalues(arg): # this sho
3677            p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.            p=A13_2+A23_2+A12_2+(A11**2+A22**2+A33**2)/2.
3678            q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13            q=A13_2*A22+A23_2*A11+A12_2*A33-A11*A22*A33-2*A12*A23*A13
3679            sq_p=sqrt(p/3.)            sq_p=sqrt(p/3.)
3680            alpha_3=acos(-q*sq_p**(-3.)/2.)/3.            alpha_3=acos(clip(-q*sq_p**(-3.)/2.,-1.,1.))/3.
3681            sq_p*=2.            sq_p*=2.
3682            f=escript.Data(0.,(3,),arg.getFunctionSpace())            f=escript.Data(0.,(3,),arg.getFunctionSpace())
3683            f[0]=-cos(alpha_3-numarray.pi/3.)            f[0]=-cos(alpha_3-numarray.pi/3.)
# Line 3772  def mult(arg0,arg1): Line 3804  def mult(arg0,arg1):
3804         """         """
3805         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3806         if testForZero(args[0]) or testForZero(args[1]):         if testForZero(args[0]) or testForZero(args[1]):
3807            return numarray.zeros(pokeShape(args[0]),numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3808         else:         else:
3809            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :            if isinstance(args[0],Symbol) or isinstance(args[1],Symbol) :
3810                return Mult_Symbol(args[0],args[1])                return Mult_Symbol(args[0],args[1])
# Line 3872  def quotient(arg0,arg1): Line 3904  def quotient(arg0,arg1):
3904         """         """
3905         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
3906         if testForZero(args[0]):         if testForZero(args[0]):
3907            return numarray.zeros(pokeShape(args[0]),numarray.Float)            return numarray.zeros(pokeShape(args[0]),numarray.Float64)
3908         elif isinstance(args[0],Symbol):         elif isinstance(args[0],Symbol):
3909            if isinstance(args[1],Symbol):            if isinstance(args[1],Symbol):
3910               return Quotient_Symbol(args[0],args[1])               return Quotient_Symbol(args[0],args[1])
# Line 3978  def power(arg0,arg1): Line 4010  def power(arg0,arg1):
4010         """         """
4011         args=matchShape(arg0,arg1)         args=matchShape(arg0,arg1)
4012         if testForZero(args[0]):         if testForZero(args[0]):
4013            return numarray.zeros(args[0],numarray.Float)            return numarray.zeros(args[0],numarray.Float64)
4014         elif testForZero(args[1]):         elif testForZero(args[1]):
4015            return numarray.ones(args[0],numarray.Float)            return numarray.ones(args[0],numarray.Float64)
4016         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):         elif isinstance(args[0],Symbol) or isinstance(args[1],Symbol):
4017            return Power_Symbol(args[0],args[1])            return Power_Symbol(args[0],args[1])
4018         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):         elif isinstance(args[0],numarray.NumArray) and not isinstance(args[1],numarray.NumArray):
# Line 4277  def generalTensorProduct(arg0,arg1,axis_ Line 4309  def generalTensorProduct(arg0,arg1,axis_
4309             for i in sh1[:axis_offset]: d01*=i             for i in sh1[:axis_offset]: d01*=i
4310             arg0_c.resize((d0,d01))             arg0_c.resize((d0,d01))
4311             arg1_c.resize((d01,d1))             arg1_c.resize((d01,d1))
4312             out=numarray.zeros((d0,d1),numarray.Float)             out=numarray.zeros((d0,d1),numarray.Float64)
4313             for i0 in range(d0):             for i0 in range(d0):
4314                      for i1 in range(d1):                      for i1 in range(d1):
4315                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])                           out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1])

Legend:
Removed from v.573  
changed lines
  Added in v.574

  ViewVC Help
Powered by ViewVC 1.1.26