mirror of
https://github.com/wassname/world-models-pytorch.git
synced 2026-06-27 16:18:00 +08:00
53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
import torch.nn as nn
|
|
from torch import normal, multinomial
|
|
|
|
|
|
|
|
class MDNRNN(nn.Module):
|
|
def __init__(self, z_dim, action_dim, hidden_size, n_mixture, temperature):
|
|
"""
|
|
:param z_dim: the dimension of VAE latent variable
|
|
:param hidden_size: hidden size of RNN
|
|
:param n_mixture: number of Gaussian Mixture Models to be used
|
|
:param temperature: controls the randomness of the model
|
|
|
|
MDNRNN stands for Mixture Density Network - RNN.
|
|
The output of this model is [mean, sigma^2, K],
|
|
where mean and sigma^2 have z_dim * n_mixture elements and
|
|
K has n_mixture elements.
|
|
"""
|
|
super(MDNRNN, self).__init__()
|
|
# define rnn
|
|
self.inpt_size = z_dim + action_dim
|
|
self.rnn = nn.LSTM(input_size=self.inpt_size, hidden_size=hidden_size, batch_first=True)
|
|
|
|
# define MDN as fully connected layer
|
|
self.mdn = nn.Linear(hidden_size, n_mixture * z_dim * 2 + n_mixture)
|
|
self.tau = temperature
|
|
|
|
def forward(self, inpt, hidden_state, action):
|
|
"""
|
|
:param inpt: a tensor of size (batch_size, seq_len, D)
|
|
:param hidden_state: two tensors of size (1, batch_size, hidden_size)
|
|
:param action: a tensor of (batch_size, seq_len, D*)
|
|
:return: pi, mean, sigma, hidden_state
|
|
"""
|
|
# TODO: Add forward function
|
|
return
|
|
|
|
def sample(self, inpt, hidden_state, action):
|
|
"""
|
|
parameters same as forward
|
|
:return:
|
|
"""
|
|
# forward and get pi, mean. sigma, hidden_state
|
|
pi, mean, sigma, hidden_state = self.forward(inpt, hidden_state, action)
|
|
|
|
# randomly draw a mixture model
|
|
k = multinomial(pi, 1)
|
|
selected_mean = mean[..., k]
|
|
selected_sigma = sigma[..., k]
|
|
|
|
# sample from normal dist
|
|
z = normal(selected_mean, selected_sigma)
|
|
return z, hidden_state |