# -*- coding: utf-8 -*-
# coding=utf-8
# Copyright 2019 The SGNMT Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of a beam search which uses an FST for synchronization."""
from cam.sgnmt import utils
from cam.sgnmt.decoding.beam import BeamDecoder
from cam.sgnmt.decoding.core import PartialHypothesis
import logging
from cam.sgnmt.utils import load_fst
try:
import pywrapfst as fst
except ImportError:
try:
import openfst_python as fst
except ImportError:
pass # Deal with it in decode.py
[docs]class FSTBeamDecoder(BeamDecoder):
"""This beam search variant synchronizes hypotheses if they share
the same node in an FST. This is similar to the syncbeam strategy,
but rather than using a dedicated synchronization symbol, we keep
track of the state ID of each hypothesis in an FST. Hypotheses are
expanded until all of them arrive at the same state id, and are
then compared with each other to select the set of active
hypotheses in the next time step.
"""
def __init__(self, decoder_args):
"""Creates a new beam decoder instance with FST-based
synchronization. In addition to the constructor of
`BeamDecoder`, the following values are fetched from
`decoder_args`:
max_word_len (int): Maximum length of a single word
fst_path (string): Path to the FST.
"""
super(FSTBeamDecoder, self).__init__(decoder_args)
self.fst_path = decoder_args.fst_path
self.max_word_len = decoder_args.max_word_len
def _register_sub_score(self, score):
"""Updates best_scores and min_score. """
if not self.maintain_best_scores:
return
self.sub_best_scores.append(score)
self.sub_best_scores.sort(reverse=True)
if len(self.sub_best_scores) >= self.beam_size:
self.sub_best_scores = self.sub_best_scores[:self.beam_size]
self.sub_min_score = self.sub_best_scores[-1]
def _get_label2node(self, root_node):
return {arc.olabel: arc.nextstate
for arc in self.cur_fst.arcs(root_node)}
def _find_start_node(self):
for arc in self.cur_fst.arcs(self.cur_fst.start()):
if arc.olabel == utils.GO_ID:
return arc.nextstate
logging.error("Start symbol %d not found in fstbeam FST!" % utils.GO_ID)
def _get_initial_hypos(self):
"""Get the list of initial ``PartialHypothesis``. """
self.cur_fst = load_fst(utils.get_path(self.fst_path,
self.current_sen_id+1))
init_hypo = PartialHypothesis(self.get_predictor_states())
init_hypo.fst_node = self._find_start_node()
return [init_hypo]
def _expand_hypo(self, hypo):
"""Expand hypo until all of the beam size best hypotheses end
with ``sync_symb`` or EOS.
Args:
hypo (PartialHypothesis): Hypothesis to expand
Return:
list. List of expanded hypotheses.
"""
# Get initial expansions
l2n = self._get_label2node(hypo.fst_node)
deepest_node = -1
next_hypos = []
next_scores = []
for next_hypo in super(FSTBeamDecoder, self)._expand_hypo(hypo):
node_id = l2n[next_hypo.trgt_sentence[-1]]
deepest_node = max(node_id, deepest_node)
next_hypo.fst_node = node_id
next_hypos.append(next_hypo)
next_scores.append(self._get_combined_score(next_hypo))
# Expand until all hypos are at deepest_node.
# This assumes that the FST is topologically sorted
open_hypos = []
open_hypos_scores = []
closed_hypos = []
for next_hypo, next_score in zip(next_hypos, next_scores):
if next_hypo.fst_node == deepest_node:
closed_hypos.append(next_hypo)
else:
open_hypos.append(next_hypo)
open_hypos_scores.append(next_score)
open_hypos = self._get_next_hypos(open_hypos, open_hypos_scores)
it = 1
while open_hypos:
if it > self.max_word_len: # prevent infinite loops
logging.debug("Maximum word length reached.")
break
it = it + 1
next_hypos = []
next_scores = []
self.sub_min_score = self.min_score
self.sub_best_scores = []
for h in open_hypos:
if h.score > self.sub_min_score:
l2n = self._get_label2node(h.fst_node)
for next_hypo in super(FSTBeamDecoder, self)._expand_hypo(h):
next_score = self._get_combined_score(next_hypo)
if next_score > self.sub_min_score:
next_hypo.fst_node = l2n[next_hypo.trgt_sentence[-1]]
if next_hypo.fst_node < deepest_node: # Keep
next_hypos.append(next_hypo)
next_scores.append(next_score)
self._register_sub_score(next_score)
elif next_hypo.fst_node == deepest_node: # Add to closed
closed_hypos.append(next_hypo)
elif next_hypo.fst_node > deepest_node: # Log
logging.debug("FSTBeam: Deepest node exceeded")
open_hypos = self._get_next_hypos(next_hypos, next_scores)
logging.debug("Expand %f: %s (%d)" % (hypo.score,
hypo.trgt_sentence,
hypo.fst_node))
for h in closed_hypos:
logging.debug("-> %f: %s (%d)" % (h.score,
h.trgt_sentence,
h.fst_node))
return closed_hypos