/[escript]/branches/trilinos_from_5897/escriptcore/src/LocalOps.h
ViewVC logotype

Diff of /branches/trilinos_from_5897/escriptcore/src/LocalOps.h

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 5962 by caltinay, Fri Feb 5 04:12:40 2016 UTC revision 5963 by caltinay, Mon Feb 22 06:59:27 2016 UTC
# Line 18  Line 18 
18  #if !defined escript_LocalOps_H  #if !defined escript_LocalOps_H
19  #define escript_LocalOps_H  #define escript_LocalOps_H
20  #include <cmath>  #include <cmath>
21    #include <complex>
22    #include "UnaryFuncs.h"
23    #include "DataTypes.h"
24    #include "DataException.h"
25  #ifndef M_PI  #ifndef M_PI
26  #   define M_PI           3.14159265358979323846  /* pi */  #   define M_PI           3.14159265358979323846  /* pi */
27  #endif  #endif
# Line 33  For operations on DataVector see DataMat Line 37  For operations on DataVector see DataMat
37    
38  namespace escript {  namespace escript {
39    
40    
41    typedef enum
42    {
43    SINF,
44    COSF,
45    TANF,
46    ASINF,
47    ACOSF,
48    ATANF,
49    SINHF,
50    COSHF,
51    TANHF,
52    ERFF,
53    ASINHF,
54    ACOSHF,
55    ATANHF,
56    LOG10F,
57    LOGF,
58    SIGNF,
59    ABSF,
60    EXPF,
61    SQRTF,
62    POWF,
63    PLUSF,
64    MINUSF,
65    MULTIPLIESF,
66    DIVIDESF,
67    LESSF,
68    GREATERF,
69    GREATER_EQUALF,
70    LESS_EQUALF,
71    EQZEROF,
72    NEQZEROF,
73    GTZEROF,
74    GEZEROF,
75    LTZEROF,
76    LEZEROF,
77    CONJF,
78    REALF,
79    IMAGF,
80    INVF
81    } ESFunction;
82    
83    bool always_real(ESFunction operation);
84    
85  /**  /**
86  \brief acts as a wrapper to isnan.  \brief acts as a wrapper to isnan.
87  \warning if compiler does not support FP_NAN this function will always return false.  \warning if compiler does not support FP_NAN this function will always return false.
# Line 40  namespace escript { Line 89  namespace escript {
89  inline  inline
90  bool nancheck(double d)  bool nancheck(double d)
91  {  {
92          // Q: so why not just test d!=d?                  // Q: so why not just test d!=d?
93          // A: Coz it doesn't always work [I've checked].                  // A: Coz it doesn't always work [I've checked].
94          // One theory is that the optimizer skips the test.                  // One theory is that the optimizer skips the test.
95      return std::isnan(d);   // isNan should be a function in C++ land      return std::isnan(d);       // isNan should be a function in C++ land
96  }  }
97    
98  /**  /**
# Line 88  void eigenvalues1(const double A00,doubl Line 137  void eigenvalues1(const double A00,doubl
137  inline  inline
138  void eigenvalues2(const double A00,const double A01,const double A11,  void eigenvalues2(const double A00,const double A01,const double A11,
139                   double* ev0, double* ev1) {                   double* ev0, double* ev1) {
140        const register double trA=(A00+A11)/2.;        const double trA=(A00+A11)/2.;
141        const register double A_00=A00-trA;        const double A_00=A00-trA;
142        const register double A_11=A11-trA;        const double A_11=A11-trA;
143        const register double s=sqrt(A01*A01-A_00*A_11);        const double s=sqrt(A01*A01-A_00*A_11);
144        *ev0=trA-s;        *ev0=trA-s;
145        *ev1=trA+s;        *ev1=trA+s;
146  }  }
# Line 115  void eigenvalues3(const double A00, cons Line 164  void eigenvalues3(const double A00, cons
164                                                       const double A22,                                                       const double A22,
165                   double* ev0, double* ev1,double* ev2) {                   double* ev0, double* ev1,double* ev2) {
166    
167        const register double trA=(A00+A11+A22)/3.;        const double trA=(A00+A11+A22)/3.;
168        const register double A_00=A00-trA;        const double A_00=A00-trA;
169        const register double A_11=A11-trA;        const double A_11=A11-trA;
170        const register double A_22=A22-trA;        const double A_22=A22-trA;
171        const register double A01_2=A01*A01;        const double A01_2=A01*A01;
172        const register double A02_2=A02*A02;        const double A02_2=A02*A02;
173        const register double A12_2=A12*A12;        const double A12_2=A12*A12;
174        const register double p=A02_2+A12_2+A01_2+(A_00*A_00+A_11*A_11+A_22*A_22)/2.;        const double p=A02_2+A12_2+A01_2+(A_00*A_00+A_11*A_11+A_22*A_22)/2.;
175        if (p<=0.) {        if (p<=0.) {
176           *ev2=trA;           *ev2=trA;
177           *ev1=trA;           *ev1=trA;
178           *ev0=trA;           *ev0=trA;
179    
180        } else {        } else {
181           const register double q=(A02_2*A_11+A12_2*A_00+A01_2*A_22)-(A_00*A_11*A_22+2*A01*A12*A02);           const double q=(A02_2*A_11+A12_2*A_00+A01_2*A_22)-(A_00*A_11*A_22+2*A01*A12*A02);
182           const register double sq_p=sqrt(p/3.);           const double sq_p=sqrt(p/3.);
183           register double z=-q/(2*pow(sq_p,3));           double z=-q/(2*pow(sq_p,3));
184           if (z<-1.) {           if (z<-1.) {
185              z=-1.;              z=-1.;
186           } else if (z>1.) {           } else if (z>1.) {
187              z=1.;              z=1.;
188           }           }
189           const register double alpha_3=acos(z)/3.;           const double alpha_3=acos(z)/3.;
190           *ev2=trA+2.*sq_p*cos(alpha_3);           *ev2=trA+2.*sq_p*cos(alpha_3);
191           *ev1=trA-2.*sq_p*cos(alpha_3+M_PI/3.);           *ev1=trA-2.*sq_p*cos(alpha_3+M_PI/3.);
192           *ev0=trA-2.*sq_p*cos(alpha_3-M_PI/3.);           *ev0=trA-2.*sq_p*cos(alpha_3-M_PI/3.);
# Line 174  inline Line 223  inline
223  void  vectorInKernel2(const double A00,const double A10,const double A01,const double A11,  void  vectorInKernel2(const double A00,const double A10,const double A01,const double A11,
224                        double* V0, double*V1)                        double* V0, double*V1)
225  {  {
226        register double absA00=fabs(A00);        double absA00=fabs(A00);
227        register double absA10=fabs(A10);        double absA10=fabs(A10);
228        register double absA01=fabs(A01);        double absA01=fabs(A01);
229        register double absA11=fabs(A11);        double absA11=fabs(A11);
230        register double m=absA11>absA10 ? absA11 : absA10;        double m=absA11>absA10 ? absA11 : absA10;
231        if (absA00>m || absA01>m) {        if (absA00>m || absA01>m) {
232           *V0=-A01;           *V0=-A01;
233           *V1=A00;           *V1=A00;
# Line 217  void  vectorInKernel3__nonZeroA00(const Line 266  void  vectorInKernel3__nonZeroA00(const
266                                  double* V0,double* V1,double* V2)                                  double* V0,double* V1,double* V2)
267  {  {
268      double TEMP0,TEMP1;      double TEMP0,TEMP1;
269      register const double I00=1./A00;      const double I00=1./A00;
270      register const double IA10=I00*A10;      const double IA10=I00*A10;
271      register const double IA20=I00*A20;      const double IA20=I00*A20;
272      vectorInKernel2(A11-IA10*A01,A12-IA10*A02,      vectorInKernel2(A11-IA10*A01,A12-IA10*A02,
273                      A21-IA20*A01,A22-IA20*A02,&TEMP0,&TEMP1);                      A21-IA20*A01,A22-IA20*A02,&TEMP0,&TEMP1);
274      *V0=-(A10*TEMP0+A20*TEMP1);      *V0=-(A10*TEMP0+A20*TEMP1);
# Line 252  void  eigenvalues_and_eigenvectors2(cons Line 301  void  eigenvalues_and_eigenvectors2(cons
301  {  {
302       double TEMP0,TEMP1;       double TEMP0,TEMP1;
303       eigenvalues2(A00,A01,A11,ev0,ev1);       eigenvalues2(A00,A01,A11,ev0,ev1);
304       const register double absev0=fabs(*ev0);       const double absev0=fabs(*ev0);
305       const register double absev1=fabs(*ev1);       const double absev1=fabs(*ev1);
306       register double max_ev=absev0>absev1 ? absev0 : absev1;       double max_ev=absev0>absev1 ? absev0 : absev1;
307       if (fabs((*ev0)-(*ev1))<tol*max_ev) {       if (fabs((*ev0)-(*ev1))<tol*max_ev) {
308          *V00=1.;          *V00=1.;
309          *V10=0.;          *V10=0.;
# Line 262  void  eigenvalues_and_eigenvectors2(cons Line 311  void  eigenvalues_and_eigenvectors2(cons
311          *V11=1.;          *V11=1.;
312       } else {       } else {
313          vectorInKernel2(A00-(*ev0),A01,A01,A11-(*ev0),&TEMP0,&TEMP1);          vectorInKernel2(A00-(*ev0),A01,A01,A11-(*ev0),&TEMP0,&TEMP1);
314          const register double scale=1./sqrt(TEMP0*TEMP0+TEMP1*TEMP1);          const double scale=1./sqrt(TEMP0*TEMP0+TEMP1*TEMP1);
315          if (TEMP0<0.) {          if (TEMP0<0.) {
316              *V00=-TEMP0*scale;              *V00=-TEMP0*scale;
317              *V10=-TEMP1*scale;              *V10=-TEMP1*scale;
# Line 302  void  eigenvalues_and_eigenvectors2(cons Line 351  void  eigenvalues_and_eigenvectors2(cons
351  inline  inline
352  void  normalizeVector3(double* V0,double* V1,double* V2)  void  normalizeVector3(double* V0,double* V1,double* V2)
353  {  {
354      register double s;      double s;
355      if (*V0>0) {      if (*V0>0) {
356          s=1./sqrt((*V0)*(*V0)+(*V1)*(*V1)+(*V2)*(*V2));          s=1./sqrt((*V0)*(*V0)+(*V1)*(*V1)+(*V2)*(*V2));
357          *V0*=s;          *V0*=s;
# Line 362  void  eigenvalues_and_eigenvectors3(cons Line 411  void  eigenvalues_and_eigenvectors3(cons
411                                      double* V02, double* V12, double* V22,                                      double* V02, double* V12, double* V22,
412                                      const double tol)                                      const double tol)
413  {  {
414        register const double absA01=fabs(A01);        const double absA01=fabs(A01);
415        register const double absA02=fabs(A02);        const double absA02=fabs(A02);
416        register const double m=absA01>absA02 ? absA01 : absA02;        const double m=absA01>absA02 ? absA01 : absA02;
417        if (m<=0) {        if (m<=0) {
418          double TEMP_V00,TEMP_V10,TEMP_V01,TEMP_V11,TEMP_EV0,TEMP_EV1;          double TEMP_V00,TEMP_V10,TEMP_V01,TEMP_V11,TEMP_EV0,TEMP_EV1;
419          eigenvalues_and_eigenvectors2(A11,A12,A22,          eigenvalues_and_eigenvectors2(A11,A12,A22,
# Line 473  void matrix_matrix_product(const int SL, Line 522  void matrix_matrix_product(const int SL,
522        for (int j=0; j<SR; j++) {        for (int j=0; j<SR; j++) {
523          double sum = 0.0;          double sum = 0.0;
524          for (int l=0; l<SM; l++) {          for (int l=0; l<SM; l++) {
525        sum += A[i+SL*l] * B[l+SM*j];            sum += A[i+SL*l] * B[l+SM*j];
526          }          }
527          C[i+SL*j] = sum;          C[i+SL*j] = sum;
528        }        }
# Line 484  void matrix_matrix_product(const int SL, Line 533  void matrix_matrix_product(const int SL,
533        for (int j=0; j<SR; j++) {        for (int j=0; j<SR; j++) {
534          double sum = 0.0;          double sum = 0.0;
535          for (int l=0; l<SM; l++) {          for (int l=0; l<SM; l++) {
536        sum += A[i*SM+l] * B[l+SM*j];            sum += A[i*SM+l] * B[l+SM*j];
537          }          }
538          C[i+SL*j] = sum;          C[i+SL*j] = sum;
539        }        }
# Line 495  void matrix_matrix_product(const int SL, Line 544  void matrix_matrix_product(const int SL,
544        for (int j=0; j<SR; j++) {        for (int j=0; j<SR; j++) {
545          double sum = 0.0;          double sum = 0.0;
546          for (int l=0; l<SM; l++) {          for (int l=0; l<SM; l++) {
547        sum += A[i+SL*l] * B[l*SR+j];            sum += A[i+SL*l] * B[l*SR+j];
548          }          }
549          C[i+SL*j] = sum;          C[i+SL*j] = sum;
550        }        }
# Line 505  void matrix_matrix_product(const int SL, Line 554  void matrix_matrix_product(const int SL,
554    
555  template <typename UnaryFunction>  template <typename UnaryFunction>
556  inline void tensor_unary_operation(const int size,  inline void tensor_unary_operation(const int size,
557                   const double *arg1,                               const double *arg1,
558                   double * argRes,                               double * argRes,
559                   UnaryFunction operation)                               UnaryFunction operation)
560  {  {
561    for (int i = 0; i < size; ++i) {    for (int i = 0; i < size; ++i) {
562      argRes[i] = operation(arg1[i]);      argRes[i] = operation(arg1[i]);
# Line 515  inline void tensor_unary_operation(const Line 564  inline void tensor_unary_operation(const
564    return;    return;
565  }  }
566    
567  template <typename BinaryFunction>  // ----------------------
568    
569    
570    // -------------------------------------
571    
572    template <typename BinaryFunction, typename T, typename U, typename V>
573  inline void tensor_binary_operation(const int size,  inline void tensor_binary_operation(const int size,
574                   const double *arg1,                               const T *arg1,
575                   const double *arg2,                               const U *arg2,
576                   double * argRes,                               V * argRes,
577                   BinaryFunction operation)                               BinaryFunction operation)
578  {  {
579    for (int i = 0; i < size; ++i) {    for (int i = 0; i < size; ++i) {
580      argRes[i] = operation(arg1[i], arg2[i]);      argRes[i] = operation(arg1[i], arg2[i]);
# Line 528  inline void tensor_binary_operation(cons Line 582  inline void tensor_binary_operation(cons
582    return;    return;
583  }  }
584    
585  template <typename BinaryFunction>  template <typename BinaryFunction, typename T, typename U, typename V>
586  inline void tensor_binary_operation(const int size,  inline void tensor_binary_operation(const int size,
587                   double arg1,                               T arg1,
588                   const double *arg2,                               const U *arg2,
589                   double *argRes,                               V *argRes,
590                   BinaryFunction operation)                               BinaryFunction operation)
591  {  {
592    for (int i = 0; i < size; ++i) {    for (int i = 0; i < size; ++i) {
593      argRes[i] = operation(arg1, arg2[i]);      argRes[i] = operation(arg1, arg2[i]);
# Line 541  inline void tensor_binary_operation(cons Line 595  inline void tensor_binary_operation(cons
595    return;    return;
596  }  }
597    
598  template <typename BinaryFunction>  template <typename BinaryFunction, typename T, typename U, typename V>
599  inline void tensor_binary_operation(const int size,  inline void tensor_binary_operation(const int size,
600                   const double *arg1,                               const T *arg1,
601                   double arg2,                               U arg2,
602                   double *argRes,                               V *argRes,
603                   BinaryFunction operation)                               BinaryFunction operation)
604  {  {
605    for (int i = 0; i < size; ++i) {    for (int i = 0; i < size; ++i) {
606      argRes[i] = operation(arg1[i], arg2);      argRes[i] = operation(arg1[i], arg2);
# Line 554  inline void tensor_binary_operation(cons Line 608  inline void tensor_binary_operation(cons
608    return;    return;
609  }  }
610    
611    // following the form of negate from <functional>
612    template <typename T>
613    struct sin_func
614    {
615        T operator() (const T& x) const {return sin(x);}
616        typedef T argument_type;
617        typedef T result_type;
618    };
619    
620    template <typename T>
621    struct cos_func
622    {
623        T operator() (const T& x) const {return cos(x);}
624        typedef T argument_type;
625        typedef T result_type;
626    };
627    
628    template <typename T>
629    struct tan_func
630    {
631        T operator() (const T& x) const {return tan(x);}
632        typedef T argument_type;
633        typedef T result_type;
634    };
635    
636    template <typename T>
637    struct asin_func
638    {
639        T operator() (const T& x) const {return asin(x);}
640        typedef T argument_type;
641        typedef T result_type;
642    };
643    
644    template <typename T>
645    struct acos_func
646    {
647        T operator() (const T& x) const {return acos(x);}
648        typedef T argument_type;
649        typedef T result_type;
650    };
651    
652    template <typename T>
653    struct atan_func
654    {
655        T operator() (const T& x) const {return atan(x);}
656        typedef T argument_type;
657        typedef T result_type;
658    };
659    
660    template <typename T>
661    struct sinh_func
662    {
663        T operator() (const T& x) const {return sinh(x);}
664        typedef T argument_type;
665        typedef T result_type;
666    };
667    
668    template <typename T>
669    struct cosh_func
670    {
671        T operator() (const T& x) const {return cosh(x);}
672        typedef T argument_type;
673        typedef T result_type;
674    };
675    
676    
677    template <typename T>
678    struct tanh_func
679    {
680        T operator() (const T& x) const {return tanh(x);}
681        typedef T argument_type;
682        typedef T result_type;
683    };
684    
685    #if defined (_WIN32) && !defined(__INTEL_COMPILER)
686    #else
687    template <typename T>
688    struct erf_func
689    {
690        T operator() (const T& x) const {return ::erf(x);}
691        typedef T argument_type;
692        typedef T result_type;
693    };
694    
695    template <>
696    struct erf_func<escript::DataTypes::cplx_t>             // dummy instantiation
697    {
698        DataTypes::cplx_t operator() (const DataTypes::cplx_t& x) const {return makeNaN();}
699        typedef DataTypes::cplx_t argument_type;
700        typedef DataTypes::cplx_t result_type;
701    };
702    
703    #endif
704        
705    template <typename T>
706    struct asinh_func
707    {
708        T operator() (const T& x) const
709        {
710    #if defined (_WIN32) && !defined(__INTEL_COMPILER)
711        return escript::asinh_substitute(x);
712    #else
713        return asinh(x);
714    #endif      
715        }
716        typedef T argument_type;
717        typedef T result_type;
718    };
719    
720    template <typename T>
721    struct acosh_func
722    {
723        T operator() (const T& x) const
724        {
725    #if defined (_WIN32) && !defined(__INTEL_COMPILER)
726        return escript::acosh_substitute(x);
727    #else
728        return acosh(x);
729    #endif
730        }
731        typedef T argument_type;
732        typedef T result_type;
733    };
734    
735    template <typename T>
736    struct atanh_func
737    {
738        T operator() (const T& x) const
739        {
740    #if defined (_WIN32) && !defined(__INTEL_COMPILER)
741        return escript::atanh_substitute(x);
742    #else
743        return atanh(x);
744    #endif
745        }    
746        typedef T argument_type;
747        typedef T result_type;
748    };
749    
750    template <typename T>
751    struct log10_func
752    {
753        T operator() (const T& x) const {return log10(x);}
754        typedef T argument_type;
755        typedef T result_type;
756    };
757    
758    template <typename T>
759    struct log_func
760    {
761        T operator() (const T& x) const {return log(x);}
762        typedef T argument_type;
763        typedef T result_type;
764    };
765    
766    template <typename T>
767    struct sign_func
768    {
769        T operator() (const T& x) const {return escript::fsign(x);}
770        typedef T argument_type;
771        typedef T result_type;
772    };
773    
774    template <>
775    struct sign_func<DataTypes::cplx_t>     // dummy instantiation
776    {
777        DataTypes::cplx_t operator() (const DataTypes::cplx_t& x) const {return makeNaN();}
778        typedef DataTypes::cplx_t argument_type;
779        typedef DataTypes::cplx_t result_type;
780    };
781    
782    
783    
784    template <typename T>
785    struct abs_func
786    {
787        T operator() (const T& x) const {return fabs(x);}
788        typedef T argument_type;
789        typedef T result_type;
790    };
791    
792    template <typename T>
793    struct exp_func
794    {
795        T operator() (const T& x) const {return exp(x);}
796        typedef T argument_type;
797        typedef T result_type;
798    };
799    
800    template <typename T>
801    struct sqrt_func
802    {
803        T operator() (const T& x) const {return sqrt(x);}
804        typedef T argument_type;
805        typedef T result_type;
806    };
807        
808    // following the form of plus from <functional>
809    template <typename T, typename U, typename V>
810    struct pow_func
811    {
812        V operator() (const T& x, const U& y) const {return pow(static_cast<V>(x),static_cast<V>(y));}
813        typedef T first_argument_type;
814        typedef U second_argument_type;
815        typedef V result_type;
816    };
817    
818    // following the form of plus from <functional>
819    template <typename T, typename U, typename V>
820    struct plus_func
821    {
822        V operator() (const T& x, const U& y) const {return x+y;}
823        typedef T first_argument_type;
824        typedef U second_argument_type;
825        typedef V result_type;
826    };
827    
828    template <typename T, typename U, typename V>
829    struct minus_func
830    {
831        V operator() (const T& x, const U& y) const {return x-y;}
832        typedef T first_argument_type;
833        typedef U second_argument_type;
834        typedef V result_type;
835    };
836    
837    template <typename T, typename U, typename V>
838    struct multiplies_func
839    {
840        V operator() (const T& x, const U& y) const {return x*y;}
841        typedef T first_argument_type;
842        typedef U second_argument_type;
843        typedef V result_type;
844    };
845    
846    template <typename T, typename U, typename V>
847    struct divides_func
848    {
849        V operator() (const T& x, const U& y) const {return x/y;}
850        typedef T first_argument_type;
851        typedef U second_argument_type;
852        typedef V result_type;
853    };
854    
855    
856    // using this instead of ::less because that returns bool and we need a result type of T
857    template <typename T>
858    struct less_func
859    {
860        T operator() (const T& x, const T& y) const {return x<y;}
861        typedef T first_argument_type;
862        typedef T second_argument_type;
863        typedef T result_type;
864    };
865    
866    // using this instead of ::less because that returns bool and we need a result type of T
867    template <typename T>
868    struct greater_func
869    {
870        T operator() (const T& x, const T& y) const {return x>y;}
871        typedef T first_argument_type;
872        typedef T second_argument_type;
873        typedef T result_type;
874    };
875    
876    template <typename T>
877    struct greater_equal_func
878    {
879        T operator() (const T& x, const T& y) const {return x>=y;}
880        typedef T first_argument_type;
881        typedef T second_argument_type;
882        typedef T result_type;
883    };
884    
885    template <typename T>
886    struct less_equal_func
887    {
888        T operator() (const T& x, const T& y) const {return x<=y;}
889        typedef T first_argument_type;
890        typedef T second_argument_type;
891        typedef T result_type;
892    };
893    
894    template <typename T>
895    struct gtzero_func
896    {
897        T operator() (const T& x) const {return x>0;}
898        typedef T first_argument_type;
899        typedef T result_type;
900    };
901    
902    template <>
903    struct gtzero_func<DataTypes::cplx_t>           // to keep the templater happy
904    {
905        DataTypes::cplx_t operator() (const DataTypes::cplx_t& x) const {return makeNaN();}
906        typedef DataTypes::cplx_t first_argument_type;
907        typedef DataTypes::cplx_t result_type;
908    };
909    
910    
911    
912    template <typename T>
913    struct gezero_func
914    {
915        T operator() (const T& x) const {return x>=0;}
916        typedef T first_argument_type;
917        typedef T result_type;
918    };
919    
920    template <>
921    struct gezero_func<DataTypes::cplx_t>           // to keep the templater happy
922    {
923        DataTypes::cplx_t operator() (const DataTypes::cplx_t& x) const {return makeNaN();}
924        typedef DataTypes::cplx_t first_argument_type;
925        typedef DataTypes::cplx_t result_type;
926    };
927    
928    
929    template <typename T>
930    struct ltzero_func
931    {
932        T operator() (const T& x) const {return x<0;}
933        typedef T first_argument_type;
934        typedef T result_type;
935    };
936    
937    template <>
938    struct ltzero_func<DataTypes::cplx_t>           // to keep the templater happy
939    {
940        DataTypes::cplx_t operator() (const DataTypes::cplx_t& x) const {return makeNaN();}
941        typedef DataTypes::cplx_t first_argument_type;
942        typedef DataTypes::cplx_t result_type;
943    };
944    
945    
946    
947    template <typename T>
948    struct lezero_func
949    {
950        T operator() (const T& x) const {return x<=0;}
951        typedef T first_argument_type;
952        typedef T result_type;
953    };
954    
955    template <>
956    struct lezero_func<DataTypes::cplx_t>           // to keep the templater happy
957    {
958        DataTypes::cplx_t operator() (const DataTypes::cplx_t& x) const {return makeNaN();}
959        typedef DataTypes::cplx_t first_argument_type;
960        typedef DataTypes::cplx_t result_type;
961    };
962    
963    
964    template <class IN, typename OUT, class UnaryFunction>
965    inline void tensor_unary_operation_helper(const size_t size,
966                                 const IN *arg1,
967                                 OUT * argRes,
968                                 UnaryFunction operation)
969    {
970    
971      for (int i = 0; i < size; ++i) {
972        argRes[i] = operation(arg1[i]);
973      }
974    }
975    
976    
977    // deals with unary operations which return real, regardless of
978    // their input type
979    template <class IN>
980    inline void tensor_unary_array_operation_real(const size_t size,
981                                 const IN *arg1,
982                                 DataTypes::real_t * argRes,
983                                 escript::ESFunction operation,
984                                 DataTypes::real_t tol=0)
985    {
986       switch (operation)
987       {
988         case REALF:
989              for (int i = 0; i < size; ++i) {
990                  argRes[i] = std::real(arg1[i]);
991              }
992              break;          
993         case IMAGF:
994              for (int i = 0; i < size; ++i) {
995                  argRes[i] = std::imag(arg1[i]);
996              }
997              break;  
998        case EQZEROF:  
999              for (size_t i = 0; i < size; ++i) {
1000                  argRes[i] = (fabs(arg1[i])<=tol);
1001              }
1002              break;
1003        case NEQZEROF:
1004              for (size_t i = 0; i < size; ++i) {
1005                  argRes[i] = (fabs(arg1[i])>tol);
1006              }
1007              break;          
1008         default:
1009              throw DataException("Unsupported unary operation");      
1010       }  
1011    }
1012    
1013    
1014    
1015    // In most cases, IN and OUT will be the same
1016    // but not ruling out putting Re() and Im()
1017    // through this
1018    template <class IN, typename OUT>
1019    inline void tensor_unary_array_operation(const size_t size,
1020                                 const IN *arg1,
1021                                 OUT * argRes,
1022                                 escript::ESFunction operation,
1023                                 DataTypes::real_t tol=0)
1024    {
1025      switch (operation)
1026      {
1027        case SINF: tensor_unary_operation_helper(size, arg1, argRes, sin_func<IN>()); break;
1028        case COSF: tensor_unary_operation_helper(size, arg1, argRes, cos_func<IN>()); break;
1029        case TANF: tensor_unary_operation_helper(size, arg1, argRes, tan_func<IN>()); break;
1030        case ASINF: tensor_unary_operation_helper(size, arg1, argRes, asin_func<IN>()); break;
1031        case ACOSF: tensor_unary_operation_helper(size, arg1, argRes, acos_func<IN>()); break;
1032        case ATANF: tensor_unary_operation_helper(size, arg1, argRes, atan_func<IN>()); break;
1033        case SINHF: tensor_unary_operation_helper(size, arg1, argRes, sinh_func<IN>()); break;
1034        case COSHF: tensor_unary_operation_helper(size, arg1, argRes, cosh_func<IN>()); break;
1035        case TANHF: tensor_unary_operation_helper(size, arg1, argRes, tanh_func<IN>()); break;
1036        case ERFF: tensor_unary_operation_helper(size, arg1, argRes, erf_func<IN>()); break;
1037        case ASINHF: tensor_unary_operation_helper(size, arg1, argRes, asinh_func<IN>()); break;
1038        case ACOSHF: tensor_unary_operation_helper(size, arg1, argRes, acosh_func<IN>()); break;
1039        case ATANHF: tensor_unary_operation_helper(size, arg1, argRes, atanh_func<IN>()); break;
1040        case LOG10F: tensor_unary_operation_helper(size, arg1, argRes, log10_func<IN>()); break;
1041        case LOGF: tensor_unary_operation_helper(size, arg1, argRes, log_func<IN>()); break;
1042        case SIGNF: tensor_unary_operation_helper(size, arg1, argRes, sign_func<IN>()); break;
1043        case ABSF: tensor_unary_operation_helper(size, arg1, argRes, abs_func<IN>()); break;
1044        case EXPF: tensor_unary_operation_helper(size, arg1, argRes, exp_func<IN>()); break;
1045        case SQRTF: tensor_unary_operation_helper(size, arg1, argRes, sqrt_func<IN>()); break;
1046    
1047        case GTZEROF: tensor_unary_operation_helper(size, arg1, argRes, gtzero_func<IN>()); break;
1048        case GEZEROF: tensor_unary_operation_helper(size, arg1, argRes, gezero_func<IN>()); break;
1049        case LTZEROF: tensor_unary_operation_helper(size, arg1, argRes, ltzero_func<IN>()); break;
1050        case LEZEROF: tensor_unary_operation_helper(size, arg1, argRes, lezero_func<IN>()); break;  
1051        case CONJF:
1052              for (size_t i = 0; i < size; ++i) {
1053                  argRes[i] = static_cast<OUT>(std::conj(arg1[i]));
1054              }
1055              break;
1056        case INVF:
1057              for (size_t i = 0; i < size; ++i) {
1058                  argRes[i] = 1.0/arg1[i];
1059              }
1060              break;
1061              
1062        default:
1063          throw DataException("Unsupported unary operation");
1064      }
1065      return;
1066    }
1067    
1068    bool supports_cplx(escript::ESFunction operation);
1069    
1070    
1071  } // end of namespace  } // end of namespace
1072  #endif  #endif

Legend:
Removed from v.5962  
changed lines
  Added in v.5963

  ViewVC Help
Powered by ViewVC 1.1.26