/[escript]/trunk/downunder/test/python/run_datasources.py
ViewVC logotype

Diff of /trunk/downunder/test/python/run_datasources.py

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 4016 by caltinay, Tue Oct 9 03:50:27 2012 UTC revision 4362 by caltinay, Tue Apr 16 04:37:13 2013 UTC
# Line 1  Line 1 
1    
2  ##############################################################################  ##############################################################################
3  #  #
4  # Copyright (c) 2003-2012 by University of Queensland  # Copyright (c) 2003-2013 by University of Queensland
5  # http://www.uq.edu.au  # http://www.uq.edu.au
6  #  #
7  # Primary Business: Queensland, Australia  # Primary Business: Queensland, Australia
# Line 13  Line 13 
13  #  #
14  ##############################################################################  ##############################################################################
15    
16  __copyright__="""Copyright (c) 2003-2012 by University of Queensland  __copyright__="""Copyright (c) 2003-2013 by University of Queensland
17  http://www.uq.edu.au  http://www.uq.edu.au
18  Primary Business: Queensland, Australia"""  Primary Business: Queensland, Australia"""
19  __license__="""Licensed under the Open Software License version 3.0  __license__="""Licensed under the Open Software License version 3.0
# Line 27  import sys Line 27  import sys
27  import unittest  import unittest
28  from esys.escript import inf,sup,saveDataCSV,getMPISizeWorld  from esys.escript import inf,sup,saveDataCSV,getMPISizeWorld
29  from esys.downunder.datasources import *  from esys.downunder.datasources import *
30    from esys.downunder.domainbuilder import DomainBuilder
31    
32  # this is mainly to avoid warning messages  # this is mainly to avoid warning messages
33  logger=logging.getLogger('inv')  logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO)
 logger.setLevel(logging.INFO)  
 handler=logging.StreamHandler()  
 handler.setLevel(logging.INFO)  
 logger.addHandler(handler)  
34    
35  try:  try:
36      TEST_DATA_ROOT=os.environ['DOWNUNDER_TEST_DATA_ROOT']      TEST_DATA_ROOT=os.environ['DOWNUNDER_TEST_DATA_ROOT']
# Line 50  ERS_DATA = os.path.join(TEST_DATA_ROOT, Line 47  ERS_DATA = os.path.join(TEST_DATA_ROOT,
47  ERS_REF = os.path.join(TEST_DATA_ROOT, 'ermapper_test.csv')  ERS_REF = os.path.join(TEST_DATA_ROOT, 'ermapper_test.csv')
48  ERS_NULL = -99999 * 1e-6  ERS_NULL = -99999 * 1e-6
49  ERS_SIZE = [20,15]  ERS_SIZE = [20,15]
50  ERS_ORIGIN = [309097.0, 6319002.0]  ERS_ORIGIN = [309241.0, 6318655.0]
51  NC_DATA = os.path.join(TEST_DATA_ROOT, 'netcdf_test.nc')  NC_DATA = os.path.join(TEST_DATA_ROOT, 'netcdf_test.nc')
52  NC_REF = os.path.join(TEST_DATA_ROOT, 'netcdf_test.csv')  NC_REF = os.path.join(TEST_DATA_ROOT, 'netcdf_test.csv')
53  NC_NULL = 0.  NC_NULL = 0.
54  NC_SIZE = [20,15]  NC_SIZE = [20,15]
55  NC_ORIGIN = [403320.91466610413, 6414860.942530109]  NC_ORIGIN = [403320.91466610413, 6414860.942530109]
56    NUMPY_NULL = -123.4
57  VMIN=-10000.  VMIN=-10000.
58  VMAX=10000  VMAX=10000.
59  NE_V=15  NE_V=15
60  ALT=0.  ALT=0.
61  PAD_X=3  PAD_X=3
62  PAD_Y=2  PAD_Y=2
63    
64  class TestERSDataSource(unittest.TestCase):  class TestNumpyData(unittest.TestCase):
65        def test_numpy_argument_check(self):
66            # invalid data type
67            self.assertRaises(ValueError, NumpyData, '_mydatatype_', [1,2])
68            # invalid shape of data
69            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, 42)
70            # invalid shape of data
71            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, np.zeros((2,2,2,2)))
72            # invalid shape of error
73            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, [1,2], [1,2,3])
74            # invalid shape of length
75            self.assertRaises(ValueError, NumpyData, DataSource.GRAVITY, [1,2], [1,2], [2,3,2])
76    
77        def test_numpy_data_1d(self):
78            DIM=1
79            testdata = np.arange(20)
80            error = 1.*np.ones(testdata.shape)
81            source = NumpyData(DataSource.GRAVITY, testdata, null_value=NUMPY_NULL)
82            X0,NP,DX=source.getDataExtents()
83            for i in range(DIM):
84                self.assertAlmostEqual(X0[i], 0., msg="Data origin wrong")
85                self.assertEqual(NP[i], testdata.shape[DIM-i-1], msg="Wrong number of data points")
86                self.assertAlmostEqual(DX[i], 1000./testdata.shape[DIM-i-1], msg="Wrong cell size")
87    
88            domainbuilder=DomainBuilder(dim=2)
89            domainbuilder.addSource(source)
90            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
91            domainbuilder.setElementPadding(PAD_X)
92            dom=domainbuilder.getDomain()
93            g,s=domainbuilder.getGravitySurveys()[0]
94    
95            outfn=os.path.join(WORKDIR, '_npdata1d.csv')
96            saveDataCSV(outfn, g=g, s=s)
97    
98            DV=(VMAX-VMIN)/NE_V
99    
100            # check data
101            nx=NP[0]+2*PAD_X
102            nz=NE_V
103            z_data=int(np.round((ALT-VMIN)/DV)-1)
104    
105            out=np.genfromtxt(outfn, delimiter=',', skip_header=1, dtype=float)
106            # recompute nz since ripley might have adjusted number of elements
107            nz=len(out)/nx
108            g_out=out[:,0].reshape(nz,nx)
109            s_out=out[:,1].reshape(nz,nx)
110            self.assertAlmostEqual(np.abs(
111                g_out[z_data, PAD_X:PAD_X+NP[0]]-testdata).max(),
112                0., msg="Difference in gravity data area")
113    
114            self.assertAlmostEqual(np.abs(
115                s_out[z_data, PAD_X:PAD_X+NP[0]]-error).max(),
116                0., msg="Difference in error data area")
117    
118            # overwrite data -> should only be padding value left
119            g_out[z_data, PAD_X:PAD_X+NP[0]]=NUMPY_NULL
120            self.assertAlmostEqual(np.abs(g_out-NUMPY_NULL).max(), 0.,
121                    msg="Wrong values in padding area")
122    
123        def test_numpy_data_2d(self):
124            DIM=2
125            testdata = np.arange(20*21).reshape(20,21)
126            error = 1.*np.ones(testdata.shape)
127            source = NumpyData(DataSource.GRAVITY, testdata, null_value=NUMPY_NULL)
128            X0,NP,DX=source.getDataExtents()
129            for i in range(DIM):
130                self.assertAlmostEqual(X0[i], 0., msg="Data origin wrong")
131                self.assertEqual(NP[i], testdata.shape[DIM-i-1], msg="Wrong number of data points")
132                self.assertAlmostEqual(DX[i], 1000./testdata.shape[DIM-i-1], msg="Wrong cell size")
133    
134            domainbuilder=DomainBuilder(dim=3)
135            domainbuilder.addSource(source)
136            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
137            domainbuilder.setElementPadding(PAD_X, PAD_Y)
138            dom=domainbuilder.getDomain()
139            g,s=domainbuilder.getGravitySurveys()[0]
140    
141            outfn=os.path.join(WORKDIR, '_npdata2d.csv')
142            saveDataCSV(outfn, g=g, s=s)
143    
144            DV=(VMAX-VMIN)/NE_V
145    
146            # check data
147            nx=NP[0]+2*PAD_X
148            ny=NP[1]+2*PAD_Y
149            nz=NE_V
150            z_data=int(np.round((ALT-VMIN)/DV)-1)
151    
152            out=np.genfromtxt(outfn, delimiter=',', skip_header=1, dtype=float)
153            # recompute nz since ripley might have adjusted number of elements
154            nz=len(out)/(nx*ny)
155            g_out=out[:,0].reshape(nz,ny,nx)
156            s_out=out[:,1].reshape(nz,ny,nx)
157            self.assertAlmostEqual(np.abs(
158                g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-testdata).max(),
159                0., msg="Difference in gravity data area")
160    
161            self.assertAlmostEqual(np.abs(
162                s_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]-error).max(),
163                0., msg="Difference in error data area")
164    
165            # overwrite data -> should only be padding value left
166            g_out[z_data, PAD_Y:PAD_Y+NP[1], PAD_X:PAD_X+NP[0]]=NUMPY_NULL
167            self.assertAlmostEqual(np.abs(g_out-NUMPY_NULL).max(), 0.,
168                    msg="Wrong values in padding area")
169    
170    
171    class TestErMapperData(unittest.TestCase):
172      def test_ers_with_padding(self):      def test_ers_with_padding(self):
173          source = ERSDataSource(headerfile=ERS_DATA, vertical_extents=(VMIN,VMAX,NE_V), alt_of_data=ALT)          source = ErMapperData(DataSource.GRAVITY, headerfile=ERS_DATA,
174          source.setPadding(PAD_X,PAD_Y)                                altitude=ALT)
175          dom=source.getDomain()          domainbuilder=DomainBuilder()
176          g,s=source.getGravityAndStdDev()          domainbuilder.addSource(source)
177            domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
178            domainbuilder.setElementPadding(PAD_X,PAD_Y)
179            dom=domainbuilder.getDomain()
180            g,s=domainbuilder.getGravitySurveys()[0]
181    
182          outfn=os.path.join(WORKDIR, '_ersdata.csv')          outfn=os.path.join(WORKDIR, '_ersdata.csv')
183          saveDataCSV(outfn, g=g[2], s=s)          saveDataCSV(outfn, g=g, s=s)
184    
185          X0,NP,DX=source.getDataExtents()          X0,NP,DX=source.getDataExtents()
186          V0,NV,DV=source.getVerticalExtents()          DV=(VMAX-VMIN)/NE_V
187    
188          # check metadata          # check metadata
189          self.assertEqual(NP, ERS_SIZE, msg="Wrong number of data points")          self.assertEqual(NP, ERS_SIZE, msg="Wrong number of data points")
190          # this test only works if gdal is available          # this test only works if gdal is available
191          try:          try:
192              import osgeo.osr              import osgeo.osr
193              for i in xrange(len(ERS_ORIGIN)):              for i in range(len(ERS_ORIGIN)):
194                  self.assertAlmostEqual(X0[i], ERS_ORIGIN[i], msg="Data origin wrong")                  self.assertAlmostEqual(X0[i], ERS_ORIGIN[i], msg="Data origin wrong")
195          except ImportError:          except ImportError:
196              print("Skipping test of data origin since gdal is not installed.")              print("Skipping test of data origin since gdal is not installed.")
# Line 90  class TestERSDataSource(unittest.TestCas Line 199  class TestERSDataSource(unittest.TestCas
199          nx=NP[0]+2*PAD_X          nx=NP[0]+2*PAD_X
200          ny=NP[1]+2*PAD_Y          ny=NP[1]+2*PAD_Y
201          nz=NE_V          nz=NE_V
202          z_data=int(np.round((ALT-V0)/DV)-1)          z_data=int(np.round((ALT-VMIN)/DV)-1)
203    
204          ref=np.genfromtxt(ERS_REF, delimiter=',', dtype=float)          ref=np.genfromtxt(ERS_REF, delimiter=',', dtype=float)
205          g_ref=ref[:,0].reshape((NP[1],NP[0]))          g_ref=ref[:,0].reshape((NP[1],NP[0]))
# Line 114  class TestERSDataSource(unittest.TestCas Line 223  class TestERSDataSource(unittest.TestCas
223          self.assertAlmostEqual(np.abs(g_out-ERS_NULL).max(), 0.,          self.assertAlmostEqual(np.abs(g_out-ERS_NULL).max(), 0.,
224                  msg="Wrong values in padding area")                  msg="Wrong values in padding area")
225    
226  class TestNetCDFDataSource(unittest.TestCase):  class TestNetCdfData(unittest.TestCase):
227      def test_cdf_with_padding(self):      def test_cdf_with_padding(self):
228          source = NetCDFDataSource(gravfile=NC_DATA, vertical_extents=(VMIN,VMAX,NE_V), alt_of_data=ALT)          source = NetCdfData(DataSource.GRAVITY, NC_DATA, ALT)
229          source.setPadding(PAD_X,PAD_Y)          domainbuilder=DomainBuilder()
230          dom=source.getDomain()          domainbuilder.addSource(source)
231          g,s=source.getGravityAndStdDev()          domainbuilder.setVerticalExtents(depth=-VMIN, air_layer=VMAX, num_cells=NE_V)
232            domainbuilder.setElementPadding(PAD_X,PAD_Y)
233            dom=domainbuilder.getDomain()
234            g,s=domainbuilder.getGravitySurveys()[0]
235    
236          outfn=os.path.join(WORKDIR, '_ncdata.csv')          outfn=os.path.join(WORKDIR, '_ncdata.csv')
237          saveDataCSV(outfn, g=g[2], s=s)          saveDataCSV(outfn, g=g, s=s)
238    
239          X0,NP,DX=source.getDataExtents()          X0,NP,DX=source.getDataExtents()
240          V0,NV,DV=source.getVerticalExtents()          DV=(VMAX-VMIN)/NE_V
241    
242          # check metadata          # check metadata
243          self.assertEqual(NP, NC_SIZE, msg="Wrong number of data points")          self.assertEqual(NP, NC_SIZE, msg="Wrong number of data points")
244          # this only works if gdal is available          # this only works if gdal is available
245          try:          try:
246              import osgeo.osr              import osgeo.osr
247              for i in xrange(len(NC_ORIGIN)):              for i in range(len(NC_ORIGIN)):
248                  self.assertAlmostEqual(X0[i], NC_ORIGIN[i], msg="Data origin wrong")                  self.assertAlmostEqual(X0[i], NC_ORIGIN[i], msg="Data origin wrong")
249          except ImportError:          except ImportError:
250              print("Skipping test of data origin since gdal is not installed.")              print("Skipping test of data origin since gdal is not installed.")
# Line 141  class TestNetCDFDataSource(unittest.Test Line 253  class TestNetCDFDataSource(unittest.Test
253          nx=NP[0]+2*PAD_X          nx=NP[0]+2*PAD_X
254          ny=NP[1]+2*PAD_Y          ny=NP[1]+2*PAD_Y
255          nz=NE_V          nz=NE_V
256          z_data=int(np.round((ALT-V0)/DV)-1)          z_data=int(np.round((ALT-VMIN)/DV)-1)
257    
258          ref=np.genfromtxt(NC_REF, delimiter=',', dtype=float)          ref=np.genfromtxt(NC_REF, delimiter=',', dtype=float)
259          g_ref=ref[:,0].reshape((NP[1],NP[0]))          g_ref=ref[:,0].reshape((NP[1],NP[0]))
# Line 169  class TestNetCDFDataSource(unittest.Test Line 281  class TestNetCDFDataSource(unittest.Test
281  if __name__ == "__main__":  if __name__ == "__main__":
282      suite = unittest.TestSuite()      suite = unittest.TestSuite()
283      if getMPISizeWorld()==1:      if getMPISizeWorld()==1:
284          suite.addTest(unittest.makeSuite(TestERSDataSource))          suite.addTest(unittest.makeSuite(TestNumpyData))
285          if 'NetCDFDataSource' in dir():          try:
286              suite.addTest(unittest.makeSuite(TestNetCDFDataSource))            import pyproj
287              haveproj=True
288            except ImportError:
289              haveproj=False
290            if haveproj:
291              suite.addTest(unittest.makeSuite(TestErMapperData))
292            if 'NetCdfData' in dir():
293                suite.addTest(unittest.makeSuite(TestNetCdfData))
294          else:          else:
295              print("Skipping netCDF data source test since netCDF is not installed")              print("Skipping netCDF data source test since netCDF is not installed")
296      else:      else:

Legend:
Removed from v.4016  
changed lines
  Added in v.4362

  ViewVC Help
Powered by ViewVC 1.1.26