Source code for cam.sgnmt.decoding.beam

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of the beam search strategy """

import copy
import logging

from cam.sgnmt import utils
from cam.sgnmt.decoding.core import Decoder, PartialHypothesis
import numpy as np


[docs]class BeamDecoder(Decoder): """This decoder implements standard beam search and several variants of it such as diversity promoting beam search and beam search with heuristic future cost estimates. This implementation supports risk-free pruning and hypotheses recombination. """ def __init__(self, decoder_args): """Creates a new beam decoder instance. The following values are fetched from `decoder_args`: hypo_recombination (bool): Activates hypo recombination beam (int): Absolute beam size. A beam of 12 means that we keep track of 12 active hypotheses sub_beam (int): Number of children per hypothesis. Set to beam size if zero. pure_heuristic_scores (bool): Hypotheses to keep in the beam are normally selected according the sum of partial hypo score and future cost estimates. If set to true, partial hypo scores are ignored. diversity_factor (float): If this is set to a positive value we add diversity promoting penalization terms to the partial hypothesis scores following Li and Jurafsky, 2016 early_stopping (bool): If true, we stop when the best scoring hypothesis ends with </S>. If false, we stop when all hypotheses end with </S>. Enable if you are only interested in the single best decoding result. If you want to create full 12-best lists, disable Args: decoder_args (object): Decoder configuration passed through from the configuration API. """ super(BeamDecoder, self).__init__(decoder_args) self.diversity_factor = decoder_args.decoder_diversity_factor self.diverse_decoding = (self.diversity_factor > 0.0) if self.diversity_factor > 0.0: logging.fatal("Diversity promoting beam search is not implemented " "yet") self.beam_size = decoder_args.beam self.sub_beam_size = decoder_args.sub_beam if self.sub_beam_size <= 0: self.sub_beam_size = decoder_args.beam self.hypo_recombination = decoder_args.hypo_recombination self.maintain_best_scores = False if decoder_args.early_stopping: self.stop_criterion = self._best_eos if not self.hypo_recombination: self.maintain_best_scores = True logging.debug("Risk-free beam-search pruning enabled") else: self.stop_criterion = self._all_eos self.pure_heuristic_scores = decoder_args.pure_heuristic_scores def _get_combined_score(self, hypo): """Combines hypo score with future cost estimates.""" est_score = -self.estimate_future_cost(hypo) if not self.pure_heuristic_scores: return est_score + hypo.score return est_score def _best_eos(self, hypos): """Returns true if the best hypothesis ends with </S>""" return hypos[0].get_last_word() != utils.EOS_ID def _all_eos(self, hypos): """Returns true if the all hypotheses end with </S>""" for hypo in hypos: if hypo.get_last_word() != utils.EOS_ID: return True return False def _expand_hypo(self, hypo): """Get the best beam size expansions of ``hypo``. Args: hypo (PartialHypothesis): Hypothesis to expand Returns: list. List of child hypotheses """ if hypo.score <= self.min_score: return [] self.set_predictor_states(copy.deepcopy(hypo.predictor_states)) if not hypo.word_to_consume is None: # Consume if cheap expand self.consume(hypo.word_to_consume) hypo.word_to_consume = None posterior, score_breakdown = self.apply_predictors(self.sub_beam_size) hypo.predictor_states = self.get_predictor_states() return [hypo.cheap_expand( trgt_word, posterior[trgt_word], score_breakdown[trgt_word]) for trgt_word in posterior] def _filter_equal_hypos(self, hypos, scores): """Apply hypo recombination to the hypotheses in ``hypos``. Args: hypos (list): List of hypotheses scores (list): hypo scores with heuristic estimates Return: list. List with hypotheses in ``hypos`` after applying hypotheses recombination. """ new_hypos = [] for idx in reversed(np.argsort(scores)): candidate = hypos[idx] self.set_predictor_states(copy.deepcopy(candidate.predictor_states)) if not candidate.word_to_consume is None: self.consume(candidate.word_to_consume) candidate.word_to_consume = None candidate.predictor_states = self.get_predictor_states() valid = True for hypo in new_hypos: if self.are_equal_predictor_states( hypo.predictor_states, candidate.predictor_states): logging.debug("Hypo recombination: %s > %s" % ( hypo.trgt_sentence, candidate.trgt_sentence)) valid = False break if valid: new_hypos.append(candidate) if len(new_hypos) >= self.beam_size: break return new_hypos def _get_next_hypos(self, all_hypos, all_scores): """Get hypos for the next iteration. """ hypos = [all_hypos[idx] for idx in np.argsort(all_scores)[-self.beam_size:]] hypos.reverse() return hypos def _register_score(self, score): """Updates best_scores and min_score. """ if not self.maintain_best_scores: return self.best_scores.append(score) self.best_scores.sort(reverse=True) if len(self.best_scores) >= self.beam_size: self.best_scores = self.best_scores[:self.beam_size] self.min_score = self.best_scores[-1] def _get_initial_hypos(self): """Get the list of initial ``PartialHypothesis``. """ return [PartialHypothesis(self.get_predictor_states())]
[docs] def decode(self, src_sentence): """Decodes a single source sentence using beam search. """ self.initialize_predictors(src_sentence) hypos = self._get_initial_hypos() it = 0 while self.stop_criterion(hypos): if it > self.max_len: # prevent infinite loops break it = it + 1 next_hypos = [] next_scores = [] self.min_score = utils.NEG_INF self.best_scores = [] for hypo in hypos: if hypo.get_last_word() == utils.EOS_ID: next_hypos.append(hypo) next_scores.append(self._get_combined_score(hypo)) continue for next_hypo in self._expand_hypo(hypo): next_score = self._get_combined_score(next_hypo) if next_score > self.min_score: next_hypos.append(next_hypo) next_scores.append(next_score) self._register_score(next_score) if self.hypo_recombination: hypos = self._filter_equal_hypos(next_hypos, next_scores) else: hypos = self._get_next_hypos(next_hypos, next_scores) for hypo in hypos: if hypo.get_last_word() == utils.EOS_ID: self.add_full_hypo(hypo.generate_full_hypothesis()) if not self.full_hypos: logging.warn("No complete hypotheses found for %s" % src_sentence) for hypo in hypos: self.add_full_hypo(hypo.generate_full_hypothesis()) return self.get_full_hypos_sorted()