1 |
// $Id$ |
2 |
/* |
3 |
************************************************************ |
4 |
* Copyright 2006 by ACcESS MNRF * |
5 |
* * |
6 |
* http://www.access.edu.au * |
7 |
* Primary Business: Queensland, Australia * |
8 |
* Licensed under the Open Software License version 3.0 * |
9 |
* http://www.opensource.org/licenses/osl-3.0.php * |
10 |
* * |
11 |
************************************************************ |
12 |
*/ |
13 |
|
14 |
#include "DataFactory.h" |
15 |
|
16 |
#include <boost/python/extract.hpp> |
17 |
#include <iostream> |
18 |
#include <netcdfcpp.h> |
19 |
|
20 |
using namespace boost::python; |
21 |
|
22 |
namespace escript { |
23 |
|
24 |
Data |
25 |
Scalar(double value, |
26 |
const FunctionSpace& what, |
27 |
bool expanded) |
28 |
{ |
29 |
// |
30 |
// an empty shape is a scalar |
31 |
DataArrayView::ShapeType shape; |
32 |
return Data(value,shape,what,expanded); |
33 |
} |
34 |
|
35 |
Data |
36 |
Vector(double value, |
37 |
const FunctionSpace& what, |
38 |
bool expanded) |
39 |
{ |
40 |
DataArrayView::ShapeType shape(1,what.getDomain().getDim()); |
41 |
return Data(value,shape,what,expanded); |
42 |
} |
43 |
|
44 |
Data |
45 |
Tensor(double value, |
46 |
const FunctionSpace& what, |
47 |
bool expanded) |
48 |
{ |
49 |
DataArrayView::ShapeType shape(2,what.getDomain().getDim()); |
50 |
return Data(value,shape,what,expanded); |
51 |
} |
52 |
|
53 |
Data |
54 |
Tensor3(double value, |
55 |
const FunctionSpace& what, |
56 |
bool expanded) |
57 |
{ |
58 |
DataArrayView::ShapeType shape(3,what.getDomain().getDim()); |
59 |
return Data(value,shape,what,expanded); |
60 |
} |
61 |
|
62 |
Data |
63 |
Tensor4(double value, |
64 |
const FunctionSpace& what, |
65 |
bool expanded) |
66 |
{ |
67 |
DataArrayView::ShapeType shape(4,what.getDomain().getDim()); |
68 |
return Data(value,shape,what,expanded); |
69 |
} |
70 |
|
71 |
Data |
72 |
load(const std::string fileName, |
73 |
const AbstractDomain& domain) |
74 |
{ |
75 |
#ifdef PASO_MPI |
76 |
throw DataException("Error - DataConstant:: dump is not implemented for MPI yet.") |
77 |
#endif |
78 |
NcAtt *type_att, *rank_att, *function_space_type_att; |
79 |
// netCDF error handler |
80 |
NcError err(NcError::verbose_nonfatal); |
81 |
// Create the file. |
82 |
NcFile dataFile(fileName.c_str(), NcFile::ReadOnly); |
83 |
if (!dataFile.is_valid()) |
84 |
throw DataException("Error - load:: opening of netCDF file for input failed."); |
85 |
/* recover function space */ |
86 |
if (! (function_space_type_att=dataFile.get_att("function_space_type")) ) |
87 |
throw DataException("Error - load:: cannot recover function_space_type attribute from escript netCDF file."); |
88 |
int function_space_type = function_space_type_att->as_int(0); |
89 |
delete function_space_type_att; |
90 |
/* test if function space id is valid and create function space instance */ |
91 |
if (! domain.isValidFunctionSpaceType(function_space_type) ) |
92 |
throw DataException("Error - load:: function space type code in netCDF file is invalid for given domain."); |
93 |
FunctionSpace function_space=FunctionSpace(domain, function_space_type); |
94 |
/* recover rank */ |
95 |
if (! (rank_att=dataFile.get_att("rank")) ) |
96 |
throw DataException("Error - load:: cannot recover rank attribute from escript netCDF file."); |
97 |
int rank = rank_att->as_int(0); |
98 |
delete rank_att; |
99 |
if (rank<0 || rank>DataArrayView::maxRank) |
100 |
throw DataException("Error - load:: rank in escript netCDF file is greater than maximum rank."); |
101 |
|
102 |
/* recover type attribute */ |
103 |
if (! (type_att=dataFile.get_att("type")) ) |
104 |
throw DataException("Error - load:: cannot recover type attribute from escript netCDF file."); |
105 |
char* type_str = type_att->as_string(0); |
106 |
int type=-1; |
107 |
if (strncmp(type_str, "constant", strlen("constant")) == 0 ) { |
108 |
type =0; |
109 |
} else if (strncmp(type_str, "tagged", strlen("tagged")) == 0 ) { |
110 |
type =1; |
111 |
} else if (strncmp(type_str, "expanded", strlen("expanded")) == 0 ) { |
112 |
type =2; |
113 |
} |
114 |
delete type_att; |
115 |
delete type_str; |
116 |
/* recover dimension */ |
117 |
int ndims=dataFile.num_dims(); |
118 |
int ntags =0 , num_samples =0 , num_data_points_per_sample =0, d=0; |
119 |
NcDim *d_dim, *tags_dim, *num_samples_dim, *num_data_points_per_sample_dim; |
120 |
/* recover shape */ |
121 |
DataArrayView::ShapeType shape; |
122 |
long dims[DataArrayView::maxRank+2]; |
123 |
if (rank>0) { |
124 |
if (! (d_dim=dataFile.get_dim("d0")) ) |
125 |
throw DataException("Error - load:: unable to recover d0 from netCDF file."); |
126 |
d=d_dim->size(); |
127 |
shape.push_back(d); |
128 |
dims[0]=d; |
129 |
} |
130 |
if (rank>1) { |
131 |
if (! (d_dim=dataFile.get_dim("d1")) ) |
132 |
throw DataException("Error - load:: unable to recover d1 from netCDF file."); |
133 |
d=d_dim->size(); |
134 |
shape.push_back(d); |
135 |
dims[1]=d; |
136 |
} |
137 |
if (rank>2) { |
138 |
if (! (d_dim=dataFile.get_dim("d2")) ) |
139 |
throw DataException("Error - load:: unable to recover d2 from netCDF file."); |
140 |
d=d_dim->size(); |
141 |
shape.push_back(d); |
142 |
dims[2]=d; |
143 |
} |
144 |
if (rank>3) { |
145 |
if (! (d_dim=dataFile.get_dim("d3")) ) |
146 |
throw DataException("Error - load:: unable to recover d3 from netCDF file."); |
147 |
d=d_dim->size(); |
148 |
shape.push_back(d); |
149 |
dims[3]=d; |
150 |
} |
151 |
/* recover stuff */ |
152 |
Data out; |
153 |
NcVar *var, *ids_var; |
154 |
if (type == 0) { |
155 |
/* constant data */ |
156 |
if ( ! ( (ndims == rank && rank >0) || ( ndims ==1 && rank == 0 ) ) ) |
157 |
throw DataException("Error - load:: illegal number of dimensions for constant data in netCDF file."); |
158 |
if (rank == 0) { |
159 |
if (! (d_dim=dataFile.get_dim("l")) ) |
160 |
throw DataException("Error - load:: unable to recover d0 for scalar constant data in netCDF file."); |
161 |
int d0 = d_dim->size(); |
162 |
if (! d0 == 1) |
163 |
throw DataException("Error - load:: d0 is expected to be one for scalar constant data in netCDF file."); |
164 |
dims[0]=1; |
165 |
} |
166 |
out=Data(0,shape,function_space); |
167 |
if (!(var = dataFile.get_var("data"))) |
168 |
throw DataException("Error - load:: unable to find data in netCDF file."); |
169 |
if (! var->get(&(out.getDataPoint(0,0).getData()[0]), dims) ) |
170 |
throw DataException("Error - load:: unable to recover data from netCDF file."); |
171 |
} else if (type == 1) { |
172 |
/* tagged data */ |
173 |
if ( ! (ndims == rank + 1) ) |
174 |
throw DataException("Error - load:: illegal number of dimensions for tagged data in netCDF file."); |
175 |
if (! (tags_dim=dataFile.get_dim("tags")) ) |
176 |
throw DataException("Error - load:: unable to recover number of tags from netCDF file."); |
177 |
ntags=tags_dim->size(); |
178 |
out=Data(0,shape,function_space); |
179 |
} else if (type == 2) { |
180 |
/* expanded data */ |
181 |
if ( ! (ndims == rank + 2) ) |
182 |
throw DataException("Error - load:: illegal number of dimensions for exanded data in netCDF file."); |
183 |
if ( ! (num_samples_dim = dataFile.get_dim("num_samples") ) ) |
184 |
throw DataException("Error - load:: unable to recover number of samples from netCDF file."); |
185 |
num_samples = num_samples_dim->size(); |
186 |
if ( ! (num_data_points_per_sample_dim = dataFile.get_dim("num_data_points_per_sample") ) ) |
187 |
throw DataException("Error - load:: unable to recover number of data points per sample from netCDF file."); |
188 |
num_data_points_per_sample=num_data_points_per_sample_dim->size(); |
189 |
// check shape: |
190 |
if ( ! (num_samples == function_space.getNumSamples() && num_data_points_per_sample == function_space.getNumDataPointsPerSample()) ) |
191 |
throw DataException("Error - load:: data sample layout of file does not match data layout of function space."); |
192 |
// get ids |
193 |
if (! ( ids_var = dataFile.get_var("id")) ) |
194 |
throw DataException("Error - load:: unable to find reference ids in netCDF file."); |
195 |
const int* ids_p=function_space.borrowSampleReferenceIDs(); |
196 |
int ids_of_nc[num_samples]; |
197 |
if (! ids_var->get(ids_of_nc, (long) num_samples) ) |
198 |
throw DataException("Error - load:: unable to recover ids from netCDF file."); |
199 |
// check order: |
200 |
int failed=-1, local_failed=-1, i; |
201 |
#pragma omp parallel private(local_failed) |
202 |
{ |
203 |
#pragma omp for private(i) schedule(static) |
204 |
for (i=0;i < num_samples; ++i) |
205 |
if (ids_of_nc[i]!=ids_p[i]) local_failed=i; |
206 |
#pragma omp critical |
207 |
if (local_failed>=0) failed = local_failed; |
208 |
} |
209 |
if (failed>=0) |
210 |
throw DataException("Error - load:: data ordering in netCDF file does not match ordering of FunctionSpace."); |
211 |
// get the data: |
212 |
dims[rank]=num_data_points_per_sample; |
213 |
dims[rank+1]=num_samples; |
214 |
out=Data(0,shape,function_space,true); |
215 |
if (!(var = dataFile.get_var("data"))) |
216 |
throw DataException("Error - load:: unable to find data in netCDF file."); |
217 |
if (! var->get(&(out.getDataPoint(0,0).getData()[0]), dims) ) |
218 |
throw DataException("Error - load:: unable to recover data from netCDF file."); |
219 |
// if (failed==-1) |
220 |
// out->m_data.reorderByReferenceIDs(ids_of_nc) |
221 |
} else { |
222 |
throw DataException("Error - load:: unknown escript data type in netCDF file."); |
223 |
} |
224 |
return out; |
225 |
|
226 |
} |
227 |
|
228 |
Data |
229 |
convertToData(const boost::python::object& value, |
230 |
const FunctionSpace& what) |
231 |
{ |
232 |
// first we try to extract a Data object from value |
233 |
extract<Data> value_data(value); |
234 |
if (value_data.check()) { |
235 |
Data extracted_data=value_data(); |
236 |
if (extracted_data.isEmpty()) { |
237 |
return extracted_data; |
238 |
} else { |
239 |
return Data(extracted_data,what); |
240 |
} |
241 |
} else { |
242 |
return Data(value,what); |
243 |
} |
244 |
} |
245 |
|
246 |
} // end of namespace |