/[escript]/trunk/downunder/py_src/inversions.py
ViewVC logotype

Contents of /trunk/downunder/py_src/inversions.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3947 - (show annotations)
Wed Aug 22 23:19:10 2012 UTC (7 years ago) by caltinay
File MIME type: text/x-python
File size: 7031 byte(s)
Compiling and installing downunder module now. Adjusted import statements
accordingly. Added a gravity test run.

1
2 ########################################################
3 #
4 # Copyright (c) 2003-2012 by University of Queensland
5 # Earth Systems Science Computational Center (ESSCC)
6 # http://www.uq.edu.au/esscc
7 #
8 # Primary Business: Queensland, Australia
9 # Licensed under the Open Software License version 3.0
10 # http://www.opensource.org/licenses/osl-3.0.php
11 #
12 ########################################################
13
14 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
15 Earth Systems Science Computational Center (ESSCC)
16 http://www.uq.edu.au/esscc
17 Primary Business: Queensland, Australia"""
18 __license__="""Licensed under the Open Software License version 3.0
19 http://www.opensource.org/licenses/osl-3.0.php"""
20 __url__="https://launchpad.net/escript-finley"
21
22 import logging
23
24 from esys.escript import *
25 from esys.weipa import createDataset, saveSilo
26
27 from costfunctions import SimpleCostFunction
28 from forwardmodels import GravityModel
29 from mappings import *
30 from minimizers import *
31 from regularizations import Regularization
32
33 class InversionBase(object):
34 """
35 Base class for inversions.
36 """
37 def __init__(self):
38 self.logger=logging.getLogger('inv.%s'%self.__class__.__name__)
39 self._solver_opts = {}
40 self._solver_tol = 1e-9
41 self._solver_maxiter = 200
42 self._output_dir = '.'
43 # use identity mapping by default
44 self.mapping=ScalingMapping(1)
45 self.source=None
46 self.solverclass=MinimizerLBFGS
47
48 def setOutputDirectory(self, outdir):
49 self._output_dir = outdir
50
51 def setSolverMaxIterations(self, maxiter):
52 self._solver_maxiter=maxiter
53
54 def setSolverOptions(self, **opts):
55 self._solver_opts.update(**opts)
56
57 def setSolverTolerance(self, tol):
58 self._solver_tol=tol
59
60 def setSolverClass(self, solverclass):
61 self.solverclass=solverclass
62
63 def setDataSource(self, source):
64 self.source=source
65
66 def setMapping(self, mapping):
67 self.mapping=mapping
68
69 def run(self):
70 raise NotImplementedError
71
72
73 class GravityInversion(InversionBase):
74 """
75 """
76 def __init__(self):
77 super(GravityInversion,self).__init__()
78 self.setWeights()
79
80 def setWeights(self, mu_reg=1., mu_model=1.):
81 self._mu_reg=mu_reg
82 self._mu_model=mu_model
83
84 def solverCallback(self, k, x, fx, gfx):
85 fn=os.path.join(self._output_dir, 'inv.%d'%k)
86 ds=createDataset(rho=self.mapping.getValue(x))
87 ds.setCycleAndTime(k,k)
88 ds.saveSilo(fn)
89 self.logger.debug("Jreg(m) = %e"%self.regularization.getValue(x))
90 self.logger.debug("f(m) = %e"%fx)
91
92 def run(self):
93 if self.source is None:
94 raise ValueError("No data source set!")
95
96 self.logger.info('Retrieving domain...')
97 domain=self.source.getDomain()
98 DIM=domain.getDim()
99 self.logger.info("Retrieving density mask...")
100 rho_mask = self.source.getDensityMask()
101 self.logger.info("Retrieving gravity and standard deviation data...")
102 g, sigma=self.source.getGravityAndStdDev()
103 chi=safeDiv(1., sigma*sigma)
104 chi=interpolate(chi, Function(domain))
105 g=interpolate(g, Function(domain))
106 self.logger.debug("g = %s"%g)
107 self.logger.debug("sigma = %s"%sigma)
108 self.logger.debug("chi = %s"%chi)
109 chi=chi*kronecker(DIM)[DIM-1]
110 m_ref=self.mapping.getInverse(0.)
111 self.regularization=Regularization(domain, m_ref=m_ref, w0=0, w=[1]*DIM, location_of_set_m=rho_mask)
112 self.forwardmodel=GravityModel(domain, chi, g)
113 self.f=SimpleCostFunction(self.regularization, self.mapping, self.forwardmodel)
114 self.f.setWeights(mu_reg=self._mu_reg, mu_model=self._mu_model)
115 solver=self.solverclass(self.f)
116 solver.setTolerance(self._solver_tol)
117 solver.setMaxIterations(self._solver_maxiter)
118 solver.setOptions(**self._solver_opts)
119 self.logger.info("Starting solver...")
120 rho_init=Scalar(0, ContinuousFunction(domain))
121 m_init=self.mapping.getInverse(rho_init)
122
123 solver.setCallback(self.solverCallback)
124 args={'rho_mask':rho_mask,'g':g[DIM-1],'chi':chi[DIM-1],'sigma':sigma}
125 try:
126 args['rho_ref']=self.source.getReferenceDensity()
127 except:
128 pass
129 saveSilo(os.path.join(self._output_dir, 'ref'), **args)
130 solver.run(m_init)
131 m_star=solver.getResult()
132 self.logger.info("m* = %s"%m_star)
133 rho_star=self.mapping.getValue(m_star)
134 self.logger.info("rho* = %s"%rho_star)
135 solver.logSummary()
136 return rho_star
137
138
139 if __name__=="__main__":
140 from esys.escript import unitsSI as U
141 from datasources import *
142
143 p={
144 'PADDING_L' : 5,
145 'PADDING_H' : 0.2,
146 'TOLERANCE' : 1e-9,
147 'MAX_ITER' : 200,
148 'VERBOSITY' : 5,
149 'OUTPUT_DIR': '.',
150 'LOGFILE' : '',
151 'SOURCE' : SyntheticDataSource,
152 'SOLVER_OPTS': {\
153 'initialHessian' : 100
154 },
155 'ARGS' : {\
156 'DIM' : 2,
157 'NE' : 40,
158 'l' : 500*U.km,
159 'h' : 60*U.km,
160 'features': [\
161 SmoothAnomaly(lx=50*U.km, ly=20*U.km, lz=40*U.km, x=100*U.km, y=3*U.km, depth=25*U.km, rho_inner=200., rho_outer=1e-6),
162 SmoothAnomaly(lx=50*U.km, ly=20*U.km, lz=40*U.km, x=400*U.km, y=1*U.km, depth=40*U.km, rho_inner=-200, rho_outer=1e-6)
163 ]
164 }
165 }
166
167 # 1..5 -> 50..10
168 loglevel=60-10*max(1, min(5, p['VERBOSITY']))
169 formatter=logging.Formatter('[%(name)s] \033[1;30m%(message)s\033[0m')
170 logger=logging.getLogger('inv')
171 logger.setLevel(loglevel)
172 handler=logging.StreamHandler()
173 handler.setFormatter(formatter)
174 handler.setLevel(loglevel)
175 logger.addHandler(handler)
176 if len(p['LOGFILE'].strip())>0:
177 handler=logging.FileHandler(os.path.join(p['OUTPUT_DIR'],p['LOGFILE']))
178 formatter=logging.Formatter('%(asctime)s - [%(name)s] %(message)s')
179 handler.setFormatter(formatter)
180 handler.setLevel(loglevel)
181 logger.addHandler(handler)
182
183 if logger.isEnabledFor(logging.DEBUG):
184 for k in sorted(p): logger.debug("%s = %s"%(k,p[k]))
185 source=p['SOURCE'](**p['ARGS'])
186 source.setPadding(p['PADDING_L'], p['PADDING_H'])
187 inv=GravityInversion()
188 inv.setDataSource(source)
189 inv.setOutputDirectory(p['OUTPUT_DIR'])
190 inv.setSolverTolerance(p['TOLERANCE'])
191 inv.setSolverMaxIterations(p['MAX_ITER'])
192 inv.setSolverOptions(**p['SOLVER_OPTS'])
193 if p.has_key('MU'):
194 mu=p['MU']
195 else:
196 logger.info('Generating domain...')
197 x=source.getDomain().getX()
198 l0=sup(x[0])-inf(x[0])
199 l1=sup(x[1])-inf(x[1])
200 l=max(l0,l1)
201 G=6.6742e-11
202 mu=0.5*(l**2*G)**2
203 logger.debug("MU = %s"%mu)
204
205 inv.setWeights(mu_reg=mu)
206 #mapping=BoundedRangeMapping(-200, 200)
207 #inv.setMapping(mapping)
208 rho_new=inv.run()
209

  ViewVC Help
Powered by ViewVC 1.1.26