Cambridge SMT System
MarginalReducer.java
Go to the documentation of this file.
1 /*******************************************************************************
2  * Licensed under the Apache License, Version 2.0 (the "License");
3  * you may not use these files except in compliance with the License.
4  * You may obtain a copy of the License at
5  *
6  * http://www.apache.org/licenses/LICENSE-2.0
7  *
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  * Copyright 2014 - Juan Pino, Aurelien Waite, William Byrne
15  *******************************************************************************/
16 package uk.ac.cam.eng.extraction.hadoop.features.phrase;
17 
18 import java.io.ByteArrayOutputStream;
19 import java.io.DataOutputStream;
20 import java.io.IOException;
21 import java.util.ArrayList;
22 import java.util.Iterator;
23 import java.util.List;
24 import java.util.Map.Entry;
25 
26 import org.apache.hadoop.io.ByteWritable;
27 import org.apache.hadoop.io.DataInputBuffer;
28 import org.apache.hadoop.io.IntWritable;
29 import org.apache.hadoop.io.WritableComparable;
30 import org.apache.hadoop.io.WritableComparator;
31 import org.apache.hadoop.io.WritableUtils;
32 import org.apache.hadoop.mapreduce.Reducer;
33 
34 import uk.ac.cam.eng.extraction.Rule;
35 import uk.ac.cam.eng.extraction.Symbol;
40 
50 class MarginalReducer extends
51  Reducer<Rule, ProvenanceCountMap, Rule, FeatureMap> {
52 
60  private static class RuleWritableSplit {
61  int sourceStart;
62  int sourceLength;
63  int targetStart;
64  int targetLength;
65  }
66 
75  private static class MRComparatorState {
76 
77  DataInputBuffer inBytes = new DataInputBuffer();
78 
79  RuleWritableSplit split1 = new RuleWritableSplit();
80 
81  RuleWritableSplit split2 = new RuleWritableSplit();
82 
83  }
84 
92  public static abstract class MRComparator extends WritableComparator {
93 
94  private ThreadLocal<MRComparatorState> threadLocalState = new ThreadLocal<>();
95 
96  public MRComparator() {
97  super(Rule.class);
98  }
99 
100  protected abstract boolean isSource2Target();
101 
102  private MRComparatorState getState() {
103  MRComparatorState state = threadLocalState.get();
104  if (state == null) {
105  state = new MRComparatorState();
106  threadLocalState.set(state);
107  }
108  return state;
109  }
110 
111  @SuppressWarnings("rawtypes")
112  @Override
113  public int compare(WritableComparable a, WritableComparable b) {
114  ByteArrayOutputStream bytes = new ByteArrayOutputStream();
115  DataOutputStream out = new DataOutputStream(bytes);
116  ((Rule)a).write(out);
117  byte[] bytesA = bytes.toByteArray();
118  bytes.reset();
119  ((Rule)b).write(out);
120  byte[] bytesB = bytes.toByteArray();
121  return compare(bytesA, 0, bytesA.length, bytesB, 0, bytesB.length);
122  }
123 
124  private void findSplits(byte[] b, int s, int l,
125  RuleWritableSplit split, DataInputBuffer inBytes) {
126 
127  try {
128  split.sourceStart = WritableUtils.decodeVIntSize(b[s])
129  + s;
130  inBytes.reset(b, s, l);
131  int pointer = split.sourceStart;
132  int len = WritableUtils.readVInt(inBytes);
133  for(int i=0; i< len; ++i){
134  pointer += WritableUtils.decodeVIntSize(b[pointer]);
135  }
136  split.sourceLength = pointer - split.sourceStart;
137  int targetN = pointer;
138  split.targetStart = WritableUtils.decodeVIntSize(b[targetN])
139  + targetN;
140  split.targetLength = l - (split.targetStart - s);
141  } catch (IOException e) {
142  throw new RuntimeException(
143  "MRComparator should not throw this exception", e);
144  }
145  }
146 
147  private int sourceCompare(byte[] b1, RuleWritableSplit s1, byte[] b2,
148  RuleWritableSplit s2) {
149  return compareBytes(b1, s1.sourceStart, s1.sourceLength, b2,
150  s2.sourceStart, s2.sourceLength);
151  }
152 
153  private int targetCompare(byte[] b1, RuleWritableSplit s1, byte[] b2,
154  RuleWritableSplit s2) {
155  return compareBytes(b1, s1.targetStart, s1.targetLength, b2,
156  s2.targetStart, s2.targetLength);
157  }
158 
159  @Override
160  public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
161  MRComparatorState state = getState();
162  RuleWritableSplit split1 = state.split1;
163  RuleWritableSplit split2 = state.split2;
164  findSplits(b1, s1, l1, split1, state.inBytes);
165  findSplits(b2, s2, l2, split2, state.inBytes);
166 
167  int firstCompare;
168  int secondCompare;
169  if (isSource2Target()) {
170  firstCompare = sourceCompare(b1, split1, b2, split2);
171  } else {
172  firstCompare = targetCompare(b1, split1, b2, split2);
173  }
174  if (firstCompare != 0) {
175  return firstCompare;
176  } else {
177  if (isSource2Target()) {
178  secondCompare = targetCompare(b1, split1, b2, split2);
179  } else {
180  secondCompare = sourceCompare(b1, split1, b2, split2);
181  }
182  }
183  return secondCompare;
184 
185  }
186  }
187 
188  private static class RuleCount {
189 
190  final Rule rule;
191  final ProvenanceCountMap counts;
192 
193  public RuleCount(Rule rule, ProvenanceCountMap counts) {
194  this.rule = rule;
195  this.counts = counts;
196  }
197 
198  }
199 
200  public static final String SOURCE_TO_TARGET = "rulextract.source2target";
201 
202  private ProvenanceCountMap totals = new ProvenanceCountMap();
203 
204  private List<RuleCount> ruleCounts = new ArrayList<>();
205 
206  private List<Symbol> marginal = new ArrayList<>();
207 
208  private boolean source2Target = true;
209 
210  private ProvenanceProbMap provProbs = new ProvenanceProbMap();
211 
212  private ProvenanceProbMap globalProb = new ProvenanceProbMap();
213 
214  private Feature globalF;
215 
216  private Feature provF;
217 
218  private FeatureMap features = new FeatureMap();
219 
220  private List<Symbol> getMarginal(Rule rule) {
221  if (source2Target) {
222  return rule.getSource();
223  } else {
224  return rule.getTarget();
225  }
226  }
227 
228  @Override
229  protected void setup(Context context) throws IOException,
230  InterruptedException {
231  super.setup(context);
232  String s2tString = context.getConfiguration().get(SOURCE_TO_TARGET);
233  if (s2tString == null) {
234  throw new RuntimeException("Need to set configuration value "
235  + SOURCE_TO_TARGET);
236  }
237  source2Target = Boolean.valueOf(s2tString);
238  if (source2Target) {
241  } else {
244  }
245  }
246 
247  private void marginalReduce(Iterable<RuleCount> rules,
248  ProvenanceCountMap totals, Context context) throws IOException,
249  InterruptedException {
250  for (RuleCount rc : rules) {
251  provProbs.clear();
252  globalProb.clear();
253  features.clear();
254  for (Entry<ByteWritable, IntWritable> entry : rc.counts.entrySet()) {
255  double probability = (double) entry.getValue().get()
256  / (double) totals.get(entry.getKey()).get();
257  int key = (int) entry.getKey().get();
258  if(key==0){
259  globalProb.put(key, Math.log(probability));
260  }
261  else{
262  provProbs.put(key, Math.log(probability));
263  }
264  }
265  features.put(globalF, globalProb);
266  features.put(provF, provProbs);
267  Rule outKey = rc.rule;
268  if (!source2Target) {
269  if (outKey.isSwapping()) {
270  outKey = outKey.invertNonTerminals();
271  }
272  }
273  context.write(outKey, features);
274  }
275 
276  }
277 
278  @Override
279  public void run(Context context) throws IOException, InterruptedException {
280  setup(context);
281 
282  while (context.nextKey()) {
283  Rule key = context.getCurrentKey();
284  // First Key!
285  if (marginal.size() == 0) {
286  marginal.addAll(getMarginal(key));
287  }
288  Iterator<ProvenanceCountMap> it = context.getValues().iterator();
289  ProvenanceCountMap counts = it.next();
290 
291  if (it.hasNext()) {
292  throw new RuntimeException("Non-unique rule! " + key);
293  }
294  if (!marginal.equals(getMarginal(key))) {
295  marginalReduce(ruleCounts, totals, context);
296  totals.clear();
297  ruleCounts.clear();
298  }
299  marginal.clear();
300  marginal.addAll(getMarginal(key));
301  ruleCounts.add(new RuleCount(new Rule(key),
302  new ProvenanceCountMap(counts)));
303  totals.increment(counts);
304  }
305  marginalReduce(ruleCounts, totals, context);
306  cleanup(context);
307  }
308 }
DoubleWritable put(IntWritable key, DoubleWritable value)
void run(ucam::util::RegistryPO const &rg)