OpenGM  2.3.x
Discrete Graphical Model Library
trws.hxx
Go to the documentation of this file.
1 
5 #pragma once
6 #ifndef OPENGM_EXTERNAL_TRWS_HXX
7 #define OPENGM_EXTERNAL_TRWS_HXX
8 
15 
16 #include "typeView.h"
17 #include "MRFEnergy.h"
18 #include "instances.h"
19 #include "MRFEnergy.cpp"
20 #include "minimize.cpp"
21 #include "treeProbabilities.cpp"
22 #include "ordering.cpp"
23 
24 namespace opengm {
25  namespace external {
38  template<class GM>
39  class TRWS : public Inference<GM, opengm::Minimizer> {
40  public:
41  typedef GM GraphicalModelType;
47  typedef size_t VariableIndex;
48 
49 
50  template<class _GM>
51  struct RebindGm{
52  typedef TRWS<_GM> type;
53  };
54 
55  template<class _GM,class _ACC>
57  typedef TRWS<_GM> type;
58  };
59 
61  struct Parameter {
63  enum EnergyType {VIEW=0, TABLES=1, TL1=2, TL2=3/*, WEIGHTEDTABLE*/};
71  bool doBPS_;
75  double tolerance_;
81  template<class P>
82  Parameter(const P & p)
83  : numberOfIterations_(p.numberOfIterations_),
84  useRandomStart_(p.useRandomStart_),
85  useZeroStart_(p.useZeroStart_),
86  doBPS_(p.doBPS_),
87  energyType_(),
88  tolerance_(p.tolerance_),
89  minDualChange_(p.minDualChange_)
90  {
91  if(p.energyType_==0){
92  energyType_ =VIEW;
93  }
94  else if(p.energyType_==1){
95  energyType_ =TABLES;
96  }
97  else if(p.energyType_==2){
98  energyType_ =TL1;
99  }
100  else if(p.energyType_==3){
101  energyType_ =TL2;
102  }
103  };
104 
106  numberOfIterations_ = 1000;
107  useRandomStart_ = false;
108  useZeroStart_ = false;
109  doBPS_ = false;
110  energyType_ = VIEW;
111  tolerance_ = 0.0;
112  minDualChange_ = 0.00001;
113  calculateMinMarginals_ = false;
114  };
115  };
116  // construction
117  TRWS(const GraphicalModelType& gm, const Parameter para = Parameter());
118  // destruction
119  ~TRWS();
120  // query
121  std::string name() const;
122  const GraphicalModelType& graphicalModel() const;
123  // inference
124  template<class VISITOR>
125  InferenceTermination infer(VISITOR & visitor);
127  InferenceTermination arg(std::vector<LabelType>&, const size_t& = 1) const;
128  InferenceTermination marginal(const size_t variableIndex, IndependentFactorType& out) const;
129  typename GM::ValueType bound() const;
130  typename GM::ValueType value() const;
131  private:
132  const GraphicalModelType& gm_;
133  Parameter parameter_;
134  ValueType constTerm_;
135 
136  MRFEnergy<TypeView<GM> >* mrfView_;
137  typename MRFEnergy<TypeView<GM> >::NodeId* nodesView_;
138  MRFEnergy<TypeGeneral>* mrfGeneral_;
139  MRFEnergy<TypeGeneral>::NodeId* nodesGeneral_;
144  TypeGeneral::REAL* minMarginals_;
145  size_t* minMarginalsOffsets_;
146 
147 
148  double runTime_;
149  ValueType lowerBound_;
150  ValueType value_;
151  std::vector<LabelType> state_;
152  const IndexType numNodes_;
153  IndexType maxNumLabels_;
154  bool hasSameLabelNumber_;
155  void checkLabelNumber();
156 
157  void generateMRFView();
158  void generateMRFTables();
159  void generateMRFTL1();
160  void generateMRFTL2();
161  //void generateMRFWeightedTable();
162 
163  ValueType getT(IndexType factor) const;
164 
165  // required for energy type tl1
166  bool truncatedAbsoluteDifferenceFactors() const;
167 
168  // required for energy type tl2
169  bool truncatedSquaredDifferenceFactors() const;
170 
171  template <class ENERGYTYPE>
172  void addNodes(MRFEnergy<ENERGYTYPE>*& mrf, typename MRFEnergy<ENERGYTYPE>::NodeId*& nodes, typename ENERGYTYPE::REAL* D);
173 
174  template<class VISITOR, class ENERGYTYPE>
175  InferenceTermination inferImpl(VISITOR & visitor, MRFEnergy<ENERGYTYPE>* mrf);
176  };
177 
178  template<class GM, class ENERGYTYPE>
180  static void* create(typename GM::IndexType numLabels);
181  };
182 
183  template<class GM>
184  struct createMRFEnergy<GM, TypeView<GM> >{
185  static void* create(typename GM::IndexType numLabels);
186  };
187 
188  template<class GM>
189  struct createMRFEnergy<GM, TypeGeneral>{
190  static void* create(typename GM::IndexType numLabels);
191  };
192 
193  template<class GM>
194  struct createMRFEnergy<GM, TypeTruncatedLinear>{
195  static void* create(typename GM::IndexType numLabels);
196  };
197 
198  template<class GM>
199  struct createMRFEnergy<GM, TypeTruncatedQuadratic>{
200  static void* create(typename GM::IndexType numLabels);
201  };
202 
203  template<class GM, class ENERGYTYPE>
204  struct addMRFNode{
205  static typename MRFEnergy<ENERGYTYPE>::NodeId add(MRFEnergy<ENERGYTYPE>* mrf, typename GM::IndexType numLabels, typename ENERGYTYPE::REAL* D);
206  };
207 
208  template<class GM>
209  struct addMRFNode<GM, TypeView<GM> >{
210  static typename MRFEnergy<TypeView<GM> >::NodeId add(MRFEnergy<TypeView<GM> >* mrf, typename GM::IndexType numLabels, typename TypeView<GM>::REAL* D);
211  };
212 
213  template<class GM>
214  struct addMRFNode<GM, TypeGeneral>{
215  static typename MRFEnergy<TypeGeneral>::NodeId add(MRFEnergy<TypeGeneral>* mrf, typename GM::IndexType numLabels, typename TypeGeneral::REAL* D);
216  };
217 
218  template<class GM>
219  struct addMRFNode<GM, TypeTruncatedLinear>{
220  static typename MRFEnergy<TypeTruncatedLinear>::NodeId add(MRFEnergy<TypeTruncatedLinear>* mrf, typename GM::IndexType numLabels, typename TypeTruncatedLinear::REAL* D);
221  };
222 
223  template<class GM>
224  struct addMRFNode<GM, TypeTruncatedQuadratic>{
225  static typename MRFEnergy<TypeTruncatedQuadratic>::NodeId add(MRFEnergy<TypeTruncatedQuadratic>* mrf, typename GM::IndexType numLabels, typename TypeTruncatedQuadratic::REAL* D);
226  };
227 
228  template<class GM>
230  const typename TRWS::GraphicalModelType& gm,
231  const Parameter para
232  )
233  : gm_(gm), parameter_(para), mrfView_(NULL), nodesView_(NULL), mrfGeneral_(NULL), nodesGeneral_(NULL),
234  mrfTL1_(NULL), nodesTL1_(NULL), mrfTL2_(NULL), nodesTL2_(NULL), minMarginals_(NULL), minMarginalsOffsets_(NULL),
235  numNodes_(gm_.numberOfVariables()), maxNumLabels_(gm_.numberOfLabels(0)) {
236  // check label number
237  checkLabelNumber();
238  if(parameter_.calculateMinMarginals_){
239  size_t count = 0;
240  minMarginalsOffsets_ = new size_t[gm_.numberOfVariables()];
241  for(size_t i=0; i<gm_.numberOfVariables(); ++i){
242  minMarginalsOffsets_[i] = count;
243  count += gm_.numberOfLabels(i);
244  }
245  minMarginals_ = new TypeGeneral::REAL[count];
246  }
247 
248  // generate mrf model
249  switch(parameter_.energyType_) {
250  case Parameter::VIEW: {
251  generateMRFView();
252  break;
253  }
254  case Parameter::TABLES: {
255  generateMRFTables();
256  break;
257  }
258  case Parameter::TL1: {
259  if(!hasSameLabelNumber_) {
260  throw(RuntimeError("TRWS TL1 only supports graphical models where each variable has the same number of states."));
261  }
262  generateMRFTL1();
263  break;
264  }
265  case Parameter::TL2: {
266  if(!hasSameLabelNumber_) {
267  throw(RuntimeError("TRWS TL2 only supports graphical models where each variable has the same number of states."));
268  }
269  generateMRFTL2();
270  break;
271  }
272  /*case Parameter::WEIGHTEDTABLE: {
273  generateMRFWeightedTable();
274  break;
275  }*/
276  default: {
277  throw(RuntimeError("Unknown energy type."));
278  }
279  }
280 
281  // set initial value and lower bound
283  AccumulationType::ineutral(lowerBound_);
284  }
285 
286  template<class GM>
288  if(mrfView_) {
289  delete mrfView_;
290  }
291  if(nodesView_) {
292  delete[] nodesView_;
293  }
294 
295  if(mrfGeneral_) {
296  delete mrfGeneral_;
297  }
298  if(nodesGeneral_) {
299  delete[] nodesGeneral_;
300  }
301 
302  if(mrfTL1_) {
303  delete mrfTL1_;
304  }
305  if(nodesTL1_) {
306  delete[] nodesTL1_;
307  }
308 
309  if(mrfTL2_) {
310  delete mrfTL2_;
311  }
312  if(nodesTL2_) {
313  delete[] nodesTL2_;
314  }
315  if(minMarginals_) {
316  delete[] minMarginals_;
317  delete[] minMarginalsOffsets_;
318  }
319  }
320 
321  template<class GM>
322  inline std::string
323  TRWS<GM>
324  ::name() const {
325  return "TRWS";
326  }
327 
328  template<class GM>
329  inline const typename TRWS<GM>::GraphicalModelType&
330  TRWS<GM>
332  return gm_;
333  }
334 
335  template<class GM>
336  inline InferenceTermination
338  (
339  ) {
340  EmptyVisitorType visitor;
341  return this->infer(visitor);
342  }
343 
344  template<class GM>
345  template<class VISITOR>
346  inline InferenceTermination
348  (
349  VISITOR & visitor
350  ) {
351  switch(parameter_.energyType_) {
352  case Parameter::VIEW: {
353  return inferImpl(visitor, mrfView_);
354  break;
355  }
356  case Parameter::TABLES: {
357  return inferImpl(visitor, mrfGeneral_);
358  break;
359  }
360  case Parameter::TL1: {
361  return inferImpl(visitor, mrfTL1_);
362  break;
363  }
364  case Parameter::TL2: {
365  return inferImpl(visitor, mrfTL2_);
366  break;
367  }
368 /* case Parameter::WEIGHTEDTABLE: {
369  return inferImpl(visitor, mrf);
370  break;
371  }*/
372  default: {
373  throw(RuntimeError("Unknown energy type."));
374  }
375  }
376  }
377 
378  template<class GM>
379  inline InferenceTermination
380  TRWS<GM>
382  std::vector<LabelType>& arg,
383  const size_t& n
384  ) const {
385 
386  if(n > 1) {
387  return UNKNOWN;
388  }
389  else {
390  arg.resize(numNodes_);
391  switch(parameter_.energyType_) {
392  case Parameter::VIEW: {
393  for(IndexType i = 0; i < numNodes_; i++) {
394  arg[i] = mrfView_->GetSolution(nodesView_[i]);
395  }
396  return NORMAL;
397  break;
398  }
399  case Parameter::TABLES: {
400  for(IndexType i = 0; i < numNodes_; i++) {
401  arg[i] = mrfGeneral_->GetSolution(nodesGeneral_[i]);
402  }
403  return NORMAL;
404  break;
405  }
406  case Parameter::TL1: {
407  for(IndexType i = 0; i < numNodes_; i++) {
408  arg[i] = mrfTL1_->GetSolution(nodesTL1_[i]);
409  }
410  return NORMAL;
411  break;
412  }
413  case Parameter::TL2: {
414  for(IndexType i = 0; i < numNodes_; i++) {
415  arg[i] = mrfTL2_->GetSolution(nodesTL2_[i]);
416  }
417  return NORMAL;
418  break;
419  }
420 /* case Parameter::WEIGHTEDTABLE: {
421  for(IndexType i = 0; i < numNodes_; i++) {
422  arg[i] = mrfGeneral_->GetSolution(nodesGeneral_[i]);
423  }
424  return NORMAL;
425  break;
426  }*/
427  default: {
428  throw(RuntimeError("Unknown energy type."));
429  }
430  }
431  }
432  }
433 
437  template<class GM>
438  inline InferenceTermination
440  const size_t variableIndex,
442  ) const
443  {
444 
445  if(parameter_.calculateMinMarginals_){
446  out.assign(gm_, &variableIndex, &variableIndex+1, 0);
447  for(size_t i=0; i<gm_.numberOfLabels(variableIndex); ++i){
448  out(i) = minMarginals_[i+minMarginalsOffsets_[variableIndex]];
449  }
450  return NORMAL;
451  }else{
452  return UNKNOWN;
453  }
454  }
455 
456  template<class GM>
457  inline typename GM::ValueType
458  TRWS<GM>::bound() const {
459  return lowerBound_+constTerm_;
460  }
461  template<class GM>
462  inline typename GM::ValueType
463  TRWS<GM>::value() const {
464  return value_+constTerm_;
465  }
466 
467  template<class GM>
468  inline void TRWS<GM>::checkLabelNumber() {
469  hasSameLabelNumber_ = true;
470  for(IndexType i = 1; i < gm_.numberOfVariables(); i++) {
471  if(gm_.numberOfLabels(i) != maxNumLabels_) {
472  hasSameLabelNumber_ = false;
473  }
474  if(gm_.numberOfLabels(i) > maxNumLabels_) {
475  maxNumLabels_ = gm_.numberOfLabels(i);
476  }
477  }
478  }
479 
480  template<class GM>
481  inline void TRWS<GM>::generateMRFView() {
482  mrfView_ = new MRFEnergy<TypeView<GM> >(typename TypeView<GM>::GlobalSize());
483  nodesView_ = new typename MRFEnergy<TypeView<GM> >::NodeId[numNodes_];
484 
485  // add nodes
486  for(IndexType i = 0; i < numNodes_; i++) {
487  std::vector<typename GM::IndexType> factors;
488  for(typename GM::ConstFactorIterator iter = gm_.factorsOfVariableBegin(i); iter != gm_.factorsOfVariableEnd(i); iter++) {
489  if(gm_[*iter].numberOfVariables() == 1) {
490  factors.push_back(*iter);
491  }
492  }
493  nodesView_[i] = mrfView_->AddNode(typename TypeView<GM>::LocalSize(gm_.numberOfLabels(i)), typename TypeView<GM>::NodeData(gm_, factors));
494  }
495 
496  // add edges
497  constTerm_ = 0;
498  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
499  if(gm_[i].numberOfVariables() == 0){
500  LabelType l = 0;
501  constTerm_ += gm_[i](&l);
502  }
503  if(gm_[i].numberOfVariables() == 2) {
504  IndexType a = gm_[i].variableIndex(0);
505  IndexType b = gm_[i].variableIndex(1);
506  mrfView_->AddEdge(nodesView_[a], nodesView_[b], typename TypeView<GM>::EdgeData(gm_, i));
507  }
508  }
509  // set random start message
510  if(parameter_.useRandomStart_) {
511  mrfView_->AddRandomMessages(1, 0.0, 1.0);
512  } else if(parameter_.useZeroStart_) {
513  mrfView_->ZeroMessages();
514  }
515  }
516 
517  template<class GM>
518  inline void TRWS<GM>::generateMRFTables() {
519  // add nodes
520  typename TypeGeneral::REAL* D = new typename TypeGeneral::REAL[maxNumLabels_];
521  addNodes(mrfGeneral_, nodesGeneral_, D);
522  delete[] D;
523 
524  // add edges
525  IndexType index[2];
526  constTerm_ = 0;
527  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
528  if(gm_[i].numberOfVariables() == 0){
529  LabelType l = 0;
530  constTerm_ += gm_[i](&l);
531  }
532  if(gm_[i].numberOfVariables() == 2) {
533  IndexType a = gm_[i].variableIndex(0);
534  IndexType b = gm_[i].variableIndex(1);
535  IndexType numLabels_a = gm_.numberOfLabels(a);
536  IndexType numLabels_b = gm_.numberOfLabels(b);
537  typename TypeGeneral::REAL* V = new typename TypeGeneral::REAL[numLabels_a * numLabels_b];
538  for(size_t j = 0; j < numLabels_a; j++) {
539  for(size_t k = 0; k < numLabels_b; k++) {
540  index[0] = j;
541  index[1] = k;
542  V[j + k * numLabels_a] = gm_[i](index);
543  }
544  }
545  mrfGeneral_->AddEdge(nodesGeneral_[a], nodesGeneral_[b], TypeGeneral::EdgeData(TypeGeneral::GENERAL, V));
546  delete[] V;
547  }
548  }
549 
550  // set random start message
551  if(parameter_.useRandomStart_) {
552  mrfGeneral_->AddRandomMessages(1, 0.0, 1.0);
553  } else if(parameter_.useZeroStart_) {
554  mrfGeneral_->ZeroMessages();
555  }
556  }
557 
558  template<class GM>
559  inline void TRWS<GM>::generateMRFTL1() {
560  OPENGM_ASSERT(truncatedAbsoluteDifferenceFactors());
561 
562  // add nodes
563  typename TypeTruncatedLinear::REAL* D = new typename TypeTruncatedLinear::REAL[maxNumLabels_];
564  addNodes(mrfTL1_, nodesTL1_, D);
565  delete[] D;
566 
567  // add edges
568  constTerm_=0;
569  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
570  if(gm_[i].numberOfVariables() == 0){
571  LabelType l = 0;
572  constTerm_ += gm_[i](&l);
573  }
574  if(gm_[i].numberOfVariables() == 2) {
575  // truncation
576  ValueType t = getT(i);
577  //std::cout << "t: " << t << std::endl;
578 
579  // weight
580  IndexType index[] = {0, 1};
581  ValueType w = gm_[i](index);
582  //std::cout << "w: " << w << std::endl;
583 
584  // corresponding node IDs
585  IndexType a = gm_[i].variableIndex(0);
586  IndexType b = gm_[i].variableIndex(1);
587  mrfTL1_->AddEdge(nodesTL1_[a], nodesTL1_[b], TypeTruncatedLinear::EdgeData(w, w * t));
588  }
589  }
590 
591  // set random start message
592  if(parameter_.useRandomStart_) {
593  mrfTL1_->AddRandomMessages(1, 0.0, 1.0);
594  } else if(parameter_.useZeroStart_) {
595  mrfTL1_->ZeroMessages();
596  }
597  }
598 
599  template<class GM>
600  inline void TRWS<GM>::generateMRFTL2() {
601  OPENGM_ASSERT(truncatedSquaredDifferenceFactors());
602 
603  // add nodes
604  typename TypeTruncatedQuadratic::REAL* D = new typename TypeTruncatedQuadratic::REAL[maxNumLabels_];
605  addNodes(mrfTL2_, nodesTL2_, D);
606  delete[] D;
607 
608  // add edges
609  constTerm_=0;
610  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
611  if(gm_[i].numberOfVariables() == 0){
612  LabelType l = 0;
613  constTerm_ += gm_[i](&l);
614  }
615  if(gm_[i].numberOfVariables() == 2) {
616  // truncation
617  ValueType t = getT(i);
618  //std::cout << "t: " << t << std::endl;
619 
620  // weight
621  IndexType index[] = {0, 1};
622  ValueType w = gm_[i](index);
623  //std::cout << "w: " << w << std::endl;
624 
625  // corresponding node IDs
626  IndexType a = gm_[i].variableIndex(0);
627  IndexType b = gm_[i].variableIndex(1);
628  mrfTL2_->AddEdge(nodesTL2_[a], nodesTL2_[b], TypeTruncatedQuadratic::EdgeData(w, w * t));
629  }
630  }
631 
632  //mrfTL2_->SetAutomaticOrdering();
633 
634  // set random start message
635  if(parameter_.useRandomStart_) {
636  mrfTL2_->AddRandomMessages(1, 0.0, 1.0);
637  } else if(parameter_.useZeroStart_) {
638  mrfTL2_->ZeroMessages();
639  }
640  }
641 
642 /* template<class GM>
643  inline void TRWS<GM>::generateMRFWeightedTable() {
644 
645  }*/
646 
647  template<class GM>
648  inline typename GM::ValueType TRWS<GM>::getT(IndexType factor) const {
649  OPENGM_ASSERT(gm_.numberOfVariables(factor) == 2);
650 
651  IndexType index1[] = {0, 1};
652  IndexType index0[] = {0, maxNumLabels_-1};
653 
654  return gm_[factor](index0)/gm_[factor](index1);
655  }
656 
657  template<class GM>
659  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
660  if(gm_.numberOfVariables(i) == 2) {
661  if(gm_[i].isTruncatedAbsoluteDifference() == false) {
662  return false;
663  }
664  }
665  }
666  return true;
667  }
668 
669  template<class GM>
670  inline bool TRWS<GM>::truncatedSquaredDifferenceFactors() const {
671  for(IndexType i = 0; i < gm_.numberOfFactors(); i++) {
672  if(gm_.numberOfVariables(i) == 2) {
673  if(gm_[i].isTruncatedSquaredDifference() == false) {
674  return false;
675  }
676  }
677  }
678  return true;
679  }
680 
681  template<class GM>
682  template <class ENERGYTYPE>
683  inline void TRWS<GM>::addNodes(MRFEnergy<ENERGYTYPE>*& mrf, typename MRFEnergy<ENERGYTYPE>::NodeId*& nodes, typename ENERGYTYPE::REAL* D) {
684 
685  mrf = reinterpret_cast<MRFEnergy<ENERGYTYPE>*>(createMRFEnergy<GM, ENERGYTYPE>::create(maxNumLabels_));
686 
687  nodes = new typename MRFEnergy<ENERGYTYPE>::NodeId[numNodes_];
688  for(IndexType i = 0; i < numNodes_; i++) {
689  for(IndexType j = 0; j < gm_.numberOfLabels(i); j++) {
690  D[j] = 0.0;
691  }
692  for(typename GM::ConstFactorIterator iter = gm_.factorsOfVariableBegin(i); iter != gm_.factorsOfVariableEnd(i); iter++) {
693  if(gm_[*iter].numberOfVariables() == 1) {
694  for(IndexType j = 0; j < gm_.numberOfLabels(i); j++) {
695  D[j] += gm_[*iter](&j);
696  }
697  }
698  }
699  nodes[i] = addMRFNode<GM, ENERGYTYPE>::add(mrf, gm_.numberOfLabels(i), D);
700  }
701  }
702 
703  template<class GM, class ENERGYTYPE>
704  inline void* createMRFEnergy<GM, ENERGYTYPE>::create(typename GM::IndexType numLabels) {
705  RuntimeError("Unsupported Energy Type!");
706  return NULL;
707  }
708 
709  template<class GM>
710  inline void* createMRFEnergy<GM, TypeView<GM> >::create(typename GM::IndexType numLabels) {
711  return reinterpret_cast<void*>(new MRFEnergy<TypeView<GM> >(typename TypeView<GM>::GlobalSize()));
712  }
713 
714  template<class GM>
715  inline void* createMRFEnergy<GM, TypeGeneral>::create(typename GM::IndexType numLabels) {
716  return reinterpret_cast<void*>(new MRFEnergy<TypeGeneral>(typename TypeGeneral::GlobalSize()));
717  }
718 
719  template<class GM>
720  inline void* createMRFEnergy<GM, TypeTruncatedLinear>::create(typename GM::IndexType numLabels) {
721  return reinterpret_cast<void*>(new MRFEnergy<TypeTruncatedLinear>(typename TypeTruncatedLinear::GlobalSize(numLabels)));
722  }
723 
724  template<class GM>
725  inline void* createMRFEnergy<GM, TypeTruncatedQuadratic>::create(typename GM::IndexType numLabels) {
726  return reinterpret_cast<void*>(new MRFEnergy<TypeTruncatedQuadratic>(typename TypeTruncatedQuadratic::GlobalSize(numLabels)));
727  }
728 
729  template<class GM, class ENERGYTYPE>
730  inline typename MRFEnergy<ENERGYTYPE>::NodeId addMRFNode<GM, ENERGYTYPE>::add(MRFEnergy<ENERGYTYPE>* mrf, typename GM::IndexType numLabels, typename ENERGYTYPE::REAL* D) {
731  RuntimeError("Unsupported Energy Type!");
732  return NULL;
733  }
734 
735  template<class GM>
736  inline typename MRFEnergy<TypeView<GM> >::NodeId addMRFNode<GM, TypeView<GM> >::add(MRFEnergy<TypeView<GM> >* mrf, typename GM::IndexType numLabels, typename TypeView<GM>::REAL* D) {
737  return mrf->AddNode(typename TypeView<GM>::LocalSize(numLabels), typename TypeView<GM>::NodeData(D));
738  }
739 
740  template<class GM>
741  inline typename MRFEnergy<TypeGeneral>::NodeId addMRFNode<GM, TypeGeneral>::add(MRFEnergy<TypeGeneral>* mrf, typename GM::IndexType numLabels, typename TypeGeneral::REAL* D) {
742  return mrf->AddNode(typename TypeGeneral::LocalSize(numLabels), typename TypeGeneral::NodeData(D));
743  }
744 
745  template<class GM>
746  inline typename MRFEnergy<TypeTruncatedLinear>::NodeId addMRFNode<GM, TypeTruncatedLinear>::add(MRFEnergy<TypeTruncatedLinear>* mrf, typename GM::IndexType numLabels, typename TypeTruncatedLinear::REAL* D) {
747  return mrf->AddNode(typename TypeTruncatedLinear::LocalSize(), typename TypeTruncatedLinear::NodeData(D));
748  }
749 
750  template<class GM>
751  inline typename MRFEnergy<TypeTruncatedQuadratic>::NodeId addMRFNode<GM, TypeTruncatedQuadratic>::add(MRFEnergy<TypeTruncatedQuadratic>* mrf, typename GM::IndexType numLabels, typename TypeTruncatedQuadratic::REAL* D) {
752  return mrf->AddNode(typename TypeTruncatedQuadratic::LocalSize(), typename TypeTruncatedQuadratic::NodeData(D));
753  }
754 
755  template<class GM>
756  template<class VISITOR, class ENERGYTYPE>
757  inline InferenceTermination TRWS<GM>::inferImpl(VISITOR & visitor, MRFEnergy<ENERGYTYPE>* mrf) {
758  typename MRFEnergy<ENERGYTYPE>::Options options;
759  options.m_iterMax = 1; // maximum number of iterations
760  options.m_printIter = 2 * parameter_.numberOfIterations_;
761  visitor.begin(*this);
762 
763 
764  if(parameter_.doBPS_) {
765  typename ENERGYTYPE::REAL v;
766  for(size_t i = 0; i < parameter_.numberOfIterations_; ++i) {
767  mrf->Minimize_BP(options, v, minMarginals_);
768  value_ = v;
769  if( visitor(*this) != visitors::VisitorReturnFlag::ContinueInf ) {
770  break;
771  }
772  }
773  } else {
774  typename ENERGYTYPE::REAL v;
775  typename ENERGYTYPE::REAL b;
776  typename ENERGYTYPE::REAL d;
777  for(size_t i = 0; i < parameter_.numberOfIterations_; ++i) {
778  mrf->Minimize_TRW_S(options, b, v, minMarginals_);
779  d = b-lowerBound_;
780  lowerBound_ = b;
781  value_ = v;
782  if( visitor(*this) != visitors::VisitorReturnFlag::ContinueInf ) {
783  break;
784  }
785  if(fabs(value_ - lowerBound_) / opengmMax(static_cast<double>(fabs(value_)), 1.0) < parameter_.tolerance_) {
786  break;
787  }
788  if(d<parameter_.minDualChange_){
789  break;
790  }
791  }
792  }
793  //Copy MinMarginals
794 
795  visitor.end(*this);
796  return NORMAL;
797  }
798 
799  } // namespace external
800 } // namespace opengm
801 
802 #endif // #ifndef OPENGM_EXTERNAL_TRWS_HXX
The OpenGM namespace.
Definition: config.hxx:43
InferenceTermination marginal(const size_t variableIndex, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
Definition: trws.hxx:439
const GraphicalModelType & graphicalModel() const
Definition: trws.hxx:331
GM::ValueType value() const
return the solution (value)
Definition: trws.hxx:463
bool calculateMinMarginals_
Calculate MinMarginals.
Definition: trws.hxx:79
visitors::EmptyVisitor< TRWS< GM > > EmptyVisitorType
Definition: trws.hxx:45
bool useZeroStart_
zero starting message
Definition: trws.hxx:69
size_t VariableIndex
Definition: trws.hxx:47
T opengmMax(const T &x, const T &y)
Definition: opengm.hxx:116
InferenceTermination infer()
Definition: trws.hxx:338
InferenceTermination arg(std::vector< LabelType > &, const size_t &=1) const
Definition: trws.hxx:381
bool useRandomStart_
random starting message
Definition: trws.hxx:67
void create(const hid_t &, const std::string &, ShapeIterator, ShapeIterator, CoordinateOrder)
Create and close an HDF5 dataset to store Marray data.
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
std::string name() const
Definition: trws.hxx:324
static MRFEnergy< ENERGYTYPE >::NodeId add(MRFEnergy< ENERGYTYPE > *mrf, typename GM::IndexType numLabels, typename ENERGYTYPE::REAL *D)
Definition: trws.hxx:730
double minDualChange_
TRWS termintas if fabs(bound(t)-bound(t+1)) < minDualChange_.
Definition: trws.hxx:77
GraphicalModelType::IndexType IndexType
Definition: inference.hxx:49
EnergyType
possible energy types for TRWS
Definition: trws.hxx:63
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:50
static T ineutral()
inverse neutral element (with return)
Definition: minimizer.hxx:25
Inference algorithm interface.
Definition: inference.hxx:43
GM::ValueType bound() const
return a bound on the solution
Definition: trws.hxx:458
double tolerance_
TRWS termintas if fabs(value - bound) / max(fabs(value), 1) < trwsTolerance_.
Definition: trws.hxx:75
size_t numberOfIterations_
number of iterations
Definition: trws.hxx:65
visitors::VerboseVisitor< TRWS< GM > > VerboseVisitorType
Definition: trws.hxx:44
static T neutral()
neutral element (with return)
Definition: minimizer.hxx:16
TRWS(const GraphicalModelType &gm, const Parameter para=Parameter())
Definition: trws.hxx:229
double REAL
Definition: typeView.h:22
Minimization as a unary accumulation.
Definition: minimizer.hxx:12
static void * create(typename GM::IndexType numLabels)
Definition: trws.hxx:704
message passing (BPS, TRWS): [?]
Definition: trws.hxx:39
bool doBPS_
use normal LBP
Definition: trws.hxx:71
Parameter(const P &p)
Constructor.
Definition: trws.hxx:82
EnergyType energyType_
selected energy type
Definition: trws.hxx:73
opengm::Minimizer AccumulationType
Definition: trws.hxx:42
OpenGM runtime error.
Definition: opengm.hxx:100
visitors::TimingVisitor< TRWS< GM > > TimingVisitorType
Definition: trws.hxx:46
InferenceTermination
Definition: inference.hxx:24
GraphicalModelType::IndependentFactorType IndependentFactorType
Definition: inference.hxx:53