/[escript]/trunk/downunder/test/python/run_datasources.py
ViewVC logotype

Diff of /trunk/downunder/test/python/run_datasources.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 4014 by caltinay, Thu Oct 4 03:28:35 2012 UTC revision 4362 by caltinay, Tue Apr 16 04:37:13 2013 UTC
# Line 1  Line 1 
1    
2  ##############################################################################  ##############################################################################
3  #  #
4  # Copyright (c) 2003-2012 by University of Queensland  # Copyright (c) 2003-2013 by University of Queensland
5  # http://www.uq.edu.au  # http://www.uq.edu.au
6  #  #
7  # Primary Business: Queensland, Australia  # Primary Business: Queensland, Australia
# Line 13  Line 13 
13  #  #
14  ##############################################################################  ##############################################################################
15    
16  __copyright__="""Copyright (c) 2003-2012 by University of Queensland  __copyright__="""Copyright (c) 2003-2013 by University of Queensland
17  http://www.uq.edu.au  http://www.uq.edu.au
18  Primary Business: Queensland, Australia"""  Primary Business: Queensland, Australia"""
19  __license__="""Licensed under the Open Software License version 3.0  __license__="""Licensed under the Open Software License version 3.0
# Line 25  import numpy as np Line 25  import numpy as np
25  import os  import os
26  import sys  import sys
27  import unittest  import unittest
28  from esys.escript import inf,sup,saveDataCSV  from esys.escript import inf,sup,saveDataCSV,getMPISizeWorld
29  from esys.downunder.datasources import *  from esys.downunder.datasources import *
30    from esys.downunder.domainbuilder import DomainBuilder
31    
32  # this is mainly to avoid warning messages  # this is mainly to avoid warning messages
33  logger=logging.getLogger('inv')  logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO)
 logger.setLevel(logging.INFO)  
 handler=logging.StreamHandler()  
 handler.setLevel(logging.INFO)  
 logger.addHandler(handler)  
34    
35  try:  try:
36      TEST_DATA_ROOT=os.environ['DOWNUNDER_TEST_DATA_ROOT']      TEST_DATA_ROOT=os.environ['DOWNUNDER_TEST_DATA_ROOT']
37  except KeyError:  except KeyError:
38      TEST_DATA_ROOT='.'      TEST_DATA_ROOT='ref_data'
39    
40  try:  try:
41      WORKDIR=os.environ['DOWNUNDER_WORKDIR']      WORKDIR=os.environ['DOWNUNDER_WORKDIR']
# Line 50  ERS_DATA = os.path.join(TEST_DATA_ROOT, Line 47  ERS_DATA = os.path.join(TEST_DATA_ROOT,
47  ERS_REF = os.path.join(TEST_DATA_ROOT, 'ermapper_test.csv')  ERS_REF = os.path.join(TEST_DATA_ROOT, 'ermapper_test.csv')
48  ERS_NULL = -99999 * 1e-6  ERS_NULL = -99999 * 1e-6
49  ERS_SIZE = [20,15]  ERS_SIZE = [20,15]
50  ERS_ORIGIN = [309097.0, 6319002.0]  ERS_ORIGIN = [309241.0, 6318655.0]
51  NC_DATA = os.path.join(TEST_DATA_ROOT, 'netcdf_test.nc')  NC_DATA = os.path.join(TEST_DATA_ROOT, 'netcdf_test.nc')
52  NC_REF = os.path.join(TEST_DATA_ROOT, 'netcdf_test.csv')  NC_REF = os.path.join(TEST_DATA_ROOT, 'netcdf_test.csv')
53  NC_NULL = 0.  NC_NULL = 0.
54  NC_SIZE = [20,15]  NC_SIZE = [20,15]
55  NC_ORIGIN = [309097.0, 6319002.0]  NC_ORIGIN = [403320.91466610413, 6414860.942530109]
56    NUMPY_NULL = -123.4
57  VMIN=-10000.  VMIN=-10000.
58  VMAX=10000  VMAX=10000.
59  NE_V=15  NE_V=15
60  ALT=0.  ALT=0.
61  PAD_X=7  PAD_X=3
62  PAD_Y=9  PAD_Y=2
63    
64  class TestERSDataSource(unittest.TestCase):  class TestNumpyData(unittest.TestCase):
65        def test_numpy_argument_check(self):
66            # invalid data type
67            self.assertRaises(ValueError, NumpyData, '_mydatatype_', [1,2])
68            # invalid shape of data
69            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, 42)
70            # invalid shape of data
71            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, np.zeros((2,2,2,2)))
72            # invalid shape of error
73            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, [1,2], [1,2,3])
74            # invalid shape of length
75            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, [1,2], [1,2], [2,3,2])
76    
77        def test_numpy_data_1d(self):
78            DIM=1
79            testdata = np.arange(20)
80            error = 1.*np.ones(testdata.shape)
81            source = NumpyData(DataSource.GRAVITY, testdata, null_value=NUMPY_NULL)
82            X0,NP,DX=source.getDataExtents()
83            for i in range(DIM):
84                self.assertAlmostEqual(X0[i], 0., msg="Data origin wrong")
85                self.assertEqual(NP[i], testdata.shape[DIM-i-1], msg="Wrong number of data points")
86                self.assertAlmostEqual(DX[i], 1000./testdata.shape[DIM-i-1], msg="Wrong cell size")
87    
88            domainbuilder=DomainBuilder(dim=2)
89            domainbuilder.addSource(source)
90            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
91            domainbuilder.setElementPadding(PAD_X)
92            dom=domainbuilder.getDomain()
93            g,s=domainbuilder.getGravitySurveys()[0]
94    
95            outfn=os.path.join(WORKDIR, '_npdata1d.csv')
96            saveDataCSV(outfn, g=g, s=s)
97    
98            DV=(VMAX-VMIN)/NE_V
99    
100            # check data
101            nx=NP[0]+2*PAD_X
102            nz=NE_V
103            z_data=int(np.round((ALT-VMIN)/DV)-1)
104    
105            out=np.genfromtxt(outfn, delimiter=',', skip_header=1, dtype=float)
106            # recompute nz since ripley might have adjusted number of elements
107            nz=len(out)/nx
108            g_out=out[:,0].reshape(nz,nx)
109            s_out=out[:,1].reshape(nz,nx)
110            self.assertAlmostEqual(np.abs(
111                g_out[z_data, PAD_X:PAD_X+NP[0]]-testdata).max(),
112                0., msg="Difference in gravity data area")
113    
114            self.assertAlmostEqual(np.abs(
115                s_out[z_data, PAD_X:PAD_X+NP[0]]-error).max(),
116                0., msg="Difference in error data area")
117    
118            # overwrite data -> should only be padding value left
119            g_out[z_data, PAD_X:PAD_X+NP[0]]=NUMPY_NULL
120            self.assertAlmostEqual(np.abs(g_out-NUMPY_NULL).max(), 0.,
121                    msg="Wrong values in padding area")
122    
123        def test_numpy_data_2d(self):
124            DIM=2
125            testdata = np.arange(20*21).reshape(20,21)
126            error = 1.*np.ones(testdata.shape)
127            source = NumpyData(DataSource.GRAVITY, testdata, null_value=NUMPY_NULL)
128            X0,NP,DX=source.getDataExtents()
129            for i in range(DIM):
130                self.assertAlmostEqual(X0[i], 0., msg="Data origin wrong")
131                self.assertEqual(NP[i], testdata.shape[DIM-i-1], msg="Wrong number of data points")
132                self.assertAlmostEqual(DX[i], 1000./testdata.shape[DIM-i-1], msg="Wrong cell size")
133    
134            domainbuilder=DomainBuilder(dim=3)
135            domainbuilder.addSource(source)
136            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
137            domainbuilder.setElementPadding(PAD_X, PAD_Y)
138            dom=domainbuilder.getDomain()
139            g,s=domainbuilder.getGravitySurveys()[0]
140    
141            outfn=os.path.join(WORKDIR, '_npdata2d.csv')
142            saveDataCSV(outfn, g=g, s=s)
143    
144            DV=(VMAX-VMIN)/NE_V
145    
146            # check data
147            nx=NP[0]+2*PAD_X
148            ny=NP[1]+2*PAD_Y
149            nz=NE_V
150            z_data=int(np.round((ALT-VMIN)/DV)-1)
151    
152            out=np.genfromtxt(outfn, delimiter=',', skip_header=1, dtype=float)
153            # recompute nz since ripley might have adjusted number of elements
154            nz=len(out)/(nx*ny)
155            g_out=out[:,0].reshape(nz,ny,nx)
156            s_out=out[:,1].reshape(nz,ny,nx)
157            self.assertAlmostEqual(np.abs(
158                g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-testdata).max(),
159                0., msg="Difference in gravity data area")
160    
161            self.assertAlmostEqual(np.abs(
162                s_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-error).max(),
163                0., msg="Difference in error data area")
164    
165            # overwrite data -> should only be padding value left
166            g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]=NUMPY_NULL
167            self.assertAlmostEqual(np.abs(g_out-NUMPY_NULL).max(), 0.,
168                    msg="Wrong values in padding area")
169    
170    
171    class TestErMapperData(unittest.TestCase):
172      def test_ers_with_padding(self):      def test_ers_with_padding(self):
173          source = ERSDataSource(headerfile=ERS_DATA, vertical_extents=(VMIN,VMAX,NE_V), alt_of_data=ALT)          source = ErMapperData(DataSource.GRAVITY, headerfile=ERS_DATA,
174          source.setPadding(PAD_X,PAD_Y)                                altitude=ALT)
175          dom=source.getDomain()          domainbuilder=DomainBuilder()
176          g,s=source.getGravityAndStdDev()          domainbuilder.addSource(source)
177            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
178            domainbuilder.setElementPadding(PAD_X,PAD_Y)
179            dom=domainbuilder.getDomain()
180            g,s=domainbuilder.getGravitySurveys()[0]
181    
182          outfn=os.path.join(WORKDIR, '_ersdata.csv')          outfn=os.path.join(WORKDIR, '_ersdata.csv')
183          saveDataCSV(outfn, g=g[2], s=s)          saveDataCSV(outfn, g=g, s=s)
184    
185          X0,NP,DX=source.getDataExtents()          X0,NP,DX=source.getDataExtents()
186          V0,NV,DV=source.getVerticalExtents()          DV=(VMAX-VMIN)/NE_V
187    
188          # check metadata          # check metadata
189          self.assertEqual(NP, ERS_SIZE, msg="Wrong number of data points")          self.assertEqual(NP, ERS_SIZE, msg="Wrong number of data points")
190          # this test only works if gdal is available          # this test only works if gdal is available
191          try:          try:
192              import osgeo.osr              import osgeo.osr
193              for i in xrange(len(ERS_ORIGIN)):              for i in range(len(ERS_ORIGIN)):
194                  self.assertAlmostEqual(X0[i], ERS_ORIGIN[i], msg="Data origin wrong")                  self.assertAlmostEqual(X0[i], ERS_ORIGIN[i], msg="Data origin wrong")
195          except ImportError:          except ImportError:
196              print("Skipping test of data origin since gdal is not installed.")              print("Skipping test of data origin since gdal is not installed.")
# Line 90  class TestERSDataSource(unittest.TestCas Line 199  class TestERSDataSource(unittest.TestCas
199          nx=NP[0]+2*PAD_X          nx=NP[0]+2*PAD_X
200          ny=NP[1]+2*PAD_Y          ny=NP[1]+2*PAD_Y
201          nz=NE_V          nz=NE_V
202          z_data=int(np.round((ALT-V0)/DV)-1)          z_data=int(np.round((ALT-VMIN)/DV)-1)
203    
204          ref=np.genfromtxt(ERS_REF, delimiter=',', dtype=float)          ref=np.genfromtxt(ERS_REF, delimiter=',', dtype=float)
205          g_ref=ref[:,0].reshape((NP[1],NP[0]))          g_ref=ref[:,0].reshape((NP[1],NP[0]))
# Line 114  class TestERSDataSource(unittest.TestCas Line 223  class TestERSDataSource(unittest.TestCas
223          self.assertAlmostEqual(np.abs(g_out-ERS_NULL).max(), 0.,          self.assertAlmostEqual(np.abs(g_out-ERS_NULL).max(), 0.,
224                  msg="Wrong values in padding area")                  msg="Wrong values in padding area")
225    
226  class TestNetCDFDataSource(unittest.TestCase):  class TestNetCdfData(unittest.TestCase):
227      def test_cdf_with_padding(self):      def test_cdf_with_padding(self):
228          source = NetCDFDataSource(gravfile=NC_DATA, vertical_extents=(VMIN,VMAX,NE_V), alt_of_data=ALT)          source = NetCdfData(DataSource.GRAVITY, NC_DATA, ALT)
229          source.setPadding(PAD_X,PAD_Y)          domainbuilder=DomainBuilder()
230          dom=source.getDomain()          domainbuilder.addSource(source)
231          g,s=source.getGravityAndStdDev()          domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
232            domainbuilder.setElementPadding(PAD_X,PAD_Y)
233            dom=domainbuilder.getDomain()
234            g,s=domainbuilder.getGravitySurveys()[0]
235    
236          outfn=os.path.join(WORKDIR, '_ncdata.csv')          outfn=os.path.join(WORKDIR, '_ncdata.csv')
237          saveDataCSV(outfn, g=g[2], s=s)          saveDataCSV(outfn, g=g, s=s)
238    
239          X0,NP,DX=source.getDataExtents()          X0,NP,DX=source.getDataExtents()
240          V0,NV,DV=source.getVerticalExtents()          DV=(VMAX-VMIN)/NE_V
241    
242          # check metadata          # check metadata
243          self.assertEqual(NP, NC_SIZE, msg="Wrong number of data points")          self.assertEqual(NP, NC_SIZE, msg="Wrong number of data points")
244          # this only works if gdal is available          # this only works if gdal is available
245          #self.assertAlmostEqual(X0, NC_ORIGIN, msg="Data origin wrong")          try:
246                import osgeo.osr
247                for i in range(len(NC_ORIGIN)):
248                    self.assertAlmostEqual(X0[i], NC_ORIGIN[i], msg="Data origin wrong")
249            except ImportError:
250                print("Skipping test of data origin since gdal is not installed.")
251    
252          # check data          # check data
253          nx=NP[0]+2*PAD_X          nx=NP[0]+2*PAD_X
254          ny=NP[1]+2*PAD_Y          ny=NP[1]+2*PAD_Y
255          nz=NE_V          nz=NE_V
256          z_data=int(np.round((ALT-V0)/DV)-1)          z_data=int(np.round((ALT-VMIN)/DV)-1)
257    
258          ref=np.genfromtxt(NC_REF, delimiter=',', dtype=float)          ref=np.genfromtxt(NC_REF, delimiter=',', dtype=float)
259          g_ref=ref[:,0].reshape((NP[1],NP[0]))          g_ref=ref[:,0].reshape((NP[1],NP[0]))
# Line 163  class TestNetCDFDataSource(unittest.Test Line 280  class TestNetCDFDataSource(unittest.Test
280    
281  if __name__ == "__main__":  if __name__ == "__main__":
282      suite = unittest.TestSuite()      suite = unittest.TestSuite()
283      suite.addTest(unittest.makeSuite(TestERSDataSource))      if getMPISizeWorld()==1:
284      if 'NetCDFDataSource' in dir():          suite.addTest(unittest.makeSuite(TestNumpyData))
285          suite.addTest(unittest.makeSuite(TestNetCDFDataSource))          try:
286              import pyproj
287              haveproj=True
288            except ImportError:
289              haveproj=False
290            if haveproj:
291              suite.addTest(unittest.makeSuite(TestErMapperData))
292            if 'NetCdfData' in dir():
293                suite.addTest(unittest.makeSuite(TestNetCdfData))
294            else:
295                print("Skipping netCDF data source test since netCDF is not installed")
296      else:      else:
297          print("Skipping netCDF data source test since netCDF is not installed")          print("Skipping data source tests since MPI size > 1")
298      s=unittest.TextTestRunner(verbosity=2).run(suite)      s=unittest.TextTestRunner(verbosity=2).run(suite)
299      if not s.wasSuccessful(): sys.exit(1)      if not s.wasSuccessful(): sys.exit(1)
300    

Legend:
Removed from v.4014  
changed lines
  Added in v.4362

  ViewVC Help
Powered by ViewVC 1.1.26