/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Contents of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 153 - (show annotations)
Tue Oct 25 01:51:20 2005 UTC (14 years ago) by jgs
Original Path: trunk/esys2/escript/py_src/util.py
File MIME type: text/x-python
File size: 21502 byte(s)
Merge of development branch dev-02 back to main trunk on 2005-10-25

1 # $Id$
2
3 ## @file util.py
4
5 """
6 Utility functions for escript
7
8 @todo:
9
10 - binary operations @ (@=+,-,*,/,**)::
11 (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
12 (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
13 (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
14 - implementation of outer::
15 outer(a,b)[:,*]=a[:]*b[*]
16 - trace::
17 trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
18 """
19
20 import numarray
21 import escript
22 import symbols
23 import os
24
25 #=========================================================
26 # some little helpers
27 #=========================================================
28 def _testForZero(arg):
29 """
30 Returns True is arg is considered to be zero.
31 """
32 if isinstance(arg,int):
33 return not arg>0
34 elif isinstance(arg,float):
35 return not arg>0.
36 elif isinstance(arg,numarray.NumArray):
37 a=abs(arg)
38 while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
39 return not a>0
40 else:
41 return False
42
43 #=========================================================
44 def saveVTK(filename,domain=None,**data):
45 """
46 writes a L{Data} objects into a files using the the VTK XML file format.
47
48 Example:
49
50 tmp=Scalar(..)
51 v=Vector(..)
52 saveVTK("solution.xml",temperature=tmp,velovity=v)
53
54 tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity"
55
56 @param filename: file name of the output file
57 @type filename: C(str}
58 @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
59 @type domain: L{escript.Domain}
60 @keyword <name>: writes the assigned value to the VTK file using <name> as identifier.
61 @type <name>: L{Data} object.
62 @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
63 """
64 if domain==None:
65 for i in data.keys():
66 if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
67 if domain==None:
68 raise ValueError,"no domain detected."
69 else:
70 domain.saveVTK(filename,data)
71 #=========================================================
72 def saveDX(filename,domain=None,**data):
73 """
74 writes a L{Data} objects into a files using the the DX file format.
75
76 Example:
77
78 tmp=Scalar(..)
79 v=Vector(..)
80 saveDX("solution.dx",temperature=tmp,velovity=v)
81
82 tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity".
83
84 @param filename: file name of the output file
85 @type filename: C(str}
86 @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
87 @type domain: L{escript.Domain}
88 @keyword <name>: writes the assigned value to the DX file using <name> as identifier. The identifier can be used to select the data set when data are imported into DX.
89 @type <name>: L{Data} object.
90 @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
91 """
92 if domain==None:
93 for i in data.keys():
94 if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
95 if domain==None:
96 raise ValueError,"no domain detected."
97 else:
98 domain.saveDX(filename,data)
99 #=========================================================
100
101 def exp(arg):
102 """
103 Applies the exponential function to arg.
104
105 @param arg: argument
106 """
107 if isinstance(arg,symbols.Symbol):
108 return symbols.Exp_Symbol(arg)
109 elif hasattr(arg,"exp"):
110 return arg.exp()
111 else:
112 return numarray.exp(arg)
113
114 def sqrt(arg):
115 """
116 Applies the squre root function to arg.
117
118 @param arg: argument
119 """
120 if isinstance(arg,symbols.Symbol):
121 return symbols.Sqrt_Symbol(arg)
122 elif hasattr(arg,"sqrt"):
123 return arg.sqrt()
124 else:
125 return numarray.sqrt(arg)
126
127 def log(arg):
128 """
129 Applies the logarithmic function base 10 to arg.
130
131 @param arg: argument
132 """
133 if isinstance(arg,symbols.Symbol):
134 return symbols.Log_Symbol(arg)
135 elif hasattr(arg,"log"):
136 return arg.log()
137 else:
138 return numarray.log10(arg)
139
140 def ln(arg):
141 """
142 Applies the natural logarithmic function to arg.
143
144 @param arg: argument
145 """
146 if isinstance(arg,symbols.Symbol):
147 return symbols.Ln_Symbol(arg)
148 elif hasattr(arg,"ln"):
149 return arg.ln()
150 else:
151 return numarray.log(arg)
152
153 def sin(arg):
154 """
155 Applies the sin function to arg.
156
157 @param arg: argument
158 """
159 if isinstance(arg,symbols.Symbol):
160 return symbols.Sin_Symbol(arg)
161 elif hasattr(arg,"sin"):
162 return arg.sin()
163 else:
164 return numarray.sin(arg)
165
166 def cos(arg):
167 """
168 Applies the cos function to arg.
169
170 @param arg: argument
171 """
172 if isinstance(arg,symbols.Symbol):
173 return symbols.Cos_Symbol(arg)
174 elif hasattr(arg,"cos"):
175 return arg.cos()
176 else:
177 return numarray.cos(arg)
178
179 def tan(arg):
180 """
181 Applies the tan function to arg.
182
183 @param arg: argument
184 """
185 if isinstance(arg,symbols.Symbol):
186 return symbols.Tan_Symbol(arg)
187 elif hasattr(arg,"tan"):
188 return arg.tan()
189 else:
190 return numarray.tan(arg)
191
192 def asin(arg):
193 """
194 Applies the asin function to arg.
195
196 @param arg: argument
197 """
198 if isinstance(arg,symbols.Symbol):
199 return symbols.Asin_Symbol(arg)
200 elif hasattr(arg,"asin"):
201 return arg.asin()
202 else:
203 return numarray.asin(arg)
204
205 def acos(arg):
206 """
207 Applies the acos function to arg.
208
209 @param arg: argument
210 """
211 if isinstance(arg,symbols.Symbol):
212 return symbols.Acos_Symbol(arg)
213 elif hasattr(arg,"acos"):
214 return arg.acos()
215 else:
216 return numarray.acos(arg)
217
218 def atan(arg):
219 """
220 Applies the atan function to arg.
221
222 @param arg: argument
223 """
224 if isinstance(arg,symbols.Symbol):
225 return symbols.Atan_Symbol(arg)
226 elif hasattr(arg,"atan"):
227 return arg.atan()
228 else:
229 return numarray.atan(arg)
230
231 def sinh(arg):
232 """
233 Applies the sinh function to arg.
234
235 @param arg: argument
236 """
237 if isinstance(arg,symbols.Symbol):
238 return symbols.Sinh_Symbol(arg)
239 elif hasattr(arg,"sinh"):
240 return arg.sinh()
241 else:
242 return numarray.sinh(arg)
243
244 def cosh(arg):
245 """
246 Applies the cosh function to arg.
247
248 @param arg: argument
249 """
250 if isinstance(arg,symbols.Symbol):
251 return symbols.Cosh_Symbol(arg)
252 elif hasattr(arg,"cosh"):
253 return arg.cosh()
254 else:
255 return numarray.cosh(arg)
256
257 def tanh(arg):
258 """
259 Applies the tanh function to arg.
260
261 @param arg: argument
262 """
263 if isinstance(arg,symbols.Symbol):
264 return symbols.Tanh_Symbol(arg)
265 elif hasattr(arg,"tanh"):
266 return arg.tanh()
267 else:
268 return numarray.tanh(arg)
269
270 def asinh(arg):
271 """
272 Applies the asinh function to arg.
273
274 @param arg: argument
275 """
276 if isinstance(arg,symbols.Symbol):
277 return symbols.Asinh_Symbol(arg)
278 elif hasattr(arg,"asinh"):
279 return arg.asinh()
280 else:
281 return numarray.asinh(arg)
282
283 def acosh(arg):
284 """
285 Applies the acosh function to arg.
286
287 @param arg: argument
288 """
289 if isinstance(arg,symbols.Symbol):
290 return symbols.Acosh_Symbol(arg)
291 elif hasattr(arg,"acosh"):
292 return arg.acosh()
293 else:
294 return numarray.acosh(arg)
295
296 def atanh(arg):
297 """
298 Applies the atanh function to arg.
299
300 @param arg: argument
301 """
302 if isinstance(arg,symbols.Symbol):
303 return symbols.Atanh_Symbol(arg)
304 elif hasattr(arg,"atanh"):
305 return arg.atanh()
306 else:
307 return numarray.atanh(arg)
308
309 def sign(arg):
310 """
311 Applies the sign function to arg.
312
313 @param arg: argument
314 """
315 if isinstance(arg,symbols.Symbol):
316 return symbols.Sign_Symbol(arg)
317 elif hasattr(arg,"sign"):
318 return arg.sign()
319 else:
320 return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
321 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
322
323 def maxval(arg):
324 """
325 Returns the maximum value of argument arg.
326
327 @param arg: argument
328 """
329 if isinstance(arg,symbols.Symbol):
330 return symbols.Max_Symbol(arg)
331 elif hasattr(arg,"maxval"):
332 return arg.maxval()
333 elif hasattr(arg,"max"):
334 return arg.max()
335 else:
336 return arg
337
338 def minval(arg):
339 """
340 Returns the minimum value of argument arg.
341
342 @param arg: argument
343 """
344 if isinstance(arg,symbols.Symbol):
345 return symbols.Min_Symbol(arg)
346 elif hasattr(arg,"maxval"):
347 return arg.minval()
348 elif hasattr(arg,"min"):
349 return arg.min()
350 else:
351 return arg
352
353 def wherePositive(arg):
354 """
355 Returns the positive values of argument arg.
356
357 @param arg: argument
358 """
359 if _testForZero(arg):
360 return 0
361 elif isinstance(arg,symbols.Symbol):
362 return symbols.WherePositive_Symbol(arg)
363 elif hasattr(arg,"wherePositive"):
364 return arg.minval()
365 elif hasattr(arg,"wherePositive"):
366 numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
367 else:
368 if arg>0:
369 return 1.
370 else:
371 return 0.
372
373 def whereNegative(arg):
374 """
375 Returns the negative values of argument arg.
376
377 @param arg: argument
378 """
379 if _testForZero(arg):
380 return 0
381 elif isinstance(arg,symbols.Symbol):
382 return symbols.WhereNegative_Symbol(arg)
383 elif hasattr(arg,"whereNegative"):
384 return arg.whereNegative()
385 elif hasattr(arg,"shape"):
386 numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
387 else:
388 if arg<0:
389 return 1.
390 else:
391 return 0.
392
393 def maximum(arg0,arg1):
394 """
395 Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned.
396 """
397 m=whereNegative(arg0-arg1)
398 return m*arg1+(1.-m)*arg0
399
400 def minimum(arg0,arg1):
401 """
402 Return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned.
403 """
404 m=whereNegative(arg0-arg1)
405 return m*arg0+(1.-m)*arg1
406
407 def outer(arg0,arg1):
408 if _testForZero(arg0) or _testForZero(arg1):
409 return 0
410 else:
411 if isinstance(arg0,symbols.Symbol) or isinstance(arg1,symbols.Symbol):
412 return symbols.Outer_Symbol(arg0,arg1)
413 elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
414 return arg0*arg1
415 elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
416 return numarray.outer(arg0,arg1)
417 else:
418 if arg0.getRank()==1 and arg1.getRank()==1:
419 out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
420 for i in range(arg0.getShape()[0]):
421 for j in range(arg1.getShape()[0]):
422 out[i,j]=arg0[i]*arg1[j]
423 return out
424 else:
425 raise ValueError,"outer is not fully implemented yet."
426
427 def interpolate(arg,where):
428 """
429 Interpolates the function into the FunctionSpace where.
430
431 @param arg: interpolant
432 @param where: FunctionSpace to interpolate to
433 """
434 if _testForZero(arg):
435 return 0
436 elif isinstance(arg,symbols.Symbol):
437 return symbols.Interpolated_Symbol(arg,where)
438 else:
439 return escript.Data(arg,where)
440
441 def div(arg,where=None):
442 """
443 Returns the divergence of arg at where.
444
445 @param arg: Data object representing the function which gradient to
446 be calculated.
447 @param where: FunctionSpace in which the gradient will be calculated.
448 If not present or C{None} an appropriate default is used.
449 """
450 return trace(grad(arg,where))
451
452 def jump(arg):
453 """
454 Returns the jump of arg across a continuity.
455
456 @param arg: Data object representing the function which gradient
457 to be calculated.
458 """
459 d=arg.getDomain()
460 return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())
461
462
463 def grad(arg,where=None):
464 """
465 Returns the spatial gradient of arg at where.
466
467 @param arg: Data object representing the function which gradient
468 to be calculated.
469 @param where: FunctionSpace in which the gradient will be calculated.
470 If not present or C{None} an appropriate default is used.
471 """
472 if _testForZero(arg):
473 return 0
474 elif isinstance(arg,symbols.Symbol):
475 return symbols.Grad_Symbol(arg,where)
476 elif hasattr(arg,"grad"):
477 if where==None:
478 return arg.grad()
479 else:
480 return arg.grad(where)
481 else:
482 return arg*0.
483
484 def integrate(arg,where=None):
485 """
486 Return the integral if the function represented by Data object arg over
487 its domain.
488
489 @param arg: Data object representing the function which is integrated.
490 @param where: FunctionSpace in which the integral is calculated.
491 If not present or C{None} an appropriate default is used.
492 """
493 if _testForZero(arg):
494 return 0
495 elif isinstance(arg,symbols.Symbol):
496 return symbols.Integral_Symbol(arg,where)
497 else:
498 if not where==None: arg=escript.Data(arg,where)
499 if arg.getRank()==0:
500 return arg.integrate()[0]
501 else:
502 return arg.integrate()
503
504 #=============================
505 #
506 # wrapper for various functions: if the argument has attribute the function name
507 # as an argument it calls the corresponding methods. Otherwise the corresponding
508 # numarray function is called.
509
510 # functions involving the underlying Domain:
511
512
513 # functions returning Data objects:
514
515 def transpose(arg,axis=None):
516 """
517 Returns the transpose of the Data object arg.
518
519 @param arg:
520 """
521 if axis==None:
522 r=0
523 if hasattr(arg,"getRank"): r=arg.getRank()
524 if hasattr(arg,"rank"): r=arg.rank
525 axis=r/2
526 if isinstance(arg,symbols.Symbol):
527 return symbols.Transpose_Symbol(arg,axis=r)
528 if isinstance(arg,escript.Data):
529 # hack for transpose
530 r=arg.getRank()
531 if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
532 s=arg.getShape()
533 out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
534 for i in range(s[0]):
535 for j in range(s[1]):
536 out[j,i]=arg[i,j]
537 return out
538 # end hack for transpose
539 return arg.transpose(axis)
540 else:
541 return numarray.transpose(arg,axis=axis)
542
543 def trace(arg,axis0=0,axis1=1):
544 """
545 Return
546
547 @param arg:
548 """
549 if isinstance(arg,symbols.Symbol):
550 s=list(arg.getShape())
551 s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
552 return symbols.Trace_Symbol(arg,axis0=axis0,axis1=axis1)
553 elif isinstance(arg,escript.Data):
554 # hack for trace
555 s=arg.getShape()
556 if s[axis0]!=s[axis1]:
557 raise ValueError,"illegal axis in trace"
558 out=escript.Scalar(0.,arg.getFunctionSpace())
559 for i in range(s[axis0]):
560 out+=arg[i,i]
561 return out
562 # end hack for trace
563 else:
564 return numarray.trace(arg,axis0=axis0,axis1=axis1)
565
566 def length(arg):
567 """
568
569 @param arg:
570 """
571 if isinstance(arg,escript.Data):
572 if arg.isEmpty(): return escript.Data()
573 if arg.getRank()==0:
574 return abs(arg)
575 elif arg.getRank()==1:
576 out=escript.Scalar(0,arg.getFunctionSpace())
577 for i in range(arg.getShape()[0]):
578 out+=arg[i]**2
579 return sqrt(out)
580 elif arg.getRank()==2:
581 out=escript.Scalar(0,arg.getFunctionSpace())
582 for i in range(arg.getShape()[0]):
583 for j in range(arg.getShape()[1]):
584 out+=arg[i,j]**2
585 return sqrt(out)
586 elif arg.getRank()==3:
587 out=escript.Scalar(0,arg.getFunctionSpace())
588 for i in range(arg.getShape()[0]):
589 for j in range(arg.getShape()[1]):
590 for k in range(arg.getShape()[2]):
591 out+=arg[i,j,k]**2
592 return sqrt(out)
593 elif arg.getRank()==4:
594 out=escript.Scalar(0,arg.getFunctionSpace())
595 for i in range(arg.getShape()[0]):
596 for j in range(arg.getShape()[1]):
597 for k in range(arg.getShape()[2]):
598 for l in range(arg.getShape()[3]):
599 out+=arg[i,j,k,l]**2
600 return sqrt(out)
601 else:
602 raise SystemError,"length is not been fully implemented yet"
603 # return arg.length()
604 elif isinstance(arg,float):
605 return abs(arg)
606 else:
607 return sqrt((arg**2).sum())
608
609 def deviator(arg):
610 """
611 @param arg:
612 """
613 if isinstance(arg,escript.Data):
614 shape=arg.getShape()
615 else:
616 shape=arg.shape
617 if len(shape)!=2:
618 raise ValueError,"Deviator requires rank 2 object"
619 if shape[0]!=shape[1]:
620 raise ValueError,"Deviator requires a square matrix"
621 return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
622
623 def inner(arg0,arg1):
624 """
625 @param arg0:
626 @param arg1:
627 """
628 if isinstance(arg0,escript.Data):
629 arg=arg0
630 else:
631 arg=arg1
632
633 out=escript.Scalar(0,arg.getFunctionSpace())
634 if arg.getRank()==0:
635 return arg0*arg1
636 elif arg.getRank()==1:
637 out=escript.Scalar(0,arg.getFunctionSpace())
638 for i in range(arg.getShape()[0]):
639 out+=arg0[i]*arg1[i]
640 elif arg.getRank()==2:
641 out=escript.Scalar(0,arg.getFunctionSpace())
642 for i in range(arg.getShape()[0]):
643 for j in range(arg.getShape()[1]):
644 out+=arg0[i,j]*arg1[i,j]
645 elif arg.getRank()==3:
646 out=escript.Scalar(0,arg.getFunctionSpace())
647 for i in range(arg.getShape()[0]):
648 for j in range(arg.getShape()[1]):
649 for k in range(arg.getShape()[2]):
650 out+=arg0[i,j,k]*arg1[i,j,k]
651 elif arg.getRank()==4:
652 out=escript.Scalar(0,arg.getFunctionSpace())
653 for i in range(arg.getShape()[0]):
654 for j in range(arg.getShape()[1]):
655 for k in range(arg.getShape()[2]):
656 for l in range(arg.getShape()[3]):
657 out+=arg0[i,j,k,l]*arg1[i,j,k,l]
658 else:
659 raise SystemError,"inner is not been implemented yet"
660 return out
661
662 def tensormult(arg0,arg1):
663 # check LinearPDE!!!!
664 raise SystemError,"tensormult is not implemented yet!"
665
666 def matrixmult(arg0,arg1):
667
668 if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
669 numarray.matrixmult(arg0,arg1)
670 else:
671 # escript.matmult(arg0,arg1)
672 if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
673 arg0=escript.Data(arg0,arg1.getFunctionSpace())
674 elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
675 arg1=escript.Data(arg1,arg0.getFunctionSpace())
676 if arg0.getRank()==2 and arg1.getRank()==1:
677 out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
678 for i in range(arg0.getShape()[0]):
679 for j in range(arg0.getShape()[1]):
680 # uses Data object slicing, plus Data * and += operators
681 out[i]+=arg0[i,j]*arg1[j]
682 return out
683 elif arg0.getRank()==1 and arg1.getRank()==1:
684 return inner(arg0,arg1)
685 else:
686 raise SystemError,"matrixmult is not fully implemented yet!"
687
688 #=========================================================
689 # reduction operations:
690 #=========================================================
691 def sum(arg):
692 """
693 @param arg:
694 """
695 return arg.sum()
696
697 def sup(arg):
698 """
699 @param arg:
700 """
701 if isinstance(arg,escript.Data):
702 return arg.sup()
703 elif isinstance(arg,float) or isinstance(arg,int):
704 return arg
705 else:
706 return arg.max()
707
708 def inf(arg):
709 """
710 @param arg:
711 """
712 if isinstance(arg,escript.Data):
713 return arg.inf()
714 elif isinstance(arg,float) or isinstance(arg,int):
715 return arg
716 else:
717 return arg.min()
718
719 def L2(arg):
720 """
721 Returns the L2-norm of the argument
722
723 @param arg:
724 """
725 if isinstance(arg,escript.Data):
726 return arg.L2()
727 elif isinstance(arg,float) or isinstance(arg,int):
728 return abs(arg)
729 else:
730 return numarry.sqrt(dot(arg,arg))
731
732 def Lsup(arg):
733 """
734 @param arg:
735 """
736 if isinstance(arg,escript.Data):
737 return arg.Lsup()
738 elif isinstance(arg,float) or isinstance(arg,int):
739 return abs(arg)
740 else:
741 return numarray.abs(arg).max()
742
743 def dot(arg0,arg1):
744 """
745 @param arg0:
746 @param arg1:
747 """
748 if isinstance(arg0,escript.Data):
749 return arg0.dot(arg1)
750 elif isinstance(arg1,escript.Data):
751 return arg1.dot(arg0)
752 else:
753 return numarray.dot(arg0,arg1)
754
755 def kronecker(d):
756 if hasattr(d,"getDim"):
757 return numarray.identity(d.getDim())*1.
758 else:
759 return numarray.identity(d)*1.
760
761 def unit(i,d):
762 """
763 Return a unit vector of dimension d with nonzero index i.
764
765 @param d: dimension
766 @param i: index
767 """
768 e = numarray.zeros((d,),numarray.Float)
769 e[i] = 1.0
770 return e

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26