50 #ifndef _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
51 #define _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
57 #include <Tpetra_RowMatrix.hpp>
75 template <
typename User,
typename UserCoord=User>
79 #ifndef DOXYGEN_SHOULD_SKIP_THIS
87 typedef UserCoord userCoord_t;
100 int nWeightsPerRow=0);
152 return matrix_->getNodeNumRows();
156 return matrix_->getNodeNumCols();
160 return matrix_->getNodeNumEntries();
167 ArrayView<const gno_t> rowView = rowMap_->getNodeElementList();
168 rowIds = rowView.getRawPtr();
173 offsets = offset_.getRawPtr();
174 colIds = columnIds_.getRawPtr();
180 offsets = offset_.getRawPtr();
181 colIds = columnIds_.getRawPtr();
182 values = values_.getRawPtr();
191 if(idx<0 || idx >= nWeightsPerRow_)
193 std::ostringstream emsg;
194 emsg << __FILE__ <<
":" << __LINE__
195 <<
" Invalid row weight index " << idx << std::endl;
196 throw std::runtime_error(emsg.str());
201 rowWeights_[idx].getStridedList(length,
weights, stride);
206 template <
typename Adapter>
210 template <
typename Adapter>
216 RCP<const User> matrix_;
217 RCP<const Tpetra::Map<lno_t, gno_t, node_t> > rowMap_;
218 RCP<const Tpetra::Map<lno_t, gno_t, node_t> > colMap_;
219 ArrayRCP<offset_t> offset_;
220 ArrayRCP<gno_t> columnIds_;
221 ArrayRCP<scalar_t> values_;
224 ArrayRCP<StridedData<lno_t, scalar_t> > rowWeights_;
225 ArrayRCP<bool> numNzWeight_;
227 bool mayHaveDiagonalEntries;
229 RCP<User> doMigration(
const User &from,
size_t numLocalRows,
230 const gno_t *myNewRows)
const;
237 template <
typename User,
typename UserCoord>
239 const RCP<const User> &inmatrix,
int nWeightsPerRow):
240 matrix_(inmatrix), rowMap_(), colMap_(),
241 offset_(), columnIds_(),
242 nWeightsPerRow_(nWeightsPerRow), rowWeights_(), numNzWeight_(),
243 mayHaveDiagonalEntries(true)
247 rowMap_ = matrix_->getRowMap();
248 colMap_ = matrix_->getColMap();
250 size_t nrows = matrix_->getNodeNumRows();
251 size_t nnz = matrix_->getNodeNumEntries();
252 size_t maxnumentries =
253 matrix_->getNodeMaxNumRowEntries();
255 offset_.resize(nrows+1, 0);
256 columnIds_.resize(nnz);
258 ArrayRCP<lno_t> indices(maxnumentries);
259 ArrayRCP<scalar_t> nzs(maxnumentries);
261 for (
size_t i=0; i < nrows; i++){
263 matrix_->getLocalRowCopy(row, indices(), nzs(), nnz);
264 for (
size_t j=0; j < nnz; j++){
265 values_[next] = nzs[j];
268 columnIds_[next++] = colMap_->getGlobalElement(indices[j]);
270 offset_[i+1] = offset_[i] + nnz;
273 if (nWeightsPerRow_ > 0){
274 rowWeights_ = arcp(
new input_t [nWeightsPerRow_], 0, nWeightsPerRow_,
true);
275 numNzWeight_ = arcp(
new bool [nWeightsPerRow_], 0, nWeightsPerRow_,
true);
276 for (
int i=0; i < nWeightsPerRow_; i++)
277 numNzWeight_[i] =
false;
282 template <
typename User,
typename UserCoord>
284 const scalar_t *weightVal,
int stride,
int idx)
286 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
287 setRowWeights(weightVal, stride, idx);
290 std::ostringstream emsg;
291 emsg << __FILE__ <<
"," << __LINE__
292 <<
" error: setWeights not yet supported for"
293 <<
" columns or nonzeros."
295 throw std::runtime_error(emsg.str());
300 template <
typename User,
typename UserCoord>
302 const scalar_t *weightVal,
int stride,
int idx)
305 if(idx<0 || idx >= nWeightsPerRow_)
307 std::ostringstream emsg;
308 emsg << __FILE__ <<
":" << __LINE__
309 <<
" Invalid row weight index " << idx << std::endl;
310 throw std::runtime_error(emsg.str());
313 size_t nvtx = getLocalNumRows();
314 ArrayRCP<const scalar_t> weightV(weightVal, 0, nvtx*stride,
false);
315 rowWeights_[idx] = input_t(weightV, stride);
319 template <
typename User,
typename UserCoord>
323 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
324 setRowWeightIsNumberOfNonZeros(idx);
327 std::ostringstream emsg;
328 emsg << __FILE__ <<
"," << __LINE__
329 <<
" error: setWeightIsNumberOfNonZeros not yet supported for"
330 <<
" columns" << std::endl;
331 throw std::runtime_error(emsg.str());
336 template <
typename User,
typename UserCoord>
340 if(idx<0 || idx >= nWeightsPerRow_)
342 std::ostringstream emsg;
343 emsg << __FILE__ <<
":" << __LINE__
344 <<
" Invalid row weight index " << idx << std::endl;
345 throw std::runtime_error(emsg.str());
349 numNzWeight_[idx] =
true;
353 template <
typename User,
typename UserCoord>
354 template <
typename Adapter>
356 const User &in, User *&out,
361 ArrayRCP<gno_t> importList;
365 (solution,
this, importList);
370 RCP<User> outPtr = doMigration(in, numNewRows, importList.getRawPtr());
376 template <
typename User,
typename UserCoord>
377 template <
typename Adapter>
379 const User &in, RCP<User> &out,
384 ArrayRCP<gno_t> importList;
388 (solution,
this, importList);
393 out = doMigration(in, numNewRows, importList.getRawPtr());
398 template <
typename User,
typename UserCoord>
402 const gno_t *myNewRows
405 typedef Tpetra::Map<lno_t, gno_t, node_t> map_t;
406 typedef Tpetra::CrsMatrix<scalar_t, lno_t, gno_t, node_t> tcrsmatrix_t;
417 const tcrsmatrix_t *pCrsMatrix = dynamic_cast<const tcrsmatrix_t *>(&from);
420 throw std::logic_error(
"TpetraRowMatrixAdapter cannot migrate data for "
421 "your RowMatrix; it can migrate data only for "
422 "Tpetra::CrsMatrix. "
423 "You can inherit from TpetraRowMatrixAdapter and "
424 "implement migration for your RowMatrix.");
428 const RCP<const map_t> &smap = from.getRowMap();
429 gno_t numGlobalRows = smap->getGlobalNumElements();
430 gno_t base = smap->getMinAllGlobalIndex();
433 ArrayView<const gno_t> rowList(myNewRows, numLocalRows);
434 const RCP<const Teuchos::Comm<int> > &comm = from.getComm();
435 RCP<const map_t> tmap = rcp(
new map_t(numGlobalRows, rowList, base, comm));
438 Tpetra::Import<lno_t, gno_t, node_t> importer(smap, tmap);
458 int oldNumElts = smap->getNodeNumElements();
459 int newNumElts = numLocalRows;
462 typedef Tpetra::Vector<scalar_t, lno_t, gno_t, node_t> vector_t;
463 vector_t numOld(smap);
464 vector_t numNew(tmap);
465 for (
int lid=0; lid < oldNumElts; lid++){
466 numOld.replaceGlobalValue(smap->getGlobalElement(lid),
467 scalar_t(from.getNumEntriesInLocalRow(lid)));
469 numNew.doImport(numOld, importer, Tpetra::INSERT);
472 ArrayRCP<size_t> nnz(newNumElts);
474 ArrayRCP<scalar_t> ptr = numNew.getDataNonConst(0);
475 for (
int lid=0; lid < newNumElts; lid++){
476 nnz[lid] = static_cast<size_t>(ptr[lid]);
480 RCP<tcrsmatrix_t> M = rcp(
new tcrsmatrix_t(tmap, nnz,
481 Tpetra::StaticProfile));
482 M->doImport(from, importer, Tpetra::INSERT);
486 return Teuchos::rcp_dynamic_cast<User>(M);