diff --git a/seq2seq_time/models/tcn.py b/seq2seq_time/models/tcn.py index e843c52..baceed3 100644 --- a/seq2seq_time/models/tcn.py +++ b/seq2seq_time/models/tcn.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.nn.utils import weight_norm - +from torch.nn import functional as F class Chomp1d(nn.Module): def __init__(self, chomp_size): @@ -114,7 +114,6 @@ class TemporalConvNet(nn.Module): self, num_inputs, num_channels, - num_embeddings=0, kernel_size=2, dropout=0.2, embedding_dim=2, @@ -144,3 +143,47 @@ class TemporalConvNet(nn.Module): for l in self.network: out = l(out) return out + + +class TCNSeq2Seq(nn.Module): + """ + See: + - https://arxiv.org/pdf/1803.01271.pdf + - https://github.com/locuslab/TCN + """ + def __init__( + self, + x_dim, + y_dim, + hidden_size=32, + nlayers=6, + kernel_size=2, + dropout=0.2, + embedding_dim=2, + ): + super().__init__() + self.tcn = TemporalConvNet( + num_inputs=x_dim+y_dim, + num_channels=[hidden_size] * nlayers, + dropout=dropout) + self._min_std = 0.01 + self.mean = nn.Linear(hidden_size, y_dim) + self.std = nn.Linear(hidden_size, y_dim) + + def forward(self, past_x, past_y, future_x, future_y=None): + device = next(self.parameters()).device + B, S, _ = future_x.shape + future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device) + context = torch.cat([past_x, past_y], -1) + target = torch.cat([future_x, future_y_fake], -1) + x = torch.cat([context, target * 1], 1).detach() + + out = self.tcn(x.permute(0, 2, 1)).permute(0, 2, 1) + + # Seems to help a little, especially with extrapolating out of bounds + steps = past_y.shape[1] + mean = self.mean(out)[:, steps:, :] + log_sigma = self.std(out)[:, steps:, :] + + sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma) + return torch.distributions.Normal(mean, sigma), {}