16 package uk.ac.cam.eng.extraction.hadoop.features.phrase;
18 import java.io.ByteArrayOutputStream;
19 import java.io.DataOutputStream;
20 import java.io.IOException;
21 import java.util.ArrayList;
22 import java.util.Iterator;
23 import java.util.List;
24 import java.util.Map.Entry;
26 import org.apache.hadoop.io.ByteWritable;
27 import org.apache.hadoop.io.DataInputBuffer;
28 import org.apache.hadoop.io.IntWritable;
29 import org.apache.hadoop.io.WritableComparable;
30 import org.apache.hadoop.io.WritableComparator;
31 import org.apache.hadoop.io.WritableUtils;
32 import org.apache.hadoop.mapreduce.Reducer;
50 class MarginalReducer
extends 51 Reducer<Rule, ProvenanceCountMap, Rule, FeatureMap> {
60 private static class RuleWritableSplit {
75 private static class MRComparatorState {
77 DataInputBuffer inBytes =
new DataInputBuffer();
79 RuleWritableSplit split1 =
new RuleWritableSplit();
81 RuleWritableSplit split2 =
new RuleWritableSplit();
92 public static abstract class MRComparator
extends WritableComparator {
94 private ThreadLocal<MRComparatorState> threadLocalState =
new ThreadLocal<>();
96 public MRComparator() {
100 protected abstract boolean isSource2Target();
102 private MRComparatorState getState() {
103 MRComparatorState state = threadLocalState.get();
105 state =
new MRComparatorState();
106 threadLocalState.set(state);
111 @SuppressWarnings(
"rawtypes")
113 public int compare(WritableComparable a, WritableComparable b) {
114 ByteArrayOutputStream bytes =
new ByteArrayOutputStream();
115 DataOutputStream out =
new DataOutputStream(bytes);
116 ((Rule)a).write(out);
117 byte[] bytesA = bytes.toByteArray();
119 ((Rule)b).write(out);
120 byte[] bytesB = bytes.toByteArray();
121 return compare(bytesA, 0, bytesA.length, bytesB, 0, bytesB.length);
124 private void findSplits(byte[] b,
int s,
int l,
125 RuleWritableSplit split, DataInputBuffer inBytes) {
128 split.sourceStart = WritableUtils.decodeVIntSize(b[s])
130 inBytes.reset(b, s, l);
131 int pointer = split.sourceStart;
132 int len = WritableUtils.readVInt(inBytes);
133 for(
int i=0; i< len; ++i){
134 pointer += WritableUtils.decodeVIntSize(b[pointer]);
136 split.sourceLength = pointer - split.sourceStart;
137 int targetN = pointer;
138 split.targetStart = WritableUtils.decodeVIntSize(b[targetN])
140 split.targetLength = l - (split.targetStart - s);
141 }
catch (IOException e) {
142 throw new RuntimeException(
143 "MRComparator should not throw this exception", e);
147 private int sourceCompare(byte[] b1, RuleWritableSplit s1, byte[] b2,
148 RuleWritableSplit s2) {
149 return compareBytes(b1, s1.sourceStart, s1.sourceLength, b2,
150 s2.sourceStart, s2.sourceLength);
153 private int targetCompare(byte[] b1, RuleWritableSplit s1, byte[] b2,
154 RuleWritableSplit s2) {
155 return compareBytes(b1, s1.targetStart, s1.targetLength, b2,
156 s2.targetStart, s2.targetLength);
160 public int compare(byte[] b1,
int s1,
int l1, byte[] b2,
int s2,
int l2) {
161 MRComparatorState state = getState();
162 RuleWritableSplit split1 = state.split1;
163 RuleWritableSplit split2 = state.split2;
164 findSplits(b1, s1, l1, split1, state.inBytes);
165 findSplits(b2, s2, l2, split2, state.inBytes);
169 if (isSource2Target()) {
170 firstCompare = sourceCompare(b1, split1, b2, split2);
172 firstCompare = targetCompare(b1, split1, b2, split2);
174 if (firstCompare != 0) {
177 if (isSource2Target()) {
178 secondCompare = targetCompare(b1, split1, b2, split2);
180 secondCompare = sourceCompare(b1, split1, b2, split2);
183 return secondCompare;
188 private static class RuleCount {
195 this.counts = counts;
200 public static final String SOURCE_TO_TARGET =
"rulextract.source2target";
204 private List<RuleCount> ruleCounts =
new ArrayList<>();
206 private List<Symbol> marginal =
new ArrayList<>();
208 private boolean source2Target =
true;
220 private List<Symbol> getMarginal(Rule rule) {
222 return rule.getSource();
224 return rule.getTarget();
229 protected void setup(Context context)
throws IOException,
230 InterruptedException {
231 super.setup(context);
232 String s2tString = context.getConfiguration().get(SOURCE_TO_TARGET);
233 if (s2tString == null) {
234 throw new RuntimeException(
"Need to set configuration value " 237 source2Target = Boolean.valueOf(s2tString);
247 private void marginalReduce(Iterable<RuleCount> rules,
249 InterruptedException {
250 for (RuleCount rc : rules) {
254 for (Entry<ByteWritable, IntWritable> entry : rc.counts.entrySet()) {
255 double probability = (double) entry.getValue().get()
256 / (double) totals.
get(entry.getKey()).
get();
257 int key = (int) entry.getKey().get();
259 globalProb.
put(key, Math.log(probability));
262 provProbs.
put(key, Math.log(probability));
265 features.put(globalF, globalProb);
266 features.put(provF, provProbs);
267 Rule outKey = rc.rule;
268 if (!source2Target) {
269 if (outKey.isSwapping()) {
270 outKey = outKey.invertNonTerminals();
273 context.
write(outKey, features);
279 public void run(Context context)
throws IOException, InterruptedException {
282 while (context.nextKey()) {
283 Rule key = context.getCurrentKey();
285 if (marginal.size() == 0) {
286 marginal.addAll(getMarginal(key));
288 Iterator<ProvenanceCountMap> it = context.getValues().iterator();
292 throw new RuntimeException(
"Non-unique rule! " + key);
294 if (!marginal.equals(getMarginal(key))) {
295 marginalReduce(ruleCounts, totals, context);
300 marginal.addAll(getMarginal(key));
301 ruleCounts.add(
new RuleCount(
new Rule(key),
305 marginalReduce(ruleCounts, totals, context);
void run(ucam::util::RegistryPO const &rg)
SOURCE2TARGET_PROBABILITY
PROVENANCE_SOURCE2TARGET_PROBABILITY
TARGET2SOURCE_PROBABILITY
PROVENANCE_TARGET2SOURCE_PROBABILITY