Source code for cam.sgnmt.decoding.bucket

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of the bucket search strategy """

import copy
import logging
import operator

from cam.sgnmt import utils
from cam.sgnmt.decoding.core import Decoder, PartialHypothesis
from cam.sgnmt.utils import INF, NEG_INF
import numpy as np


[docs]class BucketDecoder(Decoder): """The bucket decoder maintains separate buckets for each sentence length. The buckets contain partial hypotheses. In each iteration, the decoder selects a bucket, and expands the best hypothesis in this bucket by one token. The core of the bucket decoder is the bucket selection strategy. The following strategies are available: * 'iter': Puts all buckets in a big loop and iterates through it. With this strategy, the number of hypothesis expansions is equally distributed over the buckets * 'random': (with stochastic=true and bucket_selecto!=difference) Randomly select a non-empty bucket * 'difference': Similar to the heuristic used by the restarting decoder. Select the bucket in which the difference between best and second best hypothesis is minimal * 'maxdiff': Like 'iter', but filters buckets in which the difference between first and second hypo is larger than epsilon. If no such buckets exist, increase epsilon """ def __init__(self, decoder_args, hypo_recombination, max_expansions=0, low_memory_mode = True, beam=1, pure_heuristic_scores = False, diversity_factor = -1.0, early_stopping=True, stochastic=False, bucket_selector='maxscore', bucket_score_strategy='difference', collect_stats_strategy='best'): """Create a new bucket decoder Args: decoder_args (object): Decoder configuration passed through from the configuration API. hypo_recombination (boolean): Activates hypothesis recombination. Hypos are tested only within a bucket max_expansions (int): Maximum number of node expansions for inadmissible pruning. low_memory_mode (bool): Switch on low memory mode at cost of some computational overhead. Limits the number of hypotheses in each bucket to the number of remaining node expansions beam (int): Number of hypotheses which get expanded at once after selecting a bucket. pure_heuristic_scores (bool): If false, hypos are scored with partial score plus future cost estimates. If true, only the future cost estimates are used diversity_factor (float): If this is set to a positive value, we reorder hypos in a bucket by adding a term which counts how many hypos with the same parent have been expanded already early_stopping (boolean): Admissible pruning (works only if scores are non-positive) stochastic (boolean): Stochastic bucket selection. If the bucket selector is not 'difference', this results in random bucket selection. If ``bucket_selector`` is set to 'difference', buckets are randomly selected with probability according the bucket score bucket_selector (string): Bucket selection strategy. 'iter', 'maxscore'. 'score'. See the class docstring for more info bucket_score_strategy (string): Determines the way the buckets are scored. 'difference' between current best word score and best hypo in bucket, 'absolute' hypo score, 'heap' score of top scoring hypo in bucket , 'constant' score of 0.0. collect_stats_strategy (string): best, full, or all. Defines how unigram estimates are collected for heuristic """ super(BucketDecoder, self).__init__(decoder_args) self.max_expansions_param = max_expansions self.low_memory_mode = low_memory_mode self.beam = beam self.pure_heuristic_scores = pure_heuristic_scores self.diversity_factor = diversity_factor self.diverse_decoding = (diversity_factor > 0.0) self.early_stopping = early_stopping if stochastic: if bucket_selector == 'score': self.get_bucket = self._get_bucket_stochastic else: self.get_bucket = self._get_bucket_random else: if 'iter' in bucket_selector: self.get_bucket = self._get_bucket_iter self.max_iter = 1000000 if '-' in bucket_selector: _,n = bucket_selector.split("-") self.max_iter = int(n) elif bucket_selector == 'maxscore': self.get_bucket = self._get_bucket_maxscore elif bucket_selector == 'score-end': self.get_bucket = self._get_bucket_score_end elif bucket_selector == 'score': self.get_bucket = self._get_bucket_score else: logging.fatal("Unknown bucket selector") if bucket_score_strategy == 'difference': self.get_bucketscore = self._get_bucketscore_difference elif bucket_score_strategy == 'heap': self.get_bucketscore = self._get_bucketscore_heap elif bucket_score_strategy == 'absolute': self.get_bucketscore = self._get_bucketscore_absolute elif bucket_score_strategy == 'constant': self.get_bucketscore = self._get_bucketscore_constant else: logging.fatal("Unknown bucket score strategy") self.collect_stats_from_partial = (collect_stats_strategy == 'all') if collect_stats_strategy == 'best': self.collect_stats = self._collect_stats_best elif collect_stats_strategy == 'full': self.collect_stats = self._collect_stats_full self.hypo_recombination = hypo_recombination def _get_bucketscore_difference(self, length): return self.best_word_scores[length] - self.buckets[length][0][1].score def _get_bucketscore_heap(self, length): return self.buckets[length][0][0] def _get_bucketscore_absolute(self, length): return -self.buckets[length][0][1].score def _get_bucketscore_constant(self, length): return 0.0 def _get_bucket_iter(self): """Implements the bucket selector 'iter' """ if self.cur_iter > self.max_iter: if self.guaranteed_optimality: logging.info("max iter. Optimality not guaranteed for ID %d" % (self.current_sen_id + 1)) self.guaranteed_optimality = False return -1 last_length = self.last_bucket for length in range(last_length+1, self.max_len): if self.buckets[length]: self.last_bucket = length return length # Restart with first bucket self.cur_iter += 1 for length in range(last_length+1): if self.buckets[length]: self.last_bucket = length return length return -1 def _get_bucket_maxscore(self): """Implements the bucket selector 'maxscore' """ for max_score in range(0, 500, 5): length = self._get_bucket_maxscore_helper(max_score) if length >= 0: return length return -1 def _get_bucket_maxscore_helper(self, max_score): """Helper method for maxscore """ last_length = self.last_bucket for length in range(last_length+1, self.max_len): if self.buckets[length]: score = self.get_bucketscore(length) if score < max_score: self.last_bucket = length return length for length in range(last_length+1): if self.buckets[length]: score = self.get_bucketscore(length) if score < max_score: self.last_bucket = length return length return -1 def _get_bucket_score(self): """Implements the bucket selector 'score' """ best_score = INF best_length = -1 for length in range(self.max_len): if self.buckets[length]: score = self.get_bucketscore(length) if score <= best_score: best_score = score best_length = length return best_length def _get_bucket_score_end(self): """Implements the bucket selector 'score-end' """ last_length = self.last_bucket for length in range(last_length+1, self.max_len): if self.buckets[length]: self.last_bucket = length return length # Restart with best bucket best_score = INF best_length = -1 for length in range(self.max_len): if self.buckets[length]: score = self.get_bucketscore(length) if score <= best_score: best_score = score best_length = length self.last_bucket = best_length return best_length def _get_bucket_random(self): """Implements random bucket selection """ lengths = [l for l in range(self.max_len) if self.buckets[l]] return np.random.choice(lengths) def _get_bucket_stochastic(self): """Implements the stochastic bucket selector 'difference' """ lengths = [] scores = [] for length in range(self.max_len): if self.buckets[length]: score = self.get_bucketscore(length) if score == NEG_INF: return self._get_bucket_difference() lengths.append(length) scores.append(score) if not lengths: return -1 exps = np.exp([-d for d in scores]) total = sum(exps) return np.random.choice(lengths, p=[e/total for e in exps]) def _initialize_decoding(self, src_sentence): """Helper function for ``decode`` to which initializes all the class attributes """ self.initialize_predictors(src_sentence) self.max_expansions = self.get_max_expansions(self.max_expansions_param, src_sentence) init_hypo = PartialHypothesis() init_hypo.predictor_states = self.get_predictor_states() init_hypo.scores = [] init_hypo.parent_hypo_array_idx = 0 # point to guardian self.buckets = [[] for _ in range(self.max_len+1)] self.expanded_hypos = [[] for _ in range(self.max_len+1)] self.buckets[0].append((0.0, init_hypo)) self.expand_counts = [0.0] # with guardian self.expand_backpointers = [0] # with guardian self.last_bucket = 0 self.best_score = self.get_lower_score_bound() self.best_word_scores = [NEG_INF] * (self.max_len+1) self.compressed = [True] * (self.max_len+1) self.guaranteed_optimality = True self.cur_iter = 0 def _activate_hypo(self, hypo, length, heap_score): """Prepares the decoder for expanding the given hypothesis. This may include updating global word scores, loading the predictor states and consume the last word of the hypothesis if necessary. After this method, ``apply_predictors`` computes the next posterior vector """ if (self.collect_stats_from_partial and hypo.score > self.best_word_scores[length]): self.best_word_scores[length] = hypo.score self._update_heap_scores() logging.debug("Expand (best_glob=%f diff=%f (%f-%f) heap=%f exp=%d): %s" % (self.best_score, self.best_word_scores[length] - hypo.score, self.best_word_scores[length], hypo.score, heap_score, self.apply_predictors_count, ' '.join([str(w) for w in hypo.trgt_sentence]))) self.set_predictor_states(copy.deepcopy(hypo.predictor_states)) if not hypo.word_to_consume is None: # Consume if cheap expand self.consume(hypo.word_to_consume) hypo.word_to_consume = None def _get_next_hypo_array_idx(self, hypo): """Used to assign a new hypothesis array index used in ``expand_counts`` and ``expand_backpointers``. This also updates the count and backpointer array. Only called if ``diverse_decoding`` is true. """ idx = hypo.parent_hypo_array_idx while idx != 0: self.expand_counts[idx] += 1.0 idx = self.expand_backpointers[idx] hypo_array_idx = len(self.expand_counts) self.expand_counts.append(0.0) self.expand_backpointers.append(hypo.parent_hypo_array_idx) return hypo_array_idx def _collect_stats_best(self, hypo): if hypo.score > self.best_score: self.best_score = hypo.score self.best_word_scores[:len(hypo.scores) ] = hypo.scores self._update_heap_scores() def _collect_stats_full(self, hypo): update = False for idx, score in enumerate(hypo.scores): if score > self.best_word_scores[idx]: self.best_word_scores[idx] = score update = True if update: self._update_heap_scores() def _register_full_hypo(self, hypo): """Update all class attributes according a new full hypothesis. ``hypo`` is a ``PartialHypothesis`` ending with EOS. """ self.add_full_hypo(hypo.generate_full_hypothesis()) if hypo.score > self.best_score: logging.info("New_best (ID: %d): score=%f exp=%d hypo=%s" % (self.current_sen_id + 1, hypo.score, self.apply_predictors_count, ' '.join([str(w) for w in hypo.trgt_sentence]))) if not self.collect_stats_from_partial: self.collect_stats(hypo) if hypo.score > self.best_score: self.best_score = hypo.score def _update_heap_scores(self): """Called when ``best_word_scores`` has changed and heap scores need to be updated """ for length in range(len(self.buckets)): self._update_heap_score(length) def _update_heap_score(self, length): """``_update_heap_score`` for a single bucket """ new_bucket = [(-self._get_combined_score(h), h) for _,h in self.buckets[length]] self.buckets[length] = new_bucket self.buckets[length].sort(key=operator.itemgetter(0)) def _get_max_bucket_size(self): if not self.low_memory_mode: return 10000000 max_size = max(1, 1 + self.max_expansions - self.apply_predictors_count) if self.hypo_recombination: max_size = 5 * max_size return max_size def _add_new_hypos_to_bucket(self, length, new_hypos): max_size = self._get_max_bucket_size() if self.diverse_decoding: self._update_heap_score(length) if (not self.hypo_recombination or not self.guaranteed_optimality or max_size >= len(new_hypos) + len(self.buckets[length])): self.buckets[length].extend(new_hypos) self.buckets[length].sort(key=operator.itemgetter(0)) self.buckets[length] = self.buckets[length][:max_size] self.compressed[length] = False elif self.compressed[length]: # Equivalence check only for new logging.debug("Add %d hypos to compressed bucket of size %d" % ( len(new_hypos), len(self.buckets[length]))) new_hypos.sort(key=operator.itemgetter(0)) new_bucket = [] oidx = 0 nidx = 0 olen = len(self.buckets[length]) nlen = len(new_hypos) while len(new_bucket) < max_size: oscore = INF if oidx >= olen else self.buckets[length][oidx][0] nscore = INF if nidx >= nlen else new_hypos[nidx][0] if oscore == INF and nscore == INF: break if oscore < nscore: # Add hypos from old bucket without checks new_bucket.append(self.buckets[length][oidx]) oidx += 1 else: # Check equivalence hypo = new_hypos[nidx][1] self.set_predictor_states(copy.deepcopy( hypo.predictor_states)) if not hypo.word_to_consume is None: self.consume(hypo.word_to_consume) hypo.word_to_consume = None hypo.predictor_states = self.get_predictor_states() valid = True for other_hypo in [b for _,b in new_bucket]: if other_hypo.score >= hypo.score and self.are_equal_predictor_states( hypo.predictor_states, other_hypo.predictor_states): valid = False logging.debug("Hypo recombination: %s > %s (compress)" % (other_hypo.trgt_sentence, hypo.trgt_sentence)) break if valid: new_bucket.append((nscore, hypo)) nidx += 1 self.buckets[length] = new_bucket self.compressed[length] = True else: # Compress from scratch hypos = self.buckets[length] + new_hypos logging.debug("Compress bucket of size %d" % len(hypos)) new_hypos.sort(key=operator.itemgetter(0)) new_bucket = [] idx = 0 while len(new_bucket) < max_size and idx < len(hypos): hypo = hypos[idx][1] self.set_predictor_states(copy.deepcopy( hypo.predictor_states)) if not hypo.word_to_consume is None: self.consume(hypo.word_to_consume) hypo.word_to_consume = None hypo.predictor_states = self.get_predictor_states() valid = True for other_hypo in [b for _,b in new_bucket]: if other_hypo.score >= hypo.score and self.are_equal_predictor_states( hypo.predictor_states, other_hypo.predictor_states): valid = False logging.debug("Hypo recombination: %s > %s" % (other_hypo.trgt_sentence, hypo.trgt_sentence)) break if valid: new_bucket.append((hypos[idx][0], hypo)) idx += 1 self.buckets[length] = new_bucket self.compressed[length] = True if (self.hypo_recombination and self.guaranteed_optimality and len(self.buckets[length]) >= max_size): logging.info("Shrunk bucket. Optimality not guaranteed for ID %d" % (self.current_sen_id + 1)) self.guaranteed_optimality = False def _get_combined_score(self, hypo): est_score = -self.estimate_future_cost(hypo) if self.diverse_decoding: cnt = 0.0 idx = hypo.parent_hypo_array_idx while idx != 0: cnt += self.expand_counts[idx] idx = self.expand_backpointers[idx] est_score -= self.diversity_factor * cnt if not self.pure_heuristic_scores: est_score += hypo.score if self.best_score != NEG_INF: est_score -= self.best_score return est_score def _get_min_bucket_score(self, length): max_bucket_size = self._get_max_bucket_size() if len(self.buckets[length]) >= max_bucket_size: return -self.buckets[length][max_bucket_size-1][0] return NEG_INF def _get_hypo(self, length): hypo = None while self.buckets[length] and hypo is None: s,hypo = self.buckets[length].pop(0) if self.early_stopping and hypo.score <= self.best_score: hypo = None else: self._activate_hypo(hypo, length, s) if self.hypo_recombination: hypo.predictor_states = self.get_predictor_states() for other_hypo in self.expanded_hypos[length]: if other_hypo.score >= hypo.score and self.are_equal_predictor_states( hypo.predictor_states, other_hypo.predictor_states): logging.debug("Hypo recombination: %s > %s (activate)" % (other_hypo.trgt_sentence, hypo.trgt_sentence)) hypo = None break if not hypo is None: self.expanded_hypos[length].append(hypo) return hypo
[docs] def decode(self, src_sentence): """Decodes a single source sentence. """ self._initialize_decoding(src_sentence) while self.max_expansions > self.apply_predictors_count: length = self.get_bucket() if length < 0: # No more full buckets break min_next_bucket_score = self._get_min_bucket_score(length+1) hypos_to_add = [] for _ in range(self.beam): # Expand beam_size hypos in this bucket hypo = self._get_hypo(length) if hypo is None: break posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = self.get_predictor_states() if self.diverse_decoding: hypo_array_idx = self._get_next_hypo_array_idx(hypo) for w,score in posterior.items(): exp_hypo = hypo.cheap_expand(w, score, score_breakdown[w]) exp_hypo.scores = hypo.scores + [hypo.score] if self.diverse_decoding: exp_hypo.parent_hypo_array_idx = hypo_array_idx combi_score = self._get_combined_score(exp_hypo) if w == utils.EOS_ID: self._register_full_hypo(exp_hypo) elif (combi_score >= min_next_bucket_score and (exp_hypo.score > self.best_score or not self.early_stopping)): hypos_to_add.append((-combi_score, exp_hypo)) self._add_new_hypos_to_bucket(length+1, hypos_to_add) if self.guaranteed_optimality and self.max_expansions <= self.apply_predictors_count: logging.info("Reached max_node_expansions. Optimality not guaranteed for ID %d" % (self.current_sen_id + 1)) if not self.full_hypos: # Add incomplete longest hypos if no complete logging.warn("No complete hypotheses found for %s" % src_sentence) for hypo in self.buckets[self.max_len]: self.add_full_hypo(hypo.generate_full_hypothesis()) return self.get_full_hypos_sorted()