/[escript]/trunk/downunder/py_src/minimizers.py
ViewVC logotype

Diff of /trunk/downunder/py_src/minimizers.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 4212 by jfenwick, Tue Jan 22 09:30:23 2013 UTC revision 4213 by caltinay, Tue Feb 19 01:16:29 2013 UTC
# Line 48  def _zoom(phi, gradphi, phiargs, alpha_l Line 48  def _zoom(phi, gradphi, phiargs, alpha_l
48          alpha=alpha_lo+.5*(alpha_hi-alpha_lo) # should use interpolation...          alpha=alpha_lo+.5*(alpha_hi-alpha_lo) # should use interpolation...
49          args_a=phiargs(alpha)          args_a=phiargs(alpha)
50          phi_a=phi(alpha, *args_a)          phi_a=phi(alpha, *args_a)
51          zoomlogger.debug("Zoom.iteration %d, alpha=%e, phi(alpha)=%e"%(i,alpha,phi_a))          zoomlogger.debug("iteration %d, alpha=%e, phi(alpha)=%e"%(i,alpha,phi_a))
52          if phi_a > phi0+c1*alpha*gphi0 or phi_a >= phi_lo:          if phi_a > phi0+c1*alpha*gphi0 or phi_a >= phi_lo:
53              alpha_hi=alpha              alpha_hi=alpha
54          else:          else:
55              gphi_a=gradphi(alpha, *args_a)              gphi_a=gradphi(alpha, *args_a)
56              zoomlogger.debug("Zoom.grad(phi(alpha))=%e"%(gphi_a))              zoomlogger.debug("\tgrad(phi(alpha))=%e"%(gphi_a))
57              if np.abs(gphi_a) <= -c2*gphi0:              if np.abs(gphi_a) <= -c2*gphi0:
58                  break                  break
59              if gphi_a*(alpha_hi-alpha_lo) >= 0:              if gphi_a*(alpha_hi-alpha_lo) >= 0:
# Line 90  def line_search(f, x, p, g_Jx, Jx, alpha Line 90  def line_search(f, x, p, g_Jx, Jx, alpha
90      def gradphi(a, *args):      def gradphi(a, *args):
91          g_Jx_new[0]=f.getGradient(x+a*p, *args)          g_Jx_new[0]=f.getGradient(x+a*p, *args)
92          return f.getDualProduct(p, g_Jx_new[0])          return f.getDualProduct(p, g_Jx_new[0])
   
93      def phiargs(a):      def phiargs(a):
94          try:          try:
95              args=f.getArguments(x+a*p)              args=f.getArguments(x+a*p)
# Line 113  def line_search(f, x, p, g_Jx, Jx, alpha Line 112  def line_search(f, x, p, g_Jx, Jx, alpha
112      i=1      i=1
113    
114      while i<IMAX and alpha>0. and alpha<alpha_truncationax:      while i<IMAX and alpha>0. and alpha<alpha_truncationax:
   
115          args_a=phiargs(alpha)          args_a=phiargs(alpha)
116          phi_a=phi(alpha, *args_a)          phi_a=phi(alpha, *args_a)
117          lslogger.debug("Line Search.iteration %d, alpha=%e, phi(alpha)=%e"%(i,alpha,phi_a))          lslogger.debug("iteration %d, alpha=%e, phi(alpha)=%e"%(i,alpha,phi_a))
118          if (phi_a > phi0+c1*alpha*gphi0) or ((phi_a>=old_phi_a) and (i>1)):          if (phi_a > phi0+c1*alpha*gphi0) or ((phi_a>=old_phi_a) and (i>1)):
119              alpha, phi_a, gphi_a = _zoom(phi, gradphi, phiargs, old_alpha, alpha, old_phi_a, phi_a, c1, c2, phi0, gphi0)              alpha, phi_a, gphi_a = _zoom(phi, gradphi, phiargs, old_alpha, alpha, old_phi_a, phi_a, c1, c2, phi0, gphi0)
120              break              break
# Line 137  def line_search(f, x, p, g_Jx, Jx, alpha Line 135  def line_search(f, x, p, g_Jx, Jx, alpha
135      return alpha, phi_a, g_Jx_new[0]      return alpha, phi_a, g_Jx_new[0]
136    
137  class MinimizerException(Exception):  class MinimizerException(Exception):
138     """      """
139     This is a generic exception thrown by a minimizer.      This is a generic exception thrown by a minimizer.
140     """      """
141     pass      pass
142    
143  class MinimizerMaxIterReached(MinimizerException):  class MinimizerMaxIterReached(MinimizerException):
144     """      """
145     Exception thrown if the maximum number of iteration steps is reached.      Exception thrown if the maximum number of iteration steps is reached.
146     """      """
147     pass      pass
148    
149  class MinimizerIterationIncurableBreakDown(MinimizerException):  class MinimizerIterationIncurableBreakDown(MinimizerException):
150     """      """
151     Exception thrown if the iteration scheme encountered an incurable breakdown.      Exception thrown if the iteration scheme encountered an incurable
152     """      breakdown.
153     pass          """
154        pass
155    
156    
157  ##############################################################################  ##############################################################################
# Line 178  class AbstractMinimizer(object): Line 177  class AbstractMinimizer(object):
177          self._callback = None          self._callback = None
178          self.logger = logging.getLogger('inv.%s'%self.__class__.__name__)          self.logger = logging.getLogger('inv.%s'%self.__class__.__name__)
179          self.setTolerance()          self.setTolerance()
           
180    
181      def setCostFunction(self, J):      def setCostFunction(self, J):
182          """          """
183          set the cost function to be minimized          set the cost function to be minimized
184            
185          :param J: the cost function to be minimized          :param J: the cost function to be minimized
186          :type J: `CostFunction`                  :type J: `CostFunction`
187          """          """
188          self.__J=J          self.__J=J
189        
190      def getCostFunction(self):      def getCostFunction(self):
191          """          """
192          return the cost function to be minimized          return the cost function to be minimized
193            
194          :rtype: `CostFunction`                  :rtype: `CostFunction`
195          """          """
196          return self.__J          return self.__J
197            
198      def setTolerance(self, m_tol=1e-4, J_tol=None):      def setTolerance(self, m_tol=1e-4, J_tol=None):
199          """          """
200          Sets the tolerance for the stopping criterion. The minimizer stops when          Sets the tolerance for the stopping criterion. The minimizer stops
201          an appropriate norm is less than `m_tol`.          when an appropriate norm is less than `m_tol`.
202          """          """
203          self._m_tol = m_tol          self._m_tol = m_tol
204          self._J_tol = J_tol          self._J_tol = J_tol
# Line 259  class AbstractMinimizer(object): Line 257  class AbstractMinimizer(object):
257          Outputs a summary of the completed minimization process to the logger.          Outputs a summary of the completed minimization process to the logger.
258          """          """
259          if hasattr(self.getCostFunction(), "Value_calls"):          if hasattr(self.getCostFunction(), "Value_calls"):
260              self.logger.warning("Number of cost function evaluations: %d"%self.getCostFunction().Value_calls)              self.logger.info("Number of cost function evaluations: %d"%self.getCostFunction().Value_calls)
261              self.logger.warning("Number of gradient evaluations: %d"%self.getCostFunction().Gradient_calls)              self.logger.info("Number of gradient evaluations: %d"%self.getCostFunction().Gradient_calls)
262              self.logger.warning("Number of inner product evaluations: %d"%self.getCostFunction().DualProduct_calls)              self.logger.info("Number of inner product evaluations: %d"%self.getCostFunction().DualProduct_calls)
263              self.logger.warning("Number of argument evaluations: %d"%self.getCostFunction().Arguments_calls)              self.logger.info("Number of argument evaluations: %d"%self.getCostFunction().Arguments_calls)
264              self.logger.warning("Number of norm evaluations: %d"%self.getCostFunction().Norm_calls)              self.logger.info("Number of norm evaluations: %d"%self.getCostFunction().Norm_calls)
265    
266  ##############################################################################  ##############################################################################
267  class MinimizerLBFGS(AbstractMinimizer):  class MinimizerLBFGS(AbstractMinimizer):
# Line 277  class MinimizerLBFGS(AbstractMinimizer): Line 275  class MinimizerLBFGS(AbstractMinimizer):
275    
276      # Initial Hessian multiplier      # Initial Hessian multiplier
277      _initial_H = 1      _initial_H = 1
278        
279      # restart      # Restart after this many iteration steps
280      _restart= 60      _restart = 60
281    
282      def getOptions(self):      def getOptions(self):
283          return {'truncation':self._truncation,'initialHessian':self._initial_H, 'restart':self._restart}          return {'truncation':self._truncation,'initialHessian':self._initial_H, 'restart':self._restart}
# Line 296  class MinimizerLBFGS(AbstractMinimizer): Line 294  class MinimizerLBFGS(AbstractMinimizer):
294                  raise KeyError("Invalid option '%s'"%o)                  raise KeyError("Invalid option '%s'"%o)
295    
296      def run(self, x):      def run(self, x):
297            if self.getCostFunction().provides_inverse_Hessian_approximation:
298                self.getCostFunction().updateHessian()
299                invH_scale = None
300            else:
301                invH_scale = self._initial_H
302    
303      if self.getCostFunction().provides_inverse_Hessian_approximation:          # start the iteration:
304          self.getCostFunction().updateHessian()          n_iter = 0
305          invH_scale = None          n_last_break_down=-1
306      else:          non_curable_break_down = False
307          invH_scale = self._initial_H          converged = False
308                    args=self.getCostFunction().getArguments(x)
309      # start the iteration:          g_Jx=self.getCostFunction().getGradient(x, *args)
310      n_iter = 0          Jx=self.getCostFunction()(x, *args)
311      n_last_break_down=-1          Jx_0=Jx
312      non_curable_break_down = False  
313      converged = False          while not converged and not non_curable_break_down and n_iter < self._imax:
314      args=self.getCostFunction().getArguments(x)            k=0
315      g_Jx=self.getCostFunction().getGradient(x, *args)            break_down = False
316      Jx=self.getCostFunction()(x, *args)            s_and_y=[]
317      Jx_0=Jx            self._doCallback(n_iter, x, Jx, g_Jx)
318        
319      while not converged and not non_curable_break_down and n_iter < self._imax:            while not converged and not break_down and k < self._restart and n_iter < self._imax:
320                            #self.logger.info("\033[1;31miteration %d\033[1;30m"%n_iter)
321        k=0                    if n_iter%10==0:
322        break_down = False                      self.logger.info("********** iteration %3d **********"%n_iter)
323        s_and_y=[]                  else:
324        self._doCallback(n_iter, x, Jx, g_Jx)                      self.logger.debug("********** iteration %3d **********"%n_iter)
325                    # determine search direction
326        while not converged and not break_down and k < self._restart and n_iter < self._imax:                  self.logger.debug("\tJ(x) = %s"%Jx)
327          #self.logger.info("\033[1;31miteration %d\033[1;30m, error=%e"%(k,error))                  self.logger.debug("\tgrad f(x) = %s"%g_Jx)
328          self.logger.debug("LBFGS.iteration %d .......... "%n_iter)                  if invH_scale: self.logger.debug("\tH = %s"%invH_scale)
329          # determine search direction  
330          self.logger.debug("LBFGS.J(x) = %s"%Jx)                  p = -self._twoLoop(invH_scale, g_Jx, s_and_y, x, *args)
331          self.logger.debug("LBFGS.grad f(x) = %s"%g_Jx)                  # determine step length
332              if invH_scale: self.logger.debug("LBFGS.H = %s"%invH_scale)                  alpha, Jx_new, g_Jx_new = line_search(self.getCostFunction(), x, p, g_Jx, Jx)
333                    # this function returns a scaling alpha for the search
334          p = -self._twoLoop(invH_scale, g_Jx, s_and_y, x, *args)                  # direction as well as the cost function evaluation and
335          # determine step length                  # gradient for the new solution approximation x_new=x+alpha*p
336          alpha, Jx_new, g_Jx_new = line_search(self.getCostFunction(), x, p, g_Jx, Jx)                  self.logger.debug("\tSearch direction scaling alpha=%e"%alpha)
337          # this function returns the a scaling alpha for the serch direction  
338          # as well the cost function evaluation and gradient for the new solution                  # execute the step
339          # approximation x_new = x + alpha*p                  delta_x = alpha*p
340                            x_new = x + delta_x
341          self.logger.debug("LBFGS.search direction scaling alpha=%e"%(alpha))                  self.logger.debug("\tJ(x) = %s"%Jx_new)
342          # execute the step  
343          delta_x = alpha*p                  converged = True
344          x_new = x + delta_x                  if self._J_tol:
345          self.logger.debug("LBFGS.J(x) = %s"%Jx_new)                      flag=abs(Jx_new-Jx) <= self._J_tol * abs(Jx_new-Jx_0)
346                        if flag:
347          converged = True                          self.logger.debug("Cost function has converged: dJ, J*J_tol = %e, %e"%(Jx-Jx_new,abs(Jx_new-Jx_0)*self._J_tol))
348          if self._J_tol:                      else:
349          flag= abs(Jx_new-Jx) <= self._J_tol * abs(Jx_new-Jx_0)                          self.logger.debug("Cost function checked: dJ, J*J_tol = %e, %e"%(Jx-Jx_new,abs(Jx_new)*self._J_tol))
350          if flag:  
351              self.logger.debug("LBFGS: cost function has converged: dJ, J*J_tol = %e, %e"%(Jx-Jx_new,abs(Jx_new-Jx_0)*self._J_tol))                      converged = converged and flag
352          else:                      if self._m_tol:
353              self.logger.debug("LBFGS: cost function checked: dJ, J * J_tol = %e, %e"%(Jx-Jx_new,abs(Jx_new)*self._J_tol))                      norm_x=self.getCostFunction().getNorm(x_new)
354                        norm_dx=self.getCostFunction().getNorm(delta_x)
355          converged = converged and flag                      flag= norm_dx <= self._m_tol * norm_x
356          if self._m_tol:                            if flag:
357            norm_x=self.getCostFunction().getNorm(x_new)                          self.logger.debug("Solution has converged: dx, x*m_tol = %e, %e"%(norm_dx,norm_x*self._m_tol))
358            norm_dx=self.getCostFunction().getNorm(delta_x)                      else:
359            flag= norm_dx <= self._m_tol * norm_x                          self.logger.debug("Solution checked: dx, x*m_tol = %e, %e"%(norm_dx,norm_x*self._m_tol))
360            if flag:                      converged = converged and flag
361            self.logger.debug("LBFGS: solution has converged: dx, x*m_tol = %e, %e"%(norm_dx,norm_x*self._m_tol))  
362            else:                  x=x_new
363            self.logger.debug("LBFGS: solution checked: dx, x*m_tol = %e, %e"%(norm_dx,norm_x*self._m_tol))                  if converged:
364            converged = converged and flag                      break
365            
366          x=x_new                  # unfortunately there is more work to do!
367          if converged:                  if g_Jx_new is None:
368            break                      args=self.getCostFunction().getArguments(x_new)
369                                g_Jx_new=self.getCostFunction().getGradient(x_new, args)
370          # unfortunatly there is more work to do!                  delta_g=g_Jx_new-g_Jx
371          if g_Jx_new is None:  
372          args=self.getCostFunction().getArguments(x_new)                  rho=self.getCostFunction().getDualProduct(delta_x, delta_g)
373          g_Jx_new=self.getCostFunction().getGradient(x_new, args)                  if abs(rho)>0:
374          delta_g=g_Jx_new-g_Jx                      s_and_y.append((delta_x,delta_g, rho ))
375                            else:
376          rho=self.getCostFunction().getDualProduct(delta_x, delta_g)                      break_down=True
377          if abs(rho)>0 :  
378          s_and_y.append((delta_x,delta_g, rho ))                  self.getCostFunction().updateHessian()
379          else:                  g_Jx=g_Jx_new
380              break_down=True                  Jx=Jx_new
381    
382          self.getCostFunction().updateHessian()                        k+=1
383          g_Jx=g_Jx_new                  n_iter+=1
384          Jx=Jx_new                  self._doCallback(k, x, Jx, g_Jx)
385            
386          k+=1                  # delete oldest vector pair
387          n_iter+=1                  if k>self._truncation: s_and_y.pop(0)
388          self._doCallback(k, x, Jx, g_Jx)  
389                    if not self.getCostFunction().provides_inverse_Hessian_approximation and not break_down:
390          # delete oldest vector pair                      # set the new scaling factor (approximation of inverse Hessian)
391          if k>self._truncation: s_and_y.pop(0)                      denom=self.getCostFunction().getDualProduct(delta_g, delta_g)
392                        if denom > 0:
393          if not self.getCostFunction().provides_inverse_Hessian_approximation and not break_down :                          invH_scale=self.getCostFunction().getDualProduct(delta_x,delta_g)/denom
394            # set the new scaling factor (approximation of inverse Hessian)                      else:
395            denom=self.getCostFunction().getDualProduct(delta_g, delta_g)                          invH_scale=self._initial_H
396            if denom > 0:                          self.logger.debug("** Break down in H update. Resetting to initial value %s."%self._initial_H)
397                invH_scale=self.getCostFunction().getDualProduct(delta_x,delta_g)/denom            # case handling for inner iteration:
           else:  
               invH_scale=self._initial_H  
               self.logger.debug("LBFGS.Break down in H update. Resetting to initial value %s."%self._initial_H)  
           # case handeling for inner iteration:  
398            if break_down:            if break_down:
399           if n_iter == n_last_break_down +1 :                if n_iter == n_last_break_down+1:
400              non_curable_break_down = True                    non_curable_break_down = True
401              self.logger.debug("LBFGS. Incurable break down detected in step %d."%(n_iter,))                    self.logger.debug("** Incurable break down detected in step %d."%n_iter)
402           else:                else:
403              n_last_break_down = n_iter                    n_last_break_down = n_iter
404              self.logger.debug("LBFGS.Break down detected in step %d. Iteration is restarted."%(n_iter,))                    self.logger.debug("** Break down detected in step %d. Iteration is restarted."%n_iter)
405            if not k < self._restart:            if not k < self._restart:
406              self.logger.debug("LBFGS.Iteration is restarted after %d steps."%(n_iter,))                self.logger.debug("Iteration is restarted after %d steps."%(n_iter,))
407    
408            # case handling for inner iteration:
409            self._result=x
410            if n_iter >= self._imax:
411                self.return_status=self.MAX_ITERATIONS_REACHED
412                self.logger.debug(">>>>>>>>>> Maximum number of iterations reached! <<<<<<<<<<")
413                raise MinimizerMaxIterReached("Gave up after %d steps."%(n_iter,))
414            elif non_curable_break_down:
415                self.return_status=self.INCURABLE_BREAKDOWN
416                self.logger.debug(">>>>>>>>>> Incurable breakdown! <<<<<<<<<<")
417                raise MinimizerIterationIncurableBreakDown("Gave up after %d steps."%(n_iter,))
418            else:
419                self.return_status=self.TOLERANCE_REACHED
420                self.logger.debug("Success after %d iterations!"%k)
421    
422      # case handeling for inner iteration:                  return self.return_status
     self._result=x      
     if n_iter >= self._imax:  
         self.return_status=self.MAX_ITERATIONS_REACHED  
         raise MinimizerMaxIterReached("Gave up after %d steps."%(n_iter,))  
         self.logger.debug("LBFGS.>>>>>> Maximum number of iterations reached!")  
     elif non_curable_break_down:  
         self.return_status=self.INCURABLE_BREAKDOWN  
         self.logger.debug("LBFGS.>>>>>> Uncurable breakdown!")  
         raise MinimizerIterationIncurableBreakDown("Gave up after %d steps."%(n_iter,))  
     else:  
         self.return_status=self.TOLERANCE_REACHED  
         self.logger.debug("LBFGS.Success after %d iterations!"%k)  
423    
     return self.return_status  
       
424      def _twoLoop(self, invH_scale, g_Jx, s_and_y, x, *args):      def _twoLoop(self, invH_scale, g_Jx, s_and_y, x, *args):
425          """          """
426          Helper for the L-BFGS method.          Helper for the L-BFGS method.
# Line 588  if __name__=="__main__": Line 587  if __name__=="__main__":
587      x0=np.array([4.]*N) # initial guess      x0=np.array([4.]*N) # initial guess
588    
589      class RosenFunc(MeteredCostFunction):      class RosenFunc(MeteredCostFunction):
590      def __init__(self):          def __init__(self):
591        super(RosenFunc, self).__init__()            super(RosenFunc, self).__init__()
592        self.provides_inverse_Hessian_approximation=False            self.provides_inverse_Hessian_approximation=False
593      def _getDualProduct(self, f0, f1):          def _getDualProduct(self, f0, f1):
594          return np.dot(f0, f1)              return np.dot(f0, f1)
595      def _getValue(self, x, *args):          def _getValue(self, x, *args):
596          return rosen(x)              return rosen(x)
597      def _getGradient(self, x, *args):          def _getGradient(self, x, *args):
598          return rosen_der(x)              return rosen_der(x)
599      def _getNorm(self,x):          def _getNorm(self,x):
600          return Lsup(x)              return Lsup(x)
601    
602      f=RosenFunc()      f=RosenFunc()
603      m=None      m=None

Legend:
Removed from v.4212  
changed lines
  Added in v.4213

  ViewVC Help
Powered by ViewVC 1.1.26