Cambridge SMT System
printstrings.main.cpp
Go to the documentation of this file.
1 
8 #include <main.custom_assert.hpp>
9 #include <main.logger.hpp>
10 #include <main.printstrings.hpp>
11 
12 typedef std::unordered_map<std::size_t, std::string> labelmap_t;
13 typedef labelmap_t::iterator labelmap_iterator_t;
15 string vmapfile;
16 bool printweight = false;
17 bool sparseformat = false;
18 bool nohyps = false;
19 bool liblinrankformat = false;
20 bool dotProduct = false;
21 unsigned myPrecision = 6; // default
22 
23 
24 using fst::Hyp;
25 
26 // \brief Load intersection space which will be applied prior to printing
27 // If it is an fst, then load it and we are done.
28 // Otherwise, assume it is a list of integer-mapped hypotheses
29 // and build an unweighted FSA.
30 template<class ArcT>
31 VectorFst<ArcT> *createIntersectionSpace(std::string const &filename) {
32  if (filename == "") return NULL;
33  if (fst::DetectFstFile(filename))
34  return fst::VectorFstRead<ArcT>(filename);
35 
36  ucam::util::iszfstream f (filename);
37  VectorFst<ArcT> *i = new VectorFst<ArcT>;
38  i->AddState();
39  i->SetStart(0);
40  std::string sentence;
41 
42  while (getline(f,sentence)) {
43  std::vector<std::string> wrds;
44  boost::split(wrds, sentence, boost::is_any_of(" "));
45  unsigned p = 0; // starts at state 0;
46  for (unsigned k = 0; k < wrds.size(); ++k) {
47  unsigned n =i->AddState();
48  unsigned label = ucam::util::toNumber<unsigned>(wrds[k]);
49  i->AddArc(p,ArcT(label, label, ArcT::Weight::One(), n));
50  p=n;
51  }
52  i->SetFinal(p, ArcT::Weight::One());
53  }
54  f.close();
55  *i = DeterminizeFst<ArcT>(*i);
56  Minimize(i);
57  return i;
58 }
59 
63 template<class Arc>
64 struct HypW: public Hyp<Arc> {
65 
66  HypW (std::basic_string<unsigned> const& h
67  , std::basic_string<unsigned> const& oh
68  , typename Arc::Weight const& c)
69  : Hyp<Arc> (h, oh, c) {
70  }
71  HypW (HypW<Arc> const& h)
72  : Hyp<Arc> (h) {
73  }
74 };
75 
82 template <class Arc>
83 void printWeight (typename Arc::Weight const& weight
84  , std::ostream& os
85  , unsigned precision = myPrecision
86  ) {
87  os << std::setprecision(precision) << weight;
88 }
89 
98 template <>
99 void printWeight<TupleArc32> (const TupleW32& weight
100  , std::ostream& os
101  , unsigned precision
102  ) {
103  std::map<int,float> costs;
104  std::string separator (",");
105 
106  for (fst::SparseTupleWeightIterator<fst::TropicalWeight, int> it (weight);
107  !it.Done(); it.Next() ) {
108  costs[it.Value().first] += it.Value().second.Value();
109  }
110 
111  if (liblinrankformat) {
112  for (std::map<int,float>::const_iterator itx=costs.begin();
113  itx != costs.end(); ++itx) {
114  os << " " << itx->first << ":" << itx->second;
115  }
116  return;
117  }
118 
119  if (sparseformat) {
120  os << "0" << separator << costs.size();
121  for (std::map<int,float>::const_iterator itx=costs.begin()
122  ; itx != costs.end()
123  ; ++itx) {
124  os << separator << itx->first << separator << std::setprecision(precision) << itx->second;
125  }
126  return;
127  }
128  if (dotProduct) {
129  float w =0;
130  std::vector<float> const &fws = TupleW32::Params();
131  for (std::map<int,float>::const_iterator itx=costs.begin()
132  ; itx != costs.end()
133  ; ++itx) {
134  if (itx->first < 1) continue;
135 
136  float fw = fws[itx->first - 1];
137  w = w + fw * itx->second;
138  }
139  os << std::setprecision(precision) << w;
140  return;
141  }
142  std::size_t nonSparseSize = TupleW32::Params().size();
143  std::size_t counter = 1;
144  separator = "";
145  for (std::map<int,float>::const_iterator itx=costs.begin()
146  ; itx != costs.end()
147  ; ++itx) {
148  if (itx->first < 1 ) continue;
149  std::size_t featureIndex = itx->first;
150  for (std::size_t featureMissingIndex = counter;
151  featureMissingIndex < featureIndex; ++featureMissingIndex) {
152  os << separator << "0";
153  separator = ","; // @todo should be possible to avoid resetting every time.
154  }
155  os << separator << itx->second;
156  counter = itx->first + 1;
157  separator = ",";
158  }
159  for (; counter <= nonSparseSize; ++counter) {
160  os << separator << "0";
161  separator = ",";
162  }
163 }
164 
168  for (int k=0; k<h.size(); k++) {
169  if (h[k] == OOV) continue;
170  if (h[k] == DR) continue;
171  if (h[k] == EPSILON) continue;
172  if (h[k] == SEP) continue;
173  if (h[k] == 1) continue;
174  if (h[k] == 2) continue;
175  x.push_back(h[k]);
176  }
177  return x;
178 }
179 
183 template<class Arc>
184 std::ostream& operator<< (std::ostream& os, const Hyp<Arc>& obj) {
185  for (unsigned k = 0; k < obj.hyp.size(); ++k) {
186  if (obj.hyp[k] == OOV) continue;
187  if (obj.hyp[k] == DR) continue;
188  if (obj.hyp[k] == EPSILON) continue;
189  if (obj.hyp[k] == SEP) continue;
190  os << obj.hyp[k] << " ";
191  }
192  if (printweight) {
193  os << "\t";
194  printWeight<Arc> (obj.cost, os, myPrecision);
195  };
196  return os;
197 }
198 
203 template<class Arc>
204 std::ostream& operator<< (std::ostream& os, const HypW<Arc>& obj) {
205  for (unsigned k = 0; k < obj.hyp.size(); ++k) {
206  if (obj.hyp[k] == OOV) continue;
207  if (obj.hyp[k] == DR) continue;
208  if (obj.hyp[k] == EPSILON) continue;
209  if (obj.hyp[k] == SEP) continue;
210  labelmap_iterator_t itx = vmap.find (obj.hyp[k]);
211  if (itx != vmap.end() )
212  os << vmap[ obj.hyp[k] ] << " ";
213  else {
214  os << "[" << obj.hyp[k] << "] ";
215  std::cerr << "\nWARNING: word map does not contain word " << obj.hyp[k] <<
216  std::endl;
217  }
218  }
219  if (printweight) {
220  os << "\t";
221  printWeight<Arc> (obj.cost, os, myPrecision);
222  };
223  return os;
224 }
225 
226 template <class Arc, class HypT>
227 int run ( ucam::util::RegistryPO const& rg) {
228  using namespace ucam::util;
229  using namespace HifstConstants;
230  PatternAddress<unsigned> input (rg.get<std::string>
231  (kInput.c_str() ) );
232  PatternAddress<unsigned> output (rg.get<std::string>
233  (kOutput.c_str() ) );
234  PatternAddress<unsigned> intersectionLattice (rg.get<std::string>
235  (kIntersectionWithHypothesesLoad.c_str() ) );
236  unsigned n = rg.get<unsigned> (kNbest.c_str() );
237  boost::scoped_ptr<oszfstream> out;
238  bool unique = rg.exists (kUnique.c_str() );
239  bool printOutputLabels = rg.exists(kPrintOutputLabels.c_str());
240  bool printInputOutputLabels = rg.exists(kPrintInputOutputLabels.c_str());
241 
242  if (printInputOutputLabels)
243  FORCELINFO("Printing input and output labels...");
244  std::string old;
245  std::string refFiles;
246  bool intRefs = false;
247  bool dobleu = false;
248  bool sentbleu = false;
249 
251  refFiles = rg.getString(HifstConstants::kWordRefs);
252  dobleu = true;
253  }
255  refFiles = rg.getString(HifstConstants::kIntRefs);
256  intRefs = true;
257  dobleu = true;
258  }
259  if (rg.exists (HifstConstants::kSentBleu) ) {
260  if (!dobleu) {
261  LERROR("Must provide references to compute sentence level bleu");
262  exit(EXIT_FAILURE);
263  }
264  sentbleu = true;
265  }
266 
267  if (rg.exists (HifstConstants::kWeight) ) {
268  printweight = true;
269  }
271  sparseformat = true;
272  }
274  if (sparseformat == true) {
275  LERROR("Sparse format and dot product are not available at the same time.");
276  exit(EXIT_FAILURE);
277  }
278  dotProduct = true;
279  }
280 
281  if (rg.exists (HifstConstants::kSuppress) ) {
282  nohyps = true;
283  }
284 
285  // n.b. this should be last, to override any other settings
287  liblinrankformat = true;
288  if (!dobleu) {
289  LERROR("Must provide references to compute features for liblinear rankings");
290  exit(EXIT_FAILURE);
291  }
292  sentbleu = true;
293  }
294 
295  std::string extTok(rg.exists(HifstConstants::kExternalTokenizer) ?
298  ucam::fsttools::BleuScorer *bleuScorer;
299  if (dobleu)
300  bleuScorer = new ucam::fsttools::BleuScorer(refFiles, extTok, 1, intRefs, vmapfile);
301 
302  int nlines=0;
303  for ( IntRangePtr ir (IntRangeFactory ( rg,
304  kRangeOne ) );
305  !ir->done();
306  ir->next() ) {
307  nlines++;
308  boost::scoped_ptr<fst::VectorFst<Arc> > ifst (fst::VectorFstRead<Arc> (input (
309  ir->get() ) ) );
310  Connect (&*ifst);
311  if (old != output (ir->get() ) ) {
312  out.reset (new oszfstream (output (ir->get() ) ) );
313  old = output (ir->get() );
314  }
315  if (!ifst->NumStates() ) {
316  *out << "[EMPTY]" << std::endl;
317  continue;
318  }
319  // Projecting allows unique to work for all cases.
320  if (printOutputLabels)
321  fst::Project(&*ifst, PROJECT_OUTPUT);
322  else if (!printInputOutputLabels) // what a mess
323  fst::Project(&*ifst, PROJECT_INPUT);
324 
325  fst::VectorFst<Arc> nfst;
326  // find 1-best and compute bleu stats
327  if (dobleu) {
328  ShortestPath (*ifst, &nfst, 1, unique);
329  std::vector<HypT> hyps1;
330  fst::printStrings<Arc> (nfst, &hyps1);
331  ucam::fsttools::SentenceIdx h(hyps1[0].hyp.begin(), hyps1[0].hyp.end());
332  // bleuscorer indexes references from 0; ir counts from 1
333  if (h.size() > 0)
334  bStats = bStats + bleuScorer->SentenceBleuStats(ir->get()-1, RemoveUnprintable(h));
335  }
336 
337  boost::scoped_ptr< VectorFst<Arc> > intersection
338  (createIntersectionSpace<Arc>( intersectionLattice( ir->get() ) ));
339 
340 
341  if (intersection.get()) {
342  VectorFst<Arc> cmps;
343  *ifst = ComposeFst<Arc>(*intersection, *ifst);
344  }
345 
346  if (!ifst->NumStates() ) {
347  *out << "[EMPTY]" << std::endl;
348  continue;
349  }
350 
351  // Otherwise determinization runs (both determinizefst or
352  // inside shortestpath) doesn't produce the expected result:
353  // epsilons are being treated as symbols
354  if (unique) {
355  fst::RmEpsilon<Arc>(&*ifst);
356  }
357  // find nbest, compute stats, print
358  ShortestPath (*ifst, &nfst, n, unique );
359 
360  std::vector<HypT> hyps;
361  fst::printStrings<Arc> (nfst, &hyps);
362  for (unsigned k = 0; k < hyps.size(); ++k) {
363  ucam::fsttools::SentenceIdx h(hyps[k].hyp.begin(), hyps[k].hyp.end());
365  double sbleu;
366  if (sentbleu) {
367  // bleuscorer indexes references from 0; ir counts from 1
368  sbStats = bleuScorer->SentenceBleuStats(ir->get()-1, RemoveUnprintable(h));
369  sbleu = bleuScorer->ComputeSBleu(sbStats).m_bleu;
370  }
371  // output
372  if (liblinrankformat) {
373  *out->getStream() << sbleu << " qid:" << ir->get();
374  printWeight<Arc>(hyps[k].cost, *out->getStream());
375  *out->getStream() << std::endl;
376  }
377  if (nohyps == false) {
378  if (printInputOutputLabels) { // add the output labels.
379  for (unsigned j = 0; j < hyps[k].hyp.size(); ++j)
380  if (hyps[k].hyp[j] != 0)
381  *out->getStream() << hyps[k].hyp[j] << " ";
382  *out->getStream() << "\t";
383  for (unsigned j = 0; j < hyps[k].ohyp.size(); ++j)
384  if (hyps[k].ohyp[j] != 0)
385  *out->getStream() << hyps[k].ohyp[j] << " ";
386  if (printweight)
387  *out->getStream() << "\t" << std::setprecision(myPrecision) << hyps[k].cost;
388  } else {
389  *out->getStream() << hyps[k];
390  }
391  if (sentbleu)
392  *out->getStream() << "\t" << sbStats << "\t" << sbleu;
393  *out->getStream() << std::endl;
394  }
395  }
396  }
397  if (dobleu)
398  FORCELINFO("BLEU STATS:" << bStats << "; BLEU: " << bleuScorer->ComputeBleu(bStats));
399  FORCELINFO("Processed " << nlines << " files");
400 };
401 
402 /*
403  * \brief Main function.
404  * \param argc Number of command-line program options.
405  * \param argv Actual program options.
406  * \remarks
407  */
408 int main ( int argc, const char* argv[] ) {
409  using namespace HifstConstants;
410  using namespace ucam::util;
411  initLogger ( argc, argv );
412  FORCELINFO ( argv[0] << " starts!" );
413  RegistryPO rg ( argc, argv );
414  FORCELINFO ( rg.dump ( "CONFIG parameters:\n=====================",
415  "=====================" ) );
416  if (rg.exists (kWeight) ) {
417  printweight = true;
418  myPrecision = rg.get<unsigned>(kWeightPrecision.c_str());
419  LINFO("Setting float precision=" << myPrecision);
420  }
421  if (rg.exists (kSparseFormat) ) {
422  sparseformat = true;
423  }
424  if (rg.exists (kSparseDotProduct) ) {
425  if (sparseformat == true) {
426  LERROR("Sparse format and dot product are not available at the same time.");
427  exit(EXIT_FAILURE);
428  }
429  dotProduct = true;
430  }
431 
432  // check that tuplearc weights are set for the tuplearc semiring
433  if (rg.get<std::string> (kHifstSemiring.c_str() ) ==
435  const std::string& tuplearcWeights = rg.exists (
437  ? rg.get<std::string> (kTupleArcWeights.c_str() ) : "";
438  if (tuplearcWeights.empty() ) {
439  LERROR ("The tuplearc.weights option needs to be specified "
440  "for the tropical sparse tuple weight semiring "
441  "(--semiring=tuplearc)");
442  exit (EXIT_FAILURE);
443  }
444  TupleW32::Params() = ParseParamString<float> (tuplearcWeights);
445  }
446  std::string const& semiring = rg.get<std::string>
447  (kHifstSemiring);
448  if (!vmap.size() && rg.get<std::string> (kLabelMap) != "" ) {
449  FORCELINFO ("Loading symbol map file...");
450  vmapfile = rg.get<std::string> (HifstConstants::kLabelMap);
451  iszfstream f (rg.get<std::string> (kLabelMap) );
452  unsigned id;
453  std::string word;
454  while (f >> word >> id) {
455  vmap[id] = word;
456  }
457  FORCELINFO ("Loaded " << vmap.size() << " symbols");
458  if (semiring == kHifstSemiringStdArc) {
459  run<fst::StdArc, HypW<fst::StdArc> > (rg);
460  } else if (semiring == kHifstSemiringLexStdArc) {
461  run<fst::LexStdArc, HypW<fst::LexStdArc> > (rg);
462  } else if (semiring == kHifstSemiringTupleArc) {
463  run<TupleArc32, HypW<TupleArc32> > (rg);
464  } else {
465  LERROR ("Sorry, semiring option not correctly defined");
466  }
467  FORCELINFO ( argv[0] << " finished!" );
468  exit (EXIT_SUCCESS);
469  }
470  if (semiring == kHifstSemiringStdArc) {
471  run<fst::StdArc, Hyp<fst::StdArc> > (rg);
472  } else if (semiring == kHifstSemiringLexStdArc) {
473  run<fst::LexStdArc, Hyp<fst::LexStdArc> > (rg);
474  } else if (semiring == kHifstSemiringTupleArc) {
475  run<TupleArc32, Hyp<TupleArc32> > (rg);
476  } else {
477  LERROR ("Sorry, semiring option not correctly defined");
478  }
479  FORCELINFO ( argv[0] << " finished!" );
480 }
Wrapper stream class that writes to pipes, text files or gzipped files.
Definition: szfstream.hpp:200
std::string const kHifstSemiring
HypW(std::basic_string< unsigned > const &h, std::basic_string< unsigned > const &oh, typename Arc::Weight const &c)
class that expands a wildcard into its actual value. This is useful e.g. for filenames ranging severa...
std::string const kPrintOutputLabels
std::vector< Wid > SentenceIdx
Definition: bleu.hpp:22
bool sparseformat
bool DetectFstFile(std::string const &filename, std::string const &extname="fst")
Detect trivially by extension whether it is an fst or not.
Definition: fstio.hpp:47
VectorFst< ArcT > * createIntersectionSpace(std::string const &filename)
#define LINFO(msg)
#define SEP
int run(ucam::util::RegistryPO const &rg)
std::basic_string< unsigned > ohyp
Definition: fstio.hpp:143
std::string const kSentBleu
T get(const std::string &key) const
Returns parsed value associated to key.
Definition: registrypo.hpp:194
void initLogger(int argc, const char *argv[])
Inits logger, parses param options checking for –logger.verbose.
std::string const kInput
std::string const kExternalTokenizer
#define FORCELINFO(msg)
bool liblinrankformat
boost::scoped_ptr< NumberRangeInterface< unsigned > > IntRangePtr
Definition: range.hpp:214
std::string const kIntRefs
std::string const kWeight
#define IntRangeFactory
Definition: range.hpp:213
unsigned myPrecision
#define DR
std::string const kNbest
std::string const kOutput
BleuStats SentenceBleuStats(const Sid sid, const SentenceIdx &hypIdx)
Definition: bleu.hpp:296
int main(int argc, const char *argv[])
bool dotProduct
std::string const kWordRefs
Struct template that represents a hypothesis in a lattice.
Definition: fstio.hpp:142
Implements Tropical Sparse tuple weight semiring, extending from openfst SparsePowerWeight class...
bool nohyps
labelmap_t::iterator labelmap_iterator_t
std::string const kPrintInputOutputLabels
Same as Hyp but the printing will convert integer ids to words.
std::string const kLabelMap
iszfstream & getline(iszfstream &izs, std::string &line)
Definition: szfstream.hpp:178
std::string const kHifstSemiringLexStdArc
Bleu ComputeSBleu(const BleuStats &bs)
Definition: bleu.hpp:338
std::string const kHifstSemiringStdArc
Static variables for logger. Include only once from main file.
std::string getString(const std::string &key) const
Performs get<string> and checks whether the real value is to be loaded from file (–param=file://.....)
Definition: registrypo.hpp:205
std::string const kHifstSemiringTupleArc
ucam::fsttools::SentenceIdx RemoveUnprintable(const ucam::fsttools::SentenceIdx &h)
void printWeight< TupleArc32 >(const TupleW32 &weight, std::ostream &os, unsigned precision)
Template specialization of printWeight for a tropical sparse tuple weight. Uses the global var sparse...
bool exists(const std::string &key) const
Determines whether a program option (key) has been defined by the user.
Definition: registrypo.hpp:235
#define EPSILON
std::basic_string< unsigned > hyp
Definition: fstio.hpp:143
labelmap_t vmap
std::unordered_map< std::size_t, std::string > labelmap_t
void printWeight(typename Arc::Weight const &weight, std::ostream &os, unsigned precision=myPrecision)
Templated method that prints an arc weight. By default, reuses the operator<< already defined for eac...
std::string const kLibLinRankFormat
std::string dump(const std::string &decorator_start="", const std::string &decorator_end="")
Dumps all configuration parameters into a string with a reasonably pretty format. ...
Definition: registrypo.hpp:108
std::string const kWeightPrecision
std::string const kTupleArcWeights
#define OOV
#define LERROR(msg)
HypW(HypW< Arc > const &h)
Bleu ComputeBleu(const BleuStats &bs)
Definition: bleu.hpp:326
std::string const kSuppress
bool printweight
const std::string kRangeOne
Definition: range.hpp:26
std::string const kSparseDotProduct
std::string const kSparseFormat
Wrapper stream class that reads pipes, text files or gzipped files.
Definition: szfstream.hpp:34
std::string const kIntersectionWithHypothesesLoad
std::string const kUnique
Static variable for custom_assert. Include only once from main file.
string vmapfile