I am training a classification model with 3 classes using a deep neural network.
The classes have been resampled and balanced.
I have around 600000 samples... equally distributed.
The dataset is also divided equitably in the train/test/validation dataset.
After training, the overall accuracy is ~65% but individual classes have a disparity.
Class 0 and 1 have high precision and recall, but the class 2 has very low precision and recall... How can I fix this...
Model defined:
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(20, 4)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1024, activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(1024, activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(3, activation='softmax'),
])
hist = evaluate(
model=model,
train_data=(X_train, Y_train),
val_data=(X_test, Y_test),
epochs=100,
batch_size=256,
verbose=True,
loss='categorical_crossentropy',
metrics=['accuracy'],
)
LR scheduler, Early Stopping has been implemented.
Classification report:
precision recall f1-score support
0 0.66 0.99 0.79 196323
1 0.68 0.99 0.80 196323
2 0.57 0.03 0.06 196323
accuracy 0.67 588969
macro avg 0.64 0.67 0.55 588969
weighted avg 0.64 0.67 0.55 588969
Confusion Matrix
[[193909 177 2237]
[ 298 193918 2107]
[ 98184 92287 5852]]
accuracy_score: 0.6684205790117986roc_auc_score: 0.7927754415211714
UPDATE
I tried the OVA approach and below are the results
So the high accuracy score and the high ROC value is because i oversampled the data and that artificially boosted the accuracy score.
in practice, the real accuracy for class 1 (which was 1 in 11 occurrence), was as below:
precision recall f1-score support
0.0 1.00 0.73 0.84 125570
1.0 0.25 0.98 0.40 11977
accuracy 0.75 137547
macro avg 0.63 0.85 0.62 137547
weighted avg 0.93 0.75 0.80 137547
I think this extreme imbalance was the reason for the original problem as well
How do i fix this??