2

I have a question about the decoder transformer feed forward during training.

Let's pick an example: input data "i love the sun" traduction i want to predict (italian traduction) "io amo il sole".

Now i feed the encoder with the input "i love the sun" and i get the hidden states. Now i have to do multiple feed forwards on the decoder with the input "BOS io amo il" where BOS is a token that stands for beginning of sentence. So i have this feedforward i assume

  • [BOS, IO, AMO, IL] -> decoder -> IO
  • [BOS, IO, AMO, IL] -> decoder -> AMO
  • [BOS, IO, AMO, IL] -> decoder -> IL
  • [BOS, IO, AMO, IL] -> decoder -> SOLE

I think this is the correct way. And what should be applied to differentiate the training i think is the masked attention mechanism maybe(?) is it right to assume that the masking will be

[1 0 0 0,
 0 0 0 0 ,
 0 0 0 0,
 0 0 0 0]   for the first feed forward

[1 0 0 0, 1 1 0 0 , 0 0 0 0, 0 0 0 0] for the second feed forward

[1 0 0 0, 1 1 0 0 , 1 1 1 0, 0 0 0 0] for the third feed forward

[1 0 0 0, 1 1 0 0 , 1 1 1 0, 1 1 1 1] for the fourth feed forward

is it the correct way? or what should be different? If you can provide me also a python implementation could be useful, thanks in advance.

erre4
  • 95
  • 7

1 Answers1

2

There are some problems with your description:

  • During training, the decoder receives all the shifted target tokens, prepending the BOS token. You removed sole. The actual input would be: [<bos>, io, amo, il, sole]. Note that the output at the position of sole would be the end-of-sequence token <eos>.

  • During training, there is a single forward pass (not one per token), and all the output tokens are predicted at once. Therefore, only the last one of your attention masks is used.

  • During inference, we don't have the target tokens (because that is what we are trying to predict). In this case, we have one pass per generated token, starting with <bos>. This way, the decoder input in the first step would just be the sequence [<bos>], and we would predict the first token: io. Then, we would prepare the input for the next timestep as [<bos>, io], and then we would obtain the prediction for the second token. And so on. Note that, at each timestep, we are repeating the computations for the past positions; in real implementations, these states are cached instead of re-computed each timestep.

About some piece of Python code illustrating how the Transformer works, I suggest The annotated Transformer, which is a nice guide through a real implementation. You may be most interested in the function run_epoch for the training and in the function greedy_decode for the inference.

noe
  • 28,203
  • 1
  • 49
  • 83