Source code for cam.sgnmt.decoding.bigramgreedy

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of the bigram greedy search strategy """

import copy
import logging
import operator

from cam.sgnmt import utils
from cam.sgnmt.decoding.core import Decoder, PartialHypothesis
from cam.sgnmt.misc.trie import SimpleTrie


[docs]class BigramGreedyDecoder(Decoder): """The bigram greedy decoder collects bigram statistics at each node expansions. After each decoding pass, it constructs a new hypothesis to rescore by greedily selecting bigrams and gluing them together. Afterwards, the new hypothesis is rescored and new bigram statistics are collected. Note that this decoder does not support the ``max_length`` parameter as it is designed for fixed length decoding problems. Also note that this decoder works only for bag-of-words problems. Do not use the bow predictor in combination with this decoder as it will hide the EOS scores which are important to estimate bigram scores. """ def __init__(self, decoder_args): """Creates a new bigram greedy decoder. Do not use this decoder in combination with the bow predictor as it inherently already satisfies the bag-of-word constrains. The following values are fetched from `decoder_args`: trg_test (string): Path to a plain text file which defines the bag of words max_node_expansions (int): Maximum number of node expansions for inadmissible pruning. early_stopping (boolean): Activates admissible pruning Args: decoder_args (object): Decoder configuration passed through from the configuration API. """ super(BigramGreedyDecoder, self).__init__(decoder_args) self.max_expansions_param = decoder_args.max_node_expansions self.early_stopping = decoder_args.early_stopping with open(decoder_args.trg_test) as f: self.lines = f.read().splitlines() def _greedy_decode(self): """Performs greedy decoding from the start node. Used to obtain initial bigram statistics. """ hypo = PartialHypothesis() hypos = [] posteriors = [] score_breakdowns = [] bag = dict(self.full_bag) while bag: posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = copy.deepcopy(self.get_predictor_states()) bag_posterior = {w: posterior[w] for w in self.full_bag_with_eos} bag_breakdown = {w: score_breakdown[w] for w in self.full_bag_with_eos} posteriors.append(bag_posterior) score_breakdowns.append(bag_breakdown) hypos.append(hypo) best_word = utils.argmax({w: bag_posterior[w] for w in bag}) bag[best_word] -= 1 if bag[best_word] < 1: del bag[best_word] self.consume(best_word) hypo = hypo.expand(best_word, None, bag_posterior[best_word], score_breakdown[best_word]) posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = copy.deepcopy(self.get_predictor_states()) bag_posterior = {w: posterior[w] for w in self.full_bag_with_eos} bag_breakdown = {w: score_breakdown[w] for w in self.full_bag_with_eos} posteriors.append(bag_posterior) score_breakdowns.append(bag_breakdown) hypos.append(hypo) hypo = hypo.cheap_expand(utils.EOS_ID, bag_posterior[utils.EOS_ID], score_breakdown[utils.EOS_ID]) logging.debug("Greedy hypo (%f): %s" % ( hypo.score, ' '.join([str(w) for w in hypo.trgt_sentence]))) self._process_new_hypos(hypos, posteriors, score_breakdowns, hypo) def _process_new_hypos(self, hypos, posteriors, score_breakdowns, complete_hypo = None): """This method is called after a decoding pass. It updates bigram statistics, stores partial hypotheses for restarting from them later, and creates full hypotheses if a hypo ends with EOS """ if complete_hypo: self.best_score = max(self.best_score, complete_hypo.score) self.add_full_hypo(complete_hypo.generate_full_hypothesis()) for idx,hypo in enumerate(hypos): posterior = posteriors[idx] prefix = hypo.trgt_sentence self._register_bigram_scores(prefix[-1] if prefix else utils.GO_ID, posterior) self.posteriors.add(prefix, posterior) self.score_breakdowns.add(prefix, score_breakdowns[idx]) self.hypos.add(prefix, hypo) self._sort_bigram_scores() def _get_next_sentence(self): """Get the next sentence to rescore """ bag0 = dict(self.full_bag) bag0[utils.GO_ID] = 1 bag1 = dict(self.full_bag_with_eos) return self._get_next_sentence_recursive([], bag0, bag1) def _get_next_sentence_recursive(self, bigrams, remaining_bag0, remaining_bag1): """Recursive helper function for _get_next_sentence Args: bigrams (list): List of already selected bigrams remaining_bag0 (dict): Remaining words in the bag for the first word in the bigram remaining_bag1 (dict): Remaining words in the bag for the second word in the bigram Returns: Tuple. hypo, sen tuple where sen is an unexplored sentence and hypo corresponds to the largest explored prefix of sen. Returns None if no consistent sentence was found """ if len(bigrams) == self.num_words + 1: # Collected enough bigrams sens = self._get_sentences_from_bigrams(bigrams) if not sens: # Bigrams are not consistent return None for sen in sens: hypo = self._get_largest_prefix_hypo(sen) if hypo and hypo.score > self.best_score: return hypo, sen return None for bigram in self.sorted_bigrams: if remaining_bag0[bigram[0]] > 0 and remaining_bag1[bigram[1]] > 0: remaining_bag0[bigram[0]] -= 1 remaining_bag1[bigram[1]] -= 1 ret = self._get_next_sentence_recursive(bigrams + [bigram], remaining_bag0, remaining_bag1) if ret: return ret remaining_bag0[bigram[0]] += 1 remaining_bag1[bigram[1]] += 1 return None def _get_largest_prefix_hypo(self, sen): """Get the explored hypothesis with the largest common prefix with ``sen``. """ prefix = self.hypos.get_prefix(sen) if len(prefix) == len(sen): # hypo is already fully explored return None hypo = self.hypos.get(prefix) posterior = self.posteriors.get(prefix) score_breakdown = self.score_breakdowns.get(prefix) next_word = sen[len(prefix)] return hypo.cheap_expand(next_word, posterior[next_word], score_breakdown[next_word]) def _get_sentences_from_bigrams(self, bigrams): """Constructs all full consistent sentences from a list of bigrams. The search is implemented as BFS. """ candidates = [([utils.GO_ID], bigrams)] for _ in range(len(bigrams)): next_candidates = [] for candidate in candidates: # Select the next consistent bigram cand_sen,cand_bigrams = candidate last_word = cand_sen[-1] for idx,bigram in enumerate(cand_bigrams): if bigram[0] == last_word: # Consistent new_bigrams = list(cand_bigrams) del new_bigrams[idx] next_candidates.append((cand_sen + [bigram[1]], new_bigrams)) candidates = next_candidates if not candidates: break return [candidate[0][1:] for candidate in candidates] def _forced_decode(self, start_hypo, sen): """Performs forced decoding from a the node in the search tree. Args: start_hypo (PartialHypothesis): This is a partial hypothesis for a prefix of sen from which we start decoding sen (list): Sentence to rescore """ logging.debug("best=%f prefix=%s prefix_score=%f sen=%s" % ( self.best_score, start_hypo.trgt_sentence, start_hypo.score, sen)) self.set_predictor_states(copy.deepcopy(start_hypo.predictor_states)) if not start_hypo.word_to_consume is None: # Consume if cheap expand self.consume(start_hypo.word_to_consume) hypos = [] posteriors = [] score_breakdowns = [] hypo = start_hypo cancelled = False for forced_w in sen[len(start_hypo.trgt_sentence):]: posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = copy.deepcopy(self.get_predictor_states()) bag_posterior = {w: posterior[w] for w in self.full_bag_with_eos} bag_breakdown = {w: score_breakdown[w] for w in self.full_bag_with_eos} posteriors.append(bag_posterior) score_breakdowns.append(bag_breakdown) hypos.append(hypo) hypo = hypo.expand(forced_w, None, bag_posterior[forced_w], score_breakdown[forced_w]) if self.early_stopping and hypo.score < self.best_score: cancelled = True break self.consume(forced_w) self._process_new_hypos(hypos, posteriors, score_breakdowns, hypo if not cancelled else None) def _load_bag(self): """Load the current bag of words """ self.full_bag = {} for w in self.lines[self.current_sen_id].strip().split(): int_w = int(w) self.full_bag[int_w] = self.full_bag.get(int_w, 0) + 1 self.num_words = sum(self.full_bag.values()) self.full_bag_with_eos = dict(self.full_bag) self.full_bag_with_eos[utils.EOS_ID] = 1 def _register_bigram_scores(self, last_word, posterior): for w,score in utils.common_iterable(posterior): self.bigram_scores[last_word][w] = min( self.bigram_scores[last_word][w], score) def _sort_bigram_scores(self): self.sorted_bigrams = [] for w1,scores in self.bigram_scores.items(): self.sorted_bigrams.extend([(w1, w2, score) for w2,score in scores.items()]) self.sorted_bigrams.sort(key=operator.itemgetter(2), reverse=True) def _initialize_bigram_scores(self): default_scores = {w: 0.0 for w in self.full_bag_with_eos} self.bigram_scores = {w: dict(default_scores) for w in self.full_bag} self.bigram_scores[utils.GO_ID] = default_scores
[docs] def decode(self, src_sentence): """Decodes a single source sentence with the flip decoder """ self.initialize_predictors(src_sentence) self.max_expansions = self.get_max_expansions(self.max_expansions_param, src_sentence) self._load_bag() self.hypos = SimpleTrie() self.posteriors = SimpleTrie() self.score_breakdowns = SimpleTrie() self.best_score = self.get_lower_score_bound() self._initialize_bigram_scores() self._greedy_decode() while self.max_expansions > self.apply_predictors_count: ret = self._get_next_sentence() if not ret: break self._forced_decode(ret[0], ret[1]) return self.get_full_hypos_sorted()