/[escript]/trunk/downunder/test/python/run_minimizers.py
ViewVC logotype

Contents of /trunk/downunder/test/python/run_minimizers.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4213 - (show annotations)
Tue Feb 19 01:16:29 2013 UTC (7 years, 5 months ago) by caltinay
File MIME type: text/x-python
File size: 6009 byte(s)
Some cleanup and more consistent logging.

1
2 ##############################################################################
3 #
4 # Copyright (c) 2012-2013 by University of Queensland
5 # http://www.uq.edu.au
6 #
7 # Primary Business: Queensland, Australia
8 # Licensed under the Open Software License version 3.0
9 # http://www.opensource.org/licenses/osl-3.0.php
10 #
11 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 # Development since 2012 by School of Earth Sciences
13 #
14 ##############################################################################
15
16 __copyright__="""Copyright (c) 2012-2013 by University of Queensland
17 http://www.uq.edu.au
18 Primary Business: Queensland, Australia"""
19 __license__="""Licensed under the Open Software License version 3.0
20 http://www.opensource.org/licenses/osl-3.0.php"""
21 __url__="https://launchpad.net/escript-finley"
22
23 import logging
24 import numpy as np
25 import unittest
26 import sys
27 from esys.downunder.minimizers import *
28 from esys.downunder.costfunctions import CostFunction
29
30 # number of dimensions for the test function
31 N=10
32
33 # this is mainly to avoid warning messages
34 logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO)
35
36 # Rosenbrock test function to be minimized. The minimum is 0 and lies at
37 # [1,1,...,1].
38 class RosenFunc(CostFunction):
39 def __init__(self):
40 super(RosenFunc, self).__init__()
41 def getDualProduct(self, f0, f1):
42 return np.dot(f0, f1)
43 def getNorm(self,x):
44 return (abs(x.max()))
45 def getGradient(self, x, *args):
46 xm = x[1:-1]
47 xm_m1 = x[:-2]
48 xm_p1 = x[2:]
49 der = np.zeros_like(x)
50 der[1:-1] = 200*(xm-xm_m1**2) - 400*(xm_p1 - xm**2)*xm - 2*(1-xm)
51 der[0] = -400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0])
52 der[-1] = 200*(x[-1]-x[-2]**2)
53 return der
54 def getValue(self, x, *args):
55 return np.sum(100.0*(x[1:]-x[:-1]**2.)**2. + (1-x[:-1])**2.)
56
57 class TestMinimizerLBFGS(unittest.TestCase):
58 def setUp(self):
59 self.f=RosenFunc()
60 self.minimizer=MinimizerLBFGS(self.f)
61 self.x0=np.array([2.]*N)
62 self.xstar=np.array([1.]*N)
63
64 def test_max_iterations(self):
65 self.minimizer.setTolerance(1e-10)
66 self.minimizer.setMaxIterations(1)
67 self.assertRaises(MinimizerMaxIterReached, self.minimizer.run,self.x0)
68
69 def test_solution(self):
70 self.minimizer.setTolerance(1e-8)
71 self.minimizer.setMaxIterations(100)
72 reason=self.minimizer.run(self.x0)
73 x=self.minimizer.getResult()
74 # We should be able to get a solution in under 100 iterations
75 self.assertEqual(reason, MinimizerLBFGS.TOLERANCE_REACHED)
76 self.assertAlmostEqual(np.amax(abs(x-self.xstar)), 0.)
77
78 def test_callback(self):
79 n=[0]
80 def callback(k, x, fg, gf):
81 n[0]=n[0]+1
82 self.minimizer.setCallback(callback)
83 self.minimizer.setTolerance(1e-8)
84 self.minimizer.setMaxIterations(10)
85 try:
86 reason=self.minimizer.run(self.x0)
87 except MinimizerMaxIterReached:
88 pass
89 # callback should be called once for each iteration (including 0th)
90 self.assertEqual(n[0], 11)
91
92 class TestMinimizerBFGS(unittest.TestCase):
93 def setUp(self):
94 self.f=RosenFunc()
95 self.minimizer=MinimizerBFGS(self.f)
96 self.x0=np.array([2.]*N)
97 self.xstar=np.array([1.]*N)
98
99 def test_max_iterations(self):
100 self.minimizer.setTolerance(1e-10)
101 self.minimizer.setMaxIterations(1)
102 reason=self.minimizer.run(self.x0)
103 self.assertEqual(reason, MinimizerBFGS.MAX_ITERATIONS_REACHED)
104
105 def test_solution(self):
106 self.minimizer.setTolerance(1e-6)
107 self.minimizer.setMaxIterations(100)
108 self.minimizer.setOptions(initialHessian=1e-3)
109 reason=self.minimizer.run(self.x0)
110 x=self.minimizer.getResult()
111 # We should be able to get a solution in under 100 iterations
112 self.assertEqual(reason, MinimizerBFGS.TOLERANCE_REACHED)
113 self.assertAlmostEqual(np.amax(abs(x-self.xstar)), 0.)
114
115 def test_callback(self):
116 n=[0]
117 def callback(k, x, fg, gf):
118 n[0]=n[0]+1
119 self.minimizer.setCallback(callback)
120 self.minimizer.setTolerance(1e-10)
121 self.minimizer.setMaxIterations(10)
122 reason=self.minimizer.run(self.x0)
123 # callback should be called once for each iteration (including 0th)
124 self.assertEqual(n[0], 11)
125
126 class TestMinimizerNLCG(unittest.TestCase):
127 def setUp(self):
128 self.f=RosenFunc()
129 self.minimizer=MinimizerNLCG(self.f)
130 self.x0=np.array([2.]*N)
131 self.xstar=np.array([1.]*N)
132
133 def test_max_iterations(self):
134 self.minimizer.setTolerance(1e-10)
135 self.minimizer.setMaxIterations(1)
136 reason=self.minimizer.run(self.x0)
137 self.assertEqual(reason, MinimizerNLCG.MAX_ITERATIONS_REACHED)
138
139 def test_solution(self):
140 self.minimizer.setTolerance(1e-4)
141 self.minimizer.setMaxIterations(400)
142 reason=self.minimizer.run(self.x0)
143 x=self.minimizer.getResult()
144 # We should be able to get a solution to set tolerance in #iterations
145 self.assertEqual(reason, MinimizerNLCG.TOLERANCE_REACHED)
146 self.assertAlmostEqual(np.amax(abs(x-self.xstar)), 0., places=3)
147
148 def test_callback(self):
149 n=[0]
150 def callback(k, x, fg, gf):
151 n[0]=n[0]+1
152 self.minimizer.setCallback(callback)
153 self.minimizer.setTolerance(1e-10)
154 self.minimizer.setMaxIterations(10)
155 reason=self.minimizer.run(self.x0)
156 # callback should be called once for each iteration (including 0th)
157 self.assertEqual(n[0], 11)
158
159
160 if __name__ == "__main__":
161 suite = unittest.TestSuite()
162 suite.addTest(unittest.makeSuite(TestMinimizerLBFGS))
163 suite.addTest(unittest.makeSuite(TestMinimizerBFGS))
164 suite.addTest(unittest.makeSuite(TestMinimizerNLCG))
165 s=unittest.TextTestRunner(verbosity=2).run(suite)
166 if not s.wasSuccessful(): sys.exit(1)
167

  ViewVC Help
Powered by ViewVC 1.1.26