Cambridge SMT System
fast-shortest-distance.h
Go to the documentation of this file.
1 // shortest-distance.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Author: allauzen@google.com (Cyril Allauzen)
16 //
17 // \file
18 // Functions and classes to find shortest distance in an FST.
19 //
20 // This file has been modified by Aurelien Waite from the University of
21 // Cambridge.
22 //
23 // This algorithm performed the plus operation twice for two identical
24 // weights. The plus operation is expensive for tropical monomial weights
25 // The first weight is now stored in a temporary variable for speed.
26 
27 #ifndef MERT_FST_LIB_SHORTEST_DISTANCE_H__
28 #define MERT_FST_LIB_SHORTEST_DISTANCE_H__
29 
30 #include <deque>
31 #include <vector>
32 
33 #include <fst/arcfilter.h>
34 #include <fst/cache.h>
35 #include <fst/queue.h>
36 #include <fst/reverse.h>
37 #include <fst/test-properties.h>
38 
39 namespace mertfst {
40 
41 template <class Arc, class Queue, class ArcFilter>
43  typedef typename Arc::StateId StateId;
44 
45  Queue *state_queue; // Queue discipline used; owned by caller
46  ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph)
47  StateId source; // If kNoStateId, use the Fst's initial state
48  float delta; // Determines the degree of convergence required
49  bool first_path; // For a semiring with the path property (o.w.
50  // undefined), compute the shortest-distances along
51  // along the first path to a final state found
52  // by the algorithm. That path is the shortest-path
53  // only if the FST has a unique final state (or all
54  // the final states have the same final weight), the
55  // queue discipline is shortest-first and all the
56  // weights in the FST are between One() and Zero()
57  // according to NaturalLess.
58 
59  ShortestDistanceOptions (Queue *q, ArcFilter filt,
60  StateId src = fst::kNoStateId,
61  float d = fst::kDelta)
62  : state_queue (q), arc_filter (filt), source (src), delta (d),
63  first_path (false) {}
64 };
65 
66 // Computation state of the shortest-distance algorithm. Reusable
67 // information is maintained across calls to member function
68 // ShortestDistance(source) when 'retain' is true for improved
69 // efficiency when calling multiple times from different source states
70 // (e.g., in epsilon removal). Contrary to usual conventions, 'fst'
71 // may not be freed before this class. Vector 'distance' should not be
72 // modified by the user between these calls.
73 template<class Arc, class Queue, class ArcFilter>
75  public:
76  typedef typename Arc::StateId StateId;
77  typedef typename Arc::Weight Weight;
78 
80  const fst::Fst<Arc>& fst,
81  std::vector<Weight> *distance,
83  bool retain)
84  : fst_ (fst), distance_ (distance), state_queue_ (opts.state_queue),
85  arc_filter_ (opts.arc_filter),
86  delta_ (opts.delta), first_path_ (opts.first_path), retain_ (retain) {
87  distance_->clear();
88  }
89 
91 
92  void ShortestDistance (StateId source);
93 
94  private:
95  const fst::Fst<Arc>& fst_;
96  std::vector<Weight> *distance_;
97  Queue *state_queue_;
98  ArcFilter arc_filter_;
99  float delta_;
100  bool first_path_;
101  bool retain_; // Retain and reuse information across calls
102 
103  std::vector<Weight> rdistance_; // Relaxation distance.
104  std::vector<bool> enqueued_; // Is state enqueued?
105  std::vector<StateId> sources_; // Source state for ith state in 'distance_',
106  // 'rdistance_', and 'enqueued_' if retained.
107 };
108 
109 // Compute the shortest distance. If 'source' is kNoStateId, use
110 // the initial state of the Fst.
111 template <class Arc, class Queue, class ArcFilter>
113  StateId source) {
114  if (fst_.Start() == fst::kNoStateId)
115  return;
116  if (! (Weight::Properties() & fst::kRightSemiring) )
117  LOG (FATAL) << "ShortestDistance: Weight needs to be right distributive: "
118  << Weight::Type();
119  if (first_path_ && ! (Weight::Properties() & fst::kPath) )
120  LOG (FATAL) << "ShortestDistance: first_path option disallowed when "
121  << "Weight does not have the path property: "
122  << Weight::Type();
123  state_queue_->Clear();
124  if (!retain_) {
125  distance_->clear();
126  rdistance_.clear();
127  enqueued_.clear();
128  }
129  if (source == fst::kNoStateId)
130  source = fst_.Start();
131  while (distance_->size() <= source) {
132  distance_->push_back (Weight::Zero() );
133  rdistance_.push_back (Weight::Zero() );
134  enqueued_.push_back (false);
135  }
136  if (retain_) {
137  while (sources_.size() <= source)
138  sources_.push_back (fst::kNoStateId);
139  sources_[source] = source;
140  }
141  (*distance_) [source] = Weight::One();
142  rdistance_[source] = Weight::One();
143  enqueued_[source] = true;
144  state_queue_->Enqueue (source);
145  while (!state_queue_->Empty() ) {
146  StateId s = state_queue_->Head();
147  state_queue_->Dequeue();
148  while (distance_->size() <= s) {
149  distance_->push_back (Weight::Zero() );
150  rdistance_.push_back (Weight::Zero() );
151  enqueued_.push_back (false);
152  }
153  if (first_path_ && (fst_.Final (s) != Weight::Zero() ) )
154  break;
155  enqueued_[s] = false;
156  Weight r = rdistance_[s];
157  rdistance_[s] = Weight::Zero();
158  for (fst::ArcIterator< fst::Fst<Arc> > aiter (fst_, s);
159  !aiter.Done();
160  aiter.Next() ) {
161  const Arc& arc = aiter.Value();
162  if (!arc_filter_ (arc) || arc.weight == Weight::Zero() )
163  continue;
164  while (distance_->size() <= arc.nextstate) {
165  distance_->push_back (Weight::Zero() );
166  rdistance_.push_back (Weight::Zero() );
167  enqueued_.push_back (false);
168  }
169  if (retain_) {
170  while (sources_.size() <= arc.nextstate)
171  sources_.push_back (fst::kNoStateId);
172  if (sources_[arc.nextstate] != source) {
173  (*distance_) [arc.nextstate] = Weight::Zero();
174  rdistance_[arc.nextstate] = Weight::Zero();
175  enqueued_[arc.nextstate] = false;
176  sources_[arc.nextstate] = source;
177  }
178  }
179  Weight& nd = (*distance_) [arc.nextstate];
180  Weight& nr = rdistance_[arc.nextstate];
181  Weight w = Times (r, arc.weight);
182  const Weight& nd_temp = Plus (nd, w);
183  if (!ApproxEqual (nd, nd_temp, delta_) ) {
184  nd = nd_temp;
185  nr = Plus (nr, w);
186  if (!enqueued_[arc.nextstate]) {
187  state_queue_->Enqueue (arc.nextstate);
188  enqueued_[arc.nextstate] = true;
189  } else {
190  state_queue_->Update (arc.nextstate);
191  }
192  }
193  }
194  }
195 }
196 
197 // Shortest-distance algorithm: this version allows fine control
198 // via the options argument. See below for a simpler interface.
199 //
200 // This computes the shortest distance from the 'opts.source' state to
201 // each visited state S and stores the value in the 'distance' vector.
202 // An unvisited state S has distance Zero(), which will be stored in
203 // the 'distance' vector if S is less than the maximum visited state.
204 // The state queue discipline, arc filter, and convergence delta are
205 // taken in the options argument.
206 
207 // The weights must must be right distributive and k-closed (i.e., 1 +
208 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
209 //
210 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
211 // Shortest-Distance Problems", Journal of Automata, Languages and
212 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
213 // depends on the properties of the semiring and the queue discipline
214 // used. Refer to the paper for more details.
215 template<class Arc, class Queue, class ArcFilter>
217  const fst::Fst<Arc>& fst,
218  std::vector<typename Arc::Weight> *distance,
221  sd_state (fst, distance, opts, false);
222  sd_state.ShortestDistance (opts.source);
223 }
224 
225 // Shortest-distance algorithm: simplified interface. See above for a
226 // version that allows finer control.
227 //
228 // If 'reverse' is false, this computes the shortest distance from the
229 // initial state to each state S and stores the value in the
230 // 'distance' vector. If 'reverse' is true, this computes the shortest
231 // distance from each state to the final states. An unvisited state S
232 // has distance Zero(), which will be stored in the 'distance' vector
233 // if S is less than the maximum visited state. The state queue
234 // discipline is automatically-selected.
235 //
236 // The weights must must be right (left) distributive if reverse is
237 // false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 +
238 // x + x^2 + ... + x^k).
239 //
240 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
241 // Shortest-Distance Problems", Journal of Automata, Languages and
242 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
243 // depends on the properties of the semiring and the queue discipline
244 // used. Refer to the paper for more details.
245 template<class Arc>
246 void ShortestDistance (const fst::Fst<Arc>& fst,
247  std::vector<typename Arc::Weight> *distance,
248  bool reverse = false,
249  float delta = fst::kDelta) {
250  typedef typename Arc::StateId StateId;
251  typedef typename Arc::Weight Weight;
252  if (!reverse) {
253  fst::AnyArcFilter<Arc> arc_filter;
254  fst::AutoQueue<StateId> state_queue (fst, distance, arc_filter);
256  opts (&state_queue, arc_filter);
257  opts.delta = delta;
258  ShortestDistance (fst, distance, opts);
259  } else {
260  typedef fst::ReverseArc<Arc> ReverseArc;
261  typedef typename ReverseArc::Weight ReverseWeight;
262  fst::AnyArcFilter<ReverseArc> rarc_filter;
263  fst::VectorFst<ReverseArc> rfst;
264  Reverse (fst, &rfst);
265  std::vector<ReverseWeight> rdistance;
266  fst::AutoQueue<StateId> state_queue (rfst, &rdistance, rarc_filter);
268  fst::AnyArcFilter<ReverseArc> >
269  ropts (&state_queue, rarc_filter);
270  ropts.delta = delta;
271  ShortestDistance (rfst, &rdistance, ropts);
272  distance->clear();
273  while (distance->size() < rdistance.size() - 1)
274  distance->push_back (rdistance[distance->size() + 1].Reverse() );
275  }
276 }
277 
278 // Return the sum of the weight of all successful paths in an FST, i.e.,
279 // the shortest-distance from the initial state to the final states.
280 template <class Arc>
281 typename Arc::Weight ShortestDistance (const fst::Fst<Arc>& fst) {
282  typedef typename Arc::Weight Weight;
283  typedef typename Arc::StateId StateId;
284  std::vector<Weight> distance;
285  if (Weight::Properties() & fst::kRightSemiring) {
286  mertfst::ShortestDistance (fst, &distance, false);
287  Weight sum = Weight::Zero();
288  for (StateId s = 0; s < distance.size(); ++s)
289  sum = Plus (sum, Times (distance[s], fst.Final (s) ) );
290  return sum;
291  } else {
292  mertfst::ShortestDistance (fst, &distance, true);
293  StateId s = fst.Start();
294  return s != fst::kNoStateId && s < distance.size() ?
295  distance[s] : Weight::Zero();
296  }
297 }
298 
299 } // namespace fst
300 
301 #endif // MERT_FST_LIB_SHORTEST_DISTANCE_H__
MertOpt opts
Definition: MertCommon.cpp:14
ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src=fst::kNoStateId, float d=fst::kDelta)
bool ApproxEqual(const TropicalSparseTupleWeight< T > &vw1, const TropicalSparseTupleWeight< T > &vw2, float delta=kDelta)
Definition: fstio.hpp:27
TropicalSparseTupleWeight< T > Plus(const TropicalSparseTupleWeight< T > &vw1, const TropicalSparseTupleWeight< T > &vw2)
ShortestDistanceState(const fst::Fst< Arc > &fst, std::vector< Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts, bool retain)
TropicalSparseTupleWeight< T > Times(const TropicalSparseTupleWeight< T > &w1, const TropicalSparseTupleWeight< T > &w2)
void ShortestDistance(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, const ShortestDistanceOptions< Arc, Queue, ArcFilter > &opts)