/[escript]/trunk/escript/src/DataMaths.cpp
ViewVC logotype

Contents of /trunk/escript/src/DataMaths.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4286 - (show annotations)
Thu Mar 7 04:28:11 2013 UTC (6 years, 5 months ago) by caltinay
File size: 10221 byte(s)
Assorted spelling fixes.

1
2 /*****************************************************************************
3 *
4 * Copyright (c) 2003-2013 by University of Queensland
5 * http://www.uq.edu.au
6 *
7 * Primary Business: Queensland, Australia
8 * Licensed under the Open Software License version 3.0
9 * http://www.opensource.org/licenses/osl-3.0.php
10 *
11 * Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 * Development since 2012 by School of Earth Sciences
13 *
14 *****************************************************************************/
15
16
17 #include "DataTypes.h"
18 #include "DataMaths.h"
19 #include <sstream>
20
21 namespace
22 {
23 const int SUCCESS=0;
24 const int BADRANK=1;
25 const int NOTSQUARE=2;
26 const int SHAPEMISMATCH=3;
27 const int NOINVERSE=4;
28 const int NEEDLAPACK=5;
29 const int ERRFACTORISE=6;
30 const int ERRINVERT=7;
31 }
32
33 namespace escript
34 {
35 namespace DataMaths
36 {
37
38 void
39 matMult(const DataTypes::ValueType& left,
40 const DataTypes::ShapeType& leftShape,
41 DataTypes::ValueType::size_type leftOffset,
42 const DataTypes::ValueType& right,
43 const DataTypes::ShapeType& rightShape,
44 DataTypes::ValueType::size_type rightOffset,
45 DataTypes::ValueType& result,
46 const DataTypes::ShapeType& resultShape)
47 {
48 using namespace escript::DataTypes;
49 using namespace std;
50
51 int leftRank=getRank(leftShape);
52 int rightRank=getRank(rightShape);
53 int resultRank=getRank(resultShape);
54 if (leftRank==0 || rightRank==0) {
55 stringstream temp;
56 temp << "Error - (matMult) Invalid for rank 0 objects.";
57 throw DataException(temp.str());
58 }
59
60 if (leftShape[leftRank-1] != rightShape[0]) {
61 stringstream temp;
62 temp << "Error - (matMult) Dimension: " << leftRank
63 << ", size: " << leftShape[leftRank-1]
64 << " of LHS and dimension: 1, size: " << rightShape[0]
65 << " of RHS don't match.";
66 throw DataException(temp.str());
67 }
68
69 int outputRank = leftRank+rightRank-2;
70
71 if (outputRank < 0) {
72 stringstream temp;
73 temp << "Error - (matMult) LHS and RHS cannot be multiplied "
74 << "as they have incompatible rank.";
75 throw DataException(temp.str());
76 }
77
78 if (outputRank != resultRank) {
79 stringstream temp;
80 temp << "Error - (matMult) Rank of result array is: "
81 << resultRank
82 << " it must be: " << outputRank;
83 throw DataException(temp.str());
84 }
85
86 for (int i=0; i<(leftRank-1); i++) {
87 if (leftShape[i] != resultShape[i]) {
88 stringstream temp;
89 temp << "Error - (matMult) Dimension: " << i
90 << " of LHS and result array don't match.";
91 throw DataException(temp.str());
92 }
93 }
94
95 for (int i=1; i<rightRank; i++) {
96 if (rightShape[i] != resultShape[i+leftRank-2]) {
97 stringstream temp;
98 temp << "Error - (matMult) Dimension: " << i
99 << ", size: " << rightShape[i]
100 << " of RHS and dimension: " << i+leftRank-1
101 << ", size: " << resultShape[i+leftRank-1]
102 << " of result array don't match.";
103 throw DataException(temp.str());
104 }
105 }
106
107 switch (leftRank) {
108
109 case 1:
110 switch (rightRank) {
111 case 1:
112 result[0]=0;
113 for (int i=0;i<leftShape[0];i++) {
114 result[0]+=left[i+leftOffset]*right[i+rightOffset];
115 }
116 break;
117 case 2:
118 for (int i=0;i<resultShape[0];i++) {
119 result[i]=0;
120 for (int j=0;j<rightShape[0];j++) {
121 result[i]+=left[j+leftOffset]*right[getRelIndex(rightShape,j,i)+rightOffset];
122 }
123 }
124 break;
125 default:
126 stringstream temp; temp << "Error - (matMult) Invalid rank. Programming error.";
127 throw DataException(temp.str());
128 break;
129 }
130 break;
131
132 case 2:
133 switch (rightRank) {
134 case 1:
135 result[0]=0;
136 for (int i=0;i<leftShape[0];i++) {
137 result[i]=0;
138 for (int j=0;j<leftShape[1];j++) {
139 result[i]+=left[leftOffset+getRelIndex(leftShape,i,j)]*right[i+rightOffset];
140 }
141 }
142 break;
143 case 2:
144 for (int i=0;i<resultShape[0];i++) {
145 for (int j=0;j<resultShape[1];j++) {
146 result[getRelIndex(resultShape,i,j)]=0;
147 for (int jR=0;jR<rightShape[0];jR++) {
148 result[getRelIndex(resultShape,i,j)]+=left[leftOffset+getRelIndex(leftShape,i,jR)]*right[rightOffset+getRelIndex(rightShape,jR,j)];
149 }
150 }
151 }
152 break;
153 default:
154 stringstream temp; temp << "Error - (matMult) Invalid rank. Programming error.";
155 throw DataException(temp.str());
156 break;
157 }
158 break;
159
160 default:
161 stringstream temp; temp << "Error - (matMult) Not supported for rank: " << leftRank;
162 throw DataException(temp.str());
163 break;
164 }
165
166 }
167
168
169 DataTypes::ShapeType
170 determineResultShape(const DataTypes::ShapeType& left,
171 const DataTypes::ShapeType& right)
172 {
173 DataTypes::ShapeType result;
174 for (int i=0; i<(DataTypes::getRank(left)-1); i++) {
175 result.push_back(left[i]);
176 }
177 for (int i=1; i<DataTypes::getRank(right); i++) {
178 result.push_back(right[i]);
179 }
180 return result;
181 }
182
183
184
185
186 void matrixInverseError(int err)
187 {
188 switch (err)
189 {
190 case 0: break; // not an error
191 case BADRANK: throw DataException("matrix_inverse: input and output must be rank 2.");
192 case NOTSQUARE: throw DataException("matrix_inverse: matrix must be square.");
193 case SHAPEMISMATCH: throw DataException("matrix_inverse: programmer error input and output must be the same shape.");
194 case NOINVERSE: throw DataException("matrix_inverse: argument not invertible.");
195 case NEEDLAPACK:throw DataException("matrix_inverse: matrices larger than 3x3 require lapack support.");
196 case ERRFACTORISE: throw DataException("matrix_inverse: argument not invertible (factorise stage).");
197 case ERRINVERT: throw DataException("matrix_inverse: argument not invertible (inverse stage).");
198 default:
199 throw DataException("matrix_inverse: unknown error.");
200 }
201 }
202
203
204
205 // Copied from the python version in util.py
206 int
207 matrix_inverse(const DataTypes::ValueType& in,
208 const DataTypes::ShapeType& inShape,
209 DataTypes::ValueType::size_type inOffset,
210 DataTypes::ValueType& out,
211 const DataTypes::ShapeType& outShape,
212 DataTypes::ValueType::size_type outOffset,
213 int count,
214 LapackInverseHelper& helper)
215 {
216 using namespace DataTypes;
217 using namespace std;
218 int inRank=getRank(inShape);
219 int outRank=getRank(outShape);
220 int size=DataTypes::noValues(inShape);
221 if ((inRank!=2) || (outRank!=2))
222 {
223 return BADRANK;
224 }
225 if (inShape[0]!=inShape[1])
226 {
227 return NOTSQUARE;
228 }
229 if (inShape!=outShape)
230 {
231 return SHAPEMISMATCH;
232 }
233 if (inShape[0]==1)
234 {
235 for (int i=0;i<count;++i)
236 {
237 if (in[inOffset+i]!=0)
238 {
239 out[outOffset+i]=1/in[inOffset+i];
240 }
241 else
242 {
243 return NOINVERSE;
244 }
245 }
246 }
247 else if (inShape[0]==2)
248 {
249 int step=0;
250 for (int i=0;i<count;++i)
251 {
252 double A11=in[inOffset+step+getRelIndex(inShape,0,0)];
253 double A12=in[inOffset+step+getRelIndex(inShape,0,1)];
254 double A21=in[inOffset+step+getRelIndex(inShape,1,0)];
255 double A22=in[inOffset+step+getRelIndex(inShape,1,1)];
256 double D = A11*A22-A12*A21;
257 if (D!=0)
258 {
259 D=1/D;
260 out[outOffset+step+getRelIndex(inShape,0,0)]= A22*D;
261 out[outOffset+step+getRelIndex(inShape,1,0)]=-A21*D;
262 out[outOffset+step+getRelIndex(inShape,0,1)]=-A12*D;
263 out[outOffset+step+getRelIndex(inShape,1,1)]= A11*D;
264 }
265 else
266 {
267 return NOINVERSE;
268 }
269 step+=size;
270 }
271 }
272 else if (inShape[0]==3)
273 {
274 int step=0;
275 for (int i=0;i<count;++i)
276 {
277 double A11=in[inOffset+step+getRelIndex(inShape,0,0)];
278 double A21=in[inOffset+step+getRelIndex(inShape,1,0)];
279 double A31=in[inOffset+step+getRelIndex(inShape,2,0)];
280 double A12=in[inOffset+step+getRelIndex(inShape,0,1)];
281 double A22=in[inOffset+step+getRelIndex(inShape,1,1)];
282 double A32=in[inOffset+step+getRelIndex(inShape,2,1)];
283 double A13=in[inOffset+step+getRelIndex(inShape,0,2)];
284 double A23=in[inOffset+step+getRelIndex(inShape,1,2)];
285 double A33=in[inOffset+step+getRelIndex(inShape,2,2)];
286 double D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22);
287 if (D!=0)
288 {
289 D=1/D;
290 out[outOffset+step+getRelIndex(inShape,0,0)]=(A22*A33-A23*A32)*D;
291 out[outOffset+step+getRelIndex(inShape,1,0)]=(A31*A23-A21*A33)*D;
292 out[outOffset+step+getRelIndex(inShape,2,0)]=(A21*A32-A31*A22)*D;
293 out[outOffset+step+getRelIndex(inShape,0,1)]=(A13*A32-A12*A33)*D;
294 out[outOffset+step+getRelIndex(inShape,1,1)]=(A11*A33-A31*A13)*D;
295 out[outOffset+step+getRelIndex(inShape,2,1)]=(A12*A31-A11*A32)*D;
296 out[outOffset+step+getRelIndex(inShape,0,2)]=(A12*A23-A13*A22)*D;
297 out[outOffset+step+getRelIndex(inShape,1,2)]=(A13*A21-A11*A23)*D;
298 out[outOffset+step+getRelIndex(inShape,2,2)]=(A11*A22-A12*A21)*D;
299 }
300 else
301 {
302 return NOINVERSE;
303 }
304 step+=size;
305 }
306 }
307 else // inShape[0] >3 (or negative but that can hopefully never happen)
308 {
309 #ifndef USE_LAPACK
310 return NEEDLAPACK;
311 #else
312 int step=0;
313
314
315 for (int i=0;i<count;++i)
316 {
317 // need to make a copy since blas overwrites its input
318 for (int j=0;j<size;++j)
319 {
320 out[outOffset+step+j]=in[inOffset+step+j];
321 }
322 double* arr=&(out[outOffset+step]);
323 int res=helper.invert(arr);
324 if (res!=0)
325 {
326 return res;
327 }
328 step+=size;
329 }
330 #endif
331 }
332 return SUCCESS;
333 }
334
335 } // end namespace
336 } // end namespace
337

  ViewVC Help
Powered by ViewVC 1.1.26