finished VAE and VAE loss function, working on dataset

This commit is contained in:
junhong
2018-04-12 23:22:50 -04:00
parent 2d57a02c19
commit d53f74fc40
4 changed files with 168 additions and 17 deletions
+17
View File
@@ -0,0 +1,17 @@
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from rollouts import load_rollouts
class VAEDataset(Dataset):
def __init__(self, rollout_path):
super(VAEDataset, self).__init__()
rollouts = load_rollouts(rollout_path)
self.images = rollouts['observations']
def __getitem__(self, index):
pass
def __len__(self):
return len(self.images)
+15 -14
View File
@@ -2,6 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from torch import normal, multinomial from torch import normal, multinomial
from torch.autograd import Variable
class MDNRNN(nn.Module): class MDNRNN(nn.Module):
@@ -86,17 +87,17 @@ class MDNRNN(nn.Module):
return z, hidden_state return z, hidden_state
if __name__ == '__main__': # if __name__ == '__main__':
from torch.autograd import Variable # from torch.autograd import Variable
z_dim, action_dim, hidden_size, n_mixture, temp = 32, 2, 256, 5, 0.0 # z_dim, action_dim, hidden_size, n_mixture, temp = 32, 2, 256, 5, 0.0
batch_size = 1 # batch_size = 1
seq_len = 1 # seq_len = 1
mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp) # mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp)
mdnrnn.cuda() # mdnrnn.cuda()
prev_z = Variable(torch.randn(batch_size, seq_len, z_dim)).cuda() # prev_z = Variable(torch.randn(batch_size, seq_len, z_dim)).cuda()
action = Variable(torch.randn(batch_size, seq_len, action_dim)).cuda() # action = Variable(torch.randn(batch_size, seq_len, action_dim)).cuda()
#
new_z, new_hidden_state = mdnrnn.sample(prev_z, action) # new_z, new_hidden_state = mdnrnn.sample(prev_z, action)
print(new_z) # print(new_z)
pi, mean, sigma, hidden_state = mdnrnn.forward(prev_z, action) # pi, mean, sigma, hidden_state = mdnrnn.forward(prev_z, action)
print(sigma) # print(sigma)
+35
View File
@@ -0,0 +1,35 @@
import gym
import torch
def random_rollouts(env, num_rollouts, render=False):
"""
This function collects random rollouts from a given environment.
"""
obs = env.reset()
num_obs = 0
rollouts = {'observations': [obs], 'actions': ['0']}
while num_obs != num_rollouts:
if render:
env.render()
num_obs += 1
action = env.action_space.sample()
obs, reward, done, _ = env.step(action)
rollouts['observations'].append(obs)
rollouts['actions'].append(action)
if done:
obs = env.reset()
rollouts['observations'].append(obs)
rollouts['actions'].append('0')
return rollouts
def save_rollouts(rollouts):
torch.save(rollouts, 'rollouts.data')
def load_rollouts(fname):
return torch.load(fname)
+101 -3
View File
@@ -1,12 +1,110 @@
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
def make_conv_relu(inpt_kernel, output_kernel, kernel_size=4):
return nn.Sequential(
nn.Conv2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
nn.ReLU(inplace=True)
)
def make_deconv_relu(inpt_kernel, output_kernel, kernel_size, use_activation=True):
return nn.Sequential(
nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
nn.ReLU(inplace=True)
) if use_activation \
else nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2)
# Reconstruction + KL divergence losses summed over all elements and batch
# https://github.com/pytorch/examples/blob/master/vae/main.py
def loss_function(recon_x, x, mu, logvar):
n, c, h, w = recon_x.size()
recon_x = recon_x.view(n, -1)
x = x.view(n, -1)
# L2 distance
l2_dist = torch.mean(torch.sqrt(torch.sum(torch.pow(recon_x - x, 2), 1)))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return l2_dist + KLD
class VAE(nn.Module): class VAE(nn.Module):
def __init__(self): def __init__(self, latent_vector_dim=32):
# TODO: Variational Auto Encoder
super(VAE, self).__init__() super(VAE, self).__init__()
# encoder part
self.conv1 = make_conv_relu(3, 32)
self.conv2 = make_conv_relu(32, 64)
self.conv3 = make_conv_relu(64, 128)
self.conv4 = make_conv_relu(128, 256)
self.mu = nn.Linear(1024, latent_vector_dim)
self.logvar = nn.Linear(1024, latent_vector_dim)
self.z = nn.Linear(latent_vector_dim, 1024)
# decoder part
self.deconv1 = make_deconv_relu(1024, 128, 5)
self.deconv2 = make_deconv_relu(128, 64, 5)
self.deconv3 = make_deconv_relu(64, 32, 6)
self.deconv4 = make_deconv_relu(32, 3, 6)
self.sigmoid = nn.Sigmoid()
def encode(self, x):
"""
Returns mean and log variance
"""
x = self.conv4(self.conv3(self.conv2(self.conv1(x))))
x = x.view(x.size()[0], -1)
return self.mu(x), self.logvar(x)
def sample(self, mu, logvar):
if self.training:
std = logvar.exp()
std = std * Variable(std.data.new(std.size()).normal_())
return mu + std
else:
return mu
def decode(self, z):
z = self.z(z)
n, d = z.size()
z = z.view(n, d, 1, 1)
reconstruction = self.deconv4(self.deconv3(self.deconv2(self.deconv1(z))))
reconstruction = self.sigmoid(reconstruction)
return reconstruction
def forward(self, x): def forward(self, x):
pass """
Returns reconstructed image, mean, and log variance
"""
mu, logvar = self.encode(x)
z = self.sample(mu, logvar)
x = self.decode(z)
return x, mu, logvar
# if __name__ == '__main__':
# import numpy as np
# import cv2
#
# img = np.random.randn(64, 64, 3)
# gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
#
# vae = VAE()
# vae.cuda()
# x, mu, logvar = vae.forward(gpu_img)
# print(x.size())
# print(loss_function(x, gpu_img, mu, logvar))
# x = x.data.cpu().numpy()[0].transpose(1, 2, 0)
#
# cv2.imshow('original', img)
# cv2.imshow('reconstructed', x)
# cv2.waitKey()