Source code for cam.sgnmt.predictors.parse

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging

from cam.sgnmt import utils
from cam.sgnmt.predictors.core import Predictor
import numpy as np


import collections

[docs]def load_external_ids(path): """ load file of ids to list """ logging.info('Loading ids from file {}'.format(path)) with open(path) as f: ids = [int(line.strip()) for line in f] return set(ids)
[docs]class InternalHypo(object): """Helper class for internal parse predictor beam search over nonterminals """ def __init__(self, score, token_score, predictor_state, word_to_consume): self.score = score self.predictor_state = predictor_state self.word_to_consume = word_to_consume self.norm_score = score self.token_score = token_score self.beam_len = 1
[docs] def extend(self, score, predictor_state, word_to_consume): self.score += score self.predictor_state = predictor_state self.word_to_consume = word_to_consume self.beam_len += 1
[docs]class ParsePredictor(Predictor): """Predictor wrapper allowing internal beam search over a representation which contains some pre-defined 'non-terminal' ids, which should not appear in the output. """ def __init__(self, slave_predictor, normalize_scores=True, beam_size=4, max_internal_len=35, nonterminal_ids=None): """Create a new parse wrapper for a predictor Args: slave_predictor: predictor to wrap with parse wrapper normalize_scores (bool): whether to normalize posterior scores, e.g. after some tokens have been removed beam_size (int): beam size for internal beam search over non-terminals max_internal_len (int): number of consecutive non-terminal tokens allowed in internal search before path is ignored nonterminal_ids: file containing non-terminal ids, one per line """ super(ParsePredictor, self).__init__() self.predictor = slave_predictor self.normalize_scores = normalize_scores self.beam_size = beam_size self.max_internal_len = max_internal_len self.nonterminals = load_external_ids(nonterminal_ids) self.nonterminals.discard(utils.EOS_ID) self.nonterminals.discard(utils.UNK_ID) self.tok_to_hypo = {}
[docs] def get_unk_probability(self, posterior): """Return unk probability as determined by slave predictor Returns: float, unk prob """ return self.predictor.get_unk_probability(posterior)
[docs] def are_best_terminal(self, posterior): """Return true if most probable tokens in posterior are all terminals (including EOS) """ best_rule_ids = utils.argmax_n(posterior, self.beam_size) for tok in best_rule_ids: if tok in self.nonterminals: return False return True
[docs] def predict_next(self, predicting_internally=False): """Predict next tokens. Args: predicting_internally: will be true if called from internal beam search, prevents infinite loop """ original_posterior = self.predictor.predict_next() all_keys = utils.common_viewkeys(original_posterior) scores = {rule_id: original_posterior[rule_id] for rule_id in all_keys} scores = self.finalize_posterior( scores, use_weights=True, normalize_scores=self.normalize_scores) if not predicting_internally: scores = self.find_word_beam(scores) return scores
[docs] def maybe_add_new_top_tokens(self, top_terminals, hypo, next_hypos): new_post = self.predict_next(predicting_internally=True) top_tokens = utils.argmax_n(new_post, self.beam_size) next_state = copy.deepcopy(self.predictor.get_state()) for tok in top_tokens: score = hypo.score + new_post[tok] new_hypo = InternalHypo(score, new_post[tok], next_state, tok) if tok not in self.nonterminals: add_hypo = False found = False for t in top_terminals: if t == tok: found = True if self.tok_to_hypo[tok].score < new_hypo.score: add_hypo = True top_terminals.remove(t) break if not found: add_hypo = True if add_hypo: top_terminals.append(tok) self.tok_to_hypo[tok] = new_hypo else: next_hypos.append(new_hypo)
[docs] def initialize_internal_hypos(self, posterior): top_tokens = utils.argmax_n(posterior, self.beam_size) hypos = [] top_terminals = [] for tok in top_tokens: new_hypo = InternalHypo(posterior[tok], posterior[tok], copy.deepcopy(self.predictor.get_state()), tok) if tok not in self.nonterminals: self.tok_to_hypo[tok] = new_hypo top_terminals.append(tok) hypos.append(new_hypo) return hypos, top_terminals
[docs] def find_word_beam(self, posterior): """Internal beam search over posterior until a beam of terminals is found """ hypos, top_terminals = self.initialize_internal_hypos(posterior) min_score = utils.NEG_INF if top_terminals: top_terminals.sort(key=lambda h: -self.tok_to_hypo[h].score) min_score = self.tok_to_hypo[top_terminals[-1]].score hypos.sort(key=lambda h: -h.score) # search until we have n terminal hypos with path scores better # than further internal search can give us while hypos and hypos[0].score > min_score: next_hypos = [] for hypo in hypos: if not hypo.word_to_consume in self.nonterminals: continue self.predictor.set_state(copy.deepcopy(hypo.predictor_state)) self.consume(hypo.word_to_consume, internal=True) self.maybe_add_new_top_tokens(top_terminals, hypo, next_hypos) next_hypos.sort(key=lambda h: -h.score) hypos = next_hypos[:self.beam_size] top_terminals.sort(key=lambda t: -self.tok_to_hypo[t].score) top_terminals = top_terminals[:self.beam_size] if top_terminals: min_score = self.tok_to_hypo[top_terminals[-1]].score token_scores = [self.tok_to_hypo[t].score for t in top_terminals] return_post = {t: s for t, s in zip(top_terminals, token_scores)} return return_post
[docs] def initialize(self, src_sentence): """Initializes slave predictor with source sentence Args: src_sentence (list) """ self.predictor.initialize(src_sentence)
[docs] def consume(self, word, internal=False): try: if not internal: self.predictor.set_state( copy.deepcopy(self.tok_to_hypo[word].predictor_state)) except KeyError: logging.info('Consuming {}, not in tok-to-hypo'.format(word)) return self.predictor.consume(word)
[docs] def get_state(self): """Returns the current state. """ return self.predictor.get_state(), self.tok_to_hypo
[docs] def set_state(self, state): """Sets the current state. """ slave_state, tok_to_hypo = state self.tok_to_hypo = tok_to_hypo self.predictor.set_state(slave_state)
[docs] def initialize_heuristic(self, src_sentence): """Creates a matrix of shortest distances between nodes. """ pass
[docs] def is_equal(self, state1, state2): """Returns true if the current node is the same """ return state1 == state2
[docs]class TokParsePredictor(ParsePredictor): """ Unlike ParsePredictor, the grammar predicts tokens according to a grammar. Use BPEParsePredictor if including rules to connect BPE units inside words. """ def __init__(self, grammar_path, slave_predictor, word_out=True, normalize_scores=True, norm_alpha=1.0, beam_size=1, max_internal_len=35, allow_early_eos=False, consume_out_of_class=False): """Creates a new parse predictor wrapper. Args: grammar_path (string): Path to the grammar file slave_predictor: predictor to wrap word_out (bool): since this wrapper can be used for grammar constraint, this bool determines whether we also do internal beam search over non-terminals normalize_scores (bool): true if normalizing scores, e.g. if some are removed from the posterior norm_alpha (float): may be used for path weight normalization beam_size (int): beam size for internal beam search max_internal_len (int): max number of consecutive nonterminals before path is ignored by internal search allow_early_eos (bool): true if permitting EOS consumed even if it is not permitted by the grammar at that point consume_out_of_class (bool): true if permitting any tokens to be consumed even if not allowed by the grammar at that point """ super(TokParsePredictor, self).__init__(slave_predictor, normalize_scores, beam_size, max_internal_len, nonterminal_ids) self.grammar_path = grammar_path self.word_out = word_out self.stack = [] self.norm_alpha = norm_alpha self.check_n_best_terminal = False self.current_lhs = None self.current_rhs = [] self.allow_early_eos=allow_early_eos self.consume_ooc = consume_out_of_class self.prepare_grammar() self.tok_to_internal_state = {}
[docs] def norm_hypo_score(self, hypo): hypo.norm_score = self.norm_score(hypo.score, hypo.beam_len)
[docs] def norm_score(self, score, beam_len): length_penalty = (5.0 + beam_len) / 6 if self.norm_alpha != 1.0: length_penalty = pow(length_penalty, self.norm_alpha) return score / length_penalty
[docs] def prepare_grammar(self): self.lhs_to_can_follow = {} with open(self.grammar_path) as f: for line in f: nt, rule = line.split(':') nt = int(nt.strip()) self.lhs_to_can_follow[nt] = set( [int(r) for r in rule.strip().split()]) self.last_nt_in_rule = {nt: True for nt in self.lhs_to_can_follow} for nt, following in self.lhs_to_can_follow.items(): if 0 in following: following.remove(0) self.last_nt_in_rule[nt] = False if self.allow_early_eos and utils.UNK_ID in following: self.lhs_to_can_follow[nt].add(utils.EOS_ID) self.lhs_to_can_follow[utils.EOS_ID].add(utils.UNK_ID)
[docs] def initialize(self, src_sentence): self.predictor.initialize(src_sentence) self.current_lhs = None self.current_rhs = [] self.stack = [utils.EOS_ID] self.consume(utils.GO_ID)
[docs] def replace_lhs(self): while self.current_rhs: self.stack.append(self.current_rhs.pop()) if self.stack: self.current_lhs = self.stack.pop() else: self.current_lhs = utils.EOS_ID
[docs] def get_current_allowed(self): if self.current_lhs: return self.lhs_to_can_follow[self.current_lhs] return set([utils.GO_ID])
[docs] def predict_next(self, predicting_next_word=False): """predict next tokens as permitted by the current stack and the grammar """ original_posterior = self.predictor.predict_next() outgoing_rules = self.lhs_to_can_follow[self.current_lhs] scores = {rule_id: original_posterior[rule_id] for rule_id in outgoing_rules} scores = self.finalize_posterior( scores, use_weights=True, normalize_scores=self.normalize_scores) if self.word_out and not predicting_next_word: scores = self.find_word(scores) return scores
[docs] def find_word_greedy(self, posterior): while not self.are_best_terminal(posterior): best_rule_id = utils.argmax(posterior) self.consume(best_rule_id) posterior = self.predict_next(predicting_next_word=True) return posterior
[docs] def find_word_beam(self, posterior): """ Do an internal beam search over non-terminal functions to find the next best n terminal tokens, as ranked by normalized path score Returns: posterior containing up to n terminal tokens and their normalized path score """ top_tokens = utils.argmax_n(posterior, self.beam_size) hypos = [InternalHypo(posterior[tok], self.get_state(), tok) for tok in top_tokens if tok in self.nonterminals] best_hypo = InternalHypo(utils.NEG_INF, None, None) best_posterior = None while hypos and hypos[0].norm_score > best_hypo.norm_score: next_hypos = [] for hypo in hypos: self.set_state(copy.deepcopy(hypo.predictor_state)) self.consume(hypo.word_to_consume) new_post = self.predict_next(predicting_next_word=True) top_tokens = utils.argmax_n(new_post, self.beam_size) next_state = copy.deepcopy(self.get_state()) new_norm_score = self.norm_score( new_post[top_tokens[0]] + hypo.score, hypo.beam_len + 1) if (self.are_best_terminal(new_post) and new_norm_score > best_hypo.norm_score): best_hypo = copy.deepcopy(hypo) best_hypo.predictor_state = next_state best_hypo.norm_score = new_norm_score best_posterior = new_post self.norm_score(best_hypo) else: if hypo.beam_len == self.max_internal_len: logging.info('cutting off internal hypo - too long') continue for tok in top_tokens: if tok in self.nonterminals: new_hypo = copy.deepcopy(hypo) new_hypo.extend(new_post[tok], next_state, tok) next_hypos.append(new_hypo) map(self.norm_hypo_score, next_hypos) next_hypos.sort(key=lambda h: -h.norm_score) hypos = next_hypos[:self.beam_size] self.set_state(best_hypo.predictor_state) for tok in best_posterior.keys(): best_posterior[tok] = self.norm_score( best_hypo.score + best_posterior[tok], best_hypo.beam_len + 1) if tok in self.nonterminals: del best_posterior[tok] return best_posterior
[docs] def find_word(self, posterior): """Check whether rhs of best option in posterior is a terminal if it is, return the posterior for decoding if not, take the best result and follow that path until a word is found this follows a greedy 1best or a beam path through non-terminals """ if self.beam_size <= 1: return self.find_word_greedy(posterior) else: if self.are_best_terminal(posterior): return posterior else: return self.find_word_beam(posterior)
[docs] def consume(self, word): """ Args: word (int): word token being consumed """ change_to_unk = ( (word == utils.UNK_ID) or (not self.consume_ooc and word not in self.get_current_allowed())) if change_to_unk: word = utils.UNK_ID self.update_stacks(word) return self.predictor.consume(word)
[docs] def update_stacks(self, word): if word in self.nonterminals: self.current_rhs.append(word) if self.last_nt_in_rule[word]: self.replace_lhs() else: self.replace_lhs()
[docs] def get_state(self): """Returns the current state, including slave predictor state """ return (self.stack, self.current_lhs, self.current_rhs, self.predictor.get_state())
[docs] def set_state(self, state): """Sets the current state """ self.stack, self.current_lhs, self.current_rhs, slave_state = state self.predictor.set_state(slave_state)
[docs]class BpeParsePredictor(TokParsePredictor): """ Predict over a BPE-based grammar with two possible grammar constraints: one between non-terminals and bpe start-of-word tokens, one over bpe tokens in a word """ def __init__(self, grammar_path, bpe_rule_path, slave_predictor, word_out=True, normalize_scores=True, norm_alpha=1.0, beam_size=1, max_internal_len=35, allow_early_eos=False, consume_out_of_class=False, eow_ids=None, terminal_restrict=True, terminal_ids=None, internal_only_restrict=False): """Creates a new parse predictor wrapper which can be constrained to 2 grammars: one over non-terminals / terminals, one internally to constrain BPE units within a single word Args: grammar_path (string): Path to the grammar file bpe_rule_path (string): Path to file defining rules between BPEs slave_predictor: predictor to wrap word_out (bool): since this wrapper can be used for grammar constraint, this bool determines whether we also do internal beam search over non-terminals normalize_scores (bool): true if normalizing scores, e.g. if some are removed from the posterior norm_alpha (float): may be used for path weight normalization beam_size (int): beam size for internal beam search max_internal_len (int): max number of consecutive nonterminals before path is ignored by internal search allow_early_eos (bool): true if permitting EOS consumed even if it is not permitted by the grammar at that point consume_out_of_class (bool): true if permitting any tokens to be consumed even if not allowed by the grammar at that point eow_ids (string): path to file containing ids of BPEs that mark the end of a word terminal_restrict (bool): true if applying grammar constraint over nonterminals and terminals terminal_ids (string): path to file containing all terminal ids internal_only_restrict (bool): true if applying grammar constraint over BPE units inside words """ super(BpeParsePredictor, self).__init__(grammar_path, slave_predictor, word_out, normalize_scores, norm_alpha, beam_size, max_internal_len, allow_early_eos, consume_out_of_class) self.internal_only_restrict = internal_only_restrict self.terminal_restrict = terminal_restrict self.eow_ids = self.get_eow_ids(eow_ids) self.all_terminals = self.get_all_terminals(terminal_ids) self.get_bpe_can_follow(bpe_rule_path)
[docs] def get_eow_ids(self, eow_ids): eows = set() if eow_ids: with open(eow_ids) as f: for line in f: eows.add(int(line.strip())) return eows
[docs] def get_all_terminals(self, terminal_ids): all_terminals = set([utils.EOS_ID]) if terminal_ids: with open(terminal_ids) as f: for line in f: all_terminals.add(int(line.strip())) if not self.terminal_restrict: for terminal in all_terminals: self.lhs_to_can_follow[terminal] = all_terminals return all_terminals
[docs] def get_bpe_can_follow(self, rule_path): with open(rule_path) as f: for line in f: nt, following = line.split(' : ') nt_tuple = tuple(map(int, nt.split(','))) following = set([int(r) for r in following.strip().split()]) self.lhs_to_can_follow[nt_tuple] = following
[docs] def update_stacks(self, word): if word in self.nonterminals: self.current_rhs.append(word) if self.last_nt_in_rule[word]: self.replace_lhs() else: if self.terminal_restrict: try: internal_lhs = self.current_lhs + (word,) except TypeError: internal_lhs = (self.current_lhs, word) else: internal_lhs = word if not self.terminal_restrict and word in self.eow_ids: self.replace_lhs() elif (not self.terminal_restrict or internal_lhs in self.lhs_to_can_follow): self.current_lhs = internal_lhs else: self.replace_lhs()
[docs] def is_nt(self, word): if word in self.all_terminals: return False return True
[docs] def predict_next(self, predicting_next_word=False): """predict next tokens as permitted by the current stack and the BPE grammar """ original_posterior = self.predictor.predict_next() outgoing_rules = self.lhs_to_can_follow[self.current_lhs] scores = {rule_id: original_posterior[rule_id] for rule_id in outgoing_rules} if self.internal_only_restrict and self.are_best_terminal(scores): outgoing_rules = self.all_terminals scores = {rule_id: original_posterior[rule_id] for rule_id in outgoing_rules} scores = self.finalize_posterior( scores, use_weights=True, normalize_scores=self.normalize_scores) if self.word_out and not predicting_next_word: scores = self.find_word(scores) return scores