/[escript]/trunk/escript/py_src/modelframe.py
ViewVC logotype

Diff of /trunk/escript/py_src/modelframe.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

trunk/esys2/escript/py_src/modelframe.py revision 149 by jgs, Thu Sep 1 03:31:39 2005 UTC trunk/escript/py_src/modelframe.py revision 912 by gross, Wed Dec 6 03:29:49 2006 UTC
# Line 1  Line 1 
1  # $Id$  # $Id$
2    
3    """
4    Environment for implementing models in escript
5    
6    @var __author__: name of author
7    @var __copyright__: copyrights
8    @var __license__: licence agreement
9    @var __url__: url entry point on documentation
10    @var __version__: version
11    @var __date__: date of the version
12    """
13    
14    __author__="Lutz Gross, l.gross@uq.edu.au"
15    __copyright__="""  Copyright (c) 2006 by ACcESS MNRF
16                        http://www.access.edu.au
17                    Primary Business: Queensland, Australia"""
18    __license__="""Licensed under the Open Software License version 3.0
19                 http://www.opensource.org/licenses/osl-3.0.php"""
20    __url__="http://www.iservo.edu.au/esys"
21    __version__="$Revision$"
22    __date__="$Date$"
23    
24    
25  from types import StringType,IntType,FloatType,BooleanType,ListType,DictType  from types import StringType,IntType,FloatType,BooleanType,ListType,DictType
26  from sys import stdout  from sys import stdout
27    import numarray
28    import operator
29  import itertools  import itertools
30  # import modellib  temporarily removed!!!  # import modellib  temporarily removed!!!
31    
# Line 73  def parse(xml): Line 97  def parse(xml):
97    
98      return sim      return sim
99    
100    def importName(modulename, name):
101        """ Import a named object from a module in the context of this function,
102            which means you should use fully qualified module paths.
103            
104            Return None on failure.
105    
106            This function from: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52241
107        """
108        module = __import__(modulename, globals(), locals(), [name])
109            
110        try:
111            return vars(module)[name]
112        except KeyError:
113            raise ImportError("Could not import %s from %s" % (name, modulename))
114    
115  def getComponent(doc):  def getComponent(doc):
116      """      """
117      Used to get components of Simualtions, Models.      Used to get components of Simualtions, Models.
# Line 84  def getComponent(doc): Line 123  def getComponent(doc):
123                  if node.getAttribute("type") == 'Simulation':                  if node.getAttribute("type") == 'Simulation':
124                      return Simulation.fromDom(node)                      return Simulation.fromDom(node)
125              if node.tagName == 'Model':              if node.tagName == 'Model':
126                  model_type = node.getAttribute("type")                  if (node.getAttribute("module")):
127                  model_subclasses = Model.__subclasses__()                      model_module = node.getAttribute("module")
128                  for model in model_subclasses:                      model_type = node.getAttribute("type")
129                      if model_type == model.__name__:                      return importName(model_module, model_type).fromDom(node)
130                          return Model.fromDom(node)                  else:
131                        model_type = node.getAttribute("type")
132                        model_subclasses = Model.__subclasses__()
133                        for model in model_subclasses:
134                            if model_type == model.__name__:
135                                return Model.fromDom(node)
136              if node.tagName == 'ParameterSet':              if node.tagName == 'ParameterSet':
137                  parameter_type = node.getAttribute("type")                  parameter_type = node.getAttribute("type")
138                  return ParameterSet.fromDom(node)                  return ParameterSet.fromDom(node)
# Line 117  class Link: Line 161  class Link:
161          self.target = target          self.target = target
162          self.attribute = None          self.attribute = None
163          self.setAttributeName(attribute)          self.setAttributeName(attribute)
164    
165        def getAttributeName(self):
166            """
167            returns the name of the attribute the link is pointing to
168            """
169            return self.attribute
170            
171      def setAttributeName(self,attribute):      def setAttributeName(self,attribute):
172          """          """
# Line 322  class ParameterSet(LinkableObject): Line 372  class ParameterSet(LinkableObject):
372       - a ParameterSet object       - a ParameterSet object
373       - a Simulation object       - a Simulation object
374       - a Model object       - a Model object
375       - any other object (not considered by writeESySXML and writeXML)       - a numarray object
376             - a list of booleans
377            - any other object (not considered by writeESySXML and writeXML)
378            
379      Example how to create an ESySParameters object::      Example how to create an ESySParameters object::
380            
# Line 351  class ParameterSet(LinkableObject): Line 403  class ParameterSet(LinkableObject):
403          self.declareParameters(parameters)          self.declareParameters(parameters)
404    
405      def __repr__(self):      def __repr__(self):
406          return "<%s %r>" % (self.__class__.__name__,          return "<%s %d>"%(self.__class__.__name__,id(self))
                             [(p, getattr(self, p, None)) for p in self.parameters])  
407            
408      def declareParameter(self,**parameters):      def declareParameter(self,**parameters):
409          """          """
# Line 424  class ParameterSet(LinkableObject): Line 475  class ParameterSet(LinkableObject):
475          self._parametersToDom(document, pset)          self._parametersToDom(document, pset)
476    
477      def _parametersToDom(self, document, node):      def _parametersToDom(self, document, node):
478          node.setAttribute ('id', str(self.id))          node.setAttribute('id', str(self.id))
479            node.setIdAttribute("id")
480          for name,value in self:          for name,value in self:
481              param = document.createElement('Parameter')              param = document.createElement('Parameter')
482              param.setAttribute('type', value.__class__.__name__)              param.setAttribute('type', value.__class__.__name__)
# Line 433  class ParameterSet(LinkableObject): Line 485  class ParameterSet(LinkableObject):
485    
486              val = document.createElement('Value')              val = document.createElement('Value')
487    
488              if isinstance(value,ParameterSet):              if isinstance(value,(ParameterSet,Link,DataSource)):
489                  value.toDom(document, val)                  value.toDom(document, val)
490                  param.appendChild(val)                  param.appendChild(val)
491              elif isinstance(value, Link):              elif isinstance(value, numarray.NumArray):
492                  value.toDom(document, val)                  shape = value.getshape()
493                    if isinstance(shape, tuple):
494                        size = reduce(operator.mul, shape)
495                        shape = ' '.join(map(str, shape))
496                    else:
497                        size = shape
498                        shape = str(shape)
499    
500                    arraytype = value.type()
501                    numarrayElement = document.createElement('NumArray')
502                    numarrayElement.appendChild(dataNode(document, 'ArrayType', str(arraytype)))
503                    numarrayElement.appendChild(dataNode(document, 'Shape', shape))
504                    numarrayElement.appendChild(dataNode(document, 'Data', ' '.join(
505                        [str(x) for x in numarray.reshape(value, size)])))
506                    val.appendChild(numarrayElement)
507                  param.appendChild(val)                  param.appendChild(val)
508              elif isinstance(value,StringType):              elif isinstance (value, list):
509                  param.appendChild(dataNode(document, 'Value', value))                  param.appendChild(dataNode(document, 'Value', ' '.join(
510                        [str(x) for x in value])
511                    ))
512              else:              else:
513                  param.appendChild(dataNode(document, 'Value', str(value)))                  param.appendChild(dataNode(document, 'Value', str(value)))
514    
# Line 453  class ParameterSet(LinkableObject): Line 521  class ParameterSet(LinkableObject):
521              """              """
522              Remove the empty nodes from the children of this node.              Remove the empty nodes from the children of this node.
523              """              """
524              return [x for x in node.childNodes              ret = []
525                      if not isinstance(x, minidom.Text) or x.nodeValue.strip()]              for x in node.childNodes:
526                    if isinstance(x, minidom.Text):
527                        if x.nodeValue.strip():
528                            ret.append(x)
529                    else:
530                        ret.append(x)
531                return ret
532    
533          def _floatfromValue(doc):          def _floatfromValue(doc):
534              return float(doc.nodeValue.strip())              return float(doc.nodeValue.strip())
# Line 466  class ParameterSet(LinkableObject): Line 540  class ParameterSet(LinkableObject):
540              return int(doc.nodeValue.strip())              return int(doc.nodeValue.strip())
541    
542          def _boolfromValue(doc):          def _boolfromValue(doc):
543              return bool(doc.nodeValue.strip())              return _boolfromstring(doc.nodeValue.strip())
544          
545            def _nonefromValue(doc):
546                return None
547    
548            def _numarrayfromValue(doc):
549                for node in _children(doc):
550                    if node.tagName == 'ArrayType':
551                        arraytype = node.firstChild.nodeValue.strip()
552                    if node.tagName == 'Shape':
553                        shape = node.firstChild.nodeValue.strip()
554                        shape = [int(x) for x in shape.split()]
555                    if node.tagName == 'Data':
556                        data = node.firstChild.nodeValue.strip()
557                        data = [float(x) for x in data.split()]
558                return numarray.reshape(numarray.array(data, type=getattr(numarray, arraytype)),
559                                        shape)
560          
561            def _listfromValue(doc):
562                return [_boolfromstring(x) for x in doc.nodeValue.split()]
563    
564    
565            def _boolfromstring(s):
566                if s == 'True':
567                    return True
568                else:
569                    return False
570          # Mapping from text types in the xml to methods used to process trees of that type          # Mapping from text types in the xml to methods used to process trees of that type
571          ptypemap = {"Simulation": Simulation.fromDom,          ptypemap = {"Simulation": Simulation.fromDom,
572                      "Model":Model.fromDom,                      "Model":Model.fromDom,
573                      "ParameterSet":ParameterSet.fromDom,                      "ParameterSet":ParameterSet.fromDom,
574                      "Link":Link.fromDom,                      "Link":Link.fromDom,
575                        "DataSource":DataSource.fromDom,
576                      "float":_floatfromValue,                      "float":_floatfromValue,
577                      "int":_intfromValue,                      "int":_intfromValue,
578                      "str":_stringfromValue,                      "str":_stringfromValue,
579                      "bool":_boolfromValue                      "bool":_boolfromValue,
580                        "list":_listfromValue,
581                        "NumArray":_numarrayfromValue,
582                        "NoneType":_nonefromValue,
583                      }                      }
584    
585  #        print doc.toxml()  #        print doc.toxml()
# Line 493  class ParameterSet(LinkableObject): Line 596  class ParameterSet(LinkableObject):
596    
597                  if childnode.tagName == "Value":                  if childnode.tagName == "Value":
598                      nodes = _children(childnode)                      nodes = _children(childnode)
599                    #    if ptype == 'NumArray':
600                     #       pvalue = _numarrayfromValue(nodes)
601                     #   else:
602                      pvalue = ptypemap[ptype](nodes[0])                      pvalue = ptypemap[ptype](nodes[0])
603    
604              parameters[pname] = pvalue              parameters[pname] = pvalue
# Line 520  class Model(ParameterSet): Line 626  class Model(ParameterSet):
626      finalizing condition is fullfilled. At each time step an iterative      finalizing condition is fullfilled. At each time step an iterative
627      process can be performed and the time step size can be controlled. A      process can be performed and the time step size can be controlled. A
628      Model has the following work flow::      Model has the following work flow::
629              
630            doInitialization()            doInitialization()
631              while not terminateInitialIteration(): doInitializationiStep()
632              doInitialPostprocessing()
633            while not finalize():            while not finalize():
634                 dt=getSafeTimeStepSize(dt)                 dt=getSafeTimeStepSize(dt)
635                 doStepPreprocessing(dt)                 doStepPreprocessing(dt)
# Line 547  class Model(ParameterSet): Line 655  class Model(ParameterSet):
655          ParameterSet.__init__(self, parameters=parameters,**kwarg)          ParameterSet.__init__(self, parameters=parameters,**kwarg)
656    
657      def __str__(self):      def __str__(self):
658         return "<%s %d>"%(self.__class__,id(self))         return "<%s %d>"%(self.__class__.__name__,id(self))
659    
660      def toDom(self, document, node):      def toDom(self, document, node):
661          """          """
# Line 555  class Model(ParameterSet): Line 663  class Model(ParameterSet):
663      """      """
664          pset = document.createElement('Model')          pset = document.createElement('Model')
665          pset.setAttribute('type', self.__class__.__name__)          pset.setAttribute('type', self.__class__.__name__)
666            if not self.__class__.__module__.startswith('esys.escript'):
667                pset.setAttribute('module', self.__class__.__module__)
668          node.appendChild(pset)          node.appendChild(pset)
669          self._parametersToDom(document, pset)          self._parametersToDom(document, pset)
670    
# Line 565  class Model(ParameterSet): Line 675  class Model(ParameterSet):
675      This function may be overwritten.      This function may be overwritten.
676      """      """
677          pass          pass
678        def doInitialStep(self):
679            """
680        performs an iteration step in the initialization phase
681    
682        This function may be overwritten.
683        """
684            pass
685    
686        def terminateInitialIteration(self):
687            """
688        Returns True if iteration at the inital phase is terminated.
689        """
690            return True
691    
692        def doInitialPostprocessing(self):
693            """
694        finalises the initialization iteration process
695    
696        This function may be overwritten.
697        """
698            pass
699            
700      def getSafeTimeStepSize(self,dt):      def getSafeTimeStepSize(self,dt):
701          """          """
# Line 615  class Model(ParameterSet): Line 746  class Model(ParameterSet):
746      Returns True if iteration on a time step is terminated.      Returns True if iteration on a time step is terminated.
747      """      """
748          return True          return True
749    
750                
751      def doStepPostprocessing(self,dt):      def doStepPostprocessing(self,dt):
752          """          """
753      Finalalizes the time step.      finalises the time step.
754    
755          dt is the currently used time step size.          dt is the currently used time step size.
756    
# Line 644  class Simulation(Model): Line 776  class Simulation(Model):
776            
777      FAILED_TIME_STEPS_MAX=20      FAILED_TIME_STEPS_MAX=20
778      MAX_ITER_STEPS=50      MAX_ITER_STEPS=50
779        MAX_CHANGE_OF_DT=2.
780            
781      def __init__(self, models=[], **kwargs):      def __init__(self, models=[], **kwargs):
782          """          """
# Line 685  class Simulation(Model): Line 818  class Simulation(Model):
818      Sets the i-th model.      Sets the i-th model.
819      """      """
820          if not isinstance(value,Model):          if not isinstance(value,Model):
821              raise ValueError("assigned value is not a Model")              raise ValueError,"assigned value is not a Model but instance of %s"%(value.__class__.__name__,)
822          for j in range(max(i-len(self.__models)+1,0)):          for j in range(max(i-len(self.__models)+1,0)):
823              self.__models.append(None)              self.__models.append(None)
824          self.__models[i]=value          self.__models[i]=value
# Line 720  class Simulation(Model): Line 853  class Simulation(Model):
853          document, rootnode = esysDoc()          document, rootnode = esysDoc()
854          self.toDom(document, rootnode)          self.toDom(document, rootnode)
855          targetsList = document.getElementsByTagName('Target')          targetsList = document.getElementsByTagName('Target')
856          for i in targetsList:          
857              targetId = int(i.firstChild.nodeValue.strip())          for element in targetsList:
858                targetId = int(element.firstChild.nodeValue.strip())
859                if document.getElementById(str(targetId)):
860                    continue
861              targetObj = LinkableObjectRegistry[targetId]              targetObj = LinkableObjectRegistry[targetId]
862              targetObj.toDom(document, rootnode)              targetObj.toDom(document, rootnode)
863          ostream.write(document.toprettyxml())          ostream.write(document.toprettyxml())
# Line 733  class Simulation(Model): Line 869  class Simulation(Model):
869          This is the minimum over the time step sizes of all models.          This is the minimum over the time step sizes of all models.
870      """      """
871          out=min([o.getSafeTimeStepSize(dt) for o in self.iterModels()])          out=min([o.getSafeTimeStepSize(dt) for o in self.iterModels()])
         print "%s: safe step size is %e."%(str(self),out)  
872          return out          return out
873            
874      def doInitialization(self):      def doInitialization(self):
# Line 743  class Simulation(Model): Line 878  class Simulation(Model):
878          self.n=0          self.n=0
879          self.tn=0.          self.tn=0.
880          for o in self.iterModels():          for o in self.iterModels():
881              o.doInitialization()               o.doInitialization()
882            def doInitialStep(self):
883            """
884        performs an iteration step in the initialization step for all models
885        """
886            iter=0
887            while not self.terminateInitialIteration():
888                if iter==0: self.trace("iteration for initialization starts")
889                iter+=1
890                self.trace("iteration step %d"%(iter))
891                for o in self.iterModels():
892                     o.doInitialStep()
893                if iter>self.MAX_ITER_STEPS:
894                     raise IterationDivergenceError("initial iteration did not converge after %s steps."%iter)
895            self.trace("Initialization finalized after %s iteration steps."%iter)
896    
897        def doInitialPostprocessing(self):
898            """
899        finalises the initialization iteration process for all models.
900        """
901            for o in self.iterModels():
902                o.doInitialPostprocessing()
903      def finalize(self):      def finalize(self):
904          """          """
905      Returns True if any of the models is to be finalized.      Returns True if any of the models is to be finalized.
# Line 753  class Simulation(Model): Line 908  class Simulation(Model):
908                
909      def doFinalization(self):      def doFinalization(self):
910          """          """
911      Finalalizes the time stepping for all models.      finalises the time stepping for all models.
912      """      """
913          for i in self.iterModels(): i.doFinalization()          for i in self.iterModels(): i.doFinalization()
914          self.trace("end of time integation.")          self.trace("end of time integation.")
# Line 771  class Simulation(Model): Line 926  class Simulation(Model):
926      """      """
927          out=all([o.terminateIteration() for o in self.iterModels()])          out=all([o.terminateIteration() for o in self.iterModels()])
928          return out          return out
929    
930        def terminateInitialIteration(self):
931            """
932        Returns True if all initial iterations for all models are terminated.
933        """
934            out=all([o.terminateInitialIteration() for o in self.iterModels()])
935            return out
936                
937      def doStepPostprocessing(self,dt):      def doStepPostprocessing(self,dt):
938          """          """
939      Finalalizes the iteration process for all models.      finalises the iteration process for all models.
940      """      """
941          for o in self.iterModels():          for o in self.iterModels():
942              o.doStepPostprocessing(dt)              o.doStepPostprocessing(dt)
# Line 805  class Simulation(Model): Line 967  class Simulation(Model):
967      Run the simulation by performing essentially::      Run the simulation by performing essentially::
968            
969          self.doInitialization()          self.doInitialization()
970                while not self.terminateInitialIteration(): self.doInitialStep()
971                self.doInitialPostprocessing()
972          while not self.finalize():          while not self.finalize():
973              dt=self.getSafeTimeStepSize()              dt=self.getSafeTimeStepSize()
974              self.doStep(dt)              self.doStep(dt)
# Line 822  class Simulation(Model): Line 986  class Simulation(Model):
986          In both cases the time integration is given up after          In both cases the time integration is given up after
987      C{Simulation.FAILED_TIME_STEPS_MAX} attempts.      C{Simulation.FAILED_TIME_STEPS_MAX} attempts.
988          """          """
         dt=self.UNDEF_DT  
989          self.doInitialization()          self.doInitialization()
990            self.doInitialStep()
991            self.doInitialPostprocessing()
992            dt=self.UNDEF_DT
993          while not self.finalize():          while not self.finalize():
994              step_fail_counter=0              step_fail_counter=0
995              iteration_fail_counter=0              iteration_fail_counter=0
996              dt_new=self.getSafeTimeStepSize(dt)              if self.n==0:
997                    dt_new=self.getSafeTimeStepSize(dt)
998                else:
999                    dt_new=min(max(self.getSafeTimeStepSize(dt),dt/self.MAX_CHANGE_OF_DT),dt*self.MAX_CHANGE_OF_DT)
1000              self.trace("%d. time step %e (step size %e.)" % (self.n+1,self.tn+dt_new,dt_new))              self.trace("%d. time step %e (step size %e.)" % (self.n+1,self.tn+dt_new,dt_new))
1001              end_of_step=False              end_of_step=False
1002              while not end_of_step:              while not end_of_step:
1003                 end_of_step=True                 end_of_step=True
1004                 if not dt_new>0:                 if not dt_new>0:
1005                    raise NonPositiveStepSizeError("non-positive step size in step %d",self.n+1)                    raise NonPositiveStepSizeError("non-positive step size in step %d"%(self.n+1))
1006                 try:                 try:
1007                    self.doStepPreprocessing(dt_new)                    self.doStepPreprocessing(dt_new)
1008                    self.doStep(dt_new)                    self.doStep(dt_new)
# Line 843  class Simulation(Model): Line 1012  class Simulation(Model):
1012                    end_of_step=False                    end_of_step=False
1013                    iteration_fail_counter+=1                    iteration_fail_counter+=1
1014                    if iteration_fail_counter>self.FAILED_TIME_STEPS_MAX:                    if iteration_fail_counter>self.FAILED_TIME_STEPS_MAX:
1015                             raise SimulationBreakDownError("reduction of time step to achieve convergence failed.")                             raise SimulationBreakDownError("reduction of time step to achieve convergence failed after %s steps."%self.FAILED_TIME_STEPS_MAX)
1016                    self.trace("iteration fails. time step is repeated with new step size.")                    self.trace("Iteration failed. Time step is repeated with new step size %s."%dt_new)
1017                 except FailedTimeStepError:                 except FailedTimeStepError:
1018                    dt_new=self.getSafeTimeStepSize(dt)                    dt_new=self.getSafeTimeStepSize(dt)
1019                    end_of_step=False                    end_of_step=False
1020                    step_fail_counter+=1                    step_fail_counter+=1
1021                    self.trace("time step is repeated.")                    self.trace("Time step is repeated with new time step size %s."%dt_new)
1022                    if step_fail_counter>self.FAILED_TIME_STEPS_MAX:                    if step_fail_counter>self.FAILED_TIME_STEPS_MAX:
1023                          raise SimulationBreakDownError("time integration is given up after %d attempts."%step_fail_counter)                          raise SimulationBreakDownError("Time integration is given up after %d attempts."%step_fail_counter)
1024              dt=dt_new              dt=dt_new
1025              if not check_point==None:              if not check_point==None:
1026                  if n%check_point==0:                  if n%check_point==0:
# Line 901  class NonPositiveStepSizeError(Exception Line 1070  class NonPositiveStepSizeError(Exception
1070      """      """
1071      pass      pass
1072    
1073    class DataSource(object):
1074        """
1075        Class for handling data sources, including local and remote files. This class is under development.
1076        """
1077    
1078        def __init__(self, uri="file.ext", fileformat="unknown"):
1079            self.uri = uri
1080            self.fileformat = fileformat
1081    
1082        def toDom(self, document, node):
1083            """
1084            C{toDom} method of DataSource. Creates a DataSource node and appends it to the
1085        current XML document.
1086            """
1087            ds = document.createElement('DataSource')
1088            ds.appendChild(dataNode(document, 'URI', self.uri))
1089            ds.appendChild(dataNode(document, 'FileFormat', self.fileformat))
1090            node.appendChild(ds)
1091    
1092        def fromDom(cls, doc):
1093            uri= doc.getElementsByTagName("URI")[0].firstChild.nodeValue.strip()
1094            fileformat= doc.getElementsByTagName("FileFormat")[0].firstChild.nodeValue.strip()
1095            ds = cls(uri, fileformat)
1096            return ds
1097    
1098        def getLocalFileName(self):
1099            return self.uri
1100    
1101        fromDom = classmethod(fromDom)
1102        
1103  # vim: expandtab shiftwidth=4:  # vim: expandtab shiftwidth=4:

Legend:
Removed from v.149  
changed lines
  Added in v.912

  ViewVC Help
Powered by ViewVC 1.1.26