/[escript]/branches/trilinos_from_5897/trilinoswrap/src/TrilinosMatrixAdapter.cpp
ViewVC logotype

Contents of /branches/trilinos_from_5897/trilinoswrap/src/TrilinosMatrixAdapter.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 6104 - (show annotations)
Wed Mar 30 06:01:20 2016 UTC (3 years ago) by caltinay
File size: 15492 byte(s)
Factored out and templetized preconditioner,solver and direct solver creation.
The SystemMatrix constructor now takes an optional arg 'isComplex'.
Some complex operations are commented out as we need the complex getSampleData*
methods from trunk for them to work.

It looks like we have to modify the Abstract class in escript eventually as
there is a single method that takes a `double` argument (nullifyRowsAndCols).


1
2 /*****************************************************************************
3 *
4 * Copyright (c) 2016 by The University of Queensland
5 * http://www.uq.edu.au
6 *
7 * Primary Business: Queensland, Australia
8 * Licensed under the Open Software License version 3.0
9 * http://www.opensource.org/licenses/osl-3.0.php
10 *
11 * Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 * Development 2012-2013 by School of Earth Sciences
13 * Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 *
15 *****************************************************************************/
16
17 #include "TrilinosMatrixAdapter.h"
18 #include "Amesos2Wrapper.h"
19 #include "BelosWrapper.h"
20 #include "PreconditionerFactory.h"
21 #include "TrilinosAdapterException.h"
22
23 #include <escript/index.h>
24 #include <escript/Data.h>
25 #include <escript/FunctionSpaceFactory.h>
26 #include <escript/SolverOptions.h>
27
28 #include <BelosTpetraAdapter.hpp>
29 #include <BelosTypes.hpp>
30 #include <Kokkos_DefaultNode.hpp>
31 #include <MatrixMarket_Tpetra.hpp>
32 #include <Tpetra_DefaultPlatform.hpp>
33 #include <Tpetra_Vector.hpp>
34
35 namespace bp = boost::python;
36 using Teuchos::RCP;
37 using Teuchos::rcp;
38 using Teuchos::rcpFromRef;
39
40 namespace esys_trilinos {
41
42 template<typename ST>
43 void ypAxWorker(RCP<MatrixType<ST> > A, const Teuchos::ArrayView<ST>& y,
44 const Teuchos::ArrayView<const ST>& x)
45 {
46 RCP<VectorType<ST> > X = rcp(new VectorType<ST>(
47 A->getRowMap(), x, x.size(), 1));
48 RCP<VectorType<ST> > Y = rcp(new VectorType<ST>(
49 A->getRowMap(), y, y.size(), 1));
50
51 const ST alpha = Teuchos::ScalarTraits<ST>::one();
52 const ST beta = Teuchos::ScalarTraits<ST>::one();
53
54 // Y = beta*Y + alpha*A*X
55 A->apply(*X, *Y, Teuchos::NO_TRANS, alpha, beta);
56 Y->get1dCopy(y, y.size());
57 }
58
59 template<typename ST>
60 void solveWorker(RCP<MatrixType<ST> > A, const Teuchos::ArrayView<ST>& x,
61 const Teuchos::ArrayView<const ST>& b,
62 const escript::SolverBuddy& sb)
63 {
64 RCP<VectorType<ST> > X = rcp(new VectorType<ST>(A->getDomainMap(), 1));
65 RCP<VectorType<ST> > B = rcp(new VectorType<ST>(A->getRangeMap(), b,
66 b.size(), 1));
67
68 if (sb.getSolverMethod() == escript::SO_METHOD_DIRECT) {
69 RCP<DirectSolverType<ST> > solver = createDirectSolver<ST>(sb, A, X, B);
70 if (sb.isVerbose()) {
71 std::cout << solver->description() << std::endl;
72 std::cout << "Performing symbolic factorization..." << std::flush;
73 }
74 solver->symbolicFactorization();
75 if (sb.isVerbose()) {
76 std::cout << "done\nPerforming numeric factorization..." << std::flush;
77 }
78 solver->numericFactorization();
79 if (sb.isVerbose()) {
80 std::cout << "done\nSolving system..." << std::flush;
81 }
82 solver->solve();
83 if (sb.isVerbose()) {
84 std::cout << "done" << std::endl;
85 RCP<Teuchos::FancyOStream> fos(Teuchos::fancyOStream(Teuchos::rcpFromRef(std::cout)));
86 solver->printTiming(*fos, Teuchos::VERB_HIGH);
87 }
88
89 } else { // iterative solver
90 RCP<SolverType<ST> > solver = createSolver<ST>(sb);
91 RCP<OpType<ST> > prec = createPreconditioner<ST>(A, sb);
92 RCP<ProblemType<ST> > problem = rcp(new ProblemType<ST>(A, X, B));
93
94 if (!prec.is_null()) {
95 // Trilinos BiCGStab does not currently support left preconditioners
96 if (sb.getSolverMethod() == escript::SO_METHOD_BICGSTAB)
97 problem->setRightPrec(prec);
98 else
99 problem->setLeftPrec(prec);
100 }
101 problem->setProblem();
102 solver->setProblem(problem);
103 Belos::ReturnType result = solver->solve();
104 if (sb.isVerbose()) {
105 const int numIters = solver->getNumIters();
106 if (result == Belos::Converged) {
107 std::cout << "The solver took " << numIters
108 << " iteration(s) to reach a relative residual tolerance of "
109 << sb.getTolerance() << "." << std::endl;
110 } else {
111 std::cout << "The solver took " << numIters
112 << " iteration(s), but did not reach a relative residual "
113 "tolerance of " << sb.getTolerance() << "." << std::endl;
114 }
115 }
116 }
117 X->get1dCopy(x, x.size());
118 }
119
120 TrilinosMatrixAdapter::TrilinosMatrixAdapter(escript::JMPI mpiInfo,
121 int blocksize, const escript::FunctionSpace& fs,
122 const_TrilinosGraph_ptr graph, bool isComplex) :
123 AbstractSystemMatrix(blocksize, fs, blocksize, fs),
124 m_mpiInfo(mpiInfo),
125 m_isComplex(isComplex)
126 {
127 if (blocksize != 1) {
128 throw escript::ValueError("Trilinos matrices only support blocksize 1 "
129 "at the moment!");
130 }
131 importer = rcp(new ImportType(graph->getRowMap(), graph->getColMap()));
132 if (isComplex) {
133 cmat = rcp(new ComplexMatrix(graph));
134 cmat->fillComplete();
135 std::cout << "Matrix has " << cmat->getGlobalNumEntries()
136 << " entries." << std::endl;
137 } else {
138 mat = rcp(new RealMatrix(graph));
139 mat->fillComplete();
140 std::cout << "Matrix has " << mat->getGlobalNumEntries()
141 << " entries." << std::endl;
142 }
143
144 }
145
146 void TrilinosMatrixAdapter::fillComplete(bool localOnly)
147 {
148 RCP<Teuchos::ParameterList> params = Teuchos::parameterList();
149 params->set("No Nonlocal Changes", localOnly);
150 if (m_isComplex)
151 cmat->fillComplete(cmat->getDomainMap(), cmat->getRangeMap(), params);
152 else
153 mat->fillComplete(mat->getDomainMap(), mat->getRangeMap(), params);
154 }
155
156 template<>
157 void TrilinosMatrixAdapter::add<real_t>(const std::vector<LO>& rowIdx,
158 const std::vector<real_t>& array)
159 {
160 if (m_isComplex) {
161 throw escript::ValueError("Please use complex array to add to complex "
162 "matrix!");
163 } else {
164 addImpl<real_t>(mat, rowIdx, array);
165 }
166 }
167
168 template<>
169 void TrilinosMatrixAdapter::add<cplx_t>(const std::vector<LO>& rowIdx,
170 const std::vector<cplx_t>& array)
171 {
172 if (m_isComplex) {
173 addImpl<cplx_t>(cmat, rowIdx, array);
174 } else {
175 throw escript::ValueError("Please use real-valued array to add to "
176 "real-valued matrix!");
177 }
178 }
179
180 template<typename ST>
181 void TrilinosMatrixAdapter::addImpl(RCP<MatrixType<ST> > A,
182 const std::vector<LO>& rowIdx,
183 const std::vector<ST>& array)
184 {
185 const int blockSize = getBlockSize();
186 const size_t emSize = rowIdx.size();
187 const LO myLast = A->getRowMap()->getMaxLocalIndex();
188 std::vector<LO> cols(emSize*blockSize);
189 std::vector<ST> vals(emSize*blockSize);
190 for (size_t i = 0; i < emSize; i++) {
191 for (int k = 0; k < blockSize; k++) {
192 const LO row = rowIdx[i]*blockSize + k;
193 if (row <= myLast) {
194 cols.clear();
195 vals.clear();
196 for (int j = 0; j < emSize; j++) {
197 for (int m = 0; m < blockSize; m++) {
198 const LO col = rowIdx[j]*blockSize + m;
199 cols.push_back(col);
200 const size_t srcIdx =
201 INDEX4(k, m, i, j, blockSize, blockSize, emSize);
202 vals.push_back(array[srcIdx]);
203 }
204 }
205 A->sumIntoLocalValues(row, cols, vals);
206 }
207 }
208 }
209 }
210
211 void TrilinosMatrixAdapter::ypAx(escript::Data& y, escript::Data& x) const
212 {
213 if (x.getDataPointSize() != getBlockSize()) {
214 throw TrilinosAdapterException("matrix vector product: block size "
215 "does not match the number of components in input.");
216 } else if (y.getDataPointSize() != getBlockSize()) {
217 throw TrilinosAdapterException("matrix vector product: block size "
218 "does not match the number of components in output.");
219 } else if (x.getFunctionSpace() != getColumnFunctionSpace()) {
220 throw TrilinosAdapterException("matrix vector product: matrix "
221 "function space and function space of input don't match.");
222 } else if (y.getFunctionSpace() != getRowFunctionSpace()) {
223 throw TrilinosAdapterException("matrix vector product: matrix "
224 "function space and function space of output don't match.");
225 } else if (y.isComplex() != m_isComplex || x.isComplex() != m_isComplex) {
226 throw escript::ValueError("matrix vector product: matrix complexity "
227 "must match vector complexity!");
228 }
229
230 // expand data object if necessary to be able to grab the whole data
231 x.expand();
232 y.expand();
233 y.requireWrite();
234
235 if (m_isComplex) {
236 throw escript::NotImplementedError("complex ypAx not implemented!");
237 // TODO: need complex version of getSampleDataRO/RW for this:
238 //const Teuchos::ArrayView<const cplx_t> xView(x.getSampleDataRO(0),
239 // x.getNumDataPoints());
240 //const Teuchos::ArrayView<cplx_t> yView(y.getSampleDataRW(0),
241 // y.getNumDataPoints());
242 //ypAxWorker<cplx_t>(cmat, yView, xView);
243 } else {
244 const Teuchos::ArrayView<const real_t> xView(x.getSampleDataRO(0),
245 x.getNumDataPoints());
246 const Teuchos::ArrayView<real_t> yView(y.getSampleDataRW(0),
247 y.getNumDataPoints());
248 ypAxWorker<real_t>(mat, yView, xView);
249 }
250 }
251
252 void TrilinosMatrixAdapter::setToSolution(escript::Data& out, escript::Data& in,
253 bp::object& options) const
254 {
255 if (out.getDataPointSize() != getBlockSize()) {
256 throw TrilinosAdapterException("solve: block size does not match the number of components of solution.");
257 } else if (in.getDataPointSize() != getBlockSize()) {
258 throw TrilinosAdapterException("solve: block size does not match the number of components of right hand side.");
259 } else if (out.getFunctionSpace() != getColumnFunctionSpace()) {
260 throw TrilinosAdapterException("solve: matrix function space and function space of solution don't match.");
261 } else if (in.getFunctionSpace() != getRowFunctionSpace()) {
262 throw TrilinosAdapterException("solve: matrix function space and function space of right hand side don't match.");
263 } else if (in.isComplex() != m_isComplex || out.isComplex() != m_isComplex) {
264 throw escript::ValueError("solve: matrix complexity must match vector "
265 "complexity!");
266 }
267
268 options.attr("resetDiagnostics")();
269 escript::SolverBuddy sb = bp::extract<escript::SolverBuddy>(options);
270 out.expand();
271 in.expand();
272
273 if (m_isComplex) {
274 throw escript::NotImplementedError("complex solve not implemented!");
275 // TODO: need complex version of getSampleDataRO/RW for this:
276 //const Teuchos::ArrayView<const cplx_t> bView(in.getSampleDataRO(0), in.getNumDataPoints());
277 //const Teuchos::ArrayView<cplx_t> outView(out.getSampleDataRW(0),
278 // out.getNumDataPoints());
279 //solveWorker<cplx_t>(cmat, outView, bView, sb);
280
281 } else {
282 const Teuchos::ArrayView<const real_t> bView(in.getSampleDataRO(0),
283 in.getNumDataPoints());
284 const Teuchos::ArrayView<real_t> outView(out.getSampleDataRW(0),
285 out.getNumDataPoints());
286 solveWorker<real_t>(mat, outView, bView, sb);
287 }
288
289 }
290
291 void TrilinosMatrixAdapter::nullifyRowsAndCols(escript::Data& row_q,
292 escript::Data& col_q,
293 double mdv)
294 {
295 if (col_q.getDataPointSize() != getColumnBlockSize()) {
296 throw TrilinosAdapterException("nullifyRowsAndCols: column block size does not match the number of components of column mask.");
297 } else if (row_q.getDataPointSize() != getRowBlockSize()) {
298 throw TrilinosAdapterException("nullifyRowsAndCols: row block size does not match the number of components of row mask.");
299 } else if (col_q.getFunctionSpace() != getColumnFunctionSpace()) {
300 throw TrilinosAdapterException("nullifyRowsAndCols: column function space and function space of column mask don't match.");
301 } else if (row_q.getFunctionSpace() != getRowFunctionSpace()) {
302 throw TrilinosAdapterException("nullifyRowsAndCols: row function space and function space of row mask don't match.");
303 }
304
305 col_q.expand();
306 row_q.expand();
307 const Teuchos::ArrayView<const real_t> rowMask(row_q.getSampleDataRO(0),
308 row_q.getNumDataPoints());
309 // we need remote values for col_q
310 const Teuchos::ArrayView<const real_t> colView(col_q.getSampleDataRO(0),
311 col_q.getNumDataPoints());
312
313 // TODO:
314 if (m_isComplex)
315 throw escript::NotImplementedError("nullifyRowsAndCols: complex "
316 "version not implemented");
317
318 RCP<RealVector> lclCol = rcp(new RealVector(mat->getRowMap(), colView, colView.size(), 1));
319 RCP<RealVector> gblCol = rcp(new RealVector(mat->getColMap(), 1));
320
321 gblCol->doImport(*lclCol, *importer, Tpetra::INSERT);
322 Teuchos::ArrayRCP<const real_t> colMask(gblCol->getData(0));
323
324 resumeFill();
325 // Can't use OpenMP here as replaceLocalValues() is not thread-safe.
326 //#pragma omp parallel for
327 for (LO lclrow = 0; lclrow < mat->getNodeNumRows(); lclrow++) {
328 Teuchos::ArrayView<const LO> indices;
329 Teuchos::ArrayView<const real_t> values;
330 std::vector<GO> cols;
331 std::vector<real_t> vals;
332 mat->getLocalRowView(lclrow, indices, values);
333 GO row = mat->getRowMap()->getGlobalElement(lclrow);
334 for (size_t c = 0; c < indices.size(); c++) {
335 const LO lclcol = indices[c];
336 const GO col = mat->getColMap()->getGlobalElement(lclcol);
337 if (rowMask[lclrow] > 0. || colMask[lclcol] > 0.) {
338 cols.push_back(lclcol);
339 vals.push_back(row==col ? (real_t)mdv : (real_t)0);
340 }
341 }
342 if (cols.size() > 0)
343 mat->replaceLocalValues(lclrow, cols, vals);
344 }
345 fillComplete(true);
346 }
347
348 void TrilinosMatrixAdapter::saveMM(const std::string& filename) const
349 {
350 if (m_isComplex) {
351 Tpetra::MatrixMarket::Writer<RealMatrix>::writeSparseFile(filename, mat);
352 } else {
353 Tpetra::MatrixMarket::Writer<ComplexMatrix>::writeSparseFile(filename, cmat);
354 }
355 }
356
357 void TrilinosMatrixAdapter::saveHB(const std::string& filename) const
358 {
359 throw escript::NotImplementedError("Harwell-Boeing interface not available.");
360 }
361
362 void TrilinosMatrixAdapter::resetValues()
363 {
364 resumeFill();
365 if (m_isComplex) {
366 cmat->setAllToScalar(static_cast<const cplx_t>(0.));
367 } else {
368 mat->setAllToScalar(static_cast<const real_t>(0.));
369 }
370 fillComplete(true);
371 }
372
373 } // end of namespace
374

  ViewVC Help
Powered by ViewVC 1.1.26