1 |
// $Id$ |
2 |
/* |
3 |
************************************************************ |
4 |
* Copyright 2006 by ACcESS MNRF * |
5 |
* * |
6 |
* http://www.access.edu.au * |
7 |
* Primary Business: Queensland, Australia * |
8 |
* Licensed under the Open Software License version 3.0 * |
9 |
* http://www.opensource.org/licenses/osl-3.0.php * |
10 |
* * |
11 |
************************************************************ |
12 |
*/ |
13 |
|
14 |
#include "DataFactory.h" |
15 |
|
16 |
#include <boost/python/extract.hpp> |
17 |
#include <iostream> |
18 |
#ifdef USE_NETCDF |
19 |
#include <netcdfcpp.h> |
20 |
#endif |
21 |
|
22 |
using namespace boost::python; |
23 |
|
24 |
namespace escript { |
25 |
|
26 |
Data |
27 |
Scalar(double value, |
28 |
const FunctionSpace& what, |
29 |
bool expanded) |
30 |
{ |
31 |
// |
32 |
// an empty shape is a scalar |
33 |
DataArrayView::ShapeType shape; |
34 |
return Data(value,shape,what,expanded); |
35 |
} |
36 |
|
37 |
Data |
38 |
Vector(double value, |
39 |
const FunctionSpace& what, |
40 |
bool expanded) |
41 |
{ |
42 |
DataArrayView::ShapeType shape(1,what.getDomain().getDim()); |
43 |
return Data(value,shape,what,expanded); |
44 |
} |
45 |
|
46 |
Data |
47 |
Tensor(double value, |
48 |
const FunctionSpace& what, |
49 |
bool expanded) |
50 |
{ |
51 |
DataArrayView::ShapeType shape(2,what.getDomain().getDim()); |
52 |
return Data(value,shape,what,expanded); |
53 |
} |
54 |
|
55 |
Data |
56 |
Tensor3(double value, |
57 |
const FunctionSpace& what, |
58 |
bool expanded) |
59 |
{ |
60 |
DataArrayView::ShapeType shape(3,what.getDomain().getDim()); |
61 |
return Data(value,shape,what,expanded); |
62 |
} |
63 |
|
64 |
Data |
65 |
Tensor4(double value, |
66 |
const FunctionSpace& what, |
67 |
bool expanded) |
68 |
{ |
69 |
DataArrayView::ShapeType shape(4,what.getDomain().getDim()); |
70 |
return Data(value,shape,what,expanded); |
71 |
} |
72 |
|
73 |
Data |
74 |
load(const std::string fileName, |
75 |
const AbstractDomain& domain) |
76 |
{ |
77 |
#ifdef PASO_MPI |
78 |
throw DataException("Error - load :: not implemented for MPI yet.") |
79 |
#endif |
80 |
|
81 |
#ifdef USE_NETCDF |
82 |
NcAtt *type_att, *rank_att, *function_space_type_att; |
83 |
// netCDF error handler |
84 |
NcError err(NcError::verbose_nonfatal); |
85 |
// Create the file. |
86 |
NcFile dataFile(fileName.c_str(), NcFile::ReadOnly); |
87 |
if (!dataFile.is_valid()) |
88 |
throw DataException("Error - load:: opening of netCDF file for input failed."); |
89 |
/* recover function space */ |
90 |
if (! (function_space_type_att=dataFile.get_att("function_space_type")) ) |
91 |
throw DataException("Error - load:: cannot recover function_space_type attribute from escript netCDF file."); |
92 |
int function_space_type = function_space_type_att->as_int(0); |
93 |
delete function_space_type_att; |
94 |
/* test if function space id is valid and create function space instance */ |
95 |
if (! domain.isValidFunctionSpaceType(function_space_type) ) |
96 |
throw DataException("Error - load:: function space type code in netCDF file is invalid for given domain."); |
97 |
FunctionSpace function_space=FunctionSpace(domain, function_space_type); |
98 |
/* recover rank */ |
99 |
if (! (rank_att=dataFile.get_att("rank")) ) |
100 |
throw DataException("Error - load:: cannot recover rank attribute from escript netCDF file."); |
101 |
int rank = rank_att->as_int(0); |
102 |
delete rank_att; |
103 |
if (rank<0 || rank>DataArrayView::maxRank) |
104 |
throw DataException("Error - load:: rank in escript netCDF file is greater than maximum rank."); |
105 |
|
106 |
/* recover type attribute */ |
107 |
if (! (type_att=dataFile.get_att("type")) ) |
108 |
throw DataException("Error - load:: cannot recover type attribute from escript netCDF file."); |
109 |
char* type_str = type_att->as_string(0); |
110 |
int type=-1; |
111 |
if (strncmp(type_str, "constant", strlen("constant")) == 0 ) { |
112 |
type =0; |
113 |
} else if (strncmp(type_str, "tagged", strlen("tagged")) == 0 ) { |
114 |
type =1; |
115 |
} else if (strncmp(type_str, "expanded", strlen("expanded")) == 0 ) { |
116 |
type =2; |
117 |
} |
118 |
delete type_att; |
119 |
delete type_str; |
120 |
/* recover dimension */ |
121 |
int ndims=dataFile.num_dims(); |
122 |
int ntags =0 , num_samples =0 , num_data_points_per_sample =0, d=0, len_data_point=1; |
123 |
NcDim *d_dim, *tags_dim, *num_samples_dim, *num_data_points_per_sample_dim; |
124 |
/* recover shape */ |
125 |
DataArrayView::ShapeType shape; |
126 |
long dims[DataArrayView::maxRank+2]; |
127 |
if (rank>0) { |
128 |
if (! (d_dim=dataFile.get_dim("d0")) ) |
129 |
throw DataException("Error - load:: unable to recover d0 from netCDF file."); |
130 |
d=d_dim->size(); |
131 |
shape.push_back(d); |
132 |
dims[0]=d; |
133 |
len_data_point*=d; |
134 |
} |
135 |
if (rank>1) { |
136 |
if (! (d_dim=dataFile.get_dim("d1")) ) |
137 |
throw DataException("Error - load:: unable to recover d1 from netCDF file."); |
138 |
d=d_dim->size(); |
139 |
shape.push_back(d); |
140 |
dims[1]=d; |
141 |
len_data_point*=d; |
142 |
} |
143 |
if (rank>2) { |
144 |
if (! (d_dim=dataFile.get_dim("d2")) ) |
145 |
throw DataException("Error - load:: unable to recover d2 from netCDF file."); |
146 |
d=d_dim->size(); |
147 |
shape.push_back(d); |
148 |
dims[2]=d; |
149 |
len_data_point*=d; |
150 |
} |
151 |
if (rank>3) { |
152 |
if (! (d_dim=dataFile.get_dim("d3")) ) |
153 |
throw DataException("Error - load:: unable to recover d3 from netCDF file."); |
154 |
d=d_dim->size(); |
155 |
shape.push_back(d); |
156 |
dims[3]=d; |
157 |
len_data_point*=d; |
158 |
} |
159 |
/* recover stuff */ |
160 |
Data out; |
161 |
NcVar *var, *ids_var, *tags_var; |
162 |
if (type == 0) { |
163 |
/* constant data */ |
164 |
if ( ! ( (ndims == rank && rank >0) || ( ndims ==1 && rank == 0 ) ) ) |
165 |
throw DataException("Error - load:: illegal number of dimensions for constant data in netCDF file."); |
166 |
if (rank == 0) { |
167 |
if (! (d_dim=dataFile.get_dim("l")) ) |
168 |
throw DataException("Error - load:: unable to recover d0 for scalar constant data in netCDF file."); |
169 |
int d0 = d_dim->size(); |
170 |
if (! d0 == 1) |
171 |
throw DataException("Error - load:: d0 is expected to be one for scalar constant data in netCDF file."); |
172 |
dims[0]=1; |
173 |
} |
174 |
out=Data(0,shape,function_space); |
175 |
if (!(var = dataFile.get_var("data"))) |
176 |
throw DataException("Error - load:: unable to find data in netCDF file."); |
177 |
if (! var->get(&(out.getDataPoint(0,0).getData()[0]), dims) ) |
178 |
throw DataException("Error - load:: unable to recover data from netCDF file."); |
179 |
} else if (type == 1) { |
180 |
/* tagged data */ |
181 |
if ( ! (ndims == rank + 1) ) |
182 |
throw DataException("Error - load:: illegal number of dimensions for tagged data in netCDF file."); |
183 |
if (! (tags_dim=dataFile.get_dim("num_tags")) ) |
184 |
throw DataException("Error - load:: unable to recover number of tags from netCDF file."); |
185 |
ntags=tags_dim->size(); |
186 |
dims[rank]=ntags; |
187 |
int *tags = (int *) malloc(ntags*sizeof(int)); |
188 |
if (! ( tags_var = dataFile.get_var("tags")) ) |
189 |
{ |
190 |
free(tags); |
191 |
throw DataException("Error - load:: unable to find tags in netCDF file."); |
192 |
} |
193 |
if (! tags_var->get(tags, ntags) ) |
194 |
{ |
195 |
free(tags); |
196 |
throw DataException("Error - load:: unable to recover tags from netCDF file."); |
197 |
} |
198 |
|
199 |
DataVector data(len_data_point*ntags,0.,len_data_point*ntags); |
200 |
if (!(var = dataFile.get_var("data"))) |
201 |
{ |
202 |
free(tags); |
203 |
throw DataException("Error - load:: unable to find data in netCDF file."); |
204 |
} |
205 |
if (! var->get(&(data[0]), dims) ) |
206 |
{ |
207 |
free(tags); |
208 |
throw DataException("Error - load:: unable to recover data from netCDF file."); |
209 |
} |
210 |
out=Data(DataArrayView(data,shape,0),function_space); |
211 |
for (int t=1; t<ntags; ++t) { |
212 |
out.setTaggedValueFromCPP(tags[t],DataArrayView(data,shape,t*len_data_point)); |
213 |
} |
214 |
free(tags); |
215 |
} else if (type == 2) { |
216 |
/* expanded data */ |
217 |
if ( ! (ndims == rank + 2) ) |
218 |
throw DataException("Error - load:: illegal number of dimensions for exanded data in netCDF file."); |
219 |
if ( ! (num_samples_dim = dataFile.get_dim("num_samples") ) ) |
220 |
throw DataException("Error - load:: unable to recover number of samples from netCDF file."); |
221 |
num_samples = num_samples_dim->size(); |
222 |
if ( ! (num_data_points_per_sample_dim = dataFile.get_dim("num_data_points_per_sample") ) ) |
223 |
throw DataException("Error - load:: unable to recover number of data points per sample from netCDF file."); |
224 |
num_data_points_per_sample=num_data_points_per_sample_dim->size(); |
225 |
// check shape: |
226 |
if ( ! (num_samples == function_space.getNumSamples() && num_data_points_per_sample == function_space.getNumDataPointsPerSample()) ) |
227 |
throw DataException("Error - load:: data sample layout of file does not match data layout of function space."); |
228 |
// get ids |
229 |
if (! ( ids_var = dataFile.get_var("id")) ) |
230 |
throw DataException("Error - load:: unable to find reference ids in netCDF file."); |
231 |
const int* ids_p=function_space.borrowSampleReferenceIDs(); |
232 |
int *ids_of_nc = (int *)malloc(num_samples*sizeof(int)); |
233 |
if (! ids_var->get(ids_of_nc, (long) num_samples) ) |
234 |
{ |
235 |
free(ids_of_nc); |
236 |
throw DataException("Error - load:: unable to recover ids from netCDF file."); |
237 |
} |
238 |
// check order: |
239 |
int failed=-1, local_failed=-1, i; |
240 |
#pragma omp parallel private(local_failed) |
241 |
{ |
242 |
local_failed=-1; |
243 |
#pragma omp for private(i) schedule(static) |
244 |
for (i=0;i < num_samples; ++i) { |
245 |
if (ids_of_nc[i]!=ids_p[i]) local_failed=i; |
246 |
} |
247 |
#pragma omp critical |
248 |
if (local_failed>=0) failed = local_failed; |
249 |
} |
250 |
if (failed>=0) |
251 |
{ |
252 |
free(ids_of_nc); |
253 |
throw DataException("Error - load:: data ordering in netCDF file does not match ordering of FunctionSpace."); |
254 |
} |
255 |
// get the data: |
256 |
dims[rank]=num_data_points_per_sample; |
257 |
dims[rank+1]=num_samples; |
258 |
out=Data(0,shape,function_space,true); |
259 |
if (!(var = dataFile.get_var("data"))) |
260 |
{ |
261 |
free(ids_of_nc); |
262 |
throw DataException("Error - load:: unable to find data in netCDF file."); |
263 |
} |
264 |
if (! var->get(&(out.getDataPoint(0,0).getData()[0]), dims) ) |
265 |
{ |
266 |
free(ids_of_nc); |
267 |
throw DataException("Error - load:: unable to recover data from netCDF file."); |
268 |
} |
269 |
// if (failed==-1) |
270 |
// out->m_data.reorderByReferenceIDs(ids_of_nc) |
271 |
free(ids_of_nc); |
272 |
} else { |
273 |
throw DataException("Error - load:: unknown escript data type in netCDF file."); |
274 |
} |
275 |
return out; |
276 |
#else |
277 |
throw DataException("Error - load:: is not compiled with netCFD. Please contact your insatllation manager."); |
278 |
#endif |
279 |
} |
280 |
|
281 |
bool |
282 |
loadConfigured() |
283 |
{ |
284 |
#ifdef USE_NETCDF |
285 |
return true; |
286 |
#else |
287 |
return false; |
288 |
#endif |
289 |
} |
290 |
|
291 |
Data |
292 |
convertToData(const boost::python::object& value, |
293 |
const FunctionSpace& what) |
294 |
{ |
295 |
// first we try to extract a Data object from value |
296 |
extract<Data> value_data(value); |
297 |
if (value_data.check()) { |
298 |
Data extracted_data=value_data(); |
299 |
if (extracted_data.isEmpty()) { |
300 |
return extracted_data; |
301 |
} else { |
302 |
return Data(extracted_data,what); |
303 |
} |
304 |
} else { |
305 |
return Data(value,what); |
306 |
} |
307 |
} |
308 |
|
309 |
} // end of namespace |