Tutorial: Adding new components

We welcome contributions to SGNMT. This tutorial will help you get started with writing new components. This tutorial assumes that you have read the Tutorial: Basics (T2T) page and are already familiar with the basic concepts of predictors and decoders. You should also be comfortable with coding in Python.

Writing decoders

This section explains how to write decoders (search strategies) that are decoupled from predictors, ie. work with any predictor constellation. This tutorial walks you through implementing a new simplebeam decoder, which will be a simplified version of the standard beam search implementation in SGNMT.

The first step is to make SGNMT aware of the new decoder. Find the create_decoder() factory method in cam/sgnmt/decode_utils.py. This method instantiates decoders based on the --decoder argument. Add the following code to the if-structure:

elif args.decoder == "simplebeam":
        decoder = SimpleBeamDecoder(args)

Import SimpleBeamDecoder at the begin of decode_utils.py.

from cam.sgnmt.decoding.simplebeam import SimpleBeamDecoder

Finally, you need to update the specification of --decoder in the command-line argument parser. Open cam/sgnmt/ui.py and find the following snippet in the get_parser() method:

group.add_argument("--decoder", default="beam",

Add simplebeam to the list of choices and save the file.

Create a new file cam/sgnmt/decoding/simplebeam.py and open it in your favorite text editor. This is where we will implement our new decoder. Copy the following Decoder stub implementation to simplebeam.py:

import copy
from cam.sgnmt import utils
from cam.sgnmt.decoding.core import Decoder, PartialHypothesis

class SimpleBeamDecoder(Decoder):

    def decode(self, src_sentence):
        print("SimpleBeam source sentence: %s" % src_sentence)

This implementation does not do much yet. In fact, it will result in an error since decode() is expected to return translation hypotheses. However, you can still test if your code is reached by activating the simplebeam decoder in the Tutorial: Basics (T2T) environment:

$ python $SGNMT/decode.py --decoder simplebeam --predictors nmt --config_file tut.ini
SimpleBeam source sentence: [1543, 7, 1491, 1359, 1532, 692, 9, 6173]
2017-10-14 12:57:03,171 ERROR: An unexpected <type 'exceptions.TypeError'> error has occurred at sentence id 1: 'NoneType' object is not iterable, Stack trace: Traceback (most recent call last):
  File "/home/mifs/fs439/bin/sgnmt-tut/cam/sgnmt/decode_utils.py", line 651, in do_decode
      for hypo in decoder.decode(utils.apply_src_wmap(src))
      TypeError: 'NoneType' object is not iterable

As expected, SGNMT crashes because of the lack of a return value, but the print statement in your decode() method is executed. Next we would like to access certain command-line arguments, for example the --beam option for configuring the beam size. All options can be accessed via decoder_args in the constructor. Add the following constructor to SimpleBeamDecoder:

def __init__(self, decoder_args):
    super(SimpleBeamDecoder, self).__init__(decoder_args)
    self.beam_size = decoder_args.beam

We will now focus on implementing the decode() method in our SimpleBeamDecoder. The general idea is to maintain a list of beam_size (partial) hypotheses hypos, sorted from best to worst. We initialize hypos with the empty translation prefix and the initial predictor states. We keep doing beam search iterations until the best hypothesis in hypos ends with the end-of-sentence symbol:

def decode(self, src_sentence):
    hypos = [PartialHypothesis(self.get_predictor_states())]
    while hypos[0].get_last_word() != utils.EOS_ID:
        next_hypos = []
        for hypo in hypos:
        next_hypos.sort(key=lambda h: -h.score)
        hypos = next_hypos[:self.beam_size]
    return [hypos[0].generate_full_hypothesis()]

The core of the algorithm is the _expand_hypo() method which still needs to be implemented. Expanding a partial hypothesis hypo involves the following steps:

  • Load the predictor states from hypo.
  • Consume the last word in hypo if necessary.
  • Call apply_predictors() to get the scores for the next time step.
  • Expand hypo with the best beam_size scores and the updated predictor states.
def _expand_hypo(self, hypo):
    if hypo.get_last_word() == utils.EOS_ID:
        return [hypo]
    if not hypo.word_to_consume is None: # Consume if cheap expand
        hypo.word_to_consume = None
    posterior, score_breakdown = self.apply_predictors()
    hypo.predictor_states = self.get_predictor_states()
    top = utils.argmax_n(posterior, self.beam_size)
    return [hypo.cheap_expand(
                        score_breakdown[trgt_word]) for trgt_word in top]

Verify that your new decoder works as expected:

$ python $SGNMT/decode.py --decoder simplebeam --predictors nmt --config_file tut.ini
2017-10-14 13:40:00,782 INFO: Decoded (ID: 1): 1511 7 1422 894 30 8 10453
2017-10-14 13:40:00,782 INFO: Stats (ID: 1): score=-3.700894 num_expansions=85 time=28.04

You should obtain the same translation and the same score as using the standard beam search implementation with --decoder beam.

Writing predictors

The first decision to make when implementing a new predictor is about the structure of its state. Consider the following two code examples:

state = predictor.get_state()
state = predictor.get_state()

The predictor state must contain enough information such that the last lines in the above examples always exhibit the same behavior, independently of how many (if any) function calls are made in arbitrarySequenceOfPredictNextAndConsumeCalls(). The next central questions are:

  • How to initialize the predictor state given the source sentence (initialize())?
  • How to compute the posterior for the next target position given the state (predict_next())?
  • How to update the predictor state when consuming a target token (consume())?

Let’s say we want to write a predictor which enforces that the target sentence length is always equal to the source sentence length. Open cam/sgnmt/decode_utils.py, find the add_predictors() method and add the following piece of code to the main if-structure:

elif pred == "eqlen":
    p = EqualLengthPredictor()

Import the new EqualLengthPredictor at the begin of decode_utils.py

from cam.sgnmt.predictors.equal_length import EqualLengthPredictor

We will implement the new predictor in a new file called cam/sgnmt/predictors/equal_length.py. Here is the complete implementation:

from cam.sgnmt.predictors.core import Predictor
from cam.sgnmt import utils

class EqualLengthPredictor(Predictor):

    def get_unk_probability(self, posterior):
        if self.n_consumed == self.src_length:
            return utils.NEG_INF
        return 0.0

    def predict_next(self):
        if self.n_consumed == self.src_length:
            return {utils.EOS_ID : 0.0}
        return {utils.EOS_ID : utils.NEG_INF}

    def initialize(self, src_sentence):
        self.src_length = len(src_sentence)
        self.n_consumed = 0

    def consume(self, word):
        self.n_consumed += 1

    def get_state(self):
        return self.n_consumed

    def set_state(self, state):
        self.n_consumed = state

The state of the predictor is the number of target tokens consumed so far (self.n_consumed). predict_next() sets the score of the end-of-sentence symbol to 0.0 (optimal score) if this matches the required length self.src_length, and to negative infinity otherwise, preventing the end-of-sentence symbol if this condition is not met. Scores for all tokens which are not in the dictionary returned by predict_next() are defined via the get_unk_probability() method. We assign negative infinity to all tokens besides end-of-sentence if the hypothesis has the right length, and an optimal score of 0.0 otherwise.

Verify that your new predictor works by trying the following command:

$ python $SGNMT/decode.py --predictors eqlen,nmt --config_file tut.ini
2017-10-15 14:13:31,049 INFO: Next sentence (ID: 1): 1543 7 1491 1359 1532 692 9 6173
2017-10-15 14:13:57,378 INFO: Decoded (ID: 1): 1511 7 1422 0 894 30 8 10453
2017-10-15 14:13:57,378 INFO: Stats (ID: 1): score=-4.779791 num_expansions=88 time=26.33
2017-10-15 14:13:57,378 INFO: Decoding finished. Time: 26.33

Both source and target sentence consist of 8 tokens.