/[escript]/trunk/escript/py_src/minimize.py
ViewVC logotype

Contents of /trunk/escript/py_src/minimize.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 1811 - (show annotations)
Thu Sep 25 23:11:13 2008 UTC (11 years, 1 month ago) by ksteube
File MIME type: text/x-python
File size: 19851 byte(s)
Copyright updated in all files

1
2 ########################################################
3 #
4 # Copyright (c) 2003-2008 by University of Queensland
5 # Earth Systems Science Computational Center (ESSCC)
6 # http://www.uq.edu.au/esscc
7 #
8 # Primary Business: Queensland, Australia
9 # Licensed under the Open Software License version 3.0
10 # http://www.opensource.org/licenses/osl-3.0.php
11 #
12 ########################################################
13
14 __copyright__="""Copyright (c) 2003-2008 by University of Queensland
15 Earth Systems Science Computational Center (ESSCC)
16 http://www.uq.edu.au/esscc
17 Primary Business: Queensland, Australia"""
18 __license__="""Licensed under the Open Software License version 3.0
19 http://www.opensource.org/licenses/osl-3.0.php"""
20 __url__="http://www.uq.edu.au/esscc/escript-finley"
21
22 # ******NOTICE***************
23 # optimize.py module by Travis E. Oliphant
24 #
25 # You may copy and use this module as you see fit with no
26 # guarantee implied provided you keep this notice in all copies.
27 # *****END NOTICE************
28
29 # A collection of optimization algorithms. Version 0.3.1
30
31 # Minimization routines
32 """optimize.py
33
34 A collection of general-purpose optimization routines using numarrayeric
35
36 fmin --- Nelder-Mead Simplex algorithm (uses only function calls)
37 fminBFGS --- Quasi-Newton method (uses function and gradient)
38 fminNCG --- Line-search Newton Conjugate Gradient (uses function, gradient and hessian (if it's provided))
39
40 """
41 import numarray
42 __version__="0.3.1"
43
44 def para(x):
45 return (x[0]-2.)**2
46
47 def para_der(x):
48 return 2*(x-2.)
49
50 def para_hess(x):
51 return numarray.ones((1,1))*2.
52
53 def para_hess_p(x,p):
54 return 2*p
55
56 def rosen(x): # The Rosenbrock function
57 return numarray.sum(100.0*(x[1:]-x[:-1]**2.0)**2.0 + (1-x[:-1])**2.0)
58
59 def rosen_der(x):
60 xm = x[1:-1]
61 xm_m1 = x[:-2]
62 xm_p1 = x[2:]
63 der = numarray.zeros(x.shape,x.typecode())
64 der[1:-1] = 200*(xm-xm_m1**2) - 400*(xm_p1 - xm**2)*xm - 2*(1-xm)
65 der[0] = -400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0])
66 der[-1] = 200*(x[-1]-x[-2]**2)
67 return der
68
69 def rosen3_hess_p(x,p):
70 assert(len(x)==3)
71 assert(len(p)==3)
72 hessp = numarray.zeros((3,),x.typecode())
73 hessp[0] = (2 + 800*x[0]**2 - 400*(-x[0]**2 + x[1])) * p[0] \
74 - 400*x[0]*p[1] \
75 + 0
76 hessp[1] = - 400*x[0]*p[0] \
77 + (202 + 800*x[1]**2 - 400*(-x[1]**2 + x[2]))*p[1] \
78 - 400*x[1] * p[2]
79 hessp[2] = 0 \
80 - 400*x[1] * p[1] \
81 + 200 * p[2]
82
83 return hessp
84
85 def rosen3_hess(x):
86 assert(len(x)==3)
87 hessp = numarray.zeros((3,3),x.typecode())
88 hessp[0,:] = [2 + 800*x[0]**2 -400*(-x[0]**2 + x[1]), -400*x[0], 0]
89 hessp[1,:] = [-400*x[0], 202+800*x[1]**2 -400*(-x[1]**2 + x[2]), -400*x[1]]
90 hessp[2,:] = [0,-400*x[1], 200]
91 return hessp
92
93
94 def fmin(func, x0, args=(), xtol=1e-4, ftol=1e-4, maxiter=None, maxfun=None, fulloutput=0, printmessg=1):
95 """xopt,{fval,warnflag} = fmin(function, x0, args=(), xtol=1e-4, ftol=1e-4,
96 maxiter=200*len(x0), maxfun=200*len(x0), fulloutput=0, printmessg=0)
97
98 Uses a Nelder-Mead Simplex algorithm to find the minimum of function
99 of one or more variables.
100 """
101 x0 = numarray.asarray(x0)
102 assert (len(x0.shape)==1)
103 N = len(x0)
104 if maxiter is None:
105 maxiter = N * 200
106 if maxfun is None:
107 maxfun = N * 200
108
109 rho = 1; chi = 2; psi = 0.5; sigma = 0.5;
110 one2np1 = range(1,N+1)
111
112 sim = numarray.zeros((N+1,N),x0.typecode())
113 fsim = numarray.zeros((N+1,),'d')
114 sim[0] = x0
115 fsim[0] = apply(func,(x0,)+args)
116 nonzdelt = 0.05
117 zdelt = 0.00025
118 for k in range(0,N):
119 y = numarray.array(x0,copy=1)
120 if y[k] != 0:
121 y[k] = (1+nonzdelt)*y[k]
122 else:
123 y[k] = zdelt
124
125 sim[k+1] = y
126 f = apply(func,(y,)+args)
127 fsim[k+1] = f
128
129 ind = numarray.argsort(fsim)
130 fsim = numarray.take(fsim,ind) # sort so sim[0,:] has the lowest function value
131 sim = numarray.take(sim,ind,0)
132
133 iterations = 1
134 funcalls = N+1
135
136 while (funcalls < maxfun and iterations < maxiter):
137 if (max(numarray.ravel(numarray.absolute(sim[1:]-sim[0]))) <= xtol \
138 and max(numarray.absolute(fsim[0]-fsim[1:])) <= ftol):
139 break
140
141 xbar = numarray.add.reduce(sim[:-1],0) / N
142 xr = (1+rho)*xbar - rho*sim[-1]
143 fxr = apply(func,(xr,)+args)
144 funcalls = funcalls + 1
145 doshrink = 0
146
147 if fxr < fsim[0]:
148 xe = (1+rho*chi)*xbar - rho*chi*sim[-1]
149 fxe = apply(func,(xe,)+args)
150 funcalls = funcalls + 1
151
152 if fxe < fxr:
153 sim[-1] = xe
154 fsim[-1] = fxe
155 else:
156 sim[-1] = xr
157 fsim[-1] = fxr
158 else: # fsim[0] <= fxr
159 if fxr < fsim[-2]:
160 sim[-1] = xr
161 fsim[-1] = fxr
162 else: # fxr >= fsim[-2]
163 # Perform contraction
164 if fxr < fsim[-1]:
165 xc = (1+psi*rho)*xbar - psi*rho*sim[-1]
166 fxc = apply(func,(xc,)+args)
167 funcalls = funcalls + 1
168
169 if fxc <= fxr:
170 sim[-1] = xc
171 fsim[-1] = fxc
172 else:
173 doshrink=1
174 else:
175 # Perform an inside contraction
176 xcc = (1-psi)*xbar + psi*sim[-1]
177 fxcc = apply(func,(xcc,)+args)
178 funcalls = funcalls + 1
179
180 if fxcc < fsim[-1]:
181 sim[-1] = xcc
182 fsim[-1] = fxcc
183 else:
184 doshrink = 1
185
186 if doshrink:
187 for j in one2np1:
188 sim[j] = sim[0] + sigma*(sim[j] - sim[0])
189 fsim[j] = apply(func,(sim[j],)+args)
190 funcalls = funcalls + N
191
192 ind = numarray.argsort(fsim)
193 sim = numarray.take(sim,ind,0)
194 fsim = numarray.take(fsim,ind)
195 iterations = iterations + 1
196
197 x = sim[0]
198 fval = min(fsim)
199 warnflag = 0
200
201 if funcalls >= maxfun:
202 warnflag = 1
203 if printmessg:
204 print "Warning: Maximum number of function evaluations has been exceeded."
205 elif iterations >= maxiter:
206 warnflag = 2
207 if printmessg:
208 print "Warning: Maximum number of iterations has been exceeded"
209 else:
210 if printmessg:
211 print "Optimization terminated successfully."
212 print " Current function value: %f" % fval
213 print " Iterations: %d" % iterations
214 print " Function evaluations: %d" % funcalls
215
216 if fulloutput:
217 return x, fval, warnflag
218 else:
219 return x
220
221
222 def zoom(a_lo, a_hi):
223 pass
224
225
226
227 def line_search(f, fprime, xk, pk, gfk, args=(), c1=1e-4, c2=0.9, amax=50):
228 """alpha, fc, gc = line_search(f, xk, pk, gfk, args=(), c1=1e-4, c2=0.9, amax=1)
229
230 minimize the function f(xk+alpha pk) using the line search algorithm of
231 Wright and Nocedal in 'numarrayerical Optimization', 1999, pg. 59-60
232 """
233
234 fc = 0
235 gc = 0
236 alpha0 = 1.0
237 phi0 = apply(f,(xk,)+args)
238 phi_a0 = apply(f,(xk+alpha0*pk,)+args)
239 fc = fc + 2
240 derphi0 = numarray.dot(gfk,pk)
241 derphi_a0 = numarray.dot(apply(fprime,(xk+alpha0*pk,)+args),pk)
242 gc = gc + 1
243
244 # check to see if alpha0 = 1 satisfies Strong Wolfe conditions.
245 if (phi_a0 <= phi0 + c1*alpha0*derphi0) \
246 and (numarray.absolute(derphi_a0) <= c2*numarray.absolute(derphi0)):
247 return alpha0, fc, gc
248
249 alpha0 = 0
250 alpha1 = 1
251 phi_a1 = phi_a0
252 phi_a0 = phi0
253
254 i = 1
255 while 1:
256 if (phi_a1 > phi0 + c1*alpha1*derphi0) or \
257 ((phi_a1 >= phi_a0) and (i > 1)):
258 return zoom(alpha0, alpha1)
259
260 derphi_a1 = numarray.dot(apply(fprime,(xk+alpha1*pk,)+args),pk)
261 gc = gc + 1
262 if (numarray.absolute(derphi_a1) <= -c2*derphi0):
263 return alpha1
264
265 if (derphi_a1 >= 0):
266 return zoom(alpha1, alpha0)
267
268 alpha2 = (amax-alpha1)*0.25 + alpha1
269 i = i + 1
270 alpha0 = alpha1
271 alpha1 = alpha2
272 phi_a0 = phi_a1
273 phi_a1 = apply(f,(xk+alpha1*pk,)+args)
274
275
276
277 def line_search_BFGS(f, xk, pk, gfk, args=(), c1=1e-4, alpha0=1):
278 """alpha, fc, gc = line_search(f, xk, pk, gfk, args=(), c1=1e-4, alpha0=1)
279
280 minimize over alpha, the function f(xk+alpha pk) using the interpolation
281 algorithm (Armiijo backtracking) as suggested by
282 Wright and Nocedal in 'numarrayerical Optimization', 1999, pg. 56-57
283 """
284
285 fc = 0
286 phi0 = apply(f,(xk,)+args) # compute f(xk)
287 phi_a0 = apply(f,(xk+alpha0*pk,)+args) # compute f
288 fc = fc + 2
289 derphi0 = numarray.dot(gfk,pk)
290
291 if (phi_a0 <= phi0 + c1*alpha0*derphi0):
292 return alpha0, fc, 0
293
294 # Otherwise compute the minimizer of a quadratic interpolant:
295
296 alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
297 phi_a1 = apply(f,(xk+alpha1*pk,)+args)
298 fc = fc + 1
299
300 if (phi_a1 <= phi0 + c1*alpha1*derphi0):
301 return alpha1, fc, 0
302
303 # Otherwise loop with cubic interpolation until we find an alpha which satifies
304 # the first Wolfe condition (since we are backtracking, we will assume that
305 # the value of alpha is not too small and satisfies the second condition.
306
307 while 1: # we are assuming pk is a descent direction
308 factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
309 a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
310 alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
311 a = a / factor
312 b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
313 alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
314 b = b / factor
315
316 alpha2 = (-b + numarray.sqrt(numarray.absolute(b**2 - 3 * a * derphi0))) / (3.0*a)
317 phi_a2 = apply(f,(xk+alpha2*pk,)+args)
318 fc = fc + 1
319
320 if (phi_a2 <= phi0 + c1*alpha2*derphi0):
321 return alpha2, fc, 0
322
323 if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
324 alpha2 = alpha1 / 2.0
325
326 alpha0 = alpha1
327 alpha1 = alpha2
328 phi_a0 = phi_a1
329 phi_a1 = phi_a2
330
331 epsilon = 1e-8
332
333 def approx_fprime(xk,f,*args):
334 f0 = apply(f,(xk,)+args)
335 grad = numarray.zeros((len(xk),),'d')
336 ei = numarray.zeros((len(xk),),'d')
337 for k in range(len(xk)):
338 ei[k] = 1.0
339 grad[k] = (apply(f,(xk+epsilon*ei,)+args) - f0)/epsilon
340 ei[k] = 0.0
341 return grad
342
343 def approx_fhess_p(x0,p,fprime,*args):
344 f2 = apply(fprime,(x0+epsilon*p,)+args)
345 f1 = apply(fprime,(x0,)+args)
346 return (f2 - f1)/epsilon
347
348
349 def fminBFGS(f, x0, fprime=None, args=(), avegtol=1e-5, maxiter=None, fulloutput=0, printmessg=1):
350 """xopt = fminBFGS(f, x0, fprime=None, args=(), avegtol=1e-5, maxiter=None, fulloutput=0, printmessg=1)
351
352 Optimize the function, f, whose gradient is given by fprime using the
353 quasi-Newton method of Broyden, Fletcher, Goldfarb, and Shanno (BFGS)
354 See Wright, and Nocedal 'numarrayerical Optimization', 1999, pg. 198.
355 """
356
357 app_fprime = 0
358 if fprime is None:
359 app_fprime = 1
360
361 x0 = numarray.asarray(x0)
362 if maxiter is None:
363 maxiter = len(x0)*200
364 func_calls = 0
365 grad_calls = 0
366 k = 0
367 N = len(x0)
368 gtol = N*avegtol
369 I = numarray.identity(N)
370 Hk = I
371
372 if app_fprime:
373 gfk = apply(approx_fprime,(x0,f)+args)
374 func_calls = func_calls + len(x0) + 1
375 else:
376 gfk = apply(fprime,(x0,)+args)
377 grad_calls = grad_calls + 1
378 xk = x0
379 sk = [2*gtol]
380 while (numarray.add.reduce(numarray.absolute(gfk)) > gtol) and (k < maxiter):
381 pk = -numarray.dot(Hk,gfk)
382 alpha_k, fc, gc = line_search_BFGS(f,xk,pk,gfk,args)
383 func_calls = func_calls + fc
384 xkp1 = xk + alpha_k * pk
385 sk = xkp1 - xk
386 xk = xkp1
387 if app_fprime:
388 gfkp1 = apply(approx_fprime,(xkp1,f)+args)
389 func_calls = func_calls + gc + len(x0) + 1
390 else:
391 gfkp1 = apply(fprime,(xkp1,)+args)
392 grad_calls = grad_calls + gc + 1
393
394 yk = gfkp1 - gfk
395 k = k + 1
396
397 rhok = 1 / numarray.dot(yk,sk)
398 A1 = I - sk[:,numarray.NewAxis] * yk[numarray.NewAxis,:] * rhok
399 A2 = I - yk[:,numarray.NewAxis] * sk[numarray.NewAxis,:] * rhok
400 Hk = numarray.dot(A1,numarray.dot(Hk,A2)) + rhok * sk[:,numarray.NewAxis] * sk[numarray.NewAxis,:]
401 gfk = gfkp1
402
403
404 if printmessg or fulloutput:
405 fval = apply(f,(xk,)+args)
406 if k >= maxiter:
407 warnflag = 1
408 if printmessg:
409 print "Warning: Maximum number of iterations has been exceeded"
410 print " Current function value: %f" % fval
411 print " Iterations: %d" % k
412 print " Function evaluations: %d" % func_calls
413 print " Gradient evaluations: %d" % grad_calls
414 else:
415 warnflag = 0
416 if printmessg:
417 print "Optimization terminated successfully."
418 print " Current function value: %f" % fval
419 print " Iterations: %d" % k
420 print " Function evaluations: %d" % func_calls
421 print " Gradient evaluations: %d" % grad_calls
422
423 if fulloutput:
424 return xk, fval, func_calls, grad_calls, warnflag
425 else:
426 return xk
427
428
429 def fminNCG(f, x0, fprime, fhess_p=None, fhess=None, args=(), avextol=1e-5, maxiter=None, fulloutput=0, printmessg=1):
430 """xopt = fminNCG(f, x0, fprime, fhess_p=None, fhess=None, args=(), avextol=1e-5, maxiter=None, fulloutput=0, printmessg=1)
431
432 Optimize the function, f, whose gradient is given by fprime using the
433 Newton-CG method. fhess_p must compute the hessian times an arbitrary
434 vector. If it is not given, finite-differences on fprime are used to
435 compute it. See Wright, and Nocedal 'numarrayerical Optimization', 1999,
436 pg. 140.
437 """
438
439 x0 = numarray.asarray(x0)
440 fcalls = 0
441 gcalls = 0
442 hcalls = 0
443 approx_hessp = 0
444 if fhess_p is None and fhess is None: # Define hessian product
445 approx_hessp = 1
446
447 xtol = len(x0)*avextol
448 update = [2*xtol]
449 xk = x0
450 k = 0
451 while (numarray.add.reduce(numarray.absolute(update)) > xtol) and (k < maxiter):
452 # Compute a search direction pk by applying the CG method to
453 # del2 f(xk) p = - grad f(xk) starting from 0.
454 b = -apply(fprime,(xk,)+args)
455 gcalls = gcalls + 1
456 maggrad = numarray.add.reduce(numarray.absolute(b))
457 eta = min([0.5,numarray.sqrt(maggrad)])
458 termcond = eta * maggrad
459 xsupi = numarray.zeros((len(x0),))
460 ri = -b
461 psupi = -ri
462 i = 0
463 dri0 = numarray.dot(ri,ri)
464
465 if fhess is not None: # you want to compute hessian once.
466 A = apply(fhess,(xk,)+args)
467 hcalls = hcalls + 1
468
469 while numarray.add.reduce(numarray.absolute(ri)) > termcond:
470 if fhess is None:
471 if approx_hessp:
472 Ap = apply(approx_fhess_p,(xk,psupi,fprime)+args)
473 gcalls = gcalls + 2
474 else:
475 Ap = apply(fhess_p,(xk,psupi)+args)
476 hcalls = hcalls + 1
477 else:
478 # Ap = numarray.dot(A,psupi)
479 Ap = numarray.matrixmultiply(A,psupi)
480 # check curvature
481 curv = numarray.dot(psupi,Ap)
482 if (curv <= 0):
483 if (i > 0):
484 break
485 else:
486 xsupi = xsupi + dri0/curv * psupi
487 break
488 alphai = dri0 / curv
489 xsupi = xsupi + alphai * psupi
490 ri = ri + alphai * Ap
491 dri1 = numarray.dot(ri,ri)
492 betai = dri1 / dri0
493 psupi = -ri + betai * psupi
494 i = i + 1
495 dri0 = dri1 # update numarray.dot(ri,ri) for next time.
496
497 pk = xsupi # search direction is solution to system.
498 gfk = -b # gradient at xk
499 alphak, fc, gc = line_search_BFGS(f,xk,pk,gfk,args)
500 fcalls = fcalls + fc
501 gcalls = gcalls + gc
502
503 update = alphak * pk
504 xk = xk + update
505 k = k + 1
506
507 if printmessg or fulloutput:
508 fval = apply(f,(xk,)+args)
509 if k >= maxiter:
510 warnflag = 1
511 if printmessg:
512 print "Warning: Maximum number of iterations has been exceeded"
513 print " Current function value: %f" % fval
514 print " Iterations: %d" % k
515 print " Function evaluations: %d" % fcalls
516 print " Gradient evaluations: %d" % gcalls
517 print " Hessian evaluations: %d" % hcalls
518 else:
519 warnflag = 0
520 if printmessg:
521 print "Optimization terminated successfully."
522 print " Current function value: %f" % fval
523 print " Iterations: %d" % k
524 print " Function evaluations: %d" % fcalls
525 print " Gradient evaluations: %d" % gcalls
526 print " Hessian evaluations: %d" % hcalls
527
528 if fulloutput:
529 return xk, fval, fcalls, gcalls, hcalls, warnflag
530 else:
531 return xk
532
533
534
535 if __name__ == "__main__":
536 import string
537 import time
538
539
540 times = []
541 algor = []
542 x0 = [0.8,1.2,0.7]
543 start = time.time()
544 x = fmin(rosen,x0)
545 print x
546 times.append(time.time() - start)
547 algor.append('Nelder-Mead Simplex\t')
548
549 start = time.time()
550 x = fminBFGS(rosen, x0, fprime=rosen_der, maxiter=80)
551 print x
552 times.append(time.time() - start)
553 algor.append('BFGS Quasi-Newton\t')
554
555 start = time.time()
556 x = fminBFGS(rosen, x0, avegtol=1e-4, maxiter=100)
557 print x
558 times.append(time.time() - start)
559 algor.append('BFGS without gradient\t')
560
561
562 start = time.time()
563 x = fminNCG(rosen, x0, rosen_der, fhess_p=rosen3_hess_p, maxiter=80)
564 print x
565 times.append(time.time() - start)
566 algor.append('Newton-CG with hessian product')
567
568
569 start = time.time()
570 x = fminNCG(rosen, x0, rosen_der, fhess=rosen3_hess, maxiter=80)
571 print x
572 times.append(time.time() - start)
573 algor.append('Newton-CG with full hessian')
574
575 print "\nMinimizing the Rosenbrock function of order 3\n"
576 print " Algorithm \t\t\t Seconds"
577 print "===========\t\t\t ========="
578 for k in range(len(algor)):
579 print algor[k], "\t -- ", times[k]
580
581 times = []
582 algor=[]
583 x0 = [1.,]
584 start = time.time()
585 x = fmin(para,x0)
586 print x
587 times.append(time.time() - start)
588 algor.append('Nelder-Mead Simplex\t')
589
590 start = time.time()
591 x = fminBFGS(para, x0, fprime=para_der, maxiter=80)
592 print x
593 times.append(time.time() - start)
594 algor.append('BFGS Quasi-Newton\t')
595
596 start = time.time()
597 x = fminBFGS(para, x0, avegtol=1e-4, maxiter=100)
598 print x
599 times.append(time.time() - start)
600 algor.append('BFGS without gradient\t')
601
602
603 start = time.time()
604 x = fminNCG(para, x0, para_der, fhess_p=para_hess_p, maxiter=80)
605 print x
606 times.append(time.time() - start)
607 algor.append('Newton-CG with hessian product')
608
609
610 start = time.time()
611 x = fminNCG(para, x0, para_der, fhess=para_hess, maxiter=80)
612 print x
613 times.append(time.time() - start)
614 algor.append('Newton-CG with full hessian')
615
616 print "\nMinimizing x^2\n"
617 print " Algorithm \t\t\t Seconds"
618 print "===========\t\t\t ========="
619 for k in range(len(algor)):
620 print algor[k], "\t -- ", times[k]

  ViewVC Help
Powered by ViewVC 1.1.26