# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of beam search for predictors with multiple
tokenizations.
"""
from abc import abstractmethod
import copy
import heapq
import logging
import codecs
from cam.sgnmt import utils
from cam.sgnmt import io
from cam.sgnmt.decoding.core import Decoder, PartialHypothesis
from cam.sgnmt.predictors.automata import EPS_ID
[docs]def is_key_complete(key):
"""Returns true if the key is complete. Complete keys are marked
with a blank symbol at the end of the string. A complete key
corresponds to a full word, incomplete keys cannot be mapped to
word IDs.
Args:
key (string): The key
Returns:
bool. Return true if the last character in ``key`` is blank.
"""
return key and key[-1] == ' '
[docs]class WordMapper(object):
"""This class is responsible for the mapping between keys and word
IDs. The multiseg beam search can produce words which are not in
the original word map. This mapper adds these words to
``io.trg_wmap``.
This class uses the GoF design pattern singleton.
"""
singleton = None
"""Singleton instance. Access via ``get_singleton()``. """
@staticmethod
[docs] def get_singleton():
"""Get singleton instance of the word mapper. This method
implements lazy initialization.
Returns:
WordMapper. Singleton ``WordMapper`` instance.
"""
if not WordMapper.singleton:
WordMapper.singleton = WordMapper()
return WordMapper.singleton
def __init__(self):
"""Creates a new mapper instance and synchronizes it with
``io.trg_wmap``.
"""
self.max_word_id = 3
self.wmap_len = 0
self.key2id = {}
self.synchronize()
self.reserved_keys = {'<unk> ': utils.UNK_ID,
'<eps> ': utils.UNK_ID,
'<epsilon> ': utils.UNK_ID,
'<s> ': utils.GO_ID,
'</s> ': utils.EOS_ID}
[docs] def synchronize(self):
"""Synchronizes the internal state of this mapper with
``io.trg_wmap``. This includes updating the reverse lookup
table and finding the lowest free word ID which can be assigned
to new words.
"""
if self.wmap_len == len(io.trg_wmap):
return
self.key2id = {}
self.max_word_id = 3
for word_id, key in io.trg_wmap.items():
self.max_word_id = max(self.max_word_id, word_id)
self.key2id["%s " % key] = word_id
self.wmap_len = len(io.trg_wmap)
[docs] def get_word_id(self, key):
"""Finds a word ID for the given key. If no such key is in the
current word map, create a new entry in ``io.trg_wmap``.
Args:
key (string): key to look up
Returns:
int. Word ID corresponding to ``key``. Add new word ID if
the key cannot be found in ``io.trg_wmap``
"""
if not key:
return utils.UNK_ID
if key in self.reserved_keys:
return self.reserved_keys[key]
self.synchronize()
if key in self.key2id:
return self.key2id[key]
self.max_word_id += 1
io.trg_wmap[self.max_word_id] = key[:-1]
self.key2id[key] = self.max_word_id
self.wmap_len += 1
return self.max_word_id
[docs]class Tokenizer(object):
"""A tokenizer translates between token sequences and string keys.
It is mainly responsible for matching token sequences from
different predictors together.
"""
@abstractmethod
[docs] def tokens2key(self, tokens):
"""Convert a token sequence to a string key.
Args:
tokens (list): List of token IDs
Returns:
String. The key for the token sequence
"""
raise NotImplementedError
@abstractmethod
[docs] def key2tokens(self, key):
"""Convert a key to a sequence of tokens. If this mapping is
ambiguous, return one of the shortest mappings. Use UNK to
match any (sub)string without token correspondence.
Args:
key (string): key to look up
Returns:
list. List of token IDs
"""
raise NotImplementedError
@abstractmethod
[docs] def is_word_begin_token(self, token):
"""Returns true if ``token`` is only allowed at word begins. """
raise NotImplementedError
[docs]class WordTokenizer(Tokenizer):
"""This tokenizer implements a purly word-level tokenization.
Keys are generated according a standard word map.
"""
def __init__(self, path):
self.id2key = {}
self.key2id = {}
try:
split = path.split(":", 1)
max_id = int(split[0])
path = split[1]
except:
max_id = utils.INF
with codecs.open(path, encoding='utf-8') as f:
for line in f:
entry = line.strip().split()
key = entry[0]
word_id = int(entry[-1])
if word_id < max_id and word_id != utils.UNK_ID:
self.id2key[word_id] = "%s " % key
self.key2id["%s " % key] = word_id
[docs] def key2tokens(self, key):
return [self.key2id.get(key, utils.UNK_ID)]
[docs] def tokens2key(self, tokens):
if len(tokens) != 1:
return ""
return self.id2key.get(tokens[0], "")
[docs] def is_word_begin_token(self, token):
return True
[docs]class EOWTokenizer(Tokenizer):
"""This tokenizer reads word maps with explicit </w> endings. This
can be used for subword unit based tokenizers.
"""
def __init__(self, path):
self.id2key = {}
self.key2id = {}
with codecs.open(path, encoding='utf-8') as f:
for line in f:
key, word_id = line.strip().split()
if word_id == str(utils.UNK_ID):
continue
if key[-4:] == "</w>":
key = "%s " % key[:-4]
elif key in ['<s>', '</s>']:
key = "%s " % key
self.id2key[int(word_id)] = key
self.key2id[key] = int(word_id)
[docs] def key2tokens(self, key):
tokens = self._key2tokens_recursive(key)
return tokens if tokens else [utils.UNK_ID]
def _key2tokens_recursive(self, key, max_len = 100):
if not key:
return []
if max_len <= 0:
return None
if key in self.key2id: # Match of the full key
return [self.key2id[key]]
if max_len <= 1:
return None
best_tokens = None
for l in range(len(key)-1, 0, -1):
if key[:l] in self.key2id:
rest = self._key2tokens_recursive(key[l:], max_len-1)
if not rest is None and len(rest) < max_len:
best_tokens = [self.key2id[key[:l]]] + rest
max_len = len(best_tokens) - 1
return best_tokens
[docs] def tokens2key(self, tokens):
return ''.join([self.id2key.get(t, "") for t in tokens])
[docs] def is_word_begin_token(self, token):
return token in [utils.GO_ID, utils.EOS_ID]
[docs]class MixedTokenizer(Tokenizer):
"""This tokenizer allows to mix word- and character-level
tokenizations like proposed by Wu et al. (2016). Words with
<b>, <m>, and <e> prefixes are treated as character-level
tokens, all others are completed word-level tokens
"""
def __init__(self, path):
self.word_key2id = {}
self.b_key2id = {}
self.m_key2id = {}
self.e_key2id = {}
self.id2key = {}
self.mid_tokens = {}
try:
split = path.split(":", 1)
max_id = int(split[0])
path = split[1]
except:
max_id = utils.INF
with codecs.open(path, encoding='utf-8') as f:
for line in f:
key, token_id_ = line.strip().split()
token_id = int(token_id_)
if token_id == utils.UNK_ID or token_id >= max_id:
continue
if key[:3] == "<b>":
key = key[3:]
self.b_key2id[key] = token_id
elif key[:3] == "<m>":
key = key[3:]
self.m_key2id[key] = token_id
self.mid_tokens[token_id] = True
elif key[:3] == "<e>":
key = "%s " % key[3:]
self.e_key2id[key[:-1]] = token_id
self.mid_tokens[token_id] = True
else:
key = "%s " % key
self.word_key2id[key] = token_id
self.id2key[token_id] = key
[docs] def key2tokens(self, key):
if not key:
return []
if key in self.word_key2id:
return [self.word_key2id[key]]
maps = [self.m_key2id] * len(key)
if is_key_complete(key):
maps = maps[:-2] + [self.e_key2id]
maps[0] = self.b_key2id
return [maps[idx].get(key[idx], utils.UNK_ID)
for idx in range(len(maps))]
[docs] def tokens2key(self, tokens):
return ''.join([self.id2key.get(t, "") for t in tokens])
[docs] def is_word_begin_token(self, token):
return not self.mid_tokens.get(token, False)
[docs]class PredictorStub(object):
"""A predictor stub models the state of a predictor given a
continuation.
"""
def __init__(self, tokens, pred_state):
"""Creates a new stub for a certain predictor.
Args:
tokens (list): List of token IDs which correspond to the
key
pred_state (object): Predictor state before consuming
the last token in ``tokens``
"""
self.tokens = tokens
self.pred_state = pred_state
self.score = 0.0
self.score_pos = 0
[docs] def has_full_score(self):
"""Returns true if the full token sequence has been scored with
the predictor, i.e. ``self.score`` is the final predictor
score.
"""
return self.score_pos == len(self.tokens)
[docs] def score_next(self, token_score):
"""Can be called when the continuation is expanded and the
score of the next token is available
Args:
token_score (float): Predictor score of
self.tokens[self.score_pos]
"""
self.score += token_score
self.score_pos += 1
[docs] def expand(self, token, token_score, pred_state):
"""Creates a new predictor stub by adding a (scored) token.
Args:
token (int): Token ID to add
token_score (float): Token score of the added token
pred_state (object): predictor state before consuming
the added token
"""
new_stub = PredictorStub(self.tokens + [token], pred_state)
new_stub.score_pos = self.score_pos + 1
new_stub.score = self.score + token_score
return new_stub
[docs]class Continuation(object):
"""A continuation is a partial hypothesis plus the next word. A
continuation can be incomplete if predictors use finer grained
tokenization and the score is not final yet.
"""
def __init__(self, parent_hypo, pred_stubs, key = ''):
"""Create a new continuation.
Args:
parent_hypo (PartialHypothesis): hypo object encoding the
state at the last word
boundary
pred_stubs (list): List of ``PredictorStub`` objects, one
for each predictor
key (string): The lead key for this continuation. All stubs
must be consistent with this key
"""
self.parent_hypo = parent_hypo
self.pred_stubs = pred_stubs
self.key = key
self.score = 0.0
[docs] def is_complete(self):
"""Returns true if all predictor stubs are completed, i.e.
the continuation can be mapped unambiguously to a word and the
score is final.
"""
return all([s and s.has_full_score() for s in self.pred_stubs])
[docs] def calculate_score(self, pred_weights, defaults = []):
"""Calculates the full word score for this continuation using
the predictor stub scores.
Args:
pred_weights (list): Predictor weights. Length of this list
must match the number of stubs
defaults (list): Score which should be used if a predictor
stub is set to None
Returns:
float. Full score of this continuation, or an optimistic
estimate if the continuation is not complete.
"""
return sum(map(lambda x: x[0]*x[1],
zip(pred_weights,
[s.score if s else defaults[pidx]
for pidx, s in enumerate(self.pred_stubs)])))
[docs] def generate_expanded_hypo(self, decoder):
"""This can be used to create a new ``PartialHypothesis`` which
reflects the state after this continuation. This involves
expanding the history by ``word``, updating score and score_
breakdown, and consuming the last tokens in the stub to save
the final predictor states. If the continuation is complete,
this will result in a new word level hypothesis. If not, the
generated hypo will indicate an incomplete word at the last
position by using the word ID -1.
"""
score_breakdown = []
pred_weights = []
for idx,(p, w) in enumerate(decoder.predictors):
p.set_state(copy.deepcopy(self.pred_stubs[idx].pred_state))
p.consume(self.pred_stubs[idx].tokens[-1])
score_breakdown.append((self.pred_stubs[idx].score, w))
pred_weights.append(w)
word_id = WordMapper.get_singleton().get_word_id(self.key)
return self.parent_hypo.expand(word_id,
decoder.get_predictor_states(),
self.calculate_score(pred_weights),
score_breakdown)
[docs] def expand(self, decoder):
for pidx,(p, _) in enumerate(decoder.predictors):
stub = self.pred_stubs[pidx]
if not stub.has_full_score():
p.set_state(copy.deepcopy(stub.pred_state))
p.consume(stub.tokens[stub.score_pos-1])
posterior = p.predict_next()
stub.score_next(utils.common_get(
posterior,
stub.tokens[stub.score_pos],
p.get_unk_probability(posterior)))
stub.pred_state = p.get_state()
[docs]class MultisegBeamDecoder(Decoder):
"""This is a version of beam search which can handle predictors
with differing tokenizations. We assume that all tokenizations are
consistent with words, i.e. no token crosses word boundaries. The
search simulates beam search on the word level. At each time step,
we keep the n best hypotheses on the word level. Predictor scores
on finer-grained tokens are collapsed into a single score s.t. they
can be combined with scores from other predictors. This decoder can
produce words without entry in the word map. In this case, words
are added to ``io.trg_wmap``. Consider using the ``output_chars``
option to avoid dealing with the updated word map in the output.
"""
def __init__(self,
decoder_args,
hypo_recombination,
beam_size,
tokenizations,
early_stopping = True,
max_word_len = 25):
"""Creates a new beam decoder instance for predictors with
multiple tokenizations.
Args:
decoder_args (object): Decoder configuration passed through
from the configuration API.
hypo_recombination (bool): Activates hypo recombination
beam_size (int): Absolute beam size. A beam of 12 means
that we keep track of 12 active hypothesis
tokenizations (string): Comma separated strings describing
the predictor tokenizations
early_stopping (bool): If true, we stop when the best
scoring hypothesis ends with </S>.
If false, we stop when all hypotheses
end with </S>. Enable if you are
only interested in the single best
decoding result. If you want to
create full 12-best lists, disable
max_word_len (int): Maximum length of a single word
"""
super(MultisegBeamDecoder, self).__init__(decoder_args)
self.hypo_recombination = hypo_recombination
self.beam_size = beam_size
self.stop_criterion = self._best_eos if early_stopping else self._all_eos
self.toks = []
self.max_word_len = max_word_len
if not tokenizations:
logging.fatal("Specify --multiseg_tokenizations!")
for tok_config in tokenizations.split(","):
if tok_config[:6] == "mixed:":
tok = MixedTokenizer(tok_config[6:])
elif tok_config[:4] == "eow:":
tok = EOWTokenizer(tok_config[4:])
else:
if tok_config[:5] == "word:":
tok_config = tok_config[5:]
tok = WordTokenizer(tok_config)
self.toks.append(tok)
def _best_eos(self, hypos):
"""Returns true if the best hypothesis ends with </S>"""
return hypos[0].get_last_word() != utils.EOS_ID
def _all_eos(self, hypos):
"""Returns true if the all hypotheses end with </S>"""
for hypo in hypos[:self.beam_size]:
if hypo.get_last_word() != utils.EOS_ID:
return True
return False
def _rebuild_hypo_list(self, hypos, new_hypo):
"""Add new_hypo to the list of n best complete hypos.
Implements hypothesis recombination.
Returns:
list. Sorted list of n best hypos in hypos + new_hypo
"""
if not self.hypo_recombination:
hypos.append(new_hypo)
else:
combined = False
for idx,hypo in list(enumerate(hypos)):
if hypo.predictor_states and self.are_equal_predictor_states(
hypo.predictor_states,
new_hypo.predictor_states):
if hypo.score >= new_hypo.score: # Keep old one
hypo1 = hypo
hypo2 = new_hypo
else: # Discard old one
hypo1 = new_hypo
hypo2 = hypo
hypos[idx] = new_hypo
logging.debug("Hypo recombination: %s > %s" % (
hypo1.trgt_sentence, hypo2.trgt_sentence))
combined = True
break
if not combined:
hypos.append(new_hypo)
hypos.sort(key=lambda h: h.score, reverse=True)
return hypos[:self.beam_size]
def _get_word_initial_posteriors(self, hypo):
"""Call ``predict_next`` on all predictors to get the
distributions over the first tokens of the next word.
Args:
hypo (PartialHypothesis): Defines the predictor states
Returns:
list. List of posterior vectors for each predictor. The
UNK scores are added to the vectors.
"""
self.apply_predictors_count += 1
self.set_predictor_states(hypo.predictor_states)
posteriors = []
for p, _ in self.predictors:
posterior = p.predict_next()
posterior[utils.UNK_ID] = p.get_unk_probability(posterior)
posteriors.append(posterior)
return posteriors
def _get_initial_stubs(self, predictor, start_posterior, min_score):
"""Get the initial predictor stubs for full word search with a
single predictor.
"""
stubs = []
pred_state = predictor.get_state()
for t, s in utils.common_iterable(start_posterior):
stub = PredictorStub([t], pred_state)
stub.score_next(s)
if stub.score >= min_score:
stubs.append(stub)
stubs.sort(key=lambda s: s.score, reverse=True)
return stubs
def _best_keys_complete(self, stubs, tok):
"""Stopping criterion for single predictor full word search.
We stop full word search if the n best stubs are complete.
"""
return all([is_key_complete(tok.tokens2key(s.tokens))
for s in stubs[:self.beam_size]])
def _search_full_words(self, predictor, start_posterior, tok, min_score):
stubs = self._get_initial_stubs(predictor, start_posterior, min_score)
while not self._best_keys_complete(stubs, tok):
next_stubs = []
for stub in stubs[:self.beam_size]:
key = tok.tokens2key(stub.tokens)
if (not key) or len(key) > self.max_word_len:
continue
if is_key_complete(key):
next_stubs.append(stub)
continue
predictor.set_state(copy.deepcopy(stub.pred_state))
predictor.consume(stub.tokens[-1])
posterior = predictor.predict_next()
pred_state = predictor.get_state()
for t, s in utils.common_iterable(posterior):
if t != utils.UNK_ID and not tok.is_word_begin_token(t):
child_stub = stub.expand(t, s, pred_state)
if child_stub.score >= min_score:
next_stubs.append(child_stub)
stubs = next_stubs
stubs.sort(key=lambda s: s.score, reverse=True)
return stubs
def _get_complete_continuations(self, hypo, min_hypo_score):
"""This is a generator which yields the complete continuations
of ``hypo`` in descending order of score
"""
min_score = min_hypo_score - hypo.score
if min_score > 0.0:
return
pred_weights = list(map(lambda el: el[1], self.predictors))
# Get initial continuations by searching with predictors separately
start_posteriors = self._get_word_initial_posteriors(hypo)
pred_states = self.get_predictor_states()
keys = {}
for pidx, (p,w) in enumerate(self.predictors):
stubs = self._search_full_words(p,
start_posteriors[pidx],
self.toks[pidx],
min_score / w)
n_added = 0
for stub in stubs:
key = self.toks[pidx].tokens2key(stub.tokens)
if is_key_complete(key):
if key in keys: # Add to existing continuation
prev_stub = keys[key].pred_stubs[pidx]
if prev_stub is None or prev_stub.score < stub.score:
keys[key].pred_stubs[pidx] = stub
elif n_added < self.beam_size: # Create new continuation
n_added += 1
stubs = [None] * len(self.predictors)
stubs[pidx] = stub
keys[key] = Continuation(hypo, stubs, key)
# Fill in stubs which are set to None
for cont in keys.values():
for pidx in range(len(self.predictors)):
if cont.pred_stubs[pidx] is None:
stub = PredictorStub(self.toks[pidx].key2tokens(cont.key),
pred_states[pidx])
stub.score_next(utils.common_get(
start_posteriors[pidx],
stub.tokens[0],
start_posteriors[pidx][utils.UNK_ID]))
cont.pred_stubs[pidx] = stub
conts = [(-c.calculate_score(pred_weights), c) for c in keys.values()]
heapq.heapify(conts)
# Iterate through conts, expand if necessary, yield if complete
while conts:
s,cont = heapq.heappop(conts)
if cont.is_complete():
yield -s,cont
else: # Need to rescore with sec predictors
cont.expand(self)
heapq.heappush(conts, (-cont.calculate_score(pred_weights), cont))
[docs] def decode(self, src_sentence):
"""Decodes a single source sentence using beam search. """
self.initialize_predictors(src_sentence)
hypos = [PartialHypothesis(self.get_predictor_states())]
guard_hypo = PartialHypothesis(None)
guard_hypo.score = utils.NEG_INF
it = 0
while self.stop_criterion(hypos):
if it > self.max_len: # prevent infinite loops
break
it = it + 1
next_hypos = [guard_hypo]
for hypo in hypos:
if hypo.get_last_word() == utils.EOS_ID:
next_hypos = self._rebuild_hypo_list(next_hypos, hypo)
for s, cont in self._get_complete_continuations(
hypo,
next_hypos[-1].score):
if hypo.score + s < next_hypos[-1].score:
break
next_hypos = self._rebuild_hypo_list(
next_hypos,
cont.generate_expanded_hypo(self))
hypos = [h for h in next_hypos if h.score > utils.NEG_INF]
for hypo in hypos:
if hypo.get_last_word() == utils.EOS_ID:
self.add_full_hypo(hypo.generate_full_hypothesis())
if not self.full_hypos:
logging.warn("No complete hypotheses found for %s" % src_sentence)
for hypo in hypos:
self.add_full_hypo(hypo.generate_full_hypothesis())
return self.get_full_hypos_sorted()