42 #ifndef THYRA_TPETRA_MULTIVECTOR_HPP
43 #define THYRA_TPETRA_MULTIVECTOR_HPP
45 #include "Thyra_TpetraMultiVector_decl.hpp"
46 #include "Thyra_TpetraVectorSpace.hpp"
47 #include "Thyra_TpetraVector.hpp"
48 #include "Teuchos_Assert.hpp"
57 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
62 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
66 const RCP<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraMultiVector
69 initializeImpl(tpetraVectorSpace, domainSpace, tpetraMultiVector);
73 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
77 const RCP<
const Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> > &tpetraMultiVector
80 initializeImpl(tpetraVectorSpace, domainSpace, tpetraMultiVector);
84 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
88 return tpetraMultiVector_.getNonconstObj();
92 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
96 return tpetraMultiVector_;
103 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
114 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
118 tpetraMultiVector_.getNonconstObj()->putScalar(alpha);
122 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
126 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
131 tpetraMultiVector_.getNonconstObj()->assign(*tmv);
134 tpetraMultiVector_.getNonconstObj()->template sync<Kokkos::HostSpace>();
135 tpetraMultiVector_.getNonconstObj()->template modify<Kokkos::HostSpace>();
141 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
145 tpetraMultiVector_.getNonconstObj()->scale(alpha);
149 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
155 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
161 tpetraMultiVector_.getNonconstObj()->update(alpha, *tmv, ST::one());
164 tpetraMultiVector_.getNonconstObj()->template sync<Kokkos::HostSpace>();
165 tpetraMultiVector_.getNonconstObj()->template modify<Kokkos::HostSpace>();
171 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
183 typedef Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> TMV;
186 bool allCastsSuccessful =
true;
188 auto mvIter = mv.begin();
189 auto tmvIter = tmvs.begin();
190 for (; mvIter != mv.end(); ++mvIter, ++tmvIter) {
191 tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromPtr(*mvIter));
195 allCastsSuccessful =
false;
203 auto len = tmvs.
size();
205 tpetraMultiVector_.getNonconstObj()->scale(beta);
206 }
else if (len == 1 && allCastsSuccessful) {
207 tpetraMultiVector_.getNonconstObj()->update(alpha[0], *tmvs[0], beta);
208 }
else if (len == 2 && allCastsSuccessful) {
209 tpetraMultiVector_.getNonconstObj()->update(alpha[0], *tmvs[0], alpha[1], *tmvs[1], beta);
210 }
else if (allCastsSuccessful) {
212 auto tmvIter = tmvs.begin();
213 auto alphaIter = alpha.
begin();
218 for (; tmvIter != tmvs.end(); ++tmvIter) {
219 if (tmvIter->getRawPtr() == tpetraMultiVector_.getConstObj().getRawPtr()) {
221 tmv = Teuchos::rcp(
new TMV(*tpetraMultiVector_.getConstObj(), Teuchos::Copy));
226 tmvIter = tmvs.
begin();
230 if ((tmvs.size() % 2) == 0) {
231 tpetraMultiVector_.getNonconstObj()->scale(beta);
233 tpetraMultiVector_.getNonconstObj()->update(*alphaIter, *(*tmvIter), beta);
237 for (; tmvIter != tmvs.end(); tmvIter+=2, alphaIter+=2) {
238 tpetraMultiVector_.getNonconstObj()->update(
239 *alphaIter, *(*tmvIter), *(alphaIter+1), *(*(tmvIter+1)), ST::one());
243 tpetraMultiVector_.getNonconstObj()->template sync<Kokkos::HostSpace>();
244 tpetraMultiVector_.getNonconstObj()->template modify<Kokkos::HostSpace>();
250 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
256 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
261 tpetraMultiVector_.getConstObj()->dot(*tmv, prods);
264 tpetraMultiVector_.getNonconstObj()->template sync<Kokkos::HostSpace>();
265 tpetraMultiVector_.getNonconstObj()->template modify<Kokkos::HostSpace>();
271 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
276 tpetraMultiVector_.getConstObj()->norm1(norms);
280 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
285 tpetraMultiVector_.getConstObj()->norm2(norms);
289 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
294 tpetraMultiVector_.getConstObj()->normInf(norms);
298 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
305 return constTpetraVector<Scalar>(
307 tpetraMultiVector_->getVector(j)
312 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
319 return tpetraVector<Scalar>(
321 tpetraMultiVector_.getNonconstObj()->getVectorNonConst(j)
326 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
332 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
333 std::cerr <<
"\nTpetraMultiVector::subView(Range1D) const called!\n";
335 const Range1D colRng = this->validateColRange(col_rng_in);
338 this->getConstTpetraMultiVector()->subView(colRng);
341 tpetraVectorSpace<Scalar>(
342 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal>(
343 tpetraView->getNumVectors(),
344 tpetraView->getMap()->getComm(),
345 tpetraView->getMap()->getNode()
349 return constTpetraMultiVector(
357 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
363 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
364 std::cerr <<
"\nTpetraMultiVector::subView(Range1D) called!\n";
366 const Range1D colRng = this->validateColRange(col_rng_in);
369 this->getTpetraMultiVector()->subViewNonConst(colRng);
372 tpetraVectorSpace<Scalar>(
373 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal>(
374 tpetraView->getNumVectors(),
375 tpetraView->getMap()->getComm(),
376 tpetraView->getMap()->getNode()
380 return tpetraMultiVector(
388 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
394 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
395 std::cerr <<
"\nTpetraMultiVector::subView(ArrayView) const called!\n";
400 cols[i] = static_cast<std::size_t>(cols_in[i]);
403 this->getConstTpetraMultiVector()->subView(cols());
406 tpetraVectorSpace<Scalar>(
407 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal>(
408 tpetraView->getNumVectors(),
409 tpetraView->getMap()->getComm(),
410 tpetraView->getMap()->getNode()
414 return constTpetraMultiVector(
422 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
428 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
429 std::cerr <<
"\nTpetraMultiVector::subView(ArrayView) called!\n";
434 cols[i] = static_cast<std::size_t>(cols_in[i]);
437 this->getTpetraMultiVector()->subViewNonConst(cols());
440 tpetraVectorSpace<Scalar>(
441 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal>(
442 tpetraView->getNumVectors(),
443 tpetraView->getMap()->getComm(),
444 tpetraView->getMap()->getNode()
448 return tpetraMultiVector(
456 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
463 const Ordinal primary_global_offset
469 for (
auto itr = multi_vecs.begin(); itr != multi_vecs.end(); ++itr) {
472 Teuchos::rcp_const_cast<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >(
473 tmv->getConstTpetraMultiVector())->
template sync<Kokkos::HostSpace>();
478 for (
auto itr = targ_multi_vecs.begin(); itr != targ_multi_vecs.end(); ++itr) {
479 Ptr<TMV> tmv = Teuchos::ptr_dynamic_cast<TMV>(*itr);
481 tmv->getTpetraMultiVector()->template sync<Kokkos::HostSpace>();
482 tmv->getTpetraMultiVector()->template modify<Kokkos::HostSpace>();
487 primary_op, multi_vecs, targ_multi_vecs, reduct_objs, primary_global_offset);
491 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
500 typedef typename Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> TMV;
501 Teuchos::rcp_const_cast<TMV>(
502 tpetraMultiVector_.getConstObj())->
template sync<Kokkos::HostSpace>();
509 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
518 tpetraMultiVector_.getNonconstObj()->template sync<Kokkos::HostSpace>();
519 tpetraMultiVector_.getNonconstObj()->template modify<Kokkos::HostSpace>();
526 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
536 typedef typename Tpetra::MultiVector<
537 Scalar,LocalOrdinal,GlobalOrdinal,Node>::execution_space execution_space;
538 tpetraMultiVector_.getNonconstObj()->template sync<execution_space>();
589 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
593 return tpetraVectorSpace_;
597 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
602 *localValues = tpetraMultiVector_.getNonconstObj()->get1dViewNonConst();
603 *leadingDim = tpetraMultiVector_->getStride();
607 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
612 *localValues = tpetraMultiVector_->get1dView();
613 *leadingDim = tpetraMultiVector_->getStride();
617 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
627 typedef Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> TMV;
635 typedef typename TMV::execution_space execution_space;
636 Teuchos::rcp_const_cast<TMV>(X_tpetra)->template sync<execution_space>();
637 Y_tpetra->template sync<execution_space>();
638 Teuchos::rcp_const_cast<TMV>(
639 tpetraMultiVector_.getConstObj())->
template sync<execution_space>();
644 "Error, conjugation without transposition is not allowed for complex scalar types!");
649 trans = Teuchos::NO_TRANS;
652 trans = Teuchos::NO_TRANS;
658 trans = Teuchos::CONJ_TRANS;
662 Y_tpetra->template modify<execution_space>();
663 Y_tpetra->multiply(trans, Teuchos::NO_TRANS, alpha, *tpetraMultiVector_.getConstObj(), *X_tpetra, beta);
665 Teuchos::rcp_const_cast<TMV>(
666 tpetraMultiVector_.getConstObj())->
template sync<Kokkos::HostSpace>();
675 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
676 template<
class TpetraMultiVector_t>
690 tpetraVectorSpace_ = tpetraVectorSpace;
691 domainSpace_ = domainSpace;
692 tpetraMultiVector_.initialize(tpetraMultiVector);
693 this->updateSpmdSpace();
697 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
698 RCP<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >
702 using Teuchos::rcp_dynamic_cast;
706 RCP<TMV> tmv = rcp_dynamic_cast<TMV>(mv);
708 return tmv->getTpetraMultiVector();
711 RCP<TV> tv = rcp_dynamic_cast<TV>(mv);
713 return tv->getTpetraVector();
716 return Teuchos::null;
719 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
720 RCP<const Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >
724 using Teuchos::rcp_dynamic_cast;
728 RCP<const TMV> tmv = rcp_dynamic_cast<const TMV>(mv);
730 return tmv->getConstTpetraMultiVector();
733 RCP<const TV> tv = rcp_dynamic_cast<const TV>(mv);
735 return tv->getConstTpetraVector();
738 return Teuchos::null;
745 #endif // THYRA_TPETRA_MULTIVECTOR_HPP