mirror of
https://github.com/wassname/world-models-pytorch.git
synced 2026-06-27 18:27:21 +08:00
finished VAE and VAE loss function, working on dataset
This commit is contained in:
@@ -1,12 +1,110 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_conv_relu(inpt_kernel, output_kernel, kernel_size=4):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
def make_deconv_relu(inpt_kernel, output_kernel, kernel_size, use_activation=True):
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
|
||||
nn.ReLU(inplace=True)
|
||||
) if use_activation \
|
||||
else nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2)
|
||||
|
||||
|
||||
# Reconstruction + KL divergence losses summed over all elements and batch
|
||||
# https://github.com/pytorch/examples/blob/master/vae/main.py
|
||||
def loss_function(recon_x, x, mu, logvar):
|
||||
n, c, h, w = recon_x.size()
|
||||
recon_x = recon_x.view(n, -1)
|
||||
x = x.view(n, -1)
|
||||
# L2 distance
|
||||
l2_dist = torch.mean(torch.sqrt(torch.sum(torch.pow(recon_x - x, 2), 1)))
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
return l2_dist + KLD
|
||||
|
||||
|
||||
class VAE(nn.Module):
|
||||
def __init__(self):
|
||||
# TODO: Variational Auto Encoder
|
||||
def __init__(self, latent_vector_dim=32):
|
||||
super(VAE, self).__init__()
|
||||
# encoder part
|
||||
self.conv1 = make_conv_relu(3, 32)
|
||||
self.conv2 = make_conv_relu(32, 64)
|
||||
self.conv3 = make_conv_relu(64, 128)
|
||||
self.conv4 = make_conv_relu(128, 256)
|
||||
|
||||
self.mu = nn.Linear(1024, latent_vector_dim)
|
||||
self.logvar = nn.Linear(1024, latent_vector_dim)
|
||||
|
||||
self.z = nn.Linear(latent_vector_dim, 1024)
|
||||
|
||||
# decoder part
|
||||
self.deconv1 = make_deconv_relu(1024, 128, 5)
|
||||
self.deconv2 = make_deconv_relu(128, 64, 5)
|
||||
self.deconv3 = make_deconv_relu(64, 32, 6)
|
||||
self.deconv4 = make_deconv_relu(32, 3, 6)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def encode(self, x):
|
||||
"""
|
||||
Returns mean and log variance
|
||||
"""
|
||||
x = self.conv4(self.conv3(self.conv2(self.conv1(x))))
|
||||
x = x.view(x.size()[0], -1)
|
||||
return self.mu(x), self.logvar(x)
|
||||
|
||||
def sample(self, mu, logvar):
|
||||
if self.training:
|
||||
std = logvar.exp()
|
||||
std = std * Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std
|
||||
else:
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
z = self.z(z)
|
||||
n, d = z.size()
|
||||
z = z.view(n, d, 1, 1)
|
||||
reconstruction = self.deconv4(self.deconv3(self.deconv2(self.deconv1(z))))
|
||||
reconstruction = self.sigmoid(reconstruction)
|
||||
return reconstruction
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
"""
|
||||
Returns reconstructed image, mean, and log variance
|
||||
"""
|
||||
mu, logvar = self.encode(x)
|
||||
z = self.sample(mu, logvar)
|
||||
x = self.decode(z)
|
||||
return x, mu, logvar
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# import numpy as np
|
||||
# import cv2
|
||||
#
|
||||
# img = np.random.randn(64, 64, 3)
|
||||
# gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
|
||||
#
|
||||
# vae = VAE()
|
||||
# vae.cuda()
|
||||
# x, mu, logvar = vae.forward(gpu_img)
|
||||
# print(x.size())
|
||||
# print(loss_function(x, gpu_img, mu, logvar))
|
||||
# x = x.data.cpu().numpy()[0].transpose(1, 2, 0)
|
||||
#
|
||||
# cv2.imshow('original', img)
|
||||
# cv2.imshow('reconstructed', x)
|
||||
# cv2.waitKey()
|
||||
Reference in New Issue
Block a user