/[escript]/trunk/escript/src/DataFactory.cpp
ViewVC logotype

Contents of /trunk/escript/src/DataFactory.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3911 - (show annotations)
Thu Jun 14 01:01:03 2012 UTC (7 years, 4 months ago) by jfenwick
File size: 14831 byte(s)
Copyright changes
1
2 /*******************************************************
3 *
4 * Copyright (c) 2003-2012 by University of Queensland
5 * Earth Systems Science Computational Center (ESSCC)
6 * http://www.uq.edu.au/esscc
7 *
8 * Primary Business: Queensland, Australia
9 * Licensed under the Open Software License version 3.0
10 * http://www.opensource.org/licenses/osl-3.0.php
11 *
12 *******************************************************/
13
14
15 #include "DataFactory.h"
16 #include "esysUtils/esys_malloc.h"
17 #include "esysUtils/Esys_MPI.h"
18
19 #include <boost/python/extract.hpp>
20 #include <iostream>
21 #include <exception>
22 #ifdef USE_NETCDF
23 #include <netcdfcpp.h>
24 #endif
25
26 using namespace boost::python;
27
28 namespace escript {
29
30 Data
31 Scalar(double value,
32 const FunctionSpace& what,
33 bool expanded)
34 {
35 //
36 // an empty shape is a scalar
37 DataTypes::ShapeType shape;
38 return Data(value,shape,what,expanded);
39 }
40
41 Data
42 Vector(double value,
43 const FunctionSpace& what,
44 bool expanded)
45 {
46 DataTypes::ShapeType shape(1,what.getDomain()->getDim());
47 return Data(value,shape,what,expanded);
48 }
49
50 Data
51 VectorFromObj(boost::python::object o,
52 const FunctionSpace& what,
53 bool expanded)
54 {
55 double v;
56 try // first try to get a double and route it to the other method
57 {
58 v=boost::python::extract<double>(o);
59 return Vector(v,what,expanded);
60 }
61 catch(...)
62 {
63 PyErr_Clear();
64 }
65 DataTypes::ShapeType shape(1,what.getDomain()->getDim());
66 Data d(o,what,expanded);
67 if (d.getDataPointShape()!=shape)
68 {
69 throw DataException("VectorFromObj: Shape of vector passed to function does not match the dimension of the domain. ");
70 }
71 return d;
72 }
73
74 Data
75 Tensor(double value,
76 const FunctionSpace& what,
77 bool expanded)
78 {
79 DataTypes::ShapeType shape(2,what.getDomain()->getDim());
80 return Data(value,shape,what,expanded);
81 }
82
83
84 // We need to take some care here because this signature trumps the other one from boost's point of view
85 Data
86 TensorFromObj(boost::python::object o,
87 const FunctionSpace& what,
88 bool expanded)
89 {
90 double v;
91 try // first try to get a double and route it to the other method
92 {
93 v=boost::python::extract<double>(o);
94 return Tensor(v,what,expanded);
95 }
96 catch(...)
97 {
98 PyErr_Clear();
99 }
100 DataTypes::ShapeType shape(2,what.getDomain()->getDim());
101 Data d(o,what,expanded);
102 if (d.getDataPointShape()!=shape)
103 {
104 throw DataException("TensorFromObj: Shape of tensor passed to function does not match the dimension of the domain. ");
105 }
106 return d;
107 }
108
109 Data
110 Tensor3(double value,
111 const FunctionSpace& what,
112 bool expanded)
113 {
114 DataTypes::ShapeType shape(3,what.getDomain()->getDim());
115 return Data(value,shape,what,expanded);
116 }
117
118 Data
119 Tensor3FromObj(boost::python::object o,
120 const FunctionSpace& what,
121 bool expanded)
122 {
123 double v;
124 try // first try to get a double and route it to the other method
125 {
126 v=boost::python::extract<double>(o);
127 return Tensor3(v,what,expanded);
128 }
129 catch(...)
130 {
131 PyErr_Clear();
132 }
133 DataTypes::ShapeType shape(3,what.getDomain()->getDim());
134 Data d(o,what,expanded);
135 if (d.getDataPointShape()!=shape)
136 {
137 throw DataException("Tensor3FromObj: Shape of tensor passed to function does not match the dimension of the domain. ");
138 }
139 return d;
140 }
141
142 Data
143 Tensor4(double value,
144 const FunctionSpace& what,
145 bool expanded)
146 {
147 DataTypes::ShapeType shape(4,what.getDomain()->getDim());
148 return Data(value,shape,what,expanded);
149 }
150
151 Data
152 Tensor4FromObj(boost::python::object o,
153 const FunctionSpace& what,
154 bool expanded)
155 {
156 double v;
157 try // first try to get a double and route it to the other method
158 {
159 v=boost::python::extract<double>(o);
160 return Tensor4(v,what,expanded);
161 }
162 catch(...)
163 {
164 PyErr_Clear();
165 }
166 DataTypes::ShapeType shape(4,what.getDomain()->getDim());
167 Data d(o,what,expanded);
168 if (d.getDataPointShape()!=shape)
169 {
170 throw DataException("VectorFromObj: Shape of tensor passed to function does not match the dimension of the domain. ");
171 }
172 return d;
173 }
174
175
176 Data
177 load(const std::string fileName,
178 const AbstractDomain& domain)
179 {
180 #ifdef USE_NETCDF
181 NcAtt *type_att, *rank_att, *function_space_type_att;
182 // netCDF error handler
183 NcError err(NcError::silent_nonfatal);
184 int mpi_iam=0, mpi_num=1;
185 // Create the file.
186 #ifdef ESYS_MPI
187 MPI_Comm_rank(MPI_COMM_WORLD, &mpi_iam);
188 MPI_Comm_size(MPI_COMM_WORLD, &mpi_num);
189 #endif
190 char *newFileName = Escript_MPI_appendRankToFileName(fileName.c_str(), mpi_num, mpi_iam);
191 NcFile dataFile(newFileName, NcFile::ReadOnly);
192 if (!dataFile.is_valid())
193 throw DataException("Error - load:: opening of netCDF file for input failed.");
194 /* recover function space */
195 if (! (function_space_type_att=dataFile.get_att("function_space_type")) )
196 throw DataException("Error - load:: cannot recover function_space_type attribute from escript netCDF file.");
197 int function_space_type = function_space_type_att->as_int(0);
198 delete function_space_type_att;
199 /* test if function space id is valid and create function space instance */
200 if (! domain.isValidFunctionSpaceType(function_space_type) )
201 throw DataException("Error - load:: function space type code in netCDF file is invalid for given domain.");
202 FunctionSpace function_space=FunctionSpace(domain.getPtr(), function_space_type);
203 /* recover rank */
204 if (! (rank_att=dataFile.get_att("rank")) )
205 throw DataException("Error - load:: cannot recover rank attribute from escript netCDF file.");
206 int rank = rank_att->as_int(0);
207 delete rank_att;
208 if (rank<0 || rank>DataTypes::maxRank)
209 throw DataException("Error - load:: rank in escript netCDF file is greater than maximum rank.");
210 /* recover type attribute */
211 int type=-1;
212 if ((type_att=dataFile.get_att("type")) ) {
213 char* type_str = type_att->as_string(0);
214 if (strncmp(type_str, "constant", strlen("constant")) == 0 ) {
215 type =0;
216 } else if (strncmp(type_str, "tagged", strlen("tagged")) == 0 ) {
217 type =1;
218 } else if (strncmp(type_str, "expanded", strlen("expanded")) == 0 ) {
219 type =2;
220 }
221 esysUtils::free(type_str);
222 } else {
223 if (! (type_att=dataFile.get_att("type_id")) )
224 throw DataException("Error - load:: cannot recover type attribute from escript netCDF file.");
225 type=type_att->as_int(0);
226 }
227 delete type_att;
228
229 /* recover dimension */
230 int ndims=dataFile.num_dims();
231 int ntags =0 , num_samples =0 , num_data_points_per_sample =0, d=0, len_data_point=1;
232 NcDim *d_dim, *tags_dim, *num_samples_dim, *num_data_points_per_sample_dim;
233 /* recover shape */
234 DataTypes::ShapeType shape;
235 long dims[DataTypes::maxRank+2];
236 if (rank>0) {
237 if (! (d_dim=dataFile.get_dim("d0")) )
238 throw DataException("Error - load:: unable to recover d0 from netCDF file.");
239 d=d_dim->size();
240 shape.push_back(d);
241 dims[0]=d;
242 len_data_point*=d;
243 }
244 if (rank>1) {
245 if (! (d_dim=dataFile.get_dim("d1")) )
246 throw DataException("Error - load:: unable to recover d1 from netCDF file.");
247 d=d_dim->size();
248 shape.push_back(d);
249 dims[1]=d;
250 len_data_point*=d;
251 }
252 if (rank>2) {
253 if (! (d_dim=dataFile.get_dim("d2")) )
254 throw DataException("Error - load:: unable to recover d2 from netCDF file.");
255 d=d_dim->size();
256 shape.push_back(d);
257 dims[2]=d;
258 len_data_point*=d;
259 }
260 if (rank>3) {
261 if (! (d_dim=dataFile.get_dim("d3")) )
262 throw DataException("Error - load:: unable to recover d3 from netCDF file.");
263 d=d_dim->size();
264 shape.push_back(d);
265 dims[3]=d;
266 len_data_point*=d;
267 }
268 /* recover stuff */
269 Data out;
270 NcVar *var, *ids_var, *tags_var;
271 if (type == 0) {
272 /* constant data */
273 if ( ! ( (ndims == rank && rank >0) || ( ndims ==1 && rank == 0 ) ) )
274 throw DataException("Error - load:: illegal number of dimensions for constant data in netCDF file.");
275 if (rank == 0) {
276 if (! (d_dim=dataFile.get_dim("l")) )
277 throw DataException("Error - load:: unable to recover d0 for scalar constant data in netCDF file.");
278 int d0 = d_dim->size();
279 if (! d0 == 1)
280 throw DataException("Error - load:: d0 is expected to be one for scalar constant data in netCDF file.");
281 dims[0]=1;
282 }
283 out=Data(0,shape,function_space);
284 if (!(var = dataFile.get_var("data")))
285 throw DataException("Error - load:: unable to find data in netCDF file.");
286 if (! var->get(&(out.getDataAtOffsetRW(out.getDataOffset(0,0))), dims) )
287 throw DataException("Error - load:: unable to recover data from netCDF file.");
288 } else if (type == 1) {
289 /* tagged data */
290 if ( ! (ndims == rank + 1) )
291 throw DataException("Error - load:: illegal number of dimensions for tagged data in netCDF file.");
292 if (! (tags_dim=dataFile.get_dim("num_tags")) )
293 throw DataException("Error - load:: unable to recover number of tags from netCDF file.");
294 ntags=tags_dim->size();
295 dims[rank]=ntags;
296 int *tags = (int *) esysUtils::malloc(ntags*sizeof(int));
297 if (! ( tags_var = dataFile.get_var("tags")) )
298 {
299 esysUtils::free(tags);
300 throw DataException("Error - load:: unable to find tags in netCDF file.");
301 }
302 if (! tags_var->get(tags, ntags) )
303 {
304 esysUtils::free(tags);
305 throw DataException("Error - load:: unable to recover tags from netCDF file.");
306 }
307
308 // Current Version
309 /* DataVector data(len_data_point * ntags, 0., len_data_point * ntags);
310 if (!(var = dataFile.get_var("data")))
311 {
312 esysUtils::free(tags);
313 throw DataException("Error - load:: unable to find data in netCDF file.");
314 }
315 if (! var->get(&(data[0]), dims) )
316 {
317 esysUtils::free(tags);
318 throw DataException("Error - load:: unable to recover data from netCDF file.");
319 }
320 out=Data(DataArrayView(data,shape,0),function_space);
321 for (int t=1; t<ntags; ++t) {
322 out.setTaggedValueFromCPP(tags[t],shape, data, t*len_data_point);
323 // out.setTaggedValueFromCPP(tags[t],DataArrayView(data,shape,t*len_data_point));
324 }*/
325 // End current version
326
327 // New version
328
329 // A) create a DataTagged dt
330 // B) Read data from file
331 // C) copy default value into dt
332 // D) copy tagged values into dt
333 // E) create a new Data based on dt
334
335 NcVar* var1;
336 DataVector data1(len_data_point * ntags, 0., len_data_point * ntags);
337 if (!(var1 = dataFile.get_var("data")))
338 {
339 esysUtils::free(tags);
340 throw DataException("Error - load:: unable to find data in netCDF file.");
341 }
342 if (! var1->get(&(data1[0]), dims) )
343 {
344 esysUtils::free(tags);
345 throw DataException("Error - load:: unable to recover data from netCDF file.");
346 }
347 DataTagged* dt=new DataTagged(function_space, shape, tags,data1);
348 out=Data(dt);
349 esysUtils::free(tags);
350 } else if (type == 2) {
351 /* expanded data */
352 if ( ! (ndims == rank + 2) )
353 throw DataException("Error - load:: illegal number of dimensions for exanded data in netCDF file.");
354 if ( ! (num_samples_dim = dataFile.get_dim("num_samples") ) )
355 throw DataException("Error - load:: unable to recover number of samples from netCDF file.");
356 num_samples = num_samples_dim->size();
357 if ( ! (num_data_points_per_sample_dim = dataFile.get_dim("num_data_points_per_sample") ) )
358 throw DataException("Error - load:: unable to recover number of data points per sample from netCDF file.");
359 num_data_points_per_sample=num_data_points_per_sample_dim->size();
360 // check shape:
361 if ( ! (num_samples == function_space.getNumSamples() && num_data_points_per_sample == function_space.getNumDataPointsPerSample()) )
362 throw DataException("Error - load:: data sample layout of file does not match data layout of function space.");
363 if (num_samples==0) {
364 out = Data(0,shape,function_space,true);
365 }
366 else {
367 // get ids
368 if (! ( ids_var = dataFile.get_var("id")) )
369 throw DataException("Error - load:: unable to find reference ids in netCDF file.");
370 const int* ids_p=function_space.borrowSampleReferenceIDs();
371 int *ids_of_nc = (int *)esysUtils::malloc(num_samples*sizeof(int));
372 if (! ids_var->get(ids_of_nc, (long) num_samples) )
373 {
374 esysUtils::free(ids_of_nc);
375 throw DataException("Error - load:: unable to recover ids from netCDF file.");
376 }
377 // check order:
378 int failed=-1, local_failed=-1, i;
379 #pragma omp parallel private(local_failed)
380 {
381 local_failed=-1;
382 #pragma omp for private(i) schedule(static)
383 for (i=0;i < num_samples; ++i) {
384 if (ids_of_nc[i]!=ids_p[i]) local_failed=i;
385 }
386 #pragma omp critical
387 if (local_failed>=0) failed = local_failed;
388 }
389 /* if (failed>=0)
390 {
391 esysUtils::free(ids_of_nc);
392 throw DataException("Error - load:: data ordering in netCDF file does not match ordering of FunctionSpace.");
393 } */
394 // get the data:
395 dims[rank]=num_data_points_per_sample;
396 dims[rank+1]=num_samples;
397 out=Data(0,shape,function_space,true);
398 if (!(var = dataFile.get_var("data")))
399 {
400 esysUtils::free(ids_of_nc);
401 throw DataException("Error - load:: unable to find data in netCDF file.");
402 }
403 if (! var->get(&(out.getDataAtOffsetRW(out.getDataOffset(0,0))), dims) )
404 {
405 esysUtils::free(ids_of_nc);
406 throw DataException("Error - load:: unable to recover data from netCDF file.");
407 }
408 if (failed>=0) {
409 try {
410 std::cout << "Information - load: start reordering data from netCDF file " << fileName << std::endl;
411 out.borrowData()->reorderByReferenceIDs(ids_of_nc);
412 }
413 catch (std::exception&) {
414 esysUtils::free(ids_of_nc);
415 throw DataException("Error - load:: unable to reorder data in netCDF file.");
416 }
417 }
418 }
419 } else {
420 throw DataException("Error - load:: unknown escript data type in netCDF file.");
421 }
422 return out;
423 #else
424 throw DataException("Error - load:: is not compiled with netCFD. Please contact your insatllation manager.");
425 #endif
426 }
427
428 bool
429 loadConfigured()
430 {
431 #ifdef USE_NETCDF
432 return true;
433 #else
434 return false;
435 #endif
436 }
437
438 Data
439 convertToData(const boost::python::object& value,
440 const FunctionSpace& what)
441 {
442 // first we try to extract a Data object from value
443 extract<Data> value_data(value);
444 if (value_data.check()) {
445 Data extracted_data=value_data();
446 if (extracted_data.isEmpty()) {
447 return extracted_data;
448 } else {
449 return Data(extracted_data,what);
450 }
451 } else {
452 return Data(value,what);
453 }
454 }
455
456 } // end of namespace

Properties

Name Value
svn:eol-style native
svn:keywords Author Date Id Revision

  ViewVC Help
Powered by ViewVC 1.1.26