Source code for cam.sgnmt.decoding.flip

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of the flip search strategy """

import copy
from heapq import heappush, heappop
import logging

from cam.sgnmt import utils
from cam.sgnmt.decoding.core import Decoder, PartialHypothesis
from cam.sgnmt.misc.trie import SimpleTrie


[docs]class FlipCandidate(object): """Helper class for ``FlipDecoder``. Represents a full but yet unscored hypothesis which differs from an explored hypo by one flip or move operation. """ def __init__(self, trgt_sentence, scores, bigram_scores, max_score): """Creates a new candidate hypothesis. Args: trgt_sentence (list): Full target sentence scores (list): Word level scores. Same length as ``trgt_sentence`` bigram_scores (dict): Bigram scores collected along the parent hypothesis max_score (float): Maximum possible score this candidate can achieve """ self.trgt_sentence = trgt_sentence self.scores = scores self.bigram_scores = bigram_scores self.max_score = max_score self.expected_score = sum(scores)
[docs]class FlipDecoder(Decoder): """The flip decoder explores the search space by permutating already explored hypotheses with a single permutation operation. We support two operations: 'flip' flips the position of two target tokens. 'move' moves one target token to another location in the sentence. Note that this decoder does not support the ``max_length`` parameter as it is designed for fixed length decoding problems. Also note that this decoder works only for bag-of-words problems. Do not use the bow predictor in combination with this decoder as it will hide the EOS scores which are important to estimate bigram scores. """ def __init__(self, decoder_args, always_greedy=False): """Creates a new flip decoder. Do not use this decoder in combination with the bow predictor as it inherently already satisfies the bag-of-word constrains. The following values are fetched from `decoder_args`: trg_test(string): Path to a plain text file which defines the bag of words max_node_expansions (int): Maximum number of node expansions for inadmissible pruning. early_stopping (boolean): Activates admissible pruning Args: decoder_args (object): Decoder configuration passed through from the configuration API. always_greedy (boolean): Per default, the flip decoder does forced decoding along the complete candidate sentence. Set to True to do greedy decoding from the backtraced node instead """ super(FlipDecoder, self).__init__(decoder_args) self.max_expansions_param = decoder_args.max_node_expansions self.early_stopping = decoder_args.early_stopping self.always_greedy = always_greedy with open(decoder_args.trg_test) as f: self.lines = f.read().splitlines() if decoder_args.flip_strategy == 'flip': self._extract_candidates = self._extract_candidates_flip elif decoder_args.flip_strategy == 'move': self._extract_candidates = self._extract_candidates_move else: logging.fatal("Unknown flip strategy!") def _greedy_decode(self): """Performs greedy decoding from the start node. Used to obtain the initial hypothesis. """ hypo = PartialHypothesis() hypos = [] posteriors = [] score_breakdowns = [] scores = [] bag = dict(self.full_bag) while bag: posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = copy.deepcopy(self.get_predictor_states()) hypos.append(hypo) posteriors.append(posterior) score_breakdowns.append(score_breakdown) best_word = utils.argmax({w: posterior[w] for w in bag}) bag[best_word] -= 1 if bag[best_word] < 1: del bag[best_word] self.consume(best_word) hypo = hypo.expand(best_word, None, posterior[best_word], score_breakdown[best_word]) scores.append(posterior[best_word]) posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = self.get_predictor_states() hypos.append(hypo) posteriors.append(posterior) score_breakdowns.append(score_breakdown) hypo = hypo.expand(utils.EOS_ID, None, posterior[utils.EOS_ID], score_breakdown[utils.EOS_ID]) logging.debug("Greedy hypo (%f): %s" % ( hypo.score, ' '.join([str(w) for w in hypo.trgt_sentence]))) scores.append(posterior[utils.EOS_ID]) self.best_score = hypo.score self.add_full_hypo(hypo.generate_full_hypothesis()) self._process_new_hypos(FlipCandidate(hypo.trgt_sentence, scores, self._create_dummy_bigrams(), hypo.score), len(hypo.trgt_sentence), hypos, posteriors, score_breakdowns) def _create_dummy_bigrams(self): """Creates a dictionary of optimistic bigram scores which is to be updated as the decoder steps through the target sentence """ bigram_scores = {} words = [utils.GO_ID, utils.EOS_ID] + [w for w in self.full_bag] for w in words: bigram_scores[w] = {w2: 0.0 for w2 in words} return bigram_scores def _is_explored(self, trgt_sentence): """Returns true if this target sentence has been explored already """ prefix = self.explored.get_prefix(trgt_sentence) return self.explored.get(prefix) def _process_new_hypos(self, candidate, max_pos, explored_hypos, explored_posteriors, explored_score_breakdowns): """This method is called after a new candidate has been explored. It derives new candidates by updating the bigram scores and delegating to ``_extract_candidates`` for permutating the parent target sentence. """ full_hypo_len = len(candidate.trgt_sentence) max_pos = min(full_hypo_len-1, # -1 because we don't permit moving EOS max_pos) self.explored.add(candidate.trgt_sentence[0:max_pos+1], True) if (explored_hypos and explored_hypos[-1].trgt_sentence[-1] == utils.EOS_ID): full_hypo = explored_hypos[-1].generate_full_hypothesis() self.best_score = max(self.best_score, full_hypo.total_score) self.add_full_hypo(full_hypo) # Update data structures bigram_scores = copy.deepcopy(candidate.bigram_scores) for idx,hypo in enumerate(explored_hypos): prev_word = hypo.trgt_sentence[-1] if hypo.trgt_sentence \ else utils.GO_ID scores = candidate.scores[0:len(hypo.trgt_sentence)] for w in self.full_bag: score = explored_posteriors[idx][w] # Update bigram scores # not quite correct if we have two occurrences of the same word bigram_scores[prev_word][w] = score # Update tries exp_hypo = hypo.cheap_expand(w, score, explored_score_breakdowns[idx][w]) exp_hypo.scores = scores + [score] self.hypos.add(exp_hypo.trgt_sentence, exp_hypo) # Extract candidates self._extract_candidates(candidate, max_pos, bigram_scores) def _extract_candidates_move(self, candidate, max_pos, bigram_scores): """Implements the traversal strategy 'move' """ full_hypo_len = len(candidate.trgt_sentence) for from_pos in range(full_hypo_len-1): from_word = candidate.trgt_sentence[from_pos] stub = candidate.trgt_sentence[0:from_pos] \ + candidate.trgt_sentence[from_pos+1:] for to_pos in range(full_hypo_len-1): change_pos = min(from_pos, to_pos) if change_pos >= max_pos: continue trgt_sentence = list(stub) trgt_sentence.insert(to_pos, from_word) if self._is_explored(trgt_sentence): continue scores = list(candidate.scores) prev_word = trgt_sentence[change_pos-1] if change_pos > 0 \ else utils.GO_ID for idx in range(change_pos, full_hypo_len): word = trgt_sentence[idx] scores[idx] = bigram_scores[prev_word][word] prev_word = word self._add_candidate(FlipCandidate(trgt_sentence, scores, bigram_scores, sum(scores[0:change_pos+1]))) def _extract_candidates_flip(self, candidate, max_pos, bigram_scores): """Implements the traversal strategy 'flip' """ full_hypo_len = len(candidate.trgt_sentence) for from_pos in range(max_pos): max_score = sum(candidate.scores[0:from_pos]) from_word = candidate.trgt_sentence[from_pos] for to_pos in range(from_pos+1, full_hypo_len-1): trgt_sentence = list(candidate.trgt_sentence) trgt_sentence[from_pos] = trgt_sentence[to_pos] trgt_sentence[to_pos] = from_word if self._is_explored(trgt_sentence): continue scores = list(candidate.scores) prev_word = trgt_sentence[from_pos-1] if from_pos > 0 \ else utils.GO_ID for idx in range(from_pos, full_hypo_len): word = trgt_sentence[idx] scores[idx] = bigram_scores[prev_word][word] prev_word = word self._add_candidate(FlipCandidate(trgt_sentence, scores, bigram_scores, max_score+scores[from_pos])) def _add_candidate(self, candidate): """Add a candidate to the heap """ if candidate.max_score > self.best_score: heappush(self.open_candidates, (-candidate.expected_score, candidate)) def _explore_candidate(self, candidate): """Explores a candidate, adds it to the list of full hypotheses (if not pruned by early stopping) and calls ``_process_new_hypos`` to derive new candidates. """ prefix = self.hypos.get_prefix(candidate.trgt_sentence) hypo = self.hypos.get(prefix) self.set_predictor_states(copy.deepcopy(hypo.predictor_states)) for pos,score in enumerate(hypo.scores): # Update candidate scores candidate.scores[pos] = score self.consume(hypo.word_to_consume) hypos = [] posteriors = [] score_breakdowns = [] for pos in range(len(prefix), len(candidate.trgt_sentence)): if self.early_stopping and hypo.score <= self.best_score: break # admissible pruning posterior,score_breakdown = self.apply_predictors() hypo.predictor_states = copy.deepcopy(self.get_predictor_states()) hypos.append(hypo) posteriors.append(posterior) score_breakdowns.append(score_breakdown) logging.debug("Explored: %s (%f)" % ( ' '.join([str(w) for w in hypo.trgt_sentence]), hypo.score)) word = candidate.trgt_sentence[pos] score = posterior[word] if self.always_greedy: # change sentence st best word is at this pos best_pos = pos stub = candidate.trgt_sentence[0:pos] for p in range(pos+1, len(candidate.trgt_sentence)-1): if (posterior[candidate.trgt_sentence[p]] > score and not self._is_explored( stub + [candidate.trgt_sentence[p]])): best_pos = p word = candidate.trgt_sentence[p] score = posterior[word] if best_pos != pos: del candidate.trgt_sentence[best_pos] del candidate.scores[best_pos] candidate.trgt_sentence.insert(pos, word) candidate.scores.insert(pos, score) self.consume(word) hypo = hypo.expand(word, None, score, score_breakdown[word]) candidate.scores[pos] = posterior[word] if hypo.trgt_sentence[-1] == utils.EOS_ID: self.best_score = max(self.best_score, hypo.score) self.add_full_hypo(hypo.generate_full_hypothesis()) acc = 0.0 for max_pos,score in enumerate(candidate.scores): acc += score if acc <= self.best_score: break self._process_new_hypos(candidate, max_pos, hypos, posteriors, score_breakdowns) def _load_bag(self): """Load the current bag of words """ self.full_bag = {} for w in self.lines[self.current_sen_id].strip().split(): int_w = int(w) self.full_bag[int_w] = self.full_bag.get(int_w, 0) + 1
[docs] def decode(self, src_sentence): """Decodes a single source sentence with the flip decoder """ self.initialize_predictors(src_sentence) self.max_expansions = self.get_max_expansions(self.max_expansions_param, src_sentence) self._load_bag() self.hypos = SimpleTrie() self.explored = SimpleTrie() self.open_candidates = [] self.best_score = self.get_lower_score_bound() self._greedy_decode() while (self.open_candidates and self.max_expansions > self.apply_predictors_count): _,candidate = heappop(self.open_candidates) if candidate.max_score <= self.best_score: continue if self._is_explored(candidate.trgt_sentence): # Already explored continue logging.debug( "Best: %f Expected: %f Expansions: %d Open: %d Explore %s" % ( self.best_score, candidate.expected_score, self.apply_predictors_count, len(self.open_candidates), ' '.join([str(w) for w in candidate.trgt_sentence]))) self._explore_candidate(candidate) return self.get_full_hypos_sorted()