/[escript]/trunk/downunder/py_src/inversions.py
ViewVC logotype

Contents of /trunk/downunder/py_src/inversions.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3948 - (show annotations)
Fri Aug 24 01:01:34 2012 UTC (6 years, 8 months ago) by caltinay
File MIME type: text/x-python
File size: 6653 byte(s)
Some interface changes to make the inversion more flexible.

1
2 ########################################################
3 #
4 # Copyright (c) 2003-2012 by University of Queensland
5 # Earth Systems Science Computational Center (ESSCC)
6 # http://www.uq.edu.au/esscc
7 #
8 # Primary Business: Queensland, Australia
9 # Licensed under the Open Software License version 3.0
10 # http://www.opensource.org/licenses/osl-3.0.php
11 #
12 ########################################################
13
14 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15 Earth Systems Science Computational Center (ESSCC)
16 http://www.uq.edu.au/esscc
17 Primary Business: Queensland, Australia"""
18 __license__="""Licensed under the Open Software License version 3.0
19 http://www.opensource.org/licenses/osl-3.0.php"""
20 __url__="https://launchpad.net/escript-finley"
21
22 import logging
23
24 from esys.escript import *
25 from esys.weipa import createDataset
26
27 from costfunctions import SimpleCostFunction
28 from forwardmodels import GravityModel
29 from mappings import *
30 from minimizers import *
31 from regularizations import Regularization
32
33 class InversionBase(object):
34 """
35 Base class for inversions.
36 """
37 def __init__(self):
38 self.logger=logging.getLogger('inv.%s'%self.__class__.__name__)
39 self._solver_callback = None
40 self._solver_opts = {}
41 self._solver_tol = 1e-9
42 self._solver_maxiter = 200
43 # use identity mapping by default
44 self.mapping=ScalingMapping(1)
45 self.source=None
46 self.solverclass=MinimizerLBFGS
47
48 def setSolverCallback(self, callback):
49 """
50 Sets the callback function which is called after every solver iteration
51 """
52 self._solver_callback=callback
53
54 def setSolverMaxIterations(self, maxiter):
55 """
56 Sets the maximum number of solver iterations to run
57 """
58 self._solver_maxiter=maxiter
59
60 def setSolverOptions(self, **opts):
61 """
62 Sets additional solver options. The valid options depend on the solver
63 being used.
64 """
65 self._solver_opts.update(**opts)
66
67 def setSolverTolerance(self, tol):
68 """
69 Sets the error tolerance for the solver. An acceptable solution is
70 considered to be found once the tolerance is reached.
71 """
72 self._solver_tol=tol
73
74 def setSolverClass(self, solverclass):
75 """
76 The solver to be used in the inversion process. See the minimizers
77 module for available solvers. By default, the L-BFGS minimizer is used.
78 """
79 self.solverclass=solverclass
80
81 def setDataSource(self, source):
82 """
83 Sets the data source which is used to get the survey data to be
84 inverted.
85 """
86 self.source=source
87
88 def setMapping(self, mapping):
89 """
90 Sets the mapping class to map between model parameters and the data.
91 If no mapping is provided, an identity mapping is used (ScalingMapping
92 with constant 1).
93 """
94 self.mapping=mapping
95
96 def setup(self):
97 """
98 This method must be overwritten to perform any setup needed by the
99 solver to run. The relevant objects for the inversion (e.g. forward
100 model, regularization etc.) should be created and ready to use.
101 """
102 raise NotImplementedError
103
104 def run(self, *args):
105 """
106 This method starts the inversion process and must be overwritten.
107 Preferably, users should be able to set a starting point so
108 inversions can be continued after stopping.
109 """
110 raise NotImplementedError
111
112
113 class GravityInversion(InversionBase):
114 """
115 """
116 def __init__(self):
117 super(GravityInversion,self).__init__()
118 self.__is_setup=False
119 self.setWeights()
120
121 def setWeights(self, mu_reg=None, mu_model=1.):
122 self._mu_reg=mu_reg
123 self._mu_model=mu_model
124
125 def siloWriterCallback(self, k, x, fx, gfx):
126 fn='inv.%d'%k
127 ds=createDataset(rho=self.mapping.getValue(x))
128 ds.setCycleAndTime(k,k)
129 ds.saveSilo(fn)
130 self.logger.debug("Jreg(m) = %e"%self.regularization.getValue(x))
131 self.logger.debug("f(m) = %e"%fx)
132
133 def setup(self):
134 if self.source is None:
135 raise ValueError("No data source set!")
136
137 self.logger.info('Retrieving domain...')
138 domain=self.source.getDomain()
139 DIM=domain.getDim()
140 self.logger.info("Retrieving density mask...")
141 rho_mask = self.source.getDensityMask()
142 self.logger.info("Retrieving gravity and standard deviation data...")
143 g, sigma=self.source.getGravityAndStdDev()
144 chi=safeDiv(1., sigma*sigma)
145 chi=interpolate(chi, Function(domain))
146 g=interpolate(g, Function(domain))
147 self.logger.debug("g = %s"%g)
148 self.logger.debug("sigma = %s"%sigma)
149 self.logger.debug("chi = %s"%chi)
150 chi=chi*kronecker(DIM)[DIM-1]
151 m_ref=self.mapping.getInverse(0.)
152 self.regularization=Regularization(domain, m_ref=m_ref, w0=0, w=[1]*DIM, location_of_set_m=rho_mask)
153 self.forwardmodel=GravityModel(domain, chi, g)
154 self.f=SimpleCostFunction(self.regularization, self.mapping, self.forwardmodel)
155 if self._mu_reg is None:
156 x=domain.getX()
157 l=0
158 for i in range(DIM-1):
159 l=max(l, sup(x[i])-inf(x[i]))
160 G=6.6742e-11
161 self._mu_reg=0.5*(l*l*G)**2
162 self.logger.debug("mu_reg = %s"%self._mu_reg)
163 self.logger.debug("mu_model = %s"%self._mu_model)
164 self.f.setWeights(mu_reg=self._mu_reg, mu_model=self._mu_model)
165 self.__is_setup=True
166
167 def run(self, rho_init=0.):
168 if not self.__is_setup:
169 self.setup()
170 solver=self.solverclass(self.f)
171 solver.setCallback(self._solver_callback)
172 solver.setMaxIterations(self._solver_maxiter)
173 solver.setOptions(**self._solver_opts)
174 solver.setTolerance(self._solver_tol)
175 if not isinstance(rho_init, Data):
176 rho_init=Scalar(rho_init, ContinuousFunction(self.source.getDomain()))
177 m_init=self.mapping.getInverse(rho_init)
178
179 #args={'rho_mask':rho_mask,'g':g[DIM-1],'chi':chi[DIM-1],'sigma':sigma}
180 #try:
181 # args['rho_ref']=self.source.getReferenceDensity()
182 #except:
183 # pass
184 #saveSilo('ref', **args)
185
186 self.logger.info("Starting solver...")
187 solver.run(m_init)
188 m_star=solver.getResult()
189 self.logger.info("m* = %s"%m_star)
190 rho_star=self.mapping.getValue(m_star)
191 self.logger.info("rho* = %s"%rho_star)
192 solver.logSummary()
193 return rho_star
194

  ViewVC Help
Powered by ViewVC 1.1.26