5

I have barely been able to find an implementation of the Transformer (that is not bloated nor confusing), and the one that I've used as reference was the PyTorch implementation. However, the Pytorch implementation requires you to pass the input (src) and the target (tgt) tensors for every step, rather than encoding the input once and keep on iterating for n steps to generate the full output. Am I missing something here?

My first guesses were that the Transformer isn't technically a seq2seq model, that I have not understood how I'm supposed to implement it, or that I've just been implementing seq2seq models incorrectly for the last few years :)

skevelis
  • 53
  • 1
  • 3

1 Answers1

6

The Transformer is a seq2seq model.

At training time, you pass to the Transformer model both the source and target tokens, just like what you do with LSTMs or GRUs with teacher forcing, which is the default way of training them. Note that, in the Transformer decoder, we need to apply masking to avoid the predictions depending on the current and future tokens.

At inference time, we don't have the target tokens (because that is what we are trying to predict). In this case, the decoder input in the first step would just be the sequence [], and we would predict the first token. Then, we would prepare the input for the next timestep appending the prediction to the previous timestep input (i.e. []), and then we would obtain the prediction for the second token. And so on. Note that, at each timestep, we are repeating the computations for the past positions; in real implementations, these states are cached instead of re-computed each timestep.

About some piece of Python code illustrating how the Transformer works, I suggest The annotated Transformer, which is a nice guide through a real implementation. You may be most interested in the function run_epoch for the training and in the function greedy_decode for the inference.

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, 
                           Variable(ys), 
                           Variable(subsequent_mask(ys.size(1))
                                    .type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

In greedy_decode you can see how the predictions of the current timestep are concatenated to the input to create the input for the following timestep.

noe
  • 28,203
  • 1
  • 49
  • 83