2 #ifndef OPENGM_ASTAR_HXX 3 #define OPENGM_ASTAR_HXX 26 template<
class FactorType>
struct AStarNode {
27 typename std::vector<typename FactorType::LabelType> conf;
28 typename FactorType::ValueType value;
63 template<
class GM,
class ACC>
73 typedef typename std::vector<LabelType>
ConfVec ;
87 template<
class _GM,
class _ACC>
102 maxHeapSize_ = 3000000;
104 objectiveBound_ = AccumulationType::template neutral<ValueType>();
105 heuristic_ = Parameter::DEFAULTHEURISTIC;
110 : maxHeapSize_(p.maxHeapSize_),
111 numberOfOpt_(p.numberOfOpt_),
112 objectiveBound_(p.objectiveBound_),
113 nodeOrder_(p.nodeOrder_),
114 treeFactorIds_(p.treeFactorIds_){
120 { treeFactorIds_.push_back(
id); }
122 static const size_t DEFAULTHEURISTIC = 0;
124 static const size_t FASTHEURISTIC = 1;
126 static const size_t STANDARDHEURISTIC = 2;
145 virtual std::string
name()
const {
return "AStar";}
146 const GraphicalModelType& graphicalModel()
const;
148 virtual void reset();
163 std::vector<AStarNode<IndependentFactorType> > array_;
164 std::vector<size_t> numStates_;
166 std::vector<IndependentFactorType> treeFactor_;
167 std::vector<IndependentFactorType> optimizedFactor_;
168 std::vector<ConfVec > optConf_;
169 std::vector<bool> isTreeFactor_;
172 template<
class VisitorType>
void expand(VisitorType& vistitor);
173 std::vector<ValueType> fastHeuristic(ConfVec conf);
174 inline static bool comp1(
const AStarNode<IndependentFactorType>& a,
const AStarNode<IndependentFactorType>& b)
175 {
return AccumulationType::ibop(a.value,b.value);};
176 inline static bool comp2(
const AStarNode<IndependentFactorType>& a,
const AStarNode<IndependentFactorType>& b)
177 {
return AccumulationType::bop(a.value,b.value);};
190 template<
class GM,
class ACC >
198 parameterInitial_=para;
200 if( parameter_.heuristic_ == Parameter::DEFAULTHEURISTIC) {
201 if(gm_.factorOrder()<=2)
202 parameter_.
heuristic_ = Parameter::FASTHEURISTIC;
204 parameter_.heuristic_ = Parameter::STANDARDHEURISTIC;
206 OPENGM_ASSERT(parameter_.heuristic_ == Parameter::FASTHEURISTIC || parameter_.heuristic_ == Parameter::STANDARDHEURISTIC);
207 ACC::ineutral(belowBound_);
208 ACC::neutral(aboveBound_);
210 isTreeFactor_.resize(gm_.numberOfFactors());
211 numStates_.resize(gm_.numberOfVariables());
212 numNodes_ = gm_.numberOfVariables();
213 for(
size_t i=0; i<numNodes_;++i)
214 numStates_[i] = gm_.numberOfLabels(i);
216 if(parameter_.nodeOrder_.size()==0) {
217 parameter_.nodeOrder_.resize(numNodes_);
218 std::vector<std::pair<IndexType,IndexType> > tmp(numNodes_,std::pair<IndexType,IndexType>());
219 for(
size_t i=0; i<numNodes_; ++i){
220 tmp[i].first = gm_.numberOfFactors(i);
223 std::sort(tmp.begin(),tmp.end());
224 for(
size_t i=0; i<numNodes_; ++i){
225 parameter_.nodeOrder_[i] = tmp[numNodes_-i-1].second;
229 if(parameter_.nodeOrder_.size()!=numNodes_)
230 throw RuntimeError(
"The node order does not fit to the model.");
231 OPENGM_ASSERT(std::set<size_t>(parameter_.nodeOrder_.begin(), parameter_.nodeOrder_.end()).size()==numNodes_);
232 for(
size_t i=0;i<numNodes_; ++i) {
237 if(parameter_.treeFactorIds_.size()==0) {
239 for(
size_t i=0; i<gm_.numberOfFactors(); ++i) {
240 if((gm_[i].numberOfVariables()==2) &&
241 (gm_[i].variableIndex(0)==parameter_.nodeOrder_.back() || gm_[i].variableIndex(1)==parameter_.nodeOrder_.back())
243 parameter_.addTreeFactorId(i);
246 for(
size_t i=0; i<parameter_.treeFactorIds_.size(); ++i)
247 OPENGM_ASSERT(gm_.numberOfFactors() > parameter_.treeFactorIds_[i]);
249 optimizedFactor_.resize(gm_.numberOfFactors());
250 for(
size_t i=0; i<gm_.numberOfFactors(); ++i) {
251 if(gm_[i].numberOfVariables()<=1)
continue;
252 std::vector<size_t> index(gm_[i].numberOfVariables());
253 gm_[i].variableIndices(index.begin());
254 optimizedFactor_[i].assign(gm_ ,index.end()-1, index.end());
255 opengm::accumulate<ACC>(gm[i],index.begin()+1,index.end(),optimizedFactor_[i]);
258 OPENGM_ASSERT(optimizedFactor_[i].variableIndex(0) == index[0]);
261 AStarNode<IndependentFactorType> a;
265 make_heap(array_.begin(), array_.end(), comp1);
267 if(parameter_.heuristic_ == parameter_.FASTHEURISTIC) {
268 for(
size_t i=0; i<parameter_.treeFactorIds_.size(); ++i) {
269 if(gm_[parameter_.treeFactorIds_[i]].numberOfVariables() > 2) {
270 throw RuntimeError(
"The heuristic includes factor of order > 2.");
276 for(
size_t i=0; i<gm_.numberOfFactors(); ++i)
277 isTreeFactor_[i] =
false;
278 for(
size_t i=0; i<parameter_.treeFactorIds_.size(); ++i) {
279 int factorId = parameter_.treeFactorIds_[i];
280 isTreeFactor_[factorId] =
true;
281 treeFactor_.push_back(gm_[factorId]);
291 template<
class GM,
class ACC >
298 template <
class GM,
class ACC>
308 template<
class GM,
class ACC>
309 template<
class VisitorType>
314 visitor.begin(*
this);
315 while(array_.size()>0 && exitFlag==0) {
316 if(parameter_.numberOfOpt_ == optConf_.size()) {
320 while(array_.front().conf.size() < numNodes_ && exitFlag==0) {
322 belowBound_ = array_.front().value;
323 exitFlag = visitor(*
this);
326 if(array_.front().conf.size()>=numNodes_){
329 std::vector<LabelType> conf(numNodes_);
330 for(
size_t n=0; n<numNodes_; ++n) {
331 conf[parameter_.nodeOrder_[n]] = array_.front().conf[n];
333 optConf_.push_back(conf);
335 if(ACC::bop(parameter_.objectiveBound_, value)) {
340 pop_heap(array_.begin(), array_.end(), comp1);
347 template<
class GM,
class ACC>
350 if(optConf_.size()>=1){
351 return gm_.evaluate(optConf_[0]);
354 return ACC::template neutral<ValueType>();
358 template<
class GM,
class ACC>
360 ::arg(ConfVec& conf,
const size_t n)
const 362 if(n>optConf_.size()) {
363 conf.resize(gm_.numberOfVariables(),0);
375 template<
class GM,
class ACC>
383 template<
class GM,
class ACC>
384 template<
class VisitorType>
388 if(array_.size()>parameter_.maxHeapSize_*0.99) {
389 partial_sort(array_.begin(), array_.begin()+(int)(parameter_.maxHeapSize_/2), array_.end(), comp2);
390 array_.resize((
int)(parameter_.maxHeapSize_/2));
393 AStarNode<IndependentFactorType> a = array_.front();
394 size_t subconfsize = a.conf.size();
397 pop_heap(array_.begin(), array_.end(), comp1);
399 if( parameter_.heuristic_ == parameter_.STANDARDHEURISTIC) {
408 std::vector<IndexType> varMap(gm_.numberOfVariables(),0);
409 std::vector<LabelType> fixVariableLabel(gm_.numberOfVariables(),0);
410 std::vector<bool> fixVariable(gm_.numberOfVariables(),
false);
411 for(
size_t i =0; i<subconfsize ; ++i) {
412 fixVariableLabel[parameter_.nodeOrder_[i]] = a.conf[i];
413 fixVariable[parameter_.nodeOrder_[i]] =
true;
416 for(
IndexType var=0; var<gm_.numberOfVariables();++var){
417 if(fixVariable[var]==
false){
418 varMap[var] = numberOfVariables++;
421 std::vector<LabelType> shape(numberOfVariables,0);
422 for(
IndexType var=0; var<gm_.numberOfVariables();++var){
423 if(fixVariable[var]==
false){
424 shape[varMap[var]] = gm_.numberOfLabels(var);
427 MSpaceType space(shape.begin(),shape.end());
430 std::vector<PositionAndLabel<IndexType,LabelType> > fixedVars;
431 std::vector<IndexType> MVars;
433 GM::OperatorType::neutral(constant);
435 for(
IndexType f=0; f<gm_.numberOfFactors();++f){
438 for(
IndexType i=0; i<gm_[f].numberOfVariables(); ++i){
439 const IndexType var = gm_[f].variableIndex(i);
440 if(fixVariable[var]){
441 fixedVars.push_back(PositionAndLabel<IndexType,LabelType>(i,fixVariableLabel[var]));
443 MVars.push_back(varMap[var]);
446 if(fixedVars.size()==gm_[f].numberOfVariables()){
447 std::vector<LabelType> fixedStates(gm_[f].numberOfVariables(),0);
448 for(
IndexType i=0; i<gm_[f].numberOfVariables(); ++i){
449 fixedStates[i]=fixVariableLabel[ gm_[f].variableIndex(i)];
451 GM::OperatorType::op(gm_[f](fixedStates.begin()),constant);
453 if(MVars.size()<2 || isTreeFactor_[f]){
454 const ViewFixVariablesFunction<GM> func(gm_[f], fixedVars);
455 mgm.addFactor(mgm.addFunction(func),MVars.begin(), MVars.end());
457 std::vector<IndexType> variablesIndices(optimizedFactor_[f].numberOfVariables());
458 for(
size_t i=0; i<variablesIndices.size(); ++i)
459 variablesIndices[i] = varMap[optimizedFactor_[f].variableIndex(i)];
460 LabelType numberOfLabels = optimizedFactor_[f].numberOfLabels(0);
462 for(
LabelType i=0; i<numberOfLabels; ++i)
463 func(i) = optimizedFactor_[f](i);
464 mgm.addFactor(mgm.addFunction(func),variablesIndices.begin(),variablesIndices.end() );
465 OPENGM_ASSERT(mgm[mgm.numberOfFactors()-1].numberOfVariables()==1);
472 mgm.addFactor(mgm.addFunction(func),MVars.begin(), MVars.begin());
486 ACC::op(bp.value(),aboveBound_,aboveBound_);
487 std::vector<LabelType> conf(mgm.numberOfVariables());
489 std::vector<IndexType> theVar(1, varMap[parameter_.nodeOrder_[subconfsize]]);
491 std::vector<LabelType> theLabel(1,0);
492 a.conf.resize(subconfsize+1);
493 for(
size_t i=0; i<numStates_[parameter_.nodeOrder_[subconfsize]]; ++i) {
494 a.conf[subconfsize] = i;
496 bp.constrainedOptimum(theVar,theLabel,conf);
497 a.value = mgm.evaluate(conf);
499 push_heap(array_.begin(), array_.end(), comp1);
502 if( parameter_.heuristic_ == parameter_.FASTHEURISTIC) {
503 std::vector<LabelType> conf(subconfsize);
504 for(
size_t i=0;i<subconfsize;++i)
506 std::vector<ValueType> bound = fastHeuristic(conf);
507 a.conf.resize(subconfsize+1);
508 for(
size_t i=0; i<numStates_[parameter_.nodeOrder_[subconfsize]]; ++i) {
509 a.conf[subconfsize] = i;
513 push_heap(array_.begin(), array_.end(), comp1);
519 template<
class GM,
class ACC>
520 std::vector<typename AStar<GM, ACC>::ValueType>
523 std::list<size_t> factorList;
524 std::vector<size_t> nodeDegree(numNodes_,0);
525 std::vector<int> nodeLabel(numNodes_,-1);
526 std::vector<std::vector<ValueType > > nodeEnergy(numNodes_);
527 size_t nextNode = parameter_.nodeOrder_[conf.size()];
528 for(
size_t i=0; i<numNodes_; ++i) {
529 nodeEnergy[i].resize(numStates_[i]);
530 for(
size_t j=0;j<numStates_[i];++j)
531 OperatorType::neutral(nodeEnergy[i][j]);
533 for(
size_t i=0;i<conf.size();++i) {
534 nodeLabel[parameter_.nodeOrder_[i]] = conf[i];
540 for(
size_t i=0; i<gm_.numberOfFactors(); ++i) {
542 size_t nvar = f.numberOfVariables();
545 int index = f.variableIndex(0);
546 if(nodeLabel[index]>=0) {
547 nodeEnergy[index].resize(1);
550 OperatorType::op(f(coordinates), nodeEnergy[index][0]);
553 OPENGM_ASSERT(numStates_[index] == nodeEnergy[index].size());
554 for(
size_t j=0;j<numStates_[index];++j) {
557 OperatorType::op(f(coordinates),nodeEnergy[index][j]);
562 size_t index1 = f.variableIndex(0);
563 size_t index2 = f.variableIndex(1);
564 if(nodeLabel[index1]>=0) {
565 if(nodeLabel[index2]>=0) {
566 nodeEnergy[index1].resize(1);
569 static_cast<LabelType>(nodeLabel[index1]),
570 static_cast<LabelType>(nodeLabel[index2])
572 OperatorType::op(f(coordinates),nodeEnergy[index1][0]);
575 OPENGM_ASSERT(numStates_[index2] == nodeEnergy[index2].size());
576 for(
size_t j=0;j<numStates_[index2];++j) {
579 static_cast<LabelType>(nodeLabel[index1]),
580 static_cast<LabelType>(j)
582 OperatorType::op(f(coordinates), nodeEnergy[index2][j]);
586 else if(nodeLabel[index2]>=0) {
587 OPENGM_ASSERT(numStates_[index1] == nodeEnergy[index1].size());
588 for(
size_t j=0;j<numStates_[index1];++j) {
592 static_cast<LabelType>(nodeLabel[index2])
594 OperatorType::op(f(coordinates),nodeEnergy[index1][j]);
597 else if(isTreeFactor_[i]) {
598 factorList.push_front(i);
599 ++nodeDegree[index1];
600 ++nodeDegree[index2];
604 for(
size_t j=0;j<numStates_[index1];++j) {
607 OperatorType::op(optimizedFactor_[i](coordinates), nodeEnergy[index1][j]);
613 std::vector<size_t> state(nvar);
614 for(
size_t j=0; j<nvar; ++j) {
615 if(nodeLabel[f.variableIndex(j)]<0) {
616 state[j] = nodeLabel[f.variableIndex(j)];
621 nodeEnergy[f.variableIndex(0)][0] = f(state.begin());
623 for(
size_t j=0;j<numStates_[f.variableIndex(0)];++j) {
626 OperatorType::op(optimizedFactor_[i](coordinates), nodeEnergy[f.variableIndex(0)][j]);
631 nodeDegree[nextNode] += numNodes_;
633 while(factorList.size()>0) {
634 size_t id = factorList.front();
635 factorList.pop_front();
637 size_t index1 = f.variableIndex(0);
638 size_t index2 = f.variableIndex(1);
639 typename FactorType::ValueType temp;
642 OPENGM_ASSERT(gm_.numberOfLabels(index1) == numStates_[index1]);
643 OPENGM_ASSERT(gm_.numberOfLabels(index2) == numStates_[index2]);
644 if(nodeDegree[index1]==1) {
645 typename FactorType::ValueType min;
646 OPENGM_ASSERT(numStates_[index2] == nodeEnergy[index2].size());
647 for(
size_t j2=0;j2<numStates_[index2];++j2) {
649 OPENGM_ASSERT(numStates_[index1] == nodeEnergy[index1].size());
650 for(
size_t j1=0;j1<numStates_[index1];++j1) {
652 OperatorType::op(f(coordinates),nodeEnergy[index1][j1],temp);
653 ACC::op(min,temp,min);
656 OperatorType::op(min,nodeEnergy[index2][j2]);
658 --nodeDegree[index1];
659 --nodeDegree[index2];
660 nodeEnergy[index1].resize(1);
661 OperatorType::neutral(nodeEnergy[index1][0]);
663 else if(nodeDegree[index2]==1) {
664 typename FactorType::ValueType min;
665 OPENGM_ASSERT(numStates_[index1] == nodeEnergy[index1].size());
666 for(
size_t j1=0;j1<numStates_[index1];++j1) {
668 OPENGM_ASSERT(numStates_[index2] == nodeEnergy[index2].size());
669 for(
size_t j2=0;j2<numStates_[index2];++j2) {
671 OperatorType::op(f(coordinates),nodeEnergy[index2][j2],temp);
672 ACC::op(min,temp,min);
676 OperatorType::op(min,nodeEnergy[index1][j1]);
678 --nodeDegree[index1];
679 --nodeDegree[index2];
680 nodeEnergy[index2].resize(1);
681 OperatorType::neutral(nodeEnergy[index2][0]);
684 factorList.push_back(
id);
690 OperatorType::neutral(result);
691 std::vector<ValueType > bound;
692 for(
size_t i=0;i<numNodes_;++i) {
693 if(i==nextNode)
continue;
695 for(
size_t j=0; j<nodeEnergy[i].size();++j)
696 ACC::op(min,nodeEnergy[i][j],min);
698 OperatorType::op(min,result);
700 bound.resize(nodeEnergy[nextNode].size());
701 for(
size_t j=0; j<nodeEnergy[nextNode].size();++j) {
703 OperatorType::op(nodeEnergy[nextNode][j],result,bound[j]);
708 template<
class GM,
class ACC>
717 #endif // #ifndef OPENGM_ASTAR_HXX
Update rules for the MessagePassing framework.
virtual InferenceTermination arg(std::vector< LabelType > &v, const size_t=1) const
output a solution
opengm::visitors::EmptyVisitor< AStar< GM, ACC > > EmptyVisitorType
opengm::visitors::VerboseVisitor< AStar< GM, ACC > > VerboseVisitorType
visitor
Discrete space in which variables can have differently many labels.
opengm::visitors::TimingVisitor< AStar< GM, ACC > > TimingVisitorType
std::vector< LabelType > ConfVec
configuration vector type
void infer(const typename INF::GraphicalModelType &gm, const typename INF::Parameter ¶m, std::vector< typename INF::LabelType > &conf)
virtual std::string name() const
opengm::Tribool isAcyclic_
A framework for message passing algorithms Cf. F. R. Kschischang, B. J. Frey and H...
AStar(const GM &gm, Parameter para=Parameter())
constructor
size_t numberOfOpt_
number od N-best solutions that should be found
ValueType bound() const
return a bound on the solution
#define OPENGM_ASSERT(expression)
ValueType objectiveBound_
objective bound
ValueType value() const
return the solution (value)
void addTreeFactorId(size_t id)
add tree factor id
ACC AccumulationType
accumulation type
virtual InferenceTermination args(std::vector< std::vector< LabelType > > &v) const
args
virtual InferenceTermination factorMarginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for all variables connected to a factor
GraphicalModelType::IndexType IndexType
reference to a Factor of a GraphicalModel
GraphicalModelType::FactorType FactorType
std::vector< size_t > treeFactorIds_
size_t heuristic_
heuritstic
GraphicalModelType::ValueType ValueType
ConfVec::iterator ConfVecIt
configuration iterator
GM GraphicalModelType
graphical model type
Inference algorithm interface.
size_t maxHeapSize_
maxHeapSize_ maximum size of the heap
virtual InferenceTermination marginal(const size_t, IndependentFactorType &out) const
output a solution for a marginal for a specific variable
const GraphicalModelType & graphicalModel() const
virtual void reset()
reset
virtual InferenceTermination infer()
Funcion that refers to a factor of another GraphicalModel in which some variables are fixed...
std::vector< IndexType > nodeOrder_
GraphicalModelType::LabelType LabelType
size_t maximumNumberOfSteps_
GraphicalModelType::IndependentFactorType IndependentFactorType