/[escript]/trunk/downunder/py_src/costfunctions.py
ViewVC logotype

Contents of /trunk/downunder/py_src/costfunctions.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4007 - (show annotations)
Tue Oct 2 02:12:01 2012 UTC (7 years ago) by caltinay
File MIME type: text/x-python
File size: 8334 byte(s)
Doco updates.
Data source paddings now take x,y not xy,z as Z padding can be done by setting
vertical extents.

1
2 ##############################################################################
3 #
4 # Copyright (c) 2003-2012 by University of Queensland
5 # http://www.uq.edu.au
6 #
7 # Primary Business: Queensland, Australia
8 # Licensed under the Open Software License version 3.0
9 # http://www.opensource.org/licenses/osl-3.0.php
10 #
11 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 # Development since 2012 by School of Earth Sciences
13 #
14 ##############################################################################
15
16 """Collection of cost functions for minimization"""
17
18 __copyright__="""Copyright (c) 2003-2012 by University of Queensland
19 http://www.uq.edu.au
20 Primary Business: Queensland, Australia"""
21 __license__="""Licensed under the Open Software License version 3.0
22 http://www.opensource.org/licenses/osl-3.0.php"""
23 __url__="https://launchpad.net/escript-finley"
24
25 try:
26 # only needed for getDirectionalDerivative(), so ignore error
27 from esys.escript import grad
28 except:
29 pass
30
31 class CostFunction(object):
32 """
33 A function *f(x)* that can be minimized (base class).
34
35 Example of usage::
36
37 cf=DerivedCostFunction()
38 # ... calculate x ...
39 args=cf.getArguments(x) # this could be potentially expensive!
40 f=cf.getValue(x, *args)
41 # ... it could be required to update x without using the gradient...
42 # ... but then ...
43 gf=cf.getGradient(x, *args)
44
45 The function calls update statistical information.
46 The actual work is done by the methods with corresponding name and a
47 leading underscore. These functions need to be overwritten for a particular
48 cost function implementation.
49 """
50
51 def __init__(self):
52 """
53 the base constructor initializes the counters so subclasses should
54 ensure the super class constructor is called.
55 """
56 self.resetCounters()
57
58 def resetCounters(self):
59 """
60 resets all statistical counters
61 """
62 self.Inner_calls=0
63 self.Value_calls=0
64 self.Gradient_calls=0
65 self.DirectionalDerivative_calls=0
66 self.Arguments_calls=0
67
68 def getInner(self, f0, f1):
69 """
70 returns the inner product of ``f0`` and ``f1``
71 """
72 self.Inner_calls+=1
73 return self._getInner(f0, f1)
74
75 def getValue(self, x, *args):
76 """
77 returns the value *f(x)* using the precalculated values for *x*.
78 """
79 self.Value_calls+=1
80 return self._getValue(x, *args)
81
82 def __call__(self, x, *args):
83 """
84 short for ``getValue(x, *args)``.
85 """
86 return self.getValue(x, *args)
87
88 def getGradient(self, x, *args):
89 """
90 returns the gradient of *f* at *x* using the precalculated values for
91 *x*.
92 """
93 self.Gradient_calls+=1
94 return self._getGradient(x, *args)
95
96 def getDirectionalDerivative(self, x, d, *args):
97 """
98 returns ``inner(grad f(x), d)`` using the precalculated values for *x*.
99 """
100 self.DirectionalDerivative_calls+=1
101 return self._getDirectionalDerivative(x, d, *args)
102
103 def getArguments(self, x):
104 """
105 returns precalculated values that are shared in the calculation of
106 *f(x)* and *grad f(x)*.
107 """
108 self.Arguments_calls+=1
109 return self._getArguments(x)
110
111 def _getInner(self, f0, f1):
112 """
113 Worker for `getInner()`, needs to be overwritten.
114 """
115 raise NotImplementedError
116
117 def _getValue(self, x, *args):
118 """
119 Worker for `getValue()`, needs to be overwritten.
120 """
121 raise NotImplementedError
122
123 def _getGradient(self, x, *args):
124 """
125 Worker for `getGradient()`, needs to be overwritten.
126 """
127 raise NotImplementedError
128
129 def _getDirectionalDerivative(self, x, d, *args):
130 """
131 returns ``getInner(grad f(x), d)`` using the precalculated values for x.
132
133 This function may be overwritten as there might be more efficient ways
134 of calculating the return value rather than using a
135 ``self.getGradient()`` call.
136 """
137 return self.getInner(self.getGradient(x, *args), d)
138
139 def _getArguments(self, x):
140 """
141 can be overwritten to return precalculated values that are shared in
142 the calculation of *f(x)* and *grad f(x)*. By default returns an empty
143 tuple.
144 """
145 return ()
146
147
148 class SimpleCostFunction(CostFunction):
149 """
150 This is a simple cost function with a single continuous (mapped) variable.
151 It is the sum of two weighted terms, a single forward model and a single
152 regularization term. This cost function is used in the gravity inversion.
153 """
154 def __init__(self, regularization, mapping, forwardmodel):
155 """
156 constructor stores the supplied object references and sets default
157 weights.
158
159 :param regularization: The regularization part of the cost function
160 :param mapping: Parametrization object
161 :param forwardmodel: The forward model part of the cost function
162 """
163 super(SimpleCostFunction, self).__init__()
164 self.forwardmodel=forwardmodel
165 self.regularization=regularization
166 self.mapping=mapping
167 self.setWeights()
168
169 def setWeights(self, mu_model=1., mu_reg=1.):
170 """
171 sets the weighting factors for the forward model and regularization
172 terms.
173
174 :param mu_model: Weighting factor for the forward model (default=1.)
175 :type mu_model: non-negative `float`
176 :param mu_reg: Weighting factor for the regularization (default=1.)
177 :type mu_reg: non-negative `float`
178 """
179 if mu_model<0. or mu_reg<0.:
180 raise ValueError("weighting factors must be non-negative.")
181 self.mu_model=mu_model
182 self.mu_reg=mu_reg
183
184 def _getInner(self, f0, f1):
185 """
186 returns ``regularization.getInner(f0,f1)``
187
188 :rtype: `float`
189 """
190 # if there is more than one regularization involved their contributions
191 # need to be added up.
192 return self.regularization.getInner(f0, f1)
193
194 def _getArguments(self, m):
195 """
196 returns precalculated values that are shared in the calculation of
197 *f(x)* and *grad f(x)*. In this implementation returns a tuple with the
198 mapped value of ``m``, the arguments from the forward model and the
199 arguments from the regularization.
200
201 :rtype: `tuple`
202 """
203 rho=self.mapping(m)
204 return rho, self.forwardmodel.getArguments(rho), self.regularization.getArguments(m)
205
206 def _getValue(self, m, *args):
207 """
208 returns the function value at m.
209 If the precalculated values are not supplied `getArguments()` is called.
210
211 :rtype: `float`
212 """
213 # if there is more than one forward_model and/or regularization their
214 # contributions need to be added up. But this implementation allows
215 # only one of each...
216 if len(args)==0:
217 args=self.getArguments(m)
218 return self.mu_model * self.forwardmodel.getValue(args[0],*args[1]) \
219 + self.mu_reg * self.regularization.getValue(m)
220
221 def _getGradient(self, m, *args):
222 """
223 returns the gradient of *f* at *m*.
224 If the precalculated values are not supplied `getArguments()` is called.
225
226 :rtype: `esys.escript.Data`
227 """
228 drhodm = self.mapping.getDerivative(m)
229 if len(args)==0:
230 args = self.getArguments(m)
231 Y0 = self.forwardmodel.getGradient(args[0],*args[1])
232 Y1, X1 = self.regularization.getGradient(m)
233 return self.regularization.project(Y=self.mu_reg*Y1 + self.mu_model*Y0*drhodm, X=self.mu_reg*X1)
234
235 def _getDirectionalDerivative(self, m, d, *args):
236 """
237 returns the directional derivative at *m* in direction *d*.
238
239 :rtype: `float`
240 """
241 drhodm = self.mapping.getDerivative(m)
242 Y0 = self.forwardmodel.getGradient(args[0],*args[1])
243 Y1, X1 = self.regularization.getGradient(m)
244 return self.regularization.getInner(d, self.mu_reg*Y1 + self.mu_model*Y0*drhodm) \
245 + self.mu_reg*self.regularization.getInner(grad(d), X1)
246

  ViewVC Help
Powered by ViewVC 1.1.26