Cambridge SMT System
fstutils.hpp
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use these files except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 
13 // Copyright 2012 - Gonzalo Iglesias, AdriĆ  de Gispert, William Byrne
14 
21 #ifndef FSTUTILS_HPP
22 #define FSTUTILS_HPP
23 
24 namespace fst {
25 
27 inline float ZPosInfinity() {
28 #if OPENFSTVERSION>=1003002 //1.3.2
29  return fst::FloatLimits<float>::PosInfinity();
30 #else
31  return fst::FloatLimits<float>::kPosInfinity; //Up to OpenFST 1.3.1
32 #endif
33 }
34 
41 template<class Arc >
42 inline void extractSourceVocabulary ( const fst::VectorFst<Arc>& myfst,
43  unordered_set<std::string> *vcb ) {
44  USER_CHECK ( vcb, "NULL pointer not accepted" );
45  using fst::StateIterator;
46  using fst::VectorFst;
47  using fst::ArcIterator;
48  typedef typename Arc::StateId StateId;
49  for ( StateIterator< VectorFst<Arc> > si ( myfst ); !si.Done(); si.Next() ) {
50  StateId state_id = si.Value();
51  for ( ArcIterator< VectorFst<Arc> > ai ( myfst, si.Value() ); !ai.Done();
52  ai.Next() ) {
53  Arc arc = ai.Value();
54  vcb->insert ( toString ( arc.ilabel ) );
55  }
56  }
57 };
58 
66 template<class Arc >
67 inline void extractSourceVocabulary ( const fst::VectorFst<Arc>& myfst,
68  unordered_set<unsigned> *vcb,
69  unsigned offset = 0) {
70  USER_CHECK ( vcb, "NULL pointer not accepted" );
71  using fst::StateIterator;
72  using fst::VectorFst;
73  using fst::ArcIterator;
74  typedef typename Arc::StateId StateId;
75  for ( StateIterator< VectorFst<Arc> > si ( myfst ); !si.Done(); si.Next() ) {
76  StateId state_id = si.Value();
77  for ( ArcIterator< VectorFst<Arc> > ai ( myfst, si.Value() ); !ai.Done();
78  ai.Next() ) {
79  vcb->insert ( ai.Value().ilabel + offset );
80  }
81  }
82 };
83 
90 template<class Arc>
91 void extractTargetVocabulary ( const fst::VectorFst<Arc>& myfst,
92  unordered_set<std::string> *vcb ) {
93  USER_CHECK ( vcb, "NULL pointer not accepted" );
94  typedef typename Arc::StateId StateId;
95  using fst::StateIterator;
96  using fst::VectorFst;
97  using fst::ArcIterator;
98  for ( StateIterator< VectorFst<Arc> > si ( myfst ); !si.Done(); si.Next() ) {
99  StateId state_id = si.Value();
100  for ( ArcIterator< VectorFst<Arc> > ai ( myfst, si.Value() ); !ai.Done();
101  ai.Next() ) {
102  Arc arc = ai.Value();
103  vcb->insert ( ucam::util::toString ( arc.olabel ) );
104  }
105  }
106 };
107 
117 template<class Arc>
118 void buildSubstringTransducer ( fst::VectorFst<Arc> *myfst ) {
119  USER_CHECK ( myfst, "NULL pointer not accepted" );
120  USER_CHECK ( myfst->NumStates(), "Number of states is zero!" );
121  typedef typename Arc::StateId StateId;
122  fst::TopSort ( myfst );
123  fst::Map ( myfst, fst::RmWeightMapper<Arc>() );
124  for ( fst::StateIterator< fst::VectorFst<Arc> > si ( *myfst ); !si.Done();
125  si.Next() ) {
126  StateId state_id = si.Value();
127  if ( state_id ) {
128  myfst->AddArc ( 0, Arc ( 0, 0, Arc::Weight::One(), si.Value() ) );
129  myfst->SetFinal ( state_id, Arc::Weight::One() );
130  }
131  }
132  fst::RmEpsilon ( myfst );
133 };
134 
148 template <class Arc>
149 inline fst::VectorFst<Arc> *EncodeDeterminizeMinimizeDecode (
150  fst::VectorFst<Arc> *myfst ) {
151  fst::EncodeMapper<Arc> em ( fst::kEncodeLabels, fst::ENCODE );
152  fst::Encode ( myfst, &em ); //note that this modifies fst1
153  LDBG_EXECUTE ( myfst->Write ( "fsts/encoded.fst" ) );
154  fst::VectorFst<Arc> *fst2 = new fst::VectorFst<Arc>;
155  fst::Determinize ( *myfst, fst2 );
156  fst::Minimize ( fst2 );
157  fst::EncodeMapper<Arc> em2 ( em,
158  fst::DECODE ); //create em2 by copying em but now with decode option.
159  fst::Decode ( fst2, em2 );
160  return fst2;
161 };
162 
163 template <class Arc>
164 void EncodeDeterminizeMinimizeDecode ( fst::Fst<Arc> const& myfst ,
165  fst::VectorFst<Arc> *out ) {
166  out->DeleteStates();
167  *out = (myfst);
168  fst::EncodeMapper<Arc> em ( fst::kEncodeLabels, fst::ENCODE );
169  fst::Encode ( out, &em );
170  LDBG_EXECUTE ( out->Write ( "fsts/encoded.fst" ) );
171  fst::VectorFst<Arc> fst2;
172  fst::Determinize ( *out , &fst2 );
173  fst::Minimize ( &fst2 );
174  fst::EncodeMapper<Arc> em2 ( em,
175  fst::DECODE ); //create em2 by copying em but now with decode option.
176  fst::Decode ( &fst2, em2 );
177  out->DeleteStates();
178  *out = (fst2);
179 };
180 
187 template<class Arc,
188  class CharTypeT,
189  class StringTypeT
190  >
191 inline std::basic_string<CharTypeT>
192 FstGetBestHypothesis(const fst::VectorFst<Arc> &latfst) {
193  using namespace fst;
194  using namespace std;
195  VectorFst<Arc> hypfst;
196  ShortestPath(latfst, &hypfst);
197  Project(&hypfst, PROJECT_INPUT);
198  RmEpsilon(&hypfst);
199  TopSort(&hypfst);
200  basic_string<CharTypeT> hypstr;
201  for (StateIterator< VectorFst<Arc> > si(hypfst); !si.Done();
202  si.Next()) {
203  for (ArcIterator< VectorFst<Arc> > ai(hypfst, si.Value());
204  !ai.Done(); ai.Next()) {
205  stringstream ss;
206  ss << ai.Value().ilabel;
207  StringTypeT value; ss >> value;
208  hypstr += value;
209  }
210  }
211  return hypstr;
212 };
213 
214 
215 //basic_string to vector helper:
216 template<class Arc,
217  class CharTypeT>
218 void FstGetBestHypothesis(const fst::VectorFst<Arc> &latfst
219  , std::vector<CharTypeT> &hyp) {
220 
221  std::basic_string<CharTypeT> aux = FstGetBestHypothesis<Arc,CharTypeT, CharTypeT>(latfst);
222  hyp.clear();
223  hyp.resize(aux.size());
224  std::copy(aux.begin(), aux.end(), hyp.begin());
225 }
226 
227 //helper with std::string (spaces between numbers)
228 template<class Arc>
229 void FstGetBestStringHypothesis(const fst::VectorFst<Arc> &latfst
230  , std::string &hyp) {
231  std::basic_string<unsigned> aux = FstGetBestHypothesis<Arc,unsigned, unsigned>(latfst);
232  hyp.clear();
233  for (unsigned k =0; k < aux.size(); ++k){
234  std::stringstream ss; ss << aux[k];
235  hyp += ss.str() + " ";
236  }
237 }
238 
239 
252 template<class Arc>
253 inline void printstrings ( const fst::VectorFst<Arc>& pcostslat,
254  std::ostream *hyps, unsigned s = 0 ) {
255  static std::basic_string<unsigned> ihyp;
256  static std::basic_string<unsigned> ohyp;
257  static std::basic_string<typename Arc::Weight> cost;
258  if ( !USER_CHECK ( s < pcostslat.NumStates(),
259  "The state most surely doesn't exist! Topsort this lattice and make sure you use a valid s" ) )
260  return;
261  typename Arc::Weight fvalue = Arc::Weight::One();
262  fst::ArcIterator< fst::VectorFst<Arc> > i ( pcostslat, s );
263  if ( !i.Done() ) {
264  for ( ; !i.Done(); i.Next() ) {
265  Arc a = i.Value();
266  ihyp.push_back ( a.ilabel );
267  ohyp.push_back ( a.olabel );
268  cost.push_back ( a.weight );
269  printstrings ( pcostslat, hyps, a.nextstate ); // recursive call to next state.
270  ihyp.resize ( ihyp.size() - 1 );
271  ohyp.resize ( ohyp.size() - 1 );
272  cost.resize ( cost.size() - 1 );
273  };
274  fvalue = pcostslat.Final ( s );
275  if ( fvalue == Arc::Weight::Zero() ) return;
276  }
277  for ( unsigned k = 0; k < ihyp.size();
278  ++k ) if ( ihyp[k] != 0 ) *hyps << ihyp[k] << " ";
279  *hyps << "|| ";
280  for ( unsigned k = 0; k < ihyp.size();
281  ++k ) if ( ohyp[k] != 0 ) *hyps << ohyp[k] << " ";
282  *hyps << "|| ";
283  typename Arc::Weight c = fvalue;
284  for ( unsigned k = 0; k < cost.size(); ++k ) {
285  c = Times ( c, cost[k] );
286  }
287  *hyps << c << endl;
288 };
289 
291 struct hypcost {
292  std::string hyp;
293  float cost;
294 
295 };
296 
298 class CompareHyp {
299  public:
300  bool operator() ( struct hypcost& h1, hypcost& h2 ) {
301  if ( h1.cost > h2.cost ) return true;
302  return false;
303  }
304 };
305 
307 template<class Arc>
308 inline void printstrings ( const fst::VectorFst<Arc>& fst
309  , unordered_map<std::string, float>& finalhyps
310  , bool input = true ) {
311  fst::VectorFst<Arc> ofst ( fst );
312  fst::TopSort ( &ofst ); //topological order is needed for the algorithm to work.
313  std::vector <std::vector<struct hypcost> > partialhyps;
314  partialhyps.clear();
315  finalhyps.clear();
316  struct hypcost hcempty;
317  hcempty.cost = 0;
318  std::vector<struct hypcost> vhcempty;
319  partialhyps.resize ( ofst.NumStates() );
320  partialhyps[0].push_back ( hcempty );
321  for ( fst::StateIterator< fst::MutableFst<Arc> > si ( ofst ); !si.Done();
322  si.Next() ) {
323  typename Arc::Weight value = ofst.Final ( si.Value() );
324  std::vector<struct hypcost>& hypc = partialhyps[ ( unsigned ) si.Value()];
325  for ( unsigned k = 0; k < hypc.size(); ++k ) {
326  if ( value != Arc::Weight::Zero() ) {
327  finalhyps[partialhyps[ ( unsigned ) si.Value()][k].hyp] =
328  partialhyps[ ( unsigned ) si.Value()][k].cost + ( float ) value.Value();
329  }
330  for ( fst::MutableArcIterator< fst::MutableFst<Arc> > ai ( &ofst, si.Value() );
331  !ai.Done(); ai.Next() ) {
332  struct hypcost hcnew = hypc[k];
333  Arc arc = ai.Value();
334  if ( input
335  && arc.ilabel != 0 ) hcnew.hyp += " " + ucam::util::toString ( arc.ilabel );
336  else if ( !input
337  && arc.olabel != 0 ) hcnew.hyp += " " + ucam::util::toString ( arc.olabel );
338  hcnew.cost += ( float ) arc.weight.Value();
339  partialhyps[arc.nextstate].push_back ( hcnew );
340  }
341  }
342  partialhyps[ ( unsigned ) si.Value()] =
343  vhcempty; //won't need these partial hyps any more
344  }
345 };
346 
347 inline unsigned ShortestPathLength (const fst::VectorFst<fst::StdArc>* fst) {
348  using fst::StdArc;
349  using fst::VectorFst;
350  VectorFst<StdArc> tmp;
351  fst::ShortestPath (*fst, &tmp);
352  fst::RmEpsilon (&tmp);
353  unsigned n = 0;
354  for (fst::StateIterator< VectorFst<StdArc> > si (tmp); !si.Done(); si.Next() ) {
355  for (fst::ArcIterator< VectorFst<StdArc> > ai (tmp, si.Value() ); !ai.Done();
356  ai.Next() ) {
357  n++;
358  }
359  }
360  return n;
361 }
362 
363 template <class Arc>
364 inline fst::VectorFst<Arc>* PushWeightsToFinal (const fst::VectorFst<Arc>*
365  fst) {
366  fst::VectorFst<Arc>* tmp = new fst::VectorFst<Arc>;
367  fst::Push<Arc, fst::REWEIGHT_TO_FINAL> (*fst, tmp, fst::kPushWeights);
368  return tmp;
369 }
370 
371 inline fst::VectorFst<fst::LogArc>* StdToLog (const fst::VectorFst<fst::StdArc>*
372  fst) {
373  fst::VectorFst<fst::LogArc>* tmp = new fst::VectorFst<fst::LogArc>;
374  fst::Map (*fst, tmp, fst::StdToLogMapper() );
375  return tmp;
376 }
377 
378 inline fst::VectorFst<fst::StdArc>* LogToStd (const fst::VectorFst<fst::LogArc>*
379  fst) {
380  using fst::StdArc;
381  using fst::VectorFst;
382  VectorFst<StdArc>* tmp = new VectorFst<StdArc>;
383  fst::Map (*fst, tmp, fst::LogToStdMapper() );
384  return tmp;
385 }
386 
387 inline fst::VectorFst<fst::StdArc>* FstScaleWeights (fst::VectorFst<fst::StdArc>*
388  fst
389  , const double scale) {
390  using fst::StdArc;
391  using fst::VectorFst;
392  VectorFst<StdArc>* fstscaled = fst->Copy();
393  for (fst::StateIterator< VectorFst<StdArc> > si (*fstscaled); !si.Done();
394  si.Next() ) {
395  for (fst::MutableArcIterator< VectorFst<StdArc> > ai (fstscaled, si.Value() );
396  !ai.Done(); ai.Next() ) {
397  StdArc arc = ai.Value();
398  arc.weight = static_cast<StdArc::Weight> (arc.weight.Value() * scale);
399  ai.SetValue (arc);
400  }
401  StdArc::Weight final = fstscaled->Final (si.Value() );
402  if (final != ZPosInfinity() ) {
403  fstscaled->SetFinal (si.Value(),
404  static_cast<StdArc::Weight> (final.Value() * scale) );
405  }
406  }
407  return fstscaled;
408 }
409 
410 #define UNIT_COST_POSITIVE 1
411 #define UNIT_COST_NEGATIVE -1
412 
413 template<class Arc>
414 inline void GetMinAndMaxHypothesisLength (const fst::VectorFst<Arc>* fst,
415  unsigned& jMin, unsigned& jMax) {
416  using fst::Map;
417  using fst::RmWeightMapper;
418  using fst::TimesMapper;
419  fst::VectorFst<Arc>* tmp = fst->Copy();
420  Map (tmp, RmWeightMapper<Arc, Arc>() );
421  Map (tmp, TimesMapper<Arc> ( UNIT_COST_POSITIVE ) );
422  jMin = ShortestPathLength (tmp);
423  Map (tmp, RmWeightMapper<Arc, Arc>() );
424  Map (tmp, TimesMapper<Arc> ( UNIT_COST_NEGATIVE ) );
425  jMax = ShortestPathLength (tmp);
426  delete tmp;
427 }
428 
429 inline fst::StdArc::Label GetFirstUnusedLabelId (const
430  fst::VectorFst<fst::StdArc>* fst) {
431  using fst::StdArc;
432  fst::StdArc::Label x = 0;
433  for (fst::StateIterator< fst::VectorFst<StdArc> > si (*fst); !si.Done();
434  si.Next() ) {
435  for (fst::ArcIterator< fst::VectorFst<StdArc> > ai (*fst, si.Value() );
436  !ai.Done(); ai.Next() ) {
437  if (ai.Value().ilabel > x) {
438  x = ai.Value().ilabel;
439  }
440  }
441  }
442  return x + 1;
443 }
444 
445 inline void SetFinalStateCost (fst::MutableFst<fst::StdArc>* fst,
446  const fst::StdArc::Weight cost) {
447  for (fst::StateIterator< fst::MutableFst<fst::StdArc> > si (*fst); !si.Done();
448  si.Next() ) {
449  if (fst->Final (si.Value() ) != ZPosInfinity() ) {
450  fst->SetFinal (si.Value(), cost);
451  }
452  }
453 }
454 
462 template<class Arc>
463 inline void string2fst (const std::string& sidxwords,
464  fst::VectorFst<Arc> *fst,
465  const std::string& tidxwords = "",
466  typename Arc::Weight finalweight = Arc::Weight::One()
467  ) {
468  assert (sidxwords != "");
469  std::vector<std::string> swords;
470  boost::algorithm::split (swords, sidxwords,
471  boost::algorithm::is_any_of ( " " ) );
472  fst->AddState();
473  fst->SetStart (0);
474  for (unsigned k = 0; k < swords.size(); ++k) {
475  typename Arc::Label swidx = ucam::util::toNumber<unsigned> (swords[k]);
476  fst->AddState();
477  fst->AddArc (k, Arc (swidx, 0, Arc::Weight::One(), k + 1) );
478  }
479  if (tidxwords == "") {
480  fst->SetFinal (swords.size(), Arc::Weight::One() );
481  fst::Project<Arc> (fst,
482  fst::PROJECT_INPUT); //if only source provided, we assume an automaton.
483  fst->SetFinal (swords.size(), finalweight);
484  return;
485  }
486  std::vector<std::string> twords;
487  boost::algorithm::split (twords, tidxwords,
488  boost::algorithm::is_any_of ( " " ) );
489  for (unsigned k = swords.size(); k < swords.size() + twords.size(); ++k) {
490  typename Arc::Label twidx = ucam::util::toNumber<unsigned>
491  (twords[k - swords.size()]);
492  fst->AddState();
493  fst->AddArc (k, Arc (0, twidx, Arc::Weight::One(), k + 1) );
494  }
495  fst->SetFinal (swords.size() + twords.size(), finalweight);
496  fst::Determinize (fst::RmEpsilonFst<Arc> (*fst), fst);
497  fst::Minimize (fst);
498  fst::RmEpsilon (fst);
499 };
500 
502 template<class Arc>
503 class RelabelUtil {
504 
505  private:
506  std::vector<pair <typename Arc::Label, typename Arc::Label> > ipairs;
507  std::vector<pair <typename Arc::Label, typename Arc::Label> > opairs;
508  public:
510 
511  inline RelabelUtil& addIPL (typename Arc::Label labelfind,
512  typename Arc::Label labelreplace) {
513  ipairs.push_back (pair <typename Arc::Label, typename Arc::Label> (labelfind,
514  labelreplace) );
515  return *this;
516  };
517  inline RelabelUtil& addOPL (typename Arc::Label labelfind,
518  typename Arc::Label labelreplace) {
519  opairs.push_back (pair <typename Arc::Label, typename Arc::Label> (labelfind,
520  labelreplace) );
521  return *this;
522  };
523  inline RelabelUtil& operator() (fst::VectorFst<Arc> *hypfst) {
524  fst::Relabel (hypfst, ipairs, opairs);
525  return *this;
526  }
527  inline fst::VectorFst<Arc>& operator() (fst::VectorFst<Arc>& hypfst) {
528  fst::Relabel (&hypfst, ipairs, opairs);
529  return hypfst;
530  }
531 
532 };
533 
534 } // end namespace
535 
536 #endif
fst::VectorFst< fst::LogArc > * StdToLog(const fst::VectorFst< fst::StdArc > *fst)
Definition: fstutils.hpp:371
std::string toString(const T &x, uint pr=2)
Converts an arbitrary type to string Converts to string integers, floats, doubles Quits execution if ...
void FstGetBestStringHypothesis(const fst::VectorFst< Arc > &latfst, std::string &hyp)
Definition: fstutils.hpp:229
fst::VectorFst< fst::StdArc > * LogToStd(const fst::VectorFst< fst::LogArc > *fst)
Definition: fstutils.hpp:378
Definition: fstio.hpp:27
#define UNIT_COST_POSITIVE
Definition: fstutils.hpp:410
#define LDBG_EXECUTE(order)
RelabelUtil & addIPL(typename Arc::Label labelfind, typename Arc::Label labelreplace)
Definition: fstutils.hpp:511
void SetFinalStateCost(fst::MutableFst< fst::StdArc > *fst, const fst::StdArc::Weight cost)
Definition: fstutils.hpp:445
fst::TropicalWeightTpl< F > Map(double)
unsigned ShortestPathLength(const fst::VectorFst< fst::StdArc > *fst)
Definition: fstutils.hpp:347
fst::VectorFst< Arc > * EncodeDeterminizeMinimizeDecode(fst::VectorFst< Arc > *myfst)
Encodes, determinizes, minimizes and decodes an fst.
Definition: fstutils.hpp:149
void string2fst(const std::string &sidxwords, fst::VectorFst< Arc > *fst, const std::string &tidxwords="", typename Arc::Weight finalweight=Arc::Weight::One())
Convenience method that creates an fsa/fst from one/two string(s) of numbers.
Definition: fstutils.hpp:463
void printstrings(const fst::VectorFst< Arc > &pcostslat, std::ostream *hyps, unsigned s=0)
Trivial function that outputs all the hypothesis in the lattice with its cost.
Definition: fstutils.hpp:253
std::string hyp
Definition: fstutils.hpp:292
void GetMinAndMaxHypothesisLength(const fst::VectorFst< Arc > *fst, unsigned &jMin, unsigned &jMax)
Definition: fstutils.hpp:414
void buildSubstringTransducer(fst::VectorFst< Arc > *myfst)
Builds substring version of an fst. This is a destructive implementation.
Definition: fstutils.hpp:118
Utility functor for relabeling one or more lattices. Note that you can chain commands. See Unit test in fstutils.gtest.cpp for an example.
Definition: fstutils.hpp:503
Struct for priority queue comparison.
Definition: fstutils.hpp:291
#define UNIT_COST_NEGATIVE
Definition: fstutils.hpp:411
void extractTargetVocabulary(const fst::VectorFst< Arc > &myfst, unordered_set< std::string > *vcb)
Extract target (right-side) vocabulary from an fst.
Definition: fstutils.hpp:91
void extractSourceVocabulary(const fst::VectorFst< Arc > &myfst, unordered_set< std::string > *vcb)
Extract source (left-side) vocabulary from an fst.
Definition: fstutils.hpp:42
RelabelUtil & addOPL(typename Arc::Label labelfind, typename Arc::Label labelreplace)
Definition: fstutils.hpp:517
float ZPosInfinity()
Just a wrapper to maintain compatibility with OpenFST 1.3.1, last version using kPosInfinity constant...
Definition: fstutils.hpp:27
Class used by priority queue to compare two hypotheses and decide which one wins. ...
Definition: fstutils.hpp:298
fst::VectorFst< fst::StdArc > * FstScaleWeights(fst::VectorFst< fst::StdArc > *fst, const double scale)
Definition: fstutils.hpp:387
#define USER_CHECK(exp, comment)
Tests whether exp is true. If not, comment is printed and program ends.
TropicalSparseTupleWeight< T > Times(const TropicalSparseTupleWeight< T > &w1, const TropicalSparseTupleWeight< T > &w2)
fst::VectorFst< Arc > * PushWeightsToFinal(const fst::VectorFst< Arc > *fst)
Definition: fstutils.hpp:364
fst::StdArc::Label GetFirstUnusedLabelId(const fst::VectorFst< fst::StdArc > *fst)
Definition: fstutils.hpp:429
std::basic_string< CharTypeT > FstGetBestHypothesis(const fst::VectorFst< Arc > &latfst)
Takes the 1-best of an fst and converts to string.
Definition: fstutils.hpp:192