/[escript]/trunk/escript/py_src/util.py
ViewVC logotype

Annotation of /trunk/escript/py_src/util.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 153 - (hide annotations)
Tue Oct 25 01:51:20 2005 UTC (14 years ago) by jgs
Original Path: trunk/esys2/escript/py_src/util.py
File MIME type: text/x-python
File size: 21502 byte(s)
Merge of development branch dev-02 back to main trunk on 2005-10-25

1 jgs 82 # $Id$
2    
3     ## @file util.py
4    
5     """
6 jgs 149 Utility functions for escript
7 jgs 123
8 jgs 149 @todo:
9 jgs 123
10 jgs 149 - binary operations @ (@=+,-,*,/,**)::
11     (a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b)
12     (a@b)[:]=a[:]@b[:] if rank(a)=rank(b)
13     (a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b)
14     - implementation of outer::
15     outer(a,b)[:,*]=a[:]*b[*]
16     - trace::
17     trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1)
18 jgs 82 """
19    
20     import numarray
21 jgs 102 import escript
22 jgs 150 import symbols
23     import os
24 jgs 124
25 jgs 123 #=========================================================
26     # some little helpers
27     #=========================================================
28     def _testForZero(arg):
29 jgs 149 """
30     Returns True is arg is considered to be zero.
31     """
32 jgs 123 if isinstance(arg,int):
33     return not arg>0
34     elif isinstance(arg,float):
35     return not arg>0.
36 jgs 150 elif isinstance(arg,numarray.NumArray):
37 jgs 123 a=abs(arg)
38     while isinstance(a,numarray.NumArray): a=numarray.sometrue(a)
39     return not a>0
40     else:
41     return False
42    
43 jgs 150 #=========================================================
44 jgs 153 def saveVTK(filename,domain=None,**data):
45 jgs 150 """
46 jgs 153 writes a L{Data} objects into a files using the the VTK XML file format.
47 jgs 123
48 jgs 153 Example:
49 jgs 123
50 jgs 153 tmp=Scalar(..)
51     v=Vector(..)
52     saveVTK("solution.xml",temperature=tmp,velovity=v)
53 jgs 150
54 jgs 153 tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity"
55    
56     @param filename: file name of the output file
57     @type filename: C(str}
58     @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
59     @type domain: L{escript.Domain}
60     @keyword <name>: writes the assigned value to the VTK file using <name> as identifier.
61     @type <name>: L{Data} object.
62     @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
63 jgs 149 """
64 jgs 153 if domain==None:
65     for i in data.keys():
66     if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
67     if domain==None:
68     raise ValueError,"no domain detected."
69     else:
70     domain.saveVTK(filename,data)
71 jgs 150 #=========================================================
72 jgs 153 def saveDX(filename,domain=None,**data):
73 jgs 149 """
74 jgs 153 writes a L{Data} objects into a files using the the DX file format.
75 jgs 149
76 jgs 153 Example:
77 jgs 150
78 jgs 153 tmp=Scalar(..)
79     v=Vector(..)
80     saveDX("solution.dx",temperature=tmp,velovity=v)
81 jgs 150
82 jgs 153 tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity".
83    
84     @param filename: file name of the output file
85     @type filename: C(str}
86     @param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used.
87     @type domain: L{escript.Domain}
88     @keyword <name>: writes the assigned value to the DX file using <name> as identifier. The identifier can be used to select the data set when data are imported into DX.
89     @type <name>: L{Data} object.
90     @note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed.
91 jgs 149 """
92 jgs 153 if domain==None:
93     for i in data.keys():
94     if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain()
95     if domain==None:
96     raise ValueError,"no domain detected."
97     else:
98     domain.saveDX(filename,data)
99 jgs 123 #=========================================================
100    
101     def exp(arg):
102 jgs 82 """
103 jgs 149 Applies the exponential function to arg.
104    
105     @param arg: argument
106 jgs 123 """
107 jgs 150 if isinstance(arg,symbols.Symbol):
108     return symbols.Exp_Symbol(arg)
109 jgs 123 elif hasattr(arg,"exp"):
110     return arg.exp()
111 jgs 108 else:
112 jgs 123 return numarray.exp(arg)
113 jgs 82
114 jgs 123 def sqrt(arg):
115 jgs 82 """
116 jgs 149 Applies the squre root function to arg.
117    
118     @param arg: argument
119 jgs 123 """
120 jgs 150 if isinstance(arg,symbols.Symbol):
121     return symbols.Sqrt_Symbol(arg)
122 jgs 123 elif hasattr(arg,"sqrt"):
123     return arg.sqrt()
124     else:
125     return numarray.sqrt(arg)
126 jgs 82
127 jgs 123 def log(arg):
128 jgs 82 """
129 jgs 153 Applies the logarithmic function base 10 to arg.
130 jgs 149
131     @param arg: argument
132 jgs 123 """
133 jgs 150 if isinstance(arg,symbols.Symbol):
134     return symbols.Log_Symbol(arg)
135 jgs 123 elif hasattr(arg,"log"):
136     return arg.log()
137 jgs 108 else:
138 jgs 153 return numarray.log10(arg)
139 jgs 82
140 jgs 124 def ln(arg):
141     """
142 jgs 149 Applies the natural logarithmic function to arg.
143    
144     @param arg: argument
145 jgs 124 """
146 jgs 150 if isinstance(arg,symbols.Symbol):
147     return symbols.Ln_Symbol(arg)
148 jgs 124 elif hasattr(arg,"ln"):
149 jgs 153 return arg.ln()
150 jgs 124 else:
151     return numarray.log(arg)
152    
153 jgs 123 def sin(arg):
154 jgs 82 """
155 jgs 149 Applies the sin function to arg.
156    
157     @param arg: argument
158 jgs 123 """
159 jgs 150 if isinstance(arg,symbols.Symbol):
160     return symbols.Sin_Symbol(arg)
161 jgs 123 elif hasattr(arg,"sin"):
162     return arg.sin()
163     else:
164     return numarray.sin(arg)
165 jgs 82
166 jgs 123 def cos(arg):
167 jgs 82 """
168 jgs 149 Applies the cos function to arg.
169    
170     @param arg: argument
171 jgs 123 """
172 jgs 150 if isinstance(arg,symbols.Symbol):
173     return symbols.Cos_Symbol(arg)
174 jgs 123 elif hasattr(arg,"cos"):
175     return arg.cos()
176 jgs 82 else:
177 jgs 123 return numarray.cos(arg)
178 jgs 82
179 jgs 123 def tan(arg):
180 jgs 82 """
181 jgs 149 Applies the tan function to arg.
182    
183     @param arg: argument
184 jgs 123 """
185 jgs 150 if isinstance(arg,symbols.Symbol):
186     return symbols.Tan_Symbol(arg)
187 jgs 123 elif hasattr(arg,"tan"):
188     return arg.tan()
189     else:
190     return numarray.tan(arg)
191 jgs 82
192 jgs 150 def asin(arg):
193     """
194     Applies the asin function to arg.
195 jgs 123
196 jgs 150 @param arg: argument
197     """
198     if isinstance(arg,symbols.Symbol):
199     return symbols.Asin_Symbol(arg)
200     elif hasattr(arg,"asin"):
201     return arg.asin()
202     else:
203     return numarray.asin(arg)
204    
205     def acos(arg):
206     """
207     Applies the acos function to arg.
208    
209     @param arg: argument
210     """
211     if isinstance(arg,symbols.Symbol):
212     return symbols.Acos_Symbol(arg)
213     elif hasattr(arg,"acos"):
214     return arg.acos()
215     else:
216     return numarray.acos(arg)
217    
218     def atan(arg):
219     """
220     Applies the atan function to arg.
221    
222     @param arg: argument
223     """
224     if isinstance(arg,symbols.Symbol):
225     return symbols.Atan_Symbol(arg)
226     elif hasattr(arg,"atan"):
227     return arg.atan()
228     else:
229     return numarray.atan(arg)
230    
231     def sinh(arg):
232     """
233     Applies the sinh function to arg.
234    
235     @param arg: argument
236     """
237     if isinstance(arg,symbols.Symbol):
238     return symbols.Sinh_Symbol(arg)
239     elif hasattr(arg,"sinh"):
240     return arg.sinh()
241     else:
242     return numarray.sinh(arg)
243    
244     def cosh(arg):
245     """
246     Applies the cosh function to arg.
247    
248     @param arg: argument
249     """
250     if isinstance(arg,symbols.Symbol):
251     return symbols.Cosh_Symbol(arg)
252     elif hasattr(arg,"cosh"):
253     return arg.cosh()
254     else:
255     return numarray.cosh(arg)
256    
257     def tanh(arg):
258     """
259     Applies the tanh function to arg.
260    
261     @param arg: argument
262     """
263     if isinstance(arg,symbols.Symbol):
264     return symbols.Tanh_Symbol(arg)
265     elif hasattr(arg,"tanh"):
266     return arg.tanh()
267     else:
268     return numarray.tanh(arg)
269    
270     def asinh(arg):
271     """
272     Applies the asinh function to arg.
273    
274     @param arg: argument
275     """
276     if isinstance(arg,symbols.Symbol):
277     return symbols.Asinh_Symbol(arg)
278     elif hasattr(arg,"asinh"):
279     return arg.asinh()
280     else:
281     return numarray.asinh(arg)
282    
283     def acosh(arg):
284     """
285     Applies the acosh function to arg.
286    
287     @param arg: argument
288     """
289     if isinstance(arg,symbols.Symbol):
290     return symbols.Acosh_Symbol(arg)
291     elif hasattr(arg,"acosh"):
292     return arg.acosh()
293     else:
294     return numarray.acosh(arg)
295    
296     def atanh(arg):
297     """
298     Applies the atanh function to arg.
299    
300     @param arg: argument
301     """
302     if isinstance(arg,symbols.Symbol):
303     return symbols.Atanh_Symbol(arg)
304     elif hasattr(arg,"atanh"):
305     return arg.atanh()
306     else:
307     return numarray.atanh(arg)
308    
309 jgs 123 def sign(arg):
310 jgs 82 """
311 jgs 149 Applies the sign function to arg.
312    
313     @param arg: argument
314 jgs 123 """
315 jgs 150 if isinstance(arg,symbols.Symbol):
316     return symbols.Sign_Symbol(arg)
317 jgs 123 elif hasattr(arg,"sign"):
318     return arg.sign()
319 jgs 82 else:
320 jgs 123 return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \
321     numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
322 jgs 82
323 jgs 123 def maxval(arg):
324 jgs 82 """
325 jgs 149 Returns the maximum value of argument arg.
326    
327     @param arg: argument
328 jgs 123 """
329 jgs 150 if isinstance(arg,symbols.Symbol):
330     return symbols.Max_Symbol(arg)
331 jgs 123 elif hasattr(arg,"maxval"):
332     return arg.maxval()
333     elif hasattr(arg,"max"):
334     return arg.max()
335     else:
336     return arg
337 jgs 82
338 jgs 123 def minval(arg):
339 jgs 82 """
340 jgs 149 Returns the minimum value of argument arg.
341    
342     @param arg: argument
343 jgs 123 """
344 jgs 150 if isinstance(arg,symbols.Symbol):
345     return symbols.Min_Symbol(arg)
346 jgs 123 elif hasattr(arg,"maxval"):
347     return arg.minval()
348     elif hasattr(arg,"min"):
349     return arg.min()
350 jgs 82 else:
351 jgs 123 return arg
352 jgs 82
353 jgs 123 def wherePositive(arg):
354 jgs 82 """
355 jgs 149 Returns the positive values of argument arg.
356    
357     @param arg: argument
358 jgs 123 """
359     if _testForZero(arg):
360     return 0
361 jgs 150 elif isinstance(arg,symbols.Symbol):
362     return symbols.WherePositive_Symbol(arg)
363 jgs 123 elif hasattr(arg,"wherePositive"):
364     return arg.minval()
365     elif hasattr(arg,"wherePositive"):
366     numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))
367     else:
368     if arg>0:
369     return 1.
370     else:
371     return 0.
372 jgs 82
373 jgs 123 def whereNegative(arg):
374 jgs 82 """
375 jgs 149 Returns the negative values of argument arg.
376    
377     @param arg: argument
378 jgs 123 """
379     if _testForZero(arg):
380     return 0
381 jgs 150 elif isinstance(arg,symbols.Symbol):
382     return symbols.WhereNegative_Symbol(arg)
383 jgs 123 elif hasattr(arg,"whereNegative"):
384     return arg.whereNegative()
385     elif hasattr(arg,"shape"):
386     numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))
387 jgs 88 else:
388 jgs 123 if arg<0:
389     return 1.
390     else:
391     return 0.
392 jgs 82
393 jgs 147 def maximum(arg0,arg1):
394 jgs 149 """
395     Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned.
396     """
397 jgs 147 m=whereNegative(arg0-arg1)
398     return m*arg1+(1.-m)*arg0
399    
400     def minimum(arg0,arg1):
401 jgs 149 """
402     Return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned.
403     """
404 jgs 147 m=whereNegative(arg0-arg1)
405     return m*arg0+(1.-m)*arg1
406    
407 jgs 123 def outer(arg0,arg1):
408     if _testForZero(arg0) or _testForZero(arg1):
409     return 0
410     else:
411 jgs 150 if isinstance(arg0,symbols.Symbol) or isinstance(arg1,symbols.Symbol):
412     return symbols.Outer_Symbol(arg0,arg1)
413 jgs 123 elif _identifyShape(arg0)==() or _identifyShape(arg1)==():
414     return arg0*arg1
415     elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray):
416     return numarray.outer(arg0,arg1)
417     else:
418     if arg0.getRank()==1 and arg1.getRank()==1:
419     out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace())
420     for i in range(arg0.getShape()[0]):
421     for j in range(arg1.getShape()[0]):
422     out[i,j]=arg0[i]*arg1[j]
423     return out
424     else:
425     raise ValueError,"outer is not fully implemented yet."
426    
427     def interpolate(arg,where):
428 jgs 82 """
429 jgs 149 Interpolates the function into the FunctionSpace where.
430 jgs 82
431 jgs 149 @param arg: interpolant
432     @param where: FunctionSpace to interpolate to
433 jgs 82 """
434 jgs 123 if _testForZero(arg):
435     return 0
436 jgs 150 elif isinstance(arg,symbols.Symbol):
437     return symbols.Interpolated_Symbol(arg,where)
438 jgs 82 else:
439 jgs 123 return escript.Data(arg,where)
440 jgs 82
441 jgs 147 def div(arg,where=None):
442     """
443 jgs 149 Returns the divergence of arg at where.
444 jgs 147
445 jgs 149 @param arg: Data object representing the function which gradient to
446     be calculated.
447     @param where: FunctionSpace in which the gradient will be calculated.
448     If not present or C{None} an appropriate default is used.
449 jgs 147 """
450     return trace(grad(arg,where))
451    
452 jgs 149 def jump(arg):
453     """
454     Returns the jump of arg across a continuity.
455    
456     @param arg: Data object representing the function which gradient
457     to be calculated.
458     """
459     d=arg.getDomain()
460     return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero())
461    
462    
463 jgs 123 def grad(arg,where=None):
464 jgs 102 """
465 jgs 149 Returns the spatial gradient of arg at where.
466 jgs 102
467 jgs 149 @param arg: Data object representing the function which gradient
468     to be calculated.
469     @param where: FunctionSpace in which the gradient will be calculated.
470     If not present or C{None} an appropriate default is used.
471 jgs 102 """
472 jgs 123 if _testForZero(arg):
473     return 0
474 jgs 150 elif isinstance(arg,symbols.Symbol):
475     return symbols.Grad_Symbol(arg,where)
476 jgs 123 elif hasattr(arg,"grad"):
477     if where==None:
478     return arg.grad()
479     else:
480     return arg.grad(where)
481 jgs 102 else:
482 jgs 123 return arg*0.
483 jgs 102
484 jgs 123 def integrate(arg,where=None):
485 jgs 82 """
486 jgs 149 Return the integral if the function represented by Data object arg over
487     its domain.
488 jgs 82
489 jgs 123 @param arg: Data object representing the function which is integrated.
490 jgs 149 @param where: FunctionSpace in which the integral is calculated.
491     If not present or C{None} an appropriate default is used.
492 jgs 82 """
493 jgs 123 if _testForZero(arg):
494     return 0
495 jgs 150 elif isinstance(arg,symbols.Symbol):
496     return symbols.Integral_Symbol(arg,where)
497 jgs 123 else:
498     if not where==None: arg=escript.Data(arg,where)
499     if arg.getRank()==0:
500     return arg.integrate()[0]
501     else:
502     return arg.integrate()
503 jgs 82
504 jgs 123 #=============================
505     #
506     # wrapper for various functions: if the argument has attribute the function name
507 jgs 124 # as an argument it calls the corresponding methods. Otherwise the corresponding
508     # numarray function is called.
509    
510 jgs 123 # functions involving the underlying Domain:
511    
512    
513     # functions returning Data objects:
514    
515     def transpose(arg,axis=None):
516 jgs 82 """
517 jgs 149 Returns the transpose of the Data object arg.
518 jgs 82
519 jgs 149 @param arg:
520 jgs 82 """
521 jgs 123 if axis==None:
522     r=0
523     if hasattr(arg,"getRank"): r=arg.getRank()
524     if hasattr(arg,"rank"): r=arg.rank
525     axis=r/2
526 jgs 150 if isinstance(arg,symbols.Symbol):
527     return symbols.Transpose_Symbol(arg,axis=r)
528 jgs 102 if isinstance(arg,escript.Data):
529 jgs 123 # hack for transpose
530     r=arg.getRank()
531     if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects"
532     s=arg.getShape()
533     out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace())
534     for i in range(s[0]):
535     for j in range(s[1]):
536     out[j,i]=arg[i,j]
537     return out
538     # end hack for transpose
539     return arg.transpose(axis)
540 jgs 102 else:
541 jgs 123 return numarray.transpose(arg,axis=axis)
542 jgs 82
543 jgs 123 def trace(arg,axis0=0,axis1=1):
544 jgs 82 """
545 jgs 149 Return
546 jgs 82
547 jgs 149 @param arg:
548 jgs 82 """
549 jgs 150 if isinstance(arg,symbols.Symbol):
550 jgs 123 s=list(arg.getShape())
551     s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:])
552 jgs 150 return symbols.Trace_Symbol(arg,axis0=axis0,axis1=axis1)
553 jgs 123 elif isinstance(arg,escript.Data):
554     # hack for trace
555     s=arg.getShape()
556     if s[axis0]!=s[axis1]:
557     raise ValueError,"illegal axis in trace"
558     out=escript.Scalar(0.,arg.getFunctionSpace())
559 jgs 147 for i in range(s[axis0]):
560     out+=arg[i,i]
561 jgs 123 return out
562 jgs 126 # end hack for trace
563 jgs 102 else:
564 jgs 123 return numarray.trace(arg,axis0=axis0,axis1=axis1)
565 jgs 82
566 jgs 102 def length(arg):
567     """
568    
569 jgs 149 @param arg:
570 jgs 102 """
571     if isinstance(arg,escript.Data):
572 jgs 108 if arg.isEmpty(): return escript.Data()
573     if arg.getRank()==0:
574     return abs(arg)
575     elif arg.getRank()==1:
576 jgs 147 out=escript.Scalar(0,arg.getFunctionSpace())
577 jgs 104 for i in range(arg.getShape()[0]):
578 jgs 147 out+=arg[i]**2
579     return sqrt(out)
580 jgs 108 elif arg.getRank()==2:
581 jgs 147 out=escript.Scalar(0,arg.getFunctionSpace())
582 jgs 108 for i in range(arg.getShape()[0]):
583     for j in range(arg.getShape()[1]):
584 jgs 147 out+=arg[i,j]**2
585     return sqrt(out)
586 jgs 108 elif arg.getRank()==3:
587 jgs 147 out=escript.Scalar(0,arg.getFunctionSpace())
588 jgs 108 for i in range(arg.getShape()[0]):
589     for j in range(arg.getShape()[1]):
590     for k in range(arg.getShape()[2]):
591 jgs 147 out+=arg[i,j,k]**2
592     return sqrt(out)
593 jgs 108 elif arg.getRank()==4:
594 jgs 147 out=escript.Scalar(0,arg.getFunctionSpace())
595 jgs 108 for i in range(arg.getShape()[0]):
596     for j in range(arg.getShape()[1]):
597     for k in range(arg.getShape()[2]):
598     for l in range(arg.getShape()[3]):
599 jgs 147 out+=arg[i,j,k,l]**2
600     return sqrt(out)
601 jgs 104 else:
602 jgs 126 raise SystemError,"length is not been fully implemented yet"
603     # return arg.length()
604 jgs 147 elif isinstance(arg,float):
605     return abs(arg)
606 jgs 102 else:
607     return sqrt((arg**2).sum())
608    
609 jgs 113 def deviator(arg):
610     """
611 jgs 149 @param arg:
612 jgs 113 """
613     if isinstance(arg,escript.Data):
614     shape=arg.getShape()
615     else:
616     shape=arg.shape
617     if len(shape)!=2:
618     raise ValueError,"Deviator requires rank 2 object"
619     if shape[0]!=shape[1]:
620     raise ValueError,"Deviator requires a square matrix"
621     return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0])
622    
623 jgs 123 def inner(arg0,arg1):
624 jgs 113 """
625 jgs 149 @param arg0:
626     @param arg1:
627 jgs 113 """
628 jgs 147 if isinstance(arg0,escript.Data):
629     arg=arg0
630     else:
631     arg=arg1
632    
633     out=escript.Scalar(0,arg.getFunctionSpace())
634     if arg.getRank()==0:
635 jgs 123 return arg0*arg1
636 jgs 147 elif arg.getRank()==1:
637     out=escript.Scalar(0,arg.getFunctionSpace())
638     for i in range(arg.getShape()[0]):
639     out+=arg0[i]*arg1[i]
640     elif arg.getRank()==2:
641     out=escript.Scalar(0,arg.getFunctionSpace())
642     for i in range(arg.getShape()[0]):
643     for j in range(arg.getShape()[1]):
644     out+=arg0[i,j]*arg1[i,j]
645     elif arg.getRank()==3:
646     out=escript.Scalar(0,arg.getFunctionSpace())
647     for i in range(arg.getShape()[0]):
648     for j in range(arg.getShape()[1]):
649     for k in range(arg.getShape()[2]):
650     out+=arg0[i,j,k]*arg1[i,j,k]
651     elif arg.getRank()==4:
652     out=escript.Scalar(0,arg.getFunctionSpace())
653     for i in range(arg.getShape()[0]):
654     for j in range(arg.getShape()[1]):
655     for k in range(arg.getShape()[2]):
656     for l in range(arg.getShape()[3]):
657     out+=arg0[i,j,k,l]*arg1[i,j,k,l]
658 jgs 113 else:
659     raise SystemError,"inner is not been implemented yet"
660 jgs 147 return out
661 jgs 113
662 jgs 149 def tensormult(arg0,arg1):
663     # check LinearPDE!!!!
664     raise SystemError,"tensormult is not implemented yet!"
665    
666 jgs 123 def matrixmult(arg0,arg1):
667 jgs 102
668 jgs 123 if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray):
669     numarray.matrixmult(arg0,arg1)
670 jgs 102 else:
671 jgs 123 # escript.matmult(arg0,arg1)
672     if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data):
673     arg0=escript.Data(arg0,arg1.getFunctionSpace())
674     elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data):
675     arg1=escript.Data(arg1,arg0.getFunctionSpace())
676     if arg0.getRank()==2 and arg1.getRank()==1:
677     out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace())
678     for i in range(arg0.getShape()[0]):
679     for j in range(arg0.getShape()[1]):
680 jgs 126 # uses Data object slicing, plus Data * and += operators
681 jgs 123 out[i]+=arg0[i,j]*arg1[j]
682     return out
683 jgs 147 elif arg0.getRank()==1 and arg1.getRank()==1:
684     return inner(arg0,arg1)
685 jgs 123 else:
686     raise SystemError,"matrixmult is not fully implemented yet!"
687 jgs 124
688 jgs 123 #=========================================================
689 jgs 102 # reduction operations:
690 jgs 123 #=========================================================
691 jgs 102 def sum(arg):
692     """
693 jgs 149 @param arg:
694 jgs 102 """
695     return arg.sum()
696    
697 jgs 82 def sup(arg):
698     """
699 jgs 149 @param arg:
700 jgs 82 """
701 jgs 102 if isinstance(arg,escript.Data):
702     return arg.sup()
703 jgs 108 elif isinstance(arg,float) or isinstance(arg,int):
704     return arg
705 jgs 102 else:
706     return arg.max()
707 jgs 82
708     def inf(arg):
709     """
710 jgs 149 @param arg:
711 jgs 82 """
712 jgs 102 if isinstance(arg,escript.Data):
713     return arg.inf()
714 jgs 108 elif isinstance(arg,float) or isinstance(arg,int):
715     return arg
716 jgs 102 else:
717     return arg.min()
718 jgs 82
719 jgs 102 def L2(arg):
720 jgs 82 """
721 jgs 149 Returns the L2-norm of the argument
722 jgs 82
723 jgs 149 @param arg:
724 jgs 82 """
725 jgs 108 if isinstance(arg,escript.Data):
726     return arg.L2()
727     elif isinstance(arg,float) or isinstance(arg,int):
728     return abs(arg)
729     else:
730     return numarry.sqrt(dot(arg,arg))
731 jgs 82
732 jgs 102 def Lsup(arg):
733 jgs 82 """
734 jgs 149 @param arg:
735 jgs 82 """
736 jgs 102 if isinstance(arg,escript.Data):
737     return arg.Lsup()
738 jgs 108 elif isinstance(arg,float) or isinstance(arg,int):
739     return abs(arg)
740 jgs 102 else:
741 jgs 149 return numarray.abs(arg).max()
742 jgs 82
743 jgs 123 def dot(arg0,arg1):
744 jgs 117 """
745 jgs 149 @param arg0:
746     @param arg1:
747 jgs 117 """
748 jgs 123 if isinstance(arg0,escript.Data):
749     return arg0.dot(arg1)
750 jgs 102 elif isinstance(arg1,escript.Data):
751 jgs 123 return arg1.dot(arg0)
752 jgs 102 else:
753 jgs 123 return numarray.dot(arg0,arg1)
754 jgs 113
755     def kronecker(d):
756 jgs 122 if hasattr(d,"getDim"):
757 jgs 147 return numarray.identity(d.getDim())*1.
758 jgs 122 else:
759 jgs 147 return numarray.identity(d)*1.
760 jgs 117
761     def unit(i,d):
762     """
763 jgs 149 Return a unit vector of dimension d with nonzero index i.
764    
765     @param d: dimension
766     @param i: index
767 jgs 117 """
768 jgs 123 e = numarray.zeros((d,),numarray.Float)
769 jgs 117 e[i] = 1.0
770     return e

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26