/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Contents of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 123 - (show annotations)
Fri Jul 8 04:08:13 2005 UTC (14 years, 3 months ago) by jgs
Original Path: trunk/esys2/escript/py_src/util.py
File MIME type: text/x-python
File size: 39052 byte(s)
Merge of development branch back to main trunk on 2005-07-08

1 # $Id$
2
3 ## @file util.py
4
5 """
6 @brief Utility functions for escript
7
8 TODO for Data:
9
10 * binary operations @: (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11 @=+,-,*,/,** (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12 (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13
14 * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15 * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16
17 """
18
19 import numarray
20 import escript
21 #
22 # escript constants (have to be consistent witj utilC.h
23 #
24 UNKNOWN=-1
25 EPSILON=1.e-15
26 Pi=numarray.pi
27 # some solver options:
28 NO_REORDERING=0
29 MINIMUM_FILL_IN=1
30 NESTED_DISSECTION=2
31 # solver methods
32 DEFAULT_METHOD=0
33 DIRECT=1
34 CHOLEVSKY=2
35 PCG=3
36 CR=4
37 CGS=5
38 BICGSTAB=6
39 SSOR=7
40 ILU0=8
41 ILUT=9
42 JACOBI=10
43 GMRES=11
44 PRES20=12
45
46 METHOD_KEY="method"
47 SYMMETRY_KEY="symmetric"
48 TOLERANCE_KEY="tolerance"
49
50 # supported file formats:
51 VRML=1
52 PNG=2
53 JPEG=3
54 JPG=3
55 PS=4
56 OOGL=5
57 BMP=6
58 TIFF=7
59 OPENINVENTOR=8
60 RENDERMAN=9
61 PNM=10
62 #===========================================================
63 # a simple tool box to deal with _differentials of functions
64 #===========================================================
65
66 class Symbol:
67 """symbol class"""
68 def __init__(self,name="symbol",shape=(),dim=3,args=[]):
69 """creates an instance of a symbol of shape shape and spatial dimension dim
70 The symbol may depending on a list of arguments args which
71 may be symbols or other objects. name gives the name of the symbol."""
72 self.__args=args
73 self.__name=name
74 self.__shape=shape
75 if hasattr(dim,"getDim"):
76 self.__dim=dim.getDim()
77 else:
78 self.__dim=dim
79 #
80 self.__cache_val=None
81 self.__cache_argval=None
82
83 def getArgument(self,i):
84 """returns the i-th argument"""
85 return self.__args[i]
86
87 def getDim(self):
88 """returns the spatial dimension of the symbol"""
89 return self.__dim
90
91 def getRank(self):
92 """returns the rank of the symbol"""
93 return len(self.getShape())
94
95 def getShape(self):
96 """returns the shape of the symbol"""
97 return self.__shape
98
99 def getEvaluatedArguments(self,argval):
100 """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
101 if argval==self.__cache_argval:
102 print "%s: cached value used"%self
103 return self.__cache_val
104 else:
105 out=[]
106 for a in self.__args:
107 if isinstance(a,Symbol):
108 out.append(a.eval(argval))
109 else:
110 out.append(a)
111 self.__cache_argval=argval
112 self.__cache_val=out
113 return out
114
115 def getDifferentiatedArguments(self,arg):
116 """returns the list of the arguments _differentiated by arg"""
117 out=[]
118 for a in self.__args:
119 if isinstance(a,Symbol):
120 out.append(a.diff(arg))
121 else:
122 out.append(0)
123 return out
124
125 def diff(self,arg):
126 """returns the _differention of self by arg."""
127 if self==arg:
128 out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
129 if self.getRank()==0:
130 out=1.
131 elif self.getRank()==1:
132 for i0 in range(self.getShape()[0]):
133 out[i0,i0]=1.
134 elif self.getRank()==2:
135 for i0 in range(self.getShape()[0]):
136 for i1 in range(self.getShape()[1]):
137 out[i0,i1,i0,i1]=1.
138 elif self.getRank()==3:
139 for i0 in range(self.getShape()[0]):
140 for i1 in range(self.getShape()[1]):
141 for i2 in range(self.getShape()[2]):
142 out[i0,i1,i2,i0,i1,i2]=1.
143 elif self.getRank()==4:
144 for i0 in range(self.getShape()[0]):
145 for i1 in range(self.getShape()[1]):
146 for i2 in range(self.getShape()[2]):
147 for i3 in range(self.getShape()[3]):
148 out[i0,i1,i2,i3,i0,i1,i2,i3]=1.
149 else:
150 raise ValueError,"differential support rank<5 only."
151 return out
152 else:
153 return self._diff(arg)
154
155 def _diff(self,arg):
156 """return derivate of self with respect to arg (!=self).
157 This method is overwritten by a particular symbol"""
158 return 0
159
160 def eval(self,argval):
161 """subsitutes symbol u in self by argval[u] and returns the result. If
162 self is not a key of argval then self is returned."""
163 if argval.has_key(self):
164 return argval[self]
165 else:
166 return self
167
168
169 def __str__(self):
170 """returns a string representation of the symbol"""
171 return self.__name
172
173 def __add__(self,other):
174 """adds other to symbol self. if _testForZero(other) self is returned."""
175 if _testForZero(other):
176 return self
177 else:
178 a=_matchShape([self,other])
179 return Add_Symbol(a[0],a[1])
180
181
182 def __radd__(self,other):
183 """adds other to symbol self. if _testForZero(other) self is returned."""
184 return self+other
185
186 def __neg__(self):
187 """returns -self."""
188 return self*(-1.)
189
190 def __pos__(self):
191 """returns +self."""
192 return self
193
194 def __abs__(self):
195 """returns absolute value"""
196 return Abs_Symbol(self)
197
198 def __sub__(self,other):
199 """subtracts other from symbol self. if _testForZero(other) self is returned."""
200 if _testForZero(other):
201 return self
202 else:
203 return self+(-other)
204
205 def __rsub__(self,other):
206 """subtracts symbol self from other."""
207 return -self+other
208
209 def __div__(self,other):
210 """divides symbol self by other."""
211 if isinstance(other,Symbol):
212 a=_matchShape([self,other])
213 return Div_Symbol(a[0],a[1])
214 else:
215 return self*(1./other)
216
217 def __rdiv__(self,other):
218 """dived other by symbol self. if _testForZero(other) 0 is returned."""
219 if _testForZero(other):
220 return 0
221 else:
222 a=_matchShape([self,other])
223 return Div_Symbol(a[0],a[1])
224
225 def __pow__(self,other):
226 """raises symbol self to the power of other"""
227 a=_matchShape([self,other])
228 return Power_Symbol(a[0],a[1])
229
230 def __rpow__(self,other):
231 """raises other to the symbol self"""
232 a=_matchShape([self,other])
233 return Power_Symbol(a[1],a[0])
234
235 def __mul__(self,other):
236 """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
237 if _testForZero(other):
238 return 0
239 else:
240 a=_matchShape([self,other])
241 return Mult_Symbol(a[0],a[1])
242
243 def __rmul__(self,other):
244 """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
245 return self*other
246
247 def __getitem__(self,sl):
248 print sl
249
250 def Float_Symbol(Symbol):
251 def __init__(self,name="symbol",shape=(),args=[]):
252 Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
253
254 class ScalarSymbol(Symbol):
255 """a scalar symbol"""
256 def __init__(self,dim=3,name="scalar"):
257 """creates a scalar symbol of spatial dimension dim"""
258 if hasattr(dim,"getDim"):
259 d=dim.getDim()
260 else:
261 d=dim
262 Symbol.__init__(self,shape=(),dim=d,name=name)
263
264 class VectorSymbol(Symbol):
265 """a vector symbol"""
266 def __init__(self,dim=3,name="vector"):
267 """creates a vector symbol of spatial dimension dim"""
268 if hasattr(dim,"getDim"):
269 d=dim.getDim()
270 else:
271 d=dim
272 Symbol.__init__(self,shape=(d,),dim=d,name=name)
273
274 class TensorSymbol(Symbol):
275 """a tensor symbol"""
276 def __init__(self,dim=3,name="tensor"):
277 """creates a tensor symbol of spatial dimension dim"""
278 if hasattr(dim,"getDim"):
279 d=dim.getDim()
280 else:
281 d=dim
282 Symbol.__init__(self,shape=(d,d),dim=d,name=name)
283
284 class Tensor3Symbol(Symbol):
285 """a tensor order 3 symbol"""
286 def __init__(self,dim=3,name="tensor3"):
287 """creates a tensor order 3 symbol of spatial dimension dim"""
288 if hasattr(dim,"getDim"):
289 d=dim.getDim()
290 else:
291 d=dim
292 Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
293
294 class Tensor4Symbol(Symbol):
295 """a tensor order 4 symbol"""
296 def __init__(self,dim=3,name="tensor4"):
297 """creates a tensor order 4 symbol of spatial dimension dim"""
298 if hasattr(dim,"getDim"):
299 d=dim.getDim()
300 else:
301 d=dim
302 Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
303
304
305 class Add_Symbol(Symbol):
306 """symbol representing the sum of two arguments"""
307 def __init__(self,arg0,arg1):
308 a=[arg0,arg1]
309 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
310 def __str__(self):
311 return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
312 def eval(self,argval):
313 a=self.getEvaluatedArguments(argval)
314 return a[0]+a[1]
315 def _diff(self,arg):
316 a=self.getDifferentiatedArguments(arg)
317 return a[0]+a[1]
318
319 class Mult_Symbol(Symbol):
320 """symbol representing the product of two arguments"""
321 def __init__(self,arg0,arg1):
322 a=[arg0,arg1]
323 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
324 def __str__(self):
325 return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
326 def eval(self,argval):
327 a=self.getEvaluatedArguments(argval)
328 return a[0]*a[1]
329 def _diff(self,arg):
330 a=self.getDifferentiatedArguments(arg)
331 return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
332
333 class Div_Symbol(Symbol):
334 """symbol representing the quotient of two arguments"""
335 def __init__(self,arg0,arg1):
336 a=[arg0,arg1]
337 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
338 def __str__(self):
339 return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
340 def eval(self,argval):
341 a=self.getEvaluatedArguments(argval)
342 return a[0]/a[1]
343 def _diff(self,arg):
344 a=self.getDifferentiatedArguments(arg)
345 return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
346 (self.getArgument(1)*self.getArgument(1))
347
348 class Power_Symbol(Symbol):
349 """symbol representing the power of the first argument to the power of the second argument"""
350 def __init__(self,arg0,arg1):
351 a=[arg0,arg1]
352 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
353 def __str__(self):
354 return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
355 def eval(self,argval):
356 a=self.getEvaluatedArguments(argval)
357 return a[0]**a[1]
358 def _diff(self,arg):
359 a=self.getDifferentiatedArguments(arg)
360 return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
361
362 class Abs_Symbol(Symbol):
363 """symbol representing absolute value of its argument"""
364 def __init__(self,arg):
365 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
366 def __str__(self):
367 return "abs(%s)"%str(self.getArgument(0))
368 def eval(self,argval):
369 return abs(self.getEvaluatedArguments(argval)[0])
370 def _diff(self,arg):
371 return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
372
373 #=========================================================
374 # some little helpers
375 #=========================================================
376 def _testForZero(arg):
377 """returns True is arg is considered of being zero"""
378 if isinstance(arg,int):
379 return not arg>0
380 elif isinstance(arg,float):
381 return not arg>0.
382 elif isinstance(arg,numarray.NumArray):
383 a=abs(arg)
384 while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
385 return not a>0
386 else:
387 return False
388
389 def _extractDim(args):
390 dim=None
391 for a in args:
392 if hasattr(a,"getDim"):
393 d=a.getDim()
394 if dim==None:
395 dim=d
396 else:
397 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
398 if dim==None:
399 raise ValueError,"cannot recover spatial dimension"
400 return dim
401
402 def _identifyShape(arg):
403 """identifies the shape of arg."""
404 if hasattr(arg,"getShape"):
405 arg_shape=arg.getShape()
406 elif hasattr(arg,"shape"):
407 s=arg.shape
408 if callable(s):
409 arg_shape=s()
410 else:
411 arg_shape=s
412 else:
413 arg_shape=()
414 return arg_shape
415
416 def _extractShape(args):
417 """extracts the common shape of the list of arguments args"""
418 shape=None
419 for a in args:
420 a_shape=_identifyShape(a)
421 if shape==None: shape=a_shape
422 if shape!=a_shape: raise ValueError,"inconsistent shape"
423 if shape==None:
424 raise ValueError,"cannot recover shape"
425 return shape
426
427 def _matchShape(args,shape=None):
428 """returns the list of arguments args as object which have all the specified shape.
429 if shape is not given the shape "largest" shape of args is used."""
430 # identify the list of shapes:
431 arg_shapes=[]
432 for a in args: arg_shapes.append(_identifyShape(a))
433 # get the largest shape (currently the longest shape):
434 if shape==None: shape=max(arg_shapes)
435
436 out=[]
437 for i in range(len(args)):
438 if shape==arg_shapes[i]:
439 out.append(args[i])
440 else:
441 if len(shape)==0: # then len(arg_shapes[i])>0
442 raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
443 else:
444 if len(arg_shapes[i])==0:
445 out.append(outer(args[i],numarray.ones(shape)))
446 else:
447 raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
448 return out
449 #=========================================================
450 # wrapper for various mathematical functions:
451 #=========================================================
452 def diff(arg,dep):
453 """returns the derivative of arg with respect to dep. If arg is not Symbol object
454 0 is returned"""
455 if isinstance(arg,Symbol):
456 return arg.diff(dep)
457 elif hasattr(arg,"shape"):
458 if callable(arg.shape):
459 return numarray.zeros(arg.shape(),numarray.Float)
460 else:
461 return numarray.zeros(arg.shape,numarray.Float)
462 else:
463 return 0
464
465 def exp(arg):
466 """
467 @brief applies the exponential function to arg
468 @param arg (input): argument
469 """
470 if isinstance(arg,Symbol):
471 return Exp_Symbol(arg)
472 elif hasattr(arg,"exp"):
473 return arg.exp()
474 else:
475 return numarray.exp(arg)
476
477 class Exp_Symbol(Symbol):
478 """symbol representing the power of the first argument to the power of the second argument"""
479 def __init__(self,arg):
480 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
481 def __str__(self):
482 return "exp(%s)"%str(self.getArgument(0))
483 def eval(self,argval):
484 return exp(self.getEvaluatedArguments(argval)[0])
485 def _diff(self,arg):
486 return self*self.getDifferentiatedArguments(arg)[0]
487
488 def sqrt(arg):
489 """
490 @brief applies the squre root function to arg
491 @param arg (input): argument
492 """
493 if isinstance(arg,Symbol):
494 return Sqrt_Symbol(arg)
495 elif hasattr(arg,"sqrt"):
496 return arg.sqrt()
497 else:
498 return numarray.sqrt(arg)
499
500 class Sqrt_Symbol(Symbol):
501 """symbol representing square root of argument"""
502 def __init__(self,arg):
503 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
504 def __str__(self):
505 return "sqrt(%s)"%str(self.getArgument(0))
506 def eval(self,argval):
507 return sqrt(self.getEvaluatedArguments(argval)[0])
508 def _diff(self,arg):
509 return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
510
511 def log(arg):
512 """
513 @brief applies the logarithmic function bases exp(1.) to arg
514 @param arg (input): argument
515 """
516 if isinstance(arg,Symbol):
517 return Log_Symbol(arg)
518 elif hasattr(arg,"log"):
519 return arg.log()
520 else:
521 return numarray.log(arg)
522
523 class Log_Symbol(Symbol):
524 """symbol representing logarithm of the argument"""
525 def __init__(self,arg):
526 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
527 def __str__(self):
528 return "log(%s)"%str(self.getArgument(0))
529 def eval(self,argval):
530 return log(self.getEvaluatedArguments(argval)[0])
531 def _diff(self,arg):
532 return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
533
534 def sin(arg):
535 """
536 @brief applies the sinus function to arg
537 @param arg (input): argument
538 """
539 if isinstance(arg,Symbol):
540 return Sin_Symbol(arg)
541 elif hasattr(arg,"sin"):
542 return arg.sin()
543 else:
544 return numarray.sin(arg)
545
546 class Sin_Symbol(Symbol):
547 """symbol representing logarithm of the argument"""
548 def __init__(self,arg):
549 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
550 def __str__(self):
551 return "sin(%s)"%str(self.getArgument(0))
552 def eval(self,argval):
553 return sin(self.getEvaluatedArguments(argval)[0])
554 def _diff(self,arg):
555 return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
556
557 def cos(arg):
558 """
559 @brief applies the sinus function to arg
560 @param arg (input): argument
561 """
562 if isinstance(arg,Symbol):
563 return Cos_Symbol(arg)
564 elif hasattr(arg,"cos"):
565 return arg.cos()
566 else:
567 return numarray.cos(arg)
568
569 class Cos_Symbol(Symbol):
570 """symbol representing logarithm of the argument"""
571 def __init__(self,arg):
572 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
573 def __str__(self):
574 return "cos(%s)"%str(self.getArgument(0))
575 def eval(self,argval):
576 return cos(self.getEvaluatedArguments(argval)[0])
577 def _diff(self,arg):
578 return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
579
580 def tan(arg):
581 """
582 @brief applies the sinus function to arg
583 @param arg (input): argument
584 """
585 if isinstance(arg,Symbol):
586 return Tan_Symbol(arg)
587 elif hasattr(arg,"tan"):
588 return arg.tan()
589 else:
590 return numarray.tan(arg)
591
592 class Tan_Symbol(Symbol):
593 """symbol representing logarithm of the argument"""
594 def __init__(self,arg):
595 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
596 def __str__(self):
597 return "tan(%s)"%str(self.getArgument(0))
598 def eval(self,argval):
599 return tan(self.getEvaluatedArguments(argval)[0])
600 def _diff(self,arg):
601 s=cos(self.getArgument(0))
602 return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
603
604 def sign(arg):
605 """
606 @brief applies the sign function to arg
607 @param arg (input): argument
608 """
609 if isinstance(arg,Symbol):
610 return Sign_Symbol(arg)
611 elif hasattr(arg,"sign"):
612 return arg.sign()
613 else:
614 return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
615 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
616
617 class Sign_Symbol(Symbol):
618 """symbol representing the sign of the argument"""
619 def __init__(self,arg):
620 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
621 def __str__(self):
622 return "sign(%s)"%str(self.getArgument(0))
623 def eval(self,argval):
624 return sign(self.getEvaluatedArguments(argval)[0])
625
626 def maxval(arg):
627 """
628 @brief returns the maximum value of argument arg""
629 @param arg (input): argument
630 """
631 if isinstance(arg,Symbol):
632 return Max_Symbol(arg)
633 elif hasattr(arg,"maxval"):
634 return arg.maxval()
635 elif hasattr(arg,"max"):
636 return arg.max()
637 else:
638 return arg
639
640 class Max_Symbol(Symbol):
641 """symbol representing the sign of the argument"""
642 def __init__(self,arg):
643 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
644 def __str__(self):
645 return "maxval(%s)"%str(self.getArgument(0))
646 def eval(self,argval):
647 return maxval(self.getEvaluatedArguments(argval)[0])
648
649 def minval(arg):
650 """
651 @brief returns the maximum value of argument arg""
652 @param arg (input): argument
653 """
654 if isinstance(arg,Symbol):
655 return Min_Symbol(arg)
656 elif hasattr(arg,"maxval"):
657 return arg.minval()
658 elif hasattr(arg,"min"):
659 return arg.min()
660 else:
661 return arg
662
663 class Min_Symbol(Symbol):
664 """symbol representing the sign of the argument"""
665 def __init__(self,arg):
666 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
667 def __str__(self):
668 return "minval(%s)"%str(self.getArgument(0))
669 def eval(self,argval):
670 return minval(self.getEvaluatedArguments(argval)[0])
671
672 def wherePositive(arg):
673 """
674 @brief returns the maximum value of argument arg""
675 @param arg (input): argument
676 """
677 if _testForZero(arg):
678 return 0
679 elif isinstance(arg,Symbol):
680 return WherePositive_Symbol(arg)
681 elif hasattr(arg,"wherePositive"):
682 return arg.minval()
683 elif hasattr(arg,"wherePositive"):
684 numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
685 else:
686 if arg>0:
687 return 1.
688 else:
689 return 0.
690
691 class WherePositive_Symbol(Symbol):
692 """symbol representing the wherePositive function"""
693 def __init__(self,arg):
694 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
695 def __str__(self):
696 return "wherePositive(%s)"%str(self.getArgument(0))
697 def eval(self,argval):
698 return wherePositive(self.getEvaluatedArguments(argval)[0])
699
700 def whereNegative(arg):
701 """
702 @brief returns the maximum value of argument arg""
703 @param arg (input): argument
704 """
705 if _testForZero(arg):
706 return 0
707 elif isinstance(arg,Symbol):
708 return WhereNegative_Symbol(arg)
709 elif hasattr(arg,"whereNegative"):
710 return arg.whereNegative()
711 elif hasattr(arg,"shape"):
712 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
713 else:
714 if arg<0:
715 return 1.
716 else:
717 return 0.
718
719 class WhereNegative_Symbol(Symbol):
720 """symbol representing the whereNegative function"""
721 def __init__(self,arg):
722 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
723 def __str__(self):
724 return "whereNegative(%s)"%str(self.getArgument(0))
725 def eval(self,argval):
726 return whereNegative(self.getEvaluatedArguments(argval)[0])
727
728 def outer(arg0,arg1):
729 if _testForZero(arg0) or _testForZero(arg1):
730 return 0
731 else:
732 if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
733 return Outer_Symbol(arg0,arg1)
734 elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
735 return arg0*arg1
736 elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
737 return numarray.outer(arg0,arg1)
738 else:
739 if arg0.getRank()==1 and arg1.getRank()==1:
740 out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
741 for i in range(arg0.getShape()[0]):
742 for j in range(arg1.getShape()[0]):
743 out[i,j]=arg0[i]*arg1[j]
744 return out
745 else:
746 raise ValueError,"outer is not fully implemented yet."
747
748 class Outer_Symbol(Symbol):
749 """symbol representing the outer product of its two argument"""
750 def __init__(self,arg0,arg1):
751 a=[arg0,arg1]
752 s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
753 Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
754 def __str__(self):
755 return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
756 def eval(self,argval):
757 a=self.getEvaluatedArguments(argval)
758 return outer(a[0],a[1])
759 def _diff(self,arg):
760 a=self.getDifferentiatedArguments(arg)
761 return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
762
763 def interpolate(arg,where):
764 """
765 @brief interpolates the function into the FunctionSpace where.
766
767 @param arg interpolant
768 @param where FunctionSpace to interpolate to
769 """
770 if _testForZero(arg):
771 return 0
772 elif isinstance(arg,Symbol):
773 return Interpolated_Symbol(arg,where)
774 else:
775 return escript.Data(arg,where)
776
777 def Interpolated_Symbol(Symbol):
778 """symbol representing the integral of the argument"""
779 def __init__(self,arg,where):
780 Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
781 def __str__(self):
782 return "interpolated(%s)"%(str(self.getArgument(0)))
783 def eval(self,argval):
784 a=self.getEvaluatedArguments(argval)
785 return integrate(a[0],where=self.getArgument(1))
786 def _diff(self,arg):
787 a=self.getDifferentiatedArguments(arg)
788 return integrate(a[0],where=self.getArgument(1))
789
790 def grad(arg,where=None):
791 """
792 @brief returns the spatial gradient of arg at where.
793
794 @param arg: Data object representing the function which gradient to be calculated.
795 @param where: FunctionSpace in which the gradient will be calculated. If not present or
796 None an appropriate default is used.
797 """
798 if _testForZero(arg):
799 return 0
800 elif isinstance(arg,Symbol):
801 return Grad_Symbol(arg,where)
802 elif hasattr(arg,"grad"):
803 if where==None:
804 return arg.grad()
805 else:
806 return arg.grad(where)
807 else:
808 return arg*0.
809
810 def Grad_Symbol(Symbol):
811 """symbol representing the gradient of the argument"""
812 def __init__(self,arg,where=None):
813 d=_extractDim([arg])
814 s=tuple(list(_identifyShape([arg])).append(d))
815 Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
816 def __str__(self):
817 return "grad(%s)"%(str(self.getArgument(0)))
818 def eval(self,argval):
819 a=self.getEvaluatedArguments(argval)
820 return grad(a[0],where=self.getArgument(1))
821 def _diff(self,arg):
822 a=self.getDifferentiatedArguments(arg)
823 return grad(a[0],where=self.getArgument(1))
824
825 def integrate(arg,where=None):
826 """
827 @brief return the integral if the function represented by Data object arg over its domain.
828
829 @param arg: Data object representing the function which is integrated.
830 @param where: FunctionSpace in which the integral is calculated. If not present or
831 None an appropriate default is used.
832 """
833 if _testForZero(arg):
834 return 0
835 elif isinstance(arg,Symbol):
836 return Integral_Symbol(arg,where)
837 else:
838 if not where==None: arg=escript.Data(arg,where)
839 if arg.getRank()==0:
840 return arg.integrate()[0]
841 else:
842 return arg.integrate()
843
844 def Integral_Symbol(Float_Symbol):
845 """symbol representing the integral of the argument"""
846 def __init__(self,arg,where=None):
847 Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
848 def __str__(self):
849 return "integral(%s)"%(str(self.getArgument(0)))
850 def eval(self,argval):
851 a=self.getEvaluatedArguments(argval)
852 return integrate(a[0],where=self.getArgument(1))
853 def _diff(self,arg):
854 a=self.getDifferentiatedArguments(arg)
855 return integrate(a[0],where=self.getArgument(1))
856
857 #=============================
858 #
859 # wrapper for various functions: if the argument has attribute the function name
860 # as an argument it calls the correspong methods. Otherwise the coresponsing numarray
861 # function is called.
862 #
863 # functions involving the underlying Domain:
864 #
865
866
867 # functions returning Data objects:
868
869 def transpose(arg,axis=None):
870 """
871 @brief returns the transpose of the Data object arg.
872
873 @param arg
874 """
875 if axis==None:
876 r=0
877 if hasattr(arg,"getRank"): r=arg.getRank()
878 if hasattr(arg,"rank"): r=arg.rank
879 axis=r/2
880 if isinstance(arg,Symbol):
881 return Transpose_Symbol(arg,axis=r)
882 if isinstance(arg,escript.Data):
883 # hack for transpose
884 r=arg.getRank()
885 if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
886 s=arg.getShape()
887 out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
888 for i in range(s[0]):
889 for j in range(s[1]):
890 out[j,i]=arg[i,j]
891 return out
892 # end hack for transpose
893 return arg.transpose(axis)
894 else:
895 return numarray.transpose(arg,axis=axis)
896
897 def trace(arg,axis0=0,axis1=1):
898 """
899 @brief return
900
901 @param arg
902 """
903 if isinstance(arg,Symbol):
904 s=list(arg.getShape())
905 s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
906 return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
907 elif isinstance(arg,escript.Data):
908 # hack for trace
909 s=arg.getShape()
910 if s[axis0]!=s[axis1]:
911 raise ValueError,"illegal axis in trace"
912 out=escript.Scalar(0.,arg.getFunctionSpace())
913 for i in range(s[0]):
914 for j in range(s[1]):
915 out+=arg[i,j]
916 return out
917 # end hack for transpose
918 return arg.transpose(axis0=axis0,axis1=axis1)
919 else:
920 return numarray.trace(arg,axis0=axis0,axis1=axis1)
921
922
923
924 def Trace_Symbol(Symbol):
925 pass
926
927 def length(arg):
928 """
929 @brief
930
931 @param arg
932 """
933 if isinstance(arg,escript.Data):
934 if arg.isEmpty(): return escript.Data()
935 if arg.getRank()==0:
936 return abs(arg)
937 elif arg.getRank()==1:
938 sum=escript.Scalar(0,arg.getFunctionSpace())
939 for i in range(arg.getShape()[0]):
940 sum+=arg[i]**2
941 return sqrt(sum)
942 elif arg.getRank()==2:
943 sum=escript.Scalar(0,arg.getFunctionSpace())
944 for i in range(arg.getShape()[0]):
945 for j in range(arg.getShape()[1]):
946 sum+=arg[i,j]**2
947 return sqrt(sum)
948 elif arg.getRank()==3:
949 sum=escript.Scalar(0,arg.getFunctionSpace())
950 for i in range(arg.getShape()[0]):
951 for j in range(arg.getShape()[1]):
952 for k in range(arg.getShape()[2]):
953 sum+=arg[i,j,k]**2
954 return sqrt(sum)
955 elif arg.getRank()==4:
956 sum=escript.Scalar(0,arg.getFunctionSpace())
957 for i in range(arg.getShape()[0]):
958 for j in range(arg.getShape()[1]):
959 for k in range(arg.getShape()[2]):
960 for l in range(arg.getShape()[3]):
961 sum+=arg[i,j,k,l]**2
962 return sqrt(sum)
963 else:
964 raise SystemError,"length is not been implemented yet"
965 # return arg.length()
966 else:
967 return sqrt((arg**2).sum())
968
969 def deviator(arg):
970 """
971 @brief
972
973 @param arg0
974 """
975 if isinstance(arg,escript.Data):
976 shape=arg.getShape()
977 else:
978 shape=arg.shape
979 if len(shape)!=2:
980 raise ValueError,"Deviator requires rank 2 object"
981 if shape[0]!=shape[1]:
982 raise ValueError,"Deviator requires a square matrix"
983 return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
984
985 def inner(arg0,arg1):
986 """
987 @brief
988
989 @param arg0, arg1
990 """
991 sum=escript.Scalar(0,arg0.getFunctionSpace())
992 if arg.getRank()==0:
993 return arg0*arg1
994 elif arg.getRank()==1:
995 sum=escript.Scalar(0,arg.getFunctionSpace())
996 for i in range(arg.getShape()[0]):
997 sum+=arg0[i]*arg1[i]
998 elif arg.getRank()==2:
999 sum=escript.Scalar(0,arg.getFunctionSpace())
1000 for i in range(arg.getShape()[0]):
1001 for j in range(arg.getShape()[1]):
1002 sum+=arg0[i,j]*arg1[i,j]
1003 elif arg.getRank()==3:
1004 sum=escript.Scalar(0,arg.getFunctionSpace())
1005 for i in range(arg.getShape()[0]):
1006 for j in range(arg.getShape()[1]):
1007 for k in range(arg.getShape()[2]):
1008 sum+=arg0[i,j,k]*arg1[i,j,k]
1009 elif arg.getRank()==4:
1010 sum=escript.Scalar(0,arg.getFunctionSpace())
1011 for i in range(arg.getShape()[0]):
1012 for j in range(arg.getShape()[1]):
1013 for k in range(arg.getShape()[2]):
1014 for l in range(arg.getShape()[3]):
1015 sum+=arg0[i,j,k,l]*arg1[i,j,k,l]
1016 else:
1017 raise SystemError,"inner is not been implemented yet"
1018 return sum
1019
1020 def matrixmult(arg0,arg1):
1021
1022 if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1023 numarray.matrixmult(arg0,arg1)
1024 else:
1025 # escript.matmult(arg0,arg1)
1026 if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1027 arg0=escript.Data(arg0,arg1.getFunctionSpace())
1028 elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1029 arg1=escript.Data(arg1,arg0.getFunctionSpace())
1030 if arg0.getRank()==2 and arg1.getRank()==1:
1031 out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1032 for i in range(arg0.getShape()[0]):
1033 for j in range(arg0.getShape()[1]):
1034 out[i]+=arg0[i,j]*arg1[j]
1035 return out
1036 else:
1037 raise SystemError,"matrixmult is not fully implemented yet!"
1038 #=========================================================
1039 # reduction operations:
1040 #=========================================================
1041 def sum(arg):
1042 """
1043 @brief
1044
1045 @param arg
1046 """
1047 return arg.sum()
1048
1049 def sup(arg):
1050 """
1051 @brief
1052
1053 @param arg
1054 """
1055 if isinstance(arg,escript.Data):
1056 return arg.sup()
1057 elif isinstance(arg,float) or isinstance(arg,int):
1058 return arg
1059 else:
1060 return arg.max()
1061
1062 def inf(arg):
1063 """
1064 @brief
1065
1066 @param arg
1067 """
1068 if isinstance(arg,escript.Data):
1069 return arg.inf()
1070 elif isinstance(arg,float) or isinstance(arg,int):
1071 return arg
1072 else:
1073 return arg.min()
1074
1075 def L2(arg):
1076 """
1077 @brief returns the L2-norm of the
1078
1079 @param arg
1080 """
1081 if isinstance(arg,escript.Data):
1082 return arg.L2()
1083 elif isinstance(arg,float) or isinstance(arg,int):
1084 return abs(arg)
1085 else:
1086 return numarry.sqrt(dot(arg,arg))
1087
1088 def Lsup(arg):
1089 """
1090 @brief
1091
1092 @param arg
1093 """
1094 if isinstance(arg,escript.Data):
1095 return arg.Lsup()
1096 elif isinstance(arg,float) or isinstance(arg,int):
1097 return abs(arg)
1098 else:
1099 return max(numarray.abs(arg))
1100
1101 def dot(arg0,arg1):
1102 """
1103 @brief
1104
1105 @param arg
1106 """
1107 if isinstance(arg0,escript.Data):
1108 return arg0.dot(arg1)
1109 elif isinstance(arg1,escript.Data):
1110 return arg1.dot(arg0)
1111 else:
1112 return numarray.dot(arg0,arg1)
1113
1114 def kronecker(d):
1115 if hasattr(d,"getDim"):
1116 return numarray.identity(d.getDim())
1117 else:
1118 return numarray.identity(d)
1119
1120 def unit(i,d):
1121 """
1122 @brief return a unit vector of dimension d with nonzero index i
1123 @param d dimension
1124 @param i index
1125 """
1126 e = numarray.zeros((d,),numarray.Float)
1127 e[i] = 1.0
1128 return e
1129
1130 #
1131 # ============================================
1132 # testing
1133 # ============================================
1134
1135 if __name__=="__main__":
1136 u=ScalarSymbol(dim=2,name="u")
1137 v=ScalarSymbol(dim=2,name="v")
1138 v=VectorSymbol(2,"v")
1139 u=VectorSymbol(2,"u")
1140
1141
1142 print u+5,(u+5).diff(u)
1143 print 5+u,(5+u).diff(u)
1144 print u+v,(u+v).diff(u)
1145 print v+u,(v+u).diff(u)
1146
1147 print u*5,(u*5).diff(u)
1148 print 5*u,(5*u).diff(u)
1149 print u*v,(u*v).diff(u)
1150 print v*u,(v*u).diff(u)
1151
1152 print u-5,(u-5).diff(u)
1153 print 5-u,(5-u).diff(u)
1154 print u-v,(u-v).diff(u)
1155 print v-u,(v-u).diff(u)
1156
1157 print u/5,(u/5).diff(u)
1158 print 5/u,(5/u).diff(u)
1159 print u/v,(u/v).diff(u)
1160 print v/u,(v/u).diff(u)
1161
1162 print u**5,(u**5).diff(u)
1163 print 5**u,(5**u).diff(u)
1164 print u**v,(u**v).diff(u)
1165 print v**u,(v**u).diff(u)
1166
1167 print exp(u),exp(u).diff(u)
1168 print sqrt(u),sqrt(u).diff(u)
1169 print log(u),log(u).diff(u)
1170 print sin(u),sin(u).diff(u)
1171 print cos(u),cos(u).diff(u)
1172 print tan(u),tan(u).diff(u)
1173 print sign(u),sign(u).diff(u)
1174 print abs(u),abs(u).diff(u)
1175 print wherePositive(u),wherePositive(u).diff(u)
1176 print whereNegative(u),whereNegative(u).diff(u)
1177 print maxval(u),maxval(u).diff(u)
1178 print minval(u),minval(u).diff(u)
1179
1180 g=grad(u)
1181 print diff(5*g,g)
1182 4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1183 #
1184 # $Log$
1185 # Revision 1.11 2005/07/08 04:07:35 jgs
1186 # Merge of development branch back to main trunk on 2005-07-08
1187 #
1188 # Revision 1.10 2005/06/09 05:37:59 jgs
1189 # Merge of development branch back to main trunk on 2005-06-09
1190 #
1191 # Revision 1.2.2.17 2005/07/07 07:28:58 gross
1192 # some stuff added to util.py to improve functionality
1193 #
1194 # Revision 1.2.2.16 2005/06/30 01:53:55 gross
1195 # a bug in coloring fixed
1196 #
1197 # Revision 1.2.2.15 2005/06/29 02:36:43 gross
1198 # Symbols have been introduced and some function clarified. needs much more work
1199 #
1200 # Revision 1.2.2.14 2005/05/20 04:05:23 gross
1201 # some work on a darcy flow started
1202 #
1203 # Revision 1.2.2.13 2005/03/16 05:17:58 matt
1204 # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1205 # complement kronecker(dim) function.
1206 #
1207 # Revision 1.2.2.12 2005/03/10 08:14:37 matt
1208 # Added non-member Linf utility function to complement Data::Linf().
1209 #
1210 # Revision 1.2.2.11 2005/02/17 05:53:25 gross
1211 # some bug in saveDX fixed: in fact the bug was in
1212 # DataC/getDataPointShape
1213 #
1214 # Revision 1.2.2.10 2005/01/11 04:59:36 gross
1215 # automatic interpolation in integrate switched off
1216 #
1217 # Revision 1.2.2.9 2005/01/11 03:38:13 gross
1218 # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1219 #
1220 # Revision 1.2.2.8 2005/01/05 04:21:41 gross
1221 # FunctionSpace checking/matchig in slicing added
1222 #
1223 # Revision 1.2.2.7 2004/12/29 05:29:59 gross
1224 # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1225 #
1226 # Revision 1.2.2.6 2004/12/24 06:05:41 gross
1227 # some changes in linearPDEs to add AdevectivePDE
1228 #
1229 # Revision 1.2.2.5 2004/12/17 00:06:53 gross
1230 # mk sets ESYS_ROOT is undefined
1231 #
1232 # Revision 1.2.2.4 2004/12/07 03:19:51 gross
1233 # options for GMRES and PRES20 added
1234 #
1235 # Revision 1.2.2.3 2004/12/06 04:55:18 gross
1236 # function wraper extended
1237 #
1238 # Revision 1.2.2.2 2004/11/22 05:44:07 gross
1239 # a few more unitary functions have been added but not implemented in Data yet
1240 #
1241 # Revision 1.2.2.1 2004/11/12 06:58:15 gross
1242 # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1243 #
1244 # Revision 1.2 2004/10/27 00:23:36 jgs
1245 # fixed minor syntax error
1246 #
1247 # Revision 1.1.1.1 2004/10/26 06:53:56 jgs
1248 # initial import of project esys2
1249 #
1250 # Revision 1.1.2.3 2004/10/26 06:43:48 jgs
1251 # committing Lutz's and Paul's changes to brach jgs
1252 #
1253 # Revision 1.1.4.1 2004/10/20 05:32:51 cochrane
1254 # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.
1255 #
1256 # Revision 1.1 2004/08/05 03:58:27 gross
1257 # Bug in Assemble_NodeCoordinates fixed
1258 #
1259 #

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26