Cambridge SMT System
fstutils.applylmonthefly.hpp
Go to the documentation of this file.
1 // Licensed under the Apache License, Version 2.0 (the "License");
2 // you may not use these files except in compliance with the License.
3 // You may obtain a copy of the License at
4 //
5 // http://www.apache.org/licenses/LICENSE-2.0
6 //
7 // Unless required by applicable law or agreed to in writing, software
8 // distributed under the License is distributed on an "AS IS" BASIS,
9 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 // See the License for the specific language governing permissions and
11 // limitations under the License.
12 
13 // Copyright 2012 - Gonzalo Iglesias, AdriĆ  de Gispert, William Byrne
14 
15 #ifndef FSTUTILS_APPLYLMONTHEFLY_HPP
16 #define FSTUTILS_APPLYLMONTHEFLY_HPP
17 
25 #include <idbridge.hpp>
26 #include <lm/wrappers/nplm.hh>
27 namespace fst {
28 
29 
30 template <class StateT>
31 struct StateHandler {
32  inline void setLength(unsigned length) {}
33  inline unsigned getLength(StateT const &state) { return state.length;}
34 };
35 
36 // specialization to handle nplm
37 template<>
38 struct StateHandler<lm::np::State> {
39  unsigned length_;
40  inline void setLength(unsigned length) { length_ = length;}
41  inline unsigned getLength(lm::np::State const &state) const { return length_; }
42 };
43 
44 template<class StateT>
45 struct Scale {
46  float scale_;
47  Scale(float scale): scale_(scale) {}
48  float operator()(){
49  return scale_;
50  }
51 };
52 
53 // nplm in kenlm is providing by default
54 // natural logs, so correct this
55 template<>
56 struct Scale<lm::np::State> {
57  float scale_;
58  Scale(float scale): scale_(scale / ::log(10)) {}
59  float operator()(){
60  return scale_;
61  }
62 };
63 
64 // This silly little hack allows to make it srilm-compatible
65 // i.e. log10 scores should always match
66 // assumes internal 1/2 numbers are <s> / </s>
67 // if you don't care, just implement a functor that does nothing
68 // and use it instead of this one
69 template<class StateT>
70 struct HackScore {
71  void operator()(float &w, float &wp, unsigned label, StateT&) {
72  if ( label <= 2 ) {
73  wp = 0;
74  if (label == 1 ) w = 0; //get same result as srilm
75  }
76  }
77 };
78 
79 // But nplm does not do exactly the same thing because
80 // the start state is actually different
81 template<>
82 struct HackScore<lm::np::State> {
83  void operator()(float &w, float &wp, unsigned label, lm::np::State &ns) {
84  if ( label <= 2 ) {
85  wp = 0;
86  if (label==1) {
87  w = 0;
88  // set up correctly the next state
89  for (int k = 0; k < NPLM_MAX_ORDER - 1; ++k) {
90  ns.words[k]=1;
91  }
92  }
93  }
94  }
95 };
96 
97 template<class StateT>
99  void operator()(float &w, float &wp, unsigned label, StateT&) {
100  LERROR("Use only for nplm in bilingual models!");
101  exit(EXIT_FAILURE);
102  }
103 };
104 
105 
106 template<>
107 struct HackScoreBilingual<lm::np::State> {
108  void operator()(float &w, float &wp, unsigned label, lm::np::State &ns) {
109  if ( label <= 2 ) {
110  wp = 0;
111  if (label==1) {
112  for (int k = 0; k < NPLM_MAX_ORDER - 1; ++k) {
113  ns.words[k]=1;
114  }
115  }
116  }
117  }
118 };
119 
120 // nplm states have size NPLM_MAX_ORDER - 1
121 template<class StateT>
122 std::string printDebug(StateT const &state) {
123  std::string o =ucam::util::toString<unsigned>(state.words[0]);
124  for (unsigned k = 1; k < NPLM_MAX_ORDER - 1; ++k) {
125  o += "," + ucam::util::toString<unsigned>(state.words[k]);
126  }
127  return o;
128 };
129 
130 template<class StateT
131  , class KenLMModelT
132  , class IdBridgeT
133  , template<class> class HackScoreT
134  >
135 struct Scorer {
136  IdBridgeT const& idbridge_;
137  KenLMModelT& lmmodel_;
138  HackScoreT<StateT> hs_;
140  explicit Scorer(KenLMModelT &lmmodel
141  , IdBridgeT const &idbridge
142  , Scale<StateT> &nl
143  , unsigned
144  , std::vector<std::vector<unsigned> > const &
145  )
146  : idbridge_(idbridge)
147  , lmmodel_(lmmodel)
148  , natlog10_(nl)
149  {
150  // LERROR("Bilingual model scorer only works with nplm models");
151  // exit(EXIT_FAILURE);
152  }
153 
154  void operator()(StateT const &current, float &w, float &wp, int ilabel, int olabel, StateT& next) {
155  w = lmmodel_.Score ( current, idbridge_.map(olabel), next ) * natlog10_();
156  hs_(w, wp, olabel, next); //hack to make it srilm/nplm compliant
157  }
158 };
159 
160 // partial template specialization class for nplm model
161 // TODO: code reorganization/simplification needed
162 template< class IdBridgeT
163  , template<class> class HackScoreT
164  >
165 struct Scorer<lm::np::State, lm::np::Model
166  , IdBridgeT, HackScoreT> {
167  IdBridgeT const& idbridge_;
168  typedef lm::np::Model KenLMModelT;
169  typedef lm::np::State StateT;
170  KenLMModelT & lmmodel_;
171  HackScoreT<StateT> hs_;
173  unsigned srcSize_;
174 
175  std::vector< std::vector<unsigned> > srcWindows_;
176 
177  Scorer(KenLMModelT &lmmodel
178  , IdBridgeT const &idbridge
179  , Scale<StateT> &nl
180  , unsigned srcSize
181  , std::vector<std::vector<unsigned> > const &srcWindows
182  )
183  : idbridge_(idbridge)
184  , lmmodel_(lmmodel)
185  , natlog10_(nl)
186  , srcSize_(srcSize)
187  {
188  srcWindows_.clear();
189  // need to map source labels into nplm internal labels.
190  srcWindows_.resize(srcWindows.size());
191  for (unsigned k = 0; k < srcWindows.size(); ++k) {
192  srcWindows_[k].resize(srcWindows[k].size());
193  for (unsigned j = 0; j < srcWindows[k].size(); ++j) {
194  srcWindows_[k][j] = idbridge_.map(srcWindows[k][j]);
195  }
196  }
197 #ifdef PRINTDEBUG1
198  // For debugging purposes, lets print mapped vectors here:
199  for (unsigned k = 0; k < srcWindows_.size(); ++k) {
200  std::stringstream ss; ss << srcWindows_[k][0];
201  for (unsigned j = 1; j < srcWindows_[k].size(); ++j ){
202  ss << " " << srcWindows_[k][j];
203  }
204  LDEBUG("*** MAPPED src words=" << ss.str());
205  }
206 #endif
207  };
208 
209  // For bilingual model
210  // This assumes that ilabels are indices or pointers to the source sentence
211  // olabels are actual target words.
212  void operator()(StateT const &current, float &w, float &wp, int ilabel
213  , int olabel, StateT& next) {
214  // make a copy of current state and add source window.
215  StateT c2 = current;
216  LDEBUG("Current model state=" << printDebug(c2));
217  if (srcWindows_.size()) {
218  --ilabel;
219  if (ilabel >= srcWindows_.size() || ilabel < 0) {
220  LERROR("Wrong input label! Input labels should be links/affiliations "
221  << "pointing to words in the source sentence:" << ilabel);
222  exit(EXIT_FAILURE);
223  }
224 #ifdef PRINTDEBUG1
225  std::stringstream ss;
226 #endif
227  for (int k = 0; k < srcSize_; ++k) {
228 #ifdef PRINTDEBUG1
229  ss << ilabel << "," << k << ": " << srcWindows_[ilabel][k] << "\t";
230 #endif
231  c2.words[k] = srcWindows_[ilabel][k];
232  }
233  LDEBUG(ss.str() << "\nilabel="
234  << ilabel << ", Current model state after adding source="
235  << printDebug(c2));
236  }
237 
238  unsigned ol = idbridge_.mapOutput(olabel);
239  w = lmmodel_.Score ( c2, ol, next ) * natlog10_();
240  LDEBUG("Mapped olabel=" << olabel << " to "
241  << ol
242  << ", score=" << w);
243 
244  LDEBUG("Next state is (kenlm) =" << printDebug(next));
245  hs_(w, wp, olabel, next); //hack to make it srilm/nplm compliant
246  LDEBUG("Next state is (after hack) =" << printDebug(next));
247  // finally, update nextstate by taking current state,
248  // adding olabel and sliding the window
249  // (only target model context)
250  // this could go into the hackscore class?
251 
252  // kenlm has shifted the state, but now i take out
253  // the source words in the input
254  for (int k = 0; k < srcSize_; ++k) next.words[k] = 0;
255  unsigned oli = idbridge_.map(olabel);
256  bool failure = false;
257  int offset = (int) lmmodel_.Order() - 2;
258  if (next.words[offset] != ol) {
259  LERROR("Problem:" << offset << "=>" << next.words[offset]);
260  failure = true;
261  }
262  // ... and map the last word with the input vocabulary (instead of output)
263  // input and output vocabularies are not necessarily the same
264  // and this is particularly true in a bilingual model!
265  next.words[offset] = oli;
266  LDEBUG("*** " << printDebug(c2)
267  << "\t" << printDebug(revertMappings(c2))
268  << "\t" << w << "\t" << printDebug(next)
269  << "\t" << printDebug(revertMappings(next))
270  << "\t" << olabel << "\t i=" << oli << ", o=" << ol
271  );
272  if (failure) exit(EXIT_FAILURE);
273  }
274 
275  StateT revertMappings(StateT const &s) {
276  StateT x;
277  for (int k = 0; k < NPLM_MAX_ORDER; ++k) {
278  x.words[k] = idbridge_.rmap(s.words[k]);
279  }
280  return x;
281  }
282 };
283 
289 template<class ArcT>
291  virtual VectorFst<ArcT> *run(VectorFst<ArcT> const& fst) = 0;
292  virtual VectorFst<ArcT> *run(VectorFst<ArcT> const& fst
293  , unordered_set<typename ArcT::Label> const &epsilons) = 0;
294  virtual VectorFst<ArcT> *run(const VectorFst<ArcT>& fst
295  , unsigned srcSize
296  , std::vector< std::vector<unsigned> > &srcWindows) =0;
298 };
299 
307 template <class Arc
308  , class MakeWeightT = MakeWeight<Arc>
309  , class KenLMModelT = lm::ngram::Model
310  , class IdBridgeT = ucam::fsttools::IdBridge
311  , template<class> class HackScoreT = HackScore >
313  private:
314  typedef typename Arc::StateId StateId;
315  typedef typename Arc::Label Label;
316  typedef typename Arc::Weight Weight;
317  typedef unsigned long long ull;
318 
319  unordered_map< ull, StateId > stateexistence_;
320 
321  static const ull sid = 1000000000;
323  unordered_map<uint64_t
324  , pair<StateId, typename KenLMModelT::State > > statemap_;
326 
327  unordered_map<basic_string<unsigned>
328  , StateId
330  , ucam::util::hasheqvecuint> seenlmstates_;
331 
333  queue<StateId> qc_;
334 
336  unordered_set<Label> epsilons_;
337 
338  KenLMModelT& lmmodel_;
339  const typename KenLMModelT::Vocabulary& vocab_;
340 
341  // float natlog10_;
343 
345  MakeWeightT mw_;
346 
347  basic_string<unsigned> history;
348  unsigned *buffer;
349  unsigned buffersize;
350 
351  //Word Penalty.
352  float wp_;
353 
354  const IdBridgeT& idbridge_;
355 
356  // transparent state handling (kenlm vs nplm)
358  // transparent score quirks handling for srilm/nplm compliance
359  HackScoreT<typename KenLMModelT::State> hs_;
360 
362  public:
363 
365  inline void setMakeWeight ( const MakeWeightT& mw ) {
366  mw_ = mw;
367  };
368 
379  ApplyLanguageModelOnTheFly ( KenLMModelT& model
380  , unordered_set<Label>& epsilons
381  , bool natlog
382  , float lmscale
383  , float lmwp
384  , const IdBridgeT& idbridge
385  , MakeWeightT &mw
386  )
387  : natlog10_ ( natlog ? -lmscale* ::log ( 10.0 ) : -lmscale )
388  , lmmodel_ ( model )
389  , vocab_ ( model.GetVocabulary() )
390  , wp_ ( lmwp )
391  , epsilons_ ( epsilons )
392  , history ( model.Order(), 0)
393  , idbridge_ (idbridge)
394  , mw_(mw)
395  {
396  init();
397  };
398 
399  ApplyLanguageModelOnTheFly ( KenLMModelT& model
400  , bool natlog
401  , float lmscale
402  , float lmwp
403  , const IdBridgeT& idbridge
404  , MakeWeightT &mw
405  )
406  : natlog10_ ( natlog ? -lmscale* ::log ( 10.0 ) : -lmscale )
407  , lmmodel_ ( model )
408  , vocab_ ( model.GetVocabulary() )
409  , wp_ ( lmwp )
410  , history ( model.Order(), 0)
411  , idbridge_ (idbridge)
412  , mw_(mw)
413  {
414  init();
415  };
416 
417  void init() {
418  LDEBUG("Model order=" << (int) lmmodel_.Order());
419  sh_.setLength(lmmodel_.Order() );
420  buffersize = (lmmodel_.Order() - 1 ) * sizeof (unsigned);
421  buffer = const_cast<unsigned *> ( history.c_str() );
422  }
423 
426 
427  VectorFst<Arc> *run(const VectorFst<Arc>& fst) {
428  return this->operator()(fst);
429  }
430  VectorFst<Arc> *run(const VectorFst<Arc>& fst
431  , unordered_set<Label> const &epsilons) {
432  epsilons_ = epsilons; // this may be necessary e.g. for pdts
433  return this->run(fst);
434  }
435 
436  VectorFst<Arc> *run(const VectorFst<Arc>& fst
437  , unsigned size
438  , std::vector<std::vector<unsigned> > &srcWindows) {
439  // epsilons_ = epsilons; // this may be necessary e.g. for pdts
440  return this->operator()(fst,size, srcWindows);
441  };
442 
443 
445  VectorFst<Arc> * operator() (const VectorFst<Arc>& fst) {
446  unsigned ign = 0;
447  std::vector<std::vector<unsigned> > empty;
449  sc(lmmodel_, idbridge_, natlog10_, ign, empty );
450  return doComposition(fst, sc);
451  }
452 
453  // Composition with bilm
454  // note that input labels are indices to source windows
455  VectorFst<Arc> * operator() (const VectorFst<Arc>& fst, unsigned srcSize
456  , std::vector<std::vector<unsigned> > &srcw) {
457 
458  // the scorer will help compute the correct score regardless of whether it is a
459  // bilingual model or not, etc.
461  sc(lmmodel_, idbridge_, natlog10_ , srcSize, srcw);
462  return doComposition(fst, sc);
463  }
464 
465  private:
466 
471  VectorFst<Arc> * doComposition(const VectorFst<Arc>& fst
473 
474  if (!fst.NumStates() ) {
475  LWARN ("Empty lattice. ... Skipping LM application!");
476  return NULL;
477  }
478  VectorFst<Arc> *composed = new VectorFst<Arc>;
480  typename KenLMModelT::State bs = lmmodel_.NullContextState();
481  pair<StateId, bool> nextp = add ( composed, bs, fst.Start(), fst.Final ( fst.Start() ) );
482  qc_.push ( nextp.first );
483  composed->SetStart ( nextp.first );
484  while ( qc_.size() ) {
485  LDEBUG("queue size=" << qc_.size());
486  StateId s = qc_.front();
487  qc_.pop();
488  pair<StateId, const typename KenLMModelT::State> p = get ( s );
489  StateId& s1 = p.first;
490  const typename KenLMModelT::State s2 = p.second;
491 
492  for ( ArcIterator< VectorFst<Arc> > arc1 ( fst, s1 ); !arc1.Done();
493  arc1.Next() ) {
494  const Arc& a1 = arc1.Value();
495  float w = 0;
496  float wp = wp_;
497  typename KenLMModelT::State nextlmstate;
498  if ( epsilons_.find ( a1.olabel ) == epsilons_.end() ) {
499  sc(s2, w, wp, a1.ilabel, a1.olabel, nextlmstate);
500  } else {
501  // ignore epsilons completely, even if we have alignments here.
502  nextlmstate = s2;
503  wp = 0; //We don't count epsilon labels
504  }
505  pair<StateId, bool> nextp = add ( composed, nextlmstate
506  , a1.nextstate
507  , fst.Final ( a1.nextstate ) );
508  StateId& newstate = nextp.first;
509  bool visited = nextp.second;
510  composed->AddArc ( s, Arc( a1.ilabel, a1.olabel
511  , Times ( a1.weight, Times (mw_ ( w ) , mw_ (wp) ) )
512  , newstate ) );
513  //Finally, only add newstate to the queue if it hasn't been visited previously
514  if ( !visited ) {
515  qc_.push ( newstate );
516  }
517  }
518  }
519  LINFO ( "Done! Number of states=" << composed->NumStates() );
520  stateexistence_.clear();
521  statemap_.clear();
522  seenlmstates_.clear();
523  history.resize( lmmodel_.Order(), 0);
524  return composed;
525  };
526 
531  inline pair <StateId, bool> add ( fst::VectorFst<Arc> *composed, typename KenLMModelT::State& m2nextstate,
532  StateId m1nextstate, Weight m1stateweight ) {
533  static StateId lm = 0;
534  getIdx ( m2nextstate );
536  if ( seenlmstates_.find ( history ) == seenlmstates_.end() ) {
537  seenlmstates_[history] = ++lm;
538  }
539  uint64_t compound = m1nextstate * sid + seenlmstates_[history];
540  LDEBUG ( "compound id=" << compound );
541  if ( stateexistence_.find ( compound ) == stateexistence_.end() ) {
542  LDEBUG ( "New State!" );
543  statemap_[composed->NumStates()] =
544  pair<StateId, const typename KenLMModelT::State > ( m1nextstate, m2nextstate );
545  composed->AddState();
546  LDEBUG("Added..." << composed->NumStates() << "," << m1nextstate << "," << printDebug(m2nextstate));
547  if ( m1stateweight != mw_ ( ZPosInfinity() ) ) composed->SetFinal (
548  composed->NumStates() - 1, m1stateweight );
549  stateexistence_[compound] = composed->NumStates() - 1;
550  return pair<StateId, bool> ( composed->NumStates() - 1, false );
551  }
552  return pair<StateId, bool> ( stateexistence_[compound], true );
553  };
554 
559  inline void getIdx ( const typename KenLMModelT::State& state,
560  uint order = 4 ) {
561  LDEBUG("getting Idx");
562  memcpy ( buffer, state.words, buffersize );
563  // for ( uint k = state.length; k < history.size(); ++k ) history[k] = 0;
564  for ( unsigned k = sh_.getLength(state); k < history.size(); ++k ) history[k] = 0;
565 
566  };
567 
569  inline pair<StateId, typename KenLMModelT::State > get ( StateId state ) {
570  LDEBUG("get");
571  return statemap_[state];
572  };
573 
574 };
575 
576 } // end namespaces
577 
578 #endif
579 
Scorer(KenLMModelT &lmmodel, IdBridgeT const &idbridge, Scale< StateT > &nl, unsigned, std::vector< std::vector< unsigned > > const &)
Scorer(KenLMModelT &lmmodel, IdBridgeT const &idbridge, Scale< StateT > &nl, unsigned srcSize, std::vector< std::vector< unsigned > > const &srcWindows)
void run(ucam::util::RegistryPO const &rg)
void operator()(StateT const &current, float &w, float &wp, int ilabel, int olabel, StateT &next)
VectorFst< Arc > * run(const VectorFst< Arc > &fst, unsigned size, std::vector< std::vector< unsigned > > &srcWindows)
unsigned getLength(StateT const &state)
#define LINFO(msg)
Definition: fstio.hpp:27
ApplyLanguageModelOnTheFly(KenLMModelT &model, unordered_set< Label > &epsilons, bool natlog, float lmscale, float lmwp, const IdBridgeT &idbridge, MakeWeightT &mw)
#define LDEBUG(msg)
ApplyLanguageModelOnTheFly(KenLMModelT &model, bool natlog, float lmscale, float lmwp, const IdBridgeT &idbridge, MakeWeightT &mw)
void operator()(float &w, float &wp, unsigned label, lm::np::State &ns)
void setMakeWeight(const MakeWeightT &mw)
Public methods.
void operator()(StateT const &current, float &w, float &wp, int ilabel, int olabel, StateT &next)
unsigned getLength(lm::np::State const &state) const
Scale< StateT > & natlog10_
void operator()(float &w, float &wp, unsigned label, StateT &)
HashFVec< std::basic_string< unsigned > > hashfvecuint
maps between grammar targets ids and lm ids
Templated functor that creates a weight given a float.
Class that applies language model on the fly using kenlm.
#define LWARN(msg)
Interface for language model application Provides different run methods to do composition with a (bil...
HackScoreT< StateT > hs_
void setLength(unsigned length)
float ZPosInfinity()
Just a wrapper to maintain compatibility with OpenFST 1.3.1, last version using kPosInfinity constant...
Definition: fstutils.hpp:27
void operator()(float &w, float &wp, unsigned label, lm::np::State &ns)
VectorFst< Arc > * run(const VectorFst< Arc > &fst)
TropicalSparseTupleWeight< T > Times(const TropicalSparseTupleWeight< T > &w1, const TropicalSparseTupleWeight< T > &w2)
IdBridgeT const & idbridge_
#define LERROR(msg)
std::string printDebug(StateT const &state)
void operator()(float &w, float &wp, unsigned label, StateT &)
VectorFst< Arc > * run(const VectorFst< Arc > &fst, unordered_set< Label > const &epsilons)