/[escript]/trunk/escript/py_src/generateutil
ViewVC logotype

Annotation of /trunk/escript/py_src/generateutil

Parent Directory Parent Directory | Revision Log Revision Log


Revision 429 - (hide annotations)
Wed Jan 11 05:53:40 2006 UTC (13 years, 9 months ago) by gross
File size: 101320 byte(s)
new implementation and testing for trace function
1 jgs 154 #!/usr/bin/python
2 gross 283 # $Id$
3 jgs 154
4     """
5     program generates parts of the util.py and the test_util.py script
6     """
7 gross 157 test_header=""
8     test_header+="import unittest\n"
9     test_header+="import numarray\n"
10     test_header+="from esys.escript import *\n"
11     test_header+="from esys.finley import Rectangle\n"
12     test_header+="class Test_util2(unittest.TestCase):\n"
13     test_header+=" RES_TOL=1.e-7\n"
14     test_header+=" def setUp(self):\n"
15     test_header+=" self.__dom =Rectangle(11,11,2)\n"
16     test_header+=" self.functionspace = FunctionOnBoundary(self.__dom)\n"
17     test_tail=""
18     test_tail+="suite = unittest.TestSuite()\n"
19     test_tail+="suite.addTest(unittest.makeSuite(Test_util2))\n"
20     test_tail+="unittest.TextTestRunner(verbosity=2).run(suite)\n"
21    
22     case_set=["float","array","constData","taggedData","expandedData","Symbol"]
23     shape_set=[ (),(2,), (4,5), (6,2,2),(3,2,3,4)]
24    
25 jgs 154 t_prog=""
26 gross 157 t_prog_with_tags=""
27     t_prog_failing=""
28     u_prog=""
29 jgs 154
30 gross 157 def wherepos(arg):
31     if arg>0.:
32     return 1.
33     else:
34     return 0.
35    
36 gross 313
37 gross 157 class OPERATOR:
38     def __init__(self,nickname,rng=[-1000.,1000],test_expr="",math_expr=None,
39     numarray_expr="",symbol_expr=None,diff=None,name=""):
40     self.nickname=nickname
41     self.rng=rng
42     self.test_expr=test_expr
43     if math_expr==None:
44     self.math_expr=test_expr.replace("%a1%","arg")
45     else:
46     self.math_expr=math_expr
47     self.numarray_expr=numarray_expr
48     self.symbol_expr=symbol_expr
49     self.diff=diff
50     self.name=name
51    
52 jgs 154 import random
53     import numarray
54     import math
55     finc=1.e-6
56    
57 gross 157 def getResultCaseForBin(case0,case1):
58     if case0=="float":
59     if case1=="float":
60     case="float"
61     elif case1=="array":
62     case="array"
63     elif case1=="constData":
64     case="constData"
65     elif case1=="taggedData":
66     case="taggedData"
67     elif case1=="expandedData":
68     case="expandedData"
69     elif case1=="Symbol":
70     case="Symbol"
71     else:
72     raise ValueError,"unknown case1=%s"%case1
73     elif case0=="array":
74     if case1=="float":
75     case="array"
76     elif case1=="array":
77     case="array"
78     elif case1=="constData":
79     case="constData"
80     elif case1=="taggedData":
81     case="taggedData"
82     elif case1=="expandedData":
83     case="expandedData"
84     elif case1=="Symbol":
85     case="Symbol"
86     else:
87     raise ValueError,"unknown case1=%s"%case1
88     elif case0=="constData":
89     if case1=="float":
90     case="constData"
91     elif case1=="array":
92     case="constData"
93     elif case1=="constData":
94     case="constData"
95     elif case1=="taggedData":
96     case="taggedData"
97     elif case1=="expandedData":
98     case="expandedData"
99     elif case1=="Symbol":
100     case="Symbol"
101     else:
102     raise ValueError,"unknown case1=%s"%case1
103     elif case0=="taggedData":
104     if case1=="float":
105     case="taggedData"
106     elif case1=="array":
107     case="taggedData"
108     elif case1=="constData":
109     case="taggedData"
110     elif case1=="taggedData":
111     case="taggedData"
112     elif case1=="expandedData":
113     case="expandedData"
114     elif case1=="Symbol":
115     case="Symbol"
116     else:
117     raise ValueError,"unknown case1=%s"%case1
118     elif case0=="expandedData":
119     if case1=="float":
120     case="expandedData"
121     elif case1=="array":
122     case="expandedData"
123     elif case1=="constData":
124     case="expandedData"
125     elif case1=="taggedData":
126     case="expandedData"
127     elif case1=="expandedData":
128     case="expandedData"
129     elif case1=="Symbol":
130     case="Symbol"
131     else:
132     raise ValueError,"unknown case1=%s"%case1
133     elif case0=="Symbol":
134     if case1=="float":
135     case="Symbol"
136     elif case1=="array":
137     case="Symbol"
138     elif case1=="constData":
139     case="Symbol"
140     elif case1=="taggedData":
141     case="Symbol"
142     elif case1=="expandedData":
143     case="Symbol"
144     elif case1=="Symbol":
145     case="Symbol"
146     else:
147     raise ValueError,"unknown case1=%s"%case1
148     else:
149     raise ValueError,"unknown case0=%s"%case0
150     return case
151 jgs 154
152 gross 157
153 jgs 154 def makeArray(shape,rng):
154     l=rng[1]-rng[0]
155     out=numarray.zeros(shape,numarray.Float64)
156     if len(shape)==0:
157 gross 157 out=l*random.random()+rng[0]
158 jgs 154 elif len(shape)==1:
159     for i0 in range(shape[0]):
160     out[i0]=l*random.random()+rng[0]
161     elif len(shape)==2:
162     for i0 in range(shape[0]):
163     for i1 in range(shape[1]):
164     out[i0,i1]=l*random.random()+rng[0]
165     elif len(shape)==3:
166     for i0 in range(shape[0]):
167     for i1 in range(shape[1]):
168     for i2 in range(shape[2]):
169     out[i0,i1,i2]=l*random.random()+rng[0]
170     elif len(shape)==4:
171     for i0 in range(shape[0]):
172     for i1 in range(shape[1]):
173     for i2 in range(shape[2]):
174     for i3 in range(shape[3]):
175     out[i0,i1,i2,i3]=l*random.random()+rng[0]
176     else:
177     raise SystemError,"rank is restricted to 4"
178     return out
179    
180    
181 gross 157 def makeResult(val,test_expr):
182 jgs 154 if isinstance(val,float):
183 gross 157 out=eval(test_expr.replace("%a1%","val"))
184 jgs 154 elif len(val.shape)==0:
185 gross 157 out=eval(test_expr.replace("%a1%","val"))
186 jgs 154 elif len(val.shape)==1:
187     out=numarray.zeros(val.shape,numarray.Float64)
188     for i0 in range(val.shape[0]):
189 gross 157 out[i0]=eval(test_expr.replace("%a1%","val[i0]"))
190 jgs 154 elif len(val.shape)==2:
191     out=numarray.zeros(val.shape,numarray.Float64)
192     for i0 in range(val.shape[0]):
193     for i1 in range(val.shape[1]):
194 gross 157 out[i0,i1]=eval(test_expr.replace("%a1%","val[i0,i1]"))
195 jgs 154 elif len(val.shape)==3:
196     out=numarray.zeros(val.shape,numarray.Float64)
197     for i0 in range(val.shape[0]):
198     for i1 in range(val.shape[1]):
199     for i2 in range(val.shape[2]):
200 gross 157 out[i0,i1,i2]=eval(test_expr.replace("%a1%","val[i0,i1,i2]"))
201 jgs 154 elif len(val.shape)==4:
202     out=numarray.zeros(val.shape,numarray.Float64)
203     for i0 in range(val.shape[0]):
204     for i1 in range(val.shape[1]):
205     for i2 in range(val.shape[2]):
206     for i3 in range(val.shape[3]):
207 gross 157 out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val[i0,i1,i2,i3]"))
208 jgs 154 else:
209     raise SystemError,"rank is restricted to 4"
210     return out
211    
212 gross 157 def makeResult2(val0,val1,test_expr):
213     if isinstance(val0,float):
214     if isinstance(val1,float):
215     out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
216     elif len(val1.shape)==0:
217     out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
218     elif len(val1.shape)==1:
219     out=numarray.zeros(val1.shape,numarray.Float64)
220     for i0 in range(val1.shape[0]):
221     out[i0]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0]"))
222     elif len(val1.shape)==2:
223     out=numarray.zeros(val1.shape,numarray.Float64)
224     for i0 in range(val1.shape[0]):
225     for i1 in range(val1.shape[1]):
226     out[i0,i1]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1]"))
227     elif len(val1.shape)==3:
228     out=numarray.zeros(val1.shape,numarray.Float64)
229     for i0 in range(val1.shape[0]):
230     for i1 in range(val1.shape[1]):
231     for i2 in range(val1.shape[2]):
232     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2]"))
233     elif len(val1.shape)==4:
234     out=numarray.zeros(val1.shape,numarray.Float64)
235     for i0 in range(val1.shape[0]):
236     for i1 in range(val1.shape[1]):
237     for i2 in range(val1.shape[2]):
238     for i3 in range(val1.shape[3]):
239     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2,i3]"))
240     else:
241     raise SystemError,"rank of val1 is restricted to 4"
242     elif len(val0.shape)==0:
243     if isinstance(val1,float):
244     out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
245     elif len(val1.shape)==0:
246     out=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1"))
247     elif len(val1.shape)==1:
248     out=numarray.zeros(val1.shape,numarray.Float64)
249     for i0 in range(val1.shape[0]):
250     out[i0]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0]"))
251     elif len(val1.shape)==2:
252     out=numarray.zeros(val1.shape,numarray.Float64)
253     for i0 in range(val1.shape[0]):
254     for i1 in range(val1.shape[1]):
255     out[i0,i1]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1]"))
256     elif len(val1.shape)==3:
257     out=numarray.zeros(val1.shape,numarray.Float64)
258     for i0 in range(val1.shape[0]):
259     for i1 in range(val1.shape[1]):
260     for i2 in range(val1.shape[2]):
261     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2]"))
262     elif len(val1.shape)==4:
263     out=numarray.zeros(val1.shape,numarray.Float64)
264     for i0 in range(val1.shape[0]):
265     for i1 in range(val1.shape[1]):
266     for i2 in range(val1.shape[2]):
267     for i3 in range(val1.shape[3]):
268     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0").replace("%a2%","val1[i0,i1,i2,i3]"))
269     else:
270     raise SystemError,"rank of val1 is restricted to 4"
271     elif len(val0.shape)==1:
272     if isinstance(val1,float):
273     out=numarray.zeros(val0.shape,numarray.Float64)
274     for i0 in range(val0.shape[0]):
275     out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1"))
276     elif len(val1.shape)==0:
277     out=numarray.zeros(val0.shape,numarray.Float64)
278     for i0 in range(val0.shape[0]):
279     out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1"))
280     elif len(val1.shape)==1:
281     out=numarray.zeros(val0.shape,numarray.Float64)
282     for i0 in range(val0.shape[0]):
283     out[i0]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0]"))
284     elif len(val1.shape)==2:
285     out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64)
286     for i0 in range(val0.shape[0]):
287     for j1 in range(val1.shape[1]):
288     out[i0,j1]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1]"))
289     elif len(val1.shape)==3:
290     out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64)
291     for i0 in range(val0.shape[0]):
292     for j1 in range(val1.shape[1]):
293     for j2 in range(val1.shape[2]):
294     out[i0,j1,j2]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1,j2]"))
295     elif len(val1.shape)==4:
296     out=numarray.zeros(val0.shape+val1.shape[1:],numarray.Float64)
297     for i0 in range(val0.shape[0]):
298     for j1 in range(val1.shape[1]):
299     for j2 in range(val1.shape[2]):
300     for j3 in range(val1.shape[3]):
301     out[i0,j1,j2,j3]=eval(test_expr.replace("%a1%","val0[i0]").replace("%a2%","val1[i0,j1,j2,j3]"))
302     else:
303     raise SystemError,"rank of val1 is restricted to 4"
304     elif len(val0.shape)==2:
305     if isinstance(val1,float):
306     out=numarray.zeros(val0.shape,numarray.Float64)
307     for i0 in range(val0.shape[0]):
308     for i1 in range(val0.shape[1]):
309     out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1"))
310     elif len(val1.shape)==0:
311     out=numarray.zeros(val0.shape,numarray.Float64)
312     for i0 in range(val0.shape[0]):
313     for i1 in range(val0.shape[1]):
314     out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1"))
315     elif len(val1.shape)==1:
316     out=numarray.zeros(val0.shape,numarray.Float64)
317     for i0 in range(val0.shape[0]):
318     for i1 in range(val0.shape[1]):
319     out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0]"))
320     elif len(val1.shape)==2:
321     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
322     for i0 in range(val0.shape[0]):
323     for i1 in range(val0.shape[1]):
324     out[i0,i1]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1]"))
325     elif len(val1.shape)==3:
326     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
327     for i0 in range(val0.shape[0]):
328     for i1 in range(val0.shape[1]):
329     for j2 in range(val1.shape[2]):
330     out[i0,i1,j2]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1,j2]"))
331     elif len(val1.shape)==4:
332     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
333     for i0 in range(val0.shape[0]):
334     for i1 in range(val0.shape[1]):
335     for j2 in range(val1.shape[2]):
336     for j3 in range(val1.shape[3]):
337     out[i0,i1,j2,j3]=eval(test_expr.replace("%a1%","val0[i0,i1]").replace("%a2%","val1[i0,i1,j2,j3]"))
338     else:
339     raise SystemError,"rank of val1 is restricted to 4"
340     elif len(val0.shape)==3:
341     if isinstance(val1,float):
342     out=numarray.zeros(val0.shape,numarray.Float64)
343     for i0 in range(val0.shape[0]):
344     for i1 in range(val0.shape[1]):
345     for i2 in range(val0.shape[2]):
346     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1"))
347     elif len(val1.shape)==0:
348     out=numarray.zeros(val0.shape,numarray.Float64)
349     for i0 in range(val0.shape[0]):
350     for i1 in range(val0.shape[1]):
351     for i2 in range(val0.shape[2]):
352     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1"))
353     elif len(val1.shape)==1:
354     out=numarray.zeros(val0.shape,numarray.Float64)
355     for i0 in range(val0.shape[0]):
356     for i1 in range(val0.shape[1]):
357     for i2 in range(val0.shape[2]):
358     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0]"))
359     elif len(val1.shape)==2:
360     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
361     for i0 in range(val0.shape[0]):
362     for i1 in range(val0.shape[1]):
363     for i2 in range(val0.shape[2]):
364     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1]"))
365     elif len(val1.shape)==3:
366     out=numarray.zeros(val0.shape,numarray.Float64)
367     for i0 in range(val0.shape[0]):
368     for i1 in range(val0.shape[1]):
369     for i2 in range(val0.shape[2]):
370     out[i0,i1,i2]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1,i2]"))
371     elif len(val1.shape)==4:
372     out=numarray.zeros(val0.shape+val1.shape[3:],numarray.Float64)
373     for i0 in range(val0.shape[0]):
374     for i1 in range(val0.shape[1]):
375     for i2 in range(val0.shape[2]):
376     for j3 in range(val1.shape[3]):
377     out[i0,i1,i2,j3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2]").replace("%a2%","val1[i0,i1,i2,j3]"))
378     else:
379     raise SystemError,"rank of val1 is rargs[1]estricted to 4"
380     elif len(val0.shape)==4:
381     if isinstance(val1,float):
382     out=numarray.zeros(val0.shape,numarray.Float64)
383     for i0 in range(val0.shape[0]):
384     for i1 in range(val0.shape[1]):
385     for i2 in range(val0.shape[2]):
386     for i3 in range(val0.shape[3]):
387     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1"))
388     elif len(val1.shape)==0:
389     out=numarray.zeros(val0.shape,numarray.Float64)
390     for i0 in range(val0.shape[0]):
391     for i1 in range(val0.shape[1]):
392     for i2 in range(val0.shape[2]):
393     for i3 in range(val0.shape[3]):
394     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1"))
395     elif len(val1.shape)==1:
396     out=numarray.zeros(val0.shape,numarray.Float64)
397     for i0 in range(val0.shape[0]):
398     for i1 in range(val0.shape[1]):
399     for i2 in range(val0.shape[2]):
400     for i3 in range(val0.shape[3]):
401     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0]"))
402     elif len(val1.shape)==2:
403     out=numarray.zeros(val0.shape+val1.shape[2:],numarray.Float64)
404     for i0 in range(val0.shape[0]):
405     for i1 in range(val0.shape[1]):
406     for i2 in range(val0.shape[2]):
407     for i3 in range(val0.shape[3]):
408     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1]"))
409     elif len(val1.shape)==3:
410     out=numarray.zeros(val0.shape,numarray.Float64)
411     for i0 in range(val0.shape[0]):
412     for i1 in range(val0.shape[1]):
413     for i2 in range(val0.shape[2]):
414     for i3 in range(val0.shape[3]):
415     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1,i2]"))
416     elif len(val1.shape)==4:
417     out=numarray.zeros(val0.shape,numarray.Float64)
418     for i0 in range(val0.shape[0]):
419     for i1 in range(val0.shape[1]):
420     for i2 in range(val0.shape[2]):
421     for i3 in range(val0.shape[3]):
422     out[i0,i1,i2,i3]=eval(test_expr.replace("%a1%","val0[i0,i1,i2,i3]").replace("%a2%","val1[i0,i1,i2,i3]"))
423     else:
424     raise SystemError,"rank of val1 is restricted to 4"
425     else:
426     raise SystemError,"rank is restricted to 4"
427     return out
428 jgs 154
429 gross 157
430     def mkText(case,name,a,a1=None,use_tagging_for_expanded_data=False):
431 jgs 154 t_out=""
432 gross 157 if case=="float":
433 jgs 154 if isinstance(a,float):
434     t_out+=" %s=%s\n"%(name,a)
435 gross 291 elif a.rank==0:
436 jgs 154 t_out+=" %s=%s\n"%(name,a)
437     else:
438     t_out+=" %s=numarray.array(%s)\n"%(name,a.tolist())
439 gross 157 elif case=="array":
440     if isinstance(a,float):
441     t_out+=" %s=numarray.array(%s)\n"%(name,a)
442 gross 291 elif a.rank==0:
443 gross 157 t_out+=" %s=numarray.array(%s)\n"%(name,a)
444     else:
445     t_out+=" %s=numarray.array(%s)\n"%(name,a.tolist())
446 jgs 154 elif case=="constData":
447     if isinstance(a,float):
448     t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
449 gross 291 elif a.rank==0:
450 jgs 154 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
451     else:
452     t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
453     elif case=="taggedData":
454     if isinstance(a,float):
455     t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
456     t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
457 gross 291 elif a.rank==0:
458 jgs 154 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
459     t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
460     else:
461     t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
462 gross 157 t_out+=" %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist())
463 jgs 154 elif case=="expandedData":
464 gross 157 if use_tagging_for_expanded_data:
465     if isinstance(a,float):
466     t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
467     t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
468 gross 291 elif a.rank==0:
469 gross 157 t_out+=" %s=Data(%s,self.functionspace)\n"%(name,a)
470     t_out+=" %s.setTaggedValue(1,%s)\n"%(name,a1)
471     else:
472     t_out+=" %s=Data(numarray.array(%s),self.functionspace)\n"%(name,a.tolist())
473     t_out+=" %s.setTaggedValue(1,numarray.array(%s))\n"%(name,a1.tolist())
474     t_out+=" %s.expand()\n"%name
475     else:
476     t_out+=" msk_%s=whereNegative(self.functionspace.getX()[0]-0.5)\n"%name
477     if isinstance(a,float):
478     t_out+=" %s=msk_%s*(%s)+(1.-msk_%s)*(%s)\n"%(name,name,a,name,a1)
479 gross 291 elif a.rank==0:
480 gross 157 t_out+=" %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a,name,a1)
481     else:
482     t_out+=" %s=msk_%s*numarray.array(%s)+(1.-msk_%s)*numarray.array(%s)\n"%(name,name,a.tolist(),name,a1.tolist())
483 jgs 154 elif case=="Symbol":
484     if isinstance(a,float):
485     t_out+=" %s=Symbol(shape=())\n"%(name)
486 gross 291 elif a.rank==0:
487 jgs 154 t_out+=" %s=Symbol(shape=())\n"%(name)
488     else:
489     t_out+=" %s=Symbol(shape=%s)\n"%(name,str(a.shape))
490    
491     return t_out
492    
493 gross 157 def mkTypeAndShapeTest(case,sh,argstr):
494     text=""
495     if case=="float":
496     text+=" self.failUnless(isinstance(%s,float),\"wrong type of result.\")\n"%argstr
497     elif case=="array":
498     text+=" self.failUnless(isinstance(%s,numarray.NumArray),\"wrong type of result.\")\n"%argstr
499     text+=" self.failUnlessEqual(%s.shape,%s,\"wrong shape of result.\")\n"%(argstr,str(sh))
500     elif case in ["constData","taggedData","expandedData"]:
501     text+=" self.failUnless(isinstance(%s,Data),\"wrong type of result.\")\n"%argstr
502     text+=" self.failUnlessEqual(%s.getShape(),%s,\"wrong shape of result.\")\n"%(argstr,str(sh))
503     elif case=="Symbol":
504     text+=" self.failUnless(isinstance(%s,Symbol),\"wrong type of result.\")\n"%argstr
505     text+=" self.failUnlessEqual(%s.getShape(),%s,\"wrong shape of result.\")\n"%(argstr,str(sh))
506     return text
507 jgs 154
508 gross 157 def mkCode(txt,args=[],intend=""):
509     s=txt.split("\n")
510     if len(s)>1:
511     out=""
512     for l in s:
513     out+=intend+l+"\n"
514     else:
515     out="%sreturn %s\n"%(intend,txt)
516     c=1
517     for r in args:
518     out=out.replace("%%a%s%%"%c,r)
519     return out
520 jgs 154
521 gross 291 def innerTEST(arg0,arg1):
522     if isinstance(arg0,float):
523     out=numarray.array(arg0*arg1)
524     else:
525     out=(arg0*arg1).sum()
526     return out
527    
528     def outerTEST(arg0,arg1):
529     if isinstance(arg0,float):
530     out=numarray.array(arg0*arg1)
531     elif isinstance(arg1,float):
532     out=numarray.array(arg0*arg1)
533     else:
534     out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape)
535     return out
536    
537     def tensorProductTest(arg0,arg1,sh_s):
538     if isinstance(arg0,float):
539     out=numarray.array(arg0*arg1)
540     elif isinstance(arg1,float):
541     out=numarray.array(arg0*arg1)
542     elif len(sh_s)==0:
543     out=numarray.outerproduct(arg0,arg1).resize(arg0.shape+arg1.shape)
544     else:
545     l=len(sh_s)
546     sh0=arg0.shape[:arg0.rank-l]
547     sh1=arg1.shape[l:]
548     ls,l0,l1=1,1,1
549     for i in sh_s: ls*=i
550     for i in sh0: l0*=i
551     for i in sh1: l1*=i
552     out1=numarray.outerproduct(arg0,arg1).resize((l0,ls,ls,l1))
553     out2=numarray.zeros((l0,l1),numarray.Float)
554     for i0 in range(l0):
555     for i1 in range(l1):
556     for i in range(ls): out2[i0,i1]+=out1[i0,i,i,i1]
557     out=out2.resize(sh0+sh1)
558     return out
559    
560     def testMatrixMult(arg0,arg1,sh_s):
561     return numarray.matrixmultiply(arg0,arg1)
562    
563    
564     def testTensorMult(arg0,arg1,sh_s):
565     if len(arg0)==2:
566     return numarray.matrixmultiply(arg0,arg1)
567     else:
568     if arg1.rank==4:
569     out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2],arg1.shape[3]),numarray.Float)
570     for i0 in range(arg0.shape[0]):
571     for i1 in range(arg0.shape[1]):
572     for i2 in range(arg0.shape[2]):
573     for i3 in range(arg0.shape[3]):
574     for j2 in range(arg1.shape[2]):
575     for j3 in range(arg1.shape[3]):
576     out[i0,i1,j2,j3]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2,j3]
577     elif arg1.rank==3:
578     out=numarray.zeros((arg0.shape[0],arg0.shape[1],arg1.shape[2]),numarray.Float)
579     for i0 in range(arg0.shape[0]):
580     for i1 in range(arg0.shape[1]):
581     for i2 in range(arg0.shape[2]):
582     for i3 in range(arg0.shape[3]):
583     for j2 in range(arg1.shape[2]):
584     out[i0,i1,j2]+=arg0[i0,i1,i2,i3]*arg1[i2,i3,j2]
585     elif arg1.rank==2:
586     out=numarray.zeros((arg0.shape[0],arg0.shape[1]),numarray.Float)
587     for i0 in range(arg0.shape[0]):
588     for i1 in range(arg0.shape[1]):
589     for i2 in range(arg0.shape[2]):
590     for i3 in range(arg0.shape[3]):
591     out[i0,i1]+=arg0[i0,i1,i2,i3]*arg1[i2,i3]
592     return out
593 gross 313
594     def testReduce(arg0,init_val,test_expr,post_expr):
595     out=init_val
596     if isinstance(arg0,float):
597     out=eval(test_expr.replace("%a1%","arg0"))
598     elif arg0.rank==0:
599     out=eval(test_expr.replace("%a1%","arg0"))
600     elif arg0.rank==1:
601     for i0 in range(arg0.shape[0]):
602     out=eval(test_expr.replace("%a1%","arg0[i0]"))
603     elif arg0.rank==2:
604     for i0 in range(arg0.shape[0]):
605     for i1 in range(arg0.shape[1]):
606     out=eval(test_expr.replace("%a1%","arg0[i0,i1]"))
607     elif arg0.rank==3:
608     for i0 in range(arg0.shape[0]):
609     for i1 in range(arg0.shape[1]):
610     for i2 in range(arg0.shape[2]):
611     out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2]"))
612     elif arg0.rank==4:
613     for i0 in range(arg0.shape[0]):
614     for i1 in range(arg0.shape[1]):
615     for i2 in range(arg0.shape[2]):
616     for i3 in range(arg0.shape[3]):
617     out=eval(test_expr.replace("%a1%","arg0[i0,i1,i2,i3]"))
618     return eval(post_expr)
619 gross 396
620     def clipTEST(arg0,mn,mx):
621     if isinstance(arg0,float):
622     return max(min(arg0,mx),mn)
623     out=numarray.zeros(arg0.shape,numarray.Float64)
624     if arg0.rank==1:
625     for i0 in range(arg0.shape[0]):
626     out[i0]=max(min(arg0[i0],mx),mn)
627     elif arg0.rank==2:
628     for i0 in range(arg0.shape[0]):
629     for i1 in range(arg0.shape[1]):
630     out[i0,i1]=max(min(arg0[i0,i1],mx),mn)
631     elif arg0.rank==3:
632     for i0 in range(arg0.shape[0]):
633     for i1 in range(arg0.shape[1]):
634     for i2 in range(arg0.shape[2]):
635     out[i0,i1,i2]=max(min(arg0[i0,i1,i2],mx),mn)
636     elif arg0.rank==4:
637     for i0 in range(arg0.shape[0]):
638     for i1 in range(arg0.shape[1]):
639     for i2 in range(arg0.shape[2]):
640     for i3 in range(arg0.shape[3]):
641     out[i0,i1,i2,i3]=max(min(arg0[i0,i1,i2,i3],mx),mn)
642     return out
643     def minimumTEST(arg0,arg1):
644     if isinstance(arg0,float):
645     if isinstance(arg1,float):
646     if arg0>arg1:
647     return arg1
648     else:
649     return arg0
650     else:
651     arg0=numarray.ones(arg1.shape)*arg0
652     else:
653     if isinstance(arg1,float):
654     arg1=numarray.ones(arg0.shape)*arg1
655     out=numarray.zeros(arg0.shape,numarray.Float64)
656     if arg0.rank==0:
657     if arg0>arg1:
658     out=arg1
659     else:
660     out=arg0
661     elif arg0.rank==1:
662     for i0 in range(arg0.shape[0]):
663     if arg0[i0]>arg1[i0]:
664     out[i0]=arg1[i0]
665     else:
666     out[i0]=arg0[i0]
667     elif arg0.rank==2:
668     for i0 in range(arg0.shape[0]):
669     for i1 in range(arg0.shape[1]):
670     if arg0[i0,i1]>arg1[i0,i1]:
671     out[i0,i1]=arg1[i0,i1]
672     else:
673     out[i0,i1]=arg0[i0,i1]
674     elif arg0.rank==3:
675     for i0 in range(arg0.shape[0]):
676     for i1 in range(arg0.shape[1]):
677     for i2 in range(arg0.shape[2]):
678     if arg0[i0,i1,i2]>arg1[i0,i1,i2]:
679     out[i0,i1,i2]=arg1[i0,i1,i2]
680     else:
681     out[i0,i1,i2]=arg0[i0,i1,i2]
682     elif arg0.rank==4:
683     for i0 in range(arg0.shape[0]):
684     for i1 in range(arg0.shape[1]):
685     for i2 in range(arg0.shape[2]):
686     for i3 in range(arg0.shape[3]):
687     if arg0[i0,i1,i2,i3]>arg1[i0,i1,i2,i3]:
688     out[i0,i1,i2,i3]=arg1[i0,i1,i2,i3]
689     else:
690     out[i0,i1,i2,i3]=arg0[i0,i1,i2,i3]
691     return out
692 gross 429 #=======================================================================================================
693     # trace
694     #=======================================================================================================
695     def traceTest(r,offset):
696     sh=r.shape
697     r1=1
698     for i in range(offset): r1*=sh[i]
699     r2=1
700     for i in range(offset+2,len(sh)): r2*=sh[i]
701     r_s=numarray.reshape(r,(r1,sh[offset],sh[offset],r2))
702     s=numarray.zeros([r1,r2],numarray.Float)
703     for i1 in range(r1):
704     for i2 in range(r2):
705     for j in range(sh[offset]): s[i1,i2]+=r_s[i1,j,j,i2]
706     return s.resize(sh[:offset]+sh[offset+2:])
707     name,tt="trace",traceTest
708     for case0 in ["array","Symbol","constData","taggedData","expandedData"]:
709     for sh0 in [ (4,5), (6,2,2),(3,2,3,4)]:
710     for offset in range(len(sh0)-1):
711     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
712     tname="test_%s_%s_rank%s_offset%s"%(name,case0,len(sh0),offset)
713     text+=" def %s(self):\n"%tname
714     sh_t=list(sh0)
715     sh_t[offset+1]=sh_t[offset]
716     sh_t=tuple(sh_t)
717     sh_r=[]
718     for i in range(offset): sh_r.append(sh0[i])
719     for i in range(offset+2,len(sh0)): sh_r.append(sh0[i])
720     sh_r=tuple(sh_r)
721     a_0=makeArray(sh_t,[-1.,1])
722     if case0 in ["taggedData", "expandedData"]:
723     a1_0=makeArray(sh_t,[-1.,1])
724     else:
725     a1_0=a_0
726     r=tt(a_0,offset)
727     r1=tt(a1_0,offset)
728     text+=mkText(case0,"arg",a_0,a1_0)
729     text+=" res=%s(arg,%s)\n"%(name,offset)
730     if case0=="Symbol":
731     text+=mkText("array","s",a_0,a1_0)
732     text+=" sub=res.substitute({arg:s})\n"
733     res="sub"
734     text+=mkText("array","ref",r,r1)
735     else:
736     res="res"
737     text+=mkText(case0,"ref",r,r1)
738     text+=mkTypeAndShapeTest(case0,sh_r,"res")
739     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
740    
741     if case0 == "taggedData" :
742     t_prog_with_tags+=text
743     else:
744     t_prog+=text
745 gross 396
746 gross 429 print test_header
747     # print t_prog
748     print t_prog_with_tags
749     print test_tail
750     1/0
751 gross 396
752 gross 157 #=======================================================================================================
753 gross 396 # clip
754     #=======================================================================================================
755     oper_L=[["clip",clipTEST]]
756     for oper in oper_L:
757     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
758     for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
759     if len(sh0)==0 or not case0=="float":
760     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
761     tname="test_%s_%s_rank%s"%(oper[0],case0,len(sh0))
762     text+=" def %s(self):\n"%tname
763     a_0=makeArray(sh0,[-1.,1])
764     if case0 in ["taggedData", "expandedData"]:
765     a1_0=makeArray(sh0,[-1.,1])
766     else:
767     a1_0=a_0
768    
769     r=oper[1](a_0,-0.3,0.5)
770     r1=oper[1](a1_0,-0.3,0.5)
771     text+=mkText(case0,"arg",a_0,a1_0)
772     text+=" res=%s(arg,-0.3,0.5)\n"%oper[0]
773     if case0=="Symbol":
774     text+=mkText("array","s",a_0,a1_0)
775     text+=" sub=res.substitute({arg:s})\n"
776     res="sub"
777     text+=mkText("array","ref",r,r1)
778     else:
779     res="res"
780     text+=mkText(case0,"ref",r,r1)
781     text+=mkTypeAndShapeTest(case0,sh0,"res")
782     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
783    
784     if case0 == "taggedData" :
785     t_prog_with_tags+=text
786     else:
787     t_prog+=text
788    
789     print test_header
790     # print t_prog
791     print t_prog_with_tags
792     print test_tail
793     1/0
794 gross 429
795 gross 396 #=======================================================================================================
796     # maximum, minimum, clipping
797     #=======================================================================================================
798     oper_L=[ ["maximum",maximumTEST],
799     ["minimum",minimumTEST]]
800     for oper in oper_L:
801     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
802     for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
803     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
804     for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
805     if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \
806     and (sh0==sh1 or len(sh0)==0 or len(sh1)==0) :
807     use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
808    
809     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
810     tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
811     text+=" def %s(self):\n"%tname
812     a_0=makeArray(sh0,[-1.,1])
813     if case0 in ["taggedData", "expandedData"]:
814     a1_0=makeArray(sh0,[-1.,1])
815     else:
816     a1_0=a_0
817    
818     a_1=makeArray(sh1,[-1.,1])
819     if case1 in ["taggedData", "expandedData"]:
820     a1_1=makeArray(sh1,[-1.,1])
821     else:
822     a1_1=a_1
823     r=oper[1](a_0,a_1)
824     r1=oper[1](a1_0,a1_1)
825     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
826     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
827     text+=" res=%s(arg0,arg1)\n"%oper[0]
828     case=getResultCaseForBin(case0,case1)
829     if case=="Symbol":
830     c0_res,c1_res=case0,case1
831     subs="{"
832     if case0=="Symbol":
833     text+=mkText("array","s0",a_0,a1_0)
834     subs+="arg0:s0"
835     c0_res="array"
836     if case1=="Symbol":
837     text+=mkText("array","s1",a_1,a1_1)
838     if not subs.endswith("{"): subs+=","
839     subs+="arg1:s1"
840     c1_res="array"
841     subs+="}"
842     text+=" sub=res.substitute(%s)\n"%subs
843     res="sub"
844     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
845     else:
846     res="res"
847     text+=mkText(case,"ref",r,r1)
848     if len(sh0)>len(sh1):
849     text+=mkTypeAndShapeTest(case,sh0,"res")
850     else:
851     text+=mkTypeAndShapeTest(case,sh1,"res")
852     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
853    
854     if case0 == "taggedData" or case1 == "taggedData":
855     t_prog_with_tags+=text
856     else:
857     t_prog+=text
858    
859     print test_header
860     # print t_prog
861     print t_prog_with_tags
862     print test_tail
863     1/0
864    
865    
866     #=======================================================================================================
867     # outer inner
868     #=======================================================================================================
869     oper=["outer",outerTEST]
870     # oper=["inner",innerTEST]
871     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
872     for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
873     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
874     for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
875     if (len(sh0)==0 or not case0=="float") and (len(sh1)==0 or not case1=="float") \
876     and len(sh0+sh1)<5:
877     use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
878    
879     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
880     tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
881     text+=" def %s(self):\n"%tname
882     a_0=makeArray(sh0,[-1.,1])
883     if case0 in ["taggedData", "expandedData"]:
884     a1_0=makeArray(sh0,[-1.,1])
885     else:
886     a1_0=a_0
887    
888     a_1=makeArray(sh1,[-1.,1])
889     if case1 in ["taggedData", "expandedData"]:
890     a1_1=makeArray(sh1,[-1.,1])
891     else:
892     a1_1=a_1
893     r=oper[1](a_0,a_1)
894     r1=oper[1](a1_0,a1_1)
895     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
896     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
897     text+=" res=%s(arg0,arg1)\n"%oper[0]
898     case=getResultCaseForBin(case0,case1)
899     if case=="Symbol":
900     c0_res,c1_res=case0,case1
901     subs="{"
902     if case0=="Symbol":
903     text+=mkText("array","s0",a_0,a1_0)
904     subs+="arg0:s0"
905     c0_res="array"
906     if case1=="Symbol":
907     text+=mkText("array","s1",a_1,a1_1)
908     if not subs.endswith("{"): subs+=","
909     subs+="arg1:s1"
910     c1_res="array"
911     subs+="}"
912     text+=" sub=res.substitute(%s)\n"%subs
913     res="sub"
914     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
915     else:
916     res="res"
917     text+=mkText(case,"ref",r,r1)
918     text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
919     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
920    
921     if case0 == "taggedData" or case1 == "taggedData":
922     t_prog_with_tags+=text
923     else:
924     t_prog+=text
925    
926     print test_header
927     # print t_prog
928     print t_prog_with_tags
929     print test_tail
930     1/0
931    
932     #=======================================================================================================
933 gross 313 # local reduction
934     #=======================================================================================================
935     for oper in [["length",0.,"out+%a1%**2","math.sqrt(out)"],
936     ["maxval",-1.e99,"max(out,%a1%)","out"],
937     ["minval",1.e99,"min(out,%a1%)","out"] ]:
938     for case in case_set:
939     for sh in shape_set:
940     if not case=="float" or len(sh)==0:
941     text=""
942     text+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
943     tname="def test_%s_%s_rank%s"%(oper[0],case,len(sh))
944     text+=" %s(self):\n"%tname
945     a=makeArray(sh,[-1.,1.])
946     a1=makeArray(sh,[-1.,1.])
947     r1=testReduce(a1,oper[1],oper[2],oper[3])
948     r=testReduce(a,oper[1],oper[2],oper[3])
949    
950     text+=mkText(case,"arg",a,a1)
951     text+=" res=%s(arg)\n"%oper[0]
952     if case=="Symbol":
953     text+=mkText("array","s",a,a1)
954     text+=" sub=res.substitute({arg:s})\n"
955     text+=mkText("array","ref",r,r1)
956     res="sub"
957     else:
958     text+=mkText(case,"ref",r,r1)
959     res="res"
960     if oper[0]=="length":
961     text+=mkTypeAndShapeTest(case,(),"res")
962     else:
963     if case=="float" or case=="array":
964     text+=mkTypeAndShapeTest("float",(),"res")
965     else:
966     text+=mkTypeAndShapeTest(case,(),"res")
967     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
968     if case == "taggedData":
969     t_prog_with_tags+=text
970     else:
971     t_prog+=text
972     print test_header
973     # print t_prog
974     print t_prog_with_tags
975     print test_tail
976     1/0
977    
978     #=======================================================================================================
979 gross 291 # tensor multiply
980 gross 157 #=======================================================================================================
981 gross 291 # oper=["generalTensorProduct",tensorProductTest]
982     # oper=["matrixmult",testMatrixMult]
983     oper=["tensormult",testTensorMult]
984    
985     for case0 in ["float","array","Symbol","constData","taggedData","expandedData"]:
986     for sh0 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
987     for case1 in ["float","array","Symbol","constData","taggedData","expandedData"]:
988     for sh1 in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
989     for sh_s in [ (),(3,), (2,3), (2,4,3),(4,2,3,2)]:
990     if (len(sh0+sh_s)==0 or not case0=="float") and (len(sh1+sh_s)==0 or not case1=="float") \
991     and len(sh0+sh1)<5 and len(sh0+sh_s)<5 and len(sh1+sh_s)<5:
992     # if len(sh_s)==1 and len(sh0+sh_s)==2 and (len(sh_s+sh1)==1 or len(sh_s+sh1)==2)): # test for matrixmult
993     if ( len(sh_s)==1 and len(sh0+sh_s)==2 and ( len(sh1+sh_s)==2 or len(sh1+sh_s)==1 )) or (len(sh_s)==2 and len(sh0+sh_s)==4 and (len(sh1+sh_s)==2 or len(sh1+sh_s)==3 or len(sh1+sh_s)==4)): # test for tensormult
994     case=getResultCaseForBin(case0,case1)
995 gross 157 use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
996     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
997 gross 291 # tname="test_generalTensorProduct_%s_rank%s_%s_rank%s_offset%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1),len(sh_s))
998     #tname="test_matrixmult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
999     tname="test_tensormult_%s_rank%s_%s_rank%s"%(case0,len(sh0+sh_s),case1,len(sh_s+sh1))
1000     # if tname=="test_generalTensorProduct_array_rank1_array_rank2_offset1":
1001     # print tnametest_generalTensorProduct_Symbol_rank1_Symbol_rank3_offset1
1002     text+=" def %s(self):\n"%tname
1003     a_0=makeArray(sh0+sh_s,[-1.,1])
1004     if case0 in ["taggedData", "expandedData"]:
1005     a1_0=makeArray(sh0+sh_s,[-1.,1])
1006     else:
1007     a1_0=a_0
1008    
1009     a_1=makeArray(sh_s+sh1,[-1.,1])
1010     if case1 in ["taggedData", "expandedData"]:
1011     a1_1=makeArray(sh_s+sh1,[-1.,1])
1012     else:
1013     a1_1=a_1
1014     r=oper[1](a_0,a_1,sh_s)
1015     r1=oper[1](a1_0,a1_1,sh_s)
1016     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1017     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1018     #text+=" res=matrixmult(arg0,arg1)\n"
1019     text+=" res=tensormult(arg0,arg1)\n"
1020     #text+=" res=generalTensorProduct(arg0,arg1,offset=%s)\n"%(len(sh_s))
1021     if case=="Symbol":
1022     c0_res,c1_res=case0,case1
1023     subs="{"
1024     if case0=="Symbol":
1025     text+=mkText("array","s0",a_0,a1_0)
1026     subs+="arg0:s0"
1027     c0_res="array"
1028     if case1=="Symbol":
1029     text+=mkText("array","s1",a_1,a1_1)
1030     if not subs.endswith("{"): subs+=","
1031     subs+="arg1:s1"
1032     c1_res="array"
1033     subs+="}"
1034     text+=" sub=res.substitute(%s)\n"%subs
1035     res="sub"
1036     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1037     else:
1038     res="res"
1039     text+=mkText(case,"ref",r,r1)
1040     text+=mkTypeAndShapeTest(case,sh0+sh1,"res")
1041     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1042     if case0 == "taggedData" or case1 == "taggedData":
1043     t_prog_with_tags+=text
1044     else:
1045     t_prog+=text
1046     print test_header
1047     # print t_prog
1048     print t_prog_with_tags
1049     print test_tail
1050     1/0
1051     #=======================================================================================================
1052 gross 157 # basic binary operation overloading (tests only!)
1053     #=======================================================================================================
1054     oper_range=[-5.,5.]
1055     for oper in [["add" ,"+",[-5.,5.]],
1056     ["sub" ,"-",[-5.,5.]],
1057     ["mult","*",[-5.,5.]],
1058     ["div" ,"/",[-5.,5.]],
1059     ["pow" ,"**",[0.01,5.]]]:
1060     for case0 in case_set:
1061     for sh0 in shape_set:
1062     for case1 in case_set:
1063     for sh1 in shape_set:
1064 gross 291 if not case0=="array" and \
1065     (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
1066 gross 157 (sh0==() or sh1==() or sh1==sh0) and \
1067     not (case0 in ["float","array"] and case1 in ["float","array"]):
1068 gross 291 use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1069 gross 157 text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1070     tname="test_%s_overloaded_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
1071     text+=" def %s(self):\n"%tname
1072     a_0=makeArray(sh0,oper[2])
1073     if case0 in ["taggedData", "expandedData"]:
1074     a1_0=makeArray(sh0,oper[2])
1075     else:
1076     a1_0=a_0
1077 jgs 154
1078 gross 157 a_1=makeArray(sh1,oper[2])
1079     if case1 in ["taggedData", "expandedData"]:
1080     a1_1=makeArray(sh1,oper[2])
1081 jgs 154 else:
1082 gross 157 a1_1=a_1
1083     r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
1084     r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
1085     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1086     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1087     text+=" res=arg0%sarg1\n"%oper[1]
1088    
1089     case=getResultCaseForBin(case0,case1)
1090     if case=="Symbol":
1091     c0_res,c1_res=case0,case1
1092     subs="{"
1093     if case0=="Symbol":
1094     text+=mkText("array","s0",a_0,a1_0)
1095     subs+="arg0:s0"
1096     c0_res="array"
1097     if case1=="Symbol":
1098     text+=mkText("array","s1",a_1,a1_1)
1099     if not subs.endswith("{"): subs+=","
1100     subs+="arg1:s1"
1101     c1_res="array"
1102     subs+="}"
1103     text+=" sub=res.substitute(%s)\n"%subs
1104     res="sub"
1105     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1106 jgs 154 else:
1107 gross 157 res="res"
1108     text+=mkText(case,"ref",r,r1)
1109     if isinstance(r,float):
1110     text+=mkTypeAndShapeTest(case,(),"res")
1111     else:
1112     text+=mkTypeAndShapeTest(case,r.shape,"res")
1113     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1114    
1115     if case0 in [ "constData","taggedData","expandedData"] and case1 == "Symbol":
1116     t_prog_failing+=text
1117     else:
1118     if case0 == "taggedData" or case1 == "taggedData":
1119     t_prog_with_tags+=text
1120     else:
1121     t_prog+=text
1122 jgs 154
1123 gross 157
1124     print test_header
1125 gross 291 # print t_prog
1126     # print t_prog_with_tags
1127     print t_prog_failing
1128     print test_tail
1129 jgs 154 1/0
1130 gross 291 #=======================================================================================================
1131     # basic binary operations (tests only!)
1132     #=======================================================================================================
1133     oper_range=[-5.,5.]
1134     for oper in [["add" ,"+",[-5.,5.]],
1135     ["mult","*",[-5.,5.]],
1136     ["quotient" ,"/",[-5.,5.]],
1137     ["power" ,"**",[0.01,5.]]]:
1138     for case0 in case_set:
1139     for case1 in case_set:
1140     for sh in shape_set:
1141     for sh_p in shape_set:
1142     if len(sh_p)>0:
1143     resource=[-1,1]
1144     else:
1145     resource=[1]
1146     for sh_d in resource:
1147     if sh_d>0:
1148     sh0=sh
1149     sh1=sh+sh_p
1150     else:
1151     sh1=sh
1152     sh0=sh+sh_p
1153    
1154     if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
1155     len(sh0)<5 and len(sh1)<5:
1156     use_tagging_for_expanded_data= case0=="taggedData" or case1=="taggedData"
1157     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1158     tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
1159     text+=" def %s(self):\n"%tname
1160     a_0=makeArray(sh0,oper[2])
1161     if case0 in ["taggedData", "expandedData"]:
1162     a1_0=makeArray(sh0,oper[2])
1163     else:
1164     a1_0=a_0
1165    
1166     a_1=makeArray(sh1,oper[2])
1167     if case1 in ["taggedData", "expandedData"]:
1168     a1_1=makeArray(sh1,oper[2])
1169     else:
1170     a1_1=a_1
1171     r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
1172     r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
1173     text+=mkText(case0,"arg0",a_0,a1_0,use_tagging_for_expanded_data)
1174     text+=mkText(case1,"arg1",a_1,a1_1,use_tagging_for_expanded_data)
1175     text+=" res=%s(arg0,arg1)\n"%oper[0]
1176    
1177     case=getResultCaseForBin(case0,case1)
1178     if case=="Symbol":
1179     c0_res,c1_res=case0,case1
1180     subs="{"
1181     if case0=="Symbol":
1182     text+=mkText("array","s0",a_0,a1_0)
1183     subs+="arg0:s0"
1184     c0_res="array"
1185     if case1=="Symbol":
1186     text+=mkText("array","s1",a_1,a1_1)
1187     if not subs.endswith("{"): subs+=","
1188     subs+="arg1:s1"
1189     c1_res="array"
1190     subs+="}"
1191     text+=" sub=res.substitute(%s)\n"%subs
1192     res="sub"
1193     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1194     else:
1195     res="res"
1196     text+=mkText(case,"ref",r,r1)
1197     if isinstance(r,float):
1198     text+=mkTypeAndShapeTest(case,(),"res")
1199     else:
1200     text+=mkTypeAndShapeTest(case,r.shape,"res")
1201     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1202    
1203     if case0 == "taggedData" or case1 == "taggedData":
1204     t_prog_with_tags+=text
1205     else:
1206     t_prog+=text
1207     print test_header
1208     # print t_prog
1209     print t_prog_with_tags
1210     print test_tail
1211     1/0
1212    
1213 gross 157 # print t_prog_with_tagsoper_range=[-5.,5.]
1214     for oper in [["add" ,"+",[-5.,5.]],
1215     ["sub" ,"-",[-5.,5.]],
1216     ["mult","*",[-5.,5.]],
1217     ["div" ,"/",[-5.,5.]],
1218     ["pow" ,"**",[0.01,5.]]]:
1219     for case0 in case_set:
1220     for sh0 in shape_set:
1221     for case1 in case_set:
1222     for sh1 in shape_set:
1223     if (not case0=="float" or len(sh0)==0) and (not case1=="float" or len(sh1)==0) and \
1224     (sh0==() or sh1==() or sh1==sh0) and \
1225     not (case0 in ["float","array"] and case1 in ["float","array"]):
1226     text=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1227     tname="test_%s_%s_rank%s_%s_rank%s"%(oper[0],case0,len(sh0),case1,len(sh1))
1228     text+=" def %s(self):\n"%tname
1229     a_0=makeArray(sh0,oper[2])
1230     if case0 in ["taggedData", "expandedData"]:
1231     a1_0=makeArray(sh0,oper[2])
1232     else:
1233     a1_0=a_0
1234 jgs 154
1235 gross 157 a_1=makeArray(sh1,oper[2])
1236     if case1 in ["taggedData", "expandedData"]:
1237     a1_1=makeArray(sh1,oper[2])
1238 jgs 154 else:
1239 gross 157 a1_1=a_1
1240     r1=makeResult2(a1_0,a1_1,"%a1%"+oper[1]+"%a2%")
1241     r=makeResult2(a_0,a_1,"%a1%"+oper[1]+"%a2%")
1242     text+=mkText(case0,"arg0",a_0,a1_0)
1243     text+=mkText(case1,"arg1",a_1,a1_1)
1244     text+=" res=arg0%sarg1\n"%oper[1]
1245    
1246     case=getResultCaseForBin(case0,case1)
1247     if case=="Symbol":
1248     c0_res,c1_res=case0,case1
1249     subs="{"
1250     if case0=="Symbol":
1251     text+=mkText("array","s0",a_0,a1_0)
1252     subs+="arg0:s0"
1253     c0_res="array"
1254     if case1=="Symbol":
1255     text+=mkText("array","s1",a_1,a1_1)
1256     if not subs.endswith("{"): subs+=","
1257     subs+="arg1:s1"
1258     c1_res="array"
1259     subs+="}"
1260     text+=" sub=res.substitute(%s)\n"%subs
1261     res="sub"
1262     text+=mkText(getResultCaseForBin(c0_res,c1_res),"ref",r,r1)
1263 jgs 154 else:
1264 gross 157 res="res"
1265     text+=mkText(case,"ref",r,r1)
1266     if isinstance(r,float):
1267     text+=mkTypeAndShapeTest(case,(),"res")
1268     else:
1269     text+=mkTypeAndShapeTest(case,r.shape,"res")
1270     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1271    
1272     if case0 in [ "constData","taggedData","expandedData"] and case1 == "Symbol":
1273     t_prog_failing+=text
1274     else:
1275     if case0 == "taggedData" or case1 == "taggedData":
1276     t_prog_with_tags+=text
1277     else:
1278     t_prog+=text
1279 jgs 154
1280 gross 157
1281     # print u_prog
1282     # 1/0
1283     print test_header
1284 jgs 154 print t_prog
1285 gross 157 # print t_prog_with_tags
1286     # print t_prog_failing
1287     print test_tail
1288     # print t_prog_failing
1289     print test_tail
1290 jgs 154
1291 gross 157 #=======================================================================================================
1292     # unary operations:
1293     #=======================================================================================================
1294     func= [
1295     OPERATOR(nickname="log10",\
1296     rng=[1.e-3,100.],\
1297     test_expr="math.log10(%a1%)",\
1298     math_expr="math.log10(%a1%)",\
1299     numarray_expr="numarray.log10(%a1%)",\
1300     symbol_expr="log(%a1%)/log(10.)",\
1301     name="base-10 logarithm"),
1302     OPERATOR(nickname="wherePositive",\
1303     rng=[-100.,100.],\
1304     test_expr="wherepos(%a1%)",\
1305     math_expr="if arg>0:\n return 1.\nelse:\n return 0.",
1306     numarray_expr="numarray.greater(arg,numarray.zeros(arg.shape,numarray.Float))",\
1307     name="mask of positive values"),
1308     OPERATOR(nickname="whereNegative",\
1309     rng=[-100.,100.],\
1310     test_expr="wherepos(-%a1%)",\
1311     math_expr="if arg<0:\n return 1.\nelse:\n return 0.",
1312     numarray_expr="numarray.less(arg,numarray.zeros(arg.shape,numarray.Float))",\
1313     name="mask of positive values"),
1314     OPERATOR(nickname="whereNonNegative",\
1315     rng=[-100.,100.],\
1316     test_expr="1-wherepos(-%a1%)", \
1317     math_expr="if arg<0:\n return 0.\nelse:\n return 1.",
1318     numarray_expr="numarray.greater_equal(arg,numarray.zeros(arg.shape,numarray.Float))",\
1319     symbol_expr="1-wherePositive(%a1%)",\
1320     name="mask of non-negative values"),
1321     OPERATOR(nickname="whereNonPositive",\
1322     rng=[-100.,100.],\
1323     test_expr="1-wherepos(%a1%)",\
1324     math_expr="if arg>0:\n return 0.\nelse:\n return 1.",
1325     numarray_expr="numarray.less_equal(arg,numarray.zeros(arg.shape,numarray.Float))",\
1326     symbol_expr="1-whereNegative(%a1%)",\
1327     name="mask of non-positive values"),
1328     OPERATOR(nickname="whereZero",\
1329     rng=[-100.,100.],\
1330     test_expr="1-wherepos(%a1%)-wherepos(-%a1%)",\
1331     math_expr="if abs(%a1%)<=tol:\n return 1.\nelse:\n return 0.",
1332     numarray_expr="numarray.less_equal(abs(%a1%)-tol,numarray.zeros(arg.shape,numarray.Float))",\
1333     name="mask of zero entries"),
1334     OPERATOR(nickname="whereNonZero",\
1335     rng=[-100.,100.],\
1336     test_expr="wherepos(%a1%)+wherepos(-%a1%)",\
1337     math_expr="if abs(%a1%)>tol:\n return 1.\nelse:\n return 0.",\
1338     numarray_expr="numarray.greater(abs(%a1%)-tol,numarray.zeros(arg.shape,numarray.Float))",\
1339     symbol_expr="1-whereZero(arg,tol)",\
1340     name="mask of values different from zero"),
1341     OPERATOR(nickname="sin",\
1342     rng=[-100.,100.],\
1343     test_expr="math.sin(%a1%)",
1344     numarray_expr="numarray.sin(%a1%)",\
1345     diff="cos(%a1%)",\
1346     name="sine"),
1347     OPERATOR(nickname="cos",\
1348     rng=[-100.,100.],\
1349     test_expr="math.cos(%a1%)",
1350     numarray_expr="numarray.cos(%a1%)",\
1351     diff="-sin(%a1%)",
1352     name="cosine"),
1353     OPERATOR(nickname="tan",\
1354     rng=[-100.,100.],\
1355     test_expr="math.tan(%a1%)",
1356     numarray_expr="numarray.tan(%a1%)",\
1357     diff="1./cos(%a1%)**2",
1358     name="tangent"),
1359     OPERATOR(nickname="asin",\
1360     rng=[-0.99,0.99],\
1361     test_expr="math.asin(%a1%)",
1362     numarray_expr="numarray.arcsin(%a1%)",
1363     diff="1./sqrt(1.-%a1%**2)",
1364     name="inverse sine"),
1365     OPERATOR(nickname="acos",\
1366     rng=[-0.99,0.99],\
1367     test_expr="math.acos(%a1%)",
1368     numarray_expr="numarray.arccos(%a1%)",
1369     diff="-1./sqrt(1.-%a1%**2)",
1370     name="inverse cosine"),
1371     OPERATOR(nickname="atan",\
1372     rng=[-100.,100.],\
1373     test_expr="math.atan(%a1%)",
1374     numarray_expr="numarray.arctan(%a1%)",
1375     diff="1./(1+%a1%**2)",
1376     name="inverse tangent"),
1377     OPERATOR(nickname="sinh",\
1378     rng=[-5,5],\
1379     test_expr="math.sinh(%a1%)",
1380     numarray_expr="numarray.sinh(%a1%)",
1381     diff="cosh(%a1%)",
1382     name="hyperbolic sine"),
1383     OPERATOR(nickname="cosh",\
1384     rng=[-5.,5.],
1385     test_expr="math.cosh(%a1%)",
1386     numarray_expr="numarray.cosh(%a1%)",
1387     diff="sinh(%a1%)",
1388     name="hyperbolic cosine"),
1389     OPERATOR(nickname="tanh",\
1390     rng=[-5.,5.],
1391     test_expr="math.tanh(%a1%)",
1392     numarray_expr="numarray.tanh(%a1%)",
1393     diff="1./cosh(%a1%)**2",
1394     name="hyperbolic tangent"),
1395     OPERATOR(nickname="asinh",\
1396     rng=[-100.,100.], \
1397     test_expr="numarray.arcsinh(%a1%)",
1398     math_expr="numarray.arcsinh(%a1%)",
1399     numarray_expr="numarray.arcsinh(%a1%)",
1400     diff="1./sqrt(%a1%**2+1)",
1401     name="inverse hyperbolic sine"),
1402     OPERATOR(nickname="acosh",\
1403     rng=[1.001,100.],\
1404     test_expr="numarray.arccosh(%a1%)",
1405     math_expr="numarray.arccosh(%a1%)",
1406     numarray_expr="numarray.arccosh(%a1%)",
1407     diff="1./sqrt(%a1%**2-1)",
1408     name="inverse hyperolic cosine"),
1409     OPERATOR(nickname="atanh",\
1410     rng=[-0.99,0.99], \
1411     test_expr="numarray.arctanh(%a1%)",
1412     math_expr="numarray.arctanh(%a1%)",
1413     numarray_expr="numarray.arctanh(%a1%)",
1414     diff="1./(1.-%a1%**2)",
1415     name="inverse hyperbolic tangent"),
1416     OPERATOR(nickname="exp",\
1417     rng=[-5.,5.],
1418     test_expr="math.exp(%a1%)",
1419     numarray_expr="numarray.exp(%a1%)",
1420     diff="self",
1421     name="exponential"),
1422     OPERATOR(nickname="sqrt",\
1423     rng=[1.e-3,100.],\
1424     test_expr="math.sqrt(%a1%)",
1425     numarray_expr="numarray.sqrt(%a1%)",
1426     diff="0.5/self",
1427     name="square root"),
1428     OPERATOR(nickname="log", \
1429     rng=[1.e-3,100.],\
1430     test_expr="math.log(%a1%)",
1431     numarray_expr="numarray.log(%a1%)",
1432     diff="1./arg",
1433     name="natural logarithm"),
1434     OPERATOR(nickname="sign",\
1435     rng=[-100.,100.], \
1436     math_expr="if %a1%>0:\n return 1.\nelif %a1%<0:\n return -1.\nelse:\n return 0.",
1437     test_expr="wherepos(%a1%)-wherepos(-%a1%)",
1438     numarray_expr="numarray.sign(%a1%)",
1439     symbol_expr="wherePositive(%a1%)-whereNegative(%a1%)",\
1440     name="sign"),
1441     OPERATOR(nickname="abs",\
1442     rng=[-100.,100.], \
1443     math_expr="if %a1%>0:\n return %a1% \nelif %a1%<0:\n return -(%a1%)\nelse:\n return 0.",
1444     test_expr="wherepos(%a1%)*(%a1%)-wherepos(-%a1%)*(%a1%)",
1445     numarray_expr="abs(%a1%)",
1446     diff="sign(%a1%)",
1447     name="absolute value")
1448    
1449     ]
1450     for f in func:
1451     symbol_name=f.nickname[0].upper()+f.nickname[1:]
1452     if f.nickname!="abs":
1453     u_prog+="def %s(arg):\n"%f.nickname
1454     u_prog+=" \"\"\"\n"
1455     u_prog+=" returns %s of argument arg\n\n"%f.name
1456     u_prog+=" @param arg: argument\n"
1457     u_prog+=" @type arg: C{float}, L{escript.Data}, L{Symbol}, L{numarray.NumArray}.\n"
1458     u_prog+=" @rtype:C{float}, L{escript.Data}, L{Symbol}, L{numarray.NumArray} depending on the type of arg.\n"
1459     u_prog+=" @raises TypeError: if the type of the argument is not expected.\n"
1460     u_prog+=" \"\"\"\n"
1461     u_prog+=" if isinstance(arg,numarray.NumArray):\n"
1462     u_prog+=mkCode(f.numarray_expr,["arg"],2*" ")
1463     u_prog+=" elif isinstance(arg,escript.Data):\n"
1464     u_prog+=mkCode("arg._%s()"%f.nickname,[],2*" ")
1465     u_prog+=" elif isinstance(arg,float):\n"
1466     u_prog+=mkCode(f.math_expr,["arg"],2*" ")
1467     u_prog+=" elif isinstance(arg,int):\n"
1468     u_prog+=mkCode(f.math_expr,["float(arg)"],2*" ")
1469     u_prog+=" elif isinstance(arg,Symbol):\n"
1470     if f.symbol_expr==None:
1471     u_prog+=mkCode("%s_Symbol(arg)"%symbol_name,[],2*" ")
1472     else:
1473     u_prog+=mkCode(f.symbol_expr,["arg"],2*" ")
1474     u_prog+=" else:\n"
1475     u_prog+=" raise TypeError,\"%s: Unknown argument type.\"\n\n"%f.nickname
1476     if f.symbol_expr==None:
1477     u_prog+="class %s_Symbol(DependendSymbol):\n"%symbol_name
1478     u_prog+=" \"\"\"\n"
1479     u_prog+=" L{Symbol} representing the result of the %s function\n"%f.name
1480     u_prog+=" \"\"\"\n"
1481     u_prog+=" def __init__(self,arg):\n"
1482     u_prog+=" \"\"\"\n"
1483     u_prog+=" initialization of %s L{Symbol} with argument arg\n"%f.nickname
1484     u_prog+=" @param arg: argument of function\n"
1485     u_prog+=" @type arg: typically L{Symbol}.\n"
1486     u_prog+=" \"\"\"\n"
1487     u_prog+=" DependendSymbol.__init__(self,args=[arg],shape=arg.getShape(),dim=arg.getDim())\n"
1488     u_prog+="\n"
1489    
1490     u_prog+=" def getMyCode(self,argstrs,format=\"escript\"):\n"
1491     u_prog+=" \"\"\"\n"
1492     u_prog+=" returns a program code that can be used to evaluate the symbol.\n\n"
1493 jgs 154
1494 gross 157 u_prog+=" @param argstrs: gives for each argument a string representing the argument for the evaluation.\n"
1495     u_prog+=" @type argstrs: C{str} or a C{list} of length 1 of C{str}.\n"
1496     u_prog+=" @param format: specifies the format to be used. At the moment only \"escript\" ,\"text\" and \"str\" are supported.\n"
1497     u_prog+=" @type format: C{str}\n"
1498     u_prog+=" @return: a piece of program code which can be used to evaluate the expression assuming the values for the arguments are available.\n"
1499     u_prog+=" @rtype: C{str}\n"
1500     u_prog+=" @raise: NotImplementedError: if the requested format is not available\n"
1501     u_prog+=" \"\"\"\n"
1502     u_prog+=" if isinstance(argstrs,list):\n"
1503     u_prog+=" argstrs=argstrs[0]\n"
1504     u_prog+=" if format==\"escript\" or format==\"str\" or format==\"text\":\n"
1505     u_prog+=" return \"%s(%%s)\"%%argstrs\n"%f.nickname
1506     u_prog+=" else:\n"
1507     u_prog+=" raise NotImplementedError,\"%s_Symbol does not provide program code for format %%s.\"%%format\n"%symbol_name
1508     u_prog+="\n"
1509 jgs 154
1510 gross 157 u_prog+=" def substitute(self,argvals):\n"
1511     u_prog+=" \"\"\"\n"
1512     u_prog+=" assigns new values to symbols in the definition of the symbol.\n"
1513     u_prog+=" The method replaces the L{Symbol} u by argvals[u] in the expression defining this object.\n"
1514     u_prog+="\n"
1515     u_prog+=" @param argvals: new values assigned to symbols\n"
1516     u_prog+=" @type argvals: C{dict} with keywords of type L{Symbol}.\n"
1517     u_prog+=" @return: result of the substitution process. Operations are executed as much as possible.\n"
1518     u_prog+=" @rtype: L{escript.Symbol}, C{float}, L{escript.Data}, L{numarray.NumArray} depending on the degree of substitution\n"
1519     u_prog+=" @raise TypeError: if a value for a L{Symbol} cannot be substituted.\n"
1520     u_prog+=" \"\"\"\n"
1521     u_prog+=" if argvals.has_key(self):\n"
1522     u_prog+=" arg=argvals[self]\n"
1523     u_prog+=" if self.isAppropriateValue(arg):\n"
1524     u_prog+=" return arg\n"
1525     u_prog+=" else:\n"
1526     u_prog+=" raise TypeError,\"%s: new value is not appropriate.\"%str(self)\n"
1527     u_prog+=" else:\n"
1528     u_prog+=" arg=self.getSubstitutedArguments(argvals)[0]\n"
1529     u_prog+=" return %s(arg)\n\n"%f.nickname
1530     if not f.diff==None:
1531     u_prog+=" def diff(self,arg):\n"
1532     u_prog+=" \"\"\"\n"
1533     u_prog+=" differential of this object\n"
1534     u_prog+="\n"
1535     u_prog+=" @param arg: the derivative is calculated with respect to arg\n"
1536     u_prog+=" @type arg: L{escript.Symbol}\n"
1537     u_prog+=" @return: derivative with respect to C{arg}\n"
1538     u_prog+=" @rtype: typically L{Symbol} but other types such as C{float}, L{escript.Data}, L{numarray.NumArray} are possible.\n"
1539     u_prog+=" \"\"\"\n"
1540     u_prog+=" if arg==self:\n"
1541     u_prog+=" return identity(self.getShape())\n"
1542     u_prog+=" else:\n"
1543     u_prog+=" myarg=self.getArgument()[0]\n"
1544     u_prog+=" val=matchShape(%s,self.getDifferentiatedArguments(arg)[0])\n"%f.diff.replace("%a1%","myarg")
1545     u_prog+=" return val[0]*val[1]\n\n"
1546 jgs 154
1547 gross 157 for case in case_set:
1548     for sh in shape_set:
1549     if not case=="float" or len(sh)==0:
1550     text=""
1551     text+=" #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
1552     tname="def test_%s_%s_rank%s"%(f.nickname,case,len(sh))
1553     text+=" %s(self):\n"%tname
1554     a=makeArray(sh,f.rng)
1555     a1=makeArray(sh,f.rng)
1556     r1=makeResult(a1,f.test_expr)
1557     r=makeResult(a,f.test_expr)
1558    
1559     text+=mkText(case,"arg",a,a1)
1560     text+=" res=%s(arg)\n"%f.nickname
1561     if case=="Symbol":
1562     text+=mkText("array","s",a,a1)
1563     text+=" sub=res.substitute({arg:s})\n"
1564     text+=mkText("array","ref",r,r1)
1565     res="sub"
1566     else:
1567     text+=mkText(case,"ref",r,r1)
1568     res="res"
1569     text+=mkTypeAndShapeTest(case,sh,"res")
1570     text+=" self.failUnless(Lsup(%s-ref)<=self.RES_TOL*Lsup(ref),\"wrong result\")\n"%res
1571     if case == "taggedData":
1572     t_prog_with_tags+=text
1573     else:
1574     t_prog+=text
1575    
1576     #=========== END OF GOOD CODE +++++++++++++++++++++++++++
1577 jgs 154
1578 gross 157 1/0
1579 jgs 154
1580 gross 157 def X():
1581     if args=="float":
1582     a=makeArray(sh,f[RANGE])
1583 jgs 154 r=makeResult(a,f)
1584 gross 157 t_prog+=" arg=%s\n"%a[0]
1585     t_prog+=" ref=%s\n"%r[0]
1586     t_prog+=" res=%s(%a1%)\n"%f.nickname
1587     t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1588     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1589     elif args == "array":
1590     a=makeArray(sh,f[RANGE])
1591     r=makeResult(a,f)
1592 jgs 154 if len(sh)==0:
1593 gross 157 t_prog+=" arg=numarray.array(%s)\n"%a[0]
1594     t_prog+=" ref=numarray.array(%s)\n"%r[0]
1595 jgs 154 else:
1596     t_prog+=" arg=numarray.array(%s)\n"%a.tolist()
1597     t_prog+=" ref=numarray.array(%s)\n"%r.tolist()
1598 gross 157 t_prog+=" res=%s(%a1%)\n"%f.nickname
1599     t_prog+=" self.failUnlessEqual(res.shape,%s,\"wrong shape of result.\")\n"%str(sh)
1600     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1601 jgs 154 elif args== "constData":
1602 gross 157 a=makeArray(sh,f[RANGE])
1603 jgs 154 r=makeResult(a,f)
1604     if len(sh)==0:
1605     t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
1606     t_prog+=" ref=%s\n"%r
1607     else:
1608     t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist())
1609     t_prog+=" ref=numarray.array(%s)\n"%r.tolist()
1610 gross 157 t_prog+=" res=%s(%a1%)\n"%f.nickname
1611 jgs 154 t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh)
1612     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1613     elif args in [ "taggedData","expandedData"]:
1614 gross 157 a=makeArray(sh,f[RANGE])
1615 jgs 154 r=makeResult(a,f)
1616 gross 157 a1=makeArray(sh,f[RANGE])
1617 jgs 154 r1=makeResult(a1,f)
1618     if len(sh)==0:
1619     if args=="expandedData":
1620     t_prog+=" arg=Data(%s,self.functionspace,True)\n"%(a)
1621     t_prog+=" ref=Data(%s,self.functionspace,True)\n"%(r)
1622     else:
1623     t_prog+=" arg=Data(%s,self.functionspace)\n"%(a)
1624     t_prog+=" ref=Data(%s,self.functionspace)\n"%(r)
1625     t_prog+=" arg.setTaggedValue(1,%s)\n"%a
1626     t_prog+=" ref.setTaggedValue(1,%s)\n"%r1
1627     else:
1628     if args=="expandedData":
1629     t_prog+=" arg=Data(numarray.array(%s),self.functionspace,True)\n"%(a.tolist())
1630     t_prog+=" ref=Data(numarray.array(%s),self.functionspace,True)\n"%(r.tolist())
1631     else:
1632     t_prog+=" arg=Data(numarray.array(%s),self.functionspace)\n"%(a.tolist())
1633     t_prog+=" ref=Data(numarray.array(%s),self.functionspace)\n"%(r.tolist())
1634     t_prog+=" arg.setTaggedValue(1,%s)\n"%a1.tolist()
1635     t_prog+=" ref.setTaggedValue(1,%s)\n"%r1.tolist()
1636 gross 157 t_prog+=" res=%s(%a1%)\n"%f.nickname
1637 jgs 154 t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of result.\")\n"%str(sh)
1638     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1639     elif args=="Symbol":
1640     t_prog+=" arg=Symbol(shape=%s)\n"%str(sh)
1641 gross 157 t_prog+=" v=%s(%a1%)\n"%f.nickname
1642 jgs 154 t_prog+=" self.failUnlessRaises(ValueError,v.substitute,Symbol(shape=(1,1)),\"illegal shape of substitute not identified.\")\n"
1643 gross 157 a=makeArray(sh,f[RANGE])
1644 jgs 154 r=makeResult(a,f)
1645     if len(sh)==0:
1646     t_prog+=" res=v.substitute({arg : %s})\n"%a
1647     t_prog+=" ref=%s\n"%r
1648     t_prog+=" self.failUnless(isinstance(res,float),\"wrong type of result.\")\n"
1649     else:
1650     t_prog+=" res=v.substitute({arg : numarray.array(%s)})\n"%a.tolist()
1651     t_prog+=" ref=numarray.array(%s)\n"%r.tolist()
1652     t_prog+=" self.failUnlessEqual(res.getShape(),%s,\"wrong shape of substitution result.\")\n"%str(sh)
1653     t_prog+=" self.failUnless(Lsup(res-ref)<=self.tol*Lsup(ref),\"wrong result\")\n"
1654    
1655     if len(sh)==0:
1656     t_prog+=" # test derivative with respect to itself:\n"
1657     t_prog+=" dvdv=v.diff(v)\n"
1658     t_prog+=" self.failUnlessEqual(dvdv,1.,\"derivative with respect to self is not 1.\")\n"
1659     elif len(sh)==1:
1660     t_prog+=" # test derivative with respect to itself:\n"
1661     t_prog+=" dvdv=v.diff(v)\n"
1662     t_prog+=" self.failUnlessEqual(dvdv.shape,%s,\"shape of derivative with respect is wrong\")\n"%str(sh+sh)
1663     for i0_l in range(sh[0]):
1664     for i0_r in range(sh[0]):
1665     if i0_l == i0_r:
1666     v=1.
1667     else:
1668     v=0.
1669     t_prog+=" self.failUnlessEqual(dvdv[%s,%s],%s,\"derivative with respect to self: [%s,%s] is not %s\")\n"%(i0_l,i0_r,v,i0_l,i0_r,v)
1670     elif len(sh)==2:
1671     t_prog+=" # test derivative with respect to itself:\n"
1672     t_prog+=" dvdv=v.diff(v)\n"
1673     t_prog+=" self.failUnlessEqual(dvdv.shape,%s,\"shape of derivative with respect is wrong\")\n"%str(sh+sh)
1674     for i0_l in range(sh[0]):
1675     for i0_r in range(sh[0]):
1676     for i1_l in range(sh[1]):
1677     for i1_r in range(sh[1]):
1678     if i0_l == i0_r and i1_l == i1_r:
1679     v=1.
1680     else:
1681     v=0.
1682     t_prog+=" self.failUnlessEqual(dvdv[%s,%s,%s,%s],%s,\"derivative with respect to self: [%s,%s,%s,%s] is not %s\")\n"%(i0_l,i1_l,i0_r,i1_r,v,i0_l,i1_l,i0_r,i1_r,v)
1683    
1684     for sh_in in [ (),(2,), (4,5), (6,2,2),(3,2,3,4)]:
1685     if len(sh_in)+len(sh)<=4:
1686    
1687     t_prog+=" # test derivative with shape %s as argument\n"%str(sh_in)
1688     trafo=makeArray(sh+sh_in,[0,1.])
1689 gross 157 a_in=makeArray(sh_in,f[RANGE])
1690 jgs 154 t_prog+=" arg_in=Symbol(shape=%s)\n"%str(sh_in)
1691     t_prog+=" arg2=Symbol(shape=%s)\n"%str(sh)
1692    
1693     if len(sh)==0:
1694     t_prog+=" arg2="
1695     if len(sh_in)==0:
1696     t_prog+="%s*arg_in\n"%trafo
1697     elif len(sh_in)==1:
1698     for i0 in range(sh_in[0]):
1699     if i0>0: t_prog+="+"
1700     t_prog+="%s*arg_in[%s]"%(trafo[i0],i0)
1701     t_prog+="\n"
1702     elif len(sh_in)==2:
1703     for i0 in range(sh_in[0]):
1704     for i1 in range(sh_in[1]):
1705     if i0+i1>0: t_prog+="+"
1706     t_prog+="%s*arg_in[%s,%s]"%(trafo[i0,i1],i0,i1)
1707    
1708     elif len(sh_in)==3:
1709     for i0 in range(sh_in[0]):
1710     for i1 in range(sh_in[1]):
1711     for i2 in range(sh_in[2]):
1712     if i0+i1+i2>0: t_prog+="+"
1713     t_prog+="%s*arg_in[%s,%s,%s]"%(trafo[i0,i1,i2],i0,i1,i2)
1714     elif len(sh_in)==4:
1715     for i0 in range(sh_in[0]):
1716     for i1 in range(sh_in[1]):
1717     for i2 in range(sh_in[2]):
1718     for i3 in range(sh_in[3]):
1719     if i0+i1+i2+i3>0: t_prog+="+"
1720     t_prog+="%s*arg_in[%s,%s,%s,%s]"%(trafo[i0,i1,i2,i3],i0,i1,i2,i3)
1721     t_prog+="\n"
1722     elif len(sh)==1:
1723     for j0 in range(sh[0]):
1724     t_prog+=" arg2[%s]="%j0
1725     if len(sh_in)==0:
1726     t_prog+="%s*arg_in"%trafo[j0]
1727     elif len(sh_in)==1:
1728     for i0 in range(sh_in[0]):
1729     if i0>0: t_prog+="+"
1730     t_prog+="%s*arg_in[%s]"%(trafo[j0,i0],i0)
1731     elif len(sh_in)==2:
1732     for i0 in range(sh_in[0]):
1733     for i1 in range(sh_in[1]):
1734     if i0+i1>0: t_prog+="+"
1735     t_prog+="%s*arg_in[%s,%s]"%(trafo[j0,i0,i1],i0,i1)
1736     elif len(sh_in)==3:
1737     for i0 in range(sh_in[0]):
1738     for i1 in range(sh_in[1]):
1739     for i2 in range(sh_in[2]):
1740     if i0+i1+i2>0: t_prog+="+"
1741     t_prog+="%s*arg_in[%s,%s,%s]"%(trafo[j0,i0,i1,i2],i0,i1,i2)
1742     t_prog+="\n"
1743     elif len(sh)==2:
1744     for j0 in range(sh[0]):
1745     for j1 in range(sh[1]):
1746     t_prog+=" arg2[%s,%s]="%(j0,j1)
1747     if len(sh_in)==0:
1748     t_prog+="%s*arg_in"%trafo[j0,j1]
1749     elif len(sh_in)==1:
1750     for i0 in range(sh_in[0]):
1751     if i0>0: t_prog+="+"
1752     t_prog+="%s*arg_in[%s]"%(trafo[j0,j1,i0],i0)
1753     elif len(sh_in)==2:
1754     for i0 in range(sh_in[0]):
1755     for i1 in range(sh_in[1]):
1756     if i0+i1>0: t_prog+="+"
1757     t_prog+="%s*arg_in[%s,%s]"%(trafo[j0,j1,i0,i1],i0,i1)
1758     t_prog+="\n"
1759     elif len(sh)==3:
1760     for j0 in range(sh[0]):
1761     for j1 in range(sh[1]):
1762     for j2 in range(sh[2]):
1763     t_prog+=" arg2[%s,%s,%s]="%(j0,j1,j2)
1764     if len(sh_in)==0:
1765     t_prog+="%s*arg_in"%trafo[j0,j1,j2]
1766     elif len(sh_in)==1:
1767     for i0 in range(sh_in[0]):
1768     if i0>0: t_prog+="+"
1769     t_prog+="%s*arg_in[%s]"%(trafo[j0,j1,j2,i0],i0)
1770     t_prog+="\n"
1771     elif len(sh)==4:
1772     for j0 in range(sh[0]):
1773     for j1 in range(sh[1]):
1774     for j2 in range(sh[2]):
1775     for j3 in range(sh[3]):
1776     t_prog+=" arg2[%s,%s,%s,%s]="%(j0,j1,j2,j3)
1777     if len(sh_in)==0:
1778     t_prog+="%s*arg_in"%trafo[j0,j1,j2,j3]
1779     t_prog+="\n"
1780     t_prog+=" dvdin=v.substitute({arg : arg2}).diff(arg_in)\n"
1781     if len(sh_in)==0:
1782     t_prog+=" res_in=dvdin.substitute({arg_in : %s})\n"%a_in
1783     else:
1784     t_prog+=" res_in=dvdin.substitute({arg : numarray.array(%s)})\n"%a_in.tolist()
1785    
1786     if len(sh)==0:
1787     if len(sh_in)==0:
1788     ref_diff=(makeResult(trafo*a_in+finc,f)-makeResult(trafo*a_in,f))/finc
1789     t_prog+=" self.failUnlessAlmostEqual(dvdin,%s,self.places,\"%s-derivative: wrong derivative\")\n"%(ref_diff,str(sh_in))
1790     elif len(sh_in)==1:
1791     s=0
1792     for k0 in range(sh_in[0]):
1793     s+=trafo[k0]*a_in[k0]
1794     for i0 in range(sh_in[0]):
1795     ref_diff=(makeResult(s+trafo[i0]*finc,f)-makeResult(s,f))/finc
1796     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,ref_diff,str(sh_in),str(i0))
1797     elif len(sh_in)==2:
1798     s=0
1799     for k0 in range(sh_in[0]):
1800     for k1 in range(sh_in[1]):
1801     s+=trafo[k0,k1]*a_in[k0,k1]
1802     for i0 in range(sh_in[0]):
1803     for i1 in range(sh_in[1]):
1804     ref_diff=(makeResult(s+trafo[i0,i1]*finc,f)-makeResult(s,f))/finc
1805     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,ref_diff,str(sh_in),str((i0,i1)))
1806    
1807     elif len(sh_in)==3:
1808     s=0
1809     for k0 in range(sh_in[0]):
1810     for k1 in range(sh_in[1]):
1811     for k2 in range(sh_in[2]):
1812     s+=trafo[k0,k1,k2]*a_in[k0,k1,k2]
1813     for i0 in range(sh_in[0]):
1814     for i1 in range(sh_in[1]):
1815     for i2 in range(sh_in[2]):
1816     ref_diff=(makeResult(s+trafo[i0,i1,i2]*finc,f)-makeResult(s,f))/finc
1817     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,i2,ref_diff,str(sh_in),str((i0,i1,i2)))
1818     elif len(sh_in)==4:
1819     s=0
1820     for k0 in range(sh_in[0]):
1821     for k1 in range(sh_in[1]):
1822     for k2 in range(sh_in[2]):
1823     for k3 in range(sh_in[3]):
1824     s+=trafo[k0,k1,k2,k3]*a_in[k0,k1,k2,k3]
1825     for i0 in range(sh_in[0]):
1826     for i1 in range(sh_in[1]):
1827     for i2 in range(sh_in[2]):
1828     for i3 in range(sh_in[3]):
1829     ref_diff=(makeResult(s+trafo[i0,i1,i2,i3]*finc,f)-makeResult(s,f))/finc
1830     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative with respect of %s\")\n"%(i0,i1,i2,i3,ref_diff,str(sh_in),str((i0,i1,i2,i3)))
1831     elif len(sh)==1:
1832     for j0 in range(sh[0]):
1833     if len(sh_in)==0:
1834     ref_diff=(makeResult(trafo[j0]*a_in+finc,f)-makeResult(trafo[j0]*a_in,f))/finc
1835     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s],%s,self.places,\"%s-derivative: wrong derivative of %s\")\n"%(j0,ref_diff,str(sh_in),j0)
1836     elif len(sh_in)==1:
1837     s=0
1838     for k0 in range(sh_in[0]):
1839     s+=trafo[j0,k0]*a_in[k0]
1840     for i0 in range(sh_in[0]):
1841     ref_diff=(makeResult(s+trafo[j0,i0]*finc,f)-makeResult(s,f))/finc
1842     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,i0,ref_diff,str(sh_in),str(j0),str(i0))
1843     elif len(sh_in)==2:
1844     s=0
1845     for k0 in range(sh_in[0]):
1846     for k1 in range(sh_in[1]):
1847     s+=trafo[j0,k0,k1]*a_in[k0,k1]
1848     for i0 in range(sh_in[0]):
1849     for i1 in range(sh_in[1]):
1850     ref_diff=(makeResult(s+trafo[j0,i0,i1]*finc,f)-makeResult(s,f))/finc
1851     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,i0,i1,ref_diff,str(sh_in),str(j0),str((i0,i1)))
1852    
1853     elif len(sh_in)==3:
1854    
1855     s=0
1856     for k0 in range(sh_in[0]):
1857     for k1 in range(sh_in[1]):
1858     for k2 in range(sh_in[2]):
1859     s+=trafo[j0,k0,k1,k2]*a_in[k0,k1,k2]
1860    
1861     for i0 in range(sh_in[0]):
1862     for i1 in range(sh_in[1]):
1863     for i2 in range(sh_in[2]):
1864     ref_diff=(makeResult(s+trafo[j0,i0,i1,i2]*finc,f)-makeResult(s,f))/finc
1865     t_prog+=" self.failUnlessAlmostEqual(dvdin[%s,%s,%s,%s],%s,self.places,\"%s-derivative: wrong derivative of %s with respect of %s\")\n"%(j0,i0,i1,i2,ref_diff,str(sh_in),str(j0),str((i0,i1,i2)))
1866     elif len(sh)==2:
1867