/[escript]/trunk/escript/py_src/modelframe.py
ViewVC logotype

Contents of /trunk/escript/py_src/modelframe.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 1042 - (show annotations)
Mon Mar 19 03:50:34 2007 UTC (12 years, 8 months ago) by gross
File MIME type: text/x-python
File size: 46240 byte(s)
a small fix which deals with the case that a Model class
does not take the argument debug.



1 # $Id$
2
3 """
4 Environment for implementing models in escript
5
6 @var __author__: name of author
7 @var __copyright__: copyrights
8 @var __license__: licence agreement
9 @var __url__: url entry point on documentation
10 @var __version__: version
11 @var __date__: date of the version
12 """
13
14 __author__="Lutz Gross, l.gross@uq.edu.au"
15 __copyright__=""" Copyright (c) 2006 by ACcESS MNRF
16 http://www.access.edu.au
17 Primary Business: Queensland, Australia"""
18 __license__="""Licensed under the Open Software License version 3.0
19 http://www.opensource.org/licenses/osl-3.0.php"""
20 __url__="http://www.iservo.edu.au/esys"
21 __version__="$Revision$"
22 __date__="$Date$"
23
24
25 from types import StringType,IntType,FloatType,BooleanType,ListType,DictType
26 from sys import stdout
27 import numarray
28 import operator
29 import itertools
30 import time
31 import os
32
33 # import the 'set' module if it's not defined (python2.3/2.4 difference)
34 try:
35 set
36 except NameError:
37 from sets import Set as set
38
39 from xml.dom import minidom
40
41
42 def all(seq):
43 for x in seq:
44 if not x:
45 return False
46 return True
47
48 def any(seq):
49 for x in seq:
50 if x:
51 return True
52 return False
53
54 def importName(modulename, name):
55 """ Import a named object from a module in the context of this function,
56 which means you should use fully qualified module paths.
57 Return None on failure.
58
59 This function from: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52241
60 """
61 module = __import__(modulename, globals(), locals(), [name])
62
63 try:
64 return vars(module)[name]
65 except KeyError:
66 raise ImportError("Could not import %s from %s" % (name, modulename))
67
68 class ESySXMLParser(object):
69 """
70 parser for ESysXML file
71 """
72 def __init__(self,xml, debug=False):
73 self.__dom = minidom.parseString(xml)
74 self.__linkable_object_registry= {}
75 self.__link_registry= []
76 self.__esys=self.__dom.getElementsByTagName('ESys')[0]
77 self.debug=debug
78
79 def getClassPath(self, node):
80 type = node.getAttribute("type")
81 if (node.getAttribute("module")):
82 module = node.getAttribute("module")
83 return importName(module, type)
84 else:
85 return importName("__main__", type)
86
87 def setLinks(self):
88 for obj_id, link in self.__link_registry:
89 link.target = self.__linkable_object_registry[obj_id]
90
91 def parse(self):
92 """
93 parser method for EsysXML and returns the list of generating ParameterSets
94 """
95 found=[]
96 for node in self.__esys.childNodes:
97 if isinstance(node, minidom.Element):
98 if node.tagName == 'Simulation':
99 found.append(Simulation.fromDom(self, node))
100 elif node.tagName == 'Model':
101 found.append(self.getClassPath(node).fromDom(self, node))
102 elif node.tagName == 'ParameterSet':
103 found.append(self.getClassPath(node).fromDom(self, node))
104 else:
105 raise "Invalid type, %r" % node.getAttribute("type")
106 self.setLinks()
107 return found
108
109 def registerLink(self,obj_id, link):
110 self.__link_registry.append((int(obj_id),link))
111
112 def registerLinkableObject(self,obj, node):
113 id_str=node.getAttribute('id').strip()
114 if len(id_str)>0:
115 id=int(id_str)
116 if self.__linkable_object_registry.has_key(id):
117 raise ValueError("Object id %s already exists."%id)
118 else:
119 self.__linkable_object_registry[id]=obj
120
121 def getComponent(self, node):
122 """
123 returns a single component + rank from a simulation
124 parser method for EsysXML and returns the list of generating ParameterSets
125 """
126 rank = int(node.getAttribute("rank"))
127 for n in node.childNodes:
128 if isinstance(n, minidom.Element):
129 if n.tagName == 'Simulation':
130 return (rank, Simulation.fromDom(self, n))
131 elif n.tagName == 'Model':
132 return (rank, self.getClassPath(n).fromDom(self, n))
133 elif n.tagName == 'ParameterSet':
134 return (rank, self.getClassPath(n).fromDom(self, n))
135 else:
136 raise ValueError("illegal component type %s"%n.tagName)
137 raise ValueError("cannot resolve Component")
138
139 class ESySXMLCreator(object):
140 """
141 creates an XML Dom representation
142 """
143 def __init__(self):
144 self.__dom=minidom.Document()
145 self.__esys =self.__dom.createElement('ESys')
146 self.__dom.appendChild(self.__esys)
147 self.__linkable_object_registry={}
148 self.__number_sequence = itertools.count(100)
149 def getRoot(self):
150 return self.__esys
151 def createElement(self,name):
152 return self.__dom.createElement(name)
153 def createTextNode(self,name):
154 return self.__dom.createTextNode(name)
155 def getElementById(self,name):
156 return self.__dom.getElementById(name)
157 def createDataNode(self, tagName, data):
158 """
159 C{createDataNode}s are the building blocks of the xml documents constructed in
160 this module.
161
162 @param tagName: the associated xml tag
163 @param data: the values in the tag
164 """
165 n = self.createElement(tagName)
166 n.appendChild(self.createTextNode(str(data)))
167 return n
168 def getLinkableObjectId(self, obj):
169 for id, o in self.__linkable_object_registry.items():
170 if o == obj: return id
171 id =self.__number_sequence.next()
172 self.__linkable_object_registry[id]=obj
173 return id
174
175 def registerLinkableObject(self, obj, node):
176 """
177 returns a unique object id for object obj
178 """
179 id=self.getLinkableObjectId(obj)
180 node.setAttribute('id',str(id))
181 node.setIdAttribute("id")
182
183 def includeTargets(self):
184 target_written=True
185 while target_written:
186 targetsList =self.__dom.getElementsByTagName('Target')
187 target_written=False
188 for element in targetsList:
189 targetId = int(element.firstChild.nodeValue.strip())
190 if self.getElementById(str(targetId)): continue
191 targetObj = self.__linkable_object_registry[targetId]
192 targetObj.toDom(self, self.__esys)
193 target_written=True
194
195 def toprettyxml(self):
196 self.includeTargets()
197 return self.__dom.toprettyxml()
198
199 class Link:
200 """
201 A Link makes an attribute of an object callable::
202
203 o.object()
204 o.a=8
205 l=Link(o,"a")
206 assert l()==8
207 """
208
209 def __init__(self,target,attribute=None):
210 """
211 Creates a link to the object target. If attribute is given, the link is
212 establised to this attribute of the target. Otherwise the attribute is
213 undefined.
214 """
215 self.target = target
216 self.attribute = None
217 self.setAttributeName(attribute)
218
219 def getTarget(self):
220 """
221 returns the target
222 """
223 return self.target
224 def getAttributeName(self):
225 """
226 returns the name of the attribute the link is pointing to
227 """
228 return self.attribute
229
230 def setAttributeName(self,attribute):
231 """
232 Set a new attribute name to be collected from the target object. The
233 target object must have the attribute with name attribute.
234 """
235 if attribute and self.target:
236 if isinstance(self.target,LinkableObject):
237 if not self.target.hasAttribute(attribute):
238 raise AttributeError("%s: target %s has no attribute %s."%(self, self.target, attribute))
239 else:
240 if not hasattr(self.target,attribute):
241 raise AttributeError("%s: target %s has no attribute %s."%(self, self.target, attribute))
242 self.attribute = attribute
243
244 def hasDefinedAttributeName(self):
245 """
246 Returns true if an attribute name is set.
247 """
248 return self.attribute != None
249
250 def __repr__(self):
251 """
252 Returns a string representation of the link.
253 """
254 if self.hasDefinedAttributeName():
255 return "<Link to attribute %s of %s>" % (self.attribute, self.target)
256 else:
257 return "<Link to target %s>" % self.target
258
259 def __call__(self,name=None):
260 """
261 Returns the value of the attribute of the target object. If the
262 atrribute is callable then the return value of the call is returned.
263 """
264 if name:
265 out=getattr(self.target, name)
266 else:
267 out=getattr(self.target, self.attribute)
268
269 if callable(out):
270 return out()
271 else:
272 return out
273
274 def toDom(self, esysxml, node):
275 """
276 C{toDom} method of Link. Creates a Link node and appends it to the
277 current XML esysxml.
278 """
279 link = esysxml.createElement('Link')
280 assert (self.target != None), ("Target was none, name was %r" % self.attribute)
281 link.appendChild(esysxml.createDataNode('Target', esysxml.getLinkableObjectId(self.target)))
282 # this use of id will not work for purposes of being able to retrieve the intended
283 # target from the xml later. I need a better unique identifier.
284 assert self.attribute, "You can't xmlify a Link without a target attribute"
285 link.appendChild(esysxml.createDataNode('Attribute', self.attribute))
286 node.appendChild(link)
287
288 def fromDom(cls, esysxml, node):
289 targetid = int(node.getElementsByTagName("Target")[0].firstChild.nodeValue.strip())
290 attribute =str(node.getElementsByTagName("Attribute")[0].firstChild.nodeValue.strip())
291 l = cls(None, attribute)
292 esysxml.registerLink(targetid, l)
293 return l
294
295 fromDom = classmethod(fromDom)
296
297 class LinkableObject(object):
298 """
299 An object that allows to link its attributes to attributes of other objects
300 via a Link object. For instance::
301
302 p = LinkableObject()
303 p.x = Link(o,"name")
304 print p.x
305
306 links attribute C{x} of C{p} to the attribute name of object C{o}.
307
308 C{p.x} will contain the current value of attribute C{name} of object
309 C{o}.
310
311 If the value of C{getattr(o, "name")} is callable, C{p.x} will return
312 the return value of the call.
313 """
314
315
316 def __init__(self, id = None, debug=False):
317 """
318 Initializes LinkableObject so that we can operate on Links
319 """
320 self.debug = debug
321 self.__linked_attributes={}
322
323 def trace(self, msg):
324 """
325 If debugging is on, print the message, otherwise do nothing
326 """
327 if self.debug:
328 print "%s: %s"%(str(self),msg)
329
330 def __getattr__(self,name):
331 """
332 Returns the value of attribute name. If the value is a Link object the
333 object is called and the return value is returned.
334 """
335 out = self.getAttributeObject(name)
336 if isinstance(out,Link):
337 return out()
338 else:
339 return out
340
341 def getAttributeObject(self,name):
342 """
343 Return the object stored for attribute name.
344 """
345
346 if self.__dict__.has_key(name):
347 return self.__dict__[name]
348
349 if self.__linked_attributes.has_key(name):
350 return self.__linked_attributes[name]
351
352 if self.__class__.__dict__.has_key(name):
353 return self.__class.__dict__[name]
354
355 raise AttributeError,"No attribute %s."%name
356
357 def hasAttribute(self,name):
358 """
359 Returns True if self as attribute name.
360 """
361 return self.__dict__.has_key(name) or self.__linked_attributes.has_key(name) or self.__class__.__dict__.has_key(name)
362
363 def __setattr__(self,name,value):
364 """
365 Sets the value for attribute name. If value is a Link the target
366 attribute is set to name if no attribute has been specified.
367 """
368
369 if self.__dict__.has_key(name):
370 del self.__dict__[name]
371
372 if isinstance(value,Link):
373 if not value.hasDefinedAttributeName():
374 value.setAttributeName(name)
375 self.__linked_attributes[name] = value
376
377 self.trace("attribute %s is now linked by %s."%(name,value))
378 else:
379 self.__dict__[name] = value
380
381 def __delattr__(self,name):
382 """
383 Removes the attribute name.
384 """
385
386 if self.__linked_attributes.has_key[name]:
387 del self.__linked_attributes[name]
388 elif self.__dict__.has_key(name):
389 del self.__dict__[name]
390 else:
391 raise AttributeError,"No attribute %s."%name
392
393 class _ParameterIterator:
394 def __init__(self,parameterset):
395
396 self.__set=parameterset
397 self.__iter=iter(parameterset.parameters)
398
399 def next(self):
400 o=self.__iter.next()
401 return (o,self.__set.getAttributeObject(o))
402
403 def __iter__(self):
404 return self
405
406 class ParameterSet(LinkableObject):
407 """
408 A class which allows to emphazise attributes to be written and read to XML
409
410 Leaves of an ESySParameters object can be:
411
412 - a real number
413 - a integer number
414 - a string
415 - a boolean value
416 - a ParameterSet object
417 - a Simulation object
418 - a Model object
419 - a numarray object
420 - a list of booleans
421 - any other object (not considered by writeESySXML and writeXML)
422
423 Example how to create an ESySParameters object::
424
425 p11=ParameterSet(gamma1=1.,gamma2=2.,gamma3=3.)
426 p1=ParameterSet(dim=2,tol_v=0.001,output_file="/tmp/u.%3.3d.dx",runFlag=True,parm11=p11)
427 parm=ParameterSet(parm1=p1,parm2=ParameterSet(alpha=Link(p11,"gamma1")))
428
429 This can be accessed as::
430
431 parm.parm1.gamma=0.
432 parm.parm1.dim=2
433 parm.parm1.tol_v=0.001
434 parm.parm1.output_file="/tmp/u.%3.3d.dx"
435 parm.parm1.runFlag=True
436 parm.parm1.parm11.gamma1=1.
437 parm.parm1.parm11.gamma2=2.
438 parm.parm1.parm11.gamma3=3.
439 parm.parm2.alpha=1. (value of parm.parm1.parm11.gamma1)
440 """
441 def __init__(self, parameters=[], **kwargs):
442 """
443 Creates a ParameterSet with parameters parameters.
444 """
445 LinkableObject.__init__(self, **kwargs)
446 self.parameters = set()
447 self.declareParameters(parameters)
448
449 def __repr__(self):
450 return "<%s %d>"%(self.__class__.__name__,id(self))
451
452 def declareParameter(self,**parameters):
453 """
454 Declares a new parameter(s) and its (their) initial value.
455 """
456 self.declareParameters(parameters)
457
458 def declareParameters(self,parameters):
459 """
460 Declares a set of parameters. parameters can be a list, a dictionary
461 or a ParameterSet.
462 """
463 if isinstance(parameters,ListType):
464 parameters = zip(parameters, itertools.repeat(None))
465 if isinstance(parameters,DictType):
466 parameters = parameters.iteritems()
467
468 for prm, value in parameters:
469 setattr(self,prm,value)
470 self.parameters.add(prm)
471
472 def releaseParameters(self,name):
473 """
474 Removes parameter name from the paramameters.
475 """
476 if self.isParameter(name):
477 self.parameters.remove(name)
478 self.trace("parameter %s has been removed."%name)
479
480 def checkLinkTargets(self, models, hash):
481 """
482 returns a set of tuples ("<self>(<name>)", <target model>) if the parameter <name> is linked to model <target model>
483 but <target model> is not in the list models. If the a parameter is linked to another parameter set which is not in the hash list
484 the parameter set is checked for its models. hash gives the call history.
485 """
486 out=set()
487 for name, value in self:
488 if isinstance(value, Link):
489 m=value.getTarget()
490 if isinstance(m, Model):
491 if not m in models: out.add( (str(self)+"("+name+")",m) )
492 elif isinstance(m, ParameterSet) and not m in hash:
493 out|=set( [ (str(self)+"("+name+")."+f[0],f[1]) for f in m.checkLinkTargets(models, hash+[ self ] ) ] )
494 return out
495
496 def __iter__(self):
497 """
498 Creates an iterator over the parameter and their values.
499 """
500 return _ParameterIterator(self)
501
502 def showParameters(self):
503 """
504 Returns a descrition of the parameters.
505 """
506 out="{"
507 notfirst=False
508 for i,v in self:
509 if notfirst: out=out+","
510 notfirst=True
511 if isinstance(v,ParameterSet):
512 out="%s\"%s\" : %s"%(out,i,v.showParameters())
513 else:
514 out="%s\"%s\" : %s"%(out,i,v)
515 return out+"}"
516
517 def __delattr__(self,name):
518 """
519 Removes the attribute name.
520 """
521 LinkableObject.__delattr__(self,name)
522 try:
523 self.releaseParameter(name)
524 except:
525 pass
526
527 def toDom(self, esysxml, node):
528 """
529 C{toDom} method of Model class
530 """
531 pset = esysxml.createElement('ParameterSet')
532 pset.setAttribute('type', self.__class__.__name__)
533 pset.setAttribute('module', self.__class__.__module__)
534 esysxml.registerLinkableObject(self, pset)
535 self._parametersToDom(esysxml, pset)
536 node.appendChild(pset)
537
538 def _parametersToDom(self, esysxml, node):
539 for name,value in self:
540 # convert list to numarray when possible:
541 if isinstance (value, list):
542 elem_type=-1
543 for i in value:
544 if isinstance(i, bool):
545 elem_type = max(elem_type,0)
546 elif isinstance(i, int):
547 elem_type = max(elem_type,1)
548 elif isinstance(i, float):
549 elem_type = max(elem_type,2)
550 if elem_type == 0: value = numarray.array(value,numarray.Bool)
551 if elem_type == 1: value = numarray.array(value,numarray.Int)
552 if elem_type == 2: value = numarray.array(value,numarray.Float)
553
554 param = esysxml.createElement('Parameter')
555 param.setAttribute('type', value.__class__.__name__)
556
557 param.appendChild(esysxml.createDataNode('Name', name))
558
559 val = esysxml.createElement('Value')
560 if isinstance(value,(ParameterSet,Link,DataSource)):
561 value.toDom(esysxml, val)
562 param.appendChild(val)
563 elif isinstance(value, numarray.NumArray):
564 shape = value.getshape()
565 if isinstance(shape, tuple):
566 size = reduce(operator.mul, shape)
567 shape = ' '.join(map(str, shape))
568 else:
569 size = shape
570 shape = str(shape)
571
572 arraytype = value.type()
573 if isinstance(arraytype, numarray.BooleanType):
574 arraytype_str="Bool"
575 elif isinstance(arraytype, numarray.IntegralType):
576 arraytype_str="Int"
577 elif isinstance(arraytype, numarray.FloatingType):
578 arraytype_str="Float"
579 elif isinstance(arraytype, numarray.ComplexType):
580 arraytype_str="Complex"
581 else:
582 arraytype_str=str(arraytype)
583 numarrayElement = esysxml.createElement('NumArray')
584 numarrayElement.appendChild(esysxml.createDataNode('ArrayType', arraytype_str))
585 numarrayElement.appendChild(esysxml.createDataNode('Shape', shape))
586 numarrayElement.appendChild(esysxml.createDataNode('Data', ' '.join(
587 [str(x) for x in numarray.reshape(value, size)])))
588 val.appendChild(numarrayElement)
589 param.appendChild(val)
590 elif isinstance(value, list):
591 param.appendChild(esysxml.createDataNode('Value', ' '.join([str(x) for x in value]) ))
592 elif isinstance(value, (str, bool, int, float, type(None))):
593 param.appendChild(esysxml.createDataNode('Value', str(value)))
594 elif isinstance(value, dict):
595 dic = esysxml.createElement('dictionary')
596 if len(value.keys())>0:
597 dic.setAttribute('key_type', value.keys()[0].__class__.__name__)
598 dic.setAttribute('value_type', value[value.keys()[0]].__class__.__name__)
599 for k,v in value.items():
600 i=esysxml.createElement('item')
601 i.appendChild(esysxml.createDataNode('key', k))
602 i.appendChild(esysxml.createDataNode('value', v))
603 dic.appendChild(i)
604 param.appendChild(dic)
605 else:
606 raise ValueError("cannot serialize %s type to XML."%str(value.__class__))
607
608 node.appendChild(param)
609
610 def fromDom(cls, esysxml, node):
611 # Define a host of helper functions to assist us.
612 def _children(node):
613 """
614 Remove the empty nodes from the children of this node.
615 """
616 ret = []
617 for x in node.childNodes:
618 if isinstance(x, minidom.Text):
619 if x.nodeValue.strip():
620 ret.append(x)
621 else:
622 ret.append(x)
623 return ret
624
625 def _floatfromValue(esysxml, node):
626 return float(node.nodeValue.strip())
627
628 def _stringfromValue(esysxml, node):
629 return str(node.nodeValue.strip())
630
631 def _intfromValue(esysxml, node):
632 return int(node.nodeValue.strip())
633
634 def _boolfromValue(esysxml, node):
635 return _boolfromstring(node.nodeValue.strip())
636
637 def _nonefromValue(esysxml, node):
638 return None
639
640 def _numarrayfromValue(esysxml, node):
641 for node in _children(node):
642 if node.tagName == 'ArrayType':
643 arraytype = node.firstChild.nodeValue.strip()
644 if node.tagName == 'Shape':
645 shape = node.firstChild.nodeValue.strip()
646 shape = [int(x) for x in shape.split()]
647 if node.tagName == 'Data':
648 data = node.firstChild.nodeValue.strip()
649 data = [float(x) for x in data.split()]
650 return numarray.reshape(numarray.array(data, type=getattr(numarray, arraytype)),
651 shape)
652
653 def _listfromValue(esysxml, node):
654 return [x for x in node.nodeValue.split()]
655
656 def _boolfromstring(s):
657 if s == 'True':
658 return True
659 else:
660 return False
661 # Mapping from text types in the xml to methods used to process trees of that type
662 ptypemap = {"Simulation": Simulation.fromDom,
663 "Model":Model.fromDom,
664 "ParameterSet":ParameterSet.fromDom,
665 "Link":Link.fromDom,
666 "DataSource":DataSource.fromDom,
667 "float":_floatfromValue,
668 "int":_intfromValue,
669 "str":_stringfromValue,
670 "bool":_boolfromValue,
671 "list":_listfromValue,
672 "NumArray":_numarrayfromValue,
673 "NoneType":_nonefromValue,
674 }
675
676 parameters = {}
677 for n in _children(node):
678 ptype = n.getAttribute("type")
679 if not ptypemap.has_key(ptype):
680 raise KeyError("cannot handle parameter type %s."%ptype)
681
682 pname = pvalue = None
683 for childnode in _children(n):
684 if childnode.tagName == "Name":
685 pname = childnode.firstChild.nodeValue.strip()
686
687 if childnode.tagName == "Value":
688 nodes = _children(childnode)
689 pvalue = ptypemap[ptype](esysxml, nodes[0])
690
691 parameters[pname] = pvalue
692
693 # Create the instance of ParameterSet
694 try:
695 o = cls(debug=esysxml.debug)
696 except TypeError, inst:
697 print inst.args[0]
698 if inst.args[0]=="__init__() got an unexpected keyword argument 'debug'":
699 raise TypeError("The Model class %s __init__ needs to have argument 'debug'.")
700 else:
701 raise inst
702 o.declareParameters(parameters)
703 esysxml.registerLinkableObject(o, node)
704 return o
705
706 fromDom = classmethod(fromDom)
707
708 def writeXML(self,ostream=stdout):
709 """
710 Writes the object as an XML object into an output stream.
711 """
712 esysxml=ESySXMLCreator()
713 self.toDom(esysxml, esysxml.getRoot())
714 ostream.write(esysxml.toprettyxml())
715
716 class Model(ParameterSet):
717 """
718 A Model object represents a processess marching over time until a
719 finalizing condition is fullfilled. At each time step an iterative
720 process can be performed and the time step size can be controlled. A
721 Model has the following work flow::
722
723 doInitialization()
724 while not terminateInitialIteration(): doInitializationiStep()
725 doInitialPostprocessing()
726 while not finalize():
727 dt=getSafeTimeStepSize(dt)
728 doStepPreprocessing(dt)
729 while not terminateIteration(): doStep(dt)
730 doStepPostprocessing(dt)
731 doFinalization()
732
733 where C{doInitialization}, C{finalize}, C{getSafeTimeStepSize},
734 C{doStepPreprocessing}, C{terminateIteration}, C{doStepPostprocessing},
735 C{doFinalization} are methods of the particular instance of a Model. The
736 default implementations of these methods have to be overwritten by the
737 subclass implementing a Model.
738 """
739
740 UNDEF_DT=1.e300
741
742 def __init__(self,parameters=[],**kwargs):
743 """
744 Creates a model.
745
746 Just calls the parent constructor.
747 """
748 ParameterSet.__init__(self, parameters=parameters,**kwargs)
749
750 def __str__(self):
751 return "<%s %d>"%(self.__class__.__name__,id(self))
752
753
754 def setUp(self):
755 """
756 Sets up the model.
757
758 This function may be overwritten.
759 """
760 pass
761
762 def doInitialization(self):
763 """
764 Initializes the time stepping scheme. This method is not called in case of a restart.
765
766 This function may be overwritten.
767 """
768 pass
769 def doInitialStep(self):
770 """
771 performs an iteration step in the initialization phase. This method is not called in case of a restart.
772
773 This function may be overwritten.
774 """
775 pass
776
777 def terminateInitialIteration(self):
778 """
779 Returns True if iteration at the inital phase is terminated.
780 """
781 return True
782
783 def doInitialPostprocessing(self):
784 """
785 finalises the initialization iteration process. This method is not called in case of a restart.
786
787 This function may be overwritten.
788 """
789 pass
790
791 def getSafeTimeStepSize(self,dt):
792 """
793 Returns a time step size which can safely be used.
794
795 C{dt} gives the previously used step size.
796
797 This function may be overwritten.
798 """
799 return self.UNDEF_DT
800
801 def finalize(self):
802 """
803 Returns False if the time stepping is finalized.
804
805 This function may be overwritten.
806 """
807 return False
808
809 def doFinalization(self):
810 """
811 Finalizes the time stepping.
812
813 This function may be overwritten.
814 """
815 pass
816
817 def doStepPreprocessing(self,dt):
818 """
819 Sets up a time step of step size dt.
820
821 This function may be overwritten.
822 """
823 pass
824
825 def doStep(self,dt):
826 """
827 Executes an iteration step at a time step.
828
829 C{dt} is the currently used time step size.
830
831 This function may be overwritten.
832 """
833 pass
834
835 def terminateIteration(self):
836 """
837 Returns True if iteration on a time step is terminated.
838 """
839 return True
840
841
842 def doStepPostprocessing(self,dt):
843 """
844 finalises the time step.
845
846 dt is the currently used time step size.
847
848 This function may be overwritten.
849 """
850 pass
851
852 def toDom(self, esysxml, node):
853 """
854 C{toDom} method of Model class
855 """
856 pset = esysxml.createElement('Model')
857 pset.setAttribute('type', self.__class__.__name__)
858 pset.setAttribute('module', self.__class__.__module__)
859 esysxml.registerLinkableObject(self, pset)
860 node.appendChild(pset)
861 self._parametersToDom(esysxml, pset)
862
863 class Simulation(Model):
864 """
865 A Simulation object is special Model which runs a sequence of Models.
866
867 The methods C{doInitialization}, C{finalize}, C{getSafeTimeStepSize},
868 C{doStepPreprocessing}, C{terminateIteration}, C{doStepPostprocessing},
869 C{doFinalization} are executing the corresponding methods of the models in
870 the simulation.
871 """
872
873 FAILED_TIME_STEPS_MAX=20
874 MAX_ITER_STEPS=50
875 MAX_CHANGE_OF_DT=2.
876
877 def __init__(self, models=[], **kwargs):
878 """
879 Initiates a simulation from a list of models.
880 """
881 super(Simulation, self).__init__(**kwargs)
882 self.declareParameter(time=0.,
883 time_step=0,
884 dt = self.UNDEF_DT)
885 for m in models:
886 if not isinstance(m, Model):
887 raise TypeError("%s is not a subclass of Model."%m)
888 self.__models=[]
889 for i in range(len(models)):
890 self[i] = models[i]
891
892
893 def __repr__(self):
894 """
895 Returns a string representation of the Simulation.
896 """
897 return "<Simulation %r>" % self.__models
898
899 def __str__(self):
900 """
901 Returning Simulation as a string.
902 """
903 return "<Simulation %d>"%id(self)
904
905 def iterModels(self):
906 """
907 Returns an iterator over the models.
908 """
909 return self.__models
910
911 def __getitem__(self,i):
912 """
913 Returns the i-th model.
914 """
915 return self.__models[i]
916
917 def __setitem__(self,i,value):
918 """
919 Sets the i-th model.
920 """
921 if not isinstance(value,Model):
922 raise ValueError,"assigned value is not a Model but instance of %s"%(value.__class__.__name__,)
923 for j in range(max(i-len(self.__models)+1,0)):
924 self.__models.append(None)
925 self.__models[i]=value
926
927 def __len__(self):
928 """
929 Returns the number of models.
930 """
931 return len(self.__models)
932
933 def getAllModels(self):
934 """
935 returns a list of all models used in the Simulation including subsimulations
936 """
937 out=[]
938 for m in self.iterModels():
939 if isinstance(m, Simulation):
940 out+=m.getAllModels()
941 else:
942 out.append(m)
943 return list(set(out))
944
945 def checkModels(self, models, hash):
946 """
947 returns a list of (model,parameter, target model ) if the the parameter of model
948 is linking to the target_model which is not in list of models.
949 """
950 out=self.checkLinkTargets(models, hash + [self])
951 for m in self.iterModels():
952 if isinstance(m, Simulation):
953 out|=m.checkModels(models, hash)
954 else:
955 out|=m.checkLinkTargets(models, hash + [self])
956 return set( [ (str(self)+"."+f[0],f[1]) for f in out ] )
957
958
959 def getSafeTimeStepSize(self,dt):
960 """
961 Returns a time step size which can safely be used by all models.
962
963 This is the minimum over the time step sizes of all models.
964 """
965 out=min([o.getSafeTimeStepSize(dt) for o in self.iterModels()])
966 return out
967
968 def setUp(self):
969 """
970 performs the setup for all models
971 """
972 for o in self.iterModels():
973 o.setUp()
974
975 def doInitialization(self):
976 """
977 Initializes all models.
978 """
979 for o in self.iterModels():
980 o.doInitialization()
981 def doInitialStep(self):
982 """
983 performs an iteration step in the initialization step for all models
984 """
985 iter=0
986 while not self.terminateInitialIteration():
987 if iter==0: self.trace("iteration for initialization starts")
988 iter+=1
989 self.trace("iteration step %d"%(iter))
990 for o in self.iterModels():
991 o.doInitialStep()
992 if iter>self.MAX_ITER_STEPS:
993 raise IterationDivergenceError("initial iteration did not converge after %s steps."%iter)
994 self.trace("Initialization finalized after %s iteration steps."%iter)
995
996 def doInitialPostprocessing(self):
997 """
998 finalises the initialization iteration process for all models.
999 """
1000 for o in self.iterModels():
1001 o.doInitialPostprocessing()
1002 def finalize(self):
1003 """
1004 Returns True if any of the models is to be finalized.
1005 """
1006 return any([o.finalize() for o in self.iterModels()])
1007
1008 def doFinalization(self):
1009 """
1010 finalises the time stepping for all models.
1011 """
1012 for i in self.iterModels(): i.doFinalization()
1013 self.trace("end of time integation.")
1014
1015 def doStepPreprocessing(self,dt):
1016 """
1017 Initializes the time step for all models.
1018 """
1019 for o in self.iterModels():
1020 o.doStepPreprocessing(dt)
1021
1022 def terminateIteration(self):
1023 """
1024 Returns True if all iterations for all models are terminated.
1025 """
1026 out=all([o.terminateIteration() for o in self.iterModels()])
1027 return out
1028
1029 def terminateInitialIteration(self):
1030 """
1031 Returns True if all initial iterations for all models are terminated.
1032 """
1033 out=all([o.terminateInitialIteration() for o in self.iterModels()])
1034 return out
1035
1036 def doStepPostprocessing(self,dt):
1037 """
1038 finalises the iteration process for all models.
1039 """
1040 for o in self.iterModels():
1041 o.doStepPostprocessing(dt)
1042 self.time_step+=1
1043 self.time+=dt
1044 self.dt=dt
1045
1046 def doStep(self,dt):
1047 """
1048 Executes the iteration step at a time step for all model::
1049
1050 self.doStepPreprocessing(dt)
1051 while not self.terminateIteration():
1052 for all models:
1053 self.doStep(dt)
1054 self.doStepPostprocessing(dt)
1055 """
1056 self.iter=0
1057 while not self.terminateIteration():
1058 if self.iter==0: self.trace("iteration at %d-th time step %e starts"%(self.time_step+1,self.time+dt))
1059 self.iter+=1
1060 self.trace("iteration step %d"%(self.iter))
1061 for o in self.iterModels():
1062 o.doStep(dt)
1063 if self.iter>0: self.trace("iteration at %d-th time step %e finalized."%(self.time_step+1,self.time+dt))
1064
1065 def run(self,check_pointing=None):
1066 """
1067 Run the simulation by performing essentially::
1068
1069 self.setUp()
1070 if not restart:
1071 self.doInitialization()
1072 while not self.terminateInitialIteration(): self.doInitialStep()
1073 self.doInitialPostprocessing()
1074 while not self.finalize():
1075 dt=self.getSafeTimeStepSize()
1076 self.doStepPreprocessing(dt_new)
1077 self.doStep(dt_new)
1078 self.doStepPostprocessing(dt_new)
1079 self.doFinalization()
1080
1081 If one of the models in throws a C{FailedTimeStepError} exception a
1082 new time step size is computed through getSafeTimeStepSize() and the
1083 time step is repeated.
1084
1085 If one of the models in throws a C{IterationDivergenceError}
1086 exception the time step size is halved and the time step is repeated.
1087
1088 In both cases the time integration is given up after
1089 C{Simulation.FAILED_TIME_STEPS_MAX} attempts.
1090 """
1091 # check the completness of the models:
1092 # first a list of all the models involved in the simulation including subsimulations:
1093 #
1094 missing=self.checkModels(self.getAllModels(), [])
1095 if len(missing)>0:
1096 msg=""
1097 for l in missing:
1098 msg+="\n\t"+str(l[1])+" at "+l[0]
1099 raise MissingLink("link targets missing in the Simulation: %s"%msg)
1100 #==============================
1101 self.setUp()
1102 if self.time_step < 1:
1103 self.doInitialization()
1104 self.doInitialStep()
1105 self.doInitialPostprocessing()
1106 while not self.finalize():
1107 step_fail_counter=0
1108 iteration_fail_counter=0
1109 if self.time_step==0:
1110 dt_new=self.getSafeTimeStepSize(self.dt)
1111 else:
1112 dt_new=min(max(self.getSafeTimeStepSize(self.dt),self.dt/self.MAX_CHANGE_OF_DT),self.dt*self.MAX_CHANGE_OF_DT)
1113 self.trace("%d. time step %e (step size %e.)" % (self.time_step+1,self.time+dt_new,dt_new))
1114 end_of_step=False
1115 while not end_of_step:
1116 end_of_step=True
1117 if not dt_new>0:
1118 raise NonPositiveStepSizeError("non-positive step size in step %d"%(self.time_step+1))
1119 try:
1120 self.doStepPreprocessing(dt_new)
1121 self.doStep(dt_new)
1122 self.doStepPostprocessing(dt_new)
1123 except IterationDivergenceError:
1124 dt_new*=0.5
1125 end_of_step=False
1126 iteration_fail_counter+=1
1127 if iteration_fail_counter>self.FAILED_TIME_STEPS_MAX:
1128 raise SimulationBreakDownError("reduction of time step to achieve convergence failed after %s steps."%self.FAILED_TIME_STEPS_MAX)
1129 self.trace("Iteration failed. Time step is repeated with new step size %s."%dt_new)
1130 except FailedTimeStepError:
1131 dt_new=self.getSafeTimeStepSize(self.dt)
1132 end_of_step=False
1133 step_fail_counter+=1
1134 self.trace("Time step is repeated with new time step size %s."%dt_new)
1135 if step_fail_counter>self.FAILED_TIME_STEPS_MAX:
1136 raise SimulationBreakDownError("Time integration is given up after %d attempts."%step_fail_counter)
1137 if not check_pointing==None:
1138 if check_pointing.doDump():
1139 self.trace("check point is created.")
1140 self.writeXML()
1141 self.doFinalization()
1142
1143
1144 def toDom(self, esysxml, node):
1145 """
1146 C{toDom} method of Simulation class.
1147 """
1148 simulation = esysxml.createElement('Simulation')
1149 esysxml.registerLinkableObject(self, simulation)
1150 for rank, sim in enumerate(self.iterModels()):
1151 component = esysxml.createElement('Component')
1152 component.setAttribute('rank', str(rank))
1153 sim.toDom(esysxml, component)
1154 simulation.appendChild(component)
1155 node.appendChild(simulation)
1156
1157
1158 def fromDom(cls, esysxml, node):
1159 sims = []
1160 for n in node.childNodes:
1161 if isinstance(n, minidom.Text):
1162 continue
1163 sims.append(esysxml.getComponent(n))
1164 sims.sort(_comp)
1165 sim=cls([s[1] for s in sims], debug=esysxml.debug)
1166 esysxml.registerLinkableObject(sim, node)
1167 return sim
1168
1169 fromDom = classmethod(fromDom)
1170
1171 def _comp(a,b):
1172 if a[0]<a[1]:
1173 return 1
1174 elif a[0]>a[1]:
1175 return -1
1176 else:
1177 return 0
1178
1179 class IterationDivergenceError(Exception):
1180 """
1181 Exception which is thrown if there is no convergence of the iteration
1182 process at a time step.
1183
1184 But there is a chance that a smaller step could help to reach convergence.
1185 """
1186 pass
1187
1188 class FailedTimeStepError(Exception):
1189 """
1190 Exception which is thrown if the time step fails because of a step
1191 size that have been choosen to be too large.
1192 """
1193 pass
1194
1195 class SimulationBreakDownError(Exception):
1196 """
1197 Exception which is thrown if the simulation does not manage to
1198 progress in time.
1199 """
1200 pass
1201
1202 class NonPositiveStepSizeError(Exception):
1203 """
1204 Exception which is thrown if the step size is not positive.
1205 """
1206 pass
1207
1208 class MissingLink(Exception):
1209 """
1210 Exception thrown when a link is missing
1211 """
1212 pass
1213
1214 class DataSource(object):
1215 """
1216 Class for handling data sources, including local and remote files. This class is under development.
1217 """
1218
1219 def __init__(self, uri="file.ext", fileformat="unknown"):
1220 self.uri = uri
1221 self.fileformat = fileformat
1222
1223 def toDom(self, esysxml, node):
1224 """
1225 C{toDom} method of DataSource. Creates a DataSource node and appends it to the
1226 current XML esysxml.
1227 """
1228 ds = esysxml.createElement('DataSource')
1229 ds.appendChild(esysxml.createDataNode('URI', self.uri))
1230 ds.appendChild(esysxml.createDataNode('FileFormat', self.fileformat))
1231 node.appendChild(ds)
1232
1233 def fromDom(cls, esysxml, node):
1234 uri= str(node.getElementsByTagName("URI")[0].firstChild.nodeValue.strip())
1235 fileformat= str(node.getElementsByTagName("FileFormat")[0].firstChild.nodeValue.strip())
1236 ds = cls(uri, fileformat)
1237 return ds
1238
1239 def getLocalFileName(self):
1240 return self.uri
1241
1242 fromDom = classmethod(fromDom)
1243
1244 class RestartManager(object):
1245 """
1246 A restart manager which does two things: it decides when restart files have created (when doDump returns true) and
1247 manages directories for restart files. The method getNewDumper creates a new directory and returns its name.
1248
1249 This restart manager will decide to dump restart files if every dump_step calls of doDump or
1250 if more than dump_time since the last dump has elapsed. The restart manager controls two directories for dumping restart data, namely
1251 for the current and previous dump. This way the previous dump can be used for restart in the case the current dump failed.
1252
1253 @cvar SEC: unit of seconds, for instance for 5*RestartManager.SEC to define 5 seconds.
1254 @cvar MIN: unit of minutes, for instance for 5*RestartManager.MIN to define 5 minutes.
1255 @cvar H: unit of hours, for instance for 5*RestartManager.H to define 5 hours.
1256 @cvar D: unit of days, for instance for 5*RestartManager.D to define 5 days.
1257 """
1258 SEC=1.
1259 MIN=60.
1260 H=360.
1261 D=8640.
1262 def __init__(self,dump_time=1080., dump_step=None, dumper=None):
1263 """
1264 initializes the RestartManager.
1265
1266 @param dump_time: defines the minimum time interval in SEC between to dumps. If None, time is not used as criterion.
1267 @param dump_step: defines the number of calls of doDump between to dump events. If None, the call counter is not used as criterion.
1268 @param dumper: defines the directory for dumping restart files. Additionally the directories dumper+"_bkp" and dumper+"_bkp2" are used.
1269 if the directory does not exist it is created. If dumper is not present a unique directory within the current
1270 working directory is used.
1271 """
1272 self.__dump_step=dump_time
1273 self.__dump_time=dump_step
1274 self.__counter=0
1275 self.__saveMarker()
1276 if dumper == None:
1277 self.__dumper="restart"+str(os.getpid())
1278 else:
1279 self.__dumper=dumper
1280 self.__dumper_bkp=self.__dumper+"_bkp"
1281 self.__dumper_bkp2=self.__dumper+"_bkp2"
1282 self.__current_dumper=None
1283 def __saveMarker(self):
1284 self.__last_restart_time=time.time()
1285 self.__last_restart_counter=self.__counter
1286 def getCurrentDumper(self):
1287 """
1288 returns the name of the currently used dumper
1289 """
1290 return self.__current_dumper
1291 def doDump(self):
1292 """
1293 returns true the restart should be dumped. use C{getNewDumper} to get the directory name to be used.
1294 """
1295 if self.__dump_step == None:
1296 if self.__dump_step == None:
1297 out = False
1298 else:
1299 out = (self.__dump_step + self.__last_restart_counter) <= self.__counter
1300 else:
1301 if dump_step == None:
1302 out = (self.__last_restart_time + self.__dump_time) <= time.time()
1303 else:
1304 out = ( (self.__dump_step + self.__last_restart_counter) <= self.__counter) \
1305 or ( (self.__last_restart_time + self.__dump_time) <= time.time() )
1306 if out: self.__saveMarker()
1307 self__counter+=1
1308 def getNewDumper(self):
1309 """
1310 creates a new directory to be used for dumping and returns its name.
1311 """
1312 if os.access(self.__dumper_bkp,os.F_OK):
1313 if os.access(self.__dumper_bkp2, os.F_OK):
1314 raise RunTimeError("please remove %s."%self.__dumper_bkp2)
1315 try:
1316 os.rename(self.__dumper_bkp, self.__dumper_bkp2)
1317 except:
1318 self.__current_dumper=self.__dumper
1319 raise RunTimeError("renaming back-up directory %s failed. Use %s for restart."%(self.__dumper_bkp,self.__dumper))
1320 if os.access(self.__dumper,os.F_OK):
1321 if os.access(self.__dumper_bkp, os.F_OK):
1322 raise RunTimeError("please remove %s."%self.__dumper_bkp)
1323 try:
1324 os.rename(self.__dumper, self.__dumper_bkp)
1325 except:
1326 self.__current_dumper=self.__dumper_bkp2
1327 raise RunTimeError("moving directory %s to back-up failed. Use %s for restart."%(self.__dumper,self.__dumper_bkp2))
1328 try:
1329 os.mkdir(self.__dumper)
1330 except:
1331 self.__current_dumper=self.__dumper_bkp
1332 raise RunTimeError("creating a new restart directory %s failed. Use %s for restart."%(self.__dumper,self.__dumper_bkp))
1333 if os.access(self.__dumper_bkp2, os.F_OK): os.rmdir(self.__dumper_bkp2)
1334 return self.getCurrentDumper()
1335
1336
1337 # vim: expandtab shiftwidth=4:

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26