So I've been working on this convolutional neural network but my accuracy is stuck at 62% without improving and I'm afraid I'm in rather severe situation with the overfitting issue. I've been trying to play around with the weight decay and learning rate numbers but I wanted to see if I can get some insight on better fixing this problem. I will attach the code below, thanks in advance!
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class BaseModel(nn.Module):
def init(self):
super(BaseModel, self).init()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 8 * 8, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.25)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
x = self.pool(F.elu(self.conv1(x)))
x = self.pool(F.elu(self.conv2(x)))
x = x.view(-1, 32 * 8 * 8)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
class EnhancedModel(BaseModel):
def init(self):
super(EnhancedModel, self).init()
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(64)
self.dropout3 = nn.Dropout(0.25)
self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(128)
self.dropout4 = nn.Dropout(0.25)
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.bn5 = nn.BatchNorm2d(256)
self.dropout5 = nn.Dropout(0.25)
self.fc3 = nn.Linear(256 * 4 * 4, 512)
self.fc4 = nn.Linear(512, 10)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.25)
self.dropout_fc1 = nn.Dropout(0.5)
self.bn1 = nn.BatchNorm2d(16)
self.bn2 = nn.BatchNorm2d(32)
def forward(self, x):
x = F.elu(self.conv1(x))
x = self.bn1(x)
x = self.dropout1(x)
x = self.pool(x)
x = F.elu(self.conv2(x))
x = self.bn2(x)
x = self.dropout2(x)
x = self.pool(x)
x = x.view(-1, 32 * 8 * 8)
x = F.relu(self.fc1(x))
x = self.dropout_fc1(x)
x = self.fc2(x)
x = self.pool(F.relu(self.bn3(self.conv3(x))))
x = self.dropout3(x)
x = self.pool(F.relu(self.bn4(self.conv4(x))))
x = self.dropout4(x)
x = self.pool(F.relu(self.bn5(self.conv5(x))))
x = self.dropout5(x)
x = x.view(-1, 256 * 4 * 4)
x = F.relu(self.fc3(x))
x = self.dropout(x)
x = self.fc4(x)
return x
class Classifier(nn.Module):
def init(self, use_enhanced_model=False): # Add flag
super(Classifier, self).init()
self.model = EnhancedModel() if use_enhanced_model else BaseModel()
self.model.init_weights()
def init_weights(self):
self.model.init_weights()
def forward(self, x):
return self.model(x)
class Params:
class LRScheduler:
def init(self):
self.type = 'StepLR'
self.step_size = 30
self.gamma = 0.1
def __init__(self):
self.use_gpu = 1
self.train = Params.Training()
self.val = Params.Validation()
self.test = Params.Testing()
self.ckpt = Params.Checkpoint()
self.optim = Params.Optimization()
self.lr_scheduler = Params.LRScheduler()
def process(self):
self.val.vis = self.val.vis.ljust(3, '0')
self.test.vis = self.test.vis.ljust(2, '0')
class Optimization:
def __init__(self): #play with these numbers
self.type = 'adam'
self.lr = 5e-3
self.momentum = 0.9
self.eps = 1e-8
self.weight_decay = 0.01 #tune in to find the right number
# remove layers
class Training:
def __init__(self):
self.probs_data = ''
self.batch_size = 128
self.n_workers = 4
self.n_epochs = 200
class Validation:
def __init__(self):
self.batch_size = 16
self.n_workers = 1
self.gap = 1
self.ratio = 0.2
self.vis = '0'
self.tb_samples = 10
class Testing:
def __init__(self):
self.enable = 0
self.batch_size = 24
self.n_workers = 1
self.vis = '0'
class Checkpoint:
def __init__(self):
self.load = 1
self.path = './checkpoints/model.pt'
self.save_criteria = [
'train-acc',
'train-loss',
'val-acc',
'val-loss',
]
I've added layers, used dropout, playing around with the weight decay number and the learning rate but I don't know if I'm doing the right thing and what numbers would be ideal for the weight decay or learning rate. I'm really new to all this and have practically 0 knowledge or insight so if you can lend some helping hand, it would mean a lot!