# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of beam search which applies combination_sheme at
each time step.
"""
from cam.sgnmt import utils
from cam.sgnmt.decoding.beam import BeamDecoder
from cam.sgnmt.decoding import combination
from cam.sgnmt.decoding.core import PartialHypothesis
import copy
import logging
import numpy as np
[docs]class CombiStatePartialHypo(PartialHypothesis):
"""Identical to PartialHypothesis, but tracks the
last-score-but-one for score combination
"""
def __init__(self, initial_states=None):
super(CombiStatePartialHypo, self).__init__(initial_states)
self.score_minus_last = 0 # score not counting last step
def _new_partial_hypo(self, states, word, score, score_breakdown):
new_hypo = CombiStatePartialHypo(states)
new_hypo.score_minus_last = self.score
new_hypo.score = self.score + score
new_hypo.score_breakdown = copy.copy(self.score_breakdown)
new_hypo.trgt_sentence = self.trgt_sentence + [word]
new_hypo.score_breakdown.append(score_breakdown)
return new_hypo
[docs]class CombiBeamDecoder(BeamDecoder):
"""This beam search implementation is a modification to the hypo
expansion strategy. Rather than selecting hypotheses based on
the sum of the previous hypo scores and the current one, we
apply combination_scheme in each time step. This makes it possible
to use schemes like Bayesian combination on the word rather than
the full sentence level.
"""
def __init__(self, decoder_args):
"""Creates a new beam decoder instance. In addition to the
constructor of `BeamDecoder`, the following values are fetched
from `decoder_args`:
combination_scheme (string): breakdown2score strategy
"""
super(CombiBeamDecoder, self).__init__(decoder_args)
# Whether to pass combination cached predictor weights
self.breakdown2score_kwargs = {}
if decoder_args.combination_scheme == 'length_norm':
self.breakdown2score = combination.breakdown2score_length_norm
if decoder_args.combination_scheme == 'bayesian_loglin':
self.breakdown2score = combination.breakdown2score_bayesian_loglin
if decoder_args.combination_scheme == 'bayesian_state_dependent':
self.breakdown2score_kwargs['lambdas'] = self.get_domain_task_weights(
decoder_args.bayesian_domain_task_weights)
self.breakdown2score = combination.breakdown2score_bayesian_state_dependent
if decoder_args.combination_scheme == 'bayesian':
self.breakdown2score = combination.breakdown2score_bayesian
if decoder_args.combination_scheme == 'sum':
self.breakdown2score = combination.breakdown2score_sum
if decoder_args.combination_scheme in ['sum', 'length_norm']:
logging.warn("Using the %s combination strategy has no effect "
"under the combibeam decoder."
% decoder_args.combination_scheme)
else:
self.breakdown2score_kwargs['prev_score'] = None
self.maintain_best_scores = False
@staticmethod
[docs] def get_domain_task_weights(w):
"""Get array of domain-task weights from string w
Returns None if w is None or contains non-square number
of weights (currently invalid)
or 2D numpy float array of weights otherwise
"""
if w is None:
logging.critical(
'Need bayesian_domain_task_weights for state-dependent BI')
else:
domain_weights = utils.split_comma(w, float)
num_domains = int(len(domain_weights) ** 0.5)
if len(domain_weights) == num_domains ** 2:
weights_array = np.reshape(domain_weights,
(num_domains, num_domains))
logging.info('Using {} for Bayesian Interpolation'.format(
weights_array))
return weights_array
else:
logging.critical(
'Need square number of domain-task weights, have {}'.format(
len(domain_weights)))
def _get_initial_hypos(self):
"""Get list containing an initial CombiStatePartialHypothesis"""
return [CombiStatePartialHypo(self.get_predictor_states())]
def _expand_hypo(self, hypo):
"""Get the best beam size expansions of ``hypo``.
Args:
hypo (PartialHypothesis): Hypothesis to expand
Returns:
list. List of child hypotheses
"""
self.set_predictor_states(copy.deepcopy(hypo.predictor_states))
if not hypo.word_to_consume is None: # Consume if cheap expand
self.consume(hypo.word_to_consume)
hypo.word_to_consume = None
posterior, score_breakdown = self.apply_predictors()
hypo.predictor_states = self.get_predictor_states()
expanded_hypos = [hypo.cheap_expand(w, s, score_breakdown[w])
for w, s in utils.common_iterable(posterior)]
for expanded_hypo in expanded_hypos:
if 'prev_score' in self.breakdown2score_kwargs:
self.breakdown2score_kwargs['prev_score'] = expanded_hypo.score_minus_last
expanded_hypo.score = self.breakdown2score(
expanded_hypo.score,
expanded_hypo.score_breakdown,
**self.breakdown2score_kwargs)
expanded_hypos.sort(key=lambda x: -x.score)
return expanded_hypos[:self.beam_size]