/[escript]/trunk/downunder/test/python/run_datasources.py
ViewVC logotype

Diff of /trunk/downunder/test/python/run_datasources.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3985 by caltinay, Fri Sep 21 06:44:17 2012 UTC revision 4362 by caltinay, Tue Apr 16 04:37:13 2013 UTC
# Line 1  Line 1 
1    
2  ##############################################################################  ##############################################################################
3  #  #
4  # Copyright (c) 2003-2012 by University of Queensland  # Copyright (c) 2003-2013 by University of Queensland
5  # http://www.uq.edu.au  # http://www.uq.edu.au
6  #  #
7  # Primary Business: Queensland, Australia  # Primary Business: Queensland, Australia
# Line 13  Line 13 
13  #  #
14  ##############################################################################  ##############################################################################
15    
16  __copyright__="""Copyright (c) 2003-2012 by University of Queensland  __copyright__="""Copyright (c) 2003-2013 by University of Queensland
17  http://www.uq.edu.au  http://www.uq.edu.au
18  Primary Business: Queensland, Australia"""  Primary Business: Queensland, Australia"""
19  __license__="""Licensed under the Open Software License version 3.0  __license__="""Licensed under the Open Software License version 3.0
# Line 25  import numpy as np Line 25  import numpy as np
25  import os  import os
26  import sys  import sys
27  import unittest  import unittest
28  from esys.escript import inf,sup,saveDataCSV  from esys.escript import inf,sup,saveDataCSV,getMPISizeWorld
 from esys import ripley, finley, dudley  
29  from esys.downunder.datasources import *  from esys.downunder.datasources import *
30    from esys.downunder.domainbuilder import DomainBuilder
31    
32  # this is mainly to avoid warning messages  # this is mainly to avoid warning messages
33  logger=logging.getLogger('inv')  logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO)
 logger.setLevel(logging.INFO)  
 handler=logging.StreamHandler()  
 handler.setLevel(logging.INFO)  
 logger.addHandler(handler)  
34    
35  try:  try:
36      TEST_DATA_ROOT=os.environ['DOWNUNDER_TEST_DATA_ROOT']      TEST_DATA_ROOT=os.environ['DOWNUNDER_TEST_DATA_ROOT']
37  except KeyError:  except KeyError:
38      TEST_DATA_ROOT='.'      TEST_DATA_ROOT='ref_data'
39    
40  try:  try:
41      WORKDIR=os.environ['DOWNUNDER_WORKDIR']      WORKDIR=os.environ['DOWNUNDER_WORKDIR']
# Line 50  except KeyError: Line 46  except KeyError:
46  ERS_DATA = os.path.join(TEST_DATA_ROOT, 'ermapper_test.ers')  ERS_DATA = os.path.join(TEST_DATA_ROOT, 'ermapper_test.ers')
47  ERS_REF = os.path.join(TEST_DATA_ROOT, 'ermapper_test.csv')  ERS_REF = os.path.join(TEST_DATA_ROOT, 'ermapper_test.csv')
48  ERS_NULL = -99999 * 1e-6  ERS_NULL = -99999 * 1e-6
49  ERS_SIZE = [10,10]  ERS_SIZE = [20,15]
50  ERS_ORIGIN = [309097.0, 6321502.0]  ERS_ORIGIN = [309241.0, 6318655.0]
51    NC_DATA = os.path.join(TEST_DATA_ROOT, 'netcdf_test.nc')
52    NC_REF = os.path.join(TEST_DATA_ROOT, 'netcdf_test.csv')
53    NC_NULL = 0.
54    NC_SIZE = [20,15]
55    NC_ORIGIN = [403320.91466610413, 6414860.942530109]
56    NUMPY_NULL = -123.4
57  VMIN=-10000.  VMIN=-10000.
58  VMAX=10000  VMAX=10000.
59  NE_V=15  NE_V=15
60  ALT=0.  ALT=0.
61  PAD_XY=7  PAD_X=3
62  PAD_Z=3  PAD_Y=2
63    
64  class TestERSDataSource(unittest.TestCase):  class TestNumpyData(unittest.TestCase):
65        def test_numpy_argument_check(self):
66            # invalid data type
67            self.assertRaises(ValueError, NumpyData, '_mydatatype_', [1,2])
68            # invalid shape of data
69            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, 42)
70            # invalid shape of data
71            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, np.zeros((2,2,2,2)))
72            # invalid shape of error
73            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, [1,2], [1,2,3])
74            # invalid shape of length
75            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, [1,2], [1,2], [2,3,2])
76    
77        def test_numpy_data_1d(self):
78            DIM=1
79            testdata = np.arange(20)
80            error = 1.*np.ones(testdata.shape)
81            source = NumpyData(DataSource.GRAVITY, testdata, null_value=NUMPY_NULL)
82            X0,NP,DX=source.getDataExtents()
83            for i in range(DIM):
84                self.assertAlmostEqual(X0[i], 0., msg="Data origin wrong")
85                self.assertEqual(NP[i], testdata.shape[DIM-i-1], msg="Wrong number of data points")
86                self.assertAlmostEqual(DX[i], 1000./testdata.shape[DIM-i-1], msg="Wrong cell size")
87    
88            domainbuilder=DomainBuilder(dim=2)
89            domainbuilder.addSource(source)
90            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
91            domainbuilder.setElementPadding(PAD_X)
92            dom=domainbuilder.getDomain()
93            g,s=domainbuilder.getGravitySurveys()[0]
94    
95            outfn=os.path.join(WORKDIR, '_npdata1d.csv')
96            saveDataCSV(outfn, g=g, s=s)
97    
98            DV=(VMAX-VMIN)/NE_V
99    
100            # check data
101            nx=NP[0]+2*PAD_X
102            nz=NE_V
103            z_data=int(np.round((ALT-VMIN)/DV)-1)
104    
105            out=np.genfromtxt(outfn, delimiter=',', skip_header=1, dtype=float)
106            # recompute nz since ripley might have adjusted number of elements
107            nz=len(out)/nx
108            g_out=out[:,0].reshape(nz,nx)
109            s_out=out[:,1].reshape(nz,nx)
110            self.assertAlmostEqual(np.abs(
111                g_out[z_data, PAD_X:PAD_X+NP[0]]-testdata).max(),
112                0., msg="Difference in gravity data area")
113    
114            self.assertAlmostEqual(np.abs(
115                s_out[z_data, PAD_X:PAD_X+NP[0]]-error).max(),
116                0., msg="Difference in error data area")
117    
118            # overwrite data -> should only be padding value left
119            g_out[z_data, PAD_X:PAD_X+NP[0]]=NUMPY_NULL
120            self.assertAlmostEqual(np.abs(g_out-NUMPY_NULL).max(), 0.,
121                    msg="Wrong values in padding area")
122    
123        def test_numpy_data_2d(self):
124            DIM=2
125            testdata = np.arange(20*21).reshape(20,21)
126            error = 1.*np.ones(testdata.shape)
127            source = NumpyData(DataSource.GRAVITY, testdata, null_value=NUMPY_NULL)
128            X0,NP,DX=source.getDataExtents()
129            for i in range(DIM):
130                self.assertAlmostEqual(X0[i], 0., msg="Data origin wrong")
131                self.assertEqual(NP[i], testdata.shape[DIM-i-1], msg="Wrong number of data points")
132                self.assertAlmostEqual(DX[i], 1000./testdata.shape[DIM-i-1], msg="Wrong cell size")
133    
134            domainbuilder=DomainBuilder(dim=3)
135            domainbuilder.addSource(source)
136            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
137            domainbuilder.setElementPadding(PAD_X, PAD_Y)
138            dom=domainbuilder.getDomain()
139            g,s=domainbuilder.getGravitySurveys()[0]
140    
141            outfn=os.path.join(WORKDIR, '_npdata2d.csv')
142            saveDataCSV(outfn, g=g, s=s)
143    
144            DV=(VMAX-VMIN)/NE_V
145    
146            # check data
147            nx=NP[0]+2*PAD_X
148            ny=NP[1]+2*PAD_Y
149            nz=NE_V
150            z_data=int(np.round((ALT-VMIN)/DV)-1)
151    
152            out=np.genfromtxt(outfn, delimiter=',', skip_header=1, dtype=float)
153            # recompute nz since ripley might have adjusted number of elements
154            nz=len(out)/(nx*ny)
155            g_out=out[:,0].reshape(nz,ny,nx)
156            s_out=out[:,1].reshape(nz,ny,nx)
157            self.assertAlmostEqual(np.abs(
158                g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-testdata).max(),
159                0., msg="Difference in gravity data area")
160    
161            self.assertAlmostEqual(np.abs(
162                s_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-error).max(),
163                0., msg="Difference in error data area")
164    
165            # overwrite data -> should only be padding value left
166            g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]=NUMPY_NULL
167            self.assertAlmostEqual(np.abs(g_out-NUMPY_NULL).max(), 0.,
168                    msg="Wrong values in padding area")
169    
170    
171    class TestErMapperData(unittest.TestCase):
172      def test_ers_with_padding(self):      def test_ers_with_padding(self):
173          source = ERSDataSource(headerfile=ERS_DATA, vertical_extents=(VMIN,VMAX,NE_V), alt_of_data=ALT)          source = ErMapperData(DataSource.GRAVITY, headerfile=ERS_DATA,
174          source.setPadding(PAD_XY,PAD_Z)                                altitude=ALT)
175          dom=source.getDomain()          domainbuilder=DomainBuilder()
176          g,s=source.getGravityAndStdDev()          domainbuilder.addSource(source)
177            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
178            domainbuilder.setElementPadding(PAD_X,PAD_Y)
179            dom=domainbuilder.getDomain()
180            g,s=domainbuilder.getGravitySurveys()[0]
181    
182          outfn=os.path.join(WORKDIR, '_ersdata.csv')          outfn=os.path.join(WORKDIR, '_ersdata.csv')
183          saveDataCSV(outfn, g=g[2], s=s)          saveDataCSV(outfn, g=g, s=s)
184    
185          X0,NP,DX=source.getDataExtents()          X0,NP,DX=source.getDataExtents()
186          V0,NV,DV=source.getVerticalExtents()          DV=(VMAX-VMIN)/NE_V
187    
188          # check metadata          # check metadata
189          self.assertEqual(NP, ERS_SIZE, msg="Wrong number of data points")          self.assertEqual(NP, ERS_SIZE, msg="Wrong number of data points")
190          # this only works if gdal is available          # this test only works if gdal is available
191          #self.assertAlmostEqual(X0, ERS_ORIGIN, msg="Data origin wrong")          try:
192                import osgeo.osr
193                for i in range(len(ERS_ORIGIN)):
194                    self.assertAlmostEqual(X0[i], ERS_ORIGIN[i], msg="Data origin wrong")
195            except ImportError:
196                print("Skipping test of data origin since gdal is not installed.")
197    
198          # check data          # check data
199          nx=NP[0]+2*PAD_XY          nx=NP[0]+2*PAD_X
200          ny=NP[1]+2*PAD_XY          ny=NP[1]+2*PAD_Y
201          nz=NE_V+2*PAD_Z          nz=NE_V
202          z_data=int(np.round((ALT-V0)/DV)-1+PAD_Z)          z_data=int(np.round((ALT-VMIN)/DV)-1)
203    
204          ref=np.genfromtxt(ERS_REF, delimiter=',', dtype=float)          ref=np.genfromtxt(ERS_REF, delimiter=',', dtype=float)
205          g_ref=ref[:,0].reshape(NP)          g_ref=ref[:,0].reshape((NP[1],NP[0]))
206          s_ref=ref[:,1].reshape(NP)          s_ref=ref[:,1].reshape((NP[1],NP[0]))
207    
208          out=np.genfromtxt(outfn, delimiter=',', skip_header=1, dtype=float)          out=np.genfromtxt(outfn, delimiter=',', skip_header=1, dtype=float)
209            # recompute nz since ripley might have adjusted number of elements
210            nz=len(out)/(nx*ny)
211          g_out=out[:,0].reshape(nz,ny,nx)          g_out=out[:,0].reshape(nz,ny,nx)
212          s_out=out[:,1].reshape(nz,ny,nx)          s_out=out[:,1].reshape(nz,ny,nx)
213            self.assertAlmostEqual(np.abs(
214                g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-g_ref).max(),
215                0., msg="Difference in gravity data area")
216    
217          self.assertAlmostEqual(np.abs(          self.assertAlmostEqual(np.abs(
218              g_out[z_data,PAD_XY:PAD_XY+NP[1],PAD_XY:PAD_XY+NP[0]]-g_ref).max(),              s_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-s_ref).max(),
219              0., msg="Difference in data area")              0., msg="Difference in error data area")
220    
221          # overwrite data -> should only be padding value left          # overwrite data -> should only be padding value left
222          g_out[z_data, PAD_XY:PAD_XY+NP[0], PAD_XY:PAD_XY+NP[0]]=ERS_NULL          g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]=ERS_NULL
223          self.assertAlmostEqual(np.abs(g_out-ERS_NULL).max(), 0.,          self.assertAlmostEqual(np.abs(g_out-ERS_NULL).max(), 0.,
224                  msg="Wrong values in padding area")                  msg="Wrong values in padding area")
225    
226    class TestNetCdfData(unittest.TestCase):
227        def test_cdf_with_padding(self):
228            source = NetCdfData(DataSource.GRAVITY, NC_DATA, ALT)
229            domainbuilder=DomainBuilder()
230            domainbuilder.addSource(source)
231            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
232            domainbuilder.setElementPadding(PAD_X,PAD_Y)
233            dom=domainbuilder.getDomain()
234            g,s=domainbuilder.getGravitySurveys()[0]
235    
236            outfn=os.path.join(WORKDIR, '_ncdata.csv')
237            saveDataCSV(outfn, g=g, s=s)
238    
239            X0,NP,DX=source.getDataExtents()
240            DV=(VMAX-VMIN)/NE_V
241    
242            # check metadata
243            self.assertEqual(NP, NC_SIZE, msg="Wrong number of data points")
244            # this only works if gdal is available
245            try:
246                import osgeo.osr
247                for i in range(len(NC_ORIGIN)):
248                    self.assertAlmostEqual(X0[i], NC_ORIGIN[i], msg="Data origin wrong")
249            except ImportError:
250                print("Skipping test of data origin since gdal is not installed.")
251    
252            # check data
253            nx=NP[0]+2*PAD_X
254            ny=NP[1]+2*PAD_Y
255            nz=NE_V
256            z_data=int(np.round((ALT-VMIN)/DV)-1)
257    
258            ref=np.genfromtxt(NC_REF, delimiter=',', dtype=float)
259            g_ref=ref[:,0].reshape((NP[1],NP[0]))
260            s_ref=ref[:,1].reshape((NP[1],NP[0]))
261    
262            out=np.genfromtxt(outfn, delimiter=',', skip_header=1, dtype=float)
263            # recompute nz since ripley might have adjusted number of elements
264            nz=len(out)/(nx*ny)
265            g_out=out[:,0].reshape(nz,ny,nx)
266            s_out=out[:,1].reshape(nz,ny,nx)
267    
268            self.assertAlmostEqual(np.abs(
269                g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-g_ref).max(),
270                0., msg="Difference in gravity data area")
271    
272            self.assertAlmostEqual(np.abs(
273                s_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-s_ref).max(),
274                0., msg="Difference in error data area")
275    
276            # overwrite data -> should only be padding value left
277            g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]=NC_NULL
278            self.assertAlmostEqual(np.abs(g_out-NC_NULL).max(), 0.,
279                    msg="Wrong values in padding area")
280    
281  if __name__ == "__main__":  if __name__ == "__main__":
282      suite = unittest.TestSuite()      suite = unittest.TestSuite()
283      suite.addTest(unittest.makeSuite(TestERSDataSource))      if getMPISizeWorld()==1:
284            suite.addTest(unittest.makeSuite(TestNumpyData))
285            try:
286              import pyproj
287              haveproj=True
288            except ImportError:
289              haveproj=False
290            if haveproj:
291              suite.addTest(unittest.makeSuite(TestErMapperData))
292            if 'NetCdfData' in dir():
293                suite.addTest(unittest.makeSuite(TestNetCdfData))
294            else:
295                print("Skipping netCDF data source test since netCDF is not installed")
296        else:
297            print("Skipping data source tests since MPI size > 1")
298      s=unittest.TextTestRunner(verbosity=2).run(suite)      s=unittest.TextTestRunner(verbosity=2).run(suite)
299      if not s.wasSuccessful(): sys.exit(1)      if not s.wasSuccessful(): sys.exit(1)
300    

Legend:
Removed from v.3985  
changed lines
  Added in v.4362

  ViewVC Help
Powered by ViewVC 1.1.26