/[escript]/branches/subworld2/escriptcore/src/DataMaths.cpp
ViewVC logotype

Contents of /branches/subworld2/escriptcore/src/DataMaths.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5504 - (show annotations)
Wed Mar 4 22:58:13 2015 UTC (4 years, 1 month ago) by jfenwick
File size: 10339 byte(s)
Again with a more up to date copy


1
2 /*****************************************************************************
3 *
4 * Copyright (c) 2003-2015 by University of Queensland
5 * http://www.uq.edu.au
6 *
7 * Primary Business: Queensland, Australia
8 * Licensed under the Open Software License version 3.0
9 * http://www.opensource.org/licenses/osl-3.0.php
10 *
11 * Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 * Development 2012-2013 by School of Earth Sciences
13 * Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 *
15 *****************************************************************************/
16
17 #define ESNEEDPYTHON
18 #include "esysUtils/first.h"
19
20 #include "DataTypes.h"
21 #include "DataMaths.h"
22 #include <sstream>
23
24 namespace
25 {
26 const int SUCCESS=0;
27 const int BADRANK=1;
28 const int NOTSQUARE=2;
29 const int SHAPEMISMATCH=3;
30 const int NOINVERSE=4;
31 const int NEEDLAPACK=5;
32 const int ERRFACTORISE=6;
33 const int ERRINVERT=7;
34 }
35
36 namespace escript
37 {
38 namespace DataMaths
39 {
40
41 void
42 matMult(const DataTypes::ValueType& left,
43 const DataTypes::ShapeType& leftShape,
44 DataTypes::ValueType::size_type leftOffset,
45 const DataTypes::ValueType& right,
46 const DataTypes::ShapeType& rightShape,
47 DataTypes::ValueType::size_type rightOffset,
48 DataTypes::ValueType& result,
49 const DataTypes::ShapeType& resultShape)
50 {
51 using namespace escript::DataTypes;
52 using namespace std;
53
54 int leftRank=getRank(leftShape);
55 int rightRank=getRank(rightShape);
56 int resultRank=getRank(resultShape);
57 if (leftRank==0 || rightRank==0) {
58 stringstream temp;
59 temp << "Error - (matMult) Invalid for rank 0 objects.";
60 throw DataException(temp.str());
61 }
62
63 if (leftShape[leftRank-1] != rightShape[0]) {
64 stringstream temp;
65 temp << "Error - (matMult) Dimension: " << leftRank
66 << ", size: " << leftShape[leftRank-1]
67 << " of LHS and dimension: 1, size: " << rightShape[0]
68 << " of RHS don't match.";
69 throw DataException(temp.str());
70 }
71
72 int outputRank = leftRank+rightRank-2;
73
74 if (outputRank < 0) {
75 stringstream temp;
76 temp << "Error - (matMult) LHS and RHS cannot be multiplied "
77 << "as they have incompatible rank.";
78 throw DataException(temp.str());
79 }
80
81 if (outputRank != resultRank) {
82 stringstream temp;
83 temp << "Error - (matMult) Rank of result array is: "
84 << resultRank
85 << " it must be: " << outputRank;
86 throw DataException(temp.str());
87 }
88
89 for (int i=0; i<(leftRank-1); i++) {
90 if (leftShape[i] != resultShape[i]) {
91 stringstream temp;
92 temp << "Error - (matMult) Dimension: " << i
93 << " of LHS and result array don't match.";
94 throw DataException(temp.str());
95 }
96 }
97
98 for (int i=1; i<rightRank; i++) {
99 if (rightShape[i] != resultShape[i+leftRank-2]) {
100 stringstream temp;
101 temp << "Error - (matMult) Dimension: " << i
102 << ", size: " << rightShape[i]
103 << " of RHS and dimension: " << i+leftRank-1
104 << ", size: " << resultShape[i+leftRank-1]
105 << " of result array don't match.";
106 throw DataException(temp.str());
107 }
108 }
109
110 switch (leftRank) {
111
112 case 1:
113 switch (rightRank) {
114 case 1:
115 result[0]=0;
116 for (int i=0;i<leftShape[0];i++) {
117 result[0]+=left[i+leftOffset]*right[i+rightOffset];
118 }
119 break;
120 case 2:
121 for (int i=0;i<resultShape[0];i++) {
122 result[i]=0;
123 for (int j=0;j<rightShape[0];j++) {
124 result[i]+=left[j+leftOffset]*right[getRelIndex(rightShape,j,i)+rightOffset];
125 }
126 }
127 break;
128 default:
129 stringstream temp; temp << "Error - (matMult) Invalid rank. Programming error.";
130 throw DataException(temp.str());
131 break;
132 }
133 break;
134
135 case 2:
136 switch (rightRank) {
137 case 1:
138 result[0]=0;
139 for (int i=0;i<leftShape[0];i++) {
140 result[i]=0;
141 for (int j=0;j<leftShape[1];j++) {
142 result[i]+=left[leftOffset+getRelIndex(leftShape,i,j)]*right[i+rightOffset];
143 }
144 }
145 break;
146 case 2:
147 for (int i=0;i<resultShape[0];i++) {
148 for (int j=0;j<resultShape[1];j++) {
149 result[getRelIndex(resultShape,i,j)]=0;
150 for (int jR=0;jR<rightShape[0];jR++) {
151 result[getRelIndex(resultShape,i,j)]+=left[leftOffset+getRelIndex(leftShape,i,jR)]*right[rightOffset+getRelIndex(rightShape,jR,j)];
152 }
153 }
154 }
155 break;
156 default:
157 stringstream temp; temp << "Error - (matMult) Invalid rank. Programming error.";
158 throw DataException(temp.str());
159 break;
160 }
161 break;
162
163 default:
164 stringstream temp; temp << "Error - (matMult) Not supported for rank: " << leftRank;
165 throw DataException(temp.str());
166 break;
167 }
168
169 }
170
171
172 DataTypes::ShapeType
173 determineResultShape(const DataTypes::ShapeType& left,
174 const DataTypes::ShapeType& right)
175 {
176 DataTypes::ShapeType result;
177 for (int i=0; i<(DataTypes::getRank(left)-1); i++) {
178 result.push_back(left[i]);
179 }
180 for (int i=1; i<DataTypes::getRank(right); i++) {
181 result.push_back(right[i]);
182 }
183 return result;
184 }
185
186
187
188
189 void matrixInverseError(int err)
190 {
191 switch (err)
192 {
193 case 0: break; // not an error
194 case BADRANK: throw DataException("matrix_inverse: input and output must be rank 2.");
195 case NOTSQUARE: throw DataException("matrix_inverse: matrix must be square.");
196 case SHAPEMISMATCH: throw DataException("matrix_inverse: programmer error input and output must be the same shape.");
197 case NOINVERSE: throw DataException("matrix_inverse: argument not invertible.");
198 case NEEDLAPACK:throw DataException("matrix_inverse: matrices larger than 3x3 require lapack support.");
199 case ERRFACTORISE: throw DataException("matrix_inverse: argument not invertible (factorise stage).");
200 case ERRINVERT: throw DataException("matrix_inverse: argument not invertible (inverse stage).");
201 default:
202 throw DataException("matrix_inverse: unknown error.");
203 }
204 }
205
206
207
208 // Copied from the python version in util.py
209 int
210 matrix_inverse(const DataTypes::ValueType& in,
211 const DataTypes::ShapeType& inShape,
212 DataTypes::ValueType::size_type inOffset,
213 DataTypes::ValueType& out,
214 const DataTypes::ShapeType& outShape,
215 DataTypes::ValueType::size_type outOffset,
216 int count,
217 LapackInverseHelper& helper)
218 {
219 using namespace DataTypes;
220 using namespace std;
221 int inRank=getRank(inShape);
222 int outRank=getRank(outShape);
223 int size=DataTypes::noValues(inShape);
224 if ((inRank!=2) || (outRank!=2))
225 {
226 return BADRANK;
227 }
228 if (inShape[0]!=inShape[1])
229 {
230 return NOTSQUARE;
231 }
232 if (inShape!=outShape)
233 {
234 return SHAPEMISMATCH;
235 }
236 if (inShape[0]==1)
237 {
238 for (int i=0;i<count;++i)
239 {
240 if (in[inOffset+i]!=0)
241 {
242 out[outOffset+i]=1/in[inOffset+i];
243 }
244 else
245 {
246 return NOINVERSE;
247 }
248 }
249 }
250 else if (inShape[0]==2)
251 {
252 int step=0;
253 for (int i=0;i<count;++i)
254 {
255 double A11=in[inOffset+step+getRelIndex(inShape,0,0)];
256 double A12=in[inOffset+step+getRelIndex(inShape,0,1)];
257 double A21=in[inOffset+step+getRelIndex(inShape,1,0)];
258 double A22=in[inOffset+step+getRelIndex(inShape,1,1)];
259 double D = A11*A22-A12*A21;
260 if (D!=0)
261 {
262 D=1/D;
263 out[outOffset+step+getRelIndex(inShape,0,0)]= A22*D;
264 out[outOffset+step+getRelIndex(inShape,1,0)]=-A21*D;
265 out[outOffset+step+getRelIndex(inShape,0,1)]=-A12*D;
266 out[outOffset+step+getRelIndex(inShape,1,1)]= A11*D;
267 }
268 else
269 {
270 return NOINVERSE;
271 }
272 step+=size;
273 }
274 }
275 else if (inShape[0]==3)
276 {
277 int step=0;
278 for (int i=0;i<count;++i)
279 {
280 double A11=in[inOffset+step+getRelIndex(inShape,0,0)];
281 double A21=in[inOffset+step+getRelIndex(inShape,1,0)];
282 double A31=in[inOffset+step+getRelIndex(inShape,2,0)];
283 double A12=in[inOffset+step+getRelIndex(inShape,0,1)];
284 double A22=in[inOffset+step+getRelIndex(inShape,1,1)];
285 double A32=in[inOffset+step+getRelIndex(inShape,2,1)];
286 double A13=in[inOffset+step+getRelIndex(inShape,0,2)];
287 double A23=in[inOffset+step+getRelIndex(inShape,1,2)];
288 double A33=in[inOffset+step+getRelIndex(inShape,2,2)];
289 double D = A11*(A22*A33-A23*A32)+ A12*(A31*A23-A21*A33)+A13*(A21*A32-A31*A22);
290 if (D!=0)
291 {
292 D=1/D;
293 out[outOffset+step+getRelIndex(inShape,0,0)]=(A22*A33-A23*A32)*D;
294 out[outOffset+step+getRelIndex(inShape,1,0)]=(A31*A23-A21*A33)*D;
295 out[outOffset+step+getRelIndex(inShape,2,0)]=(A21*A32-A31*A22)*D;
296 out[outOffset+step+getRelIndex(inShape,0,1)]=(A13*A32-A12*A33)*D;
297 out[outOffset+step+getRelIndex(inShape,1,1)]=(A11*A33-A31*A13)*D;
298 out[outOffset+step+getRelIndex(inShape,2,1)]=(A12*A31-A11*A32)*D;
299 out[outOffset+step+getRelIndex(inShape,0,2)]=(A12*A23-A13*A22)*D;
300 out[outOffset+step+getRelIndex(inShape,1,2)]=(A13*A21-A11*A23)*D;
301 out[outOffset+step+getRelIndex(inShape,2,2)]=(A11*A22-A12*A21)*D;
302 }
303 else
304 {
305 return NOINVERSE;
306 }
307 step+=size;
308 }
309 }
310 else // inShape[0] >3 (or negative but that can hopefully never happen)
311 {
312 #ifndef USE_LAPACK
313 return NEEDLAPACK;
314 #else
315 int step=0;
316
317
318 for (int i=0;i<count;++i)
319 {
320 // need to make a copy since blas overwrites its input
321 for (int j=0;j<size;++j)
322 {
323 out[outOffset+step+j]=in[inOffset+step+j];
324 }
325 double* arr=&(out[outOffset+step]);
326 int res=helper.invert(arr);
327 if (res!=0)
328 {
329 return res;
330 }
331 step+=size;
332 }
333 #endif
334 }
335 return SUCCESS;
336 }
337
338 } // end namespace
339 } // end namespace
340

  ViewVC Help
Powered by ViewVC 1.1.26