27 #ifndef MERT_FST_LIB_SHORTEST_DISTANCE_H__ 28 #define MERT_FST_LIB_SHORTEST_DISTANCE_H__ 33 #include <fst/arcfilter.h> 34 #include <fst/cache.h> 35 #include <fst/queue.h> 36 #include <fst/reverse.h> 37 #include <fst/test-properties.h> 41 template <
class Arc,
class Queue,
class ArcFilter>
60 StateId src = fst::kNoStateId,
61 float d = fst::kDelta)
62 : state_queue (q), arc_filter (filt), source (src), delta (d),
73 template<
class Arc,
class Queue,
class ArcFilter>
80 const fst::Fst<Arc>&
fst,
81 std::vector<Weight> *distance,
84 : fst_ (fst), distance_ (distance), state_queue_ (opts.
state_queue),
95 const fst::Fst<Arc>& fst_;
96 std::vector<Weight> *distance_;
98 ArcFilter arc_filter_;
103 std::vector<Weight> rdistance_;
104 std::vector<bool> enqueued_;
105 std::vector<StateId> sources_;
111 template <
class Arc,
class Queue,
class ArcFilter>
114 if (fst_.Start() == fst::kNoStateId)
116 if (! (Weight::Properties() & fst::kRightSemiring) )
117 LOG (FATAL) <<
"ShortestDistance: Weight needs to be right distributive: " 119 if (first_path_ && ! (Weight::Properties() & fst::kPath) )
120 LOG (FATAL) <<
"ShortestDistance: first_path option disallowed when " 121 <<
"Weight does not have the path property: " 123 state_queue_->Clear();
129 if (source == fst::kNoStateId)
130 source = fst_.Start();
131 while (distance_->size() <=
source) {
132 distance_->push_back (Weight::Zero() );
133 rdistance_.push_back (Weight::Zero() );
134 enqueued_.push_back (
false);
137 while (sources_.size() <=
source)
138 sources_.push_back (fst::kNoStateId);
141 (*distance_) [
source] = Weight::One();
142 rdistance_[
source] = Weight::One();
144 state_queue_->Enqueue (source);
145 while (!state_queue_->Empty() ) {
146 StateId s = state_queue_->Head();
147 state_queue_->Dequeue();
148 while (distance_->size() <= s) {
149 distance_->push_back (Weight::Zero() );
150 rdistance_.push_back (Weight::Zero() );
151 enqueued_.push_back (
false);
153 if (first_path_ && (fst_.Final (s) != Weight::Zero() ) )
155 enqueued_[s] =
false;
157 rdistance_[s] = Weight::Zero();
158 for (fst::ArcIterator< fst::Fst<Arc> > aiter (fst_, s);
161 const Arc& arc = aiter.Value();
162 if (!arc_filter_ (arc) || arc.weight == Weight::Zero() )
164 while (distance_->size() <= arc.nextstate) {
165 distance_->push_back (Weight::Zero() );
166 rdistance_.push_back (Weight::Zero() );
167 enqueued_.push_back (
false);
170 while (sources_.size() <= arc.nextstate)
171 sources_.push_back (fst::kNoStateId);
172 if (sources_[arc.nextstate] != source) {
173 (*distance_) [arc.nextstate] = Weight::Zero();
174 rdistance_[arc.nextstate] = Weight::Zero();
175 enqueued_[arc.nextstate] =
false;
176 sources_[arc.nextstate] =
source;
179 Weight& nd = (*distance_) [arc.nextstate];
180 Weight& nr = rdistance_[arc.nextstate];
186 if (!enqueued_[arc.nextstate]) {
187 state_queue_->Enqueue (arc.nextstate);
188 enqueued_[arc.nextstate] =
true;
190 state_queue_->Update (arc.nextstate);
215 template<
class Arc,
class Queue,
class ArcFilter>
217 const fst::Fst<Arc>&
fst,
218 std::vector<typename Arc::Weight> *distance,
221 sd_state (fst, distance, opts,
false);
247 std::vector<typename Arc::Weight> *distance,
248 bool reverse =
false,
249 float delta = fst::kDelta) {
250 typedef typename Arc::StateId
StateId;
251 typedef typename Arc::Weight Weight;
254 fst::AutoQueue<StateId>
state_queue (fst, distance, arc_filter);
256 opts (&state_queue, arc_filter);
260 typedef fst::ReverseArc<Arc> ReverseArc;
261 typedef typename ReverseArc::Weight ReverseWeight;
262 fst::AnyArcFilter<ReverseArc> rarc_filter;
263 fst::VectorFst<ReverseArc> rfst;
264 Reverse (fst, &rfst);
265 std::vector<ReverseWeight> rdistance;
266 fst::AutoQueue<StateId>
state_queue (rfst, &rdistance, rarc_filter);
268 fst::AnyArcFilter<ReverseArc> >
269 ropts (&state_queue, rarc_filter);
273 while (distance->size() < rdistance.size() - 1)
274 distance->push_back (rdistance[distance->size() + 1].Reverse() );
282 typedef typename Arc::Weight Weight;
283 typedef typename Arc::StateId
StateId;
284 std::vector<Weight> distance;
285 if (Weight::Properties() & fst::kRightSemiring) {
287 Weight sum = Weight::Zero();
288 for (StateId s = 0; s < distance.size(); ++s)
289 sum =
Plus (sum,
Times (distance[s], fst.Final (s) ) );
293 StateId s = fst.Start();
294 return s != fst::kNoStateId && s < distance.size() ?
295 distance[s] : Weight::Zero();
301 #endif // MERT_FST_LIB_SHORTEST_DISTANCE_H__
ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src=fst::kNoStateId, float d=fst::kDelta)
bool ApproxEqual(const TropicalSparseTupleWeight< T > &vw1, const TropicalSparseTupleWeight< T > &vw2, float delta=kDelta)
void ShortestDistance(StateId source)
TropicalSparseTupleWeight< T > Plus(const TropicalSparseTupleWeight< T > &vw1, const TropicalSparseTupleWeight< T > &vw2)
ShortestDistanceState(const fst::Fst< Arc > &fst, std::vector< Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts, bool retain)
TropicalSparseTupleWeight< T > Times(const TropicalSparseTupleWeight< T > &w1, const TropicalSparseTupleWeight< T > &w2)
void ShortestDistance(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)