15 #ifndef FSTUTILS_APPLYLMONTHEFLY_HPP 16 #define FSTUTILS_APPLYLMONTHEFLY_HPP 26 #include <lm/wrappers/nplm.hh> 30 template <
class StateT>
33 inline unsigned getLength(StateT
const &state) {
return state.length;}
40 inline void setLength(
unsigned length) { length_ = length;}
41 inline unsigned getLength(lm::np::State
const &state)
const {
return length_; }
44 template<
class StateT>
47 Scale(
float scale): scale_(scale) {}
58 Scale(
float scale): scale_(scale / ::log(10)) {}
69 template<
class StateT>
71 void operator()(
float &w,
float &wp,
unsigned label, StateT&) {
74 if (label == 1 ) w = 0;
83 void operator()(
float &w,
float &wp,
unsigned label, lm::np::State &ns) {
89 for (
int k = 0; k < NPLM_MAX_ORDER - 1; ++k) {
97 template<
class StateT>
99 void operator()(
float &w,
float &wp,
unsigned label, StateT&) {
100 LERROR(
"Use only for nplm in bilingual models!");
108 void operator()(
float &w,
float &wp,
unsigned label, lm::np::State &ns) {
112 for (
int k = 0; k < NPLM_MAX_ORDER - 1; ++k) {
121 template<
class StateT>
123 std::string o =ucam::util::toString<unsigned>(state.words[0]);
124 for (
unsigned k = 1; k < NPLM_MAX_ORDER - 1; ++k) {
125 o +=
"," + ucam::util::toString<unsigned>(state.words[k]);
130 template<
class StateT
133 ,
template<
class>
class HackScoreT
141 , IdBridgeT
const &idbridge
144 , std::vector<std::vector<unsigned> >
const &
146 : idbridge_(idbridge)
154 void operator()(StateT
const ¤t,
float &w,
float &wp,
int ilabel,
int olabel, StateT& next) {
155 w = lmmodel_.Score ( current, idbridge_.map(olabel), next ) * natlog10_();
156 hs_(w, wp, olabel, next);
162 template<
class IdBridgeT
163 ,
template<
class>
class HackScoreT
166 , IdBridgeT, HackScoreT> {
178 , IdBridgeT
const &idbridge
181 , std::vector<std::vector<unsigned> >
const &srcWindows
183 : idbridge_(idbridge)
190 srcWindows_.resize(srcWindows.size());
191 for (
unsigned k = 0; k < srcWindows.size(); ++k) {
192 srcWindows_[k].resize(srcWindows[k].size());
193 for (
unsigned j = 0; j < srcWindows[k].size(); ++j) {
194 srcWindows_[k][j] = idbridge_.map(srcWindows[k][j]);
199 for (
unsigned k = 0; k < srcWindows_.size(); ++k) {
200 std::stringstream ss; ss << srcWindows_[k][0];
201 for (
unsigned j = 1; j < srcWindows_[k].size(); ++j ){
202 ss <<
" " << srcWindows_[k][j];
204 LDEBUG(
"*** MAPPED src words=" << ss.str());
212 void operator()(StateT
const ¤t,
float &w,
float &wp,
int ilabel
213 ,
int olabel, StateT& next) {
217 if (srcWindows_.size()) {
219 if (ilabel >= srcWindows_.size() || ilabel < 0) {
220 LERROR(
"Wrong input label! Input labels should be links/affiliations " 221 <<
"pointing to words in the source sentence:" << ilabel);
225 std::stringstream ss;
227 for (
int k = 0; k < srcSize_; ++k) {
229 ss << ilabel <<
"," << k <<
": " << srcWindows_[ilabel][k] <<
"\t";
231 c2.words[k] = srcWindows_[ilabel][k];
233 LDEBUG(ss.str() <<
"\nilabel=" 234 << ilabel <<
", Current model state after adding source=" 238 unsigned ol = idbridge_.mapOutput(olabel);
239 w = lmmodel_.Score ( c2, ol, next ) * natlog10_();
240 LDEBUG(
"Mapped olabel=" << olabel <<
" to " 245 hs_(w, wp, olabel, next);
254 for (
int k = 0; k < srcSize_; ++k) next.words[k] = 0;
255 unsigned oli = idbridge_.map(olabel);
256 bool failure =
false;
257 int offset = (int) lmmodel_.Order() - 2;
258 if (next.words[offset] != ol) {
259 LERROR(
"Problem:" << offset <<
"=>" << next.words[offset]);
265 next.words[offset] = oli;
270 <<
"\t" << olabel <<
"\t i=" << oli <<
", o=" << ol
272 if (failure) exit(EXIT_FAILURE);
277 for (
int k = 0; k < NPLM_MAX_ORDER; ++k) {
278 x.words[k] = idbridge_.rmap(s.words[k]);
291 virtual VectorFst<ArcT> *
run(VectorFst<ArcT>
const&
fst) = 0;
292 virtual VectorFst<ArcT> *
run(VectorFst<ArcT>
const& fst
293 , unordered_set<typename ArcT::Label>
const &epsilons) = 0;
294 virtual VectorFst<ArcT> *
run(
const VectorFst<ArcT>& fst
296 , std::vector< std::vector<unsigned> > &srcWindows) =0;
309 ,
class KenLMModelT = lm::ngram::Model
311 ,
template<
class>
class HackScoreT =
HackScore >
314 typedef typename Arc::StateId StateId;
315 typedef typename Arc::Label Label;
316 typedef typename Arc::Weight Weight;
317 typedef unsigned long long ull;
319 unordered_map< ull, StateId > stateexistence_;
321 static const ull sid = 1000000000;
323 unordered_map<uint64_t
324 , pair<StateId, typename KenLMModelT::State > > statemap_;
327 unordered_map<basic_string<unsigned>
336 unordered_set<Label> epsilons_;
338 KenLMModelT& lmmodel_;
339 const typename KenLMModelT::Vocabulary& vocab_;
347 basic_string<unsigned> history;
354 const IdBridgeT& idbridge_;
359 HackScoreT<typename KenLMModelT::State> hs_;
380 , unordered_set<Label>& epsilons
384 ,
const IdBridgeT& idbridge
387 : natlog10_ ( natlog ? -lmscale* ::log ( 10.0 ) : -lmscale )
389 , vocab_ ( model.GetVocabulary() )
391 , epsilons_ ( epsilons )
392 , history ( model.Order(), 0)
393 , idbridge_ (idbridge)
403 ,
const IdBridgeT& idbridge
406 : natlog10_ ( natlog ? -lmscale* ::log ( 10.0 ) : -lmscale )
408 , vocab_ ( model.GetVocabulary() )
410 , history ( model.Order(), 0)
411 , idbridge_ (idbridge)
418 LDEBUG(
"Model order=" << (
int) lmmodel_.Order());
420 buffersize = (lmmodel_.Order() - 1 ) *
sizeof (
unsigned);
421 buffer =
const_cast<unsigned *
> ( history.c_str() );
427 VectorFst<Arc> *
run(
const VectorFst<Arc>&
fst) {
428 return this->operator()(fst);
430 VectorFst<Arc> *
run(
const VectorFst<Arc>&
fst 431 , unordered_set<Label>
const &epsilons) {
432 epsilons_ = epsilons;
433 return this->
run(fst);
436 VectorFst<Arc> *
run(
const VectorFst<Arc>&
fst 438 , std::vector<std::vector<unsigned> > &srcWindows) {
440 return this->operator()(fst,size, srcWindows);
445 VectorFst<Arc> * operator() (
const VectorFst<Arc>&
fst) {
447 std::vector<std::vector<unsigned> > empty;
449 sc(lmmodel_, idbridge_, natlog10_, ign, empty );
450 return doComposition(fst, sc);
455 VectorFst<Arc> * operator() (
const VectorFst<Arc>&
fst,
unsigned srcSize
456 , std::vector<std::vector<unsigned> > &srcw) {
461 sc(lmmodel_, idbridge_, natlog10_ , srcSize, srcw);
462 return doComposition(fst, sc);
471 VectorFst<Arc> * doComposition(
const VectorFst<Arc>&
fst 474 if (!fst.NumStates() ) {
475 LWARN (
"Empty lattice. ... Skipping LM application!");
478 VectorFst<Arc> *composed =
new VectorFst<Arc>;
480 typename KenLMModelT::State bs = lmmodel_.NullContextState();
481 pair<StateId, bool> nextp = add ( composed, bs, fst.Start(), fst.Final ( fst.Start() ) );
482 qc_.push ( nextp.first );
483 composed->SetStart ( nextp.first );
484 while ( qc_.size() ) {
485 LDEBUG(
"queue size=" << qc_.size());
486 StateId s = qc_.front();
488 pair<StateId, const typename KenLMModelT::State> p =
get ( s );
489 StateId& s1 = p.first;
490 const typename KenLMModelT::State s2 = p.second;
492 for ( ArcIterator< VectorFst<Arc> > arc1 ( fst, s1 ); !arc1.Done();
494 const Arc& a1 = arc1.Value();
497 typename KenLMModelT::State nextlmstate;
498 if ( epsilons_.find ( a1.olabel ) == epsilons_.end() ) {
499 sc(s2, w, wp, a1.ilabel, a1.olabel, nextlmstate);
505 pair<StateId, bool> nextp = add ( composed, nextlmstate
507 , fst.Final ( a1.nextstate ) );
508 StateId& newstate = nextp.first;
509 bool visited = nextp.second;
510 composed->AddArc ( s, Arc( a1.ilabel, a1.olabel
511 ,
Times ( a1.weight,
Times (mw_ ( w ) , mw_ (wp) ) )
515 qc_.push ( newstate );
519 LINFO (
"Done! Number of states=" << composed->NumStates() );
520 stateexistence_.clear();
522 seenlmstates_.clear();
523 history.resize( lmmodel_.Order(), 0);
531 inline pair <StateId, bool> add ( fst::VectorFst<Arc> *composed,
typename KenLMModelT::State& m2nextstate,
532 StateId m1nextstate, Weight m1stateweight ) {
533 static StateId
lm = 0;
534 getIdx ( m2nextstate );
536 if ( seenlmstates_.find ( history ) == seenlmstates_.end() ) {
537 seenlmstates_[history] = ++lm;
539 uint64_t compound = m1nextstate * sid + seenlmstates_[history];
540 LDEBUG (
"compound id=" << compound );
541 if ( stateexistence_.find ( compound ) == stateexistence_.end() ) {
543 statemap_[composed->NumStates()] =
544 pair<StateId, const typename KenLMModelT::State > ( m1nextstate, m2nextstate );
545 composed->AddState();
546 LDEBUG(
"Added..." << composed->NumStates() <<
"," << m1nextstate <<
"," <<
printDebug(m2nextstate));
547 if ( m1stateweight != mw_ (
ZPosInfinity() ) ) composed->SetFinal (
548 composed->NumStates() - 1, m1stateweight );
549 stateexistence_[compound] = composed->NumStates() - 1;
550 return pair<StateId, bool> ( composed->NumStates() - 1, false );
552 return pair<StateId, bool> ( stateexistence_[compound], true );
559 inline void getIdx (
const typename KenLMModelT::State& state,
562 memcpy ( buffer, state.words, buffersize );
564 for (
unsigned k = sh_.
getLength(state); k < history.size(); ++k ) history[k] = 0;
569 inline pair<StateId, typename KenLMModelT::State >
get ( StateId state ) {
571 return statemap_[state];
lm::np::Model KenLMModelT
Scorer(KenLMModelT &lmmodel, IdBridgeT const &idbridge, Scale< StateT > &nl, unsigned, std::vector< std::vector< unsigned > > const &)
Scorer(KenLMModelT &lmmodel, IdBridgeT const &idbridge, Scale< StateT > &nl, unsigned srcSize, std::vector< std::vector< unsigned > > const &srcWindows)
void run(ucam::util::RegistryPO const &rg)
~ApplyLanguageModelOnTheFly()
Destructor.
void operator()(StateT const ¤t, float &w, float &wp, int ilabel, int olabel, StateT &next)
VectorFst< Arc > * run(const VectorFst< Arc > &fst, unsigned size, std::vector< std::vector< unsigned > > &srcWindows)
unsigned getLength(StateT const &state)
ApplyLanguageModelOnTheFly(KenLMModelT &model, unordered_set< Label > &epsilons, bool natlog, float lmscale, float lmwp, const IdBridgeT &idbridge, MakeWeightT &mw)
void setLength(unsigned length)
ApplyLanguageModelOnTheFly(KenLMModelT &model, bool natlog, float lmscale, float lmwp, const IdBridgeT &idbridge, MakeWeightT &mw)
StateT revertMappings(StateT const &s)
void operator()(float &w, float &wp, unsigned label, lm::np::State &ns)
void setMakeWeight(const MakeWeightT &mw)
Public methods.
Scale< StateT > & natlog10_
void operator()(StateT const ¤t, float &w, float &wp, int ilabel, int olabel, StateT &next)
IdBridgeT const & idbridge_
unsigned getLength(lm::np::State const &state) const
Scale< StateT > & natlog10_
void operator()(float &w, float &wp, unsigned label, StateT &)
HashFVec< std::basic_string< unsigned > > hashfvecuint
maps between grammar targets ids and lm ids
Templated functor that creates a weight given a float.
Class that applies language model on the fly using kenlm.
Interface for language model application Provides different run methods to do composition with a (bil...
void setLength(unsigned length)
float ZPosInfinity()
Just a wrapper to maintain compatibility with OpenFST 1.3.1, last version using kPosInfinity constant...
virtual ~ApplyLanguageModelOnTheFlyInterface()
void operator()(float &w, float &wp, unsigned label, lm::np::State &ns)
VectorFst< Arc > * run(const VectorFst< Arc > &fst)
TropicalSparseTupleWeight< T > Times(const TropicalSparseTupleWeight< T > &w1, const TropicalSparseTupleWeight< T > &w2)
IdBridgeT const & idbridge_
std::string printDebug(StateT const &state)
void operator()(float &w, float &wp, unsigned label, StateT &)
VectorFst< Arc > * run(const VectorFst< Arc > &fst, unordered_set< Label > const &epsilons)
std::vector< std::vector< unsigned > > srcWindows_