diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5cc6066 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +# Created by .ignore support plugin (hsz.mobi) +.gitignore +.idea/ diff --git a/controller.py b/controller.py new file mode 100644 index 0000000..e69de29 diff --git a/model.py b/model.py new file mode 100644 index 0000000..e69de29 diff --git a/rnn.py b/rnn.py new file mode 100644 index 0000000..8c64f62 --- /dev/null +++ b/rnn.py @@ -0,0 +1,53 @@ +import torch.nn as nn +from torch import normal, multinomial + + + +class MDNRNN(nn.Module): + def __init__(self, z_dim, action_dim, hidden_size, n_mixture, temperature): + """ + :param z_dim: the dimension of VAE latent variable + :param hidden_size: hidden size of RNN + :param n_mixture: number of Gaussian Mixture Models to be used + :param temperature: controls the randomness of the model + + MDNRNN stands for Mixture Density Network - RNN. + The output of this model is [mean, sigma^2, K], + where mean and sigma^2 have z_dim * n_mixture elements and + K has n_mixture elements. + """ + super(MDNRNN, self).__init__() + # define rnn + self.inpt_size = z_dim + action_dim + self.rnn = nn.LSTM(input_size=self.inpt_size, hidden_size=hidden_size, batch_first=True) + + # define MDN as fully connected layer + self.mdn = nn.Linear(hidden_size, n_mixture * z_dim * 2 + n_mixture) + self.tau = temperature + + def forward(self, inpt, hidden_state, action): + """ + :param inpt: a tensor of size (batch_size, seq_len, D) + :param hidden_state: two tensors of size (1, batch_size, hidden_size) + :param action: a tensor of (batch_size, seq_len, D*) + :return: pi, mean, sigma, hidden_state + """ + # TODO: Add forward function + return + + def sample(self, inpt, hidden_state, action): + """ + parameters same as forward + :return: + """ + # forward and get pi, mean. sigma, hidden_state + pi, mean, sigma, hidden_state = self.forward(inpt, hidden_state, action) + + # randomly draw a mixture model + k = multinomial(pi, 1) + selected_mean = mean[..., k] + selected_sigma = sigma[..., k] + + # sample from normal dist + z = normal(selected_mean, selected_sigma) + return z, hidden_state \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..e69de29 diff --git a/vae.py b/vae.py new file mode 100644 index 0000000..e69de29