3

I am trying to make a binary text classification model by using the encoder part of the transformer and then using its output to feed into an LSTM network. However, I am not able to achieve good accuracy on both the training set (92%) and the validation set (72%). Is my approach correct? Please tell me a better way to design the model and improve accuracy.

Jindřich
  • 1,809
  • 7
  • 9
Khobaib Alam
  • 39
  • 1
  • 2

1 Answers1

10

Your model is overfitting. You should try standard methods people use to prevent overfitting:

  • Larger dropout (up to 0.5), in low-resource setups word dropout (i.e., randomly masking input tokens) also sometimes help (0.1-0.3 might be reasonable values).
  • If you have many input classes, label smoothing can help.
  • You can try a smaller model dimension.

If you use a pre-trained Transformer (such as BERT), you, of course, cannot change the model dimension. In that case, you can try to set a much smaller learning rate for fine-tuning BERT than you use for training the actual classifier.

Jindřich
  • 1,809
  • 7
  • 9