1 |
# $Id$ |
2 |
|
3 |
## @file util.py |
4 |
|
5 |
""" |
6 |
Utility functions for escript |
7 |
|
8 |
@todo: |
9 |
|
10 |
- binary operations @ (@=+,-,*,/,**):: |
11 |
(a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b) |
12 |
(a@b)[:]=a[:]@b[:] if rank(a)=rank(b) |
13 |
(a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b) |
14 |
- implementation of outer:: |
15 |
outer(a,b)[:,*]=a[:]*b[*] |
16 |
- trace:: |
17 |
trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1) |
18 |
""" |
19 |
|
20 |
import numarray |
21 |
import escript |
22 |
import symbols |
23 |
import os |
24 |
|
25 |
#========================================================= |
26 |
# some little helpers |
27 |
#========================================================= |
28 |
def _testForZero(arg): |
29 |
""" |
30 |
Returns True is arg is considered to be zero. |
31 |
""" |
32 |
if isinstance(arg,int): |
33 |
return not arg>0 |
34 |
elif isinstance(arg,float): |
35 |
return not arg>0. |
36 |
elif isinstance(arg,numarray.NumArray): |
37 |
a=abs(arg) |
38 |
while isinstance(a,numarray.NumArray): a=numarray.sometrue(a) |
39 |
return not a>0 |
40 |
else: |
41 |
return False |
42 |
|
43 |
#========================================================= |
44 |
def saveVTK(filename,domain=None,**data): |
45 |
""" |
46 |
writes a L{Data} objects into a files using the the VTK XML file format. |
47 |
|
48 |
Example: |
49 |
|
50 |
tmp=Scalar(..) |
51 |
v=Vector(..) |
52 |
saveVTK("solution.xml",temperature=tmp,velovity=v) |
53 |
|
54 |
tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity" |
55 |
|
56 |
@param filename: file name of the output file |
57 |
@type filename: C(str} |
58 |
@param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used. |
59 |
@type domain: L{escript.Domain} |
60 |
@keyword <name>: writes the assigned value to the VTK file using <name> as identifier. |
61 |
@type <name>: L{Data} object. |
62 |
@note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed. |
63 |
""" |
64 |
if domain==None: |
65 |
for i in data.keys(): |
66 |
if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain() |
67 |
if domain==None: |
68 |
raise ValueError,"no domain detected." |
69 |
else: |
70 |
domain.saveVTK(filename,data) |
71 |
|
72 |
#========================================================= |
73 |
def saveDX(filename,domain=None,**data): |
74 |
""" |
75 |
writes a L{Data} objects into a files using the the DX file format. |
76 |
|
77 |
Example: |
78 |
|
79 |
tmp=Scalar(..) |
80 |
v=Vector(..) |
81 |
saveDX("solution.dx",temperature=tmp,velovity=v) |
82 |
|
83 |
tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity". |
84 |
|
85 |
@param filename: file name of the output file |
86 |
@type filename: C(str} |
87 |
@param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used. |
88 |
@type domain: L{escript.Domain} |
89 |
@keyword <name>: writes the assigned value to the DX file using <name> as identifier. The identifier can be used to select the data set when data are imported into DX. |
90 |
@type <name>: L{Data} object. |
91 |
@note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed. |
92 |
""" |
93 |
if domain==None: |
94 |
for i in data.keys(): |
95 |
if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain() |
96 |
if domain==None: |
97 |
raise ValueError,"no domain detected." |
98 |
else: |
99 |
domain.saveDX(filename,data) |
100 |
|
101 |
#========================================================= |
102 |
|
103 |
def exp(arg): |
104 |
""" |
105 |
Applies the exponential function to arg. |
106 |
|
107 |
@param arg: argument |
108 |
""" |
109 |
if isinstance(arg,symbols.Symbol): |
110 |
return symbols.Exp_Symbol(arg) |
111 |
elif hasattr(arg,"exp"): |
112 |
return arg.exp() |
113 |
else: |
114 |
return numarray.exp(arg) |
115 |
|
116 |
def sqrt(arg): |
117 |
""" |
118 |
Applies the squre root function to arg. |
119 |
|
120 |
@param arg: argument |
121 |
""" |
122 |
if isinstance(arg,symbols.Symbol): |
123 |
return symbols.Sqrt_Symbol(arg) |
124 |
elif hasattr(arg,"sqrt"): |
125 |
return arg.sqrt() |
126 |
else: |
127 |
return numarray.sqrt(arg) |
128 |
|
129 |
def log(arg): |
130 |
""" |
131 |
Applies the logarithmic function base 10 to arg. |
132 |
|
133 |
@param arg: argument |
134 |
""" |
135 |
if isinstance(arg,symbols.Symbol): |
136 |
return symbols.Log_Symbol(arg) |
137 |
elif hasattr(arg,"log"): |
138 |
return arg.log() |
139 |
else: |
140 |
return numarray.log10(arg) |
141 |
|
142 |
def ln(arg): |
143 |
""" |
144 |
Applies the natural logarithmic function to arg. |
145 |
|
146 |
@param arg: argument |
147 |
""" |
148 |
if isinstance(arg,symbols.Symbol): |
149 |
return symbols.Ln_Symbol(arg) |
150 |
elif hasattr(arg,"ln"): |
151 |
return arg.ln() |
152 |
else: |
153 |
return numarray.log(arg) |
154 |
|
155 |
def sin(arg): |
156 |
""" |
157 |
Applies the sin function to arg. |
158 |
|
159 |
@param arg: argument |
160 |
""" |
161 |
if isinstance(arg,symbols.Symbol): |
162 |
return symbols.Sin_Symbol(arg) |
163 |
elif hasattr(arg,"sin"): |
164 |
return arg.sin() |
165 |
else: |
166 |
return numarray.sin(arg) |
167 |
|
168 |
def cos(arg): |
169 |
""" |
170 |
Applies the cos function to arg. |
171 |
|
172 |
@param arg: argument |
173 |
""" |
174 |
if isinstance(arg,symbols.Symbol): |
175 |
return symbols.Cos_Symbol(arg) |
176 |
elif hasattr(arg,"cos"): |
177 |
return arg.cos() |
178 |
else: |
179 |
return numarray.cos(arg) |
180 |
|
181 |
def tan(arg): |
182 |
""" |
183 |
Applies the tan function to arg. |
184 |
|
185 |
@param arg: argument |
186 |
""" |
187 |
if isinstance(arg,symbols.Symbol): |
188 |
return symbols.Tan_Symbol(arg) |
189 |
elif hasattr(arg,"tan"): |
190 |
return arg.tan() |
191 |
else: |
192 |
return numarray.tan(arg) |
193 |
|
194 |
def asin(arg): |
195 |
""" |
196 |
Applies the asin function to arg. |
197 |
|
198 |
@param arg: argument |
199 |
""" |
200 |
if isinstance(arg,symbols.Symbol): |
201 |
return symbols.Asin_Symbol(arg) |
202 |
elif hasattr(arg,"asin"): |
203 |
return arg.asin() |
204 |
else: |
205 |
return numarray.asin(arg) |
206 |
|
207 |
def acos(arg): |
208 |
""" |
209 |
Applies the acos function to arg. |
210 |
|
211 |
@param arg: argument |
212 |
""" |
213 |
if isinstance(arg,symbols.Symbol): |
214 |
return symbols.Acos_Symbol(arg) |
215 |
elif hasattr(arg,"acos"): |
216 |
return arg.acos() |
217 |
else: |
218 |
return numarray.acos(arg) |
219 |
|
220 |
def atan(arg): |
221 |
""" |
222 |
Applies the atan function to arg. |
223 |
|
224 |
@param arg: argument |
225 |
""" |
226 |
if isinstance(arg,symbols.Symbol): |
227 |
return symbols.Atan_Symbol(arg) |
228 |
elif hasattr(arg,"atan"): |
229 |
return arg.atan() |
230 |
else: |
231 |
return numarray.atan(arg) |
232 |
|
233 |
def sinh(arg): |
234 |
""" |
235 |
Applies the sinh function to arg. |
236 |
|
237 |
@param arg: argument |
238 |
""" |
239 |
if isinstance(arg,symbols.Symbol): |
240 |
return symbols.Sinh_Symbol(arg) |
241 |
elif hasattr(arg,"sinh"): |
242 |
return arg.sinh() |
243 |
else: |
244 |
return numarray.sinh(arg) |
245 |
|
246 |
def cosh(arg): |
247 |
""" |
248 |
Applies the cosh function to arg. |
249 |
|
250 |
@param arg: argument |
251 |
""" |
252 |
if isinstance(arg,symbols.Symbol): |
253 |
return symbols.Cosh_Symbol(arg) |
254 |
elif hasattr(arg,"cosh"): |
255 |
return arg.cosh() |
256 |
else: |
257 |
return numarray.cosh(arg) |
258 |
|
259 |
def tanh(arg): |
260 |
""" |
261 |
Applies the tanh function to arg. |
262 |
|
263 |
@param arg: argument |
264 |
""" |
265 |
if isinstance(arg,symbols.Symbol): |
266 |
return symbols.Tanh_Symbol(arg) |
267 |
elif hasattr(arg,"tanh"): |
268 |
return arg.tanh() |
269 |
else: |
270 |
return numarray.tanh(arg) |
271 |
|
272 |
def asinh(arg): |
273 |
""" |
274 |
Applies the asinh function to arg. |
275 |
|
276 |
@param arg: argument |
277 |
""" |
278 |
if isinstance(arg,symbols.Symbol): |
279 |
return symbols.Asinh_Symbol(arg) |
280 |
elif hasattr(arg,"asinh"): |
281 |
return arg.asinh() |
282 |
else: |
283 |
return numarray.asinh(arg) |
284 |
|
285 |
def acosh(arg): |
286 |
""" |
287 |
Applies the acosh function to arg. |
288 |
|
289 |
@param arg: argument |
290 |
""" |
291 |
if isinstance(arg,symbols.Symbol): |
292 |
return symbols.Acosh_Symbol(arg) |
293 |
elif hasattr(arg,"acosh"): |
294 |
return arg.acosh() |
295 |
else: |
296 |
return numarray.acosh(arg) |
297 |
|
298 |
def atanh(arg): |
299 |
""" |
300 |
Applies the atanh function to arg. |
301 |
|
302 |
@param arg: argument |
303 |
""" |
304 |
if isinstance(arg,symbols.Symbol): |
305 |
return symbols.Atanh_Symbol(arg) |
306 |
elif hasattr(arg,"atanh"): |
307 |
return arg.atanh() |
308 |
else: |
309 |
return numarray.atanh(arg) |
310 |
|
311 |
def sign(arg): |
312 |
""" |
313 |
Applies the sign function to arg. |
314 |
|
315 |
@param arg: argument |
316 |
""" |
317 |
if isinstance(arg,symbols.Symbol): |
318 |
return symbols.Sign_Symbol(arg) |
319 |
elif hasattr(arg,"sign"): |
320 |
return arg.sign() |
321 |
else: |
322 |
return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \ |
323 |
numarray.less(arg,numarray.zeros(arg.shape,numarray.Float)) |
324 |
|
325 |
def maxval(arg): |
326 |
""" |
327 |
Returns the maximum value of argument arg. |
328 |
|
329 |
@param arg: argument |
330 |
""" |
331 |
if isinstance(arg,symbols.Symbol): |
332 |
return symbols.Max_Symbol(arg) |
333 |
elif hasattr(arg,"maxval"): |
334 |
return arg.maxval() |
335 |
elif hasattr(arg,"max"): |
336 |
return arg.max() |
337 |
else: |
338 |
return arg |
339 |
|
340 |
def minval(arg): |
341 |
""" |
342 |
Returns the minimum value of argument arg. |
343 |
|
344 |
@param arg: argument |
345 |
""" |
346 |
if isinstance(arg,symbols.Symbol): |
347 |
return symbols.Min_Symbol(arg) |
348 |
elif hasattr(arg,"maxval"): |
349 |
return arg.minval() |
350 |
elif hasattr(arg,"min"): |
351 |
return arg.min() |
352 |
else: |
353 |
return arg |
354 |
|
355 |
def wherePositive(arg): |
356 |
""" |
357 |
Returns the positive values of argument arg. |
358 |
|
359 |
@param arg: argument |
360 |
""" |
361 |
if _testForZero(arg): |
362 |
return 0 |
363 |
elif isinstance(arg,symbols.Symbol): |
364 |
return symbols.WherePositive_Symbol(arg) |
365 |
elif hasattr(arg,"wherePositive"): |
366 |
return arg.minval() |
367 |
elif hasattr(arg,"wherePositive"): |
368 |
numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float)) |
369 |
else: |
370 |
if arg>0: |
371 |
return 1. |
372 |
else: |
373 |
return 0. |
374 |
|
375 |
def whereNegative(arg): |
376 |
""" |
377 |
Returns the negative values of argument arg. |
378 |
|
379 |
@param arg: argument |
380 |
""" |
381 |
if _testForZero(arg): |
382 |
return 0 |
383 |
elif isinstance(arg,symbols.Symbol): |
384 |
return symbols.WhereNegative_Symbol(arg) |
385 |
elif hasattr(arg,"whereNegative"): |
386 |
return arg.whereNegative() |
387 |
elif hasattr(arg,"shape"): |
388 |
numarray.less(arg,numarray.zeros(arg.shape,numarray.Float)) |
389 |
else: |
390 |
if arg<0: |
391 |
return 1. |
392 |
else: |
393 |
return 0. |
394 |
|
395 |
def maximum(arg0,arg1): |
396 |
""" |
397 |
Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned. |
398 |
""" |
399 |
m=whereNegative(arg0-arg1) |
400 |
return m*arg1+(1.-m)*arg0 |
401 |
|
402 |
def minimum(arg0,arg1): |
403 |
""" |
404 |
Return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned. |
405 |
""" |
406 |
m=whereNegative(arg0-arg1) |
407 |
return m*arg0+(1.-m)*arg1 |
408 |
|
409 |
def outer(arg0,arg1): |
410 |
if _testForZero(arg0) or _testForZero(arg1): |
411 |
return 0 |
412 |
else: |
413 |
if isinstance(arg0,symbols.Symbol) or isinstance(arg1,symbols.Symbol): |
414 |
return symbols.Outer_Symbol(arg0,arg1) |
415 |
elif _identifyShape(arg0)==() or _identifyShape(arg1)==(): |
416 |
return arg0*arg1 |
417 |
elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray): |
418 |
return numarray.outer(arg0,arg1) |
419 |
else: |
420 |
if arg0.getRank()==1 and arg1.getRank()==1: |
421 |
out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace()) |
422 |
for i in range(arg0.getShape()[0]): |
423 |
for j in range(arg1.getShape()[0]): |
424 |
out[i,j]=arg0[i]*arg1[j] |
425 |
return out |
426 |
else: |
427 |
raise ValueError,"outer is not fully implemented yet." |
428 |
|
429 |
def interpolate(arg,where): |
430 |
""" |
431 |
Interpolates the function into the FunctionSpace where. |
432 |
|
433 |
@param arg: interpolant |
434 |
@param where: FunctionSpace to interpolate to |
435 |
""" |
436 |
if _testForZero(arg): |
437 |
return 0 |
438 |
elif isinstance(arg,symbols.Symbol): |
439 |
return symbols.Interpolated_Symbol(arg,where) |
440 |
else: |
441 |
return escript.Data(arg,where) |
442 |
|
443 |
def div(arg,where=None): |
444 |
""" |
445 |
Returns the divergence of arg at where. |
446 |
|
447 |
@param arg: Data object representing the function which gradient to |
448 |
be calculated. |
449 |
@param where: FunctionSpace in which the gradient will be calculated. |
450 |
If not present or C{None} an appropriate default is used. |
451 |
""" |
452 |
return trace(grad(arg,where)) |
453 |
|
454 |
def jump(arg): |
455 |
""" |
456 |
Returns the jump of arg across a continuity. |
457 |
|
458 |
@param arg: Data object representing the function which gradient |
459 |
to be calculated. |
460 |
""" |
461 |
d=arg.getDomain() |
462 |
return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero()) |
463 |
|
464 |
|
465 |
def grad(arg,where=None): |
466 |
""" |
467 |
Returns the spatial gradient of arg at where. |
468 |
|
469 |
@param arg: Data object representing the function which gradient |
470 |
to be calculated. |
471 |
@param where: FunctionSpace in which the gradient will be calculated. |
472 |
If not present or C{None} an appropriate default is used. |
473 |
""" |
474 |
if _testForZero(arg): |
475 |
return 0 |
476 |
elif isinstance(arg,symbols.Symbol): |
477 |
return symbols.Grad_Symbol(arg,where) |
478 |
elif hasattr(arg,"grad"): |
479 |
if where==None: |
480 |
return arg.grad() |
481 |
else: |
482 |
return arg.grad(where) |
483 |
else: |
484 |
return arg*0. |
485 |
|
486 |
def integrate(arg,where=None): |
487 |
""" |
488 |
Return the integral if the function represented by Data object arg over |
489 |
its domain. |
490 |
|
491 |
@param arg: Data object representing the function which is integrated. |
492 |
@param where: FunctionSpace in which the integral is calculated. |
493 |
If not present or C{None} an appropriate default is used. |
494 |
""" |
495 |
if _testForZero(arg): |
496 |
return 0 |
497 |
elif isinstance(arg,symbols.Symbol): |
498 |
return symbols.Integral_Symbol(arg,where) |
499 |
else: |
500 |
if not where==None: arg=escript.Data(arg,where) |
501 |
if arg.getRank()==0: |
502 |
return arg.integrate()[0] |
503 |
else: |
504 |
return arg.integrate() |
505 |
|
506 |
#============================= |
507 |
# |
508 |
# wrapper for various functions: if the argument has attribute the function name |
509 |
# as an argument it calls the corresponding methods. Otherwise the corresponding |
510 |
# numarray function is called. |
511 |
|
512 |
# functions involving the underlying Domain: |
513 |
|
514 |
|
515 |
# functions returning Data objects: |
516 |
|
517 |
def transpose(arg,axis=None): |
518 |
""" |
519 |
Returns the transpose of the Data object arg. |
520 |
|
521 |
@param arg: |
522 |
""" |
523 |
if axis==None: |
524 |
r=0 |
525 |
if hasattr(arg,"getRank"): r=arg.getRank() |
526 |
if hasattr(arg,"rank"): r=arg.rank |
527 |
axis=r/2 |
528 |
if isinstance(arg,symbols.Symbol): |
529 |
return symbols.Transpose_Symbol(arg,axis=r) |
530 |
if isinstance(arg,escript.Data): |
531 |
# hack for transpose |
532 |
r=arg.getRank() |
533 |
if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects" |
534 |
s=arg.getShape() |
535 |
out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace()) |
536 |
for i in range(s[0]): |
537 |
for j in range(s[1]): |
538 |
out[j,i]=arg[i,j] |
539 |
return out |
540 |
# end hack for transpose |
541 |
return arg.transpose(axis) |
542 |
else: |
543 |
return numarray.transpose(arg,axis=axis) |
544 |
|
545 |
def trace(arg,axis0=0,axis1=1): |
546 |
""" |
547 |
Return |
548 |
|
549 |
@param arg: |
550 |
""" |
551 |
if isinstance(arg,symbols.Symbol): |
552 |
s=list(arg.getShape()) |
553 |
s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:]) |
554 |
return symbols.Trace_Symbol(arg,axis0=axis0,axis1=axis1) |
555 |
elif isinstance(arg,escript.Data): |
556 |
# hack for trace |
557 |
s=arg.getShape() |
558 |
if s[axis0]!=s[axis1]: |
559 |
raise ValueError,"illegal axis in trace" |
560 |
out=escript.Scalar(0.,arg.getFunctionSpace()) |
561 |
for i in range(s[axis0]): |
562 |
out+=arg[i,i] |
563 |
return out |
564 |
# end hack for trace |
565 |
else: |
566 |
return numarray.trace(arg,axis0=axis0,axis1=axis1) |
567 |
|
568 |
def length(arg): |
569 |
""" |
570 |
@param arg: |
571 |
""" |
572 |
if isinstance(arg,escript.Data): |
573 |
if arg.isEmpty(): return escript.Data() |
574 |
if arg.getRank()==0: |
575 |
return abs(arg) |
576 |
elif arg.getRank()==1: |
577 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
578 |
for i in range(arg.getShape()[0]): |
579 |
out+=arg[i]**2 |
580 |
return sqrt(out) |
581 |
elif arg.getRank()==2: |
582 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
583 |
for i in range(arg.getShape()[0]): |
584 |
for j in range(arg.getShape()[1]): |
585 |
out+=arg[i,j]**2 |
586 |
return sqrt(out) |
587 |
elif arg.getRank()==3: |
588 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
589 |
for i in range(arg.getShape()[0]): |
590 |
for j in range(arg.getShape()[1]): |
591 |
for k in range(arg.getShape()[2]): |
592 |
out+=arg[i,j,k]**2 |
593 |
return sqrt(out) |
594 |
elif arg.getRank()==4: |
595 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
596 |
for i in range(arg.getShape()[0]): |
597 |
for j in range(arg.getShape()[1]): |
598 |
for k in range(arg.getShape()[2]): |
599 |
for l in range(arg.getShape()[3]): |
600 |
out+=arg[i,j,k,l]**2 |
601 |
return sqrt(out) |
602 |
else: |
603 |
raise SystemError,"length is not been fully implemented yet" |
604 |
# return arg.length() |
605 |
elif isinstance(arg,float): |
606 |
return abs(arg) |
607 |
else: |
608 |
return sqrt((arg**2).sum()) |
609 |
|
610 |
def deviator(arg): |
611 |
""" |
612 |
@param arg: |
613 |
""" |
614 |
if isinstance(arg,escript.Data): |
615 |
shape=arg.getShape() |
616 |
else: |
617 |
shape=arg.shape |
618 |
if len(shape)!=2: |
619 |
raise ValueError,"Deviator requires rank 2 object" |
620 |
if shape[0]!=shape[1]: |
621 |
raise ValueError,"Deviator requires a square matrix" |
622 |
return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0]) |
623 |
|
624 |
def inner(arg0,arg1): |
625 |
""" |
626 |
@param arg0: |
627 |
@param arg1: |
628 |
""" |
629 |
if isinstance(arg0,escript.Data): |
630 |
arg=arg0 |
631 |
else: |
632 |
arg=arg1 |
633 |
|
634 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
635 |
if arg.getRank()==0: |
636 |
return arg0*arg1 |
637 |
elif arg.getRank()==1: |
638 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
639 |
for i in range(arg.getShape()[0]): |
640 |
out+=arg0[i]*arg1[i] |
641 |
elif arg.getRank()==2: |
642 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
643 |
for i in range(arg.getShape()[0]): |
644 |
for j in range(arg.getShape()[1]): |
645 |
out+=arg0[i,j]*arg1[i,j] |
646 |
elif arg.getRank()==3: |
647 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
648 |
for i in range(arg.getShape()[0]): |
649 |
for j in range(arg.getShape()[1]): |
650 |
for k in range(arg.getShape()[2]): |
651 |
out+=arg0[i,j,k]*arg1[i,j,k] |
652 |
elif arg.getRank()==4: |
653 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
654 |
for i in range(arg.getShape()[0]): |
655 |
for j in range(arg.getShape()[1]): |
656 |
for k in range(arg.getShape()[2]): |
657 |
for l in range(arg.getShape()[3]): |
658 |
out+=arg0[i,j,k,l]*arg1[i,j,k,l] |
659 |
else: |
660 |
raise SystemError,"inner is not been implemented yet" |
661 |
return out |
662 |
|
663 |
def tensormult(arg0,arg1): |
664 |
# check LinearPDE!!!! |
665 |
raise SystemError,"tensormult is not implemented yet!" |
666 |
|
667 |
def matrixmult(arg0,arg1): |
668 |
|
669 |
if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray): |
670 |
numarray.matrixmult(arg0,arg1) |
671 |
else: |
672 |
# escript.matmult(arg0,arg1) |
673 |
if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data): |
674 |
arg0=escript.Data(arg0,arg1.getFunctionSpace()) |
675 |
elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data): |
676 |
arg1=escript.Data(arg1,arg0.getFunctionSpace()) |
677 |
if arg0.getRank()==2 and arg1.getRank()==1: |
678 |
out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace()) |
679 |
for i in range(arg0.getShape()[0]): |
680 |
for j in range(arg0.getShape()[1]): |
681 |
# uses Data object slicing, plus Data * and += operators |
682 |
out[i]+=arg0[i,j]*arg1[j] |
683 |
return out |
684 |
elif arg0.getRank()==1 and arg1.getRank()==1: |
685 |
return inner(arg0,arg1) |
686 |
else: |
687 |
raise SystemError,"matrixmult is not fully implemented yet!" |
688 |
|
689 |
#========================================================= |
690 |
# reduction operations: |
691 |
#========================================================= |
692 |
def sum(arg): |
693 |
""" |
694 |
@param arg: |
695 |
""" |
696 |
return arg.sum() |
697 |
|
698 |
def sup(arg): |
699 |
""" |
700 |
@param arg: |
701 |
""" |
702 |
if isinstance(arg,escript.Data): |
703 |
return arg.sup() |
704 |
elif isinstance(arg,float) or isinstance(arg,int): |
705 |
return arg |
706 |
else: |
707 |
return arg.max() |
708 |
|
709 |
def inf(arg): |
710 |
""" |
711 |
@param arg: |
712 |
""" |
713 |
if isinstance(arg,escript.Data): |
714 |
return arg.inf() |
715 |
elif isinstance(arg,float) or isinstance(arg,int): |
716 |
return arg |
717 |
else: |
718 |
return arg.min() |
719 |
|
720 |
def L2(arg): |
721 |
""" |
722 |
Returns the L2-norm of the argument |
723 |
|
724 |
@param arg: |
725 |
""" |
726 |
if isinstance(arg,escript.Data): |
727 |
return arg.L2() |
728 |
elif isinstance(arg,float) or isinstance(arg,int): |
729 |
return abs(arg) |
730 |
else: |
731 |
return numarry.sqrt(dot(arg,arg)) |
732 |
|
733 |
def Lsup(arg): |
734 |
""" |
735 |
@param arg: |
736 |
""" |
737 |
if isinstance(arg,escript.Data): |
738 |
return arg.Lsup() |
739 |
elif isinstance(arg,float) or isinstance(arg,int): |
740 |
return abs(arg) |
741 |
else: |
742 |
return numarray.abs(arg).max() |
743 |
|
744 |
def dot(arg0,arg1): |
745 |
""" |
746 |
@param arg0: |
747 |
@param arg1: |
748 |
""" |
749 |
if isinstance(arg0,escript.Data): |
750 |
return arg0.dot(arg1) |
751 |
elif isinstance(arg1,escript.Data): |
752 |
return arg1.dot(arg0) |
753 |
else: |
754 |
return numarray.dot(arg0,arg1) |
755 |
|
756 |
def kronecker(d): |
757 |
if hasattr(d,"getDim"): |
758 |
return numarray.identity(d.getDim())*1. |
759 |
else: |
760 |
return numarray.identity(d)*1. |
761 |
|
762 |
def unit(i,d): |
763 |
""" |
764 |
Return a unit vector of dimension d with nonzero index i. |
765 |
|
766 |
@param d: dimension |
767 |
@param i: index |
768 |
""" |
769 |
e = numarray.zeros((d,),numarray.Float) |
770 |
e[i] = 1.0 |
771 |
return e |