Cambridge SMT System
RuleRetriever.java
Go to the documentation of this file.
1 /*******************************************************************************
2  * Licensed under the Apache License, Version 2.0 (the "License");
3  * you may not use these files except in compliance with the License.
4  * You may obtain a copy of the License at
5  *
6  * http://www.apache.org/licenses/LICENSE-2.0
7  *
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  * Copyright 2014 - Juan Pino, Aurelien Waite, William Byrne
15  *******************************************************************************/
16 package uk.ac.cam.eng.rule.retrieval;
17 
18 import java.io.BufferedReader;
19 import java.io.BufferedWriter;
20 import java.io.File;
21 import java.io.FileNotFoundException;
22 import java.io.FileOutputStream;
23 import java.io.FileReader;
24 import java.io.FileWriter;
25 import java.io.FilenameFilter;
26 import java.io.IOException;
27 import java.io.OutputStreamWriter;
28 import java.util.ArrayList;
29 import java.util.Collection;
30 import java.util.Collections;
31 import java.util.HashMap;
32 import java.util.HashSet;
33 import java.util.List;
34 import java.util.Map;
35 import java.util.Set;
36 import java.util.TreeMap;
37 import java.util.concurrent.ExecutorService;
38 import java.util.concurrent.Executors;
39 import java.util.concurrent.TimeUnit;
40 import java.util.regex.Matcher;
41 import java.util.regex.Pattern;
42 import java.util.zip.GZIPOutputStream;
43 
44 import org.apache.commons.lang.time.StopWatch;
45 import org.apache.hadoop.conf.Configuration;
46 import org.apache.hadoop.fs.FileSystem;
47 import org.apache.hadoop.fs.Path;
48 import org.apache.hadoop.hbase.io.hfile.CacheConfig;
49 import org.apache.hadoop.hbase.io.hfile.HFile;
50 import org.apache.hadoop.hbase.util.BloomFilter;
51 import org.apache.hadoop.hbase.util.BloomFilterFactory;
52 import org.apache.hadoop.io.NullWritable;
53 import org.apache.hadoop.mapreduce.Partitioner;
54 import org.apache.hadoop.mapreduce.lib.partition.HashPartitioner;
55 
56 import uk.ac.cam.eng.extraction.Rule;
57 import uk.ac.cam.eng.extraction.RuleString;
58 import uk.ac.cam.eng.extraction.Symbol;
59 import uk.ac.cam.eng.extraction.dr$;
61 import uk.ac.cam.eng.extraction.oov$;
65 import uk.ac.cam.eng.util.CLI;
66 
67 import com.beust.jcommander.ParameterException;
68 
81 public class RuleRetriever {
82 
83  private BloomFilter[] bfs;
84 
85  private HFileRuleReader[] readers;
86 
87  private Partitioner<RuleString, NullWritable> partitioner = new HashPartitioner<>();
88 
89  RuleFilter filter;
90 
91  Set<Rule> passThroughRules = new HashSet<>();
92 
93  Set<Rule> foundPassThroughRules = new HashSet<>();
94 
95  Set<RuleString> testVocab = new HashSet<>();
96 
97  Set<RuleString> foundTestVocab = new HashSet<>();
98 
99  Map<RuleString, Set<Integer>> sourceToSentenceId = new HashMap<>();
100 
101  List<Set<Symbol>> targetSideVocab = new ArrayList<>();
102 
103  private int maxSourcePhrase;
104 
105  FeatureRegistry fReg;
106 
107  private int noOfFeatures;
108 
109  private String targetVocabFile;
110 
111  private void setup(String testFile, CLI.RuleRetrieverParameters params)
112  throws FileNotFoundException, IOException {
113  fReg = new FeatureRegistry(params.features.features,
114  params.rp.prov.provenance);
115  noOfFeatures = fReg.getFeatureIndices(fReg.getFeatures().toArray(
116  new Feature[0])).length;
117  filter = new RuleFilter(params.fp, new Configuration());
118  maxSourcePhrase = params.rp.maxSourcePhrase;
119  Set<RuleString> passThroughVocab = new HashSet<>();
120  Set<RuleString> fullTestVocab = getTestVocab(testFile);
121  for(Rule r : getPassThroughRules(params.passThroughRules)){
122  if(fullTestVocab.contains(r.source())){
123  passThroughVocab.add(r.source());
124  passThroughRules.add(r);
125  }
126  }
127  testVocab = new HashSet<>();
128  for (RuleString word : fullTestVocab) {
129  if (!passThroughVocab.contains(word)) {
130  testVocab.add(word);
131  }
132  }
133  targetVocabFile = params.vocab;
134  }
135 
136  private void loadDir(String dirString) throws IOException {
137  File dir = new File(dirString);
138  Configuration conf = new Configuration();
139  CacheConfig cacheConf = new CacheConfig(conf);
140  if (!dir.isDirectory()) {
141  throw new IOException(dirString + " is not a directory!");
142  }
143  File[] names = dir.listFiles(new FilenameFilter() {
144  @Override
145  public boolean accept(File dir, String name) {
146  return name.endsWith("hfile");
147  }
148  });
149  if (names.length == 0) {
150  throw new IOException("No hfiles in " + dirString);
151  }
152  bfs = new BloomFilter[names.length];
153  readers = new HFileRuleReader[names.length];
154  for (File file : names) {
155  String name = file.getName();
156  int i = Integer.parseInt(name.substring(7, 12));
157  HFile.Reader hfReader = HFile.createReader(
158  FileSystem.getLocal(conf), new Path(file.getPath()),
159  cacheConf);
160  bfs[i] = BloomFilterFactory.createFromMeta(
161  hfReader.getGeneralBloomFilterMetadata(), hfReader);
162  readers[i] = new HFileRuleReader(hfReader);
163  }
164  }
165 
166  private Set<Rule> getPassThroughRules(String passThroughRulesFileName) throws IOException {
167  Set<Rule> res = new HashSet<>();
168  try (BufferedReader br = new BufferedReader(new FileReader(
169  passThroughRulesFileName))) {
170  String line;
171  Pattern regex = Pattern.compile(".*: (.*) # (.*)");
172  Matcher matcher;
173  while ((line = br.readLine()) != null) {
174  matcher = regex.matcher(line);
175  if (matcher.matches()) {
176  String[] sourceString = matcher.group(1).split(" ");
177  String[] targetString = matcher.group(2).split(" ");
178  if (sourceString.length != targetString.length) {
179  throw new IOException("Malformed pass through rules file: "
180  + passThroughRulesFileName);
181  }
182  List<Symbol> source = new ArrayList<Symbol>();
183  List<Symbol> target = new ArrayList<Symbol>();
184  int i = 0;
185  while (i < sourceString.length) {
186  if (i % maxSourcePhrase == 0 && i > 0) {
187  Rule rule = new Rule(source, target);
188  res.add(rule);
189  source.clear();
190  target.clear();
191  }
192  source.add(Symbol.deserialise(Integer
193  .parseInt(sourceString[i])));
194  target.add(Symbol.deserialise(Integer
195  .parseInt(targetString[i])));
196  i++;
197  }
198  Rule rule = new Rule(source, target);
199  res.add(rule);
200  } else {
201  throw new IOException("Malformed pass through rules file: "
202  + passThroughRulesFileName);
203  }
204  }
205  }
206  return res;
207  }
208 
209 
210  private Set<RuleString> getTestVocab(String testFile)
211  throws FileNotFoundException, IOException {
212  Set<RuleString> res = new HashSet<>();
213  try (BufferedReader br = new BufferedReader(new FileReader(testFile))) {
214  String line;
215  while ((line = br.readLine()) != null) {
216  String[] parts = line.split("\\s+");
217  for (String part : parts) {
218  RuleString v = new RuleString();
219  v.add(Symbol.deserialise(part));
220  res.add(v);
221  }
222  }
223  }
224  return res;
225  }
226 
227  public void writeGlueRules(BufferedWriter out) {
228  writeRule(EnumRuleType.S, new Rule("S_D_X S_D_X"),
229  fReg.getDefaultDeleteGlueFeatures(), out);
230  writeRule(EnumRuleType.S, new Rule("S_X S_X"),
231  fReg.getDefaultGlueFeatures(), out);
232  writeRule(EnumRuleType.X, new Rule("V V"),
233  new TreeMap<Integer, Double>(), out);
234  writeRule(EnumRuleType.S, new Rule("1 1"),
236  writeRule(EnumRuleType.X, new Rule("2 2"),
238  }
239 
240  private List<Set<RuleString>> generateQueries(String testFileName,
241  CLI.RuleRetrieverParameters params) throws IOException {
242  PatternInstanceCreator patternInstanceCreator = new PatternInstanceCreator(
243  params, filter.getPermittedSourcePatterns());
244  List<Set<RuleString>> queries = new ArrayList<>(readers.length);
245  for (int i = 0; i < readers.length; ++i) {
246  queries.add(new HashSet<RuleString>());
247  }
248  targetSideVocab.add(Collections.emptySet());
249  try (BufferedReader reader = new BufferedReader(new FileReader(
250  testFileName))) {
251  int count = 1;
252  for (String line = reader.readLine(); line != null; line = reader
253  .readLine(), ++count) {
254  targetSideVocab.add(new HashSet<>());
255  StopWatch stopWatch = new StopWatch();
256  stopWatch.start();
257  Set<Rule> rules = patternInstanceCreator
258  .createSourcePatternInstances(line);
259  Collection<RuleString> sources = new ArrayList<>(rules.size());
260  for (Rule rule : rules) {
261  RuleString source = rule.source();
262  sources.add(source);
263  if (!sourceToSentenceId.containsKey(source)) {
264  sourceToSentenceId.put(source, new HashSet<>());
265  }
266  sourceToSentenceId.get(source).add(count);
267  }
268  for (RuleString source : sources) {
269  if (filter.filterSource(source)) {
270  continue;
271  }
272  int partition = partitioner.getPartition(source, null,
273  queries.size());
274  queries.get(partition).add(source);
275  }
276  System.out.println("Creating patterns for line " + count
277  + " took " + (double) stopWatch.getTime() / 1000d
278  + " seconds");
279  }
280  }
281  return queries;
282  }
283 
284  public void writeRule(EnumRuleType LHS, Rule rule,
285  Map<Integer, Double> processedFeatures, BufferedWriter out) {
286  StringBuilder res = new StringBuilder();
287  res.append(LHS.getLhs()).append(" ").append(rule);
288  for (int i = 0; i < noOfFeatures; ++i) {
289  // Features are 1-indexed
290  double featureValue = processedFeatures.containsKey(i + 1) ? -1
291  * processedFeatures.get(i + 1) : 0.0;
292  if (Math.floor(featureValue) == featureValue) {
293  res.append(String.format(" %d", (int) featureValue));
294  } else {
295  res.append(String.format(" %f", featureValue));
296  }
297  }
298  res.append("\n");
299  synchronized (out) {
300  try {
301  out.write(res.toString());
302  } catch (IOException e) {
303  e.printStackTrace();
304  System.exit(1);
305  }
306  }
307  }
308 
317  public static void main(String[] args) throws FileNotFoundException,
318  IOException, InterruptedException, IllegalArgumentException,
319  IllegalAccessException {
321  try {
322  Util.parseCommandLine(args, params);
323  } catch (ParameterException e) {
324  return;
325  }
326  RuleRetriever retriever = new RuleRetriever();
327  retriever.loadDir(params.hfile);
328  retriever.setup(params.testFile, params);
329  StopWatch stopWatch = new StopWatch();
330  stopWatch.start();
331  System.err.println("Generating query");
332  List<Set<RuleString>> queries = retriever.generateQueries(
333  params.testFile, params);
334  System.err.printf("Query took %d seconds to generate\n",
335  stopWatch.getTime() / 1000);
336  System.err.println("Executing queries");
337  try (BufferedWriter out = new BufferedWriter(new OutputStreamWriter(
338  new GZIPOutputStream(new FileOutputStream(params.rules))))) {
339  ExecutorService threadPool = Executors
340  .newFixedThreadPool(params.retrievalThreads);
341  for (int i = 0; i < queries.size(); ++i) {
342  HFileRuleQuery query = new HFileRuleQuery(retriever.readers[i],
343  retriever.bfs[i], out, queries.get(i), retriever,
344  params.sp);
345  threadPool.execute(query);
346  }
347  threadPool.shutdown();
348  threadPool.awaitTermination(1, TimeUnit.DAYS);
349  // Add pass through rule not already found in query
350  for (Rule passThroughRule : retriever.passThroughRules) {
351  if (!retriever.foundPassThroughRules.contains(passThroughRule)) {
352  retriever.writeRule(EnumRuleType.X, passThroughRule,
353  retriever.fReg.getDefaultPassThroughRuleFeatures(),
354  out);
355  }
356  }
357  // Add Deletion and OOV rules
358  Rule deletionRuleWritable = new Rule();
359  RuleString dr = new RuleString();
360  dr.add((Symbol)(dr$.MODULE$));
361  deletionRuleWritable.setTarget(dr);
362  Rule oovRuleWritable = new Rule();
363  RuleString oov = new RuleString();
364  oov.add((Symbol)(oov$.MODULE$));
365  oovRuleWritable.setTarget(oov);
366  for (RuleString source : retriever.testVocab) {
367  // If in the vocab then write deletion rule
368  if (retriever.foundTestVocab.contains(source)) {
369  deletionRuleWritable.setSource(source);
370  retriever.writeRule(EnumRuleType.D, deletionRuleWritable,
371  retriever.fReg.getDefaultDeletionFeatures(), out);
372  // Otherwise is an OOV
373  } else {
374  oovRuleWritable.setSource(source);
375  retriever.writeRule(EnumRuleType.X, oovRuleWritable,
376  retriever.fReg.getDefaultOOVFeatures(), out);
377  }
378  }
379  // Glue rules
380  retriever.writeGlueRules(out);
381  }
382  System.out.println(retriever.foundPassThroughRules);
383  System.out.println(retriever.foundTestVocab);
384  if (retriever.targetVocabFile != null) {
385  try (BufferedWriter out = new BufferedWriter(new FileWriter(
386  retriever.targetVocabFile))) {
387  for (Set<Symbol> words : retriever.targetSideVocab.subList(1,
388  retriever.targetSideVocab.size())) {
389  out.write("1 2"); // Include the start and end symbols
390  for(Symbol word : words){
391  out.write(" " + word);
392  }
393  out.write("\n");
394  }
395  }
396  }
397  }
398 
399 }
static JCommander parseCommandLine(String[] args, Object params)
Definition: Util.java:85
Map< Integer, Double > getDefaultPassThroughRuleFeatures()
Map< Integer, Double > getDefaultGlueStartOrEndFeatures()
Map< Integer, Double > getDefaultDeleteGlueFeatures()
Map< Integer, Double > getDefaultDeletionFeatures()
void writeRule(EnumRuleType LHS, Rule rule, Map< Integer, Double > processedFeatures, BufferedWriter out)