/[escript]/trunk/downunder/py_src/inversions.py
ViewVC logotype

Diff of /trunk/downunder/py_src/inversions.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 4121 by gross, Wed Dec 19 00:24:50 2012 UTC revision 4122 by gross, Thu Dec 20 05:42:35 2012 UTC
# Line 22  __license__="""Licensed under the Open S Line 22  __license__="""Licensed under the Open S
22  http://www.opensource.org/licenses/osl-3.0.php"""  http://www.opensource.org/licenses/osl-3.0.php"""
23  __url__="https://launchpad.net/escript-finley"  __url__="https://launchpad.net/escript-finley"
24    
25  __all__ = ['InversionBase','SingleParameterInversionBase','GravityInversion','MagneticInversion']  __all__ = ['InversionBase', 'GravityInversion','MagneticInversion']
26    
27  import logging  import logging
28  from esys.escript import *  from esys.escript import *
29  import esys.escript.unitsSI as U  import esys.escript.unitsSI as U
30  from esys.weipa import createDataset  from esys.weipa import createDataset
31    
32  from .inversioncostfunctions import SimpleInversionCostFunction  from .inversioncostfunctions import InversionCostFunction
33  from .forwardmodels import GravityModel, MagneticModel  from .forwardmodels import GravityModel, MagneticModel
34  from .mappings import *  from .mappings import *
35  from .minimizers import *  from .minimizers import *
# Line 38  from .datasources import DataSource Line 38  from .datasources import DataSource
38    
39  class InversionBase(object):  class InversionBase(object):
40      """      """
41      Base class for inversions.      Base class for running an inversion
42      """      """
43      def __init__(self):      def __init__(self):
44          self.logger=logging.getLogger('inv.%s'%self.__class__.__name__)          self.logger=logging.getLogger('inv.%s'%self.__class__.__name__)
45          self._solver_callback = None          self.__costfunction = None
46            self.setSolverMaxIterations()
47            self.setSolverTolerance()
48            self.setSolverCallback()
49            self.setSolverClass()
50          self._solver_opts = {}          self._solver_opts = {}
51          self._solver_xtol = 1e-9          self.initial_value = None
         self._solver_maxiter = 200  
52          # set default solver          # set default solver
         self.setSolverClass()  
         self.__domain=None  
         self.__regularization=None  
         self.__forwardmodel=None  
         self.__mapping=None  
   
     def setDomain(self, domain):  
         """  
         sets the domain of the inversion  
   
         :param domain: domain of the inversion  
         :type domain: `Domain`  
         """  
         self.__domain=domain  
   
     def getDomain(self):  
         """  
         returns the domain of the inversion  
   
         :rtype: `Domain`  
         """  
         return  self.__domain  
   
     def setMapping(self, mapping):  
         """  
         Sets the mapping object to map between model parameters and the data.  
   
         :param mapping: Parameter mapping object  
         :type mapping: `Mapping`  
         """  
         self.__mapping=mapping  
   
     def getMapping(self):  
         """  
         return the mapping(s) used in the inversion  
   
         :rtype: `Mapping`  
         """  
         return self.__mapping  
   
   
     def setRegularization(self, regularization):  
         """  
         Sets the regularization for the inversion.  
   
         :param regularization: regularization  
         :type regularization: `Regularization`  
         """  
         self.__regularization=regularization  
53    
     def getRegularization(self):  
         """  
         returns the regularization method(s)  
   
         :rtype: `Regularization`  
         """  
         return self.__regularization  
54    
55      def setForwardModel(self, forwardmodel):      def setCostFunction(self, costfunction):
56          """          """
57          Sets the forward model(s) for the inversion.          sets the cost function of the inversion. This function needs to be called
58            before the inversion iteration can be started.
59    
60          :param forwardmodel: forward model          :param costfunction: domain of the inversion
61          :type forwardmodel: `ForwardModel`          :type costfunction: 'InversionCostFunction'
62          """          """
63          self.__forwardmodel=forwardmodel          self.__costfunction=costfunction
64    
65      def getForwardModel(self):      def getCostFunction(self):
66          """          """
67          returns the forward model          returns the domain of the inversion
68    
69          :rtype: `ForwardModel`          :rtype: 'InversionCostFunction'
70          """          """
71          return self.__forwardmodel          if self.isSetUp():
72               return  self.__costfunction
73            else:
74           raise RuntimeError("Inversion is not set up.")
75        
76      def setSolverClass(self, solverclass=None):      def setSolverClass(self, solverclass=None):
77          """          """
78          The solver to be used in the inversion process. See the minimizers          The solver to be used in the inversion process. See the minimizers
# Line 132  class InversionBase(object): Line 83  class InversionBase(object):
83              self.solverclass=MinimizerLBFGS              self.solverclass=MinimizerLBFGS
84          else:          else:
85              self.solverclass=solverclass              self.solverclass=solverclass
86                
87        def isSetUp(self):
88            """
89            returns True if the inversion is set up and is ready to run.
90            
91            :rtype: `bool`
92            """
93            if  self.__costfunction:
94            return True
95        else:
96            return False
97            
98        def getDomain(self):
99            """
100            returns the domain of the inversion
101    
102      def setSolverCallback(self, callback):          :rtype: `Domain`
103            """
104            self.getCostFunction().getDomain()
105            
106        def setSolverCallback(self, callback=None):
107          """          """
108          Sets the callback function which is called after every solver          Sets the callback function which is called after every solver
109          iteration.          iteration.
110          """          """
111          self._solver_callback=callback          self._solver_callback=callback
112            
113      def setSolverMaxIterations(self, maxiter):      def setSolverMaxIterations(self, maxiter=None):
114          """          """
115          Sets the maximum number of solver iterations to run.          Sets the maximum number of solver iterations to run.
116          """          """
117            if maxiter == None: maxiter = 200
118          if maxiter>0:          if maxiter>0:
119              self._solver_maxiter=maxiter              self._solver_maxiter=maxiter
120          else:          else:
# Line 156  class InversionBase(object): Line 127  class InversionBase(object):
127          """          """
128          self._solver_opts.update(**opts)          self._solver_opts.update(**opts)
129    
130      def setSolverTolerance(self, tol):      def setSolverTolerance(self, tol=None, atol=None):
131          """      """
132          Sets the error tolerance for the solver. An acceptable solution is      Sets the error tolerance for the solver. An acceptable solution is
133          considered to be found once the tolerance is reached.      considered to be found once the tolerance is reached.
134          """      
135          if tol>0:      :param tol: tolerance for changes to level set function. If `None` changes to the
136              self._solver_xtol=tol              level set function are not checked for convergence during iteration.
137          else:      :param atol: tolerance for changes to cost function. If `None` changes to the
138              raise ValueError("tolerance must be positive.")              cost function are not checked for convergence during iteration.
139        :type tol: `float` or `None`
140      def isSetup(self):      :type atol: `float` or `None`
141          """      :note: if both arguments are equal to `None` the default setting tol=1e-4, atol=None is used.
142          returns True if the inversion is set up and ready to run.  
143          """      """
144          raise NotImplementedError      if tol == None and atol==None:
145            tol=1e-4
146      def run(self, *args):  
147          """      self._solver_xtol=tol
148          This method starts the inversion process and must be overwritten.      self._solver_atol=atol
149          Preferably, users should be able to set a starting point so  
150          inversions can be continued after stopping.  
151          """      def setInitialGuess(self, *args):
152          raise NotImplementedError          """
153            set the initial guess for the inversion iteration. By default zero is used.
154            
155  class SingleParameterInversionBase(InversionBase):          """
156      """          self.initial_value=self.getCostFunction().createLevelSetFunction(*args)
157      Base class for inversions with a single parameter to be found.          
158      """      def run(self):
159      def __init__(self):          """
160          super(SingleParameterInversionBase,self).__init__()          this function runs the inversion.
161          self.setTradeOffFactors()          
162            
     def setTradeOffFactors(self, mu_reg=None, mu_model=1.):  
         """  
         Sets the weighting factors for the cost function.  
         """  
         self.logger.debug("Setting weighting factors...")  
         self.logger.debug("mu_reg = %s"%mu_reg)  
         self.logger.debug("mu_model = %s"%mu_model)  
         self._mu_reg=mu_reg  
         self._mu_model=mu_model  
   
     def siloWriterCallback(self, k, x, fx, gfx):  
163          """          """
164          callback function that can be used to track the solution          if not self.isSetUp():
165                raise RuntimeError("Inversion is not setup.")          
         :param k: iteration count  
         :param x: current m approximation  
         :param fx: value of cost function  
         :param gfx: gradient of f at x  
         """  
         fn='inv.%d'%k  
         ds=createDataset(rho=self.getMapping().getValue(x))  
         ds.setCycleAndTime(k,k)  
         ds.saveSilo(fn)  
         self.logger.debug("f(m) = %e"%fx)  
   
   
     def isSetup(self):  
         if self.getRegularization() and self.getMapping() \  
                 and self.getForwardModel() and self.getDomain():  
             return True  
         else:  
             return False  
   
     def run(self, initial_value=0.):  
         if not self.isSetup():  
             raise RuntimeError("Inversion is not setup properly.")            
         f=SimpleInversionCostFunction(self.getRegularization(), self.getMapping(), self.getForwardModel())  
         f.setTradeOffFactors(mu_reg=self._mu_reg, mu_model=self._mu_model)  
166    
167          solver=self.solverclass(f)          if self.initial_value == None: self.setInitialGuess()
168            solver=self.solverclass(self.getCostFunction())
169          solver.setCallback(self._solver_callback)          solver.setCallback(self._solver_callback)
170          solver.setMaxIterations(self._solver_maxiter)          solver.setMaxIterations(self._solver_maxiter)
171          solver.setOptions(**self._solver_opts)          solver.setOptions(**self._solver_opts)
172          solver.setTolerance(x_tol=self._solver_xtol)          solver.setTolerance(x_tol=self._solver_xtol)
         if not isinstance(initial_value, Data):  
             initial_value=Scalar(initial_value, ContinuousFunction(self.getDomain()))  
         m_init=self.getMapping().getInverse(initial_value)  
   
173          self.logger.info("Starting solver...")          self.logger.info("Starting solver...")
174          try:          try:
175              solver.run(m_init)              solver.run(self.initial_value)
176              self.m=solver.getResult()              self.m=solver.getResult()
177              self.p=self.getMapping().getValue(self.m)              self.p=self.getCostFunction().getProperties(self.m)
178          except MinimizerException as e:          except MinimizerException as e:
179              self.m=solver.getResult()              self.m=solver.getResult()
180              self.p=self.getMapping().getValue(self.m)              self.p=self.getCostFunction().getProperties(self.m)
181                self.logger.info("iteration failed.")
182              raise e              raise e
183          self.logger.info("result* = %s"%self.p)          self.logger.info("iteration completed.")
184          solver.logSummary()          solver.logSummary()
185          return self.p          return self.p
186    
187        def setup(self, *args, **k_args):
188            """
189            returns True if the inversion is set up and ready to run.
190            """
191            pass
192    
193  class GravityInversion(SingleParameterInversionBase):  class GravityInversion(InversionBase):
194      """      """
195      Inversion of Gravity (Bouguer) anomaly data.      Inversion of Gravity (Bouguer) anomaly data.
196      """      """
# Line 267  class GravityInversion(SingleParameterIn Line 206  class GravityInversion(SingleParameterIn
206          :type rho0: ``float`` or `Scalar`          :type rho0: ``float`` or `Scalar`
207          """          """
208          self.logger.info('Retrieving domain...')          self.logger.info('Retrieving domain...')
209          self.setDomain(domainbuilder.getDomain())          dom=domainbuilder.getDomain()
210          DIM=self.getDomain().getDim()          DIM=dom.getDim()
211          #========================          #========================
212          self.logger.info('Creating mapping...')          self.logger.info('Creating mapping...')
213          self.setMapping(DensityMapping(self.getDomain(), rho0=rho0, drho=drho, z0=z0, beta=beta))          rho_mapping=DensityMapping(dom, rho0=rho0, drho=drho, z0=z0, beta=beta)
214          scale_mapping=self.getMapping().getTypicalDerivative()          scale_mapping=rho_mapping.getTypicalDerivative()
215          print " scale_mapping = ",scale_mapping          print " scale_mapping = ",scale_mapping
216          #========================          #========================
217          self.logger.info("Setting up regularization...")          self.logger.info("Setting up regularization...")
218          if w1 is None:          if w1 is None:
219              w1=[1.]*DIM              w1=[1.]*DIM
220          rho_mask = domainbuilder.getSetDensityMask()          rho_mask = domainbuilder.getSetDensityMask()
221          self.setRegularization(Regularization(self.getDomain(), numLevelSets=1,\          regularization=Regularization(dom, numLevelSets=1,\
222                                 w0=w0, w1=w1, location_of_set_m=rho_mask))                                 w0=w0, w1=w1, location_of_set_m=rho_mask)
223          #====================================================================          #====================================================================
224          self.logger.info("Retrieving gravity surveys...")          self.logger.info("Retrieving gravity surveys...")
225          surveys=domainbuilder.getGravitySurveys()          surveys=domainbuilder.getGravitySurveys()
# Line 301  class GravityInversion(SingleParameterIn Line 240  class GravityInversion(SingleParameterIn
240          #====================================================================          #====================================================================
241    
242          self.logger.info("Setting up model...")          self.logger.info("Setting up model...")
243          self.setForwardModel(GravityModel(self.getDomain(), w, g))          forward_model=GravityModel(dom, w, g)
244          self.getForwardModel().rescaleWeights(rho_scale=scale_mapping)          forward_model.rescaleWeights(rho_scale=scale_mapping)
         # this is switched off for now:  
         if self._mu_reg is None and False:  
             x=self.getDomain().getX()  
             l=0.  
             for i in range(DIM-1):  
                 l=max(l, sup(x[i])-inf(x[i]))  
             G=U.Gravitational_Constant  
             mu_reg=0.5*(l*l*G)**2  
             self.setTradeOffFactors(mu_reg=mu_reg)  
245    
246    
247            #====================================================================
248            self.logger.info("Setting cost function...")
249            self.setCostFunction(InversionCostFunction(regularization, rho_mapping, forward_model))
250            
251        def setInitialGuess(self, rho=None):
252            """
253            set the initial guess *rho* for density the inversion iteration. If no *rho* present
254            then an appropriate initial guess is chosen.
255            
256            :param rho: initial value for the density anomaly.
257            :type rho: `Scalar`
258            """
259            if rho:
260            super(GravityInversion,self).setInitialGuess(rho)
261        else:
262            super(GravityInversion,self).setInitialGuess()
263            
264        def siloWriterCallback(self, k, m, Jm, g_Jm):
265            """
266            callback function that can be used to track the solution
267    
268  class MagneticInversion(SingleParameterInversionBase):          :param k: iteration count
269            :param m: current m approximation
270            :param Jm: value of cost function
271            :param g_Jm: gradient of f at x
272            """
273            fn='inv.%d'%k
274            ds=createDataset(rho=self.getCostFunction().mappings[0].getValue(m))
275            ds.setCycleAndTime(k,k)
276            ds.saveSilo(fn)
277            self.logger.debug("J(m) = %e"%Jm)
278    
279    class MagneticInversion(InversionBase):
280      """      """
281      Inversion of magnetic data.      Inversion of magnetic data.
282      """      """
# Line 329  class MagneticInversion(SingleParameterI Line 291  class MagneticInversion(SingleParameterI
291    
292          """          """
293          self.logger.info('Retrieving domain...')          self.logger.info('Retrieving domain...')
294          self.setDomain(domainbuilder.getDomain())          dom=domainbuilder.getDomain()
295          DIM=self.getDomain().getDim()          DIM=dom.getDim()
296    
297          #========================          #========================
298          self.logger.info('Creating mapping ...')          self.logger.info('Creating mapping ...')
299          self.setMapping(SusceptibilityMapping(self.getDomain(), k0=k0, dk=dk, z0=z0, beta=beta))          susc_mapping=SusceptibilityMapping(dom, k0=k0, dk=dk, z0=z0, beta=beta)
300          scale_mapping=self.getMapping().getTypicalDerivative()          scale_mapping=susc_mapping.getTypicalDerivative()
301          print " scale_mapping = ",scale_mapping          print " scale_mapping = ",scale_mapping
302          #========================          #========================
303          self.logger.info("Setting up regularization...")          self.logger.info("Setting up regularization...")
304          if w1 is None:          if w1 is None:
305              w1=[1.]*DIM              w1=[1.]*DIM
306          k_mask = domainbuilder.getSetSusceptibilityMask()          k_mask = domainbuilder.getSetSusceptibilityMask()
307          self.setRegularization(Regularization(self.getDomain(), numLevelSets=1,\          regularization=Regularization(dom, numLevelSets=1,w0=w0, w1=w1, location_of_set_m=k_mask)
                                w0=w0, w1=w1, location_of_set_m=k_mask))  
308    
309          #====================================================================          #====================================================================
310          self.logger.info("Retrieving magnetic field surveys...")          self.logger.info("Retrieving magnetic field surveys...")
# Line 364  class MagneticInversion(SingleParameterI Line 325  class MagneticInversion(SingleParameterI
325              self.logger.debug("w = %s"%w_i)              self.logger.debug("w = %s"%w_i)
326          #====================================================================          #====================================================================
327          self.logger.info("Setting up model...")          self.logger.info("Setting up model...")
328          self.setForwardModel(MagneticModel(self.getDomain(), w, B, domainbuilder.getBackgroundMagneticFluxDensity()))          forward_model=MagneticModel(dom, w, B, domainbuilder.getBackgroundMagneticFluxDensity())
329          self.getForwardModel().rescaleWeights(k_scale=scale_mapping)          forward_model.rescaleWeights(k_scale=scale_mapping)
330          # this is switched off for now:  
331          if self._mu_reg is None and False:          #====================================================================
332              x=self.getDomain().getX()          self.logger.info("Setting cost function...")
333              l=0.          self.setCostFunction(InversionCostFunction(regularization, susc_mapping, forward_model))
334              for i in range(DIM-1):          
335                  l=max(l, sup(x[i])-inf(x[i]))      def setInitialGuess(self, k=None):
336              mu_reg=0.5*l**2          """
337              self.setTradeOffFactors(mu_reg=mu_reg)          set the initial guess *k* for susceptibility for the inversion iteration. If no *k* present
338            then an appropriate initial guess is chosen.
339            
340            :param k: initial value for the susceptibility anomaly.
341            :type k: `Scalar`
342            """
343            if k:
344            super(MagneticInversion,self).setInitialGuess(k)
345        else:
346            super(MagneticInversion,self).setInitialGuess()
347            
348        def siloWriterCallback(self, k, m, Jm, g_Jm):
349            """
350            callback function that can be used to track the solution
351    
352            :param k: iteration count
353            :param m: current m approximation
354            :param Jm: value of cost function
355            :param g_Jm: gradient of f at x
356            """
357            fn='inv.%d'%k
358            ds=createDataset(susceptibility=self.getCostFunction().mappings[0].getValue(m))
359            ds.setCycleAndTime(k,k)
360            ds.saveSilo(fn)
361            self.logger.debug("J(m) = %e"%Jm)
362            

Legend:
Removed from v.4121  
changed lines
  Added in v.4122

  ViewVC Help
Powered by ViewVC 1.1.26