/[escript]/trunk/escript/py_src/modelframe.py
ViewVC logotype

Diff of /trunk/escript/py_src/modelframe.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 836 by gross, Mon Sep 4 22:37:25 2006 UTC revision 1312 by ksteube, Mon Sep 24 06:18:44 2007 UTC
# Line 1  Line 1 
1    #
2  # $Id$  # $Id$
3    #
4    #######################################################
5    #
6    #           Copyright 2003-2007 by ACceSS MNRF
7    #       Copyright 2007 by University of Queensland
8    #
9    #                http://esscc.uq.edu.au
10    #        Primary Business: Queensland, Australia
11    #  Licensed under the Open Software License version 3.0
12    #     http://www.opensource.org/licenses/osl-3.0.php
13    #
14    #######################################################
15    #
16    
17  """  """
18  Environment for implementing models in escript  Environment for implementing models in escript
# Line 24  __date__="$Date$" Line 38  __date__="$Date$"
38    
39  from types import StringType,IntType,FloatType,BooleanType,ListType,DictType  from types import StringType,IntType,FloatType,BooleanType,ListType,DictType
40  from sys import stdout  from sys import stdout
41    import numarray
42    import operator
43  import itertools  import itertools
44  # import modellib  temporarily removed!!!  import time
45    import os
46    
47  # import the 'set' module if it's not defined (python2.3/2.4 difference)  # import the 'set' module if it's not defined (python2.3/2.4 difference)
48  try:  try:
# Line 35  except NameError: Line 52  except NameError:
52    
53  from xml.dom import minidom  from xml.dom import minidom
54    
 def dataNode(document, tagName, data):  
     """  
     C{dataNode}s are the building blocks of the xml documents constructed in  
     this module.    
       
     @param document: the current xml document  
     @param tagName: the associated xml tag  
     @param data: the values in the tag  
     """  
     t = document.createTextNode(str(data))  
     n = document.createElement(tagName)  
     n.appendChild(t)  
     return n  
   
 def esysDoc():  
     """  
     Global method for creating an instance of an EsysXML document.  
     """  
     doc = minidom.Document()  
     esys = doc.createElement('ESys')  
     doc.appendChild(esys)  
     return doc, esys  
55    
56  def all(seq):  def all(seq):
57      for x in seq:      for x in seq:
# Line 70  def any(seq): Line 65  def any(seq):
65              return True              return True
66      return False      return False
67    
 LinkableObjectRegistry = {}  
   
 def registerLinkableObject(obj_id, o):  
     LinkableObjectRegistry[obj_id] = o  
   
 LinkRegistry = []  
   
 def registerLink(obj_id, l):  
     LinkRegistry.append((obj_id,l))  
   
 def parse(xml):  
     """  
     Generic parse method for EsysXML.  Without this, Links don't work.  
     """  
     global LinkRegistry, LinkableObjectRegistry  
     LinkRegistry = []  
     LinkableObjectRegistry = {}  
   
     doc = minidom.parseString(xml)  
     sim = getComponent(doc.firstChild)  
     for obj_id, link in LinkRegistry:  
         link.target = LinkableObjectRegistry[obj_id]  
   
     return sim  
   
68  def importName(modulename, name):  def importName(modulename, name):
69      """ Import a named object from a module in the context of this function,      """ Import a named object from a module in the context of this function,
70          which means you should use fully qualified module paths.          which means you should use fully qualified module paths.
           
71          Return None on failure.          Return None on failure.
72    
73          This function from: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52241          This function from: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52241
# Line 110  def importName(modulename, name): Line 79  def importName(modulename, name):
79      except KeyError:      except KeyError:
80          raise ImportError("Could not import %s from %s" % (name, modulename))          raise ImportError("Could not import %s from %s" % (name, modulename))
81    
82  def getComponent(doc):  class ESySXMLParser(object):
     """  
     Used to get components of Simualtions, Models.  
83      """      """
84      for node in doc.childNodes:      parser for ESysXML file
85                """
86          if isinstance(node, minidom.Element):      def __init__(self,xml, debug=False):
87              if node.tagName == 'Simulation':        self.__dom = minidom.parseString(xml)
88                  if node.getAttribute("type") == 'Simulation':        self.__linkable_object_registry= {}
89                      return Simulation.fromDom(node)        self.__link_registry=  []
90              if node.tagName == 'Model':        self.__esys=self.__dom.getElementsByTagName('ESys')[0]
91                  if (node.getAttribute("module")):        self.debug=debug
92                      model_module = node.getAttribute("module")    
93                      model_type = node.getAttribute("type")      def getClassPath(self, node):
94                      return importName(model_module, model_type).fromDom(node)          type = node.getAttribute("type")
95                  else:          if (node.getAttribute("module")):
96                      model_type = node.getAttribute("type")              module = node.getAttribute("module")
97                      model_subclasses = Model.__subclasses__()              return importName(module, type)
98                      for model in model_subclasses:          else:
99                          if model_type == model.__name__:              return importName("__main__", type)
                             return Model.fromDom(node)  
             if node.tagName == 'ParameterSet':  
                 parameter_type = node.getAttribute("type")  
                 return ParameterSet.fromDom(node)  
             raise "Invalid simulation type, %r" % node.getAttribute("type")  
           
100    
101      raise ValueError("No Simulation Found")      def setLinks(self):
102                        for obj_id, link in self.__link_registry:
103                link.target = self.__linkable_object_registry[obj_id]
104    
105        def parse(self):
106           """
107           parser method for EsysXML and returns the list of generating ParameterSets
108           """
109           found=[]
110           for node in self.__esys.childNodes:
111               if isinstance(node, minidom.Element):
112                   if node.tagName == 'Simulation':
113                            found.append(Simulation.fromDom(self, node))
114                   elif node.tagName == 'Model':
115                            found.append(self.getClassPath(node).fromDom(self, node))
116                   elif node.tagName == 'ParameterSet':
117                            found.append(self.getClassPath(node).fromDom(self, node))
118                   else:
119                      raise "Invalid type, %r" % node.getAttribute("type")
120           self.setLinks()
121           return found
122    
123        def registerLink(self,obj_id, link):
124            self.__link_registry.append((int(obj_id),link))
125    
126        def registerLinkableObject(self,obj, node):
127            id_str=node.getAttribute('id').strip()
128            if len(id_str)>0:
129               id=int(id_str)
130               if self.__linkable_object_registry.has_key(id):
131                   raise ValueError("Object id %s already exists."%id)
132               else:
133                   self.__linkable_object_registry[id]=obj
134    
135        def getComponent(self, node):
136           """
137           returns a single component + rank from a simulation
138           parser method for EsysXML and returns the list of generating ParameterSets
139           """
140           rank  = int(node.getAttribute("rank"))
141           for n in node.childNodes:
142               if isinstance(n, minidom.Element):
143                   if n.tagName == 'Simulation':
144                            return (rank, Simulation.fromDom(self, n))
145                   elif n.tagName == 'Model':
146                            return (rank, self.getClassPath(n).fromDom(self, n))
147                   elif n.tagName == 'ParameterSet':
148                            return (rank, self.getClassPath(n).fromDom(self, n))
149                   else:
150                     raise ValueError("illegal component type %s"%n.tagName)
151           raise ValueError("cannot resolve Component")
152    
153    class ESySXMLCreator(object):
154        """
155        creates an XML Dom representation
156        """
157        def __init__(self):
158           self.__dom=minidom.Document()
159           self.__esys =self.__dom.createElement('ESys')
160           self.__dom.appendChild(self.__esys)
161           self.__linkable_object_registry={}
162           self.__number_sequence = itertools.count(100)
163        def getRoot(self):
164           return self.__esys
165        def createElement(self,name):
166          return self.__dom.createElement(name)
167        def createTextNode(self,name):
168          return self.__dom.createTextNode(name)
169        def getElementById(self,name):
170          return self.__dom.getElementById(name)
171        def createDataNode(self, tagName, data):
172              """
173              C{createDataNode}s are the building blocks of the xml documents constructed in
174              this module.  
175        
176              @param tagName: the associated xml tag
177              @param data: the values in the tag
178              """
179              n = self.createElement(tagName)
180              n.appendChild(self.createTextNode(str(data)))
181              return n
182        def getLinkableObjectId(self, obj):
183            for id, o in self.__linkable_object_registry.items():
184                if o == obj: return id
185            id =self.__number_sequence.next()
186            self.__linkable_object_registry[id]=obj
187            return id
188            
189        def registerLinkableObject(self, obj, node):
190            """
191            returns a unique object id for object obj
192            """
193            id=self.getLinkableObjectId(obj)
194            node.setAttribute('id',str(id))
195            node.setIdAttribute("id")
196    
197        def includeTargets(self):
198            target_written=True
199            while target_written:
200                targetsList =self.__dom.getElementsByTagName('Target')
201                target_written=False
202                for element in targetsList:
203                   targetId = int(element.firstChild.nodeValue.strip())
204                   if self.getElementById(str(targetId)): continue
205                   targetObj = self.__linkable_object_registry[targetId]
206                   targetObj.toDom(self, self.__esys)
207                   target_written=True
208    
209        def toprettyxml(self):
210            self.includeTargets()
211            return self.__dom.toprettyxml()
212    
213  class Link:  class Link:
214      """      """
215      A Link makes an attribute of an object callable::      A Link makes an attribute of an object callable::
# Line 159  class Link: Line 229  class Link:
229          self.target = target          self.target = target
230          self.attribute = None          self.attribute = None
231          self.setAttributeName(attribute)          self.setAttributeName(attribute)
232    
233        def getTarget(self):
234            """
235            returns the target
236            """
237            return self.target
238        def getAttributeName(self):
239            """
240            returns the name of the attribute the link is pointing to
241            """
242            return self.attribute
243            
244      def setAttributeName(self,attribute):      def setAttributeName(self,attribute):
245          """          """
# Line 204  class Link: Line 285  class Link:
285          else:          else:
286              return out              return out
287    
288      def toDom(self, document, node):      def toDom(self, esysxml, node):
289          """          """
290          C{toDom} method of Link. Creates a Link node and appends it to the          C{toDom} method of Link. Creates a Link node and appends it to the
291      current XML document.      current XML esysxml.
292          """          """
293          link = document.createElement('Link')          link = esysxml.createElement('Link')
294          assert (self.target != None), ("Target was none, name was %r" % self.attribute)          assert (self.target != None), ("Target was none, name was %r" % self.attribute)
295          link.appendChild(dataNode(document, 'Target', self.target.id))          link.appendChild(esysxml.createDataNode('Target', esysxml.getLinkableObjectId(self.target)))
296          # this use of id will not work for purposes of being able to retrieve the intended          # this use of id will not work for purposes of being able to retrieve the intended
297          # target from the xml later. I need a better unique identifier.          # target from the xml later. I need a better unique identifier.
298          assert self.attribute, "You can't xmlify a Link without a target attribute"          assert self.attribute, "You can't xmlify a Link without a target attribute"
299          link.appendChild(dataNode(document, 'Attribute', self.attribute))          link.appendChild(esysxml.createDataNode('Attribute', self.attribute))
300          node.appendChild(link)          node.appendChild(link)
301    
302      def fromDom(cls, doc):      def fromDom(cls, esysxml, node):
303          targetid = doc.getElementsByTagName("Target")[0].firstChild.nodeValue.strip()          targetid = int(node.getElementsByTagName("Target")[0].firstChild.nodeValue.strip())
304          attribute = doc.getElementsByTagName("Attribute")[0].firstChild.nodeValue.strip()          attribute =str(node.getElementsByTagName("Attribute")[0].firstChild.nodeValue.strip())
305          l = cls(None, attribute)          l = cls(None, attribute)
306          registerLink(targetid, l)          esysxml.registerLink(targetid, l)
307          return l          return l
308    
309      fromDom = classmethod(fromDom)      fromDom = classmethod(fromDom)
310            
     def writeXML(self,ostream=stdout):  
         """  
         Writes an XML representation of self to the output stream ostream.  
         If ostream is nor present the standart output stream is used.  If  
         esysheader==True the esys XML header is written  
         """  
         print 'I got to the Link writeXML method'  
         document, rootnode = esysDoc()  
         self.toDom(document, rootnode)  
   
         ostream.write(document.toprettyxml())  
   
311  class LinkableObject(object):  class LinkableObject(object):
312      """      """
313      An object that allows to link its attributes to attributes of other objects      An object that allows to link its attributes to attributes of other objects
# Line 257  class LinkableObject(object): Line 326  class LinkableObject(object):
326      the return value of the call.      the return value of the call.
327      """      """
328        
     number_sequence = itertools.count(100)  
329            
330      def __init__(self, debug=False):      def __init__(self, id = None, debug=False):
331          """          """
332      Initializes LinkableObject so that we can operate on Links      Initializes LinkableObject so that we can operate on Links
333      """      """
334          self.debug = debug          self.debug = debug
335          self.__linked_attributes={}          self.__linked_attributes={}
336          self.id = self.number_sequence.next()          
         registerLinkableObject(self.id, self)  
   
337      def trace(self, msg):      def trace(self, msg):
338          """          """
339      If debugging is on, print the message, otherwise do nothing      If debugging is on, print the message, otherwise do nothing
# Line 364  class ParameterSet(LinkableObject): Line 430  class ParameterSet(LinkableObject):
430       - a ParameterSet object       - a ParameterSet object
431       - a Simulation object       - a Simulation object
432       - a Model object       - a Model object
433       - any other object (not considered by writeESySXML and writeXML)       - a numarray object
434             - a list of booleans
435            - any other object (not considered by writeESySXML and writeXML)
436            
437      Example how to create an ESySParameters object::      Example how to create an ESySParameters object::
438            
# Line 393  class ParameterSet(LinkableObject): Line 461  class ParameterSet(LinkableObject):
461          self.declareParameters(parameters)          self.declareParameters(parameters)
462    
463      def __repr__(self):      def __repr__(self):
464          return "<%s %r>" % (self.__class__.__name__,          return "<%s %d>"%(self.__class__.__name__,id(self))
                             [(p, getattr(self, p, None)) for p in self.parameters])  
465            
466      def declareParameter(self,**parameters):      def declareParameter(self,**parameters):
467          """          """
# Line 416  class ParameterSet(LinkableObject): Line 483  class ParameterSet(LinkableObject):
483              setattr(self,prm,value)              setattr(self,prm,value)
484              self.parameters.add(prm)              self.parameters.add(prm)
485    
             self.trace("parameter %s has been declared."%prm)  
   
486      def releaseParameters(self,name):      def releaseParameters(self,name):
487          """          """
488      Removes parameter name from the paramameters.      Removes parameter name from the paramameters.
# Line 425  class ParameterSet(LinkableObject): Line 490  class ParameterSet(LinkableObject):
490          if self.isParameter(name):          if self.isParameter(name):
491              self.parameters.remove(name)              self.parameters.remove(name)
492              self.trace("parameter %s has been removed."%name)              self.trace("parameter %s has been removed."%name)
493    
494        def checkLinkTargets(self, models, hash):
495            """
496            returns a set of tuples ("<self>(<name>)", <target model>) if the parameter <name> is linked to model <target model>
497            but <target model> is not in the list models. If the a parameter is linked to another parameter set which is not in the hash list
498            the parameter set is checked for its models. hash gives the call history.
499            """
500            out=set()
501            for name, value in self:
502                if isinstance(value, Link):
503                   m=value.getTarget()
504                   if isinstance(m, Model):
505                       if not m in models: out.add( (str(self)+"("+name+")",m) )
506                   elif isinstance(m, ParameterSet) and not m in hash:
507                         out|=set( [ (str(self)+"("+name+")."+f[0],f[1]) for f in m.checkLinkTargets(models, hash+[ self ] ) ] )
508            return out
509            
510      def __iter__(self):      def __iter__(self):
511          """          """
# Line 457  class ParameterSet(LinkableObject): Line 538  class ParameterSet(LinkableObject):
538          except:          except:
539              pass              pass
540    
541      def toDom(self, document, node):      def toDom(self, esysxml, node):
542          """          """
543      C{toDom} method of ParameterSet class.      C{toDom} method of Model class
544      """      """
545          pset = document.createElement('ParameterSet')          pset = esysxml.createElement('ParameterSet')
546            pset.setAttribute('type', self.__class__.__name__)
547            pset.setAttribute('module', self.__class__.__module__)
548            esysxml.registerLinkableObject(self, pset)
549            self._parametersToDom(esysxml, pset)
550          node.appendChild(pset)          node.appendChild(pset)
         self._parametersToDom(document, pset)  
551    
552      def _parametersToDom(self, document, node):      def _parametersToDom(self, esysxml, node):
         node.setAttribute('id', str(self.id))  
         node.setIdAttribute("id")  
553          for name,value in self:          for name,value in self:
554              param = document.createElement('Parameter')              # convert list to numarray when possible:
555              param.setAttribute('type', value.__class__.__name__)              if isinstance (value, list):
556                    elem_type=-1
557                    for i in value:
558                        if isinstance(i, bool):
559                            elem_type = max(elem_type,0)
560                        elif isinstance(i, int):
561                            elem_type = max(elem_type,1)
562                        elif isinstance(i, float):
563                            elem_type = max(elem_type,2)
564                    if elem_type == 0: value = numarray.array(value,numarray.Bool)
565                    if elem_type == 1: value = numarray.array(value,numarray.Int)
566                    if elem_type == 2: value = numarray.array(value,numarray.Float)
567    
568              param.appendChild(dataNode(document, 'Name', name))              param = esysxml.createElement('Parameter')
569                param.setAttribute('type', value.__class__.__name__)
570    
571              val = document.createElement('Value')              param.appendChild(esysxml.createDataNode('Name', name))
572    
573              if isinstance(value,ParameterSet):              val = esysxml.createElement('Value')
574                  value.toDom(document, val)              if isinstance(value,(ParameterSet,Link,DataSource)):
575                    value.toDom(esysxml, val)
576                  param.appendChild(val)                  param.appendChild(val)
577              elif isinstance(value, Link):              elif isinstance(value, numarray.NumArray):
578                  value.toDom(document, val)                  shape = value.getshape()
579                    if isinstance(shape, tuple):
580                        size = reduce(operator.mul, shape)
581                        shape = ' '.join(map(str, shape))
582                    else:
583                        size = shape
584                        shape = str(shape)
585    
586                    arraytype = value.type()
587                    if isinstance(arraytype, numarray.BooleanType):
588                          arraytype_str="Bool"
589                    elif isinstance(arraytype, numarray.IntegralType):
590                          arraytype_str="Int"
591                    elif isinstance(arraytype, numarray.FloatingType):
592                          arraytype_str="Float"
593                    elif isinstance(arraytype, numarray.ComplexType):
594                          arraytype_str="Complex"
595                    else:
596                          arraytype_str=str(arraytype)
597                    numarrayElement = esysxml.createElement('NumArray')
598                    numarrayElement.appendChild(esysxml.createDataNode('ArrayType', arraytype_str))
599                    numarrayElement.appendChild(esysxml.createDataNode('Shape', shape))
600                    numarrayElement.appendChild(esysxml.createDataNode('Data', ' '.join(
601                        [str(x) for x in numarray.reshape(value, size)])))
602                    val.appendChild(numarrayElement)
603                  param.appendChild(val)                  param.appendChild(val)
604              elif isinstance(value,StringType):              elif isinstance(value, list):
605                  param.appendChild(dataNode(document, 'Value', value))                  param.appendChild(esysxml.createDataNode('Value', ' '.join([str(x) for x in value]) ))
606                elif isinstance(value, (str, bool, int, float, type(None))):
607                    param.appendChild(esysxml.createDataNode('Value', str(value)))
608                elif isinstance(value, dict):
609                     dic = esysxml.createElement('dictionary')
610                     if len(value.keys())>0:
611                         dic.setAttribute('key_type', value.keys()[0].__class__.__name__)
612                         dic.setAttribute('value_type', value[value.keys()[0]].__class__.__name__)
613                     for k,v in value.items():
614                        i=esysxml.createElement('item')
615                        i.appendChild(esysxml.createDataNode('key', k))
616                        i.appendChild(esysxml.createDataNode('value', v))
617                        dic.appendChild(i)
618                     param.appendChild(dic)
619              else:              else:
620                  param.appendChild(dataNode(document, 'Value', str(value)))                  raise ValueError("cannot serialize %s type to XML."%str(value.__class__))
621    
622              node.appendChild(param)              node.appendChild(param)
623    
624      def fromDom(cls, doc):      def fromDom(cls, esysxml, node):
   
625          # Define a host of helper functions to assist us.          # Define a host of helper functions to assist us.
626          def _children(node):          def _children(node):
627              """              """
628              Remove the empty nodes from the children of this node.              Remove the empty nodes from the children of this node.
629              """              """
630              return [x for x in node.childNodes              ret = []
631                      if not isinstance(x, minidom.Text) or x.nodeValue.strip()]              for x in node.childNodes:
632                    if isinstance(x, minidom.Text):
633                        if x.nodeValue.strip():
634                            ret.append(x)
635                    else:
636                        ret.append(x)
637                return ret
638    
639          def _floatfromValue(doc):          def _floatfromValue(esysxml, node):
640              return float(doc.nodeValue.strip())              return float(node.nodeValue.strip())
641    
642          def _stringfromValue(doc):          def _stringfromValue(esysxml, node):
643              return str(doc.nodeValue.strip())              return str(node.nodeValue.strip())
644                
645          def _intfromValue(doc):          def _intfromValue(esysxml, node):
646              return int(doc.nodeValue.strip())              return int(node.nodeValue.strip())
647    
648          def _boolfromValue(doc):          def _boolfromValue(esysxml, node):
649              return bool(doc.nodeValue.strip())              return _boolfromstring(node.nodeValue.strip())
650    
651          def _nonefromValue(doc):          def _nonefromValue(esysxml, node):
652              return None              return None
653          
654            def _numarrayfromValue(esysxml, node):
655                for node in _children(node):
656                    if node.tagName == 'ArrayType':
657                        arraytype = node.firstChild.nodeValue.strip()
658                    if node.tagName == 'Shape':
659                        shape = node.firstChild.nodeValue.strip()
660                        shape = [int(x) for x in shape.split()]
661                    if node.tagName == 'Data':
662                        data = node.firstChild.nodeValue.strip()
663                        data = [float(x) for x in data.split()]
664                return numarray.reshape(numarray.array(data, type=getattr(numarray, arraytype)),
665                                        shape)
666          
667            def _listfromValue(esysxml, node):
668                return [x for x in node.nodeValue.split()]
669    
670            def _boolfromstring(s):
671                if s == 'True':
672                    return True
673                else:
674                    return False
675          # Mapping from text types in the xml to methods used to process trees of that type          # Mapping from text types in the xml to methods used to process trees of that type
676          ptypemap = {"Simulation": Simulation.fromDom,          ptypemap = {"Simulation": Simulation.fromDom,
677                      "Model":Model.fromDom,                      "Model":Model.fromDom,
678                      "ParameterSet":ParameterSet.fromDom,                      "ParameterSet":ParameterSet.fromDom,
679                      "Link":Link.fromDom,                      "Link":Link.fromDom,
680                        "DataSource":DataSource.fromDom,
681                      "float":_floatfromValue,                      "float":_floatfromValue,
682                      "int":_intfromValue,                      "int":_intfromValue,
683                      "str":_stringfromValue,                      "str":_stringfromValue,
684                      "bool":_boolfromValue,                      "bool":_boolfromValue,
685                      "NoneType":_nonefromValue                      "list":_listfromValue,
686                        "NumArray":_numarrayfromValue,
687                        "NoneType":_nonefromValue,
688                      }                      }
689    
 #        print doc.toxml()  
   
690          parameters = {}          parameters = {}
691          for node in _children(doc):          for n in _children(node):
692              ptype = node.getAttribute("type")              ptype = n.getAttribute("type")
693                if not ptypemap.has_key(ptype):
694                   raise KeyError("cannot handle parameter type %s."%ptype)
695    
696              pname = pvalue = None              pname = pvalue = None
697              for childnode in _children(node):              for childnode in _children(n):
   
698                  if childnode.tagName == "Name":                  if childnode.tagName == "Name":
699                      pname = childnode.firstChild.nodeValue.strip()                      pname = childnode.firstChild.nodeValue.strip()
700    
701                  if childnode.tagName == "Value":                  if childnode.tagName == "Value":
702                      nodes = _children(childnode)                      nodes = _children(childnode)
703                      pvalue = ptypemap[ptype](nodes[0])                      pvalue = ptypemap[ptype](esysxml, nodes[0])
704    
705              parameters[pname] = pvalue              parameters[pname] = pvalue
706    
707          # Create the instance of ParameterSet          # Create the instance of ParameterSet
708          o = cls()          try:
709               o = cls(debug=esysxml.debug)
710            except TypeError, inst:
711               print inst.args[0]
712               if inst.args[0]=="__init__() got an unexpected keyword argument 'debug'":
713                  raise TypeError("The Model class %s __init__ needs to have argument 'debug'.")
714               else:
715                  raise inst
716          o.declareParameters(parameters)          o.declareParameters(parameters)
717          registerLinkableObject(doc.getAttribute("id"), o)          esysxml.registerLinkableObject(o, node)
718          return o          return o
719            
720      fromDom = classmethod(fromDom)      fromDom = classmethod(fromDom)
721        
722      def writeXML(self,ostream=stdout):      def writeXML(self,ostream=stdout):
723          """          """
724      Writes the object as an XML object into an output stream.      Writes the object as an XML object into an output stream.
725      """      """
726          # ParameterSet(d) with d[Name]=Value          esysxml=ESySXMLCreator()
727          document, node = esysDoc()          self.toDom(esysxml, esysxml.getRoot())
728          self.toDom(document, node)          ostream.write(esysxml.toprettyxml())
729          ostream.write(document.toprettyxml())      
   
730  class Model(ParameterSet):  class Model(ParameterSet):
731      """      """
732      A Model object represents a processess marching over time until a      A Model object represents a processess marching over time until a
733      finalizing condition is fullfilled. At each time step an iterative      finalizing condition is fullfilled. At each time step an iterative
734      process can be performed and the time step size can be controlled. A      process can be performed and the time step size can be controlled. A
735      Model has the following work flow::      Model has the following work flow::
736              
737            doInitialization()            doInitialization()
738              while not terminateInitialIteration(): doInitializationiStep()
739              doInitialPostprocessing()
740            while not finalize():            while not finalize():
741                 dt=getSafeTimeStepSize(dt)                 dt=getSafeTimeStepSize(dt)
742                 doStepPreprocessing(dt)                 doStepPreprocessing(dt)
# Line 585  class Model(ParameterSet): Line 753  class Model(ParameterSet):
753    
754      UNDEF_DT=1.e300      UNDEF_DT=1.e300
755    
756      def __init__(self,parameters=[],**kwarg):      def __init__(self,parameters=[],**kwargs):
757          """          """
758      Creates a model.      Creates a model.
759    
760          Just calls the parent constructor.          Just calls the parent constructor.
761          """          """
762          ParameterSet.__init__(self, parameters=parameters,**kwarg)          ParameterSet.__init__(self, parameters=parameters,**kwargs)
763    
764      def __str__(self):      def __str__(self):
765         return "<%s %d>"%(self.__class__,id(self))         return "<%s %d>"%(self.__class__.__name__,id(self))
766    
767      def toDom(self, document, node):  
768        def setUp(self):
769          """          """
770      C{toDom} method of Model class          Sets up the model.
771      """  
772          pset = document.createElement('Model')          This function may be overwritten.
773          pset.setAttribute('type', self.__class__.__name__)          """
774          if not self.__class__.__module__.startswith('esys.escript'):          pass
             pset.setAttribute('module', self.__class__.__module__)  
         node.appendChild(pset)  
         self._parametersToDom(document, pset)  
775    
776      def doInitialization(self):      def doInitialization(self):
777          """          """
778      Initializes the time stepping scheme.        Initializes the time stepping scheme. This method is not called in case of a restart.
779            
780      This function may be overwritten.      This function may be overwritten.
781      """      """
782          pass          pass
783        def doInitialStep(self):
784            """
785        performs an iteration step in the initialization phase. This method is not called in case of a restart.
786    
787        This function may be overwritten.
788        """
789            pass
790    
791        def terminateInitialIteration(self):
792            """
793        Returns True if iteration at the inital phase is terminated.
794        """
795            return True
796    
797        def doInitialPostprocessing(self):
798            """
799        finalises the initialization iteration process. This method is not called in case of a restart.
800    
801        This function may be overwritten.
802        """
803            pass
804            
805      def getSafeTimeStepSize(self,dt):      def getSafeTimeStepSize(self,dt):
806          """          """
# Line 664  class Model(ParameterSet): Line 851  class Model(ParameterSet):
851      Returns True if iteration on a time step is terminated.      Returns True if iteration on a time step is terminated.
852      """      """
853          return True          return True
854    
855                
856      def doStepPostprocessing(self,dt):      def doStepPostprocessing(self,dt):
857          """          """
858      Finalalizes the time step.      finalises the time step.
859    
860          dt is the currently used time step size.          dt is the currently used time step size.
861    
862          This function may be overwritten.          This function may be overwritten.
863      """      """
864          pass          pass
       
     def writeXML(self, ostream=stdout):  
         document, node = esysDoc()  
         self.toDom(document, node)  
         ostream.write(document.toprettyxml())  
       
865    
866        def toDom(self, esysxml, node):
867            """
868        C{toDom} method of Model class
869        """
870            pset = esysxml.createElement('Model')
871            pset.setAttribute('type', self.__class__.__name__)
872            pset.setAttribute('module', self.__class__.__module__)
873            esysxml.registerLinkableObject(self, pset)
874            node.appendChild(pset)
875            self._parametersToDom(esysxml, pset)
876        
877  class Simulation(Model):  class Simulation(Model):
878      """      """
879      A Simulation object is special Model which runs a sequence of Models.      A Simulation object is special Model which runs a sequence of Models.
# Line 699  class Simulation(Model): Line 892  class Simulation(Model):
892          """          """
893      Initiates a simulation from a list of models.      Initiates a simulation from a list of models.
894      """      """
895          Model.__init__(self, **kwargs)          super(Simulation, self).__init__(**kwargs)
896            self.declareParameter(time=0.,
897                                  time_step=0,
898                                  dt = self.UNDEF_DT)
899            for m in models:
900                if not isinstance(m, Model):
901                     raise TypeError("%s is not a subclass of Model."%m)
902          self.__models=[]          self.__models=[]
           
903          for i in range(len(models)):          for i in range(len(models)):
904              self[i] = models[i]              self[i] = models[i]
905                            
# Line 746  class Simulation(Model): Line 944  class Simulation(Model):
944      """      """
945          return len(self.__models)          return len(self.__models)
946    
947      def toDom(self, document, node):      def getAllModels(self):
948          """          """
949      C{toDom} method of Simulation class.          returns a list of all models used in the Simulation including subsimulations
950      """          """
951          simulation = document.createElement('Simulation')          out=[]
952          simulation.setAttribute('type', self.__class__.__name__)          for m in self.iterModels():
953                if isinstance(m, Simulation):
954          for rank, sim in enumerate(self.iterModels()):                 out+=m.getAllModels()
955              component = document.createElement('Component')              else:
956              component.setAttribute('rank', str(rank))                 out.append(m)
957            return list(set(out))
             sim.toDom(document, component)  
   
             simulation.appendChild(component)  
   
         node.appendChild(simulation)  
958    
959      def writeXML(self,ostream=stdout):      def checkModels(self, models, hash):
960          """          """
961      Writes the object as an XML object into an output stream.          returns a list of (model,parameter, target model ) if the the parameter of model
962      """          is linking to the target_model which is not in list of models.
963          document, rootnode = esysDoc()          """
964          self.toDom(document, rootnode)          out=self.checkLinkTargets(models, hash + [self])
965          targetsList = document.getElementsByTagName('Target')          for m in self.iterModels():
966                        if isinstance(m, Simulation):
967          for element in targetsList:                   out|=m.checkModels(models, hash)
968              targetId = int(element.firstChild.nodeValue.strip())              else:
969              if document.getElementById(str(targetId)):                   out|=m.checkLinkTargets(models, hash + [self])
970                  continue          return set( [ (str(self)+"."+f[0],f[1]) for f in out ] )
971              targetObj = LinkableObjectRegistry[targetId]  
             targetObj.toDom(document, rootnode)  
         ostream.write(document.toprettyxml())  
972            
973      def getSafeTimeStepSize(self,dt):      def getSafeTimeStepSize(self,dt):
974          """          """
# Line 786  class Simulation(Model): Line 977  class Simulation(Model):
977          This is the minimum over the time step sizes of all models.          This is the minimum over the time step sizes of all models.
978      """      """
979          out=min([o.getSafeTimeStepSize(dt) for o in self.iterModels()])          out=min([o.getSafeTimeStepSize(dt) for o in self.iterModels()])
         #print "%s: safe step size is %e."%(str(self),out)  
980          return out          return out
981    
982        def setUp(self):
983            """
984            performs the setup for all models
985            """
986            for o in self.iterModels():
987                 o.setUp()
988            
989      def doInitialization(self):      def doInitialization(self):
990          """          """
991      Initializes all models.      Initializes all models.
992      """      """
         self.n=0  
         self.tn=0.  
993          for o in self.iterModels():          for o in self.iterModels():
994              o.doInitialization()               o.doInitialization()
995            def doInitialStep(self):
996            """
997        performs an iteration step in the initialization step for all models
998        """
999            iter=0
1000            while not self.terminateInitialIteration():
1001                if iter==0: self.trace("iteration for initialization starts")
1002                iter+=1
1003                self.trace("iteration step %d"%(iter))
1004                for o in self.iterModels():
1005                     o.doInitialStep()
1006                if iter>self.MAX_ITER_STEPS:
1007                     raise IterationDivergenceError("initial iteration did not converge after %s steps."%iter)
1008            self.trace("Initialization finalized after %s iteration steps."%iter)
1009    
1010        def doInitialPostprocessing(self):
1011            """
1012        finalises the initialization iteration process for all models.
1013        """
1014            for o in self.iterModels():
1015                o.doInitialPostprocessing()
1016      def finalize(self):      def finalize(self):
1017          """          """
1018      Returns True if any of the models is to be finalized.      Returns True if any of the models is to be finalized.
# Line 806  class Simulation(Model): Line 1021  class Simulation(Model):
1021                
1022      def doFinalization(self):      def doFinalization(self):
1023          """          """
1024      Finalalizes the time stepping for all models.      finalises the time stepping for all models.
1025      """      """
1026          for i in self.iterModels(): i.doFinalization()          for i in self.iterModels(): i.doFinalization()
1027          self.trace("end of time integation.")          self.trace("end of time integation.")
# Line 824  class Simulation(Model): Line 1039  class Simulation(Model):
1039      """      """
1040          out=all([o.terminateIteration() for o in self.iterModels()])          out=all([o.terminateIteration() for o in self.iterModels()])
1041          return out          return out
1042    
1043        def terminateInitialIteration(self):
1044            """
1045        Returns True if all initial iterations for all models are terminated.
1046        """
1047            out=all([o.terminateInitialIteration() for o in self.iterModels()])
1048            return out
1049                
1050      def doStepPostprocessing(self,dt):      def doStepPostprocessing(self,dt):
1051          """          """
1052      Finalalizes the iteration process for all models.      finalises the iteration process for all models.
1053      """      """
1054          for o in self.iterModels():          for o in self.iterModels():
1055              o.doStepPostprocessing(dt)              o.doStepPostprocessing(dt)
1056          self.n+=1          self.time_step+=1
1057          self.tn+=dt          self.time+=dt
1058            self.dt=dt
1059            
1060      def doStep(self,dt):      def doStep(self,dt):
1061          """          """
# Line 846  class Simulation(Model): Line 1069  class Simulation(Model):
1069          """          """
1070          self.iter=0          self.iter=0
1071          while not self.terminateIteration():          while not self.terminateIteration():
1072              if self.iter==0: self.trace("iteration at %d-th time step %e starts"%(self.n+1,self.tn+dt))              if self.iter==0: self.trace("iteration at %d-th time step %e starts"%(self.time_step+1,self.time+dt))
1073              self.iter+=1              self.iter+=1
1074              self.trace("iteration step %d"%(self.iter))              self.trace("iteration step %d"%(self.iter))
1075              for o in self.iterModels():              for o in self.iterModels():
1076                    o.doStep(dt)                    o.doStep(dt)
1077          if self.iter>0: self.trace("iteration at %d-th time step %e finalized."%(self.n+1,self.tn+dt))          if self.iter>0: self.trace("iteration at %d-th time step %e finalized."%(self.time_step+1,self.time+dt))
1078    
1079      def run(self,check_point=None):      def run(self,check_pointing=None):
1080          """          """
1081      Run the simulation by performing essentially::      Run the simulation by performing essentially::
1082            
1083          self.doInitialization()              self.setUp()
1084                if not restart:
1085                self.doInitialization()
1086                    while not self.terminateInitialIteration(): self.doInitialStep()
1087                    self.doInitialPostprocessing()
1088          while not self.finalize():          while not self.finalize():
1089              dt=self.getSafeTimeStepSize()              dt=self.getSafeTimeStepSize()
1090              self.doStep(dt)                  self.doStepPreprocessing(dt_new)
1091              if n%check_point==0:                  self.doStep(dt_new)
1092              self.writeXML()                  self.doStepPostprocessing(dt_new)
1093          self.doFinalization()          self.doFinalization()
1094    
1095          If one of the models in throws a C{FailedTimeStepError} exception a          If one of the models in throws a C{FailedTimeStepError} exception a
# Line 875  class Simulation(Model): Line 1102  class Simulation(Model):
1102          In both cases the time integration is given up after          In both cases the time integration is given up after
1103      C{Simulation.FAILED_TIME_STEPS_MAX} attempts.      C{Simulation.FAILED_TIME_STEPS_MAX} attempts.
1104          """          """
1105          dt=self.UNDEF_DT          # check the completness of the models:
1106          self.doInitialization()          # first a list of all the models involved in the simulation including subsimulations:
1107            #
1108            missing=self.checkModels(self.getAllModels(), [])
1109            if len(missing)>0:
1110                msg=""
1111                for l in missing:
1112                     msg+="\n\t"+str(l[1])+" at "+l[0]
1113                raise MissingLink("link targets missing in the Simulation: %s"%msg)
1114            #==============================
1115            self.setUp()
1116            if self.time_step < 1:
1117               self.doInitialization()
1118               self.doInitialStep()
1119               self.doInitialPostprocessing()
1120          while not self.finalize():          while not self.finalize():
1121              step_fail_counter=0              step_fail_counter=0
1122              iteration_fail_counter=0              iteration_fail_counter=0
1123              if self.n==0:              if self.time_step==0:
1124                  dt_new=self.getSafeTimeStepSize(dt)                  dt_new=self.getSafeTimeStepSize(self.dt)
1125              else:              else:
1126                  dt_new=min(max(self.getSafeTimeStepSize(dt),dt*self.MAX_CHANGE_OF_DT),dt*self.MAX_CHANGE_OF_DT)                  dt_new=min(max(self.getSafeTimeStepSize(self.dt),self.dt/self.MAX_CHANGE_OF_DT),self.dt*self.MAX_CHANGE_OF_DT)
1127              self.trace("%d. time step %e (step size %e.)" % (self.n+1,self.tn+dt_new,dt_new))              self.trace("%d. time step %e (step size %e.)" % (self.time_step+1,self.time+dt_new,dt_new))
1128              end_of_step=False              end_of_step=False
1129              while not end_of_step:              while not end_of_step:
1130                 end_of_step=True                 end_of_step=True
1131                 if not dt_new>0:                 if not dt_new>0:
1132                    raise NonPositiveStepSizeError("non-positive step size in step %d"%(self.n+1))                    raise NonPositiveStepSizeError("non-positive step size in step %d"%(self.time_step+1))
1133                 try:                 try:
1134                    self.doStepPreprocessing(dt_new)                    self.doStepPreprocessing(dt_new)
1135                    self.doStep(dt_new)                    self.doStep(dt_new)
# Line 902  class Simulation(Model): Line 1142  class Simulation(Model):
1142                             raise SimulationBreakDownError("reduction of time step to achieve convergence failed after %s steps."%self.FAILED_TIME_STEPS_MAX)                             raise SimulationBreakDownError("reduction of time step to achieve convergence failed after %s steps."%self.FAILED_TIME_STEPS_MAX)
1143                    self.trace("Iteration failed. Time step is repeated with new step size %s."%dt_new)                    self.trace("Iteration failed. Time step is repeated with new step size %s."%dt_new)
1144                 except FailedTimeStepError:                 except FailedTimeStepError:
1145                    dt_new=self.getSafeTimeStepSize(dt)                    dt_new=self.getSafeTimeStepSize(self.dt)
1146                    end_of_step=False                    end_of_step=False
1147                    step_fail_counter+=1                    step_fail_counter+=1
1148                    self.trace("Time step is repeated with new time step size %s."%dt_new)                    self.trace("Time step is repeated with new time step size %s."%dt_new)
1149                    if step_fail_counter>self.FAILED_TIME_STEPS_MAX:                    if step_fail_counter>self.FAILED_TIME_STEPS_MAX:
1150                          raise SimulationBreakDownError("Time integration is given up after %d attempts."%step_fail_counter)                          raise SimulationBreakDownError("Time integration is given up after %d attempts."%step_fail_counter)
1151              dt=dt_new              if not check_pointing==None:
1152              if not check_point==None:                 if check_pointing.doDump():
                 if n%check_point==0:  
1153                      self.trace("check point is created.")                      self.trace("check point is created.")
1154                      self.writeXML()                      self.writeXML()
1155          self.doFinalization()          self.doFinalization()
1156    
     def fromDom(cls, doc):  
         sims = []  
         for node in doc.childNodes:  
             if isinstance(node, minidom.Text):  
                 continue  
1157    
1158              sims.append(getComponent(node))      def toDom(self, esysxml, node):
1159            """
1160        C{toDom} method of Simulation class.
1161        """
1162            simulation = esysxml.createElement('Simulation')
1163            esysxml.registerLinkableObject(self, simulation)
1164            for rank, sim in enumerate(self.iterModels()):
1165                component = esysxml.createElement('Component')
1166                component.setAttribute('rank', str(rank))
1167                sim.toDom(esysxml, component)
1168                simulation.appendChild(component)
1169            node.appendChild(simulation)
1170    
1171    
1172          return cls(sims)      def fromDom(cls, esysxml, node):
1173            sims = []
1174            for n in node.childNodes:
1175                if isinstance(n, minidom.Text):
1176                    continue
1177                sims.append(esysxml.getComponent(n))
1178            sims.sort(_comp)
1179            sim=cls([s[1] for s in sims], debug=esysxml.debug)
1180            esysxml.registerLinkableObject(sim, node)
1181            return sim
1182    
1183      fromDom = classmethod(fromDom)      fromDom = classmethod(fromDom)
1184    
1185    def _comp(a,b):
1186        if a[0]<a[1]:
1187          return 1
1188        elif a[0]>a[1]:
1189          return -1
1190        else:
1191          return 0
1192    
1193  class IterationDivergenceError(Exception):  class IterationDivergenceError(Exception):
1194      """      """
# Line 957  class NonPositiveStepSizeError(Exception Line 1219  class NonPositiveStepSizeError(Exception
1219      """      """
1220      pass      pass
1221    
1222    class MissingLink(Exception):
1223        """
1224        Exception thrown when a link is missing
1225        """
1226        pass
1227    
1228    class DataSource(object):
1229        """
1230        Class for handling data sources, including local and remote files. This class is under development.
1231        """
1232    
1233        def __init__(self, uri="file.ext", fileformat="unknown"):
1234            self.uri = uri
1235            self.fileformat = fileformat
1236    
1237        def toDom(self, esysxml, node):
1238            """
1239            C{toDom} method of DataSource. Creates a DataSource node and appends it to the
1240        current XML esysxml.
1241            """
1242            ds = esysxml.createElement('DataSource')
1243            ds.appendChild(esysxml.createDataNode('URI', self.uri))
1244            ds.appendChild(esysxml.createDataNode('FileFormat', self.fileformat))
1245            node.appendChild(ds)
1246    
1247        def fromDom(cls, esysxml, node):
1248            uri= str(node.getElementsByTagName("URI")[0].firstChild.nodeValue.strip())
1249            fileformat= str(node.getElementsByTagName("FileFormat")[0].firstChild.nodeValue.strip())
1250            ds = cls(uri, fileformat)
1251            return ds
1252    
1253        def getLocalFileName(self):
1254            return self.uri
1255    
1256        fromDom = classmethod(fromDom)
1257    
1258    class RestartManager(object):
1259         """
1260         A restart manager which does two things: it decides when restart files have created (when doDump returns true) and
1261         manages directories for restart files. The method getNewDumper creates a new directory and returns its name.
1262        
1263         This restart manager will decide to dump restart files if every dump_step calls of doDump or
1264         if more than dump_time since the last dump has elapsed. The restart manager controls two directories for dumping restart data, namely
1265         for the current and previous dump. This way the previous dump can be used for restart in the case the current dump failed.
1266    
1267         @cvar SEC: unit of seconds, for instance for 5*RestartManager.SEC to define 5 seconds.
1268         @cvar MIN: unit of minutes, for instance for 5*RestartManager.MIN to define 5 minutes.
1269         @cvar H: unit of hours, for instance for 5*RestartManager.H to define 5 hours.
1270         @cvar D: unit of days, for instance for 5*RestartManager.D to define 5 days.
1271         """
1272         SEC=1.
1273         MIN=60.
1274         H=360.
1275         D=8640.
1276         def __init__(self,dump_time=1080., dump_step=None, dumper=None):
1277             """
1278             initializes the RestartManager.
1279    
1280             @param dump_time: defines the minimum time interval in SEC between to dumps. If None, time is not used as criterion.
1281             @param dump_step: defines the number of calls of doDump between to dump events. If None, the call counter is not used as criterion.
1282             @param dumper: defines the directory for dumping restart files. Additionally the directories dumper+"_bkp" and dumper+"_bkp2" are used.
1283                            if the directory does not exist it is created. If dumper is not present a unique directory within the current
1284                            working directory is used.
1285             """
1286             self.__dump_step=dump_time
1287             self.__dump_time=dump_step
1288             self.__counter=0
1289             self.__saveMarker()
1290             if dumper == None:
1291                self.__dumper="restart"+str(os.getpid())
1292             else:
1293                self.__dumper=dumper
1294             self.__dumper_bkp=self.__dumper+"_bkp"
1295             self.__dumper_bkp2=self.__dumper+"_bkp2"
1296             self.__current_dumper=None
1297         def __saveMarker(self):
1298             self.__last_restart_time=time.time()
1299             self.__last_restart_counter=self.__counter
1300         def getCurrentDumper(self):
1301             """
1302             returns the name of the currently used dumper
1303             """
1304             return self.__current_dumper
1305         def doDump(self):
1306            """
1307            returns true the restart should be dumped. use C{getNewDumper} to get the directory name to be used.
1308            """
1309            if self.__dump_step == None:
1310               if self.__dump_step == None:
1311                  out = False
1312               else:
1313                  out = (self.__dump_step + self.__last_restart_counter) <= self.__counter
1314            else:
1315               if dump_step == None:
1316                  out = (self.__last_restart_time + self.__dump_time) <= time.time()
1317               else:
1318                  out =    ( (self.__dump_step + self.__last_restart_counter) <= self.__counter)  \
1319                        or ( (self.__last_restart_time + self.__dump_time) <= time.time() )
1320            if out: self.__saveMarker()
1321            self__counter+=1
1322         def getNewDumper(self):
1323           """
1324           creates a new directory to be used for dumping and returns its name.
1325           """
1326           if os.access(self.__dumper_bkp,os.F_OK):
1327              if os.access(self.__dumper_bkp2, os.F_OK):
1328                 raise RunTimeError("please remove %s."%self.__dumper_bkp2)
1329              try:
1330                 os.rename(self.__dumper_bkp, self.__dumper_bkp2)
1331              except:
1332                 self.__current_dumper=self.__dumper
1333                 raise RunTimeError("renaming back-up directory %s failed. Use %s for restart."%(self.__dumper_bkp,self.__dumper))
1334           if os.access(self.__dumper,os.F_OK):
1335              if os.access(self.__dumper_bkp, os.F_OK):
1336                 raise RunTimeError("please remove %s."%self.__dumper_bkp)
1337              try:
1338                 os.rename(self.__dumper, self.__dumper_bkp)
1339              except:
1340                 self.__current_dumper=self.__dumper_bkp2
1341                 raise RunTimeError("moving directory %s to back-up failed. Use %s for restart."%(self.__dumper,self.__dumper_bkp2))
1342           try:
1343              os.mkdir(self.__dumper)
1344           except:
1345              self.__current_dumper=self.__dumper_bkp
1346              raise RunTimeError("creating a new restart directory %s failed. Use %s for restart."%(self.__dumper,self.__dumper_bkp))
1347           if os.access(self.__dumper_bkp2, os.F_OK): os.rmdir(self.__dumper_bkp2)
1348           return self.getCurrentDumper()
1349            
1350        
1351  # vim: expandtab shiftwidth=4:  # vim: expandtab shiftwidth=4:

Legend:
Removed from v.836  
changed lines
  Added in v.1312

  ViewVC Help
Powered by ViewVC 1.1.26