20 package uk.ac.cam.eng.rule.filtering;
24 import java.io.BufferedReader;
25 import java.io.FileNotFoundException;
26 import java.io.IOException;
27 import java.io.InputStreamReader;
28 import java.util.Collection;
29 import java.util.Collections;
30 import java.util.Comparator;
31 import java.util.HashMap;
32 import java.util.HashSet;
33 import java.util.LinkedList;
34 import java.util.List;
37 import java.util.function.Consumer;
39 import org.apache.hadoop.conf.Configuration;
40 import org.apache.hadoop.fs.FileSystem;
41 import org.apache.hadoop.fs.Path;
42 import org.apache.hadoop.io.ByteWritable;
43 import org.apache.hadoop.io.IntWritable;
64 private static class RuleCountComparator
implements 65 Comparator<Pair<Rule, RuleData>> {
67 private final ByteWritable countIndex;
69 public RuleCountComparator(
int countIndex) {
70 this.countIndex =
new ByteWritable((byte) countIndex);
75 int aValue = a.
getSecond().getProvCounts().containsKey(countIndex) ? a
76 .
getSecond().getProvCounts().get(countIndex).get()
78 int bValue = b.
getSecond().getProvCounts().containsKey(countIndex) ? b
79 .
getSecond().getProvCounts().get(countIndex).get()
82 int countDiff = bValue < aValue ? -1 : (bValue == aValue ? 0 : 1);
91 private static class SourcePhraseConstraint {
97 public SourcePhraseConstraint(String nOcc, String nTrans) {
98 this.nOcc = Integer.parseInt(nOcc);
99 this.nTrans = Integer.parseInt(nTrans);
103 final private double minSource2TargetPhraseLog;
104 final private double minTarget2SourcePhraseLog;
105 final private double minSource2TargetRuleLog;
106 final private double minTarget2SourceRuleLog;
108 private Set<RulePattern> allowedPatterns =
new HashSet<RulePattern>();
110 private Map<SidePattern, SourcePhraseConstraint> sourcePatternConstraints =
new HashMap<>();
112 boolean provenanceUnion;
115 throws FileNotFoundException, IOException {
116 provenanceUnion = params.provenanceUnion;
117 minSource2TargetPhraseLog = Math.log(params.minSource2TargetPhrase);
118 minTarget2SourcePhraseLog = Math.log(params.minTarget2SourcePhrase);
119 minSource2TargetRuleLog = Math.log(params.minSource2TargetRule);
120 minTarget2SourceRuleLog = Math.log(params.minTarget2SourceRule);
121 loadConfig(params.allowedPatternsFile, conf,
124 params.sourcePatterns,
127 String[] parts = line.split(
" ");
128 if (parts.length != 3) {
129 throw new RuntimeException(
130 "line should have 3 fields (source pattern, # of occurances, # of translations): " 133 sourcePatternConstraints.put(
135 new SourcePhraseConstraint(parts[1], parts[2]));
139 private void loadConfig(String fileName, Configuration conf,
140 Consumer<String> block)
throws FileNotFoundException, IOException {
141 Path path =
new Path(fileName);
142 FileSystem fs = path.getFileSystem(conf);
143 try (BufferedReader br =
new BufferedReader(
new InputStreamReader(
145 for (String line = br.readLine(); line != null; line = br
147 if (line.startsWith(
"#") || line.isEmpty()) {
159 }
else if (sourcePatternConstraints.containsKey(sourcePattern)) {
165 private Map<Rule, RuleData> filterRulesBySource(
SidePattern sourcePattern,
167 Map<Rule, RuleData> results =
new HashMap<>();
170 rules.forEach(entry -> results.put(entry.getFirst(),
174 int numberTranslations = 0;
175 int numberTranslationsMonotone = 0;
176 int numberTranslationsInvert = 0;
178 double nTransConstraint = sourcePatternConstraints.get(sourcePattern).nTrans;
179 ByteWritable countIndex =
new ByteWritable((byte) provMapping);
185 int count = (int) entry.getSecond().getProvCounts().get(countIndex)
187 boolean notTied = count != prevCount;
191 && nTransConstraint <= numberTranslationsMonotone && nTransConstraint <= numberTranslationsInvert) || (!moreThan1NT && nTransConstraint <= numberTranslations))) {
194 results.put(entry.getFirst(), entry.getSecond());
197 ++numberTranslationsInvert;
199 ++numberTranslationsMonotone;
202 numberTranslations++;
219 if (!data.
getFeatures().get(s2t).containsKey(provIW)) {
223 if (!(sourcePattern.
isPhrase() || allowedPatterns.contains(rulePattern))) {
226 double source2targetProbability = data.
getFeatures().get(s2t)
228 double target2sourceProbability = data.
getFeatures().get(t2s)
231 .
get(
new ByteWritable((byte) provMapping)).get();
235 if (source2targetProbability <= minSource2TargetPhraseLog) {
239 if (target2sourceProbability <= minTarget2SourcePhraseLog) {
244 if (source2targetProbability <= minSource2TargetRuleLog) {
248 if (target2sourceProbability <= minTarget2SourceRuleLog) {
252 if (numberOfOccurrences < sourcePatternConstraints
253 .
get(sourcePattern).nOcc) {
263 Set<Integer> provenances =
new HashSet<>();
264 if (!provenanceUnion) {
268 for (ByteWritable prov : entry.getSecond().getProvCounts()
270 provenances.add((
int) prov.get());
274 Map<Rule, RuleData> filtered =
new HashMap<>();
275 for (
int i : provenances) {
280 List<Pair<Rule, RuleData>> rules =
new LinkedList<>();
281 Collections.sort(toFilter,
new RuleCountComparator(i));
283 Rule rule = entry.getFirst();
285 if (!filterRule(s2t, t2s, sourcePattern, rule, data, i)) {
289 filtered.putAll(filterRulesBySource(sourcePattern, rules, i));
295 return sourcePatternConstraints.keySet();
static< F, S > Pair< F, S > createPair(F first, S second)
SOURCE2TARGET_PROBABILITY
Map< Rule, RuleData > filter(SidePattern sourcePattern, List< Pair< Rule, RuleData >> toFilter)
static RulePattern getPattern(Rule rule)
PROVENANCE_SOURCE2TARGET_PROBABILITY
boolean filterSource(RuleString source)
TARGET2SOURCE_PROBABILITY
static RulePattern parsePattern(String patternString)
Collection< SidePattern > getPermittedSourcePatterns()
static SidePattern parsePattern(String patternString)
RuleFilter(CLI.FilterParams params, Configuration conf)
PROVENANCE_TARGET2SOURCE_PROBABILITY