6 #ifndef OPENGM_EXTERNAL_TRWS_HXX 7 #define OPENGM_EXTERNAL_TRWS_HXX 17 #include "MRFEnergy.h" 18 #include "instances.h" 19 #include "MRFEnergy.cpp" 20 #include "minimize.cpp" 21 #include "treeProbabilities.cpp" 22 #include "ordering.cpp" 55 template<
class _GM,
class _ACC>
83 : numberOfIterations_(p.numberOfIterations_),
84 useRandomStart_(p.useRandomStart_),
85 useZeroStart_(p.useZeroStart_),
88 tolerance_(p.tolerance_),
89 minDualChange_(p.minDualChange_)
94 else if(p.energyType_==1){
97 else if(p.energyType_==2){
100 else if(p.energyType_==3){
106 numberOfIterations_ = 1000;
107 useRandomStart_ =
false;
108 useZeroStart_ =
false;
112 minDualChange_ = 0.00001;
113 calculateMinMarginals_ =
false;
121 std::string
name()
const;
124 template<
class VISITOR>
129 typename GM::ValueType
bound()
const;
130 typename GM::ValueType
value()
const;
132 const GraphicalModelType& gm_;
144 TypeGeneral::REAL* minMarginals_;
145 size_t* minMarginalsOffsets_;
151 std::vector<LabelType> state_;
154 bool hasSameLabelNumber_;
155 void checkLabelNumber();
157 void generateMRFView();
158 void generateMRFTables();
159 void generateMRFTL1();
160 void generateMRFTL2();
166 bool truncatedAbsoluteDifferenceFactors()
const;
169 bool truncatedSquaredDifferenceFactors()
const;
171 template <
class ENERGYTYPE>
174 template<
class VISITOR,
class ENERGYTYPE>
178 template<
class GM,
class ENERGYTYPE>
180 static void*
create(
typename GM::IndexType numLabels);
185 static void*
create(
typename GM::IndexType numLabels);
190 static void*
create(
typename GM::IndexType numLabels);
195 static void*
create(
typename GM::IndexType numLabels);
200 static void*
create(
typename GM::IndexType numLabels);
203 template<
class GM,
class ENERGYTYPE>
233 : gm_(gm), parameter_(para), mrfView_(NULL), nodesView_(NULL), mrfGeneral_(NULL), nodesGeneral_(NULL),
234 mrfTL1_(NULL), nodesTL1_(NULL), mrfTL2_(NULL), nodesTL2_(NULL), minMarginals_(NULL), minMarginalsOffsets_(NULL),
235 numNodes_(gm_.numberOfVariables()), maxNumLabels_(gm_.numberOfLabels(0)) {
240 minMarginalsOffsets_ =
new size_t[gm_.numberOfVariables()];
241 for(
size_t i=0; i<gm_.numberOfVariables(); ++i){
242 minMarginalsOffsets_[i] = count;
243 count += gm_.numberOfLabels(i);
245 minMarginals_ =
new TypeGeneral::REAL[count];
259 if(!hasSameLabelNumber_) {
260 throw(
RuntimeError(
"TRWS TL1 only supports graphical models where each variable has the same number of states."));
266 if(!hasSameLabelNumber_) {
267 throw(
RuntimeError(
"TRWS TL2 only supports graphical models where each variable has the same number of states."));
299 delete[] nodesGeneral_;
316 delete[] minMarginals_;
317 delete[] minMarginalsOffsets_;
341 return this->
infer(visitor);
345 template<
class VISITOR>
353 return inferImpl(visitor, mrfView_);
357 return inferImpl(visitor, mrfGeneral_);
361 return inferImpl(visitor, mrfTL1_);
365 return inferImpl(visitor, mrfTL2_);
382 std::vector<LabelType>&
arg,
390 arg.resize(numNodes_);
393 for(
IndexType i = 0; i < numNodes_; i++) {
394 arg[i] = mrfView_->GetSolution(nodesView_[i]);
400 for(
IndexType i = 0; i < numNodes_; i++) {
401 arg[i] = mrfGeneral_->GetSolution(nodesGeneral_[i]);
407 for(
IndexType i = 0; i < numNodes_; i++) {
408 arg[i] = mrfTL1_->GetSolution(nodesTL1_[i]);
414 for(
IndexType i = 0; i < numNodes_; i++) {
415 arg[i] = mrfTL2_->GetSolution(nodesTL2_[i]);
440 const size_t variableIndex,
446 out.assign(gm_, &variableIndex, &variableIndex+1, 0);
447 for(
size_t i=0; i<gm_.numberOfLabels(variableIndex); ++i){
448 out(i) = minMarginals_[i+minMarginalsOffsets_[variableIndex]];
457 inline typename GM::ValueType
459 return lowerBound_+constTerm_;
462 inline typename GM::ValueType
464 return value_+constTerm_;
469 hasSameLabelNumber_ =
true;
470 for(
IndexType i = 1; i < gm_.numberOfVariables(); i++) {
471 if(gm_.numberOfLabels(i) != maxNumLabels_) {
472 hasSameLabelNumber_ =
false;
474 if(gm_.numberOfLabels(i) > maxNumLabels_) {
475 maxNumLabels_ = gm_.numberOfLabels(i);
486 for(
IndexType i = 0; i < numNodes_; i++) {
487 std::vector<typename GM::IndexType> factors;
488 for(
typename GM::ConstFactorIterator iter = gm_.factorsOfVariableBegin(i); iter != gm_.factorsOfVariableEnd(i); iter++) {
489 if(gm_[*iter].numberOfVariables() == 1) {
490 factors.push_back(*iter);
498 for(
IndexType i = 0; i < gm_.numberOfFactors(); i++) {
499 if(gm_[i].numberOfVariables() == 0){
501 constTerm_ += gm_[i](&l);
503 if(gm_[i].numberOfVariables() == 2) {
511 mrfView_->AddRandomMessages(1, 0.0, 1.0);
513 mrfView_->ZeroMessages();
520 typename TypeGeneral::REAL* D =
new typename TypeGeneral::REAL[maxNumLabels_];
521 addNodes(mrfGeneral_, nodesGeneral_, D);
527 for(
IndexType i = 0; i < gm_.numberOfFactors(); i++) {
528 if(gm_[i].numberOfVariables() == 0){
530 constTerm_ += gm_[i](&l);
532 if(gm_[i].numberOfVariables() == 2) {
535 IndexType numLabels_a = gm_.numberOfLabels(a);
536 IndexType numLabels_b = gm_.numberOfLabels(b);
537 typename TypeGeneral::REAL* V =
new typename TypeGeneral::REAL[numLabels_a * numLabels_b];
538 for(
size_t j = 0; j < numLabels_a; j++) {
539 for(
size_t k = 0; k < numLabels_b; k++) {
542 V[j + k * numLabels_a] = gm_[i](index);
545 mrfGeneral_->AddEdge(nodesGeneral_[a], nodesGeneral_[b], TypeGeneral::EdgeData(TypeGeneral::GENERAL, V));
552 mrfGeneral_->AddRandomMessages(1, 0.0, 1.0);
554 mrfGeneral_->ZeroMessages();
563 typename TypeTruncatedLinear::REAL* D =
new typename TypeTruncatedLinear::REAL[maxNumLabels_];
564 addNodes(mrfTL1_, nodesTL1_, D);
569 for(
IndexType i = 0; i < gm_.numberOfFactors(); i++) {
570 if(gm_[i].numberOfVariables() == 0){
572 constTerm_ += gm_[i](&l);
574 if(gm_[i].numberOfVariables() == 2) {
587 mrfTL1_->AddEdge(nodesTL1_[a], nodesTL1_[b], TypeTruncatedLinear::EdgeData(w, w * t));
593 mrfTL1_->AddRandomMessages(1, 0.0, 1.0);
595 mrfTL1_->ZeroMessages();
604 typename TypeTruncatedQuadratic::REAL* D =
new typename TypeTruncatedQuadratic::REAL[maxNumLabels_];
605 addNodes(mrfTL2_, nodesTL2_, D);
610 for(
IndexType i = 0; i < gm_.numberOfFactors(); i++) {
611 if(gm_[i].numberOfVariables() == 0){
613 constTerm_ += gm_[i](&l);
615 if(gm_[i].numberOfVariables() == 2) {
628 mrfTL2_->AddEdge(nodesTL2_[a], nodesTL2_[b], TypeTruncatedQuadratic::EdgeData(w, w * t));
636 mrfTL2_->AddRandomMessages(1, 0.0, 1.0);
638 mrfTL2_->ZeroMessages();
652 IndexType index0[] = {0, maxNumLabels_-1};
654 return gm_[factor](index0)/gm_[factor](index1);
659 for(
IndexType i = 0; i < gm_.numberOfFactors(); i++) {
660 if(gm_.numberOfVariables(i) == 2) {
661 if(gm_[i].isTruncatedAbsoluteDifference() ==
false) {
671 for(
IndexType i = 0; i < gm_.numberOfFactors(); i++) {
672 if(gm_.numberOfVariables(i) == 2) {
673 if(gm_[i].isTruncatedSquaredDifference() ==
false) {
682 template <
class ENERGYTYPE>
688 for(
IndexType i = 0; i < numNodes_; i++) {
689 for(
IndexType j = 0; j < gm_.numberOfLabels(i); j++) {
692 for(
typename GM::ConstFactorIterator iter = gm_.factorsOfVariableBegin(i); iter != gm_.factorsOfVariableEnd(i); iter++) {
693 if(gm_[*iter].numberOfVariables() == 1) {
694 for(
IndexType j = 0; j < gm_.numberOfLabels(i); j++) {
695 D[j] += gm_[*iter](&j);
703 template<
class GM,
class ENERGYTYPE>
729 template<
class GM,
class ENERGYTYPE>
742 return mrf->AddNode(
typename TypeGeneral::LocalSize(numLabels),
typename TypeGeneral::NodeData(D));
747 return mrf->AddNode(
typename TypeTruncatedLinear::LocalSize(),
typename TypeTruncatedLinear::NodeData(D));
752 return mrf->AddNode(
typename TypeTruncatedQuadratic::LocalSize(),
typename TypeTruncatedQuadratic::NodeData(D));
756 template<
class VISITOR,
class ENERGYTYPE>
759 options.m_iterMax = 1;
761 visitor.begin(*
this);
765 typename ENERGYTYPE::REAL v;
767 mrf->Minimize_BP(options, v, minMarginals_);
774 typename ENERGYTYPE::REAL v;
775 typename ENERGYTYPE::REAL b;
776 typename ENERGYTYPE::REAL d;
778 mrf->Minimize_TRW_S(options, b, v, minMarginals_);
785 if(fabs(value_ - lowerBound_) /
opengmMax(static_cast<double>(fabs(value_)), 1.0) < parameter_.
tolerance_) {
802 #endif // #ifndef OPENGM_EXTERNAL_TRWS_HXX
InferenceTermination marginal(const size_t variableIndex, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
const GraphicalModelType & graphicalModel() const
GM::ValueType value() const
return the solution (value)
bool calculateMinMarginals_
Calculate MinMarginals.
visitors::EmptyVisitor< TRWS< GM > > EmptyVisitorType
bool useZeroStart_
zero starting message
T opengmMax(const T &x, const T &y)
InferenceTermination infer()
InferenceTermination arg(std::vector< LabelType > &, const size_t &=1) const
bool useRandomStart_
random starting message
void create(const hid_t &, const std::string &, ShapeIterator, ShapeIterator, CoordinateOrder)
Create and close an HDF5 dataset to store Marray data.
#define OPENGM_ASSERT(expression)
static MRFEnergy< ENERGYTYPE >::NodeId add(MRFEnergy< ENERGYTYPE > *mrf, typename GM::IndexType numLabels, typename ENERGYTYPE::REAL *D)
double minDualChange_
TRWS termintas if fabs(bound(t)-bound(t+1)) < minDualChange_.
GraphicalModelType::IndexType IndexType
EnergyType
possible energy types for TRWS
GraphicalModelType::ValueType ValueType
static T ineutral()
inverse neutral element (with return)
Inference algorithm interface.
GM::ValueType bound() const
return a bound on the solution
double tolerance_
TRWS termintas if fabs(value - bound) / max(fabs(value), 1) < trwsTolerance_.
size_t numberOfIterations_
number of iterations
visitors::VerboseVisitor< TRWS< GM > > VerboseVisitorType
static T neutral()
neutral element (with return)
TRWS(const GraphicalModelType &gm, const Parameter para=Parameter())
Minimization as a unary accumulation.
static void * create(typename GM::IndexType numLabels)
message passing (BPS, TRWS): [?]
static const size_t ContinueInf
bool doBPS_
use normal LBP
Parameter(const P &p)
Constructor.
EnergyType energyType_
selected energy type
opengm::Minimizer AccumulationType
visitors::TimingVisitor< TRWS< GM > > TimingVisitorType
GraphicalModelType::IndependentFactorType IndependentFactorType