Source code for cam.sgnmt.predictors.tokenization

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains wrapper predictors which support decoding with
diverse tokenization. The ``Word2charPredictor`` can be used if the 
decoder operates on fine-grained tokens such as characters, but the
tokenization of a predictor is coarse-grained (e.g. words or subwords).

The ``word2char`` predictor maintains an explicit list of word boundary
characters and applies consume and predict_next whenever a word boundary
character is consumed.

The ``fsttok`` predictor also masks coarse grained predictors when SGNMT
uses fine-grained tokens such as characters. This wrapper loads an FST
which transduces character to predictor-unit sequences.
"""

import copy
import logging

from cam.sgnmt import utils
from cam.sgnmt.misc.trie import SimpleTrie
from cam.sgnmt.predictors.core import UnboundedVocabularyPredictor, Predictor
from cam.sgnmt.utils import NEG_INF, common_get


EPS_ID = 0
"""OpenFST's reserved ID for epsilon arcs. """


[docs]class CombinedState(object): """Combines an FST state with predictor state. Use by the fsttok predictor. """ def __init__(self, fst_node, pred_state, posterior, unconsumed = [], pending_score = 0.0): self.fst_node = fst_node self.pred_state = pred_state self.posterior = posterior self.unconsumed = list(unconsumed) self.pending_score = pending_score
[docs] def traverse_fst(self, trans_fst, char): """Returns a list of ``CombinedState``s with the same predictor state and posterior, but an ``fst_node`` which is reachable via the input label ``char``. If the output tabe contains symbols, add them to ``unconsumed``. Args: trans_fst (Fst): FST to traverse char (int): Index of character Returns: list. List of combined states reachable via ``char`` """ ret = [] self._dfs(trans_fst, ret, self.fst_node, char, self.unconsumed) return ret
def _dfs(self, trans_fst, acc, root_node, char, cur_unconsumed): """Helper method for ``traverse_fst`` for traversing the FST along ``char`` and epsilon arcs with DFS. Args: trans_fst (Fst): FST to traverse acc (list): Accumulator list root_node (int): State in the FST to start char (int): Index of character cur_unconsumed (list): Unconsumed predictor tokens so far """ for arc in trans_fst.arcs(root_node): next_unconsumed = list(cur_unconsumed) if arc.olabel != EPS_ID: next_unconsumed.append(arc.olabel) if arc.ilabel == EPS_ID: self._dfs(trans_fst, acc, arc.nextstate, char, next_unconsumed) elif arc.ilabel == char: acc.append(CombinedState(arc.nextstate, self.pred_state, self.posterior, next_unconsumed, self.pending_score))
[docs] def score(self, token, predictor): """Returns a score which can be added if ``token`` is consumed next. This is not necessarily the full score but an upper bound on it: Continuations will have a score lower or equal than this. We only use the current posterior vector and do not consume tokens with the wrapped predictor. """ if token and self.unconsumed: self.consume_all(predictor) s = self.pending_score if token: s += self._get_token_score(token, predictor) return s
[docs] def consume_all(self, predictor): """Consume all unconsumed tokens and update pred_state, pending_score, and posterior accordingly. Args: predictor (Predictor): Predictor instance """ if not self.unconsumed: return if self.posterior is None: self.update_posterior(predictor) predictor.set_state(copy.deepcopy(self.pred_state)) for token in self.unconsumed: self.pending_score += self._get_token_score(token, predictor) predictor.consume(token) self.posterior = predictor.predict_next() self.pred_state = copy.deepcopy(predictor.get_state()) self.unconsumed = []
[docs] def consume_single(self, predictor): """Consume a single token in ``self.unconsumed``. Args: predictor (Predictor): Predictor instance """ if not self.unconsumed: return if not self.posterior is None: self.pending_score += self._get_token_score(self.unconsumed[0], predictor) self.posterior = None
def _get_token_score(self, token, predictor): """Look up ``token`` in ``self.posterior``. """ return utils.common_get(self.posterior, token, predictor.get_unk_probability(self.posterior))
[docs] def update_posterior(self, predictor): """If ``self.posterior`` is None, call ``predict_next`` to be able to score the next tokens. """ if not self.posterior is None: return predictor.set_state(copy.deepcopy(self.pred_state)) predictor.consume(self.unconsumed[0]) self.posterior = predictor.predict_next() self.pred_state = copy.deepcopy(predictor.get_state()) self.unconsumed = self.unconsumed[1:]
[docs]class FSTTokPredictor(Predictor): """This wrapper can be used if the SGNMT decoder operates on the character level, but a predictor uses a more coarse grained tokenization. The mapping is defined by an FST which transduces character to predictor unit sequences. This wrapper maintains a list of ``CombinedState`` objects which are tuples of an FST node and a predictor state for which holds: - The input labels on the path to the node are consistent with the consumed characters - The output labels on the path to the node are consistent with the predictor states """ def __init__(self, path, fst_unk_id, max_pending_score, slave_predictor): """Constructor for the fsttok wrapper Args: path (string): Path to an FST which transduces characters to predictor tokens fst_unk_id (int): ID used to represent UNK in the FSTs (usually 999999998) max_pending_score (float): Maximum pending score in a ``CombinedState`` instance. slave_predictor (Predictor): Wrapped predictor """ super(FSTTokPredictor, self).__init__() self.max_pending_score = max_pending_score self.fst_unk_id = fst_unk_id self.slave_predictor = slave_predictor if isinstance(slave_predictor, UnboundedVocabularyPredictor): logging.fatal("fsttok cannot wrap an unbounded " "vocabulary predictor.") self.trans_fst = utils.load_fst(path)
[docs] def initialize(self, src_sentence): """Pass through to slave predictor. The source sentence is not modified. ``states`` is updated to the initial FST node and predictor posterior and state. """ self.slave_predictor.initialize(src_sentence) posterior = self.slave_predictor.predict_next() self.states = [CombinedState(self.trans_fst.start(), self.slave_predictor.get_state(), posterior)] self.last_prediction = {}
[docs] def initialize_heuristic(self, src_sentence): """Pass through to slave predictor. The source sentence is not modified """ logging.warning("fsttok does not support predictor heuristics") self.slave_predictor.initialize_heuristic(src_sentence)
[docs] def predict_next(self): self.last_prediction = {} for state in self.states: self._collect_chars(state, state.fst_node, None) return self.last_prediction
def _collect_chars(self, state, root_node, first_olabel): """Recursively builds up ``last_prediction`` by traversing epsilon arcs in the FST from ``root_node`` """ for arc in self.trans_fst.arcs(root_node): arc_first_olabel = first_olabel if first_olabel else arc.olabel if arc.ilabel == EPS_ID: self._collect_chars(state, arc.nextstate, arc_first_olabel) else: score = state.score(arc_first_olabel, self.slave_predictor) if arc.ilabel in self.last_prediction: self.last_prediction[arc.ilabel] = max( self.last_prediction[arc.ilabel], score) else: self.last_prediction[arc.ilabel] = score
[docs] def get_unk_probability(self, posterior): """Always returns negative infinity. Handling UNKs needs to be realized by the FST. """ return utils.NEG_INF
def _choose_better(self, s1, s2): """``consume`` merges states if they have the same ``fst_node`` This method defines which one to keep. We prefer states that 1) have less unconsumed UNK tokens 2) have higher pending_score """ if s1 is None: return s2 n_unk1 = len([1 for t in s1.unconsumed if t == self.fst_unk_id]) n_unk2 = len([1 for t in s2.unconsumed if t == self.fst_unk_id]) if n_unk1 > n_unk2: return s2 if n_unk1 < n_unk2: return s1 if s1.pending_score < s2.pending_score: return s2 return s1
[docs] def consume(self, word): """Update ``self.states`` to be consistent with ``word`` and consumes all the predictor tokens. """ next_states = [] for state in self.states: next_states.extend(state.traverse_fst(self.trans_fst, word)) consumed_score = self.last_prediction.get(word, 0.0) for state in next_states: state.pending_score -= consumed_score state.consume_single(self.slave_predictor) # if two states have the same fst_node, keep only the better one # Also: Remove states with too large pending_score uniq_states = {} for state in next_states: if state.pending_score < -self.max_pending_score: continue n = state.fst_node uniq_states[n] = self._choose_better(uniq_states.get(n, None), state) self.states = list(uniq_states.values())
[docs] def get_state(self): return self.states, self.last_prediction
[docs] def set_state(self, state): self.states, self.last_prediction = state
[docs] def estimate_future_cost(self, hypo): """Not implemented yet""" return 0.0
[docs] def set_current_sen_id(self, cur_sen_id): """We need to override this method to propagate current\_ sentence_id to the slave predictor """ super(FSTTokPredictor, self).set_current_sen_id(cur_sen_id) self.slave_predictor.set_current_sen_id(cur_sen_id)
[docs] def is_equal(self, state1, state2): """Not implemented yet""" return False
[docs]class Word2charPredictor(UnboundedVocabularyPredictor): """This predictor wraps word level predictors when SGNMT is running on the character level. The mapping between word ID and character ID sequence is loaded from the file system. All characters which do not appear in that mapping are treated as word boundary makers. The wrapper blocks consume and predict_next calls until a word boundary marker is consumed, and updates the slave predictor according the word between the last two word boundaries. The mapping is done only on the target side, and the source sentences are passed through as they are. To use alternative tokenization on the source side, see the altsrc predictor wrapper. The word2char wrapper is always an ``UnboundedVocabularyPredictor``. """ def __init__(self, map_path, slave_predictor): """Creates a new word2char wrapper predictor. The map_path file has to be plain text files, each line containing the mapping from a word index to the character index sequence (format: word char1 char2... charn). Args: map_path (string): Path to the mapping file slave_predictor (Predictor): Instance of the predictor with a different wmap than SGNMT """ super(Word2charPredictor, self).__init__() self.slave_predictor = slave_predictor self.words = SimpleTrie() self.word_chars = {} with open(map_path) as f: for line in f: l = [int(x) for x in line.strip().split()] word = l[0] chars = l[1:] self.words.add(chars, word) for c in chars: self.word_chars[c] = True if isinstance(slave_predictor, UnboundedVocabularyPredictor): self._get_stub_prob = self._get_stub_prob_unbounded self._start_new_word = self._start_new_word_unbounded else: self._get_stub_prob = self._get_stub_prob_bounded self._start_new_word = self._start_new_word_bounded
[docs] def initialize(self, src_sentence): """Pass through to slave predictor. The source sentence is not modified """ self.slave_predictor.initialize(src_sentence) self._start_new_word()
[docs] def initialize_heuristic(self, src_sentence): """Pass through to slave predictor. The source sentence is not modified """ self.slave_predictor.initialize_heuristic(src_sentence)
def _update_slave_vars(self, posterior): self.slave_unk = self.slave_predictor.get_unk_probability(posterior) self.slave_go = common_get(posterior, utils.GO_ID, self.slave_unk) self.slave_eos = common_get(posterior, utils.EOS_ID, self.slave_unk) def _start_new_word_unbounded(self): """start_new_word implementation for unbounded vocabulary slave predictors. Needs to set slave_go, slave_eos, and slave_unk """ self.word_stub = [] posterior = self.slave_predictor.predict_next([utils.UNK_ID, utils.GO_ID, utils.EOS_ID]) self._update_slave_vars(posterior) def _start_new_word_bounded(self): """start_new_word implementation for bounded vocabulary slave predictors. Needs to set slave_go, slave_eos, slave_unk, and slave_posterior """ self.word_stub = [] self.slave_posterior = self.slave_predictor.predict_next() self._update_slave_vars(self.slave_posterior) def _get_stub_prob_unbounded(self): """get_stub_prob implementation for unbounded vocabulary slave predictors. """ word = self.words.get(self.word_stub) if word: posterior = self.slave_predictor.predict_next([word]) return common_get(posterior, word, self.slave_unk) return self.slave_unk def _get_stub_prob_bounded(self): """get_stub_prob implementation for bounded vocabulary slave predictors. """ word = self.words.get(self.word_stub) return common_get(self.slave_posterior, word if word else utils.UNK_ID, self.slave_unk)
[docs] def predict_next(self, trgt_words): posterior = {} stub_prob = False for ch in trgt_words: if ch in self.word_chars: posterior[ch] = 0.0 else: # Word boundary marker if stub_prob is False: stub_prob = self._get_stub_prob() if self.word_stub else 0.0 posterior[ch] = stub_prob if utils.GO_ID in posterior: posterior[utils.GO_ID] += self.slave_go if utils.EOS_ID in posterior: posterior[utils.EOS_ID] += self.slave_eos return posterior
[docs] def get_unk_probability(self, posterior): """This is about the unkown character, not word. Since the word level slave predictor has no notion of the unknown character, we return NEG_INF unconditionally. """ return NEG_INF
[docs] def consume(self, word): """If ``word`` is a word boundary marker, truncate ``word_stub`` and let the slave predictor consume word_stub. Otherwise, extend ``word_stub`` by the character. """ if word in self.word_chars: self.word_stub.append(word) elif self.word_stub: word = self.words.get(self.word_stub) self.slave_predictor.consume(word if word else utils.UNK_ID) self._start_new_word()
[docs] def get_state(self): """Pass through to slave predictor """ return self.word_stub, self.slave_predictor.get_state()
[docs] def set_state(self, state): """Pass through to slave predictor """ self.word_stub, slave_state = state self.slave_predictor.set_state(slave_state)
[docs] def estimate_future_cost(self, hypo): """Not supported """ logging.warn("Cannot use future cost estimates of predictors " "wrapped by word2char") return 0.0
[docs] def set_current_sen_id(self, cur_sen_id): """We need to override this method to propagate current\_ sentence_id to the slave predictor """ super(Word2charPredictor, self).set_current_sen_id(cur_sen_id) self.slave_predictor.set_current_sen_id(cur_sen_id)
[docs] def is_equal(self, state1, state2): """Pass through to slave predictor """ stub1, slave_state1 = state1 stub2, slave_state2 = state2 return (stub1 == stub2 and self.slave_predictor.is_equal(slave_state1, slave_state2))