1 |
# $Id$ |
2 |
|
3 |
## @file util.py |
4 |
|
5 |
""" |
6 |
Utility functions for escript |
7 |
|
8 |
@todo: |
9 |
|
10 |
- binary operations @ (@=+,-,*,/,**):: |
11 |
(a@b)[:,*]=a[:]@b[:,*] if rank(a)<rank(b) |
12 |
(a@b)[:]=a[:]@b[:] if rank(a)=rank(b) |
13 |
(a@b)[*,:]=a[*,:]@b[:] if rank(a)>rank(b) |
14 |
- implementation of outer:: |
15 |
outer(a,b)[:,*]=a[:]*b[*] |
16 |
- trace:: |
17 |
trace(arg,axis0=a0,axis1=a1)(:,&,*)=sum_i trace(:,i,&,i,*) (i are at index a0 and a1) |
18 |
""" |
19 |
|
20 |
import numarray |
21 |
import escript |
22 |
import symbols |
23 |
import os |
24 |
|
25 |
#========================================================= |
26 |
# some little helpers |
27 |
#========================================================= |
28 |
def _testForZero(arg): |
29 |
""" |
30 |
Returns True is arg is considered to be zero. |
31 |
""" |
32 |
if isinstance(arg,int): |
33 |
return not arg>0 |
34 |
elif isinstance(arg,float): |
35 |
return not arg>0. |
36 |
elif isinstance(arg,numarray.NumArray): |
37 |
a=abs(arg) |
38 |
while isinstance(a,numarray.NumArray): a=numarray.sometrue(a) |
39 |
return not a>0 |
40 |
else: |
41 |
return False |
42 |
|
43 |
#========================================================= |
44 |
def saveVTK(filename,domain=None,**data): |
45 |
""" |
46 |
writes a L{Data} objects into a files using the the VTK XML file format. |
47 |
|
48 |
Example: |
49 |
|
50 |
tmp=Scalar(..) |
51 |
v=Vector(..) |
52 |
saveVTK("solution.xml",temperature=tmp,velovity=v) |
53 |
|
54 |
tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity" |
55 |
|
56 |
@param filename: file name of the output file |
57 |
@type filename: C(str} |
58 |
@param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used. |
59 |
@type domain: L{escript.Domain} |
60 |
@keyword <name>: writes the assigned value to the VTK file using <name> as identifier. |
61 |
@type <name>: L{Data} object. |
62 |
@note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed. |
63 |
""" |
64 |
if domain==None: |
65 |
for i in data.keys(): |
66 |
if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain() |
67 |
if domain==None: |
68 |
raise ValueError,"no domain detected." |
69 |
else: |
70 |
domain.saveVTK(filename,data) |
71 |
#========================================================= |
72 |
def saveDX(filename,domain=None,**data): |
73 |
""" |
74 |
writes a L{Data} objects into a files using the the DX file format. |
75 |
|
76 |
Example: |
77 |
|
78 |
tmp=Scalar(..) |
79 |
v=Vector(..) |
80 |
saveDX("solution.dx",temperature=tmp,velovity=v) |
81 |
|
82 |
tmp and v are written into "solution.dx" where tmp is named "temperature" and v is named "velovity". |
83 |
|
84 |
@param filename: file name of the output file |
85 |
@type filename: C(str} |
86 |
@param domain: domain of the L{Data} object. If not specified, the domain of the given L{Data} objects is used. |
87 |
@type domain: L{escript.Domain} |
88 |
@keyword <name>: writes the assigned value to the DX file using <name> as identifier. The identifier can be used to select the data set when data are imported into DX. |
89 |
@type <name>: L{Data} object. |
90 |
@note: The data objects have to be defined on the same domain. They may not be in the same L{FunctionSpace} but one cannot expect that all L{FunctionSpace} can be mixed. Typically, data on the boundary and data on the interior cannot be mixed. |
91 |
""" |
92 |
if domain==None: |
93 |
for i in data.keys(): |
94 |
if not data[i].isEmpty(): domain=data[i].getFunctionSpace().getDomain() |
95 |
if domain==None: |
96 |
raise ValueError,"no domain detected." |
97 |
else: |
98 |
domain.saveDX(filename,data) |
99 |
#========================================================= |
100 |
|
101 |
def exp(arg): |
102 |
""" |
103 |
Applies the exponential function to arg. |
104 |
|
105 |
@param arg: argument |
106 |
""" |
107 |
if isinstance(arg,symbols.Symbol): |
108 |
return symbols.Exp_Symbol(arg) |
109 |
elif hasattr(arg,"exp"): |
110 |
return arg.exp() |
111 |
else: |
112 |
return numarray.exp(arg) |
113 |
|
114 |
def sqrt(arg): |
115 |
""" |
116 |
Applies the squre root function to arg. |
117 |
|
118 |
@param arg: argument |
119 |
""" |
120 |
if isinstance(arg,symbols.Symbol): |
121 |
return symbols.Sqrt_Symbol(arg) |
122 |
elif hasattr(arg,"sqrt"): |
123 |
return arg.sqrt() |
124 |
else: |
125 |
return numarray.sqrt(arg) |
126 |
|
127 |
def log(arg): |
128 |
""" |
129 |
Applies the logarithmic function base 10 to arg. |
130 |
|
131 |
@param arg: argument |
132 |
""" |
133 |
if isinstance(arg,symbols.Symbol): |
134 |
return symbols.Log_Symbol(arg) |
135 |
elif hasattr(arg,"log"): |
136 |
return arg.log() |
137 |
else: |
138 |
return numarray.log10(arg) |
139 |
|
140 |
def ln(arg): |
141 |
""" |
142 |
Applies the natural logarithmic function to arg. |
143 |
|
144 |
@param arg: argument |
145 |
""" |
146 |
if isinstance(arg,symbols.Symbol): |
147 |
return symbols.Ln_Symbol(arg) |
148 |
elif hasattr(arg,"ln"): |
149 |
return arg.ln() |
150 |
else: |
151 |
return numarray.log(arg) |
152 |
|
153 |
def sin(arg): |
154 |
""" |
155 |
Applies the sin function to arg. |
156 |
|
157 |
@param arg: argument |
158 |
""" |
159 |
if isinstance(arg,symbols.Symbol): |
160 |
return symbols.Sin_Symbol(arg) |
161 |
elif hasattr(arg,"sin"): |
162 |
return arg.sin() |
163 |
else: |
164 |
return numarray.sin(arg) |
165 |
|
166 |
def cos(arg): |
167 |
""" |
168 |
Applies the cos function to arg. |
169 |
|
170 |
@param arg: argument |
171 |
""" |
172 |
if isinstance(arg,symbols.Symbol): |
173 |
return symbols.Cos_Symbol(arg) |
174 |
elif hasattr(arg,"cos"): |
175 |
return arg.cos() |
176 |
else: |
177 |
return numarray.cos(arg) |
178 |
|
179 |
def tan(arg): |
180 |
""" |
181 |
Applies the tan function to arg. |
182 |
|
183 |
@param arg: argument |
184 |
""" |
185 |
if isinstance(arg,symbols.Symbol): |
186 |
return symbols.Tan_Symbol(arg) |
187 |
elif hasattr(arg,"tan"): |
188 |
return arg.tan() |
189 |
else: |
190 |
return numarray.tan(arg) |
191 |
|
192 |
def asin(arg): |
193 |
""" |
194 |
Applies the asin function to arg. |
195 |
|
196 |
@param arg: argument |
197 |
""" |
198 |
if isinstance(arg,symbols.Symbol): |
199 |
return symbols.Asin_Symbol(arg) |
200 |
elif hasattr(arg,"asin"): |
201 |
return arg.asin() |
202 |
else: |
203 |
return numarray.asin(arg) |
204 |
|
205 |
def acos(arg): |
206 |
""" |
207 |
Applies the acos function to arg. |
208 |
|
209 |
@param arg: argument |
210 |
""" |
211 |
if isinstance(arg,symbols.Symbol): |
212 |
return symbols.Acos_Symbol(arg) |
213 |
elif hasattr(arg,"acos"): |
214 |
return arg.acos() |
215 |
else: |
216 |
return numarray.acos(arg) |
217 |
|
218 |
def atan(arg): |
219 |
""" |
220 |
Applies the atan function to arg. |
221 |
|
222 |
@param arg: argument |
223 |
""" |
224 |
if isinstance(arg,symbols.Symbol): |
225 |
return symbols.Atan_Symbol(arg) |
226 |
elif hasattr(arg,"atan"): |
227 |
return arg.atan() |
228 |
else: |
229 |
return numarray.atan(arg) |
230 |
|
231 |
def sinh(arg): |
232 |
""" |
233 |
Applies the sinh function to arg. |
234 |
|
235 |
@param arg: argument |
236 |
""" |
237 |
if isinstance(arg,symbols.Symbol): |
238 |
return symbols.Sinh_Symbol(arg) |
239 |
elif hasattr(arg,"sinh"): |
240 |
return arg.sinh() |
241 |
else: |
242 |
return numarray.sinh(arg) |
243 |
|
244 |
def cosh(arg): |
245 |
""" |
246 |
Applies the cosh function to arg. |
247 |
|
248 |
@param arg: argument |
249 |
""" |
250 |
if isinstance(arg,symbols.Symbol): |
251 |
return symbols.Cosh_Symbol(arg) |
252 |
elif hasattr(arg,"cosh"): |
253 |
return arg.cosh() |
254 |
else: |
255 |
return numarray.cosh(arg) |
256 |
|
257 |
def tanh(arg): |
258 |
""" |
259 |
Applies the tanh function to arg. |
260 |
|
261 |
@param arg: argument |
262 |
""" |
263 |
if isinstance(arg,symbols.Symbol): |
264 |
return symbols.Tanh_Symbol(arg) |
265 |
elif hasattr(arg,"tanh"): |
266 |
return arg.tanh() |
267 |
else: |
268 |
return numarray.tanh(arg) |
269 |
|
270 |
def asinh(arg): |
271 |
""" |
272 |
Applies the asinh function to arg. |
273 |
|
274 |
@param arg: argument |
275 |
""" |
276 |
if isinstance(arg,symbols.Symbol): |
277 |
return symbols.Asinh_Symbol(arg) |
278 |
elif hasattr(arg,"asinh"): |
279 |
return arg.asinh() |
280 |
else: |
281 |
return numarray.asinh(arg) |
282 |
|
283 |
def acosh(arg): |
284 |
""" |
285 |
Applies the acosh function to arg. |
286 |
|
287 |
@param arg: argument |
288 |
""" |
289 |
if isinstance(arg,symbols.Symbol): |
290 |
return symbols.Acosh_Symbol(arg) |
291 |
elif hasattr(arg,"acosh"): |
292 |
return arg.acosh() |
293 |
else: |
294 |
return numarray.acosh(arg) |
295 |
|
296 |
def atanh(arg): |
297 |
""" |
298 |
Applies the atanh function to arg. |
299 |
|
300 |
@param arg: argument |
301 |
""" |
302 |
if isinstance(arg,symbols.Symbol): |
303 |
return symbols.Atanh_Symbol(arg) |
304 |
elif hasattr(arg,"atanh"): |
305 |
return arg.atanh() |
306 |
else: |
307 |
return numarray.atanh(arg) |
308 |
|
309 |
def sign(arg): |
310 |
""" |
311 |
Applies the sign function to arg. |
312 |
|
313 |
@param arg: argument |
314 |
""" |
315 |
if isinstance(arg,symbols.Symbol): |
316 |
return symbols.Sign_Symbol(arg) |
317 |
elif hasattr(arg,"sign"): |
318 |
return arg.sign() |
319 |
else: |
320 |
return numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))- \ |
321 |
numarray.less(arg,numarray.zeros(arg.shape,numarray.Float)) |
322 |
|
323 |
def maxval(arg): |
324 |
""" |
325 |
Returns the maximum value of argument arg. |
326 |
|
327 |
@param arg: argument |
328 |
""" |
329 |
if isinstance(arg,symbols.Symbol): |
330 |
return symbols.Max_Symbol(arg) |
331 |
elif hasattr(arg,"maxval"): |
332 |
return arg.maxval() |
333 |
elif hasattr(arg,"max"): |
334 |
return arg.max() |
335 |
else: |
336 |
return arg |
337 |
|
338 |
def minval(arg): |
339 |
""" |
340 |
Returns the minimum value of argument arg. |
341 |
|
342 |
@param arg: argument |
343 |
""" |
344 |
if isinstance(arg,symbols.Symbol): |
345 |
return symbols.Min_Symbol(arg) |
346 |
elif hasattr(arg,"maxval"): |
347 |
return arg.minval() |
348 |
elif hasattr(arg,"min"): |
349 |
return arg.min() |
350 |
else: |
351 |
return arg |
352 |
|
353 |
def wherePositive(arg): |
354 |
""" |
355 |
Returns the positive values of argument arg. |
356 |
|
357 |
@param arg: argument |
358 |
""" |
359 |
if _testForZero(arg): |
360 |
return 0 |
361 |
elif isinstance(arg,symbols.Symbol): |
362 |
return symbols.WherePositive_Symbol(arg) |
363 |
elif hasattr(arg,"wherePositive"): |
364 |
return arg.minval() |
365 |
elif hasattr(arg,"wherePositive"): |
366 |
numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float)) |
367 |
else: |
368 |
if arg>0: |
369 |
return 1. |
370 |
else: |
371 |
return 0. |
372 |
|
373 |
def whereNegative(arg): |
374 |
""" |
375 |
Returns the negative values of argument arg. |
376 |
|
377 |
@param arg: argument |
378 |
""" |
379 |
if _testForZero(arg): |
380 |
return 0 |
381 |
elif isinstance(arg,symbols.Symbol): |
382 |
return symbols.WhereNegative_Symbol(arg) |
383 |
elif hasattr(arg,"whereNegative"): |
384 |
return arg.whereNegative() |
385 |
elif hasattr(arg,"shape"): |
386 |
numarray.less(arg,numarray.zeros(arg.shape,numarray.Float)) |
387 |
else: |
388 |
if arg<0: |
389 |
return 1. |
390 |
else: |
391 |
return 0. |
392 |
|
393 |
def maximum(arg0,arg1): |
394 |
""" |
395 |
Return arg1 where arg1 is bigger then arg0 otherwise arg0 is returned. |
396 |
""" |
397 |
m=whereNegative(arg0-arg1) |
398 |
return m*arg1+(1.-m)*arg0 |
399 |
|
400 |
def minimum(arg0,arg1): |
401 |
""" |
402 |
Return arg0 where arg1 is bigger then arg0 otherwise arg1 is returned. |
403 |
""" |
404 |
m=whereNegative(arg0-arg1) |
405 |
return m*arg0+(1.-m)*arg1 |
406 |
|
407 |
def outer(arg0,arg1): |
408 |
if _testForZero(arg0) or _testForZero(arg1): |
409 |
return 0 |
410 |
else: |
411 |
if isinstance(arg0,symbols.Symbol) or isinstance(arg1,symbols.Symbol): |
412 |
return symbols.Outer_Symbol(arg0,arg1) |
413 |
elif _identifyShape(arg0)==() or _identifyShape(arg1)==(): |
414 |
return arg0*arg1 |
415 |
elif isinstance(arg0,numarray.NumArray) and isinstance(arg1,numarray.NumArray): |
416 |
return numarray.outer(arg0,arg1) |
417 |
else: |
418 |
if arg0.getRank()==1 and arg1.getRank()==1: |
419 |
out=escript.Data(0,(arg0.getShape()[0],arg1.getShape()[0]),arg1.getFunctionSpace()) |
420 |
for i in range(arg0.getShape()[0]): |
421 |
for j in range(arg1.getShape()[0]): |
422 |
out[i,j]=arg0[i]*arg1[j] |
423 |
return out |
424 |
else: |
425 |
raise ValueError,"outer is not fully implemented yet." |
426 |
|
427 |
def interpolate(arg,where): |
428 |
""" |
429 |
Interpolates the function into the FunctionSpace where. |
430 |
|
431 |
@param arg: interpolant |
432 |
@param where: FunctionSpace to interpolate to |
433 |
""" |
434 |
if _testForZero(arg): |
435 |
return 0 |
436 |
elif isinstance(arg,symbols.Symbol): |
437 |
return symbols.Interpolated_Symbol(arg,where) |
438 |
else: |
439 |
return escript.Data(arg,where) |
440 |
|
441 |
def div(arg,where=None): |
442 |
""" |
443 |
Returns the divergence of arg at where. |
444 |
|
445 |
@param arg: Data object representing the function which gradient to |
446 |
be calculated. |
447 |
@param where: FunctionSpace in which the gradient will be calculated. |
448 |
If not present or C{None} an appropriate default is used. |
449 |
""" |
450 |
return trace(grad(arg,where)) |
451 |
|
452 |
def jump(arg): |
453 |
""" |
454 |
Returns the jump of arg across a continuity. |
455 |
|
456 |
@param arg: Data object representing the function which gradient |
457 |
to be calculated. |
458 |
""" |
459 |
d=arg.getDomain() |
460 |
return arg.interpolate(escript.FunctionOnContactOne())-arg.interpolate(escript.FunctionOnContactZero()) |
461 |
|
462 |
|
463 |
def grad(arg,where=None): |
464 |
""" |
465 |
Returns the spatial gradient of arg at where. |
466 |
|
467 |
@param arg: Data object representing the function which gradient |
468 |
to be calculated. |
469 |
@param where: FunctionSpace in which the gradient will be calculated. |
470 |
If not present or C{None} an appropriate default is used. |
471 |
""" |
472 |
if _testForZero(arg): |
473 |
return 0 |
474 |
elif isinstance(arg,symbols.Symbol): |
475 |
return symbols.Grad_Symbol(arg,where) |
476 |
elif hasattr(arg,"grad"): |
477 |
if where==None: |
478 |
return arg.grad() |
479 |
else: |
480 |
return arg.grad(where) |
481 |
else: |
482 |
return arg*0. |
483 |
|
484 |
def integrate(arg,where=None): |
485 |
""" |
486 |
Return the integral if the function represented by Data object arg over |
487 |
its domain. |
488 |
|
489 |
@param arg: Data object representing the function which is integrated. |
490 |
@param where: FunctionSpace in which the integral is calculated. |
491 |
If not present or C{None} an appropriate default is used. |
492 |
""" |
493 |
if _testForZero(arg): |
494 |
return 0 |
495 |
elif isinstance(arg,symbols.Symbol): |
496 |
return symbols.Integral_Symbol(arg,where) |
497 |
else: |
498 |
if not where==None: arg=escript.Data(arg,where) |
499 |
if arg.getRank()==0: |
500 |
return arg.integrate()[0] |
501 |
else: |
502 |
return arg.integrate() |
503 |
|
504 |
#============================= |
505 |
# |
506 |
# wrapper for various functions: if the argument has attribute the function name |
507 |
# as an argument it calls the corresponding methods. Otherwise the corresponding |
508 |
# numarray function is called. |
509 |
|
510 |
# functions involving the underlying Domain: |
511 |
|
512 |
|
513 |
# functions returning Data objects: |
514 |
|
515 |
def transpose(arg,axis=None): |
516 |
""" |
517 |
Returns the transpose of the Data object arg. |
518 |
|
519 |
@param arg: |
520 |
""" |
521 |
if axis==None: |
522 |
r=0 |
523 |
if hasattr(arg,"getRank"): r=arg.getRank() |
524 |
if hasattr(arg,"rank"): r=arg.rank |
525 |
axis=r/2 |
526 |
if isinstance(arg,symbols.Symbol): |
527 |
return symbols.Transpose_Symbol(arg,axis=r) |
528 |
if isinstance(arg,escript.Data): |
529 |
# hack for transpose |
530 |
r=arg.getRank() |
531 |
if r!=2: raise ValueError,"Tranpose only avalaible for rank 2 objects" |
532 |
s=arg.getShape() |
533 |
out=escript.Data(0.,(s[1],s[0]),arg.getFunctionSpace()) |
534 |
for i in range(s[0]): |
535 |
for j in range(s[1]): |
536 |
out[j,i]=arg[i,j] |
537 |
return out |
538 |
# end hack for transpose |
539 |
return arg.transpose(axis) |
540 |
else: |
541 |
return numarray.transpose(arg,axis=axis) |
542 |
|
543 |
def trace(arg,axis0=0,axis1=1): |
544 |
""" |
545 |
Return |
546 |
|
547 |
@param arg: |
548 |
""" |
549 |
if isinstance(arg,symbols.Symbol): |
550 |
s=list(arg.getShape()) |
551 |
s=tuple(s[0:axis0]+s[axis0+1:axis1]+s[axis1+1:]) |
552 |
return symbols.Trace_Symbol(arg,axis0=axis0,axis1=axis1) |
553 |
elif isinstance(arg,escript.Data): |
554 |
# hack for trace |
555 |
s=arg.getShape() |
556 |
if s[axis0]!=s[axis1]: |
557 |
raise ValueError,"illegal axis in trace" |
558 |
out=escript.Scalar(0.,arg.getFunctionSpace()) |
559 |
for i in range(s[axis0]): |
560 |
out+=arg[i,i] |
561 |
return out |
562 |
# end hack for trace |
563 |
else: |
564 |
return numarray.trace(arg,axis0=axis0,axis1=axis1) |
565 |
|
566 |
def length(arg): |
567 |
""" |
568 |
|
569 |
@param arg: |
570 |
""" |
571 |
if isinstance(arg,escript.Data): |
572 |
if arg.isEmpty(): return escript.Data() |
573 |
if arg.getRank()==0: |
574 |
return abs(arg) |
575 |
elif arg.getRank()==1: |
576 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
577 |
for i in range(arg.getShape()[0]): |
578 |
out+=arg[i]**2 |
579 |
return sqrt(out) |
580 |
elif arg.getRank()==2: |
581 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
582 |
for i in range(arg.getShape()[0]): |
583 |
for j in range(arg.getShape()[1]): |
584 |
out+=arg[i,j]**2 |
585 |
return sqrt(out) |
586 |
elif arg.getRank()==3: |
587 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
588 |
for i in range(arg.getShape()[0]): |
589 |
for j in range(arg.getShape()[1]): |
590 |
for k in range(arg.getShape()[2]): |
591 |
out+=arg[i,j,k]**2 |
592 |
return sqrt(out) |
593 |
elif arg.getRank()==4: |
594 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
595 |
for i in range(arg.getShape()[0]): |
596 |
for j in range(arg.getShape()[1]): |
597 |
for k in range(arg.getShape()[2]): |
598 |
for l in range(arg.getShape()[3]): |
599 |
out+=arg[i,j,k,l]**2 |
600 |
return sqrt(out) |
601 |
else: |
602 |
raise SystemError,"length is not been fully implemented yet" |
603 |
# return arg.length() |
604 |
elif isinstance(arg,float): |
605 |
return abs(arg) |
606 |
else: |
607 |
return sqrt((arg**2).sum()) |
608 |
|
609 |
def deviator(arg): |
610 |
""" |
611 |
@param arg: |
612 |
""" |
613 |
if isinstance(arg,escript.Data): |
614 |
shape=arg.getShape() |
615 |
else: |
616 |
shape=arg.shape |
617 |
if len(shape)!=2: |
618 |
raise ValueError,"Deviator requires rank 2 object" |
619 |
if shape[0]!=shape[1]: |
620 |
raise ValueError,"Deviator requires a square matrix" |
621 |
return arg-1./(shape[0]*1.)*trace(arg)*kronecker(shape[0]) |
622 |
|
623 |
def inner(arg0,arg1): |
624 |
""" |
625 |
@param arg0: |
626 |
@param arg1: |
627 |
""" |
628 |
if isinstance(arg0,escript.Data): |
629 |
arg=arg0 |
630 |
else: |
631 |
arg=arg1 |
632 |
|
633 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
634 |
if arg.getRank()==0: |
635 |
return arg0*arg1 |
636 |
elif arg.getRank()==1: |
637 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
638 |
for i in range(arg.getShape()[0]): |
639 |
out+=arg0[i]*arg1[i] |
640 |
elif arg.getRank()==2: |
641 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
642 |
for i in range(arg.getShape()[0]): |
643 |
for j in range(arg.getShape()[1]): |
644 |
out+=arg0[i,j]*arg1[i,j] |
645 |
elif arg.getRank()==3: |
646 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
647 |
for i in range(arg.getShape()[0]): |
648 |
for j in range(arg.getShape()[1]): |
649 |
for k in range(arg.getShape()[2]): |
650 |
out+=arg0[i,j,k]*arg1[i,j,k] |
651 |
elif arg.getRank()==4: |
652 |
out=escript.Scalar(0,arg.getFunctionSpace()) |
653 |
for i in range(arg.getShape()[0]): |
654 |
for j in range(arg.getShape()[1]): |
655 |
for k in range(arg.getShape()[2]): |
656 |
for l in range(arg.getShape()[3]): |
657 |
out+=arg0[i,j,k,l]*arg1[i,j,k,l] |
658 |
else: |
659 |
raise SystemError,"inner is not been implemented yet" |
660 |
return out |
661 |
|
662 |
def tensormult(arg0,arg1): |
663 |
# check LinearPDE!!!! |
664 |
raise SystemError,"tensormult is not implemented yet!" |
665 |
|
666 |
def matrixmult(arg0,arg1): |
667 |
|
668 |
if isinstance(arg1,numarray.NumArray) and isinstance(arg0,numarray.NumArray): |
669 |
numarray.matrixmult(arg0,arg1) |
670 |
else: |
671 |
# escript.matmult(arg0,arg1) |
672 |
if isinstance(arg1,escript.Data) and not isinstance(arg0,escript.Data): |
673 |
arg0=escript.Data(arg0,arg1.getFunctionSpace()) |
674 |
elif isinstance(arg0,escript.Data) and not isinstance(arg1,escript.Data): |
675 |
arg1=escript.Data(arg1,arg0.getFunctionSpace()) |
676 |
if arg0.getRank()==2 and arg1.getRank()==1: |
677 |
out=escript.Data(0,(arg0.getShape()[0],),arg0.getFunctionSpace()) |
678 |
for i in range(arg0.getShape()[0]): |
679 |
for j in range(arg0.getShape()[1]): |
680 |
# uses Data object slicing, plus Data * and += operators |
681 |
out[i]+=arg0[i,j]*arg1[j] |
682 |
return out |
683 |
elif arg0.getRank()==1 and arg1.getRank()==1: |
684 |
return inner(arg0,arg1) |
685 |
else: |
686 |
raise SystemError,"matrixmult is not fully implemented yet!" |
687 |
|
688 |
#========================================================= |
689 |
# reduction operations: |
690 |
#========================================================= |
691 |
def sum(arg): |
692 |
""" |
693 |
@param arg: |
694 |
""" |
695 |
return arg.sum() |
696 |
|
697 |
def sup(arg): |
698 |
""" |
699 |
@param arg: |
700 |
""" |
701 |
if isinstance(arg,escript.Data): |
702 |
return arg.sup() |
703 |
elif isinstance(arg,float) or isinstance(arg,int): |
704 |
return arg |
705 |
else: |
706 |
return arg.max() |
707 |
|
708 |
def inf(arg): |
709 |
""" |
710 |
@param arg: |
711 |
""" |
712 |
if isinstance(arg,escript.Data): |
713 |
return arg.inf() |
714 |
elif isinstance(arg,float) or isinstance(arg,int): |
715 |
return arg |
716 |
else: |
717 |
return arg.min() |
718 |
|
719 |
def L2(arg): |
720 |
""" |
721 |
Returns the L2-norm of the argument |
722 |
|
723 |
@param arg: |
724 |
""" |
725 |
if isinstance(arg,escript.Data): |
726 |
return arg.L2() |
727 |
elif isinstance(arg,float) or isinstance(arg,int): |
728 |
return abs(arg) |
729 |
else: |
730 |
return numarry.sqrt(dot(arg,arg)) |
731 |
|
732 |
def Lsup(arg): |
733 |
""" |
734 |
@param arg: |
735 |
""" |
736 |
if isinstance(arg,escript.Data): |
737 |
return arg.Lsup() |
738 |
elif isinstance(arg,float) or isinstance(arg,int): |
739 |
return abs(arg) |
740 |
else: |
741 |
return numarray.abs(arg).max() |
742 |
|
743 |
def dot(arg0,arg1): |
744 |
""" |
745 |
@param arg0: |
746 |
@param arg1: |
747 |
""" |
748 |
if isinstance(arg0,escript.Data): |
749 |
return arg0.dot(arg1) |
750 |
elif isinstance(arg1,escript.Data): |
751 |
return arg1.dot(arg0) |
752 |
else: |
753 |
return numarray.dot(arg0,arg1) |
754 |
|
755 |
def kronecker(d): |
756 |
if hasattr(d,"getDim"): |
757 |
return numarray.identity(d.getDim())*1. |
758 |
else: |
759 |
return numarray.identity(d)*1. |
760 |
|
761 |
def unit(i,d): |
762 |
""" |
763 |
Return a unit vector of dimension d with nonzero index i. |
764 |
|
765 |
@param d: dimension |
766 |
@param i: index |
767 |
""" |
768 |
e = numarray.zeros((d,),numarray.Float) |
769 |
e[i] = 1.0 |
770 |
return e |