diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..a686755 --- /dev/null +++ b/dataset.py @@ -0,0 +1,17 @@ +from torch.utils.data.dataset import Dataset +from torch.utils.data.dataloader import DataLoader +from rollouts import load_rollouts + + +class VAEDataset(Dataset): + def __init__(self, rollout_path): + super(VAEDataset, self).__init__() + rollouts = load_rollouts(rollout_path) + self.images = rollouts['observations'] + + def __getitem__(self, index): + pass + + def __len__(self): + return len(self.images) + diff --git a/rnn.py b/rnn.py index 946ce0e..eeac48a 100644 --- a/rnn.py +++ b/rnn.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from torch.nn import functional as F from torch import normal, multinomial +from torch.autograd import Variable class MDNRNN(nn.Module): @@ -86,17 +87,17 @@ class MDNRNN(nn.Module): return z, hidden_state -if __name__ == '__main__': - from torch.autograd import Variable - z_dim, action_dim, hidden_size, n_mixture, temp = 32, 2, 256, 5, 0.0 - batch_size = 1 - seq_len = 1 - mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp) - mdnrnn.cuda() - prev_z = Variable(torch.randn(batch_size, seq_len, z_dim)).cuda() - action = Variable(torch.randn(batch_size, seq_len, action_dim)).cuda() - - new_z, new_hidden_state = mdnrnn.sample(prev_z, action) - print(new_z) - pi, mean, sigma, hidden_state = mdnrnn.forward(prev_z, action) - print(sigma) \ No newline at end of file +# if __name__ == '__main__': +# from torch.autograd import Variable +# z_dim, action_dim, hidden_size, n_mixture, temp = 32, 2, 256, 5, 0.0 +# batch_size = 1 +# seq_len = 1 +# mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp) +# mdnrnn.cuda() +# prev_z = Variable(torch.randn(batch_size, seq_len, z_dim)).cuda() +# action = Variable(torch.randn(batch_size, seq_len, action_dim)).cuda() +# +# new_z, new_hidden_state = mdnrnn.sample(prev_z, action) +# print(new_z) +# pi, mean, sigma, hidden_state = mdnrnn.forward(prev_z, action) +# print(sigma) \ No newline at end of file diff --git a/rollouts.py b/rollouts.py new file mode 100644 index 0000000..04a45c7 --- /dev/null +++ b/rollouts.py @@ -0,0 +1,35 @@ +import gym +import torch + + +def random_rollouts(env, num_rollouts, render=False): + """ + This function collects random rollouts from a given environment. + """ + obs = env.reset() + num_obs = 0 + rollouts = {'observations': [obs], 'actions': ['0']} + while num_obs != num_rollouts: + if render: + env.render() + num_obs += 1 + action = env.action_space.sample() + obs, reward, done, _ = env.step(action) + rollouts['observations'].append(obs) + rollouts['actions'].append(action) + + if done: + obs = env.reset() + rollouts['observations'].append(obs) + rollouts['actions'].append('0') + return rollouts + + +def save_rollouts(rollouts): + torch.save(rollouts, 'rollouts.data') + + +def load_rollouts(fname): + return torch.load(fname) + + diff --git a/vae.py b/vae.py index 759637d..4a3cbdc 100644 --- a/vae.py +++ b/vae.py @@ -1,12 +1,110 @@ import torch from torch.nn import functional as F +from torch.autograd import Variable import torch.nn as nn +def make_conv_relu(inpt_kernel, output_kernel, kernel_size=4): + return nn.Sequential( + nn.Conv2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2), + nn.ReLU(inplace=True) + ) + + +def make_deconv_relu(inpt_kernel, output_kernel, kernel_size, use_activation=True): + return nn.Sequential( + nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2), + nn.ReLU(inplace=True) + ) if use_activation \ + else nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2) + + +# Reconstruction + KL divergence losses summed over all elements and batch +# https://github.com/pytorch/examples/blob/master/vae/main.py +def loss_function(recon_x, x, mu, logvar): + n, c, h, w = recon_x.size() + recon_x = recon_x.view(n, -1) + x = x.view(n, -1) + # L2 distance + l2_dist = torch.mean(torch.sqrt(torch.sum(torch.pow(recon_x - x, 2), 1))) + # see Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + return l2_dist + KLD + + class VAE(nn.Module): - def __init__(self): - # TODO: Variational Auto Encoder + def __init__(self, latent_vector_dim=32): super(VAE, self).__init__() + # encoder part + self.conv1 = make_conv_relu(3, 32) + self.conv2 = make_conv_relu(32, 64) + self.conv3 = make_conv_relu(64, 128) + self.conv4 = make_conv_relu(128, 256) + + self.mu = nn.Linear(1024, latent_vector_dim) + self.logvar = nn.Linear(1024, latent_vector_dim) + + self.z = nn.Linear(latent_vector_dim, 1024) + + # decoder part + self.deconv1 = make_deconv_relu(1024, 128, 5) + self.deconv2 = make_deconv_relu(128, 64, 5) + self.deconv3 = make_deconv_relu(64, 32, 6) + self.deconv4 = make_deconv_relu(32, 3, 6) + + self.sigmoid = nn.Sigmoid() + + def encode(self, x): + """ + Returns mean and log variance + """ + x = self.conv4(self.conv3(self.conv2(self.conv1(x)))) + x = x.view(x.size()[0], -1) + return self.mu(x), self.logvar(x) + + def sample(self, mu, logvar): + if self.training: + std = logvar.exp() + std = std * Variable(std.data.new(std.size()).normal_()) + return mu + std + else: + return mu + + def decode(self, z): + z = self.z(z) + n, d = z.size() + z = z.view(n, d, 1, 1) + reconstruction = self.deconv4(self.deconv3(self.deconv2(self.deconv1(z)))) + reconstruction = self.sigmoid(reconstruction) + return reconstruction def forward(self, x): - pass + """ + Returns reconstructed image, mean, and log variance + """ + mu, logvar = self.encode(x) + z = self.sample(mu, logvar) + x = self.decode(z) + return x, mu, logvar + + +# if __name__ == '__main__': +# import numpy as np +# import cv2 +# +# img = np.random.randn(64, 64, 3) +# gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda() +# +# vae = VAE() +# vae.cuda() +# x, mu, logvar = vae.forward(gpu_img) +# print(x.size()) +# print(loss_function(x, gpu_img, mu, logvar)) +# x = x.data.cpu().numpy()[0].transpose(1, 2, 0) +# +# cv2.imshow('original', img) +# cv2.imshow('reconstructed', x) +# cv2.waitKey() \ No newline at end of file