1 |
#!/usr/bin/python |
2 |
|
3 |
# $Id$ |
4 |
__copyright__=""" Copyright (c) 2006 by ACcESS MNRF |
5 |
http://www.access.edu.au |
6 |
Primary Business: Queensland, Australia""" |
7 |
__license__="""Licensed under the Open Software License version 3.0 |
8 |
http://www.opensource.org/licenses/osl-3.0.php""" |
9 |
|
10 |
import unittest |
11 |
from esys.escript.modelframe import Model,Link,Simulation,ParameterSet,parse |
12 |
import math |
13 |
from cStringIO import StringIO |
14 |
from xml.dom import minidom |
15 |
|
16 |
class XMLDocumentTestCase(unittest.TestCase): |
17 |
|
18 |
def setUp(self): |
19 |
|
20 |
o1=ODETEST(debug=False) |
21 |
o1.u=10 |
22 |
o2=ODETEST(debug=False) |
23 |
o2.u=-10. |
24 |
o1.f=Link(o2,"u") |
25 |
o2.f=Link(o1,"u") |
26 |
m=Messenger() |
27 |
o1.dt=0.01 |
28 |
m.message=Link(o1) |
29 |
s=Simulation([o1,o2,m],debug=False) |
30 |
s.run() |
31 |
output = StringIO() |
32 |
s.writeXML(output) |
33 |
output.reset() |
34 |
outputList = output.readlines() |
35 |
self.xmlList = outputList |
36 |
|
37 |
def testFirstLine(self): |
38 |
firstLine = self.xmlList[0] |
39 |
self.assertEqual('<?xml version="1.0" ?>\n', firstLine) |
40 |
|
41 |
def testEsysHeader(self): |
42 |
header = self.xmlList[1] |
43 |
self.assertEqual('<ESys>\n', header) |
44 |
|
45 |
def testEsysFooter(self): |
46 |
footer = self.xmlList[-1] |
47 |
self.assertEqual('</ESys>\n', footer) |
48 |
|
49 |
def testSimulationHeader(self): |
50 |
pass |
51 |
|
52 |
def testSimulationFooter(self): |
53 |
pass |
54 |
|
55 |
class SimulationTestCase(unittest.TestCase): |
56 |
def setUp(self): |
57 |
o1=ODETEST(debug=False) |
58 |
o1.u=10 |
59 |
o2=ODETEST(debug=False) |
60 |
o2.u=-10. |
61 |
o1.f=Link(o2,"u") |
62 |
o2.f=Link(o1,"u") |
63 |
m=Messenger() |
64 |
o1.dt=0.01 |
65 |
m.message=Link(o1) |
66 |
self.s=Simulation([o1,o2,m],debug=False) |
67 |
self.s.run() |
68 |
output = StringIO() |
69 |
self.s.writeXML(output) |
70 |
output.reset() |
71 |
self.xml = output.read() |
72 |
|
73 |
def testSimulation(self): |
74 |
assert "<Simulation" in self.xml, "I should see a Simulation" |
75 |
|
76 |
def testParseAndInstanceOfSimulation(self): |
77 |
|
78 |
newSim = parse(self.xml) |
79 |
assert (isinstance (newSim, Simulation)) |
80 |
newout = StringIO() |
81 |
newSim.writeXML(newout) |
82 |
newout.reset() |
83 |
xml = newout.read() |
84 |
assert '<Simulation' in xml, "Missing a Simulation! It should be in this!" |
85 |
|
86 |
|
87 |
|
88 |
|
89 |
class LinkTestCase(unittest.TestCase): |
90 |
|
91 |
|
92 |
def setUp(self): |
93 |
|
94 |
self.o1=ODETEST(debug=False) |
95 |
#self.o1.u=10 |
96 |
self.o2=ODETEST(debug=False) |
97 |
self.o2.u=-10. |
98 |
self.o1.f=Link(self.o2,"u") |
99 |
self.o2.f=Link(self.o1,"u") |
100 |
self.o2.declareParameter(child=self.o1) |
101 |
|
102 |
def testLinkCreation(self): |
103 |
self.o1.f=Link(self.o2,"u") |
104 |
assert self.o1.f |
105 |
|
106 |
|
107 |
def testLinkValue(self): |
108 |
self.assertEqual(self.o1.f, -10) |
109 |
|
110 |
def testLinkTarget(self): |
111 |
pass |
112 |
|
113 |
def testLinkDefaultAttribute(self): |
114 |
Link(self.o2) |
115 |
|
116 |
def testLinkXML(self): |
117 |
s = StringIO() |
118 |
self.o2.writeXML(s) |
119 |
s.reset() |
120 |
xmlout = s.read() |
121 |
assert '<Link' in xmlout |
122 |
|
123 |
def testLinkTargetXML(self): |
124 |
pass |
125 |
|
126 |
class ParamaterSetTestCase(unittest.TestCase): |
127 |
|
128 |
|
129 |
def setUp(self): |
130 |
self.p = ParameterSet() |
131 |
self.p.declareParameter(gamma1=1.,gamma2=2.,gamma3=3.) |
132 |
|
133 |
def testParameterSetCreation(self): |
134 |
self.assertEqual(self.p.gamma1, 1.) |
135 |
|
136 |
def testParameterSetXMLCreation(self): |
137 |
s = StringIO() |
138 |
self.p.writeXML(s) |
139 |
s.reset() |
140 |
xmlout = s.read() |
141 |
assert ("gamma1" in xmlout) |
142 |
assert ("gamma2" in xmlout) |
143 |
assert ("gamma3" in xmlout) |
144 |
parsable = parse(xmlout) |
145 |
assert (isinstance (parsable, ParameterSet)) |
146 |
assert (self._dom(self.p).getElementsByTagName("ParameterSet")) |
147 |
|
148 |
def testParameterSetFromXML(self): |
149 |
doc = self._class(self.p) |
150 |
pset = ParameterSet.fromDom(self._dom(self.p).getElementsByTagName("ParameterSet")[0]) |
151 |
assert (isinstance(pset, ParameterSet)) |
152 |
assert (isinstance(doc, ParameterSet)) |
153 |
self.assertEqual(self.p.gamma1,doc.gamma1) |
154 |
|
155 |
|
156 |
def testParameterSetWithChildrenFromXML(self): |
157 |
p2 = ParameterSet() |
158 |
p2.declareParameter(s="abc", f=3.) |
159 |
self.p.declareParameter(child=p2) |
160 |
doc = self._class(self.p) |
161 |
#pset = ParameterSet.fromDom(doc.getElementsByTagName("ParameterSet")[0]) |
162 |
self.assertEqual(self.p.child.f, doc.child.f) |
163 |
|
164 |
def testParameterSetChild(self): |
165 |
p2 = ParameterSet() |
166 |
p2.declareParameter(s="abc", f=3.) |
167 |
self.p.declareParameter(child=p2) |
168 |
self.assertEqual(self.p.child.s, "abc") |
169 |
self.assertEqual(self.p.child.f, 3.) |
170 |
|
171 |
def _dom(self, input): |
172 |
s = StringIO() |
173 |
input.writeXML(s) |
174 |
s.reset() |
175 |
xmlout = s.read() |
176 |
doc = minidom.parseString(xmlout) |
177 |
return doc |
178 |
|
179 |
def _class(self, input): |
180 |
s = StringIO() |
181 |
input.writeXML(s) |
182 |
s.reset() |
183 |
xmlout = s.read() |
184 |
doc = parse(xmlout) |
185 |
return doc |
186 |
|
187 |
def testFromDomInt(self): |
188 |
p3 = ParameterSet() |
189 |
p3.declareParameter(inttest=1) |
190 |
doc = self._class(p3) |
191 |
assert type(doc.inttest)==int |
192 |
|
193 |
def testFromDomNumarrayVector(self): |
194 |
import numarray |
195 |
p3 = ParameterSet() |
196 |
mynumarray = numarray.array([3., 4., 5.], type=numarray.Float64) |
197 |
p3.declareParameter(numtest=mynumarray) |
198 |
doc = self._class(p3) |
199 |
assert doc.numtest.type() == numarray.Float64 |
200 |
assert type(doc.numtest) == numarray.NumArray |
201 |
|
202 |
def testFromDomNumarrayMulti(self): |
203 |
import numarray |
204 |
p3 = ParameterSet() |
205 |
mynumarray = numarray.array([[1., 2., 3.], [3., 4., 5.]], type=numarray.Float64) |
206 |
p3.declareParameter(numtest=mynumarray) |
207 |
doc = self._class(p3) |
208 |
assert doc.numtest.type() == numarray.Float64 |
209 |
assert type(doc.numtest) == numarray.NumArray |
210 |
|
211 |
def testLists(self): |
212 |
p4 = ParameterSet() |
213 |
mylist = [True, False, False, True] |
214 |
p4.declareParameter(listest=mylist) |
215 |
doc = self._class(p4) |
216 |
assert type(doc.listest) == list |
217 |
self.assertEquals(mylist, doc.listest) |
218 |
assert type(doc.listest[0]) == bool |
219 |
|
220 |
|
221 |
class ModeltoDomTestCase(unittest.TestCase): |
222 |
|
223 |
def _class(self): |
224 |
# returns a modelframe class, generated from the xml |
225 |
s = StringIO() |
226 |
self.o1.writeXML(s) |
227 |
s.reset() |
228 |
self.xmlout = s.read() |
229 |
doc = parse(self.xmlout) |
230 |
return doc |
231 |
|
232 |
def _dom(self): |
233 |
# returns a minidom dom element, generated from the xml |
234 |
s = StringIO() |
235 |
self.o1.writeXML(s) |
236 |
s.reset() |
237 |
self.xmlout = s.read() |
238 |
doc = minidom.parseString(self.xmlout) |
239 |
return doc |
240 |
|
241 |
def setUp(self): |
242 |
self.o1=ODETEST(debug=False) |
243 |
self.o1.message='blah' |
244 |
|
245 |
def testModelExists(self): |
246 |
modeldoc = self._class() |
247 |
assert (isinstance, (modeldoc, Model)) |
248 |
assert self._dom().getElementsByTagName("Model") |
249 |
|
250 |
def testModelhasID(self): |
251 |
assert int(self._dom().getElementsByTagName("Model")[0].getAttribute("id"))>99 |
252 |
|
253 |
class ModeltoDomTestCase(unittest.TestCase): |
254 |
def _xml(self, modulename, modelname): |
255 |
# returns a modelframe class, generated from the xml |
256 |
return '''<?xml version="1.0" ?> |
257 |
<ESys> <Simulation type="Simulation"> <Component rank="0"> |
258 |
|
259 |
<Model id="127" module="%s" type="%s"> |
260 |
|
261 |
<Parameter type="float"> <Name> a </Name> <Value> 0.9 </Value> </Parameter> |
262 |
<Parameter type="Link"> <Name> f </Name> <Value> <Link> <Target> 128 </Target> |
263 |
<Attribute> u </Attribute> </Link> </Value> </Parameter> <Parameter |
264 |
type="float"> <Name> tend </Name> <Value> |
265 |
1.0 </Value> </Parameter> <Parameter type="int"> <Name> u </Name> <Value> 10 |
266 |
</Value> </Parameter> <Parameter type="float"> <Name> tol </Name> <Value> |
267 |
1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt </Name> |
268 |
<Value> |
269 |
0.01 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> |
270 |
<Value> current error = 9.516258e-01 </Value> </Parameter> </Model> |
271 |
</Component> <Component rank="1"> <Model id="128" type="ODETEST"> <Parameter |
272 |
type="float"> <Name> a </Name> <Value> |
273 |
0.9 </Value> </Parameter> <Parameter type="Link"> <Name> f </Name> <Value> |
274 |
<Link> <Target> 127 </Target> <Attribute> u </Attribute> </Link> </Value> |
275 |
</Parameter> <Parameter type="float"> <Name> tend </Name> <Value> |
276 |
1.0 </Value> </Parameter> <Parameter type="float"> <Name> u </Name> <Value> |
277 |
-10.0 </Value> </Parameter> <Parameter type="float"> <Name> tol </Name> |
278 |
<Value> 1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt |
279 |
</Name> <Value> |
280 |
0.1 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> <Value> |
281 |
current error = 1.904837e+01 </Value> </Parameter> </Model> </Component> |
282 |
<Component rank="2"> <Model id="129" type="Messenger"> <Parameter |
283 |
type="Link"> <Name> message </Name> <Value> <Link> <Target> 127 </Target> |
284 |
<Attribute> message </Attribute> </Link> </Value> </Parameter> </Model> |
285 |
</Component> </Simulation> <Model id="128" type="ODETEST"> <Parameter |
286 |
type="float"> <Name> a </Name> <Value> |
287 |
0.9 </Value> </Parameter> <Parameter type="Link"> <Name> f </Name> <Value> |
288 |
<Link> <Target> 127 </Target> <Attribute> u </Attribute> </Link> </Value> |
289 |
</Parameter> <Parameter type="float"> <Name> tend </Name> <Value> |
290 |
1.0 </Value> </Parameter> <Parameter type="float"> <Name> u </Name> <Value> |
291 |
-10.0 </Value> </Parameter> <Parameter type="float"> <Name> tol </Name> |
292 |
<Value> 1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt |
293 |
</Name> <Value> |
294 |
0.1 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> <Value> |
295 |
current error = 1.904837e+01 </Value> </Parameter> </Model> <Model id="127" |
296 |
type="ODETEST"> <Parameter type="float"> <Name> a </Name> <Value> |
297 |
0.9 </Value> </Parameter> <Parameter type="Link"> <Name> f </Name> <Value> |
298 |
<Link> <Target> 128 </Target> <Attribute> u </Attribute> </Link> </Value> |
299 |
</Parameter> <Parameter type="float"> <Name> tend </Name> <Value> |
300 |
1.0 </Value> </Parameter> <Parameter type="int"> <Name> u </Name> <Value> 10 |
301 |
</Value> </Parameter> <Parameter type="float"> <Name> tol </Name> <Value> |
302 |
1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt </Name> |
303 |
<Value> |
304 |
0.01 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> |
305 |
<Value> current error = 9.516258e-01 </Value> </Parameter> </Model> <Model |
306 |
id="127" type="ODETEST"> <Parameter type="float"> <Name> a </Name> <Value> |
307 |
0.9 </Value> </Parameter> <Parameter type="Link"> <Name> f </Name> <Value> |
308 |
<Link> <Target> 128 </Target> <Attribute> u </Attribute> </Link> </Value> |
309 |
</Parameter> <Parameter type="float"> <Name> tend </Name> <Value> |
310 |
1.0 </Value> </Parameter> <Parameter type="int"> <Name> u </Name> <Value> 10 |
311 |
</Value> </Parameter> <Parameter type="float"> <Name> tol </Name> <Value> |
312 |
1e-08 </Value> </Parameter> <Parameter type="float"> <Name> dt </Name> |
313 |
<Value> |
314 |
0.01 </Value> </Parameter> <Parameter type="str"> <Name> message </Name> |
315 |
<Value> current error = 9.516258e-01 </Value> </Parameter> </Model> </ESys> |
316 |
''' % (modulename, modelname) |
317 |
|
318 |
def testModuleAttribute(self): |
319 |
modeldoc = parse(self._xml('run_xml', 'ODETEST')) |
320 |
|
321 |
def testModuleAttributeFails(self): |
322 |
try: |
323 |
modeldoc = parse(self._xml('a', 'b')) |
324 |
except ImportError: |
325 |
return # correct |
326 |
|
327 |
assert False, "This test should have resulted in an ImportError" |
328 |
|
329 |
class Messenger(Model): |
330 |
def __init__(self, *args, **kwargs): |
331 |
Model.__init__(self, *args, **kwargs) |
332 |
self.declareParameter(message="none") |
333 |
|
334 |
def doInitialization(self): |
335 |
self.__t=0 |
336 |
#print "I start talking now!" |
337 |
|
338 |
def doStepPostprocessing(self,dt): |
339 |
self.__t+=dt |
340 |
#print "Message (time %e) : %s "%(self.__t,self.message) |
341 |
|
342 |
def doFinalization(self): |
343 |
#print "I have no more to say!" |
344 |
pass |
345 |
|
346 |
|
347 |
|
348 |
class ODETEST(Model): |
349 |
""" implements a solver for the ODE |
350 |
|
351 |
du/dt=a*u+f(t) |
352 |
|
353 |
we use a implicit euler scheme : |
354 |
|
355 |
u_n-u_{n-1}= dt*a u_n + st*f(t_n) |
356 |
|
357 |
to get u_n we run an iterative process |
358 |
|
359 |
u_{n.k}=u_{n-1}+dt*(a u_{n.i-1} + f(t_n)) |
360 |
|
361 |
|
362 |
input for this model are step size dt, end time tend and a value for |
363 |
a, f and initial value for u. we need also a tolerance tol for a |
364 |
stopping criterion. |
365 |
|
366 |
""" |
367 |
|
368 |
def __init__(self, *args, **kwargs): |
369 |
Model.__init__(self, *args, **kwargs) |
370 |
self.declareParameter(tend=1.,dt=0.1,a=0.9,u=10.,f=0.,message="",tol=1.e-8) |
371 |
|
372 |
def doInitialization(self): |
373 |
self.__tn=0 |
374 |
self.__iter=0 |
375 |
|
376 |
def doStepPreprocessing(self,dt): |
377 |
self.__iter=0 |
378 |
self.__u_last=self.u |
379 |
|
380 |
def doStep(self,dt): |
381 |
self.__iter+=1 |
382 |
self.__u_old=self.u |
383 |
self.u=self.__u_last+dt*(self.a*self.__u_old+self.f) |
384 |
|
385 |
def terminate(self): |
386 |
if self.__iter<1: |
387 |
return False |
388 |
else: |
389 |
return abs(self.__u_old-self.u)<self.tol*abs(self.u) |
390 |
|
391 |
def doStepPostprocessing(self,dt): |
392 |
self.__tn+=dt |
393 |
self.message="current error = %e"%abs(self.u-10.*math.exp((self.a-1.)*self.__tn)) |
394 |
|
395 |
def getSafeTimeStepSize(self,dt): |
396 |
return min(self.dt,1./(abs(self.a)+1.)) |
397 |
|
398 |
def finalize(self): |
399 |
return self.__tn>=self.tend |
400 |
|
401 |
|
402 |
|
403 |
if __name__ == "__main__": |
404 |
unittest.main() |
405 |
|