/[escript]/trunk/escript/py_src/generateutil
ViewVC logotype

Annotation of /trunk/escript/py_src/generateutil

Parent Directory Parent Directory | Revision Log Revision Log


Revision 433 - (hide annotations)
Tue Jan 17 23:54:38 2006 UTC (13 years, 9 months ago) by gross
File size: 102994 byte(s)
new function inverse and tests added
1 jgs 154 #!/usr/bin/python
2 gross 283 # $Id$
3 jgs 154
4     """
5     program generates parts of the util.py and the test_util.py script
6     """
7 gross 157 test_header=""
8     test_header+="import unittest\n"
9     test_header+="import numarray\n"
10     test_header+="from esys.escript import *\n"
11     test_header+="from esys.finley import Rectangle\n"
12     test_header+="class Test_util2(unittest.TestCase):\n"
13     test_header+=" RES_TOL=1.e-7\n"
14     test_header+=" def setUp(self):\n"
15     test_header+=" self.__dom =Rectangle(11,11,2)\n"
16     test_header+=" self.functionspace = FunctionOnBoundary(self.__dom)\n"
17     test_tail=""
18     test_tail+="suite = unittest.TestSuite()\n"
19     test_tail+="suite.addTest(unittest.makeSuite(Test_util2))\n"
20     test_tail+="unittest.TextTestRunner(verbosity=2).run(suite)\n"
21    
22     case_set=["float","array","constData","taggedData","expandedData","Symbol"]
23     shape_set=[ (),(2,), (4,5), (6,2,2),(3,2,3,4)]
24    
25 jgs 154 t_prog=""
26 gross 157 t_prog_with_tags=""
27     t_prog_failing=""
28     u_prog=""
29 jgs 154
30 gross 157 def wherepos(arg):
31     if arg>0.:
32     return 1.
33     else:
34     return 0.
35    
36 gross 313
37 gross 157 class OPERATOR:
38     def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,
39     numarray_expr="",symbol_expr=None,diff=None,name=""):
40     self.nickname=nickname
41     self.rng=rng
42     self.test_expr=test_expr
43     if math_expr==None:
44     self.math_expr=test_expr.replace("%a1%","arg")
45     else:
46     self.math_expr=math_expr
47     self.numarray_expr=numarray_expr
48     self.symbol_expr=symbol_expr
49     self.diff=diff
50     self.name=name
51    
52 jgs 154 import random
53     import numarray
54     import math
55     finc=1.e-6
56    
57 gross 157 def getResultCaseForBin(case0,case1):
58     if case0=="float":
59     if case1=="float":
60     case="float"
61     elif case1=="array":
62     case="array"
63     elif case1=="constData":
64     case="constData"
65     elif case1=="taggedData":
66     case="taggedData"
67     elif case1=="expandedData":
68     case="expandedData"
69     elif case1=="Symbol":
70     case="Symbol"
71     else:
72     raise ValueError,"unknown case1=%s"%case1
73     elif case0=="array":
74     if case1=="float":
75     case="array"
76     elif case1=="array":
77     case="array"
78     elif case1=="constData":
79     case="constData"
80     elif case1=="taggedData":
81     case="taggedData"
82     elif case1=="expandedData":
83     case="expandedData"
84     elif case1=="Symbol":
85     case="Symbol"
86     else:
87     raise ValueError,"unknown case1=%s"%case1
88     elif case0=="constData":
89     if case1=="float":
90     case="constData"
91     elif case1=="array":
92     case="constData"
93     elif case1=="constData":
94     case="constData"
95     elif case1=="taggedData":
96     case="taggedData"
97     elif case1=="expandedData":
98     case="expandedData"
99     elif case1=="Symbol":
100     case="Symbol"
101     else:
102     raise ValueError,"unknown case1=%s"%case1
103     elif case0=="taggedData":
104     if case1=="float":
105     case="taggedData"
106     elif case1=="array":
107     case="taggedData"
108     elif case1=="constData":
109     case="taggedData"
110     elif case1=="taggedData":
111     case="taggedData"
112     elif case1=="expandedData":
113     case="expandedData"
114     elif case1=="Symbol":
115     case="Symbol"
116     else:
117     raise ValueError,"unknown case1=%s"%case1
118     elif case0=="expandedData":
119     if case1=="float":
120     case="expandedData"
121     elif case1=="array":
122     case="expandedData"
123     elif case1=="constData":
124     case="expandedData"
125     elif case1=="taggedData":
126     case="expandedData"
127     elif case1=="expandedData":
128     case="expandedData"
129     elif case1=="Symbol":
130     case="Symbol"
131     else:
132     raise ValueError,"unknown case1=%s"%case1
133     elif case0=="Symbol":
134     if case1=="float":
135     case="Symbol"
136     elif case1=="array":
137     case="Symbol"
138     elif case1=="constData":
139     case="Symbol"
140     elif case1=="taggedData":
141     case="Symbol"
142     elif case1=="expandedData":
143     case="Symbol"
144     elif case1=="Symbol":
145     case="Symbol"
146     else:
147     raise ValueError,"unknown case1=%s"%case1
148     else:
149     raise ValueError,"unknown case0=%s"%case0
150     return case
151 jgs 154
152 gross 157
153 jgs 154 def makeArray(shape,rng):
154     l=rng[1]-rng[0]
155     out=numarray.zeros(shape,numarray.Float64)
156     if len(shape)==0:
157 gross 157 out=l*random.random()+rng[0]
158 jgs 154 elif len(shape)==1:
159     for i0 in range(shape[0]):
160     out[i0]=l*random.random()+rng[0]
161     elif len(shape)==2:
162     for i0 in range(shape[0]):
163     for i1 in range(shape[1]):
164     out[i0,i1]=l*random.random()+rng[0]
165     elif len(shape)==3:
166     for i0 in range(shape[0]):
167     for i1 in range(shape[1]):
168     for i2 in range(shape[2]):
169     out[i0,i1,i2]=l*random.random()+rng[0]
170     elif len(shape)==4:
171     for i0 in range(shape[0]):
172     for i1 in range(shape[1]):
173     for i2 in range(shape[2]):
174     for i3 in range(shape[3]):
175     out[i0,i1,i2,i3]=l*random.random()+rng[0]
176     else:
177     raise SystemError,"rank is restricted to 4"
178     return out
179    
180    
181 gross 157 def makeResult(val,test_expr):
182 jgs 154 if isinstance(val,float):
183 gross 157 out=eval(test_expr.replace("%a1%","val"))
184 jgs 154 elif len(val.shape)==0:
185 gross 157 out=eval(test_expr.replace("%a1%","val"))
186 jgs 154 elif len(val.shape)==1:
187     out=numarray.zeros(val.shape,numarray.Float64)
188     for i0 in range(val.shape[0]):
189 gross 157 out[i0]=eval(test_expr.replace("%a1%","val[i0]"))
190 jgs 154 elif len(val.shape)==2:
191     out=numarray.zeros(val.shape,numarray.Float64)
192     for i0 in range(val.shape[0]):
193     for i1 in range(val.shape[1]):
194 gross 157 out[i0,i1]=eval(test_expr.replace("%a1%","val[i0,i1]"))
195 jgs 154 elif len(val.shape)==3:
196     out=numarray.zeros(val.shape,numarray.Float64)
197     for i0 in range(val.shape[0]):
198     for i1 in range(val.shape[1]):
199     for i2 in range(val.shape[2]):
200 gross 157 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val[i0,i1,i2]"))
201 jgs 154 elif len(val.shape)==4:
202     out=numarray.zeros(val.shape,numarray.Float64)
203     for i0 in range(val.shape[0]):
204     for i1 in range(val.shape[1]):
205     for i2 in range(val.shape[2]):
206     for i3 in range(val.shape[3]):
207 gross 157 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val[i0,i1,i2,i3]"))
208 jgs 154 else:
209     raise SystemError,"rank is restricted to 4"
210     return out
211    
212 gross 157 def makeResult2(val0,val1,test_expr):
213     if isinstance(val0,float):
214     if isinstance(val1,float):
215     out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
216     elif len(val1.shape)==0:
217     out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
218     elif len(val1.shape)==1:
219     out=numarray.zeros(val1.shape,numarray.Float64)
220     for i0 in range(val1.shape[0]):
221     out[i0]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0]"))
222     elif len(val1.shape)==2:
223     out=numarray.zeros(val1.shape,numarray.Float64)
224     for i0 in range(val1.shape[0]):
225     for i1 in range(val1.shape[1]):
226     out[i0,i1]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1]"))
227     elif len(val1.shape)==3:
228     out=numarray.zeros(val1.shape,numarray.Float64)
229     for i0 in range(val1.shape[0]):
230     for i1 in range(val1.shape[1]):
231     for i2 in range(val1.shape[2]):
232     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2]"))
233     elif len(val1.shape)==4:
234     out=numarray.zeros(val1.shape,numarray.Float64)
235     for i0 in range(val1.shape[0]):
236     for i1 in range(val1.shape[1]):
237     for i2 in range(val1.shape[2]):
238     for i3 in range(val1.shape[3]):
239     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2,i3]"))
240     else:
241     raise SystemError,"rank of val1 is restricted to 4"
242     elif len(val0.shape)==0:
243     if isinstance(val1,float):
244     out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
245     elif len(val1.shape)==0:
246     out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
247     elif len(val1.shape)==1:
248     out=numarray.zeros(val1.shape,numarray.Float64)
249     for i0 in range(val1.shape[0]):
250     out[i0]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0]"))
251     elif len(val1.shape)==2:
252     out=numarray.zeros(val1.shape,numarray.Float64)
253     for i0 in range(val1.shape[0]):
254     for i1 in range(val1.shape[1]):
255     out[i0,i1]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1]"))
256     elif len(val1.shape)==3:
257     out=numarray.zeros(val1.shape,numarray.Float64)
258     for i0 in range(val1.shape[0]):
259     for i1 in range(val1.shape[1]):
260     for i2 in range(val1.shape[2]):
261     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2]"))
262     elif len(val1.shape)==4:
263     out=numarray.zeros(val1.shape,numarray.Float64)
264     for i0 in range(val1.shape[0]):
265     for i1 in range(val1.shape[1]):
266     for i2 in range(val1.shape[2]):
267     for i3 in range(val1.shape[3]):
268     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2,i3]"))
269     else:
270     raise SystemError,"rank of val1 is restricted to 4"
271     elif len(val0.shape)==1:
272     if isinstance(val1,float):
273     out=numarray.zeros(val0.shape,numarray.Float64)
274     for i0 in range(val0.shape[0]):
275     out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1"))
276     elif len(val1.shape)==0:
277     out=numarray.zeros(val0.shape,numarray.Float64)
278     for i0 in range(val0.shape[0]):
279     out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1"))
280     elif len(val1.shape)==1:
281     out=numarray.zeros(val0.shape,numarray.Float64)
282     for i0 in range(val0.shape[0]):
283     out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0]"))
284     elif len(val1.shape)==2:
285     out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64)
286     for i0 in range(val0.shape[0]):
287     for j1 in range(val1.shape[1]):
288     out[i0,j1]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1]"))
289     elif len(val1.shape)==3:
290     out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64)
291     for i0 in range(val0.shape[0]):
292     for j1 in range(val1.shape[1]):
293     for j2 in range(val1.shape[2]):
294     out[i0,j1,j2]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1,j2]"))
295     elif len(val1.shape)==4:
296     out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64)
297     for i0 in range(val0.shape[0]):
298     for j1 in range(val1.shape[1]):
299     for j2 in range(val1.shape[2]):
300     for j3 in range(val1.shape[3]):
301     out[i0,j1,j2,j3]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1,j2,j3]"))
302     else:
303     raise SystemError,"rank of val1 is restricted to 4"
304     elif len(val0.shape)==2:
305     if isinstance(val1,float):
306     out=numarray.zeros(val0.shape,numarray.Float64)
307     for i0 in range(val0.shape[0]):
308     for i1 in range(val0.shape[1]):
309     out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1"))
310     elif len(val1.shape)==0:
311     out=numarray.zeros(val0.shape,numarray.Float64)
312     for i0 in range(val0.shape[0]):
313     for i1 in range(val0.shape[1]):
314     out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1"))
315     elif len(val1.shape)==1:
316     out=numarray.zeros(val0.shape,numarray.Float64)
317     for i0 in range(val0.shape[0]):
318     for i1 in range(val0.shape[1]):
319     out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0]"))
320     elif len(val1.shape)==2:
321     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
322     for i0 in range(val0.shape[0]):
323     for i1 in range(val0.shape[1]):
324     out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1]"))
325     elif len(val1.shape)==3:
326     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
327     for i0 in range(val0.shape[0]):
328     for i1 in range(val0.shape[1]):
329     for j2 in range(val1.shape[2]):
330     out[i0,i1,j2]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1,j2]"))
331     elif len(val1.shape)==4:
332     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
333     for i0 in range(val0.shape[0]):
334     for i1 in range(val0.shape[1]):
335     for j2 in range(val1.shape[2]):
336     for j3 in range(val1.shape[3]):
337     out[i0,i1,j2,j3]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1,j2,j3]"))
338     else:
339     raise SystemError,"rank of val1 is restricted to 4"
340     elif len(val0.shape)==3:
341     if isinstance(val1,float):
342     out=numarray.zeros(val0.shape,numarray.Float64)
343     for i0 in range(val0.shape[0]):
344     for i1 in range(val0.shape[1]):
345     for i2 in range(val0.shape[2]):
346     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1"))
347     elif len(val1.shape)==0:
348     out=numarray.zeros(val0.shape,numarray.Float64)
349     for i0 in range(val0.shape[0]):
350     for i1 in range(val0.shape[1]):
351     for i2 in range(val0.shape[2]):
352     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1"))
353     elif len(val1.shape)==1:
354     out=numarray.zeros(val0.shape,numarray.Float64)
355     for i0 in range(val0.shape[0]):
356     for i1 in range(val0.shape[1]):
357     for i2 in range(val0.shape[2]):
358     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0]"))
359     elif len(val1.shape)==2:
360     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
361     for i0 in range(val0.shape[0]):
362     for i1 in range(val0.shape[1]):
363     for i2 in range(val0.shape[2]):
364     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1]"))
365     elif len(val1.shape)==3:
366     out=numarray.zeros(val0.shape,numarray.Float64)
367     for i0 in range(val0.shape[0]):
368     for i1 in range(val0.shape[1]):
369     for i2 in range(val0.shape[2]):
370     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1,i2]"))
371     elif len(val1.shape)==4:
372     out=numarray.zeros(val0.shape+val1.shape[3:],numarray.Float64)
373     for i0 in range(val0.shape[0]):
374     for i1 in range(val0.shape[1]):
375     for i2 in range(val0.shape[2]):
376     for j3 in range(val1.shape[3]):
377     out[i0,i1,i2,j3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1,i2,j3]"))
378     else:
379     raise SystemError,"rank of val1 is rargs[1]estricted to 4"
380     elif len(val0.shape)==4:
381     if isinstance(val1,float):
382     out=numarray.zeros(val0.shape,numarray.Float64)
383     for i0 in range(val0.shape[0]):
384     for i1 in range(val0.shape[1]):
385     for i2 in range(val0.shape[2]):
386     for i3 in range(val0.shape[3]):
387     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1"))
388     elif len(val1.shape)==0:
389     out=numarray.zeros(val0.shape,numarray.Float64)
390     for i0 in range(val0.shape[0]):
391     for i1 in range(val0.shape[1]):
392     for i2 in range(val0.shape[2]):
393     for i3 in range(val0.shape[3]):
394     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1"))
395     elif len(val1.shape)==1:
396     out=numarray.zeros(val0.shape,numarray.Float64)
397     for i0 in range(val0.shape[0]):
398     for i1 in range(val0.shape[1]):
399     for i2 in range(val0.shape[2]):
400     for i3 in range(val0.shape[3]):
401     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0]"))
402     elif len(val1.shape)==2:
403     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
404     for i0 in range(val0.shape[0]):
405     for i1 in range(val0.shape[1]):
406     for i2 in range(val0.shape[2]):
407     for i3 in range(val0.shape[3]):
408     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1]"))
409     elif len(val1.shape)==3:
410     out=numarray.zeros(val0.shape,numarray.Float64)
411     for i0 in range(val0.shape[0]):
412     for i1 in range(val0.shape[1]):
413     for i2 in range(val0.shape[2]):
414     for i3 in range(val0.shape[3]):
415     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1,i2]"))
416     elif len(val1.shape)==4:
417     out=numarray.zeros(val0.shape,numarray.Float64)
418     for i0 in range(val0.shape[0]):
419     for i1 in range(val0.shape[1]):
420     for i2 in range(val0.shape[2]):
421     for i3 in range(val0.shape[3]):
422     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1,i2,i3]"))
423     else:
424     raise SystemError,"rank of val1 is restricted to 4"
425     else:
426     raise SystemError,"rank is restricted to 4"
427     return out
428 jgs 154
429 gross 157
430     def mkText(case,name,a,a1=None,use_tagging_for_expanded_data=False):
431 jgs 154 t_out=""
432 gross 157 if case=="float":
433 jgs 154 if isinstance(a,float):
434     t_out+=" %s=%s\n"%(name,a)
435 gross 291 elif a.rank==0:
436 jgs 154 t_out+=" %s=%s\n"%(name,a)
437     else:
438     t_out+=" %s=numarray.array(%s)\n"%(name,a.tolist())
439 gross 157 elif case=="array":
440     if isinstance(a,float):
441     t_out+=" %s=numarray.array(%s)\n"%(name,a)
442 gross 291 elif a.rank==0:
443 gross 157 t_out+=" %s=numarray.array(%s)\n"%(name,a)
444     else:
445     t_out+=" %s=numarray.array(%s)\n"%(name,a.tolist())
446 jgs 154 elif case=="constData":
447     if isinstance(a,float):
448     t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
449 gross 291 elif a.rank==0:
450 jgs 154 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
451     else:
452     t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
453     elif case=="taggedData":
454     if isinstance(a,float):
455     t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
456     t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
457 gross 291 elif a.rank==0:
458 jgs 154 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
459     t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
460     else:
461     t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
462 gross 157 t_out+=" %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist())
463 jgs 154 elif case=="expandedData":
464 gross 157 if use_tagging_for_expanded_data:
465     if isinstance(a,float):
466     t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
467     t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
468 gross 291 elif a.rank==0:
469 gross 157 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
470     t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
471     else:
472     t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
473     t_out+=" %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist())
474     t_out+=" %s.expand()\n"%name
475     else:
476     t_out+=" msk_%s=whereNegative(self.functionspace.getX()[0]-0.5)\n"%name
477     if isinstance(a,float):
478     t_out+=" %s=msk_%s*(%s)+(1.-msk_%s)*(%s)\n"%(name,name,a,name,a1)
479 gross 291 elif a.rank==0:
480 gross 157 t_out+=" %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a,name,a1)
481     else:
482     t_out+=" %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a.tolist(),name,a1.tolist())
483 jgs 154 elif case=="Symbol":
484     if isinstance(a,float):
485     t_out+=" %s=Symbol(shape=())\n"%(name)
486 gross 291 elif a.rank==0:
487 jgs 154 t_out+=" %s=Symbol(shape=())\n"%(name)
488     else:
489     t_out+=" %s=Symbol(shape=%s)\n"%(name,str(a.shape))
490    
491     return t_out
492    
493 gross 157 def mkTypeAndShapeTest(case,sh,argstr):
494     text=""
495     if case=="float":
496     text+=" self.failUnless(isinstance(%s,float),\"wrong type of result.\")\n"%argstr
497     elif case=="array":
498     text+=" self.failUnless(isinstance(%s,numarray.NumArray),\"wrong type of result.\")\n"%argstr
499     text+=" self.failUnlessEqual(%s.shape,%s,\"wrong shape of result.\")\n"%(argstr,str(sh))
500     elif case in ["constData","taggedData","expandedData"]:
501     text+=" self.failUnless(isinstance(%s,Data),\"wrong type of result.\")\n"%argstr
502     text+=" self.failUnlessEqual(%s.getShape(),%s,\"wrong shape of result.\")\n"%(argstr,str(sh))
503     elif case=="Symbol":
504     text+=" self.failUnless(isinstance(%s,Symbol),\"wrong type of result.\")\n"%argstr
505     text+=" self.failUnlessEqual(%s.getShape(),%s,\"wrong shape of result.\")\n"%(argstr,str(sh))
506     return text
507 jgs 154
508 gross 157 def mkCode(txt,args=[],intend=""):
509     s=txt.split("\n")
510     if len(s)>1:
511     out=""
512     for l in s:
513     out+=intend+l+"\n"
514     else:
515     out="%sreturn %s\n"%(intend,txt)
516     c=1
517     for r in args:
518     out=out.replace("%%a%s%%"%c,r)
519     return out
520 jgs 154
521 gross 291 def innerTEST(arg0,arg1):
522     if isinstance(arg0,float):
523     out=numarray.array(arg0*arg1)
524     else:
525     out=(arg0*arg1).sum()
526     return out
527    
528     def outerTEST(arg0,arg1):
529     if isinstance(arg0,float):
530     out=numarray.array(arg0*arg1)
531     elif isinstance(arg1,float):
532     out=numarray.array(arg0*arg1)
533     else:
534     out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape)
535     return out
536    
537     def tensorProductTest(arg0,arg1,sh_s):
538     if isinstance(arg0,float):
539     out=numarray.array(arg0*arg1)
540     elif isinstance(arg1,float):
541     out=numarray.array(arg0*arg1)
542     elif len(sh_s)==0:
543     out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape)
544     else:
545     l=len(sh_s)
546     sh0=arg0.shape[:arg0.rank-l]
547     sh1=arg1.shape[l:]
548     ls,l0,l1=1,1,1
549     for i in sh_s: ls*=i
550     for i in sh0: l0*=i
551     for i in sh1: l1*=i
552     out1=numarray.outerproduct(arg0,arg1).resize((l0,ls,ls,l1))
553     out2=numarray.zeros((l0,l1),numarray.Float)
554     for i0 in range(l0):
555     for i1 in range(l1):
556     for i in range(ls): out2[i0,i1]+=out1[i0,i,i,i1]
557     out=out2.resize(sh0+sh1)
558     return out
559    
560     def testMatrixMult(arg0,arg1,sh_s):
561     return numarray.matrixmultiply(arg0,arg1)
562    
563    
564     def testTensorMult(arg0,arg1,sh_s):
565     if len(arg0)==2:
566     return numarray.matrixmultiply(arg0,arg1)
567     else:
568     if arg1.rank==4:
569     out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2],arg1.shape[3]),numarray.Float)
570     for i0 in range(arg0.shape[0]):
571     for i1 in range(arg0.shape[1]):
572     for i2 in range(arg0.shape[2]):
573     for i3 in range(arg0.shape[3]):
574     for j2 in range(arg1.shape[2]):
575     for j3 in range(arg1.shape[3]):
576     out[i0,i1,j2,j3]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2,j3]
577     elif arg1.rank==3:
578     out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2]),numarray.Float)
579     for i0 in range(arg0.shape[0]):
580     for i1 in range(arg0.shape[1]):
581     for i2 in range(arg0.shape[2]):
582     for i3 in range(arg0.shape[3]):
583     for j2 in range(arg1.shape[2]):
584     out[i0,i1,j2]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2]
585     elif arg1.rank==2:
586     out=numarray.zeros((arg0.shape[0],arg0.shape[1]),numarray.Float)
587     for i0 in range(arg0.shape[0]):
588     for i1 in range(arg0.shape[1]):
589     for i2 in range(arg0.shape[2]):
590     for i3 in range(arg0.shape[3]):
591     out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]
592     return out
593 gross 313
594     def testReduce(arg0,init_val,test_expr,post_expr):
595     out=init_val
596     if isinstance(arg0,float):
597     out=eval(test_expr.replace("%a1%","arg0"))
598     elif arg0.rank==0:
599     out=eval(test_expr.replace("%a1%","arg0"))
600     elif arg0.rank==1:
601     for i0 in range(arg0.shape[0]):
602     out=eval(test_expr.replace("%a1%","arg0[i0]"))
603     elif arg0.rank==2:
604     for i0 in range(arg0.shape[0]):
605     for i1 in range(arg0.shape[1]):
606     out=eval(test_expr.replace("%a1%","arg0[i0,i1]"))
607     elif arg0.rank==3:
608     for i0 in range(arg0.shape[0]):
609     for i1 in range(arg0.shape[1]):
610     for i2 in range(arg0.shape[2]):
611     out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2]"))
612     elif arg0.rank==4:
613     for i0 in range(arg0.shape[0]):
614     for i1 in range(arg0.shape[1]):
615     for i2 in range(arg0.shape[2]):
616     for i3 in range(arg0.shape[3]):
617     out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2,i3]"))
618     return eval(post_expr)
619 gross 396
620     def clipTEST(arg0,mn,mx):
621     if isinstance(arg0,float):
622     return max(min(arg0,mx),mn)
623     out=numarray.zeros(arg0.shape,numarray.Float64)
624     if arg0.rank==1:
625     for i0 in range(arg0.shape[0]):
626     out[i0]=max(min(arg0[i0],mx),mn)
627     elif arg0.rank==2:
628     for i0 in range(arg0.shape[0]):
629     for i1 in range(arg0.shape[1]):
630     out[i0,i1]=max(min(arg0[i0,i1],mx),mn)
631     elif arg0.rank==3:
632     for i0 in range(arg0.shape[0]):
633     for i1 in range(arg0.shape[1]):
634     for i2 in range(arg0.shape[2]):
635     out[i0,i1,i2]=max(min(arg0[i0,i1,i2],mx),mn)
636     elif arg0.rank==4:
637     for i0 in range(arg0.shape[0]):
638     for i1 in range(arg0.shape[1]):
639     for i2 in range(arg0.shape[2]):
640     for i3 in range(arg0.shape[3]):
641     out[i0,i1,i2,i3]=max(min(arg0[i0,i1,i2,i3],mx),mn)
642     return out
643     def minimumTEST(arg0,arg1):
644     if isinstance(arg0,float):
645     if isinstance(arg1,float):
646     if arg0>arg1:
647     return arg1
648     else:
649     return arg0
650     else:
651     arg0=numarray.ones(arg1.shape)*arg0
652     else:
653     if isinstance(arg1,float):
654     arg1=numarray.ones(arg0.shape)*arg1
655     out=numarray.zeros(arg0.shape,numarray.Float64)
656     if arg0.rank==0:
657     if arg0>arg1:
658     out=arg1
659     else:
660     out=arg0
661     elif arg0.rank==1:
662     for i0 in range(arg0.shape[0]):
663     if arg0[i0]>arg1[i0]:
664     out[i0]=arg1[i0]
665     else:
666     out[i0]=arg0[i0]
667     elif arg0.rank==2:
668     for i0 in range(arg0.shape[0]):
669     for i1 in range(arg0.shape[1]):
670     if arg0[i0,i1]>arg1[i0,i1]:
671     out[i0,i1]=arg1[i0,i1]
672     else:
673     out[i0,i1]=arg0[i0,i1]
674     elif arg0.rank==3:
675     for i0 in range(arg0.shape[0]):
676     for i1 in range(arg0.shape[1]):
677     for i2 in range(arg0.shape[2]):
678     if arg0[i0,i1,i2]>arg1[i0,i1,i2]:
679     out[i0,i1,i2]=arg1[i0,i1,i2]
680     else:
681     out[i0,i1,i2]=arg0[i0,i1,i2]
682     elif arg0.rank==4:
683     for i0 in range(arg0.shape[0]):
684     for i1 in range(arg0.shape[1]):
685     for i2 in range(arg0.shape[2]):
686     for i3 in range(arg0.shape[3]):
687     if arg0[i0,i1,i2,i3]>arg1[i0,i1,i2,i3]:
688     out[i0,i1,i2,i3]=arg1[i0,i1,i2,i3]
689     else:
690     out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]
691     return out
692 gross 429 #=======================================================================================================
693 gross 433 # inverse
694     #=======================================================================================================
695     name="inverse"
696     for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
697     for sh0 in [ (1,1), (2,2), (3,3)]:
698     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
699     tname="test_%s_%s_dim%s"%(name,case0,sh0[0])
700     text+=" def %s(self):\n"%tname
701     a_0=makeArray(sh0,[-1.,1])
702     for i in range(sh0[0]): a_0[i,i]+=2
703     if case0 in ["taggedData", "expandedData"]:
704     a1_0=makeArray(sh0,[-1.,1])
705     for i in range(sh0[0]): a1_0[i,i]+=3
706     else:
707     a1_0=a_0
708    
709     text+=mkText(case0,"arg",a_0,a1_0)
710     text+=" res=%s(arg)\n"%name
711     if case0=="Symbol":
712     text+=mkText("array","s",a_0,a1_0)
713     text+=" sub=res.substitute({arg:s})\n"
714     res="sub"
715     ref="s"
716     else:
717     ref="arg"
718     res="res"
719     text+=mkTypeAndShapeTest(case0,sh0,"res")
720     text+=" self.failUnless(Lsup(matrixmult(%s,%s)-kronecker(%s))<=self.RES_TOL,\"wrong result\")\n"%(res,ref,sh0[0])
721    
722     if case0 == "taggedData" :
723     t_prog_with_tags+=text
724     else:
725     t_prog+=text
726    
727     print test_header
728     # print t_prog
729     print t_prog_with_tags
730     print test_tail
731     1/0
732    
733     #=======================================================================================================
734 gross 429 # trace
735     #=======================================================================================================
736     def traceTest(r,offset):
737     sh=r.shape
738     r1=1
739     for i in range(offset): r1*=sh[i]
740     r2=1
741     for i in range(offset+2,len(sh)): r2*=sh[i]
742     r_s=numarray.reshape(r,(r1,sh[offset],sh[offset],r2))
743     s=numarray.zeros([r1,r2],numarray.Float)
744     for i1 in range(r1):
745     for i2 in range(r2):
746     for j in range(sh[offset]): s[i1,i2]+=r_s[i1,j,j,i2]
747     return s.resize(sh[:offset]+sh[offset+2:])
748     name,tt="trace",traceTest
749     for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
750     for sh0 in [ (4,5), (6,2,2),(3,2,3,4)]:
751     for offset in range(len(sh0)-1):
752     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
753     tname="test_%s_%s_rank%s_offset%s"%(name,case0,len(sh0),offset)
754     text+=" def %s(self):\n"%tname
755     sh_t=list(sh0)
756     sh_t[offset+1]=sh_t[offset]
757     sh_t=tuple(sh_t)
758     sh_r=[]
759     for i in range(offset): sh_r.append(sh0[i])
760     for i in range(offset+2,len(sh0)): sh_r.append(sh0[i])
761     sh_r=tuple(sh_r)
762     a_0=makeArray(sh_t,[-1.,1])
763     if case0 in ["taggedData", "expandedData"]:
764     a1_0=makeArray(sh_t,[-1.,1])
765     else:
766     a1_0=a_0
767     r=tt(a_0,offset)
768     r1=tt(a1_0,offset)
769     text+=mkText(case0,"arg",a_0,a1_0)
770     text+=" res=%s(arg,%s)\n"%(name,offset)
771     if case0=="Symbol":
772     text+=mkText("array","s",a_0,a1_0)
773     text+=" sub=res.substitute({arg:s})\n"
774     res="sub"
775     text+=mkText("array","ref",r,r1)
776     else:
777     res="res"
778     text+=mkText(case0,"ref",r,r1)
779     text+=mkTypeAndShapeTest(case0,sh_r,"res")
780     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
781    
782     if case0 == "taggedData" :
783     t_prog_with_tags+=text
784     else:
785     t_prog+=text
786 gross 396
787 gross 429 print test_header
788     # print t_prog
789     print t_prog_with_tags
790     print test_tail
791     1/0
792 gross 396
793 gross 157 #=======================================================================================================
794 gross 396 # clip
795     #=======================================================================================================
796     oper_L=[["clip",clipTEST]]
797     for oper in oper_L:
798     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
799     for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
800     if len(sh0)==0 or not case0=="float":
801     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
802     tname="test_%s_%s_rank%s"%(oper[0],case0,len(sh0))
803     text+=" def %s(self):\n"%tname
804     a_0=makeArray(sh0,[-1.,1])
805     if case0 in ["taggedData", "expandedData"]:
806     a1_0=makeArray(sh0,[-1.,1])
807     else:
808     a1_0=a_0
809    
810     r=oper[1](a_0,-0.3,0.5)
811     r1=oper[1](a1_0,-0.3,0.5)
812     text+=mkText(case0,"arg",a_0,a1_0)
813     text+=" res=%s(arg,-0.3,0.5)\n"%oper[0]
814     if case0=="Symbol":
815     text+=mkText("array","s",a_0,a1_0)
816     text+=" sub=res.substitute({arg:s})\n"
817     res="sub"
818     text+=mkText("array","ref",r,r1)
819     else:
820     res="res"
821     text+=mkText(case0,"ref",r,r1)
822     text+=mkTypeAndShapeTest(case0,sh0,"res")
823     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
824    
825     if case0 == "taggedData" :
826     t_prog_with_tags+=text
827     else:
828     t_prog+=text
829    
830     print test_header
831     # print t_prog
832     print t_prog_with_tags
833     print test_tail
834     1/0
835 gross 429
836 gross 396 #=======================================================================================================
837     # maximum, minimum, clipping
838     #=======================================================================================================
839     oper_L=[ ["maximum",maximumTEST],
840     ["minimum",minimumTEST]]
841     for oper in oper_L:
842     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
843     for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
844     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
845     for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
846     if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \
847     and (sh0==sh1 or len(sh0)==0 or len(sh1)==0) :
848     use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
849    
850     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
851     tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
852     text+=" def %s(self):\n"%tname
853     a_0=makeArray(sh0,[-1.,1])
854     if case0 in ["taggedData", "expandedData"]:
855     a1_0=makeArray(sh0,[-1.,1])
856     else:
857     a1_0=a_0
858    
859     a_1=makeArray(sh1,[-1.,1])
860     if case1 in ["taggedData", "expandedData"]:
861     a1_1=makeArray(sh1,[-1.,1])
862     else:
863     a1_1=a_1
864     r=oper[1](a_0,a_1)
865     r1=oper[1](a1_0,a1_1)
866     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
867     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
868     text+=" res=%s(arg0,arg1)\n"%oper[0]
869     case=getResultCaseForBin(case0,case1)
870     if case=="Symbol":
871     c0_res,c1_res=case0,case1
872     subs="{"
873     if case0=="Symbol":
874     text+=mkText("array","s0",a_0,a1_0)
875     subs+="arg0:s0"
876     c0_res="array"
877     if case1=="Symbol":
878     text+=mkText("array","s1",a_1,a1_1)
879     if not subs.endswith("{"): subs+=","
880     subs+="arg1:s1"
881     c1_res="array"
882     subs+="}"
883     text+=" sub=res.substitute(%s)\n"%subs
884     res="sub"
885     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
886     else:
887     res="res"
888     text+=mkText(case,"ref",r,r1)
889     if len(sh0)>len(sh1):
890     text+=mkTypeAndShapeTest(case,sh0,"res")
891     else:
892     text+=mkTypeAndShapeTest(case,sh1,"res")
893     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
894    
895     if case0 == "taggedData" or case1 == "taggedData":
896     t_prog_with_tags+=text
897     else:
898     t_prog+=text
899    
900     print test_header
901     # print t_prog
902     print t_prog_with_tags
903     print test_tail
904     1/0
905    
906    
907     #=======================================================================================================
908     # outer inner
909     #=======================================================================================================
910     oper=["outer",outerTEST]
911     # oper=["inner",innerTEST]
912     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
913     for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
914     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
915     for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
916     if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \
917     and len(sh0+sh1)<5:
918     use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
919    
920     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
921     tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
922     text+=" def %s(self):\n"%tname
923     a_0=makeArray(sh0,[-1.,1])
924     if case0 in ["taggedData", "expandedData"]:
925     a1_0=makeArray(sh0,[-1.,1])
926     else:
927     a1_0=a_0
928    
929     a_1=makeArray(sh1,[-1.,1])
930     if case1 in ["taggedData", "expandedData"]:
931     a1_1=makeArray(sh1,[-1.,1])
932     else:
933     a1_1=a_1
934     r=oper[1](a_0,a_1)
935     r1=oper[1](a1_0,a1_1)
936     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
937     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
938     text+=" res=%s(arg0,arg1)\n"%oper[0]
939     case=getResultCaseForBin(case0,case1)
940     if case=="Symbol":
941     c0_res,c1_res=case0,case1
942     subs="{"
943     if case0=="Symbol":
944     text+=mkText("array","s0",a_0,a1_0)
945     subs+="arg0:s0"
946     c0_res="array"
947     if case1=="Symbol":
948     text+=mkText("array","s1",a_1,a1_1)
949     if not subs.endswith("{"): subs+=","
950     subs+="arg1:s1"
951     c1_res="array"
952     subs+="}"
953     text+=" sub=res.substitute(%s)\n"%subs
954     res="sub"
955     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
956     else:
957     res="res"
958     text+=mkText(case,"ref",r,r1)
959     text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
960     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
961    
962     if case0 == "taggedData" or case1 == "taggedData":
963     t_prog_with_tags+=text
964     else:
965     t_prog+=text
966    
967     print test_header
968     # print t_prog
969     print t_prog_with_tags
970     print test_tail
971     1/0
972    
973     #=======================================================================================================
974 gross 313 # local reduction
975     #=======================================================================================================
976     for oper in [["length",0.,"out+%a1%**2","math.sqrt(out)"],
977     ["maxval",-1.e99,"max(out,%a1%)","out"],
978     ["minval",1.e99,"min(out,%a1%)","out"] ]:
979     for case in case_set:
980     for sh in shape_set:
981     if not case=="float" or len(sh)==0:
982     text=""
983     text+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
984     tname="def test_%s_%s_rank%s"%(oper[0],case,len(sh))
985     text+=" %s(self):\n"%tname
986     a=makeArray(sh,[-1.,1.])
987     a1=makeArray(sh,[-1.,1.])
988     r1=testReduce(a1,oper[1],oper[2],oper[3])
989     r=testReduce(a,oper[1],oper[2],oper[3])
990    
991     text+=mkText(case,"arg",a,a1)
992     text+=" res=%s(arg)\n"%oper[0]
993     if case=="Symbol":
994     text+=mkText("array","s",a,a1)
995     text+=" sub=res.substitute({arg:s})\n"
996     text+=mkText("array","ref",r,r1)
997     res="sub"
998     else:
999     text+=mkText(case,"ref",r,r1)
1000     res="res"
1001     if oper[0]=="length":
1002     text+=mkTypeAndShapeTest(case,(),"res")
1003     else:
1004     if case=="float" or case=="array":
1005     text+=mkTypeAndShapeTest("float",(),"res")
1006     else:
1007     text+=mkTypeAndShapeTest(case,(),"res")
1008     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1009     if case == "taggedData":
1010     t_prog_with_tags+=text
1011     else:
1012     t_prog+=text
1013     print test_header
1014     # print t_prog
1015     print t_prog_with_tags
1016     print test_tail
1017     1/0
1018    
1019     #=======================================================================================================
1020 gross 291 # tensor multiply
1021 gross 157 #=======================================================================================================
1022 gross 291 # oper=["generalTensorProduct",tensorProductTest]
1023     # oper=["matrixmult",testMatrixMult]
1024     oper=["tensormult",testTensorMult]
1025    
1026     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1027     for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1028     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
1029     for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1030     for sh_s in [ (),(3,), (2,3), (2,4,3),(4,2,3,2)]:
1031     if (len(sh0+sh_s)==0 or not case0=="float") and (len(sh1+sh_s)==0 or not case1=="float") \
1032     and len(sh0+sh1)<5 and len(sh0+sh_s)<5 and len(sh1+sh_s)<5:
1033     # if len(sh_s)==1 and len(sh0+sh_s)==2 and (len(sh_s+sh1)==1 or len(sh_s+sh1)==2)): # test for matrixmult
1034     if ( len(sh_s)==1 and len(sh0+sh_s)==2 and ( len(sh1+sh_s)==2 or len(sh1+sh_s)==1 )) or (len(sh_s)==2 and len(sh0+sh_s)==4 and (len(sh1+sh_s)==2 or len(sh1+sh_s)==3 or len(sh1+sh_s)==4)): # test for tensormult
1035     case=getResultCaseForBin(case0,case1)
1036 gross 157 use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1037     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1038 gross 291 # tname="test_generalTensorProduct_%s_rank%s_%s_rank%s_offset%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1),len(sh_s))
1039     #tname="test_matrixmult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
1040     tname="test_tensormult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
1041     # if tname=="test_generalTensorProduct_array_rank1_array_rank2_offset1":
1042     # print tnametest_generalTensorProduct_Symbol_rank1_Symbol_rank3_offset1
1043     text+=" def %s(self):\n"%tname
1044     a_0=makeArray(sh0+sh_s,[-1.,1])
1045     if case0 in ["taggedData", "expandedData"]:
1046     a1_0=makeArray(sh0+sh_s,[-1.,1])
1047     else:
1048     a1_0=a_0
1049    
1050     a_1=makeArray(sh_s+sh1,[-1.,1])
1051     if case1 in ["taggedData", "expandedData"]:
1052     a1_1=makeArray(sh_s+sh1,[-1.,1])
1053     else:
1054     a1_1=a_1
1055     r=oper[1](a_0,a_1,sh_s)
1056     r1=oper[1](a1_0,a1_1,sh_s)
1057     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1058     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1059     #text+=" res=matrixmult(arg0,arg1)\n"
1060     text+=" res=tensormult(arg0,arg1)\n"
1061     #text+=" res=generalTensorProduct(arg0,arg1,offset=%s)\n"%(len(sh_s))
1062     if case=="Symbol":
1063     c0_res,c1_res=case0,case1
1064     subs="{"
1065     if case0=="Symbol":
1066     text+=mkText("array","s0",a_0,a1_0)
1067     subs+="arg0:s0"
1068     c0_res="array"
1069     if case1=="Symbol":
1070     text+=mkText("array","s1",a_1,a1_1)
1071     if not subs.endswith("{"): subs+=","
1072     subs+="arg1:s1"
1073     c1_res="array"
1074     subs+="}"
1075     text+=" sub=res.substitute(%s)\n"%subs
1076     res="sub"
1077     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1078     else:
1079     res="res"
1080     text+=mkText(case,"ref",r,r1)
1081     text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
1082     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1083     if case0 == "taggedData" or case1 == "taggedData":
1084     t_prog_with_tags+=text
1085     else:
1086     t_prog+=text
1087     print test_header
1088     # print t_prog
1089     print t_prog_with_tags
1090     print test_tail
1091     1/0
1092     #=======================================================================================================
1093 gross 157 # basic binary operation overloading (tests only!)
1094     #=======================================================================================================
1095     oper_range=[-5.,5.]
1096     for oper in [["add" ,"+",[-5.,5.]],
1097     ["sub" ,"-",[-5.,5.]],
1098     ["mult","*",[-5.,5.]],
1099     ["div" ,"/",[-5.,5.]],
1100     ["pow" ,"**",[0.01,5.]]]:
1101     for case0 in case_set:
1102     for sh0 in shape_set:
1103     for case1 in case_set:
1104     for sh1 in shape_set:
1105 gross 291 if not case0=="array" and \
1106     (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
1107 gross 157 (sh0==() or sh1==() or sh1==sh0) and \
1108     not (case0 in ["float","array"] and case1 in ["float","array"]):
1109 gross 291 use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1110 gross 157 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1111     tname="test_%s_overloaded_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
1112     text+=" def %s(self):\n"%tname
1113     a_0=makeArray(sh0,oper[2])
1114     if case0 in ["taggedData", "expandedData"]:
1115     a1_0=makeArray(sh0,oper[2])
1116     else:
1117     a1_0=a_0
1118 jgs 154
1119 gross 157 a_1=makeArray(sh1,oper[2])
1120     if case1 in ["taggedData", "expandedData"]:
1121     a1_1=makeArray(sh1,oper[2])
1122 jgs 154 else:
1123 gross 157 a1_1=a_1
1124     r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
1125     r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
1126     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1127     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1128     text+=" res=arg0%sarg1\n"%oper[1]
1129    
1130     case=getResultCaseForBin(case0,case1)
1131     if case=="Symbol":
1132     c0_res,c1_res=case0,case1
1133     subs="{"
1134     if case0=="Symbol":
1135     text+=mkText("array","s0",a_0,a1_0)
1136     subs+="arg0:s0"
1137     c0_res="array"
1138     if case1=="Symbol":
1139     text+=mkText("array","s1",a_1,a1_1)
1140     if not subs.endswith("{"): subs+=","
1141     subs+="arg1:s1"
1142     c1_res="array"
1143     subs+="}"
1144     text+=" sub=res.substitute(%s)\n"%subs
1145     res="sub"
1146     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1147 jgs 154 else:
1148 gross 157 res="res"
1149     text+=mkText(case,"ref",r,r1)
1150     if isinstance(r,float):
1151     text+=mkTypeAndShapeTest(case,(),"res")
1152     else:
1153     text+=mkTypeAndShapeTest(case,r.shape,"res")
1154     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1155    
1156     if case0 in [ "constData","taggedData","expandedData"] and case1 == "Symbol":
1157     t_prog_failing+=text
1158     else:
1159     if case0 == "taggedData" or case1 == "taggedData":
1160     t_prog_with_tags+=text
1161     else:
1162     t_prog+=text
1163 jgs 154
1164 gross 157
1165     print test_header
1166 gross 291 # print t_prog
1167     # print t_prog_with_tags
1168     print t_prog_failing
1169     print test_tail
1170 jgs 154 1/0
1171 gross 291 #=======================================================================================================
1172     # basic binary operations (tests only!)
1173     #=======================================================================================================
1174     oper_range=[-5.,5.]
1175     for oper in [["add" ,"+",[-5.,5.]],
1176     ["mult","*",[-5.,5.]],
1177     ["quotient" ,"/",[-5.,5.]],
1178     ["power" ,"**",[0.01,5.]]]:
1179     for case0 in case_set:
1180     for case1 in case_set:
1181     for sh in shape_set:
1182     for sh_p in shape_set:
1183     if len(sh_p)>0:
1184     resource=[-1,1]
1185     else:
1186     resource=[1]
1187     for sh_d in resource:
1188     if sh_d>0:
1189     sh0=sh
1190     sh1=sh+sh_p
1191     else:
1192     sh1=sh
1193     sh0=sh+sh_p
1194    
1195     if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
1196     len(sh0)<5 and len(sh1)<5:
1197     use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1198     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1199     tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
1200     text+=" def %s(self):\n"%tname
1201     a_0=makeArray(sh0,oper[2])
1202     if case0 in ["taggedData", "expandedData"]:
1203     a1_0=makeArray(sh0,oper[2])
1204     else:
1205     a1_0=a_0
1206    
1207     a_1=makeArray(sh1,oper[2])
1208     if case1 in ["taggedData", "expandedData"]:
1209     a1_1=makeArray(sh1,oper[2])
1210     else:
1211     a1_1=a_1
1212     r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
1213     r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
1214     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1215     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1216     text+=" res=%s(arg0,arg1)\n"%oper[0]
1217    
1218     case=getResultCaseForBin(case0,case1)
1219     if case=="Symbol":
1220     c0_res,c1_res=case0,case1
1221     subs="{"
1222     if case0=="Symbol":
1223     text+=mkText("array","s0",a_0,a1_0)
1224     subs+="arg0:s0"
1225     c0_res="array"
1226     if case1=="Symbol":
1227     text+=mkText("array","s1",a_1,a1_1)
1228     if not subs.endswith("{"): subs+=","
1229     subs+="arg1:s1"
1230     c1_res="array"
1231     subs+="}"
1232     text+=" sub=res.substitute(%s)\n"%subs
1233     res="sub"
1234     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1235     else:
1236     res="res"
1237     text+=mkText(case,"ref",r,r1)
1238     if isinstance(r,float):
1239     text+=mkTypeAndShapeTest(case,(),"res")
1240     else:
1241     text+=mkTypeAndShapeTest(case,r.shape,"res")
1242     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1243    
1244     if case0 == "taggedData" or case1 == "taggedData":
1245     t_prog_with_tags+=text
1246     else:
1247     t_prog+=text
1248     print test_header
1249     # print t_prog
1250     print t_prog_with_tags
1251     print test_tail
1252     1/0
1253    
1254 gross 157 # print t_prog_with_tagsoper_range=[-5.,5.]
1255     for oper in [["add" ,"+",[-5.,5.]],
1256     ["sub" ,"-",[-5.,5.]],
1257     ["mult","*",[-5.,5.]],
1258     ["div" ,"/",[-5.,5.]],
1259     ["pow" ,"**",[0.01,5.]]]:
1260     for case0 in case_set:
1261     for sh0 in shape_set:
1262     for case1 in case_set:
1263     for sh1 in shape_set:
1264     if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
1265     (sh0==() or sh1==() or sh1==sh0) and \
1266     not (case0 in ["float","array"] and case1 in ["float","array"]):
1267     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1268     tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
1269     text+=" def %s(self):\n"%tname
1270     a_0=makeArray(sh0,oper[2])
1271     if case0 in ["taggedData", "expandedData"]:
1272     a1_0=makeArray(sh0,oper[2])
1273     else:
1274     a1_0=a_0
1275 jgs 154
1276 gross 157 a_1=makeArray(sh1,oper[2])
1277     if case1 in ["taggedData", "expandedData"]:
1278     a1_1=makeArray(sh1,oper[2])
1279 jgs 154 else:
1280 gross 157 a1_1=a_1
1281     r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
1282     r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
1283     text+=mkText(case0,"arg0",a_0,a1_0)
1284     text+=mkText(case1,"arg1",a_1,a1_1)
1285     text+=" res=arg0%sarg1\n"%oper[1]
1286    
1287     case=getResultCaseForBin(case0,case1)
1288     if case=="Symbol":
1289     c0_res,c1_res=case0,case1
1290     subs="{"
1291     if case0=="Symbol":
1292     text+=mkText("array","s0",a_0,a1_0)
1293     subs+="arg0:s0"
1294     c0_res="array"
1295     if case1=="Symbol":
1296     text+=mkText("array","s1",a_1,a1_1)
1297     if not subs.endswith("{"): subs+=","
1298     subs+="arg1:s1"
1299     c1_res="array"
1300     subs+="}"
1301     text+=" sub=res.substitute(%s)\n"%subs
1302     res="sub"
1303     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1304 jgs 154 else:
1305 gross 157 res="res"
1306     text+=mkText(case,"ref",r,r1)
1307     if isinstance(r,float):
1308     text+=mkTypeAndShapeTest(case,(),"res")
1309     else:
1310     text+=mkTypeAndShapeTest(case,r.shape,"res")
1311     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1312    
1313     if case0 in [ "constData","taggedData","expandedData"] and case1 == "Symbol":
1314     t_prog_failing+=text
1315     else:
1316     if case0 == "taggedData" or case1 == "taggedData":
1317     t_prog_with_tags+=text
1318     else:
1319     t_prog+=text
1320 jgs 154
1321 gross 157
1322     # print u_prog
1323     # 1/0
1324     print test_header
1325 jgs 154 print t_prog
1326 gross 157 # print t_prog_with_tags
1327     # print t_prog_failing
1328     print test_tail
1329     # print t_prog_failing
1330     print test_tail
1331 jgs 154
1332 gross 157 #=======================================================================================================
1333     # unary operations:
1334     #=======================================================================================================
1335     func= [
1336     OPERATOR(nickname="log10",\
1337     rng=[1.e-3,100.],\
1338     test_expr="math.log10(%a1%)",\
1339     math_expr="math.log10(%a1%)",\
1340     numarray_expr="numarray.log10(%a1%)",\
1341     symbol_expr="log(%a1%)/log(10.)",\
1342     name="base-10 logarithm"),
1343     OPERATOR(nickname="wherePositive",\
1344     rng=[-100.,100.],\
1345     test_expr="wherepos(%a1%)",\
1346     math_expr="if arg>0:\n return 1.\nelse:\n return 0.",
1347     numarray_expr="numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))",\
1348     name="mask of positive values"),
1349     OPERATOR(nickname="whereNegative",\
1350     rng=[-100.,100.],\
1351     test_expr="wherepos(-%a1%)",\
1352     math_expr="if arg<0:\n return 1.\nelse:\n return 0.",
1353     numarray_expr="numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))",\
1354     name="mask of positive values"),
1355     OPERATOR(nickname="whereNonNegative",\
1356     rng=[-100.,100.],\
1357     test_expr="1-wherepos(-%a1%)", \
1358     math_expr="if arg<0:\n return 0.\nelse:\n return 1.",
1359     numarray_expr="numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))",\
1360     symbol_expr="1-wherePositive(%a1%)",\
1361     name="mask of non-negative values"),
1362     OPERATOR(nickname="whereNonPositive",\
1363     rng=[-100.,100.],\
1364     test_expr="1-wherepos(%a1%)",\
1365     math_expr="if arg>0:\n return 0.\nelse:\n return 1.",
1366     numarray_expr="numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))",\
1367     symbol_expr="1-whereNegative(%a1%)",\
1368     name="mask of non-positive values"),
1369     OPERATOR(nickname="whereZero",\
1370     rng=[-100.,100.],\
1371     test_expr="1-wherepos(%a1%)-wherepos(-%a1%)",\
1372     math_expr="if abs(%a1%)<=tol:\n return 1.\nelse:\n return 0.",
1373     numarray_expr="numarray.less_equal(abs(%a1%)-tol,numarray.zeros(arg.shape,numarray.Float))",\
1374     name="mask of zero entries"),
1375     OPERATOR(nickname="whereNonZero",\
1376     rng=[-100.,100.],\
1377     test_expr="wherepos(%a1%)+wherepos(-%a1%)",\
1378     math_expr="if abs(%a1%)>tol:\n return 1.\nelse:\n return 0.",\
1379     numarray_expr="numarray.greater(abs(%a1%)-tol,numarray.zeros(arg.shape,numarray.Float))",\
1380     symbol_expr="1-whereZero(arg,tol)",\
1381     name="mask of values different from zero"),
1382     OPERATOR(nickname="sin",\
1383     rng=[-100.,100.],\
1384     test_expr="math.sin(%a1%)",
1385     numarray_expr="numarray.sin(%a1%)",\
1386     diff="cos(%a1%)",\
1387     name="sine"),
1388     OPERATOR(nickname="cos",\
1389     rng=[-100.,100.],\
1390     test_expr="math.cos(%a1%)",
1391     numarray_expr="numarray.cos(%a1%)",\
1392     diff="-sin(%a1%)",
1393     name="cosine"),
1394     OPERATOR(nickname="tan",\
1395     rng=[-100.,100.],\
1396     test_expr="math.tan(%a1%)",
1397     numarray_expr="numarray.tan(%a1%)",\
1398     diff="1./cos(%a1%)**2",
1399     name="tangent"),
1400     OPERATOR(nickname="asin",\
1401     rng=[-0.99,0.99],\
1402     test_expr="math.asin(%a1%)",
1403     numarray_expr="numarray.arcsin(%a1%)",
1404     diff="1./sqrt(1.-%a1%**2)",
1405     name="inverse sine"),
1406     OPERATOR(nickname="acos",\
1407     rng=[-0.99,0.99],\
1408     test_expr="math.acos(%a1%)",
1409     numarray_expr="numarray.arccos(%a1%)",
1410     diff="-1./sqrt(1.-%a1%**2)",
1411     name="inverse cosine"),
1412     OPERATOR(nickname="atan",\
1413     rng=[-100.,100.],\
1414     test_expr="math.atan(%a1%)",
1415     numarray_expr="numarray.arctan(%a1%)",
1416     diff="1./(1+%a1%**2)",
1417     name="inverse tangent"),
1418     OPERATOR(nickname="sinh",\
1419     rng=[-5,5],\
1420     test_expr="math.sinh(%a1%)",
1421     numarray_expr="numarray.sinh(%a1%)",
1422     diff="cosh(%a1%)",
1423     name="hyperbolic sine"),
1424     OPERATOR(nickname="cosh",\
1425     rng=[-5.,5.],
1426     test_expr="math.cosh(%a1%)",
1427     numarray_expr="numarray.cosh(%a1%)",
1428     diff="sinh(%a1%)",
1429     name="hyperbolic cosine"),
1430     OPERATOR(nickname="tanh",\
1431     rng=[-5.,5.],
1432     test_expr="math.tanh(%a1%)",
1433     numarray_expr="numarray.tanh(%a1%)",
1434     diff="1./cosh(%a1%)**2",
1435     name="hyperbolic tangent"),
1436     OPERATOR(nickname="asinh",\
1437     rng=[-100.,100.], \
1438     test_expr="numarray.arcsinh(%a1%)",
1439     math_expr="numarray.arcsinh(%a1%)",
1440     numarray_expr="numarray.arcsinh(%a1%)",
1441     diff="1./sqrt(%a1%**2+1)",
1442     name="inverse hyperbolic sine"),
1443     OPERATOR(nickname="acosh",\
1444     rng=[1.001,100.],\
1445     test_expr="numarray.arccosh(%a1%)",
1446     math_expr="numarray.arccosh(%a1%)",
1447     numarray_expr="numarray.arccosh(%a1%)",
1448     diff="1./sqrt(%a1%**2-1)",
1449     name="inverse hyperolic cosine"),
1450     OPERATOR(nickname="atanh",\
1451     rng=[-0.99,0.99], \
1452     test_expr="numarray.arctanh(%a1%)",
1453     math_expr="numarray.arctanh(%a1%)",
1454     numarray_expr="numarray.arctanh(%a1%)",
1455     diff="1./(1.-%a1%**2)",
1456     name="inverse hyperbolic tangent"),
1457     OPERATOR(nickname="exp",\
1458     rng=[-5.,5.],
1459     test_expr="math.exp(%a1%)",
1460     numarray_expr="numarray.exp(%a1%)",
1461     diff="self",
1462     name="exponential"),
1463     OPERATOR(nickname="sqrt",\
1464     rng=[1.e-3,100.],\
1465     test_expr="math.sqrt(%a1%)",
1466     numarray_expr="numarray.sqrt(%a1%)",
1467     diff="0.5/self",
1468     name="square root"),
1469     OPERATOR(nickname="log", \
1470     rng=[1.e-3,100.],\
1471     test_expr="math.log(%a1%)",
1472     numarray_expr="numarray.log(%a1%)",
1473     diff="1./arg",
1474     name="natural logarithm"),
1475     OPERATOR(nickname="sign",\
1476     rng=[-100.,100.], \
1477     math_expr="if %a1%>0:\n return 1.\nelif %a1%<0:\n return -1.\nelse:\n return 0.",
1478     test_expr="wherepos(%a1%)-wherepos(-%a1%)",
1479     numarray_expr="numarray.sign(%a1%)",
1480     symbol_expr="wherePositive(%a1%)-whereNegative(%a1%)",\
1481     name="sign"),
1482     OPERATOR(nickname="abs",\
1483     rng=[-100.,100.], \
1484     math_expr="if %a1%>0:\n return %a1% \nelif %a1%<0:\n return -(%a1%)\nelse:\n return 0.",
1485     test_expr="wherepos(%a1%)*(%a1%)-wherepos(-%a1%)*(%a1%)",
1486     numarray_expr="abs(%a1%)",
1487     diff="sign(%a1%)",
1488     name="absolute value")
1489    
1490     ]
1491     for f in func:
1492     symbol_name=f.nickname[0].upper()+f.nickname[1:]
1493     if f.nickname!="abs":
1494     u_prog+="def %s(arg):\n"%f.nickname
1495     u_prog+=" \"\"\"\n"
1496     u_prog+=" returns %s of argument arg\n\n"%f.name
1497     u_prog+=" @param arg: argument\n"
1498     u_prog+=" @type arg: C{float}, L{escript.Data}, L{Symbol}, L{numarray.NumArray}.\n"
1499     u_prog+=" @rtype:C{float}, L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.\n"
1500     u_prog+=" @raises TypeError: if the type of the argument is not expected.\n"
1501     u_prog+=" \"\"\"\n"
1502     u_prog+=" if isinstance(arg,numarray.NumArray):\n"
1503     u_prog+=mkCode(f.numarray_expr,["arg"],2*" ")
1504     u_prog+=" elif isinstance(arg,escript.Data):\n"
1505     u_prog+=mkCode("arg._%s()"%f.nickname,[],2*" ")
1506     u_prog+=" elif isinstance(arg,float):\n"
1507     u_prog+=mkCode(f.math_expr,["arg"],2*" ")
1508     u_prog+=" elif isinstance(arg,int):\n"
1509     u_prog+=mkCode(f.math_expr,["float(arg)"],2*" ")
1510     u_prog+=" elif isinstance(arg,Symbol):\n"
1511     if f.symbol_expr==None:
1512     u_prog+=mkCode("%s_Symbol(arg)"%symbol_name,[],2*" ")
1513     else:
1514     u_prog+=mkCode(f.symbol_expr,["arg"],2*" ")
1515     u_prog+=" else:\n"
1516     u_prog+=" raise TypeError,\"%s: Unknown argument type.\"\n\n"%f.nickname
1517     if f.symbol_expr==None:
1518     u_prog+="class %s_Symbol(DependendSymbol):\n"%symbol_name
1519     u_prog+=" \"\"\"\n"
1520     u_prog+=" L{Symbol} representing the result of the %s function\n"%f.name
1521     u_prog+=" \"\"\"\n"
1522     u_prog+=" def __init__(self,arg):\n"
1523     u_prog+=" \"\"\"\n"
1524     u_prog+=" initialization of %s L{Symbol} with argument arg\n"%f.nickname
1525     u_prog+=" @param arg: argument of function\n"
1526     u_prog+=" @type arg: typically L{Symbol}.\n"
1527     u_prog+=" \"\"\"\n"
1528     u_prog+=" DependendSymbol.__init__(self,args=[arg],shape=arg.getShape(),dim=arg.getDim())\n"
1529     u_prog+="\n"
1530    
1531     u_prog+=" def getMyCode(self,argstrs,format=\"escript\"):\n"
1532     u_prog+=" \"\"\"\n"
1533     u_prog+=" returns a program code that can be used to evaluate the symbol.\n\n"
1534 jgs 154
1535 gross 157 u_prog+=" @param argstrs: gives for each argument a string representing the argument for the evaluation.\n"
1536     u_prog+=" @type argstrs: C{str} or a C{list} of length 1 of C{str}.\n"
1537     u_prog+=" @param format: specifies the format to be used. At the moment only \"escript\" ,\"text\" and \"str\" are supported.\n"
1538     u_prog+=" @type format: C{str}\n"
1539     u_prog+=" @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.\n"
1540     u_prog+=" @rtype: C{str}\n"
1541     u_prog+=" @raise: NotImplementedError: if the requested format is not available\n"
1542     u_prog+=" \"\"\"\n"
1543     u_prog+=" if isinstance(argstrs,list):\n"
1544     u_prog+=" argstrs=argstrs[0]\n"
1545     u_prog+=" if format==\"escript\" or format==\"str\" or format==\"text\":\n"
1546     u_prog+=" return \"%s(%%s)\"%%argstrs\n"%f.nickname
1547     u_prog+=" else:\n"
1548     u_prog+=" raise NotImplementedError,\"%s_Symbol does not provide program code for format %%s.\"%%format\n"%symbol_name
1549     u_prog+="\n"
1550 jgs 154
1551 gross 157 u_prog+=" def substitute(self,argvals):\n"
1552     u_prog+=" \"\"\"\n"
1553     u_prog+=" assigns new values to symbols in the definition of the symbol.\n"
1554     u_prog+=" The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.\n"
1555     u_prog+="\n"
1556     u_prog+=" @param argvals: new values assigned to symbols\n"
1557     u_prog+=" @type argvals: C{dict} with keywords of type L{Symbol}.\n"
1558     u_prog+=" @return: result of the substitution process. Operations are executed as much as possible.\n"
1559     u_prog+=" @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution\n"
1560     u_prog+=" @raise TypeError: if a value for a L{Symbol} cannot be substituted.\n"
1561     u_prog+=" \"\"\"\n"
1562     u_prog+=" if argvals.has_key(self):\n"
1563     u_prog+=" arg=argvals[self]\n"
1564     u_prog+=" if self.isAppropriateValue(arg):\n"
1565     u_prog+=" return arg\n"
1566     u_prog+=" else:\n"
1567     u_prog+=" raise TypeError,\"%s: new value is not appropriate.\"%str(self)\n"
1568     u_prog+=" else:\n"
1569     u_prog+=" arg=self.getSubstitutedArguments(argvals)[0]\n"
1570     u_prog+=" return %s(arg)\n\n"%f.nickname
1571     if not f.diff==None:
1572     u_prog+=" def diff(self,arg):\n"
1573     u_prog+=" \"\"\"\n"
1574     u_prog+=" differential of this object\n"
1575     u_prog+="\n"
1576     u_prog+=" @param arg: the derivative is calculated with respect to arg\n"
1577     u_prog+=" @type arg: L{escript.Symbol}\n"
1578     u_prog+=" @return: derivative with respect to C{arg}\n"
1579     u_prog+=" @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray} are possible.\n"
1580     u_prog+=" \"\"\"\n"
1581     u_prog+=" if arg==self:\n"
1582     u_prog+=" return identity(self.getShape())\n"
1583     u_prog+=" else:\n"
1584     u_prog+=" myarg=self.getArgument()[0]\n"
1585     u_prog+=" val=matchShape(%s,self.getDifferentiatedArguments(arg)[0])\n"%f.diff.replace("%a1%","myarg")
1586     u_prog+=" return val[0]*val[1]\n\n"
1587 jgs 154
1588 gross 157 for case in case_set:
1589     for sh in shape_set:
1590     if not case=="float" or len(sh)==0:
1591     text=""
1592     text+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1593     tname="def test_%s_%s_rank%s"%(f.nickname,case,len(sh))
1594     text+=" %s(self):\n"%tname
1595     a=makeArray(sh,f.rng)
1596     a1=makeArray(sh,f.rng)
1597     r1=makeResult(a1,f.test_expr)
1598     r=makeResult(a,f.test_expr)
1599    
1600     text+=mkText(case,"arg",a,a1)
1601     text+=" res=%s(arg)\n"%f.nickname
1602     if case=="Symbol":
1603     text+=mkText("array","s",a,a1)
1604     text+=" sub=res.substitute({arg:s})\n"
1605     text+=mkText("array","ref",r,r1)
1606     res="sub"
1607     else:
1608     text+=mkText(case,"ref",r,r1)
1609     res="res"
1610     text+=mkTypeAndShapeTest(case,sh,"res")
1611     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1612     if case == "taggedData":
1613     t_prog_with_tags+=text
1614     else:
1615     t_prog+=text
1616    
1617     #=========== END OF GOOD CODE +++++++++++++++++++++++++++
1618 jgs 154
1619 gross 157 1/0
1620 jgs 154
1621 gross 157 def X():
1622     if args=="float":
1623     a=makeArray(sh,f[RANGE])
1624 jgs 154 r=makeResult(a,f)
1625 gross 157 t_prog+=" arg=%s\n"%a[0]
1626     t_prog+=" ref=%s\n"%r[0]
1627     t_prog+=" res=%s(%a1%)\n"%f.nickname
1628     t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1629     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1630     elif args == "array":
1631     a=makeArray(sh,f[RANGE])
1632     r=makeResult(a,f)
1633 jgs 154 if len(sh)==0:
1634 gross 157 t_prog+=" arg=numarray.array(%s)\n"%a[0]
1635     t_prog+=" ref=numarray.array(%s)\n"%r[0]
1636 jgs 154 else:
1637     t_prog+=" arg=numarray.array(%s)\n"%a.tolist()
1638     t_prog+=" ref=numarray.array(%s)\n"%r.tolist()
1639 gross 157 t_prog+=" res=%s(%a1%)\n"%f.nickname
1640     t_prog+=" self.failUnlessEqual(res.shape,%s,\"wrong shape of result.\")\n"%str(sh)
1641     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1642 jgs 154 elif args== "constData":
1643 gross 157 a=makeArray(sh,f[RANGE])
1644 jgs 154 r=makeResult(a,f)
1645     if len(sh)==0:
1646     t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
1647     t_prog+=" ref=%s\n"%r
1648     else:
1649     t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist())
1650     t_prog+=" ref=numarray.array(%s)\n"%r.tolist()
1651 gross 157 t_prog+=" res=%s(%a1%)\n"%f.nickname
1652 jgs 154 t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh)
1653     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1654     elif args in [ "taggedData","expandedData"]:
1655 gross 157 a=makeArray(sh,f[RANGE])
1656 jgs 154 r=makeResult(a,f)
1657 gross 157 a1=makeArray(sh,f[RANGE])
1658 jgs 154 r1=makeResult(a1,f)
1659     if len(sh)==0:
1660     if args=="expandedData":
1661     t_prog+=" arg=Data(%s,self.functionspace,True)\n"%(a)
1662     t_prog+=" ref=Data(%s,self.functionspace,True)\n"%(r)
1663     else:
1664     t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
1665     t_prog+=" ref=Data(%s,self.functionspace)\n"%(r)
1666     t_prog+=" arg.setTaggedValue(1,%s)\n"%a
1667     t_prog+=" ref.setTaggedValue(1,%s)\n"%r1
1668     else:
1669     if args=="expandedData":
1670     t_prog+=" arg=Data(numarray.array(%s),self.functionspace,True)\n"%(a.tolist())
1671     t_prog+=" ref=Data(numarray.array(%s),self.functionspace,True)\n"%(r.tolist())
1672     else:
1673     t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist())
1674     t_prog+=" ref=Data(numarray.array(%s),self.functionspace)\n"%(r.tolist())
1675     t_prog+=" arg.setTaggedValue(1,%s)\n"%a1.tolist()
1676     t_prog+=" ref.setTaggedValue(1,%s)\n"%r1.tolist()
1677 gross 157 t_prog+=" res=%s(%a1%)\n"%f.nickname
1678 jgs 154 t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh)
1679     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1680     elif args=="Symbol":
1681     t_prog+=" arg=Symbol(shape=%s)\n"%str(sh)
1682 gross 157 t_prog+=" v=%s(%a1%)\n"%f.nickname
1683 jgs 154 t_prog+=" self.failUnlessRaises(ValueError,v.substitute,Symbol(shape=(1,1)),\"illegal shape of substitute not identified.\")\n"
1684 gross 157 a=makeArray(sh,f[RANGE])
1685 jgs 154 r=makeResult(a,f)
1686     if len(sh)==0:
1687     t_prog+=" res=v.substitute({arg : %s})\n"%a
1688     t_prog+=" ref=%s\n"%r
1689     t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1690     else:
1691     t_prog+=" res=v.substitute({arg : numarray.array(%s)})\n"%a.tolist()
1692     t_prog+=" ref=numarray.array(%s)\n"%r.tolist()
1693     t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of substitution result.\")\n"%str(sh)
1694     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1695    
1696     if len(sh)==0:
1697     t_prog+=" # test derivative with respect to itself:\n"
1698     t_prog+=" dvdv=v.diff(v)\n"
1699     t_prog+=" self.failUnlessEqual(dvdv,1.,\"derivative with respect to self is not 1.\")\n"
1700     elif len(sh)==1:
1701     t_prog+=" # test derivative with respect to itself:\n"
1702     t_prog+=" dvdv=v.diff(v)\n"
1703     t_prog+=" self.failUnlessEqual(dvdv.shape,%s,\"shape of derivative with respect is wrong\")\n"%str(sh+sh)
1704     for i0_l in range(sh[0]):
1705     for i0_r in range(sh[0]):
1706     if i0_l == i0_r:
1707     v=1.
1708     else:
1709     v=0.
1710     t_prog+=" self.failUnlessEqual(dvdv[%s,%s],%s,\"derivative with respect to self: [%s,%s] is not %s\")\n"%(i0_l,i0_r,v,i0_l,i0_r,v)
1711     elif len(sh)==2:
1712     t_prog+=" # test derivative with respect to itself:\n"
1713     t_prog+=" dvdv=v.diff(v)\n"
1714     t_prog+=" self.failUnlessEqual(dvdv.shape,%s,\"shape of derivative with respect is wrong\")\n"%str(sh+sh)
1715     for i0_l in range(sh[0]):
1716     for i0_r in range(sh[0]):
1717     for i1_l in range(sh[1]):
1718     for i1_r in range(sh[1]):
1719     if i0_l == i0_r and i1_l == i1_r:
1720     v=1.
1721     else:
1722     v=0.
1723     t_prog+=" self.failUnlessEqual(dvdv[%s,%s,%s,%s],%s,\"derivative with respect to self: [%s,%s,%s,%s] is not %s\")\n"%(i0_l,i1_l,i0_r,i1_r,v,i0_l,i1_l,i0_r,i1_r,v)
1724    
1725     for sh_in in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1726     if len(sh_in)+len(sh)<=4:
1727    
1728     t_prog+=" # test derivative with shape %s as argument\n"%str(sh_in)
1729     trafo=makeArray(sh+sh_in,[0,1.])
1730 gross 157 a_in=makeArray(sh_in,f[RANGE])
1731 jgs 154 t_prog+=" arg_in=Symbol(shape=%s)\n"%str(sh_in)
1732     t_prog+=" arg2=Symbol(shape=%s)\n"%str(sh)
1733    
1734     if len(sh)==0:
1735     t_prog+=" arg2="
1736     if len(sh_in)==0:
1737     t_prog+="%s*arg_in\n"%trafo
1738     elif len(sh_in)==1:
1739     for i0 in range(sh_in[0]):
1740     if i0>0: t_prog+="+"
1741     t_prog+="%s*arg_in[%s]"%(trafo[i0],i0)
1742     t_prog+="\n"
1743     elif len(sh_in)==2:
1744     for i0 in range(sh_in[0]):
1745     for i1 in range(sh_in[1]):
1746     if i0+i1>0: t_prog+="+"
1747     t_prog+="%s*arg_in[%s,%s]"%(trafo[i0,i1],i0,i1)
1748    
1749     elif len(sh_in)==3:
1750     for i0 in range(sh_in[0]):
1751     for i1 in range(sh_in[1]):
1752     for i2 in range(sh_in[2]):
1753     if i0+i1+i2>0: t_prog+="+"
1754     t_prog+="%s*arg_in[%s,%s,%s]"%(trafo[i0,i1,i2],i0,i1,i2)
1755     elif len(sh_in)==4:
1756     for i0 in range(sh_in[0]):
1757     for i1 in range(sh_in[1]):
1758     for i2 in range(sh_in[2]):
1759     for i3 in range(sh_in[3]):
1760     if i0+i1+i2+i3>0: t_prog+="+"
1761     t_prog+="%s*arg_in[%s,%s,%s,%s]"%(trafo[i0,i1,i2,i3],i0,i1,i2,i3)
1762     t_prog+="\n"
1763     elif len(sh)==1:
1764     for j0 in range(sh[0]):
1765     t_prog+=" arg2[%s]="%j0
1766     if len(sh_in)==0:
1767     t_prog+="%s*arg_in"%trafo[j0]
1768     elif len(sh_in)==1:
1769     for i0 in range(sh_in[0]):
1770     if i0>0: t_prog+="+"
1771     t_prog+="%s*arg_in[%s]"%(trafo[j0,i0],i0)
1772     elif len(sh_in)==2:
1773     for i0 in range(sh_in[0]):
1774     for i1 in range(sh_in[1]):
1775     if i0+i1>0: t_prog+="+"
1776     t_prog+="%s*arg_in[%s,%s]"%(trafo[j0,i0,i1],i0,i1)
1777     elif len(sh_in)==3:
1778     for i0 in range(sh_in[0]):
1779     for i1 in range(sh_in[1]):
1780     for i2 in range(sh_in[2]):
1781     if i0+i1+i2>0: t_prog+="+"
1782     t_prog+="%s*arg_in[%s,%s,%s]"%(trafo[j0,i0,i1,i2],i0,i1,i2)
1783     t_prog+="\n"
1784     elif len(sh)==2:
1785     for j0 in range(sh[0]):
1786     for j1 in range(sh[1]):
1787     t_prog+=" arg2[%s,%s]="%(j0,j1)
1788     if len(sh_in)==0:
1789     t_prog+="%s*arg_in"%trafo[j0,j1]
1790     elif len(sh_in)==1:
1791     for i0 in range(sh_in[0]):
1792     if i0>0: t_prog+="+"
1793     t_prog+="%s*arg_in[%s]"%(trafo[j0,j1,i0],i0)
1794     elif len(sh_in)==2:
1795     for i0 in range(sh_in[0]):
1796     for i1 in range(sh_in[1]):
1797     if i0+i1>0: t_prog+="+"
1798     t_prog+="%s*arg_in[%s,%s]"%(trafo[j0,j1,i0,i1],i0,i1)
1799     t_prog+="\n"
1800     elif len(sh)==3:
1801     for j0 in range(sh[0]):
1802     for j1 in range(sh[1]):
1803     for j2 in range(sh[2]):
1804     t_prog+=" arg2[%s,%s,%s]="%(j0,j1,j2)
1805     if len(sh_in)==0:
1806     t_prog+="%s*arg_in"%trafo[j0,j1,j2]
1807     elif len(sh_in)==1:
1808     for i0 in range(sh_in[0]):
1809     if i0>0: t_prog+="+"
1810     t_prog+="%s*arg_in[%s]"%(trafo[j0,j1,j2,i0],i0)
1811     t_prog+="\n"
1812     elif len(sh)==4:
1813     for j0 in range(sh[0]):
1814     for j1 in range(sh[1]):
1815     for j2 in range(sh[2]):
1816     for j3 in range(sh[3]):
1817     t_prog+=" arg2[%s,%s,%s,%s]="%(j0,j1,j2,j3)
1818     if len(sh_in)==0:
1819     t_prog+="%s*arg_in"%trafo[j0,j1,j2,j3]
1820     t_prog+="\n"
1821     t_prog+=" dvdin=v.substitute({arg : arg2}).diff(arg_in)\n"
1822     if len(sh_in)==0:
1823     t_prog+=" res_in=dvdin.substitute({arg_in : %s})\n"%a_in
1824     else:
1825     t_prog+=" res_in=dvdin.substitute({arg : numarray.array(%s)})\n"%a_in.tolist()
1826    
1827     if len(sh)==0:
1828     if len(sh_in)==0:
1829     ref_diff=(makeResult(trafo*a_in+finc,f)-makeResult(trafo*a_in,f))/finc
1830     t_prog+=" self.failUnlessAlmostEqual(dvdin,%s,self.places,\"%s-derivative: wrong derivative\")\n"%(ref_diff,str(sh_in))
1831     elif len(sh_in)==1:
1832     s=0
1833     for k0 in range(sh_in[0]):
1834     s+=trafo[k0]*a_in[k0]
1835     for i0 in range(sh_in[0]):
1836     ref_diff=(makeResult(s+trafo[i0]*finc,f)-makeResult(s,f))/finc
1837     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,ref_diff,str(sh_in),str(i0))
1838     elif len(sh_in)==2:
1839     s=0
1840     for k0 in range(sh_in[0]):
1841     for k1 in range(sh_in[1]):
1842     s+=trafo[k0,k1]*a_in[k0,k1]
1843     for i0 in range(sh_in[0]):
1844     for i1 in range(sh_in[1]):
1845     ref_diff=(makeResult(s+trafo[i0,i1]*finc,f)-makeResult(s,f))/finc
1846     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,ref_diff,str(sh_in),str((i0,i1)))
1847    
1848     elif len(sh_in)==3:
1849     s=0
1850     for k0 in range(sh_in[0]):
1851     for k1 in range(sh_in[1]):
1852     for k2 in range(sh_in[2]):
1853     s+=trafo[k0,k1,k2]*a_in[k0,k1,k2]
1854     for i0 in range(sh_in[0]):
1855     for i1 in range(sh_in[1]):
1856     for i2 in range(sh_in[2]):
1857     ref_diff=(makeResult(s+trafo[i0,i1,i2]*finc,f)-makeResult(s,f))/finc
1858     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,i2,ref_diff,str(sh_in),str((i0,i1,i2)))
1859     elif len(sh_in)==4:
1860     s=0
1861     for k0 in range(sh_in[0]):
1862     for k1 in range(sh_in[1]):
1863     for k2 in range(sh_in[2]):
1864     for k3 in range(sh_in[3]):
1865     s+=trafo[k0,k1,k2,k3]*a_in[k0,k1,k2,k3]
1866     for i0 in range(sh_in[0]):
1867     for i1 in range(sh_in[1]):
1868     for i2 in range(sh_in[2]):
1869     for i3 in range(sh_in[3]):
1870     ref_diff=(makeResult(s+trafo[i0,i1,i2,i3]*finc,f)-makeResult(s,f))/finc
1871     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,i2,i3,ref_diff,str(sh_in),str((i0,i1,i2,i3)))
1872     elif len(sh)==1:
1873     for j0 in range(sh[0]):
1874     if len(sh_in)==0:
1875     ref_diff=(makeResult(trafo[j0]*a_in+finc,f)-makeResult(trafo[j0]*a_in,f))/finc
1876     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,ref_diff,str(sh_in),j0)
1877     elif len(sh_in)==1:
1878     s=0
1879     for k0 in range(sh_in[0]):
1880     s+=trafo[j0,k0]*a_in[k0]
1881     for i0 in range(sh_in[0]):
1882     ref_diff=(makeResult(