1 |
# $Id$ |
2 |
|
3 |
""" |
4 |
symbolic tool box for escript |
5 |
""" |
6 |
|
7 |
import numarray |
8 |
|
9 |
#=========================================================== |
10 |
# a simple tool box to deal with _differentials of functions |
11 |
#=========================================================== |
12 |
|
13 |
class Symbol: |
14 |
""" |
15 |
Symbol class. |
16 |
""" |
17 |
def __init__(self,name="symbol",shape=(),dim=3,args=[]): |
18 |
""" |
19 |
Creates an instance of a symbol of shape shape and spatial dimension |
20 |
dim. |
21 |
|
22 |
The symbol may depending on a list of arguments args which may be |
23 |
symbols or other objects. name gives the name of the symbol. |
24 |
""" |
25 |
|
26 |
self.__args=args |
27 |
self.__name=name |
28 |
self.__shape=shape |
29 |
if hasattr(dim,"getDim"): |
30 |
self.__dim=dim.getDim() |
31 |
else: |
32 |
self.__dim=dim |
33 |
# |
34 |
self.__cache_val=None |
35 |
self.__cache_argval=None |
36 |
|
37 |
def getArgument(self,i): |
38 |
""" |
39 |
Returns the i-th argument. |
40 |
""" |
41 |
return self.__args[i] |
42 |
|
43 |
def getDim(self): |
44 |
""" |
45 |
Returns the spatial dimension of the symbol. |
46 |
""" |
47 |
return self.__dim |
48 |
|
49 |
def getRank(self): |
50 |
""" |
51 |
Returns the rank of the symbol. |
52 |
""" |
53 |
return len(self.getShape()) |
54 |
|
55 |
def getShape(self): |
56 |
""" |
57 |
Returns the shape of the symbol. |
58 |
""" |
59 |
return self.__shape |
60 |
|
61 |
def getEvaluatedArguments(self,argval): |
62 |
""" |
63 |
Returns the list of evaluated arguments by subsituting symbol u by |
64 |
argval[u]. |
65 |
""" |
66 |
if argval==self.__cache_argval: |
67 |
print "%s: cached value used"%self |
68 |
return self.__cache_val |
69 |
else: |
70 |
out=[] |
71 |
for a in self.__args: |
72 |
if isinstance(a,Symbol): |
73 |
out.append(a.eval(argval)) |
74 |
else: |
75 |
out.append(a) |
76 |
self.__cache_argval=argval |
77 |
self.__cache_val=out |
78 |
return out |
79 |
|
80 |
def getDifferentiatedArguments(self,arg): |
81 |
""" |
82 |
Returns the list of the arguments _differentiated by arg. |
83 |
""" |
84 |
out=[] |
85 |
for a in self.__args: |
86 |
if isinstance(a,Symbol): |
87 |
out.append(a.diff(arg)) |
88 |
else: |
89 |
out.append(0) |
90 |
return out |
91 |
|
92 |
def diff(self,arg): |
93 |
""" |
94 |
Returns the _differention of self by arg. |
95 |
""" |
96 |
if self==arg: |
97 |
out=numarray.zeros(tuple(2*list(self.getShape())),numarray.Float) |
98 |
if self.getRank()==0: |
99 |
out=1. |
100 |
elif self.getRank()==1: |
101 |
for i0 in range(self.getShape()[0]): |
102 |
out[i0,i0]=1. |
103 |
elif self.getRank()==2: |
104 |
for i0 in range(self.getShape()[0]): |
105 |
for i1 in range(self.getShape()[1]): |
106 |
out[i0,i1,i0,i1]=1. |
107 |
elif self.getRank()==3: |
108 |
for i0 in range(self.getShape()[0]): |
109 |
for i1 in range(self.getShape()[1]): |
110 |
for i2 in range(self.getShape()[2]): |
111 |
out[i0,i1,i2,i0,i1,i2]=1. |
112 |
elif self.getRank()==4: |
113 |
for i0 in range(self.getShape()[0]): |
114 |
for i1 in range(self.getShape()[1]): |
115 |
for i2 in range(self.getShape()[2]): |
116 |
for i3 in range(self.getShape()[3]): |
117 |
out[i0,i1,i2,i3,i0,i1,i2,i3]=1. |
118 |
else: |
119 |
raise ValueError,"differential support rank<5 only." |
120 |
return out |
121 |
else: |
122 |
return self._diff(arg) |
123 |
|
124 |
def _diff(self,arg): |
125 |
""" |
126 |
Return derivate of self with respect to arg (!=self). |
127 |
|
128 |
This method is overwritten by a particular symbol. |
129 |
""" |
130 |
return 0 |
131 |
|
132 |
def eval(self,argval): |
133 |
""" |
134 |
Subsitutes symbol u in self by argval[u] and returns the result. If |
135 |
self is not a key of argval then self is returned. |
136 |
""" |
137 |
if argval.has_key(self): |
138 |
return argval[self] |
139 |
else: |
140 |
return self |
141 |
|
142 |
def __str__(self): |
143 |
""" |
144 |
Returns a string representation of the symbol. |
145 |
""" |
146 |
return self.__name |
147 |
|
148 |
def __add__(self,other): |
149 |
""" |
150 |
Adds other to symbol self. if _testForZero(other) self is returned. |
151 |
""" |
152 |
if _testForZero(other): |
153 |
return self |
154 |
else: |
155 |
a=_matchShape([self,other]) |
156 |
return Add_Symbol(a[0],a[1]) |
157 |
|
158 |
def __radd__(self,other): |
159 |
""" |
160 |
Adds other to symbol self. if _testForZero(other) self is returned. |
161 |
""" |
162 |
return self+other |
163 |
|
164 |
def __neg__(self): |
165 |
""" |
166 |
Returns -self. |
167 |
""" |
168 |
return self*(-1.) |
169 |
|
170 |
def __pos__(self): |
171 |
""" |
172 |
Returns +self. |
173 |
""" |
174 |
return self |
175 |
|
176 |
def __abs__(self): |
177 |
""" |
178 |
Returns absolute value. |
179 |
""" |
180 |
return Abs_Symbol(self) |
181 |
|
182 |
def __sub__(self,other): |
183 |
""" |
184 |
Subtracts other from symbol self. |
185 |
|
186 |
If _testForZero(other) self is returned. |
187 |
""" |
188 |
if _testForZero(other): |
189 |
return self |
190 |
else: |
191 |
return self+(-other) |
192 |
|
193 |
def __rsub__(self,other): |
194 |
""" |
195 |
Subtracts symbol self from other. |
196 |
""" |
197 |
return -self+other |
198 |
|
199 |
def __div__(self,other): |
200 |
""" |
201 |
Divides symbol self by other. |
202 |
""" |
203 |
if isinstance(other,Symbol): |
204 |
a=_matchShape([self,other]) |
205 |
return Div_Symbol(a[0],a[1]) |
206 |
else: |
207 |
return self*(1./other) |
208 |
|
209 |
def __rdiv__(self,other): |
210 |
""" |
211 |
Dived other by symbol self. if _testForZero(other) 0 is returned. |
212 |
""" |
213 |
if _testForZero(other): |
214 |
return 0 |
215 |
else: |
216 |
a=_matchShape([self,other]) |
217 |
return Div_Symbol(a[0],a[1]) |
218 |
|
219 |
def __pow__(self,other): |
220 |
""" |
221 |
Raises symbol self to the power of other. |
222 |
""" |
223 |
a=_matchShape([self,other]) |
224 |
return Power_Symbol(a[0],a[1]) |
225 |
|
226 |
def __rpow__(self,other): |
227 |
""" |
228 |
Raises other to the symbol self. |
229 |
""" |
230 |
a=_matchShape([self,other]) |
231 |
return Power_Symbol(a[1],a[0]) |
232 |
|
233 |
def __mul__(self,other): |
234 |
""" |
235 |
Multiplies other by symbol self. if _testForZero(other) 0 is returned. |
236 |
""" |
237 |
if _testForZero(other): |
238 |
return 0 |
239 |
else: |
240 |
a=_matchShape([self,other]) |
241 |
return Mult_Symbol(a[0],a[1]) |
242 |
|
243 |
def __rmul__(self,other): |
244 |
""" |
245 |
Multiplies other by symbol self. if _testSForZero(other) 0 is returned. |
246 |
""" |
247 |
return self*other |
248 |
|
249 |
def __getitem__(self,sl): |
250 |
print sl |
251 |
|
252 |
class Float_Symbol(Symbol): |
253 |
def __init__(self,name="symbol",shape=(),args=[]): |
254 |
Symbol.__init__(self,dim=0,name="symbol",shape=(),args=[]) |
255 |
|
256 |
class ScalarSymbol(Symbol): |
257 |
""" |
258 |
A scalar symbol. |
259 |
""" |
260 |
def __init__(self,dim=3,name="scalar"): |
261 |
""" |
262 |
Creates a scalar symbol of spatial dimension dim. |
263 |
""" |
264 |
if hasattr(dim,"getDim"): |
265 |
d=dim.getDim() |
266 |
else: |
267 |
d=dim |
268 |
Symbol.__init__(self,shape=(),dim=d,name=name) |
269 |
|
270 |
class VectorSymbol(Symbol): |
271 |
""" |
272 |
A vector symbol. |
273 |
""" |
274 |
def __init__(self,dim=3,name="vector"): |
275 |
""" |
276 |
Creates a vector symbol of spatial dimension dim. |
277 |
""" |
278 |
if hasattr(dim,"getDim"): |
279 |
d=dim.getDim() |
280 |
else: |
281 |
d=dim |
282 |
Symbol.__init__(self,shape=(d,),dim=d,name=name) |
283 |
|
284 |
class TensorSymbol(Symbol): |
285 |
""" |
286 |
A tensor symbol. |
287 |
""" |
288 |
def __init__(self,dim=3,name="tensor"): |
289 |
""" |
290 |
Creates a tensor symbol of spatial dimension dim. |
291 |
""" |
292 |
if hasattr(dim,"getDim"): |
293 |
d=dim.getDim() |
294 |
else: |
295 |
d=dim |
296 |
Symbol.__init__(self,shape=(d,d),dim=d,name=name) |
297 |
|
298 |
class Tensor3Symbol(Symbol): |
299 |
""" |
300 |
A tensor order 3 symbol. |
301 |
""" |
302 |
def __init__(self,dim=3,name="tensor3"): |
303 |
""" |
304 |
Creates a tensor order 3 symbol of spatial dimension dim. |
305 |
""" |
306 |
if hasattr(dim,"getDim"): |
307 |
d=dim.getDim() |
308 |
else: |
309 |
d=dim |
310 |
Symbol.__init__(self,shape=(d,d,d),dim=d,name=name) |
311 |
|
312 |
class Tensor4Symbol(Symbol): |
313 |
""" |
314 |
A tensor order 4 symbol. |
315 |
""" |
316 |
def __init__(self,dim=3,name="tensor4"): |
317 |
""" |
318 |
Creates a tensor order 4 symbol of spatial dimension dim. |
319 |
""" |
320 |
if hasattr(dim,"getDim"): |
321 |
d=dim.getDim() |
322 |
else: |
323 |
d=dim |
324 |
Symbol.__init__(self,shape=(d,d,d,d),dim=d,name=name) |
325 |
|
326 |
class Add_Symbol(Symbol): |
327 |
""" |
328 |
Symbol representing the sum of two arguments. |
329 |
""" |
330 |
def __init__(self,arg0,arg1): |
331 |
a=[arg0,arg1] |
332 |
Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a) |
333 |
def __str__(self): |
334 |
return "(%s+%s)"%(str(self.getArgument(0)),str(self.getArgument(1))) |
335 |
def eval(self,argval): |
336 |
a=self.getEvaluatedArguments(argval) |
337 |
return a[0]+a[1] |
338 |
def _diff(self,arg): |
339 |
a=self.getDifferentiatedArguments(arg) |
340 |
return a[0]+a[1] |
341 |
|
342 |
class Mult_Symbol(Symbol): |
343 |
""" |
344 |
Symbol representing the product of two arguments. |
345 |
""" |
346 |
def __init__(self,arg0,arg1): |
347 |
a=[arg0,arg1] |
348 |
Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a) |
349 |
def __str__(self): |
350 |
return "(%s*%s)"%(str(self.getArgument(0)),str(self.getArgument(1))) |
351 |
def eval(self,argval): |
352 |
a=self.getEvaluatedArguments(argval) |
353 |
return a[0]*a[1] |
354 |
def _diff(self,arg): |
355 |
a=self.getDifferentiatedArguments(arg) |
356 |
return self.getArgument(1)*a[0]+self.getArgument(0)*a[1] |
357 |
|
358 |
class Div_Symbol(Symbol): |
359 |
""" |
360 |
Symbol representing the quotient of two arguments. |
361 |
""" |
362 |
def __init__(self,arg0,arg1): |
363 |
a=[arg0,arg1] |
364 |
Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a) |
365 |
def __str__(self): |
366 |
return "(%s/%s)"%(str(self.getArgument(0)),str(self.getArgument(1))) |
367 |
def eval(self,argval): |
368 |
a=self.getEvaluatedArguments(argval) |
369 |
return a[0]/a[1] |
370 |
def _diff(self,arg): |
371 |
a=self.getDifferentiatedArguments(arg) |
372 |
return (a[0]*self.getArgument(1)-self.getArgument(0)*a[1])/ \ |
373 |
(self.getArgument(1)*self.getArgument(1)) |
374 |
|
375 |
class Power_Symbol(Symbol): |
376 |
""" |
377 |
Symbol representing the power of the first argument to the power of the |
378 |
second argument. |
379 |
""" |
380 |
def __init__(self,arg0,arg1): |
381 |
a=[arg0,arg1] |
382 |
Symbol.__init__(self,dim=_extractDim(a),shape=_extractShape(a),args=a) |
383 |
def __str__(self): |
384 |
return "(%s**%s)"%(str(self.getArgument(0)),str(self.getArgument(1))) |
385 |
def eval(self,argval): |
386 |
a=self.getEvaluatedArguments(argval) |
387 |
return a[0]**a[1] |
388 |
def _diff(self,arg): |
389 |
a=self.getDifferentiatedArguments(arg) |
390 |
return self*(a[1]*log(self.getArgument(0))+self.getArgument(1)/self.getArgument(0)*a[0]) |
391 |
|
392 |
class Abs_Symbol(Symbol): |
393 |
""" |
394 |
Symbol representing absolute value of its argument. |
395 |
""" |
396 |
def __init__(self,arg): |
397 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
398 |
def __str__(self): |
399 |
return "abs(%s)"%str(self.getArgument(0)) |
400 |
def eval(self,argval): |
401 |
return abs(self.getEvaluatedArguments(argval)[0]) |
402 |
def _diff(self,arg): |
403 |
return sign(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0] |
404 |
|
405 |
#========================================================= |
406 |
# some little helpers |
407 |
#========================================================= |
408 |
def _testForZero(arg): |
409 |
""" |
410 |
Returns True is arg is considered to be zero. |
411 |
""" |
412 |
if isinstance(arg,int): |
413 |
return not arg>0 |
414 |
elif isinstance(arg,float): |
415 |
return not arg>0. |
416 |
elif isinstance(arg,numarray.NumArray): |
417 |
a=abs(arg) |
418 |
while isinstance(a,numarray.NumArray): a=numarray.sometrue(a) |
419 |
return not a>0 |
420 |
else: |
421 |
return False |
422 |
|
423 |
def _extractDim(args): |
424 |
dim=None |
425 |
for a in args: |
426 |
if hasattr(a,"getDim"): |
427 |
d=a.getDim() |
428 |
if dim==None: |
429 |
dim=d |
430 |
else: |
431 |
if dim!=d: raise ValueError,"inconsistent spatial dimension of arguments" |
432 |
if dim==None: |
433 |
raise ValueError,"cannot recover spatial dimension" |
434 |
return dim |
435 |
|
436 |
def _identifyShape(arg): |
437 |
""" |
438 |
Identifies the shape of arg. |
439 |
""" |
440 |
if hasattr(arg,"getShape"): |
441 |
arg_shape=arg.getShape() |
442 |
elif hasattr(arg,"shape"): |
443 |
s=arg.shape |
444 |
if callable(s): |
445 |
arg_shape=s() |
446 |
else: |
447 |
arg_shape=s |
448 |
else: |
449 |
arg_shape=() |
450 |
return arg_shape |
451 |
|
452 |
def _extractShape(args): |
453 |
""" |
454 |
Extracts the common shape of the list of arguments args. |
455 |
""" |
456 |
shape=None |
457 |
for a in args: |
458 |
a_shape=_identifyShape(a) |
459 |
if shape==None: shape=a_shape |
460 |
if shape!=a_shape: raise ValueError,"inconsistent shape" |
461 |
if shape==None: |
462 |
raise ValueError,"cannot recover shape" |
463 |
return shape |
464 |
|
465 |
def _matchShape(args,shape=None): |
466 |
""" |
467 |
Returns the list of arguments args as object which have all the |
468 |
specified shape. |
469 |
|
470 |
If shape is not given the shape "largest" shape of args is used. |
471 |
""" |
472 |
# identify the list of shapes: |
473 |
arg_shapes=[] |
474 |
for a in args: arg_shapes.append(_identifyShape(a)) |
475 |
# get the largest shape (currently the longest shape): |
476 |
if shape==None: shape=max(arg_shapes) |
477 |
|
478 |
out=[] |
479 |
for i in range(len(args)): |
480 |
if shape==arg_shapes[i]: |
481 |
out.append(args[i]) |
482 |
else: |
483 |
if len(shape)==0: # then len(arg_shapes[i])>0 |
484 |
raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape)) |
485 |
else: |
486 |
if len(arg_shapes[i])==0: |
487 |
out.append(outer(args[i],numarray.ones(shape))) |
488 |
else: |
489 |
raise ValueError,"cannot adopt shape of %s to %s"%(str(args[i]),str(shape)) |
490 |
return out |
491 |
|
492 |
class Exp_Symbol(Symbol): |
493 |
""" |
494 |
Symbol representing the power of the first argument to the power of the |
495 |
second argument. |
496 |
""" |
497 |
def __init__(self,arg): |
498 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
499 |
def __str__(self): |
500 |
return "exp(%s)"%str(self.getArgument(0)) |
501 |
def eval(self,argval): |
502 |
return exp(self.getEvaluatedArguments(argval)[0]) |
503 |
def _diff(self,arg): |
504 |
return self*self.getDifferentiatedArguments(arg)[0] |
505 |
|
506 |
class Sqrt_Symbol(Symbol): |
507 |
""" |
508 |
Symbol representing square root of argument. |
509 |
""" |
510 |
def __init__(self,arg): |
511 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
512 |
def __str__(self): |
513 |
return "sqrt(%s)"%str(self.getArgument(0)) |
514 |
def eval(self,argval): |
515 |
return sqrt(self.getEvaluatedArguments(argval)[0]) |
516 |
def _diff(self,arg): |
517 |
return (-0.5)/self*self.getDifferentiatedArguments(arg)[0] |
518 |
|
519 |
class Log_Symbol(Symbol): |
520 |
""" |
521 |
Symbol representing logarithm of the argument. |
522 |
""" |
523 |
def __init__(self,arg): |
524 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
525 |
def __str__(self): |
526 |
return "log(%s)"%str(self.getArgument(0)) |
527 |
def eval(self,argval): |
528 |
return log(self.getEvaluatedArguments(argval)[0]) |
529 |
def _diff(self,arg): |
530 |
return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0) |
531 |
|
532 |
class Ln_Symbol(Symbol): |
533 |
""" |
534 |
Symbol representing natural logarithm of the argument. |
535 |
""" |
536 |
def __init__(self,arg): |
537 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
538 |
def __str__(self): |
539 |
return "ln(%s)"%str(self.getArgument(0)) |
540 |
def eval(self,argval): |
541 |
return ln(self.getEvaluatedArguments(argval)[0]) |
542 |
def _diff(self,arg): |
543 |
return self.getDifferentiatedArguments(arg)[0]/self.getArgument(0) |
544 |
|
545 |
class Sin_Symbol(Symbol): |
546 |
""" |
547 |
Symbol representing sin of the argument. |
548 |
""" |
549 |
def __init__(self,arg): |
550 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
551 |
def __str__(self): |
552 |
return "sin(%s)"%str(self.getArgument(0)) |
553 |
def eval(self,argval): |
554 |
return sin(self.getEvaluatedArguments(argval)[0]) |
555 |
def _diff(self,arg): |
556 |
return cos(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0] |
557 |
|
558 |
class Cos_Symbol(Symbol): |
559 |
""" |
560 |
Symbol representing cos of the argument. |
561 |
""" |
562 |
def __init__(self,arg): |
563 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
564 |
def __str__(self): |
565 |
return "cos(%s)"%str(self.getArgument(0)) |
566 |
def eval(self,argval): |
567 |
return cos(self.getEvaluatedArguments(argval)[0]) |
568 |
def _diff(self,arg): |
569 |
return -sin(self.getArgument(0))*self.getDifferentiatedArguments(arg)[0] |
570 |
|
571 |
class Tan_Symbol(Symbol): |
572 |
""" |
573 |
Symbol representing tan of the argument. |
574 |
""" |
575 |
def __init__(self,arg): |
576 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
577 |
def __str__(self): |
578 |
return "tan(%s)"%str(self.getArgument(0)) |
579 |
def eval(self,argval): |
580 |
return tan(self.getEvaluatedArguments(argval)[0]) |
581 |
def _diff(self,arg): |
582 |
s=cos(self.getArgument(0)) |
583 |
return 1./(s*s)*self.getDifferentiatedArguments(arg)[0] |
584 |
|
585 |
class Sign_Symbol(Symbol): |
586 |
""" |
587 |
Symbol representing the sign of the argument. |
588 |
""" |
589 |
def __init__(self,arg): |
590 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
591 |
def __str__(self): |
592 |
return "sign(%s)"%str(self.getArgument(0)) |
593 |
def eval(self,argval): |
594 |
return sign(self.getEvaluatedArguments(argval)[0]) |
595 |
|
596 |
class Max_Symbol(Symbol): |
597 |
""" |
598 |
Symbol representing the maximum value of the argument. |
599 |
""" |
600 |
def __init__(self,arg): |
601 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
602 |
def __str__(self): |
603 |
return "maxval(%s)"%str(self.getArgument(0)) |
604 |
def eval(self,argval): |
605 |
return maxval(self.getEvaluatedArguments(argval)[0]) |
606 |
|
607 |
class Min_Symbol(Symbol): |
608 |
""" |
609 |
Symbol representing the minimum value of the argument. |
610 |
""" |
611 |
def __init__(self,arg): |
612 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
613 |
def __str__(self): |
614 |
return "minval(%s)"%str(self.getArgument(0)) |
615 |
def eval(self,argval): |
616 |
return minval(self.getEvaluatedArguments(argval)[0]) |
617 |
|
618 |
class WherePositive_Symbol(Symbol): |
619 |
""" |
620 |
Symbol representing the wherePositive function. |
621 |
""" |
622 |
def __init__(self,arg): |
623 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
624 |
def __str__(self): |
625 |
return "wherePositive(%s)"%str(self.getArgument(0)) |
626 |
def eval(self,argval): |
627 |
return wherePositive(self.getEvaluatedArguments(argval)[0]) |
628 |
|
629 |
class WhereNegative_Symbol(Symbol): |
630 |
""" |
631 |
Symbol representing the whereNegative function. |
632 |
""" |
633 |
def __init__(self,arg): |
634 |
Symbol.__init__(self,shape=arg.getShape(),dim=arg.getDim(),args=[arg]) |
635 |
def __str__(self): |
636 |
return "whereNegative(%s)"%str(self.getArgument(0)) |
637 |
def eval(self,argval): |
638 |
return whereNegative(self.getEvaluatedArguments(argval)[0]) |
639 |
|
640 |
class Outer_Symbol(Symbol): |
641 |
""" |
642 |
Symbol representing the outer product of its two arguments. |
643 |
""" |
644 |
def __init__(self,arg0,arg1): |
645 |
a=[arg0,arg1] |
646 |
s=tuple(list(_identifyShape(arg0))+list(_identifyShape(arg1))) |
647 |
Symbol.__init__(self,shape=s,dim=_extractDim(a),args=a) |
648 |
def __str__(self): |
649 |
return "outer(%s,%s)"%(str(self.getArgument(0)),str(self.getArgument(1))) |
650 |
def eval(self,argval): |
651 |
a=self.getEvaluatedArguments(argval) |
652 |
return outer(a[0],a[1]) |
653 |
def _diff(self,arg): |
654 |
a=self.getDifferentiatedArguments(arg) |
655 |
return outer(a[0],self.getArgument(1))+outer(self.getArgument(0),a[1]) |
656 |
|
657 |
class Interpolated_Symbol(Symbol): |
658 |
""" |
659 |
Symbol representing the integral of the argument. |
660 |
""" |
661 |
def __init__(self,arg,where): |
662 |
Symbol.__init__(self,shape=_extractShape(arg),dim=_extractDim([arg]),args=[arg,where]) |
663 |
def __str__(self): |
664 |
return "interpolated(%s)"%(str(self.getArgument(0))) |
665 |
def eval(self,argval): |
666 |
a=self.getEvaluatedArguments(argval) |
667 |
return integrate(a[0],where=self.getArgument(1)) |
668 |
def _diff(self,arg): |
669 |
a=self.getDifferentiatedArguments(arg) |
670 |
return integrate(a[0],where=self.getArgument(1)) |
671 |
|
672 |
class Grad_Symbol(Symbol): |
673 |
""" |
674 |
Symbol representing the gradient of the argument. |
675 |
""" |
676 |
def __init__(self,arg,where=None): |
677 |
d=_extractDim([arg]) |
678 |
s=tuple(list(_identifyShape([arg])).append(d)) |
679 |
Symbol.__init__(self,shape=s,dim=_extractDim([arg]),args=[arg,where]) |
680 |
def __str__(self): |
681 |
return "grad(%s)"%(str(self.getArgument(0))) |
682 |
def eval(self,argval): |
683 |
a=self.getEvaluatedArguments(argval) |
684 |
return grad(a[0],where=self.getArgument(1)) |
685 |
def _diff(self,arg): |
686 |
a=self.getDifferentiatedArguments(arg) |
687 |
return grad(a[0],where=self.getArgument(1)) |
688 |
|
689 |
class Integral_Symbol(Float_Symbol): |
690 |
""" |
691 |
Symbol representing the integral of the argument. |
692 |
""" |
693 |
def __init__(self,arg,where=None): |
694 |
Float_Symbol.__init__(self,shape=_identifyShape([arg]),args=[arg,where]) |
695 |
def __str__(self): |
696 |
return "integral(%s)"%(str(self.getArgument(0))) |
697 |
def eval(self,argval): |
698 |
a=self.getEvaluatedArguments(argval) |
699 |
return integrate(a[0],where=self.getArgument(1)) |
700 |
def _diff(self,arg): |
701 |
a=self.getDifferentiatedArguments(arg) |
702 |
return integrate(a[0],where=self.getArgument(1)) |
703 |
|
704 |
# ============================================ |
705 |
# testing |
706 |
# ============================================ |
707 |
|
708 |
if __name__=="__main__": |
709 |
u=ScalarSymbol(dim=2,name="u") |
710 |
v=ScalarSymbol(dim=2,name="v") |
711 |
v=VectorSymbol(2,"v") |
712 |
u=VectorSymbol(2,"u") |
713 |
|
714 |
print u+5,(u+5).diff(u) |
715 |
print 5+u,(5+u).diff(u) |
716 |
print u+v,(u+v).diff(u) |
717 |
print v+u,(v+u).diff(u) |
718 |
|
719 |
print u*5,(u*5).diff(u) |
720 |
print 5*u,(5*u).diff(u) |
721 |
print u*v,(u*v).diff(u) |
722 |
print v*u,(v*u).diff(u) |
723 |
|
724 |
print u-5,(u-5).diff(u) |
725 |
print 5-u,(5-u).diff(u) |
726 |
print u-v,(u-v).diff(u) |
727 |
print v-u,(v-u).diff(u) |
728 |
|
729 |
print u/5,(u/5).diff(u) |
730 |
print 5/u,(5/u).diff(u) |
731 |
print u/v,(u/v).diff(u) |
732 |
print v/u,(v/u).diff(u) |
733 |
|
734 |
print u**5,(u**5).diff(u) |
735 |
print 5**u,(5**u).diff(u) |
736 |
print u**v,(u**v).diff(u) |
737 |
print v**u,(v**u).diff(u) |
738 |
|
739 |
print exp(u),exp(u).diff(u) |
740 |
print sqrt(u),sqrt(u).diff(u) |
741 |
print log(u),log(u).diff(u) |
742 |
print sin(u),sin(u).diff(u) |
743 |
print cos(u),cos(u).diff(u) |
744 |
print tan(u),tan(u).diff(u) |
745 |
print sign(u),sign(u).diff(u) |
746 |
print abs(u),abs(u).diff(u) |
747 |
print wherePositive(u),wherePositive(u).diff(u) |
748 |
print whereNegative(u),whereNegative(u).diff(u) |
749 |
print maxval(u),maxval(u).diff(u) |
750 |
print minval(u),minval(u).diff(u) |
751 |
|
752 |
g=grad(u) |
753 |
print diff(5*g,g) |
754 |
4*(g+transpose(g))/2+6*trace(g)*kronecker(3) |
755 |
|
756 |
# |
757 |
# $Log$ |
758 |
# Revision 1.2 2005/09/15 03:44:19 jgs |
759 |
# Merge of development branch dev-02 back to main trunk on 2005-09-15 |
760 |
# |
761 |
# Revision 1.1.2.1 2005/09/07 10:32:05 gross |
762 |
# Symbols removed from util and put into symmbols.py. |
763 |
# |
764 |
# |