/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Annotation of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 155 - (hide annotations)
Wed Nov 9 02:02:19 2005 UTC (14 years, 1 month ago) by jgs
File MIME type: text/x-python
File size: 21503 byte(s)
move all directories from trunk/esys2 into trunk and remove esys2

1 jgs 82 # $Id$
2    
3     ## @file util.py
4    
5     """
6 jgs 149 Utility functions for escript
7 jgs 123
8 jgs 149 @todo:
9 jgs 123
10 jgs 149 - binary operations @ (@=+,-,*,/,**)::
11     (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
12     (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
13     (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
14     - implementation of outer::
15     outer(a,b)[:,*]=a[:]*b[*]
16     - trace::
17     trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
18 jgs 82 """
19    
20     import numarray
21 jgs 102 import escript
22 jgs 150 import symbols
23     import os
24 jgs 124
25 jgs 123 #=========================================================
26     # some little helpers
27     #=========================================================
28     def _testForZero(arg):
29 jgs 149 """
30     Returns True is arg is considered to be zero.
31     """
32 jgs 123 if isinstance(arg,int):
33     return not arg>0
34     elif isinstance(arg,float):
35     return not arg>0.
36 jgs 150 elif isinstance(arg,numarray.NumArray):
37 jgs 123 a=abs(arg)
38     while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
39     return not a>0
40     else:
41     return False
42    
43 jgs 150 #=========================================================
44 jgs 153 def saveVTK(filename,domain=None,**data):
45 jgs 150 """
46 jgs 153 writes a L{Data} objects into a files using the the VTK XML file format.
47 jgs 123
48 jgs 153 Example:
49 jgs 123
50 jgs 153 tmp=Scalar(..)
51     v=Vector(..)
52     saveVTK("solution.xml",temperature=tmp,velovity=v)
53 jgs 150
54 jgs 153 tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity"
55    
56     @param filename: file name of the output file
57     @type filename: C(str}
58     @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
59     @type domain: L{escript.Domain}
60     @keyword <name>: writes the assigned value to the VTK file using <name> as identifier.
61     @type <name>: L{Data} object.
62     @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
63 jgs 149 """
64 jgs 153 if domain==None:
65     for i in data.keys():
66     if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
67     if domain==None:
68     raise ValueError,"no domain detected."
69     else:
70     domain.saveVTK(filename,data)
71 jgs 154
72 jgs 150 #=========================================================
73 jgs 153 def saveDX(filename,domain=None,**data):
74 jgs 149 """
75 jgs 153 writes a L{Data} objects into a files using the the DX file format.
76 jgs 149
77 jgs 153 Example:
78 jgs 150
79 jgs 153 tmp=Scalar(..)
80     v=Vector(..)
81     saveDX("solution.dx",temperature=tmp,velovity=v)
82 jgs 150
83 jgs 153 tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity".
84    
85     @param filename: file name of the output file
86     @type filename: C(str}
87     @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
88     @type domain: L{escript.Domain}
89     @keyword <name>: writes the assigned value to the DX file using <name> as identifier. The identifier can be used to select the data set when data are imported into DX.
90     @type <name>: L{Data} object.
91     @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
92 jgs 149 """
93 jgs 153 if domain==None:
94     for i in data.keys():
95     if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
96     if domain==None:
97     raise ValueError,"no domain detected."
98     else:
99     domain.saveDX(filename,data)
100 jgs 154
101 jgs 123 #=========================================================
102    
103     def exp(arg):
104 jgs 82 """
105 jgs 149 Applies the exponential function to arg.
106    
107     @param arg: argument
108 jgs 123 """
109 jgs 150 if isinstance(arg,symbols.Symbol):
110     return symbols.Exp_Symbol(arg)
111 jgs 123 elif hasattr(arg,"exp"):
112     return arg.exp()
113 jgs 108 else:
114 jgs 123 return numarray.exp(arg)
115 jgs 82
116 jgs 123 def sqrt(arg):
117 jgs 82 """
118 jgs 149 Applies the squre root function to arg.
119    
120     @param arg: argument
121 jgs 123 """
122 jgs 150 if isinstance(arg,symbols.Symbol):
123     return symbols.Sqrt_Symbol(arg)
124 jgs 123 elif hasattr(arg,"sqrt"):
125     return arg.sqrt()
126     else:
127     return numarray.sqrt(arg)
128 jgs 82
129 jgs 123 def log(arg):
130 jgs 82 """
131 jgs 153 Applies the logarithmic function base 10 to arg.
132 jgs 149
133     @param arg: argument
134 jgs 123 """
135 jgs 150 if isinstance(arg,symbols.Symbol):
136     return symbols.Log_Symbol(arg)
137 jgs 123 elif hasattr(arg,"log"):
138     return arg.log()
139 jgs 108 else:
140 jgs 153 return numarray.log10(arg)
141 jgs 82
142 jgs 124 def ln(arg):
143     """
144 jgs 149 Applies the natural logarithmic function to arg.
145    
146     @param arg: argument
147 jgs 124 """
148 jgs 150 if isinstance(arg,symbols.Symbol):
149     return symbols.Ln_Symbol(arg)
150 jgs 124 elif hasattr(arg,"ln"):
151 jgs 153 return arg.ln()
152 jgs 124 else:
153     return numarray.log(arg)
154    
155 jgs 123 def sin(arg):
156 jgs 82 """
157 jgs 149 Applies the sin function to arg.
158    
159     @param arg: argument
160 jgs 123 """
161 jgs 150 if isinstance(arg,symbols.Symbol):
162     return symbols.Sin_Symbol(arg)
163 jgs 123 elif hasattr(arg,"sin"):
164     return arg.sin()
165     else:
166     return numarray.sin(arg)
167 jgs 82
168 jgs 123 def cos(arg):
169 jgs 82 """
170 jgs 149 Applies the cos function to arg.
171    
172     @param arg: argument
173 jgs 123 """
174 jgs 150 if isinstance(arg,symbols.Symbol):
175     return symbols.Cos_Symbol(arg)
176 jgs 123 elif hasattr(arg,"cos"):
177     return arg.cos()
178 jgs 82 else:
179 jgs 123 return numarray.cos(arg)
180 jgs 82
181 jgs 123 def tan(arg):
182 jgs 82 """
183 jgs 149 Applies the tan function to arg.
184    
185     @param arg: argument
186 jgs 123 """
187 jgs 150 if isinstance(arg,symbols.Symbol):
188     return symbols.Tan_Symbol(arg)
189 jgs 123 elif hasattr(arg,"tan"):
190     return arg.tan()
191     else:
192     return numarray.tan(arg)
193 jgs 82
194 jgs 150 def asin(arg):
195     """
196     Applies the asin function to arg.
197 jgs 123
198 jgs 150 @param arg: argument
199     """
200     if isinstance(arg,symbols.Symbol):
201     return symbols.Asin_Symbol(arg)
202     elif hasattr(arg,"asin"):
203     return arg.asin()
204     else:
205     return numarray.asin(arg)
206    
207     def acos(arg):
208     """
209     Applies the acos function to arg.
210    
211     @param arg: argument
212     """
213     if isinstance(arg,symbols.Symbol):
214     return symbols.Acos_Symbol(arg)
215     elif hasattr(arg,"acos"):
216     return arg.acos()
217     else:
218     return numarray.acos(arg)
219    
220     def atan(arg):
221     """
222     Applies the atan function to arg.
223    
224     @param arg: argument
225     """
226     if isinstance(arg,symbols.Symbol):
227     return symbols.Atan_Symbol(arg)
228     elif hasattr(arg,"atan"):
229     return arg.atan()
230     else:
231     return numarray.atan(arg)
232    
233     def sinh(arg):
234     """
235     Applies the sinh function to arg.
236    
237     @param arg: argument
238     """
239     if isinstance(arg,symbols.Symbol):
240     return symbols.Sinh_Symbol(arg)
241     elif hasattr(arg,"sinh"):
242     return arg.sinh()
243     else:
244     return numarray.sinh(arg)
245    
246     def cosh(arg):
247     """
248     Applies the cosh function to arg.
249    
250     @param arg: argument
251     """
252     if isinstance(arg,symbols.Symbol):
253     return symbols.Cosh_Symbol(arg)
254     elif hasattr(arg,"cosh"):
255     return arg.cosh()
256     else:
257     return numarray.cosh(arg)
258    
259     def tanh(arg):
260     """
261     Applies the tanh function to arg.
262    
263     @param arg: argument
264     """
265     if isinstance(arg,symbols.Symbol):
266     return symbols.Tanh_Symbol(arg)
267     elif hasattr(arg,"tanh"):
268     return arg.tanh()
269     else:
270     return numarray.tanh(arg)
271    
272     def asinh(arg):
273     """
274     Applies the asinh function to arg.
275    
276     @param arg: argument
277     """
278     if isinstance(arg,symbols.Symbol):
279     return symbols.Asinh_Symbol(arg)
280     elif hasattr(arg,"asinh"):
281     return arg.asinh()
282     else:
283     return numarray.asinh(arg)
284    
285     def acosh(arg):
286     """
287     Applies the acosh function to arg.
288    
289     @param arg: argument
290     """
291     if isinstance(arg,symbols.Symbol):
292     return symbols.Acosh_Symbol(arg)
293     elif hasattr(arg,"acosh"):
294     return arg.acosh()
295     else:
296     return numarray.acosh(arg)
297    
298     def atanh(arg):
299     """
300     Applies the atanh function to arg.
301    
302     @param arg: argument
303     """
304     if isinstance(arg,symbols.Symbol):
305     return symbols.Atanh_Symbol(arg)
306     elif hasattr(arg,"atanh"):
307     return arg.atanh()
308     else:
309     return numarray.atanh(arg)
310    
311 jgs 123 def sign(arg):
312 jgs 82 """
313 jgs 149 Applies the sign function to arg.
314    
315     @param arg: argument
316 jgs 123 """
317 jgs 150 if isinstance(arg,symbols.Symbol):
318     return symbols.Sign_Symbol(arg)
319 jgs 123 elif hasattr(arg,"sign"):
320     return arg.sign()
321 jgs 82 else:
322 jgs 123 return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
323     numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
324 jgs 82
325 jgs 123 def maxval(arg):
326 jgs 82 """
327 jgs 149 Returns the maximum value of argument arg.
328    
329     @param arg: argument
330 jgs 123 """
331 jgs 150 if isinstance(arg,symbols.Symbol):
332     return symbols.Max_Symbol(arg)
333 jgs 123 elif hasattr(arg,"maxval"):
334     return arg.maxval()
335     elif hasattr(arg,"max"):
336     return arg.max()
337     else:
338     return arg
339 jgs 82
340 jgs 123 def minval(arg):
341 jgs 82 """
342 jgs 149 Returns the minimum value of argument arg.
343    
344     @param arg: argument
345 jgs 123 """
346 jgs 150 if isinstance(arg,symbols.Symbol):
347     return symbols.Min_Symbol(arg)
348 jgs 123 elif hasattr(arg,"maxval"):
349     return arg.minval()
350     elif hasattr(arg,"min"):
351     return arg.min()
352 jgs 82 else:
353 jgs 123 return arg
354 jgs 82
355 jgs 123 def wherePositive(arg):
356 jgs 82 """
357 jgs 149 Returns the positive values of argument arg.
358    
359     @param arg: argument
360 jgs 123 """
361     if _testForZero(arg):
362     return 0
363 jgs 150 elif isinstance(arg,symbols.Symbol):
364     return symbols.WherePositive_Symbol(arg)
365 jgs 123 elif hasattr(arg,"wherePositive"):
366     return arg.minval()
367     elif hasattr(arg,"wherePositive"):
368     numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
369     else:
370     if arg>0:
371     return 1.
372     else:
373     return 0.
374 jgs 82
375 jgs 123 def whereNegative(arg):
376 jgs 82 """
377 jgs 149 Returns the negative values of argument arg.
378    
379     @param arg: argument
380 jgs 123 """
381     if _testForZero(arg):
382     return 0
383 jgs 150 elif isinstance(arg,symbols.Symbol):
384     return symbols.WhereNegative_Symbol(arg)
385 jgs 123 elif hasattr(arg,"whereNegative"):
386     return arg.whereNegative()
387     elif hasattr(arg,"shape"):
388     numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
389 jgs 88 else:
390 jgs 123 if arg<0:
391     return 1.
392     else:
393     return 0.
394 jgs 82
395 jgs 147 def maximum(arg0,arg1):
396 jgs 149 """
397     Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned.
398     """
399 jgs 147 m=whereNegative(arg0-arg1)
400     return m*arg1+(1.-m)*arg0
401    
402     def minimum(arg0,arg1):
403 jgs 149 """
404     Return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned.
405     """
406 jgs 147 m=whereNegative(arg0-arg1)
407     return m*arg0+(1.-m)*arg1
408    
409 jgs 123 def outer(arg0,arg1):
410     if _testForZero(arg0) or _testForZero(arg1):
411     return 0
412     else:
413 jgs 150 if isinstance(arg0,symbols.Symbol) or isinstance(arg1,symbols.Symbol):
414     return symbols.Outer_Symbol(arg0,arg1)
415 jgs 123 elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
416     return arg0*arg1
417     elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
418     return numarray.outer(arg0,arg1)
419     else:
420     if arg0.getRank()==1 and arg1.getRank()==1:
421     out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
422     for i in range(arg0.getShape()[0]):
423     for j in range(arg1.getShape()[0]):
424     out[i,j]=arg0[i]*arg1[j]
425     return out
426     else:
427     raise ValueError,"outer is not fully implemented yet."
428    
429     def interpolate(arg,where):
430 jgs 82 """
431 jgs 149 Interpolates the function into the FunctionSpace where.
432 jgs 82
433 jgs 149 @param arg: interpolant
434     @param where: FunctionSpace to interpolate to
435 jgs 82 """
436 jgs 123 if _testForZero(arg):
437     return 0
438 jgs 150 elif isinstance(arg,symbols.Symbol):
439     return symbols.Interpolated_Symbol(arg,where)
440 jgs 82 else:
441 jgs 123 return escript.Data(arg,where)
442 jgs 82
443 jgs 147 def div(arg,where=None):
444     """
445 jgs 149 Returns the divergence of arg at where.
446 jgs 147
447 jgs 149 @param arg: Data object representing the function which gradient to
448     be calculated.
449     @param where: FunctionSpace in which the gradient will be calculated.
450     If not present or C{None} an appropriate default is used.
451 jgs 147 """
452     return trace(grad(arg,where))
453    
454 jgs 149 def jump(arg):
455     """
456     Returns the jump of arg across a continuity.
457    
458     @param arg: Data object representing the function which gradient
459     to be calculated.
460     """
461     d=arg.getDomain()
462     return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())
463    
464    
465 jgs 123 def grad(arg,where=None):
466 jgs 102 """
467 jgs 149 Returns the spatial gradient of arg at where.
468 jgs 102
469 jgs 149 @param arg: Data object representing the function which gradient
470     to be calculated.
471     @param where: FunctionSpace in which the gradient will be calculated.
472     If not present or C{None} an appropriate default is used.
473 jgs 102 """
474 jgs 123 if _testForZero(arg):
475     return 0
476 jgs 150 elif isinstance(arg,symbols.Symbol):
477     return symbols.Grad_Symbol(arg,where)
478 jgs 123 elif hasattr(arg,"grad"):
479     if where==None:
480     return arg.grad()
481     else:
482     return arg.grad(where)
483 jgs 102 else:
484 jgs 123 return arg*0.
485 jgs 102
486 jgs 123 def integrate(arg,where=None):
487 jgs 82 """
488 jgs 149 Return the integral if the function represented by Data object arg over
489     its domain.
490 jgs 82
491 jgs 123 @param arg: Data object representing the function which is integrated.
492 jgs 149 @param where: FunctionSpace in which the integral is calculated.
493     If not present or C{None} an appropriate default is used.
494 jgs 82 """
495 jgs 123 if _testForZero(arg):
496     return 0
497 jgs 150 elif isinstance(arg,symbols.Symbol):
498     return symbols.Integral_Symbol(arg,where)
499 jgs 123 else:
500     if not where==None: arg=escript.Data(arg,where)
501     if arg.getRank()==0:
502     return arg.integrate()[0]
503     else:
504     return arg.integrate()
505 jgs 82
506 jgs 123 #=============================
507     #
508     # wrapper for various functions: if the argument has attribute the function name
509 jgs 124 # as an argument it calls the corresponding methods. Otherwise the corresponding
510     # numarray function is called.
511    
512 jgs 123 # functions involving the underlying Domain:
513    
514    
515     # functions returning Data objects:
516    
517     def transpose(arg,axis=None):
518 jgs 82 """
519 jgs 149 Returns the transpose of the Data object arg.
520 jgs 82
521 jgs 149 @param arg:
522 jgs 82 """
523 jgs 123 if axis==None:
524     r=0
525     if hasattr(arg,"getRank"): r=arg.getRank()
526     if hasattr(arg,"rank"): r=arg.rank
527     axis=r/2
528 jgs 150 if isinstance(arg,symbols.Symbol):
529     return symbols.Transpose_Symbol(arg,axis=r)
530 jgs 102 if isinstance(arg,escript.Data):
531 jgs 123 # hack for transpose
532     r=arg.getRank()
533     if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
534     s=arg.getShape()
535     out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
536     for i in range(s[0]):
537     for j in range(s[1]):
538     out[j,i]=arg[i,j]
539     return out
540     # end hack for transpose
541     return arg.transpose(axis)
542 jgs 102 else:
543 jgs 123 return numarray.transpose(arg,axis=axis)
544 jgs 82
545 jgs 123 def trace(arg,axis0=0,axis1=1):
546 jgs 82 """
547 jgs 149 Return
548 jgs 82
549 jgs 149 @param arg:
550 jgs 82 """
551 jgs 150 if isinstance(arg,symbols.Symbol):
552 jgs 123 s=list(arg.getShape())
553     s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
554 jgs 150 return symbols.Trace_Symbol(arg,axis0=axis0,axis1=axis1)
555 jgs 123 elif isinstance(arg,escript.Data):
556     # hack for trace
557     s=arg.getShape()
558     if s[axis0]!=s[axis1]:
559     raise ValueError,"illegal axis in trace"
560     out=escript.Scalar(0.,arg.getFunctionSpace())
561 jgs 147 for i in range(s[axis0]):
562     out+=arg[i,i]
563 jgs 123 return out
564 jgs 126 # end hack for trace
565 jgs 102 else:
566 jgs 123 return numarray.trace(arg,axis0=axis0,axis1=axis1)
567 jgs 82
568 jgs 102 def length(arg):
569     """
570 jgs 149 @param arg:
571 jgs 102 """
572     if isinstance(arg,escript.Data):
573 jgs 108 if arg.isEmpty(): return escript.Data()
574     if arg.getRank()==0:
575     return abs(arg)
576     elif arg.getRank()==1:
577 jgs 147 out=escript.Scalar(0,arg.getFunctionSpace())
578 jgs 104 for i in range(arg.getShape()[0]):
579 jgs 147 out+=arg[i]**2
580     return sqrt(out)
581 jgs 108 elif arg.getRank()==2:
582 jgs 147 out=escript.Scalar(0,arg.getFunctionSpace())
583 jgs 108 for i in range(arg.getShape()[0]):
584     for j in range(arg.getShape()[1]):
585 jgs 147 out+=arg[i,j]**2
586     return sqrt(out)
587 jgs 108 elif arg.getRank()==3:
588 jgs 147 out=escript.Scalar(0,arg.getFunctionSpace())
589 jgs 108 for i in range(arg.getShape()[0]):
590     for j in range(arg.getShape()[1]):
591     for k in range(arg.getShape()[2]):
592 jgs 147 out+=arg[i,j,k]**2
593     return sqrt(out)
594 jgs 108 elif arg.getRank()==4:
595 jgs 147 out=escript.Scalar(0,arg.getFunctionSpace())
596 jgs 108 for i in range(arg.getShape()[0]):
597     for j in range(arg.getShape()[1]):
598     for k in range(arg.getShape()[2]):
599     for l in range(arg.getShape()[3]):
600 jgs 147 out+=arg[i,j,k,l]**2
601     return sqrt(out)
602 jgs 104 else:
603 jgs 126 raise SystemError,"length is not been fully implemented yet"
604     # return arg.length()
605 jgs 147 elif isinstance(arg,float):
606     return abs(arg)
607 jgs 102 else:
608     return sqrt((arg**2).sum())
609    
610 jgs 113 def deviator(arg):
611     """
612 jgs 149 @param arg:
613 jgs 113 """
614     if isinstance(arg,escript.Data):
615     shape=arg.getShape()
616     else:
617     shape=arg.shape
618     if len(shape)!=2:
619     raise ValueError,"Deviator requires rank 2 object"
620     if shape[0]!=shape[1]:
621     raise ValueError,"Deviator requires a square matrix"
622     return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
623    
624 jgs 123 def inner(arg0,arg1):
625 jgs 113 """
626 jgs 149 @param arg0:
627     @param arg1:
628 jgs 113 """
629 jgs 147 if isinstance(arg0,escript.Data):
630     arg=arg0
631     else:
632     arg=arg1
633    
634     out=escript.Scalar(0,arg.getFunctionSpace())
635     if arg.getRank()==0:
636 jgs 123 return arg0*arg1
637 jgs 147 elif arg.getRank()==1:
638     out=escript.Scalar(0,arg.getFunctionSpace())
639     for i in range(arg.getShape()[0]):
640     out+=arg0[i]*arg1[i]
641     elif arg.getRank()==2:
642     out=escript.Scalar(0,arg.getFunctionSpace())
643     for i in range(arg.getShape()[0]):
644     for j in range(arg.getShape()[1]):
645     out+=arg0[i,j]*arg1[i,j]
646     elif arg.getRank()==3:
647     out=escript.Scalar(0,arg.getFunctionSpace())
648     for i in range(arg.getShape()[0]):
649     for j in range(arg.getShape()[1]):
650     for k in range(arg.getShape()[2]):
651     out+=arg0[i,j,k]*arg1[i,j,k]
652     elif arg.getRank()==4:
653     out=escript.Scalar(0,arg.getFunctionSpace())
654     for i in range(arg.getShape()[0]):
655     for j in range(arg.getShape()[1]):
656     for k in range(arg.getShape()[2]):
657     for l in range(arg.getShape()[3]):
658     out+=arg0[i,j,k,l]*arg1[i,j,k,l]
659 jgs 113 else:
660     raise SystemError,"inner is not been implemented yet"
661 jgs 147 return out
662 jgs 113
663 jgs 149 def tensormult(arg0,arg1):
664     # check LinearPDE!!!!
665     raise SystemError,"tensormult is not implemented yet!"
666    
667 jgs 123 def matrixmult(arg0,arg1):
668 jgs 102
669 jgs 123 if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
670     numarray.matrixmult(arg0,arg1)
671 jgs 102 else:
672 jgs 123 # escript.matmult(arg0,arg1)
673     if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
674     arg0=escript.Data(arg0,arg1.getFunctionSpace())
675     elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
676     arg1=escript.Data(arg1,arg0.getFunctionSpace())
677     if arg0.getRank()==2 and arg1.getRank()==1:
678     out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
679     for i in range(arg0.getShape()[0]):
680     for j in range(arg0.getShape()[1]):
681 jgs 126 # uses Data object slicing, plus Data * and += operators
682 jgs 123 out[i]+=arg0[i,j]*arg1[j]
683     return out
684 jgs 147 elif arg0.getRank()==1 and arg1.getRank()==1:
685     return inner(arg0,arg1)
686 jgs 123 else:
687     raise SystemError,"matrixmult is not fully implemented yet!"
688 jgs 124
689 jgs 123 #=========================================================
690 jgs 102 # reduction operations:
691 jgs 123 #=========================================================
692 jgs 102 def sum(arg):
693     """
694 jgs 149 @param arg:
695 jgs 102 """
696     return arg.sum()
697    
698 jgs 82 def sup(arg):
699     """
700 jgs 149 @param arg:
701 jgs 82 """
702 jgs 102 if isinstance(arg,escript.Data):
703     return arg.sup()
704 jgs 108 elif isinstance(arg,float) or isinstance(arg,int):
705     return arg
706 jgs 102 else:
707     return arg.max()
708 jgs 82
709     def inf(arg):
710     """
711 jgs 149 @param arg:
712 jgs 82 """
713 jgs 102 if isinstance(arg,escript.Data):
714     return arg.inf()
715 jgs 108 elif isinstance(arg,float) or isinstance(arg,int):
716     return arg
717 jgs 102 else:
718     return arg.min()
719 jgs 82
720 jgs 102 def L2(arg):
721 jgs 82 """
722 jgs 149 Returns the L2-norm of the argument
723 jgs 82
724 jgs 149 @param arg:
725 jgs 82 """
726 jgs 108 if isinstance(arg,escript.Data):
727     return arg.L2()
728     elif isinstance(arg,float) or isinstance(arg,int):
729     return abs(arg)
730     else:
731     return numarry.sqrt(dot(arg,arg))
732 jgs 82
733 jgs 102 def Lsup(arg):
734 jgs 82 """
735 jgs 149 @param arg:
736 jgs 82 """
737 jgs 102 if isinstance(arg,escript.Data):
738     return arg.Lsup()
739 jgs 108 elif isinstance(arg,float) or isinstance(arg,int):
740     return abs(arg)
741 jgs 102 else:
742 jgs 149 return numarray.abs(arg).max()
743 jgs 82
744 jgs 123 def dot(arg0,arg1):
745 jgs 117 """
746 jgs 149 @param arg0:
747     @param arg1:
748 jgs 117 """
749 jgs 123 if isinstance(arg0,escript.Data):
750     return arg0.dot(arg1)
751 jgs 102 elif isinstance(arg1,escript.Data):
752 jgs 123 return arg1.dot(arg0)
753 jgs 102 else:
754 jgs 123 return numarray.dot(arg0,arg1)
755 jgs 113
756     def kronecker(d):
757 jgs 122 if hasattr(d,"getDim"):
758 jgs 147 return numarray.identity(d.getDim())*1.
759 jgs 122 else:
760 jgs 147 return numarray.identity(d)*1.
761 jgs 117
762     def unit(i,d):
763     """
764 jgs 149 Return a unit vector of dimension d with nonzero index i.
765    
766     @param d: dimension
767     @param i: index
768 jgs 117 """
769 jgs 123 e = numarray.zeros((d,),numarray.Float)
770 jgs 117 e[i] = 1.0
771     return e

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26