/[escript]/trunk/downunder/py_src/minimizers.py
ViewVC logotype

Annotation of /trunk/downunder/py_src/minimizers.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4213 - (hide annotations)
Tue Feb 19 01:16:29 2013 UTC (7 years, 5 months ago) by caltinay
File MIME type: text/x-python
File size: 22372 byte(s)
Some cleanup and more consistent logging.

1 caltinay 3946
2 jfenwick 3981 ##############################################################################
3 caltinay 3946 #
4 jfenwick 4154 # Copyright (c) 2003-2013 by University of Queensland
5 jfenwick 3981 # http://www.uq.edu.au
6 caltinay 3946 #
7     # Primary Business: Queensland, Australia
8     # Licensed under the Open Software License version 3.0
9     # http://www.opensource.org/licenses/osl-3.0.php
10     #
11 jfenwick 3981 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12     # Development since 2012 by School of Earth Sciences
13     #
14     ##############################################################################
15 caltinay 3946
16 caltinay 4007 """Generic minimization algorithms"""
17    
18 jfenwick 4154 __copyright__="""Copyright (c) 2003-2013 by University of Queensland
19 jfenwick 3981 http://www.uq.edu.au
20 caltinay 3946 Primary Business: Queensland, Australia"""
21     __license__="""Licensed under the Open Software License version 3.0
22     http://www.opensource.org/licenses/osl-3.0.php"""
23     __url__="https://launchpad.net/escript-finley"
24    
25 gross 4120 __all__ = ['MinimizerException', 'MinimizerIterationIncurableBreakDown', 'MinimizerMaxIterReached' , 'AbstractMinimizer', 'MinimizerLBFGS', 'MinimizerBFGS', 'MinimizerNLCG']
26 caltinay 3947
27 caltinay 3946 import logging
28     import numpy as np
29 gross 4074
30 caltinay 3946 try:
31     from esys.escript import Lsup, sqrt, EPSILON
32     except:
33     Lsup=lambda x: np.amax(abs(x))
34     sqrt=np.sqrt
35     EPSILON=1e-18
36    
37     lslogger=logging.getLogger('inv.minimizer.linesearch')
38     zoomlogger=logging.getLogger('inv.minimizer.linesearch.zoom')
39    
40     def _zoom(phi, gradphi, phiargs, alpha_lo, alpha_hi, phi_lo, phi_hi, c1, c2, phi0, gphi0, IMAX=25):
41     """
42     Helper function for `line_search` below which tries to tighten the range
43     alpha_lo...alpha_hi. See Chapter 3 of 'Numerical Optimization' by
44     J. Nocedal for an explanation.
45     """
46     i=0
47     while True:
48     alpha=alpha_lo+.5*(alpha_hi-alpha_lo) # should use interpolation...
49     args_a=phiargs(alpha)
50     phi_a=phi(alpha, *args_a)
51 caltinay 4213 zoomlogger.debug("iteration %d, alpha=%e, phi(alpha)=%e"%(i,alpha,phi_a))
52 caltinay 3946 if phi_a > phi0+c1*alpha*gphi0 or phi_a >= phi_lo:
53     alpha_hi=alpha
54     else:
55     gphi_a=gradphi(alpha, *args_a)
56 caltinay 4213 zoomlogger.debug("\tgrad(phi(alpha))=%e"%(gphi_a))
57 caltinay 3946 if np.abs(gphi_a) <= -c2*gphi0:
58     break
59     if gphi_a*(alpha_hi-alpha_lo) >= 0:
60     alpha_hi = alpha_lo
61     alpha_lo=alpha
62     phi_lo=phi_a
63     i+=1
64     if i>IMAX:
65     gphi_a=None
66     break
67     return alpha, phi_a, gphi_a
68    
69 gross 4120 def line_search(f, x, p, g_Jx, Jx, alpha_truncationax=50.0, c1=1e-4, c2=0.9, IMAX=15):
70 caltinay 3946 """
71     Line search method that satisfies the strong Wolfe conditions.
72     See Chapter 3 of 'Numerical Optimization' by J. Nocedal for an explanation.
73    
74     :param f: callable objective function f(x)
75     :param x: start value for the line search
76     :param p: search direction
77 gross 4120 :param g_Jx: value for the gradient of f at x
78     :param Jx: value of f(x)
79     :param alpha_truncationax: algorithm terminates if alpha reaches this value
80 caltinay 3946 :param c1: value for Armijo condition (see reference)
81     :param c2: value for curvature condition (see reference)
82     :param IMAX: maximum number of iterations to perform
83     """
84     # this stores the latest gradf(x+a*p) which is returned
85 gross 4120 g_Jx_new=[g_Jx]
86 caltinay 3946
87     def phi(a, *args):
88     """ phi(a):=f(x+a*p) """
89     return f(x+a*p, *args)
90     def gradphi(a, *args):
91 gross 4120 g_Jx_new[0]=f.getGradient(x+a*p, *args)
92     return f.getDualProduct(p, g_Jx_new[0])
93 caltinay 3946 def phiargs(a):
94     try:
95     args=f.getArguments(x+a*p)
96     except:
97     args=()
98     return args
99    
100     old_alpha=0.
101 gross 4120 # we assume g_Jx is properly scaled so alpha=1 is a reasonable starting value
102 caltinay 3946 alpha=1.
103 gross 4120 if Jx is None:
104 caltinay 3946 args0=phiargs(0.)
105     phi0=phi(0., *args0)
106     else:
107 gross 4120 phi0=Jx
108 caltinay 3946 lslogger.debug("phi(0)=%e"%(phi0))
109 gross 4120 gphi0=f.getDualProduct(p, g_Jx) #gradphi(0., *args0)
110 caltinay 3946 lslogger.debug("grad phi(0)=%e"%(gphi0))
111     old_phi_a=phi0
112     i=1
113    
114 gross 4120 while i<IMAX and alpha>0. and alpha<alpha_truncationax:
115 caltinay 3946 args_a=phiargs(alpha)
116     phi_a=phi(alpha, *args_a)
117 caltinay 4213 lslogger.debug("iteration %d, alpha=%e, phi(alpha)=%e"%(i,alpha,phi_a))
118 caltinay 3946 if (phi_a > phi0+c1*alpha*gphi0) or ((phi_a>=old_phi_a) and (i>1)):
119     alpha, phi_a, gphi_a = _zoom(phi, gradphi, phiargs, old_alpha, alpha, old_phi_a, phi_a, c1, c2, phi0, gphi0)
120     break
121 gross 4047
122 caltinay 3946 gphi_a=gradphi(alpha, *args_a)
123     if np.abs(gphi_a) <= -c2*gphi0:
124     break
125     if gphi_a >= 0:
126     alpha, phi_a, gphi_a = _zoom(phi, gradphi, phiargs, alpha, old_alpha, phi_a, old_phi_a, c1, c2, phi0, gphi0)
127     break
128    
129     old_alpha=alpha
130     # the factor is arbitrary as long as there is sufficient increase
131     alpha=2.*alpha
132     old_phi_a=phi_a
133     i+=1
134    
135 gross 4120 return alpha, phi_a, g_Jx_new[0]
136 caltinay 3946
137 gross 4120 class MinimizerException(Exception):
138 caltinay 4213 """
139     This is a generic exception thrown by a minimizer.
140     """
141     pass
142 gross 4120
143     class MinimizerMaxIterReached(MinimizerException):
144 caltinay 4213 """
145     Exception thrown if the maximum number of iteration steps is reached.
146     """
147     pass
148 gross 4120
149     class MinimizerIterationIncurableBreakDown(MinimizerException):
150 caltinay 4213 """
151     Exception thrown if the iteration scheme encountered an incurable
152     breakdown.
153     """
154     pass
155 gross 4120
156    
157 caltinay 3946 ##############################################################################
158     class AbstractMinimizer(object):
159     """
160     Base class for function minimization methods.
161     """
162    
163 gross 4120 TOLERANCE_REACHED, MAX_ITERATIONS_REACHED, INCURABLE_BREAKDOWN = list(range(3))
164 caltinay 3946
165 gross 4143 def __init__(self, J=None, m_tol=1e-5, J_tol=None, imax=300):
166 caltinay 3946 """
167     Initializes a new minimizer for a given cost function.
168    
169 caltinay 4132 :param J: the cost function to be minimized
170     :type J: `CostFunction`
171 caltinay 3946 """
172 gross 4143 self.setCostFunction(J)
173     self._m_tol = m_tol
174 gross 4120 self._J_tol = J_tol
175 caltinay 3946 self._imax = imax
176     self._result = None
177     self._callback = None
178     self.logger = logging.getLogger('inv.%s'%self.__class__.__name__)
179 gross 4120 self.setTolerance()
180 caltinay 3946
181 gross 4143 def setCostFunction(self, J):
182 caltinay 3946 """
183 gross 4143 set the cost function to be minimized
184 caltinay 4213
185 gross 4143 :param J: the cost function to be minimized
186 caltinay 4213 :type J: `CostFunction`
187 gross 4143 """
188     self.__J=J
189 caltinay 4213
190 gross 4143 def getCostFunction(self):
191     """
192     return the cost function to be minimized
193 caltinay 4213
194     :rtype: `CostFunction`
195 gross 4143 """
196     return self.__J
197 caltinay 4213
198 gross 4143 def setTolerance(self, m_tol=1e-4, J_tol=None):
199     """
200 caltinay 4213 Sets the tolerance for the stopping criterion. The minimizer stops
201     when an appropriate norm is less than `m_tol`.
202 caltinay 3946 """
203 gross 4143 self._m_tol = m_tol
204     self._J_tol = J_tol
205 caltinay 3946
206     def setMaxIterations(self, imax):
207     """
208     Sets the maximum number of iterations before the minimizer terminates.
209     """
210     self._imax = imax
211    
212     def setCallback(self, callback):
213     """
214     Sets a callback function to be called after every iteration.
215 gross 4120 The arguments to the function are: (k, x, Jx, g_Jxx), where
216     k is the current iteration, x is the current estimate, Jx=f(x) and
217     g_Jxx=grad f(x).
218 caltinay 3946 """
219 caltinay 3948 if callback is not None and not callable(callback):
220 caltinay 3946 raise TypeError("Callback function not callable.")
221     self._callback = callback
222    
223     def _doCallback(self, *args):
224     if self._callback is not None:
225     self._callback(*args)
226    
227     def getResult(self):
228     """
229     Returns the result of the minimization.
230     """
231     return self._result
232    
233     def getOptions(self):
234     """
235     Returns a dictionary of minimizer-specific options.
236     """
237     return {}
238    
239     def setOptions(self, **opts):
240     """
241     Sets minimizer-specific options. For a list of possible options see
242     `getOptions()`.
243     """
244     raise NotImplementedError
245    
246     def run(self, x0):
247     """
248 caltinay 3990 Executes the minimization algorithm for *f* starting with the initial
249     guess ``x0``.
250    
251     :return: `TOLERANCE_REACHED` or `MAX_ITERATIONS_REACHED`
252 caltinay 3946 """
253     raise NotImplementedError
254    
255     def logSummary(self):
256     """
257     Outputs a summary of the completed minimization process to the logger.
258     """
259 gross 4143 if hasattr(self.getCostFunction(), "Value_calls"):
260 caltinay 4213 self.logger.info("Number of cost function evaluations: %d"%self.getCostFunction().Value_calls)
261     self.logger.info("Number of gradient evaluations: %d"%self.getCostFunction().Gradient_calls)
262     self.logger.info("Number of inner product evaluations: %d"%self.getCostFunction().DualProduct_calls)
263     self.logger.info("Number of argument evaluations: %d"%self.getCostFunction().Arguments_calls)
264     self.logger.info("Number of norm evaluations: %d"%self.getCostFunction().Norm_calls)
265 caltinay 3946
266     ##############################################################################
267     class MinimizerLBFGS(AbstractMinimizer):
268     """
269     Minimizer that uses the limited-memory Broyden-Fletcher-Goldfarb-Shanno
270     method.
271     """
272    
273     # History size
274 gross 4120 _truncation = 30
275 caltinay 3946
276     # Initial Hessian multiplier
277     _initial_H = 1
278    
279 caltinay 4213 # Restart after this many iteration steps
280     _restart = 60
281    
282 caltinay 3946 def getOptions(self):
283 gross 4120 return {'truncation':self._truncation,'initialHessian':self._initial_H, 'restart':self._restart}
284 caltinay 3946
285     def setOptions(self, **opts):
286     for o in opts:
287 gross 4120 if o=='historySize' or 'truncation':
288     self._truncation=opts[o]
289 caltinay 3946 elif o=='initialHessian':
290     self._initial_H=opts[o]
291 gross 4120 elif o=='restart':
292     self._restart=opts[o]
293 caltinay 3946 else:
294     raise KeyError("Invalid option '%s'"%o)
295    
296     def run(self, x):
297 caltinay 4213 if self.getCostFunction().provides_inverse_Hessian_approximation:
298     self.getCostFunction().updateHessian()
299     invH_scale = None
300     else:
301     invH_scale = self._initial_H
302 gross 4120
303 caltinay 4213 # start the iteration:
304     n_iter = 0
305     n_last_break_down=-1
306     non_curable_break_down = False
307     converged = False
308     args=self.getCostFunction().getArguments(x)
309     g_Jx=self.getCostFunction().getGradient(x, *args)
310     Jx=self.getCostFunction()(x, *args)
311     Jx_0=Jx
312 caltinay 3946
313 caltinay 4213 while not converged and not non_curable_break_down and n_iter < self._imax:
314     k=0
315     break_down = False
316     s_and_y=[]
317     self._doCallback(n_iter, x, Jx, g_Jx)
318 caltinay 3946
319 caltinay 4213 while not converged and not break_down and k < self._restart and n_iter < self._imax:
320     #self.logger.info("\033[1;31miteration %d\033[1;30m"%n_iter)
321     if n_iter%10==0:
322     self.logger.info("********** iteration %3d **********"%n_iter)
323     else:
324     self.logger.debug("********** iteration %3d **********"%n_iter)
325     # determine search direction
326     self.logger.debug("\tJ(x) = %s"%Jx)
327     self.logger.debug("\tgrad f(x) = %s"%g_Jx)
328     if invH_scale: self.logger.debug("\tH = %s"%invH_scale)
329 caltinay 3946
330 caltinay 4213 p = -self._twoLoop(invH_scale, g_Jx, s_and_y, x, *args)
331     # determine step length
332     alpha, Jx_new, g_Jx_new = line_search(self.getCostFunction(), x, p, g_Jx, Jx)
333     # this function returns a scaling alpha for the search
334     # direction as well as the cost function evaluation and
335     # gradient for the new solution approximation x_new=x+alpha*p
336     self.logger.debug("\tSearch direction scaling alpha=%e"%alpha)
337 caltinay 3946
338 caltinay 4213 # execute the step
339     delta_x = alpha*p
340     x_new = x + delta_x
341     self.logger.debug("\tJ(x) = %s"%Jx_new)
342 caltinay 3946
343 caltinay 4213 converged = True
344     if self._J_tol:
345     flag=abs(Jx_new-Jx) <= self._J_tol * abs(Jx_new-Jx_0)
346     if flag:
347     self.logger.debug("Cost function has converged: dJ, J*J_tol = %e, %e"%(Jx-Jx_new,abs(Jx_new-Jx_0)*self._J_tol))
348     else:
349     self.logger.debug("Cost function checked: dJ, J*J_tol = %e, %e"%(Jx-Jx_new,abs(Jx_new)*self._J_tol))
350 caltinay 3946
351 caltinay 4213 converged = converged and flag
352     if self._m_tol:
353     norm_x=self.getCostFunction().getNorm(x_new)
354     norm_dx=self.getCostFunction().getNorm(delta_x)
355     flag= norm_dx <= self._m_tol * norm_x
356     if flag:
357     self.logger.debug("Solution has converged: dx, x*m_tol = %e, %e"%(norm_dx,norm_x*self._m_tol))
358     else:
359     self.logger.debug("Solution checked: dx, x*m_tol = %e, %e"%(norm_dx,norm_x*self._m_tol))
360     converged = converged and flag
361 gross 4100
362 caltinay 4213 x=x_new
363     if converged:
364     break
365    
366     # unfortunately there is more work to do!
367     if g_Jx_new is None:
368     args=self.getCostFunction().getArguments(x_new)
369     g_Jx_new=self.getCostFunction().getGradient(x_new, args)
370     delta_g=g_Jx_new-g_Jx
371    
372     rho=self.getCostFunction().getDualProduct(delta_x, delta_g)
373     if abs(rho)>0:
374     s_and_y.append((delta_x,delta_g, rho ))
375     else:
376     break_down=True
377    
378     self.getCostFunction().updateHessian()
379     g_Jx=g_Jx_new
380     Jx=Jx_new
381    
382     k+=1
383     n_iter+=1
384     self._doCallback(k, x, Jx, g_Jx)
385    
386     # delete oldest vector pair
387     if k>self._truncation: s_and_y.pop(0)
388    
389     if not self.getCostFunction().provides_inverse_Hessian_approximation and not break_down:
390     # set the new scaling factor (approximation of inverse Hessian)
391     denom=self.getCostFunction().getDualProduct(delta_g, delta_g)
392     if denom > 0:
393     invH_scale=self.getCostFunction().getDualProduct(delta_x,delta_g)/denom
394     else:
395     invH_scale=self._initial_H
396     self.logger.debug("** Break down in H update. Resetting to initial value %s."%self._initial_H)
397     # case handling for inner iteration:
398 gross 4120 if break_down:
399 caltinay 4213 if n_iter == n_last_break_down+1:
400     non_curable_break_down = True
401     self.logger.debug("** Incurable break down detected in step %d."%n_iter)
402     else:
403     n_last_break_down = n_iter
404     self.logger.debug("** Break down detected in step %d. Iteration is restarted."%n_iter)
405 gross 4120 if not k < self._restart:
406 caltinay 4213 self.logger.debug("Iteration is restarted after %d steps."%(n_iter,))
407 gross 4100
408 caltinay 4213 # case handling for inner iteration:
409     self._result=x
410     if n_iter >= self._imax:
411     self.return_status=self.MAX_ITERATIONS_REACHED
412     self.logger.debug(">>>>>>>>>> Maximum number of iterations reached! <<<<<<<<<<")
413     raise MinimizerMaxIterReached("Gave up after %d steps."%(n_iter,))
414     elif non_curable_break_down:
415     self.return_status=self.INCURABLE_BREAKDOWN
416     self.logger.debug(">>>>>>>>>> Incurable breakdown! <<<<<<<<<<")
417     raise MinimizerIterationIncurableBreakDown("Gave up after %d steps."%(n_iter,))
418     else:
419     self.return_status=self.TOLERANCE_REACHED
420     self.logger.debug("Success after %d iterations!"%k)
421 gross 4100
422 caltinay 4213 return self.return_status
423    
424 gross 4120 def _twoLoop(self, invH_scale, g_Jx, s_and_y, x, *args):
425 caltinay 3946 """
426     Helper for the L-BFGS method.
427     See 'Numerical Optimization' by J. Nocedal for an explanation.
428     """
429 gross 4120 q=g_Jx
430 caltinay 3946 alpha=[]
431 gross 4074 for s,y, rho in reversed(s_and_y):
432 gross 4143 a=self.getCostFunction().getDualProduct(s, q)/rho
433 caltinay 3946 alpha.append(a)
434     q=q-a*y
435    
436 gross 4143 if self.getCostFunction().provides_inverse_Hessian_approximation:
437     r = self.getCostFunction().getInverseHessianApproximation(x, q, *args)
438 gross 4074 else:
439 caltinay 4097 r = invH_scale * q
440    
441 gross 4074 for s,y,rho in s_and_y:
442 gross 4143 beta = self.getCostFunction().getDualProduct(r, y)/rho
443 caltinay 4097 a = alpha.pop()
444     r = r + s * (a-beta)
445 caltinay 3946 return r
446    
447     ##############################################################################
448     class MinimizerBFGS(AbstractMinimizer):
449     """
450     Minimizer that uses the Broyden-Fletcher-Goldfarb-Shanno method.
451     """
452    
453 caltinay 3950 # Initial Hessian multiplier
454     _initial_H = 1
455    
456     def getOptions(self):
457     return {'initialHessian':self._initial_H}
458    
459     def setOptions(self, **opts):
460     for o in opts:
461     if o=='initialHessian':
462     self._initial_H=opts[o]
463     else:
464     raise KeyError("Invalid option '%s'"%o)
465    
466 caltinay 3946 def run(self, x):
467 gross 4143 args=self.getCostFunction().getArguments(x)
468     g_Jx=self.getCostFunction().getGradient(x, *args)
469     Jx=self.getCostFunction()(x, *args)
470 caltinay 3946 k=0
471     try:
472     n=len(x)
473     except:
474     n=x.getNumberOfDataPoints()
475     I=np.eye(n)
476 caltinay 3950 H=self._initial_H*I
477 gross 4120 gnorm=Lsup(g_Jx)
478     self._doCallback(k, x, Jx, g_Jx)
479 caltinay 3950
480 gross 4143 while gnorm > self._m_tol and k < self._imax:
481 gross 4121 self.logger.debug("iteration %d, gnorm=%e"%(k,gnorm))
482 caltinay 3946
483     # determine search direction
484 gross 4143 d=-self.getCostFunction().getDualProduct(H, g_Jx)
485 caltinay 3950
486     self.logger.debug("H = %s"%H)
487 gross 4120 self.logger.debug("grad f(x) = %s"%g_Jx)
488 caltinay 3950 self.logger.debug("d = %s"%d)
489     self.logger.debug("x = %s"%x)
490    
491 caltinay 3946 # determine step length
492 gross 4143 alpha, Jx, g_Jx_new = line_search(self.getCostFunction(), x, d, g_Jx, Jx)
493 caltinay 3946 self.logger.debug("alpha=%e"%alpha)
494     # execute the step
495     x_new=x+alpha*d
496     delta_x=x_new-x
497     x=x_new
498 gross 4120 if g_Jx_new is None:
499 gross 4143 g_Jx_new=self.getCostFunction().getGradient(x_new)
500 gross 4120 delta_g=g_Jx_new-g_Jx
501     g_Jx=g_Jx_new
502 caltinay 3946 k+=1
503 gross 4120 self._doCallback(k, x, Jx, g_Jx)
504     gnorm=Lsup(g_Jx)
505 gross 4143 if (gnorm<=self._m_tol): break
506 caltinay 3946
507     # update Hessian
508 gross 4143 denom=self.getCostFunction().getDualProduct(delta_x, delta_g)
509 caltinay 3950 if denom < EPSILON * gnorm:
510     denom=1e-5
511     self.logger.debug("Break down in H update. Resetting.")
512     rho=1./denom
513 caltinay 3946 self.logger.debug("rho=%e"%rho)
514 caltinay 3950 A=I-rho*delta_x[:,None]*delta_g[None,:]
515     AT=I-rho*delta_g[:,None]*delta_x[None,:]
516 gross 4143 H=self.getCostFunction().getDualProduct(A, self.getCostFunction().getDualProduct(H,AT)) + rho*delta_x[:,None]*delta_x[None,:]
517 caltinay 3946 if k >= self._imax:
518     reason=self.MAX_ITERATIONS_REACHED
519 gross 4143 self.logger.debug("Maximum number of iterations reached!")
520 caltinay 3946 else:
521     reason=self.TOLERANCE_REACHED
522 gross 4143 self.logger.debug("Success after %d iterations! Final gnorm=%e"%(k,gnorm))
523 caltinay 3946
524     self._result=x
525     return reason
526    
527     ##############################################################################
528     class MinimizerNLCG(AbstractMinimizer):
529     """
530     Minimizer that uses the nonlinear conjugate gradient method
531 caltinay 3950 (Fletcher-Reeves variant).
532 caltinay 3946 """
533    
534     def run(self, x):
535     i=0
536     k=0
537 gross 4143 args=self.getCostFunction().getArguments(x)
538     r=-self.getCostFunction().getGradient(x, *args)
539     Jx=self.getCostFunction()(x, *args)
540 caltinay 3946 d=r
541 gross 4143 delta=self.getCostFunction().getDualProduct(r,r)
542 caltinay 3946 delta0=delta
543 gross 4120 self._doCallback(i, x, Jx, -r)
544 caltinay 3950
545 gross 4143 while i<self._imax and Lsup(r)>self._m_tol:
546 gross 4121 self.logger.debug("iteration %d"%i)
547 caltinay 3950 self.logger.debug("grad f(x) = %s"%(-r))
548     self.logger.debug("d = %s"%d)
549     self.logger.debug("x = %s"%x)
550    
551 gross 4143 alpha, Jx, g_Jx_new = line_search(self.getCostFunction(), x, d, -r, Jx, c2=0.4)
552 caltinay 3946 self.logger.debug("alpha=%e"%(alpha))
553     x=x+alpha*d
554 gross 4143 r=-self.getCostFunction().getGradient(x) if g_Jx_new is None else -g_Jx_new
555 caltinay 3946 delta_o=delta
556 gross 4143 delta=self.getCostFunction().getDualProduct(r,r)
557 caltinay 3946 beta=delta/delta_o
558     d=r+beta*d
559     k=k+1
560     try:
561     lenx=len(x)
562     except:
563     lenx=x.getNumberOfDataPoints()
564 gross 4143 if k == lenx or self.getCostFunction().getDualProduct(r,d) <= 0:
565 caltinay 3946 d=r
566     k=0
567     i+=1
568 gross 4120 self._doCallback(i, x, Jx, g_Jx_new)
569 caltinay 3946
570     if i >= self._imax:
571     reason=self.MAX_ITERATIONS_REACHED
572 gross 4143 self.logger.debug("Maximum number of iterations reached!")
573 caltinay 3946 else:
574     reason=self.TOLERANCE_REACHED
575 gross 4143 self.logger.debug("Success after %d iterations! Final delta=%e"%(i,delta))
576 caltinay 3946
577     self._result=x
578     return reason
579    
580    
581     if __name__=="__main__":
582     # Example usage with function 'rosen' (minimum=[1,1,...1]):
583     from scipy.optimize import rosen, rosen_der
584 gross 4074 from esys.downunder import MeteredCostFunction
585 caltinay 3946 import sys
586 gross 4074 N=10
587 caltinay 3946 x0=np.array([4.]*N) # initial guess
588    
589 gross 4074 class RosenFunc(MeteredCostFunction):
590 caltinay 4213 def __init__(self):
591     super(RosenFunc, self).__init__()
592     self.provides_inverse_Hessian_approximation=False
593     def _getDualProduct(self, f0, f1):
594     return np.dot(f0, f1)
595     def _getValue(self, x, *args):
596     return rosen(x)
597     def _getGradient(self, x, *args):
598     return rosen_der(x)
599     def _getNorm(self,x):
600     return Lsup(x)
601 caltinay 3946
602     f=RosenFunc()
603     m=None
604     if len(sys.argv)>1:
605     method=sys.argv[1].lower()
606     if method=='nlcg':
607     m=MinimizerNLCG(f)
608     elif method=='bfgs':
609     m=MinimizerBFGS(f)
610    
611     if m is None:
612     # default
613     m=MinimizerLBFGS(f)
614     #m.setOptions(historySize=10000)
615    
616 gross 4074 logging.basicConfig(format='[%(funcName)s] \033[1;30m%(message)s\033[0m', level=logging.DEBUG)
617 gross 4143 m.setTolerance(m_tol=1e-5)
618 caltinay 3950 m.setMaxIterations(600)
619 caltinay 3946 m.run(x0)
620     m.logSummary()
621     print("\tLsup(result)=%.8f"%np.amax(abs(m.getResult())))
622    
623     #from scipy.optimize import fmin_cg
624     #print("scipy ref=%.8f"%np.amax(abs(fmin_cg(rosen, x0, rosen_der, maxiter=10000))))
625    

  ViewVC Help
Powered by ViewVC 1.1.26