1

I am new to pytorch and trying to implement a VAE for MNIST data. When I try to train my model, it appears that the model forces mu and logvar to zero (or something very close to zero) independent of the input. In a way it appears that it is failing to take into account the MSE part of the loss function, but I don't understand why.

Here's the complete code I am using:

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F

batch_size = 32

loc_data = 'MNIST' transformations = transforms.ToTensor() mnist_train = datasets.MNIST(loc_data, train=True, download=True, transform = transformations) mnist_test = datasets.MNIST(loc_data, train=False, download=True, transform = transformations)

train_loader = DataLoader(mnist_train, batch_size=batch_size, drop_last = True, shuffle=True) test_loader = DataLoader(mnist_test, batch_size=batch_size, drop_last = True, shuffle=True)

class Encoder(nn.Module): def init(self, latent_dim=10): super(Encoder, self).init() self.latent_dim = latent_dim self._encoder = nn.Sequential(nn.Linear(in_features = 2828, out_features = 512), nn.ReLU(), nn.Linear(in_features = 512, out_features = 2latent_dim) )

def forward(self, x): x = torch.reshape(self._encoder.forward(x), (-1, 2, self.latent_dim)) mu, logvar = x[:,0,:], x[:,1,:] return mu, logvar

class Decoder(nn.Module): def init(self, latent_dim=10): super(Decoder, self).init() self.latent_dim = latent_dim self._decoder = nn.Sequential(nn.Linear(in_features = latent_dim, out_features = 512), nn.ReLU(), nn.Linear(in_features = 512, out_features = 28*28), nn.Sigmoid())

def forward(self,x): return self._decoder.forward(x)

def sample(mu, logvar): z = torch.randn_like(mu) return mu + torch.mul(torch.exp(0.5*logvar), z)

def vae_loss(x, x_hat, mu, logvar): mse = (x - x_hat).pow(2).sum()/(x.shape[0]1.0) KL_loss = 0.5torch.sum(-1 + torch.pow(mu,2) - logvar + torch.exp(logvar)) return torch.add(mse, KL_loss)

def train(encoder, decoder, train_loader, optimizer, num_epochs = 10):
encoder.train() decoder.train() for ii in range(num_epochs): print("Epoch {}".format(ii)) for jj, (x, y) in enumerate(train_loader): x = torch.reshape(x, (-1,28*28)) x.to(device) _mu, _logvar = encoder.forward(x) _z = sample(_mu, _logvar) x_hat = decoder.forward(_z) #.reshape((-1,28,28)) optimizer.zero_grad() loss = vae_loss(x, x_hat, _mu, _logvar) loss.backward() optimizer.step() if jj % 100 == 0: print(loss)
return loss

latent_dim = 20 encoder = Encoder(latent_dim) decoder = Decoder(latent_dim) params = list(encoder.parameters())+list(decoder.parameters()) optimizer = optim.Adam(params, lr=1e-2) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train(encoder, decoder, train_loader, optimizer, num_epochs = 1)

when I try to probe the mu or logvar for some test data, it seems that the result is almost identically zero.

yorkiva
  • 11
  • 4

0 Answers0