finished MDN-RNN and need to implement VAE

This commit is contained in:
junhong
2018-04-04 22:43:48 -04:00
parent 076e45cac0
commit 21d1d5dd8d
2 changed files with 72 additions and 11 deletions
+60 -11
View File
@@ -1,8 +1,9 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import normal, multinomial
class MDNRNN(nn.Module):
def __init__(self, z_dim, action_dim, hidden_size, n_mixture, temperature):
"""
@@ -19,35 +20,83 @@ class MDNRNN(nn.Module):
super(MDNRNN, self).__init__()
# define rnn
self.inpt_size = z_dim + action_dim
self.hidden_size = hidden_size
self.n_mixture = n_mixture
self.z_dim = z_dim
self.rnn = nn.LSTM(input_size=self.inpt_size, hidden_size=hidden_size, batch_first=True)
# define MDN as fully connected layer
self.mdn = nn.Linear(hidden_size, n_mixture * z_dim * 2 + n_mixture)
self.tau = temperature
def forward(self, inpt, hidden_state, action):
def forward(self, inpt, action, hidden_state=None):
"""
:param inpt: a tensor of size (batch_size, seq_len, D)
:param hidden_state: two tensors of size (1, batch_size, hidden_size)
:param action: a tensor of (batch_size, seq_len, D*)
:param action: a tensor of (batch_size, seq_len, action_dim)
:return: pi, mean, sigma, hidden_state
"""
# TODO: Add forward function
return
batch_size, seq_len, _ = inpt.size()
if hidden_state is None:
# use new so that we do not need to know the tensor type explicitly.
hidden_state = (Variable(inpt.data.new(1, batch_size, self.hidden_size)),
Variable(inpt.data.new(1, batch_size, self.hidden_size)))
def sample(self, inpt, hidden_state, action):
# concatenate input and action, maybe we can use an extra fc layer to project action to a space same
# as inpt?
concat = torch.cat((inpt, action), dim=-1)
output, hidden_state = self.rnn(concat, hidden_state)
output = output.contiguous()
output = output.view(-1, self.hidden_size)
# N, seq_len, n_mixture * z_dim * 2 + n_mixture
mixture = self.mdn(output)
mixture = mixture.view(batch_size, seq_len, -1)
# N * seq_len, n_mixture * z_dim
mean = mixture[..., :self.n_mixture * self.z_dim]
sigma = mixture[..., self.n_mixture * self.z_dim: self.n_mixture * self.z_dim*2]
sigma = torch.exp(sigma)
# N * seq_len, n_mixture
pi = mixture[..., -self.n_mixture:]
pi = F.softmax(pi, -1)
# add temperature
if self.tau > 0:
pi /= self.tau
sigma *= self.tau ** 0.5
return pi, mean, sigma, hidden_state
def sample(self, inpt, action, hidden_state=None):
"""
Sample from a mixture of Gaussians. This function is only in testing, so batch_size=seq_len=1 for now.
parameters same as forward
:return:
"""
# forward and get pi, mean. sigma, hidden_state
pi, mean, sigma, hidden_state = self.forward(inpt, hidden_state, action)
pi, mean, sigma, hidden_state = self.forward(inpt, action, hidden_state)
batch_size, seq_len, _ = inpt.size()
pi, mean, sigma = pi.contiguous().view(-1), mean.contiguous().view(-1), sigma.contiguous().view(-1)
# randomly draw a mixture model
k = multinomial(pi, 1)
selected_mean = mean[..., k]
selected_sigma = sigma[..., k]
selected_mean = mean[int(self.z_dim * k): int(self.z_dim * (k+1))]
selected_sigma = sigma[int(self.z_dim * k): int(self.z_dim * (k+1))]
# sample from normal dist
z = normal(selected_mean, selected_sigma)
return z, hidden_state
return z, hidden_state
if __name__ == '__main__':
from torch.autograd import Variable
z_dim, action_dim, hidden_size, n_mixture, temp = 32, 2, 256, 5, 0.0
batch_size = 1
seq_len = 1
mdnrnn = MDNRNN(z_dim, action_dim, hidden_size, n_mixture, temp)
mdnrnn.cuda()
prev_z = Variable(torch.randn(batch_size, seq_len, z_dim)).cuda()
action = Variable(torch.randn(batch_size, seq_len, action_dim)).cuda()
new_z, new_hidden_state = mdnrnn.sample(prev_z, action)
print(new_z)
pi, mean, sigma, hidden_state = mdnrnn.forward(prev_z, action)
print(sigma)
+12
View File
@@ -0,0 +1,12 @@
import torch
from torch.nn import functional as F
import torch.nn as nn
class VAE(nn.Module):
def __init__(self):
# TODO: Variational Auto Encoder
super(VAE, self).__init__()
def forward(self, x):
pass