mirror of
https://github.com/wassname/world-models-pytorch.git
synced 2026-06-27 17:33:07 +08:00
finished VAE and VAE loss function, working on dataset
This commit is contained in:
+17
@@ -0,0 +1,17 @@
|
|||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
from rollouts import load_rollouts
|
||||||
|
|
||||||
|
|
||||||
|
class VAEDataset(Dataset):
|
||||||
|
def __init__(self, rollout_path):
|
||||||
|
super(VAEDataset, self).__init__()
|
||||||
|
rollouts = load_rollouts(rollout_path)
|
||||||
|
self.images = rollouts['observations']
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.images)
|
||||||
|
|
||||||
@@ -2,6 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch import normal, multinomial
|
from torch import normal, multinomial
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
|
||||||
class MDNRNN(nn.Module):
|
class MDNRNN(nn.Module):
|
||||||
@@ -86,17 +87,17 @@ class MDNRNN(nn.Module):
|
|||||||
return z, hidden_state
|
return z, hidden_state
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
# if __name__ == '__main__':
|
||||||
from torch.autograd import Variable
|
# from torch.autograd import Variable
|
||||||
z_dim, action_dim, hidden_size, n_mixture, temp = 32, 2, 256, 5, 0.0
|
# z_dim, action_dim, hidden_size, n_mixture, temp = 32, 2, 256, 5, 0.0
|
||||||
batch_size = 1
|
# batch_size = 1
|
||||||
seq_len = 1
|
# seq_len = 1
|
||||||
mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp)
|
# mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp)
|
||||||
mdnrnn.cuda()
|
# mdnrnn.cuda()
|
||||||
prev_z = Variable(torch.randn(batch_size, seq_len, z_dim)).cuda()
|
# prev_z = Variable(torch.randn(batch_size, seq_len, z_dim)).cuda()
|
||||||
action = Variable(torch.randn(batch_size, seq_len, action_dim)).cuda()
|
# action = Variable(torch.randn(batch_size, seq_len, action_dim)).cuda()
|
||||||
|
#
|
||||||
new_z, new_hidden_state = mdnrnn.sample(prev_z, action)
|
# new_z, new_hidden_state = mdnrnn.sample(prev_z, action)
|
||||||
print(new_z)
|
# print(new_z)
|
||||||
pi, mean, sigma, hidden_state = mdnrnn.forward(prev_z, action)
|
# pi, mean, sigma, hidden_state = mdnrnn.forward(prev_z, action)
|
||||||
print(sigma)
|
# print(sigma)
|
||||||
+35
@@ -0,0 +1,35 @@
|
|||||||
|
import gym
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def random_rollouts(env, num_rollouts, render=False):
|
||||||
|
"""
|
||||||
|
This function collects random rollouts from a given environment.
|
||||||
|
"""
|
||||||
|
obs = env.reset()
|
||||||
|
num_obs = 0
|
||||||
|
rollouts = {'observations': [obs], 'actions': ['0']}
|
||||||
|
while num_obs != num_rollouts:
|
||||||
|
if render:
|
||||||
|
env.render()
|
||||||
|
num_obs += 1
|
||||||
|
action = env.action_space.sample()
|
||||||
|
obs, reward, done, _ = env.step(action)
|
||||||
|
rollouts['observations'].append(obs)
|
||||||
|
rollouts['actions'].append(action)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
obs = env.reset()
|
||||||
|
rollouts['observations'].append(obs)
|
||||||
|
rollouts['actions'].append('0')
|
||||||
|
return rollouts
|
||||||
|
|
||||||
|
|
||||||
|
def save_rollouts(rollouts):
|
||||||
|
torch.save(rollouts, 'rollouts.data')
|
||||||
|
|
||||||
|
|
||||||
|
def load_rollouts(fname):
|
||||||
|
return torch.load(fname)
|
||||||
|
|
||||||
|
|
||||||
@@ -1,12 +1,110 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torch.autograd import Variable
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def make_conv_relu(inpt_kernel, output_kernel, kernel_size=4):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_deconv_relu(inpt_kernel, output_kernel, kernel_size, use_activation=True):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
) if use_activation \
|
||||||
|
else nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2)
|
||||||
|
|
||||||
|
|
||||||
|
# Reconstruction + KL divergence losses summed over all elements and batch
|
||||||
|
# https://github.com/pytorch/examples/blob/master/vae/main.py
|
||||||
|
def loss_function(recon_x, x, mu, logvar):
|
||||||
|
n, c, h, w = recon_x.size()
|
||||||
|
recon_x = recon_x.view(n, -1)
|
||||||
|
x = x.view(n, -1)
|
||||||
|
# L2 distance
|
||||||
|
l2_dist = torch.mean(torch.sqrt(torch.sum(torch.pow(recon_x - x, 2), 1)))
|
||||||
|
# see Appendix B from VAE paper:
|
||||||
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||||
|
# https://arxiv.org/abs/1312.6114
|
||||||
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||||
|
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
|
return l2_dist + KLD
|
||||||
|
|
||||||
|
|
||||||
class VAE(nn.Module):
|
class VAE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, latent_vector_dim=32):
|
||||||
# TODO: Variational Auto Encoder
|
|
||||||
super(VAE, self).__init__()
|
super(VAE, self).__init__()
|
||||||
|
# encoder part
|
||||||
|
self.conv1 = make_conv_relu(3, 32)
|
||||||
|
self.conv2 = make_conv_relu(32, 64)
|
||||||
|
self.conv3 = make_conv_relu(64, 128)
|
||||||
|
self.conv4 = make_conv_relu(128, 256)
|
||||||
|
|
||||||
|
self.mu = nn.Linear(1024, latent_vector_dim)
|
||||||
|
self.logvar = nn.Linear(1024, latent_vector_dim)
|
||||||
|
|
||||||
|
self.z = nn.Linear(latent_vector_dim, 1024)
|
||||||
|
|
||||||
|
# decoder part
|
||||||
|
self.deconv1 = make_deconv_relu(1024, 128, 5)
|
||||||
|
self.deconv2 = make_deconv_relu(128, 64, 5)
|
||||||
|
self.deconv3 = make_deconv_relu(64, 32, 6)
|
||||||
|
self.deconv4 = make_deconv_relu(32, 3, 6)
|
||||||
|
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
"""
|
||||||
|
Returns mean and log variance
|
||||||
|
"""
|
||||||
|
x = self.conv4(self.conv3(self.conv2(self.conv1(x))))
|
||||||
|
x = x.view(x.size()[0], -1)
|
||||||
|
return self.mu(x), self.logvar(x)
|
||||||
|
|
||||||
|
def sample(self, mu, logvar):
|
||||||
|
if self.training:
|
||||||
|
std = logvar.exp()
|
||||||
|
std = std * Variable(std.data.new(std.size()).normal_())
|
||||||
|
return mu + std
|
||||||
|
else:
|
||||||
|
return mu
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
z = self.z(z)
|
||||||
|
n, d = z.size()
|
||||||
|
z = z.view(n, d, 1, 1)
|
||||||
|
reconstruction = self.deconv4(self.deconv3(self.deconv2(self.deconv1(z))))
|
||||||
|
reconstruction = self.sigmoid(reconstruction)
|
||||||
|
return reconstruction
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
pass
|
"""
|
||||||
|
Returns reconstructed image, mean, and log variance
|
||||||
|
"""
|
||||||
|
mu, logvar = self.encode(x)
|
||||||
|
z = self.sample(mu, logvar)
|
||||||
|
x = self.decode(z)
|
||||||
|
return x, mu, logvar
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# import numpy as np
|
||||||
|
# import cv2
|
||||||
|
#
|
||||||
|
# img = np.random.randn(64, 64, 3)
|
||||||
|
# gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
|
||||||
|
#
|
||||||
|
# vae = VAE()
|
||||||
|
# vae.cuda()
|
||||||
|
# x, mu, logvar = vae.forward(gpu_img)
|
||||||
|
# print(x.size())
|
||||||
|
# print(loss_function(x, gpu_img, mu, logvar))
|
||||||
|
# x = x.data.cpu().numpy()[0].transpose(1, 2, 0)
|
||||||
|
#
|
||||||
|
# cv2.imshow('original', img)
|
||||||
|
# cv2.imshow('reconstructed', x)
|
||||||
|
# cv2.waitKey()
|
||||||
Reference in New Issue
Block a user