1 |
#!/usr/bin/python |
2 |
|
3 |
# $Id$ |
4 |
__copyright__=""" Copyright (c) 2006 by ACcESS MNRF |
5 |
http://www.access.edu.au |
6 |
Primary Business: Queensland, Australia""" |
7 |
__license__="""Licensed under the Open Software License version 3.0 |
8 |
http://www.opensource.org/licenses/osl-3.0.php""" |
9 |
|
10 |
import unittest |
11 |
from esys.escript.modelframe import Model,Link,Simulation,ParameterSet,ESySXMLParser,DataSource |
12 |
import math |
13 |
from cStringIO import StringIO |
14 |
from xml.dom import minidom |
15 |
import numarray |
16 |
|
17 |
class XMLDocumentTestCase(unittest.TestCase): |
18 |
|
19 |
def setUp(self): |
20 |
|
21 |
o1=ODETEST(debug=False) |
22 |
o1.u=10 |
23 |
o2=ODETEST(debug=False) |
24 |
o2.u=-10. |
25 |
o1.f=Link(o2,"u") |
26 |
o2.f=Link(o1,"u") |
27 |
m=Messenger() |
28 |
o1.dt=0.01 |
29 |
m.message=Link(o1) |
30 |
s=Simulation([o1,o2,m],debug=False) |
31 |
s.run() |
32 |
output = StringIO() |
33 |
s.writeXML(output) |
34 |
output.reset() |
35 |
outputList = output.readlines() |
36 |
self.xmlList = outputList |
37 |
|
38 |
def testFirstLine(self): |
39 |
firstLine = self.xmlList[0] |
40 |
self.assertEqual('<?xml version="1.0" ?>\n', firstLine) |
41 |
|
42 |
def testEsysHeader(self): |
43 |
header = self.xmlList[1] |
44 |
self.assertEqual('<ESys>\n', header) |
45 |
|
46 |
def testEsysFooter(self): |
47 |
footer = self.xmlList[-1] |
48 |
self.assertEqual('</ESys>\n', footer) |
49 |
|
50 |
def testSimulationHeader(self): |
51 |
pass |
52 |
|
53 |
def testSimulationFooter(self): |
54 |
pass |
55 |
|
56 |
class SimulationTestCase(unittest.TestCase): |
57 |
def setUp(self): |
58 |
o1=ODETEST(debug=False) |
59 |
o1.u=10 |
60 |
o2=ODETEST(debug=False) |
61 |
o2.u=-10. |
62 |
o1.f=Link(o2,"u") |
63 |
o2.f=Link(o1,"u") |
64 |
m=Messenger() |
65 |
o1.dt=0.01 |
66 |
m.message=Link(o1) |
67 |
self.s=Simulation([o1,o2,m],debug=False) |
68 |
self.s.run() |
69 |
output = StringIO() |
70 |
self.s.writeXML(output) |
71 |
output.reset() |
72 |
self.xml = output.read() |
73 |
|
74 |
def testSimulation(self): |
75 |
assert "<Simulation" in self.xml, "I should see a Simulation" |
76 |
|
77 |
def testParseAndInstanceOfSimulation(self): |
78 |
|
79 |
xml = ESySXMLParser(self.xml) |
80 |
newSim = xml.parse()[0] |
81 |
assert (isinstance (newSim, Simulation)) |
82 |
newout = StringIO() |
83 |
newSim.writeXML(newout) |
84 |
newout.reset() |
85 |
xml = newout.read() |
86 |
assert '<Simulation' in xml, "Missing a Simulation! It should be in this!" |
87 |
|
88 |
|
89 |
|
90 |
|
91 |
class LinkTestCase(unittest.TestCase): |
92 |
|
93 |
|
94 |
def setUp(self): |
95 |
|
96 |
self.o1=ODETEST(debug=False) |
97 |
#self.o1.u=10 |
98 |
self.o2=ODETEST(debug=False) |
99 |
self.o2.u=-10. |
100 |
self.o1.f=Link(self.o2,"u") |
101 |
self.o2.f=Link(self.o1,"u") |
102 |
self.o2.declareParameter(child=self.o1) |
103 |
|
104 |
def testLinkCreation(self): |
105 |
self.o1.f=Link(self.o2,"u") |
106 |
assert self.o1.f |
107 |
|
108 |
|
109 |
def testLinkValue(self): |
110 |
self.assertEqual(self.o1.f, -10) |
111 |
|
112 |
def testLinkTarget(self): |
113 |
pass |
114 |
|
115 |
def testLinkDefaultAttribute(self): |
116 |
Link(self.o2) |
117 |
|
118 |
def testLinkXML(self): |
119 |
s = StringIO() |
120 |
self.o2.writeXML(s) |
121 |
s.reset() |
122 |
xmlout = s.read() |
123 |
assert '<Link' in xmlout |
124 |
|
125 |
def testLinkTargetXML(self): |
126 |
pass |
127 |
|
128 |
class ParamaterSetTestCase(unittest.TestCase): |
129 |
|
130 |
|
131 |
def setUp(self): |
132 |
self.p = ParameterSet() |
133 |
self.p.declareParameter(gamma1=1.,gamma2=2.,gamma3=3.) |
134 |
|
135 |
def testParameterSetCreation(self): |
136 |
self.assertEqual(self.p.gamma1, 1.) |
137 |
|
138 |
def testParameterSetXMLCreation(self): |
139 |
s = StringIO() |
140 |
self.p.writeXML(s) |
141 |
s.reset() |
142 |
xmlout = s.read() |
143 |
assert ("gamma1" in xmlout) |
144 |
assert ("gamma2" in xmlout) |
145 |
assert ("gamma3" in xmlout) |
146 |
esysxml=ESySXMLParser(xmlout) |
147 |
parsable = esysxml.parse()[0] |
148 |
assert (isinstance (parsable, ParameterSet)) |
149 |
assert (self._dom(self.p).getElementsByTagName("ParameterSet")) |
150 |
|
151 |
def testParameterSetFromXML(self): |
152 |
s = StringIO() |
153 |
self.p.writeXML(s) |
154 |
s.reset() |
155 |
xmlout = s.read() |
156 |
esysxml=ESySXMLParser(xmlout) |
157 |
doc = self._class(self.p) |
158 |
pset = ParameterSet.fromDom(esysxml, self._dom(self.p).getElementsByTagName("ParameterSet")[0]) |
159 |
assert (isinstance(pset, ParameterSet)) |
160 |
assert (isinstance(doc, ParameterSet)) |
161 |
self.assertEqual(self.p.gamma1,doc.gamma1) |
162 |
|
163 |
|
164 |
def testParameterSetWithChildrenFromXML(self): |
165 |
p2 = ParameterSet() |
166 |
p2.declareParameter(s="abc", f=3.) |
167 |
self.p.declareParameter(child=p2) |
168 |
doc = self._class(self.p) |
169 |
#pset = ParameterSet.fromDom(doc.getElementsByTagName("ParameterSet")[0]) |
170 |
self.assertEqual(self.p.child.f, doc.child.f) |
171 |
|
172 |
def testParameterSetChild(self): |
173 |
p2 = ParameterSet() |
174 |
p2.declareParameter(s="abc", f=3.) |
175 |
self.p.declareParameter(child=p2) |
176 |
self.assertEqual(self.p.child.s, "abc") |
177 |
self.assertEqual(self.p.child.f, 3.) |
178 |
|
179 |
def _dom(self, input): |
180 |
s = StringIO() |
181 |
input.writeXML(s) |
182 |
s.reset() |
183 |
xmlout = s.read() |
184 |
doc = minidom.parseString(xmlout) |
185 |
return doc |
186 |
|
187 |
def _class(self, input): |
188 |
s = StringIO() |
189 |
input.writeXML(s) |
190 |
s.reset() |
191 |
xmlout = s.read() |
192 |
esysxml=ESySXMLParser(xmlout) |
193 |
doc =esysxml.parse()[0] |
194 |
return doc |
195 |
|
196 |
def testFromDomInt(self): |
197 |
p3 = ParameterSet() |
198 |
p3.declareParameter(inttest=1) |
199 |
doc = self._class(p3) |
200 |
assert type(doc.inttest)==int |
201 |
|
202 |
def testFromDomNumarrayVector(self): |
203 |
p3 = ParameterSet() |
204 |
mynumarray = numarray.array([3., 4., 5.], type=numarray.Float64) |
205 |
p3.declareParameter(numtest=mynumarray) |
206 |
doc = self._class(p3) |
207 |
assert doc.numtest.type() == numarray.Float64 |
208 |
assert type(doc.numtest) == numarray.NumArray |
209 |
|
210 |
def testFromDomNumarrayMulti(self): |
211 |
p3 = ParameterSet() |
212 |
mynumarray = numarray.array([[1., 2., 3.], [3., 4., 5.]], type=numarray.Float64) |
213 |
p3.declareParameter(numtest=mynumarray) |
214 |
doc = self._class(p3) |
215 |
assert doc.numtest.type() == numarray.Float64 |
216 |
assert type(doc.numtest) == numarray.NumArray |
217 |
|
218 |
def testBoolLists(self): |
219 |
p4 = ParameterSet() |
220 |
mylist = [True, False, False, True] |
221 |
p4.declareParameter(listest=mylist) |
222 |
doc = self._class(p4) |
223 |
assert type(doc.listest) == numarray.NumArray |
224 |
assert doc.listest.type() == numarray.Bool |
225 |
assert len(doc.listest) == len(mylist) |
226 |
assert min([ mylist[i] == doc.listest[i] for i in range(len( doc.listest)) ]) |
227 |
|
228 |
def testIntLists(self): |
229 |
p4 = ParameterSet() |
230 |
mylist = [1,2,4] |
231 |
p4.declareParameter(listest=mylist) |
232 |
doc = self._class(p4) |
233 |
assert type(doc.listest) == numarray.NumArray |
234 |
assert doc.listest.type() == numarray.Int |
235 |
assert len(doc.listest) == len(mylist) |
236 |
assert min([ mylist[i] == doc.listest[i] for i in range(len( doc.listest)) ]) |
237 |
|
238 |
def testFloatLists(self): |
239 |
p4 = ParameterSet() |
240 |
mylist = [1.,2.,4.] |
241 |
p4.declareParameter(listest=mylist) |
242 |
doc = self._class(p4) |
243 |
assert type(doc.listest) == numarray.NumArray |
244 |
assert doc.listest.type() == numarray.Float |
245 |
assert len(doc.listest) == len(mylist) |
246 |
assert min([ mylist[i] == doc.listest[i] for i in range(len( doc.listest)) ]) |
247 |
|
248 |
def testStringLists(self): |
249 |
p4 = ParameterSet() |
250 |
mylist = ["a", "b", "c"] |
251 |
p4.declareParameter(listest=mylist) |
252 |
doc = self._class(p4) |
253 |
assert type(doc.listest) == list |
254 |
assert len(doc.listest) == len(mylist) |
255 |
assert min([ mylist[i] == doc.listest[i] for i in range(len( doc.listest)) ]) |
256 |
|
257 |
def testDatasource(self): |
258 |
p5 = ParameterSet() |
259 |
myURI = DataSource("somelocalfile.txt", "text") |
260 |
p5.declareParameter(uritest=myURI) |
261 |
doc = self._class(p5) |
262 |
self.assertEquals(myURI.uri, doc.uritest.uri) |
263 |
self.assertEquals(myURI.fileformat, doc.uritest.fileformat) |
264 |
assert type(doc.uritest) == DataSource |
265 |
|
266 |
|
267 |
|
268 |
class ModeltoDomTestCase(unittest.TestCase): |
269 |
|
270 |
def _class(self): |
271 |
# returns a modelframe class, generated from the xml |
272 |
s = StringIO() |
273 |
self.o1.writeXML(s) |
274 |
s.reset() |
275 |
self.xmlout = s.read() |
276 |
esysxml=ESySXMLParser(xmlout) |
277 |
doc =esysxml.parse()[0] |
278 |
return doc |
279 |
|
280 |
def _dom(self): |
281 |
# returns a minidom dom element, generated from the xml |
282 |
s = StringIO() |
283 |
self.o1.writeXML(s) |
284 |
s.reset() |
285 |
self.xmlout = s.read() |
286 |
doc = minidom.parseString(self.xmlout) |
287 |
return doc |
288 |
|
289 |
def setUp(self): |
290 |
self.o1=ODETEST(debug=False) |
291 |
self.o1.message='blah' |
292 |
|
293 |
def testModelExists(self): |
294 |
modeldoc = self._class() |
295 |
assert (isinstance, (modeldoc, Model)) |
296 |
assert self._dom().getElementsByTagName("Model") |
297 |
|
298 |
def testModelhasID(self): |
299 |
assert int(self._dom().getElementsByTagName("Model")[0].getAttribute("id"))>99 |
300 |
|
301 |
class ModeltoDomTestCase(unittest.TestCase): |
302 |
def _xml(self, modulename, modelname): |
303 |
# returns a modelframe class, generated from the xml |
304 |
return '''<?xml version="1.0" ?> |
305 |
<ESys> <Simulation type="Simulation"> <Component rank="0"> |
306 |
|
307 |
<Model id="127" module="%s" type="%s"> |
308 |
|
309 |
<Parameter type="float"> <Name> a </Name> <Value> 0.9 </Value> </Parameter> |
310 |
<Parameter type="Link"> <Name> f </Name> <Value> <Link> <Target> 128 </Target> |
311 |
<Attribute> u </Attribute> </Link> </Value> </Parameter> <Parameter |
312 |
type="float"> <Name> tend </Name> <Value> |
313 |
1.0 </Value> </Parameter> <Parameter type="int"> <Name> u </Name> <Value> 10 |
314 |
</Value> </Parameter> <Parameter type="float"> <Name> tol </Name> <Value> |
315 |
1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt </Name> |
316 |
<Value> |
317 |
0.01 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> |
318 |
<Value> current error = 9.516258e-01 </Value> </Parameter> </Model> |
319 |
</Component> <Component rank="1"> <Model id="128" type="ODETEST"> <Parameter |
320 |
type="float"> <Name> a </Name> <Value> |
321 |
0.9 </Value> </Parameter> <Parameter type="Link"> <Name> f </Name> <Value> |
322 |
<Link> <Target> 127 </Target> <Attribute> u </Attribute> </Link> </Value> |
323 |
</Parameter> <Parameter type="float"> <Name> tend </Name> <Value> |
324 |
1.0 </Value> </Parameter> <Parameter type="float"> <Name> u </Name> <Value> |
325 |
-10.0 </Value> </Parameter> <Parameter type="float"> <Name> tol </Name> |
326 |
<Value> 1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt |
327 |
</Name> <Value> |
328 |
0.1 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> <Value> |
329 |
current error = 1.904837e+01 </Value> </Parameter> </Model> </Component> |
330 |
<Component rank="2"> <Model id="129" type="Messenger"> <Parameter |
331 |
type="Link"> <Name> message </Name> <Value> <Link> <Target> 127 </Target> |
332 |
<Attribute> message </Attribute> </Link> </Value> </Parameter> </Model> |
333 |
</Component> </Simulation> <Model id="150" type="ODETEST"> <Parameter |
334 |
type="float"> <Name> a </Name> <Value> |
335 |
0.9 </Value> </Parameter> <Parameter type="Link"> <Name> f </Name> <Value> |
336 |
<Link> <Target> 127 </Target> <Attribute> u </Attribute> </Link> </Value> |
337 |
</Parameter> <Parameter type="float"> <Name> tend </Name> <Value> |
338 |
1.0 </Value> </Parameter> <Parameter type="float"> <Name> u </Name> <Value> |
339 |
-10.0 </Value> </Parameter> <Parameter type="float"> <Name> tol </Name> |
340 |
<Value> 1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt |
341 |
</Name> <Value> |
342 |
0.1 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> <Value> |
343 |
current error = 1.904837e+01 </Value> </Parameter> </Model> <Model id="130" |
344 |
type="ODETEST"> <Parameter type="float"> <Name> a </Name> <Value> |
345 |
0.9 </Value> </Parameter> <Parameter type="Link"> <Name> f </Name> <Value> |
346 |
<Link> <Target> 128 </Target> <Attribute> u </Attribute> </Link> </Value> |
347 |
</Parameter> <Parameter type="float"> <Name> tend </Name> <Value> |
348 |
1.0 </Value> </Parameter> <Parameter type="int"> <Name> u </Name> <Value> 10 |
349 |
</Value> </Parameter> <Parameter type="float"> <Name> tol </Name> <Value> |
350 |
1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt </Name> |
351 |
<Value> |
352 |
0.01 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> |
353 |
<Value> current error = 9.516258e-01 </Value> </Parameter> </Model> <Model |
354 |
id="170" type="ODETEST"> <Parameter type="float"> <Name> a </Name> <Value> |
355 |
0.9 </Value> </Parameter> <Parameter type="Link"> <Name> f </Name> <Value> |
356 |
<Link> <Target> 128 </Target> <Attribute> u </Attribute> </Link> </Value> |
357 |
</Parameter> <Parameter type="float"> <Name> tend </Name> <Value> |
358 |
1.0 </Value> </Parameter> <Parameter type="int"> <Name> u </Name> <Value> 10 |
359 |
</Value> </Parameter> <Parameter type="float"> <Name> tol </Name> <Value> |
360 |
1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt </Name> |
361 |
<Value> |
362 |
0.01 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> |
363 |
<Value> current error = 9.516258e-01 </Value> </Parameter> </Model> </ESys> |
364 |
''' % (modulename, modelname) |
365 |
|
366 |
def testModuleAttribute(self): |
367 |
esysxml=ESySXMLParser(self._xml('run_xml', 'ODETEST')) |
368 |
modeldoc=esysxml.parse()[0] |
369 |
|
370 |
def testModuleAttributeFails(self): |
371 |
try: |
372 |
esysxml=ESySXMLParser(self._xml('a', 'b')) |
373 |
modeldoc=esysxml.parse()[0] |
374 |
except ImportError: |
375 |
return # correct |
376 |
|
377 |
assert False, "This test should have resulted in an ImportError" |
378 |
|
379 |
class Messenger(Model): |
380 |
def __init__(self, *args, **kwargs): |
381 |
Model.__init__(self, *args, **kwargs) |
382 |
self.declareParameter(message="none") |
383 |
|
384 |
def doInitialization(self): |
385 |
self.__t=0 |
386 |
#print "I start talking now!" |
387 |
|
388 |
def doStepPostprocessing(self,dt): |
389 |
self.__t+=dt |
390 |
#print "Message (time %e) : %s "%(self.__t,self.message) |
391 |
|
392 |
def doFinalization(self): |
393 |
#print "I have no more to say!" |
394 |
pass |
395 |
|
396 |
|
397 |
|
398 |
class ODETEST(Model): |
399 |
""" implements a solver for the ODE |
400 |
|
401 |
du/dt=a*u+f(t) |
402 |
|
403 |
we use a implicit euler scheme : |
404 |
|
405 |
u_n-u_{n-1}= dt*a u_n + st*f(t_n) |
406 |
|
407 |
to get u_n we run an iterative process |
408 |
|
409 |
u_{n.k}=u_{n-1}+dt*(a u_{n.i-1} + f(t_n)) |
410 |
|
411 |
|
412 |
input for this model are step size dt, end time tend and a value for |
413 |
a, f and initial value for u. we need also a tolerance tol for a |
414 |
stopping criterion. |
415 |
|
416 |
""" |
417 |
|
418 |
def __init__(self, *args, **kwargs): |
419 |
Model.__init__(self, *args, **kwargs) |
420 |
self.declareParameter(tend=1.,dt=0.1,a=0.9,u=10.,f=0.,message="",tol=1.e-8) |
421 |
|
422 |
def doInitialization(self): |
423 |
self.__tn=0 |
424 |
self.__iter=0 |
425 |
|
426 |
def doStepPreprocessing(self,dt): |
427 |
self.__iter=0 |
428 |
self.__u_last=self.u |
429 |
|
430 |
def doStep(self,dt): |
431 |
self.__iter+=1 |
432 |
self.__u_old=self.u |
433 |
self.u=self.__u_last+dt*(self.a*self.__u_old+self.f) |
434 |
|
435 |
def terminate(self): |
436 |
if self.__iter<1: |
437 |
return False |
438 |
else: |
439 |
return abs(self.__u_old-self.u)<self.tol*abs(self.u) |
440 |
|
441 |
def doStepPostprocessing(self,dt): |
442 |
self.__tn+=dt |
443 |
self.message="current error = %e"%abs(self.u-10.*math.exp((self.a-1.)*self.__tn)) |
444 |
|
445 |
def getSafeTimeStepSize(self,dt): |
446 |
return min(self.dt,1./(abs(self.a)+1.)) |
447 |
|
448 |
def finalize(self): |
449 |
return self.__tn>=self.tend |
450 |
|
451 |
|
452 |
|
453 |
if __name__ == "__main__": |
454 |
unittest.main() |
455 |
|