Source code for cam.sgnmt.predictors.bow

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains predictors for bag of words experiments. This
is the standard bow predictor and the bowsearch predictor which first
does an unrestricted search to construct a skeleton and then restricts
the order of words by that skeleton (in addition to the bag 
restriction).
"""

import logging

from cam.sgnmt import utils
from cam.sgnmt.decoding.beam import BeamDecoder
from cam.sgnmt.decoding.core import CLOSED_VOCAB_SCORE_NORM_NONE
from cam.sgnmt.misc.trie import SimpleTrie
from cam.sgnmt.misc.unigram import FileUnigramTable, \
    BestStatsUnigramTable, FullStatsUnigramTable, AllStatsUnigramTable
from cam.sgnmt.predictors.core import Predictor
from cam.sgnmt.utils import INF, NEG_INF, MESSAGE_TYPE_FULL_HYPO, \
    MESSAGE_TYPE_DEFAULT


[docs]class BagOfWordsPredictor(Predictor): """This predictor is similar to the forced predictor, but it does not enforce the word order in the reference. Therefore, it assigns 1 to all hypotheses which have the words in the reference in any order, and -inf to all other hypos. """ def __init__(self, trg_test_file, accept_subsets=False, accept_duplicates=False, heuristic_scores_file="", collect_stats_strategy='best', heuristic_add_consumed = False, heuristic_add_remaining = True, diversity_heuristic_factor = -1.0, equivalence_vocab=-1): """Creates a new bag-of-words predictor. Args: trg_test_file (string): Path to the plain text file with the target sentences. Must have the same number of lines as the number of source sentences to decode. The word order in the target sentences is not relevant for this predictor. accept_subsets (bool): If true, this predictor permits EOS even if the bag is not fully consumed yet accept_duplicates (bool): If true, counts are not updated when a word is consumed. This means that we allow a word in a bag to appear multiple times heuristic_scores_file (string): Path to the unigram scores which are used if this predictor estimates future costs collect_stats_strategy (string): best, full, or all. Defines how unigram estimates are collected for heuristic heuristic_add_consumed (bool): Set to true to add the difference between actual partial score and unigram estimates of consumed words to the predictor heuristic heuristic_add_remaining (bool): Set to true to add the sum of unigram scores of words remaining in the bag to the predictor heuristic diversity_heuristic_factor (float): Factor for diversity heuristic which penalizes hypotheses with the same bag as full hypos equivalence_vocab (int): If positive, predictor states are considered equal if the the remaining words within that vocab and OOVs regarding this vocab are the same. Only relevant when using hypothesis recombination """ super(BagOfWordsPredictor, self).__init__() with open(trg_test_file) as f: self.lines = f.read().splitlines() if heuristic_scores_file: self.estimates = FileUnigramTable(heuristic_scores_file) elif collect_stats_strategy == 'best': self.estimates = BestStatsUnigramTable() elif collect_stats_strategy == 'full': self.estimates = FullStatsUnigramTable() elif collect_stats_strategy == 'all': self.estimates = AllStatsUnigramTable() else: logging.error("Unknown statistics collection strategy") self.accept_subsets = accept_subsets self.accept_duplicates = accept_duplicates self.heuristic_add_consumed = heuristic_add_consumed self.heuristic_add_remaining = heuristic_add_remaining self.equivalence_vocab = equivalence_vocab if accept_duplicates and not accept_subsets: logging.error("You enabled bow_accept_duplicates but not bow_" "accept_subsets. Therefore, the bow predictor will " "never accept end-of-sentence and could cause " "an infinite loop in the search strategy.") self.diversity_heuristic_factor = diversity_heuristic_factor self.diverse_heuristic = (diversity_heuristic_factor > 0.0)
[docs] def get_unk_probability(self, posterior): """Returns negative infinity unconditionally: Words which are not in the target sentence have assigned probability 0 by this predictor. """ return NEG_INF
[docs] def predict_next(self): """If the bag is empty, the only allowed symbol is EOS. Otherwise, return the list of keys in the bag. """ if not self.bag: # Empty bag return {utils.EOS_ID : 0.0} ret = {w : 0.0 for w in self.bag} if self.accept_subsets: ret[utils.EOS_ID] = 0.0 return ret
[docs] def initialize(self, src_sentence): """Creates a new bag for the current target sentence.. Args: src_sentence (list): Not used """ self.best_hypo_score = NEG_INF self.bag = {} for w in self.lines[self.current_sen_id].strip().split(): int_w = int(w) self.bag[int_w] = self.bag.get(int_w, 0) + 1 self.full_bag = dict(self.bag)
[docs] def consume(self, word): """Updates the bag by deleting the consumed word. Args: word (int): Next word to consume """ if word == utils.EOS_ID: self.bag = {} return if not word in self.bag: logging.warn("Consuming word which is not in bag-of-words!") return cnt = self.bag.pop(word) if cnt > 1 and not self.accept_duplicates: self.bag[word] = cnt - 1
[docs] def get_state(self): """State of this predictor is the current bag """ return self.bag
[docs] def set_state(self, state): """State of this predictor is the current bag """ self.bag = state
[docs] def initialize_heuristic(self, src_sentence): """Calls ``reset`` of the used unigram table with estimates ``self.estimates`` to clear all statistics from the previous sentence Args: src_sentence (list): Not used """ self.estimates.reset() if self.diverse_heuristic: self.explored_bags = SimpleTrie()
[docs] def notify(self, message, message_type = MESSAGE_TYPE_DEFAULT): """This gets called if this predictor observes the decoder. It updates unigram heuristic estimates via passing through this message to the unigram table ``self.estimates``. """ self.estimates.notify(message, message_type) if self.diverse_heuristic and message_type == MESSAGE_TYPE_FULL_HYPO: self._update_explored_bags(message)
def _update_explored_bags(self, hypo): """This is called if diversity heuristic is enabled. It updates ``self.explored_bags`` """ sen = hypo.trgt_sentence for l in range(len(sen)): key = sen[:l] key.sort() cnt = self.explored_bags.get(key) if not cnt: cnt = 0.0 self.explored_bags.add(key, cnt + 1.0)
[docs] def estimate_future_cost(self, hypo): """The bow predictor comes with its own heuristic function. We use the sum of scores of the remaining words as future cost estimator. """ acc = 0.0 if self.heuristic_add_remaining: remaining = dict(self.full_bag) remaining[utils.EOS_ID] = 1 for w in hypo.trgt_sentence: remaining[w] -= 1 acc -= sum([cnt*self.estimates.estimate(w) for w,cnt in remaining.items()]) if self.diverse_heuristic: key = list(hypo.trgt_sentence) key.sort() cnt = self.explored_bags.get(key) if cnt: acc += cnt * self.diversity_heuristic_factor if self.heuristic_add_consumed: acc -= hypo.score - sum([self.estimates.estimate(w, -1000.0) for w in hypo.trgt_sentence]) return acc
def _get_unk_bag(self, org_bag): if self.equivalence_vocab <= 0: return org_bag unk_bag = {} for word,cnt in org_bag.items(): idx = word if word < self.equivalence_vocab else utils.UNK_ID unk_bag[idx] = unk_bag.get(idx, 0) + cnt return unk_bag
[docs] def is_equal(self, state1, state2): """Returns true if the bag is the same """ return self._get_unk_bag(state1) == self._get_unk_bag(state2)
[docs]class BagOfWordsSearchPredictor(BagOfWordsPredictor): """Combines the bag-of-words predictor with a proxy decoding pass which creates a skeleton translation. """ def __init__(self, main_decoder, hypo_recombination, trg_test_file, accept_subsets=False, accept_duplicates=False, heuristic_scores_file="", collect_stats_strategy='best', heuristic_add_consumed = False, heuristic_add_remaining = True, diversity_heuristic_factor = -1.0, equivalence_vocab=-1): """Creates a new bag-of-words predictor with pre search Args: main_decoder (Decoder): Reference to the main decoder instance, used to fetch the predictors hypo_recombination (bool): Activates hypo recombination for the pre decoder trg_test_file (string): Path to the plain text file with the target sentences. Must have the same number of lines as the number of source sentences to decode. The word order in the target sentences is not relevant for this predictor. accept_subsets (bool): If true, this predictor permits EOS even if the bag is not fully consumed yet accept_duplicates (bool): If true, counts are not updated when a word is consumed. This means that we allow a word in a bag to appear multiple times heuristic_scores_file (string): Path to the unigram scores which are used if this predictor estimates future costs collect_stats_strategy (string): best, full, or all. Defines how unigram estimates are collected for heuristic heuristic_add_consumed (bool): Set to true to add the difference between actual partial score and unigram estimates of consumed words to the predictor heuristic heuristic_add_remaining (bool): Set to true to add the sum of unigram scores of words remaining in the bag to the predictor heuristic equivalence_vocab (int): If positive, predictor states are considered equal if the the remaining words within that vocab and OOVs regarding this vocab are the same. Only relevant when using hypothesis recombination """ self.main_decoder = main_decoder self.pre_decoder = BeamDecoder(CLOSED_VOCAB_SCORE_NORM_NONE, main_decoder.max_len_factor, hypo_recombination, 10) self.pre_decoder.combine_posteriors = main_decoder.combine_posteriors super(BagOfWordsSearchPredictor, self).__init__(trg_test_file, accept_subsets, accept_duplicates, heuristic_scores_file, collect_stats_strategy, heuristic_add_consumed, heuristic_add_remaining, diversity_heuristic_factor, equivalence_vocab) self.pre_mode = False
[docs] def predict_next(self): """If in ``pre_mode``, pass through to super class. Otherwise, scan skeleton """ if self.pre_mode: return super(BagOfWordsSearchPredictor, self).predict_next() if not self.bag: # Empty bag return {utils.EOS_ID : 0.0} ret = {w : 0.0 for w in self.missing} if self.accept_subsets: ret[utils.EOS_ID] = 0.0 if self.skeleton_pos < len(self.skeleton): ret[self.skeleton[self.skeleton_pos]] = 0.0 return ret
[docs] def initialize(self, src_sentence): """If in ``pre_mode``, pass through to super class. Otherwise, initialize skeleton. """ if self.pre_mode: return super(BagOfWordsSearchPredictor, self).initialize(src_sentence) self.pre_mode = True old_accept_subsets = self.accept_subsets old_accept_duplicates = self.accept_duplicates self.accept_subsets = True self.accept_duplicates = True self.pre_decoder.predictors = self.main_decoder.predictors self.pre_decoder.current_sen_id = self.main_decoder.current_sen_id - 1 hypos = self.pre_decoder.decode(src_sentence) score = INF if not hypos: logging.warn("No hypothesis found by the pre decoder. Effectively " "reducing bowsearch predictor to bow predictor.") self.skeleton = [] else: self.skeleton = hypos[0].trgt_sentence score = hypos[0].total_score if self.skeleton and self.skeleton[-1] -- utils.EOS_ID: self.skeleton = self.skeleton[:-1] # Remove EOS self.skeleton_pos = 0 self.accept_subsets = old_accept_subsets self.accept_duplicates = old_accept_duplicates self._set_up_full_mode() logging.debug("BOW Skeleton (score=%f missing=%d): %s" % ( score, sum(self.missing.values()), self.skeleton)) self.main_decoder.current_sen_id -= 1 self.main_decoder.initialize_predictors(src_sentence) self.pre_mode = False
def _set_up_full_mode(self): """This method initializes ``missing`` by using ``self.skeleton`` and ``self.full_bag`` and removes duplicates from ``self.skeleton``. """ self.bag = dict(self.full_bag) missing = dict(self.full_bag) skeleton_no_duplicates = [] for word in self.skeleton: if missing[word] > 0: missing[word] -= 1 skeleton_no_duplicates.append(word) self.skeleton = skeleton_no_duplicates self.missing = {w: cnt for w, cnt in missing.items() if cnt > 0}
[docs] def consume(self, word): """Calls super class ``consume``. If not in ``pre_mode``, update skeleton info. Args: word (int): Next word to consume """ super(BagOfWordsSearchPredictor, self).consume(word) if self.pre_mode: return if (self.skeleton_pos < len(self.skeleton) and word == self.skeleton[self.skeleton_pos]): self.skeleton_pos += 1 elif word in self.missing: self.missing[word] -= 1 if self.missing[word] <= 0: del self.missing[word]
[docs] def get_state(self): """If in pre_mode, state of this predictor is the current bag Otherwise, its the bag plus skeleton state """ if self.pre_mode: return super(BagOfWordsSearchPredictor, self).get_state() return self.bag, self.skeleton_pos, self.missing
[docs] def set_state(self, state): """If in pre_mode, state of this predictor is the current bag Otherwise, its the bag plus skeleton state """ if self.pre_mode: return super(BagOfWordsSearchPredictor, self).set_state(state) self.bag, self.skeleton_pos, self.missing = state
[docs] def is_equal(self, state1, state2): """Returns true if the bag and the skeleton states are the same """ if self.pre_mode: return super(BagOfWordsSearchPredictor, self).is_equal(state1, state2) return super(BagOfWordsSearchPredictor, self).is_equal(state1[0], state2[0])