Source code for cam.sgnmt.predictors.tf_nizza

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module integrates Nizza alignment models.

https://github.com/fstahlberg/nizza
"""

import logging
import os
import numpy as np
from scipy.misc import logsumexp

from cam.sgnmt import utils
from cam.sgnmt.predictors.core import Predictor

try:
    import tensorflow as tf
    from tensorflow.python.training import saver
    from tensorflow.python.training import training
    # Requires nizza
    from nizza import registry
    from nizza.utils import common_utils
except ImportError:
    pass # Deal with it in decode.py


[docs]class BaseNizzaPredictor(Predictor): """Common functionality for Nizza based predictors. This includes loading checkpoints, creating sessions, and creating computation graphs. """ def __init__(self, src_vocab_size, trg_vocab_size, model_name, hparams_set_name, checkpoint_dir, single_cpu_thread, nizza_unk_id=None): """Initializes a nizza predictor. Args: src_vocab_size (int): Source vocabulary size (called inputs_vocab_size in nizza) trg_vocab_size (int): Target vocabulary size (called targets_vocab_size in nizza) model_name (string): Name of the nizza model hparams_set_name (string): Name of the nizza hyper-parameter set checkpoint_dir (string): Path to the Nizza checkpoint directory. The predictor will load the top most checkpoint in the `checkpoints` file. single_cpu_thread (bool): If true, prevent tensorflow from doing multithreading. nizza_unk_id (int): If set, use this as UNK id. Otherwise, the nizza is assumed to have no UNKs Raises: IOError if checkpoint file not found. """ super(BaseNizzaPredictor, self).__init__() if not os.path.isfile("%s/checkpoint" % checkpoint_dir): logging.fatal("Checkpoint file %s/checkpoint not found!" % checkpoint_dir) raise IOError self._single_cpu_thread = single_cpu_thread self._checkpoint_dir = checkpoint_dir self._nizza_unk_id = nizza_unk_id predictor_graph = tf.Graph() with predictor_graph.as_default() as g: hparams = registry.get_registered_hparams_set(hparams_set_name) hparams.add_hparam("inputs_vocab_size", src_vocab_size) hparams.add_hparam("targets_vocab_size", trg_vocab_size) run_config = tf.contrib.learn.RunConfig() run_config = run_config.replace(model_dir=checkpoint_dir) model = registry.get_registered_model(model_name, hparams, run_config) self._inputs_var = tf.placeholder(dtype=tf.int32, shape=[None], name="sgnmt_inputs") self._targets_var = tf.placeholder(dtype=tf.int32, shape=[None], name="sgnmt_targets") features = {"inputs": tf.expand_dims(self._inputs_var, 0), "targets": tf.expand_dims(self._targets_var, 0)} mode = tf.estimator.ModeKeys.PREDICT self.precomputed = model.precompute(features, mode, hparams) self.log_probs = tf.squeeze( model.predict_next_word(features, hparams, self.precomputed), 0) self.mon_sess = self.create_session(self._checkpoint_dir) def _session_config(self): """Creates the session config with t2t default parameters.""" graph_options = tf.GraphOptions(optimizer_options=tf.OptimizerOptions( opt_level=tf.OptimizerOptions.L1, do_function_inlining=False)) if self._single_cpu_thread: config = tf.ConfigProto( intra_op_parallelism_threads=1, inter_op_parallelism_threads=1, allow_soft_placement=True, graph_options=graph_options, log_device_placement=False) else: gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=0.95) config = tf.ConfigProto( allow_soft_placement=True, graph_options=graph_options, gpu_options=gpu_options, log_device_placement=False) return config
[docs] def create_session(self, checkpoint_dir): """Creates a MonitoredSession for this predictor.""" checkpoint_path = saver.latest_checkpoint(checkpoint_dir) return training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, config=self._session_config()))
[docs] def get_unk_probability(self, posterior): """Fetch posterior[t2t_unk_id] or return NEG_INF if None.""" if self._nizza_unk_id is None: return utils.NEG_INF return posterior[self._nizza_unk_id]
[docs]class NizzaPredictor(BaseNizzaPredictor): """This predictor uses Nizza alignment models to derive a posterior over the target vocabulary for the next position. It mainly relies on the predict_next_word() implementation of Nizza models. """
[docs] def predict_next(self): """Call the T2T model in self.mon_sess.""" log_probs = self.mon_sess.run(self.log_probs, {self._inputs_var: self.src_sentence, self._targets_var: self.consumed + [common_utils.PAD_ID]}) log_probs[common_utils.PAD_ID] = utils.NEG_INF # Mask padding return log_probs
[docs] def initialize(self, src_sentence): """Set src_sentence, reset consumed.""" self.consumed = [] self.src_sentence = src_sentence + [utils.EOS_ID]
[docs] def consume(self, word): """Append ``word`` to the current history.""" self.consumed.append(word)
[docs] def get_state(self): """The predictor state is the complete history.""" return self.consumed
[docs] def set_state(self, state): """The predictor state is the complete history.""" self.consumed = state
[docs] def is_equal(self, state1, state2): """Returns true if the history is the same """ return state1 == state2
[docs]class LexNizzaPredictor(BaseNizzaPredictor): """This predictor is only compatible to Model1-like Nizza models which return lexical translation probabilities in precompute(). The predictor keeps a list of the same length as the source sentence and initializes it with zeros. At each timestep it updates this list by the lexical scores Model1 assigned to the last consumed token. The predictor score aims to bring up all entries in the list, and thus serves as a coverage mechanism over the source sentence. """ def __init__(self, src_vocab_size, trg_vocab_size, model_name, hparams_set_name, checkpoint_dir, single_cpu_thread, alpha, beta, shortlist_strategies, trg2src_model_name="", trg2src_hparams_set_name="", trg2src_checkpoint_dir="", max_shortlist_length=0, min_id=0, nizza_unk_id=None): """Initializes a nizza predictor. Args: src_vocab_size (int): Source vocabulary size (called inputs_vocab_size in nizza) trg_vocab_size (int): Target vocabulary size (called targets_vocab_size in nizza) model_name (string): Name of the nizza model hparams_set_name (string): Name of the nizza hyper-parameter set checkpoint_dir (string): Path to the Nizza checkpoint directory. The predictor will load the top most checkpoint in the `checkpoints` file. single_cpu_thread (bool): If true, prevent tensorflow from doing multithreading. alpha (float): Score for each matching word beta (float): Penalty for each uncovered word at the end shortlist_strategies (string): Comma-separated list of shortlist strategies. trg2src_model_name (string): Name of the target2source nizza model trg2src_hparams_set_name (string): Name of the nizza hyper-parameter set for the target2source model trg2src_checkpoint_dir (string): Path to the Nizza checkpoint directory for the target2source model. The predictor will load the top most checkpoint in the `checkpoints` file. max_shortlist_length (int): If a shortlist exceeds this limit, initialize the initial coverage with 1 at this position. If zero, do not apply any limit min_id (int): Do not use IDs below this threshold (filters out most frequent words). nizza_unk_id (int): If set, use this as UNK id. Otherwise, the nizza is assumed to have no UNKs Raises: IOError if checkpoint file not found. """ super(LexNizzaPredictor, self).__init__( src_vocab_size, trg_vocab_size, model_name, hparams_set_name, checkpoint_dir, single_cpu_thread, nizza_unk_id=nizza_unk_id) self.alpha = alpha self.alpha_is_zero = alpha == 0.0 self.beta = beta self.shortlist_strategies = utils.split_comma(shortlist_strategies) self.max_shortlist_length = max_shortlist_length self.min_id = min_id if trg2src_checkpoint_dir: self.use_trg2src = True predictor_graph = tf.Graph() with predictor_graph.as_default() as g: hparams = registry.get_registered_hparams_set(trg2src_hparams_set_name) hparams.add_hparam("inputs_vocab_size", trg_vocab_size) hparams.add_hparam("targets_vocab_size", src_vocab_size) run_config = tf.contrib.learn.RunConfig() run_config = run_config.replace(model_dir=trg2src_checkpoint_dir) model = registry.get_registered_model(trg2src_model_name, hparams, run_config) features = {"inputs": tf.expand_dims(tf.range(trg_vocab_size), 0)} mode = tf.estimator.ModeKeys.PREDICT trg2src_lex_logits = model.precompute(features, mode, hparams) # Precompute trg2src partitions partitions = tf.reduce_logsumexp(trg2src_lex_logits, axis=-1) self._trg2src_src_words_var = tf.placeholder(dtype=tf.int32, shape=[None], name="sgnmt_trg2src_src_words") # trg2src_lex_logits has shape [1, trg_vocab_size, src_vocab_size] self.trg2src_logits = tf.gather(tf.transpose(trg2src_lex_logits[0, :, :]), self._trg2src_src_words_var) # trg2src_logits has shape [len(src_words), trg_vocab_size] self.trg2src_mon_sess = self.create_session(trg2src_checkpoint_dir) logging.debug("Precomputing lexnizza trg2src partitions...") self.trg2src_partitions = self.trg2src_mon_sess.run(partitions) else: self.use_trg2src = False logging.warn("No target-to-source model specified for lexnizza.")
[docs] def get_unk_probability(self, posterior): if self.alpha_is_zero: return 0.0 if self._nizza_unk_id is None: return utils.NEG_INF return posterior[self._nizza_unk_id]
[docs] def predict_next(self): """Predict record scores.""" if self.alpha_is_zero: n_uncovered = self.coverage.count("0") return {utils.EOS_ID: -float(n_uncovered) * self.beta} uncovered_scores = [self.short_list_scores[src_pos] for src_pos, is_covered in enumerate(self.coverage) if is_covered == "0"] if not uncovered_scores: return np.zeros(self.trg_vocab_size) scores = np.max(uncovered_scores, axis=0) scores[utils.EOS_ID] = -len(uncovered_scores) * self.beta return scores
[docs] def initialize(self, src_sentence): """Set src_sentence, reset consumed.""" self.filt_src_sentence = [w for w in src_sentence if w >= self.min_id] scores = self.mon_sess.run(self.precomputed, {self._inputs_var: self.filt_src_sentence}) scores = scores[0, :, :] # scores has shape [src_sentence_len, trg_vocab_size] self.trg_vocab_size = scores.shape[1] if self.use_trg2src: trg2src_logits = self.trg2src_mon_sess.run(self.trg2src_logits, {self._trg2src_src_words_var: self.filt_src_sentence}) src2trg_logits = scores src2trg_partitions = logsumexp(src2trg_logits, axis=1, keepdims=True) trg2src_logprobs = trg2src_logits - self.trg2src_partitions src2trg_logprobs = src2trg_logits - src2trg_partitions scores = src2trg_logprobs + trg2src_logprobs src_len = len(self.filt_src_sentence) is_covered = [] self.short_lists = [] self.short_list_scores = [] for src_pos in range(src_len): shortlist = self._create_short_list(scores[src_pos, :]) if (self.max_shortlist_length > 0 and len(shortlist) > self.max_shortlist_length): is_covered.append("1") shortlist = set([]) else: is_covered.append("0") self.short_lists.append(shortlist) if not self.alpha_is_zero: alpha_scores = np.zeros(self.trg_vocab_size) for w in shortlist: alpha_scores[w] = self.alpha self.short_list_scores.append(alpha_scores) self.coverage = "".join(is_covered) logging.debug("Short list sizes: %s" % ", ".join([ str(len(l)) for l in self.short_lists])) logging.debug("Initial coverage: %s" % self.coverage)
#print("SHORT LISTS") #for w, l in zip(self.filt_src_sentence, self.short_lists): # print("\n\n%d" % w) # if len(l) < 40: # print(" ".join(map(str, l)))
[docs] def consume(self, word): """Update coverage.""" new_coverage = [] for src_pos, is_covered in enumerate(self.coverage): if is_covered == "0" and word in self.short_lists[src_pos]: is_covered = "1" new_coverage.append(is_covered) self.coverage = "".join(new_coverage)
#logging.debug("Partial: %s" % " ".join([str(self.filt_src_sentence[idx]) for idx, c in enumerate(self.coverage) if c == "1"])) #logging.debug(self.coverage) def _create_short_list(self, logits): """Creates a set of tokens which are likely translations.""" words = set() filt_logits = logits[self.min_id:] for strat in self.shortlist_strategies: if strat[:3] == "top": n = int(strat[3:]) words.update(utils.argmax_n(filt_logits, n)) elif strat[:4] == "prob": p = float(strat[4:]) unnorm_probs = np.exp(filt_logits) threshold = np.sum(unnorm_probs) * p acc = 0.0 for word in np.argsort(filt_logits)[::-1]: acc += unnorm_probs[word] words.add(word) if acc >= threshold: break else: raise AttributeError("Unknown shortlist strategy '%s'" % strat) if self.min_id: words = set(w+self.min_id for w in words) try: words.remove(utils.EOS_ID) except KeyError: pass return words
[docs] def estimate_future_cost(self, hypo): """We use the number of uncovered words times beta as heuristic estimate. """ if hypo.trgt_sentence[:-1] == [utils.EOS_ID]: return 0.0 n_uncovered = 0 for short_list in self.short_lists: if not any(w in hypo.trgt_sentence for w in short_list): n_uncovered += 1 return -float(n_uncovered) * self.beta * 0.1
[docs] def get_state(self): """The predictor state is the coverage vector.""" return self.coverage
[docs] def set_state(self, state): """The predictor state is the coverage vector.""" self.coverage = state