/[escript]/branches/complex/escriptcore/src/WrappedArray.cpp
ViewVC logotype

Contents of /branches/complex/escriptcore/src/WrappedArray.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5870 - (show annotations)
Thu Jan 14 06:50:47 2016 UTC (3 years, 6 months ago) by jfenwick
File size: 10322 byte(s)
Made space in wrapped array for complex values as well.
Note: This code passes existing unit tests (for double values) but the complex functionality is 
_completely_ untested.


1
2 /*****************************************************************************
3 *
4 * Copyright (c) 2003-2016 by The University of Queensland
5 * http://www.uq.edu.au
6 *
7 * Primary Business: Queensland, Australia
8 * Licensed under the Open Software License version 3.0
9 * http://www.opensource.org/licenses/osl-3.0.php
10 *
11 * Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 * Development 2012-2013 by School of Earth Sciences
13 * Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 *
15 *****************************************************************************/
16
17 #include <boost/python/tuple.hpp>
18 #include "WrappedArray.h"
19 #include "DataException.h"
20 #if HAVE_NUMPY_H
21 #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
22 #include <numpy/ndarrayobject.h>
23 #endif
24
25 #include <iostream>
26
27 using namespace escript;
28 using namespace boost::python;
29
30 namespace
31 {
32
33 void checkFeatures(const boost::python::object& obj)
34 {
35 using namespace std;
36 boost::python::object o2;
37 try
38 {
39 /*int len=*/ extract<int>(obj.attr("__len__")());
40 }
41 catch (...)
42 {
43 PyErr_Clear();
44 throw DataException("Object passed to WrappedArray must support __len__");
45 }
46 try
47 {
48 o2=obj.attr("__getitem__");
49 }
50 catch (...)
51 {
52 PyErr_Clear();
53 throw DataException("Object passed to WrappedArray must support __getitem__");
54 }
55 }
56
57 void getObjShape(const boost::python::object& obj, DataTypes::ShapeType& s)
58 {
59 int len=0;
60 try
61 {
62 len=extract<int>(obj.attr("__len__")());
63 }
64 catch(...)
65 {
66 PyErr_Clear(); // tell python the error isn't there anymore
67 return;
68 }
69 if (len<1)
70 {
71 throw DataException("Array filter - no empty components in arrays please.");
72 }
73 s.push_back(len);
74
75 if (s.size()>ESCRIPT_MAX_DATA_RANK)
76 {
77 throw DataException("Array filter - Maximum rank exceeded in array");
78 }
79 getObjShape(obj[0],s);
80 }
81
82 }
83
84 WrappedArray::WrappedArray(const boost::python::object& obj_in)
85 :obj(obj_in),iscomplex(false),scalar_r(nan("")),scalar_c(nan(""))
86 {
87 dat_r=0;
88 dat_c=0;
89 // First we check for scalars
90 try
91 {
92 extract<complextype> ec(obj_in);
93 extract<double> er(obj_in);
94 if (er.check()) // check for double first because complex will fail this
95 {
96 scalar_r=er();
97 }
98 else
99 {
100 scalar_c=ec();
101 iscomplex=true;
102
103 }
104 rank=0;
105 return;
106 }
107 catch (...)
108 { // so we clear the failure
109 PyErr_Clear();
110 }
111 try
112 {
113 const boost::python::object obj_in_t=obj_in[make_tuple()];
114 extract<complextype> ec(obj_in_t);
115 extract<double> er(obj_in_t);
116 if (er.check())
117 {
118 scalar_r=er();
119
120 }
121 else
122 {
123 scalar_c=ec();
124 iscomplex=true;
125 }
126 rank=0;
127 return;
128 }
129 catch (...)
130 { // so we clear the failure
131 PyErr_Clear();
132 }
133
134
135 scalar_c=0;
136 scalar_r=0;
137 checkFeatures(obj_in);
138 getObjShape(obj,shape);
139 rank=shape.size();
140
141 #if HAVE_NUMPY_H
142 // if obj is a numpy array it is much faster to copy the array through the
143 // __array_struct__ interface instead of extracting single values from the
144 // components via getElt(). For this to work we check below that
145 // (1) this is a valid PyArrayInterface instance
146 // (2) the data is stored as a contiguous C array
147 // (3) the data type is suitable (correct type and byte size)
148 try
149 {
150 object o = (extract<object>(obj.attr("__array_struct__")));
151 if (PyCObject_Check(o.ptr()))
152 {
153 PyObject* cobj=(PyObject*)o.ptr();
154 PyArrayInterface* arr=(PyArrayInterface*)PyCObject_AsVoidPtr(cobj);
155 #ifndef NPY_1_7_API_VERSION
156 #define NPY_ARRAY_IN_ARRAY NPY_IN_ARRAY
157 #define NPY_ARRAY_NOTSWAPPED NPY_NOTSWAPPED
158 #endif
159 if (arr->two==2 && arr->flags&NPY_ARRAY_IN_ARRAY && arr->flags&NPY_ARRAY_NOTSWAPPED)
160 {
161 std::vector<int> strides;
162 // convert #bytes to #elements
163 for (int i=0; i<arr->nd; i++)
164 {
165 strides.push_back(arr->strides[i]/arr->itemsize);
166 }
167
168 if (arr->typekind == 'f')
169 {
170 if (arr->itemsize==sizeof(double))
171 {
172 convertNumpyArray<double>((const double*)arr->data, strides);
173 }
174 else if (arr->itemsize==sizeof(float))
175 {
176 convertNumpyArray<float>((const float*)arr->data, strides);
177 }
178 }
179 else if (arr->typekind == 'i')
180 {
181 if (arr->itemsize==sizeof(int))
182 {
183 convertNumpyArray<int>((const int*)arr->data, strides);
184 }
185 else if (arr->itemsize==sizeof(long))
186 {
187 convertNumpyArray<long>((const long*)arr->data, strides);
188 }
189 }
190 else if (arr->typekind == 'u')
191 {
192 if (arr->itemsize==sizeof(unsigned))
193 {
194 convertNumpyArray<unsigned>((const unsigned*)arr->data, strides);
195 }
196 else if (arr->itemsize==sizeof(unsigned long))
197 {
198 convertNumpyArray<unsigned long>((const unsigned long*)arr->data, strides);
199 }
200 }
201 else if (arr->typekind == 'c')
202 {
203 if (arr->itemsize==sizeof(complextype))
204 {
205 convertNumpyArrayC<complextype>((const complextype*)arr->data, strides);
206 iscomplex=true;
207 }
208 // not accomodating other types of complex values
209 }
210 }
211 }
212 } catch (...)
213 {
214 PyErr_Clear();
215 }
216 #endif
217 }
218
219
220 template<typename T>
221 void WrappedArray::convertNumpyArrayC(const T* array, const std::vector<int>& strides) const
222 {
223 // this method is only called by the constructor above which does the
224 // necessary checks and initialisations
225 int size=DataTypes::noValues(shape);
226 dat_c=new complextype[size];
227 switch (rank)
228 {
229 case 1:
230 #pragma omp parallel for
231 for (int i=0;i<shape[0];i++)
232 {
233 dat_c[i]=array[i*strides[0]];
234 }
235 break;
236 case 2:
237 #pragma omp parallel for
238 for (int i=0;i<shape[0];i++)
239 {
240 for (int j=0;j<shape[1];j++)
241 {
242 dat_c[DataTypes::getRelIndex(shape,i,j)]=array[i*strides[0]+j*strides[1]];
243 }
244 }
245 break;
246 case 3:
247 #pragma omp parallel for
248 for (int i=0;i<shape[0];i++)
249 {
250 for (int j=0;j<shape[1];j++)
251 {
252 for (int k=0;k<shape[2];k++)
253 {
254 dat_c[DataTypes::getRelIndex(shape,i,j,k)]=array[i*strides[0]+j*strides[1]+k*strides[2]];
255 }
256 }
257 }
258 break;
259 case 4:
260 #pragma omp parallel for
261 for (int i=0;i<shape[0];i++)
262 {
263 for (int j=0;j<shape[1];j++)
264 {
265 for (int k=0;k<shape[2];k++)
266 {
267 for (int m=0;m<shape[3];m++)
268 {
269 dat_c[DataTypes::getRelIndex(shape,i,j,k,m)]=array[i*strides[0]+j*strides[1]+k*strides[2]+m*strides[3]];
270 }
271 }
272 }
273 }
274 break;
275 }
276 }
277
278
279 template<typename T>
280 void WrappedArray::convertNumpyArray(const T* array, const std::vector<int>& strides) const
281 {
282 // this method is only called by the constructor above which does the
283 // necessary checks and initialisations
284 int size=DataTypes::noValues(shape);
285 dat_r=new double[size];
286 switch (rank)
287 {
288 case 1:
289 #pragma omp parallel for
290 for (int i=0;i<shape[0];i++)
291 {
292 dat_r[i]=array[i*strides[0]];
293 }
294 break;
295 case 2:
296 #pragma omp parallel for
297 for (int i=0;i<shape[0];i++)
298 {
299 for (int j=0;j<shape[1];j++)
300 {
301 dat_r[DataTypes::getRelIndex(shape,i,j)]=array[i*strides[0]+j*strides[1]];
302 }
303 }
304 break;
305 case 3:
306 #pragma omp parallel for
307 for (int i=0;i<shape[0];i++)
308 {
309 for (int j=0;j<shape[1];j++)
310 {
311 for (int k=0;k<shape[2];k++)
312 {
313 dat_r[DataTypes::getRelIndex(shape,i,j,k)]=array[i*strides[0]+j*strides[1]+k*strides[2]];
314 }
315 }
316 }
317 break;
318 case 4:
319 #pragma omp parallel for
320 for (int i=0;i<shape[0];i++)
321 {
322 for (int j=0;j<shape[1];j++)
323 {
324 for (int k=0;k<shape[2];k++)
325 {
326 for (int m=0;m<shape[3];m++)
327 {
328 dat_r[DataTypes::getRelIndex(shape,i,j,k,m)]=array[i*strides[0]+j*strides[1]+k*strides[2]+m*strides[3]];
329 }
330 }
331 }
332 }
333 break;
334 }
335 }
336
337
338 void WrappedArray::convertArrayR() const
339 {
340 if ((converted) || (rank<=0) || (rank>4)) // checking illegal rank here to avoid memory issues later
341 { // yes the failure is silent here but not doing the copy
342 return; // will just cause an error to be raised later
343 }
344 int size=DataTypes::noValues(shape);
345 double* tdat=new double[size];
346 switch (rank)
347 {
348 case 1: for (int i=0;i<shape[0];i++)
349 {
350 tdat[i]=getElt(i);
351 }
352 break;
353 case 2: for (int i=0;i<shape[0];i++)
354 {
355 for (int j=0;j<shape[1];j++)
356 {
357 tdat[DataTypes::getRelIndex(shape,i,j)]=getElt(i,j);
358 }
359 }
360 break;
361 case 3: for (int i=0;i<shape[0];i++)
362 {
363 for (int j=0;j<shape[1];j++)
364 {
365 for (int k=0;k<shape[2];k++)
366 {
367 tdat[DataTypes::getRelIndex(shape,i,j,k)]=getElt(i,j,k);
368 }
369 }
370 }
371 break;
372 case 4: for (int i=0;i<shape[0];i++)
373 {
374 for (int j=0;j<shape[1];j++)
375 {
376 for (int k=0;k<shape[2];k++)
377 {
378 for (int m=0;m<shape[3];m++)
379 {
380 tdat[DataTypes::getRelIndex(shape,i,j,k,m)]=getElt(i,j,k,m);
381 }
382 }
383 }
384 }
385 break;
386 default:
387 ; // do nothing
388 // can't happen. We've already checked the bounds above
389 }
390 dat_r=tdat;
391 converted=true;
392 }
393
394
395 void WrappedArray::convertArrayC() const
396 {
397 if ((converted) || (rank<=0) || (rank>4)) // checking illegal rank here to avoid memory issues later
398 { // yes the failure is silent here but not doing the copy
399 return; // will just cause an error to be raised later
400 }
401 int size=DataTypes::noValues(shape);
402 complextype* tdat=new complextype[size];
403 switch (rank)
404 {
405 case 1: for (int i=0;i<shape[0];i++)
406 {
407 tdat[i]=getElt(i);
408 }
409 break;
410 case 2: for (int i=0;i<shape[0];i++)
411 {
412 for (int j=0;j<shape[1];j++)
413 {
414 tdat[DataTypes::getRelIndex(shape,i,j)]=getElt(i,j);
415 }
416 }
417 break;
418 case 3: for (int i=0;i<shape[0];i++)
419 {
420 for (int j=0;j<shape[1];j++)
421 {
422 for (int k=0;k<shape[2];k++)
423 {
424 tdat[DataTypes::getRelIndex(shape,i,j,k)]=getElt(i,j,k);
425 }
426 }
427 }
428 break;
429 case 4: for (int i=0;i<shape[0];i++)
430 {
431 for (int j=0;j<shape[1];j++)
432 {
433 for (int k=0;k<shape[2];k++)
434 {
435 for (int m=0;m<shape[3];m++)
436 {
437 tdat[DataTypes::getRelIndex(shape,i,j,k,m)]=getElt(i,j,k,m);
438 }
439 }
440 }
441 }
442 break;
443 default:
444 ; // do nothing
445 // can't happen. We've already checked the bounds above
446 }
447 dat_c=tdat;
448 converted=true;
449 }
450
451
452 void WrappedArray::convertArray() const
453 {
454 if (iscomplex)
455 {
456 convertArrayC();
457 }
458 else
459 {
460 convertArrayR();
461 }
462 }
463
464 WrappedArray::~WrappedArray()
465 {
466 if (dat_r!=0)
467 {
468 delete[] dat_r;
469 }
470 if (dat_c!=0)
471 {
472 delete[] dat_c;
473 }
474 }
475
476

  ViewVC Help
Powered by ViewVC 1.1.26