mirror of
https://github.com/wassname/world-models-pytorch.git
synced 2026-06-27 14:47:32 +08:00
110 lines
3.5 KiB
Python
110 lines
3.5 KiB
Python
import torch
|
|
from torch.nn import functional as F
|
|
from torch.autograd import Variable
|
|
import torch.nn as nn
|
|
|
|
|
|
def make_conv_relu(inpt_kernel, output_kernel, kernel_size=4):
|
|
return nn.Sequential(
|
|
nn.Conv2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
|
|
def make_deconv_relu(inpt_kernel, output_kernel, kernel_size, use_activation=True):
|
|
return nn.Sequential(
|
|
nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
|
|
nn.ReLU(inplace=True)
|
|
) if use_activation \
|
|
else nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2)
|
|
|
|
|
|
# Reconstruction + KL divergence losses summed over all elements and batch
|
|
# https://github.com/pytorch/examples/blob/master/vae/main.py
|
|
def loss_function(recon_x, x, mu, logvar):
|
|
n, c, h, w = recon_x.size()
|
|
recon_x = recon_x.view(n, -1)
|
|
x = x.view(n, -1)
|
|
# L2 distance
|
|
l2_dist = torch.mean(torch.sqrt(torch.sum(torch.pow(recon_x - x, 2), 1)))
|
|
# see Appendix B from VAE paper:
|
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
|
# https://arxiv.org/abs/1312.6114
|
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
|
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
|
return l2_dist + KLD
|
|
|
|
|
|
class VAE(nn.Module):
|
|
def __init__(self, latent_vector_dim=32):
|
|
super(VAE, self).__init__()
|
|
# encoder part
|
|
self.conv1 = make_conv_relu(3, 32)
|
|
self.conv2 = make_conv_relu(32, 64)
|
|
self.conv3 = make_conv_relu(64, 128)
|
|
self.conv4 = make_conv_relu(128, 256)
|
|
|
|
self.mu = nn.Linear(1024, latent_vector_dim)
|
|
self.logvar = nn.Linear(1024, latent_vector_dim)
|
|
|
|
self.z = nn.Linear(latent_vector_dim, 1024)
|
|
|
|
# decoder part
|
|
self.deconv1 = make_deconv_relu(1024, 128, 5)
|
|
self.deconv2 = make_deconv_relu(128, 64, 5)
|
|
self.deconv3 = make_deconv_relu(64, 32, 6)
|
|
self.deconv4 = make_deconv_relu(32, 3, 6)
|
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
def encode(self, x):
|
|
"""
|
|
Returns mean and log variance
|
|
"""
|
|
x = self.conv4(self.conv3(self.conv2(self.conv1(x))))
|
|
x = x.view(x.size()[0], -1)
|
|
return self.mu(x), self.logvar(x)
|
|
|
|
def sample(self, mu, logvar):
|
|
if self.training:
|
|
std = logvar.exp()
|
|
std = std * Variable(std.data.new(std.size()).normal_())
|
|
return mu + std
|
|
else:
|
|
return mu
|
|
|
|
def decode(self, z):
|
|
z = self.z(z)
|
|
n, d = z.size()
|
|
z = z.view(n, d, 1, 1)
|
|
reconstruction = self.deconv4(self.deconv3(self.deconv2(self.deconv1(z))))
|
|
reconstruction = self.sigmoid(reconstruction)
|
|
return reconstruction
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Returns reconstructed image, mean, and log variance
|
|
"""
|
|
mu, logvar = self.encode(x)
|
|
z = self.sample(mu, logvar)
|
|
x = self.decode(z)
|
|
return x, mu, logvar
|
|
|
|
|
|
# if __name__ == '__main__':
|
|
# import numpy as np
|
|
# import cv2
|
|
#
|
|
# img = np.random.randn(64, 64, 3)
|
|
# gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
|
|
#
|
|
# vae = VAE()
|
|
# vae.cuda()
|
|
# x, mu, logvar = vae.forward(gpu_img)
|
|
# print(x.size())
|
|
# print(loss_function(x, gpu_img, mu, logvar))
|
|
# x = x.data.cpu().numpy()[0].transpose(1, 2, 0)
|
|
#
|
|
# cv2.imshow('original', img)
|
|
# cv2.imshow('reconstructed', x)
|
|
# cv2.waitKey() |