/[escript]/trunk/escript/src/WrappedArray.cpp
ViewVC logotype

Contents of /trunk/escript/src/WrappedArray.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3981 - (show annotations)
Fri Sep 21 02:47:54 2012 UTC (6 years, 11 months ago) by jfenwick
File size: 6585 byte(s)
First pass of updating copyright notices
1
2 /*****************************************************************************
3 *
4 * Copyright (c) 2003-2012 by University of Queensland
5 * http://www.uq.edu.au
6 *
7 * Primary Business: Queensland, Australia
8 * Licensed under the Open Software License version 3.0
9 * http://www.opensource.org/licenses/osl-3.0.php
10 *
11 * Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 * Development since 2012 by School of Earth Sciences
13 *
14 *****************************************************************************/
15
16 #include <boost/python/tuple.hpp>
17 #include "WrappedArray.h"
18 #include "DataException.h"
19 #if HAVE_NUMPY_H
20 #include <numpy/ndarrayobject.h>
21 #endif
22
23 #include <iostream>
24
25 using namespace escript;
26 using namespace boost::python;
27
28 namespace
29 {
30
31 void checkFeatures(const boost::python::object& obj)
32 {
33 using namespace std;
34 boost::python::object o2;
35 try
36 {
37 /*int len=*/ extract<int>(obj.attr("__len__")());
38 }
39 catch (...)
40 {
41 PyErr_Clear();
42 throw DataException("Object passed to WrappedArray must support __len__");
43 }
44 try
45 {
46 o2=obj.attr("__getitem__");
47 }
48 catch (...)
49 {
50 PyErr_Clear();
51 throw DataException("Object passed to WrappedArray must support __getitem__");
52 }
53 }
54
55 void getObjShape(const boost::python::object& obj, DataTypes::ShapeType& s)
56 {
57 int len=0;
58 try
59 {
60 len=extract<int>(obj.attr("__len__")());
61 }
62 catch(...)
63 {
64 PyErr_Clear(); // tell python the error isn't there anymore
65 return;
66 }
67 if (len<1)
68 {
69 throw DataException("Array filter - no empty components in arrays please.");
70 }
71 s.push_back(len);
72
73 if (s.size()>ESCRIPT_MAX_DATA_RANK)
74 {
75 throw DataException("Array filter - Maximum rank exceeded in array");
76 }
77 getObjShape(obj[0],s);
78 }
79
80 }
81
82 WrappedArray::WrappedArray(const boost::python::object& obj_in)
83 :obj(obj_in)
84 {
85 dat=0;
86 // First we check for scalars
87 try
88 {
89 double v=extract<double>(obj_in);
90 m_scalar=v;
91 rank=0;
92 return;
93 }
94 catch (...)
95 { // so we clear the failure
96 PyErr_Clear();
97 }
98 try
99 {
100 double v=extract<double>(obj_in[make_tuple()]);
101 m_scalar=v;
102 rank=0;
103 return;
104 }
105 catch (...)
106 { // so we clear the failure
107 PyErr_Clear();
108 }
109
110
111 m_scalar=0;
112 checkFeatures(obj_in);
113 getObjShape(obj,shape);
114 rank=shape.size();
115
116 #if HAVE_NUMPY_H
117 // if obj is a numpy array it is much faster to copy the array through the
118 // __array_struct__ interface instead of extracting single values from the
119 // components via getElt(). For this to work we check below that
120 // (1) this is a valid PyArrayInterface instance
121 // (2) the data is stored as a contiguous C array
122 // (3) the data type is suitable (correct type and byte size)
123 try
124 {
125 object o = (extract<object>(obj.attr("__array_struct__")));
126 if (PyCObject_Check(o.ptr()))
127 {
128 PyObject* cobj=(PyObject*)o.ptr();
129 PyArrayInterface* arr=(PyArrayInterface*)PyCObject_AsVoidPtr(cobj);
130 if (arr->two==2 && arr->flags&NPY_IN_ARRAY && arr->flags&NPY_NOTSWAPPED)
131 {
132 std::vector<int> strides;
133 // convert #bytes to #elements
134 for (int i=0; i<arr->nd; i++)
135 {
136 strides.push_back(arr->strides[i]/arr->itemsize);
137 }
138
139 if (arr->typekind == 'f')
140 {
141 if (arr->itemsize==sizeof(double))
142 {
143 convertNumpyArray<double>((const double*)arr->data, strides);
144 }
145 else if (arr->itemsize==sizeof(float))
146 {
147 convertNumpyArray<float>((const float*)arr->data, strides);
148 }
149 }
150 else if (arr->typekind == 'i')
151 {
152 if (arr->itemsize==sizeof(int))
153 {
154 convertNumpyArray<int>((const int*)arr->data, strides);
155 }
156 else if (arr->itemsize==sizeof(long))
157 {
158 convertNumpyArray<long>((const long*)arr->data, strides);
159 }
160 }
161 else if (arr->typekind == 'u')
162 {
163 if (arr->itemsize==sizeof(unsigned))
164 {
165 convertNumpyArray<unsigned>((const unsigned*)arr->data, strides);
166 }
167 else if (arr->itemsize==sizeof(unsigned long))
168 {
169 convertNumpyArray<unsigned long>((const unsigned long*)arr->data, strides);
170 }
171 }
172 }
173 }
174 } catch (...)
175 {
176 PyErr_Clear();
177 }
178 #endif
179 }
180
181 template<typename T>
182 void WrappedArray::convertNumpyArray(const T* array, const std::vector<int>& strides) const
183 {
184 // this method is only called by the constructor above which does the
185 // necessary checks and initialisations
186 int size=DataTypes::noValues(shape);
187 dat=new double[size];
188 switch (rank)
189 {
190 case 1:
191 #pragma omp parallel for
192 for (int i=0;i<shape[0];i++)
193 {
194 dat[i]=array[i*strides[0]];
195 }
196 break;
197 case 2:
198 #pragma omp parallel for
199 for (int i=0;i<shape[0];i++)
200 {
201 for (int j=0;j<shape[1];j++)
202 {
203 dat[DataTypes::getRelIndex(shape,i,j)]=array[i*strides[0]+j*strides[1]];
204 }
205 }
206 break;
207 case 3:
208 #pragma omp parallel for
209 for (int i=0;i<shape[0];i++)
210 {
211 for (int j=0;j<shape[1];j++)
212 {
213 for (int k=0;k<shape[2];k++)
214 {
215 dat[DataTypes::getRelIndex(shape,i,j,k)]=array[i*strides[0]+j*strides[1]+k*strides[2]];
216 }
217 }
218 }
219 break;
220 case 4:
221 #pragma omp parallel for
222 for (int i=0;i<shape[0];i++)
223 {
224 for (int j=0;j<shape[1];j++)
225 {
226 for (int k=0;k<shape[2];k++)
227 {
228 for (int m=0;m<shape[3];m++)
229 {
230 dat[DataTypes::getRelIndex(shape,i,j,k,m)]=array[i*strides[0]+j*strides[1]+k*strides[2]+m*strides[3]];
231 }
232 }
233 }
234 }
235 break;
236 }
237 }
238
239 void WrappedArray::convertArray() const
240 {
241 if ((dat!=0) || (rank<=0) || (rank>4)) // checking illegal rank here to avoid memory issues later
242 { // yes the failure is silent here but not doing the copy
243 return; // will just cause an error to be raised later
244 }
245 int size=DataTypes::noValues(shape);
246 double* tdat=new double[size];
247 switch (rank)
248 {
249 case 1: for (int i=0;i<shape[0];i++)
250 {
251 tdat[i]=getElt(i);
252 }
253 break;
254 case 2: for (int i=0;i<shape[0];i++)
255 {
256 for (int j=0;j<shape[1];j++)
257 {
258 tdat[DataTypes::getRelIndex(shape,i,j)]=getElt(i,j);
259 }
260 }
261 break;
262 case 3: for (int i=0;i<shape[0];i++)
263 {
264 for (int j=0;j<shape[1];j++)
265 {
266 for (int k=0;k<shape[2];k++)
267 {
268 tdat[DataTypes::getRelIndex(shape,i,j,k)]=getElt(i,j,k);
269 }
270 }
271 }
272 break;
273 case 4: for (int i=0;i<shape[0];i++)
274 {
275 for (int j=0;j<shape[1];j++)
276 {
277 for (int k=0;k<shape[2];k++)
278 {
279 for (int m=0;m<shape[3];m++)
280 {
281 tdat[DataTypes::getRelIndex(shape,i,j,k,m)]=getElt(i,j,k,m);
282 }
283 }
284 }
285 }
286 break;
287 default:
288 ; // do nothing
289 // can't happen. We've already checked the bounds above
290 }
291 dat=tdat;
292 }
293
294 WrappedArray::~WrappedArray()
295 {
296 if (dat!=0)
297 {
298 delete dat;
299 }
300 }
301
302

  ViewVC Help
Powered by ViewVC 1.1.26