Source code for cam.sgnmt.decoding.syncbeam

# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of beam search with explicit synchronization symbol"""

from cam.sgnmt import utils
from cam.sgnmt.decoding.beam import BeamDecoder
import logging


[docs]class SyncBeamDecoder(BeamDecoder): """This beam search implementation is a two level approach. Hypotheses are not compared after each iteration, but after consuming an explicit synchronization symbol. This is useful when SGNMT runs on the character level, but it makes more sense to compare hypos with same lengths in terms of number of words and not characters. The end-of-word symbol </w> can be used as synchronization symbol. """ def __init__(self, decoder_args): """Creates a new beam decoder instance with explicit synchronization symbol. In addition to the constructor of `BeamDecoder`, the following values are fetched from `decoder_args`: sync_symb (int): Synchronization symbol. If negative, consider a hypothesis as closed when it ends with a terminal symbol (see syntax_min_terminal_id and syntax_max_terminal_id) max_word_len (int): Maximum length of a single word """ super(SyncBeamDecoder, self).__init__(decoder_args) if decoder_args.sync_symbol >= 0: self.sync_min = decoder_args.sync_symbol self.sync_max = decoder_args.sync_symbol else: self.sync_min = decoder_args.syntax_min_terminal_id self.sync_max = decoder_args.syntax_max_terminal_id self.max_word_len = decoder_args.max_word_len def _is_closed(self, hypo): """Returns true if hypo ends with </S> or </w>""" w = hypo.get_last_word() return w == utils.EOS_ID or (w >= self.sync_min and w <= self.sync_max) def _all_eos_or_eow(self, hypos): """Returns true if the all hypotheses end with </S> or </W>""" for hypo in hypos: if not self._is_closed(hypo): return True return False def _register_sub_score(self, score): """Updates best_scores and min_score. """ if not self.maintain_best_scores: return self.sub_best_scores.append(score) self.sub_best_scores.sort(reverse=True) if len(self.sub_best_scores) >= self.beam_size: self.sub_best_scores = self.sub_best_scores[:self.beam_size] self.sub_min_score = self.sub_best_scores[-1] def _expand_hypo(self, hypo): """Expand hypo until all of the beam size best hypotheses end with ``sync_symb`` or EOS. Args: hypo (PartialHypothesis): Hypothesis to expand Return: list. List of expanded hypotheses. """ # Get initial expansions next_hypos = [] next_scores = [] for next_hypo in super(SyncBeamDecoder, self)._expand_hypo(hypo): next_hypos.append(next_hypo) next_scores.append(self._get_combined_score(next_hypo)) hypos = self._get_next_hypos(next_hypos, next_scores) # Expand until all hypos are closed it = 1 while self._all_eos_or_eow(hypos): if it > self.max_word_len: # prevent infinite loops logging.debug("Maximum word length reached.") return [] it = it + 1 next_hypos = [] next_scores = [] self.sub_min_score = self.min_score self.sub_best_scores = [] for h in hypos: if self._is_closed(h): next_hypos.append(h) next_scores.append(self._get_combined_score(h)) continue if h.score > self.sub_min_score: for next_hypo in super(SyncBeamDecoder, self)._expand_hypo(h): next_score = self._get_combined_score(next_hypo) if next_score > self.sub_min_score: next_hypos.append(next_hypo) next_scores.append(next_score) self._register_sub_score(next_score) hypos = self._get_next_hypos(next_hypos, next_scores) ret = [h for h in hypos if self._is_closed(h)] logging.debug("Expand %f: %s (%d)" % (hypo.score, hypo.trgt_sentence, len(hypo.trgt_sentence))) for h in ret: logging.debug("-> %f: %s (%d)" % (h.score, h.trgt_sentence, len(h.trgt_sentence))) return ret