# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module implements constraints which assure that highly structured
output is well-formatted. For example, the bracket predictor checks for
balanced bracket expressions, and the OSM predictor prevents any sequence
of operations which cannot be compiled to a string.
"""
import logging
from cam.sgnmt import utils
from cam.sgnmt.predictors.core import Predictor, UnboundedVocabularyPredictor
# Default operation IDs
OSM_EOP_ID = 4
OSM_SRC_POP_ID = 4
OSM_SET_MARKER_ID = 5
OSM_JUMP_FWD_ID = 6
OSM_JUMP_BWD_ID = 7
OSM_SRC_POP2_ID = 8
OSM_COPY_ID = 8
OSM_SRC_UNPOP_ID = 9
[docs]def load_external_lengths(path):
"""Loads a length distribution from a plain text file. The file
must contain blank separated <length>:<score> pairs in each line.
Args:
path (string): Path to the length file.
Returns:
list of dicts mapping a length to its scores, one dict for each
sentence.
"""
lengths = []
with open(path) as f:
for line in f:
scores = {}
for pair in line.strip().split():
if ':' in pair:
length, score = pair.split(':')
scores[int(length)] = float(score)
else:
scores[int(pair)] = 0.0
lengths.append(scores)
return lengths
[docs]def update_trg_osm_ids(wmap_path):
"""Update the OSM_*_ID variables using a target word map.
Args:
wmap_path (string): Path to the wmap file.
"""
global OSM_SRC_POP_ID, OSM_SET_MARKER_ID, OSM_JUMP_FWD_ID, \
OSM_JUMP_BWD_ID, OSM_SRC_POP2_ID, OSM_COPY_ID, \
OSM_SRC_UNPOP_ID
if not wmap_path:
return
with open(wmap_path) as f:
for line in f:
word, word_id = line.strip().split()
if word == "<SRC_POP>":
OSM_SRC_POP_ID = int(word_id)
logging.debug("OSM SRC_POP = %d" % OSM_SRC_POP_ID)
elif word == "<SET_MARKER>":
OSM_SET_MARKER_ID = int(word_id)
logging.debug("OSM SET_MARKER = %d" % OSM_SET_MARKER_ID)
elif word == "<JUMP_FWD>":
OSM_JUMP_FWD_ID = int(word_id)
logging.debug("OSM JUMP_FWD = %d" % OSM_JUMP_FWD_ID)
elif word == "<JUMP_BWD>":
OSM_JUMP_BWD_ID = int(word_id)
logging.debug("OSM JUMP_BWD = %d" % OSM_JUMP_BWD_ID)
elif word == "<SRC_POP2>":
OSM_SRC_POP2_ID = int(word_id)
logging.debug("OSM SRC_POP2 = %d" % OSM_SRC_POP2_ID)
elif word == "<COPY>":
OSM_COPY_ID = int(word_id)
logging.debug("OSM COPY = %d" % OSM_COPY_ID)
elif word == "<SRC_UNPOP>":
OSM_SRC_UNPOP_ID = int(word_id)
logging.debug("SRC_UNPOP = %d" % OSM_SRC_UNPOP_ID)
[docs]def update_src_osm_ids(wmap_path):
"""Update the OSM_*_ID variables using a source word map.
Args:
wmap_path (string): Path to the wmap file.
"""
global OSM_EOP_ID
if not wmap_path:
return
with open(wmap_path) as f:
for line in f:
word, word_id = line.strip().split()
if word == "<EOP>":
OSM_EOP_ID = int(word_id)
logging.debug("OSM EOP = %d" % OSM_EOP_ID)
[docs]class OSMPredictor(Predictor):
"""This predictor applies the following constraints to an OSM output:
- The number of POP tokens must be equal to the number of source
tokens
- JUMP_FWD and JUMP_BWD tokens are constraint to avoid jumping out
of bounds.
The predictor supports the original OSNMT operation set (default)
plus a number of variations that are set by the use_* arguments in
the constructor.
"""
def __init__(self, src_wmap, trg_wmap, use_jumps=True, use_auto_pop=False,
use_unpop=False, use_pop2=False, use_src_eop=False,
use_copy=False):
"""Creates a new osm predictor.
Args:
src_wmap (string): Path to the source wmap. Used to grap
EOP id.
trg_wmap (string): Path to the target wmap. Used to update
IDs of operations.
use_jumps (bool): If true, use SET_MARKER, JUMP_FWD and
JUMP_BWD operations
use_auto_pop (bool): If true, each word insertion
automatically moves read head
use_unpop (bool): If true, use SRC_UNPOP to move read head
to the left.
use_pop2 (bool): If true, use two read heads to align
phrases
use_src_eop (bool): If true, expect EOP tokens in the src
sentence
use_copy (bool): If true, move read head at COPY operations
"""
super(OSMPredictor, self).__init__()
update_trg_osm_ids(trg_wmap)
self.use_jumps = use_jumps
self.use_auto_pop = use_auto_pop
self.use_unpop = use_unpop
self.use_src_eop = use_src_eop
if use_src_eop:
update_src_osm_ids(src_wmap)
self.pop_ids = set([OSM_SRC_POP_ID])
if use_pop2:
self.pop_ids.add(OSM_SRC_POP2_ID)
if use_copy:
self.pop_ids.add(OSM_COPY_ID)
self.illegal_sequences = []
if use_jumps:
self.illegal_sequences.extend([
#[OSM_JUMP_FWD_ID, OSM_JUMP_BWD_ID],
#[OSM_JUMP_BWD_ID, OSM_JUMP_FWD_ID],
#[OSM_JUMP_FWD_ID, OSM_SET_MARKER_ID, OSM_JUMP_FWD_ID],
#[OSM_JUMP_FWD_ID, OSM_SET_MARKER_ID, OSM_JUMP_BWD_ID],
#[OSM_JUMP_BWD_ID, OSM_SET_MARKER_ID, OSM_JUMP_FWD_ID],
#[OSM_JUMP_BWD_ID, OSM_SET_MARKER_ID, OSM_JUMP_BWD_ID],
[OSM_SET_MARKER_ID, OSM_SET_MARKER_ID]
])
if use_auto_pop:
self.no_auto_pop = set()
if use_jumps:
self.no_auto_pop.add(OSM_JUMP_FWD_ID)
self.no_auto_pop.add(OSM_JUMP_BWD_ID)
self.no_auto_pop.add(OSM_SET_MARKER_ID)
if use_unpop:
self.no_auto_pop.add(OSM_SRC_UNPOP_ID)
def _is_pop(self, token):
if token in self.pop_ids:
return True
return self.use_auto_pop and token not in self.no_auto_pop
[docs] def initialize(self, src_sentence):
"""Sets the number of source tokens.
Args:
src_sentence (list): Not used
"""
if self.use_src_eop:
self.src_len = src_sentence.count(OSM_EOP_ID) + 1
else:
self.src_len = len(src_sentence)
self.n_holes = 0
self.head = 0
self.n_pop = 0
self.history = []
[docs] def predict_next(self):
"""Apply OSM constraints.
Returns:
dict.
"""
ret = {}
if self.n_pop >= self.src_len:
return {utils.EOS_ID: 0.0} # Force EOS
else:
ret[utils.EOS_ID] = utils.NEG_INF
if self.use_unpop and self.n_pop <= 0:
ret[OSM_SRC_UNPOP_ID] = utils.NEG_INF
if self.use_jumps:
if self.head <= 0:
ret[OSM_JUMP_BWD_ID] = utils.NEG_INF
if self.head >= self.n_holes:
ret[OSM_JUMP_FWD_ID] = utils.NEG_INF
for seq in self.illegal_sequences:
hist = seq[:-1]
if self.history[-len(hist):] == hist:
ret[seq[-1]] = utils.NEG_INF
return ret
[docs] def get_unk_probability(self, posterior):
if self.n_pop >= self.src_len: # Force EOS
return utils.NEG_INF
return 0.0
[docs] def consume(self, word):
"""Updates the number of holes, EOPs, and the head position."""
if not self._is_pop(word):
if self.use_unpop and word == OSM_SRC_UNPOP_ID:
self.n_pop -= 1
else:
self.history.append(word)
else:
self.n_pop += 1
if self.use_jumps:
if word == OSM_SET_MARKER_ID:
self.n_holes += 1
self.head += 1
elif word == OSM_JUMP_FWD_ID:
self.head += 1
elif word == OSM_JUMP_BWD_ID:
self.head -= 1
[docs] def get_state(self):
return self.n_holes, self.head, self.n_pop
[docs] def set_state(self, state):
self.n_holes, self.head, self.n_pop = state
[docs] def is_equal(self, state1, state2):
"""Trivial implementation"""
return state1 == state2
[docs]class ForcedOSMPredictor(Predictor):
"""This predictor allows forced decoding with an OSM output, which
essentially means running the OSM in alignment mode. This predictor
assumes well-formed operation sequences. Please combine this
predictor with the osm constraint predictor to satisfy this
requirement. The state of this predictor is the compiled version of
the current history. It allows terminal symbols which are
consistent with the reference. The end-of-sentence symbol is
supressed until all words in the reference have been consumed.
"""
def __init__(self, trg_wmap, trg_test_file):
"""Creates a new forcedosm predictor.
Args:
trg_wmap (string): Path to the target wmap file. Used to
grap OSM operation IDs.
trg_test_file (string): Path to the plain text file with
the target sentences. Must have the
same number of lines as the number
of source sentences to decode
"""
super(ForcedOSMPredictor, self).__init__()
update_trg_osm_ids(trg_wmap)
self.trg_sentences = []
with open(trg_test_file) as f:
for line in f:
self.trg_sentences.append([int(w)
for w in line.strip().split()])
[docs] def initialize(self, src_sentence):
"""Resets compiled and head.
Args:
src_sentence (list): Not used
"""
self.compiled = ["X"]
self.head = 0
self.cur_trg_sentence = self.trg_sentences[self.current_sen_id]
def _is_complete(self):
"""Returns true if the compiled sentence contains the right
number of terminals.
"""
n_terminals = len([s for s in self.compiled if s != "X"])
return n_terminals == len(self.cur_trg_sentence)
def _generate_alignments(self, align_stub=[], compiled_start_pos=0,
sentence_start_pos=0):
for pos in range(compiled_start_pos, len(self.compiled)):
if self.compiled[pos] != 'X':
word = int(self.compiled[pos])
for sen_pos in range(sentence_start_pos,
len(self.cur_trg_sentence)):
if self.cur_trg_sentence[sen_pos] == word:
self._generate_alignments(
align_stub + [(pos, sen_pos)],
pos+1,
sen_pos+1)
return
self.alignments.append(align_stub)
def _align(self):
possible_words = [set() for _ in range(len(self.compiled))]
self.alignments = []
self._generate_alignments(align_stub=[])
for alignment in self.alignments:
alignment.append((len(self.compiled), len(self.cur_trg_sentence)))
prev_compiled_pos = -1
prev_sentence_pos = -1
for compiled_pos, sentence_pos in alignment:
section_words = set(
self.cur_trg_sentence[prev_sentence_pos+1:sentence_pos])
if section_words:
seen_gap = False
for section_pos in range(prev_compiled_pos+1,
compiled_pos):
if self.compiled[section_pos] == "X":
if seen_gap:
possible_words[section_pos] |= section_words
else:
possible_words[section_pos].add(
self.cur_trg_sentence[prev_sentence_pos
+ section_pos
- prev_compiled_pos])
seen_gap = True
prev_compiled_pos = compiled_pos
prev_sentence_pos = sentence_pos
return possible_words
[docs] def predict_next(self):
"""Apply word reference constraints.
Returns:
dict.
"""
ret = {OSM_SRC_POP_ID: 0.0}
possible_words = self._align()
if possible_words[self.head]:
ret[OSM_SET_MARKER_ID] = 0.0
if any(possible_words[:self.head]):
ret[OSM_JUMP_BWD_ID] = 0.0
if any(possible_words[self.head+1:]):
ret[OSM_JUMP_FWD_ID] = 0.0
if self._is_complete():
ret[utils.EOS_ID] = 0.0
for word in possible_words[self.head]:
ret[word] = 0.0
return ret
[docs] def get_unk_probability(self, posterior):
"""Always returns -inf."""
return utils.NEG_INF
def _jump_op(self, step):
self.head += step
while self.compiled[self.head] != "X":
self.head += step
def _insert_op(self, op):
self.compiled = self.compiled[:self.head] + [op] + \
self.compiled[self.head:]
self.head += 1
[docs] def consume(self, word):
"""Updates the compiled string and the head position."""
if word == OSM_SET_MARKER_ID:
self._insert_op("X")
elif word == OSM_JUMP_FWD_ID:
self._jump_op(1)
elif word == OSM_JUMP_BWD_ID:
self._jump_op(-1)
elif word != OSM_SRC_POP_ID:
self._insert_op(str(word))
[docs] def get_state(self):
return self.compiled, self.head
[docs] def set_state(self, state):
self.compiled, self.head = state
[docs] def is_equal(self, state1, state2):
"""Trivial implementation"""
return state1 == state2
[docs]class BracketPredictor(UnboundedVocabularyPredictor):
"""This predictor constrains the output to well-formed bracket
expressions. It also allows to specify the number of terminals with
an external length distribution file.
"""
def __init__(self, max_terminal_id, closing_bracket_id, max_depth=-1,
extlength_path=""):
"""Creates a new bracket predictor.
Args:
max_terminal_id (int): All IDs greater than this are
brackets
closing_bracket_id (string): All brackets except these ones are
opening. Comma-separated list of integers.
max_depth (int): If positive, restrict the maximum depth
extlength_path (string): If this is set, restrict the
number of terminals to the distribution specified in
the referenced file. Terminals can be implicit: We
count a single terminal between each adjacent opening
and closing bracket.
"""
super(BracketPredictor, self).__init__()
self.max_terminal_id = max_terminal_id
try:
self.closing_bracket_ids = utils.split_comma(closing_bracket_id, int)
except:
self.closing_bracket_ids = [int(closing_bracket_id)]
self.max_depth = max_depth if max_depth >= 0 else 1000000
if extlength_path:
self.length_scores = load_external_lengths(extlength_path)
else:
self.length_scores = None
self.max_length = 1000000
[docs] def initialize(self, src_sentence):
"""Sets the current depth to 0.
Args:
src_sentence (list): Not used
"""
self.cur_depth = 0
self.ends_with_opening = True
self.n_terminals = 0
if self.length_scores:
self.cur_length_scores = self.length_scores[self.current_sen_id]
self.max_length = max(self.cur_length_scores)
def _no_closing_bracket(self):
return {i: utils.NEG_INF for i in self.closing_bracket_ids}
[docs] def predict_next(self, words):
"""If the maximum depth is reached, exclude all opening
brackets. If history is not balanced, exclude EOS. If the
current depth is zero, exclude closing brackets.
Args:
words (list): Set of words to score
Returns:
dict.
"""
if self.cur_depth == 0:
# Balanced: Score EOS with extlengths, supress closing bracket
if self.ends_with_opening: # Initial predict next call
ret = self._no_closing_bracket()
ret[utils.EOS_ID] = utils.NEG_INF
return ret
return {utils.EOS_ID: self.cur_length_scores.get(
self.n_terminals, utils.NEG_INF)
if self.length_scores else 0.0}
# Unbalanced: do not allow EOS
ret = {utils.EOS_ID: utils.NEG_INF}
if (self.cur_depth >= self.max_depth
or self.n_terminals >= self.max_length):
# Do not allow opening brackets
ret.update({w: utils.NEG_INF for w in words
if (w > self.max_terminal_id
and not w in self.closing_bracket_ids)})
if (self.length_scores
and self.cur_depth == 1
and self.n_terminals > 0
and not self.n_terminals in self.cur_length_scores):
# Do not allow to go back to depth 0 with wrong number of terminals
ret.update(self._no_closing_bracket())
return ret
[docs] def get_unk_probability(self, posterior):
"""Always returns 0.0"""
if self.cur_depth == 0 and not self.ends_with_opening:
return utils.NEG_INF
return 0.0
[docs] def consume(self, word):
"""Updates current depth and the number of consumed terminals."""
if word in self.closing_bracket_ids:
if self.ends_with_opening:
self.n_terminals += 1
self.cur_depth -= 1
self.ends_with_opening = False
elif word > self.max_terminal_id:
self.cur_depth += 1
self.ends_with_opening = True
[docs] def get_state(self):
"""Returns the current depth and number of consumed terminals"""
return self.cur_depth, self.n_terminals, self.ends_with_opening
[docs] def set_state(self, state):
"""Sets the current depth and number of consumed terminals"""
self.cur_depth, self.n_terminals, self.ends_with_opening = state
[docs] def is_equal(self, state1, state2):
"""Trivial implementation"""
return state1 == state2