/[escript]/trunk/escript/py_src/generatediff
ViewVC logotype

Contents of /trunk/escript/py_src/generatediff

Parent Directory Parent Directory | Revision Log Revision Log


Revision 433 - (show annotations)
Tue Jan 17 23:54:38 2006 UTC (13 years, 10 months ago) by gross
Original Path: trunk/escript/py_src/generateutil
File size: 102994 byte(s)
new function inverse and tests added
1 #!/usr/bin/python
2 # $Id$
3
4 """
5 program generates parts of the util.py and the test_util.py script
6 """
7 test_header=""
8 test_header+="import unittest\n"
9 test_header+="import numarray\n"
10 test_header+="from esys.escript import *\n"
11 test_header+="from esys.finley import Rectangle\n"
12 test_header+="class Test_util2(unittest.TestCase):\n"
13 test_header+=" RES_TOL=1.e-7\n"
14 test_header+=" def setUp(self):\n"
15 test_header+=" self.__dom =Rectangle(11,11,2)\n"
16 test_header+=" self.functionspace = FunctionOnBoundary(self.__dom)\n"
17 test_tail=""
18 test_tail+="suite = unittest.TestSuite()\n"
19 test_tail+="suite.addTest(unittest.makeSuite(Test_util2))\n"
20 test_tail+="unittest.TextTestRunner(verbosity=2).run(suite)\n"
21
22 case_set=["float","array","constData","taggedData","expandedData","Symbol"]
23 shape_set=[ (),(2,), (4,5), (6,2,2),(3,2,3,4)]
24
25 t_prog=""
26 t_prog_with_tags=""
27 t_prog_failing=""
28 u_prog=""
29
30 def wherepos(arg):
31 if arg>0.:
32 return 1.
33 else:
34 return 0.
35
36
37 class OPERATOR:
38 def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,
39 numarray_expr="",symbol_expr=None,diff=None,name=""):
40 self.nickname=nickname
41 self.rng=rng
42 self.test_expr=test_expr
43 if math_expr==None:
44 self.math_expr=test_expr.replace("%a1%","arg")
45 else:
46 self.math_expr=math_expr
47 self.numarray_expr=numarray_expr
48 self.symbol_expr=symbol_expr
49 self.diff=diff
50 self.name=name
51
52 import random
53 import numarray
54 import math
55 finc=1.e-6
56
57 def getResultCaseForBin(case0,case1):
58 if case0=="float":
59 if case1=="float":
60 case="float"
61 elif case1=="array":
62 case="array"
63 elif case1=="constData":
64 case="constData"
65 elif case1=="taggedData":
66 case="taggedData"
67 elif case1=="expandedData":
68 case="expandedData"
69 elif case1=="Symbol":
70 case="Symbol"
71 else:
72 raise ValueError,"unknown case1=%s"%case1
73 elif case0=="array":
74 if case1=="float":
75 case="array"
76 elif case1=="array":
77 case="array"
78 elif case1=="constData":
79 case="constData"
80 elif case1=="taggedData":
81 case="taggedData"
82 elif case1=="expandedData":
83 case="expandedData"
84 elif case1=="Symbol":
85 case="Symbol"
86 else:
87 raise ValueError,"unknown case1=%s"%case1
88 elif case0=="constData":
89 if case1=="float":
90 case="constData"
91 elif case1=="array":
92 case="constData"
93 elif case1=="constData":
94 case="constData"
95 elif case1=="taggedData":
96 case="taggedData"
97 elif case1=="expandedData":
98 case="expandedData"
99 elif case1=="Symbol":
100 case="Symbol"
101 else:
102 raise ValueError,"unknown case1=%s"%case1
103 elif case0=="taggedData":
104 if case1=="float":
105 case="taggedData"
106 elif case1=="array":
107 case="taggedData"
108 elif case1=="constData":
109 case="taggedData"
110 elif case1=="taggedData":
111 case="taggedData"
112 elif case1=="expandedData":
113 case="expandedData"
114 elif case1=="Symbol":
115 case="Symbol"
116 else:
117 raise ValueError,"unknown case1=%s"%case1
118 elif case0=="expandedData":
119 if case1=="float":
120 case="expandedData"
121 elif case1=="array":
122 case="expandedData"
123 elif case1=="constData":
124 case="expandedData"
125 elif case1=="taggedData":
126 case="expandedData"
127 elif case1=="expandedData":
128 case="expandedData"
129 elif case1=="Symbol":
130 case="Symbol"
131 else:
132 raise ValueError,"unknown case1=%s"%case1
133 elif case0=="Symbol":
134 if case1=="float":
135 case="Symbol"
136 elif case1=="array":
137 case="Symbol"
138 elif case1=="constData":
139 case="Symbol"
140 elif case1=="taggedData":
141 case="Symbol"
142 elif case1=="expandedData":
143 case="Symbol"
144 elif case1=="Symbol":
145 case="Symbol"
146 else:
147 raise ValueError,"unknown case1=%s"%case1
148 else:
149 raise ValueError,"unknown case0=%s"%case0
150 return case
151
152
153 def makeArray(shape,rng):
154 l=rng[1]-rng[0]
155 out=numarray.zeros(shape,numarray.Float64)
156 if len(shape)==0:
157 out=l*random.random()+rng[0]
158 elif len(shape)==1:
159 for i0 in range(shape[0]):
160 out[i0]=l*random.random()+rng[0]
161 elif len(shape)==2:
162 for i0 in range(shape[0]):
163 for i1 in range(shape[1]):
164 out[i0,i1]=l*random.random()+rng[0]
165 elif len(shape)==3:
166 for i0 in range(shape[0]):
167 for i1 in range(shape[1]):
168 for i2 in range(shape[2]):
169 out[i0,i1,i2]=l*random.random()+rng[0]
170 elif len(shape)==4:
171 for i0 in range(shape[0]):
172 for i1 in range(shape[1]):
173 for i2 in range(shape[2]):
174 for i3 in range(shape[3]):
175 out[i0,i1,i2,i3]=l*random.random()+rng[0]
176 else:
177 raise SystemError,"rank is restricted to 4"
178 return out
179
180
181 def makeResult(val,test_expr):
182 if isinstance(val,float):
183 out=eval(test_expr.replace("%a1%","val"))
184 elif len(val.shape)==0:
185 out=eval(test_expr.replace("%a1%","val"))
186 elif len(val.shape)==1:
187 out=numarray.zeros(val.shape,numarray.Float64)
188 for i0 in range(val.shape[0]):
189 out[i0]=eval(test_expr.replace("%a1%","val[i0]"))
190 elif len(val.shape)==2:
191 out=numarray.zeros(val.shape,numarray.Float64)
192 for i0 in range(val.shape[0]):
193 for i1 in range(val.shape[1]):
194 out[i0,i1]=eval(test_expr.replace("%a1%","val[i0,i1]"))
195 elif len(val.shape)==3:
196 out=numarray.zeros(val.shape,numarray.Float64)
197 for i0 in range(val.shape[0]):
198 for i1 in range(val.shape[1]):
199 for i2 in range(val.shape[2]):
200 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val[i0,i1,i2]"))
201 elif len(val.shape)==4:
202 out=numarray.zeros(val.shape,numarray.Float64)
203 for i0 in range(val.shape[0]):
204 for i1 in range(val.shape[1]):
205 for i2 in range(val.shape[2]):
206 for i3 in range(val.shape[3]):
207 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val[i0,i1,i2,i3]"))
208 else:
209 raise SystemError,"rank is restricted to 4"
210 return out
211
212 def makeResult2(val0,val1,test_expr):
213 if isinstance(val0,float):
214 if isinstance(val1,float):
215 out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
216 elif len(val1.shape)==0:
217 out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
218 elif len(val1.shape)==1:
219 out=numarray.zeros(val1.shape,numarray.Float64)
220 for i0 in range(val1.shape[0]):
221 out[i0]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0]"))
222 elif len(val1.shape)==2:
223 out=numarray.zeros(val1.shape,numarray.Float64)
224 for i0 in range(val1.shape[0]):
225 for i1 in range(val1.shape[1]):
226 out[i0,i1]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1]"))
227 elif len(val1.shape)==3:
228 out=numarray.zeros(val1.shape,numarray.Float64)
229 for i0 in range(val1.shape[0]):
230 for i1 in range(val1.shape[1]):
231 for i2 in range(val1.shape[2]):
232 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2]"))
233 elif len(val1.shape)==4:
234 out=numarray.zeros(val1.shape,numarray.Float64)
235 for i0 in range(val1.shape[0]):
236 for i1 in range(val1.shape[1]):
237 for i2 in range(val1.shape[2]):
238 for i3 in range(val1.shape[3]):
239 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2,i3]"))
240 else:
241 raise SystemError,"rank of val1 is restricted to 4"
242 elif len(val0.shape)==0:
243 if isinstance(val1,float):
244 out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
245 elif len(val1.shape)==0:
246 out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
247 elif len(val1.shape)==1:
248 out=numarray.zeros(val1.shape,numarray.Float64)
249 for i0 in range(val1.shape[0]):
250 out[i0]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0]"))
251 elif len(val1.shape)==2:
252 out=numarray.zeros(val1.shape,numarray.Float64)
253 for i0 in range(val1.shape[0]):
254 for i1 in range(val1.shape[1]):
255 out[i0,i1]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1]"))
256 elif len(val1.shape)==3:
257 out=numarray.zeros(val1.shape,numarray.Float64)
258 for i0 in range(val1.shape[0]):
259 for i1 in range(val1.shape[1]):
260 for i2 in range(val1.shape[2]):
261 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2]"))
262 elif len(val1.shape)==4:
263 out=numarray.zeros(val1.shape,numarray.Float64)
264 for i0 in range(val1.shape[0]):
265 for i1 in range(val1.shape[1]):
266 for i2 in range(val1.shape[2]):
267 for i3 in range(val1.shape[3]):
268 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2,i3]"))
269 else:
270 raise SystemError,"rank of val1 is restricted to 4"
271 elif len(val0.shape)==1:
272 if isinstance(val1,float):
273 out=numarray.zeros(val0.shape,numarray.Float64)
274 for i0 in range(val0.shape[0]):
275 out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1"))
276 elif len(val1.shape)==0:
277 out=numarray.zeros(val0.shape,numarray.Float64)
278 for i0 in range(val0.shape[0]):
279 out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1"))
280 elif len(val1.shape)==1:
281 out=numarray.zeros(val0.shape,numarray.Float64)
282 for i0 in range(val0.shape[0]):
283 out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0]"))
284 elif len(val1.shape)==2:
285 out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64)
286 for i0 in range(val0.shape[0]):
287 for j1 in range(val1.shape[1]):
288 out[i0,j1]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1]"))
289 elif len(val1.shape)==3:
290 out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64)
291 for i0 in range(val0.shape[0]):
292 for j1 in range(val1.shape[1]):
293 for j2 in range(val1.shape[2]):
294 out[i0,j1,j2]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1,j2]"))
295 elif len(val1.shape)==4:
296 out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64)
297 for i0 in range(val0.shape[0]):
298 for j1 in range(val1.shape[1]):
299 for j2 in range(val1.shape[2]):
300 for j3 in range(val1.shape[3]):
301 out[i0,j1,j2,j3]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1,j2,j3]"))
302 else:
303 raise SystemError,"rank of val1 is restricted to 4"
304 elif len(val0.shape)==2:
305 if isinstance(val1,float):
306 out=numarray.zeros(val0.shape,numarray.Float64)
307 for i0 in range(val0.shape[0]):
308 for i1 in range(val0.shape[1]):
309 out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1"))
310 elif len(val1.shape)==0:
311 out=numarray.zeros(val0.shape,numarray.Float64)
312 for i0 in range(val0.shape[0]):
313 for i1 in range(val0.shape[1]):
314 out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1"))
315 elif len(val1.shape)==1:
316 out=numarray.zeros(val0.shape,numarray.Float64)
317 for i0 in range(val0.shape[0]):
318 for i1 in range(val0.shape[1]):
319 out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0]"))
320 elif len(val1.shape)==2:
321 out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
322 for i0 in range(val0.shape[0]):
323 for i1 in range(val0.shape[1]):
324 out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1]"))
325 elif len(val1.shape)==3:
326 out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
327 for i0 in range(val0.shape[0]):
328 for i1 in range(val0.shape[1]):
329 for j2 in range(val1.shape[2]):
330 out[i0,i1,j2]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1,j2]"))
331 elif len(val1.shape)==4:
332 out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
333 for i0 in range(val0.shape[0]):
334 for i1 in range(val0.shape[1]):
335 for j2 in range(val1.shape[2]):
336 for j3 in range(val1.shape[3]):
337 out[i0,i1,j2,j3]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1,j2,j3]"))
338 else:
339 raise SystemError,"rank of val1 is restricted to 4"
340 elif len(val0.shape)==3:
341 if isinstance(val1,float):
342 out=numarray.zeros(val0.shape,numarray.Float64)
343 for i0 in range(val0.shape[0]):
344 for i1 in range(val0.shape[1]):
345 for i2 in range(val0.shape[2]):
346 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1"))
347 elif len(val1.shape)==0:
348 out=numarray.zeros(val0.shape,numarray.Float64)
349 for i0 in range(val0.shape[0]):
350 for i1 in range(val0.shape[1]):
351 for i2 in range(val0.shape[2]):
352 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1"))
353 elif len(val1.shape)==1:
354 out=numarray.zeros(val0.shape,numarray.Float64)
355 for i0 in range(val0.shape[0]):
356 for i1 in range(val0.shape[1]):
357 for i2 in range(val0.shape[2]):
358 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0]"))
359 elif len(val1.shape)==2:
360 out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
361 for i0 in range(val0.shape[0]):
362 for i1 in range(val0.shape[1]):
363 for i2 in range(val0.shape[2]):
364 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1]"))
365 elif len(val1.shape)==3:
366 out=numarray.zeros(val0.shape,numarray.Float64)
367 for i0 in range(val0.shape[0]):
368 for i1 in range(val0.shape[1]):
369 for i2 in range(val0.shape[2]):
370 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1,i2]"))
371 elif len(val1.shape)==4:
372 out=numarray.zeros(val0.shape+val1.shape[3:],numarray.Float64)
373 for i0 in range(val0.shape[0]):
374 for i1 in range(val0.shape[1]):
375 for i2 in range(val0.shape[2]):
376 for j3 in range(val1.shape[3]):
377 out[i0,i1,i2,j3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1,i2,j3]"))
378 else:
379 raise SystemError,"rank of val1 is rargs[1]estricted to 4"
380 elif len(val0.shape)==4:
381 if isinstance(val1,float):
382 out=numarray.zeros(val0.shape,numarray.Float64)
383 for i0 in range(val0.shape[0]):
384 for i1 in range(val0.shape[1]):
385 for i2 in range(val0.shape[2]):
386 for i3 in range(val0.shape[3]):
387 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1"))
388 elif len(val1.shape)==0:
389 out=numarray.zeros(val0.shape,numarray.Float64)
390 for i0 in range(val0.shape[0]):
391 for i1 in range(val0.shape[1]):
392 for i2 in range(val0.shape[2]):
393 for i3 in range(val0.shape[3]):
394 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1"))
395 elif len(val1.shape)==1:
396 out=numarray.zeros(val0.shape,numarray.Float64)
397 for i0 in range(val0.shape[0]):
398 for i1 in range(val0.shape[1]):
399 for i2 in range(val0.shape[2]):
400 for i3 in range(val0.shape[3]):
401 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0]"))
402 elif len(val1.shape)==2:
403 out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
404 for i0 in range(val0.shape[0]):
405 for i1 in range(val0.shape[1]):
406 for i2 in range(val0.shape[2]):
407 for i3 in range(val0.shape[3]):
408 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1]"))
409 elif len(val1.shape)==3:
410 out=numarray.zeros(val0.shape,numarray.Float64)
411 for i0 in range(val0.shape[0]):
412 for i1 in range(val0.shape[1]):
413 for i2 in range(val0.shape[2]):
414 for i3 in range(val0.shape[3]):
415 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1,i2]"))
416 elif len(val1.shape)==4:
417 out=numarray.zeros(val0.shape,numarray.Float64)
418 for i0 in range(val0.shape[0]):
419 for i1 in range(val0.shape[1]):
420 for i2 in range(val0.shape[2]):
421 for i3 in range(val0.shape[3]):
422 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1,i2,i3]"))
423 else:
424 raise SystemError,"rank of val1 is restricted to 4"
425 else:
426 raise SystemError,"rank is restricted to 4"
427 return out
428
429
430 def mkText(case,name,a,a1=None,use_tagging_for_expanded_data=False):
431 t_out=""
432 if case=="float":
433 if isinstance(a,float):
434 t_out+=" %s=%s\n"%(name,a)
435 elif a.rank==0:
436 t_out+=" %s=%s\n"%(name,a)
437 else:
438 t_out+=" %s=numarray.array(%s)\n"%(name,a.tolist())
439 elif case=="array":
440 if isinstance(a,float):
441 t_out+=" %s=numarray.array(%s)\n"%(name,a)
442 elif a.rank==0:
443 t_out+=" %s=numarray.array(%s)\n"%(name,a)
444 else:
445 t_out+=" %s=numarray.array(%s)\n"%(name,a.tolist())
446 elif case=="constData":
447 if isinstance(a,float):
448 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
449 elif a.rank==0:
450 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
451 else:
452 t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
453 elif case=="taggedData":
454 if isinstance(a,float):
455 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
456 t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
457 elif a.rank==0:
458 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
459 t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
460 else:
461 t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
462 t_out+=" %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist())
463 elif case=="expandedData":
464 if use_tagging_for_expanded_data:
465 if isinstance(a,float):
466 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
467 t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
468 elif a.rank==0:
469 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
470 t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
471 else:
472 t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
473 t_out+=" %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist())
474 t_out+=" %s.expand()\n"%name
475 else:
476 t_out+=" msk_%s=whereNegative(self.functionspace.getX()[0]-0.5)\n"%name
477 if isinstance(a,float):
478 t_out+=" %s=msk_%s*(%s)+(1.-msk_%s)*(%s)\n"%(name,name,a,name,a1)
479 elif a.rank==0:
480 t_out+=" %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a,name,a1)
481 else:
482 t_out+=" %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a.tolist(),name,a1.tolist())
483 elif case=="Symbol":
484 if isinstance(a,float):
485 t_out+=" %s=Symbol(shape=())\n"%(name)
486 elif a.rank==0:
487 t_out+=" %s=Symbol(shape=())\n"%(name)
488 else:
489 t_out+=" %s=Symbol(shape=%s)\n"%(name,str(a.shape))
490
491 return t_out
492
493 def mkTypeAndShapeTest(case,sh,argstr):
494 text=""
495 if case=="float":
496 text+=" self.failUnless(isinstance(%s,float),\"wrong type of result.\")\n"%argstr
497 elif case=="array":
498 text+=" self.failUnless(isinstance(%s,numarray.NumArray),\"wrong type of result.\")\n"%argstr
499 text+=" self.failUnlessEqual(%s.shape,%s,\"wrong shape of result.\")\n"%(argstr,str(sh))
500 elif case in ["constData","taggedData","expandedData"]:
501 text+=" self.failUnless(isinstance(%s,Data),\"wrong type of result.\")\n"%argstr
502 text+=" self.failUnlessEqual(%s.getShape(),%s,\"wrong shape of result.\")\n"%(argstr,str(sh))
503 elif case=="Symbol":
504 text+=" self.failUnless(isinstance(%s,Symbol),\"wrong type of result.\")\n"%argstr
505 text+=" self.failUnlessEqual(%s.getShape(),%s,\"wrong shape of result.\")\n"%(argstr,str(sh))
506 return text
507
508 def mkCode(txt,args=[],intend=""):
509 s=txt.split("\n")
510 if len(s)>1:
511 out=""
512 for l in s:
513 out+=intend+l+"\n"
514 else:
515 out="%sreturn %s\n"%(intend,txt)
516 c=1
517 for r in args:
518 out=out.replace("%%a%s%%"%c,r)
519 return out
520
521 def innerTEST(arg0,arg1):
522 if isinstance(arg0,float):
523 out=numarray.array(arg0*arg1)
524 else:
525 out=(arg0*arg1).sum()
526 return out
527
528 def outerTEST(arg0,arg1):
529 if isinstance(arg0,float):
530 out=numarray.array(arg0*arg1)
531 elif isinstance(arg1,float):
532 out=numarray.array(arg0*arg1)
533 else:
534 out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape)
535 return out
536
537 def tensorProductTest(arg0,arg1,sh_s):
538 if isinstance(arg0,float):
539 out=numarray.array(arg0*arg1)
540 elif isinstance(arg1,float):
541 out=numarray.array(arg0*arg1)
542 elif len(sh_s)==0:
543 out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape)
544 else:
545 l=len(sh_s)
546 sh0=arg0.shape[:arg0.rank-l]
547 sh1=arg1.shape[l:]
548 ls,l0,l1=1,1,1
549 for i in sh_s: ls*=i
550 for i in sh0: l0*=i
551 for i in sh1: l1*=i
552 out1=numarray.outerproduct(arg0,arg1).resize((l0,ls,ls,l1))
553 out2=numarray.zeros((l0,l1),numarray.Float)
554 for i0 in range(l0):
555 for i1 in range(l1):
556 for i in range(ls): out2[i0,i1]+=out1[i0,i,i,i1]
557 out=out2.resize(sh0+sh1)
558 return out
559
560 def testMatrixMult(arg0,arg1,sh_s):
561 return numarray.matrixmultiply(arg0,arg1)
562
563
564 def testTensorMult(arg0,arg1,sh_s):
565 if len(arg0)==2:
566 return numarray.matrixmultiply(arg0,arg1)
567 else:
568 if arg1.rank==4:
569 out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2],arg1.shape[3]),numarray.Float)
570 for i0 in range(arg0.shape[0]):
571 for i1 in range(arg0.shape[1]):
572 for i2 in range(arg0.shape[2]):
573 for i3 in range(arg0.shape[3]):
574 for j2 in range(arg1.shape[2]):
575 for j3 in range(arg1.shape[3]):
576 out[i0,i1,j2,j3]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2,j3]
577 elif arg1.rank==3:
578 out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2]),numarray.Float)
579 for i0 in range(arg0.shape[0]):
580 for i1 in range(arg0.shape[1]):
581 for i2 in range(arg0.shape[2]):
582 for i3 in range(arg0.shape[3]):
583 for j2 in range(arg1.shape[2]):
584 out[i0,i1,j2]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2]
585 elif arg1.rank==2:
586 out=numarray.zeros((arg0.shape[0],arg0.shape[1]),numarray.Float)
587 for i0 in range(arg0.shape[0]):
588 for i1 in range(arg0.shape[1]):
589 for i2 in range(arg0.shape[2]):
590 for i3 in range(arg0.shape[3]):
591 out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]
592 return out
593
594 def testReduce(arg0,init_val,test_expr,post_expr):
595 out=init_val
596 if isinstance(arg0,float):
597 out=eval(test_expr.replace("%a1%","arg0"))
598 elif arg0.rank==0:
599 out=eval(test_expr.replace("%a1%","arg0"))
600 elif arg0.rank==1:
601 for i0 in range(arg0.shape[0]):
602 out=eval(test_expr.replace("%a1%","arg0[i0]"))
603 elif arg0.rank==2:
604 for i0 in range(arg0.shape[0]):
605 for i1 in range(arg0.shape[1]):
606 out=eval(test_expr.replace("%a1%","arg0[i0,i1]"))
607 elif arg0.rank==3:
608 for i0 in range(arg0.shape[0]):
609 for i1 in range(arg0.shape[1]):
610 for i2 in range(arg0.shape[2]):
611 out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2]"))
612 elif arg0.rank==4:
613 for i0 in range(arg0.shape[0]):
614 for i1 in range(arg0.shape[1]):
615 for i2 in range(arg0.shape[2]):
616 for i3 in range(arg0.shape[3]):
617 out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2,i3]"))
618 return eval(post_expr)
619
620 def clipTEST(arg0,mn,mx):
621 if isinstance(arg0,float):
622 return max(min(arg0,mx),mn)
623 out=numarray.zeros(arg0.shape,numarray.Float64)
624 if arg0.rank==1:
625 for i0 in range(arg0.shape[0]):
626 out[i0]=max(min(arg0[i0],mx),mn)
627 elif arg0.rank==2:
628 for i0 in range(arg0.shape[0]):
629 for i1 in range(arg0.shape[1]):
630 out[i0,i1]=max(min(arg0[i0,i1],mx),mn)
631 elif arg0.rank==3:
632 for i0 in range(arg0.shape[0]):
633 for i1 in range(arg0.shape[1]):
634 for i2 in range(arg0.shape[2]):
635 out[i0,i1,i2]=max(min(arg0[i0,i1,i2],mx),mn)
636 elif arg0.rank==4:
637 for i0 in range(arg0.shape[0]):
638 for i1 in range(arg0.shape[1]):
639 for i2 in range(arg0.shape[2]):
640 for i3 in range(arg0.shape[3]):
641 out[i0,i1,i2,i3]=max(min(arg0[i0,i1,i2,i3],mx),mn)
642 return out
643 def minimumTEST(arg0,arg1):
644 if isinstance(arg0,float):
645 if isinstance(arg1,float):
646 if arg0>arg1:
647 return arg1
648 else:
649 return arg0
650 else:
651 arg0=numarray.ones(arg1.shape)*arg0
652 else:
653 if isinstance(arg1,float):
654 arg1=numarray.ones(arg0.shape)*arg1
655 out=numarray.zeros(arg0.shape,numarray.Float64)
656 if arg0.rank==0:
657 if arg0>arg1:
658 out=arg1
659 else:
660 out=arg0
661 elif arg0.rank==1:
662 for i0 in range(arg0.shape[0]):
663 if arg0[i0]>arg1[i0]:
664 out[i0]=arg1[i0]
665 else:
666 out[i0]=arg0[i0]
667 elif arg0.rank==2:
668 for i0 in range(arg0.shape[0]):
669 for i1 in range(arg0.shape[1]):
670 if arg0[i0,i1]>arg1[i0,i1]:
671 out[i0,i1]=arg1[i0,i1]
672 else:
673 out[i0,i1]=arg0[i0,i1]
674 elif arg0.rank==3:
675 for i0 in range(arg0.shape[0]):
676 for i1 in range(arg0.shape[1]):
677 for i2 in range(arg0.shape[2]):
678 if arg0[i0,i1,i2]>arg1[i0,i1,i2]:
679 out[i0,i1,i2]=arg1[i0,i1,i2]
680 else:
681 out[i0,i1,i2]=arg0[i0,i1,i2]
682 elif arg0.rank==4:
683 for i0 in range(arg0.shape[0]):
684 for i1 in range(arg0.shape[1]):
685 for i2 in range(arg0.shape[2]):
686 for i3 in range(arg0.shape[3]):
687 if arg0[i0,i1,i2,i3]>arg1[i0,i1,i2,i3]:
688 out[i0,i1,i2,i3]=arg1[i0,i1,i2,i3]
689 else:
690 out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]
691 return out
692 #=======================================================================================================
693 # inverse
694 #=======================================================================================================
695 name="inverse"
696 for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
697 for sh0 in [ (1,1), (2,2), (3,3)]:
698 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
699 tname="test_%s_%s_dim%s"%(name,case0,sh0[0])
700 text+=" def %s(self):\n"%tname
701 a_0=makeArray(sh0,[-1.,1])
702 for i in range(sh0[0]): a_0[i,i]+=2
703 if case0 in ["taggedData", "expandedData"]:
704 a1_0=makeArray(sh0,[-1.,1])
705 for i in range(sh0[0]): a1_0[i,i]+=3
706 else:
707 a1_0=a_0
708
709 text+=mkText(case0,"arg",a_0,a1_0)
710 text+=" res=%s(arg)\n"%name
711 if case0=="Symbol":
712 text+=mkText("array","s",a_0,a1_0)
713 text+=" sub=res.substitute({arg:s})\n"
714 res="sub"
715 ref="s"
716 else:
717 ref="arg"
718 res="res"
719 text+=mkTypeAndShapeTest(case0,sh0,"res")
720 text+=" self.failUnless(Lsup(matrixmult(%s,%s)-kronecker(%s))<=self.RES_TOL,\"wrong result\")\n"%(res,ref,sh0[0])
721
722 if case0 == "taggedData" :
723 t_prog_with_tags+=text
724 else:
725 t_prog+=text
726
727 print test_header
728 # print t_prog
729 print t_prog_with_tags
730 print test_tail
731 1/0
732
733 #=======================================================================================================
734 # trace
735 #=======================================================================================================
736 def traceTest(r,offset):
737 sh=r.shape
738 r1=1
739 for i in range(offset): r1*=sh[i]
740 r2=1
741 for i in range(offset+2,len(sh)): r2*=sh[i]
742 r_s=numarray.reshape(r,(r1,sh[offset],sh[offset],r2))
743 s=numarray.zeros([r1,r2],numarray.Float)
744 for i1 in range(r1):
745 for i2 in range(r2):
746 for j in range(sh[offset]): s[i1,i2]+=r_s[i1,j,j,i2]
747 return s.resize(sh[:offset]+sh[offset+2:])
748 name,tt="trace",traceTest
749 for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
750 for sh0 in [ (4,5), (6,2,2),(3,2,3,4)]:
751 for offset in range(len(sh0)-1):
752 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
753 tname="test_%s_%s_rank%s_offset%s"%(name,case0,len(sh0),offset)
754 text+=" def %s(self):\n"%tname
755 sh_t=list(sh0)
756 sh_t[offset+1]=sh_t[offset]
757 sh_t=tuple(sh_t)
758 sh_r=[]
759 for i in range(offset): sh_r.append(sh0[i])
760 for i in range(offset+2,len(sh0)): sh_r.append(sh0[i])
761 sh_r=tuple(sh_r)
762 a_0=makeArray(sh_t,[-1.,1])
763 if case0 in ["taggedData", "expandedData"]:
764 a1_0=makeArray(sh_t,[-1.,1])
765 else:
766 a1_0=a_0
767 r=tt(a_0,offset)
768 r1=tt(a1_0,offset)
769 text+=mkText(case0,"arg",a_0,a1_0)
770 text+=" res=%s(arg,%s)\n"%(name,offset)
771 if case0=="Symbol":
772 text+=mkText("array","s",a_0,a1_0)
773 text+=" sub=res.substitute({arg:s})\n"
774 res="sub"
775 text+=mkText("array","ref",r,r1)
776 else:
777 res="res"
778 text+=mkText(case0,"ref",r,r1)
779 text+=mkTypeAndShapeTest(case0,sh_r,"res")
780 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
781
782 if case0 == "taggedData" :
783 t_prog_with_tags+=text
784 else:
785 t_prog+=text
786
787 print test_header
788 # print t_prog
789 print t_prog_with_tags
790 print test_tail
791 1/0
792
793 #=======================================================================================================
794 # clip
795 #=======================================================================================================
796 oper_L=[["clip",clipTEST]]
797 for oper in oper_L:
798 for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
799 for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
800 if len(sh0)==0 or not case0=="float":
801 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
802 tname="test_%s_%s_rank%s"%(oper[0],case0,len(sh0))
803 text+=" def %s(self):\n"%tname
804 a_0=makeArray(sh0,[-1.,1])
805 if case0 in ["taggedData", "expandedData"]:
806 a1_0=makeArray(sh0,[-1.,1])
807 else:
808 a1_0=a_0
809
810 r=oper[1](a_0,-0.3,0.5)
811 r1=oper[1](a1_0,-0.3,0.5)
812 text+=mkText(case0,"arg",a_0,a1_0)
813 text+=" res=%s(arg,-0.3,0.5)\n"%oper[0]
814 if case0=="Symbol":
815 text+=mkText("array","s",a_0,a1_0)
816 text+=" sub=res.substitute({arg:s})\n"
817 res="sub"
818 text+=mkText("array","ref",r,r1)
819 else:
820 res="res"
821 text+=mkText(case0,"ref",r,r1)
822 text+=mkTypeAndShapeTest(case0,sh0,"res")
823 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
824
825 if case0 == "taggedData" :
826 t_prog_with_tags+=text
827 else:
828 t_prog+=text
829
830 print test_header
831 # print t_prog
832 print t_prog_with_tags
833 print test_tail
834 1/0
835
836 #=======================================================================================================
837 # maximum, minimum, clipping
838 #=======================================================================================================
839 oper_L=[ ["maximum",maximumTEST],
840 ["minimum",minimumTEST]]
841 for oper in oper_L:
842 for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
843 for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
844 for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
845 for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
846 if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \
847 and (sh0==sh1 or len(sh0)==0 or len(sh1)==0) :
848 use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
849
850 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
851 tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
852 text+=" def %s(self):\n"%tname
853 a_0=makeArray(sh0,[-1.,1])
854 if case0 in ["taggedData", "expandedData"]:
855 a1_0=makeArray(sh0,[-1.,1])
856 else:
857 a1_0=a_0
858
859 a_1=makeArray(sh1,[-1.,1])
860 if case1 in ["taggedData", "expandedData"]:
861 a1_1=makeArray(sh1,[-1.,1])
862 else:
863 a1_1=a_1
864 r=oper[1](a_0,a_1)
865 r1=oper[1](a1_0,a1_1)
866 text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
867 text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
868 text+=" res=%s(arg0,arg1)\n"%oper[0]
869 case=getResultCaseForBin(case0,case1)
870 if case=="Symbol":
871 c0_res,c1_res=case0,case1
872 subs="{"
873 if case0=="Symbol":
874 text+=mkText("array","s0",a_0,a1_0)
875 subs+="arg0:s0"
876 c0_res="array"
877 if case1=="Symbol":
878 text+=mkText("array","s1",a_1,a1_1)
879 if not subs.endswith("{"): subs+=","
880 subs+="arg1:s1"
881 c1_res="array"
882 subs+="}"
883 text+=" sub=res.substitute(%s)\n"%subs
884 res="sub"
885 text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
886 else:
887 res="res"
888 text+=mkText(case,"ref",r,r1)
889 if len(sh0)>len(sh1):
890 text+=mkTypeAndShapeTest(case,sh0,"res")
891 else:
892 text+=mkTypeAndShapeTest(case,sh1,"res")
893 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
894
895 if case0 == "taggedData" or case1 == "taggedData":
896 t_prog_with_tags+=text
897 else:
898 t_prog+=text
899
900 print test_header
901 # print t_prog
902 print t_prog_with_tags
903 print test_tail
904 1/0
905
906
907 #=======================================================================================================
908 # outer inner
909 #=======================================================================================================
910 oper=["outer",outerTEST]
911 # oper=["inner",innerTEST]
912 for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
913 for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
914 for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
915 for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
916 if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \
917 and len(sh0+sh1)<5:
918 use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
919
920 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
921 tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
922 text+=" def %s(self):\n"%tname
923 a_0=makeArray(sh0,[-1.,1])
924 if case0 in ["taggedData", "expandedData"]:
925 a1_0=makeArray(sh0,[-1.,1])
926 else:
927 a1_0=a_0
928
929 a_1=makeArray(sh1,[-1.,1])
930 if case1 in ["taggedData", "expandedData"]:
931 a1_1=makeArray(sh1,[-1.,1])
932 else:
933 a1_1=a_1
934 r=oper[1](a_0,a_1)
935 r1=oper[1](a1_0,a1_1)
936 text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
937 text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
938 text+=" res=%s(arg0,arg1)\n"%oper[0]
939 case=getResultCaseForBin(case0,case1)
940 if case=="Symbol":
941 c0_res,c1_res=case0,case1
942 subs="{"
943 if case0=="Symbol":
944 text+=mkText("array","s0",a_0,a1_0)
945 subs+="arg0:s0"
946 c0_res="array"
947 if case1=="Symbol":
948 text+=mkText("array","s1",a_1,a1_1)
949 if not subs.endswith("{"): subs+=","
950 subs+="arg1:s1"
951 c1_res="array"
952 subs+="}"
953 text+=" sub=res.substitute(%s)\n"%subs
954 res="sub"
955 text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
956 else:
957 res="res"
958 text+=mkText(case,"ref",r,r1)
959 text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
960 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
961
962 if case0 == "taggedData" or case1 == "taggedData":
963 t_prog_with_tags+=text
964 else:
965 t_prog+=text
966
967 print test_header
968 # print t_prog
969 print t_prog_with_tags
970 print test_tail
971 1/0
972
973 #=======================================================================================================
974 # local reduction
975 #=======================================================================================================
976 for oper in [["length",0.,"out+%a1%**2","math.sqrt(out)"],
977 ["maxval",-1.e99,"max(out,%a1%)","out"],
978 ["minval",1.e99,"min(out,%a1%)","out"] ]:
979 for case in case_set:
980 for sh in shape_set:
981 if not case=="float" or len(sh)==0:
982 text=""
983 text+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
984 tname="def test_%s_%s_rank%s"%(oper[0],case,len(sh))
985 text+=" %s(self):\n"%tname
986 a=makeArray(sh,[-1.,1.])
987 a1=makeArray(sh,[-1.,1.])
988 r1=testReduce(a1,oper[1],oper[2],oper[3])
989 r=testReduce(a,oper[1],oper[2],oper[3])
990
991 text+=mkText(case,"arg",a,a1)
992 text+=" res=%s(arg)\n"%oper[0]
993 if case=="Symbol":
994 text+=mkText("array","s",a,a1)
995 text+=" sub=res.substitute({arg:s})\n"
996 text+=mkText("array","ref",r,r1)
997 res="sub"
998 else:
999 text+=mkText(case,"ref",r,r1)
1000 res="res"
1001 if oper[0]=="length":
1002 text+=mkTypeAndShapeTest(case,(),"res")
1003 else:
1004 if case=="float" or case=="array":
1005 text+=mkTypeAndShapeTest("float",(),"res")
1006 else:
1007 text+=mkTypeAndShapeTest(case,(),"res")
1008 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1009 if case == "taggedData":
1010 t_prog_with_tags+=text
1011 else:
1012 t_prog+=text
1013 print test_header
1014 # print t_prog
1015 print t_prog_with_tags
1016 print test_tail
1017 1/0
1018
1019 #=======================================================================================================
1020 # tensor multiply
1021 #=======================================================================================================
1022 # oper=["generalTensorProduct",tensorProductTest]
1023 # oper=["matrixmult",testMatrixMult]
1024 oper=["tensormult",testTensorMult]
1025
1026 for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1027 for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1028 for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1029 for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1030 for sh_s in [ (),(3,), (2,3), (2,4,3),(4,2,3,2)]:
1031 if (len(sh0+sh_s)==0 or not case0=="float") and (len(sh1+sh_s)==0 or not case1=="float") \
1032 and len(sh0+sh1)<5 and len(sh0+sh_s)<5 and len(sh1+sh_s)<5:
1033 # if len(sh_s)==1 and len(sh0+sh_s)==2 and (len(sh_s+sh1)==1 or len(sh_s+sh1)==2)): # test for matrixmult
1034 if ( len(sh_s)==1 and len(sh0+sh_s)==2 and ( len(sh1+sh_s)==2 or len(sh1+sh_s)==1 )) or (len(sh_s)==2 and len(sh0+sh_s)==4 and (len(sh1+sh_s)==2 or len(sh1+sh_s)==3 or len(sh1+sh_s)==4)): # test for tensormult
1035 case=getResultCaseForBin(case0,case1)
1036 use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1037 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1038 # tname="test_generalTensorProduct_%s_rank%s_%s_rank%s_offset%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1),len(sh_s))
1039 #tname="test_matrixmult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
1040 tname="test_tensormult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
1041 # if tname=="test_generalTensorProduct_array_rank1_array_rank2_offset1":
1042 # print tnametest_generalTensorProduct_Symbol_rank1_Symbol_rank3_offset1
1043 text+=" def %s(self):\n"%tname
1044 a_0=makeArray(sh0+sh_s,[-1.,1])
1045 if case0 in ["taggedData", "expandedData"]:
1046 a1_0=makeArray(sh0+sh_s,[-1.,1])
1047 else:
1048 a1_0=a_0
1049
1050 a_1=makeArray(sh_s+sh1,[-1.,1])
1051 if case1 in ["taggedData", "expandedData"]:
1052 a1_1=makeArray(sh_s+sh1,[-1.,1])
1053 else:
1054 a1_1=a_1
1055 r=oper[1](a_0,a_1,sh_s)
1056 r1=oper[1](a1_0,a1_1,sh_s)
1057 text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1058 text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1059 #text+=" res=matrixmult(arg0,arg1)\n"
1060 text+=" res=tensormult(arg0,arg1)\n"
1061 #text+=" res=generalTensorProduct(arg0,arg1,offset=%s)\n"%(len(sh_s))
1062 if case=="Symbol":
1063 c0_res,c1_res=case0,case1
1064 subs="{"
1065 if case0=="Symbol":
1066 text+=mkText("array","s0",a_0,a1_0)
1067 subs+="arg0:s0"
1068 c0_res="array"
1069 if case1=="Symbol":
1070 text+=mkText("array","s1",a_1,a1_1)
1071 if not subs.endswith("{"): subs+=","
1072 subs+="arg1:s1"
1073 c1_res="array"
1074 subs+="}"
1075 text+=" sub=res.substitute(%s)\n"%subs
1076 res="sub"
1077 text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1078 else:
1079 res="res"
1080 text+=mkText(case,"ref",r,r1)
1081 text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
1082 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1083 if case0 == "taggedData" or case1 == "taggedData":
1084 t_prog_with_tags+=text
1085 else:
1086 t_prog+=text
1087 print test_header
1088 # print t_prog
1089 print t_prog_with_tags
1090 print test_tail
1091 1/0
1092 #=======================================================================================================
1093 # basic binary operation overloading (tests only!)
1094 #=======================================================================================================
1095 oper_range=[-5.,5.]
1096 for oper in [["add" ,"+",[-5.,5.]],
1097 ["sub" ,"-",[-5.,5.]],
1098 ["mult","*",[-5.,5.]],
1099 ["div" ,"/",[-5.,5.]],
1100 ["pow" ,"**",[0.01,5.]]]:
1101 for case0 in case_set:
1102 for sh0 in shape_set:
1103 for case1 in case_set:
1104 for sh1 in shape_set:
1105 if not case0=="array" and \
1106 (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
1107 (sh0==() or sh1==() or sh1==sh0) and \
1108 not (case0 in ["float","array"] and case1 in ["float","array"]):
1109 use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1110 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1111 tname="test_%s_overloaded_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
1112 text+=" def %s(self):\n"%tname
1113 a_0=makeArray(sh0,oper[2])
1114 if case0 in ["taggedData", "expandedData"]:
1115 a1_0=makeArray(sh0,oper[2])
1116 else:
1117 a1_0=a_0
1118
1119 a_1=makeArray(sh1,oper[2])
1120 if case1 in ["taggedData", "expandedData"]:
1121 a1_1=makeArray(sh1,oper[2])
1122 else:
1123 a1_1=a_1
1124 r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
1125 r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
1126 text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1127 text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1128 text+=" res=arg0%sarg1\n"%oper[1]
1129
1130 case=getResultCaseForBin(case0,case1)
1131 if case=="Symbol":
1132 c0_res,c1_res=case0,case1
1133 subs="{"
1134 if case0=="Symbol":
1135 text+=mkText("array","s0",a_0,a1_0)
1136 subs+="arg0:s0"
1137 c0_res="array"
1138 if case1=="Symbol":
1139 text+=mkText("array","s1",a_1,a1_1)
1140 if not subs.endswith("{"): subs+=","
1141 subs+="arg1:s1"
1142 c1_res="array"
1143 subs+="}"
1144 text+=" sub=res.substitute(%s)\n"%subs
1145 res="sub"
1146 text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1147 else:
1148 res="res"
1149 text+=mkText(case,"ref",r,r1)
1150 if isinstance(r,float):
1151 text+=mkTypeAndShapeTest(case,(),"res")
1152 else:
1153 text+=mkTypeAndShapeTest(case,r.shape,"res")
1154 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1155
1156 if case0 in [ "constData","taggedData","expandedData"] and case1 == "Symbol":
1157 t_prog_failing+=text
1158 else:
1159 if case0 == "taggedData" or case1 == "taggedData":
1160 t_prog_with_tags+=text
1161 else:
1162 t_prog+=text
1163
1164
1165 print test_header
1166 # print t_prog
1167 # print t_prog_with_tags
1168 print t_prog_failing
1169 print test_tail
1170 1/0
1171 #=======================================================================================================
1172 # basic binary operations (tests only!)
1173 #=======================================================================================================
1174 oper_range=[-5.,5.]
1175 for oper in [["add" ,"+",[-5.,5.]],
1176 ["mult","*",[-5.,5.]],
1177 ["quotient" ,"/",[-5.,5.]],
1178 ["power" ,"**",[0.01,5.]]]:
1179 for case0 in case_set:
1180 for case1 in case_set:
1181 for sh in shape_set:
1182 for sh_p in shape_set:
1183 if len(sh_p)>0:
1184 resource=[-1,1]
1185 else:
1186 resource=[1]
1187 for sh_d in resource:
1188 if sh_d>0:
1189 sh0=sh
1190 sh1=sh+sh_p
1191 else:
1192 sh1=sh
1193 sh0=sh+sh_p
1194
1195 if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
1196 len(sh0)<5 and len(sh1)<5:
1197 use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1198 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1199 tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
1200 text+=" def %s(self):\n"%tname
1201 a_0=makeArray(sh0,oper[2])
1202 if case0 in ["taggedData", "expandedData"]:
1203 a1_0=makeArray(sh0,oper[2])
1204 else:
1205 a1_0=a_0
1206
1207 a_1=makeArray(sh1,oper[2])
1208 if case1 in ["taggedData", "expandedData"]:
1209 a1_1=makeArray(sh1,oper[2])
1210 else:
1211 a1_1=a_1
1212 r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
1213 r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
1214 text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1215 text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1216 text+=" res=%s(arg0,arg1)\n"%oper[0]
1217
1218 case=getResultCaseForBin(case0,case1)
1219 if case=="Symbol":
1220 c0_res,c1_res=case0,case1
1221 subs="{"
1222 if case0=="Symbol":
1223 text+=mkText("array","s0",a_0,a1_0)
1224 subs+="arg0:s0"
1225 c0_res="array"
1226 if case1=="Symbol":
1227 text+=mkText("array","s1",a_1,a1_1)
1228 if not subs.endswith("{"): subs+=","
1229 subs+="arg1:s1"
1230 c1_res="array"
1231 subs+="}"
1232 text+=" sub=res.substitute(%s)\n"%subs
1233 res="sub"
1234 text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1235 else:
1236 res="res"
1237 text+=mkText(case,"ref",r,r1)
1238 if isinstance(r,float):
1239 text+=mkTypeAndShapeTest(case,(),"res")
1240 else:
1241 text+=mkTypeAndShapeTest(case,r.shape,"res")
1242 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1243
1244 if case0 == "taggedData" or case1 == "taggedData":
1245 t_prog_with_tags+=text
1246 else:
1247 t_prog+=text
1248 print test_header
1249 # print t_prog
1250 print t_prog_with_tags
1251 print test_tail
1252 1/0
1253
1254 # print t_prog_with_tagsoper_range=[-5.,5.]
1255 for oper in [["add" ,"+",[-5.,5.]],
1256 ["sub" ,"-",[-5.,5.]],
1257 ["mult","*",[-5.,5.]],
1258 ["div" ,"/",[-5.,5.]],
1259 ["pow" ,"**",[0.01,5.]]]:
1260 for case0 in case_set:
1261 for sh0 in shape_set:
1262 for case1 in case_set:
1263 for sh1 in shape_set:
1264 if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
1265 (sh0==() or sh1==() or sh1==sh0) and \
1266 not (case0 in ["float","array"] and case1 in ["float","array"]):
1267 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1268 tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
1269 text+=" def %s(self):\n"%tname
1270 a_0=makeArray(sh0,oper[2])
1271 if case0 in ["taggedData", "expandedData"]:
1272 a1_0=makeArray(sh0,oper[2])
1273 else:
1274 a1_0=a_0
1275
1276 a_1=makeArray(sh1,oper[2])
1277 if case1 in ["taggedData", "expandedData"]:
1278 a1_1=makeArray(sh1,oper[2])
1279 else:
1280 a1_1=a_1
1281 r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
1282 r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
1283 text+=mkText(case0,"arg0",a_0,a1_0)
1284 text+=mkText(case1,"arg1",a_1,a1_1)
1285 text+=" res=arg0%sarg1\n"%oper[1]
1286
1287 case=getResultCaseForBin(case0,case1)
1288 if case=="Symbol":
1289 c0_res,c1_res=case0,case1
1290 subs="{"
1291 if case0=="Symbol":
1292 text+=mkText("array","s0",a_0,a1_0)
1293 subs+="arg0:s0"
1294 c0_res="array"
1295 if case1=="Symbol":
1296 text+=mkText("array","s1",a_1,a1_1)
1297 if not subs.endswith("{"): subs+=","
1298 subs+="arg1:s1"
1299 c1_res="array"
1300 subs+="}"
1301 text+=" sub=res.substitute(%s)\n"%subs
1302 res="sub"
1303 text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1304 else:
1305 res="res"
1306 text+=mkText(case,"ref",r,r1)
1307 if isinstance(r,float):
1308 text+=mkTypeAndShapeTest(case,(),"res")
1309 else:
1310 text+=mkTypeAndShapeTest(case,r.shape,"res")
1311 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1312
1313 if case0 in [ "constData","taggedData","expandedData"] and case1 == "Symbol":
1314 t_prog_failing+=text
1315 else:
1316 if case0 == "taggedData" or case1 == "taggedData":
1317 t_prog_with_tags+=text
1318 else:
1319 t_prog+=text
1320
1321
1322 # print u_prog
1323 # 1/0
1324 print test_header
1325 print t_prog
1326 # print t_prog_with_tags
1327 # print t_prog_failing
1328 print test_tail
1329 # print t_prog_failing
1330 print test_tail
1331
1332 #=======================================================================================================
1333 # unary operations:
1334 #=======================================================================================================
1335 func= [
1336 OPERATOR(nickname="log10",\
1337 rng=[1.e-3,100.],\
1338 test_expr="math.log10(%a1%)",\
1339 math_expr="math.log10(%a1%)",\
1340 numarray_expr="numarray.log10(%a1%)",\
1341 symbol_expr="log(%a1%)/log(10.)",\
1342 name="base-10 logarithm"),
1343 OPERATOR(nickname="wherePositive",\
1344 rng=[-100.,100.],\
1345 test_expr="wherepos(%a1%)",\
1346 math_expr="if arg>0:\n return 1.\nelse:\n return 0.",
1347 numarray_expr="numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))",\
1348 name="mask of positive values"),
1349 OPERATOR(nickname="whereNegative",\
1350 rng=[-100.,100.],\
1351 test_expr="wherepos(-%a1%)",\
1352 math_expr="if arg<0:\n return 1.\nelse:\n return 0.",
1353 numarray_expr="numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))",\
1354 name="mask of positive values"),
1355 OPERATOR(nickname="whereNonNegative",\
1356 rng=[-100.,100.],\
1357 test_expr="1-wherepos(-%a1%)", \
1358 math_expr="if arg<0:\n return 0.\nelse:\n return 1.",
1359 numarray_expr="numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))",\
1360 symbol_expr="1-wherePositive(%a1%)",\
1361 name="mask of non-negative values"),
1362 OPERATOR(nickname="whereNonPositive",\
1363 rng=[-100.,100.],\
1364 test_expr="1-wherepos(%a1%)",\
1365 math_expr="if arg>0:\n return 0.\nelse:\n return 1.",
1366 numarray_expr="numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))",\
1367 symbol_expr="1-whereNegative(%a1%)",\
1368 name="mask of non-positive values"),
1369 OPERATOR(nickname="whereZero",\
1370 rng=[-100.,100.],\
1371 test_expr="1-wherepos(%a1%)-wherepos(-%a1%)",\
1372 math_expr="if abs(%a1%)<=tol:\n return 1.\nelse:\n return 0.",
1373 numarray_expr="numarray.less_equal(abs(%a1%)-tol,numarray.zeros(arg.shape,numarray.Float))",\
1374 name="mask of zero entries"),
1375 OPERATOR(nickname="whereNonZero",\
1376 rng=[-100.,100.],\
1377 test_expr="wherepos(%a1%)+wherepos(-%a1%)",\
1378 math_expr="if abs(%a1%)>tol:\n return 1.\nelse:\n return 0.",\
1379 numarray_expr="numarray.greater(abs(%a1%)-tol,numarray.zeros(arg.shape,numarray.Float))",\
1380 symbol_expr="1-whereZero(arg,tol)",\
1381 name="mask of values different from zero"),
1382 OPERATOR(nickname="sin",\
1383 rng=[-100.,100.],\
1384 test_expr="math.sin(%a1%)",
1385 numarray_expr="numarray.sin(%a1%)",\
1386 diff="cos(%a1%)",\
1387 name="sine"),
1388 OPERATOR(nickname="cos",\
1389 rng=[-100.,100.],\
1390 test_expr="math.cos(%a1%)",
1391 numarray_expr="numarray.cos(%a1%)",\
1392 diff="-sin(%a1%)",
1393 name="cosine"),
1394 OPERATOR(nickname="tan",\
1395 rng=[-100.,100.],\
1396 test_expr="math.tan(%a1%)",
1397 numarray_expr="numarray.tan(%a1%)",\
1398 diff="1./cos(%a1%)**2",
1399 name="tangent"),
1400 OPERATOR(nickname="asin",\
1401 rng=[-0.99,0.99],\
1402 test_expr="math.asin(%a1%)",
1403 numarray_expr="numarray.arcsin(%a1%)",
1404 diff="1./sqrt(1.-%a1%**2)",
1405 name="inverse sine"),
1406 OPERATOR(nickname="acos",\
1407 rng=[-0.99,0.99],\
1408 test_expr="math.acos(%a1%)",
1409 numarray_expr="numarray.arccos(%a1%)",
1410 diff="-1./sqrt(1.-%a1%**2)",
1411 name="inverse cosine"),
1412 OPERATOR(nickname="atan",\
1413 rng=[-100.,100.],\
1414 test_expr="math.atan(%a1%)",
1415 numarray_expr="numarray.arctan(%a1%)",
1416 diff="1./(1+%a1%**2)",
1417 name="inverse tangent"),
1418 OPERATOR(nickname="sinh",\
1419 rng=[-5,5],\
1420 test_expr="math.sinh(%a1%)",
1421 numarray_expr="numarray.sinh(%a1%)",
1422 diff="cosh(%a1%)",
1423 name="hyperbolic sine"),
1424 OPERATOR(nickname="cosh",\
1425 rng=[-5.,5.],
1426 test_expr="math.cosh(%a1%)",
1427 numarray_expr="numarray.cosh(%a1%)",
1428 diff="sinh(%a1%)",
1429 name="hyperbolic cosine"),
1430 OPERATOR(nickname="tanh",\
1431 rng=[-5.,5.],
1432 test_expr="math.tanh(%a1%)",
1433 numarray_expr="numarray.tanh(%a1%)",
1434 diff="1./cosh(%a1%)**2",
1435 name="hyperbolic tangent"),
1436 OPERATOR(nickname="asinh",\
1437 rng=[-100.,100.], \
1438 test_expr="numarray.arcsinh(%a1%)",
1439 math_expr="numarray.arcsinh(%a1%)",
1440 numarray_expr="numarray.arcsinh(%a1%)",
1441 diff="1./sqrt(%a1%**2+1)",
1442 name="inverse hyperbolic sine"),
1443 OPERATOR(nickname="acosh",\
1444 rng=[1.001,100.],\
1445 test_expr="numarray.arccosh(%a1%)",
1446 math_expr="numarray.arccosh(%a1%)",
1447 numarray_expr="numarray.arccosh(%a1%)",
1448 diff="1./sqrt(%a1%**2-1)",
1449 name="inverse hyperolic cosine"),
1450 OPERATOR(nickname="atanh",\
1451 rng=[-0.99,0.99], \
1452 test_expr="numarray.arctanh(%a1%)",
1453 math_expr="numarray.arctanh(%a1%)",
1454 numarray_expr="numarray.arctanh(%a1%)",
1455 diff="1./(1.-%a1%**2)",
1456 name="inverse hyperbolic tangent"),
1457 OPERATOR(nickname="exp",\
1458 rng=[-5.,5.],
1459 test_expr="math.exp(%a1%)",
1460 numarray_expr="numarray.exp(%a1%)",
1461 diff="self",
1462 name="exponential"),
1463 OPERATOR(nickname="sqrt",\
1464 rng=[1.e-3,100.],\
1465 test_expr="math.sqrt(%a1%)",
1466 numarray_expr="numarray.sqrt(%a1%)",
1467 diff="0.5/self",
1468 name="square root"),
1469 OPERATOR(nickname="log", \
1470 rng=[1.e-3,100.],\
1471 test_expr="math.log(%a1%)",
1472 numarray_expr="numarray.log(%a1%)",
1473 diff="1./arg",
1474 name="natural logarithm"),
1475 OPERATOR(nickname="sign",\
1476 rng=[-100.,100.], \
1477 math_expr="if %a1%>0:\n return 1.\nelif %a1%<0:\n return -1.\nelse:\n return 0.",
1478 test_expr="wherepos(%a1%)-wherepos(-%a1%)",
1479 numarray_expr="numarray.sign(%a1%)",
1480 symbol_expr="wherePositive(%a1%)-whereNegative(%a1%)",\
1481 name="sign"),
1482 OPERATOR(nickname="abs",\
1483 rng=[-100.,100.], \
1484 math_expr="if %a1%>0:\n return %a1% \nelif %a1%<0:\n return -(%a1%)\nelse:\n return 0.",
1485 test_expr="wherepos(%a1%)*(%a1%)-wherepos(-%a1%)*(%a1%)",
1486 numarray_expr="abs(%a1%)",
1487 diff="sign(%a1%)",
1488 name="absolute value")
1489
1490 ]
1491 for f in func:
1492 symbol_name=f.nickname[0].upper()+f.nickname[1:]
1493 if f.nickname!="abs":
1494 u_prog+="def %s(arg):\n"%f.nickname
1495 u_prog+=" \"\"\"\n"
1496 u_prog+=" returns %s of argument arg\n\n"%f.name
1497 u_prog+=" @param arg: argument\n"
1498 u_prog+=" @type arg: C{float}, L{escript.Data}, L{Symbol}, L{numarray.NumArray}.\n"
1499 u_prog+=" @rtype:C{float}, L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.\n"
1500 u_prog+=" @raises TypeError: if the type of the argument is not expected.\n"
1501 u_prog+=" \"\"\"\n"
1502 u_prog+=" if isinstance(arg,numarray.NumArray):\n"
1503 u_prog+=mkCode(f.numarray_expr,["arg"],2*" ")
1504 u_prog+=" elif isinstance(arg,escript.Data):\n"
1505 u_prog+=mkCode("arg._%s()"%f.nickname,[],2*" ")
1506 u_prog+=" elif isinstance(arg,float):\n"
1507 u_prog+=mkCode(f.math_expr,["arg"],2*" ")
1508 u_prog+=" elif isinstance(arg,int):\n"
1509 u_prog+=mkCode(f.math_expr,["float(arg)"],2*" ")
1510 u_prog+=" elif isinstance(arg,Symbol):\n"
1511 if f.symbol_expr==None:
1512 u_prog+=mkCode("%s_Symbol(arg)"%symbol_name,[],2*" ")
1513 else:
1514 u_prog+=mkCode(f.symbol_expr,["arg"],2*" ")
1515 u_prog+=" else:\n"
1516 u_prog+=" raise TypeError,\"%s: Unknown argument type.\"\n\n"%f.nickname
1517 if f.symbol_expr==None:
1518 u_prog+="class %s_Symbol(DependendSymbol):\n"%symbol_name
1519 u_prog+=" \"\"\"\n"
1520 u_prog+=" L{Symbol} representing the result of the %s function\n"%f.name
1521 u_prog+=" \"\"\"\n"
1522 u_prog+=" def __init__(self,arg):\n"
1523 u_prog+=" \"\"\"\n"
1524 u_prog+=" initialization of %s L{Symbol} with argument arg\n"%f.nickname
1525 u_prog+=" @param arg: argument of function\n"
1526 u_prog+=" @type arg: typically L{Symbol}.\n"
1527 u_prog+=" \"\"\"\n"
1528 u_prog+=" DependendSymbol.__init__(self,args=[arg],shape=arg.getShape(),dim=arg.getDim())\n"
1529 u_prog+="\n"
1530
1531 u_prog+=" def getMyCode(self,argstrs,format=\"escript\"):\n"
1532 u_prog+=" \"\"\"\n"
1533 u_prog+=" returns a program code that can be used to evaluate the symbol.\n\n"
1534
1535 u_prog+=" @param argstrs: gives for each argument a string representing the argument for the evaluation.\n"
1536 u_prog+=" @type argstrs: C{str} or a C{list} of length 1 of C{str}.\n"
1537 u_prog+=" @param format: specifies the format to be used. At the moment only \"escript\" ,\"text\" and \"str\" are supported.\n"
1538 u_prog+=" @type format: C{str}\n"
1539 u_prog+=" @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.\n"
1540 u_prog+=" @rtype: C{str}\n"
1541 u_prog+=" @raise: NotImplementedError: if the requested format is not available\n"
1542 u_prog+=" \"\"\"\n"
1543 u_prog+=" if isinstance(argstrs,list):\n"
1544 u_prog+=" argstrs=argstrs[0]\n"
1545 u_prog+=" if format==\"escript\" or format==\"str\" or format==\"text\":\n"
1546 u_prog+=" return \"%s(%%s)\"%%argstrs\n"%f.nickname
1547 u_prog+=" else:\n"
1548 u_prog+=" raise NotImplementedError,\"%s_Symbol does not provide program code for format %%s.\"%%format\n"%symbol_name
1549 u_prog+="\n"
1550
1551 u_prog+=" def substitute(self,argvals):\n"
1552 u_prog+=" \"\"\"\n"
1553 u_prog+=" assigns new values to symbols in the definition of the symbol.\n"
1554 u_prog+=" The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.\n"
1555 u_prog+="\n"
1556 u_prog+=" @param argvals: new values assigned to symbols\n"
1557 u_prog+=" @type argvals: C{dict} with keywords of type L{Symbol}.\n"
1558 u_prog+=" @return: result of the substitution process. Operations are executed as much as possible.\n"
1559 u_prog+=" @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution\n"
1560 u_prog+=" @raise TypeError: if a value for a L{Symbol} cannot be substituted.\n"
1561 u_prog+=" \"\"\"\n"
1562 u_prog+=" if argvals.has_key(self):\n"
1563 u_prog+=" arg=argvals[self]\n"
1564 u_prog+=" if self.isAppropriateValue(arg):\n"
1565 u_prog+=" return arg\n"
1566 u_prog+=" else:\n"
1567 u_prog+=" raise TypeError,\"%s: new value is not appropriate.\"%str(self)\n"
1568 u_prog+=" else:\n"
1569 u_prog+=" arg=self.getSubstitutedArguments(argvals)[0]\n"
1570 u_prog+=" return %s(arg)\n\n"%f.nickname
1571 if not f.diff==None:
1572 u_prog+=" def diff(self,arg):\n"
1573 u_prog+=" \"\"\"\n"
1574 u_prog+=" differential of this object\n"
1575 u_prog+="\n"
1576 u_prog+=" @param arg: the derivative is calculated with respect to arg\n"
1577 u_prog+=" @type arg: L{escript.Symbol}\n"
1578 u_prog+=" @return: derivative with respect to C{arg}\n"
1579 u_prog+=" @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray} are possible.\n"
1580 u_prog+=" \"\"\"\n"
1581 u_prog+=" if arg==self:\n"
1582 u_prog+=" return identity(self.getShape())\n"
1583 u_prog+=" else:\n"
1584 u_prog+=" myarg=self.getArgument()[0]\n"
1585 u_prog+=" val=matchShape(%s,self.getDifferentiatedArguments(arg)[0])\n"%f.diff.replace("%a1%","myarg")
1586 u_prog+=" return val[0]*val[1]\n\n"
1587
1588 for case in case_set:
1589 for sh in shape_set:
1590 if not case=="float" or len(sh)==0:
1591 text=""
1592 text+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1593 tname="def test_%s_%s_rank%s"%(f.nickname,case,len(sh))
1594 text+=" %s(self):\n"%tname
1595 a=makeArray(sh,f.rng)
1596 a1=makeArray(sh,f.rng)
1597 r1=makeResult(a1,f.test_expr)
1598 r=makeResult(a,f.test_expr)
1599
1600 text+=mkText(case,"arg",a,a1)
1601 text+=" res=%s(arg)\n"%f.nickname
1602 if case=="Symbol":
1603 text+=mkText("array","s",a,a1)
1604 text+=" sub=res.substitute({arg:s})\n"
1605 text+=mkText("array","ref",r,r1)
1606 res="sub"
1607 else:
1608 text+=mkText(case,"ref",r,r1)
1609 res="res"
1610 text+=mkTypeAndShapeTest(case,sh,"res")
1611 text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1612 if case == "taggedData":
1613 t_prog_with_tags+=text
1614 else:
1615 t_prog+=text
1616
1617 #=========== END OF GOOD CODE +++++++++++++++++++++++++++
1618
1619 1/0
1620
1621 def X():
1622 if args=="float":
1623 a=makeArray(sh,f[RANGE])
1624 r=makeResult(a,f)
1625 t_prog+=" arg=%s\n"%a[0]
1626 t_prog+=" ref=%s\n"%r[0]
1627 t_prog+=" res=%s(%a1%)\n"%f.nickname
1628 t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1629 t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1630 elif args == "array":
1631 a=makeArray(sh,f[RANGE])
1632 r=makeResult(a,f)
1633 if len(sh)==0:
1634 t_prog+=" arg=numarray.array(%s)\n"%a[0]
1635 t_prog+=" ref=numarray.array(%s)\n"%r[0]
1636 else:
1637 t_prog+=" arg=numarray.array(%s)\n"%a.tolist()
1638 t_prog+=" ref=numarray.array(%s)\n"%r.tolist()
1639 t_prog+=" res=%s(%a1%)\n"%f.nickname
1640 t_prog+=" self.failUnlessEqual(res.shape,%s,\"wrong shape of result.\")\n"%str(sh)
1641 t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1642 elif args== "constData":
1643 a=makeArray(sh,f[RANGE])
1644 r=makeResult(a,f)
1645 if len(sh)==0:
1646 t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
1647 t_prog+=" ref=%s\n"%r
1648 else:
1649 t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist())
1650 t_prog+=" ref=numarray.array(%s)\n"%r.tolist()
1651 t_prog+=" res=%s(%a1%)\n"%f.nickname
1652 t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh)
1653 t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1654 elif args in [ "taggedData","expandedData"]:
1655 a=makeArray(sh,f[RANGE])
1656 r=makeResult(a,f)
1657 a1=makeArray(sh,f[RANGE])
1658 r1=makeResult(a1,f)
1659 if len(sh)==0:
1660 if args=="expandedData":
1661 t_prog+=" arg=Data(%s,self.functionspace,True)\n"%(a)
1662 t_prog+=" ref=Data(%s,self.functionspace,True)\n"%(r)
1663 else:
1664 t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
1665 t_prog+=" ref=Data(%s,self.functionspace)\n"%(r)
1666 t_prog+=" arg.setTaggedValue(1,%s)\n"%a
1667 t_prog+=" ref.setTaggedValue(1,%s)\n"%r1
1668 else:
1669 if args=="expandedData":
1670 t_prog+=" arg=Data(numarray.array(%s),self.functionspace,True)\n"%(a.tolist())
1671 t_prog+=" ref=Data(numarray.array(%s),self.functionspace,True)\n"%(r.tolist())
1672 else:
1673 t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist())
1674 t_prog+=" ref=Data(numarray.array(%s),self.functionspace)\n"%(r.tolist())
1675 t_prog+=" arg.setTaggedValue(1,%s)\n"%a1.tolist()
1676 t_prog+=" ref.setTaggedValue(1,%s)\n"%r1.tolist()
1677 t_prog+=" res=%s(%a1%)\n"%f.nickname
1678 t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh)
1679 t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1680 elif args=="Symbol":
1681 t_prog+=" arg=Symbol(shape=%s)\n"%str(sh)
1682 t_prog+=" v=%s(%a1%)\n"%f.nickname
1683 t_prog+=" self.failUnlessRaises(ValueError,v.substitute,Symbol(shape=(1,1)),\"illegal shape of substitute not identified.\")\n"
1684 a=makeArray(sh,f[RANGE])
1685 r=makeResult(a,f)
1686 if len(sh)==0:
1687 t_prog+=" res=v.substitute({arg : %s})\n"%a
1688 t_prog+=" ref=%s\n"%r
1689 t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1690 else:
1691 t_prog+=" res=v.substitute({arg : numarray.array(%s)})\n"%a.tolist()
1692 t_prog+=" ref=numarray.array(%s)\n"%r.tolist()
1693 t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of substitution result.\")\n"%str(sh)
1694 t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1695
1696 if len(sh)==0:
1697 t_prog+=" # test derivative with respect to itself:\n"
1698 t_prog+=" dvdv=v.diff(v)\n"
1699 t_prog+=" self.failUnlessEqual(dvdv,1.,\"derivative with respect to self is not 1.\")\n"
1700 elif len(sh)==1:
1701 t_prog+=" # test derivative with respect to itself:\n"
1702 t_prog+=" dvdv=v.diff(v)\n"
1703 t_prog+=" self.failUnlessEqual(dvdv.shape,%s,\"shape of derivative with respect is wrong\")\n"%str(sh+sh)
1704 for i0_l in range(sh[0]):
1705 for i0_r in range(sh[0]):
1706 if i0_l == i0_r:
1707 v=1.
1708 else:
1709 v=0.
1710 t_prog+=" self.failUnlessEqual(dvdv[%s,%s],%s,\"derivative with respect to self: [%s,%s] is not %s\")\n"%(i0_l,i0_r,v,i0_l,i0_r,v)
1711 elif len(sh)==2:
1712 t_prog+=" # test derivative with respect to itself:\n"
1713 t_prog+=" dvdv=v.diff(v)\n"
1714 t_prog+=" self.failUnlessEqual(dvdv.shape,%s,\"shape of derivative with respect is wrong\")\n"%str(sh+sh)
1715 for i0_l in range(sh[0]):
1716 for i0_r in range(sh[0]):
1717 for i1_l in range(sh[1]):
1718 for i1_r in range(sh[1]):
1719 if i0_l == i0_r and i1_l == i1_r:
1720 v=1.
1721 else:
1722 v=0.
1723 t_prog+=" self.failUnlessEqual(dvdv[%s,%s,%s,%s],%s,\"derivative with respect to self: [%s,%s,%s,%s] is not %s\")\n"%(i0_l,i1_l,i0_r,i1_r,v,i0_l,i1_l,i0_r,i1_r,v)
1724
1725 for sh_in in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1726 if len(sh_in)+len(sh)<=4:
1727
1728 t_prog+=" # test derivative with shape %s as argument\n"%str(sh_in)
1729 trafo=makeArray(sh+sh_in,[0,1.])
1730 a_in=makeArray(sh_in,f[RANGE])
1731 t_prog+=" arg_in=Symbol(shape=%s)\n"%str(sh_in)
1732 t_prog+=" arg2=Symbol(shape=%s)\n"%str(sh)
1733
1734 if len(sh)==0:
1735 t_prog+=" arg2="
1736 if len(sh_in)==0:
1737 t_prog+="%s*arg_in\n"%trafo
1738 elif len(sh_in)==1:
1739 for i0 in range(sh_in[0]):
1740 if i0>0: t_prog+="+"
1741 t_prog+="%s*arg_in[%s]"%(trafo[i0],i0)
1742 t_prog+="\n"
1743 elif len(sh_in)==2:
1744 for i0 in range(sh_in[0]):
1745 for i1 in range(sh_in[1]):
1746 if i0+i1>0: t_prog+="+"
1747 t_prog+="%s*arg_in[%s,%s]"%(trafo[i0,i1],i0,i1)
1748
1749 elif len(sh_in)==3:
1750 for i0 in range(sh_in[0]):
1751 for i1 in range(sh_in[1]):
1752 for i2 in range(sh_in[2]):
1753 if i0+i1+i2>0: t_prog+="+"
1754 t_prog+="%s*arg_in[%s,%s,%s]"%(trafo[i0,i1,i2],i0,i1,i2)
1755 elif len(sh_in)==4:
1756 for i0 in range(sh_in[0]):
1757 for i1 in range(sh_in[1]):
1758 for i2 in range(sh_in[2]):
1759 for i3 in range(sh_in[3]):
1760 if i0+i1+i2+i3>0: t_prog+="+"
1761 t_prog+="%s*arg_in[%s,%s,%s,%s]"%(trafo[i0,i1,i2,i3],i0,i1,i2,i3)
1762 t_prog+="\n"
1763 elif len(sh)==1:
1764 for j0 in range(sh[0]):
1765 t_prog+=" arg2[%s]="%j0
1766 if len(sh_in)==0:
1767 t_prog+="%s*arg_in"%trafo[j0]
1768 elif len(sh_in)==1:
1769 for i0 in range(sh_in[0]):
1770 if i0>0: t_prog+="+"
1771 t_prog+="%s*arg_in[%s]"%(trafo[j0,i0],i0)
1772 elif len(sh_in)==2:
1773 for i0 in range(sh_in[0]):
1774 for i1 in range(sh_in[1]):
1775 if i0+i1>0: t_prog+="+"
1776 t_prog+="%s*arg_in[%s,%s]"%(trafo[j0,i0,i1],i0,i1)
1777 elif len(sh_in)==3:
1778 for i0 in range(sh_in[0]):
1779 for i1 in range(sh_in[1]):
1780 for i2 in range(sh_in[2]):
1781 if i0+i1+i2>0: t_prog+="+"
1782 t_prog+="%s*arg_in[%s,%s,%s]"%(trafo[j0,i0,i1,i2],i0,i1,i2)
1783 t_prog+="\n"
1784 elif len(sh)==2:
1785 for j0 in range(sh[0]):
1786 for j1 in range(sh[1]):
1787 t_prog+=" arg2[%s,%s]="%(j0,j1)
1788 if len(sh_in)==0:
1789 t_prog+="%s*arg_in"%trafo[j0,j1]
1790 elif len(sh_in)==1:
1791 for i0 in range(sh_in[0]):
1792 if i0>0: t_prog+="+"
1793 t_prog+="%s*arg_in[%s]"%(trafo[j0,j1,i0],i0)
1794 elif len(sh_in)==2:
1795 for i0 in range(sh_in[0]):
1796 for i1 in range(sh_in[1]):
1797 if i0+i1>0: t_prog+="+"
1798 t_prog+="%s*arg_in[%s,%s]"%(trafo[j0,j1,i0,i1],i0,i1)
1799 t_prog+="\n"
1800 elif len(sh)==3:
1801 for j0 in range(sh[0]):
1802 for j1 in range(sh[1]):
1803 for j2 in range(sh[2]):
1804 t_prog+=" arg2[%s,%s,%s]="%(j0,j1,j2)
1805 if len(sh_in)==0:
1806 t_prog+="%s*arg_in"%trafo[j0,j1,j2]
1807 elif len(sh_in)==1:
1808 for i0 in range(sh_in[0]):
1809 if i0>0: t_prog+="+"
1810 t_prog+="%s*arg_in[%s]"%(trafo[j0,j1,j2,i0],i0)
1811 t_prog+="\n"
1812 elif len(sh)==4:
1813 for j0 in range(sh[0]):
1814 for j1 in range(sh[1]):
1815 for j2 in range(sh[2]):
1816 for j3 in range(sh[3]):
1817 t_prog+=" arg2[%s,%s,%s,%s]="%(j0,j1,j2,j3)
1818 if len(sh_in)==0:
1819 t_prog+="%s*arg_in"%trafo[j0,j1,j2,j3]
1820 t_prog+="\n"
1821 t_prog+=" dvdin=v.substitute({arg : arg2}).diff(arg_in)\n"
1822 if len(sh_in)==0:
1823 t_prog+=" res_in=dvdin.substitute({arg_in : %s})\n"%a_in
1824 else:
1825 t_prog+=" res_in=dvdin.substitute({arg : numarray.array(%s)})\n"%a_in.tolist()
1826
1827 if len(sh)==0:
1828 if len(sh_in)==0:
1829 ref_diff=(makeResult(trafo*a_in+finc,f)-makeResult(trafo*a_in,f))/finc
1830 t_prog+=" self.failUnlessAlmostEqual(dvdin,%s,self.places,\"%s-derivative: wrong derivative\")\n"%(ref_diff,str(sh_in))
1831 elif len(sh_in)==1:
1832 s=0
1833 for k0 in range(sh_in[0]):
1834 s+=trafo[k0]*a_in[k0]
1835 for i0 in range(sh_in[0]):
1836 ref_diff=(makeResult(s+trafo[i0]*finc,f)-makeResult(s,f))/finc
1837 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,ref_diff,str(sh_in),str(i0))
1838 elif len(sh_in)==2:
1839 s=0
1840 for k0 in range(sh_in[0]):
1841 for k1 in range(sh_in[1]):
1842 s+=trafo[k0,k1]*a_in[k0,k1]
1843 for i0 in range(sh_in[0]):
1844 for i1 in range(sh_in[1]):
1845 ref_diff=(makeResult(s+trafo[i0,i1]*finc,f)-makeResult(s,f))/finc
1846 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,ref_diff,str(sh_in),str((i0,i1)))
1847
1848 elif len(sh_in)==3:
1849 s=0
1850 for k0 in range(sh_in[0]):
1851 for k1 in range(sh_in[1]):
1852 for k2 in range(sh_in[2]):
1853 s+=trafo[k0,k1,k2]*a_in[k0,k1,k2]
1854 for i0 in range(sh_in[0]):
1855 for i1 in range(sh_in[1]):
1856 for i2 in range(sh_in[2]):
1857 ref_diff=(makeResult(s+trafo[i0,i1,i2]*finc,f)-makeResult(s,f))/finc
1858 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,i2,ref_diff,str(sh_in),str((i0,i1,i2)))
1859 elif len(sh_in)==4:
1860 s=0
1861 for k0 in range(sh_in[0]):
1862 for k1 in range(sh_in[1]):
1863 for k2 in range(sh_in[2]):
1864 for k3 in range(sh_in[3]):
1865 s+=trafo[k0,k1,k2,k3]*a_in[k0,k1,k2,k3]
1866 for i0 in range(sh_in[0]):
1867 for i1 in range(sh_in[1]):
1868 for i2 in range(sh_in[2]):
1869 for i3 in range(sh_in[3]):
1870 ref_diff=(makeResult(s+trafo[i0,i1,i2,i3]*finc,f)-makeResult(s,f))/finc
1871 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,i2,i3,ref_diff,str(sh_in),str((i0,i1,i2,i3)))
1872 elif len(sh)==1:
1873 for j0 in range(sh[0]):
1874 if len(sh_in)==0:
1875 ref_diff=(makeResult(trafo[j0]*a_in+finc,f)-makeResult(trafo[j0]*a_in,f))/finc
1876 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,ref_diff,str(sh_in),j0)
1877 elif len(sh_in)==1:
1878 s=0
1879 for k0 in range(sh_in[0]):
1880 s+=trafo[j0,k0]*a_in[k0]
1881 for i0 in range(sh_in[0]):
1882 ref_diff=(makeResult(s+trafo[j0,i0]*finc,f)-makeResult(s,f))/finc
1883 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,i0,ref_diff,str(sh_in),str(j0),str(i0))
1884 elif len(sh_in)==2:
1885 s=0
1886 for k0 in range(sh_in[0]):
1887 for k1 in range(sh_in[1]):
1888 s+=trafo[j0,k0,k1]*a_in[k0,k1]
1889 for i0 in range(sh_in[0]):
1890 for i1 in range(sh_in[1]):
1891 ref_diff=(makeResult(s+trafo[j0,i0,i1]*finc,f)-makeResult(s,f))/finc
1892 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,i0,i1,ref_diff,str(sh_in),str(j0),str((i0,i1)))
1893
1894 elif len(sh_in)==3:
1895
1896 s=0
1897 for k0 in range(sh_in[0]):
1898 for k1 in range(sh_in[1]):
1899 for k2 in range(sh_in[2]):
1900 s+=trafo[j0,k0,k1,k2]*a_in[k0,k1,k2]
1901
1902 for i0 in range(sh_in[0]):
1903 for i1 in range(sh_in[1]):
1904 for i2 in range(sh_in[2]):
1905 ref_diff=(makeResult(s+trafo[j0,i0,i1,i2]*finc,f)-makeResult(s,f))/finc
1906 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,i0,i1,i2,ref_diff,str(sh_in),str(j0),str((i0,i1,i2)))
1907 elif len(sh)==2:
1908 for j0 in range(sh[0]):
1909 for j1 in range(sh[1]):
1910 if len(sh_in)==0:
1911 ref_diff=(makeResult(trafo[j0,j1]*a_in+finc,f)-makeResult(trafo[j0,j1]*a_in,f))/finc
1912 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,j1,ref_diff,str(sh_in),str((j0,j1)))
1913 elif len(sh_in)==1:
1914 s=0
1915 for k0 in range(sh_in[0]):
1916 s+=trafo[j0,j1,k0]*a_in[k0]
1917 for i0 in range(sh_in[0]):
1918 ref_diff=(makeResult(s+trafo[j0,j1,i0]*finc,f)-makeResult(s,f))/finc
1919 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,j1,i0,ref_diff,str(sh_in),str((j0,j1)),str(i0))
1920
1921 elif len(sh_in)==2:
1922 s=0
1923 for k0 in range(sh_in[0]):
1924 for k1 in range(sh_in[1]):
1925 s+=trafo[j0,j1,k0,k1]*a_in[k0,k1]
1926 for i0 in range(sh_in[0]):
1927 for i1 in range(sh_in[1]):
1928 ref_diff=(makeResult(s+trafo[j0,j1,i0,i1]*finc,f)-makeResult(s,f))/finc
1929 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,j1,i0,i1,ref_diff,str(sh_in),str((j0,j1)),str((i0,i1)))
1930 elif len(sh)==3:
1931 for j0 in range(sh[0]):
1932 for j1 in range(sh[1]):
1933 for j2 in range(sh[2]):
1934 if len(sh_in)==0:
1935 ref_diff=(makeResult(trafo[j0,j1,j2]*a_in+finc,f)-makeResult(trafo[j0,j1,j2]*a_in,f))/finc
1936 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,j1,j2,ref_diff,str(sh_in),str((j0,j1,j2)))
1937 elif len(sh_in)==1:
1938 s=0
1939 for k0 in range(sh_in[0]):
1940 s+=trafo[j0,j1,j2,k0]*a_in[k0]
1941 for i0 in range(sh_in[0]):
1942 ref_diff=(makeResult(s+trafo[j0,j1,j2,i0]*finc,f)-makeResult(s,f))/finc
1943 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,j1,j2,i0,ref_diff,str(sh_in),str((j0,j1,j2)),i0)
1944 elif len(sh)==4:
1945 for j0 in range(sh[0]):
1946 for j1 in range(sh[1]):
1947 for j2 in range(sh[2]):
1948 for j3 in range(sh[3]):
1949 if len(sh_in)==0:
1950 ref_diff=(makeResult(trafo[j0,j1,j2,j3]*a_in+finc,f)-makeResult(trafo[j0,j1,j2,j3]*a_in,f))/finc
1951 t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,j1,j2,j3,ref_diff,str(sh_in),str((j0,j1,j2,j3)))
1952
1953 #
1954
1955 #==================
1956 cases=["Scalar","Vector","Tensor", "Tensor3","Tensor4"]
1957
1958 for case in range(len(cases)):
1959 for d in [ None , "d", 1, 2 , 3]:
1960 if not d==None or cases[case]=="Scalar":
1961 t_prog+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1962 tname="test_Symbol_%s_d%s"%(cases[case],d)
1963 t_prog+=" def %s(self):\n"%tname
1964 if d=="d":
1965 t_prog+=" s=%sSymbol(dim=self.functionspace)\n"%(cases[case])
1966 t_prog+=" d=self.functionspace.getDim()\n"
1967 sh="("
1968 for i in range(case):
1969 if i==0:
1970 sh+=d
1971 else:
1972 sh+=","+d
1973 sh+=")"
1974 else:
1975 t_prog+=" s=%sSymbol(dim=%s)\n"%(cases[case],d)
1976 sh=()
1977 for i in range(case): sh=sh+(d,)
1978 t_prog+=" self.failUnlessEqual(s.getRank(),%s,\"wrong rank.\")\n"%case
1979 t_prog+=" self.failUnlessEqual(s.getShape(),%s,\"wrong shape.\")\n"%str(sh)
1980 t_prog+=" self.failUnlessEqual(s.getDim(),%s,\"wrong spatial dimension.\")\n"%d
1981 t_prog+=" self.failUnlessEqual(s.getArgument(),[],\"wrong arguments.\")\n"
1982
1983 print t_prog
1984 1/0
1985 for sh in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1986 for d in [ None , "domain", 1, 2 , 3]:
1987 for args in [ [], ["s2"], [1,-1.] ]:
1988 t_prog+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1989 tname="def test_Symbol_rank%s_d%s_nargs%s"%(len(sh),d,len(args))
1990 t_prog+=" %s(self):\n"%tname
1991 t_prog+=" s2=ScalarSymbol()\n"
1992 if args==["s2"]:
1993 a="[s2]"
1994 else:
1995 a=str(args)
1996 if d=="domain":
1997 t_prog+=" s=Symbol(shape=%s,dim=self.functionspace.getDim(),args=%s)\n"%(str(sh),a)
1998 d2="self.functionspace.getDim()"
1999 else:
2000 t_prog+=" s=Symbol(shape=%s,dim=%s,args=%s)\n"%(sh,d,a)
2001 d2=str(d)
2002
2003 t_prog+=" self.failUnlessEqual(s.getRank(),%s,\"wrong rank.\")\n"%len(sh)
2004 t_prog+=" self.failUnlessEqual(s.getShape(),%s,\"wrong shape.\")\n"%str(sh)
2005 t_prog+=" self.failUnlessEqual(s.getDim(),%s,\"wrong spatial dimension.\")\n"%d2
2006 t_prog+=" self.failUnlessEqual(s.getArgument(),%s,\"wrong arguments.\")\n\n"%a
2007 t_prog+=" ss=s.substitute({s:numarray.zeros(%s)})\n"%str(sh)
2008 t_prog+=" self.failUnless(isinstance(ss,numarray.NumArray),\"value after substitution is not numarray.\")\n"
2009 t_prog+=" self.failUnlessEqual(ss.shape,%s,\"value after substitution has not expected shape\")\n"%str(sh)
2010
2011 t_prog+=" try:\n s.substitute({s:numarray.zeros((5,))})\n fail(\"illegal substition was successful\")\n"
2012 t_prog+=" except TypeError:\n pass\n\n"
2013
2014 ###
2015 for sh2 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
2016 if len(sh+sh2)<5:
2017 t_prog+=" dsdarg=s.diff(Symbol(shape=%s))\n"%str(sh2)
2018 if len(sh+sh2)==0:
2019 t_prog+=" self.failUnless(isinstance(dsdarg,float),\"ds/ds() has wrong type.\")\n"
2020 else:
2021 t_prog+=" self.failUnless(isinstance(dsdarg,numarray.NumArray),\"ds/ds%s has wrong type.\")\n"%str(sh2)
2022
2023 t_prog+=" self.failUnlessEqual(dsdarg.shape,%s,\"ds/ds%s has wrong shape.\")\n"%(str(sh+sh2),str(sh2))
2024 t_prog+=" self.failIf(Lsup(dsdarg)>0.,\"ds/ds%s has wrong value.\")\n"%str(sh2)
2025 if len(sh)<3:
2026 t_prog+="\n dsds=s.diff(s)\n"
2027 if len(sh)==0:
2028 t_prog+=" self.failUnless(isinstance(dsds,float),\"ds/ds has wrong type.\")\n"
2029 t_prog+=" self.failUnlessEqual(dsds,1.,\"ds/ds has wrong value.\")\n"
2030 else:
2031 t_prog+=" self.failUnless(isinstance(dsds,numarray.NumArray),\"ds/ds has wrong type.\")\n"
2032 t_prog+=" self.failUnlessEqual(dsds.shape,%s,\"ds/ds has wrong shape.\")\n"%str(sh+sh)
2033 if len(sh)==1:
2034 for i0 in range(sh[0]):
2035 for i2 in range(sh[0]):
2036 if i0==i2:
2037 v=1.
2038 else:
2039 v=0.
2040 t_prog+=" self.failUnlessEqual(dsds[%s,%s],%s,\"ds/ds has wrong value at (%s,%s).\")\n"%(i0,i2,v,i0,i2)
2041 else:
2042 for i0 in range(sh[0]):
2043 for i1 in range(sh[1]):
2044 for i2 in range(sh[0]):
2045 for i3 in range(sh[1]):
2046 if i0==i2 and i1==i3:
2047 v=1.
2048 else:
2049 v=0.
2050 t_prog+=" self.failUnlessEqual(dsds[%s,%s,%s,%s],%s,\"ds/ds has wrong value at (%s,%s,%s,%s).\")\n"%(i0,i1,i2,i3,v,i0,i1,i2,i3)
2051
2052 ###
2053 t_prog+="\n"
2054 for i in range(len(args)):
2055 t_prog+=" self.failUnlessEqual(s.getArgument(%s),%s,\"wrong argument %s.\")\n"%(i,str(args[i]),i)
2056 t_prog+=" sa=s.getSubstitutedArguments({s2:-10})\n"
2057 t_prog+=" self.failUnlessEqual(len(sa),%s,\"wrong number of substituted arguments\")\n"%len(args)
2058 if args==["s2"]:
2059 t_prog+=" self.failUnlessEqual(sa[0],-10,\"wrongly substituted argument 0.\")\n"
2060 else:
2061 for i in range(len(args)):
2062 t_prog+=" self.failUnlessEqual(sa[%s],%s,\"wrongly substituted argument %s.\")\n"%(i,str(args[i]),i)
2063
2064 t_prog+="\n"
2065 for arg in ["10.", "10", "SymbolMatch", "SymbolMisMatch", \
2066 "DataMatch","DataMisMatch", "NumArrayMatch", "NumArrayMisMatch"]:
2067 if arg in ["10.", "10"]:
2068 a=str(arg)
2069 if len(sh)==0:
2070 t_prog+=" self.failUnless(s.isAppropriateValue(%s),\"%s is appropriate substitute\")\n"%(a,arg)
2071 else:
2072 t_prog+=" self.failIf(s.isAppropriateValue(%s),\" %s is not appropriate substitute\")\n"%(a,arg)
2073 elif arg in ["SymbolMatch", "SymbolMisMatch"]:
2074 if arg=="SymbolMatch":
2075 t_prog+=" self.failUnless(s.isAppropriateValue(Symbol(shape=%s,dim=%s)),\"Symbol is appropriate substitute\")\n"%(str(sh),d)
2076 else:
2077 if isinstance(d,int):
2078 t_prog+=" self.failIf(s.isAppropriateValue(Symbol(shape=%s,dim=%s)),\"Symbol is not appropriate substitute (dim)\")\n"%(str(sh),d+1)
2079 else:
2080 t_prog+=" self.failIf(s.isAppropriateValue(Symbol(shape=%s)),\"Symbol is not appropriate substitute (shape)\")\n"%((5,))
2081
2082 elif arg in ["DataMatch","DataMisMatch"]:
2083 if arg=="DataMatch" and d=="domain":
2084 t_prog+=" self.failUnless(s.isAppropriateValue(escript.Data(0.,%s,self.functionspace)),\"Data is appropriate substitute\")\n"%str(sh)
2085 elif arg=="DataMisMatch":
2086 t_prog+=" self.failIf(s.isAppropriateValue(escript.Data(0.,%s,self.functionspace)),\"Data is not appropriate substitute (shape)\")\n"%(str((5,)))
2087 else:
2088 if arg=="NumArrayMatch":
2089 t_prog+=" self.failUnless(s.isAppropriateValue(numarray.zeros(%s)),\"NumArray is appropriate substitute\")\n"%str(sh)
2090 else:
2091 t_prog+=" self.failIf(s.isAppropriateValue(numarray.zeros(%s)),\"NumArray is not appropriate substitute (shape)\")\n"%(str((5,)))
2092 print t_prog
2093 1/0
2094
2095
2096 for case in ["Lsup", "sup", "inf"]:
2097 for args in ["float","array","constData","taggedData","expandedData"]:
2098 for sh in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
2099 if not args=="float" or len(sh)==0:
2100 t_prog+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
2101 tname="def test_%s_%s_rank%s"%(case,args,len(sh))
2102 t_prog+=" %s(self):\n"%tname
2103 if args in ["float","array" ]:
2104 a=makeArray(sh,[-1,1])
2105 r=makeResult2(a,case)
2106 if len(sh)==0:
2107 t_prog+=" arg=%s\n"%a
2108 else:
2109 t_prog+=" arg=numarray.array(%s)\n"%a.tolist()
2110 t_prog+=" ref=%s\n"%r
2111 t_prog+=" res=%s(%a1%)\n"%case
2112 t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
2113 t_prog+=" self.failUnless(abs(res-ref)<=self.tol*abs(ref),\"wrong result\")\n"
2114 elif args== "constData":
2115 a=makeArray(sh,[-1,1])
2116 r=makeResult2(a,case)
2117 if len(sh)==0:
2118 t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
2119 else:
2120 t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist())
2121 t_prog+=" ref=%s\n"%r
2122 t_prog+=" res=%s(%a1%)\n"%case
2123 t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
2124 t_prog+=" self.failUnless(abs(res-ref)<=self.tol*abs(ref),\"wrong result\")\n"
2125 elif args in [ "taggedData","expandedData"]:
2126 a=makeArray(sh,[-1,1])
2127 r=makeResult2(a,case)
2128 a1=makeArray(sh,[-1,1])
2129 r1=makeResult2(a1,case)
2130 if case in ["Lsup","sup"]:
2131 r=max(r,r1)
2132 else:
2133 r=min(r,r1)
2134 if len(sh)==0:
2135 if args=="expandedData":
2136 t_prog+=" arg=Data(%s,self.functionspace,True)\n"%(a)
2137 else:
2138 t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
2139 t_prog+=" arg.setTaggedValue(1,%s)\n"%a
2140 else:
2141 if args=="expandedData":
2142 t_prog+=" arg=Data(numarray.array(%s),self.functionspace,True)\n"%(a.tolist())
2143 else:
2144 t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist())
2145 t_prog+=" arg.setTaggedValue(1,%s)\n"%a1.tolist()
2146 t_prog+=" res=%s(%a1%)\n"%case
2147 t_prog+=" ref=%s\n"%r
2148 t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
2149 t_prog+=" self.failUnless(abs(res-ref)<=self.tol*abs(ref),\"wrong result\")\n"
2150
2151 print t_prog
2152
2153 1/0
2154
2155
2156 for case in ["Lsup", "sup", "inf"]:
2157 for args in ["float","array","constData","taggedData","expandedData"]:
2158 for sh in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
2159 if not args=="float" or len(sh)==0:
2160 t_prog+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
2161 tname="def test_%s_%s_rank%s"%(case,args,len(sh))
2162 t_prog+=" %s(self):\n"%tname
2163 if args in ["float","array" ]:
2164 a=makeArray(sh,[-1,1])
2165 r=makeResult2(a,case)
2166 if len(sh)==0:
2167 t_prog+=" arg=%s\n"%a
2168 else:
2169 t_prog+=" arg=numarray.array(%s)\n"%a.tolist()
2170 t_prog+=" ref=%s\n"%r
2171 t_prog+=" res=%s(%a1%)\n"%case
2172 t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
2173 t_prog+=" self.failUnless(abs(res-ref)<=self.tol*abs(ref),\"wrong result\")\n"
2174 elif args== "constData":
2175 a=makeArray(sh,[-1,1])
2176 r=makeResult2(a,case)
2177 if len(sh)==0:
2178 t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
2179 else:
2180 t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist())
2181 t_prog+=" ref=%s\n"%r
2182 t_prog+=" res=%s(%a1%)\n"%case
2183 t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
2184 t_prog+=" self.failUnless(abs(res-ref)<=self.tol*abs(ref),\"wrong result\")\n"
2185 elif args in [ "taggedData","expandedData"]:
2186 a=makeArray(sh,[-1,1])
2187 r=makeResult2(a,case)
2188 a1=makeArray(sh,[-1,1])
2189 r1=makeResult2(a1,case)
2190 if case in ["Lsup","sup"]:
2191 r=max(r,r1)
2192 else:
2193 r=min(r,r1)
2194 if len(sh)==0:
2195 if args=="expandedData":
2196 t_prog+=" arg=Data(%s,self.functionspace,True)\n"%(a)
2197 else:
2198 t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
2199 t_prog+=" arg.setTaggedValue(1,%s)\n"%a
2200 else:
2201 if args=="expandedData":
2202 t_prog+=" arg=Da