/[escript]/trunk/escript/py_src/modelframe.py
ViewVC logotype

Diff of /trunk/escript/py_src/modelframe.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

trunk/esys2/escript/py_src/modelframe.py revision 149 by jgs, Thu Sep 1 03:31:39 2005 UTC trunk/escript/py_src/modelframe.py revision 917 by gross, Tue Jan 2 02:46:53 2007 UTC
# Line 1  Line 1 
1  # $Id$  # $Id$
2    
3    """
4    Environment for implementing models in escript
5    
6    @var __author__: name of author
7    @var __copyright__: copyrights
8    @var __license__: licence agreement
9    @var __url__: url entry point on documentation
10    @var __version__: version
11    @var __date__: date of the version
12    """
13    
14    __author__="Lutz Gross, l.gross@uq.edu.au"
15    __copyright__="""  Copyright (c) 2006 by ACcESS MNRF
16                        http://www.access.edu.au
17                    Primary Business: Queensland, Australia"""
18    __license__="""Licensed under the Open Software License version 3.0
19                 http://www.opensource.org/licenses/osl-3.0.php"""
20    __url__="http://www.iservo.edu.au/esys"
21    __version__="$Revision$"
22    __date__="$Date$"
23    
24    
25  from types import StringType,IntType,FloatType,BooleanType,ListType,DictType  from types import StringType,IntType,FloatType,BooleanType,ListType,DictType
26  from sys import stdout  from sys import stdout
27    import numarray
28    import operator
29  import itertools  import itertools
30  # import modellib  temporarily removed!!!  # import modellib  temporarily removed!!!
31    
# Line 69  def parse(xml): Line 93  def parse(xml):
93      doc = minidom.parseString(xml)      doc = minidom.parseString(xml)
94      sim = getComponent(doc.firstChild)      sim = getComponent(doc.firstChild)
95      for obj_id, link in LinkRegistry:      for obj_id, link in LinkRegistry:
96            print obj_id.__class__
97          link.target = LinkableObjectRegistry[obj_id]          link.target = LinkableObjectRegistry[obj_id]
98    
99      return sim      return sim
100    
101    def importName(modulename, name):
102        """ Import a named object from a module in the context of this function,
103            which means you should use fully qualified module paths.
104            
105            Return None on failure.
106    
107            This function from: http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/52241
108        """
109        module = __import__(modulename, globals(), locals(), [name])
110            
111        try:
112            return vars(module)[name]
113        except KeyError:
114            raise ImportError("Could not import %s from %s" % (name, modulename))
115    
116  def getComponent(doc):  def getComponent(doc):
117      """      """
118      Used to get components of Simualtions, Models.      Used to get components of Simualtions, Models.
# Line 84  def getComponent(doc): Line 124  def getComponent(doc):
124                  if node.getAttribute("type") == 'Simulation':                  if node.getAttribute("type") == 'Simulation':
125                      return Simulation.fromDom(node)                      return Simulation.fromDom(node)
126              if node.tagName == 'Model':              if node.tagName == 'Model':
127                  model_type = node.getAttribute("type")                  if (node.getAttribute("module")):
128                  model_subclasses = Model.__subclasses__()                      model_module = node.getAttribute("module")
129                  for model in model_subclasses:                      model_type = node.getAttribute("type")
130                      if model_type == model.__name__:                      return importName(model_module, model_type).fromDom(node)
131                          return Model.fromDom(node)                  else:
132                        model_type = node.getAttribute("type")
133                        model_subclasses = Model.__subclasses__()
134                        for model in model_subclasses:
135                            if model_type == model.__name__:
136                                return Model.fromDom(node)
137              if node.tagName == 'ParameterSet':              if node.tagName == 'ParameterSet':
138                  parameter_type = node.getAttribute("type")                  parameter_type = node.getAttribute("type")
139                  return ParameterSet.fromDom(node)                  return ParameterSet.fromDom(node)
# Line 117  class Link: Line 162  class Link:
162          self.target = target          self.target = target
163          self.attribute = None          self.attribute = None
164          self.setAttributeName(attribute)          self.setAttributeName(attribute)
165    
166        def getAttributeName(self):
167            """
168            returns the name of the attribute the link is pointing to
169            """
170            return self.attribute
171            
172      def setAttributeName(self,attribute):      def setAttributeName(self,attribute):
173          """          """
# Line 322  class ParameterSet(LinkableObject): Line 373  class ParameterSet(LinkableObject):
373       - a ParameterSet object       - a ParameterSet object
374       - a Simulation object       - a Simulation object
375       - a Model object       - a Model object
376       - any other object (not considered by writeESySXML and writeXML)       - a numarray object
377             - a list of booleans
378            - any other object (not considered by writeESySXML and writeXML)
379            
380      Example how to create an ESySParameters object::      Example how to create an ESySParameters object::
381            
# Line 351  class ParameterSet(LinkableObject): Line 404  class ParameterSet(LinkableObject):
404          self.declareParameters(parameters)          self.declareParameters(parameters)
405    
406      def __repr__(self):      def __repr__(self):
407          return "<%s %r>" % (self.__class__.__name__,          return "<%s %d>"%(self.__class__.__name__,id(self))
                             [(p, getattr(self, p, None)) for p in self.parameters])  
408            
409      def declareParameter(self,**parameters):      def declareParameter(self,**parameters):
410          """          """
# Line 424  class ParameterSet(LinkableObject): Line 476  class ParameterSet(LinkableObject):
476          self._parametersToDom(document, pset)          self._parametersToDom(document, pset)
477    
478      def _parametersToDom(self, document, node):      def _parametersToDom(self, document, node):
479          node.setAttribute ('id', str(self.id))          node.setAttribute('id', str(self.id))
480            node.setIdAttribute("id")
481          for name,value in self:          for name,value in self:
482              param = document.createElement('Parameter')              param = document.createElement('Parameter')
483              param.setAttribute('type', value.__class__.__name__)              param.setAttribute('type', value.__class__.__name__)
# Line 433  class ParameterSet(LinkableObject): Line 486  class ParameterSet(LinkableObject):
486    
487              val = document.createElement('Value')              val = document.createElement('Value')
488    
489              if isinstance(value,ParameterSet):              if isinstance(value,(ParameterSet,Link,DataSource)):
490                  value.toDom(document, val)                  value.toDom(document, val)
491                  param.appendChild(val)                  param.appendChild(val)
492              elif isinstance(value, Link):              elif isinstance(value, numarray.NumArray):
493                  value.toDom(document, val)                  shape = value.getshape()
494                    if isinstance(shape, tuple):
495                        size = reduce(operator.mul, shape)
496                        shape = ' '.join(map(str, shape))
497                    else:
498                        size = shape
499                        shape = str(shape)
500    
501                    arraytype = value.type()
502                    numarrayElement = document.createElement('NumArray')
503                    numarrayElement.appendChild(dataNode(document, 'ArrayType', str(arraytype)))
504                    numarrayElement.appendChild(dataNode(document, 'Shape', shape))
505                    numarrayElement.appendChild(dataNode(document, 'Data', ' '.join(
506                        [str(x) for x in numarray.reshape(value, size)])))
507                    val.appendChild(numarrayElement)
508                  param.appendChild(val)                  param.appendChild(val)
509              elif isinstance(value,StringType):              elif isinstance (value, list):
510                  param.appendChild(dataNode(document, 'Value', value))                  param.appendChild(dataNode(document, 'Value', ' '.join(
511                        [str(x) for x in value])
512                    ))
513              else:              else:
514                  param.appendChild(dataNode(document, 'Value', str(value)))                  param.appendChild(dataNode(document, 'Value', str(value)))
515    
# Line 453  class ParameterSet(LinkableObject): Line 522  class ParameterSet(LinkableObject):
522              """              """
523              Remove the empty nodes from the children of this node.              Remove the empty nodes from the children of this node.
524              """              """
525              return [x for x in node.childNodes              ret = []
526                      if not isinstance(x, minidom.Text) or x.nodeValue.strip()]              for x in node.childNodes:
527                    if isinstance(x, minidom.Text):
528                        if x.nodeValue.strip():
529                            ret.append(x)
530                    else:
531                        ret.append(x)
532                return ret
533    
534          def _floatfromValue(doc):          def _floatfromValue(doc):
535              return float(doc.nodeValue.strip())              return float(doc.nodeValue.strip())
# Line 466  class ParameterSet(LinkableObject): Line 541  class ParameterSet(LinkableObject):
541              return int(doc.nodeValue.strip())              return int(doc.nodeValue.strip())
542    
543          def _boolfromValue(doc):          def _boolfromValue(doc):
544              return bool(doc.nodeValue.strip())              return _boolfromstring(doc.nodeValue.strip())
545          
546            def _nonefromValue(doc):
547                return None
548    
549            def _numarrayfromValue(doc):
550                for node in _children(doc):
551                    if node.tagName == 'ArrayType':
552                        arraytype = node.firstChild.nodeValue.strip()
553                    if node.tagName == 'Shape':
554                        shape = node.firstChild.nodeValue.strip()
555                        shape = [int(x) for x in shape.split()]
556                    if node.tagName == 'Data':
557                        data = node.firstChild.nodeValue.strip()
558                        data = [float(x) for x in data.split()]
559                return numarray.reshape(numarray.array(data, type=getattr(numarray, arraytype)),
560                                        shape)
561          
562            def _listfromValue(doc):
563                return [_boolfromstring(x) for x in doc.nodeValue.split()]
564    
565    
566            def _boolfromstring(s):
567                if s == 'True':
568                    return True
569                else:
570                    return False
571          # Mapping from text types in the xml to methods used to process trees of that type          # Mapping from text types in the xml to methods used to process trees of that type
572          ptypemap = {"Simulation": Simulation.fromDom,          ptypemap = {"Simulation": Simulation.fromDom,
573                      "Model":Model.fromDom,                      "Model":Model.fromDom,
574                      "ParameterSet":ParameterSet.fromDom,                      "ParameterSet":ParameterSet.fromDom,
575                      "Link":Link.fromDom,                      "Link":Link.fromDom,
576                        "DataSource":DataSource.fromDom,
577                      "float":_floatfromValue,                      "float":_floatfromValue,
578                      "int":_intfromValue,                      "int":_intfromValue,
579                      "str":_stringfromValue,                      "str":_stringfromValue,
580                      "bool":_boolfromValue                      "bool":_boolfromValue,
581                        "list":_listfromValue,
582                        "NumArray":_numarrayfromValue,
583                        "NoneType":_nonefromValue,
584                      }                      }
585    
586  #        print doc.toxml()  #        print doc.toxml()
# Line 493  class ParameterSet(LinkableObject): Line 597  class ParameterSet(LinkableObject):
597    
598                  if childnode.tagName == "Value":                  if childnode.tagName == "Value":
599                      nodes = _children(childnode)                      nodes = _children(childnode)
600                    #    if ptype == 'NumArray':
601                     #       pvalue = _numarrayfromValue(nodes)
602                     #   else:
603                      pvalue = ptypemap[ptype](nodes[0])                      pvalue = ptypemap[ptype](nodes[0])
604    
605              parameters[pname] = pvalue              parameters[pname] = pvalue
# Line 520  class Model(ParameterSet): Line 627  class Model(ParameterSet):
627      finalizing condition is fullfilled. At each time step an iterative      finalizing condition is fullfilled. At each time step an iterative
628      process can be performed and the time step size can be controlled. A      process can be performed and the time step size can be controlled. A
629      Model has the following work flow::      Model has the following work flow::
630              
631            doInitialization()            doInitialization()
632              while not terminateInitialIteration(): doInitializationiStep()
633              doInitialPostprocessing()
634            while not finalize():            while not finalize():
635                 dt=getSafeTimeStepSize(dt)                 dt=getSafeTimeStepSize(dt)
636                 doStepPreprocessing(dt)                 doStepPreprocessing(dt)
# Line 547  class Model(ParameterSet): Line 656  class Model(ParameterSet):
656          ParameterSet.__init__(self, parameters=parameters,**kwarg)          ParameterSet.__init__(self, parameters=parameters,**kwarg)
657    
658      def __str__(self):      def __str__(self):
659         return "<%s %d>"%(self.__class__,id(self))         return "<%s %d>"%(self.__class__.__name__,id(self))
660    
661      def toDom(self, document, node):      def toDom(self, document, node):
662          """          """
# Line 555  class Model(ParameterSet): Line 664  class Model(ParameterSet):
664      """      """
665          pset = document.createElement('Model')          pset = document.createElement('Model')
666          pset.setAttribute('type', self.__class__.__name__)          pset.setAttribute('type', self.__class__.__name__)
667            if not self.__class__.__module__.startswith('esys.escript'):
668                pset.setAttribute('module', self.__class__.__module__)
669          node.appendChild(pset)          node.appendChild(pset)
670          self._parametersToDom(document, pset)          self._parametersToDom(document, pset)
671    
# Line 565  class Model(ParameterSet): Line 676  class Model(ParameterSet):
676      This function may be overwritten.      This function may be overwritten.
677      """      """
678          pass          pass
679        def doInitialStep(self):
680            """
681        performs an iteration step in the initialization phase
682    
683        This function may be overwritten.
684        """
685            pass
686    
687        def terminateInitialIteration(self):
688            """
689        Returns True if iteration at the inital phase is terminated.
690        """
691            return True
692    
693        def doInitialPostprocessing(self):
694            """
695        finalises the initialization iteration process
696    
697        This function may be overwritten.
698        """
699            pass
700            
701      def getSafeTimeStepSize(self,dt):      def getSafeTimeStepSize(self,dt):
702          """          """
# Line 615  class Model(ParameterSet): Line 747  class Model(ParameterSet):
747      Returns True if iteration on a time step is terminated.      Returns True if iteration on a time step is terminated.
748      """      """
749          return True          return True
750    
751                
752      def doStepPostprocessing(self,dt):      def doStepPostprocessing(self,dt):
753          """          """
754      Finalalizes the time step.      finalises the time step.
755    
756          dt is the currently used time step size.          dt is the currently used time step size.
757    
# Line 644  class Simulation(Model): Line 777  class Simulation(Model):
777            
778      FAILED_TIME_STEPS_MAX=20      FAILED_TIME_STEPS_MAX=20
779      MAX_ITER_STEPS=50      MAX_ITER_STEPS=50
780        MAX_CHANGE_OF_DT=2.
781            
782      def __init__(self, models=[], **kwargs):      def __init__(self, models=[], **kwargs):
783          """          """
# Line 685  class Simulation(Model): Line 819  class Simulation(Model):
819      Sets the i-th model.      Sets the i-th model.
820      """      """
821          if not isinstance(value,Model):          if not isinstance(value,Model):
822              raise ValueError("assigned value is not a Model")              raise ValueError,"assigned value is not a Model but instance of %s"%(value.__class__.__name__,)
823          for j in range(max(i-len(self.__models)+1,0)):          for j in range(max(i-len(self.__models)+1,0)):
824              self.__models.append(None)              self.__models.append(None)
825          self.__models[i]=value          self.__models[i]=value
# Line 720  class Simulation(Model): Line 854  class Simulation(Model):
854          document, rootnode = esysDoc()          document, rootnode = esysDoc()
855          self.toDom(document, rootnode)          self.toDom(document, rootnode)
856          targetsList = document.getElementsByTagName('Target')          targetsList = document.getElementsByTagName('Target')
857          for i in targetsList:          
858              targetId = int(i.firstChild.nodeValue.strip())          for element in targetsList:
859                targetId = int(element.firstChild.nodeValue.strip())
860                if document.getElementById(str(targetId)):
861                    continue
862              targetObj = LinkableObjectRegistry[targetId]              targetObj = LinkableObjectRegistry[targetId]
863              targetObj.toDom(document, rootnode)              targetObj.toDom(document, rootnode)
864          ostream.write(document.toprettyxml())          ostream.write(document.toprettyxml())
# Line 733  class Simulation(Model): Line 870  class Simulation(Model):
870          This is the minimum over the time step sizes of all models.          This is the minimum over the time step sizes of all models.
871      """      """
872          out=min([o.getSafeTimeStepSize(dt) for o in self.iterModels()])          out=min([o.getSafeTimeStepSize(dt) for o in self.iterModels()])
         print "%s: safe step size is %e."%(str(self),out)  
873          return out          return out
874            
875      def doInitialization(self):      def doInitialization(self):
# Line 743  class Simulation(Model): Line 879  class Simulation(Model):
879          self.n=0          self.n=0
880          self.tn=0.          self.tn=0.
881          for o in self.iterModels():          for o in self.iterModels():
882              o.doInitialization()               o.doInitialization()
883            def doInitialStep(self):
884            """
885        performs an iteration step in the initialization step for all models
886        """
887            iter=0
888            while not self.terminateInitialIteration():
889                if iter==0: self.trace("iteration for initialization starts")
890                iter+=1
891                self.trace("iteration step %d"%(iter))
892                for o in self.iterModels():
893                     o.doInitialStep()
894                if iter>self.MAX_ITER_STEPS:
895                     raise IterationDivergenceError("initial iteration did not converge after %s steps."%iter)
896            self.trace("Initialization finalized after %s iteration steps."%iter)
897    
898        def doInitialPostprocessing(self):
899            """
900        finalises the initialization iteration process for all models.
901        """
902            for o in self.iterModels():
903                o.doInitialPostprocessing()
904      def finalize(self):      def finalize(self):
905          """          """
906      Returns True if any of the models is to be finalized.      Returns True if any of the models is to be finalized.
# Line 753  class Simulation(Model): Line 909  class Simulation(Model):
909                
910      def doFinalization(self):      def doFinalization(self):
911          """          """
912      Finalalizes the time stepping for all models.      finalises the time stepping for all models.
913      """      """
914          for i in self.iterModels(): i.doFinalization()          for i in self.iterModels(): i.doFinalization()
915          self.trace("end of time integation.")          self.trace("end of time integation.")
# Line 771  class Simulation(Model): Line 927  class Simulation(Model):
927      """      """
928          out=all([o.terminateIteration() for o in self.iterModels()])          out=all([o.terminateIteration() for o in self.iterModels()])
929          return out          return out
930    
931        def terminateInitialIteration(self):
932            """
933        Returns True if all initial iterations for all models are terminated.
934        """
935            out=all([o.terminateInitialIteration() for o in self.iterModels()])
936            return out
937                
938      def doStepPostprocessing(self,dt):      def doStepPostprocessing(self,dt):
939          """          """
940      Finalalizes the iteration process for all models.      finalises the iteration process for all models.
941      """      """
942          for o in self.iterModels():          for o in self.iterModels():
943              o.doStepPostprocessing(dt)              o.doStepPostprocessing(dt)
# Line 805  class Simulation(Model): Line 968  class Simulation(Model):
968      Run the simulation by performing essentially::      Run the simulation by performing essentially::
969            
970          self.doInitialization()          self.doInitialization()
971                while not self.terminateInitialIteration(): self.doInitialStep()
972                self.doInitialPostprocessing()
973          while not self.finalize():          while not self.finalize():
974              dt=self.getSafeTimeStepSize()              dt=self.getSafeTimeStepSize()
975              self.doStep(dt)              self.doStep(dt)
# Line 822  class Simulation(Model): Line 987  class Simulation(Model):
987          In both cases the time integration is given up after          In both cases the time integration is given up after
988      C{Simulation.FAILED_TIME_STEPS_MAX} attempts.      C{Simulation.FAILED_TIME_STEPS_MAX} attempts.
989          """          """
         dt=self.UNDEF_DT  
990          self.doInitialization()          self.doInitialization()
991            self.doInitialStep()
992            self.doInitialPostprocessing()
993            dt=self.UNDEF_DT
994          while not self.finalize():          while not self.finalize():
995              step_fail_counter=0              step_fail_counter=0
996              iteration_fail_counter=0              iteration_fail_counter=0
997              dt_new=self.getSafeTimeStepSize(dt)              if self.n==0:
998                    dt_new=self.getSafeTimeStepSize(dt)
999                else:
1000                    dt_new=min(max(self.getSafeTimeStepSize(dt),dt/self.MAX_CHANGE_OF_DT),dt*self.MAX_CHANGE_OF_DT)
1001              self.trace("%d. time step %e (step size %e.)" % (self.n+1,self.tn+dt_new,dt_new))              self.trace("%d. time step %e (step size %e.)" % (self.n+1,self.tn+dt_new,dt_new))
1002              end_of_step=False              end_of_step=False
1003              while not end_of_step:              while not end_of_step:
1004                 end_of_step=True                 end_of_step=True
1005                 if not dt_new>0:                 if not dt_new>0:
1006                    raise NonPositiveStepSizeError("non-positive step size in step %d",self.n+1)                    raise NonPositiveStepSizeError("non-positive step size in step %d"%(self.n+1))
1007                 try:                 try:
1008                    self.doStepPreprocessing(dt_new)                    self.doStepPreprocessing(dt_new)
1009                    self.doStep(dt_new)                    self.doStep(dt_new)
# Line 843  class Simulation(Model): Line 1013  class Simulation(Model):
1013                    end_of_step=False                    end_of_step=False
1014                    iteration_fail_counter+=1                    iteration_fail_counter+=1
1015                    if iteration_fail_counter>self.FAILED_TIME_STEPS_MAX:                    if iteration_fail_counter>self.FAILED_TIME_STEPS_MAX:
1016                             raise SimulationBreakDownError("reduction of time step to achieve convergence failed.")                             raise SimulationBreakDownError("reduction of time step to achieve convergence failed after %s steps."%self.FAILED_TIME_STEPS_MAX)
1017                    self.trace("iteration fails. time step is repeated with new step size.")                    self.trace("Iteration failed. Time step is repeated with new step size %s."%dt_new)
1018                 except FailedTimeStepError:                 except FailedTimeStepError:
1019                    dt_new=self.getSafeTimeStepSize(dt)                    dt_new=self.getSafeTimeStepSize(dt)
1020                    end_of_step=False                    end_of_step=False
1021                    step_fail_counter+=1                    step_fail_counter+=1
1022                    self.trace("time step is repeated.")                    self.trace("Time step is repeated with new time step size %s."%dt_new)
1023                    if step_fail_counter>self.FAILED_TIME_STEPS_MAX:                    if step_fail_counter>self.FAILED_TIME_STEPS_MAX:
1024                          raise SimulationBreakDownError("time integration is given up after %d attempts."%step_fail_counter)                          raise SimulationBreakDownError("Time integration is given up after %d attempts."%step_fail_counter)
1025              dt=dt_new              dt=dt_new
1026              if not check_point==None:              if not check_point==None:
1027                  if n%check_point==0:                  if n%check_point==0:
# Line 901  class NonPositiveStepSizeError(Exception Line 1071  class NonPositiveStepSizeError(Exception
1071      """      """
1072      pass      pass
1073    
1074    class DataSource(object):
1075        """
1076        Class for handling data sources, including local and remote files. This class is under development.
1077        """
1078    
1079        def __init__(self, uri="file.ext", fileformat="unknown"):
1080            self.uri = uri
1081            self.fileformat = fileformat
1082    
1083        def toDom(self, document, node):
1084            """
1085            C{toDom} method of DataSource. Creates a DataSource node and appends it to the
1086        current XML document.
1087            """
1088            ds = document.createElement('DataSource')
1089            ds.appendChild(dataNode(document, 'URI', self.uri))
1090            ds.appendChild(dataNode(document, 'FileFormat', self.fileformat))
1091            node.appendChild(ds)
1092    
1093        def fromDom(cls, doc):
1094            uri= doc.getElementsByTagName("URI")[0].firstChild.nodeValue.strip()
1095            fileformat= doc.getElementsByTagName("FileFormat")[0].firstChild.nodeValue.strip()
1096            ds = cls(uri, fileformat)
1097            return ds
1098    
1099        def getLocalFileName(self):
1100            return self.uri
1101    
1102        fromDom = classmethod(fromDom)
1103        
1104  # vim: expandtab shiftwidth=4:  # vim: expandtab shiftwidth=4:

Legend:
Removed from v.149  
changed lines
  Added in v.917

  ViewVC Help
Powered by ViewVC 1.1.26