/[escript]/trunk/trilinoswrap/src/CrsMatrixWrapper.cpp
ViewVC logotype

Contents of /trunk/trilinoswrap/src/CrsMatrixWrapper.cpp

Parent Directory Parent Directory | Revision Log Revision Log


Revision 6408 - (show annotations)
Thu Oct 27 03:29:59 2016 UTC (19 months, 3 weeks ago) by gross
File size: 10146 byte(s)
work around for  Bug #389
1
2 /*****************************************************************************
3 *
4 * Copyright (c) 2016 by The University of Queensland
5 * http://www.uq.edu.au
6 *
7 * Primary Business: Queensland, Australia
8 * Licensed under the Apache License, version 2.0
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Development until 2012 by Earth Systems Science Computational Center (ESSCC)
12 * Development 2012-2013 by School of Earth Sciences
13 * Development from 2014 by Centre for Geoscience Computing (GeoComp)
14 *
15 *****************************************************************************/
16
17 #include "CrsMatrixWrapper.h"
18 #include "Amesos2Wrapper.h"
19 #include "BelosWrapper.h"
20 #include "PreconditionerFactory.h"
21 #include "TrilinosAdapterException.h"
22 #include "util.h"
23
24 #include <escript/SolverOptions.h>
25
26 #include <Kokkos_DefaultNode.hpp>
27 #include <MatrixMarket_Tpetra.hpp>
28 #include <MueLu_CreateTpetraPreconditioner.hpp>
29
30 #include <Tpetra_DefaultPlatform.hpp>
31 #include <Tpetra_Vector.hpp>
32
33 using Teuchos::RCP;
34 using Teuchos::rcp;
35 using Teuchos::rcpFromRef;
36
37 namespace esys_trilinos {
38
39 template<typename ST>
40 CrsMatrixWrapper<ST>::CrsMatrixWrapper(const_TrilinosGraph_ptr graph) :
41 mat(graph),
42 m_resetCalled(false)
43 {
44 mat.fillComplete();
45 maxLocalRow = graph->getRowMap()->getMaxLocalIndex();
46 }
47
48 template<typename ST>
49 void CrsMatrixWrapper<ST>::fillComplete(bool localOnly)
50 {
51 RCP<Teuchos::ParameterList> params = Teuchos::parameterList();
52 params->set("No Nonlocal Changes", localOnly);
53 mat.fillComplete(params);
54 }
55
56 template<typename ST>
57 void CrsMatrixWrapper<ST>::add(const std::vector<LO>& rowIdx,
58 const std::vector<ST>& array)
59 {
60 const size_t emSize = rowIdx.size();
61 std::vector<LO> cols(emSize);
62 std::vector<ST> vals(emSize);
63 for (size_t i = 0; i < emSize; i++) {
64 const LO row = rowIdx[i];
65 if (row <= maxLocalRow) {
66 for (int j = 0; j < emSize; j++) {
67 const LO col = rowIdx[j];
68 cols[j] = col;
69 const size_t srcIdx = j * emSize + i;
70 vals[j] = array[srcIdx];
71 }
72 mat.sumIntoLocalValues(row, cols, vals);
73 }
74 }
75 }
76
77 template<typename ST>
78 void CrsMatrixWrapper<ST>::ypAx(const Teuchos::ArrayView<ST>& y,
79 const Teuchos::ArrayView<const ST>& x) const
80 {
81 RCP<VectorType<ST> > X = rcp(new VectorType<ST>(mat.getRowMap(), x, x.size(), 1));
82 RCP<VectorType<ST> > Y = rcp(new VectorType<ST>(mat.getRowMap(), y, y.size(), 1));
83
84 const ST alpha = Teuchos::ScalarTraits<ST>::one();
85 const ST beta = Teuchos::ScalarTraits<ST>::one();
86
87 // Y = beta*Y + alpha*A*X
88 mat.apply(*X, *Y, Teuchos::NO_TRANS, alpha, beta);
89 Y->get1dCopy(y, y.size());
90 }
91
92 template<typename ST>
93 void CrsMatrixWrapper<ST>::solve(const Teuchos::ArrayView<ST>& x,
94 const Teuchos::ArrayView<const ST>& b,
95 escript::SolverBuddy& sb) const
96 {
97 typedef VectorType<ST> Vector;
98
99 RCP<Vector> X = rcp(new Vector(mat.getDomainMap(), 1));
100 RCP<Vector> B = rcp(new Vector(mat.getRangeMap(), b, b.size(), 1));
101 RCP<const Matrix> A = rcpFromRef(mat);
102
103 if (escript::isDirectSolver(sb.getSolverMethod())) {
104 RCP<DirectSolverType<Matrix,Vector> > solver(m_direct);
105 if (solver.is_null()) {
106 solver = createDirectSolver<Matrix,Vector>(sb, A, X, B);
107 m_direct = solver;
108 if (sb.isVerbose()) {
109 std::cout << "Using " << solver->description() << std::endl;
110 std::cout << "Performing symbolic factorization..." << std::flush;
111 }
112 solver->symbolicFactorization();
113 if (sb.isVerbose()) {
114 std::cout << "done\nPerforming numeric factorization..." << std::flush;
115 }
116 solver->numericFactorization();
117 if (sb.isVerbose()) {
118 std::cout << "done\n" << std::flush;
119 }
120 } else {
121 if (sb.isVerbose()) {
122 std::cout << "Using " << solver->description() << std::endl;
123 }
124 if (m_resetCalled) {
125 // matrix structure never changes
126 solver->setA(A, Amesos2::SYMBFACT);
127 m_resetCalled = false;
128 }
129 solver->setX(X);
130 solver->setB(B);
131 }
132 if (sb.isVerbose()) {
133 std::cout << "Solving system..." << std::flush;
134 }
135 solver->solve();
136 if (sb.isVerbose()) {
137 std::cout << "done" << std::endl;
138 RCP<Teuchos::FancyOStream> fos(Teuchos::fancyOStream(Teuchos::rcpFromRef(std::cout)));
139 solver->printTiming(*fos, Teuchos::VERB_HIGH);
140 }
141
142 } else { // iterative solver
143 double t0 = Teuchos::Time::wallTime();
144 RCP<ProblemType<ST> > problem(m_solver);
145 if (problem.is_null()) {
146 problem = rcp(new ProblemType<ST>(A, X, B));
147 m_solver = problem;
148 RCP<OpType<ST> > prec = createPreconditioner<ST>(A, sb);
149 m_preconditioner = prec;
150 if (!prec.is_null()) {
151 // Trilinos BiCGStab does not support left preconditioners
152 if (sb.getSolverMethod() == escript::SO_METHOD_BICGSTAB)
153 problem->setRightPrec(prec);
154 else
155 problem->setLeftPrec(prec);
156 }
157 problem->setHermitian(sb.isSymmetric());
158 problem->setProblem();
159 } else {
160 for (auto t: problem->getTimers()) {
161 t->reset();
162 }
163 if (m_resetCalled) {
164 // special case for MueLu preconditioner - call Reuse...
165 // which honours the "reuse: type" parameter.
166 RCP<MueLu::TpetraOperator<ST,LO,GO,NT> > mlOp =
167 Teuchos::rcp_dynamic_cast<MueLu::TpetraOperator<ST,LO,GO,NT> >(m_preconditioner);
168 if (mlOp.get()) {
169 RCP<Matrix> A_(Teuchos::rcp_const_cast<Matrix>(A));
170 MueLu::ReuseTpetraPreconditioner(A_, *mlOp);
171 }
172 }
173 problem->setProblem(X, B);
174 }
175
176 double t1 = Teuchos::Time::wallTime();
177 RCP<SolverType<ST> > solver = createSolver<ST>(sb);
178 if (sb.isVerbose()) {
179 std::cout << "Using " << solver->description() << std::endl;
180 }
181 solver->setProblem(problem);
182 Belos::ReturnType result = solver->solve();
183 double t2 = Teuchos::Time::wallTime();
184 const int numIters = solver->getNumIters();
185 double tol = sb.getTolerance();
186 try {
187 tol = solver->achievedTol();
188 } catch (...) {
189 }
190 if (sb.isVerbose()) {
191 if (result == Belos::Converged) {
192 sb.updateDiagnostics("converged", true);
193 std::cout << "The solver took " << numIters
194 << " iteration(s) to reach a residual tolerance of "
195 << tol << "." << std::endl;
196 } else {
197 std::cout << "The solver took " << numIters
198 << " iteration(s), but did not reach a relative residual "
199 "tolerance of " << sb.getTolerance() << "." << std::endl;
200 }
201 }
202 double solverTime = 0.;
203 for (auto t: problem->getTimers()) {
204 solverTime += t->totalElapsedTime();
205 }
206 sb.updateDiagnostics("set_up_time", t1-t0);
207 sb.updateDiagnostics("net_time", solverTime);
208 sb.updateDiagnostics("time", t2-t0);
209 sb.updateDiagnostics("num_iter", numIters);
210 sb.updateDiagnostics("residual_norm", tol);
211 }
212 X->get1dCopy(x, x.size());
213 }
214
215 template<typename ST>
216 void CrsMatrixWrapper<ST>::nullifyRowsAndCols(
217 const Teuchos::ArrayView<const real_t>& rowMask,
218 const Teuchos::ArrayView<const real_t>& colView,
219 ST mdv)
220 {
221 const_TrilinosMap_ptr rowMap(mat.getRowMap());
222 RCP<VectorType<real_t> > lclCol = rcp(new VectorType<real_t>(rowMap,
223 colView, colView.size(), 1));
224 RCP<VectorType<real_t> > gblCol = rcp(new VectorType<real_t>(
225 mat.getColMap(), 1));
226
227 const ImportType importer(rowMap, mat.getColMap());
228 gblCol->doImport(*lclCol, importer, Tpetra::INSERT);
229 Teuchos::ArrayRCP<const real_t> colMask(gblCol->getData(0));
230 const ST zero = Teuchos::ScalarTraits<ST>::zero();
231
232 resumeFill();
233 // Can't use OpenMP here as replaceLocalValues() is not thread-safe.
234 //#pragma omp parallel for
235 for (LO lclrow = 0; lclrow < mat.getNodeNumRows(); lclrow++) {
236 Teuchos::ArrayView<const LO> indices;
237 Teuchos::ArrayView<const ST> values;
238 std::vector<GO> cols;
239 std::vector<ST> vals;
240 mat.getLocalRowView(lclrow, indices, values);
241 GO row = rowMap->getGlobalElement(lclrow);
242 for (size_t c = 0; c < indices.size(); c++) {
243 const LO lclcol = indices[c];
244 const GO col = mat.getColMap()->getGlobalElement(lclcol);
245 if (rowMask[lclrow] > 0. || colMask[lclcol] > 0.) {
246 cols.push_back(lclcol);
247 vals.push_back(row==col ? mdv : zero);
248 }
249 }
250 if (cols.size() > 0)
251 mat.replaceLocalValues(lclrow, cols, vals);
252 }
253 fillComplete(true);
254 }
255
256 template<typename ST>
257 void CrsMatrixWrapper<ST>::saveMM(const std::string& filename) const
258 {
259 Tpetra::MatrixMarket::Writer<Matrix>::writeSparseFile(filename, rcpFromRef(mat));
260 }
261
262 template<typename ST>
263 void CrsMatrixWrapper<ST>::resetValues(bool preserveSolverData)
264 {
265 resumeFill();
266 mat.setAllToScalar(static_cast<ST>(0.));
267 fillComplete(true);
268 if (!preserveSolverData) {
269 m_solver.reset();
270 m_preconditioner.reset();
271 }
272 m_resetCalled = true;
273 }
274
275
276 // instantiate the supported variants
277 template class CrsMatrixWrapper<real_t>;
278 template class CrsMatrixWrapper<cplx_t>;
279
280 } // end of namespace
281

  ViewVC Help
Powered by ViewVC 1.1.26