OpenGM  2.3.x
Discrete Graphical Model Library
messagepassing.hxx
Go to the documentation of this file.
1 #pragma once
2 #ifndef OPENGM_MESSAGE_PASSING_HXX
3 #define OPENGM_MESSAGE_PASSING_HXX
4 
5 #include <vector>
6 #include <map>
7 #include <list>
8 #include <set>
9 
10 #include "opengm/opengm.hxx"
19 
20 namespace opengm {
21 
24 struct MaxDistance {
28  template<class M>
29  static typename M::ValueType
30  op(const M& in1, const M& in2)
31  {
32  typedef typename M::ValueType ValueType;
33  ValueType v1,v2,d1,d2;
34  Maximizer::neutral(v1);
36  for(size_t n=0; n<in1.size(); ++n) {
37  d1=in1(n)-in2(n);
38  d2=-d1;
39  Maximizer::op(d1,v1);
40  Maximizer::op(d2,v2);
41  }
42  Maximizer::op(v2,v1);
43  return v1;
44  }
45 };
46 
49 template<class GM, class ACC, class UPDATE_RULES, class DIST=opengm::MaxDistance>
50 class MessagePassing : public Inference<GM, ACC> {
51 public:
52  typedef GM GraphicalModelType;
53  typedef ACC Accumulation;
54  typedef ACC AccumulatorType;
56  typedef DIST Distance;
57  typedef typename UPDATE_RULES::FactorHullType FactorHullType;
58  typedef typename UPDATE_RULES::VariableHullType VariableHullType;
59 
66 
67 
68  template<class _GM>
69  struct RebindGm{
70  typedef typename UPDATE_RULES:: template RebindGm<_GM>::type UR;
72  };
73 
74  template<class _GM,class _ACC>
76  typedef typename UPDATE_RULES:: template RebindGmAndAcc<_GM,_ACC>::type UR;
78  };
79 
80 
81 
82  struct Parameter {
83  typedef typename UPDATE_RULES::SpecialParameterType SpecialParameterType;
84  Parameter
85  (
86  const size_t maximumNumberOfSteps = 100,
87  const ValueType bound = static_cast<ValueType> (0.000000),
88  const ValueType damping = static_cast<ValueType> (0),
89  const SpecialParameterType & specialParameter =SpecialParameterType(),
90  const opengm::Tribool isAcyclic = opengm::Tribool::Maybe
91  )
92  : maximumNumberOfSteps_(maximumNumberOfSteps),
93  bound_(bound),
94  damping_(damping),
95  inferSequential_(false),
96  useNormalization_(true),
97  specialParameter_(specialParameter),
98  isAcyclic_(isAcyclic)
99  {}
100 
101  template<class P>
102  Parameter
103  (
104  const P & p
105  )
106  : maximumNumberOfSteps_(p.maximumNumberOfSteps_),
107  bound_(p.bound_),
108  damping_(p.damping_),
109  inferSequential_(p.inferSequential_),
110  useNormalization_(p.useNormalization_),
111  specialParameter_(p.specialParameter_),
112  isAcyclic_(p.isAcyclic_)
113  {}
114 
115 
120  std::vector<size_t> sortedNodeList_;
122  //bool useNormalization_;
123  SpecialParameterType specialParameter_;
125  };
126 
128  struct Message {
129  Message()
130  : nodeId_(-1),
131  internalMessageId_(-1)
132  {}
133  Message(const size_t nodeId, const size_t & internalMessageId)
134  : nodeId_(nodeId),
135  internalMessageId_(internalMessageId)
136  {}
137 
138  size_t nodeId_;
139  size_t internalMessageId_;
140  };
142 
143  MessagePassing(const GraphicalModelType&, const Parameter& = Parameter());
144  std::string name() const;
145  const GraphicalModelType& graphicalModel() const;
146  InferenceTermination marginal(const size_t, IndependentFactorType& out) const;
147  InferenceTermination factorMarginal(const size_t, IndependentFactorType & out) const;
148  ValueType convergenceXF() const;
149  ValueType convergenceFX() const;
150  ValueType convergence() const;
151  virtual void reset();
153  template<class VisitorType>
154  InferenceTermination infer(VisitorType&);
155  void propagate(const ValueType& = 0);
156  InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
157  void setMaxSteps(size_t maxSteps) {parameter_.maximumNumberOfSteps_ = maxSteps;}
158  //InferenceTermination bound(ValueType&) const;
159  //ValueType bound() const;
160 
161 private:
162  void inferAcyclic();
163  void inferParallel();
164  void inferSequential();
165  template<class VisitorType>
166  void inferParallel(VisitorType&);
167  template<class VisitorType>
168  void inferAcyclic(VisitorType&);
169  template<class VisitorType>
170  void inferSequential(VisitorType&);
171 private:
172  const GraphicalModelType& gm_;
173  Parameter parameter_;
174  std::vector<FactorHullType> factorHulls_;
175  std::vector<VariableHullType> variableHulls_;
176 };
177 
178 template<class GM, class ACC, class UPDATE_RULES, class DIST>
180 (
181  const GraphicalModelType& gm,
183 )
184 : gm_(gm),
185  parameter_(parameter)
186 {
187  if(parameter_.sortedNodeList_.size() == 0) {
188  parameter_.sortedNodeList_.resize(gm.numberOfVariables());
189  for (size_t i = 0; i < gm.numberOfVariables(); ++i)
190  parameter_.sortedNodeList_[i] = i;
191  }
192  OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm.numberOfVariables());
193 
194  UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
195 
196  // set hulls
197  variableHulls_.resize(gm.numberOfVariables(), VariableHullType ());
198  for (size_t i = 0; i < gm.numberOfVariables(); ++i) {
199  variableHulls_[i].assign(gm, i, &parameter_.specialParameter_);
200  }
201  factorHulls_.resize(gm.numberOfFactors(), FactorHullType ());
202  for (size_t i = 0; i < gm.numberOfFactors(); i++) {
203  factorHulls_[i].assign(gm, i, variableHulls_, &parameter_.specialParameter_);
204  }
205 }
206 
207 template<class GM, class ACC, class UPDATE_RULES, class DIST>
208 void
210 {
211  if(parameter_.sortedNodeList_.size() == 0) {
212  parameter_.sortedNodeList_.resize(gm_.numberOfVariables());
213  for (size_t i = 0; i < gm_.numberOfVariables(); ++i)
214  parameter_.sortedNodeList_[i] = i;
215  }
216  OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
217  UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
218 
219  // set hulls
220  variableHulls_.resize(gm_.numberOfVariables(), VariableHullType ());
221  for (size_t i = 0; i < gm_.numberOfVariables(); ++i) {
222  variableHulls_[i].assign(gm_, i, &parameter_.specialParameter_);
223  }
224  factorHulls_.resize(gm_.numberOfFactors(), FactorHullType ());
225  for (size_t i = 0; i < gm_.numberOfFactors(); i++) {
226  factorHulls_[i].assign(gm_, i, variableHulls_, &parameter_.specialParameter_);
227  }
228 }
229 
230 template<class GM, class ACC, class UPDATE_RULES, class DIST>
231 inline std::string
233  return "MP";
234 }
235 
236 template<class GM, class ACC, class UPDATE_RULES, class DIST>
239  return gm_;
240 }
241 
242 template<class GM, class ACC, class UPDATE_RULES, class DIST>
245  EmptyVisitorType v;
246  return infer(v);
247 }
248 
249 template<class GM, class ACC, class UPDATE_RULES, class DIST>
250 template<class VisitorType>
253 (
254  VisitorType& visitor
255 ) {
256  if (parameter_.isAcyclic_ == opengm::Tribool::True) {
257  if(parameter_.useNormalization_==opengm::Tribool::Maybe)
258  parameter_.useNormalization_=false;
259  inferAcyclic(visitor);
260  } else if (parameter_.isAcyclic_ == opengm::Tribool::False) {
261  if (parameter_.inferSequential_) {
262  inferSequential(visitor);
263  } else {
264  inferParallel(visitor);
265  }
266  } else { //triibool maby
267  if (gm_.isAcyclic()) {
268  parameter_.isAcyclic_ = opengm::Tribool::True;
269  if(parameter_.useNormalization_==opengm::Tribool::Maybe)
270  parameter_.useNormalization_=false;
271  inferAcyclic(visitor);
272  } else {
273  parameter_.isAcyclic_ = opengm::Tribool::False;
274  if (parameter_.inferSequential_) {
275  inferSequential(visitor);
276  } else {
277  inferParallel(visitor);
278  }
279  }
280  }
281  return NORMAL;
282 }
283 
289 template<class GM, class ACC, class UPDATE_RULES, class DIST>
290 inline void
292  EmptyVisitorType v;
293  return inferAcyclic(v);
294 }
295 
297 //
303 template<class GM, class ACC, class UPDATE_RULES, class DIST>
304 template<class VisitorType>
305 void
307 (
308  VisitorType& visitor
309 )
310 {
311  OPENGM_ASSERT(gm_.isAcyclic());
312  visitor.begin(*this);
313  size_t numberOfVariables = gm_.numberOfVariables();
314  size_t numberOfFactors = gm_.numberOfFactors();
315  // number of messages which have not yet been recevied
316  // but are required for sending
317  std::vector<std::vector<size_t> > counterVar2FacMessage(numberOfVariables);
318  std::vector<std::vector<size_t> > counterFac2VarMessage(numberOfFactors);
319  // list of messages which are ready to send
320  std::vector<Message> ready2SendVar2FacMessage;
321  std::vector<Message> ready2SendFac2VarMessage;
322  ready2SendVar2FacMessage.reserve(100);
323  ready2SendFac2VarMessage.reserve(100);
324  for (size_t fac = 0; fac < numberOfFactors; ++fac) {
325  counterFac2VarMessage[fac].resize(gm_[fac].numberOfVariables(), gm_[fac].numberOfVariables() - 1);
326  }
327  for (size_t var = 0; var < numberOfVariables; ++var) {
328  counterVar2FacMessage[var].resize(gm_.numberOfFactors(var));
329  for (size_t i = 0; i < gm_.numberOfFactors(var); ++i) {
330  counterVar2FacMessage[var][i] = gm_.numberOfFactors(var) - 1;
331  }
332  }
333  // find all messages which are ready for sending
334  for (size_t var = 0; var < numberOfVariables; ++var) {
335  for (size_t i = 0; i < counterVar2FacMessage[var].size(); ++i) {
336  if (counterVar2FacMessage[var][i] == 0) {
337  --counterVar2FacMessage[var][i];
338  ready2SendVar2FacMessage.push_back(Message(var, i));
339  }
340  }
341  }
342  for (size_t fac = 0; fac < numberOfFactors; ++fac) {
343  for (size_t i = 0; i < counterFac2VarMessage[fac].size(); ++i) {
344  if (counterFac2VarMessage[fac][i] == 0) {
345  --counterFac2VarMessage[fac][i];
346  ready2SendFac2VarMessage.push_back(Message(fac, i));
347  }
348  }
349  }
350  // send messages
351  while (ready2SendVar2FacMessage.size() > 0 || ready2SendFac2VarMessage.size() > 0) {
352  while (ready2SendVar2FacMessage.size() > 0) {
353  Message m = ready2SendVar2FacMessage.back();
354  size_t nodeId = m.nodeId_;
355  size_t factorId = gm_.factorOfVariable(nodeId,m.internalMessageId_);
356  // send message
357  variableHulls_[nodeId].propagate(gm_, m.internalMessageId_, 0, false);
358  ready2SendVar2FacMessage.pop_back();
359  //check if new messages can be sent
360  for (size_t i = 0; i < gm_[factorId].numberOfVariables(); ++i) {
361  if (gm_[factorId].variableIndex(i) != nodeId) {
362  if (--counterFac2VarMessage[factorId][i] == 0) {
363  ready2SendFac2VarMessage.push_back(Message(factorId, i));
364  }
365  }
366  }
367  }
368  while (ready2SendFac2VarMessage.size() > 0) {
369  Message m = ready2SendFac2VarMessage.back();
370  size_t factorId = m.nodeId_;
371  size_t nodeId = gm_[factorId].variableIndex(m.internalMessageId_);
372  // send message
373  factorHulls_[factorId].propagate(m.internalMessageId_, 0, parameter_.useNormalization_);
374  ready2SendFac2VarMessage.pop_back();
375  // check if new messages can be sent
376  for (size_t i = 0; i < gm_.numberOfFactors(nodeId); ++i) {
377  if (gm_.factorOfVariable(nodeId,i) != factorId) {
378  if (--counterVar2FacMessage[nodeId][i] == 0) {
379  ready2SendVar2FacMessage.push_back(Message(nodeId, i));
380  }
381  }
382  }
383  }
384  if(visitor(*this)!=0)
385  break;
386  }
387  visitor.end(*this);
388 
389 }
390 
392 template<class GM, class ACC, class UPDATE_RULES, class DIST>
394 (
395  const ValueType& damping
396 ) {
397  for (size_t i = 0; i < variableHulls_.size(); ++i) {
398  variableHulls_[i].propagateAll(damping, false);
399  }
400  for (size_t i = 0; i < factorHulls_.size(); ++i) {
401  factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
402  }
403 }
404 
406 template<class GM, class ACC, class UPDATE_RULES, class DIST>
408  EmptyVisitorType v;
409  return inferParallel(v);
410 }
411 
414 template<class GM, class ACC, class UPDATE_RULES, class DIST>
415 template<class VisitorType>
417 (
418  VisitorType& visitor
419 )
420 {
421  ValueType c = 0;
422  ValueType damping = parameter_.damping_;
423  visitor.begin(*this);
424 
425  // let all Factors with a order lower than 2 sending their Message
426  for (size_t i = 0; i < factorHulls_.size(); ++i) {
427  if (factorHulls_[i].numberOfBuffers() < 2) {
428  factorHulls_[i].propagateAll(0, parameter_.useNormalization_);
429  factorHulls_[i].propagateAll(0, parameter_.useNormalization_); // 2 times to fill both buffers
430  }
431  }
432  for (unsigned long n = 0; n < parameter_.maximumNumberOfSteps_; ++n) {
433  for (size_t i = 0; i < variableHulls_.size(); ++i) {
434  variableHulls_[i].propagateAll(gm_, damping, false);
435  }
436  for (size_t i = 0; i < factorHulls_.size(); ++i) {
437  if (factorHulls_[i].numberOfBuffers() >= 2)// messages from factors of order <2 do not change
438  factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
439  }
440  if(visitor(*this)!=0)
441  break;
442  c = convergence();
443  if (c < parameter_.bound_) {
444  break;
445  }
446  }
447  visitor.end(*this);
448 
449 }
450 
459 template<class GM, class ACC, class UPDATE_RULES, class DIST>
461  EmptyVisitorType v;
462  return inferSequential(v);
463 }
464 
475 template<class GM, class ACC, class UPDATE_RULES, class DIST>
476 template<class VisitorType>
478 (
479  VisitorType& visitor
480 ) {
481  OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
482  visitor.begin(*this);
483  ValueType damping = parameter_.damping_;
484 
485  // set nodeOrder
486  std::vector<size_t> nodeOrder(gm_.numberOfVariables());
487  for (size_t o = 0; o < gm_.numberOfVariables(); ++o) {
488  nodeOrder[parameter_.sortedNodeList_[o]] = o;
489  }
490 
491  // let all Factors with a order lower than 2 sending their Message
492  for (size_t f = 0; f < factorHulls_.size(); ++f) {
493  if (factorHulls_[f].numberOfBuffers() < 2) {
494  factorHulls_[f].propagateAll(0, parameter_.useNormalization_);
495  factorHulls_[f].propagateAll(0, parameter_.useNormalization_); //2 times to fill both buffers
496  }
497  }
498 
499  // calculate inverse positions
500  std::vector<std::vector<size_t> > inversePositions(gm_.numberOfVariables());
501  for(size_t var=0; var<gm_.numberOfVariables();++var) {
502  for(size_t i=0; i<gm_.numberOfFactors(var); ++i) {
503  size_t factorId = gm_.factorOfVariable(var,i);
504  for(size_t j=0; j<gm_.numberOfVariables(factorId);++j) {
505  if(gm_.variableOfFactor(factorId,j)==var) {
506  inversePositions[var].push_back(j);
507  break;
508  }
509  }
510  }
511  }
512 
513 
514  // the following Code is not optimized and maybe too slow for small factors
515  for (unsigned long itteration = 0; itteration < parameter_.maximumNumberOfSteps_; ++itteration) {
516  if(itteration%2==0) {
517  // in increasing ordering
518  for (size_t o = 0; o < gm_.numberOfVariables(); ++o) {
519  size_t variableId = parameter_.sortedNodeList_[o];
520  // update messages to the variable node
521  for(size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
522  size_t factorId = gm_.factorOfVariable(variableId,i);
523  factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
524  }
525 
526  // update messages from the variable node
527  variableHulls_[variableId].propagateAll(gm_, damping, false);
528  }
529  }
530  else{
531  // in decreasing ordering
532  for (size_t o = 0; o < gm_.numberOfVariables(); ++o) {
533  size_t variableId = parameter_.sortedNodeList_[gm_.numberOfVariables() - 1 - o];
534  // update messages to the variable node
535  for(size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
536  size_t factorId = gm_.factorOfVariable(variableId,i);
537  factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
538  }
539  // update messages from Variable
540  variableHulls_[variableId].propagateAll(gm_, damping, false);
541  }
542  }
543  if(visitor(*this)!=0)
544  break;
545  ValueType c = convergence();
546  if (c < parameter_.bound_) {
547  break;
548  }
549 
550  }
551  visitor.end(*this);
552 }
553 
554 template<class GM, class ACC, class UPDATE_RULES, class DIST>
557 (
558  const size_t variableIndex,
560 ) const {
561  OPENGM_ASSERT(variableIndex < variableHulls_.size());
562  variableHulls_[variableIndex].marginal(gm_, variableIndex, out, parameter_.useNormalization_);
563  return NORMAL;
564 }
565 
566 template<class GM, class ACC, class UPDATE_RULES, class DIST>
569 (
570  const size_t factorIndex,
572 ) const {
573  typedef typename GM::OperatorType OP;
574  OPENGM_ASSERT(factorIndex < factorHulls_.size());
575  out.assign(gm_, gm_[factorIndex].variableIndicesBegin(), gm_[factorIndex].variableIndicesEnd(), OP::template neutral<ValueType>());
576  factorHulls_[factorIndex].marginal(out, parameter_.useNormalization_);
577  return NORMAL;
578 }
579 
581 template<class GM, class ACC, class UPDATE_RULES, class DIST>
584  ValueType result = 0;
585  for (size_t j = 0; j < factorHulls_.size(); ++j) {
586  for (size_t i = 0; i < factorHulls_[j].numberOfBuffers(); ++i) {
587  ValueType d = factorHulls_[j].template distance<DIST > (i);
588  if (d > result) {
589  result = d;
590  }
591  }
592  }
593  return result;
594 }
595 
597 template<class GM, class ACC, class UPDATE_RULES, class DIST>
600  ValueType result = 0;
601  for (size_t j = 0; j < variableHulls_.size(); ++j) {
602  for (size_t i = 0; i < variableHulls_[j].numberOfBuffers(); ++i) {
603  ValueType d = variableHulls_[j].template distance<DIST > (i);
604  if (d > result) {
605  result = d;
606  }
607  }
608  }
609  return result;
610 }
611 
613 template<class GM, class ACC, class UPDATE_RULES, class DIST>
616  return convergenceXF();
617 }
618 
619 template<class GM, class ACC,class UPDATE_RULES, class DIST >
622 (
623  std::vector<LabelType>& conf,
624  const size_t N
625 ) const {
626  if (N != 1) {
627  throw RuntimeError("This implementation of message passing cannot return the k-th optimal configuration.");
628  }
629  else {
630  if (parameter_.isAcyclic_ == opengm::Tribool::True) {
631  return this->modeFromFactorMarginal(conf);
632  }
633  else {
634  return this->modeFromFactorMarginal(conf);
635  //return modeFromMarginal(conf);
636  }
637  }
638 }
639 
640 } // namespace opengm
641 
642 #endif // #ifndef OPENGM_BELIEFPROPAGATION_HXX
MessagePassing< _GM, _ACC, UR, DIST > type
InferenceTermination factorMarginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for all variables connected to a factor
UPDATE_RULES::SpecialParameterType SpecialParameterType
The OpenGM namespace.
Definition: config.hxx:43
std::vector< size_t > sortedNodeList_
UPDATE_RULES::FactorHullType FactorHullType
void infer(const typename INF::GraphicalModelType &gm, const typename INF::Parameter &param, std::vector< typename INF::LabelType > &conf)
Definition: inference.hxx:34
A framework for message passing algorithms Cf. F. R. Kschischang, B. J. Frey and H...
SpecialParameterType specialParameter_
visitors::TimingVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > TimingVisitorType
Visitor.
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
ValueType convergence() const
cumulative distance between all pairs of messages (between the previous and the current interation) ...
visitors::VerboseVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > VerboseVisitorType
Visitor.
InferenceTermination marginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
static T neutral()
neutral element (with return)
Definition: maximizer.hxx:14
visitors::EmptyVisitor< MessagePassing< GM, ACC, UPDATE_RULES, DIST > > EmptyVisitorType
Visitor.
UPDATE_RULES::template RebindGm< _GM >::type UR
UPDATE_RULES::VariableHullType VariableHullType
MessagePassing< _GM, ACC, UR, DIST > type
UPDATE_RULES::template RebindGmAndAcc< _GM, _ACC >::type UR
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:50
const GraphicalModelType & graphicalModel() const
Inference algorithm interface.
Definition: inference.hxx:43
void propagate(const ValueType &=0)
invoke one iteration of message passing
ValueType convergenceFX() const
cumulative distance between all pairs of messages from factors to variables (between the previous and...
std::string name() const
static M::ValueType op(const M &in1, const M &in2)
operation
InferenceTermination arg(std::vector< LabelType > &, const size_t=1) const
output a solution
Variable with three values (true=1, false=0, maybe=-1 )
Definition: tribool.hxx:8
ValueType convergenceXF() const
cumulative distance between all pairs of messages from variables to factors (between the previous and...
InferenceTermination infer()
static void op(const T1 &in1, T2 &out)
operation (in-place)
Definition: maximizer.hxx:34
void setMaxSteps(size_t maxSteps)
OpenGM runtime error.
Definition: opengm.hxx:100
MessagePassing(const GraphicalModelType &, const Parameter &=Parameter())
InferenceTermination
Definition: inference.hxx:24
GraphicalModelType::IndependentFactorType IndependentFactorType
Definition: inference.hxx:53