/[escript]/branches/lapack2681/escript/src/DataMaths.cpp
ViewVC logotype

Contents of /branches/lapack2681/escript/src/DataMaths.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2740 - (show annotations)
Tue Nov 10 06:48:24 2009 UTC (9 years, 7 months ago) by jfenwick
File size: 10301 byte(s)


1
2 /*******************************************************
3 *
4 * Copyright (c) 2003-2009 by University of Queensland
5 * Earth Systems Science Computational Center (ESSCC)
6 * http://www.uq.edu.au/esscc
7 *
8 * Primary Business: Queensland, Australia
9 * Licensed under the Open Software License version 3.0
10 * http://www.opensource.org/licenses/osl-3.0.php
11 *
12 *******************************************************/
13
14
15 #include "DataTypes.h"
16 #include "DataMaths.h"
17 #include <sstream>
18
19 #ifdef USE_LAPACK
20 extern "C"
21 {
22 #include <clapack.h>
23 }
24 #endif
25
26 namespace
27 {
28 const int SUCCESS=0;
29 const int BADRANK=1;
30 const int NOTSQUARE=2;
31 const int SHAPEMISMATCH=3;
32 const int NOINVERSE=4;
33 const int NEEDLAPACK=5;
34 const int ERRFACTORISE=6;
35 const int ERRINVERT=7;
36 }
37
38 namespace escript
39 {
40 namespace DataMaths
41 {
42
43 void
44 matMult(const DataTypes::ValueType& left,
45 const DataTypes::ShapeType& leftShape,
46 DataTypes::ValueType::size_type leftOffset,
47 const DataTypes::ValueType& right,
48 const DataTypes::ShapeType& rightShape,
49 DataTypes::ValueType::size_type rightOffset,
50 DataTypes::ValueType& result,
51 const DataTypes::ShapeType& resultShape)
52 {
53 using namespace escript::DataTypes;
54 using namespace std;
55
56 int leftRank=getRank(leftShape);
57 int rightRank=getRank(rightShape);
58 int resultRank=getRank(resultShape);
59 if (leftRank==0 || rightRank==0) {
60 stringstream temp;
61 temp << "Error - (matMult) Invalid for rank 0 objects.";
62 throw DataException(temp.str());
63 }
64
65 if (leftShape[leftRank-1] != rightShape[0]) {
66 stringstream temp;
67 temp << "Error - (matMult) Dimension: " << leftRank
68 << ", size: " << leftShape[leftRank-1]
69 << " of LHS and dimension: 1, size: " << rightShape[0]
70 << " of RHS don't match.";
71 throw DataException(temp.str());
72 }
73
74 int outputRank = leftRank+rightRank-2;
75
76 if (outputRank < 0) {
77 stringstream temp;
78 temp << "Error - (matMult) LHS and RHS cannot be multiplied "
79 << "as they have incompatable rank.";
80 throw DataException(temp.str());
81 }
82
83 if (outputRank != resultRank) {
84 stringstream temp;
85 temp << "Error - (matMult) Rank of result array is: "
86 << resultRank
87 << " it must be: " << outputRank;
88 throw DataException(temp.str());
89 }
90
91 for (int i=0; i<(leftRank-1); i++) {
92 if (leftShape[i] != resultShape[i]) {
93 stringstream temp;
94 temp << "Error - (matMult) Dimension: " << i
95 << " of LHS and result array don't match.";
96 throw DataException(temp.str());
97 }
98 }
99
100 for (int i=1; i<rightRank; i++) {
101 if (rightShape[i] != resultShape[i+leftRank-2]) {
102 stringstream temp;
103 temp << "Error - (matMult) Dimension: " << i
104 << ", size: " << rightShape[i]
105 << " of RHS and dimension: " << i+leftRank-1
106 << ", size: " << resultShape[i+leftRank-1]
107 << " of result array don't match.";
108 throw DataException(temp.str());
109 }
110 }
111
112 switch (leftRank) {
113
114 case 1:
115 switch (rightRank) {
116 case 1:
117 result[0]=0;
118 for (int i=0;i<leftShape[0];i++) {
119 result[0]+=left[i+leftOffset]*right[i+rightOffset];
120 }
121 break;
122 case 2:
123 for (int i=0;i<resultShape[0];i++) {
124 result[i]=0;
125 for (int j=0;j<rightShape[0];j++) {
126 result[i]+=left[j+leftOffset]*right[getRelIndex(rightShape,j,i)+rightOffset];
127 }
128 }
129 break;
130 default:
131 stringstream temp; temp << "Error - (matMult) Invalid rank. Programming error.";
132 throw DataException(temp.str());
133 break;
134 }
135 break;
136
137 case 2:
138 switch (rightRank) {
139 case 1:
140 result[0]=0;
141 for (int i=0;i<leftShape[0];i++) {
142 result[i]=0;
143 for (int j=0;j<leftShape[1];j++) {
144 result[i]+=left[leftOffset+getRelIndex(leftShape,i,j)]*right[i+rightOffset];
145 }
146 }
147 break;
148 case 2:
149 for (int i=0;i<resultShape[0];i++) {
150 for (int j=0;j<resultShape[1];j++) {
151 result[getRelIndex(resultShape,i,j)]=0;
152 for (int jR=0;jR<rightShape[0];jR++) {
153 result[getRelIndex(resultShape,i,j)]+=left[leftOffset+getRelIndex(leftShape,i,jR)]*right[rightOffset+getRelIndex(rightShape,jR,j)];
154 }
155 }
156 }
157 break;
158 default:
159 stringstream temp; temp << "Error - (matMult) Invalid rank. Programming error.";
160 throw DataException(temp.str());
161 break;
162 }
163 break;
164
165 default:
166 stringstream temp; temp << "Error - (matMult) Not supported for rank: " << leftRank;
167 throw DataException(temp.str());
168 break;
169 }
170
171 }
172
173
174 DataTypes::ShapeType
175 determineResultShape(const DataTypes::ShapeType& left,
176 const DataTypes::ShapeType& right)
177 {
178 DataTypes::ShapeType result;
179 for (int i=0; i<(DataTypes::getRank(left)-1); i++) {
180 result.push_back(left[i]);
181 }
182 for (int i=1; i<DataTypes::getRank(right); i++) {
183 result.push_back(right[i]);
184 }
185 return result;
186 }
187
188
189
190
191 void matrixInverseError(int err)
192 {
193 switch (err)
194 {
195 case 0: break; // not an error
196 case BADRANK: throw DataException("matrix_inverse: input and output must be rank 2.");
197 case NOTSQUARE: throw DataException("matrix_inverse: matrix must be square.");
198 case SHAPEMISMATCH: throw DataException("matrix_inverse: programmer error input and output must be the same shape.");
199 case NOINVERSE: throw DataException("matrix_inverse: argument not invertible.");
200 case NEEDLAPACK:throw DataException("matrix_inverse: matrices larger than 3x3 require lapack support.");
201 case ERRFACTORISE: throw DataException("matrix_inverse: argument not invertible (factorise stage).");
202 case ERRINVERT: throw DataException("matrix_inverse: argument not invertible (inverse stage).");
203 default:
204 throw DataException("matrix_inverse: unknown error.");
205 }
206 }
207
208
209
210 // Copied from the python version in util.py
211 int
212 matrix_inverse(const DataTypes::ValueType& in,
213 const DataTypes::ShapeType& inShape,
214 DataTypes::ValueType::size_type inOffset,
215 DataTypes::ValueType& out,
216 const DataTypes::ShapeType& outShape,
217 DataTypes::ValueType::size_type outOffset,
218 int count,
219 int* piv)
220 {
221 using namespace DataTypes;
222 using namespace std;
223 int inRank=getRank(inShape);
224 int outRank=getRank(outShape);
225 int size=DataTypes::noValues(inShape);
226 if ((inRank!=2) || (outRank!=2))
227 {
228 return BADRANK;
229 }
230 if (inShape[0]!=inShape[1])
231 {
232 return NOTSQUARE;
233 }
234 if (inShape!=outShape)
235 {
236 return SHAPEMISMATCH;
237 }
238 if (inShape[0]==1)
239 {
240 for (int i=0;i<count;++i)
241 {
242 if (in[inOffset+i]!=0)
243 {
244 out[outOffset+i]=1/in[inOffset+i];
245 }
246 else
247 {
248 return NOINVERSE;
249 }
250 }
251 }
252 else if (inShape[0]==2)
253 {
254 int step=0;
255 for (int i=0;i<count;++i)
256 {
257 double A11=in[inOffset+step+getRelIndex(inShape,0,0)];
258 double A12=in[inOffset+step+getRelIndex(inShape,0,1)];
259 double A21=in[inOffset+step+getRelIndex(inShape,1,0)];
260 double A22=in[inOffset+step+getRelIndex(inShape,1,1)];
261 double D = A11*A22-A12*A21;
262 if (D!=0)
263 {
264 D=1/D;
265 out[outOffset+step+getRelIndex(inShape,0,0)]= A22*D;
266 out[outOffset+step+getRelIndex(inShape,1,0)]=-A21*D;
267 out[outOffset+step+getRelIndex(inShape,0,1)]=-A12*D;
268 out[outOffset+step+getRelIndex(inShape,1,1)]= A11*D;
269 }
270 else
271 {
272 return NOINVERSE;
273 }
274 step+=size;
275 }
276 }
277 else if (inShape[0]==3)
278 {
279 int step=0;
280 for (int i=0;i<count;++i)
281 {
282 double A11=in[inOffset+step+getRelIndex(inShape,0,0)];
283 double A21=in[inOffset+step+getRelIndex(inShape,1,0)];
284 double A31=in[inOffset+step+getRelIndex(inShape,2,0)];
285 double A12=in[inOffset+step+getRelIndex(inShape,0,1)];
286 double A22=in[inOffset+step+getRelIndex(inShape,1,1)];
287 double A32=in[inOffset+step+getRelIndex(inShape,2,1)];
288 double A13=in[inOffset+step+getRelIndex(inShape,0,2)];
289 double A23=in[inOffset+step+getRelIndex(inShape,1,2)];
290 double A33=in[inOffset+step+getRelIndex(inShape,2,2)];
291 double D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22);
292 if (D!=0)
293 {
294 D=1/D;
295 out[outOffset+step+getRelIndex(inShape,0,0)]=(A22*A33-A23*A32)*D;
296 out[outOffset+step+getRelIndex(inShape,1,0)]=(A31*A23-A21*A33)*D;
297 out[outOffset+step+getRelIndex(inShape,2,0)]=(A21*A32-A31*A22)*D;
298 out[outOffset+step+getRelIndex(inShape,0,1)]=(A13*A32-A12*A33)*D;
299 out[outOffset+step+getRelIndex(inShape,1,1)]=(A11*A33-A31*A13)*D;
300 out[outOffset+step+getRelIndex(inShape,2,1)]=(A12*A31-A11*A32)*D;
301 out[outOffset+step+getRelIndex(inShape,0,2)]=(A12*A23-A13*A22)*D;
302 out[outOffset+step+getRelIndex(inShape,1,2)]=(A13*A21-A11*A23)*D;
303 out[outOffset+step+getRelIndex(inShape,2,2)]=(A11*A22-A12*A21)*D;
304 }
305 else
306 {
307 return NOINVERSE;
308 }
309 step+=size;
310 }
311 }
312 else // inShape[0] >3 (or negative but that can hopefully never happen)
313 {
314 #ifndef USE_LAPACK
315 return NEEDLAPACK;
316 #else
317 int step=0;
318
319
320 for (int i=0;i<count;++i)
321 {
322 // need to make a copy since blas overwrites its input
323 for (int j=0;j<size;++j)
324 {
325 out[outOffset+step+j]=in[inOffset+step+j];
326 }
327 double* arr=&(out[outOffset+step]);
328 int res1=clapack_dgetrf(CblasColMajor, inShape[0], inShape[0], arr, inShape[0], piv);
329 if (res1) {return ERRFACTORISE;}
330 res1=clapack_dgetri(CblasColMajor, inShape[0], arr, inShape[0],piv);
331 if (res1) {return ERRINVERT;}
332 step+=size;
333 }
334 #endif
335 }
336 return SUCCESS;
337 }
338
339 } // end namespace
340 } // end namespace
341

  ViewVC Help
Powered by ViewVC 1.1.26