Source code for cam.sgnmt.decoding.lenbeam

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of the lenbeam search strategy """

import copy
from cam.sgnmt import utils
from cam.sgnmt.decoding.core import Decoder, PartialHypothesis


[docs]class LengthBeamDecoder(Decoder): """This beam decoder variant finds hypotheses for all lengths up to the maximum hypo length. At each time step, all EOS extensions are added to the results set. """ def __init__(self, decoder_args): """Creates a new beam decoder instance. The following values are fetched from `decoder_args`: beam (int): Absolute beam size. A beam of 12 means that we keep track of 12 active hypotheses """ super(LengthBeamDecoder, self).__init__(decoder_args) self.beam_size = decoder_args.beam def _expand_hypo(self, hypo): self.set_predictor_states(copy.deepcopy(hypo.predictor_states)) if not hypo.word_to_consume is None: # Consume if cheap expand self.consume(hypo.word_to_consume) hypo.word_to_consume = None posterior, score_breakdown = self.apply_predictors() hypo.predictor_states = self.get_predictor_states() top = utils.argmax_n(posterior, self.beam_size) # EOS hypo eos_hypo = hypo.cheap_expand(utils.EOS_ID, posterior[utils.EOS_ID], score_breakdown[utils.EOS_ID]) self.add_full_hypo(eos_hypo.generate_full_hypothesis()) # All other hypos return [hypo.cheap_expand( trgt_word, posterior[trgt_word], score_breakdown[trgt_word]) for trgt_word in top if trgt_word != utils.EOS_ID]
[docs] def decode(self, src_sentence): """Decodes a single source sentence using beam search. """ self.initialize_predictors(src_sentence) hypos = [PartialHypothesis(self.get_predictor_states())] for _ in range(self.max_len): next_hypos = [] for hypo in hypos: next_hypos.extend(self._expand_hypo(hypo)) next_hypos.sort(key=lambda h: -h.score) hypos = next_hypos[:self.beam_size] return self.get_full_hypos_sorted()