2858 |
""" |
""" |
2859 |
return sqrt(inner(arg,arg)) |
return sqrt(inner(arg,arg)) |
2860 |
|
|
2861 |
|
def trace(arg,axis_offset=0): |
2862 |
|
""" |
2863 |
|
returns the trace of arg which the sum of arg[k,k] over k. |
2864 |
|
|
2865 |
|
@param arg: argument |
2866 |
|
@type arg: L{escript.Data}, L{Symbol}, L{numarray.NumArray}. |
2867 |
|
@param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component |
2868 |
|
axis_offset and axis_offset+1 must be equal. |
2869 |
|
@type axis_offset: C{int} |
2870 |
|
@return: trace of arg. The rank of the returned object is minus 2 of the rank of arg. |
2871 |
|
@rtype: L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg. |
2872 |
|
""" |
2873 |
|
if isinstance(arg,numarray.NumArray): |
2874 |
|
sh=arg.shape |
2875 |
|
if len(sh)<2: |
2876 |
|
raise ValueError,"trace: rank of argument must be greater than 1" |
2877 |
|
if axis_offset<0 or axis_offset>len(sh)-2: |
2878 |
|
raise ValueError,"trace: axis_offset must be between 0 and %s"%len(sh)-2 |
2879 |
|
s1=1 |
2880 |
|
for i in range(axis_offset): s1*=sh[i] |
2881 |
|
s2=1 |
2882 |
|
for i in range(axis_offset+2,len(sh)): s2*=sh[i] |
2883 |
|
if not sh[axis_offset] == sh[axis_offset+1]: |
2884 |
|
raise ValueError,"trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1) |
2885 |
|
arg_reshaped=numarray.reshape(arg,(s1,sh[axis_offset],sh[axis_offset],s2)) |
2886 |
|
out=numarray.zeros([s1,s2],numarray.Float) |
2887 |
|
for i1 in range(s1): |
2888 |
|
for i2 in range(s2): |
2889 |
|
for j in range(sh[axis_offset]): out[i1,i2]+=arg_reshaped[i1,j,j,i2] |
2890 |
|
out.resize(sh[:axis_offset]+sh[axis_offset+2:]) |
2891 |
|
return out |
2892 |
|
elif isinstance(arg,escript.Data): |
2893 |
|
return escript_trace(arg,axis_offset) |
2894 |
|
elif isinstance(arg,float): |
2895 |
|
raise TypeError,"trace: illegal argument type float." |
2896 |
|
elif isinstance(arg,int): |
2897 |
|
raise TypeError,"trace: illegal argument type int." |
2898 |
|
elif isinstance(arg,Symbol): |
2899 |
|
return Trace_Symbol(arg,axis_offset) |
2900 |
|
else: |
2901 |
|
raise TypeError,"trace: Unknown argument type." |
2902 |
|
|
2903 |
|
def escript_trace(arg,axis_offset): # this should be escript._trace |
2904 |
|
"arg si a Data objects!!!" |
2905 |
|
if arg.getRank()<2: |
2906 |
|
raise ValueError,"escript_trace: rank of argument must be greater than 1" |
2907 |
|
if axis_offset<0 or axis_offset>arg.getRank()-2: |
2908 |
|
raise ValueError,"escript_trace: axis_offset must be between 0 and %s"%arg.getRank()-2 |
2909 |
|
s=list(arg.getShape()) |
2910 |
|
if not s[axis_offset] == s[axis_offset+1]: |
2911 |
|
raise ValueError,"escript_trace: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1) |
2912 |
|
out=escript.Data(0.,tuple(s[0:axis_offset]+s[axis_offset+2:]),arg.getFunctionSpace()) |
2913 |
|
if arg.getRank()==2: |
2914 |
|
for i0 in range(s[0]): |
2915 |
|
out+=arg[i0,i0] |
2916 |
|
elif arg.getRank()==3: |
2917 |
|
if axis_offset==0: |
2918 |
|
for i0 in range(s[0]): |
2919 |
|
for i2 in range(s[2]): |
2920 |
|
out[i2]+=arg[i0,i0,i2] |
2921 |
|
elif axis_offset==1: |
2922 |
|
for i0 in range(s[0]): |
2923 |
|
for i1 in range(s[1]): |
2924 |
|
out[i0]+=arg[i0,i1,i1] |
2925 |
|
elif arg.getRank()==4: |
2926 |
|
if axis_offset==0: |
2927 |
|
for i0 in range(s[0]): |
2928 |
|
for i2 in range(s[2]): |
2929 |
|
for i3 in range(s[3]): |
2930 |
|
out[i2,i3]+=arg[i0,i0,i2,i3] |
2931 |
|
elif axis_offset==1: |
2932 |
|
for i0 in range(s[0]): |
2933 |
|
for i1 in range(s[1]): |
2934 |
|
for i3 in range(s[3]): |
2935 |
|
out[i0,i3]+=arg[i0,i1,i1,i3] |
2936 |
|
elif axis_offset==2: |
2937 |
|
for i0 in range(s[0]): |
2938 |
|
for i1 in range(s[1]): |
2939 |
|
for i2 in range(s[2]): |
2940 |
|
out[i0,i1]+=arg[i0,i1,i2,i2] |
2941 |
|
return out |
2942 |
|
class Trace_Symbol(DependendSymbol): |
2943 |
|
""" |
2944 |
|
L{Symbol} representing the result of the trace function |
2945 |
|
""" |
2946 |
|
def __init__(self,arg,axis_offset=0): |
2947 |
|
""" |
2948 |
|
initialization of trace L{Symbol} with argument arg |
2949 |
|
@param arg: argument of function |
2950 |
|
@type arg: L{Symbol}. |
2951 |
|
@param axis_offset: axis_offset to components to sum over. C{axis_offset} must be non-negative and less than the rank of arg +1. The dimensions on component |
2952 |
|
axis_offset and axis_offset+1 must be equal. |
2953 |
|
@type axis_offset: C{int} |
2954 |
|
""" |
2955 |
|
if arg.getRank()<2: |
2956 |
|
raise ValueError,"Trace_Symbol: rank of argument must be greater than 1" |
2957 |
|
if axis_offset<0 or axis_offset>arg.getRank()-2: |
2958 |
|
raise ValueError,"Trace_Symbol: axis_offset must be between 0 and %s"%arg.getRank()-2 |
2959 |
|
s=list(arg.getShape()) |
2960 |
|
if not s[axis_offset] == s[axis_offset+1]: |
2961 |
|
raise ValueError,"Trace_Symbol: dimensions of component %s and %s must match."%(axis_offset.axis_offset+1) |
2962 |
|
super(Trace_Symbol,self).__init__(args=[arg,axis_offset],shape=tuple(s[0:axis_offset]+s[axis_offset+2:]),dim=arg.getDim()) |
2963 |
|
|
2964 |
|
def getMyCode(self,argstrs,format="escript"): |
2965 |
|
""" |
2966 |
|
returns a program code that can be used to evaluate the symbol. |
2967 |
|
|
2968 |
|
@param argstrs: gives for each argument a string representing the argument for the evaluation. |
2969 |
|
@type argstrs: C{str} or a C{list} of length 1 of C{str}. |
2970 |
|
@param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported. |
2971 |
|
@type format: C{str} |
2972 |
|
@return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available. |
2973 |
|
@rtype: C{str} |
2974 |
|
@raise: NotImplementedError: if the requested format is not available |
2975 |
|
""" |
2976 |
|
if format=="escript" or format=="str" or format=="text": |
2977 |
|
return "trace(%s,axis_offset=%s)"%(argstrs[0],argstrs[1]) |
2978 |
|
else: |
2979 |
|
raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format |
2980 |
|
|
2981 |
|
def substitute(self,argvals): |
2982 |
|
""" |
2983 |
|
assigns new values to symbols in the definition of the symbol. |
2984 |
|
The method replaces the L{Symbol} u by argvals[u] in the expression defining this object. |
2985 |
|
|
2986 |
|
@param argvals: new values assigned to symbols |
2987 |
|
@type argvals: C{dict} with keywords of type L{Symbol}. |
2988 |
|
@return: result of the substitution process. Operations are executed as much as possible. |
2989 |
|
@rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution |
2990 |
|
@raise TypeError: if a value for a L{Symbol} cannot be substituted. |
2991 |
|
""" |
2992 |
|
if argvals.has_key(self): |
2993 |
|
arg=argvals[self] |
2994 |
|
if self.isAppropriateValue(arg): |
2995 |
|
return arg |
2996 |
|
else: |
2997 |
|
raise TypeError,"%s: new value is not appropriate."%str(self) |
2998 |
|
else: |
2999 |
|
arg=self.getSubstitutedArguments(argvals) |
3000 |
|
return trace(arg[0],axis_offset=arg[1]) |
3001 |
|
|
3002 |
|
def diff(self,arg): |
3003 |
|
""" |
3004 |
|
differential of this object |
3005 |
|
|
3006 |
|
@param arg: the derivative is calculated with respect to arg |
3007 |
|
@type arg: L{escript.Symbol} |
3008 |
|
@return: derivative with respect to C{arg} |
3009 |
|
@rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray} are possible. |
3010 |
|
""" |
3011 |
|
if arg==self: |
3012 |
|
return identity(self.getShape()) |
3013 |
|
else: |
3014 |
|
return trace(self.getDifferentiatedArguments(arg)[0],axis_offset=self.getArgument()[1]) |
3015 |
|
|
3016 |
#======================================================= |
#======================================================= |
3017 |
# Binary operations: |
# Binary operations: |
3018 |
#======================================================= |
#======================================================= |
3504 |
sh1=pokeShape(arg1) |
sh1=pokeShape(arg1) |
3505 |
if not sh0==sh1: |
if not sh0==sh1: |
3506 |
raise ValueError,"inner: shape of arguments does not match" |
raise ValueError,"inner: shape of arguments does not match" |
3507 |
return generalTensorProduct(arg0,arg1,offset=len(sh0)) |
return generalTensorProduct(arg0,arg1,axis_offset=len(sh0)) |
3508 |
|
|
3509 |
def matrixmult(arg0,arg1): |
def matrixmult(arg0,arg1): |
3510 |
""" |
""" |
3532 |
raise ValueError,"first argument must have rank 2" |
raise ValueError,"first argument must have rank 2" |
3533 |
if not len(sh1)==2 and not len(sh1)==1: |
if not len(sh1)==2 and not len(sh1)==1: |
3534 |
raise ValueError,"second argument must have rank 1 or 2" |
raise ValueError,"second argument must have rank 1 or 2" |
3535 |
return generalTensorProduct(arg0,arg1,offset=1) |
return generalTensorProduct(arg0,arg1,axis_offset=1) |
3536 |
|
|
3537 |
def outer(arg0,arg1): |
def outer(arg0,arg1): |
3538 |
""" |
""" |
3550 |
@return: the outer product of arg0 and arg1 at each data point |
@return: the outer product of arg0 and arg1 at each data point |
3551 |
@rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input |
@rtype: L{numarray.NumArray}, L{escript.Data}, L{Symbol} depending on the input |
3552 |
""" |
""" |
3553 |
return generalTensorProduct(arg0,arg1,offset=0) |
return generalTensorProduct(arg0,arg1,axis_offset=0) |
3554 |
|
|
3555 |
|
|
3556 |
def tensormult(arg0,arg1): |
def tensormult(arg0,arg1): |
3592 |
sh0=pokeShape(arg0) |
sh0=pokeShape(arg0) |
3593 |
sh1=pokeShape(arg1) |
sh1=pokeShape(arg1) |
3594 |
if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ): |
if len(sh0)==2 and ( len(sh1)==2 or len(sh1)==1 ): |
3595 |
return generalTensorProduct(arg0,arg1,offset=1) |
return generalTensorProduct(arg0,arg1,axis_offset=1) |
3596 |
elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4): |
elif len(sh0)==4 and (len(sh1)==2 or len(sh1)==3 or len(sh1)==4): |
3597 |
return generalTensorProduct(arg0,arg1,offset=2) |
return generalTensorProduct(arg0,arg1,axis_offset=2) |
3598 |
else: |
else: |
3599 |
raise ValueError,"tensormult: first argument must have rank 2 or 4" |
raise ValueError,"tensormult: first argument must have rank 2 or 4" |
3600 |
|
|
3601 |
def generalTensorProduct(arg0,arg1,offset=0): |
def generalTensorProduct(arg0,arg1,axis_offset=0): |
3602 |
""" |
""" |
3603 |
generalized tensor product |
generalized tensor product |
3604 |
|
|
3605 |
out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t] |
out[s,t]=S{Sigma}_r arg0[s,r]*arg1[r,t] |
3606 |
|
|
3607 |
where s runs through arg0.Shape[:arg0.Rank-offset] |
where s runs through arg0.Shape[:arg0.Rank-axis_offset] |
3608 |
r runs trough arg0.Shape[:offset] |
r runs trough arg0.Shape[:axis_offset] |
3609 |
t runs through arg1.Shape[offset:] |
t runs through arg1.Shape[axis_offset:] |
3610 |
|
|
3611 |
In the first case the the second dimension of arg0 and the length of arg1 must match and |
In the first case the the second dimension of arg0 and the length of arg1 must match and |
3612 |
in the second case the two last dimensions of arg0 must match the shape of arg1. |
in the second case the two last dimensions of arg0 must match the shape of arg1. |
3623 |
# at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols |
# at this stage arg0 and arg0 are both numarray.NumArray or escript.Data or Symbols |
3624 |
if isinstance(arg0,numarray.NumArray): |
if isinstance(arg0,numarray.NumArray): |
3625 |
if isinstance(arg1,Symbol): |
if isinstance(arg1,Symbol): |
3626 |
return GeneralTensorProduct_Symbol(arg0,arg1,offset) |
return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset) |
3627 |
else: |
else: |
3628 |
if not arg0.shape[arg0.rank-offset:]==arg1.shape[:offset]: |
if not arg0.shape[arg0.rank-axis_offset:]==arg1.shape[:axis_offset]: |
3629 |
raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset) |
raise ValueError,"generalTensorProduct: dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset) |
3630 |
arg0_c=arg0.copy() |
arg0_c=arg0.copy() |
3631 |
arg1_c=arg1.copy() |
arg1_c=arg1.copy() |
3632 |
sh0,sh1=arg0.shape,arg1.shape |
sh0,sh1=arg0.shape,arg1.shape |
3633 |
d0,d1,d01=1,1,1 |
d0,d1,d01=1,1,1 |
3634 |
for i in sh0[:arg0.rank-offset]: d0*=i |
for i in sh0[:arg0.rank-axis_offset]: d0*=i |
3635 |
for i in sh1[offset:]: d1*=i |
for i in sh1[axis_offset:]: d1*=i |
3636 |
for i in sh1[:offset]: d01*=i |
for i in sh1[:axis_offset]: d01*=i |
3637 |
arg0_c.resize((d0,d01)) |
arg0_c.resize((d0,d01)) |
3638 |
arg1_c.resize((d01,d1)) |
arg1_c.resize((d01,d1)) |
3639 |
out=numarray.zeros((d0,d1),numarray.Float) |
out=numarray.zeros((d0,d1),numarray.Float) |
3640 |
for i0 in range(d0): |
for i0 in range(d0): |
3641 |
for i1 in range(d1): |
for i1 in range(d1): |
3642 |
out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1]) |
out[i0,i1]=numarray.sum(arg0_c[i0,:]*arg1_c[:,i1]) |
3643 |
out.resize(sh0[:arg0.rank-offset]+sh1[offset:]) |
out.resize(sh0[:arg0.rank-axis_offset]+sh1[axis_offset:]) |
3644 |
return out |
return out |
3645 |
elif isinstance(arg0,escript.Data): |
elif isinstance(arg0,escript.Data): |
3646 |
if isinstance(arg1,Symbol): |
if isinstance(arg1,Symbol): |
3647 |
return GeneralTensorProduct_Symbol(arg0,arg1,offset) |
return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset) |
3648 |
else: |
else: |
3649 |
return escript_generalTensorProduct(arg0,arg1,offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,offset) |
return escript_generalTensorProduct(arg0,arg1,axis_offset) # this calls has to be replaced by escript._generalTensorProduct(arg0,arg1,axis_offset) |
3650 |
else: |
else: |
3651 |
return GeneralTensorProduct_Symbol(arg0,arg1,offset) |
return GeneralTensorProduct_Symbol(arg0,arg1,axis_offset) |
3652 |
|
|
3653 |
class GeneralTensorProduct_Symbol(DependendSymbol): |
class GeneralTensorProduct_Symbol(DependendSymbol): |
3654 |
""" |
""" |
3655 |
Symbol representing the quotient of two arguments. |
Symbol representing the quotient of two arguments. |
3656 |
""" |
""" |
3657 |
def __init__(self,arg0,arg1,offset=0): |
def __init__(self,arg0,arg1,axis_offset=0): |
3658 |
""" |
""" |
3659 |
initialization of L{Symbol} representing the quotient of two arguments |
initialization of L{Symbol} representing the quotient of two arguments |
3660 |
|
|
3667 |
""" |
""" |
3668 |
sh_arg0=pokeShape(arg0) |
sh_arg0=pokeShape(arg0) |
3669 |
sh_arg1=pokeShape(arg1) |
sh_arg1=pokeShape(arg1) |
3670 |
sh0=sh_arg0[:len(sh_arg0)-offset] |
sh0=sh_arg0[:len(sh_arg0)-axis_offset] |
3671 |
sh01=sh_arg0[len(sh_arg0)-offset:] |
sh01=sh_arg0[len(sh_arg0)-axis_offset:] |
3672 |
sh10=sh_arg1[:offset] |
sh10=sh_arg1[:axis_offset] |
3673 |
sh1=sh_arg1[offset:] |
sh1=sh_arg1[axis_offset:] |
3674 |
if not sh01==sh10: |
if not sh01==sh10: |
3675 |
raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset) |
raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset) |
3676 |
DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,offset]) |
DependendSymbol.__init__(self,dim=commonDim(arg0,arg1),shape=sh0+sh1,args=[arg0,arg1,axis_offset]) |
3677 |
|
|
3678 |
def getMyCode(self,argstrs,format="escript"): |
def getMyCode(self,argstrs,format="escript"): |
3679 |
""" |
""" |
3688 |
@raise: NotImplementedError: if the requested format is not available |
@raise: NotImplementedError: if the requested format is not available |
3689 |
""" |
""" |
3690 |
if format=="escript" or format=="str" or format=="text": |
if format=="escript" or format=="str" or format=="text": |
3691 |
return "generalTensorProduct(%s,%s,offset=%s)"%(argstrs[0],argstrs[1],argstrs[2]) |
return "generalTensorProduct(%s,%s,axis_offset=%s)"%(argstrs[0],argstrs[1],argstrs[2]) |
3692 |
else: |
else: |
3693 |
raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format) |
raise NotImplementedError,"%s does not provide program code for format %s."%(str(self),format) |
3694 |
|
|
3713 |
args=self.getSubstitutedArguments(argvals) |
args=self.getSubstitutedArguments(argvals) |
3714 |
return generalTensorProduct(args[0],args[1],args[2]) |
return generalTensorProduct(args[0],args[1],args[2]) |
3715 |
|
|
3716 |
def escript_generalTensorProduct(arg0,arg1,offset): # this should be escript._generalTensorProduct |
def escript_generalTensorProduct(arg0,arg1,axis_offset): # this should be escript._generalTensorProduct |
3717 |
"arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!" |
"arg0 and arg1 are both Data objects but not neccesrily on the same function space. they could be identical!!!" |
3718 |
# calculate the return shape: |
# calculate the return shape: |
3719 |
shape0=arg0.getShape()[:arg0.getRank()-offset] |
shape0=arg0.getShape()[:arg0.getRank()-axis_offset] |
3720 |
shape01=arg0.getShape()[arg0.getRank()-offset:] |
shape01=arg0.getShape()[arg0.getRank()-axis_offset:] |
3721 |
shape10=arg1.getShape()[:offset] |
shape10=arg1.getShape()[:axis_offset] |
3722 |
shape1=arg1.getShape()[offset:] |
shape1=arg1.getShape()[axis_offset:] |
3723 |
if not shape01==shape10: |
if not shape01==shape10: |
3724 |
raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(offset,offset) |
raise ValueError,"dimensions of last %s components in left argument don't match the first %s components in the right argument."%(axis_offset,axis_offset) |
3725 |
|
|
3726 |
# whatr function space should be used? (this here is not good!) |
# whatr function space should be used? (this here is not good!) |
3727 |
fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace() |
fs=(escript.Scalar(0.,arg0.getFunctionSpace())+escript.Scalar(0.,arg1.getFunctionSpace())).getFunctionSpace() |
3755 |
out.__setitem__(tuple(i0+i1),s) |
out.__setitem__(tuple(i0+i1),s) |
3756 |
return out |
return out |
3757 |
|
|
3758 |
|
|
3759 |
#========================================================= |
#========================================================= |
3760 |
# some little helpers |
# functions dealing with spatial dependency |
3761 |
#========================================================= |
#========================================================= |
3762 |
def grad(arg,where=None): |
def grad(arg,where=None): |
3763 |
""" |
""" |
3764 |
Returns the spatial gradient of arg at where. |
Returns the spatial gradient of arg at where. |
3765 |
|
|
3766 |
@param arg: Data object representing the function which gradient |
If C{g} is the returned object, then |
3767 |
to be calculated. |
|
3768 |
|
- if C{arg} is rank 0 C{g[s]} is the derivative of C{arg} with respect to the C{s}-th spatial dimension. |
3769 |
|
- if C{arg} is rank 1 C{g[i,s]} is the derivative of C{arg[i]} with respect to the C{s}-th spatial dimension. |
3770 |
|
- if C{arg} is rank 2 C{g[i,j,s]} is the derivative of C{arg[i,j]} with respect to the C{s}-th spatial dimension. |
3771 |
|
- if C{arg} is rank 3 C{g[i,j,k,s]} is the derivative of C{arg[i,j,k]} with respect to the C{s}-th spatial dimension. |
3772 |
|
|
3773 |
|
@param arg: function which gradient to be calculated. Its rank has to be less than 3. |
3774 |
|
@type arg: L{escript.Data} or L{Symbol} |
3775 |
@param where: FunctionSpace in which the gradient will be calculated. |
@param where: FunctionSpace in which the gradient will be calculated. |
3776 |
If not present or C{None} an appropriate default is used. |
If not present or C{None} an appropriate default is used. |
3777 |
|
@type where: C{None} or L{escript.FunctionSpace} |
3778 |
|
@return: gradient of arg. |
3779 |
|
@rtype: L{escript.Data} or L{Symbol} |
3780 |
""" |
""" |
3781 |
if isinstance(arg,Symbol): |
if isinstance(arg,Symbol): |
3782 |
return Grad_Symbol(arg,where) |
return Grad_Symbol(arg,where) |
3786 |
else: |
else: |
3787 |
return arg._grad(where) |
return arg._grad(where) |
3788 |
else: |
else: |
3789 |
raise TypeError,"grad: Unknown argument type." |
raise TypeError,"grad: Unknown argument type." |
3790 |
|
|
3791 |
|
class Grad_Symbol(DependendSymbol): |
3792 |
|
""" |
3793 |
|
L{Symbol} representing the result of the gradient operator |
3794 |
|
""" |
3795 |
|
def __init__(self,arg,where=None): |
3796 |
|
""" |
3797 |
|
initialization of gradient L{Symbol} with argument arg |
3798 |
|
@param arg: argument of function |
3799 |
|
@type arg: L{Symbol}. |
3800 |
|
@param where: FunctionSpace in which the gradient will be calculated. |
3801 |
|
If not present or C{None} an appropriate default is used. |
3802 |
|
@type where: C{None} or L{escript.FunctionSpace} |
3803 |
|
""" |
3804 |
|
d=arg.getDim() |
3805 |
|
if d==None: |
3806 |
|
raise ValueError,"argument must have a spatial dimension" |
3807 |
|
super(Grad_Symbol,self).__init__(args=[arg,where],shape=tuple(list(arg.getShape()).extend(d)),dim=d) |
3808 |
|
|
3809 |
|
def getMyCode(self,argstrs,format="escript"): |
3810 |
|
""" |
3811 |
|
returns a program code that can be used to evaluate the symbol. |
3812 |
|
|
3813 |
|
@param argstrs: gives for each argument a string representing the argument for the evaluation. |
3814 |
|
@type argstrs: C{str} or a C{list} of length 1 of C{str}. |
3815 |
|
@param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported. |
3816 |
|
@type format: C{str} |
3817 |
|
@return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available. |
3818 |
|
@rtype: C{str} |
3819 |
|
@raise: NotImplementedError: if the requested format is not available |
3820 |
|
""" |
3821 |
|
if format=="escript" or format=="str" or format=="text": |
3822 |
|
return "grad(%s,where=%s)"%(argstrs[0],argstrs[1]) |
3823 |
|
else: |
3824 |
|
raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format |
3825 |
|
|
3826 |
|
def substitute(self,argvals): |
3827 |
|
""" |
3828 |
|
assigns new values to symbols in the definition of the symbol. |
3829 |
|
The method replaces the L{Symbol} u by argvals[u] in the expression defining this object. |
3830 |
|
|
3831 |
|
@param argvals: new values assigned to symbols |
3832 |
|
@type argvals: C{dict} with keywords of type L{Symbol}. |
3833 |
|
@return: result of the substitution process. Operations are executed as much as possible. |
3834 |
|
@rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution |
3835 |
|
@raise TypeError: if a value for a L{Symbol} cannot be substituted. |
3836 |
|
""" |
3837 |
|
if argvals.has_key(self): |
3838 |
|
arg=argvals[self] |
3839 |
|
if self.isAppropriateValue(arg): |
3840 |
|
return arg |
3841 |
|
else: |
3842 |
|
raise TypeError,"%s: new value is not appropriate."%str(self) |
3843 |
|
else: |
3844 |
|
arg=self.getSubstitutedArguments(argvals) |
3845 |
|
return grad(arg[0],where=arg[1]) |
3846 |
|
|
3847 |
|
def diff(self,arg): |
3848 |
|
""" |
3849 |
|
differential of this object |
3850 |
|
|
3851 |
|
@param arg: the derivative is calculated with respect to arg |
3852 |
|
@type arg: L{escript.Symbol} |
3853 |
|
@return: derivative with respect to C{arg} |
3854 |
|
@rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray} are possible. |
3855 |
|
""" |
3856 |
|
if arg==self: |
3857 |
|
return identity(self.getShape()) |
3858 |
|
else: |
3859 |
|
return grad(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1]) |
3860 |
|
|
3861 |
def integrate(arg,where=None): |
def integrate(arg,where=None): |
3862 |
""" |
""" |
3863 |
Return the integral if the function represented by Data object arg over |
Return the integral of the function C{arg} over its domain. If C{where} is present C{arg} is interpolated to C{where} |
3864 |
its domain. |
before integration. |
3865 |
|
|
3866 |
@param arg: Data object representing the function which is integrated. |
@param arg: the function which is integrated. |
3867 |
|
@type arg: L{escript.Data} or L{Symbol} |
3868 |
@param where: FunctionSpace in which the integral is calculated. |
@param where: FunctionSpace in which the integral is calculated. |
3869 |
If not present or C{None} an appropriate default is used. |
If not present or C{None} an appropriate default is used. |
3870 |
|
@type where: C{None} or L{escript.FunctionSpace} |
3871 |
|
@return: integral of arg. |
3872 |
|
@rtype: C{float}, C{numarray.NumArray} or L{Symbol} |
3873 |
""" |
""" |
3874 |
if isinstance(arg,numarray.NumArray): |
if isinstance(arg,Symbol): |
|
if checkForZero(arg): |
|
|
return arg |
|
|
else: |
|
|
raise TypeError,"integrate: cannot intergrate argument" |
|
|
elif isinstance(arg,float): |
|
|
if checkForZero(arg): |
|
|
return arg |
|
|
else: |
|
|
raise TypeError,"integrate: cannot intergrate argument" |
|
|
elif isinstance(arg,int): |
|
|
if checkForZero(arg): |
|
|
return float(arg) |
|
|
else: |
|
|
raise TypeError,"integrate: cannot intergrate argument" |
|
|
elif isinstance(arg,Symbol): |
|
3875 |
return Integrate_Symbol(arg,where) |
return Integrate_Symbol(arg,where) |
3876 |
elif isinstance(arg,escript.Data): |
elif isinstance(arg,escript.Data): |
3877 |
if not where==None: arg=escript.Data(arg,where) |
if not where==None: arg=escript.Data(arg,where) |
3882 |
else: |
else: |
3883 |
raise TypeError,"integrate: Unknown argument type." |
raise TypeError,"integrate: Unknown argument type." |
3884 |
|
|
3885 |
|
class Integrate_Symbol(DependendSymbol): |
3886 |
|
""" |
3887 |
|
L{Symbol} representing the result of the spatial integration operator |
3888 |
|
""" |
3889 |
|
def __init__(self,arg,where=None): |
3890 |
|
""" |
3891 |
|
initialization of integration L{Symbol} with argument arg |
3892 |
|
@param arg: argument of the integration |
3893 |
|
@type arg: L{Symbol}. |
3894 |
|
@param where: FunctionSpace in which the integration will be calculated. |
3895 |
|
If not present or C{None} an appropriate default is used. |
3896 |
|
@type where: C{None} or L{escript.FunctionSpace} |
3897 |
|
""" |
3898 |
|
super(Integrate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim()) |
3899 |
|
|
3900 |
|
def getMyCode(self,argstrs,format="escript"): |
3901 |
|
""" |
3902 |
|
returns a program code that can be used to evaluate the symbol. |
3903 |
|
|
3904 |
|
@param argstrs: gives for each argument a string representing the argument for the evaluation. |
3905 |
|
@type argstrs: C{str} or a C{list} of length 1 of C{str}. |
3906 |
|
@param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported. |
3907 |
|
@type format: C{str} |
3908 |
|
@return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available. |
3909 |
|
@rtype: C{str} |
3910 |
|
@raise: NotImplementedError: if the requested format is not available |
3911 |
|
""" |
3912 |
|
if format=="escript" or format=="str" or format=="text": |
3913 |
|
return "integrate(%s,where=%s)"%(argstrs[0],argstrs[1]) |
3914 |
|
else: |
3915 |
|
raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format |
3916 |
|
|
3917 |
|
def substitute(self,argvals): |
3918 |
|
""" |
3919 |
|
assigns new values to symbols in the definition of the symbol. |
3920 |
|
The method replaces the L{Symbol} u by argvals[u] in the expression defining this object. |
3921 |
|
|
3922 |
|
@param argvals: new values assigned to symbols |
3923 |
|
@type argvals: C{dict} with keywords of type L{Symbol}. |
3924 |
|
@return: result of the substitution process. Operations are executed as much as possible. |
3925 |
|
@rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution |
3926 |
|
@raise TypeError: if a value for a L{Symbol} cannot be substituted. |
3927 |
|
""" |
3928 |
|
if argvals.has_key(self): |
3929 |
|
arg=argvals[self] |
3930 |
|
if self.isAppropriateValue(arg): |
3931 |
|
return arg |
3932 |
|
else: |
3933 |
|
raise TypeError,"%s: new value is not appropriate."%str(self) |
3934 |
|
else: |
3935 |
|
arg=self.getSubstitutedArguments(argvals) |
3936 |
|
return integrate(arg[0],where=arg[1]) |
3937 |
|
|
3938 |
|
def diff(self,arg): |
3939 |
|
""" |
3940 |
|
differential of this object |
3941 |
|
|
3942 |
|
@param arg: the derivative is calculated with respect to arg |
3943 |
|
@type arg: L{escript.Symbol} |
3944 |
|
@return: derivative with respect to C{arg} |
3945 |
|
@rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray} are possible. |
3946 |
|
""" |
3947 |
|
if arg==self: |
3948 |
|
return identity(self.getShape()) |
3949 |
|
else: |
3950 |
|
return integrate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1]) |
3951 |
|
|
3952 |
|
|
3953 |
def interpolate(arg,where): |
def interpolate(arg,where): |
3954 |
""" |
""" |
3955 |
Interpolates the function into the FunctionSpace where. |
interpolates the function into the FunctionSpace where. |
3956 |
|
|
3957 |
@param arg: interpolant |
@param arg: interpolant |
3958 |
@param where: FunctionSpace to interpolate to |
@type arg: L{escript.Data} or L{Symbol} |
3959 |
|
@param where: FunctionSpace to be interpolated to |
3960 |
|
@type where: L{escript.FunctionSpace} |
3961 |
|
@return: interpolated argument |
3962 |
|
@rtype: C{escript.Data} or L{Symbol} |
3963 |
""" |
""" |
3964 |
if isinstance(arg,Symbol): |
if isinstance(arg,Symbol): |
3965 |
return Interpolated_Symbol(arg,where) |
return Interpolate_Symbol(arg,where) |
3966 |
else: |
else: |
3967 |
return escript.Data(arg,where) |
return escript.Data(arg,where) |
3968 |
|
|
3969 |
|
class Interpolate_Symbol(DependendSymbol): |
3970 |
|
""" |
3971 |
|
L{Symbol} representing the result of the interpolation operator |
3972 |
|
""" |
3973 |
|
def __init__(self,arg,where): |
3974 |
|
""" |
3975 |
|
initialization of interpolation L{Symbol} with argument arg |
3976 |
|
@param arg: argument of the interpolation |
3977 |
|
@type arg: L{Symbol}. |
3978 |
|
@param where: FunctionSpace into which the argument is interpolated. |
3979 |
|
@type where: L{escript.FunctionSpace} |
3980 |
|
""" |
3981 |
|
super(Interpolate_Symbol,self).__init__(args=[arg,where],shape=arg.getShape(),dim=arg.getDim()) |
3982 |
|
|
3983 |
|
def getMyCode(self,argstrs,format="escript"): |
3984 |
|
""" |
3985 |
|
returns a program code that can be used to evaluate the symbol. |
3986 |
|
|
3987 |
|
@param argstrs: gives for each argument a string representing the argument for the evaluation. |
3988 |
|
@type argstrs: C{str} or a C{list} of length 1 of C{str}. |
3989 |
|
@param format: specifies the format to be used. At the moment only "escript" ,"text" and "str" are supported. |
3990 |
|
@type format: C{str} |
3991 |
|
@return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available. |
3992 |
|
@rtype: C{str} |
3993 |
|
@raise: NotImplementedError: if the requested format is not available |
3994 |
|
""" |
3995 |
|
if format=="escript" or format=="str" or format=="text": |
3996 |
|
return "interpolate(%s,where=%s)"%(argstrs[0],argstrs[1]) |
3997 |
|
else: |
3998 |
|
raise NotImplementedError,"Trace_Symbol does not provide program code for format %s."%format |
3999 |
|
|
4000 |
|
def substitute(self,argvals): |
4001 |
|
""" |
4002 |
|
assigns new values to symbols in the definition of the symbol. |
4003 |
|
The method replaces the L{Symbol} u by argvals[u] in the expression defining this object. |
4004 |
|
|
4005 |
|
@param argvals: new values assigned to symbols |
4006 |
|
@type argvals: C{dict} with keywords of type L{Symbol}. |
4007 |
|
@return: result of the substitution process. Operations are executed as much as possible. |
4008 |
|
@rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution |
4009 |
|
@raise TypeError: if a value for a L{Symbol} cannot be substituted. |
4010 |
|
""" |
4011 |
|
if argvals.has_key(self): |
4012 |
|
arg=argvals[self] |
4013 |
|
if self.isAppropriateValue(arg): |
4014 |
|
return arg |
4015 |
|
else: |
4016 |
|
raise TypeError,"%s: new value is not appropriate."%str(self) |
4017 |
|
else: |
4018 |
|
arg=self.getSubstitutedArguments(argvals) |
4019 |
|
return interpolate(arg[0],where=arg[1]) |
4020 |
|
|
4021 |
|
def diff(self,arg): |
4022 |
|
""" |
4023 |
|
differential of this object |
4024 |
|
|
4025 |
|
@param arg: the derivative is calculated with respect to arg |
4026 |
|
@type arg: L{escript.Symbol} |
4027 |
|
@return: derivative with respect to C{arg} |
4028 |
|
@rtype: L{Symbol} but other types such as L{escript.Data}, L{numarray.NumArray} are possible. |
4029 |
|
""" |
4030 |
|
if arg==self: |
4031 |
|
return identity(self.getShape()) |
4032 |
|
else: |
4033 |
|
return interpolate(self.getDifferentiatedArguments(arg)[0],where=self.getArgument()[1]) |
4034 |
|
|
4035 |
|
|
4036 |
def div(arg,where=None): |
def div(arg,where=None): |
4037 |
""" |
""" |
4038 |
Returns the divergence of arg at where. |
returns the divergence of arg at where. |
4039 |
|
|
4040 |
@param arg: Data object representing the function which gradient to |
@param arg: function which divergence to be calculated. Its shape has to be (d,) where d is the spatial dimension. |
4041 |
be calculated. |
@type arg: L{escript.Data} or L{Symbol} |
4042 |
@param where: FunctionSpace in which the gradient will be calculated. |
@param where: FunctionSpace in which the divergence will be calculated. |
4043 |
If not present or C{None} an appropriate default is used. |
If not present or C{None} an appropriate default is used. |
4044 |
|
@type where: C{None} or L{escript.FunctionSpace} |
4045 |
|
@return: divergence of arg. |
4046 |
|
@rtype: L{escript.Data} or L{Symbol} |
4047 |
""" |
""" |
4048 |
g=grad(arg,where) |
if not arg.getShape()==(arg.getDim(),): |
4049 |
return trace(g,axis0=g.getRank()-2,axis1=g.getRank()-1) |
raise ValueError,"div: expected shape is (%s,)"%arg.getDim() |
4050 |
|
return trace(grad(arg,where)) |
4051 |
|
|
4052 |
def jump(arg): |
def jump(arg,domain=None): |
4053 |
""" |
""" |
4054 |
Returns the jump of arg across a continuity. |
returns the jump of arg across the continuity of the domain |
4055 |
|
|
4056 |
@param arg: Data object representing the function which gradient |
@param arg: argument |
4057 |
to be calculated. |
@type arg: L{escript.Data} or L{Symbol} |
4058 |
|
@param domain: the domain where the discontinuity is located. If domain is not present or equal to C{None} |
4059 |
|
the domain of arg is used. If arg is a L{Symbol} the domain must be present. |
4060 |
|
@type domain: C{None} or L{escript.Domain} |
4061 |
|
@return: jump of arg |
4062 |
|
@rtype: L{escript.Data} or L{Symbol} |
4063 |
""" |
""" |
4064 |
d=arg.getDomain() |
if domain==None: domain=arg.getDomain() |
4065 |
return arg.interpolate(escript.FunctionOnContactOne(d))-arg.interpolate(escript.FunctionOnContactZero(d)) |
return interpolate(arg,escript.FunctionOnContactOne(domain))-interpolate(arg,escript.FunctionOnContactZero(domain)) |
|
|
|
4066 |
#============================= |
#============================= |
4067 |
# |
# |
4068 |
# wrapper for various functions: if the argument has attribute the function name |
# wrapper for various functions: if the argument has attribute the function name |
4099 |
else: |
else: |
4100 |
return numarray.transpose(arg,axis=axis) |
return numarray.transpose(arg,axis=axis) |
4101 |
|
|
|
def trace(arg,axis0=0,axis1=1): |
|
|
""" |
|
|
Return |
|
|
|
|
|
@param arg: |
|
|
""" |
|
|
if isinstance(arg,Symbol): |
|
|
s=list(arg.getShape()) |
|
|
s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:]) |
|
|
return Trace_Symbol(arg,axis0=axis0,axis1=axis1) |
|
|
elif isinstance(arg,escript.Data): |
|
|
# hack for trace |
|
|
s=arg.getShape() |
|
|
if s[axis0]!=s[axis1]: |
|
|
raise ValueError,"illegal axis in trace" |
|
|
out=escript.Scalar(0.,arg.getFunctionSpace()) |
|
|
for i in range(s[axis0]): |
|
|
out+=arg[i,i] |
|
|
return out |
|
|
# end hack for trace |
|
|
else: |
|
|
return numarray.trace(arg,axis0=axis0,axis1=axis1) |
|
4102 |
|
|
4103 |
|
|
4104 |
def reorderComponents(arg,index): |
def reorderComponents(arg,index): |