Cambridge SMT System
TTableServer.java
Go to the documentation of this file.
1 /*******************************************************************************
2  * Licensed under the Apache License, Version 2.0 (the "License");
3  * you may not use these files except in compliance with the License.
4  * You may obtain a copy of the License at
5  *
6  * http://www.apache.org/licenses/LICENSE-2.0
7  *
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  * Copyright 2014 - Juan Pino, Aurelien Waite, William Byrne
15  *******************************************************************************/
16 package uk.ac.cam.eng.extraction.hadoop.features.lexical;
17 
18 import java.io.BufferedInputStream;
19 import java.io.BufferedReader;
20 import java.io.ByteArrayOutputStream;
21 import java.io.Closeable;
22 import java.io.DataInputStream;
23 import java.io.DataOutputStream;
24 import java.io.EOFException;
25 import java.io.FileInputStream;
26 import java.io.FileNotFoundException;
27 import java.io.IOException;
28 import java.io.InputStreamReader;
29 import java.io.OutputStream;
30 import java.net.ServerSocket;
31 import java.net.Socket;
32 import java.net.SocketException;
33 import java.util.HashMap;
34 import java.util.Map;
35 import java.util.concurrent.ExecutorService;
36 import java.util.concurrent.Executors;
37 import java.util.concurrent.TimeUnit;
38 import java.util.zip.GZIPInputStream;
39 
40 import org.apache.commons.lang.time.StopWatch;
41 import org.apache.hadoop.util.StringUtils;
42 
44 import uk.ac.cam.eng.util.CLI;
45 
53 public class TTableServer implements Closeable {
54 
55  final static int BUFFER_SIZE = 65536;
56 
57  private static final String GENRE = "$GENRE";
58 
59  private static final String DIRECTION = "$DIRECTION";
60 
61  private ExecutorService threadPool = Executors.newFixedThreadPool(6);
62 
63  private class LoadTask implements Runnable {
64 
65  private final String fileName;
66  private final byte prov;
67 
68  private LoadTask(String fileName, byte prov) {
69  this.fileName = fileName;
70  this.prov = prov;
71  }
72 
73  @Override
74  public void run() {
75  try {
76  loadModel(fileName, prov);
77  } catch (IOException e) {
78  e.printStackTrace();
79  System.exit(1);
80  }
81 
82  }
83 
84  }
85 
86  private class QueryRunnable implements Runnable {
87 
88  private Socket querySocket;
89 
90  private ByteArrayOutputStream byteBuffer = new ByteArrayOutputStream(
91  BUFFER_SIZE);
92 
93  private DataOutputStream probWriter = new DataOutputStream(byteBuffer);
94 
95  private long queryTime = 0;
96 
97  private long totalKeys = 0;
98 
99  private int noOfQueries = 0;
100 
101  private QueryRunnable(Socket querySocket) {
102  this.querySocket = querySocket;
103  }
104 
105  @Override
106  public void run() {
107  try {
108  runWithExceptions();
109  } catch (IOException e) {
110  throw new RuntimeException(e);
111  }
112  }
113 
114  private void runWithExceptions() throws IOException {
115  try (DataInputStream queryReader = new DataInputStream(
116  new BufferedInputStream(querySocket.getInputStream()))) {
117  try (OutputStream out = querySocket.getOutputStream()) {
118  StopWatch stopWatch = new StopWatch();
119  // A bit nasty, but will block on the readInt.
120  // It's not really polling. Honest!
121  try {
122  int querySize = queryReader.readInt();
123  totalKeys += querySize;
124  stopWatch.start();
125  for (int i = 0; i < querySize; ++i) {
126  int provInt = queryReader.readInt();
127  byte prov = (byte) provInt;
128  int source = queryReader.readInt();
129  int target = queryReader.readInt();
130  if (model.containsKey(prov)
131  && model.get(prov).containsKey(source)
132  && model.get(prov).get(source)
133  .containsKey(target)) {
134  probWriter.writeDouble(model.get(prov)
135  .get(source).get(target));
136  } else {
137  probWriter.writeDouble(Double.MAX_VALUE);
138  }
139  }
140  byteBuffer.writeTo(out);
141  byteBuffer.reset();
142  stopWatch.stop();
143  queryTime += stopWatch.getTime();
144  if (++noOfQueries == 1000) {
145  System.out.println("Time per key = "
146  + (double) queryTime / (double) totalKeys);
147  noOfQueries = 0;
148  queryTime = totalKeys = 0;
149  }
150  } catch (EOFException e) {
151  System.out.println("Connection from mapper closed");
152  }
153  }
154  }
155  querySocket.close();
156  }
157  }
158 
159  private ServerSocket serverSocket;
160 
161  private Map<Byte, Map<Integer, Map<Integer, Double>>> model = new HashMap<>();
162 
163  private double minLexProb = 0;
164 
165  private Runnable server = new Runnable() {
166 
167  @Override
168  public void run() {
169  while (true) {
170  try {
171  Socket querySocket = serverSocket.accept();
172  threadPool.execute(new QueryRunnable(querySocket));
173  } catch (SocketException e) {
174  e.printStackTrace();
175  } catch (IOException e) {
176  e.printStackTrace();
177  }
178  }
179 
180  }
181  };
182 
183  public void startServer() {
184  Thread serverThread = new Thread(server);
185  serverThread.setDaemon(true);
186  serverThread.start();
187  }
188 
189  private void loadModel(String modelFile, byte prov)
190  throws FileNotFoundException, IOException {
191  try (BufferedReader br = new BufferedReader(new InputStreamReader(
192  new GZIPInputStream(new FileInputStream(modelFile))))) {
193  String line;
194  int count = 1;
195  while ((line = br.readLine()) != null) {
196  if (count % 1000000 == 0) {
197  System.err.println("Processed " + count + " lines");
198  }
199  count++;
200  line = line.replace("NULL", "0");
201  String[] parts = StringUtils.split(line, '\\', ' ');
202  try {
203  int sourceWord = Integer.parseInt(parts[0]);
204  int targetWord = Integer.parseInt(parts[1]);
205  double model1Probability = Double.parseDouble(parts[2]);
206  if (model1Probability < minLexProb) {
207  continue;
208  }
209  if (!model.get(prov).containsKey(sourceWord)) {
210  model.get(prov).put(sourceWord,
211  new HashMap<Integer, Double>());
212  }
213  model.get(prov).get(sourceWord)
214  .put(targetWord, model1Probability);
215  } catch (NumberFormatException e) {
216  System.out.println("Unable to parse line: "
217  + e.getMessage() + "\n" + line);
218  }
219  }
220  }
221  }
222 
223  public void setup(CLI.TTableServerParameters params) throws IOException,
224  InterruptedException {
225  boolean source2Target;
226  if (params.ttableDirection.equals("s2t")) {
227  source2Target = true;
228  } else if (params.ttableDirection.equals("t2s")) {
229  source2Target = false;
230  } else {
231  throw new RuntimeException("Unknown direction: "
232  + params.ttableDirection);
233  }
234  int serverPort;
235  if (source2Target) {
236  serverPort = params.sp.ttableS2TServerPort;
237  } else {
238  serverPort = params.sp.ttableT2SServerPort;
239  ;
240  }
241  minLexProb = params.minLexProb;
242  serverSocket = new ServerSocket(serverPort);
243  String lexTemplate = params.ttableServerTemplate;
244  String allString = lexTemplate.replace(GENRE, "ALL").replace(DIRECTION,
245  params.ttableLanguagePair);
246  System.out.println("Loading " + allString);
247  String[] provenances = params.prov.provenance.split(",");
248  ExecutorService loaderThreadPool = Executors.newFixedThreadPool(4);
249  model.put((byte) 0, new HashMap<Integer, Map<Integer, Double>>());
250  loaderThreadPool.execute(new LoadTask(allString, (byte) 0));
251  for (int i = 0; i < provenances.length; ++i) {
252  String provString = lexTemplate.replace(GENRE, provenances[i])
253  .replace(DIRECTION, params.ttableLanguagePair);
254  System.out.println("Loading " + provString);
255  byte prov = (byte) (i + 1);
256  model.put(prov, new HashMap<Integer, Map<Integer, Double>>());
257  loaderThreadPool.execute(new LoadTask(provString, prov));
258  }
259  loaderThreadPool.shutdown();
260  loaderThreadPool.awaitTermination(3, TimeUnit.HOURS);
261  System.gc();
262  }
263 
264  @Override
265  public void close() throws IOException {
266  threadPool.shutdown();
267  }
268 
269  public static void main(String[] args) throws IllegalArgumentException,
270  IllegalAccessException, IOException, InterruptedException {
272  Util.parseCommandLine(args, params);
273 
274  try (TTableServer server = new TTableServer()) {
275  server.setup(params);
276  server.startServer();
277  System.err.println("TTable server ready on port: "
278  + server.serverSocket.getLocalPort());
279  Thread.sleep(24 * 60 * 60 * 1000); // Sleep for 24 hours
280  }
281  }
282 
283 }
void run(ucam::util::RegistryPO const &rg)
static JCommander parseCommandLine(String[] args, Object params)
Definition: Util.java:85