16 package uk.ac.cam.eng.rule.retrieval;
18 import java.io.BufferedReader;
19 import java.io.BufferedWriter;
21 import java.io.FileNotFoundException;
22 import java.io.FileOutputStream;
23 import java.io.FileReader;
24 import java.io.FileWriter;
25 import java.io.FilenameFilter;
26 import java.io.IOException;
27 import java.io.OutputStreamWriter;
28 import java.util.ArrayList;
29 import java.util.Collection;
30 import java.util.Collections;
31 import java.util.HashMap;
32 import java.util.HashSet;
33 import java.util.List;
36 import java.util.TreeMap;
37 import java.util.concurrent.ExecutorService;
38 import java.util.concurrent.Executors;
39 import java.util.concurrent.TimeUnit;
40 import java.util.regex.Matcher;
41 import java.util.regex.Pattern;
42 import java.util.zip.GZIPOutputStream;
44 import org.apache.commons.lang.time.StopWatch;
45 import org.apache.hadoop.conf.Configuration;
46 import org.apache.hadoop.fs.FileSystem;
47 import org.apache.hadoop.fs.Path;
48 import org.apache.hadoop.hbase.io.hfile.CacheConfig;
49 import org.apache.hadoop.hbase.io.hfile.HFile;
50 import org.apache.hadoop.hbase.util.BloomFilter;
51 import org.apache.hadoop.hbase.util.BloomFilterFactory;
52 import org.apache.hadoop.io.NullWritable;
53 import org.apache.hadoop.mapreduce.Partitioner;
54 import org.apache.hadoop.mapreduce.lib.partition.HashPartitioner;
67 import com.beust.jcommander.ParameterException;
83 private BloomFilter[] bfs;
87 private Partitioner<RuleString, NullWritable> partitioner =
new HashPartitioner<>();
91 Set<Rule> passThroughRules =
new HashSet<>();
93 Set<Rule> foundPassThroughRules =
new HashSet<>();
95 Set<RuleString> testVocab =
new HashSet<>();
97 Set<RuleString> foundTestVocab =
new HashSet<>();
99 Map<RuleString, Set<Integer>> sourceToSentenceId =
new HashMap<>();
101 List<Set<Symbol>> targetSideVocab =
new ArrayList<>();
103 private int maxSourcePhrase;
107 private int noOfFeatures;
109 private String targetVocabFile;
112 throws FileNotFoundException, IOException {
114 params.rp.prov.provenance);
117 filter =
new RuleFilter(params.fp,
new Configuration());
118 maxSourcePhrase = params.rp.maxSourcePhrase;
119 Set<RuleString> passThroughVocab =
new HashSet<>();
120 Set<RuleString> fullTestVocab = getTestVocab(testFile);
121 for(Rule r : getPassThroughRules(params.passThroughRules)){
122 if(fullTestVocab.contains(r.source())){
123 passThroughVocab.add(r.source());
124 passThroughRules.add(r);
127 testVocab =
new HashSet<>();
128 for (RuleString word : fullTestVocab) {
129 if (!passThroughVocab.contains(word)) {
133 targetVocabFile = params.vocab;
136 private void loadDir(String dirString)
throws IOException {
137 File dir =
new File(dirString);
138 Configuration conf =
new Configuration();
139 CacheConfig cacheConf =
new CacheConfig(conf);
140 if (!dir.isDirectory()) {
141 throw new IOException(dirString +
" is not a directory!");
143 File[] names = dir.listFiles(
new FilenameFilter() {
145 public boolean accept(File dir, String name) {
146 return name.endsWith(
"hfile");
149 if (names.length == 0) {
150 throw new IOException(
"No hfiles in " + dirString);
152 bfs =
new BloomFilter[names.length];
154 for (File file : names) {
155 String name = file.getName();
156 int i = Integer.parseInt(name.substring(7, 12));
157 HFile.Reader hfReader = HFile.createReader(
158 FileSystem.getLocal(conf),
new Path(file.getPath()),
160 bfs[i] = BloomFilterFactory.createFromMeta(
161 hfReader.getGeneralBloomFilterMetadata(), hfReader);
166 private Set<Rule> getPassThroughRules(String passThroughRulesFileName)
throws IOException {
167 Set<Rule> res =
new HashSet<>();
168 try (BufferedReader br =
new BufferedReader(
new FileReader(
169 passThroughRulesFileName))) {
171 Pattern regex = Pattern.compile(
".*: (.*) # (.*)");
173 while ((line = br.readLine()) != null) {
174 matcher = regex.matcher(line);
175 if (matcher.matches()) {
176 String[] sourceString = matcher.group(1).split(
" ");
177 String[] targetString = matcher.group(2).split(
" ");
178 if (sourceString.length != targetString.length) {
179 throw new IOException(
"Malformed pass through rules file: " 180 + passThroughRulesFileName);
182 List<Symbol> source =
new ArrayList<Symbol>();
183 List<Symbol> target =
new ArrayList<Symbol>();
185 while (i < sourceString.length) {
186 if (i % maxSourcePhrase == 0 && i > 0) {
187 Rule rule =
new Rule(source, target);
192 source.add(Symbol.deserialise(Integer
193 .parseInt(sourceString[i])));
194 target.add(Symbol.deserialise(Integer
195 .parseInt(targetString[i])));
198 Rule rule =
new Rule(source, target);
201 throw new IOException(
"Malformed pass through rules file: " 202 + passThroughRulesFileName);
210 private Set<RuleString> getTestVocab(String testFile)
211 throws FileNotFoundException, IOException {
212 Set<RuleString> res =
new HashSet<>();
213 try (BufferedReader br =
new BufferedReader(
new FileReader(testFile))) {
215 while ((line = br.readLine()) != null) {
216 String[] parts = line.split(
"\\s+");
217 for (String part : parts) {
218 RuleString v =
new RuleString();
219 v.add(Symbol.deserialise(part));
233 new TreeMap<Integer, Double>(), out);
240 private List<Set<RuleString>> generateQueries(String testFileName,
242 PatternInstanceCreator patternInstanceCreator =
new PatternInstanceCreator(
243 params, filter.getPermittedSourcePatterns());
244 List<Set<RuleString>> queries =
new ArrayList<>(readers.length);
245 for (
int i = 0; i < readers.length; ++i) {
246 queries.add(
new HashSet<RuleString>());
248 targetSideVocab.add(Collections.emptySet());
249 try (BufferedReader reader =
new BufferedReader(
new FileReader(
252 for (String line = reader.readLine(); line != null; line = reader
253 .readLine(), ++count) {
254 targetSideVocab.add(
new HashSet<>());
255 StopWatch stopWatch =
new StopWatch();
257 Set<Rule> rules = patternInstanceCreator
258 .createSourcePatternInstances(line);
259 Collection<RuleString> sources =
new ArrayList<>(rules.size());
260 for (Rule rule : rules) {
261 RuleString source = rule.source();
263 if (!sourceToSentenceId.containsKey(source)) {
264 sourceToSentenceId.put(source,
new HashSet<>());
266 sourceToSentenceId.get(source).add(count);
268 for (RuleString source : sources) {
269 if (filter.filterSource(source)) {
272 int partition = partitioner.getPartition(source, null,
274 queries.get(partition).add(source);
276 System.out.println(
"Creating patterns for line " + count
277 +
" took " + (
double) stopWatch.getTime() / 1000d
285 Map<Integer, Double> processedFeatures, BufferedWriter out) {
286 StringBuilder res =
new StringBuilder();
287 res.append(LHS.
getLhs()).append(
" ").append(rule);
288 for (
int i = 0; i < noOfFeatures; ++i) {
290 double featureValue = processedFeatures.containsKey(i + 1) ? -1
291 * processedFeatures.get(i + 1) : 0.0;
292 if (Math.floor(featureValue) == featureValue) {
293 res.append(String.format(
" %d", (
int) featureValue));
295 res.append(String.format(
" %f", featureValue));
301 out.write(res.toString());
302 }
catch (IOException e) {
317 public static void main(String[] args)
throws FileNotFoundException,
318 IOException, InterruptedException, IllegalArgumentException,
319 IllegalAccessException {
323 }
catch (ParameterException e) {
327 retriever.loadDir(params.hfile);
328 retriever.setup(params.testFile, params);
329 StopWatch stopWatch =
new StopWatch();
331 System.err.println(
"Generating query");
332 List<Set<RuleString>> queries = retriever.generateQueries(
333 params.testFile, params);
334 System.err.printf(
"Query took %d seconds to generate\n",
335 stopWatch.getTime() / 1000);
336 System.err.println(
"Executing queries");
337 try (BufferedWriter out =
new BufferedWriter(
new OutputStreamWriter(
338 new GZIPOutputStream(
new FileOutputStream(params.rules))))) {
339 ExecutorService threadPool = Executors
340 .newFixedThreadPool(params.retrievalThreads);
341 for (
int i = 0; i < queries.size(); ++i) {
342 HFileRuleQuery query =
new HFileRuleQuery(retriever.readers[i],
343 retriever.bfs[i], out, queries.get(i), retriever,
345 threadPool.execute(query);
347 threadPool.shutdown();
348 threadPool.awaitTermination(1, TimeUnit.DAYS);
350 for (Rule passThroughRule : retriever.passThroughRules) {
351 if (!retriever.foundPassThroughRules.contains(passThroughRule)) {
358 Rule deletionRuleWritable =
new Rule();
359 RuleString dr =
new RuleString();
360 dr.add((Symbol)(dr$.MODULE$));
361 deletionRuleWritable.setTarget(dr);
362 Rule oovRuleWritable =
new Rule();
363 RuleString oov =
new RuleString();
364 oov.add((Symbol)(oov$.MODULE$));
365 oovRuleWritable.setTarget(oov);
366 for (RuleString source : retriever.testVocab) {
368 if (retriever.foundTestVocab.contains(source)) {
369 deletionRuleWritable.setSource(source);
374 oovRuleWritable.setSource(source);
382 System.out.println(retriever.foundPassThroughRules);
383 System.out.println(retriever.foundTestVocab);
384 if (retriever.targetVocabFile != null) {
385 try (BufferedWriter out =
new BufferedWriter(
new FileWriter(
386 retriever.targetVocabFile))) {
387 for (Set<Symbol> words : retriever.targetSideVocab.subList(1,
388 retriever.targetSideVocab.size())) {
390 for(Symbol word : words){
391 out.write(
" " + word);
static void main(String[] args)
Map< Integer, Double > getDefaultPassThroughRuleFeatures()
int[] getFeatureIndices(Feature...features)
List< Feature > getFeatures()
Map< Integer, Double > getDefaultGlueStartOrEndFeatures()
Map< Integer, Double > getDefaultOOVFeatures()
Map< Integer, Double > getDefaultGlueFeatures()
Map< Integer, Double > getDefaultDeleteGlueFeatures()
Map< Integer, Double > getDefaultDeletionFeatures()
void writeGlueRules(BufferedWriter out)
void writeRule(EnumRuleType LHS, Rule rule, Map< Integer, Double > processedFeatures, BufferedWriter out)