Files
seq2seq-time/seq2seq_time/models/lstm.py
T
2020-10-19 20:51:23 +08:00

40 lines
1.5 KiB
Python

import torch
from torch import nn
from torch.nn import functional as F
class LSTM(nn.Module):
def __init__(self, input_size, output_size, hidden_size=64, lstm_layers=3, lstm_dropout=0, _min_std = 0.05, nan_value=0):
super().__init__()
self._min_std = _min_std
self.nan_value = nan_value
self.lstm = nn.LSTM(
input_size=input_size + output_size,
hidden_size=hidden_size,
batch_first=True,
num_layers=lstm_layers,
dropout=lstm_dropout,
)
self.mean = nn.Linear(hidden_size, output_size)
self.std = nn.Linear(hidden_size, output_size)
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
future_y_fake = (
torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * self.nan_value
)
context = torch.cat([past_x, past_y], -1).detach()
target = torch.cat([future_x, future_y_fake], -1).detach()
x = torch.cat([context, target * 1], 1).detach()
steps = past_y.shape[1]
outputs, _ = self.lstm(x)
outputs = outputs[:, steps:, :]
# outputs: [B, T, num_direction * H]
mean = self.mean(outputs)
log_sigma = self.std(outputs)
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
y_dist = torch.distributions.Normal(mean, sigma)
return y_dist, {}