OpenGM  2.3.x
Discrete Graphical Model Library
alphabetaswap.hxx
Go to the documentation of this file.
1 #pragma once
2 #ifndef OPENGM_ALPHABEATSWAP_HXX
3 #define OPENGM_ALPHABETASWAP_HXX
4 
5 #include <vector>
6 
9 
10 namespace opengm {
11 
14 template<class GM, class INF>
15 class AlphaBetaSwap : public Inference<GM, typename INF::AccumulationType> {
16 public:
17  typedef GM GraphicalModelType;
18  typedef INF InferenceType;
19  typedef typename INF::AccumulationType AccumulationType;
24 
25 
26  template<class _GM>
27  struct RebindGm{
28  typedef typename INF:: template RebindGm<_GM>::type RebindedInf;
30  };
31 
32  template<class _GM,class _ACC>
34  typedef typename INF:: template RebindGmAndAcc<_GM,_ACC>::type RebindedInf;
36  };
37 
38 
39  struct Parameter {
41  maxNumberOfIterations_ = 1000;
42  }
43  template<class P>
44  Parameter(const P & p)
45  : parameter_(p.parameter_),
46  maxNumberOfIterations_(maxNumberOfIterations_){
47  }
48 
49  typename InferenceType::Parameter parameter_;
51  };
52 
53  AlphaBetaSwap(const GraphicalModelType&, Parameter = Parameter());
54  std::string name() const;
55  const GraphicalModelType& graphicalModel() const;
57  template<class VISITOR>
58  InferenceTermination infer(VISITOR & );
59  void reset();
60  void setStartingPoint(typename std::vector<LabelType>::const_iterator);
61  InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
62 
63 private:
64  const GraphicalModelType& gm_;
65  Parameter parameter_;
66  std::vector<LabelType> label_;
67  size_t alpha_;
68  size_t beta_;
69  size_t maxState_;
70  void increment();
71  void addUnary(INF&, const size_t var, const ValueType v0, const ValueType v1);
72  void addPairwise(INF&, const size_t var1, const size_t var2, const ValueType v0, const ValueType v1, const ValueType v2, const ValueType v3);
73 };
74 
75 // reset assumes that the structure of the graphical model has not changed
76 template<class GM, class INF>
77 inline void
79  alpha_ = 0;
80  beta_ = 0;
81  std::fill(label_.begin(),label_.end(),0);
82 }
83 
84 template<class GM, class INF>
85 inline void
87  if (++beta_ >= maxState_) {
88  if (++alpha_ >= maxState_ - 1) {
89  alpha_ = 0;
90  }
91  beta_ = alpha_ + 1;
92  }
93  OPENGM_ASSERT(alpha_ < maxState_);
94  OPENGM_ASSERT(beta_ < maxState_);
95  OPENGM_ASSERT(alpha_ < beta_);
96 }
97 
98 template<class GM, class INF>
99 inline std::string
101  return "Alpha-Beta-Swap";
102 }
103 
104 template<class GM, class INF>
105 inline const typename AlphaBetaSwap<GM, INF>::GraphicalModelType&
107  return gm_;
108 }
109 
110 template<class GM, class INF>
112 (
113  const GraphicalModelType& gm,
114  Parameter para
115 )
116 : gm_(gm)
117 {
118  parameter_ = para;
119  label_.resize(gm_.numberOfVariables(), 0);
120  alpha_ = 0;
121  beta_ = 0;
122  for (size_t j = 0; j < gm_.numberOfFactors(); ++j) {
123  if (gm_[j].numberOfVariables() > 2) {
124  throw RuntimeError("This implementation of Alpha-Beta-Swap supports only factors of order <= 2.");
125  }
126  }
127  maxState_ = 0;
128  for (size_t i = 0; i < gm_.numberOfVariables(); ++i) {
129  size_t numSt = gm_.numberOfLabels(i);
130  if (numSt > maxState_)
131  maxState_ = numSt;
132  }
133 }
134 
135 template<class GM, class INF>
136 inline void
138 (
139  typename std::vector<typename AlphaBetaSwap<GM,INF>::LabelType>::const_iterator begin
140 ) {
141  try{
142  label_.assign(begin, begin+gm_.numberOfVariables());
143  }
144  catch(...) {
145  throw RuntimeError("unsuitable starting point");
146  }
147 }
148 
149 template<class GM, class INF>
150 inline void
152 (
153  INF& inf,
154  const size_t var1,
155  const ValueType v0,
156  const ValueType v1
157 ) {
158  const size_t shape[] = {2};
159  const size_t vars[] = {var1};
160  opengm::IndependentFactor<ValueType,IndexType,LabelType> fac(vars, vars + 1, shape, shape + 1);
161  fac(0) = v0;
162  fac(1) = v1;
163  inf.addFactor(fac);
164 }
165 
166 template<class GM, class INF>
167 inline void
169 (
170  INF& inf,
171  const size_t var1,
172  const size_t var2,
173  const ValueType v0,
174  const ValueType v1,
175  const ValueType v2,
176  const ValueType v3
177 ) {
178  const size_t shape[] = {2, 2};
179  const size_t vars[] = {var1, var2};
180  opengm::IndependentFactor<ValueType,IndexType,LabelType> fac(vars, vars + 2, shape, shape + 2);
181  fac(0, 0) = v0;
182  fac(0, 1) = v1;
183  fac(1, 0) = v2;
184  fac(1, 1) = v3;
185  OPENGM_ASSERT(v1 + v2 - v0 - v3 >= 0);
186  inf.addFactor(fac);
187 }
188 template<class GM, class INF>
191  EmptyVisitorType v;
192  return infer(v);
193 }
194 
195 template<class GM, class INF>
196 template<class VISITOR>
199 (
200  VISITOR & visitor
201 ) {
202  bool exitInf=false;
203  visitor.begin(*this);
204  size_t it = 0;
205  size_t countUnchanged = 0;
206  size_t numberOfVariables = gm_.numberOfVariables();
207  std::vector<size_t> variable2Node(numberOfVariables, 0);
208  ValueType energy = gm_.evaluate(label_);
209  size_t vecA[1];
210  size_t vecB[1];
211  size_t vecAA[2];
212  size_t vecAB[2];
213  size_t vecBA[2];
214  size_t vecBB[2];
215  size_t vecAX[2];
216  size_t vecBX[2];
217  size_t vecXA[2];
218  size_t vecXB[2];
219  size_t numberOfLabelPairs = maxState_*(maxState_ - 1)/2;
220  while (it++ < parameter_.maxNumberOfIterations_ && countUnchanged < numberOfLabelPairs && exitInf == false) {
221  increment();
222  size_t counter = 0;
223  std::vector<size_t> numFacDim(4, 0);
224  for (size_t i = 0; i < numberOfVariables; ++i) {
225  if (label_[i] == alpha_ || label_[i] == beta_) {
226  variable2Node[i] = counter++;
227  }
228  }
229  if (counter == 0) {
230  continue;
231  }
232  INF inf(counter, numFacDim);
233  vecA[0] = alpha_;
234  vecB[0] = beta_;
235  vecAA[0] = alpha_;
236  vecAA[1] = alpha_;
237  vecBB[0] = beta_;
238  vecBB[1] = beta_;
239  vecBA[0] = beta_;
240  vecBA[1] = alpha_;
241  vecAB[0] = alpha_;
242  vecAB[1] = beta_;
243  vecAX[0] = alpha_;
244  vecBX[0] = beta_;
245  vecXA[1] = alpha_;
246  vecXB[1] = beta_;
247  for (size_t k = 0; k < gm_.numberOfFactors(); ++k) {
248  const FactorType& factor = gm_[k];
249  if (factor.numberOfVariables() == 1) {
250  size_t var = factor.variableIndex(0);
251  size_t node = variable2Node[var];
252  if (label_[var] == alpha_ || label_[var] == beta_) {
253  OPENGM_ASSERT(alpha_ < gm_.numberOfLabels(var));
254  OPENGM_ASSERT(beta_ < gm_.numberOfLabels(var));
255  addUnary(inf, node, factor(vecA), factor(vecB));
256  //inf.addUnary(node, factor(vecA), factor(vecB));
257  }
258  } else if (factor.numberOfVariables() == 2) {
259  size_t var1 = factor.variableIndex(0);
260  size_t var2 = factor.variableIndex(1);
261  size_t node1 = variable2Node[var1];
262  size_t node2 = variable2Node[var2];
263 
264  if ((label_[var1] == alpha_ || label_[var1] == beta_) && (label_[var2] == alpha_ || label_[var2] == beta_)) {
265  addPairwise(inf, node1, node2, factor(vecAA), factor(vecAB), factor(vecBA), factor(vecBB));
266  //inf.addPairwise(node1, node2, factor(vecAA), factor(vecAB), factor(vecBA), factor(vecBB));
267  } else if ((label_[var1] == alpha_ || label_[var1] == beta_) && (label_[var2] != alpha_ && label_[var2] != beta_)) {
268  vecAX[1] = vecBX[1] = label_[var2];
269  addUnary(inf, node1, factor(vecAX), factor(vecBX));
270  //inf.addUnary(node1, factor(vecAX), factor(vecBX));
271  } else if ((label_[var2] == alpha_ || label_[var2] == beta_) && (label_[var1] != alpha_ && label_[var1] != beta_)) {
272  vecXA[0] = vecXB[0] = label_[var1];
273  addUnary(inf, node2, factor(vecXA), factor(vecXB));
274  //inf.addUnary(node2, factor(vecXA), factor(vecXB));
275  }
276  }
277  }
278  std::vector<LabelType> state; //(counter);
279  inf.infer();
280  inf.arg(state);
281  OPENGM_ASSERT(state.size() == counter);
282  for (size_t var = 0; var < numberOfVariables; ++var) {
283  if (label_[var] == alpha_ || label_[var] == beta_) {
284  if (state[variable2Node[var]] == 0)
285  label_[var] = alpha_;
286  else
287  label_[var] = beta_;
288  } else {
289  //do nothing
290  }
291  }
292  ValueType energy2 = gm_.evaluate(label_);
293  if( visitor(*this) != visitors::VisitorReturnFlag::ContinueInf ){
294  exitInf=true;
295  }
296  OPENGM_ASSERT(!AccumulationType::ibop(energy2, energy));
297  if (AccumulationType::bop(energy2, energy)) {
298  energy = energy2;
299  } else {
300  ++countUnchanged;
301  }
302  }
303  visitor.end(*this);
304  return NORMAL;
305 }
306 
307 template<class GM, class INF>
309 AlphaBetaSwap<GM, INF>::arg(std::vector<LabelType>& arg, const size_t n) const {
310  if (n > 1) {
311  return UNKNOWN;
312  } else {
313  OPENGM_ASSERT(label_.size() == gm_.numberOfVariables());
314  arg.resize(label_.size());
315  for (size_t i = 0; i < label_.size(); ++i)
316  arg[i] = label_[i];
317  return NORMAL;
318  }
319 }
320 
321 } // namespace opengm
322 
323 #endif // #ifndef OPENGM_ALPHABEATSWAP_HXX
const GraphicalModelType & graphicalModel() const
The OpenGM namespace.
Definition: config.hxx:43
Factor (with corresponding function and variable indices), independent of a GraphicalModel.
std::string name() const
AlphaBetaSwap< _GM, RebindedInf > type
opengm::visitors::EmptyVisitor< AlphaBetaSwap< GM, INF > > EmptyVisitorType
AlphaBetaSwap< _GM, RebindedInf > type
opengm::visitors::VerboseVisitor< AlphaBetaSwap< GM, INF > > VerboseVisitorType
#define OPENGM_ASSERT(expression)
Definition: opengm.hxx:77
AlphaBetaSwap(const GraphicalModelType &, Parameter=Parameter())
void setStartingPoint(typename std::vector< LabelType >::const_iterator)
INF::template RebindGmAndAcc< _GM, _ACC >::type RebindedInf
InferenceType::Parameter parameter_
opengm::visitors::TimingVisitor< AlphaBetaSwap< GM, INF > > TimingVisitorType
GraphicalModelType::FactorType FactorType
Definition: inference.hxx:52
Alpha-Beta-Swap Algorithm.
INF::AccumulationType AccumulationType
InferenceTermination arg(std::vector< LabelType > &, const size_t=1) const
GraphicalModelType::ValueType ValueType
Definition: inference.hxx:50
Inference algorithm interface.
Definition: inference.hxx:43
GraphicalModelType::LabelType LabelType
Definition: inference.hxx:48
InferenceTermination infer()
INF::template RebindGm< _GM >::type RebindedInf
OpenGM runtime error.
Definition: opengm.hxx:100
InferenceTermination
Definition: inference.hxx:24