mirror of
https://github.com/wassname/world-models-pytorch.git
synced 2026-06-27 16:18:00 +08:00
finished VAE and VAE loss function, working on dataset
This commit is contained in:
+17
@@ -0,0 +1,17 @@
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from rollouts import load_rollouts
|
||||
|
||||
|
||||
class VAEDataset(Dataset):
|
||||
def __init__(self, rollout_path):
|
||||
super(VAEDataset, self).__init__()
|
||||
rollouts = load_rollouts(rollout_path)
|
||||
self.images = rollouts['observations']
|
||||
|
||||
def __getitem__(self, index):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch import normal, multinomial
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class MDNRNN(nn.Module):
|
||||
@@ -86,17 +87,17 @@ class MDNRNN(nn.Module):
|
||||
return z, hidden_state
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from torch.autograd import Variable
|
||||
z_dim, action_dim, hidden_size, n_mixture, temp = 32, 2, 256, 5, 0.0
|
||||
batch_size = 1
|
||||
seq_len = 1
|
||||
mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp)
|
||||
mdnrnn.cuda()
|
||||
prev_z = Variable(torch.randn(batch_size, seq_len, z_dim)).cuda()
|
||||
action = Variable(torch.randn(batch_size, seq_len, action_dim)).cuda()
|
||||
|
||||
new_z, new_hidden_state = mdnrnn.sample(prev_z, action)
|
||||
print(new_z)
|
||||
pi, mean, sigma, hidden_state = mdnrnn.forward(prev_z, action)
|
||||
print(sigma)
|
||||
# if __name__ == '__main__':
|
||||
# from torch.autograd import Variable
|
||||
# z_dim, action_dim, hidden_size, n_mixture, temp = 32, 2, 256, 5, 0.0
|
||||
# batch_size = 1
|
||||
# seq_len = 1
|
||||
# mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp)
|
||||
# mdnrnn.cuda()
|
||||
# prev_z = Variable(torch.randn(batch_size, seq_len, z_dim)).cuda()
|
||||
# action = Variable(torch.randn(batch_size, seq_len, action_dim)).cuda()
|
||||
#
|
||||
# new_z, new_hidden_state = mdnrnn.sample(prev_z, action)
|
||||
# print(new_z)
|
||||
# pi, mean, sigma, hidden_state = mdnrnn.forward(prev_z, action)
|
||||
# print(sigma)
|
||||
+35
@@ -0,0 +1,35 @@
|
||||
import gym
|
||||
import torch
|
||||
|
||||
|
||||
def random_rollouts(env, num_rollouts, render=False):
|
||||
"""
|
||||
This function collects random rollouts from a given environment.
|
||||
"""
|
||||
obs = env.reset()
|
||||
num_obs = 0
|
||||
rollouts = {'observations': [obs], 'actions': ['0']}
|
||||
while num_obs != num_rollouts:
|
||||
if render:
|
||||
env.render()
|
||||
num_obs += 1
|
||||
action = env.action_space.sample()
|
||||
obs, reward, done, _ = env.step(action)
|
||||
rollouts['observations'].append(obs)
|
||||
rollouts['actions'].append(action)
|
||||
|
||||
if done:
|
||||
obs = env.reset()
|
||||
rollouts['observations'].append(obs)
|
||||
rollouts['actions'].append('0')
|
||||
return rollouts
|
||||
|
||||
|
||||
def save_rollouts(rollouts):
|
||||
torch.save(rollouts, 'rollouts.data')
|
||||
|
||||
|
||||
def load_rollouts(fname):
|
||||
return torch.load(fname)
|
||||
|
||||
|
||||
@@ -1,12 +1,110 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Variable
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_conv_relu(inpt_kernel, output_kernel, kernel_size=4):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
def make_deconv_relu(inpt_kernel, output_kernel, kernel_size, use_activation=True):
|
||||
return nn.Sequential(
|
||||
nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2),
|
||||
nn.ReLU(inplace=True)
|
||||
) if use_activation \
|
||||
else nn.ConvTranspose2d(in_channels=inpt_kernel, out_channels=output_kernel, kernel_size=kernel_size, stride=2)
|
||||
|
||||
|
||||
# Reconstruction + KL divergence losses summed over all elements and batch
|
||||
# https://github.com/pytorch/examples/blob/master/vae/main.py
|
||||
def loss_function(recon_x, x, mu, logvar):
|
||||
n, c, h, w = recon_x.size()
|
||||
recon_x = recon_x.view(n, -1)
|
||||
x = x.view(n, -1)
|
||||
# L2 distance
|
||||
l2_dist = torch.mean(torch.sqrt(torch.sum(torch.pow(recon_x - x, 2), 1)))
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
return l2_dist + KLD
|
||||
|
||||
|
||||
class VAE(nn.Module):
|
||||
def __init__(self):
|
||||
# TODO: Variational Auto Encoder
|
||||
def __init__(self, latent_vector_dim=32):
|
||||
super(VAE, self).__init__()
|
||||
# encoder part
|
||||
self.conv1 = make_conv_relu(3, 32)
|
||||
self.conv2 = make_conv_relu(32, 64)
|
||||
self.conv3 = make_conv_relu(64, 128)
|
||||
self.conv4 = make_conv_relu(128, 256)
|
||||
|
||||
self.mu = nn.Linear(1024, latent_vector_dim)
|
||||
self.logvar = nn.Linear(1024, latent_vector_dim)
|
||||
|
||||
self.z = nn.Linear(latent_vector_dim, 1024)
|
||||
|
||||
# decoder part
|
||||
self.deconv1 = make_deconv_relu(1024, 128, 5)
|
||||
self.deconv2 = make_deconv_relu(128, 64, 5)
|
||||
self.deconv3 = make_deconv_relu(64, 32, 6)
|
||||
self.deconv4 = make_deconv_relu(32, 3, 6)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def encode(self, x):
|
||||
"""
|
||||
Returns mean and log variance
|
||||
"""
|
||||
x = self.conv4(self.conv3(self.conv2(self.conv1(x))))
|
||||
x = x.view(x.size()[0], -1)
|
||||
return self.mu(x), self.logvar(x)
|
||||
|
||||
def sample(self, mu, logvar):
|
||||
if self.training:
|
||||
std = logvar.exp()
|
||||
std = std * Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std
|
||||
else:
|
||||
return mu
|
||||
|
||||
def decode(self, z):
|
||||
z = self.z(z)
|
||||
n, d = z.size()
|
||||
z = z.view(n, d, 1, 1)
|
||||
reconstruction = self.deconv4(self.deconv3(self.deconv2(self.deconv1(z))))
|
||||
reconstruction = self.sigmoid(reconstruction)
|
||||
return reconstruction
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
"""
|
||||
Returns reconstructed image, mean, and log variance
|
||||
"""
|
||||
mu, logvar = self.encode(x)
|
||||
z = self.sample(mu, logvar)
|
||||
x = self.decode(z)
|
||||
return x, mu, logvar
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# import numpy as np
|
||||
# import cv2
|
||||
#
|
||||
# img = np.random.randn(64, 64, 3)
|
||||
# gpu_img = Variable(torch.from_numpy(img[np.newaxis].transpose(0, 3, 1, 2))).float().cuda()
|
||||
#
|
||||
# vae = VAE()
|
||||
# vae.cuda()
|
||||
# x, mu, logvar = vae.forward(gpu_img)
|
||||
# print(x.size())
|
||||
# print(loss_function(x, gpu_img, mu, logvar))
|
||||
# x = x.data.cpu().numpy()[0].transpose(1, 2, 0)
|
||||
#
|
||||
# cv2.imshow('original', img)
|
||||
# cv2.imshow('reconstructed', x)
|
||||
# cv2.waitKey()
|
||||
Reference in New Issue
Block a user