2 #ifndef OPENGM_MESSAGE_PASSING_HXX 3 #define OPENGM_MESSAGE_PASSING_HXX 29 static typename M::ValueType
30 op(
const M& in1,
const M& in2)
32 typedef typename M::ValueType ValueType;
33 ValueType v1,v2,d1,d2;
36 for(
size_t n=0; n<in1.size(); ++n) {
49 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST=opengm::MaxDistance>
74 template<
class _GM,
class _ACC>
86 const size_t maximumNumberOfSteps = 100,
87 const ValueType bound = static_cast<ValueType> (0.000000),
88 const ValueType damping = static_cast<ValueType> (0),
89 const SpecialParameterType & specialParameter =SpecialParameterType(),
92 : maximumNumberOfSteps_(maximumNumberOfSteps),
95 inferSequential_(
false),
96 useNormalization_(
true),
97 specialParameter_(specialParameter),
106 : maximumNumberOfSteps_(p.maximumNumberOfSteps_),
108 damping_(p.damping_),
109 inferSequential_(p.inferSequential_),
110 useNormalization_(p.useNormalization_),
111 specialParameter_(p.specialParameter_),
112 isAcyclic_(p.isAcyclic_)
131 internalMessageId_(-1)
133 Message(
const size_t nodeId,
const size_t & internalMessageId)
135 internalMessageId_(internalMessageId)
139 size_t internalMessageId_;
144 std::string name()
const;
145 const GraphicalModelType& graphicalModel()
const;
151 virtual void reset();
153 template<
class VisitorType>
157 void setMaxSteps(
size_t maxSteps) {parameter_.maximumNumberOfSteps_ = maxSteps;}
163 void inferParallel();
164 void inferSequential();
165 template<
class VisitorType>
166 void inferParallel(VisitorType&);
167 template<
class VisitorType>
168 void inferAcyclic(VisitorType&);
169 template<
class VisitorType>
170 void inferSequential(VisitorType&);
172 const GraphicalModelType& gm_;
174 std::vector<FactorHullType> factorHulls_;
175 std::vector<VariableHullType> variableHulls_;
178 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
181 const GraphicalModelType& gm,
185 parameter_(parameter)
187 if(parameter_.sortedNodeList_.size() == 0) {
188 parameter_.sortedNodeList_.resize(gm.numberOfVariables());
189 for (
size_t i = 0; i < gm.numberOfVariables(); ++i)
190 parameter_.sortedNodeList_[i] = i;
192 OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm.numberOfVariables());
194 UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
197 variableHulls_.resize(gm.numberOfVariables(), VariableHullType ());
198 for (
size_t i = 0; i < gm.numberOfVariables(); ++i) {
199 variableHulls_[i].assign(gm, i, ¶meter_.specialParameter_);
201 factorHulls_.resize(gm.numberOfFactors(), FactorHullType ());
202 for (
size_t i = 0; i < gm.numberOfFactors(); i++) {
203 factorHulls_[i].assign(gm, i, variableHulls_, ¶meter_.specialParameter_);
207 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
211 if(parameter_.sortedNodeList_.size() == 0) {
212 parameter_.sortedNodeList_.resize(gm_.numberOfVariables());
213 for (
size_t i = 0; i < gm_.numberOfVariables(); ++i)
214 parameter_.sortedNodeList_[i] = i;
216 OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
217 UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
220 variableHulls_.resize(gm_.numberOfVariables(), VariableHullType ());
221 for (
size_t i = 0; i < gm_.numberOfVariables(); ++i) {
222 variableHulls_[i].assign(gm_, i, ¶meter_.specialParameter_);
224 factorHulls_.resize(gm_.numberOfFactors(), FactorHullType ());
225 for (
size_t i = 0; i < gm_.numberOfFactors(); i++) {
226 factorHulls_[i].assign(gm_, i, variableHulls_, ¶meter_.specialParameter_);
230 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
236 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
242 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
249 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
250 template<
class VisitorType>
258 parameter_.useNormalization_=
false;
259 inferAcyclic(visitor);
261 if (parameter_.inferSequential_) {
262 inferSequential(visitor);
264 inferParallel(visitor);
267 if (gm_.isAcyclic()) {
270 parameter_.useNormalization_=
false;
271 inferAcyclic(visitor);
274 if (parameter_.inferSequential_) {
275 inferSequential(visitor);
277 inferParallel(visitor);
289 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
293 return inferAcyclic(v);
303 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
304 template<
class VisitorType>
312 visitor.begin(*
this);
313 size_t numberOfVariables = gm_.numberOfVariables();
314 size_t numberOfFactors = gm_.numberOfFactors();
317 std::vector<std::vector<size_t> > counterVar2FacMessage(numberOfVariables);
318 std::vector<std::vector<size_t> > counterFac2VarMessage(numberOfFactors);
320 std::vector<Message> ready2SendVar2FacMessage;
321 std::vector<Message> ready2SendFac2VarMessage;
322 ready2SendVar2FacMessage.reserve(100);
323 ready2SendFac2VarMessage.reserve(100);
324 for (
size_t fac = 0; fac < numberOfFactors; ++fac) {
325 counterFac2VarMessage[fac].resize(gm_[fac].numberOfVariables(), gm_[fac].numberOfVariables() - 1);
327 for (
size_t var = 0; var < numberOfVariables; ++var) {
328 counterVar2FacMessage[var].resize(gm_.numberOfFactors(var));
329 for (
size_t i = 0; i < gm_.numberOfFactors(var); ++i) {
330 counterVar2FacMessage[var][i] = gm_.numberOfFactors(var) - 1;
334 for (
size_t var = 0; var < numberOfVariables; ++var) {
335 for (
size_t i = 0; i < counterVar2FacMessage[var].size(); ++i) {
336 if (counterVar2FacMessage[var][i] == 0) {
337 --counterVar2FacMessage[var][i];
338 ready2SendVar2FacMessage.push_back(Message(var, i));
342 for (
size_t fac = 0; fac < numberOfFactors; ++fac) {
343 for (
size_t i = 0; i < counterFac2VarMessage[fac].size(); ++i) {
344 if (counterFac2VarMessage[fac][i] == 0) {
345 --counterFac2VarMessage[fac][i];
346 ready2SendFac2VarMessage.push_back(Message(fac, i));
351 while (ready2SendVar2FacMessage.size() > 0 || ready2SendFac2VarMessage.size() > 0) {
352 while (ready2SendVar2FacMessage.size() > 0) {
353 Message m = ready2SendVar2FacMessage.back();
354 size_t nodeId = m.nodeId_;
355 size_t factorId = gm_.factorOfVariable(nodeId,m.internalMessageId_);
357 variableHulls_[nodeId].propagate(gm_, m.internalMessageId_, 0,
false);
358 ready2SendVar2FacMessage.pop_back();
360 for (
size_t i = 0; i < gm_[factorId].numberOfVariables(); ++i) {
361 if (gm_[factorId].variableIndex(i) != nodeId) {
362 if (--counterFac2VarMessage[factorId][i] == 0) {
363 ready2SendFac2VarMessage.push_back(Message(factorId, i));
368 while (ready2SendFac2VarMessage.size() > 0) {
369 Message m = ready2SendFac2VarMessage.back();
370 size_t factorId = m.nodeId_;
371 size_t nodeId = gm_[factorId].variableIndex(m.internalMessageId_);
373 factorHulls_[factorId].propagate(m.internalMessageId_, 0, parameter_.useNormalization_);
374 ready2SendFac2VarMessage.pop_back();
376 for (
size_t i = 0; i < gm_.numberOfFactors(nodeId); ++i) {
377 if (gm_.factorOfVariable(nodeId,i) != factorId) {
378 if (--counterVar2FacMessage[nodeId][i] == 0) {
379 ready2SendVar2FacMessage.push_back(Message(nodeId, i));
384 if(visitor(*
this)!=0)
392 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
397 for (
size_t i = 0; i < variableHulls_.size(); ++i) {
398 variableHulls_[i].propagateAll(damping,
false);
400 for (
size_t i = 0; i < factorHulls_.size(); ++i) {
401 factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
406 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
409 return inferParallel(v);
414 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
415 template<
class VisitorType>
423 visitor.begin(*
this);
426 for (
size_t i = 0; i < factorHulls_.size(); ++i) {
427 if (factorHulls_[i].numberOfBuffers() < 2) {
428 factorHulls_[i].propagateAll(0, parameter_.useNormalization_);
429 factorHulls_[i].propagateAll(0, parameter_.useNormalization_);
432 for (
unsigned long n = 0; n < parameter_.maximumNumberOfSteps_; ++n) {
433 for (
size_t i = 0; i < variableHulls_.size(); ++i) {
434 variableHulls_[i].propagateAll(gm_, damping,
false);
436 for (
size_t i = 0; i < factorHulls_.size(); ++i) {
437 if (factorHulls_[i].numberOfBuffers() >= 2)
438 factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
440 if(visitor(*
this)!=0)
443 if (c < parameter_.bound_) {
459 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
462 return inferSequential(v);
475 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
476 template<
class VisitorType>
481 OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
482 visitor.begin(*
this);
486 std::vector<size_t> nodeOrder(gm_.numberOfVariables());
487 for (
size_t o = 0; o < gm_.numberOfVariables(); ++o) {
488 nodeOrder[parameter_.sortedNodeList_[o]] = o;
492 for (
size_t f = 0; f < factorHulls_.size(); ++f) {
493 if (factorHulls_[f].numberOfBuffers() < 2) {
494 factorHulls_[f].propagateAll(0, parameter_.useNormalization_);
495 factorHulls_[f].propagateAll(0, parameter_.useNormalization_);
500 std::vector<std::vector<size_t> > inversePositions(gm_.numberOfVariables());
501 for(
size_t var=0; var<gm_.numberOfVariables();++var) {
502 for(
size_t i=0; i<gm_.numberOfFactors(var); ++i) {
503 size_t factorId = gm_.factorOfVariable(var,i);
504 for(
size_t j=0; j<gm_.numberOfVariables(factorId);++j) {
505 if(gm_.variableOfFactor(factorId,j)==var) {
506 inversePositions[var].push_back(j);
515 for (
unsigned long itteration = 0; itteration < parameter_.maximumNumberOfSteps_; ++itteration) {
516 if(itteration%2==0) {
518 for (
size_t o = 0; o < gm_.numberOfVariables(); ++o) {
519 size_t variableId = parameter_.sortedNodeList_[o];
521 for(
size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
522 size_t factorId = gm_.factorOfVariable(variableId,i);
523 factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
527 variableHulls_[variableId].propagateAll(gm_, damping,
false);
532 for (
size_t o = 0; o < gm_.numberOfVariables(); ++o) {
533 size_t variableId = parameter_.sortedNodeList_[gm_.numberOfVariables() - 1 - o];
535 for(
size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
536 size_t factorId = gm_.factorOfVariable(variableId,i);
537 factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
540 variableHulls_[variableId].propagateAll(gm_, damping,
false);
543 if(visitor(*
this)!=0)
546 if (c < parameter_.bound_) {
554 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
558 const size_t variableIndex,
562 variableHulls_[variableIndex].marginal(gm_, variableIndex, out, parameter_.useNormalization_);
566 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
570 const size_t factorIndex,
573 typedef typename GM::OperatorType OP;
575 out.assign(gm_, gm_[factorIndex].variableIndicesBegin(), gm_[factorIndex].variableIndicesEnd(), OP::template neutral<ValueType>());
576 factorHulls_[factorIndex].marginal(out, parameter_.useNormalization_);
581 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
585 for (
size_t j = 0; j < factorHulls_.size(); ++j) {
586 for (
size_t i = 0; i < factorHulls_[j].numberOfBuffers(); ++i) {
587 ValueType d = factorHulls_[j].template distance<DIST > (i);
597 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
601 for (
size_t j = 0; j < variableHulls_.size(); ++j) {
602 for (
size_t i = 0; i < variableHulls_[j].numberOfBuffers(); ++i) {
603 ValueType d = variableHulls_[j].template distance<DIST > (i);
613 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST>
616 return convergenceXF();
619 template<
class GM,
class ACC,
class UPDATE_RULES,
class DIST >
623 std::vector<LabelType>& conf,
627 throw RuntimeError(
"This implementation of message passing cannot return the k-th optimal configuration.");
631 return this->modeFromFactorMarginal(conf);
634 return this->modeFromFactorMarginal(conf);
642 #endif // #ifndef OPENGM_BELIEFPROPAGATION_HXX MessagePassing< _GM, _ACC, UR, DIST > type
InferenceTermination factorMarginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for all variables connected to a factor
UPDATE_RULES::SpecialParameterType SpecialParameterType
std::vector< size_t > sortedNodeList_
UPDATE_RULES::FactorHullType FactorHullType
void infer(const typename INF::GraphicalModelType &gm, const typename INF::Parameter ¶m, std::vector< typename INF::LabelType > &conf)
opengm::Tribool isAcyclic_
A framework for message passing algorithms Cf. F. R. Kschischang, B. J. Frey and H...
SpecialParameterType specialParameter_
visitors::TimingVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > TimingVisitorType
Visitor.
#define OPENGM_ASSERT(expression)
ValueType convergence() const
cumulative distance between all pairs of messages (between the previous and the current interation) ...
opengm::Tribool useNormalization_
visitors::VerboseVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > VerboseVisitorType
Visitor.
InferenceTermination marginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
static T neutral()
neutral element (with return)
visitors::EmptyVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > EmptyVisitorType
Visitor.
UPDATE_RULES::template RebindGm< _GM >::type UR
UPDATE_RULES::VariableHullType VariableHullType
MessagePassing< _GM, ACC, UR, DIST > type
UPDATE_RULES::template RebindGmAndAcc< _GM, _ACC >::type UR
GraphicalModelType::ValueType ValueType
const GraphicalModelType & graphicalModel() const
Inference algorithm interface.
void propagate(const ValueType &=0)
invoke one iteration of message passing
ValueType convergenceFX() const
cumulative distance between all pairs of messages from factors to variables (between the previous and...
static M::ValueType op(const M &in1, const M &in2)
operation
InferenceTermination arg(std::vector< LabelType > &, const size_t=1) const
output a solution
Variable with three values (true=1, false=0, maybe=-1 )
ValueType convergenceXF() const
cumulative distance between all pairs of messages from variables to factors (between the previous and...
InferenceTermination infer()
static void op(const T1 &in1, T2 &out)
operation (in-place)
void setMaxSteps(size_t maxSteps)
size_t maximumNumberOfSteps_
MessagePassing(const GraphicalModelType &, const Parameter &=Parameter())
GraphicalModelType::IndependentFactorType IndependentFactorType