/[escript]/trunk/esys2/escript/py_src/util.py
ViewVC logotype

Contents of /trunk/esys2/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 148 - (show annotations)
Tue Aug 23 01:24:31 2005 UTC (14 years, 1 month ago) by jgs
File MIME type: text/x-python
File size: 41409 byte(s)
Merge of development branch dev-02 back to main trunk on 2005-08-23

1 # $Id$
2
3 ## @file util.py
4
5 """
6 @brief Utility functions for escript
7
8 TODO for Data:
9
10 * binary operations @: (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
11 @=+,-,*,/,** (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
12 (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
13
14 * implementation of outer outer(a,b)[:,*]=a[:]*b[*]
15 * trace: trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
16 """
17
18 import numarray
19 import escript
20
21 #===========================================================
22 # a simple tool box to deal with _differentials of functions
23 #===========================================================
24
25 class Symbol:
26 """symbol class"""
27 def __init__(self,name="symbol",shape=(),dim=3,args=[]):
28 """creates an instance of a symbol of shape shape and spatial dimension dim
29 The symbol may depending on a list of arguments args which
30 may be symbols or other objects. name gives the name of the symbol."""
31 self.__args=args
32 self.__name=name
33 self.__shape=shape
34 if hasattr(dim,"getDim"):
35 self.__dim=dim.getDim()
36 else:
37 self.__dim=dim
38 #
39 self.__cache_val=None
40 self.__cache_argval=None
41
42 def getArgument(self,i):
43 """returns the i-th argument"""
44 return self.__args[i]
45
46 def getDim(self):
47 """returns the spatial dimension of the symbol"""
48 return self.__dim
49
50 def getRank(self):
51 """returns the rank of the symbol"""
52 return len(self.getShape())
53
54 def getShape(self):
55 """returns the shape of the symbol"""
56 return self.__shape
57
58 def getEvaluatedArguments(self,argval):
59 """returns the list of evaluated arguments by subsituting symbol u by argval[u]."""
60 if argval==self.__cache_argval:
61 print "%s: cached value used"%self
62 return self.__cache_val
63 else:
64 out=[]
65 for a in self.__args:
66 if isinstance(a,Symbol):
67 out.append(a.eval(argval))
68 else:
69 out.append(a)
70 self.__cache_argval=argval
71 self.__cache_val=out
72 return out
73
74 def getDifferentiatedArguments(self,arg):
75 """returns the list of the arguments _differentiated by arg"""
76 out=[]
77 for a in self.__args:
78 if isinstance(a,Symbol):
79 out.append(a.diff(arg))
80 else:
81 out.append(0)
82 return out
83
84 def diff(self,arg):
85 """returns the _differention of self by arg."""
86 if self==arg:
87 out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float)
88 if self.getRank()==0:
89 out=1.
90 elif self.getRank()==1:
91 for i0 in range(self.getShape()[0]):
92 out[i0,i0]=1.
93 elif self.getRank()==2:
94 for i0 in range(self.getShape()[0]):
95 for i1 in range(self.getShape()[1]):
96 out[i0,i1,i0,i1]=1.
97 elif self.getRank()==3:
98 for i0 in range(self.getShape()[0]):
99 for i1 in range(self.getShape()[1]):
100 for i2 in range(self.getShape()[2]):
101 out[i0,i1,i2,i0,i1,i2]=1.
102 elif self.getRank()==4:
103 for i0 in range(self.getShape()[0]):
104 for i1 in range(self.getShape()[1]):
105 for i2 in range(self.getShape()[2]):
106 for i3 in range(self.getShape()[3]):
107 out[i0,i1,i2,i3,i0,i1,i2,i3]=1.
108 else:
109 raise ValueError,"differential support rank<5 only."
110 return out
111 else:
112 return self._diff(arg)
113
114 def _diff(self,arg):
115 """return derivate of self with respect to arg (!=self).
116 This method is overwritten by a particular symbol"""
117 return 0
118
119 def eval(self,argval):
120 """subsitutes symbol u in self by argval[u] and returns the result. If
121 self is not a key of argval then self is returned."""
122 if argval.has_key(self):
123 return argval[self]
124 else:
125 return self
126
127 def __str__(self):
128 """returns a string representation of the symbol"""
129 return self.__name
130
131 def __add__(self,other):
132 """adds other to symbol self. if _testForZero(other) self is returned."""
133 if _testForZero(other):
134 return self
135 else:
136 a=_matchShape([self,other])
137 return Add_Symbol(a[0],a[1])
138
139 def __radd__(self,other):
140 """adds other to symbol self. if _testForZero(other) self is returned."""
141 return self+other
142
143 def __neg__(self):
144 """returns -self."""
145 return self*(-1.)
146
147 def __pos__(self):
148 """returns +self."""
149 return self
150
151 def __abs__(self):
152 """returns absolute value"""
153 return Abs_Symbol(self)
154
155 def __sub__(self,other):
156 """subtracts other from symbol self. if _testForZero(other) self is returned."""
157 if _testForZero(other):
158 return self
159 else:
160 return self+(-other)
161
162 def __rsub__(self,other):
163 """subtracts symbol self from other."""
164 return -self+other
165
166 def __div__(self,other):
167 """divides symbol self by other."""
168 if isinstance(other,Symbol):
169 a=_matchShape([self,other])
170 return Div_Symbol(a[0],a[1])
171 else:
172 return self*(1./other)
173
174 def __rdiv__(self,other):
175 """dived other by symbol self. if _testForZero(other) 0 is returned."""
176 if _testForZero(other):
177 return 0
178 else:
179 a=_matchShape([self,other])
180 return Div_Symbol(a[0],a[1])
181
182 def __pow__(self,other):
183 """raises symbol self to the power of other"""
184 a=_matchShape([self,other])
185 return Power_Symbol(a[0],a[1])
186
187 def __rpow__(self,other):
188 """raises other to the symbol self"""
189 a=_matchShape([self,other])
190 return Power_Symbol(a[1],a[0])
191
192 def __mul__(self,other):
193 """multiplies other by symbol self. if _testForZero(other) 0 is returned."""
194 if _testForZero(other):
195 return 0
196 else:
197 a=_matchShape([self,other])
198 return Mult_Symbol(a[0],a[1])
199
200 def __rmul__(self,other):
201 """multiplies other by symbol self. if _testSForZero(other) 0 is returned."""
202 return self*other
203
204 def __getitem__(self,sl):
205 print sl
206
207 def Float_Symbol(Symbol):
208 def __init__(self,name="symbol",shape=(),args=[]):
209 Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[])
210
211 class ScalarSymbol(Symbol):
212 """a scalar symbol"""
213 def __init__(self,dim=3,name="scalar"):
214 """creates a scalar symbol of spatial dimension dim"""
215 if hasattr(dim,"getDim"):
216 d=dim.getDim()
217 else:
218 d=dim
219 Symbol.__init__(self,shape=(),dim=d,name=name)
220
221 class VectorSymbol(Symbol):
222 """a vector symbol"""
223 def __init__(self,dim=3,name="vector"):
224 """creates a vector symbol of spatial dimension dim"""
225 if hasattr(dim,"getDim"):
226 d=dim.getDim()
227 else:
228 d=dim
229 Symbol.__init__(self,shape=(d,),dim=d,name=name)
230
231 class TensorSymbol(Symbol):
232 """a tensor symbol"""
233 def __init__(self,dim=3,name="tensor"):
234 """creates a tensor symbol of spatial dimension dim"""
235 if hasattr(dim,"getDim"):
236 d=dim.getDim()
237 else:
238 d=dim
239 Symbol.__init__(self,shape=(d,d),dim=d,name=name)
240
241 class Tensor3Symbol(Symbol):
242 """a tensor order 3 symbol"""
243 def __init__(self,dim=3,name="tensor3"):
244 """creates a tensor order 3 symbol of spatial dimension dim"""
245 if hasattr(dim,"getDim"):
246 d=dim.getDim()
247 else:
248 d=dim
249 Symbol.__init__(self,shape=(d,d,d),dim=d,name=name)
250
251 class Tensor4Symbol(Symbol):
252 """a tensor order 4 symbol"""
253 def __init__(self,dim=3,name="tensor4"):
254 """creates a tensor order 4 symbol of spatial dimension dim"""
255 if hasattr(dim,"getDim"):
256 d=dim.getDim()
257 else:
258 d=dim
259 Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name)
260
261 class Add_Symbol(Symbol):
262 """symbol representing the sum of two arguments"""
263 def __init__(self,arg0,arg1):
264 a=[arg0,arg1]
265 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
266 def __str__(self):
267 return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
268 def eval(self,argval):
269 a=self.getEvaluatedArguments(argval)
270 return a[0]+a[1]
271 def _diff(self,arg):
272 a=self.getDifferentiatedArguments(arg)
273 return a[0]+a[1]
274
275 class Mult_Symbol(Symbol):
276 """symbol representing the product of two arguments"""
277 def __init__(self,arg0,arg1):
278 a=[arg0,arg1]
279 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
280 def __str__(self):
281 return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
282 def eval(self,argval):
283 a=self.getEvaluatedArguments(argval)
284 return a[0]*a[1]
285 def _diff(self,arg):
286 a=self.getDifferentiatedArguments(arg)
287 return self.getArgument(1)*a[0]+self.getArgument(0)*a[1]
288
289 class Div_Symbol(Symbol):
290 """symbol representing the quotient of two arguments"""
291 def __init__(self,arg0,arg1):
292 a=[arg0,arg1]
293 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
294 def __str__(self):
295 return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
296 def eval(self,argval):
297 a=self.getEvaluatedArguments(argval)
298 return a[0]/a[1]
299 def _diff(self,arg):
300 a=self.getDifferentiatedArguments(arg)
301 return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \
302 (self.getArgument(1)*self.getArgument(1))
303
304 class Power_Symbol(Symbol):
305 """symbol representing the power of the first argument to the power of the second argument"""
306 def __init__(self,arg0,arg1):
307 a=[arg0,arg1]
308 Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a)
309 def __str__(self):
310 return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
311 def eval(self,argval):
312 a=self.getEvaluatedArguments(argval)
313 return a[0]**a[1]
314 def _diff(self,arg):
315 a=self.getDifferentiatedArguments(arg)
316 return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0])
317
318 class Abs_Symbol(Symbol):
319 """symbol representing absolute value of its argument"""
320 def __init__(self,arg):
321 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
322 def __str__(self):
323 return "abs(%s)"%str(self.getArgument(0))
324 def eval(self,argval):
325 return abs(self.getEvaluatedArguments(argval)[0])
326 def _diff(self,arg):
327 return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
328
329 #=========================================================
330 # some little helpers
331 #=========================================================
332 def _testForZero(arg):
333 """returns True is arg is considered of being zero"""
334 if isinstance(arg,int):
335 return not arg>0
336 elif isinstance(arg,float):
337 return not arg>0.
338 elif isinstance(arg,numarray.NumArray):
339 a=abs(arg)
340 while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
341 return not a>0
342 else:
343 return False
344
345 def _extractDim(args):
346 dim=None
347 for a in args:
348 if hasattr(a,"getDim"):
349 d=a.getDim()
350 if dim==None:
351 dim=d
352 else:
353 if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments"
354 if dim==None:
355 raise ValueError,"cannot recover spatial dimension"
356 return dim
357
358 def _identifyShape(arg):
359 """identifies the shape of arg."""
360 if hasattr(arg,"getShape"):
361 arg_shape=arg.getShape()
362 elif hasattr(arg,"shape"):
363 s=arg.shape
364 if callable(s):
365 arg_shape=s()
366 else:
367 arg_shape=s
368 else:
369 arg_shape=()
370 return arg_shape
371
372 def _extractShape(args):
373 """extracts the common shape of the list of arguments args"""
374 shape=None
375 for a in args:
376 a_shape=_identifyShape(a)
377 if shape==None: shape=a_shape
378 if shape!=a_shape: raise ValueError,"inconsistent shape"
379 if shape==None:
380 raise ValueError,"cannot recover shape"
381 return shape
382
383 def _matchShape(args,shape=None):
384 """returns the list of arguments args as object which have all the specified shape.
385 if shape is not given the shape "largest" shape of args is used."""
386 # identify the list of shapes:
387 arg_shapes=[]
388 for a in args: arg_shapes.append(_identifyShape(a))
389 # get the largest shape (currently the longest shape):
390 if shape==None: shape=max(arg_shapes)
391
392 out=[]
393 for i in range(len(args)):
394 if shape==arg_shapes[i]:
395 out.append(args[i])
396 else:
397 if len(shape)==0: # then len(arg_shapes[i])>0
398 raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
399 else:
400 if len(arg_shapes[i])==0:
401 out.append(outer(args[i],numarray.ones(shape)))
402 else:
403 raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape))
404 return out
405
406 #=========================================================
407 # wrappers for various mathematical functions:
408 #=========================================================
409 def diff(arg,dep):
410 """returns the derivative of arg with respect to dep. If arg is not Symbol object
411 0 is returned"""
412 if isinstance(arg,Symbol):
413 return arg.diff(dep)
414 elif hasattr(arg,"shape"):
415 if callable(arg.shape):
416 return numarray.zeros(arg.shape(),numarray.Float)
417 else:
418 return numarray.zeros(arg.shape,numarray.Float)
419 else:
420 return 0
421
422 def exp(arg):
423 """
424 @brief applies the exponential function to arg
425 @param arg (input): argument
426 """
427 if isinstance(arg,Symbol):
428 return Exp_Symbol(arg)
429 elif hasattr(arg,"exp"):
430 return arg.exp()
431 else:
432 return numarray.exp(arg)
433
434 class Exp_Symbol(Symbol):
435 """symbol representing the power of the first argument to the power of the second argument"""
436 def __init__(self,arg):
437 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
438 def __str__(self):
439 return "exp(%s)"%str(self.getArgument(0))
440 def eval(self,argval):
441 return exp(self.getEvaluatedArguments(argval)[0])
442 def _diff(self,arg):
443 return self*self.getDifferentiatedArguments(arg)[0]
444
445 def sqrt(arg):
446 """
447 @brief applies the squre root function to arg
448 @param arg (input): argument
449 """
450 if isinstance(arg,Symbol):
451 return Sqrt_Symbol(arg)
452 elif hasattr(arg,"sqrt"):
453 return arg.sqrt()
454 else:
455 return numarray.sqrt(arg)
456
457 class Sqrt_Symbol(Symbol):
458 """symbol representing square root of argument"""
459 def __init__(self,arg):
460 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
461 def __str__(self):
462 return "sqrt(%s)"%str(self.getArgument(0))
463 def eval(self,argval):
464 return sqrt(self.getEvaluatedArguments(argval)[0])
465 def _diff(self,arg):
466 return (-0.5)/self*self.getDifferentiatedArguments(arg)[0]
467
468 def log(arg):
469 """
470 @brief applies the logarithmic function bases exp(1.) to arg
471 @param arg (input): argument
472 """
473 if isinstance(arg,Symbol):
474 return Log_Symbol(arg)
475 elif hasattr(arg,"log"):
476 return arg.log()
477 else:
478 return numarray.log(arg)
479
480 class Log_Symbol(Symbol):
481 """symbol representing logarithm of the argument"""
482 def __init__(self,arg):
483 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
484 def __str__(self):
485 return "log(%s)"%str(self.getArgument(0))
486 def eval(self,argval):
487 return log(self.getEvaluatedArguments(argval)[0])
488 def _diff(self,arg):
489 return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
490
491 def ln(arg):
492 """
493 @brief applies the natural logarithmic function to arg
494 @param arg (input): argument
495 """
496 if isinstance(arg,Symbol):
497 return Ln_Symbol(arg)
498 elif hasattr(arg,"ln"):
499 return arg.log()
500 else:
501 return numarray.log(arg)
502
503 class Ln_Symbol(Symbol):
504 """symbol representing natural logarithm of the argument"""
505 def __init__(self,arg):
506 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
507 def __str__(self):
508 return "ln(%s)"%str(self.getArgument(0))
509 def eval(self,argval):
510 return ln(self.getEvaluatedArguments(argval)[0])
511 def _diff(self,arg):
512 return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0)
513
514 def sin(arg):
515 """
516 @brief applies the sin function to arg
517 @param arg (input): argument
518 """
519 if isinstance(arg,Symbol):
520 return Sin_Symbol(arg)
521 elif hasattr(arg,"sin"):
522 return arg.sin()
523 else:
524 return numarray.sin(arg)
525
526 class Sin_Symbol(Symbol):
527 """symbol representing sin of the argument"""
528 def __init__(self,arg):
529 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
530 def __str__(self):
531 return "sin(%s)"%str(self.getArgument(0))
532 def eval(self,argval):
533 return sin(self.getEvaluatedArguments(argval)[0])
534 def _diff(self,arg):
535 return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
536
537 def cos(arg):
538 """
539 @brief applies the cos function to arg
540 @param arg (input): argument
541 """
542 if isinstance(arg,Symbol):
543 return Cos_Symbol(arg)
544 elif hasattr(arg,"cos"):
545 return arg.cos()
546 else:
547 return numarray.cos(arg)
548
549 class Cos_Symbol(Symbol):
550 """symbol representing cos of the argument"""
551 def __init__(self,arg):
552 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
553 def __str__(self):
554 return "cos(%s)"%str(self.getArgument(0))
555 def eval(self,argval):
556 return cos(self.getEvaluatedArguments(argval)[0])
557 def _diff(self,arg):
558 return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0]
559
560 def tan(arg):
561 """
562 @brief applies the tan function to arg
563 @param arg (input): argument
564 """
565 if isinstance(arg,Symbol):
566 return Tan_Symbol(arg)
567 elif hasattr(arg,"tan"):
568 return arg.tan()
569 else:
570 return numarray.tan(arg)
571
572 class Tan_Symbol(Symbol):
573 """symbol representing tan of the argument"""
574 def __init__(self,arg):
575 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
576 def __str__(self):
577 return "tan(%s)"%str(self.getArgument(0))
578 def eval(self,argval):
579 return tan(self.getEvaluatedArguments(argval)[0])
580 def _diff(self,arg):
581 s=cos(self.getArgument(0))
582 return 1./(s*s)*self.getDifferentiatedArguments(arg)[0]
583
584 def sign(arg):
585 """
586 @brief applies the sign function to arg
587 @param arg (input): argument
588 """
589 if isinstance(arg,Symbol):
590 return Sign_Symbol(arg)
591 elif hasattr(arg,"sign"):
592 return arg.sign()
593 else:
594 return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
595 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
596
597 class Sign_Symbol(Symbol):
598 """symbol representing the sign of the argument"""
599 def __init__(self,arg):
600 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
601 def __str__(self):
602 return "sign(%s)"%str(self.getArgument(0))
603 def eval(self,argval):
604 return sign(self.getEvaluatedArguments(argval)[0])
605
606 def maxval(arg):
607 """
608 @brief returns the maximum value of argument arg""
609 @param arg (input): argument
610 """
611 if isinstance(arg,Symbol):
612 return Max_Symbol(arg)
613 elif hasattr(arg,"maxval"):
614 return arg.maxval()
615 elif hasattr(arg,"max"):
616 return arg.max()
617 else:
618 return arg
619
620 class Max_Symbol(Symbol):
621 """symbol representing the maximum value of the argument"""
622 def __init__(self,arg):
623 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
624 def __str__(self):
625 return "maxval(%s)"%str(self.getArgument(0))
626 def eval(self,argval):
627 return maxval(self.getEvaluatedArguments(argval)[0])
628
629 def minval(arg):
630 """
631 @brief returns the minimum value of argument arg""
632 @param arg (input): argument
633 """
634 if isinstance(arg,Symbol):
635 return Min_Symbol(arg)
636 elif hasattr(arg,"maxval"):
637 return arg.minval()
638 elif hasattr(arg,"min"):
639 return arg.min()
640 else:
641 return arg
642
643 class Min_Symbol(Symbol):
644 """symbol representing the minimum value of the argument"""
645 def __init__(self,arg):
646 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
647 def __str__(self):
648 return "minval(%s)"%str(self.getArgument(0))
649 def eval(self,argval):
650 return minval(self.getEvaluatedArguments(argval)[0])
651
652 def wherePositive(arg):
653 """
654 @brief returns the positive values of argument arg""
655 @param arg (input): argument
656 """
657 if _testForZero(arg):
658 return 0
659 elif isinstance(arg,Symbol):
660 return WherePositive_Symbol(arg)
661 elif hasattr(arg,"wherePositive"):
662 return arg.minval()
663 elif hasattr(arg,"wherePositive"):
664 numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
665 else:
666 if arg>0:
667 return 1.
668 else:
669 return 0.
670
671 class WherePositive_Symbol(Symbol):
672 """symbol representing the wherePositive function"""
673 def __init__(self,arg):
674 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
675 def __str__(self):
676 return "wherePositive(%s)"%str(self.getArgument(0))
677 def eval(self,argval):
678 return wherePositive(self.getEvaluatedArguments(argval)[0])
679
680 def whereNegative(arg):
681 """
682 @brief returns the negative values of argument arg""
683 @param arg (input): argument
684 """
685 if _testForZero(arg):
686 return 0
687 elif isinstance(arg,Symbol):
688 return WhereNegative_Symbol(arg)
689 elif hasattr(arg,"whereNegative"):
690 return arg.whereNegative()
691 elif hasattr(arg,"shape"):
692 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
693 else:
694 if arg<0:
695 return 1.
696 else:
697 return 0.
698
699 class WhereNegative_Symbol(Symbol):
700 """symbol representing the whereNegative function"""
701 def __init__(self,arg):
702 Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg])
703 def __str__(self):
704 return "whereNegative(%s)"%str(self.getArgument(0))
705 def eval(self,argval):
706 return whereNegative(self.getEvaluatedArguments(argval)[0])
707
708 def maximum(arg0,arg1):
709 """return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned"""
710 m=whereNegative(arg0-arg1)
711 return m*arg1+(1.-m)*arg0
712
713 def minimum(arg0,arg1):
714 """return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned"""
715 m=whereNegative(arg0-arg1)
716 return m*arg0+(1.-m)*arg1
717
718 def outer(arg0,arg1):
719 if _testForZero(arg0) or _testForZero(arg1):
720 return 0
721 else:
722 if isinstance(arg0,Symbol) or isinstance(arg1,Symbol):
723 return Outer_Symbol(arg0,arg1)
724 elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
725 return arg0*arg1
726 elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
727 return numarray.outer(arg0,arg1)
728 else:
729 if arg0.getRank()==1 and arg1.getRank()==1:
730 out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
731 for i in range(arg0.getShape()[0]):
732 for j in range(arg1.getShape()[0]):
733 out[i,j]=arg0[i]*arg1[j]
734 return out
735 else:
736 raise ValueError,"outer is not fully implemented yet."
737
738 class Outer_Symbol(Symbol):
739 """symbol representing the outer product of its two arguments"""
740 def __init__(self,arg0,arg1):
741 a=[arg0,arg1]
742 s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1)))
743 Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a)
744 def __str__(self):
745 return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1)))
746 def eval(self,argval):
747 a=self.getEvaluatedArguments(argval)
748 return outer(a[0],a[1])
749 def _diff(self,arg):
750 a=self.getDifferentiatedArguments(arg)
751 return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1])
752
753 def interpolate(arg,where):
754 """
755 @brief interpolates the function into the FunctionSpace where.
756
757 @param arg interpolant
758 @param where FunctionSpace to interpolate to
759 """
760 if _testForZero(arg):
761 return 0
762 elif isinstance(arg,Symbol):
763 return Interpolated_Symbol(arg,where)
764 else:
765 return escript.Data(arg,where)
766
767 def Interpolated_Symbol(Symbol):
768 """symbol representing the integral of the argument"""
769 def __init__(self,arg,where):
770 Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where])
771 def __str__(self):
772 return "interpolated(%s)"%(str(self.getArgument(0)))
773 def eval(self,argval):
774 a=self.getEvaluatedArguments(argval)
775 return integrate(a[0],where=self.getArgument(1))
776 def _diff(self,arg):
777 a=self.getDifferentiatedArguments(arg)
778 return integrate(a[0],where=self.getArgument(1))
779
780 def div(arg,where=None):
781 """
782 @brief returns the divergence of arg at where.
783
784 @param arg: Data object representing the function which gradient to be calculated.
785 @param where: FunctionSpace in which the gradient will be calculated. If not present or
786 None an appropriate default is used.
787 """
788 return trace(grad(arg,where))
789
790 def grad(arg,where=None):
791 """
792 @brief returns the spatial gradient of arg at where.
793
794 @param arg: Data object representing the function which gradient to be calculated.
795 @param where: FunctionSpace in which the gradient will be calculated. If not present or
796 None an appropriate default is used.
797 """
798 if _testForZero(arg):
799 return 0
800 elif isinstance(arg,Symbol):
801 return Grad_Symbol(arg,where)
802 elif hasattr(arg,"grad"):
803 if where==None:
804 return arg.grad()
805 else:
806 return arg.grad(where)
807 else:
808 return arg*0.
809
810 def Grad_Symbol(Symbol):
811 """symbol representing the gradient of the argument"""
812 def __init__(self,arg,where=None):
813 d=_extractDim([arg])
814 s=tuple(list(_identifyShape([arg])).append(d))
815 Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where])
816 def __str__(self):
817 return "grad(%s)"%(str(self.getArgument(0)))
818 def eval(self,argval):
819 a=self.getEvaluatedArguments(argval)
820 return grad(a[0],where=self.getArgument(1))
821 def _diff(self,arg):
822 a=self.getDifferentiatedArguments(arg)
823 return grad(a[0],where=self.getArgument(1))
824
825 def integrate(arg,where=None):
826 """
827 @brief return the integral if the function represented by Data object arg over its domain.
828
829 @param arg: Data object representing the function which is integrated.
830 @param where: FunctionSpace in which the integral is calculated. If not present or
831 None an appropriate default is used.
832 """
833 if _testForZero(arg):
834 return 0
835 elif isinstance(arg,Symbol):
836 return Integral_Symbol(arg,where)
837 else:
838 if not where==None: arg=escript.Data(arg,where)
839 if arg.getRank()==0:
840 return arg.integrate()[0]
841 else:
842 return arg.integrate()
843
844 def Integral_Symbol(Float_Symbol):
845 """symbol representing the integral of the argument"""
846 def __init__(self,arg,where=None):
847 Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where])
848 def __str__(self):
849 return "integral(%s)"%(str(self.getArgument(0)))
850 def eval(self,argval):
851 a=self.getEvaluatedArguments(argval)
852 return integrate(a[0],where=self.getArgument(1))
853 def _diff(self,arg):
854 a=self.getDifferentiatedArguments(arg)
855 return integrate(a[0],where=self.getArgument(1))
856
857 #=============================
858 #
859 # wrapper for various functions: if the argument has attribute the function name
860 # as an argument it calls the corresponding methods. Otherwise the corresponding
861 # numarray function is called.
862
863 # functions involving the underlying Domain:
864
865
866 # functions returning Data objects:
867
868 def transpose(arg,axis=None):
869 """
870 @brief returns the transpose of the Data object arg.
871
872 @param arg
873 """
874 if axis==None:
875 r=0
876 if hasattr(arg,"getRank"): r=arg.getRank()
877 if hasattr(arg,"rank"): r=arg.rank
878 axis=r/2
879 if isinstance(arg,Symbol):
880 return Transpose_Symbol(arg,axis=r)
881 if isinstance(arg,escript.Data):
882 # hack for transpose
883 r=arg.getRank()
884 if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
885 s=arg.getShape()
886 out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
887 for i in range(s[0]):
888 for j in range(s[1]):
889 out[j,i]=arg[i,j]
890 return out
891 # end hack for transpose
892 return arg.transpose(axis)
893 else:
894 return numarray.transpose(arg,axis=axis)
895
896 def trace(arg,axis0=0,axis1=1):
897 """
898 @brief return
899
900 @param arg
901 """
902 if isinstance(arg,Symbol):
903 s=list(arg.getShape())
904 s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
905 return Trace_Symbol(arg,axis0=axis0,axis1=axis1)
906 elif isinstance(arg,escript.Data):
907 # hack for trace
908 s=arg.getShape()
909 if s[axis0]!=s[axis1]:
910 raise ValueError,"illegal axis in trace"
911 out=escript.Scalar(0.,arg.getFunctionSpace())
912 for i in range(s[axis0]):
913 out+=arg[i,i]
914 return out
915 # end hack for trace
916 else:
917 return numarray.trace(arg,axis0=axis0,axis1=axis1)
918
919 def Trace_Symbol(Symbol):
920 pass
921
922 def length(arg):
923 """
924 @brief
925
926 @param arg
927 """
928 if isinstance(arg,escript.Data):
929 if arg.isEmpty(): return escript.Data()
930 if arg.getRank()==0:
931 return abs(arg)
932 elif arg.getRank()==1:
933 out=escript.Scalar(0,arg.getFunctionSpace())
934 for i in range(arg.getShape()[0]):
935 out+=arg[i]**2
936 return sqrt(out)
937 elif arg.getRank()==2:
938 out=escript.Scalar(0,arg.getFunctionSpace())
939 for i in range(arg.getShape()[0]):
940 for j in range(arg.getShape()[1]):
941 out+=arg[i,j]**2
942 return sqrt(out)
943 elif arg.getRank()==3:
944 out=escript.Scalar(0,arg.getFunctionSpace())
945 for i in range(arg.getShape()[0]):
946 for j in range(arg.getShape()[1]):
947 for k in range(arg.getShape()[2]):
948 out+=arg[i,j,k]**2
949 return sqrt(out)
950 elif arg.getRank()==4:
951 out=escript.Scalar(0,arg.getFunctionSpace())
952 for i in range(arg.getShape()[0]):
953 for j in range(arg.getShape()[1]):
954 for k in range(arg.getShape()[2]):
955 for l in range(arg.getShape()[3]):
956 out+=arg[i,j,k,l]**2
957 return sqrt(out)
958 else:
959 raise SystemError,"length is not been fully implemented yet"
960 # return arg.length()
961 elif isinstance(arg,float):
962 return abs(arg)
963 else:
964 return sqrt((arg**2).sum())
965
966 def deviator(arg):
967 """
968 @brief
969
970 @param arg0
971 """
972 if isinstance(arg,escript.Data):
973 shape=arg.getShape()
974 else:
975 shape=arg.shape
976 if len(shape)!=2:
977 raise ValueError,"Deviator requires rank 2 object"
978 if shape[0]!=shape[1]:
979 raise ValueError,"Deviator requires a square matrix"
980 return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
981
982 def inner(arg0,arg1):
983 """
984 @brief
985
986 @param arg0, arg1
987 """
988 if isinstance(arg0,escript.Data):
989 arg=arg0
990 else:
991 arg=arg1
992
993 out=escript.Scalar(0,arg.getFunctionSpace())
994 if arg.getRank()==0:
995 return arg0*arg1
996 elif arg.getRank()==1:
997 out=escript.Scalar(0,arg.getFunctionSpace())
998 for i in range(arg.getShape()[0]):
999 out+=arg0[i]*arg1[i]
1000 elif arg.getRank()==2:
1001 out=escript.Scalar(0,arg.getFunctionSpace())
1002 for i in range(arg.getShape()[0]):
1003 for j in range(arg.getShape()[1]):
1004 out+=arg0[i,j]*arg1[i,j]
1005 elif arg.getRank()==3:
1006 out=escript.Scalar(0,arg.getFunctionSpace())
1007 for i in range(arg.getShape()[0]):
1008 for j in range(arg.getShape()[1]):
1009 for k in range(arg.getShape()[2]):
1010 out+=arg0[i,j,k]*arg1[i,j,k]
1011 elif arg.getRank()==4:
1012 out=escript.Scalar(0,arg.getFunctionSpace())
1013 for i in range(arg.getShape()[0]):
1014 for j in range(arg.getShape()[1]):
1015 for k in range(arg.getShape()[2]):
1016 for l in range(arg.getShape()[3]):
1017 out+=arg0[i,j,k,l]*arg1[i,j,k,l]
1018 else:
1019 raise SystemError,"inner is not been implemented yet"
1020 return out
1021
1022 def matrixmult(arg0,arg1):
1023
1024 if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
1025 numarray.matrixmult(arg0,arg1)
1026 else:
1027 # escript.matmult(arg0,arg1)
1028 if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
1029 arg0=escript.Data(arg0,arg1.getFunctionSpace())
1030 elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
1031 arg1=escript.Data(arg1,arg0.getFunctionSpace())
1032 if arg0.getRank()==2 and arg1.getRank()==1:
1033 out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
1034 for i in range(arg0.getShape()[0]):
1035 for j in range(arg0.getShape()[1]):
1036 # uses Data object slicing, plus Data * and += operators
1037 out[i]+=arg0[i,j]*arg1[j]
1038 return out
1039 elif arg0.getRank()==1 and arg1.getRank()==1:
1040 return inner(arg0,arg1)
1041 else:
1042 raise SystemError,"matrixmult is not fully implemented yet!"
1043
1044 #=========================================================
1045 # reduction operations:
1046 #=========================================================
1047 def sum(arg):
1048 """
1049 @brief
1050
1051 @param arg
1052 """
1053 return arg.sum()
1054
1055 def sup(arg):
1056 """
1057 @brief
1058
1059 @param arg
1060 """
1061 if isinstance(arg,escript.Data):
1062 return arg.sup()
1063 elif isinstance(arg,float) or isinstance(arg,int):
1064 return arg
1065 else:
1066 return arg.max()
1067
1068 def inf(arg):
1069 """
1070 @brief
1071
1072 @param arg
1073 """
1074 if isinstance(arg,escript.Data):
1075 return arg.inf()
1076 elif isinstance(arg,float) or isinstance(arg,int):
1077 return arg
1078 else:
1079 return arg.min()
1080
1081 def L2(arg):
1082 """
1083 @brief returns the L2-norm of the
1084
1085 @param arg
1086 """
1087 if isinstance(arg,escript.Data):
1088 return arg.L2()
1089 elif isinstance(arg,float) or isinstance(arg,int):
1090 return abs(arg)
1091 else:
1092 return numarry.sqrt(dot(arg,arg))
1093
1094 def Lsup(arg):
1095 """
1096 @brief
1097
1098 @param arg
1099 """
1100 if isinstance(arg,escript.Data):
1101 return arg.Lsup()
1102 elif isinstance(arg,float) or isinstance(arg,int):
1103 return abs(arg)
1104 else:
1105 return max(numarray.abs(arg))
1106
1107 def dot(arg0,arg1):
1108 """
1109 @brief
1110
1111 @param arg
1112 """
1113 if isinstance(arg0,escript.Data):
1114 return arg0.dot(arg1)
1115 elif isinstance(arg1,escript.Data):
1116 return arg1.dot(arg0)
1117 else:
1118 return numarray.dot(arg0,arg1)
1119
1120 def kronecker(d):
1121 if hasattr(d,"getDim"):
1122 return numarray.identity(d.getDim())*1.
1123 else:
1124 return numarray.identity(d)*1.
1125
1126 def unit(i,d):
1127 """
1128 @brief return a unit vector of dimension d with nonzero index i
1129 @param d dimension
1130 @param i index
1131 """
1132 e = numarray.zeros((d,),numarray.Float)
1133 e[i] = 1.0
1134 return e
1135
1136 # ============================================
1137 # testing
1138 # ============================================
1139
1140 if __name__=="__main__":
1141 u=ScalarSymbol(dim=2,name="u")
1142 v=ScalarSymbol(dim=2,name="v")
1143 v=VectorSymbol(2,"v")
1144 u=VectorSymbol(2,"u")
1145
1146 print u+5,(u+5).diff(u)
1147 print 5+u,(5+u).diff(u)
1148 print u+v,(u+v).diff(u)
1149 print v+u,(v+u).diff(u)
1150
1151 print u*5,(u*5).diff(u)
1152 print 5*u,(5*u).diff(u)
1153 print u*v,(u*v).diff(u)
1154 print v*u,(v*u).diff(u)
1155
1156 print u-5,(u-5).diff(u)
1157 print 5-u,(5-u).diff(u)
1158 print u-v,(u-v).diff(u)
1159 print v-u,(v-u).diff(u)
1160
1161 print u/5,(u/5).diff(u)
1162 print 5/u,(5/u).diff(u)
1163 print u/v,(u/v).diff(u)
1164 print v/u,(v/u).diff(u)
1165
1166 print u**5,(u**5).diff(u)
1167 print 5**u,(5**u).diff(u)
1168 print u**v,(u**v).diff(u)
1169 print v**u,(v**u).diff(u)
1170
1171 print exp(u),exp(u).diff(u)
1172 print sqrt(u),sqrt(u).diff(u)
1173 print log(u),log(u).diff(u)
1174 print sin(u),sin(u).diff(u)
1175 print cos(u),cos(u).diff(u)
1176 print tan(u),tan(u).diff(u)
1177 print sign(u),sign(u).diff(u)
1178 print abs(u),abs(u).diff(u)
1179 print wherePositive(u),wherePositive(u).diff(u)
1180 print whereNegative(u),whereNegative(u).diff(u)
1181 print maxval(u),maxval(u).diff(u)
1182 print minval(u),minval(u).diff(u)
1183
1184 g=grad(u)
1185 print diff(5*g,g)
1186 4*(g+transpose(g))/2+6*trace(g)*kronecker(3)
1187
1188 #
1189 # $Log$
1190 # Revision 1.16 2005/08/23 01:24:28 jgs
1191 # Merge of development branch dev-02 back to main trunk on 2005-08-23
1192 #
1193 # Revision 1.15 2005/08/12 01:45:36 jgs
1194 # erge of development branch dev-02 back to main trunk on 2005-08-12
1195 #
1196 # Revision 1.14.2.4 2005/08/18 04:39:32 gross
1197 # the constants have been removed from util.py as they not needed anymore. PDE related constants are accessed through LinearPDE attributes now
1198 #
1199 # Revision 1.14.2.3 2005/08/03 09:55:33 gross
1200 # ContactTest is passing now./mk install!
1201 #
1202 # Revision 1.14.2.2 2005/08/02 03:15:14 gross
1203 # bug inb trace fixed!
1204 #
1205 # Revision 1.14.2.1 2005/07/29 07:10:28 gross
1206 # new functions in util and a new pde type in linearPDEs
1207 #
1208 # Revision 1.2.2.21 2005/07/28 04:19:23 gross
1209 # new functions maximum and minimum introduced.
1210 #
1211 # Revision 1.2.2.20 2005/07/25 01:26:27 gross
1212 # bug in inner fixed
1213 #
1214 # Revision 1.2.2.19 2005/07/21 04:01:28 jgs
1215 # minor comment fixes
1216 #
1217 # Revision 1.2.2.18 2005/07/21 01:02:43 jgs
1218 # commit ln() updates to development branch version
1219 #
1220 # Revision 1.12 2005/07/20 06:14:58 jgs
1221 # added ln(data) style wrapper for data.ln() - also added corresponding
1222 # implementation of Ln_Symbol class (not sure if this is right though)
1223 #
1224 # Revision 1.11 2005/07/08 04:07:35 jgs
1225 # Merge of development branch back to main trunk on 2005-07-08
1226 #
1227 # Revision 1.10 2005/06/09 05:37:59 jgs
1228 # Merge of development branch back to main trunk on 2005-06-09
1229 #
1230 # Revision 1.2.2.17 2005/07/07 07:28:58 gross
1231 # some stuff added to util.py to improve functionality
1232 #
1233 # Revision 1.2.2.16 2005/06/30 01:53:55 gross
1234 # a bug in coloring fixed
1235 #
1236 # Revision 1.2.2.15 2005/06/29 02:36:43 gross
1237 # Symbols have been introduced and some function clarified. needs much more work
1238 #
1239 # Revision 1.2.2.14 2005/05/20 04:05:23 gross
1240 # some work on a darcy flow started
1241 #
1242 # Revision 1.2.2.13 2005/03/16 05:17:58 matt
1243 # Implemented unit(idx, dim) to create cartesian unit basis vectors to
1244 # complement kronecker(dim) function.
1245 #
1246 # Revision 1.2.2.12 2005/03/10 08:14:37 matt
1247 # Added non-member Linf utility function to complement Data::Linf().
1248 #
1249 # Revision 1.2.2.11 2005/02/17 05:53:25 gross
1250 # some bug in saveDX fixed: in fact the bug was in
1251 # DataC/getDataPointShape
1252 #
1253 # Revision 1.2.2.10 2005/01/11 04:59:36 gross
1254 # automatic interpolation in integrate switched off
1255 #
1256 # Revision 1.2.2.9 2005/01/11 03:38:13 gross
1257 # Bug in Data.integrate() fixed for the case of rank 0. The problem is not totallly resolved as the method should return a scalar rather than a numarray object in the case of rank 0. This problem is fixed by the util.integrate wrapper.
1258 #
1259 # Revision 1.2.2.8 2005/01/05 04:21:41 gross
1260 # FunctionSpace checking/matchig in slicing added
1261 #
1262 # Revision 1.2.2.7 2004/12/29 05:29:59 gross
1263 # AdvectivePDE successfully tested for Peclet number 1000000. there is still a problem with setValue and Data()
1264 #
1265 # Revision 1.2.2.6 2004/12/24 06:05:41 gross
1266 # some changes in linearPDEs to add AdevectivePDE
1267 #
1268 # Revision 1.2.2.5 2004/12/17 00:06:53 gross
1269 # mk sets ESYS_ROOT is undefined
1270 #
1271 # Revision 1.2.2.4 2004/12/07 03:19:51 gross
1272 # options for GMRES and PRES20 added
1273 #
1274 # Revision 1.2.2.3 2004/12/06 04:55:18 gross
1275 # function wraper extended
1276 #
1277 # Revision 1.2.2.2 2004/11/22 05:44:07 gross
1278 # a few more unitary functions have been added but not implemented in Data yet
1279 #
1280 # Revision 1.2.2.1 2004/11/12 06:58:15 gross
1281 # a lot of changes to get the linearPDE class running: most important change is that there is no matrix format exposed to the user anymore. the format is chosen by the Domain according to the solver and symmetry
1282 #
1283 # Revision 1.2 2004/10/27 00:23:36 jgs
1284 # fixed minor syntax error
1285 #
1286 # Revision 1.1.1.1 2004/10/26 06:53:56 jgs
1287 # initial import of project esys2
1288 #
1289 # Revision 1.1.2.3 2004/10/26 06:43:48 jgs
1290 # committing Lutz's and Paul's changes to brach jgs
1291 #
1292 # Revision 1.1.4.1 2004/10/20 05:32:51 cochrane
1293 # Added incomplete Doxygen comments to files, or merely put the docstrings that already exist into Doxygen form.
1294 #
1295 # Revision 1.1 2004/08/05 03:58:27 gross
1296 # Bug in Assemble_NodeCoordinates fixed
1297 #
1298 #

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26