I am trying to make a binary text classification model by using the encoder part of the transformer and then using its output to feed into an LSTM network. However, I am not able to achieve good accuracy on both the training set (92%) and the validation set (72%). Is my approach correct? Please tell me a better way to design the model and improve accuracy.
Asked
Active
Viewed 9,766 times
1 Answers
10
Your model is overfitting. You should try standard methods people use to prevent overfitting:
- Larger dropout (up to 0.5), in low-resource setups word dropout (i.e., randomly masking input tokens) also sometimes help (0.1-0.3 might be reasonable values).
- If you have many input classes, label smoothing can help.
- You can try a smaller model dimension.
If you use a pre-trained Transformer (such as BERT), you, of course, cannot change the model dimension. In that case, you can try to set a much smaller learning rate for fine-tuning BERT than you use for training the actual classifier.
Jindřich
- 1,809
- 7
- 9