/[escript]/trunk/esys2/escript/py_src/util.py
ViewVC logotype

Contents of /trunk/esys2/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 147 - (show annotations)
Fri Aug 12 01:45:47 2005 UTC (14 years ago) by jgs
File MIME type: text/x-python
File size: 41602 byte(s)
erge of development branch dev-02 back to main trunk on 2005-08-12

1 # $Id$
2
3 ## @file util.py
4
5 """
6 @brief Utility functions for escript
7
8 TODO for Data:
9
10 * binary operations @: (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11 @=+,-,*,/,** (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12 (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13
14 * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15 * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16 """
17
18 import numarray
19 import escript
20
21 #
22 # escript constants (have to be consistent with utilC.h )
23 #
24 UNKNOWN=-1
25 EPSILON=1.e-15
26 Pi=numarray.pi
27 # some solver options:
28 NO_REORDERING=0
29 MINIMUM_FILL_IN=1
30 NESTED_DISSECTION=2
31 # solver methods
32 DEFAULT_METHOD=0
33 DIRECT=1
34 CHOLEVSKY=2
35 PCG=3
36 CR=4
37 CGS=5
38 BICGSTAB=6
39 SSOR=7
40 ILU0=8
41 ILUT=9
42 JACOBI=10
43 GMRES=11
44 PRES20=12
45
46 METHOD_KEY="method"
47 SYMMETRY_KEY="symmetric"
48 TOLERANCE_KEY="tolerance"
49
50 # supported file formats:
51 VRML=1
52 PNG=2
53 JPEG=3
54 JPG=3
55 PS=4
56 OOGL=5
57 BMP=6
58 TIFF=7
59 OPENINVENTOR=8
60 RENDERMAN=9
61 PNM=10
62
63 #===========================================================
64 # a simple tool box to deal with _differentials of functions
65 #===========================================================
66
67 class Symbol:
68 """symbol class"""
69 def __init__(self,name="symbol",shape=(),dim=3,args=[]):
70 """creates an instance of a symbol of shape shape and spatial dimension dim
71 The symbol may depending on a list of arguments args which
72 may be symbols or other objects. name gives the name of the symbol."""
73 self.__args=args
74 self.__name=name
75 self.__shape=shape
76 if hasattr(dim,"getDim"):
77 self.__dim=dim.getDim()
78 else:
79 self.__dim=dim
80 #
81 self.__cache_val=None
82 self.__cache_argval=None
83
84 def getArgument(self,i):
85 """returns the i-th argument"""
86 return self.__args[i]
87
88 def getDim(self):
89 """returns the spatial dimension of the symbol"""
90 return self.__dim
91
92 def getRank(self):
93 """returns the rank of the symbol"""
94 return len(self.getShape())
95
96 def getShape(self):
97 """returns the shape of the symbol"""
98 return self.__shape
99
100 def getEvaluatedArguments(self,argval):
101 """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
102 if argval==self.__cache_argval:
103 print "%s: cached value used"%self
104 return self.__cache_val
105 else:
106 out=[]
107 for a in self.__args:
108 if isinstance(a,Symbol):
109 out.append(a.eval(argval))
110 else:
111 out.append(a)
112 self.__cache_argval=argval
113 self.__cache_val=out
114 return out
115
116 def getDifferentiatedArguments(self,arg):
117 """returns the list of the arguments _differentiated by arg"""
118 out=[]
119 for a in self.__args:
120 if isinstance(a,Symbol):
121 out.append(a.diff(arg))
122 else:
123 out.append(0)
124 return out
125
126 def diff(self,arg):
127 """returns the _differention of self by arg."""
128 if self==arg:
129 out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
130 if self.getRank()==0:
131 out=1.
132 elif self.getRank()==1:
133 for i0 in range(self.getShape()[0]):
134 out[i0,i0]=1.
135 elif self.getRank()==2:
136 for i0 in range(self.getShape()[0]):
137 for i1 in range(self.getShape()[1]):
138 out[i0,i1,i0,i1]=1.
139 elif self.getRank()==3:
140 for i0 in range(self.getShape()[0]):
141 for i1 in range(self.getShape()[1]):
142 for i2 in range(self.getShape()[2]):
143 out[i0,i1,i2,i0,i1,i2]=1.
144 elif self.getRank()==4:
145 for i0 in range(self.getShape()[0]):
146 for i1 in range(self.getShape()[1]):
147 for i2 in range(self.getShape()[2]):
148 for i3 in range(self.getShape()[3]):
149 out[i0,i1,i2,i3,i0,i1,i2,i3]=1.
150 else:
151 raise ValueError,"differential support rank<5 only."
152 return out
153 else:
154 return self._diff(arg)
155
156 def _diff(self,arg):
157 """return derivate of self with respect to arg (!=self).
158 This method is overwritten by a particular symbol"""
159 return 0
160
161 def eval(self,argval):
162 """subsitutes symbol u in self by argval[u] and returns the result. If
163 self is not a key of argval then self is returned."""
164 if argval.has_key(self):
165 return argval[self]
166 else:
167 return self
168
169 def __str__(self):
170 """returns a string representation of the symbol"""
171 return self.__name
172
173 def __add__(self,other):
174 """adds other to symbol self. if _testForZero(other) self is returned."""
175 if _testForZero(other):
176 return self
177 else:
178 a=_matchShape([self,other])
179 return Add_Symbol(a[0],a[1])
180
181 def __radd__(self,other):
182 """adds other to symbol self. if _testForZero(other) self is returned."""
183 return self+other
184
185 def __neg__(self):
186 """returns -self."""
187 return self*(-1.)
188
189 def __pos__(self):
190 """returns +self."""
191 return self
192
193 def __abs__(self):
194 """returns absolute value"""
195 return Abs_Symbol(self)
196
197 def __sub__(self,other):
198 """subtracts other from symbol self. if _testForZero(other) self is returned."""
199 if _testForZero(other):
200 return self
201 else:
202 return self+(-other)
203
204 def __rsub__(self,other):
205 """subtracts symbol self from other."""
206 return -self+other
207
208 def __div__(self,other):
209 """divides symbol self by other."""
210 if isinstance(other,Symbol):
211 a=_matchShape([self,other])
212 return Div_Symbol(a[0],a[1])
213 else:
214 return self*(1./other)
215
216 def __rdiv__(self,other):
217 """dived other by symbol self. if _testForZero(other) 0 is returned."""
218 if _testForZero(other):
219 return 0
220 else:
221 a=_matchShape([self,other])
222 return Div_Symbol(a[0],a[1])
223
224 def __pow__(self,other):
225 """raises symbol self to the power of other"""
226 a=_matchShape([self,other])
227 return Power_Symbol(a[0],a[1])
228
229 def __rpow__(self,other):
230 """raises other to the symbol self"""
231 a=_matchShape([self,other])
232 return Power_Symbol(a[1],a[0])
233
234 def __mul__(self,other):
235 """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
236 if _testForZero(other):
237 return 0
238 else:
239 a=_matchShape([self,other])
240 return Mult_Symbol(a[0],a[1])
241
242 def __rmul__(self,other):
243 """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
244 return self*other
245
246 def __getitem__(self,sl):
247 print sl
248
249 def Float_Symbol(Symbol):
250 def __init__(self,name="symbol",shape=(),args=[]):
251 Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
252
253 class ScalarSymbol(Symbol):
254 """a scalar symbol"""
255 def __init__(self,dim=3,name="scalar"):
256 """creates a scalar symbol of spatial dimension dim"""
257 if hasattr(dim,"getDim"):
258 d=dim.getDim()
259 else:
260 d=dim
261 Symbol.__init__(self,shape=(),dim=d,name=name)
262
263 class VectorSymbol(Symbol):
264 """a vector symbol"""
265 def __init__(self,dim=3,name="vector"):
266 """creates a vector symbol of spatial dimension dim"""
267 if hasattr(dim,"getDim"):
268 d=dim.getDim()
269 else:
270 d=dim
271 Symbol.__init__(self,shape=(d,),dim=d,name=name)
272
273 class TensorSymbol(Symbol):
274 """a tensor symbol"""
275 def __init__(self,dim=3,name="tensor"):
276 """creates a tensor symbol of spatial dimension dim"""
277 if hasattr(dim,"getDim"):
278 d=dim.getDim()
279 else:
280 d=dim
281 Symbol.__init__(self,shape=(d,d),dim=d,name=name)
282
283 class Tensor3Symbol(Symbol):
284 """a tensor order 3 symbol"""
285 def __init__(self,dim=3,name="tensor3"):
286 """creates a tensor order 3 symbol of spatial dimension dim"""
287 if hasattr(dim,"getDim"):
288 d=dim.getDim()
289 else:
290 d=dim
291 Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
292
293 class Tensor4Symbol(Symbol):
294 """a tensor order 4 symbol"""
295 def __init__(self,dim=3,name="tensor4"):
296 """creates a tensor order 4 symbol of spatial dimension dim"""
297 if hasattr(dim,"getDim"):
298 d=dim.getDim()
299 else:
300 d=dim
301 Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
302
303 class Add_Symbol(Symbol):
304 """symbol representing the sum of two arguments"""
305 def __init__(self,arg0,arg1):
306 a=[arg0,arg1]
307 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
308 def __str__(self):
309 return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
310 def eval(self,argval):
311 a=self.getEvaluatedArguments(argval)
312 return a[0]+a[1]
313 def _diff(self,arg):
314 a=self.getDifferentiatedArguments(arg)
315 return a[0]+a[1]
316
317 class Mult_Symbol(Symbol):
318 """symbol representing the product of two arguments"""
319 def __init__(self,arg0,arg1):
320 a=[arg0,arg1]
321 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
322 def __str__(self):
323 return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
324 def eval(self,argval):
325 a=self.getEvaluatedArguments(argval)
326 return a[0]*a[1]
327 def _diff(self,arg):
328 a=self.getDifferentiatedArguments(arg)
329 return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
330
331 class Div_Symbol(Symbol):
332 """symbol representing the quotient of two arguments"""
333 def __init__(self,arg0,arg1):
334 a=[arg0,arg1]
335 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
336 def __str__(self):
337 return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
338 def eval(self,argval):
339 a=self.getEvaluatedArguments(argval)
340 return a[0]/a[1]
341 def _diff(self,arg):
342 a=self.getDifferentiatedArguments(arg)
343 return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
344 (self.getArgument(1)*self.getArgument(1))
345
346 class Power_Symbol(Symbol):
347 """symbol representing the power of the first argument to the power of the second argument"""
348 def __init__(self,arg0,arg1):
349 a=[arg0,arg1]
350 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
351 def __str__(self):
352 return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
353 def eval(self,argval):
354 a=self.getEvaluatedArguments(argval)
355 return a[0]**a[1]
356 def _diff(self,arg):
357 a=self.getDifferentiatedArguments(arg)
358 return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
359
360 class Abs_Symbol(Symbol):
361 """symbol representing absolute value of its argument"""
362 def __init__(self,arg):
363 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
364 def __str__(self):
365 return "abs(%s)"%str(self.getArgument(0))
366 def eval(self,argval):
367 return abs(self.getEvaluatedArguments(argval)[0])
368 def _diff(self,arg):
369 return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
370
371 #=========================================================
372 # some little helpers
373 #=========================================================
374 def _testForZero(arg):
375 """returns True is arg is considered of being zero"""
376 if isinstance(arg,int):
377 return not arg>0
378 elif isinstance(arg,float):
379 return not arg>0.
380 elif isinstance(arg,numarray.NumArray):
381 a=abs(arg)
382 while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
383 return not a>0
384 else:
385 return False
386
387 def _extractDim(args):
388 dim=None
389 for a in args:
390 if hasattr(a,"getDim"):
391 d=a.getDim()
392 if dim==None:
393 dim=d
394 else:
395 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
396 if dim==None:
397 raise ValueError,"cannot recover spatial dimension"
398 return dim
399
400 def _identifyShape(arg):
401 """identifies the shape of arg."""
402 if hasattr(arg,"getShape"):
403 arg_shape=arg.getShape()
404 elif hasattr(arg,"shape"):
405 s=arg.shape
406 if callable(s):
407 arg_shape=s()
408 else:
409 arg_shape=s
410 else:
411 arg_shape=()
412 return arg_shape
413
414 def _extractShape(args):
415 """extracts the common shape of the list of arguments args"""
416 shape=None
417 for a in args:
418 a_shape=_identifyShape(a)
419 if shape==None: shape=a_shape
420 if shape!=a_shape: raise ValueError,"inconsistent shape"
421 if shape==None:
422 raise ValueError,"cannot recover shape"
423 return shape
424
425 def _matchShape(args,shape=None):
426 """returns the list of arguments args as object which have all the specified shape.
427 if shape is not given the shape "largest" shape of args is used."""
428 # identify the list of shapes:
429 arg_shapes=[]
430 for a in args: arg_shapes.append(_identifyShape(a))
431 # get the largest shape (currently the longest shape):
432 if shape==None: shape=max(arg_shapes)
433
434 out=[]
435 for i in range(len(args)):
436 if shape==arg_shapes[i]:
437 out.append(args[i])
438 else:
439 if len(shape)==0: # then len(arg_shapes[i])>0
440 raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
441 else:
442 if len(arg_shapes[i])==0:
443 out.append(outer(args[i],numarray.ones(shape)))
444 else:
445 raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
446 return out
447
448 #=========================================================
449 # wrappers for various mathematical functions:
450 #=========================================================
451 def diff(arg,dep):
452 """returns the derivative of arg with respect to dep. If arg is not Symbol object
453 0 is returned"""
454 if isinstance(arg,Symbol):
455 return arg.diff(dep)
456 elif hasattr(arg,"shape"):
457 if callable(arg.shape):
458 return numarray.zeros(arg.shape(),numarray.Float)
459 else:
460 return numarray.zeros(arg.shape,numarray.Float)
461 else:
462 return 0
463
464 def exp(arg):
465 """
466 @brief applies the exponential function to arg
467 @param arg (input): argument
468 """
469 if isinstance(arg,Symbol):
470 return Exp_Symbol(arg)
471 elif hasattr(arg,"exp"):
472 return arg.exp()
473 else:
474 return numarray.exp(arg)
475
476 class Exp_Symbol(Symbol):
477 """symbol representing the power of the first argument to the power of the second argument"""
478 def __init__(self,arg):
479 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
480 def __str__(self):
481 return "exp(%s)"%str(self.getArgument(0))
482 def eval(self,argval):
483 return exp(self.getEvaluatedArguments(argval)[0])
484 def _diff(self,arg):
485 return self*self.getDifferentiatedArguments(arg)[0]
486
487 def sqrt(arg):
488 """
489 @brief applies the squre root function to arg
490 @param arg (input): argument
491 """
492 if isinstance(arg,Symbol):
493 return Sqrt_Symbol(arg)
494 elif hasattr(arg,"sqrt"):
495 return arg.sqrt()
496 else:
497 return numarray.sqrt(arg)
498
499 class Sqrt_Symbol(Symbol):
500 """symbol representing square root of argument"""
501 def __init__(self,arg):
502 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
503 def __str__(self):
504 return "sqrt(%s)"%str(self.getArgument(0))
505 def eval(self,argval):
506 return sqrt(self.getEvaluatedArguments(argval)[0])
507 def _diff(self,arg):
508 return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
509
510 def log(arg):
511 """
512 @brief applies the logarithmic function bases exp(1.) to arg
513 @param arg (input): argument
514 """
515 if isinstance(arg,Symbol):
516 return Log_Symbol(arg)
517 elif hasattr(arg,"log"):
518 return arg.log()
519 else:
520 return numarray.log(arg)
521
522 class Log_Symbol(Symbol):
523 """symbol representing logarithm of the argument"""
524 def __init__(self,arg):
525 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
526 def __str__(self):
527 return "log(%s)"%str(self.getArgument(0))
528 def eval(self,argval):
529 return log(self.getEvaluatedArguments(argval)[0])
530 def _diff(self,arg):
531 return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
532
533 def ln(arg):
534 """
535 @brief applies the natural logarithmic function to arg
536 @param arg (input): argument
537 """
538 if isinstance(arg,Symbol):
539 return Ln_Symbol(arg)
540 elif hasattr(arg,"ln"):
541 return arg.log()
542 else:
543 return numarray.log(arg)
544
545 class Ln_Symbol(Symbol):
546 """symbol representing natural logarithm of the argument"""
547 def __init__(self,arg):
548 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
549 def __str__(self):
550 return "ln(%s)"%str(self.getArgument(0))
551 def eval(self,argval):
552 return ln(self.getEvaluatedArguments(argval)[0])
553 def _diff(self,arg):
554 return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
555
556 def sin(arg):
557 """
558 @brief applies the sin function to arg
559 @param arg (input): argument
560 """
561 if isinstance(arg,Symbol):
562 return Sin_Symbol(arg)
563 elif hasattr(arg,"sin"):
564 return arg.sin()
565 else:
566 return numarray.sin(arg)
567
568 class Sin_Symbol(Symbol):
569 """symbol representing sin of the argument"""
570 def __init__(self,arg):
571 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
572 def __str__(self):
573 return "sin(%s)"%str(self.getArgument(0))
574 def eval(self,argval):
575 return sin(self.getEvaluatedArguments(argval)[0])
576 def _diff(self,arg):
577 return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
578
579 def cos(arg):
580 """
581 @brief applies the cos function to arg
582 @param arg (input): argument
583 """
584 if isinstance(arg,Symbol):
585 return Cos_Symbol(arg)
586 elif hasattr(arg,"cos"):
587 return arg.cos()
588 else:
589 return numarray.cos(arg)
590
591 class Cos_Symbol(Symbol):
592 """symbol representing cos of the argument"""
593 def __init__(self,arg):
594 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
595 def __str__(self):
596 return "cos(%s)"%str(self.getArgument(0))
597 def eval(self,argval):
598 return cos(self.getEvaluatedArguments(argval)[0])
599 def _diff(self,arg):
600 return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
601
602 def tan(arg):
603 """
604 @brief applies the tan function to arg
605 @param arg (input): argument
606 """
607 if isinstance(arg,Symbol):
608 return Tan_Symbol(arg)
609 elif hasattr(arg,"tan"):
610 return arg.tan()
611 else:
612 return numarray.tan(arg)
613
614 class Tan_Symbol(Symbol):
615 """symbol representing tan of the argument"""
616 def __init__(self,arg):
617 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
618 def __str__(self):
619 return "tan(%s)"%str(self.getArgument(0))
620 def eval(self,argval):
621 return tan(self.getEvaluatedArguments(argval)[0])
622 def _diff(self,arg):
623 s=cos(self.getArgument(0))
624 return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
625
626 def sign(arg):
627 """
628 @brief applies the sign function to arg
629 @param arg (input): argument
630 """
631 if isinstance(arg,Symbol):
632 return Sign_Symbol(arg)
633 elif hasattr(arg,"sign"):
634 return arg.sign()
635 else:
636 return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
637 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
638
639 class Sign_Symbol(Symbol):
640 """symbol representing the sign of the argument"""
641 def __init__(self,arg):
642 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
643 def __str__(self):
644 return "sign(%s)"%str(self.getArgument(0))
645 def eval(self,argval):
646 return sign(self.getEvaluatedArguments(argval)[0])
647
648 def maxval(arg):
649 """
650 @brief returns the maximum value of argument arg""
651 @param arg (input): argument
652 """
653 if isinstance(arg,Symbol):
654 return Max_Symbol(arg)
655 elif hasattr(arg,"maxval"):
656 return arg.maxval()
657 elif hasattr(arg,"max"):
658 return arg.max()
659 else:
660 return arg
661
662 class Max_Symbol(Symbol):
663 """symbol representing the maximum value of the argument"""
664 def __init__(self,arg):
665 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
666 def __str__(self):
667 return "maxval(%s)"%str(self.getArgument(0))
668 def eval(self,argval):
669 return maxval(self.getEvaluatedArguments(argval)[0])
670
671 def minval(arg):
672 """
673 @brief returns the minimum value of argument arg""
674 @param arg (input): argument
675 """
676 if isinstance(arg,Symbol):
677 return Min_Symbol(arg)
678 elif hasattr(arg,"maxval"):
679 return arg.minval()
680 elif hasattr(arg,"min"):
681 return arg.min()
682 else:
683 return arg
684
685 class Min_Symbol(Symbol):
686 """symbol representing the minimum value of the argument"""
687 def __init__(self,arg):
688 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
689 def __str__(self):
690 return "minval(%s)"%str(self.getArgument(0))
691 def eval(self,argval):
692 return minval(self.getEvaluatedArguments(argval)[0])
693
694 def wherePositive(arg):
695 """
696 @brief returns the positive values of argument arg""
697 @param arg (input): argument
698 """
699 if _testForZero(arg):
700 return 0
701 elif isinstance(arg,Symbol):
702 return WherePositive_Symbol(arg)
703 elif hasattr(arg,"wherePositive"):
704 return arg.minval()
705 elif hasattr(arg,"wherePositive"):
706 numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
707 else:
708 if arg>0:
709 return 1.
710 else:
711 return 0.
712
713 class WherePositive_Symbol(Symbol):
714 """symbol representing the wherePositive function"""
715 def __init__(self,arg):
716 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
717 def __str__(self):
718 return "wherePositive(%s)"%str(self.getArgument(0))
719 def eval(self,argval):
720 return wherePositive(self.getEvaluatedArguments(argval)[0])
721
722 def whereNegative(arg):
723 """
724 @brief returns the negative values of argument arg""
725 @param arg (input): argument
726 """
727 if _testForZero(arg):
728 return 0
729 elif isinstance(arg,Symbol):
730 return WhereNegative_Symbol(arg)
731 elif hasattr(arg,"whereNegative"):
732 return arg.whereNegative()
733 elif hasattr(arg,"shape"):
734 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
735 else:
736 if arg<0:
737 return 1.
738 else:
739 return 0.
740
741 class WhereNegative_Symbol(Symbol):
742 """symbol representing the whereNegative function"""
743 def __init__(self,arg):
744 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
745 def __str__(self):
746 return "whereNegative(%s)"%str(self.getArgument(0))
747 def eval(self,argval):
748 return whereNegative(self.getEvaluatedArguments(argval)[0])
749
750 def maximum(arg0,arg1):
751 """return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned"""
752 m=whereNegative(arg0-arg1)
753 return m*arg1+(1.-m)*arg0
754
755 def minimum(arg0,arg1):
756 """return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned"""
757 m=whereNegative(arg0-arg1)
758 return m*arg0+(1.-m)*arg1
759
760 def outer(arg0,arg1):
761 if _testForZero(arg0) or _testForZero(arg1):
762 return 0
763 else:
764 if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
765 return Outer_Symbol(arg0,arg1)
766 elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
767 return arg0*arg1
768 elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
769 return numarray.outer(arg0,arg1)
770 else:
771 if arg0.getRank()==1 and arg1.getRank()==1:
772 out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
773 for i in range(arg0.getShape()[0]):
774 for j in range(arg1.getShape()[0]):
775 out[i,j]=arg0[i]*arg1[j]
776 return out
777 else:
778 raise ValueError,"outer is not fully implemented yet."
779
780 class Outer_Symbol(Symbol):
781 """symbol representing the outer product of its two arguments"""
782 def __init__(self,arg0,arg1):
783 a=[arg0,arg1]
784 s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
785 Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
786 def __str__(self):
787 return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
788 def eval(self,argval):
789 a=self.getEvaluatedArguments(argval)
790 return outer(a[0],a[1])
791 def _diff(self,arg):
792 a=self.getDifferentiatedArguments(arg)
793 return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
794
795 def interpolate(arg,where):
796 """
797 @brief interpolates the function into the FunctionSpace where.
798
799 @param arg interpolant
800 @param where FunctionSpace to interpolate to
801 """
802 if _testForZero(arg):
803 return 0
804 elif isinstance(arg,Symbol):
805 return Interpolated_Symbol(arg,where)
806 else:
807 return escript.Data(arg,where)
808
809 def Interpolated_Symbol(Symbol):
810 """symbol representing the integral of the argument"""
811 def __init__(self,arg,where):
812 Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
813 def __str__(self):
814 return "interpolated(%s)"%(str(self.getArgument(0)))
815 def eval(self,argval):
816 a=self.getEvaluatedArguments(argval)
817 return integrate(a[0],where=self.getArgument(1))
818 def _diff(self,arg):
819 a=self.getDifferentiatedArguments(arg)
820 return integrate(a[0],where=self.getArgument(1))
821
822 def div(arg,where=None):
823 """
824 @brief returns the divergence of arg at where.
825
826 @param arg: Data object representing the function which gradient to be calculated.
827 @param where: FunctionSpace in which the gradient will be calculated. If not present or
828 None an appropriate default is used.
829 """
830 return trace(grad(arg,where))
831
832 def grad(arg,where=None):
833 """
834 @brief returns the spatial gradient of arg at where.
835
836 @param arg: Data object representing the function which gradient to be calculated.
837 @param where: FunctionSpace in which the gradient will be calculated. If not present or
838 None an appropriate default is used.
839 """
840 if _testForZero(arg):
841 return 0
842 elif isinstance(arg,Symbol):
843 return Grad_Symbol(arg,where)
844 elif hasattr(arg,"grad"):
845 if where==None:
846 return arg.grad()
847 else:
848 return arg.grad(where)
849 else:
850 return arg*0.
851
852 def Grad_Symbol(Symbol):
853 """symbol representing the gradient of the argument"""
854 def __init__(self,arg,where=None):
855 d=_extractDim([arg])
856 s=tuple(list(_identifyShape([arg])).append(d))
857 Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
858 def __str__(self):
859 return "grad(%s)"%(str(self.getArgument(0)))
860 def eval(self,argval):
861 a=self.getEvaluatedArguments(argval)
862 return grad(a[0],where=self.getArgument(1))
863 def _diff(self,arg):
864 a=self.getDifferentiatedArguments(arg)
865 return grad(a[0],where=self.getArgument(1))
866
867 def integrate(arg,where=None):
868 """
869 @brief return the integral if the function represented by Data object arg over its domain.
870
871 @param arg: Data object representing the function which is integrated.
872 @param where: FunctionSpace in which the integral is calculated. If not present or
873 None an appropriate default is used.
874 """
875 if _testForZero(arg):
876 return 0
877 elif isinstance(arg,Symbol):
878 return Integral_Symbol(arg,where)
879 else:
880 if not where==None: arg=escript.Data(arg,where)
881 if arg.getRank()==0:
882 return arg.integrate()[0]
883 else:
884 return arg.integrate()
885
886 def Integral_Symbol(Float_Symbol):
887 """symbol representing the integral of the argument"""
888 def __init__(self,arg,where=None):
889 Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
890 def __str__(self):
891 return "integral(%s)"%(str(self.getArgument(0)))
892 def eval(self,argval):
893 a=self.getEvaluatedArguments(argval)
894 return integrate(a[0],where=self.getArgument(1))
895 def _diff(self,arg):
896 a=self.getDifferentiatedArguments(arg)
897 return integrate(a[0],where=self.getArgument(1))
898
899 #=============================
900 #
901 # wrapper for various functions: if the argument has attribute the function name
902 # as an argument it calls the corresponding methods. Otherwise the corresponding
903 # numarray function is called.
904
905 # functions involving the underlying Domain:
906
907
908 # functions returning Data objects:
909
910 def transpose(arg,axis=None):
911 """
912 @brief returns the transpose of the Data object arg.
913
914 @param arg
915 """
916 if axis==None:
917 r=0
918 if hasattr(arg,"getRank"): r=arg.getRank()
919 if hasattr(arg,"rank"): r=arg.rank
920 axis=r/2
921 if isinstance(arg,Symbol):
922 return Transpose_Symbol(arg,axis=r)
923 if isinstance(arg,escript.Data):
924 # hack for transpose
925 r=arg.getRank()
926 if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
927 s=arg.getShape()
928 out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
929 for i in range(s[0]):
930 for j in range(s[1]):
931 out[j,i]=arg[i,j]
932 return out
933 # end hack for transpose
934 return arg.transpose(axis)
935 else:
936 return numarray.transpose(arg,axis=axis)
937
938 def trace(arg,axis0=0,axis1=1):
939 """
940 @brief return
941
942 @param arg
943 """
944 if isinstance(arg,Symbol):
945 s=list(arg.getShape())
946 s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
947 return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
948 elif isinstance(arg,escript.Data):
949 # hack for trace
950 s=arg.getShape()
951 if s[axis0]!=s[axis1]:
952 raise ValueError,"illegal axis in trace"
953 out=escript.Scalar(0.,arg.getFunctionSpace())
954 for i in range(s[axis0]):
955 out+=arg[i,i]
956 return out
957 # end hack for trace
958 else:
959 return numarray.trace(arg,axis0=axis0,axis1=axis1)
960
961 def Trace_Symbol(Symbol):
962 pass
963
964 def length(arg):
965 """
966 @brief
967
968 @param arg
969 """
970 if isinstance(arg,escript.Data):
971 if arg.isEmpty(): return escript.Data()
972 if arg.getRank()==0:
973 return abs(arg)
974 elif arg.getRank()==1:
975 out=escript.Scalar(0,arg.getFunctionSpace())
976 for i in range(arg.getShape()[0]):
977 out+=arg[i]**2
978 return sqrt(out)
979 elif arg.getRank()==2:
980 out=escript.Scalar(0,arg.getFunctionSpace())
981 for i in range(arg.getShape()[0]):
982 for j in range(arg.getShape()[1]):
983 out+=arg[i,j]**2
984 return sqrt(out)
985 elif arg.getRank()==3:
986 out=escript.Scalar(0,arg.getFunctionSpace())
987 for i in range(arg.getShape()[0]):
988 for j in range(arg.getShape()[1]):
989 for k in range(arg.getShape()[2]):
990 out+=arg[i,j,k]**2
991 return sqrt(out)
992 elif arg.getRank()==4:
993 out=escript.Scalar(0,arg.getFunctionSpace())
994 for i in range(arg.getShape()[0]):
995 for j in range(arg.getShape()[1]):
996 for k in range(arg.getShape()[2]):
997 for l in range(arg.getShape()[3]):
998 out+=arg[i,j,k,l]**2
999 return sqrt(out)
1000 else:
1001 raise SystemError,"length is not been fully implemented yet"
1002 # return arg.length()
1003 elif isinstance(arg,float):
1004 return abs(arg)
1005 else:
1006 return sqrt((arg**2).sum())
1007
1008 def deviator(arg):
1009 """
1010 @brief
1011
1012 @param arg0
1013 """
1014 if isinstance(arg,escript.Data):
1015 shape=arg.getShape()
1016 else:
1017 shape=arg.shape
1018 if len(shape)!=2:
1019 raise ValueError,"Deviator requires rank 2 object"
1020 if shape[0]!=shape[1]:
1021 raise ValueError,"Deviator requires a square matrix"
1022 return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
1023
1024 def inner(arg0,arg1):
1025 """
1026 @brief
1027
1028 @param arg0, arg1
1029 """
1030 if isinstance(arg0,escript.Data):
1031 arg=arg0
1032 else:
1033 arg=arg1
1034
1035 out=escript.Scalar(0,arg.getFunctionSpace())
1036 if arg.getRank()==0:
1037 return arg0*arg1
1038 elif arg.getRank()==1:
1039 out=escript.Scalar(0,arg.getFunctionSpace())
1040 for i in range(arg.getShape()[0]):
1041 out+=arg0[i]*arg1[i]
1042 elif arg.getRank()==2:
1043 out=escript.Scalar(0,arg.getFunctionSpace())
1044 for i in range(arg.getShape()[0]):
1045 for j in range(arg.getShape()[1]):
1046 out+=arg0[i,j]*arg1[i,j]
1047 elif arg.getRank()==3:
1048 out=escript.Scalar(0,arg.getFunctionSpace())
1049 for i in range(arg.getShape()[0]):
1050 for j in range(arg.getShape()[1]):
1051 for k in range(arg.getShape()[2]):
1052 out+=arg0[i,j,k]*arg1[i,j,k]
1053 elif arg.getRank()==4:
1054 out=escript.Scalar(0,arg.getFunctionSpace())
1055 for i in range(arg.getShape()[0]):
1056 for j in range(arg.getShape()[1]):
1057 for k in range(arg.getShape()[2]):
1058 for l in range(arg.getShape()[3]):
1059 out+=arg0[i,j,k,l]*arg1[i,j,k,l]
1060 else:
1061 raise SystemError,"inner is not been implemented yet"
1062 return out
1063
1064 def matrixmult(arg0,arg1):
1065
1066 if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1067 numarray.matrixmult(arg0,arg1)
1068 else:
1069 # escript.matmult(arg0,arg1)
1070 if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1071 arg0=escript.Data(arg0,arg1.getFunctionSpace())
1072 elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1073 arg1=escript.Data(arg1,arg0.getFunctionSpace())
1074 if arg0.getRank()==2 and arg1.getRank()==1:
1075 out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1076 for i in range(arg0.getShape()[0]):
1077 for j in range(arg0.getShape()[1]):
1078 # uses Data object slicing, plus Data * and += operators
1079 out[i]+=arg0[i,j]*arg1[j]
1080 return out
1081 elif arg0.getRank()==1 and arg1.getRank()==1:
1082 return inner(arg0,arg1)
1083 else:
1084 raise SystemError,"matrixmult is not fully implemented yet!"
1085
1086 #=========================================================
1087 # reduction operations:
1088 #=========================================================
1089 def sum(arg):
1090 """
1091 @brief
1092
1093 @param arg
1094 """
1095 return arg.sum()
1096
1097 def sup(arg):
1098 """
1099 @brief
1100
1101 @param arg
1102 """
1103 if isinstance(arg,escript.Data):
1104 return arg.sup()
1105 elif isinstance(arg,float) or isinstance(arg,int):
1106 return arg
1107 else:
1108 return arg.max()
1109
1110 def inf(arg):
1111 """
1112 @brief
1113
1114 @param arg
1115 """
1116 if isinstance(arg,escript.Data):
1117 return arg.inf()
1118 elif isinstance(arg,float) or isinstance(arg,int):
1119 return arg
1120 else:
1121 return arg.min()
1122
1123 def L2(arg):
1124 """
1125 @brief returns the L2-norm of the
1126
1127 @param arg
1128 """
1129 if isinstance(arg,escript.Data):
1130 return arg.L2()
1131 elif isinstance(arg,float) or isinstance(arg,int):
1132 return abs(arg)
1133 else:
1134 return numarry.sqrt(dot(arg,arg))
1135
1136 def Lsup(arg):
1137 """
1138 @brief
1139
1140 @param arg
1141 """
1142 if isinstance(arg,escript.Data):
1143 return arg.Lsup()
1144 elif isinstance(arg,float) or isinstance(arg,int):
1145 return abs(arg)
1146 else:
1147 return max(numarray.abs(arg))
1148
1149 def dot(arg0,arg1):
1150 """
1151 @brief
1152
1153 @param arg
1154 """
1155 if isinstance(arg0,escript.Data):
1156 return arg0.dot(arg1)
1157 elif isinstance(arg1,escript.Data):
1158 return arg1.dot(arg0)
1159 else:
1160 return numarray.dot(arg0,arg1)
1161
1162 def kronecker(d):
1163 if hasattr(d,"getDim"):
1164 return numarray.identity(d.getDim())*1.
1165 else:
1166 return numarray.identity(d)*1.
1167
1168 def unit(i,d):
1169 """
1170 @brief return a unit vector of dimension d with nonzero index i
1171 @param d dimension
1172 @param i index
1173 """
1174 e = numarray.zeros((d,),numarray.Float)
1175 e[i] = 1.0
1176 return e
1177
1178 # ============================================
1179 # testing
1180 # ============================================
1181
1182 if __name__=="__main__":
1183 u=ScalarSymbol(dim=2,name="u")
1184 v=ScalarSymbol(dim=2,name="v")
1185 v=VectorSymbol(2,"v")
1186 u=VectorSymbol(2,"u")
1187
1188 print u+5,(u+5).diff(u)
1189 print 5+u,(5+u).diff(u)
1190 print u+v,(u+v).diff(u)
1191 print v+u,(v+u).diff(u)
1192
1193 print u*5,(u*5).diff(u)
1194 print 5*u,(5*u).diff(u)
1195 print u*v,(u*v).diff(u)
1196 print v*u,(v*u).diff(u)
1197
1198 print u-5,(u-5).diff(u)
1199 print 5-u,(5-u).diff(u)
1200 print u-v,(u-v).diff(u)
1201 print v-u,(v-u).diff(u)
1202
1203 print u/5,(u/5).diff(u)
1204 print 5/u,(5/u).diff(u)
1205 print u/v,(u/v).diff(u)
1206 print v/u,(v/u).diff(u)
1207
1208 print u**5,(u**5).diff(u)
1209 print 5**u,(5**u).diff(u)
1210 print u**v,(u**v).diff(u)
1211 print v**u,(v**u).diff(u)
1212
1213 print exp(u),exp(u).diff(u)
1214 print sqrt(u),sqrt(u).diff(u)
1215 print log(u),log(u).diff(u)
1216 print sin(u),sin(u).diff(u)
1217 print cos(u),cos(u).diff(u)
1218 print tan(u),tan(u).diff(u)
1219 print sign(u),sign(u).diff(u)
1220 print abs(u),abs(u).diff(u)
1221 print wherePositive(u),wherePositive(u).diff(u)
1222 print whereNegative(u),whereNegative(u).diff(u)
1223 print maxval(u),maxval(u).diff(u)
1224 print minval(u),minval(u).diff(u)
1225
1226 g=grad(u)
1227 print diff(5*g,g)
1228 4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1229
1230 #
1231 # $Log$
1232 # Revision 1.15 2005/08/12 01:45:36 jgs
1233 # erge of development branch dev-02 back to main trunk on 2005-08-12
1234 #
1235 # Revision 1.14.2.3 2005/08/03 09:55:33 gross
1236 # ContactTest is passing now./mk install!
1237 #
1238 # Revision 1.14.2.2 2005/08/02 03:15:14 gross
1239 # bug inb trace fixed!
1240 #
1241 # Revision 1.14.2.1 2005/07/29 07:10:28 gross
1242 # new functions in util and a new pde type in linearPDEs
1243 #
1244 # Revision 1.2.2.21 2005/07/28 04:19:23 gross
1245 # new functions maximum and minimum introduced.
1246 #
1247 # Revision 1.2.2.20 2005/07/25 01:26:27 gross
1248 # bug in inner fixed
1249 #
1250 # Revision 1.2.2.19 2005/07/21 04:01:28 jgs
1251 # minor comment fixes
1252 #
1253 # Revision 1.2.2.18 2005/07/21 01:02:43 jgs
1254 # commit ln() updates to development branch version
1255 #
1256 # Revision 1.12 2005/07/20 06:14:58 jgs
1257 # added ln(data) style wrapper for data.ln() - also added corresponding
1258 # implementation of Ln_Symbol class (not sure if this is right though)
1259 #
1260 # Revision 1.11 2005/07/08 04:07:35 jgs
1261 # Merge of development branch back to main trunk on 2005-07-08
1262 #
1263 # Revision 1.10 2005/06/09 05:37:59 jgs
1264 # Merge of development branch back to main trunk on 2005-06-09
1265 #
1266 # Revision 1.2.2.17 2005/07/07 07:28:58 gross
1267 # some stuff added to util.py to improve functionality
1268 #
1269 # Revision 1.2.2.16 2005/06/30 01:53:55 gross
1270 # a bug in coloring fixed
1271 #
1272 # Revision 1.2.2.15 2005/06/29 02:36:43 gross
1273 # Symbols have been introduced and some function clarified. needs much more work
1274 #
1275 # Revision 1.2.2.14 2005/05/20 04:05:23 gross
1276 # some work on a darcy flow started
1277 #
1278 # Revision 1.2.2.13 2005/03/16 05:17:58 matt
1279 # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1280 # complement kronecker(dim) function.
1281 #
1282 # Revision 1.2.2.12 2005/03/10 08:14:37 matt
1283 # Added non-member Linf utility function to complement Data::Linf().
1284 #
1285 # Revision 1.2.2.11 2005/02/17 05:53:25 gross
1286 # some bug in saveDX fixed: in fact the bug was in
1287 # DataC/getDataPointShape
1288 #
1289 # Revision 1.2.2.10 2005/01/11 04:59:36 gross
1290 # automatic interpolation in integrate switched off
1291 #
1292 # Revision 1.2.2.9 2005/01/11 03:38:13 gross
1293 # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1294 #
1295 # Revision 1.2.2.8 2005/01/05 04:21:41 gross
1296 # FunctionSpace checking/matchig in slicing added
1297 #
1298 # Revision 1.2.2.7 2004/12/29 05:29:59 gross
1299 # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1300 #
1301 # Revision 1.2.2.6 2004/12/24 06:05:41 gross
1302 # some changes in linearPDEs to add AdevectivePDE
1303 #
1304 # Revision 1.2.2.5 2004/12/17 00:06:53 gross
1305 # mk sets ESYS_ROOT is undefined
1306 #
1307 # Revision 1.2.2.4 2004/12/07 03:19:51 gross
1308 # options for GMRES and PRES20 added
1309 #
1310 # Revision 1.2.2.3 2004/12/06 04:55:18 gross
1311 # function wraper extended
1312 #
1313 # Revision 1.2.2.2 2004/11/22 05:44:07 gross
1314 # a few more unitary functions have been added but not implemented in Data yet
1315 #
1316 # Revision 1.2.2.1 2004/11/12 06:58:15 gross
1317 # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1318 #
1319 # Revision 1.2 2004/10/27 00:23:36 jgs
1320 # fixed minor syntax error
1321 #
1322 # Revision 1.1.1.1 2004/10/26 06:53:56 jgs
1323 # initial import of project esys2
1324 #
1325 # Revision 1.1.2.3 2004/10/26 06:43:48 jgs
1326 # committing Lutz's and Paul's changes to brach jgs
1327 #
1328 # Revision 1.1.4.1 2004/10/20 05:32:51 cochrane
1329 # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.
1330 #
1331 # Revision 1.1 2004/08/05 03:58:27 gross
1332 # Bug in Assemble_NodeCoordinates fixed
1333 #
1334 #

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26