/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Contents of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 155 - (show annotations)
Wed Nov 9 02:02:19 2005 UTC (14 years, 1 month ago) by jgs
File MIME type: text/x-python
File size: 21503 byte(s)
move all directories from trunk/esys2 into trunk and remove esys2

1 # $Id$
2
3 ## @file util.py
4
5 """
6 Utility functions for escript
7
8 @todo:
9
10 - binary operations @ (@=+,-,*,/,**)::
11 (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
12 (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
13 (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
14 - implementation of outer::
15 outer(a,b)[:,*]=a[:]*b[*]
16 - trace::
17 trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
18 """
19
20 import numarray
21 import escript
22 import symbols
23 import os
24
25 #=========================================================
26 # some little helpers
27 #=========================================================
28 def _testForZero(arg):
29 """
30 Returns True is arg is considered to be zero.
31 """
32 if isinstance(arg,int):
33 return not arg>0
34 elif isinstance(arg,float):
35 return not arg>0.
36 elif isinstance(arg,numarray.NumArray):
37 a=abs(arg)
38 while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
39 return not a>0
40 else:
41 return False
42
43 #=========================================================
44 def saveVTK(filename,domain=None,**data):
45 """
46 writes a L{Data} objects into a files using the the VTK XML file format.
47
48 Example:
49
50 tmp=Scalar(..)
51 v=Vector(..)
52 saveVTK("solution.xml",temperature=tmp,velovity=v)
53
54 tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity"
55
56 @param filename: file name of the output file
57 @type filename: C(str}
58 @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
59 @type domain: L{escript.Domain}
60 @keyword <name>: writes the assigned value to the VTK file using <name> as identifier.
61 @type <name>: L{Data} object.
62 @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
63 """
64 if domain==None:
65 for i in data.keys():
66 if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
67 if domain==None:
68 raise ValueError,"no domain detected."
69 else:
70 domain.saveVTK(filename,data)
71
72 #=========================================================
73 def saveDX(filename,domain=None,**data):
74 """
75 writes a L{Data} objects into a files using the the DX file format.
76
77 Example:
78
79 tmp=Scalar(..)
80 v=Vector(..)
81 saveDX("solution.dx",temperature=tmp,velovity=v)
82
83 tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity".
84
85 @param filename: file name of the output file
86 @type filename: C(str}
87 @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
88 @type domain: L{escript.Domain}
89 @keyword <name>: writes the assigned value to the DX file using <name> as identifier. The identifier can be used to select the data set when data are imported into DX.
90 @type <name>: L{Data} object.
91 @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
92 """
93 if domain==None:
94 for i in data.keys():
95 if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
96 if domain==None:
97 raise ValueError,"no domain detected."
98 else:
99 domain.saveDX(filename,data)
100
101 #=========================================================
102
103 def exp(arg):
104 """
105 Applies the exponential function to arg.
106
107 @param arg: argument
108 """
109 if isinstance(arg,symbols.Symbol):
110 return symbols.Exp_Symbol(arg)
111 elif hasattr(arg,"exp"):
112 return arg.exp()
113 else:
114 return numarray.exp(arg)
115
116 def sqrt(arg):
117 """
118 Applies the squre root function to arg.
119
120 @param arg: argument
121 """
122 if isinstance(arg,symbols.Symbol):
123 return symbols.Sqrt_Symbol(arg)
124 elif hasattr(arg,"sqrt"):
125 return arg.sqrt()
126 else:
127 return numarray.sqrt(arg)
128
129 def log(arg):
130 """
131 Applies the logarithmic function base 10 to arg.
132
133 @param arg: argument
134 """
135 if isinstance(arg,symbols.Symbol):
136 return symbols.Log_Symbol(arg)
137 elif hasattr(arg,"log"):
138 return arg.log()
139 else:
140 return numarray.log10(arg)
141
142 def ln(arg):
143 """
144 Applies the natural logarithmic function to arg.
145
146 @param arg: argument
147 """
148 if isinstance(arg,symbols.Symbol):
149 return symbols.Ln_Symbol(arg)
150 elif hasattr(arg,"ln"):
151 return arg.ln()
152 else:
153 return numarray.log(arg)
154
155 def sin(arg):
156 """
157 Applies the sin function to arg.
158
159 @param arg: argument
160 """
161 if isinstance(arg,symbols.Symbol):
162 return symbols.Sin_Symbol(arg)
163 elif hasattr(arg,"sin"):
164 return arg.sin()
165 else:
166 return numarray.sin(arg)
167
168 def cos(arg):
169 """
170 Applies the cos function to arg.
171
172 @param arg: argument
173 """
174 if isinstance(arg,symbols.Symbol):
175 return symbols.Cos_Symbol(arg)
176 elif hasattr(arg,"cos"):
177 return arg.cos()
178 else:
179 return numarray.cos(arg)
180
181 def tan(arg):
182 """
183 Applies the tan function to arg.
184
185 @param arg: argument
186 """
187 if isinstance(arg,symbols.Symbol):
188 return symbols.Tan_Symbol(arg)
189 elif hasattr(arg,"tan"):
190 return arg.tan()
191 else:
192 return numarray.tan(arg)
193
194 def asin(arg):
195 """
196 Applies the asin function to arg.
197
198 @param arg: argument
199 """
200 if isinstance(arg,symbols.Symbol):
201 return symbols.Asin_Symbol(arg)
202 elif hasattr(arg,"asin"):
203 return arg.asin()
204 else:
205 return numarray.asin(arg)
206
207 def acos(arg):
208 """
209 Applies the acos function to arg.
210
211 @param arg: argument
212 """
213 if isinstance(arg,symbols.Symbol):
214 return symbols.Acos_Symbol(arg)
215 elif hasattr(arg,"acos"):
216 return arg.acos()
217 else:
218 return numarray.acos(arg)
219
220 def atan(arg):
221 """
222 Applies the atan function to arg.
223
224 @param arg: argument
225 """
226 if isinstance(arg,symbols.Symbol):
227 return symbols.Atan_Symbol(arg)
228 elif hasattr(arg,"atan"):
229 return arg.atan()
230 else:
231 return numarray.atan(arg)
232
233 def sinh(arg):
234 """
235 Applies the sinh function to arg.
236
237 @param arg: argument
238 """
239 if isinstance(arg,symbols.Symbol):
240 return symbols.Sinh_Symbol(arg)
241 elif hasattr(arg,"sinh"):
242 return arg.sinh()
243 else:
244 return numarray.sinh(arg)
245
246 def cosh(arg):
247 """
248 Applies the cosh function to arg.
249
250 @param arg: argument
251 """
252 if isinstance(arg,symbols.Symbol):
253 return symbols.Cosh_Symbol(arg)
254 elif hasattr(arg,"cosh"):
255 return arg.cosh()
256 else:
257 return numarray.cosh(arg)
258
259 def tanh(arg):
260 """
261 Applies the tanh function to arg.
262
263 @param arg: argument
264 """
265 if isinstance(arg,symbols.Symbol):
266 return symbols.Tanh_Symbol(arg)
267 elif hasattr(arg,"tanh"):
268 return arg.tanh()
269 else:
270 return numarray.tanh(arg)
271
272 def asinh(arg):
273 """
274 Applies the asinh function to arg.
275
276 @param arg: argument
277 """
278 if isinstance(arg,symbols.Symbol):
279 return symbols.Asinh_Symbol(arg)
280 elif hasattr(arg,"asinh"):
281 return arg.asinh()
282 else:
283 return numarray.asinh(arg)
284
285 def acosh(arg):
286 """
287 Applies the acosh function to arg.
288
289 @param arg: argument
290 """
291 if isinstance(arg,symbols.Symbol):
292 return symbols.Acosh_Symbol(arg)
293 elif hasattr(arg,"acosh"):
294 return arg.acosh()
295 else:
296 return numarray.acosh(arg)
297
298 def atanh(arg):
299 """
300 Applies the atanh function to arg.
301
302 @param arg: argument
303 """
304 if isinstance(arg,symbols.Symbol):
305 return symbols.Atanh_Symbol(arg)
306 elif hasattr(arg,"atanh"):
307 return arg.atanh()
308 else:
309 return numarray.atanh(arg)
310
311 def sign(arg):
312 """
313 Applies the sign function to arg.
314
315 @param arg: argument
316 """
317 if isinstance(arg,symbols.Symbol):
318 return symbols.Sign_Symbol(arg)
319 elif hasattr(arg,"sign"):
320 return arg.sign()
321 else:
322 return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
323 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
324
325 def maxval(arg):
326 """
327 Returns the maximum value of argument arg.
328
329 @param arg: argument
330 """
331 if isinstance(arg,symbols.Symbol):
332 return symbols.Max_Symbol(arg)
333 elif hasattr(arg,"maxval"):
334 return arg.maxval()
335 elif hasattr(arg,"max"):
336 return arg.max()
337 else:
338 return arg
339
340 def minval(arg):
341 """
342 Returns the minimum value of argument arg.
343
344 @param arg: argument
345 """
346 if isinstance(arg,symbols.Symbol):
347 return symbols.Min_Symbol(arg)
348 elif hasattr(arg,"maxval"):
349 return arg.minval()
350 elif hasattr(arg,"min"):
351 return arg.min()
352 else:
353 return arg
354
355 def wherePositive(arg):
356 """
357 Returns the positive values of argument arg.
358
359 @param arg: argument
360 """
361 if _testForZero(arg):
362 return 0
363 elif isinstance(arg,symbols.Symbol):
364 return symbols.WherePositive_Symbol(arg)
365 elif hasattr(arg,"wherePositive"):
366 return arg.minval()
367 elif hasattr(arg,"wherePositive"):
368 numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
369 else:
370 if arg>0:
371 return 1.
372 else:
373 return 0.
374
375 def whereNegative(arg):
376 """
377 Returns the negative values of argument arg.
378
379 @param arg: argument
380 """
381 if _testForZero(arg):
382 return 0
383 elif isinstance(arg,symbols.Symbol):
384 return symbols.WhereNegative_Symbol(arg)
385 elif hasattr(arg,"whereNegative"):
386 return arg.whereNegative()
387 elif hasattr(arg,"shape"):
388 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
389 else:
390 if arg<0:
391 return 1.
392 else:
393 return 0.
394
395 def maximum(arg0,arg1):
396 """
397 Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned.
398 """
399 m=whereNegative(arg0-arg1)
400 return m*arg1+(1.-m)*arg0
401
402 def minimum(arg0,arg1):
403 """
404 Return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned.
405 """
406 m=whereNegative(arg0-arg1)
407 return m*arg0+(1.-m)*arg1
408
409 def outer(arg0,arg1):
410 if _testForZero(arg0) or _testForZero(arg1):
411 return 0
412 else:
413 if isinstance(arg0,symbols.Symbol) or isinstance(arg1,symbols.Symbol):
414 return symbols.Outer_Symbol(arg0,arg1)
415 elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
416 return arg0*arg1
417 elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
418 return numarray.outer(arg0,arg1)
419 else:
420 if arg0.getRank()==1 and arg1.getRank()==1:
421 out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
422 for i in range(arg0.getShape()[0]):
423 for j in range(arg1.getShape()[0]):
424 out[i,j]=arg0[i]*arg1[j]
425 return out
426 else:
427 raise ValueError,"outer is not fully implemented yet."
428
429 def interpolate(arg,where):
430 """
431 Interpolates the function into the FunctionSpace where.
432
433 @param arg: interpolant
434 @param where: FunctionSpace to interpolate to
435 """
436 if _testForZero(arg):
437 return 0
438 elif isinstance(arg,symbols.Symbol):
439 return symbols.Interpolated_Symbol(arg,where)
440 else:
441 return escript.Data(arg,where)
442
443 def div(arg,where=None):
444 """
445 Returns the divergence of arg at where.
446
447 @param arg: Data object representing the function which gradient to
448 be calculated.
449 @param where: FunctionSpace in which the gradient will be calculated.
450 If not present or C{None} an appropriate default is used.
451 """
452 return trace(grad(arg,where))
453
454 def jump(arg):
455 """
456 Returns the jump of arg across a continuity.
457
458 @param arg: Data object representing the function which gradient
459 to be calculated.
460 """
461 d=arg.getDomain()
462 return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())
463
464
465 def grad(arg,where=None):
466 """
467 Returns the spatial gradient of arg at where.
468
469 @param arg: Data object representing the function which gradient
470 to be calculated.
471 @param where: FunctionSpace in which the gradient will be calculated.
472 If not present or C{None} an appropriate default is used.
473 """
474 if _testForZero(arg):
475 return 0
476 elif isinstance(arg,symbols.Symbol):
477 return symbols.Grad_Symbol(arg,where)
478 elif hasattr(arg,"grad"):
479 if where==None:
480 return arg.grad()
481 else:
482 return arg.grad(where)
483 else:
484 return arg*0.
485
486 def integrate(arg,where=None):
487 """
488 Return the integral if the function represented by Data object arg over
489 its domain.
490
491 @param arg: Data object representing the function which is integrated.
492 @param where: FunctionSpace in which the integral is calculated.
493 If not present or C{None} an appropriate default is used.
494 """
495 if _testForZero(arg):
496 return 0
497 elif isinstance(arg,symbols.Symbol):
498 return symbols.Integral_Symbol(arg,where)
499 else:
500 if not where==None: arg=escript.Data(arg,where)
501 if arg.getRank()==0:
502 return arg.integrate()[0]
503 else:
504 return arg.integrate()
505
506 #=============================
507 #
508 # wrapper for various functions: if the argument has attribute the function name
509 # as an argument it calls the corresponding methods. Otherwise the corresponding
510 # numarray function is called.
511
512 # functions involving the underlying Domain:
513
514
515 # functions returning Data objects:
516
517 def transpose(arg,axis=None):
518 """
519 Returns the transpose of the Data object arg.
520
521 @param arg:
522 """
523 if axis==None:
524 r=0
525 if hasattr(arg,"getRank"): r=arg.getRank()
526 if hasattr(arg,"rank"): r=arg.rank
527 axis=r/2
528 if isinstance(arg,symbols.Symbol):
529 return symbols.Transpose_Symbol(arg,axis=r)
530 if isinstance(arg,escript.Data):
531 # hack for transpose
532 r=arg.getRank()
533 if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
534 s=arg.getShape()
535 out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
536 for i in range(s[0]):
537 for j in range(s[1]):
538 out[j,i]=arg[i,j]
539 return out
540 # end hack for transpose
541 return arg.transpose(axis)
542 else:
543 return numarray.transpose(arg,axis=axis)
544
545 def trace(arg,axis0=0,axis1=1):
546 """
547 Return
548
549 @param arg:
550 """
551 if isinstance(arg,symbols.Symbol):
552 s=list(arg.getShape())
553 s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
554 return symbols.Trace_Symbol(arg,axis0=axis0,axis1=axis1)
555 elif isinstance(arg,escript.Data):
556 # hack for trace
557 s=arg.getShape()
558 if s[axis0]!=s[axis1]:
559 raise ValueError,"illegal axis in trace"
560 out=escript.Scalar(0.,arg.getFunctionSpace())
561 for i in range(s[axis0]):
562 out+=arg[i,i]
563 return out
564 # end hack for trace
565 else:
566 return numarray.trace(arg,axis0=axis0,axis1=axis1)
567
568 def length(arg):
569 """
570 @param arg:
571 """
572 if isinstance(arg,escript.Data):
573 if arg.isEmpty(): return escript.Data()
574 if arg.getRank()==0:
575 return abs(arg)
576 elif arg.getRank()==1:
577 out=escript.Scalar(0,arg.getFunctionSpace())
578 for i in range(arg.getShape()[0]):
579 out+=arg[i]**2
580 return sqrt(out)
581 elif arg.getRank()==2:
582 out=escript.Scalar(0,arg.getFunctionSpace())
583 for i in range(arg.getShape()[0]):
584 for j in range(arg.getShape()[1]):
585 out+=arg[i,j]**2
586 return sqrt(out)
587 elif arg.getRank()==3:
588 out=escript.Scalar(0,arg.getFunctionSpace())
589 for i in range(arg.getShape()[0]):
590 for j in range(arg.getShape()[1]):
591 for k in range(arg.getShape()[2]):
592 out+=arg[i,j,k]**2
593 return sqrt(out)
594 elif arg.getRank()==4:
595 out=escript.Scalar(0,arg.getFunctionSpace())
596 for i in range(arg.getShape()[0]):
597 for j in range(arg.getShape()[1]):
598 for k in range(arg.getShape()[2]):
599 for l in range(arg.getShape()[3]):
600 out+=arg[i,j,k,l]**2
601 return sqrt(out)
602 else:
603 raise SystemError,"length is not been fully implemented yet"
604 # return arg.length()
605 elif isinstance(arg,float):
606 return abs(arg)
607 else:
608 return sqrt((arg**2).sum())
609
610 def deviator(arg):
611 """
612 @param arg:
613 """
614 if isinstance(arg,escript.Data):
615 shape=arg.getShape()
616 else:
617 shape=arg.shape
618 if len(shape)!=2:
619 raise ValueError,"Deviator requires rank 2 object"
620 if shape[0]!=shape[1]:
621 raise ValueError,"Deviator requires a square matrix"
622 return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
623
624 def inner(arg0,arg1):
625 """
626 @param arg0:
627 @param arg1:
628 """
629 if isinstance(arg0,escript.Data):
630 arg=arg0
631 else:
632 arg=arg1
633
634 out=escript.Scalar(0,arg.getFunctionSpace())
635 if arg.getRank()==0:
636 return arg0*arg1
637 elif arg.getRank()==1:
638 out=escript.Scalar(0,arg.getFunctionSpace())
639 for i in range(arg.getShape()[0]):
640 out+=arg0[i]*arg1[i]
641 elif arg.getRank()==2:
642 out=escript.Scalar(0,arg.getFunctionSpace())
643 for i in range(arg.getShape()[0]):
644 for j in range(arg.getShape()[1]):
645 out+=arg0[i,j]*arg1[i,j]
646 elif arg.getRank()==3:
647 out=escript.Scalar(0,arg.getFunctionSpace())
648 for i in range(arg.getShape()[0]):
649 for j in range(arg.getShape()[1]):
650 for k in range(arg.getShape()[2]):
651 out+=arg0[i,j,k]*arg1[i,j,k]
652 elif arg.getRank()==4:
653 out=escript.Scalar(0,arg.getFunctionSpace())
654 for i in range(arg.getShape()[0]):
655 for j in range(arg.getShape()[1]):
656 for k in range(arg.getShape()[2]):
657 for l in range(arg.getShape()[3]):
658 out+=arg0[i,j,k,l]*arg1[i,j,k,l]
659 else:
660 raise SystemError,"inner is not been implemented yet"
661 return out
662
663 def tensormult(arg0,arg1):
664 # check LinearPDE!!!!
665 raise SystemError,"tensormult is not implemented yet!"
666
667 def matrixmult(arg0,arg1):
668
669 if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
670 numarray.matrixmult(arg0,arg1)
671 else:
672 # escript.matmult(arg0,arg1)
673 if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
674 arg0=escript.Data(arg0,arg1.getFunctionSpace())
675 elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
676 arg1=escript.Data(arg1,arg0.getFunctionSpace())
677 if arg0.getRank()==2 and arg1.getRank()==1:
678 out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
679 for i in range(arg0.getShape()[0]):
680 for j in range(arg0.getShape()[1]):
681 # uses Data object slicing, plus Data * and += operators
682 out[i]+=arg0[i,j]*arg1[j]
683 return out
684 elif arg0.getRank()==1 and arg1.getRank()==1:
685 return inner(arg0,arg1)
686 else:
687 raise SystemError,"matrixmult is not fully implemented yet!"
688
689 #=========================================================
690 # reduction operations:
691 #=========================================================
692 def sum(arg):
693 """
694 @param arg:
695 """
696 return arg.sum()
697
698 def sup(arg):
699 """
700 @param arg:
701 """
702 if isinstance(arg,escript.Data):
703 return arg.sup()
704 elif isinstance(arg,float) or isinstance(arg,int):
705 return arg
706 else:
707 return arg.max()
708
709 def inf(arg):
710 """
711 @param arg:
712 """
713 if isinstance(arg,escript.Data):
714 return arg.inf()
715 elif isinstance(arg,float) or isinstance(arg,int):
716 return arg
717 else:
718 return arg.min()
719
720 def L2(arg):
721 """
722 Returns the L2-norm of the argument
723
724 @param arg:
725 """
726 if isinstance(arg,escript.Data):
727 return arg.L2()
728 elif isinstance(arg,float) or isinstance(arg,int):
729 return abs(arg)
730 else:
731 return numarry.sqrt(dot(arg,arg))
732
733 def Lsup(arg):
734 """
735 @param arg:
736 """
737 if isinstance(arg,escript.Data):
738 return arg.Lsup()
739 elif isinstance(arg,float) or isinstance(arg,int):
740 return abs(arg)
741 else:
742 return numarray.abs(arg).max()
743
744 def dot(arg0,arg1):
745 """
746 @param arg0:
747 @param arg1:
748 """
749 if isinstance(arg0,escript.Data):
750 return arg0.dot(arg1)
751 elif isinstance(arg1,escript.Data):
752 return arg1.dot(arg0)
753 else:
754 return numarray.dot(arg0,arg1)
755
756 def kronecker(d):
757 if hasattr(d,"getDim"):
758 return numarray.identity(d.getDim())*1.
759 else:
760 return numarray.identity(d)*1.
761
762 def unit(i,d):
763 """
764 Return a unit vector of dimension d with nonzero index i.
765
766 @param d: dimension
767 @param i: index
768 """
769 e = numarray.zeros((d,),numarray.Float)
770 e[i] = 1.0
771 return e

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26