Cambridge SMT System
RuleFilter.java
Go to the documentation of this file.
1 /*******************************************************************************
2  * Licensed under the Apache License, Version 2.0 (the "License");
3  * you may not use these files except in compliance with the License.
4  * You may obtain a copy of the License at
5  *
6  * http://www.apache.org/licenses/LICENSE-2.0
7  *
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  * Copyright 2014 - Juan Pino, Aurelien Waite, William Byrne
15  *******************************************************************************/
20 package uk.ac.cam.eng.rule.filtering;
21 
22 // TODO remove hard coded indices
23 
24 import java.io.BufferedReader;
25 import java.io.FileNotFoundException;
26 import java.io.IOException;
27 import java.io.InputStreamReader;
28 import java.util.Collection;
29 import java.util.Collections;
30 import java.util.Comparator;
31 import java.util.HashMap;
32 import java.util.HashSet;
33 import java.util.LinkedList;
34 import java.util.List;
35 import java.util.Map;
36 import java.util.Set;
37 import java.util.function.Consumer;
38 
39 import org.apache.hadoop.conf.Configuration;
40 import org.apache.hadoop.fs.FileSystem;
41 import org.apache.hadoop.fs.Path;
42 import org.apache.hadoop.io.ByteWritable;
43 import org.apache.hadoop.io.IntWritable;
44 
45 import uk.ac.cam.eng.extraction.Rule;
46 import uk.ac.cam.eng.extraction.RuleString;
52 import uk.ac.cam.eng.util.CLI;
53 import uk.ac.cam.eng.util.Pair;
54 
62 public class RuleFilter {
63 
64  private static class RuleCountComparator implements
65  Comparator<Pair<Rule, RuleData>> {
66 
67  private final ByteWritable countIndex;
68 
69  public RuleCountComparator(int countIndex) {
70  this.countIndex = new ByteWritable((byte) countIndex);
71  }
72 
73  @Override
74  public int compare(Pair<Rule, RuleData> a, Pair<Rule, RuleData> b) {
75  int aValue = a.getSecond().getProvCounts().containsKey(countIndex) ? a
76  .getSecond().getProvCounts().get(countIndex).get()
77  : 0;
78  int bValue = b.getSecond().getProvCounts().containsKey(countIndex) ? b
79  .getSecond().getProvCounts().get(countIndex).get()
80  : 0;
81  // We want descending order!
82  int countDiff = bValue < aValue ? -1 : (bValue == aValue ? 0 : 1);
83  if (countDiff != 0) {
84  return countDiff;
85  } else {
86  return (a.getFirst().compareTo(b.getFirst()));
87  }
88  }
89  }
90 
91  private static class SourcePhraseConstraint {
92  // Number of Occurrences
93  final int nOcc;
94  // Number of Translations
95  final int nTrans;
96 
97  public SourcePhraseConstraint(String nOcc, String nTrans) {
98  this.nOcc = Integer.parseInt(nOcc);
99  this.nTrans = Integer.parseInt(nTrans);
100  }
101  }
102 
103  final private double minSource2TargetPhraseLog;
104  final private double minTarget2SourcePhraseLog;
105  final private double minSource2TargetRuleLog;
106  final private double minTarget2SourceRuleLog;
107  // allowed patterns
108  private Set<RulePattern> allowedPatterns = new HashSet<RulePattern>();
109 
110  private Map<SidePattern, SourcePhraseConstraint> sourcePatternConstraints = new HashMap<>();
111 
112  boolean provenanceUnion;
113 
114  public RuleFilter(CLI.FilterParams params, Configuration conf)
115  throws FileNotFoundException, IOException {
116  provenanceUnion = params.provenanceUnion;
117  minSource2TargetPhraseLog = Math.log(params.minSource2TargetPhrase);
118  minTarget2SourcePhraseLog = Math.log(params.minTarget2SourcePhrase);
119  minSource2TargetRuleLog = Math.log(params.minSource2TargetRule);
120  minTarget2SourceRuleLog = Math.log(params.minTarget2SourceRule);
121  loadConfig(params.allowedPatternsFile, conf,
122  line -> allowedPatterns.add(RulePattern.parsePattern(line)));
123  loadConfig(
124  params.sourcePatterns,
125  conf,
126  line -> {
127  String[] parts = line.split(" ");
128  if (parts.length != 3) {
129  throw new RuntimeException(
130  "line should have 3 fields (source pattern, # of occurances, # of translations): "
131  + line);
132  }
133  sourcePatternConstraints.put(
134  SidePattern.parsePattern(parts[0]),
135  new SourcePhraseConstraint(parts[1], parts[2]));
136  });
137  }
138 
139  private void loadConfig(String fileName, Configuration conf,
140  Consumer<String> block) throws FileNotFoundException, IOException {
141  Path path = new Path(fileName);
142  FileSystem fs = path.getFileSystem(conf);
143  try (BufferedReader br = new BufferedReader(new InputStreamReader(
144  fs.open(path)))) {
145  for (String line = br.readLine(); line != null; line = br
146  .readLine()) {
147  if (line.startsWith("#") || line.isEmpty()) {
148  continue;
149  }
150  block.accept(line);
151  }
152  }
153  }
154 
155  public boolean filterSource(RuleString source) {
156  SidePattern sourcePattern = source.toPattern();
157  if (sourcePattern.isPhrase()) {
158  return false;
159  } else if (sourcePatternConstraints.containsKey(sourcePattern)) {
160  return false;
161  }
162  return true;
163  }
164 
165  private Map<Rule, RuleData> filterRulesBySource(SidePattern sourcePattern,
166  List<Pair<Rule, RuleData>> rules, int provMapping) {
167  Map<Rule, RuleData> results = new HashMap<>();
168  // If the source side is a phrase, then we want everything
169  if (sourcePattern.isPhrase()) {
170  rules.forEach(entry -> results.put(entry.getFirst(),
171  entry.getSecond()));
172  return results;
173  }
174  int numberTranslations = 0;
175  int numberTranslationsMonotone = 0; // case with more than 1 NT
176  int numberTranslationsInvert = 0;
177  int prevCount = -1;
178  double nTransConstraint = sourcePatternConstraints.get(sourcePattern).nTrans;
179  ByteWritable countIndex = new ByteWritable((byte) provMapping);
180  for (Pair<Rule, RuleData> entry : rules) {
181  // number of translations per source threshold
182  // in case of ties we either keep or don't keep the ties
183  // depending on the config
184  RulePattern rulePattern = RulePattern.getPattern(entry.getFirst());
185  int count = (int) entry.getSecond().getProvCounts().get(countIndex)
186  .get();
187  boolean notTied = count != prevCount;
188  boolean moreThan1NT = sourcePattern.hasMoreThan1NT();
189  if (notTied
190  && ((moreThan1NT
191  && nTransConstraint <= numberTranslationsMonotone && nTransConstraint <= numberTranslationsInvert) || (!moreThan1NT && nTransConstraint <= numberTranslations))) {
192  break;
193  }
194  results.put(entry.getFirst(), entry.getSecond());
195  if (moreThan1NT) {
196  if (rulePattern.isSwappingNT()) {
197  ++numberTranslationsInvert;
198  } else {
199  ++numberTranslationsMonotone;
200  }
201  }
202  numberTranslations++;
203  prevCount = count;
204  }
205  return results;
206  }
207 
214  private boolean filterRule(Feature s2t, Feature t2s,
215  SidePattern sourcePattern, Rule rule, RuleData data, int provMapping) {
216  IntWritable provIW = IntWritableCache.createIntWritable(provMapping);
217  // Immediately filter if there is data for this rule under this
218  // provenance
219  if (!data.getFeatures().get(s2t).containsKey(provIW)) {
220  return true;
221  }
222  RulePattern rulePattern = RulePattern.getPattern(rule);
223  if (!(sourcePattern.isPhrase() || allowedPatterns.contains(rulePattern))) {
224  return true;
225  }
226  double source2targetProbability = data.getFeatures().get(s2t)
227  .get(provIW).get();
228  double target2sourceProbability = data.getFeatures().get(t2s)
229  .get(provIW).get();
230  int numberOfOccurrences = (int) data.getProvCounts()
231  .get(new ByteWritable((byte) provMapping)).get();
232 
233  if (sourcePattern.isPhrase()) {
234  // source-to-target threshold
235  if (source2targetProbability <= minSource2TargetPhraseLog) {
236  return true;
237  }
238  // target-to-source threshold
239  if (target2sourceProbability <= minTarget2SourcePhraseLog) {
240  return true;
241  }
242  } else {
243  // source-to-target threshold
244  if (source2targetProbability <= minSource2TargetRuleLog) {
245  return true;
246  }
247  // target-to-source threshold
248  if (target2sourceProbability <= minTarget2SourceRuleLog) {
249  return true;
250  }
251  // minimum number of occurrence threshold
252  if (numberOfOccurrences < sourcePatternConstraints
253  .get(sourcePattern).nOcc) {
254  return true;
255  }
256  }
257  return false;
258  }
259 
260  public Map<Rule, RuleData> filter(SidePattern sourcePattern,
261  List<Pair<Rule, RuleData>> toFilter) {
262  // Establish the provenances used in these rules
263  Set<Integer> provenances = new HashSet<>();
264  if (!provenanceUnion) {
265  provenances.add(0);
266  } else {
267  for (Pair<?, RuleData> entry : toFilter) {
268  for (ByteWritable prov : entry.getSecond().getProvCounts()
269  .keySet()) {
270  provenances.add((int) prov.get());
271  }
272  }
273  }
274  Map<Rule, RuleData> filtered = new HashMap<>();
275  for (int i : provenances) {
280  List<Pair<Rule, RuleData>> rules = new LinkedList<>();
281  Collections.sort(toFilter, new RuleCountComparator(i));
282  for (Pair<Rule, RuleData> entry : toFilter) {
283  Rule rule = entry.getFirst();
284  RuleData data = entry.getSecond();
285  if (!filterRule(s2t, t2s, sourcePattern, rule, data, i)) {
286  rules.add(Pair.createPair(new Rule(rule), data));
287  }
288  }
289  filtered.putAll(filterRulesBySource(sourcePattern, rules, i));
290  }
291  return filtered;
292  }
293 
294  public Collection<SidePattern> getPermittedSourcePatterns() {
295  return sourcePatternConstraints.keySet();
296  }
297 }
static< F, S > Pair< F, S > createPair(F first, S second)
Definition: Pair.java:46
Map< Rule, RuleData > filter(SidePattern sourcePattern, List< Pair< Rule, RuleData >> toFilter)
static RulePattern getPattern(Rule rule)
boolean filterSource(RuleString source)
static RulePattern parsePattern(String patternString)
Collection< SidePattern > getPermittedSourcePatterns()
static SidePattern parsePattern(String patternString)
RuleFilter(CLI.FilterParams params, Configuration conf)