/[escript]/trunk/downunder/test/python/run_minimizers.py
ViewVC logotype

Contents of /trunk/downunder/test/python/run_minimizers.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4154 - (show annotations)
Tue Jan 22 09:30:23 2013 UTC (7 years, 6 months ago) by jfenwick
File MIME type: text/x-python
File size: 6089 byte(s)
Round 1 of copyright fixes
1
2 ##############################################################################
3 #
4 # Copyright (c) 2012-2013 by University of Queensland
5 # http://www.uq.edu.au
6 #
7 # Primary Business: Queensland, Australia
8 # Licensed under the Open Software License version 3.0
9 # http://www.opensource.org/licenses/osl-3.0.php
10 #
11 # Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 # Development since 2012 by School of Earth Sciences
13 #
14 ##############################################################################
15
16 __copyright__="""Copyright (c) 2012-2013 by University of Queensland
17 http://www.uq.edu.au
18 Primary Business: Queensland, Australia"""
19 __license__="""Licensed under the Open Software License version 3.0
20 http://www.opensource.org/licenses/osl-3.0.php"""
21 __url__="https://launchpad.net/escript-finley"
22
23 import logging
24 import numpy as np
25 import unittest
26 import sys
27 from esys.downunder.minimizers import *
28 from esys.downunder.costfunctions import CostFunction
29
30 # number of dimensions for the test function
31 N=10
32
33 # this is mainly to avoid warning messages
34 logger=logging.getLogger('inv')
35 logger.setLevel(logging.INFO)
36 handler=logging.StreamHandler()
37 handler.setLevel(logging.INFO)
38 logger.addHandler(handler)
39
40 # Rosenbrock test function to be minimized. The minimum is 0 and lies at
41 # [1,1,...,1].
42 class RosenFunc(CostFunction):
43 def __init__(self):
44 super(RosenFunc, self).__init__()
45 def getDualProduct(self, f0, f1):
46 return np.dot(f0, f1)
47 def getNorm(self,x):
48 return (abs(x.max()))
49 def getGradient(self, x, *args):
50 xm = x[1:-1]
51 xm_m1 = x[:-2]
52 xm_p1 = x[2:]
53 der = np.zeros_like(x)
54 der[1:-1] = 200*(xm-xm_m1**2) - 400*(xm_p1 - xm**2)*xm - 2*(1-xm)
55 der[0] = -400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0])
56 der[-1] = 200*(x[-1]-x[-2]**2)
57 return der
58 def getValue(self, x, *args):
59 return np.sum(100.0*(x[1:]-x[:-1]**2.)**2. + (1-x[:-1])**2.)
60
61 class TestMinimizerLBFGS(unittest.TestCase):
62 def setUp(self):
63 self.f=RosenFunc()
64 self.minimizer=MinimizerLBFGS(self.f)
65 self.x0=np.array([2.]*N)
66 self.xstar=np.array([1.]*N)
67
68 def test_max_iterations(self):
69 self.minimizer.setTolerance(1e-10)
70 self.minimizer.setMaxIterations(1)
71 self.assertRaises(MinimizerMaxIterReached, self.minimizer.run,self.x0)
72
73 def test_solution(self):
74 self.minimizer.setTolerance(1e-8)
75 self.minimizer.setMaxIterations(100)
76 reason=self.minimizer.run(self.x0)
77 x=self.minimizer.getResult()
78 # We should be able to get a solution in under 100 iterations
79 self.assertEqual(reason, MinimizerLBFGS.TOLERANCE_REACHED)
80 self.assertAlmostEqual(np.amax(abs(x-self.xstar)), 0.)
81
82 def test_callback(self):
83 n=[0]
84 def callback(k, x, fg, gf):
85 n[0]=n[0]+1
86 self.minimizer.setCallback(callback)
87 self.minimizer.setTolerance(1e-8)
88 self.minimizer.setMaxIterations(10)
89 try:
90 reason=self.minimizer.run(self.x0)
91 except MinimizerMaxIterReached:
92 pass
93 # callback should be called once for each iteration (including 0th)
94 self.assertEqual(n[0], 11)
95
96 class TestMinimizerBFGS(unittest.TestCase):
97 def setUp(self):
98 self.f=RosenFunc()
99 self.minimizer=MinimizerBFGS(self.f)
100 self.x0=np.array([2.]*N)
101 self.xstar=np.array([1.]*N)
102
103 def test_max_iterations(self):
104 self.minimizer.setTolerance(1e-10)
105 self.minimizer.setMaxIterations(1)
106 reason=self.minimizer.run(self.x0)
107 self.assertEqual(reason, MinimizerBFGS.MAX_ITERATIONS_REACHED)
108
109 def test_solution(self):
110 self.minimizer.setTolerance(1e-6)
111 self.minimizer.setMaxIterations(100)
112 self.minimizer.setOptions(initialHessian=1e-3)
113 reason=self.minimizer.run(self.x0)
114 x=self.minimizer.getResult()
115 # We should be able to get a solution in under 100 iterations
116 self.assertEqual(reason, MinimizerBFGS.TOLERANCE_REACHED)
117 self.assertAlmostEqual(np.amax(abs(x-self.xstar)), 0.)
118
119 def test_callback(self):
120 n=[0]
121 def callback(k, x, fg, gf):
122 n[0]=n[0]+1
123 self.minimizer.setCallback(callback)
124 self.minimizer.setTolerance(1e-10)
125 self.minimizer.setMaxIterations(10)
126 reason=self.minimizer.run(self.x0)
127 # callback should be called once for each iteration (including 0th)
128 self.assertEqual(n[0], 11)
129
130 class TestMinimizerNLCG(unittest.TestCase):
131 def setUp(self):
132 self.f=RosenFunc()
133 self.minimizer=MinimizerNLCG(self.f)
134 self.x0=np.array([2.]*N)
135 self.xstar=np.array([1.]*N)
136
137 def test_max_iterations(self):
138 self.minimizer.setTolerance(1e-10)
139 self.minimizer.setMaxIterations(1)
140 reason=self.minimizer.run(self.x0)
141 self.assertEqual(reason, MinimizerNLCG.MAX_ITERATIONS_REACHED)
142
143 def test_solution(self):
144 self.minimizer.setTolerance(1e-4)
145 self.minimizer.setMaxIterations(400)
146 reason=self.minimizer.run(self.x0)
147 x=self.minimizer.getResult()
148 # We should be able to get a solution to set tolerance in #iterations
149 self.assertEqual(reason, MinimizerNLCG.TOLERANCE_REACHED)
150 self.assertAlmostEqual(np.amax(abs(x-self.xstar)), 0., places=3)
151
152 def test_callback(self):
153 n=[0]
154 def callback(k, x, fg, gf):
155 n[0]=n[0]+1
156 self.minimizer.setCallback(callback)
157 self.minimizer.setTolerance(1e-10)
158 self.minimizer.setMaxIterations(10)
159 reason=self.minimizer.run(self.x0)
160 # callback should be called once for each iteration (including 0th)
161 self.assertEqual(n[0], 11)
162
163
164 if __name__ == "__main__":
165 suite = unittest.TestSuite()
166 suite.addTest(unittest.makeSuite(TestMinimizerLBFGS))
167 suite.addTest(unittest.makeSuite(TestMinimizerBFGS))
168 suite.addTest(unittest.makeSuite(TestMinimizerNLCG))
169 s=unittest.TextTestRunner(verbosity=2).run(suite)
170 if not s.wasSuccessful(): sys.exit(1)
171

  ViewVC Help
Powered by ViewVC 1.1.26