1 |
|
2 |
######################################################## |
3 |
# |
4 |
# Copyright (c) 2003-2012 by University of Queensland |
5 |
# Earth Systems Science Computational Center (ESSCC) |
6 |
# http://www.uq.edu.au/esscc |
7 |
# |
8 |
# Primary Business: Queensland, Australia |
9 |
# Licensed under the Open Software License version 3.0 |
10 |
# http://www.opensource.org/licenses/osl-3.0.php |
11 |
# |
12 |
######################################################## |
13 |
|
14 |
__copyright__="""Copyright (c) 2003-2012 by University of Queensland |
15 |
Earth Systems Science Computational Center (ESSCC) |
16 |
http://www.uq.edu.au/esscc |
17 |
Primary Business: Queensland, Australia""" |
18 |
__license__="""Licensed under the Open Software License version 3.0 |
19 |
http://www.opensource.org/licenses/osl-3.0.php""" |
20 |
__url__="https://launchpad.net/escript-finley" |
21 |
|
22 |
import logging |
23 |
|
24 |
from esys.escript import * |
25 |
from esys.weipa import createDataset, saveSilo |
26 |
|
27 |
from costfunctions import SimpleCostFunction |
28 |
from forwardmodels import GravityModel |
29 |
from mappings import * |
30 |
from minimizers import * |
31 |
from regularizations import Regularization |
32 |
|
33 |
class InversionBase(object): |
34 |
""" |
35 |
Base class for inversions. |
36 |
""" |
37 |
def __init__(self): |
38 |
self.logger=logging.getLogger('inv.%s'%self.__class__.__name__) |
39 |
self._solver_opts = {} |
40 |
self._solver_tol = 1e-9 |
41 |
self._solver_maxiter = 200 |
42 |
self._output_dir = '.' |
43 |
# use identity mapping by default |
44 |
self.mapping=ScalingMapping(1) |
45 |
self.source=None |
46 |
self.solverclass=MinimizerLBFGS |
47 |
|
48 |
def setOutputDirectory(self, outdir): |
49 |
self._output_dir = outdir |
50 |
|
51 |
def setSolverMaxIterations(self, maxiter): |
52 |
self._solver_maxiter=maxiter |
53 |
|
54 |
def setSolverOptions(self, **opts): |
55 |
self._solver_opts.update(**opts) |
56 |
|
57 |
def setSolverTolerance(self, tol): |
58 |
self._solver_tol=tol |
59 |
|
60 |
def setSolverClass(self, solverclass): |
61 |
self.solverclass=solverclass |
62 |
|
63 |
def setDataSource(self, source): |
64 |
self.source=source |
65 |
|
66 |
def setMapping(self, mapping): |
67 |
self.mapping=mapping |
68 |
|
69 |
def run(self): |
70 |
raise NotImplementedError |
71 |
|
72 |
|
73 |
class GravityInversion(InversionBase): |
74 |
""" |
75 |
""" |
76 |
def __init__(self): |
77 |
super(GravityInversion,self).__init__() |
78 |
self.setWeights() |
79 |
|
80 |
def setWeights(self, mu_reg=1., mu_model=1.): |
81 |
self._mu_reg=mu_reg |
82 |
self._mu_model=mu_model |
83 |
|
84 |
def solverCallback(self, k, x, fx, gfx): |
85 |
fn=os.path.join(self._output_dir, 'inv.%d'%k) |
86 |
ds=createDataset(rho=self.mapping.getValue(x)) |
87 |
ds.setCycleAndTime(k,k) |
88 |
ds.saveSilo(fn) |
89 |
self.logger.debug("Jreg(m) = %e"%self.regularization.getValue(x)) |
90 |
self.logger.debug("f(m) = %e"%fx) |
91 |
|
92 |
def run(self): |
93 |
if self.source is None: |
94 |
raise ValueError("No data source set!") |
95 |
|
96 |
self.logger.info('Retrieving domain...') |
97 |
domain=self.source.getDomain() |
98 |
DIM=domain.getDim() |
99 |
self.logger.info("Retrieving density mask...") |
100 |
rho_mask = self.source.getDensityMask() |
101 |
self.logger.info("Retrieving gravity and standard deviation data...") |
102 |
g, sigma=self.source.getGravityAndStdDev() |
103 |
chi=safeDiv(1., sigma*sigma) |
104 |
chi=interpolate(chi, Function(domain)) |
105 |
g=interpolate(g, Function(domain)) |
106 |
self.logger.debug("g = %s"%g) |
107 |
self.logger.debug("sigma = %s"%sigma) |
108 |
self.logger.debug("chi = %s"%chi) |
109 |
chi=chi*kronecker(DIM)[DIM-1] |
110 |
m_ref=self.mapping.getInverse(0.) |
111 |
self.regularization=Regularization(domain, m_ref=m_ref, w0=0, w=[1]*DIM, location_of_set_m=rho_mask) |
112 |
self.forwardmodel=GravityModel(domain, chi, g) |
113 |
self.f=SimpleCostFunction(self.regularization, self.mapping, self.forwardmodel) |
114 |
self.f.setWeights(mu_reg=self._mu_reg, mu_model=self._mu_model) |
115 |
solver=self.solverclass(self.f) |
116 |
solver.setTolerance(self._solver_tol) |
117 |
solver.setMaxIterations(self._solver_maxiter) |
118 |
solver.setOptions(**self._solver_opts) |
119 |
self.logger.info("Starting solver...") |
120 |
rho_init=Scalar(0, ContinuousFunction(domain)) |
121 |
m_init=self.mapping.getInverse(rho_init) |
122 |
|
123 |
solver.setCallback(self.solverCallback) |
124 |
args={'rho_mask':rho_mask,'g':g[DIM-1],'chi':chi[DIM-1],'sigma':sigma} |
125 |
try: |
126 |
args['rho_ref']=self.source.getReferenceDensity() |
127 |
except: |
128 |
pass |
129 |
saveSilo(os.path.join(self._output_dir, 'ref'), **args) |
130 |
solver.run(m_init) |
131 |
m_star=solver.getResult() |
132 |
self.logger.info("m* = %s"%m_star) |
133 |
rho_star=self.mapping.getValue(m_star) |
134 |
self.logger.info("rho* = %s"%rho_star) |
135 |
solver.logSummary() |
136 |
return rho_star |
137 |
|
138 |
|
139 |
if __name__=="__main__": |
140 |
from esys.escript import unitsSI as U |
141 |
from datasources import * |
142 |
|
143 |
p={ |
144 |
'PADDING_L' : 5, |
145 |
'PADDING_H' : 0.2, |
146 |
'TOLERANCE' : 1e-9, |
147 |
'MAX_ITER' : 200, |
148 |
'VERBOSITY' : 5, |
149 |
'OUTPUT_DIR': '.', |
150 |
'LOGFILE' : '', |
151 |
'SOURCE' : SyntheticDataSource, |
152 |
'SOLVER_OPTS': {\ |
153 |
'initialHessian' : 100 |
154 |
}, |
155 |
'ARGS' : {\ |
156 |
'DIM' : 2, |
157 |
'NE' : 40, |
158 |
'l' : 500*U.km, |
159 |
'h' : 60*U.km, |
160 |
'features': [\ |
161 |
SmoothAnomaly(lx=50*U.km, ly=20*U.km, lz=40*U.km, x=100*U.km, y=3*U.km, depth=25*U.km, rho_inner=200., rho_outer=1e-6), |
162 |
SmoothAnomaly(lx=50*U.km, ly=20*U.km, lz=40*U.km, x=400*U.km, y=1*U.km, depth=40*U.km, rho_inner=-200, rho_outer=1e-6) |
163 |
] |
164 |
} |
165 |
} |
166 |
|
167 |
# 1..5 -> 50..10 |
168 |
loglevel=60-10*max(1, min(5, p['VERBOSITY'])) |
169 |
formatter=logging.Formatter('[%(name)s] \033[1;30m%(message)s\033[0m') |
170 |
logger=logging.getLogger('inv') |
171 |
logger.setLevel(loglevel) |
172 |
handler=logging.StreamHandler() |
173 |
handler.setFormatter(formatter) |
174 |
handler.setLevel(loglevel) |
175 |
logger.addHandler(handler) |
176 |
if len(p['LOGFILE'].strip())>0: |
177 |
handler=logging.FileHandler(os.path.join(p['OUTPUT_DIR'],p['LOGFILE'])) |
178 |
formatter=logging.Formatter('%(asctime)s - [%(name)s] %(message)s') |
179 |
handler.setFormatter(formatter) |
180 |
handler.setLevel(loglevel) |
181 |
logger.addHandler(handler) |
182 |
|
183 |
if logger.isEnabledFor(logging.DEBUG): |
184 |
for k in sorted(p): logger.debug("%s = %s"%(k,p[k])) |
185 |
source=p['SOURCE'](**p['ARGS']) |
186 |
source.setPadding(p['PADDING_L'], p['PADDING_H']) |
187 |
inv=GravityInversion() |
188 |
inv.setDataSource(source) |
189 |
inv.setOutputDirectory(p['OUTPUT_DIR']) |
190 |
inv.setSolverTolerance(p['TOLERANCE']) |
191 |
inv.setSolverMaxIterations(p['MAX_ITER']) |
192 |
inv.setSolverOptions(**p['SOLVER_OPTS']) |
193 |
if p.has_key('MU'): |
194 |
mu=p['MU'] |
195 |
else: |
196 |
logger.info('Generating domain...') |
197 |
x=source.getDomain().getX() |
198 |
l0=sup(x[0])-inf(x[0]) |
199 |
l1=sup(x[1])-inf(x[1]) |
200 |
l=max(l0,l1) |
201 |
G=6.6742e-11 |
202 |
mu=0.5*(l**2*G)**2 |
203 |
logger.debug("MU = %s"%mu) |
204 |
|
205 |
inv.setWeights(mu_reg=mu) |
206 |
#mapping=BoundedRangeMapping(-200, 200) |
207 |
#inv.setMapping(mapping) |
208 |
rho_new=inv.run() |
209 |
|