Cambridge SMT System
PatternInstanceCreator.java
Go to the documentation of this file.
1 /*******************************************************************************
2  * Licensed under the Apache License, Version 2.0 (the "License");
3  * you may not use these files except in compliance with the License.
4  * You may obtain a copy of the License at
5  *
6  * http://www.apache.org/licenses/LICENSE-2.0
7  *
8  * Unless required by applicable law or agreed to in writing, software
9  * distributed under the License is distributed on an "AS IS" BASIS,
10  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11  * See the License for the specific language governing permissions and
12  * limitations under the License.
13  *
14  * Copyright 2014 - Juan Pino, Aurelien Waite, William Byrne
15  *******************************************************************************/
20 package uk.ac.cam.eng.rule.retrieval;
21 
22 import java.util.ArrayList;
23 import java.util.Collection;
24 import java.util.HashSet;
25 import java.util.List;
26 import java.util.Set;
27 
28 import uk.ac.cam.eng.extraction.Rule;
29 import uk.ac.cam.eng.extraction.Symbol;
30 import uk.ac.cam.eng.util.CLI;
31 
39 class PatternInstanceCreator {
40 
41  private final int maxSourcePhrase;
42 
43  private final int maxSourceElements;
44 
45  private final int maxTerminalLength;
46 
47  private final int maxNonTerminalSpan;
48 
49  private final int hrMaxHeight;
50 
51  private final Set<SidePattern> sidePatterns;
52 
53  public PatternInstanceCreator(
54  CLI.RuleRetrieverParameters params,
55  Collection<SidePattern> sidePatterns) {
56  maxSourcePhrase = params.rp.maxSourcePhrase;
57  maxSourceElements = params.rp.maxSourceElements;
58  maxTerminalLength = params.rp.maxTerminalLength;
59  maxNonTerminalSpan = params.rp.maxNonTerminalSpan;
60  hrMaxHeight = params.hr_max_height;
61  this.sidePatterns = new HashSet<>(sidePatterns);
62  }
63 
70  public Set<Rule> createSourcePatternInstances(String line) {
71  Set<Rule> res = new HashSet<>();
72  String[] parts;
73  parts = line.split(" ");
74  List<Integer> sourceSentence = new ArrayList<Integer>();
75  for (int i = 0; i < parts.length; i++) {
76  sourceSentence.add(Integer.parseInt(parts[i]));
77  List<Symbol> sourcePhrase = new ArrayList<Symbol>();
78  for (int j = 0; j < maxSourcePhrase && j < parts.length - i; j++) {
79  sourcePhrase.add(Symbol.deserialise(Integer.parseInt(parts[i + j])));
80  // add source phrase
81  Rule r = new Rule(sourcePhrase, new ArrayList<Symbol>());
82  res.add(r);
83  }
84  }
85  Set<Rule> sourcePatternInstances = getPatternInstancesFromSourceSentence(
86  sourceSentence, sidePatterns);
87  for(Rule r : sourcePatternInstances){
88  if(sidePatterns.contains(r.source().toPattern())){
89  res.add(r);
90  }
91  }
92  return res;
93  }
94 
95  private Set<Rule> getPatternInstancesFromSourceSentence(
96  List<Integer> sourceSentence, Set<SidePattern> sidePatterns) {
97  Set<Rule> res = new HashSet<Rule>();
98  for (SidePattern sidePattern : sidePatterns) {
99  for (int i = 0; i < sourceSentence.size(); i++) {
100  res.addAll(getPatternInstancesFromSourceAndPattern2(
101  sourceSentence, sidePattern, i, 0, 0, 0));
102  }
103  }
104  return res;
105  }
106 
107  private Set<Rule> merge(Rule partialLeft, Set<Rule> partialRight) {
108  Set<Rule> res = new HashSet<Rule>();
109  List<Symbol> sourceLeft = partialLeft.getSource();
110  if (partialRight.isEmpty()) {
111  res.add(new Rule(sourceLeft, new ArrayList<Symbol>()));
112  return res;
113  }
114  for (Rule r : partialRight) {
115  List<Symbol> merged = new ArrayList<Symbol>();
116  List<Symbol> sourceRight = r.getSource();
117  merged.addAll(sourceLeft);
118  merged.addAll(sourceRight);
119  res.add(new Rule(merged, new ArrayList<Symbol>()));
120  }
121  return res;
122  }
123 
124  private Set<Rule> getPatternInstancesFromSourceAndPattern2(
125  List<Integer> sourceSentence, SidePattern sidePattern,
126  int startSentenceIndex, int startPatternIndex, int nbSrcElt,
127  int nbCoveredWords) {
128  Set<Rule> res = new HashSet<Rule>();
129  if (startSentenceIndex >= sourceSentence.size()) {
130  return res;
131  }
132  if (startPatternIndex >= sidePattern.size()) {
133  return res;
134  }
135  // pattern is too big for the (rest of the) sentence, e.g. pattern wXw
136  // for the phrase 2_3
137  if (sourceSentence.size() - startSentenceIndex < sidePattern.size()
138  - startPatternIndex) {
139  return res;
140  }
141  // we already have max source elements
142  if (nbSrcElt >= maxSourceElements) {
143  return res;
144  }
145  // we already cover hr max height
146  if (nbCoveredWords >= hrMaxHeight) {
147  return res;
148  }
149  if (sourceSentence.size() - startSentenceIndex == sidePattern.size()
150  - startPatternIndex) {
151  if (nbSrcElt + sidePattern.size() - startPatternIndex > maxSourceElements) {
152  return res;
153  }
154  if (nbCoveredWords + sourceSentence.size() - startSentenceIndex > hrMaxHeight) {
155  return res;
156  }
157  List<Symbol> patternInstance = new ArrayList<Symbol>();
158  for (int i = 0; i < sourceSentence.size() - startSentenceIndex; i++) {
159  if (sidePattern.get(startPatternIndex + i).equals("w")) {
160  patternInstance.add(Symbol.deserialise(sourceSentence.get(startSentenceIndex
161  + i)));
162  } else {
163  patternInstance.add(Symbol.deserialise(Integer.parseInt(sidePattern
164  .get(startPatternIndex + i))));
165  }
166  }
167  Rule r = new Rule(patternInstance, new ArrayList<Symbol>());
168  res.add(r);
169  return res;
170  }
171  List<Symbol> partialPattern = new ArrayList<Symbol>();
172  if (sidePattern.get(startPatternIndex).equals("w")) {
173  for (int i = startSentenceIndex; i < sourceSentence.size()
174  - (sidePattern.size() - startPatternIndex - 1)
175  && i < startSentenceIndex + maxTerminalLength
176  && i < startSentenceIndex + maxSourceElements - nbSrcElt
177  && i < startSentenceIndex + hrMaxHeight - nbCoveredWords; i++) {
178  partialPattern.add(Symbol.deserialise(sourceSentence.get(i)));
179  Rule r = new Rule(partialPattern, new ArrayList<Symbol>());
180  Set<Rule> right = getPatternInstancesFromSourceAndPattern2(
181  sourceSentence, sidePattern, i + 1,
182  startPatternIndex + 1, nbSrcElt + i
183  - startSentenceIndex + 1, nbCoveredWords + i
184  - startSentenceIndex + 1);
185  Set<Rule> merged = merge(r, right);
186  res.addAll(merged);
187  }
188  } else {
189  partialPattern.add(Symbol.deserialise(Integer.parseInt(sidePattern
190  .get(startPatternIndex))));
191  Rule r = new Rule(partialPattern, new ArrayList<Symbol>());
192  for (int i = startSentenceIndex; i < sourceSentence.size()
193  - (sidePattern.size() - startPatternIndex - 1)
194  && i < startSentenceIndex + maxNonTerminalSpan
195  && i < startSentenceIndex + hrMaxHeight - nbCoveredWords; i++) {
196  Set<Rule> merged = merge(
197  r,
198  getPatternInstancesFromSourceAndPattern2(
199  sourceSentence, sidePattern, i + 1,
200  startPatternIndex + 1, nbSrcElt + 1,
201  nbCoveredWords + i - startSentenceIndex + 1));
202  res.addAll(merged);
203  }
204  }
205  return res;
206  }
207 }