/[escript]/branches/complex/escriptcore/src/WrappedArray.cpp
ViewVC logotype

Diff of /branches/complex/escriptcore/src/WrappedArray.cpp

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 5869 by jfenwick, Wed Jan 13 03:13:08 2016 UTC revision 5870 by jfenwick, Thu Jan 14 06:50:47 2016 UTC
# Line 82  void getObjShape(const boost::python::ob Line 82  void getObjShape(const boost::python::ob
82  }  }
83    
84  WrappedArray::WrappedArray(const boost::python::object& obj_in)  WrappedArray::WrappedArray(const boost::python::object& obj_in)
85  :obj(obj_in)  :obj(obj_in),iscomplex(false),scalar_r(nan("")),scalar_c(nan(""))
86  {  {
87      dat=0;      dat_r=0;
88        dat_c=0;
89      // First we check for scalars      // First we check for scalars
90      try      try
91      {      {
92         double v=extract<double>(obj_in);         extract<complextype> ec(obj_in);
93         m_scalar=v;         extract<double> er(obj_in);
94           if (er.check())      // check for double first because complex will fail this
95           {
96              scalar_r=er();
97           }
98           else
99           {
100              scalar_c=ec();
101              iscomplex=true;
102            
103           }
104         rank=0;         rank=0;
105         return;         return;
106      }      }
# Line 99  WrappedArray::WrappedArray(const boost:: Line 110  WrappedArray::WrappedArray(const boost::
110      }      }
111      try      try
112      {      {
113         double v=extract<double>(obj_in[make_tuple()]);         const boost::python::object obj_in_t=obj_in[make_tuple()];
114         m_scalar=v;         extract<complextype> ec(obj_in_t);
115           extract<double> er(obj_in_t);
116           if (er.check())
117           {        
118              scalar_r=er();
119            
120           }
121           else
122           {
123              scalar_c=ec();
124              iscomplex=true;
125           }      
126         rank=0;         rank=0;
127         return;         return;
128      }      }
# Line 110  WrappedArray::WrappedArray(const boost:: Line 132  WrappedArray::WrappedArray(const boost::
132      }      }
133    
134    
135      m_scalar=0;      scalar_c=0;
136        scalar_r=0;
137      checkFeatures(obj_in);      checkFeatures(obj_in);
138      getObjShape(obj,shape);      getObjShape(obj,shape);
139      rank=shape.size();      rank=shape.size();
# Line 175  WrappedArray::WrappedArray(const boost:: Line 198  WrappedArray::WrappedArray(const boost::
198                          convertNumpyArray<unsigned long>((const unsigned long*)arr->data, strides);                          convertNumpyArray<unsigned long>((const unsigned long*)arr->data, strides);
199                      }                      }
200                  }                  }
201                    else if (arr->typekind == 'c')
202                    {
203                        if (arr->itemsize==sizeof(complextype))
204                        {
205                            convertNumpyArrayC<complextype>((const complextype*)arr->data, strides);
206                            iscomplex=true;
207                        }
208                        // not accomodating other types of complex values
209                    }              
210              }              }
211          }          }
212      } catch (...)      } catch (...)
# Line 184  WrappedArray::WrappedArray(const boost:: Line 216  WrappedArray::WrappedArray(const boost::
216  #endif  #endif
217  }  }
218    
219    
220    template<typename T>
221    void WrappedArray::convertNumpyArrayC(const T* array, const std::vector<int>& strides) const
222    {
223        // this method is only called by the constructor above which does the
224        // necessary checks and initialisations
225        int size=DataTypes::noValues(shape);
226        dat_c=new complextype[size];
227        switch (rank)
228        {
229            case 1:
230    #pragma omp parallel for
231                for (int i=0;i<shape[0];i++)
232                {
233                    dat_c[i]=array[i*strides[0]];
234                }
235            break;
236            case 2:
237    #pragma omp parallel for
238                for (int i=0;i<shape[0];i++)
239                {
240                    for (int j=0;j<shape[1];j++)
241                    {
242                        dat_c[DataTypes::getRelIndex(shape,i,j)]=array[i*strides[0]+j*strides[1]];
243                    }
244                }
245            break;
246            case 3:
247    #pragma omp parallel for
248                for (int i=0;i<shape[0];i++)
249                {
250                    for (int j=0;j<shape[1];j++)
251                    {
252                        for (int k=0;k<shape[2];k++)
253                        {
254                            dat_c[DataTypes::getRelIndex(shape,i,j,k)]=array[i*strides[0]+j*strides[1]+k*strides[2]];
255                        }
256                    }
257                }
258            break;
259            case 4:
260    #pragma omp parallel for
261                for (int i=0;i<shape[0];i++)
262                {
263                    for (int j=0;j<shape[1];j++)
264                    {
265                        for (int k=0;k<shape[2];k++)
266                        {
267                            for (int m=0;m<shape[3];m++)
268                            {
269                                dat_c[DataTypes::getRelIndex(shape,i,j,k,m)]=array[i*strides[0]+j*strides[1]+k*strides[2]+m*strides[3]];
270                            }
271                        }
272                    }
273                }
274            break;
275        }
276    }
277    
278    
279  template<typename T>  template<typename T>
280  void WrappedArray::convertNumpyArray(const T* array, const std::vector<int>& strides) const  void WrappedArray::convertNumpyArray(const T* array, const std::vector<int>& strides) const
281  {  {
282      // this method is only called by the constructor above which does the      // this method is only called by the constructor above which does the
283      // necessary checks and initialisations      // necessary checks and initialisations
284      int size=DataTypes::noValues(shape);      int size=DataTypes::noValues(shape);
285      dat=new double[size];      dat_r=new double[size];
286      switch (rank)      switch (rank)
287      {      {
288          case 1:          case 1:
289  #pragma omp parallel for  #pragma omp parallel for
290              for (int i=0;i<shape[0];i++)              for (int i=0;i<shape[0];i++)
291              {              {
292                  dat[i]=array[i*strides[0]];                  dat_r[i]=array[i*strides[0]];
293              }              }
294          break;          break;
295          case 2:          case 2:
# Line 206  void WrappedArray::convertNumpyArray(con Line 298  void WrappedArray::convertNumpyArray(con
298              {              {
299                  for (int j=0;j<shape[1];j++)                  for (int j=0;j<shape[1];j++)
300                  {                  {
301                      dat[DataTypes::getRelIndex(shape,i,j)]=array[i*strides[0]+j*strides[1]];                      dat_r[DataTypes::getRelIndex(shape,i,j)]=array[i*strides[0]+j*strides[1]];
302                  }                  }
303              }              }
304          break;          break;
# Line 218  void WrappedArray::convertNumpyArray(con Line 310  void WrappedArray::convertNumpyArray(con
310                  {                  {
311                      for (int k=0;k<shape[2];k++)                      for (int k=0;k<shape[2];k++)
312                      {                      {
313                          dat[DataTypes::getRelIndex(shape,i,j,k)]=array[i*strides[0]+j*strides[1]+k*strides[2]];                          dat_r[DataTypes::getRelIndex(shape,i,j,k)]=array[i*strides[0]+j*strides[1]+k*strides[2]];
314                      }                      }
315                  }                  }
316              }              }
# Line 233  void WrappedArray::convertNumpyArray(con Line 325  void WrappedArray::convertNumpyArray(con
325                      {                      {
326                          for (int m=0;m<shape[3];m++)                          for (int m=0;m<shape[3];m++)
327                          {                          {
328                              dat[DataTypes::getRelIndex(shape,i,j,k,m)]=array[i*strides[0]+j*strides[1]+k*strides[2]+m*strides[3]];                              dat_r[DataTypes::getRelIndex(shape,i,j,k,m)]=array[i*strides[0]+j*strides[1]+k*strides[2]+m*strides[3]];
329                          }                          }
330                      }                      }
331                  }                  }
# Line 242  void WrappedArray::convertNumpyArray(con Line 334  void WrappedArray::convertNumpyArray(con
334      }      }
335  }  }
336    
337  void WrappedArray::convertArray() const  
338    void WrappedArray::convertArrayR() const
339  {  {
340      if ((dat!=0) || (rank<=0) || (rank>4))  // checking illegal rank here to avoid memory issues later      if ((converted) || (rank<=0) || (rank>4))   // checking illegal rank here to avoid memory issues later
341      {                   // yes the failure is silent here but not doing the copy      {                   // yes the failure is silent here but not doing the copy
342          return;             // will just cause an error to be raised later          return;             // will just cause an error to be raised later
343      }      }
# Line 294  void WrappedArray::convertArray() const Line 387  void WrappedArray::convertArray() const
387          ;  // do nothing          ;  // do nothing
388          // can't happen. We've already checked the bounds above          // can't happen. We've already checked the bounds above
389      }      }
390      dat=tdat;      dat_r=tdat;    
391  }      converted=true;
392    }  
393    
394  WrappedArray::~WrappedArray()  
395    void WrappedArray::convertArrayC() const
396  {  {
397      if (dat!=0)      if ((converted) || (rank<=0) || (rank>4))   // checking illegal rank here to avoid memory issues later
398        {                   // yes the failure is silent here but not doing the copy
399            return;             // will just cause an error to be raised later
400        }
401        int size=DataTypes::noValues(shape);
402        complextype* tdat=new complextype[size];
403        switch (rank)
404      {      {
405          delete[] dat;      case 1: for (int i=0;i<shape[0];i++)
406            {
407                tdat[i]=getElt(i);
408            }
409            break;
410        case 2: for (int i=0;i<shape[0];i++)
411            {
412                for (int j=0;j<shape[1];j++)
413                {
414                tdat[DataTypes::getRelIndex(shape,i,j)]=getElt(i,j);
415                }
416            }
417            break;
418        case 3: for (int i=0;i<shape[0];i++)
419            {
420                for (int j=0;j<shape[1];j++)
421                {
422                for (int k=0;k<shape[2];k++)
423                {
424                    tdat[DataTypes::getRelIndex(shape,i,j,k)]=getElt(i,j,k);
425                }
426                }
427            }
428            break;
429        case 4: for (int i=0;i<shape[0];i++)
430            {
431                for (int j=0;j<shape[1];j++)
432                {
433                for (int k=0;k<shape[2];k++)
434                {
435                    for (int m=0;m<shape[3];m++)
436                    {
437                        tdat[DataTypes::getRelIndex(shape,i,j,k,m)]=getElt(i,j,k,m);
438                    }
439                }
440                }
441            }
442            break;
443        default:
444            ;  // do nothing
445            // can't happen. We've already checked the bounds above
446      }      }
447        dat_c=tdat;    
448        converted=true;
449    }  
450    
451    
452    void WrappedArray::convertArray() const
453    {
454        if (iscomplex)
455        {
456        convertArrayC();
457        }
458        else
459        {
460        convertArrayR();
461        }
462    }
463    
464    WrappedArray::~WrappedArray()
465    {
466        if (dat_r!=0)
467        {
468        delete[] dat_r;
469        }
470        if (dat_c!=0)
471        {
472        delete[] dat_c;
473        }
474  }  }
475    
476    

Legend:
Removed from v.5869  
changed lines
  Added in v.5870

  ViewVC Help
Powered by ViewVC 1.1.26