/[escript]/branches/subworld2/downunder/py_src/inversioncostfunctions.py
ViewVC logotype

Contents of /branches/subworld2/downunder/py_src/inversioncostfunctions.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5530 - (show annotations)
Wed Mar 11 23:39:58 2015 UTC (3 years, 11 months ago) by jfenwick
File MIME type: text/x-python
File size: 23223 byte(s)
Finally making the modded script available.   Some other tweaks
1
2 ##############################################################################
3 #
4 # Copyright (c) 2003-2015 by University of Queensland
5 # http://www.uq.edu.au
6 #
7 # Primary Business: Queensland, Australia
8 # Licensed under the Open Software License version 3.0
9 # http://www.opensource.org/licenses/osl-3.0.php
10 #
11 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 # Development 2012-2013 by School of Earth Sciences
13 # Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 #
15 ##############################################################################
16
17 """Cost functions for inversions with one or more forward models"""
18
19 __copyright__="""Copyright (c) 2003-2015 by University of Queensland
20 http://www.uq.edu.au
21 Primary Business: Queensland, Australia"""
22 __license__="""Licensed under the Open Software License version 3.0
23 http://www.opensource.org/licenses/osl-3.0.php"""
24 __url__="https://launchpad.net/escript-finley"
25
26 __all__ = [ 'InversionCostFunction']
27
28 from .costfunctions import MeteredCostFunction
29 from .mappings import Mapping
30 from .forwardmodels import ForwardModel
31 from esys.escript.pdetools import ArithmeticTuple
32 from esys.escript import Data, inner
33 from esys.escript import SplitWorld, addJob, FunctionJob
34 import numpy as np
35
36
37 class InversionCostFunction(MeteredCostFunction):
38 """
39 Class to define cost function *J(m)* for inversion with one or more
40 forward models based on a multi-valued level set function *m*:
41
42 *J(m) = J_reg(m) + sum_f mu_f * J_f(p)*
43
44 where *J_reg(m)* is the regularization and cross gradient component of the
45 cost function applied to a level set function *m*, *J_f(p)* are the data
46 defect cost functions involving a physical forward model using the
47 physical parameter(s) *p* and *mu_f* is the trade-off factor for model f.
48
49 A forward model depends on a set of physical parameters *p* which are
50 constructed from components of the level set function *m* via mappings.
51
52 Example 1 (single forward model):
53 m=Mapping()
54 f=ForwardModel()
55 J=InversionCostFunction(Regularization(), m, f)
56
57 Example 2 (two forward models on a single valued level set)
58 m0=Mapping()
59 m1=Mapping()
60 f0=ForwardModel()
61 f1=ForwardModel()
62
63 J=InversionCostFunction(Regularization(), mappings=[m0, m1], forward_models=[(f0, 0), (f1,1)])
64
65 Example 3 (two forward models on 2-valued level set)
66 m0=Mapping()
67 m1=Mapping()
68 f0=ForwardModel()
69 f1=ForwardModel()
70
71 J=InversionCostFunction(Regularization(self.numLevelSets=2), mappings=[(m0,0), (m1,0)], forward_models=[(f0, 0), (f1,1)])
72
73 :cvar provides_inverse_Hessian_approximation: if true the class provides an
74 approximative inverse of the Hessian operator.
75 """
76 provides_inverse_Hessian_approximation=True
77
78 # returns lists of mappings and models where each one has a collection on indices
79 @staticmethod
80 def processMapsAndModels(mappings, forward_models, numLevelSets):
81 outmappings=[]
82 for i in range(len(mappings)):
83 mm=mappings[i]
84 if isinstance(mm, Mapping):
85 m=mm
86 if numLevelSets>1:
87 idx=[ p for p in range(numLevelSets)]
88 else:
89 idx=None
90 elif len(mm) == 1:
91 m=mm[0]
92 if numLevelSets>1:
93 idx=[ p for p in range(numLevelSets)]
94 else:
95 idx=None
96 else:
97 m=mm[0]
98 if isinstance(mm[1], int):
99 idx=[mm[1]]
100 else:
101 idx=list(mm[1])
102 if numLevelSets>1:
103 for k in idx:
104 if k < 0 or k > numLevelSets-1:
105 raise ValueError("level set index %s is out of range."%(k,))
106
107 else:
108 if idx[0] != 0:
109 raise ValueError("Level set index %s is out of range."%(k,))
110 else:
111 idx=None
112 outmappings.append((m,idx))
113 numMappings=len(outmappings)
114 if isinstance(forward_models, ForwardModel):
115 forward_models = [ forward_models ]
116 temp_fwdmod=[]
117 for i in range(len(forward_models)):
118 f=forward_models[i]
119 if isinstance(f, ForwardModel):
120 idx=[0]
121 fm=f
122 elif len(f) == 1:
123 idx=[0]
124 fm=f[0]
125 else:
126 if isinstance(f[1],int):
127 idx=[f[1]]
128 else:
129 idx=list(f[1])
130 for k in idx:
131 if k<0 or k> numMappings:
132 raise ValueError("mapping index %s in model %s is out of range."%(k,i))
133 fm=f[0]
134 temp_fwdmod.append((fm,idx))
135 return outmappings, temp_fwdmod
136
137
138 #This is a test at hacking splitworld functionality into downunder
139 #In production code, we may need a new subclass for this
140 # If splitw is supplied, jobs will be run in that split world.
141 # If worldsinit_fn is supplied (and spliw is present), then that function will be run on each
142 # world
143 def __init__(self, regularization, mappings, forward_models, splitw=None, worldsinit_fn=None, numLevelSets=None,
144 numModels=None, numMappings=None):
145 """
146 constructor for the cost function.
147 Stores the supplied object references and sets default weights.
148
149 :param regularization: the regularization part of the cost function
150 :type regularization: `Regularization`
151 :param mappings: the mappings to calculate physical parameters from the
152 regularization. This is a list of 2-tuples *(map, i)*
153 where the first component map defines a `Mapping` and
154 the second component *i* defines the index of the
155 component of level set function to be used to
156 calculate the mapping. Items in the list may also be
157 just `Mapping` objects in which case the entire level
158 set function is fed into the `Mapping` (typically used
159 for a single-component level set function.
160 :type mappings: `Mapping` or ``list``
161 :param forward_models: the forward models involved in the calculation
162 of the cost function. This is a list of 2-tuples
163 *(f, ii)* where the first component f defines a
164 `ForwardModel` and the second component *ii* a
165 list of indexes referring to the physical
166 parameters in the `mappings` list. The 2-tuple
167 can be replaced by a `ForwardModel` if the
168 `mappings` list has a single entry.
169 :param forward_models: `ForwardModel` or ``list``
170 """
171 super(InversionCostFunction, self).__init__()
172 if regularization is not None:
173 self.regularization=regularization
174 self.numLevelSets = self.regularization.getNumLevelSets()
175 else:
176 self.numLevelSets = numLevelSets
177
178 if isinstance(mappings, Mapping):
179 mappings = [ mappings ]
180
181 if splitw is None:
182 splitw=SplitWorld(1)
183
184
185 #temporary hack
186 if numModels<1:
187 raise ValueError("numModels must be at least one")
188 #temporary hack
189 if numMappings<1:
190 raise ValueError("numMappings must be at least one")
191
192 # calls worlds init in each of the subworlds
193 if worldsinit_fn is not None:
194 for i in range(0,splitw.getNumWorlds()):
195
196 start=i*numModels//splitw.getNumWorlds()
197 howmany=min(numModels-start+1,numModels% splitw.getNumWorlds())
198 addJob(splitw, FunctionJob, worldsinit_fn, rangestart=start, rangelen=howmany, numLevelSets=self.numLevelSets)
199 splitw.runJobs()
200
201 self.numMappings=numMappings
202 self.numModels=numModels
203
204 # Need to make sure this is updated later
205 if regularization is not None:
206 self.__num_tradeoff_factors = self.regularization.getNumTradeOffFactors() + self.numModels
207 else:
208 self.__num_tradeoff_factors=None
209
210 self.setTradeOffFactorsModels()
211
212 def getDomain(self):
213 """
214 returns the domain of the cost function
215
216 :rtype: `Domain`
217 """
218 raise RuntimeError("External access to regularization not permitted in split world mode.")
219 self.regularization.getDomain()
220
221 def getNumTradeOffFactors(self):
222 """
223 returns the number of trade-off factors being used including the
224 trade-off factors used in the regularization component.
225
226 :rtype: ``int``
227 """
228 return self.__num_tradeoff_factors
229
230 def getForwardModel(self, idx=None):
231 """
232 returns the *idx*-th forward model.
233
234 :param idx: model index. If cost function contains one model only `idx`
235 can be omitted.
236 :type idx: ``int``
237 """
238 if idx==None: idx=0
239 return self.forward_models[idx][0]
240
241 def getRegularization(self):
242 """
243 returns the regularization
244
245 :rtype: `Regularization`
246 """
247 raise RuntimeError("External access to regularization not permitted in split world mode.")
248 return self.regularization
249
250 def setTradeOffFactorsModels(self, mu=None):
251 """
252 sets the trade-off factors for the forward model components.
253
254 :param mu: list of the trade-off factors. If not present ones are used.
255 :type mu: ``float`` in case of a single model or a ``list`` of
256 ``float`` with the length of the number of models.
257 """
258 if mu==None:
259 self.mu_model=np.ones((self.numModels, ))
260 else:
261 if self.numModels > 1:
262 mu=np.asarray(mu, dtype=float)
263 if min(mu) > 0:
264 self.mu_model= mu
265 else:
266 raise ValueError("All values for trade-off factor mu must be positive.")
267 else:
268 mu=float(mu)
269 if mu > 0:
270 self.mu_model= [mu, ]
271 else:
272 raise ValueError("Trade-off factor must be positive.")
273
274 def getTradeOffFactorsModels(self):
275 """
276 returns the trade-off factors for the forward models
277
278 :rtype: ``float`` or ``list`` of ``float``
279 """
280 if self.numModels>1:
281 return self.mu_model
282 else:
283 return self.mu_model[0]
284
285 def setTradeOffFactorsRegularization(self, mu=None, mu_c=None):
286 """
287 sets the trade-off factors for the regularization component of the
288 cost function, see `Regularization` for details.
289
290 :param mu: trade-off factors for the level-set variation part
291 :param mu_c: trade-off factors for the cross gradient variation part
292 """
293 raise RuntimeError("External access to regularization not permitted in split world mode.")
294 self.regularization.setTradeOffFactorsForVariation(mu)
295 self.regularization.setTradeOffFactorsForCrossGradient(mu_c)
296
297 def setTradeOffFactors(self, mu=None):
298 """
299 sets the trade-off factors for the forward model and regularization
300 terms.
301
302 :param mu: list of trade-off factors.
303 :type mu: ``list`` of ``float``
304 """
305 if mu is None:
306 mu=np.ones((self.__num_tradeoff_factors,))
307 self.setTradeOffFactorsModels(mu[:self.numModels])
308 self.regularization.setTradeOffFactors(mu[self.numModels:])
309
310 def getTradeOffFactors(self, mu=None):
311 """
312 returns a list of the trade-off factors.
313
314 :rtype: ``list`` of ``float``
315 """
316 mu1=self.getTradeOffFactorsModels(mu[:self.numModels])
317 mu2=self.regularization.getTradeOffFactors()
318 return [ m for m in mu1] + [ m for m in mu2]
319
320 def createLevelSetFunction(self, *props):
321 """
322 returns an instance of an object used to represent a level set function
323 initialized with zeros. Components can be overwritten by physical
324 properties `props`. If present entries must correspond to the
325 `mappings` arguments in the constructor. Use ``None`` for properties
326 for which no value is given.
327 """
328 m=self.regularization.getPDE().createSolution()
329 if len(props) > 0:
330 for i in range(self.numMappings):
331 if props[i]:
332 mp, idx=self.mappings[i]
333 m2=mp.getInverse(props[i])
334 if idx:
335 if len(idx) == 1:
336 m[idx[0]]=m2
337 else:
338 for k in range(idx): m[idx[k]]=m2[k]
339 else:
340 m=m2
341 return m
342
343 def getProperties(self, m, return_list=False):
344 """
345 returns a list of the physical properties from a given level set
346 function *m* using the mappings of the cost function.
347
348 :param m: level set function
349 :type m: `Data`
350 :param return_list: if ``True`` a list is returned.
351 :type return_list: ``bool``
352 :rtype: ``list`` of `Data`
353 """
354
355 props=[]
356 for i in range(self.numMappings):
357 mp, idx=self.mappings[i]
358 if idx:
359 if len(idx)==1:
360 p=mp.getValue(m[idx[0]])
361 else:
362 m2=Data(0.,(len(idx),),m.getFunctionSpace())
363 for k in range(len(idx)): m2[k]=m[idx[k]]
364 p=mp.getValue(m2)
365 else:
366 p=mp.getValue(m)
367 props.append(p)
368 if self.numMappings > 1 or return_list:
369 return props
370 else:
371 return props[0]
372
373 def _getDualProduct(self, x, r):
374 """
375 Returns the dual product, see `Regularization.getDualProduct`
376
377 :type x: `Data`
378 :type r: `ArithmeticTuple`
379 :rtype: ``float``
380 """
381 return self.regularization.getDualProduct(x, r)
382
383 def _getArguments(self, m):
384 """
385 returns pre-computed values that are shared in the calculation of
386 *J(m)* and *grad J(m)*. In this implementation returns a tuple with the
387 mapped value of ``m``, the arguments from the forward model and the
388 arguments from the regularization.
389
390 :param m: current approximation of the level set function
391 :type m: `Data`
392 :return: tuple of of values of the parameters, pre-computed values
393 for the forward model and pre-computed values for the
394 regularization
395 :rtype: ``tuple``
396 """
397 args_reg=self.regularization.getArguments(m)
398 # cache for physical parameters:
399 props=self.getProperties(m, return_list=True)
400 args_f=[]
401 for i in range(self.numModels):
402 f, idx=self.forward_models[i]
403 pp=tuple( [ props[k] for k in idx] )
404 aa=f.getArguments(*pp)
405 args_f.append(aa)
406
407 return props, args_f, args_reg
408
409 def _getValue(self, m, *args):
410 """
411 Returns the value *J(m)* of the cost function at *m*.
412 If the pre-computed values are not supplied `getArguments()` is called.
413
414 :param m: current approximation of the level set function
415 :type m: `Data`
416 :param args: tuple of values of the parameters, pre-computed values
417 for the forward model and pre-computed values for the
418 regularization
419 :rtype: ``float``
420 """
421 if len(args)==0:
422 args=self.getArguments(m)
423
424 props=args[0]
425 args_f=args[1]
426 args_reg=args[2]
427
428 J = self.regularization.getValue(m, *args_reg)
429 self.logger.debug("J_R (incl. trade-offs) = %e"%J)
430
431 for i in range(self.numModels):
432 f, idx=self.forward_models[i]
433 args=tuple( [ props[k] for k in idx] + list( args_f[i] ) )
434 J_f = f.getDefect(*args)
435 self.logger.debug("J_f[%d] = %e, mu_model[%d] = %e"%(i, J_f, i, self.mu_model[i]))
436 J += self.mu_model[i] * J_f
437
438 return J
439
440 def getComponentValues(self, m, *args):
441 return self._getComponentValues(m, *args)
442
443 def _getComponentValues(self, m, *args):
444 """
445 returns the values of the individual cost functions that make up *f(x)*
446 using the precalculated values for *x*.
447
448 :param x: a solution approximation
449 :type x: x-type
450 :rtype: ``list<<float>>``
451 """
452 if len(args)==0:
453 args=self.getArguments(m)
454
455 props=args[0]
456 args_f=args[1]
457 args_reg=args[2]
458
459 J_reg = self.regularization.getValue(m, *args_reg)
460 result = [J_reg]
461
462 for i in range(self.numModels):
463 f, idx=self.forward_models[i]
464 args=tuple( [ props[k] for k in idx] + list( args_f[i] ) )
465 J_f = f.getValue(*args)
466 self.logger.debug("J_f[%d] = %e, mu_model[%d] = %e"%(i, J_f, i, self.mu_model[i]))
467
468 result += [J_f] # self.mu_model[i] * ??
469
470 return result
471
472 def _getGradient(self, m, *args):
473 """
474 returns the gradient of the cost function at *m*.
475 If the pre-computed values are not supplied `getArguments()` is called.
476
477 :param m: current approximation of the level set function
478 :type m: `Data`
479 :param args: tuple of values of the parameters, pre-computed values
480 for the forward model and pre-computed values for the
481 regularization
482
483 :rtype: `ArithmeticTuple`
484
485 :note: returns (Y^,X) where Y^ is the gradient from regularization plus
486 gradients of fwd models. X is the gradient of the regularization
487 w.r.t. gradient of m.
488 """
489 if len(args)==0:
490 args = self.getArguments(m)
491
492 props=args[0]
493 args_f=args[1]
494 args_reg=args[2]
495
496 g_J = self.regularization.getGradient(m, *args_reg)
497 p_diffs=[]
498 for i in range(self.numMappings):
499 mm, idx=self.mappings[i]
500 if idx and self.numLevelSets > 1:
501 if len(idx)>1:
502 m2=Data(0,(len(idx),),m.getFunctionSpace())
503 for k in range(len(idx)): m2[k]=m[idx[k]]
504 dpdm = mm.getDerivative(m2)
505 else:
506 dpdm = mm.getDerivative(m[idx[0]])
507 else:
508 dpdm = mm.getDerivative(m)
509 p_diffs.append(dpdm)
510
511 Y=g_J[0] # Because g_J==(Y,X) Y_k=dKer/dm_k
512 for i in range(self.numModels):
513 mu=self.mu_model[i]
514 f, idx_f=self.forward_models[i]
515 args=tuple( [ props[k] for k in idx_f] + list( args_f[i] ) )
516 Ys = f.getGradient(*args) # this d Jf/d props
517 # in this case f depends on one parameter props only but this can
518 # still depend on several level set components
519 if Ys.getRank() == 0:
520 # run through all level sets k prop j is depending on:
521 idx_m=self.mappings[idx_f[0]][1]
522 # tmp[k] = dJ_f/d_prop * d prop/d m[idx_m[k]]
523 tmp=Ys * p_diffs[idx_f[0]] * mu
524 if idx_m:
525 if tmp.getRank()== 0:
526 for k in range(len(idx_m)):
527 Y[idx_m[k]]+=tmp # dJ_f /d m[idx_m[k]] = tmp
528 else:
529 for k in range(len(idx_m)):
530 Y[idx_m[k]]+=tmp[k] # dJ_f /d m[idx_m[k]] = tmp[k]
531 else:
532 Y+=tmp # dJ_f /d m[idx_m[k]] = tmp
533 else:
534 s=0
535 # run through all props j forward model f is depending on:
536 for j in range(len(idx_f)):
537 # run through all level sets k prop j is depending on:
538 idx_m=self.mappings[j][1]
539 if p_diffs[idx_f[j]].getRank() == 0 :
540 if idx_m: # this case is not needed (really?)
541 self.logger.error("something wrong A")
542 # tmp[k] = dJ_f/d_prop[j] * d prop[j]/d m[idx_m[k]]
543 tmp=Ys[s]*p_diffs[idx_f[j]] * mu
544 for k in range(len(idx_m)):
545 Y[idx_m[k]]+=tmp[k] # dJ_f /d m[idx_m[k]] = tmp[k]
546 else:
547 Y+=Ys[s]*p_diffs[idx_f[j]] * mu
548 s+=1
549 elif p_diffs[idx_f[j]].getRank() == 1 :
550 l=p_diffs[idx_f[j]].getShape()[0]
551 # tmp[k]=sum_j dJ_f/d_prop[j] * d prop[j]/d m[idx_m[k]]
552 tmp=inner(Ys[s:s+l], p_diffs[idx_f[j]]) * mu
553 if idx_m:
554 for k in range(len(idx_m)):
555 Y[idx_m[k]]+=tmp # dJ_f /d m[idx_m[k]] = tmp[k]
556 else:
557 Y+=tmp
558 s+=l
559 else: # rank 2 case
560 l=p_diffs[idx_f[j]].getShape()[0]
561 Yss=Ys[s:s+l]
562 if idx_m:
563 for k in range(len(idx_m)):
564 # dJ_f /d m[idx_m[k]] = tmp[k]
565 Y[idx_m[k]]+=inner(Yss, p_diffs[idx_f[j]][:,k])
566 else:
567 Y+=inner(Yss, p_diffs[idx_f[j]]) * mu
568 s+=l
569 return g_J
570
571 def _getInverseHessianApproximation(self, m, r, *args):
572 """
573 returns an approximative evaluation *p* of the inverse of the Hessian
574 operator of the cost function for a given gradient type *r* at a
575 given location *m*: *H(m) p = r*
576
577 :param m: level set approximation where to calculate Hessian inverse
578 :type m: `Data`
579 :param r: a given gradient
580 :type r: `ArithmeticTuple`
581 :param args: tuple of values of the parameters, pre-computed values
582 for the forward model and pre-computed values for the
583 regularization
584 :rtype: `Data`
585 :note: in the current implementation only the regularization term is
586 considered in the inverse Hessian approximation.
587 """
588 m=self.regularization.getInverseHessianApproximation(m, r, *args[2])
589 return m
590
591 def updateHessian(self):
592 """
593 notifies the class that the Hessian operator needs to be updated.
594 """
595 self.regularization.updateHessian()
596
597 def _getNorm(self, m):
598 """
599 returns the norm of `m`
600
601 :param m: level set function
602 :type m: `Data`
603 :rtype: ``float``
604 """
605 return self.regularization.getNorm(m)
606

  ViewVC Help
Powered by ViewVC 1.1.26