/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Annotation of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 123 - (hide annotations)
Fri Jul 8 04:08:13 2005 UTC (14 years, 5 months ago) by jgs
Original Path: trunk/esys2/escript/py_src/util.py
File MIME type: text/x-python
File size: 39052 byte(s)
Merge of development branch back to main trunk on 2005-07-08

1 jgs 82 # $Id$
2    
3     ## @file util.py
4    
5     """
6     @brief Utility functions for escript
7 jgs 123
8     TODO for Data:
9    
10     * binary operations @: (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11     @=+,-,*,/,** (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12     (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13    
14     * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15     * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16    
17 jgs 82 """
18    
19     import numarray
20 jgs 102 import escript
21 jgs 82 #
22 jgs 102 # escript constants (have to be consistent witj utilC.h
23 jgs 82 #
24     UNKNOWN=-1
25     EPSILON=1.e-15
26 jgs 102 Pi=numarray.pi
27 jgs 82 # some solver options:
28     NO_REORDERING=0
29     MINIMUM_FILL_IN=1
30     NESTED_DISSECTION=2
31 jgs 102 # solver methods
32 jgs 82 DEFAULT_METHOD=0
33 jgs 102 DIRECT=1
34     CHOLEVSKY=2
35     PCG=3
36     CR=4
37     CGS=5
38     BICGSTAB=6
39     SSOR=7
40     ILU0=8
41     ILUT=9
42     JACOBI=10
43     GMRES=11
44     PRES20=12
45    
46     METHOD_KEY="method"
47     SYMMETRY_KEY="symmetric"
48     TOLERANCE_KEY="tolerance"
49    
50 jgs 82 # supported file formats:
51     VRML=1
52     PNG=2
53     JPEG=3
54     JPG=3
55     PS=4
56     OOGL=5
57     BMP=6
58     TIFF=7
59     OPENINVENTOR=8
60     RENDERMAN=9
61     PNM=10
62 jgs 123 #===========================================================
63     # a simple tool box to deal with _differentials of functions
64     #===========================================================
65 jgs 113
66 jgs 123 class Symbol:
67     """symbol class"""
68     def __init__(self,name="symbol",shape=(),dim=3,args=[]):
69     """creates an instance of a symbol of shape shape and spatial dimension dim
70     The symbol may depending on a list of arguments args which
71     may be symbols or other objects. name gives the name of the symbol."""
72     self.__args=args
73     self.__name=name
74     self.__shape=shape
75     if hasattr(dim,"getDim"):
76     self.__dim=dim.getDim()
77     else:
78     self.__dim=dim
79     #
80     self.__cache_val=None
81     self.__cache_argval=None
82 jgs 82
83 jgs 123 def getArgument(self,i):
84     """returns the i-th argument"""
85     return self.__args[i]
86    
87     def getDim(self):
88     """returns the spatial dimension of the symbol"""
89     return self.__dim
90    
91     def getRank(self):
92     """returns the rank of the symbol"""
93     return len(self.getShape())
94    
95     def getShape(self):
96     """returns the shape of the symbol"""
97     return self.__shape
98    
99     def getEvaluatedArguments(self,argval):
100     """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
101     if argval==self.__cache_argval:
102     print "%s: cached value used"%self
103     return self.__cache_val
104     else:
105     out=[]
106     for a in self.__args:
107     if isinstance(a,Symbol):
108     out.append(a.eval(argval))
109     else:
110     out.append(a)
111     self.__cache_argval=argval
112     self.__cache_val=out
113     return out
114    
115     def getDifferentiatedArguments(self,arg):
116     """returns the list of the arguments _differentiated by arg"""
117     out=[]
118     for a in self.__args:
119     if isinstance(a,Symbol):
120     out.append(a.diff(arg))
121     else:
122     out.append(0)
123     return out
124    
125     def diff(self,arg):
126     """returns the _differention of self by arg."""
127     if self==arg:
128     out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
129     if self.getRank()==0:
130     out=1.
131     elif self.getRank()==1:
132     for i0 in range(self.getShape()[0]):
133     out[i0,i0]=1.
134     elif self.getRank()==2:
135     for i0 in range(self.getShape()[0]):
136     for i1 in range(self.getShape()[1]):
137     out[i0,i1,i0,i1]=1.
138     elif self.getRank()==3:
139     for i0 in range(self.getShape()[0]):
140     for i1 in range(self.getShape()[1]):
141     for i2 in range(self.getShape()[2]):
142     out[i0,i1,i2,i0,i1,i2]=1.
143     elif self.getRank()==4:
144     for i0 in range(self.getShape()[0]):
145     for i1 in range(self.getShape()[1]):
146     for i2 in range(self.getShape()[2]):
147     for i3 in range(self.getShape()[3]):
148     out[i0,i1,i2,i3,i0,i1,i2,i3]=1.
149     else:
150     raise ValueError,"differential support rank<5 only."
151     return out
152 jgs 108 else:
153 jgs 123 return self._diff(arg)
154 jgs 82
155 jgs 123 def _diff(self,arg):
156     """return derivate of self with respect to arg (!=self).
157     This method is overwritten by a particular symbol"""
158     return 0
159 jgs 82
160 jgs 123 def eval(self,argval):
161     """subsitutes symbol u in self by argval[u] and returns the result. If
162     self is not a key of argval then self is returned."""
163     if argval.has_key(self):
164     return argval[self]
165     else:
166     return self
167    
168    
169     def __str__(self):
170     """returns a string representation of the symbol"""
171     return self.__name
172    
173     def __add__(self,other):
174     """adds other to symbol self. if _testForZero(other) self is returned."""
175     if _testForZero(other):
176     return self
177     else:
178     a=_matchShape([self,other])
179     return Add_Symbol(a[0],a[1])
180    
181    
182     def __radd__(self,other):
183     """adds other to symbol self. if _testForZero(other) self is returned."""
184     return self+other
185    
186     def __neg__(self):
187     """returns -self."""
188     return self*(-1.)
189    
190     def __pos__(self):
191     """returns +self."""
192     return self
193    
194     def __abs__(self):
195     """returns absolute value"""
196     return Abs_Symbol(self)
197    
198     def __sub__(self,other):
199     """subtracts other from symbol self. if _testForZero(other) self is returned."""
200     if _testForZero(other):
201     return self
202     else:
203     return self+(-other)
204    
205     def __rsub__(self,other):
206     """subtracts symbol self from other."""
207     return -self+other
208    
209     def __div__(self,other):
210     """divides symbol self by other."""
211     if isinstance(other,Symbol):
212     a=_matchShape([self,other])
213     return Div_Symbol(a[0],a[1])
214     else:
215     return self*(1./other)
216    
217     def __rdiv__(self,other):
218     """dived other by symbol self. if _testForZero(other) 0 is returned."""
219     if _testForZero(other):
220     return 0
221     else:
222     a=_matchShape([self,other])
223     return Div_Symbol(a[0],a[1])
224    
225     def __pow__(self,other):
226     """raises symbol self to the power of other"""
227     a=_matchShape([self,other])
228     return Power_Symbol(a[0],a[1])
229    
230     def __rpow__(self,other):
231     """raises other to the symbol self"""
232     a=_matchShape([self,other])
233     return Power_Symbol(a[1],a[0])
234    
235     def __mul__(self,other):
236     """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
237     if _testForZero(other):
238     return 0
239     else:
240     a=_matchShape([self,other])
241     return Mult_Symbol(a[0],a[1])
242    
243     def __rmul__(self,other):
244     """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
245     return self*other
246    
247     def __getitem__(self,sl):
248     print sl
249    
250     def Float_Symbol(Symbol):
251     def __init__(self,name="symbol",shape=(),args=[]):
252     Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
253    
254     class ScalarSymbol(Symbol):
255     """a scalar symbol"""
256     def __init__(self,dim=3,name="scalar"):
257     """creates a scalar symbol of spatial dimension dim"""
258     if hasattr(dim,"getDim"):
259     d=dim.getDim()
260     else:
261     d=dim
262     Symbol.__init__(self,shape=(),dim=d,name=name)
263    
264     class VectorSymbol(Symbol):
265     """a vector symbol"""
266     def __init__(self,dim=3,name="vector"):
267     """creates a vector symbol of spatial dimension dim"""
268     if hasattr(dim,"getDim"):
269     d=dim.getDim()
270     else:
271     d=dim
272     Symbol.__init__(self,shape=(d,),dim=d,name=name)
273    
274     class TensorSymbol(Symbol):
275     """a tensor symbol"""
276     def __init__(self,dim=3,name="tensor"):
277     """creates a tensor symbol of spatial dimension dim"""
278     if hasattr(dim,"getDim"):
279     d=dim.getDim()
280     else:
281     d=dim
282     Symbol.__init__(self,shape=(d,d),dim=d,name=name)
283    
284     class Tensor3Symbol(Symbol):
285     """a tensor order 3 symbol"""
286     def __init__(self,dim=3,name="tensor3"):
287     """creates a tensor order 3 symbol of spatial dimension dim"""
288     if hasattr(dim,"getDim"):
289     d=dim.getDim()
290     else:
291     d=dim
292     Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
293    
294     class Tensor4Symbol(Symbol):
295     """a tensor order 4 symbol"""
296     def __init__(self,dim=3,name="tensor4"):
297     """creates a tensor order 4 symbol of spatial dimension dim"""
298     if hasattr(dim,"getDim"):
299     d=dim.getDim()
300     else:
301     d=dim
302     Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
303    
304    
305     class Add_Symbol(Symbol):
306     """symbol representing the sum of two arguments"""
307     def __init__(self,arg0,arg1):
308     a=[arg0,arg1]
309     Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
310     def __str__(self):
311     return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
312     def eval(self,argval):
313     a=self.getEvaluatedArguments(argval)
314     return a[0]+a[1]
315     def _diff(self,arg):
316     a=self.getDifferentiatedArguments(arg)
317     return a[0]+a[1]
318    
319     class Mult_Symbol(Symbol):
320     """symbol representing the product of two arguments"""
321     def __init__(self,arg0,arg1):
322     a=[arg0,arg1]
323     Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
324     def __str__(self):
325     return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
326     def eval(self,argval):
327     a=self.getEvaluatedArguments(argval)
328     return a[0]*a[1]
329     def _diff(self,arg):
330     a=self.getDifferentiatedArguments(arg)
331     return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
332    
333     class Div_Symbol(Symbol):
334     """symbol representing the quotient of two arguments"""
335     def __init__(self,arg0,arg1):
336     a=[arg0,arg1]
337     Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
338     def __str__(self):
339     return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
340     def eval(self,argval):
341     a=self.getEvaluatedArguments(argval)
342     return a[0]/a[1]
343     def _diff(self,arg):
344     a=self.getDifferentiatedArguments(arg)
345     return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
346     (self.getArgument(1)*self.getArgument(1))
347    
348     class Power_Symbol(Symbol):
349     """symbol representing the power of the first argument to the power of the second argument"""
350     def __init__(self,arg0,arg1):
351     a=[arg0,arg1]
352     Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
353     def __str__(self):
354     return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
355     def eval(self,argval):
356     a=self.getEvaluatedArguments(argval)
357     return a[0]**a[1]
358     def _diff(self,arg):
359     a=self.getDifferentiatedArguments(arg)
360     return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
361    
362     class Abs_Symbol(Symbol):
363     """symbol representing absolute value of its argument"""
364     def __init__(self,arg):
365     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
366     def __str__(self):
367     return "abs(%s)"%str(self.getArgument(0))
368     def eval(self,argval):
369     return abs(self.getEvaluatedArguments(argval)[0])
370     def _diff(self,arg):
371     return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
372    
373     #=========================================================
374     # some little helpers
375     #=========================================================
376     def _testForZero(arg):
377     """returns True is arg is considered of being zero"""
378     if isinstance(arg,int):
379     return not arg>0
380     elif isinstance(arg,float):
381     return not arg>0.
382     elif isinstance(arg,numarray.NumArray):
383     a=abs(arg)
384     while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
385     return not a>0
386     else:
387     return False
388    
389     def _extractDim(args):
390     dim=None
391     for a in args:
392     if hasattr(a,"getDim"):
393     d=a.getDim()
394     if dim==None:
395     dim=d
396     else:
397     if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
398     if dim==None:
399     raise ValueError,"cannot recover spatial dimension"
400     return dim
401    
402     def _identifyShape(arg):
403     """identifies the shape of arg."""
404     if hasattr(arg,"getShape"):
405     arg_shape=arg.getShape()
406     elif hasattr(arg,"shape"):
407     s=arg.shape
408     if callable(s):
409     arg_shape=s()
410     else:
411     arg_shape=s
412     else:
413     arg_shape=()
414     return arg_shape
415    
416     def _extractShape(args):
417     """extracts the common shape of the list of arguments args"""
418     shape=None
419     for a in args:
420     a_shape=_identifyShape(a)
421     if shape==None: shape=a_shape
422     if shape!=a_shape: raise ValueError,"inconsistent shape"
423     if shape==None:
424     raise ValueError,"cannot recover shape"
425     return shape
426    
427     def _matchShape(args,shape=None):
428     """returns the list of arguments args as object which have all the specified shape.
429     if shape is not given the shape "largest" shape of args is used."""
430     # identify the list of shapes:
431     arg_shapes=[]
432     for a in args: arg_shapes.append(_identifyShape(a))
433     # get the largest shape (currently the longest shape):
434     if shape==None: shape=max(arg_shapes)
435    
436     out=[]
437     for i in range(len(args)):
438     if shape==arg_shapes[i]:
439     out.append(args[i])
440     else:
441     if len(shape)==0: # then len(arg_shapes[i])>0
442     raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
443     else:
444     if len(arg_shapes[i])==0:
445     out.append(outer(args[i],numarray.ones(shape)))
446     else:
447     raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
448     return out
449     #=========================================================
450     # wrapper for various mathematical functions:
451     #=========================================================
452     def diff(arg,dep):
453     """returns the derivative of arg with respect to dep. If arg is not Symbol object
454     0 is returned"""
455     if isinstance(arg,Symbol):
456     return arg.diff(dep)
457     elif hasattr(arg,"shape"):
458     if callable(arg.shape):
459     return numarray.zeros(arg.shape(),numarray.Float)
460     else:
461     return numarray.zeros(arg.shape,numarray.Float)
462     else:
463     return 0
464    
465     def exp(arg):
466 jgs 82 """
467 jgs 123 @brief applies the exponential function to arg
468     @param arg (input): argument
469     """
470     if isinstance(arg,Symbol):
471     return Exp_Symbol(arg)
472     elif hasattr(arg,"exp"):
473     return arg.exp()
474 jgs 108 else:
475 jgs 123 return numarray.exp(arg)
476 jgs 82
477 jgs 123 class Exp_Symbol(Symbol):
478     """symbol representing the power of the first argument to the power of the second argument"""
479     def __init__(self,arg):
480     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
481     def __str__(self):
482     return "exp(%s)"%str(self.getArgument(0))
483     def eval(self,argval):
484     return exp(self.getEvaluatedArguments(argval)[0])
485     def _diff(self,arg):
486     return self*self.getDifferentiatedArguments(arg)[0]
487    
488     def sqrt(arg):
489 jgs 82 """
490 jgs 123 @brief applies the squre root function to arg
491     @param arg (input): argument
492     """
493     if isinstance(arg,Symbol):
494     return Sqrt_Symbol(arg)
495     elif hasattr(arg,"sqrt"):
496     return arg.sqrt()
497     else:
498     return numarray.sqrt(arg)
499 jgs 82
500 jgs 123 class Sqrt_Symbol(Symbol):
501     """symbol representing square root of argument"""
502     def __init__(self,arg):
503     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
504     def __str__(self):
505     return "sqrt(%s)"%str(self.getArgument(0))
506     def eval(self,argval):
507     return sqrt(self.getEvaluatedArguments(argval)[0])
508     def _diff(self,arg):
509     return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
510    
511     def log(arg):
512 jgs 82 """
513 jgs 123 @brief applies the logarithmic function bases exp(1.) to arg
514     @param arg (input): argument
515     """
516     if isinstance(arg,Symbol):
517     return Log_Symbol(arg)
518     elif hasattr(arg,"log"):
519     return arg.log()
520 jgs 108 else:
521 jgs 123 return numarray.log(arg)
522 jgs 82
523 jgs 123 class Log_Symbol(Symbol):
524     """symbol representing logarithm of the argument"""
525     def __init__(self,arg):
526     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
527     def __str__(self):
528     return "log(%s)"%str(self.getArgument(0))
529     def eval(self,argval):
530     return log(self.getEvaluatedArguments(argval)[0])
531     def _diff(self,arg):
532     return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
533 jgs 102
534 jgs 123 def sin(arg):
535 jgs 82 """
536 jgs 123 @brief applies the sinus function to arg
537     @param arg (input): argument
538     """
539     if isinstance(arg,Symbol):
540     return Sin_Symbol(arg)
541     elif hasattr(arg,"sin"):
542     return arg.sin()
543     else:
544     return numarray.sin(arg)
545 jgs 82
546 jgs 123 class Sin_Symbol(Symbol):
547     """symbol representing logarithm of the argument"""
548     def __init__(self,arg):
549     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
550     def __str__(self):
551     return "sin(%s)"%str(self.getArgument(0))
552     def eval(self,argval):
553     return sin(self.getEvaluatedArguments(argval)[0])
554     def _diff(self,arg):
555     return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
556    
557     def cos(arg):
558 jgs 82 """
559 jgs 123 @brief applies the sinus function to arg
560     @param arg (input): argument
561     """
562     if isinstance(arg,Symbol):
563     return Cos_Symbol(arg)
564     elif hasattr(arg,"cos"):
565     return arg.cos()
566 jgs 82 else:
567 jgs 123 return numarray.cos(arg)
568 jgs 82
569 jgs 123 class Cos_Symbol(Symbol):
570     """symbol representing logarithm of the argument"""
571     def __init__(self,arg):
572     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
573     def __str__(self):
574     return "cos(%s)"%str(self.getArgument(0))
575     def eval(self,argval):
576     return cos(self.getEvaluatedArguments(argval)[0])
577     def _diff(self,arg):
578     return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
579    
580     def tan(arg):
581 jgs 82 """
582 jgs 123 @brief applies the sinus function to arg
583     @param arg (input): argument
584     """
585     if isinstance(arg,Symbol):
586     return Tan_Symbol(arg)
587     elif hasattr(arg,"tan"):
588     return arg.tan()
589     else:
590     return numarray.tan(arg)
591 jgs 82
592 jgs 123 class Tan_Symbol(Symbol):
593     """symbol representing logarithm of the argument"""
594     def __init__(self,arg):
595     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
596     def __str__(self):
597     return "tan(%s)"%str(self.getArgument(0))
598     def eval(self,argval):
599     return tan(self.getEvaluatedArguments(argval)[0])
600     def _diff(self,arg):
601     s=cos(self.getArgument(0))
602     return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
603    
604     def sign(arg):
605 jgs 82 """
606 jgs 123 @brief applies the sign function to arg
607     @param arg (input): argument
608     """
609     if isinstance(arg,Symbol):
610     return Sign_Symbol(arg)
611     elif hasattr(arg,"sign"):
612     return arg.sign()
613 jgs 82 else:
614 jgs 123 return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
615     numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
616 jgs 82
617 jgs 123 class Sign_Symbol(Symbol):
618     """symbol representing the sign of the argument"""
619     def __init__(self,arg):
620     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
621     def __str__(self):
622     return "sign(%s)"%str(self.getArgument(0))
623     def eval(self,argval):
624     return sign(self.getEvaluatedArguments(argval)[0])
625    
626     def maxval(arg):
627 jgs 82 """
628 jgs 123 @brief returns the maximum value of argument arg""
629     @param arg (input): argument
630     """
631     if isinstance(arg,Symbol):
632     return Max_Symbol(arg)
633     elif hasattr(arg,"maxval"):
634     return arg.maxval()
635     elif hasattr(arg,"max"):
636     return arg.max()
637     else:
638     return arg
639 jgs 82
640 jgs 123 class Max_Symbol(Symbol):
641     """symbol representing the sign of the argument"""
642     def __init__(self,arg):
643     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
644     def __str__(self):
645     return "maxval(%s)"%str(self.getArgument(0))
646     def eval(self,argval):
647     return maxval(self.getEvaluatedArguments(argval)[0])
648    
649     def minval(arg):
650 jgs 82 """
651 jgs 123 @brief returns the maximum value of argument arg""
652     @param arg (input): argument
653     """
654     if isinstance(arg,Symbol):
655     return Min_Symbol(arg)
656     elif hasattr(arg,"maxval"):
657     return arg.minval()
658     elif hasattr(arg,"min"):
659     return arg.min()
660 jgs 82 else:
661 jgs 123 return arg
662 jgs 82
663 jgs 123 class Min_Symbol(Symbol):
664     """symbol representing the sign of the argument"""
665     def __init__(self,arg):
666     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
667     def __str__(self):
668     return "minval(%s)"%str(self.getArgument(0))
669     def eval(self,argval):
670     return minval(self.getEvaluatedArguments(argval)[0])
671    
672     def wherePositive(arg):
673 jgs 82 """
674 jgs 123 @brief returns the maximum value of argument arg""
675     @param arg (input): argument
676     """
677     if _testForZero(arg):
678     return 0
679     elif isinstance(arg,Symbol):
680     return WherePositive_Symbol(arg)
681     elif hasattr(arg,"wherePositive"):
682     return arg.minval()
683     elif hasattr(arg,"wherePositive"):
684     numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
685     else:
686     if arg>0:
687     return 1.
688     else:
689     return 0.
690 jgs 82
691 jgs 123 class WherePositive_Symbol(Symbol):
692     """symbol representing the wherePositive function"""
693     def __init__(self,arg):
694     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
695     def __str__(self):
696     return "wherePositive(%s)"%str(self.getArgument(0))
697     def eval(self,argval):
698     return wherePositive(self.getEvaluatedArguments(argval)[0])
699    
700     def whereNegative(arg):
701 jgs 82 """
702 jgs 123 @brief returns the maximum value of argument arg""
703     @param arg (input): argument
704     """
705     if _testForZero(arg):
706     return 0
707     elif isinstance(arg,Symbol):
708     return WhereNegative_Symbol(arg)
709     elif hasattr(arg,"whereNegative"):
710     return arg.whereNegative()
711     elif hasattr(arg,"shape"):
712     numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
713 jgs 88 else:
714 jgs 123 if arg<0:
715     return 1.
716     else:
717     return 0.
718 jgs 82
719 jgs 123 class WhereNegative_Symbol(Symbol):
720     """symbol representing the whereNegative function"""
721     def __init__(self,arg):
722     Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
723     def __str__(self):
724     return "whereNegative(%s)"%str(self.getArgument(0))
725     def eval(self,argval):
726     return whereNegative(self.getEvaluatedArguments(argval)[0])
727    
728     def outer(arg0,arg1):
729     if _testForZero(arg0) or _testForZero(arg1):
730     return 0
731     else:
732     if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
733     return Outer_Symbol(arg0,arg1)
734     elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
735     return arg0*arg1
736     elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
737     return numarray.outer(arg0,arg1)
738     else:
739     if arg0.getRank()==1 and arg1.getRank()==1:
740     out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
741     for i in range(arg0.getShape()[0]):
742     for j in range(arg1.getShape()[0]):
743     out[i,j]=arg0[i]*arg1[j]
744     return out
745     else:
746     raise ValueError,"outer is not fully implemented yet."
747    
748     class Outer_Symbol(Symbol):
749     """symbol representing the outer product of its two argument"""
750     def __init__(self,arg0,arg1):
751     a=[arg0,arg1]
752     s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
753     Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
754     def __str__(self):
755     return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
756     def eval(self,argval):
757     a=self.getEvaluatedArguments(argval)
758     return outer(a[0],a[1])
759     def _diff(self,arg):
760     a=self.getDifferentiatedArguments(arg)
761     return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
762    
763     def interpolate(arg,where):
764 jgs 82 """
765 jgs 123 @brief interpolates the function into the FunctionSpace where.
766 jgs 82
767 jgs 123 @param arg interpolant
768     @param where FunctionSpace to interpolate to
769 jgs 82 """
770 jgs 123 if _testForZero(arg):
771     return 0
772     elif isinstance(arg,Symbol):
773     return Interpolated_Symbol(arg,where)
774 jgs 82 else:
775 jgs 123 return escript.Data(arg,where)
776 jgs 82
777 jgs 123 def Interpolated_Symbol(Symbol):
778     """symbol representing the integral of the argument"""
779     def __init__(self,arg,where):
780     Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
781     def __str__(self):
782     return "interpolated(%s)"%(str(self.getArgument(0)))
783     def eval(self,argval):
784     a=self.getEvaluatedArguments(argval)
785     return integrate(a[0],where=self.getArgument(1))
786     def _diff(self,arg):
787     a=self.getDifferentiatedArguments(arg)
788     return integrate(a[0],where=self.getArgument(1))
789    
790     def grad(arg,where=None):
791 jgs 102 """
792 jgs 123 @brief returns the spatial gradient of arg at where.
793 jgs 102
794 jgs 123 @param arg: Data object representing the function which gradient to be calculated.
795     @param where: FunctionSpace in which the gradient will be calculated. If not present or
796     None an appropriate default is used.
797 jgs 102 """
798 jgs 123 if _testForZero(arg):
799     return 0
800     elif isinstance(arg,Symbol):
801     return Grad_Symbol(arg,where)
802     elif hasattr(arg,"grad"):
803     if where==None:
804     return arg.grad()
805     else:
806     return arg.grad(where)
807 jgs 102 else:
808 jgs 123 return arg*0.
809 jgs 102
810 jgs 123 def Grad_Symbol(Symbol):
811     """symbol representing the gradient of the argument"""
812     def __init__(self,arg,where=None):
813     d=_extractDim([arg])
814     s=tuple(list(_identifyShape([arg])).append(d))
815     Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
816     def __str__(self):
817     return "grad(%s)"%(str(self.getArgument(0)))
818     def eval(self,argval):
819     a=self.getEvaluatedArguments(argval)
820     return grad(a[0],where=self.getArgument(1))
821     def _diff(self,arg):
822     a=self.getDifferentiatedArguments(arg)
823     return grad(a[0],where=self.getArgument(1))
824    
825     def integrate(arg,where=None):
826 jgs 82 """
827 jgs 123 @brief return the integral if the function represented by Data object arg over its domain.
828 jgs 82
829 jgs 123 @param arg: Data object representing the function which is integrated.
830     @param where: FunctionSpace in which the integral is calculated. If not present or
831     None an appropriate default is used.
832 jgs 82 """
833 jgs 123 if _testForZero(arg):
834     return 0
835     elif isinstance(arg,Symbol):
836     return Integral_Symbol(arg,where)
837     else:
838     if not where==None: arg=escript.Data(arg,where)
839     if arg.getRank()==0:
840     return arg.integrate()[0]
841     else:
842     return arg.integrate()
843 jgs 82
844 jgs 123 def Integral_Symbol(Float_Symbol):
845     """symbol representing the integral of the argument"""
846     def __init__(self,arg,where=None):
847     Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
848     def __str__(self):
849     return "integral(%s)"%(str(self.getArgument(0)))
850     def eval(self,argval):
851     a=self.getEvaluatedArguments(argval)
852     return integrate(a[0],where=self.getArgument(1))
853     def _diff(self,arg):
854     a=self.getDifferentiatedArguments(arg)
855     return integrate(a[0],where=self.getArgument(1))
856    
857     #=============================
858     #
859     # wrapper for various functions: if the argument has attribute the function name
860     # as an argument it calls the correspong methods. Otherwise the coresponsing numarray
861     # function is called.
862     #
863     # functions involving the underlying Domain:
864     #
865    
866    
867     # functions returning Data objects:
868    
869     def transpose(arg,axis=None):
870 jgs 82 """
871 jgs 123 @brief returns the transpose of the Data object arg.
872 jgs 82
873     @param arg
874     """
875 jgs 123 if axis==None:
876     r=0
877     if hasattr(arg,"getRank"): r=arg.getRank()
878     if hasattr(arg,"rank"): r=arg.rank
879     axis=r/2
880     if isinstance(arg,Symbol):
881     return Transpose_Symbol(arg,axis=r)
882 jgs 102 if isinstance(arg,escript.Data):
883 jgs 123 # hack for transpose
884     r=arg.getRank()
885     if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
886     s=arg.getShape()
887     out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
888     for i in range(s[0]):
889     for j in range(s[1]):
890     out[j,i]=arg[i,j]
891     return out
892     # end hack for transpose
893     return arg.transpose(axis)
894 jgs 102 else:
895 jgs 123 return numarray.transpose(arg,axis=axis)
896 jgs 82
897 jgs 123 def trace(arg,axis0=0,axis1=1):
898 jgs 82 """
899 jgs 123 @brief return
900 jgs 82
901     @param arg
902     """
903 jgs 123 if isinstance(arg,Symbol):
904     s=list(arg.getShape())
905     s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
906     return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
907     elif isinstance(arg,escript.Data):
908     # hack for trace
909     s=arg.getShape()
910     if s[axis0]!=s[axis1]:
911     raise ValueError,"illegal axis in trace"
912     out=escript.Scalar(0.,arg.getFunctionSpace())
913     for i in range(s[0]):
914     for j in range(s[1]):
915     out+=arg[i,j]
916     return out
917     # end hack for transpose
918     return arg.transpose(axis0=axis0,axis1=axis1)
919 jgs 102 else:
920 jgs 123 return numarray.trace(arg,axis0=axis0,axis1=axis1)
921 jgs 82
922 jgs 123
923    
924     def Trace_Symbol(Symbol):
925     pass
926    
927 jgs 102 def length(arg):
928     """
929     @brief
930    
931     @param arg
932     """
933     if isinstance(arg,escript.Data):
934 jgs 108 if arg.isEmpty(): return escript.Data()
935     if arg.getRank()==0:
936     return abs(arg)
937     elif arg.getRank()==1:
938 jgs 104 sum=escript.Scalar(0,arg.getFunctionSpace())
939     for i in range(arg.getShape()[0]):
940     sum+=arg[i]**2
941     return sqrt(sum)
942 jgs 108 elif arg.getRank()==2:
943     sum=escript.Scalar(0,arg.getFunctionSpace())
944     for i in range(arg.getShape()[0]):
945     for j in range(arg.getShape()[1]):
946     sum+=arg[i,j]**2
947     return sqrt(sum)
948     elif arg.getRank()==3:
949     sum=escript.Scalar(0,arg.getFunctionSpace())
950     for i in range(arg.getShape()[0]):
951     for j in range(arg.getShape()[1]):
952     for k in range(arg.getShape()[2]):
953     sum+=arg[i,j,k]**2
954     return sqrt(sum)
955     elif arg.getRank()==4:
956     sum=escript.Scalar(0,arg.getFunctionSpace())
957     for i in range(arg.getShape()[0]):
958     for j in range(arg.getShape()[1]):
959     for k in range(arg.getShape()[2]):
960     for l in range(arg.getShape()[3]):
961     sum+=arg[i,j,k,l]**2
962     return sqrt(sum)
963 jgs 104 else:
964     raise SystemError,"length is not been implemented yet"
965     # return arg.length()
966 jgs 102 else:
967     return sqrt((arg**2).sum())
968    
969 jgs 113 def deviator(arg):
970     """
971     @brief
972    
973 jgs 123 @param arg0
974 jgs 113 """
975     if isinstance(arg,escript.Data):
976     shape=arg.getShape()
977     else:
978     shape=arg.shape
979     if len(shape)!=2:
980     raise ValueError,"Deviator requires rank 2 object"
981     if shape[0]!=shape[1]:
982     raise ValueError,"Deviator requires a square matrix"
983     return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
984    
985 jgs 123 def inner(arg0,arg1):
986 jgs 113 """
987     @brief
988    
989 jgs 123 @param arg0, arg1
990 jgs 113 """
991 jgs 123 sum=escript.Scalar(0,arg0.getFunctionSpace())
992 jgs 113 if arg.getRank()==0:
993 jgs 123 return arg0*arg1
994 jgs 113 elif arg.getRank()==1:
995     sum=escript.Scalar(0,arg.getFunctionSpace())
996     for i in range(arg.getShape()[0]):
997 jgs 123 sum+=arg0[i]*arg1[i]
998 jgs 113 elif arg.getRank()==2:
999     sum=escript.Scalar(0,arg.getFunctionSpace())
1000     for i in range(arg.getShape()[0]):
1001     for j in range(arg.getShape()[1]):
1002 jgs 123 sum+=arg0[i,j]*arg1[i,j]
1003 jgs 113 elif arg.getRank()==3:
1004     sum=escript.Scalar(0,arg.getFunctionSpace())
1005     for i in range(arg.getShape()[0]):
1006     for j in range(arg.getShape()[1]):
1007     for k in range(arg.getShape()[2]):
1008 jgs 123 sum+=arg0[i,j,k]*arg1[i,j,k]
1009 jgs 113 elif arg.getRank()==4:
1010     sum=escript.Scalar(0,arg.getFunctionSpace())
1011     for i in range(arg.getShape()[0]):
1012     for j in range(arg.getShape()[1]):
1013     for k in range(arg.getShape()[2]):
1014     for l in range(arg.getShape()[3]):
1015 jgs 123 sum+=arg0[i,j,k,l]*arg1[i,j,k,l]
1016 jgs 113 else:
1017     raise SystemError,"inner is not been implemented yet"
1018     return sum
1019    
1020 jgs 123 def matrixmult(arg0,arg1):
1021 jgs 102
1022 jgs 123 if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1023     numarray.matrixmult(arg0,arg1)
1024 jgs 102 else:
1025 jgs 123 # escript.matmult(arg0,arg1)
1026     if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1027     arg0=escript.Data(arg0,arg1.getFunctionSpace())
1028     elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1029     arg1=escript.Data(arg1,arg0.getFunctionSpace())
1030     if arg0.getRank()==2 and arg1.getRank()==1:
1031     out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1032     for i in range(arg0.getShape()[0]):
1033     for j in range(arg0.getShape()[1]):
1034     out[i]+=arg0[i,j]*arg1[j]
1035     return out
1036     else:
1037     raise SystemError,"matrixmult is not fully implemented yet!"
1038     #=========================================================
1039 jgs 102 # reduction operations:
1040 jgs 123 #=========================================================
1041 jgs 102 def sum(arg):
1042     """
1043     @brief
1044    
1045     @param arg
1046     """
1047     return arg.sum()
1048    
1049 jgs 82 def sup(arg):
1050     """
1051     @brief
1052    
1053     @param arg
1054     """
1055 jgs 102 if isinstance(arg,escript.Data):
1056     return arg.sup()
1057 jgs 108 elif isinstance(arg,float) or isinstance(arg,int):
1058     return arg
1059 jgs 102 else:
1060     return arg.max()
1061 jgs 82
1062     def inf(arg):
1063     """
1064     @brief
1065    
1066     @param arg
1067     """
1068 jgs 102 if isinstance(arg,escript.Data):
1069     return arg.inf()
1070 jgs 108 elif isinstance(arg,float) or isinstance(arg,int):
1071     return arg
1072 jgs 102 else:
1073     return arg.min()
1074 jgs 82
1075 jgs 102 def L2(arg):
1076 jgs 82 """
1077 jgs 102 @brief returns the L2-norm of the
1078 jgs 82
1079     @param arg
1080     """
1081 jgs 108 if isinstance(arg,escript.Data):
1082     return arg.L2()
1083     elif isinstance(arg,float) or isinstance(arg,int):
1084     return abs(arg)
1085     else:
1086     return numarry.sqrt(dot(arg,arg))
1087 jgs 82
1088 jgs 102 def Lsup(arg):
1089 jgs 82 """
1090     @brief
1091    
1092     @param arg
1093     """
1094 jgs 102 if isinstance(arg,escript.Data):
1095     return arg.Lsup()
1096 jgs 108 elif isinstance(arg,float) or isinstance(arg,int):
1097     return abs(arg)
1098 jgs 102 else:
1099 jgs 108 return max(numarray.abs(arg))
1100 jgs 82
1101 jgs 123 def dot(arg0,arg1):
1102 jgs 117 """
1103     @brief
1104    
1105     @param arg
1106     """
1107 jgs 123 if isinstance(arg0,escript.Data):
1108     return arg0.dot(arg1)
1109 jgs 102 elif isinstance(arg1,escript.Data):
1110 jgs 123 return arg1.dot(arg0)
1111 jgs 102 else:
1112 jgs 123 return numarray.dot(arg0,arg1)
1113 jgs 113
1114     def kronecker(d):
1115 jgs 122 if hasattr(d,"getDim"):
1116     return numarray.identity(d.getDim())
1117     else:
1118     return numarray.identity(d)
1119 jgs 117
1120     def unit(i,d):
1121     """
1122     @brief return a unit vector of dimension d with nonzero index i
1123     @param d dimension
1124     @param i index
1125     """
1126 jgs 123 e = numarray.zeros((d,),numarray.Float)
1127 jgs 117 e[i] = 1.0
1128     return e
1129 jgs 122
1130     #
1131 jgs 123 # ============================================
1132     # testing
1133     # ============================================
1134    
1135     if __name__=="__main__":
1136     u=ScalarSymbol(dim=2,name="u")
1137     v=ScalarSymbol(dim=2,name="v")
1138     v=VectorSymbol(2,"v")
1139     u=VectorSymbol(2,"u")
1140    
1141    
1142     print u+5,(u+5).diff(u)
1143     print 5+u,(5+u).diff(u)
1144     print u+v,(u+v).diff(u)
1145     print v+u,(v+u).diff(u)
1146    
1147     print u*5,(u*5).diff(u)
1148     print 5*u,(5*u).diff(u)
1149     print u*v,(u*v).diff(u)
1150     print v*u,(v*u).diff(u)
1151    
1152     print u-5,(u-5).diff(u)
1153     print 5-u,(5-u).diff(u)
1154     print u-v,(u-v).diff(u)
1155     print v-u,(v-u).diff(u)
1156    
1157     print u/5,(u/5).diff(u)
1158     print 5/u,(5/u).diff(u)
1159     print u/v,(u/v).diff(u)
1160     print v/u,(v/u).diff(u)
1161    
1162     print u**5,(u**5).diff(u)
1163     print 5**u,(5**u).diff(u)
1164     print u**v,(u**v).diff(u)
1165     print v**u,(v**u).diff(u)
1166    
1167     print exp(u),exp(u).diff(u)
1168     print sqrt(u),sqrt(u).diff(u)
1169     print log(u),log(u).diff(u)
1170     print sin(u),sin(u).diff(u)
1171     print cos(u),cos(u).diff(u)
1172     print tan(u),tan(u).diff(u)
1173     print sign(u),sign(u).diff(u)
1174     print abs(u),abs(u).diff(u)
1175     print wherePositive(u),wherePositive(u).diff(u)
1176     print whereNegative(u),whereNegative(u).diff(u)
1177     print maxval(u),maxval(u).diff(u)
1178     print minval(u),minval(u).diff(u)
1179    
1180     g=grad(u)
1181     print diff(5*g,g)
1182     4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1183     #
1184 jgs 122 # $Log$
1185 jgs 123 # Revision 1.11 2005/07/08 04:07:35 jgs
1186     # Merge of development branch back to main trunk on 2005-07-08
1187     #
1188 jgs 122 # Revision 1.10 2005/06/09 05:37:59 jgs
1189     # Merge of development branch back to main trunk on 2005-06-09
1190     #
1191 jgs 123 # Revision 1.2.2.17 2005/07/07 07:28:58 gross
1192     # some stuff added to util.py to improve functionality
1193     #
1194     # Revision 1.2.2.16 2005/06/30 01:53:55 gross
1195     # a bug in coloring fixed
1196     #
1197     # Revision 1.2.2.15 2005/06/29 02:36:43 gross
1198     # Symbols have been introduced and some function clarified. needs much more work
1199     #
1200 jgs 122 # Revision 1.2.2.14 2005/05/20 04:05:23 gross
1201     # some work on a darcy flow started
1202     #
1203     # Revision 1.2.2.13 2005/03/16 05:17:58 matt
1204     # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1205     # complement kronecker(dim) function.
1206     #
1207     # Revision 1.2.2.12 2005/03/10 08:14:37 matt
1208     # Added non-member Linf utility function to complement Data::Linf().
1209     #
1210     # Revision 1.2.2.11 2005/02/17 05:53:25 gross
1211     # some bug in saveDX fixed: in fact the bug was in
1212     # DataC/getDataPointShape
1213     #
1214     # Revision 1.2.2.10 2005/01/11 04:59:36 gross
1215     # automatic interpolation in integrate switched off
1216     #
1217     # Revision 1.2.2.9 2005/01/11 03:38:13 gross
1218     # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1219     #
1220     # Revision 1.2.2.8 2005/01/05 04:21:41 gross
1221     # FunctionSpace checking/matchig in slicing added
1222     #
1223     # Revision 1.2.2.7 2004/12/29 05:29:59 gross
1224     # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1225     #
1226     # Revision 1.2.2.6 2004/12/24 06:05:41 gross
1227     # some changes in linearPDEs to add AdevectivePDE
1228     #
1229     # Revision 1.2.2.5 2004/12/17 00:06:53 gross
1230     # mk sets ESYS_ROOT is undefined
1231     #
1232     # Revision 1.2.2.4 2004/12/07 03:19:51 gross
1233     # options for GMRES and PRES20 added
1234     #
1235     # Revision 1.2.2.3 2004/12/06 04:55:18 gross
1236     # function wraper extended
1237     #
1238     # Revision 1.2.2.2 2004/11/22 05:44:07 gross
1239     # a few more unitary functions have been added but not implemented in Data yet
1240     #
1241     # Revision 1.2.2.1 2004/11/12 06:58:15 gross
1242     # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1243     #
1244     # Revision 1.2 2004/10/27 00:23:36 jgs
1245     # fixed minor syntax error
1246     #
1247     # Revision 1.1.1.1 2004/10/26 06:53:56 jgs
1248     # initial import of project esys2
1249     #
1250     # Revision 1.1.2.3 2004/10/26 06:43:48 jgs
1251     # committing Lutz's and Paul's changes to brach jgs
1252     #
1253     # Revision 1.1.4.1 2004/10/20 05:32:51 cochrane
1254     # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.
1255     #
1256     # Revision 1.1 2004/08/05 03:58:27 gross
1257     # Bug in Assemble_NodeCoordinates fixed
1258     #
1259     #

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26