/[escript]/trunk/escript/src/DataMaths.cpp
ViewVC logotype

Contents of /trunk/escript/src/DataMaths.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3911 - (show annotations)
Thu Jun 14 01:01:03 2012 UTC (7 years, 4 months ago) by jfenwick
File size: 10102 byte(s)
Copyright changes
1
2 /*******************************************************
3 *
4 * Copyright (c) 2003-2012 by University of Queensland
5 * Earth Systems Science Computational Center (ESSCC)
6 * http://www.uq.edu.au/esscc
7 *
8 * Primary Business: Queensland, Australia
9 * Licensed under the Open Software License version 3.0
10 * http://www.opensource.org/licenses/osl-3.0.php
11 *
12 *******************************************************/
13
14
15 #include "DataTypes.h"
16 #include "DataMaths.h"
17 #include <sstream>
18
19 namespace
20 {
21 const int SUCCESS=0;
22 const int BADRANK=1;
23 const int NOTSQUARE=2;
24 const int SHAPEMISMATCH=3;
25 const int NOINVERSE=4;
26 const int NEEDLAPACK=5;
27 const int ERRFACTORISE=6;
28 const int ERRINVERT=7;
29 }
30
31 namespace escript
32 {
33 namespace DataMaths
34 {
35
36 void
37 matMult(const DataTypes::ValueType& left,
38 const DataTypes::ShapeType& leftShape,
39 DataTypes::ValueType::size_type leftOffset,
40 const DataTypes::ValueType& right,
41 const DataTypes::ShapeType& rightShape,
42 DataTypes::ValueType::size_type rightOffset,
43 DataTypes::ValueType& result,
44 const DataTypes::ShapeType& resultShape)
45 {
46 using namespace escript::DataTypes;
47 using namespace std;
48
49 int leftRank=getRank(leftShape);
50 int rightRank=getRank(rightShape);
51 int resultRank=getRank(resultShape);
52 if (leftRank==0 || rightRank==0) {
53 stringstream temp;
54 temp << "Error - (matMult) Invalid for rank 0 objects.";
55 throw DataException(temp.str());
56 }
57
58 if (leftShape[leftRank-1] != rightShape[0]) {
59 stringstream temp;
60 temp << "Error - (matMult) Dimension: " << leftRank
61 << ", size: " << leftShape[leftRank-1]
62 << " of LHS and dimension: 1, size: " << rightShape[0]
63 << " of RHS don't match.";
64 throw DataException(temp.str());
65 }
66
67 int outputRank = leftRank+rightRank-2;
68
69 if (outputRank < 0) {
70 stringstream temp;
71 temp << "Error - (matMult) LHS and RHS cannot be multiplied "
72 << "as they have incompatable rank.";
73 throw DataException(temp.str());
74 }
75
76 if (outputRank != resultRank) {
77 stringstream temp;
78 temp << "Error - (matMult) Rank of result array is: "
79 << resultRank
80 << " it must be: " << outputRank;
81 throw DataException(temp.str());
82 }
83
84 for (int i=0; i<(leftRank-1); i++) {
85 if (leftShape[i] != resultShape[i]) {
86 stringstream temp;
87 temp << "Error - (matMult) Dimension: " << i
88 << " of LHS and result array don't match.";
89 throw DataException(temp.str());
90 }
91 }
92
93 for (int i=1; i<rightRank; i++) {
94 if (rightShape[i] != resultShape[i+leftRank-2]) {
95 stringstream temp;
96 temp << "Error - (matMult) Dimension: " << i
97 << ", size: " << rightShape[i]
98 << " of RHS and dimension: " << i+leftRank-1
99 << ", size: " << resultShape[i+leftRank-1]
100 << " of result array don't match.";
101 throw DataException(temp.str());
102 }
103 }
104
105 switch (leftRank) {
106
107 case 1:
108 switch (rightRank) {
109 case 1:
110 result[0]=0;
111 for (int i=0;i<leftShape[0];i++) {
112 result[0]+=left[i+leftOffset]*right[i+rightOffset];
113 }
114 break;
115 case 2:
116 for (int i=0;i<resultShape[0];i++) {
117 result[i]=0;
118 for (int j=0;j<rightShape[0];j++) {
119 result[i]+=left[j+leftOffset]*right[getRelIndex(rightShape,j,i)+rightOffset];
120 }
121 }
122 break;
123 default:
124 stringstream temp; temp << "Error - (matMult) Invalid rank. Programming error.";
125 throw DataException(temp.str());
126 break;
127 }
128 break;
129
130 case 2:
131 switch (rightRank) {
132 case 1:
133 result[0]=0;
134 for (int i=0;i<leftShape[0];i++) {
135 result[i]=0;
136 for (int j=0;j<leftShape[1];j++) {
137 result[i]+=left[leftOffset+getRelIndex(leftShape,i,j)]*right[i+rightOffset];
138 }
139 }
140 break;
141 case 2:
142 for (int i=0;i<resultShape[0];i++) {
143 for (int j=0;j<resultShape[1];j++) {
144 result[getRelIndex(resultShape,i,j)]=0;
145 for (int jR=0;jR<rightShape[0];jR++) {
146 result[getRelIndex(resultShape,i,j)]+=left[leftOffset+getRelIndex(leftShape,i,jR)]*right[rightOffset+getRelIndex(rightShape,jR,j)];
147 }
148 }
149 }
150 break;
151 default:
152 stringstream temp; temp << "Error - (matMult) Invalid rank. Programming error.";
153 throw DataException(temp.str());
154 break;
155 }
156 break;
157
158 default:
159 stringstream temp; temp << "Error - (matMult) Not supported for rank: " << leftRank;
160 throw DataException(temp.str());
161 break;
162 }
163
164 }
165
166
167 DataTypes::ShapeType
168 determineResultShape(const DataTypes::ShapeType& left,
169 const DataTypes::ShapeType& right)
170 {
171 DataTypes::ShapeType result;
172 for (int i=0; i<(DataTypes::getRank(left)-1); i++) {
173 result.push_back(left[i]);
174 }
175 for (int i=1; i<DataTypes::getRank(right); i++) {
176 result.push_back(right[i]);
177 }
178 return result;
179 }
180
181
182
183
184 void matrixInverseError(int err)
185 {
186 switch (err)
187 {
188 case 0: break; // not an error
189 case BADRANK: throw DataException("matrix_inverse: input and output must be rank 2.");
190 case NOTSQUARE: throw DataException("matrix_inverse: matrix must be square.");
191 case SHAPEMISMATCH: throw DataException("matrix_inverse: programmer error input and output must be the same shape.");
192 case NOINVERSE: throw DataException("matrix_inverse: argument not invertible.");
193 case NEEDLAPACK:throw DataException("matrix_inverse: matrices larger than 3x3 require lapack support.");
194 case ERRFACTORISE: throw DataException("matrix_inverse: argument not invertible (factorise stage).");
195 case ERRINVERT: throw DataException("matrix_inverse: argument not invertible (inverse stage).");
196 default:
197 throw DataException("matrix_inverse: unknown error.");
198 }
199 }
200
201
202
203 // Copied from the python version in util.py
204 int
205 matrix_inverse(const DataTypes::ValueType& in,
206 const DataTypes::ShapeType& inShape,
207 DataTypes::ValueType::size_type inOffset,
208 DataTypes::ValueType& out,
209 const DataTypes::ShapeType& outShape,
210 DataTypes::ValueType::size_type outOffset,
211 int count,
212 LapackInverseHelper& helper)
213 {
214 using namespace DataTypes;
215 using namespace std;
216 int inRank=getRank(inShape);
217 int outRank=getRank(outShape);
218 int size=DataTypes::noValues(inShape);
219 if ((inRank!=2) || (outRank!=2))
220 {
221 return BADRANK;
222 }
223 if (inShape[0]!=inShape[1])
224 {
225 return NOTSQUARE;
226 }
227 if (inShape!=outShape)
228 {
229 return SHAPEMISMATCH;
230 }
231 if (inShape[0]==1)
232 {
233 for (int i=0;i<count;++i)
234 {
235 if (in[inOffset+i]!=0)
236 {
237 out[outOffset+i]=1/in[inOffset+i];
238 }
239 else
240 {
241 return NOINVERSE;
242 }
243 }
244 }
245 else if (inShape[0]==2)
246 {
247 int step=0;
248 for (int i=0;i<count;++i)
249 {
250 double A11=in[inOffset+step+getRelIndex(inShape,0,0)];
251 double A12=in[inOffset+step+getRelIndex(inShape,0,1)];
252 double A21=in[inOffset+step+getRelIndex(inShape,1,0)];
253 double A22=in[inOffset+step+getRelIndex(inShape,1,1)];
254 double D = A11*A22-A12*A21;
255 if (D!=0)
256 {
257 D=1/D;
258 out[outOffset+step+getRelIndex(inShape,0,0)]= A22*D;
259 out[outOffset+step+getRelIndex(inShape,1,0)]=-A21*D;
260 out[outOffset+step+getRelIndex(inShape,0,1)]=-A12*D;
261 out[outOffset+step+getRelIndex(inShape,1,1)]= A11*D;
262 }
263 else
264 {
265 return NOINVERSE;
266 }
267 step+=size;
268 }
269 }
270 else if (inShape[0]==3)
271 {
272 int step=0;
273 for (int i=0;i<count;++i)
274 {
275 double A11=in[inOffset+step+getRelIndex(inShape,0,0)];
276 double A21=in[inOffset+step+getRelIndex(inShape,1,0)];
277 double A31=in[inOffset+step+getRelIndex(inShape,2,0)];
278 double A12=in[inOffset+step+getRelIndex(inShape,0,1)];
279 double A22=in[inOffset+step+getRelIndex(inShape,1,1)];
280 double A32=in[inOffset+step+getRelIndex(inShape,2,1)];
281 double A13=in[inOffset+step+getRelIndex(inShape,0,2)];
282 double A23=in[inOffset+step+getRelIndex(inShape,1,2)];
283 double A33=in[inOffset+step+getRelIndex(inShape,2,2)];
284 double D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22);
285 if (D!=0)
286 {
287 D=1/D;
288 out[outOffset+step+getRelIndex(inShape,0,0)]=(A22*A33-A23*A32)*D;
289 out[outOffset+step+getRelIndex(inShape,1,0)]=(A31*A23-A21*A33)*D;
290 out[outOffset+step+getRelIndex(inShape,2,0)]=(A21*A32-A31*A22)*D;
291 out[outOffset+step+getRelIndex(inShape,0,1)]=(A13*A32-A12*A33)*D;
292 out[outOffset+step+getRelIndex(inShape,1,1)]=(A11*A33-A31*A13)*D;
293 out[outOffset+step+getRelIndex(inShape,2,1)]=(A12*A31-A11*A32)*D;
294 out[outOffset+step+getRelIndex(inShape,0,2)]=(A12*A23-A13*A22)*D;
295 out[outOffset+step+getRelIndex(inShape,1,2)]=(A13*A21-A11*A23)*D;
296 out[outOffset+step+getRelIndex(inShape,2,2)]=(A11*A22-A12*A21)*D;
297 }
298 else
299 {
300 return NOINVERSE;
301 }
302 step+=size;
303 }
304 }
305 else // inShape[0] >3 (or negative but that can hopefully never happen)
306 {
307 #ifndef USE_LAPACK
308 return NEEDLAPACK;
309 #else
310 int step=0;
311
312
313 for (int i=0;i<count;++i)
314 {
315 // need to make a copy since blas overwrites its input
316 for (int j=0;j<size;++j)
317 {
318 out[outOffset+step+j]=in[inOffset+step+j];
319 }
320 double* arr=&(out[outOffset+step]);
321 int res=helper.invert(arr);
322 if (res!=0)
323 {
324 return res;
325 }
326 step+=size;
327 }
328 #endif
329 }
330 return SUCCESS;
331 }
332
333 } // end namespace
334 } // end namespace
335

  ViewVC Help
Powered by ViewVC 1.1.26