/[escript]/branches/subworld2/escriptcore/src/DataFactory.cpp
ViewVC logotype

Contents of /branches/subworld2/escriptcore/src/DataFactory.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5504 - (show annotations)
Wed Mar 4 22:58:13 2015 UTC (4 years, 1 month ago) by jfenwick
File size: 14805 byte(s)
Again with a more up to date copy


1
2 /*****************************************************************************
3 *
4 * Copyright (c) 2003-2015 by University of Queensland
5 * http://www.uq.edu.au
6 *
7 * Primary Business: Queensland, Australia
8 * Licensed under the Open Software License version 3.0
9 * http://www.opensource.org/licenses/osl-3.0.php
10 *
11 * Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 * Development 2012-2013 by School of Earth Sciences
13 * Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 *
15 *****************************************************************************/
16
17 #define ESNEEDPYTHON
18 #include "esysUtils/first.h"
19
20
21 #include "DataFactory.h"
22 #include "esysUtils/Esys_MPI.h"
23
24 #include <boost/python/extract.hpp>
25 #include <boost/scoped_array.hpp>
26
27 #include <iostream>
28 #include <exception>
29 #ifdef USE_NETCDF
30 #include <netcdfcpp.h>
31 #endif
32
33 using namespace boost::python;
34
35 namespace escript {
36
37 Data
38 Scalar(double value,
39 const FunctionSpace& what,
40 bool expanded)
41 {
42 //
43 // an empty shape is a scalar
44 DataTypes::ShapeType shape;
45 return Data(value,shape,what,expanded);
46 }
47
48 Data
49 Vector(double value,
50 const FunctionSpace& what,
51 bool expanded)
52 {
53 DataTypes::ShapeType shape(1,what.getDomain()->getDim());
54 return Data(value,shape,what,expanded);
55 }
56
57 Data
58 VectorFromObj(boost::python::object o,
59 const FunctionSpace& what,
60 bool expanded)
61 {
62 double v;
63 try // first try to get a double and route it to the other method
64 {
65 v=boost::python::extract<double>(o);
66 return Vector(v,what,expanded);
67 }
68 catch(...)
69 {
70 PyErr_Clear();
71 }
72 DataTypes::ShapeType shape(1,what.getDomain()->getDim());
73 Data d(o,what,expanded);
74 if (d.getDataPointShape()!=shape)
75 {
76 throw DataException("VectorFromObj: Shape of vector passed to function does not match the dimension of the domain. ");
77 }
78 return d;
79 }
80
81 Data
82 Tensor(double value,
83 const FunctionSpace& what,
84 bool expanded)
85 {
86 DataTypes::ShapeType shape(2,what.getDomain()->getDim());
87 return Data(value,shape,what,expanded);
88 }
89
90
91 // We need to take some care here because this signature trumps the other one from boost's point of view
92 Data
93 TensorFromObj(boost::python::object o,
94 const FunctionSpace& what,
95 bool expanded)
96 {
97 double v;
98 try // first try to get a double and route it to the other method
99 {
100 v=boost::python::extract<double>(o);
101 return Tensor(v,what,expanded);
102 }
103 catch(...)
104 {
105 PyErr_Clear();
106 }
107 DataTypes::ShapeType shape(2,what.getDomain()->getDim());
108 Data d(o,what,expanded);
109 if (d.getDataPointShape()!=shape)
110 {
111 throw DataException("TensorFromObj: Shape of tensor passed to function does not match the dimension of the domain. ");
112 }
113 return d;
114 }
115
116 Data
117 Tensor3(double value,
118 const FunctionSpace& what,
119 bool expanded)
120 {
121 DataTypes::ShapeType shape(3,what.getDomain()->getDim());
122 return Data(value,shape,what,expanded);
123 }
124
125 Data
126 Tensor3FromObj(boost::python::object o,
127 const FunctionSpace& what,
128 bool expanded)
129 {
130 double v;
131 try // first try to get a double and route it to the other method
132 {
133 v=boost::python::extract<double>(o);
134 return Tensor3(v,what,expanded);
135 }
136 catch(...)
137 {
138 PyErr_Clear();
139 }
140 DataTypes::ShapeType shape(3,what.getDomain()->getDim());
141 Data d(o,what,expanded);
142 if (d.getDataPointShape()!=shape)
143 {
144 throw DataException("Tensor3FromObj: Shape of tensor passed to function does not match the dimension of the domain. ");
145 }
146 return d;
147 }
148
149 Data
150 Tensor4(double value,
151 const FunctionSpace& what,
152 bool expanded)
153 {
154 DataTypes::ShapeType shape(4,what.getDomain()->getDim());
155 return Data(value,shape,what,expanded);
156 }
157
158 Data
159 Tensor4FromObj(boost::python::object o,
160 const FunctionSpace& what,
161 bool expanded)
162 {
163 double v;
164 try // first try to get a double and route it to the other method
165 {
166 v=boost::python::extract<double>(o);
167 return Tensor4(v,what,expanded);
168 }
169 catch(...)
170 {
171 PyErr_Clear();
172 }
173 DataTypes::ShapeType shape(4,what.getDomain()->getDim());
174 Data d(o,what,expanded);
175 if (d.getDataPointShape()!=shape)
176 {
177 throw DataException("VectorFromObj: Shape of tensor passed to function does not match the dimension of the domain. ");
178 }
179 return d;
180 }
181
182
183 Data
184 load(const std::string fileName,
185 const AbstractDomain& domain)
186 {
187 #ifdef USE_NETCDF
188 NcAtt *type_att, *rank_att, *function_space_type_att;
189 // netCDF error handler
190 NcError err(NcError::silent_nonfatal);
191 int mpi_iam=0, mpi_num=1;
192 // Create the file.
193 #ifdef ESYS_MPI
194 MPI_Comm_rank(MPI_COMM_WORLD, &mpi_iam);
195 MPI_Comm_size(MPI_COMM_WORLD, &mpi_num);
196 #endif
197 const std::string newFileName(esysUtils::appendRankToFileName(fileName,
198 mpi_num, mpi_iam));
199 NcFile dataFile(newFileName.c_str(), NcFile::ReadOnly);
200 if (!dataFile.is_valid())
201 throw DataException("Error - load:: opening of netCDF file for input failed.");
202 /* recover function space */
203 if (! (function_space_type_att=dataFile.get_att("function_space_type")) )
204 throw DataException("Error - load:: cannot recover function_space_type attribute from escript netCDF file.");
205 int function_space_type = function_space_type_att->as_int(0);
206 delete function_space_type_att;
207 /* test if function space id is valid and create function space instance */
208 if (! domain.isValidFunctionSpaceType(function_space_type) )
209 throw DataException("Error - load:: function space type code in netCDF file is invalid for given domain.");
210 FunctionSpace function_space=FunctionSpace(domain.getPtr(), function_space_type);
211 /* recover rank */
212 if (! (rank_att=dataFile.get_att("rank")) )
213 throw DataException("Error - load:: cannot recover rank attribute from escript netCDF file.");
214 int rank = rank_att->as_int(0);
215 delete rank_att;
216 if (rank<0 || rank>DataTypes::maxRank)
217 throw DataException("Error - load:: rank in escript netCDF file is greater than maximum rank.");
218 /* recover type attribute */
219 int type=-1;
220 if ((type_att=dataFile.get_att("type")) ) {
221 boost::scoped_array<char> type_str(type_att->as_string(0));
222 if (strncmp(type_str.get(), "constant", strlen("constant")) == 0 ) {
223 type = 0;
224 } else if (strncmp(type_str.get(), "tagged", strlen("tagged")) == 0 ) {
225 type = 1;
226 } else if (strncmp(type_str.get(), "expanded", strlen("expanded")) == 0 ) {
227 type = 2;
228 }
229 } else {
230 if (! (type_att=dataFile.get_att("type_id")) )
231 throw DataException("Error - load:: cannot recover type attribute from escript netCDF file.");
232 type=type_att->as_int(0);
233 }
234 delete type_att;
235
236 /* recover dimension */
237 int ndims=dataFile.num_dims();
238 int ntags =0 , num_samples =0 , num_data_points_per_sample =0, d=0, len_data_point=1;
239 NcDim *d_dim, *tags_dim, *num_samples_dim, *num_data_points_per_sample_dim;
240 /* recover shape */
241 DataTypes::ShapeType shape;
242 long dims[DataTypes::maxRank+2];
243 if (rank>0) {
244 if (! (d_dim=dataFile.get_dim("d0")) )
245 throw DataException("Error - load:: unable to recover d0 from netCDF file.");
246 d=d_dim->size();
247 shape.push_back(d);
248 dims[0]=d;
249 len_data_point*=d;
250 }
251 if (rank>1) {
252 if (! (d_dim=dataFile.get_dim("d1")) )
253 throw DataException("Error - load:: unable to recover d1 from netCDF file.");
254 d=d_dim->size();
255 shape.push_back(d);
256 dims[1]=d;
257 len_data_point*=d;
258 }
259 if (rank>2) {
260 if (! (d_dim=dataFile.get_dim("d2")) )
261 throw DataException("Error - load:: unable to recover d2 from netCDF file.");
262 d=d_dim->size();
263 shape.push_back(d);
264 dims[2]=d;
265 len_data_point*=d;
266 }
267 if (rank>3) {
268 if (! (d_dim=dataFile.get_dim("d3")) )
269 throw DataException("Error - load:: unable to recover d3 from netCDF file.");
270 d=d_dim->size();
271 shape.push_back(d);
272 dims[3]=d;
273 len_data_point*=d;
274 }
275 /* recover stuff */
276 Data out;
277 NcVar *var, *ids_var, *tags_var;
278 if (type == 0) {
279 /* constant data */
280 if ( ! ( (ndims == rank && rank >0) || ( ndims ==1 && rank == 0 ) ) )
281 throw DataException("Error - load:: illegal number of dimensions for constant data in netCDF file.");
282 if (rank == 0) {
283 if (! (d_dim=dataFile.get_dim("l")) )
284 throw DataException("Error - load:: unable to recover d0 for scalar constant data in netCDF file.");
285 int d0 = d_dim->size();
286 if (d0 != 1)
287 throw DataException("Error - load:: d0 is expected to be one for scalar constant data in netCDF file.");
288 dims[0]=1;
289 }
290 out=Data(0,shape,function_space);
291 if (!(var = dataFile.get_var("data")))
292 throw DataException("Error - load:: unable to find data in netCDF file.");
293 if (! var->get(&(out.getDataAtOffsetRW(out.getDataOffset(0,0))), dims) )
294 throw DataException("Error - load:: unable to recover data from netCDF file.");
295 } else if (type == 1) {
296 /* tagged data */
297 if ( ! (ndims == rank + 1) )
298 throw DataException("Error - load:: illegal number of dimensions for tagged data in netCDF file.");
299 if (! (tags_dim=dataFile.get_dim("num_tags")) )
300 throw DataException("Error - load:: unable to recover number of tags from netCDF file.");
301 ntags=tags_dim->size();
302 dims[rank]=ntags;
303 std::vector<int> tags(ntags);
304 if (! ( tags_var = dataFile.get_var("tags")) )
305 {
306 throw DataException("Error - load:: unable to find tags in netCDF file.");
307 }
308 if (! tags_var->get(&tags[0], ntags) )
309 {
310 throw DataException("Error - load:: unable to recover tags from netCDF file.");
311 }
312
313 // Current Version
314 /* DataVector data(len_data_point * ntags, 0., len_data_point * ntags);
315 if (!(var = dataFile.get_var("data")))
316 {
317 esysUtils::free(tags);
318 throw DataException("Error - load:: unable to find data in netCDF file.");
319 }
320 if (! var->get(&(data[0]), dims) )
321 {
322 esysUtils::free(tags);
323 throw DataException("Error - load:: unable to recover data from netCDF file.");
324 }
325 out=Data(DataArrayView(data,shape,0),function_space);
326 for (int t=1; t<ntags; ++t) {
327 out.setTaggedValueFromCPP(tags[t],shape, data, t*len_data_point);
328 // out.setTaggedValueFromCPP(tags[t],DataArrayView(data,shape,t*len_data_point));
329 }*/
330 // End current version
331
332 // New version
333
334 // A) create a DataTagged dt
335 // B) Read data from file
336 // C) copy default value into dt
337 // D) copy tagged values into dt
338 // E) create a new Data based on dt
339
340 NcVar* var1;
341 DataVector data1(len_data_point * ntags, 0., len_data_point * ntags);
342 if (!(var1 = dataFile.get_var("data")))
343 {
344 throw DataException("Error - load:: unable to find data in netCDF file.");
345 }
346 if (! var1->get(&(data1[0]), dims) )
347 {
348 throw DataException("Error - load:: unable to recover data from netCDF file.");
349 }
350 DataTagged* dt=new DataTagged(function_space, shape, &tags[0], data1);
351 out=Data(dt);
352 } else if (type == 2) {
353 /* expanded data */
354 if ( ! (ndims == rank + 2) )
355 throw DataException("Error - load:: illegal number of dimensions for expanded data in netCDF file.");
356 if ( ! (num_samples_dim = dataFile.get_dim("num_samples") ) )
357 throw DataException("Error - load:: unable to recover number of samples from netCDF file.");
358 num_samples = num_samples_dim->size();
359 if ( ! (num_data_points_per_sample_dim = dataFile.get_dim("num_data_points_per_sample") ) )
360 throw DataException("Error - load:: unable to recover number of data points per sample from netCDF file.");
361 num_data_points_per_sample=num_data_points_per_sample_dim->size();
362 // check shape:
363 if ( ! (num_samples == function_space.getNumSamples() && num_data_points_per_sample == function_space.getNumDataPointsPerSample()) )
364 throw DataException("Error - load:: data sample layout of file does not match data layout of function space.");
365 if (num_samples==0) {
366 out = Data(0,shape,function_space,true);
367 }
368 else {
369 // get ids
370 if (! ( ids_var = dataFile.get_var("id")) )
371 throw DataException("Error - load:: unable to find reference ids in netCDF file.");
372 const dim_t* ids_p=function_space.borrowSampleReferenceIDs();
373 std::vector<dim_t> ids_of_nc(num_samples);
374 if (! ids_var->get(&ids_of_nc[0], (long) num_samples) )
375 {
376 throw DataException("Error - load:: unable to recover ids from netCDF file.");
377 }
378 // check order:
379 int failed=-1, local_failed=-1, i;
380 #pragma omp parallel private(local_failed)
381 {
382 local_failed=-1;
383 #pragma omp for private(i) schedule(static)
384 for (i=0;i < num_samples; ++i) {
385 if (ids_of_nc[i]!=ids_p[i]) local_failed=i;
386 }
387 #pragma omp critical
388 if (local_failed>=0) failed = local_failed;
389 }
390 /* if (failed>=0)
391 {
392 throw DataException("Error - load:: data ordering in netCDF file does not match ordering of FunctionSpace.");
393 } */
394 // get the data:
395 dims[rank]=num_data_points_per_sample;
396 dims[rank+1]=num_samples;
397 out=Data(0,shape,function_space,true);
398 if (!(var = dataFile.get_var("data")))
399 {
400 throw DataException("Error - load:: unable to find data in netCDF file.");
401 }
402 if (! var->get(&(out.getDataAtOffsetRW(out.getDataOffset(0,0))), dims) )
403 {
404 throw DataException("Error - load:: unable to recover data from netCDF file.");
405 }
406 if (failed>=0) {
407 try {
408 std::cout << "Information - load: start reordering data from netCDF file " << fileName << std::endl;
409 out.borrowData()->reorderByReferenceIDs(&ids_of_nc[0]);
410 }
411 catch (std::exception&) {
412 throw DataException("Error - load:: unable to reorder data in netCDF file.");
413 }
414 }
415 }
416 } else {
417 throw DataException("Error - load:: unknown escript data type in netCDF file.");
418 }
419 return out;
420 #else
421 throw DataException("Error - load:: is not compiled with netCDF. Please contact your installation manager.");
422 #endif
423 }
424
425 bool
426 loadConfigured()
427 {
428 #ifdef USE_NETCDF
429 return true;
430 #else
431 return false;
432 #endif
433 }
434
435 Data
436 convertToData(const boost::python::object& value,
437 const FunctionSpace& what)
438 {
439 // first we try to extract a Data object from value
440 extract<Data> value_data(value);
441 if (value_data.check()) {
442 Data extracted_data=value_data();
443 if (extracted_data.isEmpty()) {
444 return extracted_data;
445 } else {
446 return Data(extracted_data,what);
447 }
448 } else {
449 return Data(value,what);
450 }
451 }
452
453 } // end of namespace

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26