16 package uk.ac.cam.eng.extraction.hadoop.features.lexical;
18 import java.io.BufferedInputStream;
19 import java.io.BufferedReader;
20 import java.io.ByteArrayOutputStream;
21 import java.io.Closeable;
22 import java.io.DataInputStream;
23 import java.io.DataOutputStream;
24 import java.io.EOFException;
25 import java.io.FileInputStream;
26 import java.io.FileNotFoundException;
27 import java.io.IOException;
28 import java.io.InputStreamReader;
29 import java.io.OutputStream;
30 import java.net.ServerSocket;
31 import java.net.Socket;
32 import java.net.SocketException;
33 import java.util.HashMap;
35 import java.util.concurrent.ExecutorService;
36 import java.util.concurrent.Executors;
37 import java.util.concurrent.TimeUnit;
38 import java.util.zip.GZIPInputStream;
40 import org.apache.commons.lang.time.StopWatch;
41 import org.apache.hadoop.util.StringUtils;
55 final static int BUFFER_SIZE = 65536;
57 private static final String GENRE =
"$GENRE";
59 private static final String DIRECTION =
"$DIRECTION";
61 private ExecutorService threadPool = Executors.newFixedThreadPool(6);
63 private class LoadTask
implements Runnable {
65 private final String fileName;
66 private final byte prov;
68 private LoadTask(String fileName, byte prov) {
69 this.fileName = fileName;
76 loadModel(fileName, prov);
77 }
catch (IOException e) {
86 private class QueryRunnable
implements Runnable {
88 private Socket querySocket;
90 private ByteArrayOutputStream byteBuffer =
new ByteArrayOutputStream(
93 private DataOutputStream probWriter =
new DataOutputStream(byteBuffer);
95 private long queryTime = 0;
97 private long totalKeys = 0;
99 private int noOfQueries = 0;
101 private QueryRunnable(Socket querySocket) {
102 this.querySocket = querySocket;
109 }
catch (IOException e) {
110 throw new RuntimeException(e);
114 private void runWithExceptions()
throws IOException {
115 try (DataInputStream queryReader =
new DataInputStream(
116 new BufferedInputStream(querySocket.getInputStream()))) {
117 try (OutputStream out = querySocket.getOutputStream()) {
118 StopWatch stopWatch =
new StopWatch();
122 int querySize = queryReader.readInt();
123 totalKeys += querySize;
125 for (
int i = 0; i < querySize; ++i) {
126 int provInt = queryReader.readInt();
127 byte prov = (byte) provInt;
128 int source = queryReader.readInt();
129 int target = queryReader.readInt();
130 if (model.containsKey(prov)
131 && model.get(prov).containsKey(source)
132 && model.get(prov).get(source)
133 .containsKey(target)) {
134 probWriter.writeDouble(model.get(prov)
135 .get(source).get(target));
137 probWriter.writeDouble(Double.MAX_VALUE);
140 byteBuffer.writeTo(out);
143 queryTime += stopWatch.getTime();
144 if (++noOfQueries == 1000) {
145 System.out.println(
"Time per key = " 146 + (
double) queryTime / (
double) totalKeys);
148 queryTime = totalKeys = 0;
150 }
catch (EOFException e) {
151 System.out.println(
"Connection from mapper closed");
159 private ServerSocket serverSocket;
161 private Map<Byte, Map<Integer, Map<Integer, Double>>> model =
new HashMap<>();
163 private double minLexProb = 0;
165 private Runnable server =
new Runnable() {
171 Socket querySocket = serverSocket.accept();
172 threadPool.execute(
new QueryRunnable(querySocket));
173 }
catch (SocketException e) {
175 }
catch (IOException e) {
184 Thread serverThread =
new Thread(server);
185 serverThread.setDaemon(
true);
186 serverThread.start();
189 private void loadModel(String modelFile, byte prov)
190 throws FileNotFoundException, IOException {
191 try (BufferedReader br =
new BufferedReader(
new InputStreamReader(
192 new GZIPInputStream(
new FileInputStream(modelFile))))) {
195 while ((line = br.readLine()) != null) {
196 if (count % 1000000 == 0) {
197 System.err.println(
"Processed " + count +
" lines");
200 line = line.replace(
"NULL",
"0");
201 String[] parts = StringUtils.split(line,
'\\',
' ');
203 int sourceWord = Integer.parseInt(parts[0]);
204 int targetWord = Integer.parseInt(parts[1]);
205 double model1Probability = Double.parseDouble(parts[2]);
206 if (model1Probability < minLexProb) {
209 if (!model.get(prov).containsKey(sourceWord)) {
210 model.get(prov).put(sourceWord,
211 new HashMap<Integer, Double>());
213 model.get(prov).get(sourceWord)
214 .put(targetWord, model1Probability);
215 }
catch (NumberFormatException e) {
216 System.out.println(
"Unable to parse line: " 217 + e.getMessage() +
"\n" + line);
224 InterruptedException {
225 boolean source2Target;
226 if (params.ttableDirection.equals(
"s2t")) {
227 source2Target =
true;
228 }
else if (params.ttableDirection.equals(
"t2s")) {
229 source2Target =
false;
231 throw new RuntimeException(
"Unknown direction: " 232 + params.ttableDirection);
236 serverPort = params.sp.ttableS2TServerPort;
238 serverPort = params.sp.ttableT2SServerPort;
241 minLexProb = params.minLexProb;
242 serverSocket =
new ServerSocket(serverPort);
243 String lexTemplate = params.ttableServerTemplate;
244 String allString = lexTemplate.replace(GENRE,
"ALL").replace(DIRECTION,
245 params.ttableLanguagePair);
246 System.out.println(
"Loading " + allString);
247 String[] provenances = params.prov.provenance.split(
",");
248 ExecutorService loaderThreadPool = Executors.newFixedThreadPool(4);
249 model.put((byte) 0,
new HashMap<Integer, Map<Integer, Double>>());
250 loaderThreadPool.execute(
new LoadTask(allString, (byte) 0));
251 for (
int i = 0; i < provenances.length; ++i) {
252 String provString = lexTemplate.replace(GENRE, provenances[i])
253 .replace(DIRECTION, params.ttableLanguagePair);
254 System.out.println(
"Loading " + provString);
255 byte prov = (byte) (i + 1);
256 model.put(prov,
new HashMap<Integer, Map<Integer, Double>>());
257 loaderThreadPool.execute(
new LoadTask(provString, prov));
259 loaderThreadPool.shutdown();
260 loaderThreadPool.awaitTermination(3, TimeUnit.HOURS);
265 public void close() throws IOException {
266 threadPool.shutdown();
269 public static void main(String[] args)
throws IllegalArgumentException,
270 IllegalAccessException, IOException, InterruptedException {
275 server.setup(params);
276 server.startServer();
277 System.err.println(
"TTable server ready on port: " 278 + server.serverSocket.getLocalPort());
279 Thread.sleep(24 * 60 * 60 * 1000);
void run(ucam::util::RegistryPO const &rg)