20 package uk.ac.cam.eng.rule.retrieval;
22 import java.util.ArrayList;
23 import java.util.Collection;
24 import java.util.HashSet;
25 import java.util.List;
39 class PatternInstanceCreator {
41 private final int maxSourcePhrase;
43 private final int maxSourceElements;
45 private final int maxTerminalLength;
47 private final int maxNonTerminalSpan;
49 private final int hrMaxHeight;
51 private final Set<SidePattern> sidePatterns;
53 public PatternInstanceCreator(
54 CLI.RuleRetrieverParameters params,
55 Collection<SidePattern> sidePatterns) {
56 maxSourcePhrase = params.rp.maxSourcePhrase;
57 maxSourceElements = params.rp.maxSourceElements;
58 maxTerminalLength = params.rp.maxTerminalLength;
59 maxNonTerminalSpan = params.rp.maxNonTerminalSpan;
60 hrMaxHeight = params.hr_max_height;
61 this.sidePatterns =
new HashSet<>(sidePatterns);
70 public Set<Rule> createSourcePatternInstances(String line) {
71 Set<Rule> res =
new HashSet<>();
73 parts = line.split(
" ");
74 List<Integer> sourceSentence =
new ArrayList<Integer>();
75 for (
int i = 0; i < parts.length; i++) {
76 sourceSentence.add(Integer.parseInt(parts[i]));
77 List<Symbol> sourcePhrase =
new ArrayList<Symbol>();
78 for (
int j = 0; j < maxSourcePhrase && j < parts.length - i; j++) {
79 sourcePhrase.add(Symbol.deserialise(Integer.parseInt(parts[i + j])));
81 Rule r =
new Rule(sourcePhrase,
new ArrayList<Symbol>());
85 Set<Rule> sourcePatternInstances = getPatternInstancesFromSourceSentence(
86 sourceSentence, sidePatterns);
87 for(Rule r : sourcePatternInstances){
88 if(sidePatterns.contains(r.source().toPattern())){
95 private Set<Rule> getPatternInstancesFromSourceSentence(
96 List<Integer> sourceSentence, Set<SidePattern> sidePatterns) {
97 Set<Rule> res =
new HashSet<Rule>();
98 for (SidePattern sidePattern : sidePatterns) {
99 for (
int i = 0; i < sourceSentence.size(); i++) {
100 res.addAll(getPatternInstancesFromSourceAndPattern2(
101 sourceSentence, sidePattern, i, 0, 0, 0));
107 private Set<Rule> merge(Rule partialLeft, Set<Rule> partialRight) {
108 Set<Rule> res =
new HashSet<Rule>();
109 List<Symbol> sourceLeft = partialLeft.getSource();
110 if (partialRight.isEmpty()) {
111 res.add(
new Rule(sourceLeft,
new ArrayList<Symbol>()));
114 for (Rule r : partialRight) {
115 List<Symbol> merged =
new ArrayList<Symbol>();
116 List<Symbol> sourceRight = r.getSource();
117 merged.addAll(sourceLeft);
118 merged.addAll(sourceRight);
119 res.add(
new Rule(merged,
new ArrayList<Symbol>()));
124 private Set<Rule> getPatternInstancesFromSourceAndPattern2(
125 List<Integer> sourceSentence, SidePattern sidePattern,
126 int startSentenceIndex,
int startPatternIndex,
int nbSrcElt,
127 int nbCoveredWords) {
128 Set<Rule> res =
new HashSet<Rule>();
129 if (startSentenceIndex >= sourceSentence.size()) {
132 if (startPatternIndex >= sidePattern.size()) {
137 if (sourceSentence.size() - startSentenceIndex < sidePattern.size()
138 - startPatternIndex) {
142 if (nbSrcElt >= maxSourceElements) {
146 if (nbCoveredWords >= hrMaxHeight) {
149 if (sourceSentence.size() - startSentenceIndex == sidePattern.size()
150 - startPatternIndex) {
151 if (nbSrcElt + sidePattern.size() - startPatternIndex > maxSourceElements) {
154 if (nbCoveredWords + sourceSentence.size() - startSentenceIndex > hrMaxHeight) {
157 List<Symbol> patternInstance =
new ArrayList<Symbol>();
158 for (
int i = 0; i < sourceSentence.size() - startSentenceIndex; i++) {
159 if (sidePattern.get(startPatternIndex + i).equals(
"w")) {
160 patternInstance.add(Symbol.deserialise(sourceSentence.get(startSentenceIndex
163 patternInstance.add(Symbol.deserialise(Integer.parseInt(sidePattern
164 .get(startPatternIndex + i))));
167 Rule r =
new Rule(patternInstance,
new ArrayList<Symbol>());
171 List<Symbol> partialPattern =
new ArrayList<Symbol>();
172 if (sidePattern.get(startPatternIndex).equals(
"w")) {
173 for (
int i = startSentenceIndex; i < sourceSentence.size()
174 - (sidePattern.size() - startPatternIndex - 1)
175 && i < startSentenceIndex + maxTerminalLength
176 && i < startSentenceIndex + maxSourceElements - nbSrcElt
177 && i < startSentenceIndex + hrMaxHeight - nbCoveredWords; i++) {
178 partialPattern.add(Symbol.deserialise(sourceSentence.get(i)));
179 Rule r =
new Rule(partialPattern,
new ArrayList<Symbol>());
180 Set<Rule> right = getPatternInstancesFromSourceAndPattern2(
181 sourceSentence, sidePattern, i + 1,
182 startPatternIndex + 1, nbSrcElt + i
183 - startSentenceIndex + 1, nbCoveredWords + i
184 - startSentenceIndex + 1);
185 Set<Rule> merged = merge(r, right);
189 partialPattern.add(Symbol.deserialise(Integer.parseInt(sidePattern
190 .get(startPatternIndex))));
191 Rule r =
new Rule(partialPattern,
new ArrayList<Symbol>());
192 for (
int i = startSentenceIndex; i < sourceSentence.size()
193 - (sidePattern.size() - startPatternIndex - 1)
194 && i < startSentenceIndex + maxNonTerminalSpan
195 && i < startSentenceIndex + hrMaxHeight - nbCoveredWords; i++) {
196 Set<Rule> merged = merge(
198 getPatternInstancesFromSourceAndPattern2(
199 sourceSentence, sidePattern, i + 1,
200 startPatternIndex + 1, nbSrcElt + 1,
201 nbCoveredWords + i - startSentenceIndex + 1));