Cambridge SMT System
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use these files except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 //
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
14 // Copyright 2012 - Gonzalo Iglesias, AdriĆ  de Gispert, William Byrne
16 #ifndef TASK_HIFST_HPP
17 #define TASK_HIFST_HPP
29 #include "task.hifst.rtn.hpp"
30 #include "task.hifst.optimize.hpp"
33 namespace ucam {
34 namespace hifst {
40 template <class Data ,
41  class Arc = fst::LexStdArc ,
42  class OptimizeT = OptimizeMachine<Arc> ,
43  class CYKdataT = CYKdata ,
44  // class MultiUnionT = fst::MultiUnionRational<Arc> ,
45  class MultiUnionT = fst::MultiUnionReplace<Arc> ,
46  class ExpandedNumStatesRTNT = ExpandedNumStatesRTN<Arc> ,
47  class ReplaceFstByArcT = ManualReplaceFstByArc<Arc> ,
48  class RTNT = RTN<Arc>
49  >
50 class HiFSTTask: public ucam::util::TaskInterface<Data> {
51  typedef typename Arc::Label Label;
52  typedef typename Arc::Weight Weight;
54  //Private variables are shown here. Private methods go after public methods
55  private:
58  // unsigned sc_;
59  unsigned piscount_;
60  std::set<Label> hieroindexexistence_;
62  OptimizeT optimize;
67  bool hipdtmode_;
70  bool rtnopt_;
73  Data *d_;
76  ReplaceFstByArcT *rfba_;
77  unordered_set<std::string> replacefstbyarc_;
78  unordered_set<std::string> replacefstbyarcexceptions_;
79  unsigned replacefstbynumstates_;
82  CYKdataT *cykdata_;
85  bool aligner_;
96  std::vector<std::pair<Label, Label> > pdtparens_;
100  RTNT *rtn_;
102  ExpandedNumStatesRTNT *rtnnumstates_;
105  std::vector< std::pair< Label, const fst::Fst<Arc> * > > pairlabelfsts_;
108  fst::VectorFst<Arc> cykfstresult_;
111  bool localprune_;
113  unsigned numlocallm_;
116  std::vector<std::string> lpctuples_;
125  float pruneweight_;
127  //where to store rtn files
130  //Input/output keys
131  const std::string lmkey_;
132  const std::string locallmkey_;
133  const std::string outputkey_;
134  const std::string fullreferencelatticekey_;
136  //To avoid multiple warnings on logs per sentence
137  bool warned_;
139  // If in translation after cell pruning the number of states of a lattice is bigger than numstatethresholdafterpruning_,
140  // then the lattice will not be determinized/minimized.
141  unsigned numstatesthreshold_;
143  //If false, no determinization/minimization will be applied anywhere to any of the components of the RTN, expanded or not.
144  bool optimize_;
145  const ucam::util::RegistryPO& rg_;
146  // const int localLmPos_;
147  enum AlignmentType {RULES, AFFILIATION};
148  AlignmentType at_;
149  public:
153  const std::string& outputkey = HifstConstants::kHifstLatticeStore,
154  const std::string& locallmkey = HifstConstants::kHifstLocalpruneLmLoad,
155  const std::string& fullreferencelatticekey =
157  const std::string& lmkey = HifstConstants::kLmLoad
158  ) :
159  optimize_ (rg.getBool (HifstConstants::kHifstOptimizecells) ),
160  numlocallm_ (rg.getVectorString (locallmkey).size() ),
161  warned_ (false),
162  rtnfiles_ (rg.get<std::string> (HifstConstants::kHifstWritertn) ),
163  fullreferencelatticekey_ ( fullreferencelatticekey ),
164  lmkey_ ( lmkey ),
165  locallmkey_ ( locallmkey ),
166  outputkey_ ( outputkey ),
167  piscount_ ( 0 ),
168  aligner_ ( rg.getBool ( HifstConstants::kHifstAlilatsmode ) ),
169  // cellredm_ ( rg.getBool ( "hifst.cellredm" ) ),
170  // finalredm_ ( rg.getBool ( "hifst.finalredm" ) ),
171  hipdtmode_ (rg.getBool (HifstConstants::kHifstUsepdt) ),
172  rtnopt_ (rg.getBool (HifstConstants::kHifstRtnopt) ),
173  replacefstbyarc_ ( rg.getSetString (
175  replacefstbyarcexceptions_ ( rg.getSetString (
177  replacefstbynumstates_ ( rg.get<unsigned>
179  localprune_ ( rg.getBool ( HifstConstants::kHifstLocalpruneEnable ) ),
180  pruneweight_ ( rg.get<float> ( HifstConstants::kHifstPrune ) ),
181  numstatesthreshold_ ( rg.get<unsigned>
183  lpctuples_ ( rg.getVectorString (
185  mw_(rg),
186  at_(RULES),
187  rg_(rg)
188  // localLmPos_(rg.getVectorString(HifstConstants::kLmFeatureweights).size() + 1 + 1)
189  {
191  LINFO ("Number of local language models=" << numlocallm_);
192  LINFO ("aligner mode=" << aligner_);
193  LINFO ("localprune mode=" << localprune_);
194  LINFO("reference filtering with: " << rg_.get<std::string> (HifstConstants::kReferencefilterLoad));
195  USER_CHECK ( ! ( lpc_.size() % 4 ),
196  "local pruning conditions are defined by tuples of 4 elements: category,x,y,Number-of-states. Category is a string and x,y are int. Number of states is unsigned" );
197  USER_CHECK ( (localprune_ && numlocallm_ )
198  || ( localprune_ && !numlocallm_ && rg_.get<std::string> (HifstConstants::kReferencefilterLoad) != "" )
199  || (!localprune_) ,
200  "If you want to do cell pruning in translation, you should normally use a language model for local pruning. Check --hifst.localprune.lm.load and --hifst.localprune.enable.\n");
201  optimize.setAlignMode (aligner_);
205  if (hipdtmode_) {
206  LINFO ("Hipdt mode enabled!");
207  }
208  if (!rtnopt_) {
209  LINFO ("RTN openfst optimizations will not be applied");
210  }
212  if (rg.get<std::string>(HifstConstants::kHifstAlilatsmodeLinks) == "affiliation") {
213  at_ = AFFILIATION;
214  }
216  LDEBUG ( "Hifst constructor done!" );
217  };
224  bool run ( Data& d ) {
225  cykfstresult_.DeleteStates();
226  this->d_ = &d;
227  hieroindexexistence_.clear();
228  LINFO ( "Running HiFST" );
229  //Reset one-time warnings for inexistent language models.
230  warned_ = false;
231  pdtparens_.clear();
232  cykdata_ = d.cykdata;
233  if ( !USER_CHECK ( cykdata_, "cyk parse has not been executed previously?" ) ) {
234  resetExternalData (d);
235  return true;
236  }
237  if ( d.cykdata->success == CYK_RETURN_FAILURE ) {
239  fst::VectorFst<Arc> aux;
240  d.fsts[outputkey_] = &cykfstresult_;
241  d.vcat = cykdata_->vcat;
242  resetExternalData (d);
243  return false;
244  }
246  initLocalConditions();
247  rtn_ = new RTNT;
248  if ( localprune_ )
249  rtnnumstates_ = new ExpandedNumStatesRTNT;
250  rfba_ = new ReplaceFstByArcT ( cykdata_->vcat, replacefstbyarc_,
251  replacefstbyarcexceptions_, aligner_, replacefstbynumstates_ );
252  piscount_ = 0; //reset pruning-in-search count to 0
253  LINFO ( "Second Pass: FST-building!" );
254  d.stats->setTimeStart ( "lattice-construction" );
255  //Owned by rtn_;
256  fst::Fst<Arc> *sfst = buildRTN ( cykdata_->categories["S"], 0,
257  cykdata_->sentence.size() - 1 ).ptr_;
258  d.stats->setTimeEnd ( "lattice-construction" );
259  cykfstresult_ = (*sfst);
260  LINFO ( "Final - RTN head optimizations !" );
261  optimize ( &cykfstresult_ ,
262  std::numeric_limits<unsigned>::max() ,
263  !hipdtmode_ && optimize_
264  );
265  FORCELINFO ("Stats for Sentence " << d.sidx <<
266  ": local pruning, number of times=" << piscount_);
267  d.stats->lpcount = piscount_; //store local pruning counts in stats
268  LINFO ("RTN expansion starts now!");
269  //Expand...
270  {
272  Label hieroindex = APBASETAG + 1 * APCCTAG + 0 * APXTAG +
273  ( cykdata_->sentence.size() - 1 ) * APYTAG;
274  if ( hieroindexexistence_.find ( hieroindex ) == hieroindexexistence_.end() )
275  pairlabelfsts_.push_back ( pair< Label, const fst::Fst<Arc> * > ( hieroindex,
276  &cykfstresult_ ) );
279 #if OPENFSTVERSION>=1005000
280  fst::ReplaceUtilOptions ruopt(hieroindex, !aligner_);
281  fst::ReplaceUtil<Arc> replace_util (pairlabelfsts_, ruopt);
282 #elif OPENFSTVERSION>=1004000
283  fst::ReplaceUtilOptions<Arc> ruopt(hieroindex, !aligner_);
284  fst::ReplaceUtil<Arc> replace_util (pairlabelfsts_, ruopt);
285 #else
286  fst::ReplaceUtil<Arc> replace_util (pairlabelfsts_, hieroindex
287  , !aligner_); //has ownership of modified rtn fsts
288 #endif
289  if (rtnopt_) {
290  LINFO ("rtn optimizations...");
291  d_->stats->setTimeStart ("replace-opts");
292  replace_util.ReplaceTrivial();
293  replace_util.ReplaceUnique();
294  replace_util.Connect();
295  pairlabelfsts_.clear();
296  replace_util.GetFstPairs (&pairlabelfsts_);
297  d_->stats->setTimeEnd ("replace-opts");
298  }
299  //After optimizations, we can write RTN if required by user
300  writeRTN();
301  boost::scoped_ptr< fst::VectorFst<Arc> > efst (new fst::VectorFst<Arc>);
302  if (!hipdtmode_ ) {
303  LINFO ("Final Replace (RTN->FSA), main index=" << hieroindex);
304  d_->stats->setTimeStart ("replace-rtn-final");
305  Replace (pairlabelfsts_, &*efst, hieroindex, !aligner_);
306  d_->stats->setTimeEnd ("replace-rtn-final");
307  } else {
308  LINFO ("Final Replace (RTN->PDA)");
309  d_->stats->setTimeStart ("replace-pdt-final");
310  Replace (pairlabelfsts_, &*efst, &pdtparens_, hieroindex);
311  d_->stats->setTimeEnd ("replace-pdt-final");
312  LINFO ("Number of pdtparens=" << pdtparens_.size() );
313  }
314  LDBG_EXECUTE ( efst->Write ( "fsts/FINAL-e.fst" ) );
315  // Currently no need to call this applyFilters: it will do the same
316  // and it is more efficient to compose with the normal lattice
317  // rather than the substringed lattice.
318  // LINFO ("Removing Epsilons...");
319  // fst::RmEpsilon<Arc> ( &*efst );
320  // LINFO ("Done! NS=" << efst->NumStates() );
321  // applyFilters ( &*efst );
322  //Compose with full reference lattice to ensure that final lattice is correct.
323  if ( d.fsts.find ( fullreferencelatticekey_ ) != d.fsts.end() ) {
324  if ( static_cast< fst::VectorFst<Arc> * >
325  (d.fsts[fullreferencelatticekey_])->NumStates() > 0 ) {
326  LINFO ( "Composing with full reference lattice, NS=" <<
327  static_cast< fst::VectorFst<Arc> * >
328  (d.fsts[fullreferencelatticekey_])->NumStates() );
329  fst::Compose<Arc> ( *efst,
330  * ( static_cast<fst::VectorFst<Arc> * > (d.fsts[fullreferencelatticekey_]) ),
331  &*efst );
332  LINFO ( "After composition: NS=" << efst->NumStates() );
333  } else {
334  LINFO ( "No composition with full ref lattice" );
335  };
336  } else {
337  LINFO ( "No composition with full ref lattice" );
338  };
339  LDBG_EXECUTE ( efst->Write ( "fsts/FINAL-ef.fst" ) );
340  //Apply language model
341  fst::VectorFst<Arc> *res = NULL;
342  if (efst->NumStates() )
343  res = applyLanguageModel ( *efst );
344  else {
345  LWARN ("Empty lattice -- skipping LM application");
346  }
347  if ( res != NULL ) {
348  boost::shared_ptr<fst::VectorFst<Arc> >latlm ( res );
349  if ( latlm.get() == efst.get() ) {
350  LWARN ( "Yikes! Unexpected situation! Will it crash? (muhahaha) " );
351  }
352  LDBG_EXECUTE ( latlm->Write ( "fsts/FINAL-efc.fst" ) );
353  //Todo: union with shortest path...
354  if ( pruneweight_ < std::numeric_limits<float>::max() ) {
355  if (!hipdtmode_ || pdtparens_.empty() ) {
356  LINFO ("Pruning, weight=" << pruneweight_);
357  fst::Prune<Arc> (*latlm, &cykfstresult_, mw_ ( pruneweight_ ) );
358  } else {
359  LINFO ("Expanding, weight=" << pruneweight_);
360  fst::ExpandOptions<Arc> eopts (true, false, mw_ ( pruneweight_ ) );
361  Expand ( *latlm, pdtparens_, &cykfstresult_, eopts);
362  pdtparens_.clear();
363  }
364  } else {
365  LINFO ("Copying through full lattice with lm scores");
366  cykfstresult_ = *latlm;
367  }
368  } else {
369  LINFO ("Copying through full lattice (no lm)");
370  cykfstresult_ = *efst;
371  }
372  if ( hieroindexexistence_.find ( hieroindex ) == hieroindexexistence_.end() )
373  pairlabelfsts_.pop_back();
374  }
375  pairlabelfsts_.clear();
376  LDBG_EXECUTE ( cykfstresult_.Write ( "fsts/FINAL-efcp.fst" ) );
377  LINFO ( "Reps" );
378  fst::RmEpsilon ( &cykfstresult_ );
379  LDBG_EXECUTE ( cykfstresult_.Write ( "fsts/FINAL-efcpr.fst" ) );
380  LINFO ( "NS=" << cykfstresult_.NumStates() );
381  //This should delete all pertinent fsas...
382  LINFO ( "deleting data stuff..." );
383  delete rtn_;
384  if ( localprune_ )
385  delete rtnnumstates_;
386  delete rfba_;
387  d.vcat = cykdata_->vcat;
388  resetExternalData (d);
389  d.fsts[outputkey_] = &cykfstresult_;
390  if (hipdtmode_ && pdtparens_.size() )
391  d.fsts[outputkey_ + ".parens" ] = &pdtparens_;
392  LINFO ( "done..." );
393  FORCELINFO ( "End Sentence ******************************************************" );
394  d.stats->setTimeEnd ( "sent-dec" );
395  d.stats->message += "[" + ucam::util::getTimestamp() + "] End Sentence\n";
396  return false;
397  };
399  private:
402  inline void resetExternalData (Data& d) {
403  cykdata_->freeMemory();
404  d.tvcb.clear();
405  d.filters.clear();
406  }
409  void writeRTN() {
410  //Dump to disk all the FSAs for this RTN.
411  if (rtnfiles_() != "") {
412  std::string filenamepattern = rtnfiles_ (d_->sidx);
413  FORCELINFO ("Writing rtn files..." << filenamepattern);
414  for (unsigned k = 0; k < pairlabelfsts_.size(); ++k) {
415  std::string filename = filenamepattern;
416  ucam::util::find_and_replace (filename, "%%rtn_label%%"
417  , ucam::util::toString<Label> (pairlabelfsts_[k].first) );
418  fst::FstWrite (static_cast< fst::VectorFst<Arc> const& > (*
419  (pairlabelfsts_[k].second) ), filename);
420  }
421  }
422  };
424  struct FSAPlusInfo {
425  fst::Fst<Arc>* ptr_;
426  unsigned cc_;
427  unsigned x_;
428  unsigned y_;
429  explicit FSAPlusInfo(fst::Fst<Arc>* p
430  , unsigned cc
431  , unsigned x
432  , unsigned y
433  )
434  : ptr_(p)
435  , cc_(cc)
436  , x_(x)
437  , y_(y)
438  {}
440  explicit FSAPlusInfo()
441  : ptr_(NULL)
442  , cc_(0)
443  , x_(0)
444  , y_(0)
445  {}
446  };
455  // inline void mapfsts ( unsigned int rule_idx,
456  // std::vector < fst::Fst < Arc > * >& fsts ) {
457  inline void mapfsts ( unsigned int rule_idx,
458  std::vector < FSAPlusInfo >& fsts ) {
459  unordered_map<unsigned int, unsigned int > mappings;
460  d_->ssgd->getMappings ( rule_idx, &mappings );
461  USER_CHECK ( fsts.size() == mappings.size(),
462  "Mismatch between mappings and lower-level fsts" );
463  LDEBUG ( "mappings size=" << mappings.size() );
464  // std::vector<fst::Fst<Arc>* > newfsts ( fsts.size(), NULL );
465  std::vector<FSAPlusInfo> newfsts(fsts.size());
466  for ( unsigned int k = 0; k < fsts.size(); k++ ) {
467  newfsts[mappings[k]] = fsts[k];
468  // mmmm lack of foresight here, copy original y_:
469  newfsts[mappings[k]].x_ = fsts[mappings[k]].x_;
470  newfsts[mappings[k]].y_ = fsts[mappings[k]].y_;
471  newfsts[mappings[k]].cc_ = fsts[mappings[k]].cc_;
472  }
473  fsts = newfsts;
474  };
487  // fst::Fst<Arc>*
488  FSAPlusInfo buildRTN ( unsigned int cc, unsigned int x, unsigned int y ) {
489  FSAPlusInfo fpi( ( *rtn_ ) ( cc, x, y ), cc, x, y);
490  // fst::Fst<Arc> *ptr = ( *rtn_ ) ( cc, x, y );
491  if ( fpi.ptr_ != NULL ) return fpi;
492 #ifdef PRINTDEBUG
493  std::ostringstream o;
494  o << cc << "." << x << "." << y;
495 #endif
496  unsigned& nnt = cykdata_->nnt;
497  grammar_inversecategories_t& vcat = cykdata_->vcat;
498  std::stringstream ostr;
499  ostr << vcat[cc] << "." << x << "." << y;
500  std::string filename;
501  ostr >> filename;
502  SentenceSpecificGrammarData& g = *d_->ssgd;
503  MultiUnionT mur;
504  Label hieroindex = APBASETAG + cc * APCCTAG + x * APXTAG + y * APYTAG;
505  LDEBUG ( "bp> " << cc << "," << x << "," << y << ":" <<
506  ( unsigned ) cykdata_->bp ( cc, x, y ).size() );
507  for ( unsigned i = 0; i < cykdata_->bp ( cc, x, y ).size(); i++ ) {
508  unsigned idx = cykdata_->cykgrid ( cc, x, y, i );
510  std::vector<FSAPlusInfo> requiredfsts;
511  if ( g.isPhrase ( idx ) ) {
512  mur.Add ( addRule ( idx, requiredfsts, x + 1) ) ;
513  LDEBUG ( "AT " << cc << "," << x << "," << y <<
514  ":adding phrase-based rule index " << idx );
515  continue;
516  }
517  const cykparser_ruledependencies_t& mybp = cykdata_->bp ( cc, x, y );
518  for ( unsigned j = 0; j < mybp[i].size(); j += 3 ) {
519  if ( mybp[i][j] > nnt ) {
520  continue;
521  }
522  requiredfsts.push_back ( buildRTN ( mybp[i][j], mybp[i][j + 1],
523  mybp[i][j + 2] ) );
524  LDEBUG ( "back to bp> " << cc << "," << x << "," << y << ":" <<
525  ( unsigned ) cykdata_->bp ( cc, x, y ).size() );
526  }
527  mapfsts ( idx, requiredfsts );
528  LDEBUG ( "AT " << cc << "," << x << "," << y << ": adding hiero rule index " <<
529  idx );
530  mur.Add ( addRule ( idx, requiredfsts , x + 1) );
531  }
532  boost::shared_ptr< fst::VectorFst<Arc> > mdfst ( mur() );
533  LDBG_EXECUTE ( mdfst->Write ( "fsts/" + o.str() + ".fst" ) );
534  //Optimize
535  optimize ( &*mdfst ,
536  std::numeric_limits<unsigned>::max(),
537  optimize_ );
538  LDEBUG ( "AT " << cc << "," << x << "," << y << ": FST built!" );
539  LDBG_EXECUTE ( mdfst->Write ( "fsts/" + o.str() + "redm.fst" ) );
540  d_->stats->numstates[ cc * 1000000 + y * 1000 + x ] =
541  mdfst->NumStates(); //Just store the number of states of the not-expanded FSA.
542  //Calculate expanded number of states of the partial rtn.
543  if ( localprune_ )
544  rtnnumstates_->update ( cc, x, y, &*mdfst );
545  boost::scoped_ptr< fst::VectorFst<Arc> > pruned ( localPruning ( *mdfst, cc, x, y ) );
546  //We now might have a pruned lattice!
547  if ( pruned.get() != NULL ) {
548  LDBG_EXECUTE ( pruned->Write ( "fsts/" + o.str() + "redmp.fst" ) );
549  optimize (&*pruned , numstatesthreshold_ , !hipdtmode_ && optimize_ );
550  LDBG_EXECUTE ( pruned->Write ( "fsts/" + o.str() + "redmpo.fst" ) );
551  *mdfst = *pruned;
552  //Only if we prune, we add to stats total number of states of full and pruned lattice
553  d_->stats->numstates[ cc * 1000000 + y * 1000 + x ] = ( *rtnnumstates_ ) ( cc,
554  x,
555  y );
556  rtnnumstates_->update ( cc, x, y, &*mdfst ); //Update rtnnumstates again...
557  d_->stats->numprunedstates[ cc * 1000000 + y * 1000 + x ]
558  = ( *rtnnumstates_ ) (cc, x, y );
559  } else {
560  LDEBUG ( "AT " << cc << "," << x << "," << y << ":No pruning" );
561  }
562  boost::shared_ptr< fst::VectorFst<Arc> > outfst ( ( *rfba_ ) ( *mdfst,
563  hieroindex ) );
564  if ( outfst.get() != NULL ) {
565  LDEBUG ( "AT " << cc << "," << x << "," << y << ": replacefstbyarcfor cat= " <<
566  vcat[cc] << ",NS=" << mdfst->NumStates() );
567  rtn_->Add ( cc, x, y, outfst , mdfst );
568  hieroindexexistence_.insert ( hieroindex );
569  pairlabelfsts_.push_back ( pair< Label, const fst::Fst<Arc> * > ( hieroindex,
570  &*mdfst ) );
571  } else {
572  rtn_->Add ( cc, x, y, mdfst , outfst );
573  LDEBUG ( "AT: " << cc << "," << x << "," << y << ":" <<
574  "Delaying not applied. Stored, NS=" << ( unsigned ) mdfst->NumStates() );
575  }
576  // return ( *rtn_ ) ( cc, x, y );
577  FSAPlusInfo fpi2( ( *rtn_ ) ( cc, x, y ), cc, x, y);
578  return fpi2;
579  };
588  // fst::VectorFst<Arc> *addRule ( unsigned rule_idx,
589  // std::vector<fst::Fst<Arc>*>& lowerfsts ) {
590  fst::VectorFst<Arc> *addRule ( unsigned rule_idx,
591  std::vector<FSAPlusInfo>& lowerfsts
592  , unsigned offset ) {
593  SentenceSpecificGrammarData& gd = *d_->ssgd;
594  std::vector<std::string> translation = gd.getRHSSplitTranslation ( rule_idx );
595  if ( !translation.size() ) {
596  LERROR ( gd.getRule ( rule_idx ) );
597  translation.push_back ( 0 );
598  }
599  for (unsigned k = 0; k < translation.size(); ++k) {
600  if ( translation[k] == "<s>" ) {
601  translation[k] = "1";
602  } else if ( translation[k] == "</s>" ) translation[k] = "2";
603  else if ( translation[k] == "<dr>" ) {
604  std::stringstream dr;
605  dr << DR;
606  translation[k] = dr.str();
607  LDEBUG ( "Deletion rule: " << gd.getRule ( rule_idx ) << "," <<
608  translation[k] );
609  } else if ( translation[k] == "<oov>" ) {
610  std::stringstream oov;
611  oov << OOV;
612  translation[k] = oov.str();
613  LDEBUG ( "oov rule: " << gd.getRule ( rule_idx ) << "," << translation[k] );
614  } else if ( translation[k] == "<sep>" ) {
615  std::stringstream sep;
616  sep << SEP;
617  translation[k] = sep.str();
618  LDEBUG ( "separator rule: " << gd.getRule ( rule_idx ) << "," <<
619  translation[k] );
620  }
621  }
622  LDEBUG ( "Starting to build!" );
623  fst::VectorFst<Arc> *rulefst = new fst::VectorFst<Arc>;
624  rulefst->AddState();
625  rulefst->SetStart ( 0 );
626  rulefst->AddState();
627  Label iw2 = (at_ == RULES)?gd.getIdx ( rule_idx ) + 1: 0;
628  Label iw;
629  if ( !aligner_ ) iw = 0;
630  else iw = iw2;
631  LDEBUG ("Building FST for rule " << rule_idx << ":" << gd.getRule ( rule_idx ) << ", original id=" << gd.getIdx(rule_idx)
632  << ", translation size=" << translation.size() );
633  unsigned kmax = translation.size();
634  unsigned nonterminal = 0;
635  std::vector< pair< Label, const fst::Fst<Arc> * > > pairlabelfsts;
637  std::vector<unsigned> links(translation.size(), NORULE);
638  if (at_ == AFFILIATION) {
639  LDEBUG("Getting affiliation...");
640  gd.getLinks(rule_idx, links);
641  }
642  for ( unsigned k = 0; k < kmax; ++k ) {
643  //if non-terminal... just place special arc and expand later...
644  Label ow;
645  bool isnonterminal = !isTerminal ( translation[k] );
646  if ( isnonterminal) {
647  ow = APRULETAG + nonterminal;
648  USER_CHECK ( lowerfsts.size() > nonterminal,
649  "Missing fsts to build the rule..." );
650  offset +=lowerfsts[nonterminal].y_ + 1; // add span of the non-terminal
651  pairlabelfsts.push_back ( pair< Label, const fst::Fst<Arc> * >
652  ( ow, lowerfsts[nonterminal++].ptr_ ) );
654  } else {
655  std::istringstream buffer ( translation[k] );
656  buffer >> ow;
657  }
658  rulefst->AddState();
659  Label iw;
660  if ( !aligner_ ) iw = ow;
661  else {
662  if (isnonterminal) iw = NORULE; // ignore nts for affiliation
663  else {
664  iw = links[k];
665  if (at_ == AFFILIATION) {
666  iw += offset - nonterminal; // links are counting nts as 1, so we need to discount them (silly)
667  }
668  }
669  }
670  LDEBUG("Adding arc iw=" << iw << ",ow=" << ow);
671  rulefst->AddArc ( k, Arc ( iw, ow, Weight::One(), k + 1 ) );
672  }
673  float w = gd.getWeight ( rule_idx );
674  Weight weight = mw_ ( w , iw2 );
675  rulefst->AddArc ( kmax, Arc ( iw, 0, weight, kmax + 1 ) );
676  rulefst->SetFinal ( kmax + 1, Weight::One() );
677  fst::VectorFst<Arc>* auxi;
678  if ( nonterminal > 0 ) {
679  pairlabelfsts.push_back ( pair< Label, const fst::Fst<Arc> * >
680  ( APRULETAG + nonterminal, rulefst ) );
681  fst::VectorFst<Arc> *aux = new fst::VectorFst<Arc>;
682  Replace (pairlabelfsts, aux, APRULETAG + nonterminal, !aligner_);
683  delete rulefst;
684  rulefst = aux;
685  }
686  fst::RmEpsilon<Arc> ( rulefst );
687  return rulefst;
688  }
690  //Note that local conditions can be sentence-specific
691  void initLocalConditions() {
692  if ( !localprune_ ) return;
693  if ( !lpctuples_.size() ) return;
694  lpc_.clear();
695  LINFO ( "Set up conditions for local cell pruning" );
696  for ( unsigned k = 0; k < lpctuples_.size(); k += 4 ) {
697  int y = ucam::util::toNumber<int> ( lpctuples_[k + 1] );
698  if ( y < 0 ) y = cykdata_->getNumberWordsSentence() + y + 1;
699  LINFO ( "cell pruning conditions (cat,span,numstates,weight): "
700  << cykdata_->categories[lpctuples_[k]]
701  << "," << y << ","
702  << ucam::util::toNumber<unsigned> ( lpctuples_[k + 2] ) << ","
703  << ucam::util::toNumber<unsigned> ( lpctuples_[k + 3] ) );
704  conditions c ( cykdata_->categories[lpctuples_[k]]
705  , y
706  , ucam::util::toNumber<unsigned> ( lpctuples_[k + 2] )
707  , ucam::util::toNumber<unsigned> ( lpctuples_[k + 3] ) );
708  lpc_.add ( c );
709  }
710  LINFO ( "We have: " << lpc_.size() << " conditions" );
711  };
719  inline void applyFilters ( fst::VectorFst<Arc> *fst ) {
720  fst::ArcSort<Arc> ( fst, fst::OLabelCompare<Arc>() );
723  LINFO ( "Apply " << d_->filters.size() << " filters to the search space!" );
724  for ( unsigned k = 0; k < d_->filters.size(); ++k ) {
725  LDBG_EXECUTE ( fst::FstWrite ( * (d_->filters[k]), "fsts/filter.fst.gz" ) );
726  LDBG_EXECUTE ( fst::FstWrite ( *fst, "fsts/before-composition.fst.gz" ) );
728  if (!hipdtmode_ || pdtparens_.empty() ) {
729  LINFO ("FST composition with filter");
730  *fst = (fst::ComposeFst<Arc> (*fst, *d_->filters[k]) );
731  } else {
732  LINFO ("PDT composition");
733 #if OPENFSTVERSION>=1003003
734  fst::PdtComposeFstOptions<Arc> opts (*fst, pdtparens_, *d_->filters[k]);
735 #else
736  fst::PdtComposeOptions<Arc> opts (*fst, pdtparens_, *d_->filters[k]);
737 #endif
738  opts.gc_limit = 0;
739  *fst = (fst::ComposeFst<Arc> (*fst, *d_->filters[k], opts) );
740  }
741  LINFO ( "After filter " << k << ", NS=" << fst->NumStates() );
742  Connect ( fst );
743  LDBG_EXECUTE ( fst::FstWrite ( *fst, "fsts/after-composition.fst.gz" ) );
744  if ( !fst->NumStates() ) break;
745  }
746  };
749  typedef boost::shared_ptr<ApplyLanguageModelOnTheFlyInterfaceType> ApplyLanguageModelOnTheFlyInterfacePtrType;
750  std::vector<ApplyLanguageModelOnTheFlyInterfacePtrType> almotfLocal_;
751  std::vector<ApplyLanguageModelOnTheFlyInterfacePtrType> almotf_;
753  // Prepares language model application handlers for each kenlm type.
754  // i.e. an array of templated instances of ApplyLanguageModelOnTheFly
755  // Note: possibly can be refactored/merged with method initializeLanguageModelHandlers
756  // in task.applylm.hpp
757  template< template<class> class MakeWeightT>
758  void initializeLanguageModelHandlers(const std::string& lmkey
759  , MakeWeightT<Arc> &mw
760  , std::vector<ApplyLanguageModelOnTheFlyInterfacePtrType> &almotf) {
761  if (almotf.size()) {
762  LINFO("Skipping!");
763  return; // already done
764  }
765  almotf.resize(d_->klm[lmkey].size());
766  unordered_set<Label> epsilons;
767  for ( unsigned k = 0; k < d_->klm[lmkey].size(); ++k ) {
768  USER_CHECK ( d_->klm[lmkey][k]->model != NULL,
769  "Language model " << k << " not available!" );
770  almotf[k].reset(fsttools::assignKenLmHandler<Arc, MakeWeightT >(rg_, lmkey, epsilons
771  , *(d_->klm[lmkey][k])
772  , mw, true,k));
773  mw.update();
774  }
775  LINFO("Initialized " << d_->klm[lmkey].size() << " language model handlers");
776  }
778  // \todo Merge/refactor this code with task.applylm.hpp.
779  template< template<class> class MakeWeightT>
780  inline fst::VectorFst<Arc> *applyLanguageModel ( const fst::Fst<Arc>& localfst
781  , const std::string& lmkey
782  , MakeWeightT<Arc> &mw
783  , std::vector<ApplyLanguageModelOnTheFlyInterfacePtrType> &almo
784  ) {
785  if ( d_->klm.find ( lmkey ) == d_->klm.end() ) {
786  if (!warned_) {
787  FORCELINFO ( "No Language models for key=" << lmkey
788  << " available! Skipping language model application. " );
789  }
790  warned_ = true;
791  return NULL;
792  }
794  fst::VectorFst<Arc> *output
795  = new fst::VectorFst<Arc> (* (const_cast<fst::Fst<Arc> *> ( &localfst ) ) );
797  // unfortunately they can be lattice-specific (pdt parentheses)
798  unordered_set<Label> epsilons;
799  epsilons.insert ( DR );
800  epsilons.insert ( OOV );
801  epsilons.insert ( EPSILON );
802  epsilons.insert ( SEP );
803  // If it is a pdt, add all parentheses so they get treated as epsilons too
804  // for this particular lattice
805  for (unsigned j = 0; j < pdtparens_.size(); ++j) {
806  epsilons.insert (pdtparens_[j].first);
807  epsilons.insert (pdtparens_[j].second);
808  }
810  for ( unsigned k = 0; k < d_->klm[lmkey].size(); ++k ) {
811  LINFO ( "Composing with " << k << "-th language model" );
812  d_->stats->setTimeStart ( "on-the-fly-composition "
813  + ucam::util::toString ( k ) );
814  fst::VectorFst<Arc> *aux = almo[k]->run(*output, epsilons);
815  if ( !aux ) {
816  LERROR ("Something very wrong happened in composition with the lm...");
817  exit (EXIT_FAILURE);
818  }
819  delete output; output = aux;
820  d_->stats->setTimeEnd ( "on-the-fly-composition "
821  + ucam::util::toString ( k ) );
822  LDEBUG ( "After applying language model, NS=" << output->NumStates() );
823  }
824  LINFO ( "Connect!" );
825  Connect (output);
826  LINFO ( "Done! NS=" << output->NumStates() );
827  return output;
828  }
834  inline fst::VectorFst<Arc> *applyLanguageModel ( const fst::Fst<Arc>& localfst
835  , bool local = false ) {
836  if ( local ) {
838  initializeLanguageModelHandlers(locallmkey_, mw, almotfLocal_);
839  if (!almotfLocal_.size()) return NULL;
840  LINFO ( "Composing with local lm for inadmissible pruning (unless on top cell)" );
841  return applyLanguageModel (localfst, locallmkey_, mw, almotfLocal_);
842  } else {
844  initializeLanguageModelHandlers(lmkey_, mw, almotf_);
845  if (!almotf_.size()) return NULL;
846  LINFO ( "Composing with full lm for admissible pruning" );
847  return applyLanguageModel (localfst, lmkey_, mw, almotf_);
848  }
849  };
851  inline fst::VectorFst<Arc> *expand ( const fst::VectorFst<Arc>& localfst,
852  unsigned cc, unsigned x, unsigned y ) {
853  Label hieroindex = APBASETAG + cc * APCCTAG + x * APXTAG + y * APYTAG;
854  USER_CHECK ( localfst.NumStates() > 0, "Empty lattice?" );
856  if ( hieroindexexistence_.find ( hieroindex ) == hieroindexexistence_.end() )
857  pairlabelfsts_.push_back ( pair< Label, const fst::Fst<Arc> * > ( hieroindex,
858  &localfst ) );
859  fst::VectorFst<Arc> *aux = new fst::VectorFst<Arc>;
860  if (!hipdtmode_ ) {
861  LINFO ("Replace (RTN->FSA)");
862  d_->stats->setTimeStart ("replace-rtn");
863  Replace (pairlabelfsts_, aux, hieroindex, !aligner_);
864  d_->stats->setTimeEnd ("replace-rtn");
865  } else {
866  LINFO ("Replace (RTN->PDA)");
867  d_->stats->setTimeStart ("replace-pdt");
868  Replace (pairlabelfsts_, aux, &pdtparens_, hieroindex);
869  d_->stats->setTimeEnd ("replace-pdt");
870  LINFO ("Number of pdtparens=" << pdtparens_.size() );
871  }
872  //if it doesn't exist, then leave the pair list as it was!
873  if ( hieroindexexistence_.find ( hieroindex ) == hieroindexexistence_.end() )
874  pairlabelfsts_.pop_back();
875  return aux;
876  }
890  fst::VectorFst<Arc> *localPruning ( const fst::VectorFst<Arc>& fst, unsigned cc,
891  unsigned x, unsigned y ) {
892 #ifdef PRINTDEBUG
893  std::ostringstream o;
894  o << cc << "." << x << "." << y;
895 #endif
896  if ( !localprune_ ) return NULL;
897  float weight;
898  unsigned referenceminstates;
899  LDEBUG ( "AT " << cc << "," << x << "," << y <<
900  ": Testing conditions; expected lattice size=" << ( *rtnnumstates_ ) ( cc, x,
901  y ) );
902  if ( lpc_ ( cc, y + 1, ( *rtnnumstates_ ) ( cc, x, y ), weight ) ) {
903  LINFO ( "AT " << cc << "," << x << "," << y <<
904  ": Qualifies for local pruning. Making it so!" );
905  LDEBUG ( "AT " << cc << "," << x << "," << y << ": expanding RTN/RmEpsilon" );
906  fst::VectorFst<Arc> *efst = expand ( fst, cc, x, y );
907  fst::RmEpsilon<Arc> ( efst );
908  LINFO ( "AT " << cc << "," << x << "," << y << ": NS=" << efst->NumStates() );
909  ++piscount_;
910  LINFO("Apply filtering");
911  applyFilters ( efst );
912  LINFO ( "Apply LM" );
913  fst::VectorFst<Arc> * latlm = applyLanguageModel ( *efst , true );
915  if ( latlm != NULL ) {
916  delete efst;
917  //\todo Include union with shortest path...
918  if (!hipdtmode_ || pdtparens_.empty() ) {
919  LINFO ( "Prune with weight=" << weight );
920  fst::Prune<Arc> ( latlm, mw_ ( weight ) );
921  } else {
922  LINFO ( "PDT expanding with weight=" << weight );
923  fst::ExpandOptions<Arc> eopts (true, false, mw_ ( weight ) );
924  fst::VectorFst<Arc> latlmaux;
925  Expand ( *latlm, pdtparens_, &latlmaux, eopts);
926  *latlm = latlmaux;
927  pdtparens_.clear();
928  }
929  LINFO ( "Delete LM scores" );
930  //Deletes LM scores if using lexstdarc or tuplearc
931  // fst::MakeWeight2<Arc> mwcopy;
932  MakeWeightHifstLocalLm<Arc > mwcopy(rg_);
933  fst::Map<Arc> ( latlm,
935  LINFO ( "AT " << cc << "," << x << "," << y << ": pruned with weight=" << weight
936  << ",NS=" << latlm->NumStates() );
937  return latlm;
938  }
939  LINFO ( "AT " << cc << "," << x << "," << y <<
940  "Local LM not applied, filtered with " << d_->filters.size() <<
941  " filter(s) ,NS=" << efst->NumStates() );
942  return efst;
943  }
944  LINFO ( "AT " << cc << "," << x << "," << y <<
945  ": Does not qualify for local pruning. " );
946  return NULL;
947  };
951 };
953 }
954 } // end namespaces
956 #endif
