2 #ifndef OPENGM_EXTERNAL_AD3_HXX 3 #define OPENGM_EXTERNAL_AD3_HXX 11 #include "ad3/FactorGraph.h" 22 template<
class GM,
class ACC>
38 template<
class _GM,
class _ACC>
52 const double eta = 0.1,
53 const bool adaptEta =
true,
55 const double residualThreshold = 1e-6,
56 const int verbosity = 0
58 solverType_(solverType),
62 residualThreshold_(residualThreshold),
71 solverType_(p.solverType_),
73 adaptEta_(p.adaptEta_),
75 residualThreshold_(p.residualThreshold_),
76 verbosity_(p.verbosity_)
94 std::string
name()
const;
98 template<
class VisitorType>
103 return gm_.evaluate(arg_);
107 if(inferenceDone_ && parameter_.solverType_==
AD3_ILP ){
117 if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Minimizer>::value){
120 else if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Maximizer>::value){
126 if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Minimizer>::value){
129 else if( meta::Compare<OperatorType,Adder>::value && meta::Compare<AccumulationType,Maximizer>::value){
137 template<
class N_LABELS_ITER>
144 template<
class VI_ITERATOR,
class FUNCTION>
145 void addFactor(VI_ITERATOR viBegin,VI_ITERATOR viEnd,
const FUNCTION &
function);
153 return additional_posteriors_;
158 const GraphicalModelType& gm_;
163 AD3::FactorGraph factor_graph_;
164 std::vector<AD3::MultiVariable*> multi_variables_;
166 std::vector<double> posteriors_;
167 std::vector<double> additional_posteriors_;
170 std::vector<LabelType> arg_;
174 std::vector<LabelType> space_;
182 template<
class GM,
class ACC>
190 numVar_(gm.numberOfVariables()),
192 multi_variables_(gm.numberOfVariables()),
194 additional_posteriors_(),
196 arg_(gm.numberOfVariables(),
static_cast<LabelType>(0)),
197 inferenceDone_(
false),
201 if(meta::Compare<OperatorType,Adder>::value==
false){
202 throw RuntimeError(
"AD3 does not only support opengm::Adder as Operator");
205 if(meta::Compare<AccumulationType,Minimizer>::value==
false and meta::Compare<AccumulationType,Maximizer>::value==
false ){
206 throw RuntimeError(
"AD3 does not only support opengm::Minimizer and opengm::Maximizer as Accumulatpr");
210 bound_ = ACC::template ineutral<ValueType>();
214 factor_graph_.SetVerbosity(parameter_.verbosity_);
216 for(
IndexType fi=0;fi<gm_.numberOfFactors();++fi){
217 maxFactorSize=std::max(static_cast<UInt64Type>(gm_[fi].size()),maxFactorSize);
226 for(
IndexType vi=0;vi<gm_.numberOfVariables();++vi){
227 multi_variables_[vi] = factor_graph_.CreateMultiVariable(gm_.numberOfLabels(vi));
228 for(
LabelType l=0;l<gm_.numberOfLabels(vi);++l){
229 multi_variables_[vi]->SetLogPotential(l,0.0);
236 for(
IndexType fi=0;fi<gm_.numberOfFactors();++fi){
238 gm_[fi].copyValuesSwitchedOrder(facVal);
239 const IndexType nVar=gm_[fi].numberOfVariables();
242 const IndexType vi0 = gm_[fi].variableIndex(0);
243 const IndexType nl0 = gm_.numberOfLabels(vi0);
246 const ValueType logP = multi_variables_[vi0]->GetLogPotential(l);
248 multi_variables_[vi0]->SetLogPotential(l,logP+val);
254 std::vector<double> additional_log_potentials(gm_[fi].size());
260 std::vector<AD3::MultiVariable*> multi_variables_local(nVar);
262 multi_variables_local[v]=multi_variables_[gm_[fi].variableIndex(v)];
266 factor_graph_.CreateFactorDense(multi_variables_local,additional_log_potentials);
269 OPENGM_CHECK(
false,
"const factors are not yet implemented");
278 template<
class GM,
class ACC>
279 template<
class N_LABELS_ITER>
281 N_LABELS_ITER nLabelsBegin,
282 N_LABELS_ITER nLabelsEnd,
287 numVar_(
std::distance(nLabelsBegin,nLabelsEnd)),
289 multi_variables_(
std::distance(nLabelsBegin,nLabelsEnd)),
291 additional_posteriors_(),
293 arg_(
std::distance(nLabelsBegin,nLabelsEnd),static_cast<
LabelType>(0)),
294 space_(nLabelsBegin,nLabelsEnd)
297 if(meta::Compare<OperatorType,Adder>::value==
false){
298 throw RuntimeError(
"AD3 does not only support opengm::Adder as Operator");
300 if(meta::Compare<AccumulationType,Minimizer>::value==
false and meta::Compare<AccumulationType,Maximizer>::value==
false ){
301 throw RuntimeError(
"AD3 does not only support opengm::Minimizer and opengm::Maximizer as Accumulatpr");
303 bound_ = ACC::template ineutral<ValueType>();
304 factor_graph_.SetVerbosity(parameter_.
verbosity_);
308 multi_variables_[vi] = factor_graph_.CreateMultiVariable(space_[vi]);
310 multi_variables_[vi]->SetLogPotential(l,0.0);
315 template<
class GM,
class ACC>
326 multi_variables_(nVar),
328 additional_posteriors_(),
334 if(meta::Compare<OperatorType,Adder>::value==
false){
335 throw RuntimeError(
"AD3 does not only support opengm::Adder as Operator");
337 if(meta::Compare<AccumulationType,Minimizer>::value==
false and meta::Compare<AccumulationType,Maximizer>::value==
false ){
338 throw RuntimeError(
"AD3 does not only support opengm::Minimizer and opengm::Maximizer as Accumulatpr");
340 bound_ = ACC::template ineutral<ValueType>();
341 factor_graph_.SetVerbosity(parameter_.
verbosity_);
343 multi_variables_[vi] = factor_graph_.CreateMultiVariable(space_[vi]);
345 multi_variables_[vi]->SetLogPotential(l,0.0);
351 template<
class GM,
class ACC>
352 template<
class VI_ITERATOR,
class FUNCTION>
355 VI_ITERATOR visBegin,
357 const FUNCTION &
function 359 const IndexType nVis = std::distance(visBegin,visEnd);
360 OPENGM_CHECK_OP(nVis,==,
function.dimension(),
"functions dimension does not match number of variabole indices");
363 OPENGM_CHECK_OP(space_[visBegin[v]],==,
function.shape(v),
"functions shape does not match space");
369 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0]){
370 const ValueType logP = multi_variables_[visBegin[0]]->GetLogPotential(l[0]);
372 multi_variables_[visBegin[0]]->SetLogPotential(l[0],logP+val);
379 std::vector<AD3::MultiVariable*> multi_variables_local(nVis);
381 multi_variables_local[v]=multi_variables_[visBegin[v]];
385 std::vector<double> additional_log_potentials(
function.size());
392 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
393 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1]){
394 additional_log_potentials[c]=this->
valueToMaxSum(
function(l));
401 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
402 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
403 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2]){
404 additional_log_potentials[c]=this->
valueToMaxSum(
function(l));
411 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
412 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
413 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
414 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3]){
415 additional_log_potentials[c]=this->
valueToMaxSum(
function(l));
422 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
423 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
424 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
425 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
426 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4]){
427 additional_log_potentials[c]=this->
valueToMaxSum(
function(l));
434 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
435 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
436 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
437 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
438 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
439 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5]){
440 additional_log_potentials[c]=this->
valueToMaxSum(
function(l));
447 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
448 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
449 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
450 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
451 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
452 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
453 for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6]){
454 additional_log_potentials[c]=this->
valueToMaxSum(
function(l));
461 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
462 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
463 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
464 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
465 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
466 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
467 for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6])
468 for(l[7]=0; l[7]<space_[visBegin[7]]; ++l[7])
470 additional_log_potentials[c]=this->
valueToMaxSum(
function(l));
477 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
478 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
479 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
480 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
481 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
482 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
483 for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6])
484 for(l[7]=0; l[7]<space_[visBegin[7]]; ++l[7])
485 for(l[8]=0; l[8]<space_[visBegin[8]]; ++l[8])
487 additional_log_potentials[c]=this->
valueToMaxSum(
function(l));
494 for(l[0]=0; l[0]<space_[visBegin[0]]; ++l[0])
495 for(l[1]=0; l[1]<space_[visBegin[1]]; ++l[1])
496 for(l[2]=0; l[2]<space_[visBegin[2]]; ++l[2])
497 for(l[3]=0; l[3]<space_[visBegin[3]]; ++l[3])
498 for(l[4]=0; l[4]<space_[visBegin[4]]; ++l[4])
499 for(l[5]=0; l[5]<space_[visBegin[5]]; ++l[5])
500 for(l[6]=0; l[6]<space_[visBegin[6]]; ++l[6])
501 for(l[7]=0; l[7]<space_[visBegin[7]]; ++l[7])
502 for(l[8]=0; l[8]<space_[visBegin[8]]; ++l[8])
503 for(l[9]=0; l[9]<space_[visBegin[9]]; ++l[9])
505 additional_log_potentials[c]=this->
valueToMaxSum(
function(l));
510 throw RuntimeError(
"order must be <=10 for inplace building of Ad3Inf (call us if you need higher order)");
518 factor_graph_.CreateFactorDense(multi_variables_local,additional_log_potentials);
525 template<
class GM,
class ACC>
531 template<
class GM,
class ACC>
538 template<
class GM,
class ACC>
545 template<
class GM,
class ACC>
553 template<
class GM,
class ACC>
554 template<
class VisitorType>
558 visitor.begin(*
this);
562 factor_graph_.SetEtaAD3(parameter_.
eta_);
563 factor_graph_.AdaptEtaAD3(parameter_.
adaptEta_);
564 factor_graph_.SetMaxIterationsAD3(parameter_.
steps_);
568 factor_graph_.SetEtaPSDD(parameter_.
eta_);
569 factor_graph_.SetMaxIterationsPSDD(parameter_.
steps_);
577 factor_graph_.SolveLPMAPWithAD3(&posteriors_, &additional_posteriors_, &value, &bound_);
581 factor_graph_.SolveExactMAPWithAD3(&posteriors_, &additional_posteriors_, &value, &bound_);
585 factor_graph_.SolveExactMAPWithAD3(&posteriors_, &additional_posteriors_, &value, &bound_);
593 for(
IndexType vi = 0; vi < numVar_; ++vi) {
595 double bestVal = -100000;
596 const LabelType nLabels = (space_.size()==0 ? gm_.numberOfLabels(vi) : space_[vi] );
598 const double val = posteriors_[c];
601 if(bestVal<0 || val>bestVal){
616 template<
class GM,
class ACC>
619 ::arg(std::vector<LabelType>&
arg,
const size_t& n)
const {
625 std::copy(arg_.begin(),arg_.end(),arg.begin());
634 #endif // #ifndef OPENGM_EXTERNAL_AD3Inf_HXX
AD3Inf(const GraphicalModelType &gm, const Parameter para=Parameter())
void addFactor(VI_ITERATOR viBegin, VI_ITERATOR viEnd, const FUNCTION &function)
ValueType valueToMaxSum(const ValueType val) const
const std::vector< double > & higherOrderPosteriors() const
ValueType value() const
return the solution (value)
Parameter(const SolverType solverType=AD3_ILP, const double eta=0.1, const bool adaptEta=true, UInt64Type steps=1000, const double residualThreshold=1e-6, const int verbosity=0)
const GraphicalModelType & graphicalModel() const
detail_types::UInt64Type UInt64Type
uint64
ValueType bound() const
return a bound on the solution
double residualThreshold_
const std::vector< double > & posteriors() const
GraphicalModelType::IndexType IndexType
visitors::EmptyVisitor< AD3Inf< GM, ACC > > EmptyVisitorType
GraphicalModelType::ValueType ValueType
Inference algorithm interface.
ValueType valueFromMaxSum(const ValueType val) const
InferenceTermination arg(std::vector< LabelType > &, const size_t &=1) const
#define OPENGM_CHECK_OP(A, OP, B, TXT)
visitors::VerboseVisitor< AD3Inf< GM, ACC > > VerboseVisitorType
InferenceTermination infer()
#define OPENGM_CHECK(B, TXT)
visitors::TimingVisitor< AD3Inf< GM, ACC > > TimingVisitorType
GraphicalModelType::LabelType LabelType