Parent Directory
|
Revision Log
and a bit more tests
1 | jgs | 154 | #!/usr/bin/python |
2 | gross | 283 | # $Id$ |
3 | jgs | 154 | |
4 | """ | ||
5 | program generates parts of the util.py and the test_util.py script | ||
6 | """ | ||
7 | gross | 157 | test_header="" |
8 | test_header+="import unittest\n" | ||
9 | test_header+="import numarray\n" | ||
10 | test_header+="from esys.escript import *\n" | ||
11 | test_header+="from esys.finley import Rectangle\n" | ||
12 | test_header+="class Test_util2(unittest.TestCase):\n" | ||
13 | test_header+=" RES_TOL=1.e-7\n" | ||
14 | test_header+=" def setUp(self):\n" | ||
15 | test_header+=" self.__dom =Rectangle(11,11,2)\n" | ||
16 | test_header+=" self.functionspace = FunctionOnBoundary(self.__dom)\n" | ||
17 | test_tail="" | ||
18 | test_tail+="suite = unittest.TestSuite()\n" | ||
19 | test_tail+="suite.addTest(unittest.makeSuite(Test_util2))\n" | ||
20 | test_tail+="unittest.TextTestRunner(verbosity=2).run(suite)\n" | ||
21 | |||
22 | case_set=["float","array","constData","taggedData","expandedData","Symbol"] | ||
23 | shape_set=[ (),(2,), (4,5), (6,2,2),(3,2,3,4)] | ||
24 | |||
25 | jgs | 154 | t_prog="" |
26 | gross | 157 | t_prog_with_tags="" |
27 | t_prog_failing="" | ||
28 | u_prog="" | ||
29 | jgs | 154 | |
30 | gross | 157 | def wherepos(arg): |
31 | if arg>0.: | ||
32 | return 1. | ||
33 | else: | ||
34 | return 0. | ||
35 | |||
36 | gross | 313 | |
37 | gross | 157 | class OPERATOR: |
38 | def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None, | ||
39 | numarray_expr="",symbol_expr=None,diff=None,name=""): | ||
40 | self.nickname=nickname | ||
41 | self.rng=rng | ||
42 | self.test_expr=test_expr | ||
43 | if math_expr==None: | ||
44 | self.math_expr=test_expr.replace("%a1%","arg") | ||
45 | else: | ||
46 | self.math_expr=math_expr | ||
47 | self.numarray_expr=numarray_expr | ||
48 | self.symbol_expr=symbol_expr | ||
49 | self.diff=diff | ||
50 | self.name=name | ||
51 | |||
52 | jgs | 154 | import random |
53 | import numarray | ||
54 | import math | ||
55 | finc=1.e-6 | ||
56 | |||
57 | gross | 157 | def getResultCaseForBin(case0,case1): |
58 | if case0=="float": | ||
59 | if case1=="float": | ||
60 | case="float" | ||
61 | elif case1=="array": | ||
62 | case="array" | ||
63 | elif case1=="constData": | ||
64 | case="constData" | ||
65 | elif case1=="taggedData": | ||
66 | case="taggedData" | ||
67 | elif case1=="expandedData": | ||
68 | case="expandedData" | ||
69 | elif case1=="Symbol": | ||
70 | case="Symbol" | ||
71 | else: | ||
72 | raise ValueError,"unknown case1=%s"%case1 | ||
73 | elif case0=="array": | ||
74 | if case1=="float": | ||
75 | case="array" | ||
76 | elif case1=="array": | ||
77 | case="array" | ||
78 | elif case1=="constData": | ||
79 | case="constData" | ||
80 | elif case1=="taggedData": | ||
81 | case="taggedData" | ||
82 | elif case1=="expandedData": | ||
83 | case="expandedData" | ||
84 | elif case1=="Symbol": | ||
85 | case="Symbol" | ||
86 | else: | ||
87 | raise ValueError,"unknown case1=%s"%case1 | ||
88 | elif case0=="constData": | ||
89 | if case1=="float": | ||
90 | case="constData" | ||
91 | elif case1=="array": | ||
92 | case="constData" | ||
93 | elif case1=="constData": | ||
94 | case="constData" | ||
95 | elif case1=="taggedData": | ||
96 | case="taggedData" | ||
97 | elif case1=="expandedData": | ||
98 | case="expandedData" | ||
99 | elif case1=="Symbol": | ||
100 | case="Symbol" | ||
101 | else: | ||
102 | raise ValueError,"unknown case1=%s"%case1 | ||
103 | elif case0=="taggedData": | ||
104 | if case1=="float": | ||
105 | case="taggedData" | ||
106 | elif case1=="array": | ||
107 | case="taggedData" | ||
108 | elif case1=="constData": | ||
109 | case="taggedData" | ||
110 | elif case1=="taggedData": | ||
111 | case="taggedData" | ||
112 | elif case1=="expandedData": | ||
113 | case="expandedData" | ||
114 | elif case1=="Symbol": | ||
115 | case="Symbol" | ||
116 | else: | ||
117 | raise ValueError,"unknown case1=%s"%case1 | ||
118 | elif case0=="expandedData": | ||
119 | if case1=="float": | ||
120 | case="expandedData" | ||
121 | elif case1=="array": | ||
122 | case="expandedData" | ||
123 | elif case1=="constData": | ||
124 | case="expandedData" | ||
125 | elif case1=="taggedData": | ||
126 | case="expandedData" | ||
127 | elif case1=="expandedData": | ||
128 | case="expandedData" | ||
129 | elif case1=="Symbol": | ||
130 | case="Symbol" | ||
131 | else: | ||
132 | raise ValueError,"unknown case1=%s"%case1 | ||
133 | elif case0=="Symbol": | ||
134 | if case1=="float": | ||
135 | case="Symbol" | ||
136 | elif case1=="array": | ||
137 | case="Symbol" | ||
138 | elif case1=="constData": | ||
139 | case="Symbol" | ||
140 | elif case1=="taggedData": | ||
141 | case="Symbol" | ||
142 | elif case1=="expandedData": | ||
143 | case="Symbol" | ||
144 | elif case1=="Symbol": | ||
145 | case="Symbol" | ||
146 | else: | ||
147 | raise ValueError,"unknown case1=%s"%case1 | ||
148 | else: | ||
149 | raise ValueError,"unknown case0=%s"%case0 | ||
150 | return case | ||
151 | jgs | 154 | |
152 | gross | 157 | |
153 | jgs | 154 | def makeArray(shape,rng): |
154 | l=rng[1]-rng[0] | ||
155 | out=numarray.zeros(shape,numarray.Float64) | ||
156 | if len(shape)==0: | ||
157 | gross | 157 | out=l*random.random()+rng[0] |
158 | jgs | 154 | elif len(shape)==1: |
159 | for i0 in range(shape[0]): | ||
160 | out[i0]=l*random.random()+rng[0] | ||
161 | elif len(shape)==2: | ||
162 | for i0 in range(shape[0]): | ||
163 | for i1 in range(shape[1]): | ||
164 | out[i0,i1]=l*random.random()+rng[0] | ||
165 | elif len(shape)==3: | ||
166 | for i0 in range(shape[0]): | ||
167 | for i1 in range(shape[1]): | ||
168 | for i2 in range(shape[2]): | ||
169 | out[i0,i1,i2]=l*random.random()+rng[0] | ||
170 | elif len(shape)==4: | ||
171 | for i0 in range(shape[0]): | ||
172 | for i1 in range(shape[1]): | ||
173 | for i2 in range(shape[2]): | ||
174 | for i3 in range(shape[3]): | ||
175 | out[i0,i1,i2,i3]=l*random.random()+rng[0] | ||
176 | else: | ||
177 | raise SystemError,"rank is restricted to 4" | ||
178 | return out | ||
179 | |||
180 | |||
181 | gross | 157 | def makeResult(val,test_expr): |
182 | jgs | 154 | if isinstance(val,float): |
183 | gross | 157 | out=eval(test_expr.replace("%a1%","val")) |
184 | jgs | 154 | elif len(val.shape)==0: |
185 | gross | 157 | out=eval(test_expr.replace("%a1%","val")) |
186 | jgs | 154 | elif len(val.shape)==1: |
187 | out=numarray.zeros(val.shape,numarray.Float64) | ||
188 | for i0 in range(val.shape[0]): | ||
189 | gross | 157 | out[i0]=eval(test_expr.replace("%a1%","val[i0]")) |
190 | jgs | 154 | elif len(val.shape)==2: |
191 | out=numarray.zeros(val.shape,numarray.Float64) | ||
192 | for i0 in range(val.shape[0]): | ||
193 | for i1 in range(val.shape[1]): | ||
194 | gross | 157 | out[i0,i1]=eval(test_expr.replace("%a1%","val[i0,i1]")) |
195 | jgs | 154 | elif len(val.shape)==3: |
196 | out=numarray.zeros(val.shape,numarray.Float64) | ||
197 | for i0 in range(val.shape[0]): | ||
198 | for i1 in range(val.shape[1]): | ||
199 | for i2 in range(val.shape[2]): | ||
200 | gross | 157 | out[i0,i1,i2]=eval(test_expr.replace("%a1%","val[i0,i1,i2]")) |
201 | jgs | 154 | elif len(val.shape)==4: |
202 | out=numarray.zeros(val.shape,numarray.Float64) | ||
203 | for i0 in range(val.shape[0]): | ||
204 | for i1 in range(val.shape[1]): | ||
205 | for i2 in range(val.shape[2]): | ||
206 | for i3 in range(val.shape[3]): | ||
207 | gross | 157 | out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val[i0,i1,i2,i3]")) |
208 | jgs | 154 | else: |
209 | raise SystemError,"rank is restricted to 4" | ||
210 | return out | ||
211 | |||
212 | gross | 157 | def makeResult2(val0,val1,test_expr): |
213 | if isinstance(val0,float): | ||
214 | if isinstance(val1,float): | ||
215 | out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1")) | ||
216 | elif len(val1.shape)==0: | ||
217 | out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1")) | ||
218 | elif len(val1.shape)==1: | ||
219 | out=numarray.zeros(val1.shape,numarray.Float64) | ||
220 | for i0 in range(val1.shape[0]): | ||
221 | out[i0]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0]")) | ||
222 | elif len(val1.shape)==2: | ||
223 | out=numarray.zeros(val1.shape,numarray.Float64) | ||
224 | for i0 in range(val1.shape[0]): | ||
225 | for i1 in range(val1.shape[1]): | ||
226 | out[i0,i1]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1]")) | ||
227 | elif len(val1.shape)==3: | ||
228 | out=numarray.zeros(val1.shape,numarray.Float64) | ||
229 | for i0 in range(val1.shape[0]): | ||
230 | for i1 in range(val1.shape[1]): | ||
231 | for i2 in range(val1.shape[2]): | ||
232 | out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2]")) | ||
233 | elif len(val1.shape)==4: | ||
234 | out=numarray.zeros(val1.shape,numarray.Float64) | ||
235 | for i0 in range(val1.shape[0]): | ||
236 | for i1 in range(val1.shape[1]): | ||
237 | for i2 in range(val1.shape[2]): | ||
238 | for i3 in range(val1.shape[3]): | ||
239 | out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2,i3]")) | ||
240 | else: | ||
241 | raise SystemError,"rank of val1 is restricted to 4" | ||
242 | elif len(val0.shape)==0: | ||
243 | if isinstance(val1,float): | ||
244 | out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1")) | ||
245 | elif len(val1.shape)==0: | ||
246 | out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1")) | ||
247 | elif len(val1.shape)==1: | ||
248 | out=numarray.zeros(val1.shape,numarray.Float64) | ||
249 | for i0 in range(val1.shape[0]): | ||
250 | out[i0]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0]")) | ||
251 | elif len(val1.shape)==2: | ||
252 | out=numarray.zeros(val1.shape,numarray.Float64) | ||
253 | for i0 in range(val1.shape[0]): | ||
254 | for i1 in range(val1.shape[1]): | ||
255 | out[i0,i1]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1]")) | ||
256 | elif len(val1.shape)==3: | ||
257 | out=numarray.zeros(val1.shape,numarray.Float64) | ||
258 | for i0 in range(val1.shape[0]): | ||
259 | for i1 in range(val1.shape[1]): | ||
260 | for i2 in range(val1.shape[2]): | ||
261 | out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2]")) | ||
262 | elif len(val1.shape)==4: | ||
263 | out=numarray.zeros(val1.shape,numarray.Float64) | ||
264 | for i0 in range(val1.shape[0]): | ||
265 | for i1 in range(val1.shape[1]): | ||
266 | for i2 in range(val1.shape[2]): | ||
267 | for i3 in range(val1.shape[3]): | ||
268 | out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2,i3]")) | ||
269 | else: | ||
270 | raise SystemError,"rank of val1 is restricted to 4" | ||
271 | elif len(val0.shape)==1: | ||
272 | if isinstance(val1,float): | ||
273 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
274 | for i0 in range(val0.shape[0]): | ||
275 | out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1")) | ||
276 | elif len(val1.shape)==0: | ||
277 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
278 | for i0 in range(val0.shape[0]): | ||
279 | out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1")) | ||
280 | elif len(val1.shape)==1: | ||
281 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
282 | for i0 in range(val0.shape[0]): | ||
283 | out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0]")) | ||
284 | elif len(val1.shape)==2: | ||
285 | out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64) | ||
286 | for i0 in range(val0.shape[0]): | ||
287 | for j1 in range(val1.shape[1]): | ||
288 | out[i0,j1]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1]")) | ||
289 | elif len(val1.shape)==3: | ||
290 | out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64) | ||
291 | for i0 in range(val0.shape[0]): | ||
292 | for j1 in range(val1.shape[1]): | ||
293 | for j2 in range(val1.shape[2]): | ||
294 | out[i0,j1,j2]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1,j2]")) | ||
295 | elif len(val1.shape)==4: | ||
296 | out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64) | ||
297 | for i0 in range(val0.shape[0]): | ||
298 | for j1 in range(val1.shape[1]): | ||
299 | for j2 in range(val1.shape[2]): | ||
300 | for j3 in range(val1.shape[3]): | ||
301 | out[i0,j1,j2,j3]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1,j2,j3]")) | ||
302 | else: | ||
303 | raise SystemError,"rank of val1 is restricted to 4" | ||
304 | elif len(val0.shape)==2: | ||
305 | if isinstance(val1,float): | ||
306 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
307 | for i0 in range(val0.shape[0]): | ||
308 | for i1 in range(val0.shape[1]): | ||
309 | out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1")) | ||
310 | elif len(val1.shape)==0: | ||
311 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
312 | for i0 in range(val0.shape[0]): | ||
313 | for i1 in range(val0.shape[1]): | ||
314 | out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1")) | ||
315 | elif len(val1.shape)==1: | ||
316 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
317 | for i0 in range(val0.shape[0]): | ||
318 | for i1 in range(val0.shape[1]): | ||
319 | out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0]")) | ||
320 | elif len(val1.shape)==2: | ||
321 | out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64) | ||
322 | for i0 in range(val0.shape[0]): | ||
323 | for i1 in range(val0.shape[1]): | ||
324 | out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1]")) | ||
325 | elif len(val1.shape)==3: | ||
326 | out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64) | ||
327 | for i0 in range(val0.shape[0]): | ||
328 | for i1 in range(val0.shape[1]): | ||
329 | for j2 in range(val1.shape[2]): | ||
330 | out[i0,i1,j2]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1,j2]")) | ||
331 | elif len(val1.shape)==4: | ||
332 | out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64) | ||
333 | for i0 in range(val0.shape[0]): | ||
334 | for i1 in range(val0.shape[1]): | ||
335 | for j2 in range(val1.shape[2]): | ||
336 | for j3 in range(val1.shape[3]): | ||
337 | out[i0,i1,j2,j3]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1,j2,j3]")) | ||
338 | else: | ||
339 | raise SystemError,"rank of val1 is restricted to 4" | ||
340 | elif len(val0.shape)==3: | ||
341 | if isinstance(val1,float): | ||
342 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
343 | for i0 in range(val0.shape[0]): | ||
344 | for i1 in range(val0.shape[1]): | ||
345 | for i2 in range(val0.shape[2]): | ||
346 | out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1")) | ||
347 | elif len(val1.shape)==0: | ||
348 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
349 | for i0 in range(val0.shape[0]): | ||
350 | for i1 in range(val0.shape[1]): | ||
351 | for i2 in range(val0.shape[2]): | ||
352 | out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1")) | ||
353 | elif len(val1.shape)==1: | ||
354 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
355 | for i0 in range(val0.shape[0]): | ||
356 | for i1 in range(val0.shape[1]): | ||
357 | for i2 in range(val0.shape[2]): | ||
358 | out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0]")) | ||
359 | elif len(val1.shape)==2: | ||
360 | out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64) | ||
361 | for i0 in range(val0.shape[0]): | ||
362 | for i1 in range(val0.shape[1]): | ||
363 | for i2 in range(val0.shape[2]): | ||
364 | out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1]")) | ||
365 | elif len(val1.shape)==3: | ||
366 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
367 | for i0 in range(val0.shape[0]): | ||
368 | for i1 in range(val0.shape[1]): | ||
369 | for i2 in range(val0.shape[2]): | ||
370 | out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1,i2]")) | ||
371 | elif len(val1.shape)==4: | ||
372 | out=numarray.zeros(val0.shape+val1.shape[3:],numarray.Float64) | ||
373 | for i0 in range(val0.shape[0]): | ||
374 | for i1 in range(val0.shape[1]): | ||
375 | for i2 in range(val0.shape[2]): | ||
376 | for j3 in range(val1.shape[3]): | ||
377 | out[i0,i1,i2,j3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1,i2,j3]")) | ||
378 | else: | ||
379 | raise SystemError,"rank of val1 is rargs[1]estricted to 4" | ||
380 | elif len(val0.shape)==4: | ||
381 | if isinstance(val1,float): | ||
382 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
383 | for i0 in range(val0.shape[0]): | ||
384 | for i1 in range(val0.shape[1]): | ||
385 | for i2 in range(val0.shape[2]): | ||
386 | for i3 in range(val0.shape[3]): | ||
387 | out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1")) | ||
388 | elif len(val1.shape)==0: | ||
389 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
390 | for i0 in range(val0.shape[0]): | ||
391 | for i1 in range(val0.shape[1]): | ||
392 | for i2 in range(val0.shape[2]): | ||
393 | for i3 in range(val0.shape[3]): | ||
394 | out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1")) | ||
395 | elif len(val1.shape)==1: | ||
396 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
397 | for i0 in range(val0.shape[0]): | ||
398 | for i1 in range(val0.shape[1]): | ||
399 | for i2 in range(val0.shape[2]): | ||
400 | for i3 in range(val0.shape[3]): | ||
401 | out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0]")) | ||
402 | elif len(val1.shape)==2: | ||
403 | out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64) | ||
404 | for i0 in range(val0.shape[0]): | ||
405 | for i1 in range(val0.shape[1]): | ||
406 | for i2 in range(val0.shape[2]): | ||
407 | for i3 in range(val0.shape[3]): | ||
408 | out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1]")) | ||
409 | elif len(val1.shape)==3: | ||
410 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
411 | for i0 in range(val0.shape[0]): | ||
412 | for i1 in range(val0.shape[1]): | ||
413 | for i2 in range(val0.shape[2]): | ||
414 | for i3 in range(val0.shape[3]): | ||
415 | out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1,i2]")) | ||
416 | elif len(val1.shape)==4: | ||
417 | out=numarray.zeros(val0.shape,numarray.Float64) | ||
418 | for i0 in range(val0.shape[0]): | ||
419 | for i1 in range(val0.shape[1]): | ||
420 | for i2 in range(val0.shape[2]): | ||
421 | for i3 in range(val0.shape[3]): | ||
422 | out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1,i2,i3]")) | ||
423 | else: | ||
424 | raise SystemError,"rank of val1 is restricted to 4" | ||
425 | else: | ||
426 | raise SystemError,"rank is restricted to 4" | ||
427 | return out | ||
428 | jgs | 154 | |
429 | gross | 157 | |
430 | def mkText(case,name,a,a1=None,use_tagging_for_expanded_data=False): | ||
431 | jgs | 154 | t_out="" |
432 | gross | 157 | if case=="float": |
433 | jgs | 154 | if isinstance(a,float): |
434 | t_out+=" %s=%s\n"%(name,a) | ||
435 | gross | 291 | elif a.rank==0: |
436 | jgs | 154 | t_out+=" %s=%s\n"%(name,a) |
437 | else: | ||
438 | t_out+=" %s=numarray.array(%s)\n"%(name,a.tolist()) | ||
439 | gross | 157 | elif case=="array": |
440 | if isinstance(a,float): | ||
441 | t_out+=" %s=numarray.array(%s)\n"%(name,a) | ||
442 | gross | 291 | elif a.rank==0: |
443 | gross | 157 | t_out+=" %s=numarray.array(%s)\n"%(name,a) |
444 | else: | ||
445 | t_out+=" %s=numarray.array(%s)\n"%(name,a.tolist()) | ||
446 | jgs | 154 | elif case=="constData": |
447 | if isinstance(a,float): | ||
448 | t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a) | ||
449 | gross | 291 | elif a.rank==0: |
450 | jgs | 154 | t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a) |
451 | else: | ||
452 | t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist()) | ||
453 | elif case=="taggedData": | ||
454 | if isinstance(a,float): | ||
455 | t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a) | ||
456 | t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1) | ||
457 | gross | 291 | elif a.rank==0: |
458 | jgs | 154 | t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a) |
459 | t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1) | ||
460 | else: | ||
461 | t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist()) | ||
462 | gross | 157 | t_out+=" %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist()) |
463 | jgs | 154 | elif case=="expandedData": |
464 | gross | 157 | if use_tagging_for_expanded_data: |
465 | if isinstance(a,float): | ||
466 | t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a) | ||
467 | t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1) | ||
468 | gross | 291 | elif a.rank==0: |
469 | gross | 157 | t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a) |
470 | t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1) | ||
471 | else: | ||
472 | t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist()) | ||
473 | t_out+=" %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist()) | ||
474 | t_out+=" %s.expand()\n"%name | ||
475 | else: | ||
476 | t_out+=" msk_%s=whereNegative(self.functionspace.getX()[0]-0.5)\n"%name | ||
477 | if isinstance(a,float): | ||
478 | t_out+=" %s=msk_%s*(%s)+(1.-msk_%s)*(%s)\n"%(name,name,a,name,a1) | ||
479 | gross | 291 | elif a.rank==0: |
480 | gross | 157 | t_out+=" %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a,name,a1) |
481 | else: | ||
482 | t_out+=" %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a.tolist(),name,a1.tolist()) | ||
483 | jgs | 154 | elif case=="Symbol": |
484 | if isinstance(a,float): | ||
485 | t_out+=" %s=Symbol(shape=())\n"%(name) | ||
486 | gross | 291 | elif a.rank==0: |
487 | jgs | 154 | t_out+=" %s=Symbol(shape=())\n"%(name) |
488 | else: | ||
489 | t_out+=" %s=Symbol(shape=%s)\n"%(name,str(a.shape)) | ||
490 | |||
491 | return t_out | ||
492 | |||
493 | gross | 157 | def mkTypeAndShapeTest(case,sh,argstr): |
494 | text="" | ||
495 | if case=="float": | ||
496 | text+=" self.failUnless(isinstance(%s,float),\"wrong type of result.\")\n"%argstr | ||
497 | elif case=="array": | ||
498 | text+=" self.failUnless(isinstance(%s,numarray.NumArray),\"wrong type of result.\")\n"%argstr | ||
499 | text+=" self.failUnlessEqual(%s.shape,%s,\"wrong shape of result.\")\n"%(argstr,str(sh)) | ||
500 | elif case in ["constData","taggedData","expandedData"]: | ||
501 | text+=" self.failUnless(isinstance(%s,Data),\"wrong type of result.\")\n"%argstr | ||
502 | text+=" self.failUnlessEqual(%s.getShape(),%s,\"wrong shape of result.\")\n"%(argstr,str(sh)) | ||
503 | elif case=="Symbol": | ||
504 | text+=" self.failUnless(isinstance(%s,Symbol),\"wrong type of result.\")\n"%argstr | ||
505 | text+=" self.failUnlessEqual(%s.getShape(),%s,\"wrong shape of result.\")\n"%(argstr,str(sh)) | ||
506 | return text | ||
507 | jgs | 154 | |
508 | gross | 157 | def mkCode(txt,args=[],intend=""): |
509 | s=txt.split("\n") | ||
510 | if len(s)>1: | ||
511 | out="" | ||
512 | for l in s: | ||
513 | out+=intend+l+"\n" | ||
514 | else: | ||
515 | out="%sreturn %s\n"%(intend,txt) | ||
516 | c=1 | ||
517 | for r in args: | ||
518 | out=out.replace("%%a%s%%"%c,r) | ||
519 | return out | ||
520 | jgs | 154 | |
521 | gross | 291 | def innerTEST(arg0,arg1): |
522 | if isinstance(arg0,float): | ||
523 | out=numarray.array(arg0*arg1) | ||
524 | else: | ||
525 | out=(arg0*arg1).sum() | ||
526 | return out | ||
527 | |||
528 | def outerTEST(arg0,arg1): | ||
529 | if isinstance(arg0,float): | ||
530 | out=numarray.array(arg0*arg1) | ||
531 | elif isinstance(arg1,float): | ||
532 | out=numarray.array(arg0*arg1) | ||
533 | else: | ||
534 | out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape) | ||
535 | return out | ||
536 | |||
537 | def tensorProductTest(arg0,arg1,sh_s): | ||
538 | if isinstance(arg0,float): | ||
539 | out=numarray.array(arg0*arg1) | ||
540 | elif isinstance(arg1,float): | ||
541 | out=numarray.array(arg0*arg1) | ||
542 | elif len(sh_s)==0: | ||
543 | out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape) | ||
544 | else: | ||
545 | l=len(sh_s) | ||
546 | sh0=arg0.shape[:arg0.rank-l] | ||
547 | sh1=arg1.shape[l:] | ||
548 | ls,l0,l1=1,1,1 | ||
549 | for i in sh_s: ls*=i | ||
550 | for i in sh0: l0*=i | ||
551 | for i in sh1: l1*=i | ||
552 | out1=numarray.outerproduct(arg0,arg1).resize((l0,ls,ls,l1)) | ||
553 | out2=numarray.zeros((l0,l1),numarray.Float) | ||
554 | for i0 in range(l0): | ||
555 | for i1 in range(l1): | ||
556 | for i in range(ls): out2[i0,i1]+=out1[i0,i,i,i1] | ||
557 | out=out2.resize(sh0+sh1) | ||
558 | return out | ||
559 | |||
560 | def testMatrixMult(arg0,arg1,sh_s): | ||
561 | return numarray.matrixmultiply(arg0,arg1) | ||
562 | |||
563 | |||
564 | def testTensorMult(arg0,arg1,sh_s): | ||
565 | if len(arg0)==2: | ||
566 | return numarray.matrixmultiply(arg0,arg1) | ||
567 | else: | ||
568 | if arg1.rank==4: | ||
569 | out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2],arg1.shape[3]),numarray.Float) | ||
570 | for i0 in range(arg0.shape[0]): | ||
571 | for i1 in range(arg0.shape[1]): | ||
572 | for i2 in range(arg0.shape[2]): | ||
573 | for i3 in range(arg0.shape[3]): | ||
574 | for j2 in range(arg1.shape[2]): | ||
575 | for j3 in range(arg1.shape[3]): | ||
576 | out[i0,i1,j2,j3]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2,j3] | ||
577 | elif arg1.rank==3: | ||
578 | out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2]),numarray.Float) | ||
579 | for i0 in range(arg0.shape[0]): | ||
580 | for i1 in range(arg0.shape[1]): | ||
581 | for i2 in range(arg0.shape[2]): | ||
582 | for i3 in range(arg0.shape[3]): | ||
583 | for j2 in range(arg1.shape[2]): | ||
584 | out[i0,i1,j2]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2] | ||
585 | elif arg1.rank==2: | ||
586 | out=numarray.zeros((arg0.shape[0],arg0.shape[1]),numarray.Float) | ||
587 | for i0 in range(arg0.shape[0]): | ||
588 | for i1 in range(arg0.shape[1]): | ||
589 | for i2 in range(arg0.shape[2]): | ||
590 | for i3 in range(arg0.shape[3]): | ||
591 | out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3] | ||
592 | return out | ||
593 | gross | 313 | |
594 | def testReduce(arg0,init_val,test_expr,post_expr): | ||
595 | out=init_val | ||
596 | if isinstance(arg0,float): | ||
597 | out=eval(test_expr.replace("%a1%","arg0")) | ||
598 | elif arg0.rank==0: | ||
599 | out=eval(test_expr.replace("%a1%","arg0")) | ||
600 | elif arg0.rank==1: | ||
601 | for i0 in range(arg0.shape[0]): | ||
602 | out=eval(test_expr.replace("%a1%","arg0[i0]")) | ||
603 | elif arg0.rank==2: | ||
604 | for i0 in range(arg0.shape[0]): | ||
605 | for i1 in range(arg0.shape[1]): | ||
606 | out=eval(test_expr.replace("%a1%","arg0[i0,i1]")) | ||
607 | elif arg0.rank==3: | ||
608 | for i0 in range(arg0.shape[0]): | ||
609 | for i1 in range(arg0.shape[1]): | ||
610 | for i2 in range(arg0.shape[2]): | ||
611 | out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2]")) | ||
612 | elif arg0.rank==4: | ||
613 | for i0 in range(arg0.shape[0]): | ||
614 | for i1 in range(arg0.shape[1]): | ||
615 | for i2 in range(arg0.shape[2]): | ||
616 | for i3 in range(arg0.shape[3]): | ||
617 | out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2,i3]")) | ||
618 | return eval(post_expr) | ||
619 | gross | 157 | #======================================================================================================= |
620 | gross | 313 | # local reduction |
621 | #======================================================================================================= | ||
622 | for oper in [["length",0.,"out+%a1%**2","math.sqrt(out)"], | ||
623 | ["maxval",-1.e99,"max(out,%a1%)","out"], | ||
624 | ["minval",1.e99,"min(out,%a1%)","out"] ]: | ||
625 | for case in case_set: | ||
626 | for sh in shape_set: | ||
627 | if not case=="float" or len(sh)==0: | ||
628 | text="" | ||
629 | text+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" | ||
630 | tname="def test_%s_%s_rank%s"%(oper[0],case,len(sh)) | ||
631 | text+=" %s(self):\n"%tname | ||
632 | a=makeArray(sh,[-1.,1.]) | ||
633 | a1=makeArray(sh,[-1.,1.]) | ||
634 | r1=testReduce(a1,oper[1],oper[2],oper[3]) | ||
635 | r=testReduce(a,oper[1],oper[2],oper[3]) | ||
636 | |||
637 | text+=mkText(case,"arg",a,a1) | ||
638 | text+=" res=%s(arg)\n"%oper[0] | ||
639 | if case=="Symbol": | ||
640 | text+=mkText("array","s",a,a1) | ||
641 | text+=" sub=res.substitute({arg:s})\n" | ||
642 | text+=mkText("array","ref",r,r1) | ||
643 | res="sub" | ||
644 | else: | ||
645 | text+=mkText(case,"ref",r,r1) | ||
646 | res="res" | ||
647 | if oper[0]=="length": | ||
648 | text+=mkTypeAndShapeTest(case,(),"res") | ||
649 | else: | ||
650 | if case=="float" or case=="array": | ||
651 | text+=mkTypeAndShapeTest("float",(),"res") | ||
652 | else: | ||
653 | text+=mkTypeAndShapeTest(case,(),"res") | ||
654 | text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res | ||
655 | if case == "taggedData": | ||
656 | t_prog_with_tags+=text | ||
657 | else: | ||
658 | t_prog+=text | ||
659 | print test_header | ||
660 | # print t_prog | ||
661 | print t_prog_with_tags | ||
662 | print test_tail | ||
663 | 1/0 | ||
664 | |||
665 | #======================================================================================================= | ||
666 | gross | 291 | # tensor multiply |
667 | gross | 157 | #======================================================================================================= |
668 | gross | 291 | # oper=["generalTensorProduct",tensorProductTest] |
669 | # oper=["matrixmult",testMatrixMult] | ||
670 | oper=["tensormult",testTensorMult] | ||
671 | |||
672 | for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]: | ||
673 | for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]: | ||
674 | for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]: | ||
675 | for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]: | ||
676 | for sh_s in [ (),(3,), (2,3), (2,4,3),(4,2,3,2)]: | ||
677 | if (len(sh0+sh_s)==0 or not case0=="float") and (len(sh1+sh_s)==0 or not case1=="float") \ | ||
678 | and len(sh0+sh1)<5 and len(sh0+sh_s)<5 and len(sh1+sh_s)<5: | ||
679 | # if len(sh_s)==1 and len(sh0+sh_s)==2 and (len(sh_s+sh1)==1 or len(sh_s+sh1)==2)): # test for matrixmult | ||
680 | if ( len(sh_s)==1 and len(sh0+sh_s)==2 and ( len(sh1+sh_s)==2 or len(sh1+sh_s)==1 )) or (len(sh_s)==2 and len(sh0+sh_s)==4 and (len(sh1+sh_s)==2 or len(sh1+sh_s)==3 or len(sh1+sh_s)==4)): # test for tensormult | ||
681 | case=getResultCaseForBin(case0,case1) | ||
682 | gross | 157 | use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData" |
683 | text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" | ||
684 | gross | 291 | # tname="test_generalTensorProduct_%s_rank%s_%s_rank%s_offset%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1),len(sh_s)) |
685 | #tname="test_matrixmult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1)) | ||
686 | tname="test_tensormult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1)) | ||
687 | # if tname=="test_generalTensorProduct_array_rank1_array_rank2_offset1": | ||
688 | # print tnametest_generalTensorProduct_Symbol_rank1_Symbol_rank3_offset1 | ||
689 | text+=" def %s(self):\n"%tname | ||
690 | a_0=makeArray(sh0+sh_s,[-1.,1]) | ||
691 | if case0 in ["taggedData", "expandedData"]: | ||
692 | a1_0=makeArray(sh0+sh_s,[-1.,1]) | ||
693 | else: | ||
694 | a1_0=a_0 | ||
695 | |||
696 | a_1=makeArray(sh_s+sh1,[-1.,1]) | ||
697 | if case1 in ["taggedData", "expandedData"]: | ||
698 | a1_1=makeArray(sh_s+sh1,[-1.,1]) | ||
699 | else: | ||
700 | a1_1=a_1 | ||
701 | r=oper[1](a_0,a_1,sh_s) | ||
702 | r1=oper[1](a1_0,a1_1,sh_s) | ||
703 | text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data) | ||
704 | text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data) | ||
705 | #text+=" res=matrixmult(arg0,arg1)\n" | ||
706 | text+=" res=tensormult(arg0,arg1)\n" | ||
707 | #text+=" res=generalTensorProduct(arg0,arg1,offset=%s)\n"%(len(sh_s)) | ||
708 | if case=="Symbol": | ||
709 | c0_res,c1_res=case0,case1 | ||
710 | subs="{" | ||
711 | if case0=="Symbol": | ||
712 | text+=mkText("array","s0",a_0,a1_0) | ||
713 | subs+="arg0:s0" | ||
714 | c0_res="array" | ||
715 | if case1=="Symbol": | ||
716 | text+=mkText("array","s1",a_1,a1_1) | ||
717 | if not subs.endswith("{"): subs+="," | ||
718 | subs+="arg1:s1" | ||
719 | c1_res="array" | ||
720 | subs+="}" | ||
721 | text+=" sub=res.substitute(%s)\n"%subs | ||
722 | res="sub" | ||
723 | text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1) | ||
724 | else: | ||
725 | res="res" | ||
726 | text+=mkText(case,"ref",r,r1) | ||
727 | text+=mkTypeAndShapeTest(case,sh0+sh1,"res") | ||
728 | text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res | ||
729 | if case0 == "taggedData" or case1 == "taggedData": | ||
730 | t_prog_with_tags+=text | ||
731 | else: | ||
732 | t_prog+=text | ||
733 | print test_header | ||
734 | # print t_prog | ||
735 | print t_prog_with_tags | ||
736 | print test_tail | ||
737 | 1/0 | ||
738 | #======================================================================================================= | ||
739 | # outer/inner | ||
740 | #======================================================================================================= | ||
741 | oper=["inner",innerTEST] | ||
742 | # oper=["outer",outerTEST] | ||
743 | for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]: | ||
744 | for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]: | ||
745 | for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]: | ||
746 | for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]: | ||
747 | if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \ | ||
748 | and len(sh0+sh1)<5: | ||
749 | use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData" | ||
750 | |||
751 | text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" | ||
752 | gross | 157 | tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1)) |
753 | text+=" def %s(self):\n"%tname | ||
754 | gross | 291 | a_0=makeArray(sh0,[-1.,1]) |
755 | gross | 157 | if case0 in ["taggedData", "expandedData"]: |
756 | gross | 291 | a1_0=makeArray(sh0,[-1.,1]) |
757 | gross | 157 | else: |
758 | a1_0=a_0 | ||
759 | jgs | 154 | |
760 | gross | 291 | a_1=makeArray(sh1,[-1.,1]) |
761 | gross | 157 | if case1 in ["taggedData", "expandedData"]: |
762 | gross | 291 | a1_1=makeArray(sh1,[-1.,1]) |
763 | gross | 157 | else: |
764 | a1_1=a_1 | ||
765 | gross | 291 | r=oper[1](a_0,a_1) |
766 | r1=oper[1](a1_0,a1_1) | ||
767 | gross | 157 | text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data) |
768 | text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data) | ||
769 | text+=" res=%s(arg0,arg1)\n"%oper[0] | ||
770 | case=getResultCaseForBin(case0,case1) | ||
771 | if case=="Symbol": | ||
772 | c0_res,c1_res=case0,case1 | ||
773 | subs="{" | ||
774 | if case0=="Symbol": | ||
775 | text+=mkText("array","s0",a_0,a1_0) | ||
776 | subs+="arg0:s0" | ||
777 | c0_res="array" | ||
778 | if case1=="Symbol": | ||
779 | text+=mkText("array","s1",a_1,a1_1) | ||
780 | if not subs.endswith("{"): subs+="," | ||
781 | subs+="arg1:s1" | ||
782 | c1_res="array" | ||
783 | subs+="}" | ||
784 | text+=" sub=res.substitute(%s)\n"%subs | ||
785 | res="sub" | ||
786 | text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1) | ||
787 | else: | ||
788 | res="res" | ||
789 | text+=mkText(case,"ref",r,r1) | ||
790 | gross | 291 | text+=mkTypeAndShapeTest(case,sh0+sh1,"res") |
791 | gross | 157 | text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res |
792 | |||
793 | if case0 == "taggedData" or case1 == "taggedData": | ||
794 | t_prog_with_tags+=text | ||
795 | else: | ||
796 | t_prog+=text | ||
797 | gross | 291 | |
798 | gross | 157 | print test_header |
799 | # print t_prog | ||
800 | print t_prog_with_tags | ||
801 | gross | 291 | print test_tail |
802 | jgs | 154 | 1/0 |
803 | gross | 157 | #======================================================================================================= |
804 | # basic binary operation overloading (tests only!) | ||
805 | #======================================================================================================= | ||
806 | oper_range=[-5.,5.] | ||
807 | for oper in [["add" ,"+",[-5.,5.]], | ||
808 | ["sub" ,"-",[-5.,5.]], | ||
809 | ["mult","*",[-5.,5.]], | ||
810 | ["div" ,"/",[-5.,5.]], | ||
811 | ["pow" ,"**",[0.01,5.]]]: | ||
812 | for case0 in case_set: | ||
813 | for sh0 in shape_set: | ||
814 | for case1 in case_set: | ||
815 | for sh1 in shape_set: | ||
816 | gross | 291 | if not case0=="array" and \ |
817 | (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \ | ||
818 | gross | 157 | (sh0==() or sh1==() or sh1==sh0) and \ |
819 | not (case0 in ["float","array"] and case1 in ["float","array"]): | ||
820 | gross | 291 | use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData" |
821 | gross | 157 | text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" |
822 | tname="test_%s_overloaded_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1)) | ||
823 | text+=" def %s(self):\n"%tname | ||
824 | a_0=makeArray(sh0,oper[2]) | ||
825 | if case0 in ["taggedData", "expandedData"]: | ||
826 | a1_0=makeArray(sh0,oper[2]) | ||
827 | else: | ||
828 | a1_0=a_0 | ||
829 | jgs | 154 | |
830 | gross | 157 | a_1=makeArray(sh1,oper[2]) |
831 | if case1 in ["taggedData", "expandedData"]: | ||
832 | a1_1=makeArray(sh1,oper[2]) | ||
833 | jgs | 154 | else: |
834 | gross | 157 | a1_1=a_1 |
835 | r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%") | ||
836 | r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%") | ||
837 | text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data) | ||
838 | text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data) | ||
839 | text+=" res=arg0%sarg1\n"%oper[1] | ||
840 | |||
841 | case=getResultCaseForBin(case0,case1) | ||
842 | if case=="Symbol": | ||
843 | c0_res,c1_res=case0,case1 | ||
844 | subs="{" | ||
845 | if case0=="Symbol": | ||
846 | text+=mkText("array","s0",a_0,a1_0) | ||
847 | subs+="arg0:s0" | ||
848 | c0_res="array" | ||
849 | if case1=="Symbol": | ||
850 | text+=mkText("array","s1",a_1,a1_1) | ||
851 | if not subs.endswith("{"): subs+="," | ||
852 | subs+="arg1:s1" | ||
853 | c1_res="array" | ||
854 | subs+="}" | ||
855 | text+=" sub=res.substitute(%s)\n"%subs | ||
856 | res="sub" | ||
857 | text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1) | ||
858 | jgs | 154 | else: |
859 | gross | 157 | res="res" |
860 | text+=mkText(case,"ref",r,r1) | ||
861 | if isinstance(r,float): | ||
862 | text+=mkTypeAndShapeTest(case,(),"res") | ||
863 | else: | ||
864 | text+=mkTypeAndShapeTest(case,r.shape,"res") | ||
865 | text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res | ||
866 | |||
867 | if case0 in [ "constData","taggedData","expandedData"] and case1 == "Symbol": | ||
868 | t_prog_failing+=text | ||
869 | else: | ||
870 | if case0 == "taggedData" or case1 == "taggedData": | ||
871 | t_prog_with_tags+=text | ||
872 | else: | ||
873 | t_prog+=text | ||
874 | jgs | 154 | |
875 | gross | 157 | |
876 | print test_header | ||
877 | gross | 291 | # print t_prog |
878 | # print t_prog_with_tags | ||
879 | print t_prog_failing | ||
880 | print test_tail | ||
881 | jgs | 154 | 1/0 |
882 | gross | 291 | #======================================================================================================= |
883 | # basic binary operations (tests only!) | ||
884 | #======================================================================================================= | ||
885 | oper_range=[-5.,5.] | ||
886 | for oper in [["add" ,"+",[-5.,5.]], | ||
887 | ["mult","*",[-5.,5.]], | ||
888 | ["quotient" ,"/",[-5.,5.]], | ||
889 | ["power" ,"**",[0.01,5.]]]: | ||
890 | for case0 in case_set: | ||
891 | for case1 in case_set: | ||
892 | for sh in shape_set: | ||
893 | for sh_p in shape_set: | ||
894 | if len(sh_p)>0: | ||
895 | resource=[-1,1] | ||
896 | else: | ||
897 | resource=[1] | ||
898 | for sh_d in resource: | ||
899 | if sh_d>0: | ||
900 | sh0=sh | ||
901 | sh1=sh+sh_p | ||
902 | else: | ||
903 | sh1=sh | ||
904 | sh0=sh+sh_p | ||
905 | |||
906 | if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \ | ||
907 | len(sh0)<5 and len(sh1)<5: | ||
908 | use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData" | ||
909 | text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" | ||
910 | tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1)) | ||
911 | text+=" def %s(self):\n"%tname | ||
912 | a_0=makeArray(sh0,oper[2]) | ||
913 | if case0 in ["taggedData", "expandedData"]: | ||
914 | a1_0=makeArray(sh0,oper[2]) | ||
915 | else: | ||
916 | a1_0=a_0 | ||
917 | |||
918 | a_1=makeArray(sh1,oper[2]) | ||
919 | if case1 in ["taggedData", "expandedData"]: | ||
920 | a1_1=makeArray(sh1,oper[2]) | ||
921 | else: | ||
922 | a1_1=a_1 | ||
923 | r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%") | ||
924 | r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%") | ||
925 | text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data) | ||
926 | text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data) | ||
927 | text+=" res=%s(arg0,arg1)\n"%oper[0] | ||
928 | |||
929 | case=getResultCaseForBin(case0,case1) | ||
930 | if case=="Symbol": | ||
931 | c0_res,c1_res=case0,case1 | ||
932 | subs="{" | ||
933 | if case0=="Symbol": | ||
934 | text+=mkText("array","s0",a_0,a1_0) | ||
935 | subs+="arg0:s0" | ||
936 | c0_res="array" | ||
937 | if case1=="Symbol": | ||
938 | text+=mkText("array","s1",a_1,a1_1) | ||
939 | if not subs.endswith("{"): subs+="," | ||
940 | subs+="arg1:s1" | ||
941 | c1_res="array" | ||
942 | subs+="}" | ||
943 | text+=" sub=res.substitute(%s)\n"%subs | ||
944 | res="sub" | ||
945 | text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1) | ||
946 | else: | ||
947 | res="res" | ||
948 | text+=mkText(case,"ref",r,r1) | ||
949 | if isinstance(r,float): | ||
950 | text+=mkTypeAndShapeTest(case,(),"res") | ||
951 | else: | ||
952 | text+=mkTypeAndShapeTest(case,r.shape,"res") | ||
953 | text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res | ||
954 | |||
955 | if case0 == "taggedData" or case1 == "taggedData": | ||
956 | t_prog_with_tags+=text | ||
957 | else: | ||
958 | t_prog+=text | ||
959 | print test_header | ||
960 | # print t_prog | ||
961 | print t_prog_with_tags | ||
962 | print test_tail | ||
963 | 1/0 | ||
964 | |||
965 | gross | 157 | # print t_prog_with_tagsoper_range=[-5.,5.] |
966 | for oper in [["add" ,"+",[-5.,5.]], | ||
967 | ["sub" ,"-",[-5.,5.]], | ||
968 | ["mult","*",[-5.,5.]], | ||
969 | ["div" ,"/",[-5.,5.]], | ||
970 | ["pow" ,"**",[0.01,5.]]]: | ||
971 | for case0 in case_set: | ||
972 | for sh0 in shape_set: | ||
973 | for case1 in case_set: | ||
974 | for sh1 in shape_set: | ||
975 | if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \ | ||
976 | (sh0==() or sh1==() or sh1==sh0) and \ | ||
977 | not (case0 in ["float","array"] and case1 in ["float","array"]): | ||
978 | text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" | ||
979 | tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1)) | ||
980 | text+=" def %s(self):\n"%tname | ||
981 | a_0=makeArray(sh0,oper[2]) | ||
982 | if case0 in ["taggedData", "expandedData"]: | ||
983 | a1_0=makeArray(sh0,oper[2]) | ||
984 | else: | ||
985 | a1_0=a_0 | ||
986 | jgs | 154 | |
987 | gross | 157 | a_1=makeArray(sh1,oper[2]) |
988 | if case1 in ["taggedData", "expandedData"]: | ||
989 | a1_1=makeArray(sh1,oper[2]) | ||
990 | jgs | 154 | else: |
991 | gross | 157 | a1_1=a_1 |
992 | r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%") | ||
993 | r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%") | ||
994 | text+=mkText(case0,"arg0",a_0,a1_0) | ||
995 | text+=mkText(case1,"arg1",a_1,a1_1) | ||
996 | text+=" res=arg0%sarg1\n"%oper[1] | ||
997 | |||
998 | case=getResultCaseForBin(case0,case1) | ||
999 | if case=="Symbol": | ||
1000 | c0_res,c1_res=case0,case1 | ||
1001 | subs="{" | ||
1002 | if case0=="Symbol": | ||
1003 | text+=mkText("array","s0",a_0,a1_0) | ||
1004 | subs+="arg0:s0" | ||
1005 | c0_res="array" | ||
1006 | if case1=="Symbol": | ||
1007 | text+=mkText("array","s1",a_1,a1_1) | ||
1008 | if not subs.endswith("{"): subs+="," | ||
1009 | subs+="arg1:s1" | ||
1010 | c1_res="array" | ||
1011 | subs+="}" | ||
1012 | text+=" sub=res.substitute(%s)\n"%subs | ||
1013 | res="sub" | ||
1014 | text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1) | ||
1015 | jgs | 154 | else: |
1016 | gross | 157 | res="res" |
1017 | text+=mkText(case,"ref",r,r1) | ||
1018 | if isinstance(r,float): | ||
1019 | text+=mkTypeAndShapeTest(case,(),"res") | ||
1020 | else: | ||
1021 | text+=mkTypeAndShapeTest(case,r.shape,"res") | ||
1022 | text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res | ||
1023 | |||
1024 | if case0 in [ "constData","taggedData","expandedData"] and case1 == "Symbol": | ||
1025 | t_prog_failing+=text | ||
1026 | else: | ||
1027 | if case0 == "taggedData" or case1 == "taggedData": | ||
1028 | t_prog_with_tags+=text | ||
1029 | else: | ||
1030 | t_prog+=text | ||
1031 | jgs | 154 | |
1032 | gross | 157 | |
1033 | # print u_prog | ||
1034 | # 1/0 | ||
1035 | print test_header | ||
1036 | jgs | 154 | print t_prog |
1037 | gross | 157 | # print t_prog_with_tags |
1038 | # print t_prog_failing | ||
1039 | print test_tail | ||
1040 | # print t_prog_failing | ||
1041 | print test_tail | ||
1042 | jgs | 154 | |
1043 | gross | 157 | #======================================================================================================= |
1044 | # unary operations: | ||
1045 | #======================================================================================================= | ||
1046 | func= [ | ||
1047 | OPERATOR(nickname="log10",\ | ||
1048 | rng=[1.e-3,100.],\ | ||
1049 | test_expr="math.log10(%a1%)",\ | ||
1050 | math_expr="math.log10(%a1%)",\ | ||
1051 | numarray_expr="numarray.log10(%a1%)",\ | ||
1052 | symbol_expr="log(%a1%)/log(10.)",\ | ||
1053 | name="base-10 logarithm"), | ||
1054 | OPERATOR(nickname="wherePositive",\ | ||
1055 | rng=[-100.,100.],\ | ||
1056 | test_expr="wherepos(%a1%)",\ | ||
1057 | math_expr="if arg>0:\n return 1.\nelse:\n return 0.", | ||
1058 | numarray_expr="numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))",\ | ||
1059 | name="mask of positive values"), | ||
1060 | OPERATOR(nickname="whereNegative",\ | ||
1061 | rng=[-100.,100.],\ | ||
1062 | test_expr="wherepos(-%a1%)",\ | ||
1063 | math_expr="if arg<0:\n return 1.\nelse:\n return 0.", | ||
1064 | numarray_expr="numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))",\ | ||
1065 | name="mask of positive values"), | ||
1066 | OPERATOR(nickname="whereNonNegative",\ | ||
1067 | rng=[-100.,100.],\ | ||
1068 | test_expr="1-wherepos(-%a1%)", \ | ||
1069 | math_expr="if arg<0:\n return 0.\nelse:\n return 1.", | ||
1070 | numarray_expr="numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))",\ | ||
1071 | symbol_expr="1-wherePositive(%a1%)",\ | ||
1072 | name="mask of non-negative values"), | ||
1073 | OPERATOR(nickname="whereNonPositive",\ | ||
1074 | rng=[-100.,100.],\ | ||
1075 | test_expr="1-wherepos(%a1%)",\ | ||
1076 | math_expr="if arg>0:\n return 0.\nelse:\n return 1.", | ||
1077 | numarray_expr="numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))",\ | ||
1078 | symbol_expr="1-whereNegative(%a1%)",\ | ||
1079 | name="mask of non-positive values"), | ||
1080 | OPERATOR(nickname="whereZero",\ | ||
1081 | rng=[-100.,100.],\ | ||
1082 | test_expr="1-wherepos(%a1%)-wherepos(-%a1%)",\ | ||
1083 | math_expr="if abs(%a1%)<=tol:\n return 1.\nelse:\n return 0.", | ||
1084 | numarray_expr="numarray.less_equal(abs(%a1%)-tol,numarray.zeros(arg.shape,numarray.Float))",\ | ||
1085 | name="mask of zero entries"), | ||
1086 | OPERATOR(nickname="whereNonZero",\ | ||
1087 | rng=[-100.,100.],\ | ||
1088 | test_expr="wherepos(%a1%)+wherepos(-%a1%)",\ | ||
1089 | math_expr="if abs(%a1%)>tol:\n return 1.\nelse:\n return 0.",\ | ||
1090 | numarray_expr="numarray.greater(abs(%a1%)-tol,numarray.zeros(arg.shape,numarray.Float))",\ | ||
1091 | symbol_expr="1-whereZero(arg,tol)",\ | ||
1092 | name="mask of values different from zero"), | ||
1093 | OPERATOR(nickname="sin",\ | ||
1094 | rng=[-100.,100.],\ | ||
1095 | test_expr="math.sin(%a1%)", | ||
1096 | numarray_expr="numarray.sin(%a1%)",\ | ||
1097 | diff="cos(%a1%)",\ | ||
1098 | name="sine"), | ||
1099 | OPERATOR(nickname="cos",\ | ||
1100 | rng=[-100.,100.],\ | ||
1101 | test_expr="math.cos(%a1%)", | ||
1102 | numarray_expr="numarray.cos(%a1%)",\ | ||
1103 | diff="-sin(%a1%)", | ||
1104 | name="cosine"), | ||
1105 | OPERATOR(nickname="tan",\ | ||
1106 | rng=[-100.,100.],\ | ||
1107 | test_expr="math.tan(%a1%)", | ||
1108 | numarray_expr="numarray.tan(%a1%)",\ | ||
1109 | diff="1./cos(%a1%)**2", | ||
1110 | name="tangent"), | ||
1111 | OPERATOR(nickname="asin",\ | ||
1112 | rng=[-0.99,0.99],\ | ||
1113 | test_expr="math.asin(%a1%)", | ||
1114 | numarray_expr="numarray.arcsin(%a1%)", | ||
1115 | diff="1./sqrt(1.-%a1%**2)", | ||
1116 | name="inverse sine"), | ||
1117 | OPERATOR(nickname="acos",\ | ||
1118 | rng=[-0.99,0.99],\ | ||
1119 | test_expr="math.acos(%a1%)", | ||
1120 | numarray_expr="numarray.arccos(%a1%)", | ||
1121 | diff="-1./sqrt(1.-%a1%**2)", | ||
1122 | name="inverse cosine"), | ||
1123 | OPERATOR(nickname="atan",\ | ||
1124 | rng=[-100.,100.],\ | ||
1125 | test_expr="math.atan(%a1%)", | ||
1126 | numarray_expr="numarray.arctan(%a1%)", | ||
1127 | diff="1./(1+%a1%**2)", | ||
1128 | name="inverse tangent"), | ||
1129 | OPERATOR(nickname="sinh",\ | ||
1130 | rng=[-5,5],\ | ||
1131 | test_expr="math.sinh(%a1%)", | ||
1132 | numarray_expr="numarray.sinh(%a1%)", | ||
1133 | diff="cosh(%a1%)", | ||
1134 | name="hyperbolic sine"), | ||
1135 | OPERATOR(nickname="cosh",\ | ||
1136 | rng=[-5.,5.], | ||
1137 | test_expr="math.cosh(%a1%)", | ||
1138 | numarray_expr="numarray.cosh(%a1%)", | ||
1139 | diff="sinh(%a1%)", | ||
1140 | name="hyperbolic cosine"), | ||
1141 | OPERATOR(nickname="tanh",\ | ||
1142 | rng=[-5.,5.], | ||
1143 | test_expr="math.tanh(%a1%)", | ||
1144 | numarray_expr="numarray.tanh(%a1%)", | ||
1145 | diff="1./cosh(%a1%)**2", | ||
1146 | name="hyperbolic tangent"), | ||
1147 | OPERATOR(nickname="asinh",\ | ||
1148 | rng=[-100.,100.], \ | ||
1149 | test_expr="numarray.arcsinh(%a1%)", | ||
1150 | math_expr="numarray.arcsinh(%a1%)", | ||
1151 | numarray_expr="numarray.arcsinh(%a1%)", | ||
1152 | diff="1./sqrt(%a1%**2+1)", | ||
1153 | name="inverse hyperbolic sine"), | ||
1154 | OPERATOR(nickname="acosh",\ | ||
1155 | rng=[1.001,100.],\ | ||
1156 | test_expr="numarray.arccosh(%a1%)", | ||
1157 | math_expr="numarray.arccosh(%a1%)", | ||
1158 | numarray_expr="numarray.arccosh(%a1%)", | ||
1159 | diff="1./sqrt(%a1%**2-1)", | ||
1160 | name="inverse hyperolic cosine"), | ||
1161 | OPERATOR(nickname="atanh",\ | ||
1162 | rng=[-0.99,0.99], \ | ||
1163 | test_expr="numarray.arctanh(%a1%)", | ||
1164 | math_expr="numarray.arctanh(%a1%)", | ||
1165 | numarray_expr="numarray.arctanh(%a1%)", | ||
1166 | diff="1./(1.-%a1%**2)", | ||
1167 | name="inverse hyperbolic tangent"), | ||
1168 | OPERATOR(nickname="exp",\ | ||
1169 | rng=[-5.,5.], | ||
1170 | test_expr="math.exp(%a1%)", | ||
1171 | numarray_expr="numarray.exp(%a1%)", | ||
1172 | diff="self", | ||
1173 | name="exponential"), | ||
1174 | OPERATOR(nickname="sqrt",\ | ||
1175 | rng=[1.e-3,100.],\ | ||
1176 | test_expr="math.sqrt(%a1%)", | ||
1177 | numarray_expr="numarray.sqrt(%a1%)", | ||
1178 | diff="0.5/self", | ||
1179 | name="square root"), | ||
1180 | OPERATOR(nickname="log", \ | ||
1181 | rng=[1.e-3,100.],\ | ||
1182 | test_expr="math.log(%a1%)", | ||
1183 | numarray_expr="numarray.log(%a1%)", | ||
1184 | diff="1./arg", | ||
1185 | name="natural logarithm"), | ||
1186 | OPERATOR(nickname="sign",\ | ||
1187 | rng=[-100.,100.], \ | ||
1188 | math_expr="if %a1%>0:\n return 1.\nelif %a1%<0:\n return -1.\nelse:\n return 0.", | ||
1189 | test_expr="wherepos(%a1%)-wherepos(-%a1%)", | ||
1190 | numarray_expr="numarray.sign(%a1%)", | ||
1191 | symbol_expr="wherePositive(%a1%)-whereNegative(%a1%)",\ | ||
1192 | name="sign"), | ||
1193 | OPERATOR(nickname="abs",\ | ||
1194 | rng=[-100.,100.], \ | ||
1195 | math_expr="if %a1%>0:\n return %a1% \nelif %a1%<0:\n return -(%a1%)\nelse:\n return 0.", | ||
1196 | test_expr="wherepos(%a1%)*(%a1%)-wherepos(-%a1%)*(%a1%)", | ||
1197 | numarray_expr="abs(%a1%)", | ||
1198 | diff="sign(%a1%)", | ||
1199 | name="absolute value") | ||
1200 | |||
1201 | ] | ||
1202 | for f in func: | ||
1203 | symbol_name=f.nickname[0].upper()+f.nickname[1:] | ||
1204 | if f.nickname!="abs": | ||
1205 | u_prog+="def %s(arg):\n"%f.nickname | ||
1206 | u_prog+=" \"\"\"\n" | ||
1207 | u_prog+=" returns %s of argument arg\n\n"%f.name | ||
1208 | u_prog+=" @param arg: argument\n" | ||
1209 | u_prog+=" @type arg: C{float}, L{escript.Data}, L{Symbol}, L{numarray.NumArray}.\n" | ||
1210 | u_prog+=" @rtype:C{float}, L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.\n" | ||
1211 | u_prog+=" @raises TypeError: if the type of the argument is not expected.\n" | ||
1212 | u_prog+=" \"\"\"\n" | ||
1213 | u_prog+=" if isinstance(arg,numarray.NumArray):\n" | ||
1214 | u_prog+=mkCode(f.numarray_expr,["arg"],2*" ") | ||
1215 | u_prog+=" elif isinstance(arg,escript.Data):\n" | ||
1216 | u_prog+=mkCode("arg._%s()"%f.nickname,[],2*" ") | ||
1217 | u_prog+=" elif isinstance(arg,float):\n" | ||
1218 | u_prog+=mkCode(f.math_expr,["arg"],2*" ") | ||
1219 | u_prog+=" elif isinstance(arg,int):\n" | ||
1220 | u_prog+=mkCode(f.math_expr,["float(arg)"],2*" ") | ||
1221 | u_prog+=" elif isinstance(arg,Symbol):\n" | ||
1222 | if f.symbol_expr==None: | ||
1223 | u_prog+=mkCode("%s_Symbol(arg)"%symbol_name,[],2*" ") | ||
1224 | else: | ||
1225 | u_prog+=mkCode(f.symbol_expr,["arg"],2*" ") | ||
1226 | u_prog+=" else:\n" | ||
1227 | u_prog+=" raise TypeError,\"%s: Unknown argument type.\"\n\n"%f.nickname | ||
1228 | if f.symbol_expr==None: | ||
1229 | u_prog+="class %s_Symbol(DependendSymbol):\n"%symbol_name | ||
1230 | u_prog+=" \"\"\"\n" | ||
1231 | u_prog+=" L{Symbol} representing the result of the %s function\n"%f.name | ||
1232 | u_prog+=" \"\"\"\n" | ||
1233 | u_prog+=" def __init__(self,arg):\n" | ||
1234 | u_prog+=" \"\"\"\n" | ||
1235 | u_prog+=" initialization of %s L{Symbol} with argument arg\n"%f.nickname | ||
1236 | u_prog+=" @param arg: argument of function\n" | ||
1237 | u_prog+=" @type arg: typically L{Symbol}.\n" | ||
1238 | u_prog+=" \"\"\"\n" | ||
1239 | u_prog+=" DependendSymbol.__init__(self,args=[arg],shape=arg.getShape(),dim=arg.getDim())\n" | ||
1240 | u_prog+="\n" | ||
1241 | |||
1242 | u_prog+=" def getMyCode(self,argstrs,format=\"escript\"):\n" | ||
1243 | u_prog+=" \"\"\"\n" | ||
1244 | u_prog+=" returns a program code that can be used to evaluate the symbol.\n\n" | ||
1245 | jgs | 154 | |
1246 | gross | 157 | u_prog+=" @param argstrs: gives for each argument a string representing the argument for the evaluation.\n" |
1247 | u_prog+=" @type argstrs: C{str} or a C{list} of length 1 of C{str}.\n" | ||
1248 | u_prog+=" @param format: specifies the format to be used. At the moment only \"escript\" ,\"text\" and \"str\" are supported.\n" | ||
1249 | u_prog+=" @type format: C{str}\n" | ||
1250 | u_prog+=" @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.\n" | ||
1251 | u_prog+=" @rtype: C{str}\n" | ||
1252 | u_prog+=" @raise: NotImplementedError: if the requested format is not available\n" | ||
1253 | u_prog+=" \"\"\"\n" | ||
1254 | u_prog+=" if isinstance(argstrs,list):\n" | ||
1255 | u_prog+=" argstrs=argstrs[0]\n" | ||
1256 | u_prog+=" if format==\"escript\" or format==\"str\" or format==\"text\":\n" | ||
1257 | u_prog+=" return \"%s(%%s)\"%%argstrs\n"%f.nickname | ||
1258 | u_prog+=" else:\n" | ||
1259 | u_prog+=" raise NotImplementedError,\"%s_Symbol does not provide program code for format %%s.\"%%format\n"%symbol_name | ||
1260 | u_prog+="\n" | ||
1261 | jgs | 154 | |
1262 | gross | 157 | u_prog+=" def substitute(self,argvals):\n" |
1263 | u_prog+=" \"\"\"\n" | ||
1264 | u_prog+=" assigns new values to symbols in the definition of the symbol.\n" | ||
1265 | u_prog+=" The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.\n" | ||
1266 | u_prog+="\n" | ||
1267 | u_prog+=" @param argvals: new values assigned to symbols\n" | ||
1268 | u_prog+=" @type argvals: C{dict} with keywords of type L{Symbol}.\n" | ||
1269 | u_prog+=" @return: result of the substitution process. Operations are executed as much as possible.\n" | ||
1270 | u_prog+=" @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution\n" | ||
1271 | u_prog+=" @raise TypeError: if a value for a L{Symbol} cannot be substituted.\n" | ||
1272 | u_prog+=" \"\"\"\n" | ||
1273 | u_prog+=" if argvals.has_key(self):\n" | ||
1274 | u_prog+=" arg=argvals[self]\n" | ||
1275 | u_prog+=" if self.isAppropriateValue(arg):\n" | ||
1276 | u_prog+=" return arg\n" | ||
1277 | u_prog+=" else:\n" | ||
1278 | u_prog+=" raise TypeError,\"%s: new value is not appropriate.\"%str(self)\n" | ||
1279 | u_prog+=" else:\n" | ||
1280 | u_prog+=" arg=self.getSubstitutedArguments(argvals)[0]\n" | ||
1281 | u_prog+=" return %s(arg)\n\n"%f.nickname | ||
1282 | if not f.diff==None: | ||
1283 | u_prog+=" def diff(self,arg):\n" | ||
1284 | u_prog+=" \"\"\"\n" | ||
1285 | u_prog+=" differential of this object\n" | ||
1286 | u_prog+="\n" | ||
1287 | u_prog+=" @param arg: the derivative is calculated with respect to arg\n" | ||
1288 | u_prog+=" @type arg: L{escript.Symbol}\n" | ||
1289 | u_prog+=" @return: derivative with respect to C{arg}\n" | ||
1290 | u_prog+=" @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray} are possible.\n" | ||
1291 | u_prog+=" \"\"\"\n" | ||
1292 | u_prog+=" if arg==self:\n" | ||
1293 | u_prog+=" return identity(self.getShape())\n" | ||
1294 | u_prog+=" else:\n" | ||
1295 | u_prog+=" myarg=self.getArgument()[0]\n" | ||
1296 | u_prog+=" val=matchShape(%s,self.getDifferentiatedArguments(arg)[0])\n"%f.diff.replace("%a1%","myarg") | ||
1297 | u_prog+=" return val[0]*val[1]\n\n" | ||
1298 | jgs | 154 | |
1299 | gross | 157 | for case in case_set: |
1300 | for sh in shape_set: | ||
1301 | if not case=="float" or len(sh)==0: | ||
1302 | text="" | ||
1303 | text+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" | ||
1304 | tname="def test_%s_%s_rank%s"%(f.nickname,case,len(sh)) | ||
1305 | text+=" %s(self):\n"%tname | ||
1306 | a=makeArray(sh,f.rng) | ||
1307 | a1=makeArray(sh,f.rng) | ||
1308 | r1=makeResult(a1,f.test_expr) | ||
1309 | r=makeResult(a,f.test_expr) | ||
1310 | |||
1311 | text+=mkText(case,"arg",a,a1) | ||
1312 | text+=" res=%s(arg)\n"%f.nickname | ||
1313 | if case=="Symbol": | ||
1314 | text+=mkText("array","s",a,a1) | ||
1315 | text+=" sub=res.substitute({arg:s})\n" | ||
1316 | text+=mkText("array","ref",r,r1) | ||
1317 | res="sub" | ||
1318 | else: | ||
1319 | text+=mkText(case,"ref",r,r1) | ||
1320 | res="res" | ||
1321 | text+=mkTypeAndShapeTest(case,sh,"res") | ||
1322 | text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res | ||
1323 | if case == "taggedData": | ||
1324 | t_prog_with_tags+=text | ||
1325 | else: | ||
1326 | t_prog+=text | ||
1327 | |||
1328 | #=========== END OF GOOD CODE +++++++++++++++++++++++++++ | ||
1329 | jgs | 154 | |
1330 | gross | 157 | 1/0 |
1331 | jgs | 154 | |
1332 | gross | 157 | def X(): |
1333 | if args=="float": | ||
1334 | a=makeArray(sh,f[RANGE]) | ||
1335 | jgs | 154 | r=makeResult(a,f) |
1336 | gross | 157 | t_prog+=" arg=%s\n"%a[0] |
1337 | t_prog+=" ref=%s\n"%r[0] | ||
1338 | t_prog+=" res=%s(%a1%)\n"%f.nickname | ||
1339 | t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n" | ||
1340 | t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n" | ||
1341 | elif args == "array": | ||
1342 | a=makeArray(sh,f[RANGE]) | ||
1343 | r=makeResult(a,f) | ||
1344 | jgs | 154 | if len(sh)==0: |
1345 | gross | 157 | t_prog+=" arg=numarray.array(%s)\n"%a[0] |
1346 | t_prog+=" ref=numarray.array(%s)\n"%r[0] | ||
1347 | jgs | 154 | else: |
1348 | t_prog+=" arg=numarray.array(%s)\n"%a.tolist() | ||
1349 | t_prog+=" ref=numarray.array(%s)\n"%r.tolist() | ||
1350 | gross | 157 | t_prog+=" res=%s(%a1%)\n"%f.nickname |
1351 | t_prog+=" self.failUnlessEqual(res.shape,%s,\"wrong shape of result.\")\n"%str(sh) | ||
1352 | t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n" | ||
1353 | jgs | 154 | elif args== "constData": |
1354 | gross | 157 | a=makeArray(sh,f[RANGE]) |
1355 | jgs | 154 | r=makeResult(a,f) |
1356 | if len(sh)==0: | ||
1357 | t_prog+=" arg=Data(%s,self.functionspace)\n"%(a) | ||
1358 | t_prog+=" ref=%s\n"%r | ||
1359 | else: | ||
1360 | t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist()) | ||
1361 | t_prog+=" ref=numarray.array(%s)\n"%r.tolist() | ||
1362 | gross | 157 | t_prog+=" res=%s(%a1%)\n"%f.nickname |
1363 | jgs | 154 | t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh) |
1364 | t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n" | ||
1365 | elif args in [ "taggedData","expandedData"]: | ||
1366 | gross | 157 | a=makeArray(sh,f[RANGE]) |
1367 | jgs | 154 | r=makeResult(a,f) |
1368 | gross | 157 | a1=makeArray(sh,f[RANGE]) |
1369 | jgs | 154 | r1=makeResult(a1,f) |
1370 | if len(sh)==0: | ||
1371 | if args=="expandedData": | ||
1372 | t_prog+=" arg=Data(%s,self.functionspace,True)\n"%(a) | ||
1373 | t_prog+=" ref=Data(%s,self.functionspace,True)\n"%(r) | ||
1374 | else: | ||
1375 | t_prog+=" arg=Data(%s,self.functionspace)\n"%(a) | ||
1376 | t_prog+=" ref=Data(%s,self.functionspace)\n"%(r) | ||
1377 | t_prog+=" arg.setTaggedValue(1,%s)\n"%a | ||
1378 | t_prog+=" ref.setTaggedValue(1,%s)\n"%r1 | ||
1379 | else: | ||
1380 | if args=="expandedData": | ||
1381 | t_prog+=" arg=Data(numarray.array(%s),self.functionspace,True)\n"%(a.tolist()) | ||
1382 | t_prog+=" ref=Data(numarray.array(%s),self.functionspace,True)\n"%(r.tolist()) | ||
1383 | else: | ||
1384 | t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist()) | ||
1385 | t_prog+=" ref=Data(numarray.array(%s),self.functionspace)\n"%(r.tolist()) | ||
1386 | t_prog+=" arg.setTaggedValue(1,%s)\n"%a1.tolist() | ||
1387 | t_prog+=" ref.setTaggedValue(1,%s)\n"%r1.tolist() | ||
1388 | gross | 157 | t_prog+=" res=%s(%a1%)\n"%f.nickname |
1389 | jgs | 154 | t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh) |
1390 | t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n" | ||
1391 | elif args=="Symbol": | ||
1392 | t_prog+=" arg=Symbol(shape=%s)\n"%str(sh) | ||
1393 | gross | 157 | t_prog+=" v=%s(%a1%)\n"%f.nickname |
1394 | jgs | 154 | t_prog+=" self.failUnlessRaises(ValueError,v.substitute,Symbol(shape=(1,1)),\"illegal shape of substitute not identified.\")\n" |
1395 | gross | 157 | a=makeArray(sh,f[RANGE]) |
1396 | jgs | 154 | r=makeResult(a,f) |
1397 | if len(sh)==0: | ||
1398 | t_prog+=" res=v.substitute({arg : %s})\n"%a | ||
1399 | t_prog+=" ref=%s\n"%r | ||
1400 | t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n" | ||
1401 | else: | ||
1402 | t_prog+=" res=v.substitute({arg : numarray.array(%s)})\n"%a.tolist() | ||
1403 | t_prog+=" ref=numarray.array(%s)\n"%r.tolist() | ||
1404 | t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of substitution result.\")\n"%str(sh) | ||
1405 | t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n" | ||
1406 | |||
1407 | if len(sh)==0: | ||
1408 | t_prog+=" # test derivative with respect to itself:\n" | ||
1409 | t_prog+=" dvdv=v.diff(v)\n" | ||
1410 | t_prog+=" self.failUnlessEqual(dvdv,1.,\"derivative with respect to self is not 1.\")\n" | ||
1411 | elif len(sh)==1: | ||
1412 | t_prog+=" # test derivative with respect to itself:\n" | ||
1413 | t_prog+=" dvdv=v.diff(v)\n" | ||
1414 | t_prog+=" self.failUnlessEqual(dvdv.shape,%s,\"shape of derivative with respect is wrong\")\n"%str(sh+sh) | ||
1415 | for i0_l in range(sh[0]): | ||
1416 | for i0_r in range(sh[0]): | ||
1417 | if i0_l == i0_r: | ||
1418 | v=1. | ||
1419 | else: | ||
1420 | v=0. | ||
1421 | t_prog+=" self.failUnlessEqual(dvdv[%s,%s],%s,\"derivative with respect to self: [%s,%s] is not %s\")\n"%(i0_l,i0_r,v,i0_l,i0_r,v) | ||
1422 | elif len(sh)==2: | ||
1423 | t_prog+=" # test derivative with respect to itself:\n" | ||
1424 | t_prog+=" dvdv=v.diff(v)\n" | ||
1425 | t_prog+=" self.failUnlessEqual(dvdv.shape,%s,\"shape of derivative with respect is wrong\")\n"%str(sh+sh) | ||
1426 | for i0_l in range(sh[0]): | ||
1427 | for i0_r in range(sh[0]): | ||
1428 | for i1_l in range(sh[1]): | ||
1429 | for i1_r in range(sh[1]): | ||
1430 | if i0_l == i0_r and i1_l == i1_r: | ||
1431 | v=1. | ||
1432 | else: | ||
1433 | v=0. | ||
1434 | t_prog+=" self.failUnlessEqual(dvdv[%s,%s,%s,%s],%s,\"derivative with respect to self: [%s,%s,%s,%s] is not %s\")\n"%(i0_l,i1_l,i0_r,i1_r,v,i0_l,i1_l,i0_r,i1_r,v) | ||
1435 | |||
1436 | for sh_in in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]: | ||
1437 | if len(sh_in)+len(sh)<=4: | ||
1438 | |||
1439 | t_prog+=" # test derivative with shape %s as argument\n"%str(sh_in) | ||
1440 | trafo=makeArray(sh+sh_in,[0,1.]) | ||
1441 | gross | 157 | a_in=makeArray(sh_in,f[RANGE]) |
1442 | jgs | 154 | t_prog+=" arg_in=Symbol(shape=%s)\n"%str(sh_in) |
1443 | t_prog+=" arg2=Symbol(shape=%s)\n"%str(sh) | ||
1444 | |||
1445 | if len(sh)==0: | ||
1446 | t_prog+=" arg2=" | ||
1447 | if len(sh_in)==0: | ||
1448 | t_prog+="%s*arg_in\n"%trafo | ||
1449 | elif len(sh_in)==1: | ||
1450 | for i0 in range(sh_in[0]): | ||
1451 | if i0>0: t_prog+="+" | ||
1452 | t_prog+="%s*arg_in[%s]"%(trafo[i0],i0) | ||
1453 | t_prog+="\n" | ||
1454 | elif len(sh_in)==2: | ||
1455 | for i0 in range(sh_in[0]): | ||
1456 | for i1 in range(sh_in[1]): | ||
1457 | if i0+i1>0: t_prog+="+" | ||
1458 | t_prog+="%s*arg_in[%s,%s]"%(trafo[i0,i1],i0,i1) | ||
1459 | |||
1460 | elif len(sh_in)==3: | ||
1461 | for i0 in range(sh_in[0]): | ||
1462 | for i1 in range(sh_in[1]): | ||
1463 | for i2 in range(sh_in[2]): | ||
1464 | if i0+i1+i2>0: t_prog+="+" | ||
1465 | t_prog+="%s*arg_in[%s,%s,%s]"%(trafo[i0,i1,i2],i0,i1,i2) | ||
1466 | elif len(sh_in)==4: | ||
1467 | for i0 in range(sh_in[0]): | ||
1468 | for i1 in range(sh_in[1]): | ||
1469 | for i2 in range(sh_in[2]): | ||
1470 | for i3 in range(sh_in[3]): | ||
1471 | if i0+i1+i2+i3>0: t_prog+="+" | ||
1472 | t_prog+="%s*arg_in[%s,%s,%s,%s]"%(trafo[i0,i1,i2,i3],i0,i1,i2,i3) | ||
1473 | t_prog+="\n" | ||
1474 | elif len(sh)==1: | ||
1475 | for j0 in range(sh[0]): | ||
1476 | t_prog+=" arg2[%s]="%j0 | ||
1477 | if len(sh_in)==0: | ||
1478 | t_prog+="%s*arg_in"%trafo[j0] | ||
1479 | elif len(sh_in)==1: | ||
1480 | for i0 in range(sh_in[0]): | ||
1481 | if i0>0: t_prog+="+" | ||
1482 | t_prog+="%s*arg_in[%s]"%(trafo[j0,i0],i0) | ||
1483 | elif len(sh_in)==2: | ||
1484 | for i0 in range(sh_in[0]): | ||
1485 | for i1 in range(sh_in[1]): | ||
1486 | if i0+i1>0: t_prog+="+" | ||
1487 | t_prog+="%s*arg_in[%s,%s]"%(trafo[j0,i0,i1],i0,i1) | ||
1488 | elif len(sh_in)==3: | ||
1489 | for i0 in range(sh_in[0]): | ||
1490 | for i1 in range(sh_in[1]): | ||
1491 | for i2 in range(sh_in[2]): | ||
1492 | if i0+i1+i2>0: t_prog+="+" | ||
1493 | t_prog+="%s*arg_in[%s,%s,%s]"%(trafo[j0,i0,i1,i2],i0,i1,i2) | ||
1494 | t_prog+="\n" | ||
1495 | elif len(sh)==2: | ||
1496 | for j0 in range(sh[0]): | ||
1497 | for j1 in range(sh[1]): | ||
1498 | t_prog+=" arg2[%s,%s]="%(j0,j1) | ||
1499 | if len(sh_in)==0: | ||
1500 | t_prog+="%s*arg_in"%trafo[j0,j1] | ||
1501 | elif len(sh_in)==1: | ||
1502 | for i0 in range(sh_in[0]): | ||
1503 | if i0>0: t_prog+="+" | ||
1504 | t_prog+="%s*arg_in[%s]"%(trafo[j0,j1,i0],i0) | ||
1505 | elif len(sh_in)==2: | ||
1506 | for i0 in range(sh_in[0]): | ||
1507 | for i1 in range(sh_in[1]): | ||
1508 | if i0+i1>0: t_prog+="+" | ||
1509 | t_prog+="%s*arg_in[%s,%s]"%(trafo[j0,j1,i0,i1],i0,i1) | ||
1510 | t_prog+="\n" | ||
1511 | elif len(sh)==3: | ||
1512 | for j0 in range(sh[0]): | ||
1513 | for j1 in range(sh[1]): | ||
1514 | for j2 in range(sh[2]): | ||
1515 | t_prog+=" arg2[%s,%s,%s]="%(j0,j1,j2) | ||
1516 | if len(sh_in)==0: | ||
1517 | t_prog+="%s*arg_in"%trafo[j0,j1,j2] | ||
1518 | elif len(sh_in)==1: | ||
1519 | for i0 in range(sh_in[0]): | ||
1520 | if i0>0: t_prog+="+" | ||
1521 | t_prog+="%s*arg_in[%s]"%(trafo[j0,j1,j2,i0],i0) | ||
1522 | t_prog+="\n" | ||
1523 | elif len(sh)==4: | ||
1524 | for j0 in range(sh[0]): | ||
1525 | for j1 in range(sh[1]): | ||
1526 | for j2 in range(sh[2]): | ||
1527 | for j3 in range(sh[3]): | ||
1528 | t_prog+=" arg2[%s,%s,%s,%s]="%(j0,j1,j2,j3) | ||
1529 | if len(sh_in)==0: | ||
1530 | t_prog+="%s*arg_in"%trafo[j0,j1,j2,j3] | ||
1531 | t_prog+="\n" | ||
1532 | t_prog+=" dvdin=v.substitute({arg : arg2}).diff(arg_in)\n" | ||
1533 | if len(sh_in)==0: | ||
1534 | t_prog+=" res_in=dvdin.substitute({arg_in : %s})\n"%a_in | ||
1535 | else: | ||
1536 | t_prog+=" res_in=dvdin.substitute({arg : numarray.array(%s)})\n"%a_in.tolist() | ||
1537 | |||
1538 | if len(sh)==0: | ||
1539 | if len(sh_in)==0: | ||
1540 | ref_diff=(makeResult(trafo*a_in+finc,f)-makeResult(trafo*a_in,f))/finc | ||
1541 | t_prog+=" self.failUnlessAlmostEqual(dvdin,%s,self.places,\"%s-derivative: wrong derivative\")\n"%(ref_diff,str(sh_in)) | ||
1542 | elif len(sh_in)==1: | ||
1543 | s=0 | ||
1544 | for k0 in range(sh_in[0]): | ||
1545 | s+=trafo[k0]*a_in[k0] | ||
1546 | for i0 in range(sh_in[0]): | ||
1547 | ref_diff=(makeResult(s+trafo[i0]*finc,f)-makeResult(s,f))/finc | ||
1548 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,ref_diff,str(sh_in),str(i0)) | ||
1549 | elif len(sh_in)==2: | ||
1550 | s=0 | ||
1551 | for k0 in range(sh_in[0]): | ||
1552 | for k1 in range(sh_in[1]): | ||
1553 | s+=trafo[k0,k1]*a_in[k0,k1] | ||
1554 | for i0 in range(sh_in[0]): | ||
1555 | for i1 in range(sh_in[1]): | ||
1556 | ref_diff=(makeResult(s+trafo[i0,i1]*finc,f)-makeResult(s,f))/finc | ||
1557 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,ref_diff,str(sh_in),str((i0,i1))) | ||
1558 | |||
1559 | elif len(sh_in)==3: | ||
1560 | s=0 | ||
1561 | for k0 in range(sh_in[0]): | ||
1562 | for k1 in range(sh_in[1]): | ||
1563 | for k2 in range(sh_in[2]): | ||
1564 | s+=trafo[k0,k1,k2]*a_in[k0,k1,k2] | ||
1565 | for i0 in range(sh_in[0]): | ||
1566 | for i1 in range(sh_in[1]): | ||
1567 | for i2 in range(sh_in[2]): | ||
1568 | ref_diff=(makeResult(s+trafo[i0,i1,i2]*finc,f)-makeResult(s,f))/finc | ||
1569 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,i2,ref_diff,str(sh_in),str((i0,i1,i2))) | ||
1570 | elif len(sh_in)==4: | ||
1571 | s=0 | ||
1572 | for k0 in range(sh_in[0]): | ||
1573 | for k1 in range(sh_in[1]): | ||
1574 | for k2 in range(sh_in[2]): | ||
1575 | for k3 in range(sh_in[3]): | ||
1576 | s+=trafo[k0,k1,k2,k3]*a_in[k0,k1,k2,k3] | ||
1577 | for i0 in range(sh_in[0]): | ||
1578 | for i1 in range(sh_in[1]): | ||
1579 | for i2 in range(sh_in[2]): | ||
1580 | for i3 in range(sh_in[3]): | ||
1581 | ref_diff=(makeResult(s+trafo[i0,i1,i2,i3]*finc,f)-makeResult(s,f))/finc | ||
1582 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,i2,i3,ref_diff,str(sh_in),str((i0,i1,i2,i3))) | ||
1583 | elif len(sh)==1: | ||
1584 | for j0 in range(sh[0]): | ||
1585 | if len(sh_in)==0: | ||
1586 | ref_diff=(makeResult(trafo[j0]*a_in+finc,f)-makeResult(trafo[j0]*a_in,f))/finc | ||
1587 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,ref_diff,str(sh_in),j0) | ||
1588 | elif len(sh_in)==1: | ||
1589 | s=0 | ||
1590 | for k0 in range(sh_in[0]): | ||
1591 | s+=trafo[j0,k0]*a_in[k0] | ||
1592 | for i0 in range(sh_in[0]): | ||
1593 | ref_diff=(makeResult(s+trafo[j0,i0]*finc,f)-makeResult(s,f))/finc | ||
1594 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,i0,ref_diff,str(sh_in),str(j0),str(i0)) | ||
1595 | elif len(sh_in)==2: | ||
1596 | s=0 | ||
1597 | for k0 in range(sh_in[0]): | ||
1598 | for k1 in range(sh_in[1]): | ||
1599 | s+=trafo[j0,k0,k1]*a_in[k0,k1] | ||
1600 | for i0 in range(sh_in[0]): | ||
1601 | for i1 in range(sh_in[1]): | ||
1602 | ref_diff=(makeResult(s+trafo[j0,i0,i1]*finc,f)-makeResult(s,f))/finc | ||
1603 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,i0,i1,ref_diff,str(sh_in),str(j0),str((i0,i1))) | ||
1604 | |||
1605 | elif len(sh_in)==3: | ||
1606 | |||
1607 | s=0 | ||
1608 | for k0 in range(sh_in[0]): | ||
1609 | for k1 in range(sh_in[1]): | ||
1610 | for k2 in range(sh_in[2]): | ||
1611 | s+=trafo[j0,k0,k1,k2]*a_in[k0,k1,k2] | ||
1612 | |||
1613 | for i0 in range(sh_in[0]): | ||
1614 | for i1 in range(sh_in[1]): | ||
1615 | for i2 in range(sh_in[2]): | ||
1616 | ref_diff=(makeResult(s+trafo[j0,i0,i1,i2]*finc,f)-makeResult(s,f))/finc | ||
1617 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,i0,i1,i2,ref_diff,str(sh_in),str(j0),str((i0,i1,i2))) | ||
1618 | elif len(sh)==2: | ||
1619 | for j0 in range(sh[0]): | ||
1620 | for j1 in range(sh[1]): | ||
1621 | if len(sh_in)==0: | ||
1622 | ref_diff=(makeResult(trafo[j0,j1]*a_in+finc,f)-makeResult(trafo[j0,j1]*a_in,f))/finc | ||
1623 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,j1,ref_diff,str(sh_in),str((j0,j1))) | ||
1624 | elif len(sh_in)==1: | ||
1625 | s=0 | ||
1626 | for k0 in range(sh_in[0]): | ||
1627 | s+=trafo[j0,j1,k0]*a_in[k0] | ||
1628 | for i0 in range(sh_in[0]): | ||
1629 | ref_diff=(makeResult(s+trafo[j0,j1,i0]*finc,f)-makeResult(s,f))/finc | ||
1630 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,j1,i0,ref_diff,str(sh_in),str((j0,j1)),str(i0)) | ||
1631 | |||
1632 | elif len(sh_in)==2: | ||
1633 | s=0 | ||
1634 | for k0 in range(sh_in[0]): | ||
1635 | for k1 in range(sh_in[1]): | ||
1636 | s+=trafo[j0,j1,k0,k1]*a_in[k0,k1] | ||
1637 | for i0 in range(sh_in[0]): | ||
1638 | for i1 in range(sh_in[1]): | ||
1639 | ref_diff=(makeResult(s+trafo[j0,j1,i0,i1]*finc,f)-makeResult(s,f))/finc | ||
1640 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,j1,i0,i1,ref_diff,str(sh_in),str((j0,j1)),str((i0,i1))) | ||
1641 | elif len(sh)==3: | ||
1642 | for j0 in range(sh[0]): | ||
1643 | for j1 in range(sh[1]): | ||
1644 | for j2 in range(sh[2]): | ||
1645 | if len(sh_in)==0: | ||
1646 | ref_diff=(makeResult(trafo[j0,j1,j2]*a_in+finc,f)-makeResult(trafo[j0,j1,j2]*a_in,f))/finc | ||
1647 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,j1,j2,ref_diff,str(sh_in),str((j0,j1,j2))) | ||
1648 | elif len(sh_in)==1: | ||
1649 | s=0 | ||
1650 | for k0 in range(sh_in[0]): | ||
1651 | s+=trafo[j0,j1,j2,k0]*a_in[k0] | ||
1652 | for i0 in range(sh_in[0]): | ||
1653 | ref_diff=(makeResult(s+trafo[j0,j1,j2,i0]*finc,f)-makeResult(s,f))/finc | ||
1654 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,j1,j2,i0,ref_diff,str(sh_in),str((j0,j1,j2)),i0) | ||
1655 | elif len(sh)==4: | ||
1656 | for j0 in range(sh[0]): | ||
1657 | for j1 in range(sh[1]): | ||
1658 | for j2 in range(sh[2]): | ||
1659 | for j3 in range(sh[3]): | ||
1660 | if len(sh_in)==0: | ||
1661 | ref_diff=(makeResult(trafo[j0,j1,j2,j3]*a_in+finc,f)-makeResult(trafo[j0,j1,j2,j3]*a_in,f))/finc | ||
1662 | t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,j1,j2,j3,ref_diff,str(sh_in),str((j0,j1,j2,j3))) | ||
1663 | gross | 157 | |
1664 | gross | 291 | # |
1665 | gross | 157 | |
1666 | #================== | ||
1667 | cases=["Scalar","Vector","Tensor", "Tensor3","Tensor4"] | ||
1668 | |||
1669 | for case in range(len(cases)): | ||
1670 | for d in [ None , "d", 1, 2 , 3]: | ||
1671 | if not d==None or cases[case]=="Scalar": | ||
1672 | t_prog+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" | ||
1673 | tname="test_Symbol_%s_d%s"%(cases[case],d) | ||
1674 | t_prog+=" def %s(self):\n"%tname | ||
1675 | if d=="d": | ||
1676 | t_prog+=" s=%sSymbol(dim=self.functionspace)\n"%(cases[case]) | ||
1677 | t_prog+=" d=self.functionspace.getDim()\n" | ||
1678 | sh="(" | ||
1679 | for i in range(case): | ||
1680 | if i==0: | ||
1681 | sh+=d | ||
1682 | else: | ||
1683 | sh+=","+d | ||
1684 | sh+=")" | ||
1685 | else: | ||
1686 | t_prog+=" s=%sSymbol(dim=%s)\n"%(cases[case],d) | ||
1687 | sh=() | ||
1688 | for i in range(case): sh=sh+(d,) | ||
1689 | t_prog+=" self.failUnlessEqual(s.getRank(),%s,\"wrong rank.\")\n"%case | ||
1690 | t_prog+=" self.failUnlessEqual(s.getShape(),%s,\"wrong shape.\")\n"%str(sh) | ||
1691 | t_prog+=" self.failUnlessEqual(s.getDim(),%s,\"wrong spatial dimension.\")\n"%d | ||
1692 | t_prog+=" self.failUnlessEqual(s.getArgument(),[],\"wrong arguments.\")\n" | ||
1693 | |||
1694 | print t_prog | ||
1695 | 1/0 | ||
1696 | for sh in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]: | ||
1697 | for d in [ None , "domain", 1, 2 , 3]: | ||
1698 | for args in [ [], ["s2"], [1,-1.] ]: | ||
1699 | t_prog+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" | ||
1700 | tname="def test_Symbol_rank%s_d%s_nargs%s"%(len(sh),d,len(args)) | ||
1701 | t_prog+=" %s(self):\n"%tname | ||
1702 | t_prog+=" s2=ScalarSymbol()\n" | ||
1703 | if args==["s2"]: | ||
1704 | a="[s2]" | ||
1705 | else: | ||
1706 | a=str(args) | ||
1707 | if d=="domain": | ||
1708 | t_prog+=" s=Symbol(shape=%s,dim=self.functionspace.getDim(),args=%s)\n"%(str(sh),a) | ||
1709 | d2="self.functionspace.getDim()" | ||
1710 | else: | ||
1711 | t_prog+=" s=Symbol(shape=%s,dim=%s,args=%s)\n"%(sh,d,a) | ||
1712 | d2=str(d) | ||
1713 | |||
1714 | t_prog+=" self.failUnlessEqual(s.getRank(),%s,\"wrong rank.\")\n"%len(sh) | ||
1715 | t_prog+=" self.failUnlessEqual(s.getShape(),%s,\"wrong shape.\")\n"%str(sh) | ||
1716 | t_prog+=" self.failUnlessEqual(s.getDim(),%s,\"wrong spatial dimension.\")\n"%d2 | ||
1717 | t_prog+=" self.failUnlessEqual(s.getArgument(),%s,\"wrong arguments.\")\n\n"%a | ||
1718 | t_prog+=" ss=s.substitute({s:numarray.zeros(%s)})\n"%str(sh) | ||
1719 | t_prog+=" self.failUnless(isinstance(ss,numarray.NumArray),\"value after substitution is not numarray.\")\n" | ||
1720 | t_prog+=" self.failUnlessEqual(ss.shape,%s,\"value after substitution has not expected shape\")\n"%str(sh) | ||
1721 | |||
1722 | t_prog+=" try:\n s.substitute({s:numarray.zeros((5,))})\n fail(\"illegal substition was successful\")\n" | ||
1723 | t_prog+=" except TypeError:\n pass\n\n" | ||
1724 | |||
1725 | ### | ||
1726 | for sh2 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]: | ||
1727 | if len(sh+sh2)<5: | ||
1728 | t_prog+=" dsdarg=s.diff(Symbol(shape=%s))\n"%str(sh2) | ||
1729 | if len(sh+sh2)==0: | ||
1730 | t_prog+=" self.failUnless(isinstance(dsdarg,float),\"ds/ds() has wrong type.\")\n" | ||
1731 | else: | ||
1732 | t_prog+=" self.failUnless(isinstance(dsdarg,numarray.NumArray),\"ds/ds%s has wrong type.\")\n"%str(sh2) | ||
1733 | |||
1734 | t_prog+=" self.failUnlessEqual(dsdarg.shape,%s,\"ds/ds%s has wrong shape.\")\n"%(str(sh+sh2),str(sh2)) | ||
1735 | t_prog+=" self.failIf(Lsup(dsdarg)>0.,\"ds/ds%s has wrong value.\")\n"%str(sh2) | ||
1736 | if len(sh)<3: | ||
1737 | t_prog+="\n dsds=s.diff(s)\n" | ||
1738 | if len(sh)==0: | ||
1739 | t_prog+=" self.failUnless(isinstance(dsds,float),\"ds/ds has wrong type.\")\n" | ||
1740 | t_prog+=" self.failUnlessEqual(dsds,1.,\"ds/ds has wrong value.\")\n" | ||
1741 | else: | ||
1742 | t_prog+=" self.failUnless(isinstance(dsds,numarray.NumArray),\"ds/ds has wrong type.\")\n" | ||
1743 | t_prog+=" self.failUnlessEqual(dsds.shape,%s,\"ds/ds has wrong shape.\")\n"%str(sh+sh) | ||
1744 | if len(sh)==1: | ||
1745 | for i0 in range(sh[0]): | ||
1746 | for i2 in range(sh[0]): | ||
1747 | if i0==i2: | ||
1748 | v=1. | ||
1749 | else: | ||
1750 | v=0. | ||
1751 | t_prog+=" self.failUnlessEqual(dsds[%s,%s],%s,\"ds/ds has wrong value at (%s,%s).\")\n"%(i0,i2,v,i0,i2) | ||
1752 | else: | ||
1753 | for i0 in range(sh[0]): | ||
1754 | for i1 in range(sh[1]): | ||
1755 | for i2 in range(sh[0]): | ||
1756 | for i3 in range(sh[1]): | ||
1757 | if i0==i2 and i1==i3: | ||
1758 | v=1. | ||
1759 | else: | ||
1760 | v=0. | ||
1761 | t_prog+=" self.failUnlessEqual(dsds[%s,%s,%s,%s],%s,\"ds/ds has wrong value at (%s,%s,%s,%s).\")\n"%(i0,i1,i2,i3,v,i0,i1,i2,i3) | ||
1762 | |||
1763 | ### | ||
1764 | t_prog+="\n" | ||
1765 | for i in range(len(args)): | ||
1766 | t_prog+=" self.failUnlessEqual(s.getArgument(%s),%s,\"wrong argument %s.\")\n"%(i,str(args[i]),i) | ||
1767 | t_prog+=" sa=s.getSubstitutedArguments({s2:-10})\n" | ||
1768 | t_prog+=" self.failUnlessEqual(len(sa),%s,\"wrong number of substituted arguments\")\n"%len(args) | ||
1769 | if args==["s2"]: | ||
1770 | t_prog+=" self.failUnlessEqual(sa[0],-10,\"wrongly substituted argument 0.\")\n" | ||
1771 | else: | ||
1772 | for i in range(len(args)): | ||
1773 | t_prog+=" self.failUnlessEqual(sa[%s],%s,\"wrongly substituted argument %s.\")\n"%(i,str(args[i]),i) | ||
1774 | |||
1775 | t_prog+="\n" | ||
1776 | for arg in ["10.", "10", "SymbolMatch", "SymbolMisMatch", \ | ||
1777 | "DataMatch","DataMisMatch", "NumArrayMatch", "NumArrayMisMatch"]: | ||
1778 | if arg in ["10.", "10"]: | ||
1779 | a=str(arg) | ||
1780 | if len(sh)==0: | ||
1781 | t_prog+=" self.failUnless(s.isAppropriateValue(%s),\"%s is appropriate substitute\")\n"%(a,arg) | ||
1782 | else: | ||
1783 | t_prog+=" self.failIf(s.isAppropriateValue(%s),\" %s is not appropriate substitute\")\n"%(a,arg) | ||
1784 | elif arg in ["SymbolMatch", "SymbolMisMatch"]: | ||
1785 | if arg=="SymbolMatch": | ||
1786 | t_prog+=" self.failUnless(s.isAppropriateValue(Symbol(shape=%s,dim=%s)),\"Symbol is appropriate substitute\")\n"%(str(sh),d) | ||
1787 | else: | ||
1788 | if isinstance(d,int): | ||
1789 | t_prog+=" self.failIf(s.isAppropriateValue(Symbol(shape=%s,dim=%s)),\"Symbol is not appropriate substitute (dim)\")\n"%(str(sh),d+1) | ||
1790 | else: | ||
1791 | t_prog+=" self.failIf(s.isAppropriateValue(Symbol(shape=%s)),\"Symbol is not appropriate substitute (shape)\")\n"%((5,)) | ||
1792 | |||
1793 | elif arg in ["DataMatch","DataMisMatch"]: | ||
1794 | if arg=="DataMatch" and d=="domain": | ||
1795 | t_prog+=" self.failUnless(s.isAppropriateValue(escript.Data(0.,%s,self.functionspace)),\"Data is appropriate substitute\")\n"%str(sh) | ||
1796 | elif arg=="DataMisMatch": | ||
1797 | t_prog+=" self.failIf(s.isAppropriateValue(escript.Data(0.,%s,self.functionspace)),\"Data is not appropriate substitute (shape)\")\n"%(str((5,))) | ||
1798 | else: | ||
1799 | if arg=="NumArrayMatch": | ||
1800 | t_prog+=" self.failUnless(s.isAppropriateValue(numarray.zeros(%s)),\"NumArray is appropriate substitute\")\n"%str(sh) | ||
1801 | else: | ||
1802 | t_prog+=" self.failIf(s.isAppropriateValue(numarray.zeros(%s)),\"NumArray is not appropriate substitute (shape)\")\n"%(str((5,))) | ||
1803 | jgs | 154 | print t_prog |
1804 | gross | 157 | 1/0 |
1805 | |||
1806 | |||
1807 | for case in ["Lsup", "sup", "inf"]: | ||
1808 | for args in ["float","array","constData","taggedData","expandedData"]: | ||
1809 | for sh in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]: | ||
1810 | if not args=="float" or len(sh)==0: | ||
1811 | t_prog+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n" | ||
1812 | tname="def test_%s_%s_rank%s"%(case,args,len(sh)) | ||
1813 | t_prog+=" %s(self):\n"%tname | ||
1814 | if args in ["float","array" ]: | ||
1815 | a=makeArray(sh,[-1,1]) | ||
1816 | r=makeResult2(a,case) | ||
1817 | if len(sh)==0: | ||
1818 | t_prog+=" arg=%s\n"%a | ||
1819 | else: | ||
1820 | t_prog+=" arg=numarray.array(%s)\n"%a.tolist() | ||
1821 | t_prog+=" ref=%s\n"%r | ||
1822 | t_prog+=" res=%s(%a1%)\n"%case | ||
1823 | t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n&q |