/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Contents of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 124 - (show annotations)
Wed Jul 20 06:14:58 2005 UTC (14 years, 4 months ago) by jgs
Original Path: trunk/esys2/escript/py_src/util.py
File MIME type: text/x-python
File size: 39957 byte(s)
added ln(data) style wrapper for data.ln() - also added corresponding
implementation of Ln_Symbol class (not sure if this is right though)

1 # $Id$
2
3 ## @file util.py
4
5 """
6 @brief Utility functions for escript
7
8 TODO for Data:
9
10 * binary operations @: (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11 @=+,-,*,/,** (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12 (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13
14 * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15 * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16 """
17
18 import numarray
19 import escript
20
21 #
22 # escript constants (have to be consistent with utilC.h )
23 #
24 UNKNOWN=-1
25 EPSILON=1.e-15
26 Pi=numarray.pi
27 # some solver options:
28 NO_REORDERING=0
29 MINIMUM_FILL_IN=1
30 NESTED_DISSECTION=2
31 # solver methods
32 DEFAULT_METHOD=0
33 DIRECT=1
34 CHOLEVSKY=2
35 PCG=3
36 CR=4
37 CGS=5
38 BICGSTAB=6
39 SSOR=7
40 ILU0=8
41 ILUT=9
42 JACOBI=10
43 GMRES=11
44 PRES20=12
45
46 METHOD_KEY="method"
47 SYMMETRY_KEY="symmetric"
48 TOLERANCE_KEY="tolerance"
49
50 # supported file formats:
51 VRML=1
52 PNG=2
53 JPEG=3
54 JPG=3
55 PS=4
56 OOGL=5
57 BMP=6
58 TIFF=7
59 OPENINVENTOR=8
60 RENDERMAN=9
61 PNM=10
62
63 #===========================================================
64 # a simple tool box to deal with _differentials of functions
65 #===========================================================
66
67 class Symbol:
68 """symbol class"""
69 def __init__(self,name="symbol",shape=(),dim=3,args=[]):
70 """creates an instance of a symbol of shape shape and spatial dimension dim
71 The symbol may depending on a list of arguments args which
72 may be symbols or other objects. name gives the name of the symbol."""
73 self.__args=args
74 self.__name=name
75 self.__shape=shape
76 if hasattr(dim,"getDim"):
77 self.__dim=dim.getDim()
78 else:
79 self.__dim=dim
80 #
81 self.__cache_val=None
82 self.__cache_argval=None
83
84 def getArgument(self,i):
85 """returns the i-th argument"""
86 return self.__args[i]
87
88 def getDim(self):
89 """returns the spatial dimension of the symbol"""
90 return self.__dim
91
92 def getRank(self):
93 """returns the rank of the symbol"""
94 return len(self.getShape())
95
96 def getShape(self):
97 """returns the shape of the symbol"""
98 return self.__shape
99
100 def getEvaluatedArguments(self,argval):
101 """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
102 if argval==self.__cache_argval:
103 print "%s: cached value used"%self
104 return self.__cache_val
105 else:
106 out=[]
107 for a in self.__args:
108 if isinstance(a,Symbol):
109 out.append(a.eval(argval))
110 else:
111 out.append(a)
112 self.__cache_argval=argval
113 self.__cache_val=out
114 return out
115
116 def getDifferentiatedArguments(self,arg):
117 """returns the list of the arguments _differentiated by arg"""
118 out=[]
119 for a in self.__args:
120 if isinstance(a,Symbol):
121 out.append(a.diff(arg))
122 else:
123 out.append(0)
124 return out
125
126 def diff(self,arg):
127 """returns the _differention of self by arg."""
128 if self==arg:
129 out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
130 if self.getRank()==0:
131 out=1.
132 elif self.getRank()==1:
133 for i0 in range(self.getShape()[0]):
134 out[i0,i0]=1.
135 elif self.getRank()==2:
136 for i0 in range(self.getShape()[0]):
137 for i1 in range(self.getShape()[1]):
138 out[i0,i1,i0,i1]=1.
139 elif self.getRank()==3:
140 for i0 in range(self.getShape()[0]):
141 for i1 in range(self.getShape()[1]):
142 for i2 in range(self.getShape()[2]):
143 out[i0,i1,i2,i0,i1,i2]=1.
144 elif self.getRank()==4:
145 for i0 in range(self.getShape()[0]):
146 for i1 in range(self.getShape()[1]):
147 for i2 in range(self.getShape()[2]):
148 for i3 in range(self.getShape()[3]):
149 out[i0,i1,i2,i3,i0,i1,i2,i3]=1.
150 else:
151 raise ValueError,"differential support rank<5 only."
152 return out
153 else:
154 return self._diff(arg)
155
156 def _diff(self,arg):
157 """return derivate of self with respect to arg (!=self).
158 This method is overwritten by a particular symbol"""
159 return 0
160
161 def eval(self,argval):
162 """subsitutes symbol u in self by argval[u] and returns the result. If
163 self is not a key of argval then self is returned."""
164 if argval.has_key(self):
165 return argval[self]
166 else:
167 return self
168
169 def __str__(self):
170 """returns a string representation of the symbol"""
171 return self.__name
172
173 def __add__(self,other):
174 """adds other to symbol self. if _testForZero(other) self is returned."""
175 if _testForZero(other):
176 return self
177 else:
178 a=_matchShape([self,other])
179 return Add_Symbol(a[0],a[1])
180
181 def __radd__(self,other):
182 """adds other to symbol self. if _testForZero(other) self is returned."""
183 return self+other
184
185 def __neg__(self):
186 """returns -self."""
187 return self*(-1.)
188
189 def __pos__(self):
190 """returns +self."""
191 return self
192
193 def __abs__(self):
194 """returns absolute value"""
195 return Abs_Symbol(self)
196
197 def __sub__(self,other):
198 """subtracts other from symbol self. if _testForZero(other) self is returned."""
199 if _testForZero(other):
200 return self
201 else:
202 return self+(-other)
203
204 def __rsub__(self,other):
205 """subtracts symbol self from other."""
206 return -self+other
207
208 def __div__(self,other):
209 """divides symbol self by other."""
210 if isinstance(other,Symbol):
211 a=_matchShape([self,other])
212 return Div_Symbol(a[0],a[1])
213 else:
214 return self*(1./other)
215
216 def __rdiv__(self,other):
217 """dived other by symbol self. if _testForZero(other) 0 is returned."""
218 if _testForZero(other):
219 return 0
220 else:
221 a=_matchShape([self,other])
222 return Div_Symbol(a[0],a[1])
223
224 def __pow__(self,other):
225 """raises symbol self to the power of other"""
226 a=_matchShape([self,other])
227 return Power_Symbol(a[0],a[1])
228
229 def __rpow__(self,other):
230 """raises other to the symbol self"""
231 a=_matchShape([self,other])
232 return Power_Symbol(a[1],a[0])
233
234 def __mul__(self,other):
235 """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
236 if _testForZero(other):
237 return 0
238 else:
239 a=_matchShape([self,other])
240 return Mult_Symbol(a[0],a[1])
241
242 def __rmul__(self,other):
243 """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
244 return self*other
245
246 def __getitem__(self,sl):
247 print sl
248
249 def Float_Symbol(Symbol):
250 def __init__(self,name="symbol",shape=(),args=[]):
251 Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
252
253 class ScalarSymbol(Symbol):
254 """a scalar symbol"""
255 def __init__(self,dim=3,name="scalar"):
256 """creates a scalar symbol of spatial dimension dim"""
257 if hasattr(dim,"getDim"):
258 d=dim.getDim()
259 else:
260 d=dim
261 Symbol.__init__(self,shape=(),dim=d,name=name)
262
263 class VectorSymbol(Symbol):
264 """a vector symbol"""
265 def __init__(self,dim=3,name="vector"):
266 """creates a vector symbol of spatial dimension dim"""
267 if hasattr(dim,"getDim"):
268 d=dim.getDim()
269 else:
270 d=dim
271 Symbol.__init__(self,shape=(d,),dim=d,name=name)
272
273 class TensorSymbol(Symbol):
274 """a tensor symbol"""
275 def __init__(self,dim=3,name="tensor"):
276 """creates a tensor symbol of spatial dimension dim"""
277 if hasattr(dim,"getDim"):
278 d=dim.getDim()
279 else:
280 d=dim
281 Symbol.__init__(self,shape=(d,d),dim=d,name=name)
282
283 class Tensor3Symbol(Symbol):
284 """a tensor order 3 symbol"""
285 def __init__(self,dim=3,name="tensor3"):
286 """creates a tensor order 3 symbol of spatial dimension dim"""
287 if hasattr(dim,"getDim"):
288 d=dim.getDim()
289 else:
290 d=dim
291 Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
292
293 class Tensor4Symbol(Symbol):
294 """a tensor order 4 symbol"""
295 def __init__(self,dim=3,name="tensor4"):
296 """creates a tensor order 4 symbol of spatial dimension dim"""
297 if hasattr(dim,"getDim"):
298 d=dim.getDim()
299 else:
300 d=dim
301 Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
302
303 class Add_Symbol(Symbol):
304 """symbol representing the sum of two arguments"""
305 def __init__(self,arg0,arg1):
306 a=[arg0,arg1]
307 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
308 def __str__(self):
309 return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
310 def eval(self,argval):
311 a=self.getEvaluatedArguments(argval)
312 return a[0]+a[1]
313 def _diff(self,arg):
314 a=self.getDifferentiatedArguments(arg)
315 return a[0]+a[1]
316
317 class Mult_Symbol(Symbol):
318 """symbol representing the product of two arguments"""
319 def __init__(self,arg0,arg1):
320 a=[arg0,arg1]
321 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
322 def __str__(self):
323 return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
324 def eval(self,argval):
325 a=self.getEvaluatedArguments(argval)
326 return a[0]*a[1]
327 def _diff(self,arg):
328 a=self.getDifferentiatedArguments(arg)
329 return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
330
331 class Div_Symbol(Symbol):
332 """symbol representing the quotient of two arguments"""
333 def __init__(self,arg0,arg1):
334 a=[arg0,arg1]
335 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
336 def __str__(self):
337 return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
338 def eval(self,argval):
339 a=self.getEvaluatedArguments(argval)
340 return a[0]/a[1]
341 def _diff(self,arg):
342 a=self.getDifferentiatedArguments(arg)
343 return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
344 (self.getArgument(1)*self.getArgument(1))
345
346 class Power_Symbol(Symbol):
347 """symbol representing the power of the first argument to the power of the second argument"""
348 def __init__(self,arg0,arg1):
349 a=[arg0,arg1]
350 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
351 def __str__(self):
352 return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
353 def eval(self,argval):
354 a=self.getEvaluatedArguments(argval)
355 return a[0]**a[1]
356 def _diff(self,arg):
357 a=self.getDifferentiatedArguments(arg)
358 return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
359
360 class Abs_Symbol(Symbol):
361 """symbol representing absolute value of its argument"""
362 def __init__(self,arg):
363 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
364 def __str__(self):
365 return "abs(%s)"%str(self.getArgument(0))
366 def eval(self,argval):
367 return abs(self.getEvaluatedArguments(argval)[0])
368 def _diff(self,arg):
369 return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
370
371 #=========================================================
372 # some little helpers
373 #=========================================================
374 def _testForZero(arg):
375 """returns True is arg is considered of being zero"""
376 if isinstance(arg,int):
377 return not arg>0
378 elif isinstance(arg,float):
379 return not arg>0.
380 elif isinstance(arg,numarray.NumArray):
381 a=abs(arg)
382 while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
383 return not a>0
384 else:
385 return False
386
387 def _extractDim(args):
388 dim=None
389 for a in args:
390 if hasattr(a,"getDim"):
391 d=a.getDim()
392 if dim==None:
393 dim=d
394 else:
395 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
396 if dim==None:
397 raise ValueError,"cannot recover spatial dimension"
398 return dim
399
400 def _identifyShape(arg):
401 """identifies the shape of arg."""
402 if hasattr(arg,"getShape"):
403 arg_shape=arg.getShape()
404 elif hasattr(arg,"shape"):
405 s=arg.shape
406 if callable(s):
407 arg_shape=s()
408 else:
409 arg_shape=s
410 else:
411 arg_shape=()
412 return arg_shape
413
414 def _extractShape(args):
415 """extracts the common shape of the list of arguments args"""
416 shape=None
417 for a in args:
418 a_shape=_identifyShape(a)
419 if shape==None: shape=a_shape
420 if shape!=a_shape: raise ValueError,"inconsistent shape"
421 if shape==None:
422 raise ValueError,"cannot recover shape"
423 return shape
424
425 def _matchShape(args,shape=None):
426 """returns the list of arguments args as object which have all the specified shape.
427 if shape is not given the shape "largest" shape of args is used."""
428 # identify the list of shapes:
429 arg_shapes=[]
430 for a in args: arg_shapes.append(_identifyShape(a))
431 # get the largest shape (currently the longest shape):
432 if shape==None: shape=max(arg_shapes)
433
434 out=[]
435 for i in range(len(args)):
436 if shape==arg_shapes[i]:
437 out.append(args[i])
438 else:
439 if len(shape)==0: # then len(arg_shapes[i])>0
440 raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
441 else:
442 if len(arg_shapes[i])==0:
443 out.append(outer(args[i],numarray.ones(shape)))
444 else:
445 raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
446 return out
447
448 #=========================================================
449 # wrappers for various mathematical functions:
450 #=========================================================
451 def diff(arg,dep):
452 """returns the derivative of arg with respect to dep. If arg is not Symbol object
453 0 is returned"""
454 if isinstance(arg,Symbol):
455 return arg.diff(dep)
456 elif hasattr(arg,"shape"):
457 if callable(arg.shape):
458 return numarray.zeros(arg.shape(),numarray.Float)
459 else:
460 return numarray.zeros(arg.shape,numarray.Float)
461 else:
462 return 0
463
464 def exp(arg):
465 """
466 @brief applies the exponential function to arg
467 @param arg (input): argument
468 """
469 if isinstance(arg,Symbol):
470 return Exp_Symbol(arg)
471 elif hasattr(arg,"exp"):
472 return arg.exp()
473 else:
474 return numarray.exp(arg)
475
476 class Exp_Symbol(Symbol):
477 """symbol representing the power of the first argument to the power of the second argument"""
478 def __init__(self,arg):
479 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
480 def __str__(self):
481 return "exp(%s)"%str(self.getArgument(0))
482 def eval(self,argval):
483 return exp(self.getEvaluatedArguments(argval)[0])
484 def _diff(self,arg):
485 return self*self.getDifferentiatedArguments(arg)[0]
486
487 def sqrt(arg):
488 """
489 @brief applies the squre root function to arg
490 @param arg (input): argument
491 """
492 if isinstance(arg,Symbol):
493 return Sqrt_Symbol(arg)
494 elif hasattr(arg,"sqrt"):
495 return arg.sqrt()
496 else:
497 return numarray.sqrt(arg)
498
499 class Sqrt_Symbol(Symbol):
500 """symbol representing square root of argument"""
501 def __init__(self,arg):
502 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
503 def __str__(self):
504 return "sqrt(%s)"%str(self.getArgument(0))
505 def eval(self,argval):
506 return sqrt(self.getEvaluatedArguments(argval)[0])
507 def _diff(self,arg):
508 return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
509
510 def log(arg):
511 """
512 @brief applies the logarithmic function bases exp(1.) to arg
513 @param arg (input): argument
514 """
515 if isinstance(arg,Symbol):
516 return Log_Symbol(arg)
517 elif hasattr(arg,"log"):
518 return arg.log()
519 else:
520 return numarray.log(arg)
521
522 class Log_Symbol(Symbol):
523 """symbol representing logarithm of the argument"""
524 def __init__(self,arg):
525 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
526 def __str__(self):
527 return "log(%s)"%str(self.getArgument(0))
528 def eval(self,argval):
529 return log(self.getEvaluatedArguments(argval)[0])
530 def _diff(self,arg):
531 return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
532
533 def ln(arg):
534 """
535 @brief applies the natural logarithmic function to arg
536 @param arg (input): argument
537 """
538 if isinstance(arg,Symbol):
539 return Ln_Symbol(arg)
540 elif hasattr(arg,"ln"):
541 return arg.log()
542 else:
543 return numarray.log(arg)
544
545 class Ln_Symbol(Symbol):
546 """symbol representing natural logarithm of the argument"""
547 def __init__(self,arg):
548 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
549 def __str__(self):
550 return "ln(%s)"%str(self.getArgument(0))
551 def eval(self,argval):
552 return ln(self.getEvaluatedArguments(argval)[0])
553 def _diff(self,arg):
554 return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
555
556 def sin(arg):
557 """
558 @brief applies the sinus function to arg
559 @param arg (input): argument
560 """
561 if isinstance(arg,Symbol):
562 return Sin_Symbol(arg)
563 elif hasattr(arg,"sin"):
564 return arg.sin()
565 else:
566 return numarray.sin(arg)
567
568 class Sin_Symbol(Symbol):
569 """symbol representing logarithm of the argument"""
570 def __init__(self,arg):
571 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
572 def __str__(self):
573 return "sin(%s)"%str(self.getArgument(0))
574 def eval(self,argval):
575 return sin(self.getEvaluatedArguments(argval)[0])
576 def _diff(self,arg):
577 return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
578
579 def cos(arg):
580 """
581 @brief applies the sinus function to arg
582 @param arg (input): argument
583 """
584 if isinstance(arg,Symbol):
585 return Cos_Symbol(arg)
586 elif hasattr(arg,"cos"):
587 return arg.cos()
588 else:
589 return numarray.cos(arg)
590
591 class Cos_Symbol(Symbol):
592 """symbol representing logarithm of the argument"""
593 def __init__(self,arg):
594 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
595 def __str__(self):
596 return "cos(%s)"%str(self.getArgument(0))
597 def eval(self,argval):
598 return cos(self.getEvaluatedArguments(argval)[0])
599 def _diff(self,arg):
600 return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
601
602 def tan(arg):
603 """
604 @brief applies the sinus function to arg
605 @param arg (input): argument
606 """
607 if isinstance(arg,Symbol):
608 return Tan_Symbol(arg)
609 elif hasattr(arg,"tan"):
610 return arg.tan()
611 else:
612 return numarray.tan(arg)
613
614 class Tan_Symbol(Symbol):
615 """symbol representing logarithm of the argument"""
616 def __init__(self,arg):
617 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
618 def __str__(self):
619 return "tan(%s)"%str(self.getArgument(0))
620 def eval(self,argval):
621 return tan(self.getEvaluatedArguments(argval)[0])
622 def _diff(self,arg):
623 s=cos(self.getArgument(0))
624 return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
625
626 def sign(arg):
627 """
628 @brief applies the sign function to arg
629 @param arg (input): argument
630 """
631 if isinstance(arg,Symbol):
632 return Sign_Symbol(arg)
633 elif hasattr(arg,"sign"):
634 return arg.sign()
635 else:
636 return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
637 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
638
639 class Sign_Symbol(Symbol):
640 """symbol representing the sign of the argument"""
641 def __init__(self,arg):
642 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
643 def __str__(self):
644 return "sign(%s)"%str(self.getArgument(0))
645 def eval(self,argval):
646 return sign(self.getEvaluatedArguments(argval)[0])
647
648 def maxval(arg):
649 """
650 @brief returns the maximum value of argument arg""
651 @param arg (input): argument
652 """
653 if isinstance(arg,Symbol):
654 return Max_Symbol(arg)
655 elif hasattr(arg,"maxval"):
656 return arg.maxval()
657 elif hasattr(arg,"max"):
658 return arg.max()
659 else:
660 return arg
661
662 class Max_Symbol(Symbol):
663 """symbol representing the sign of the argument"""
664 def __init__(self,arg):
665 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
666 def __str__(self):
667 return "maxval(%s)"%str(self.getArgument(0))
668 def eval(self,argval):
669 return maxval(self.getEvaluatedArguments(argval)[0])
670
671 def minval(arg):
672 """
673 @brief returns the maximum value of argument arg""
674 @param arg (input): argument
675 """
676 if isinstance(arg,Symbol):
677 return Min_Symbol(arg)
678 elif hasattr(arg,"maxval"):
679 return arg.minval()
680 elif hasattr(arg,"min"):
681 return arg.min()
682 else:
683 return arg
684
685 class Min_Symbol(Symbol):
686 """symbol representing the sign of the argument"""
687 def __init__(self,arg):
688 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
689 def __str__(self):
690 return "minval(%s)"%str(self.getArgument(0))
691 def eval(self,argval):
692 return minval(self.getEvaluatedArguments(argval)[0])
693
694 def wherePositive(arg):
695 """
696 @brief returns the maximum value of argument arg""
697 @param arg (input): argument
698 """
699 if _testForZero(arg):
700 return 0
701 elif isinstance(arg,Symbol):
702 return WherePositive_Symbol(arg)
703 elif hasattr(arg,"wherePositive"):
704 return arg.minval()
705 elif hasattr(arg,"wherePositive"):
706 numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
707 else:
708 if arg>0:
709 return 1.
710 else:
711 return 0.
712
713 class WherePositive_Symbol(Symbol):
714 """symbol representing the wherePositive function"""
715 def __init__(self,arg):
716 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
717 def __str__(self):
718 return "wherePositive(%s)"%str(self.getArgument(0))
719 def eval(self,argval):
720 return wherePositive(self.getEvaluatedArguments(argval)[0])
721
722 def whereNegative(arg):
723 """
724 @brief returns the maximum value of argument arg""
725 @param arg (input): argument
726 """
727 if _testForZero(arg):
728 return 0
729 elif isinstance(arg,Symbol):
730 return WhereNegative_Symbol(arg)
731 elif hasattr(arg,"whereNegative"):
732 return arg.whereNegative()
733 elif hasattr(arg,"shape"):
734 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
735 else:
736 if arg<0:
737 return 1.
738 else:
739 return 0.
740
741 class WhereNegative_Symbol(Symbol):
742 """symbol representing the whereNegative function"""
743 def __init__(self,arg):
744 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
745 def __str__(self):
746 return "whereNegative(%s)"%str(self.getArgument(0))
747 def eval(self,argval):
748 return whereNegative(self.getEvaluatedArguments(argval)[0])
749
750 def outer(arg0,arg1):
751 if _testForZero(arg0) or _testForZero(arg1):
752 return 0
753 else:
754 if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
755 return Outer_Symbol(arg0,arg1)
756 elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
757 return arg0*arg1
758 elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
759 return numarray.outer(arg0,arg1)
760 else:
761 if arg0.getRank()==1 and arg1.getRank()==1:
762 out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
763 for i in range(arg0.getShape()[0]):
764 for j in range(arg1.getShape()[0]):
765 out[i,j]=arg0[i]*arg1[j]
766 return out
767 else:
768 raise ValueError,"outer is not fully implemented yet."
769
770 class Outer_Symbol(Symbol):
771 """symbol representing the outer product of its two argument"""
772 def __init__(self,arg0,arg1):
773 a=[arg0,arg1]
774 s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
775 Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
776 def __str__(self):
777 return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
778 def eval(self,argval):
779 a=self.getEvaluatedArguments(argval)
780 return outer(a[0],a[1])
781 def _diff(self,arg):
782 a=self.getDifferentiatedArguments(arg)
783 return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
784
785 def interpolate(arg,where):
786 """
787 @brief interpolates the function into the FunctionSpace where.
788
789 @param arg interpolant
790 @param where FunctionSpace to interpolate to
791 """
792 if _testForZero(arg):
793 return 0
794 elif isinstance(arg,Symbol):
795 return Interpolated_Symbol(arg,where)
796 else:
797 return escript.Data(arg,where)
798
799 def Interpolated_Symbol(Symbol):
800 """symbol representing the integral of the argument"""
801 def __init__(self,arg,where):
802 Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
803 def __str__(self):
804 return "interpolated(%s)"%(str(self.getArgument(0)))
805 def eval(self,argval):
806 a=self.getEvaluatedArguments(argval)
807 return integrate(a[0],where=self.getArgument(1))
808 def _diff(self,arg):
809 a=self.getDifferentiatedArguments(arg)
810 return integrate(a[0],where=self.getArgument(1))
811
812 def grad(arg,where=None):
813 """
814 @brief returns the spatial gradient of arg at where.
815
816 @param arg: Data object representing the function which gradient to be calculated.
817 @param where: FunctionSpace in which the gradient will be calculated. If not present or
818 None an appropriate default is used.
819 """
820 if _testForZero(arg):
821 return 0
822 elif isinstance(arg,Symbol):
823 return Grad_Symbol(arg,where)
824 elif hasattr(arg,"grad"):
825 if where==None:
826 return arg.grad()
827 else:
828 return arg.grad(where)
829 else:
830 return arg*0.
831
832 def Grad_Symbol(Symbol):
833 """symbol representing the gradient of the argument"""
834 def __init__(self,arg,where=None):
835 d=_extractDim([arg])
836 s=tuple(list(_identifyShape([arg])).append(d))
837 Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
838 def __str__(self):
839 return "grad(%s)"%(str(self.getArgument(0)))
840 def eval(self,argval):
841 a=self.getEvaluatedArguments(argval)
842 return grad(a[0],where=self.getArgument(1))
843 def _diff(self,arg):
844 a=self.getDifferentiatedArguments(arg)
845 return grad(a[0],where=self.getArgument(1))
846
847 def integrate(arg,where=None):
848 """
849 @brief return the integral if the function represented by Data object arg over its domain.
850
851 @param arg: Data object representing the function which is integrated.
852 @param where: FunctionSpace in which the integral is calculated. If not present or
853 None an appropriate default is used.
854 """
855 if _testForZero(arg):
856 return 0
857 elif isinstance(arg,Symbol):
858 return Integral_Symbol(arg,where)
859 else:
860 if not where==None: arg=escript.Data(arg,where)
861 if arg.getRank()==0:
862 return arg.integrate()[0]
863 else:
864 return arg.integrate()
865
866 def Integral_Symbol(Float_Symbol):
867 """symbol representing the integral of the argument"""
868 def __init__(self,arg,where=None):
869 Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
870 def __str__(self):
871 return "integral(%s)"%(str(self.getArgument(0)))
872 def eval(self,argval):
873 a=self.getEvaluatedArguments(argval)
874 return integrate(a[0],where=self.getArgument(1))
875 def _diff(self,arg):
876 a=self.getDifferentiatedArguments(arg)
877 return integrate(a[0],where=self.getArgument(1))
878
879 #=============================
880 #
881 # wrapper for various functions: if the argument has attribute the function name
882 # as an argument it calls the corresponding methods. Otherwise the corresponding
883 # numarray function is called.
884
885 # functions involving the underlying Domain:
886
887
888 # functions returning Data objects:
889
890 def transpose(arg,axis=None):
891 """
892 @brief returns the transpose of the Data object arg.
893
894 @param arg
895 """
896 if axis==None:
897 r=0
898 if hasattr(arg,"getRank"): r=arg.getRank()
899 if hasattr(arg,"rank"): r=arg.rank
900 axis=r/2
901 if isinstance(arg,Symbol):
902 return Transpose_Symbol(arg,axis=r)
903 if isinstance(arg,escript.Data):
904 # hack for transpose
905 r=arg.getRank()
906 if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
907 s=arg.getShape()
908 out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
909 for i in range(s[0]):
910 for j in range(s[1]):
911 out[j,i]=arg[i,j]
912 return out
913 # end hack for transpose
914 return arg.transpose(axis)
915 else:
916 return numarray.transpose(arg,axis=axis)
917
918 def trace(arg,axis0=0,axis1=1):
919 """
920 @brief return
921
922 @param arg
923 """
924 if isinstance(arg,Symbol):
925 s=list(arg.getShape())
926 s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
927 return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
928 elif isinstance(arg,escript.Data):
929 # hack for trace
930 s=arg.getShape()
931 if s[axis0]!=s[axis1]:
932 raise ValueError,"illegal axis in trace"
933 out=escript.Scalar(0.,arg.getFunctionSpace())
934 for i in range(s[0]):
935 for j in range(s[1]):
936 out+=arg[i,j]
937 return out
938 # end hack for transpose
939 return arg.transpose(axis0=axis0,axis1=axis1)
940 else:
941 return numarray.trace(arg,axis0=axis0,axis1=axis1)
942
943 def Trace_Symbol(Symbol):
944 pass
945
946 def length(arg):
947 """
948 @brief
949
950 @param arg
951 """
952 if isinstance(arg,escript.Data):
953 if arg.isEmpty(): return escript.Data()
954 if arg.getRank()==0:
955 return abs(arg)
956 elif arg.getRank()==1:
957 sum=escript.Scalar(0,arg.getFunctionSpace())
958 for i in range(arg.getShape()[0]):
959 sum+=arg[i]**2
960 return sqrt(sum)
961 elif arg.getRank()==2:
962 sum=escript.Scalar(0,arg.getFunctionSpace())
963 for i in range(arg.getShape()[0]):
964 for j in range(arg.getShape()[1]):
965 sum+=arg[i,j]**2
966 return sqrt(sum)
967 elif arg.getRank()==3:
968 sum=escript.Scalar(0,arg.getFunctionSpace())
969 for i in range(arg.getShape()[0]):
970 for j in range(arg.getShape()[1]):
971 for k in range(arg.getShape()[2]):
972 sum+=arg[i,j,k]**2
973 return sqrt(sum)
974 elif arg.getRank()==4:
975 sum=escript.Scalar(0,arg.getFunctionSpace())
976 for i in range(arg.getShape()[0]):
977 for j in range(arg.getShape()[1]):
978 for k in range(arg.getShape()[2]):
979 for l in range(arg.getShape()[3]):
980 sum+=arg[i,j,k,l]**2
981 return sqrt(sum)
982 else:
983 raise SystemError,"length is not been implemented yet"
984 # return arg.length()
985 else:
986 return sqrt((arg**2).sum())
987
988 def deviator(arg):
989 """
990 @brief
991
992 @param arg0
993 """
994 if isinstance(arg,escript.Data):
995 shape=arg.getShape()
996 else:
997 shape=arg.shape
998 if len(shape)!=2:
999 raise ValueError,"Deviator requires rank 2 object"
1000 if shape[0]!=shape[1]:
1001 raise ValueError,"Deviator requires a square matrix"
1002 return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
1003
1004 def inner(arg0,arg1):
1005 """
1006 @brief
1007
1008 @param arg0, arg1
1009 """
1010 sum=escript.Scalar(0,arg0.getFunctionSpace())
1011 if arg.getRank()==0:
1012 return arg0*arg1
1013 elif arg.getRank()==1:
1014 sum=escript.Scalar(0,arg.getFunctionSpace())
1015 for i in range(arg.getShape()[0]):
1016 sum+=arg0[i]*arg1[i]
1017 elif arg.getRank()==2:
1018 sum=escript.Scalar(0,arg.getFunctionSpace())
1019 for i in range(arg.getShape()[0]):
1020 for j in range(arg.getShape()[1]):
1021 sum+=arg0[i,j]*arg1[i,j]
1022 elif arg.getRank()==3:
1023 sum=escript.Scalar(0,arg.getFunctionSpace())
1024 for i in range(arg.getShape()[0]):
1025 for j in range(arg.getShape()[1]):
1026 for k in range(arg.getShape()[2]):
1027 sum+=arg0[i,j,k]*arg1[i,j,k]
1028 elif arg.getRank()==4:
1029 sum=escript.Scalar(0,arg.getFunctionSpace())
1030 for i in range(arg.getShape()[0]):
1031 for j in range(arg.getShape()[1]):
1032 for k in range(arg.getShape()[2]):
1033 for l in range(arg.getShape()[3]):
1034 sum+=arg0[i,j,k,l]*arg1[i,j,k,l]
1035 else:
1036 raise SystemError,"inner is not been implemented yet"
1037 return sum
1038
1039 def matrixmult(arg0,arg1):
1040
1041 if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1042 numarray.matrixmult(arg0,arg1)
1043 else:
1044 # escript.matmult(arg0,arg1)
1045 if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1046 arg0=escript.Data(arg0,arg1.getFunctionSpace())
1047 elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1048 arg1=escript.Data(arg1,arg0.getFunctionSpace())
1049 if arg0.getRank()==2 and arg1.getRank()==1:
1050 out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1051 for i in range(arg0.getShape()[0]):
1052 for j in range(arg0.getShape()[1]):
1053 out[i]+=arg0[i,j]*arg1[j]
1054 return out
1055 else:
1056 raise SystemError,"matrixmult is not fully implemented yet!"
1057
1058 #=========================================================
1059 # reduction operations:
1060 #=========================================================
1061 def sum(arg):
1062 """
1063 @brief
1064
1065 @param arg
1066 """
1067 return arg.sum()
1068
1069 def sup(arg):
1070 """
1071 @brief
1072
1073 @param arg
1074 """
1075 if isinstance(arg,escript.Data):
1076 return arg.sup()
1077 elif isinstance(arg,float) or isinstance(arg,int):
1078 return arg
1079 else:
1080 return arg.max()
1081
1082 def inf(arg):
1083 """
1084 @brief
1085
1086 @param arg
1087 """
1088 if isinstance(arg,escript.Data):
1089 return arg.inf()
1090 elif isinstance(arg,float) or isinstance(arg,int):
1091 return arg
1092 else:
1093 return arg.min()
1094
1095 def L2(arg):
1096 """
1097 @brief returns the L2-norm of the
1098
1099 @param arg
1100 """
1101 if isinstance(arg,escript.Data):
1102 return arg.L2()
1103 elif isinstance(arg,float) or isinstance(arg,int):
1104 return abs(arg)
1105 else:
1106 return numarry.sqrt(dot(arg,arg))
1107
1108 def Lsup(arg):
1109 """
1110 @brief
1111
1112 @param arg
1113 """
1114 if isinstance(arg,escript.Data):
1115 return arg.Lsup()
1116 elif isinstance(arg,float) or isinstance(arg,int):
1117 return abs(arg)
1118 else:
1119 return max(numarray.abs(arg))
1120
1121 def dot(arg0,arg1):
1122 """
1123 @brief
1124
1125 @param arg
1126 """
1127 if isinstance(arg0,escript.Data):
1128 return arg0.dot(arg1)
1129 elif isinstance(arg1,escript.Data):
1130 return arg1.dot(arg0)
1131 else:
1132 return numarray.dot(arg0,arg1)
1133
1134 def kronecker(d):
1135 if hasattr(d,"getDim"):
1136 return numarray.identity(d.getDim())
1137 else:
1138 return numarray.identity(d)
1139
1140 def unit(i,d):
1141 """
1142 @brief return a unit vector of dimension d with nonzero index i
1143 @param d dimension
1144 @param i index
1145 """
1146 e = numarray.zeros((d,),numarray.Float)
1147 e[i] = 1.0
1148 return e
1149
1150 # ============================================
1151 # testing
1152 # ============================================
1153
1154 if __name__=="__main__":
1155 u=ScalarSymbol(dim=2,name="u")
1156 v=ScalarSymbol(dim=2,name="v")
1157 v=VectorSymbol(2,"v")
1158 u=VectorSymbol(2,"u")
1159
1160 print u+5,(u+5).diff(u)
1161 print 5+u,(5+u).diff(u)
1162 print u+v,(u+v).diff(u)
1163 print v+u,(v+u).diff(u)
1164
1165 print u*5,(u*5).diff(u)
1166 print 5*u,(5*u).diff(u)
1167 print u*v,(u*v).diff(u)
1168 print v*u,(v*u).diff(u)
1169
1170 print u-5,(u-5).diff(u)
1171 print 5-u,(5-u).diff(u)
1172 print u-v,(u-v).diff(u)
1173 print v-u,(v-u).diff(u)
1174
1175 print u/5,(u/5).diff(u)
1176 print 5/u,(5/u).diff(u)
1177 print u/v,(u/v).diff(u)
1178 print v/u,(v/u).diff(u)
1179
1180 print u**5,(u**5).diff(u)
1181 print 5**u,(5**u).diff(u)
1182 print u**v,(u**v).diff(u)
1183 print v**u,(v**u).diff(u)
1184
1185 print exp(u),exp(u).diff(u)
1186 print sqrt(u),sqrt(u).diff(u)
1187 print log(u),log(u).diff(u)
1188 print sin(u),sin(u).diff(u)
1189 print cos(u),cos(u).diff(u)
1190 print tan(u),tan(u).diff(u)
1191 print sign(u),sign(u).diff(u)
1192 print abs(u),abs(u).diff(u)
1193 print wherePositive(u),wherePositive(u).diff(u)
1194 print whereNegative(u),whereNegative(u).diff(u)
1195 print maxval(u),maxval(u).diff(u)
1196 print minval(u),minval(u).diff(u)
1197
1198 g=grad(u)
1199 print diff(5*g,g)
1200 4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1201
1202 #
1203 # $Log$
1204 # Revision 1.12 2005/07/20 06:14:58 jgs
1205 # added ln(data) style wrapper for data.ln() - also added corresponding
1206 # implementation of Ln_Symbol class (not sure if this is right though)
1207 #
1208 # Revision 1.11 2005/07/08 04:07:35 jgs
1209 # Merge of development branch back to main trunk on 2005-07-08
1210 #
1211 # Revision 1.10 2005/06/09 05:37:59 jgs
1212 # Merge of development branch back to main trunk on 2005-06-09
1213 #
1214 # Revision 1.2.2.17 2005/07/07 07:28:58 gross
1215 # some stuff added to util.py to improve functionality
1216 #
1217 # Revision 1.2.2.16 2005/06/30 01:53:55 gross
1218 # a bug in coloring fixed
1219 #
1220 # Revision 1.2.2.15 2005/06/29 02:36:43 gross
1221 # Symbols have been introduced and some function clarified. needs much more work
1222 #
1223 # Revision 1.2.2.14 2005/05/20 04:05:23 gross
1224 # some work on a darcy flow started
1225 #
1226 # Revision 1.2.2.13 2005/03/16 05:17:58 matt
1227 # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1228 # complement kronecker(dim) function.
1229 #
1230 # Revision 1.2.2.12 2005/03/10 08:14:37 matt
1231 # Added non-member Linf utility function to complement Data::Linf().
1232 #
1233 # Revision 1.2.2.11 2005/02/17 05:53:25 gross
1234 # some bug in saveDX fixed: in fact the bug was in
1235 # DataC/getDataPointShape
1236 #
1237 # Revision 1.2.2.10 2005/01/11 04:59:36 gross
1238 # automatic interpolation in integrate switched off
1239 #
1240 # Revision 1.2.2.9 2005/01/11 03:38:13 gross
1241 # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1242 #
1243 # Revision 1.2.2.8 2005/01/05 04:21:41 gross
1244 # FunctionSpace checking/matchig in slicing added
1245 #
1246 # Revision 1.2.2.7 2004/12/29 05:29:59 gross
1247 # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1248 #
1249 # Revision 1.2.2.6 2004/12/24 06:05:41 gross
1250 # some changes in linearPDEs to add AdevectivePDE
1251 #
1252 # Revision 1.2.2.5 2004/12/17 00:06:53 gross
1253 # mk sets ESYS_ROOT is undefined
1254 #
1255 # Revision 1.2.2.4 2004/12/07 03:19:51 gross
1256 # options for GMRES and PRES20 added
1257 #
1258 # Revision 1.2.2.3 2004/12/06 04:55:18 gross
1259 # function wraper extended
1260 #
1261 # Revision 1.2.2.2 2004/11/22 05:44:07 gross
1262 # a few more unitary functions have been added but not implemented in Data yet
1263 #
1264 # Revision 1.2.2.1 2004/11/12 06:58:15 gross
1265 # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1266 #
1267 # Revision 1.2 2004/10/27 00:23:36 jgs
1268 # fixed minor syntax error
1269 #
1270 # Revision 1.1.1.1 2004/10/26 06:53:56 jgs
1271 # initial import of project esys2
1272 #
1273 # Revision 1.1.2.3 2004/10/26 06:43:48 jgs
1274 # committing Lutz's and Paul's changes to brach jgs
1275 #
1276 # Revision 1.1.4.1 2004/10/20 05:32:51 cochrane
1277 # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.
1278 #
1279 # Revision 1.1 2004/08/05 03:58:27 gross
1280 # Bug in Assemble_NodeCoordinates fixed
1281 #
1282 #

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26