mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 17:50:09 +08:00
tcn
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from torch.nn import functional as F
|
||||
|
||||
class Chomp1d(nn.Module):
|
||||
def __init__(self, chomp_size):
|
||||
@@ -114,7 +114,6 @@ class TemporalConvNet(nn.Module):
|
||||
self,
|
||||
num_inputs,
|
||||
num_channels,
|
||||
num_embeddings=0,
|
||||
kernel_size=2,
|
||||
dropout=0.2,
|
||||
embedding_dim=2,
|
||||
@@ -144,3 +143,47 @@ class TemporalConvNet(nn.Module):
|
||||
for l in self.network:
|
||||
out = l(out)
|
||||
return out
|
||||
|
||||
|
||||
class TCNSeq2Seq(nn.Module):
|
||||
"""
|
||||
See:
|
||||
- https://arxiv.org/pdf/1803.01271.pdf
|
||||
- https://github.com/locuslab/TCN
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
x_dim,
|
||||
y_dim,
|
||||
hidden_size=32,
|
||||
nlayers=6,
|
||||
kernel_size=2,
|
||||
dropout=0.2,
|
||||
embedding_dim=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.tcn = TemporalConvNet(
|
||||
num_inputs=x_dim+y_dim,
|
||||
num_channels=[hidden_size] * nlayers,
|
||||
dropout=dropout)
|
||||
self._min_std = 0.01
|
||||
self.mean = nn.Linear(hidden_size, y_dim)
|
||||
self.std = nn.Linear(hidden_size, y_dim)
|
||||
|
||||
def forward(self, past_x, past_y, future_x, future_y=None):
|
||||
device = next(self.parameters()).device
|
||||
B, S, _ = future_x.shape
|
||||
future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device)
|
||||
context = torch.cat([past_x, past_y], -1)
|
||||
target = torch.cat([future_x, future_y_fake], -1)
|
||||
x = torch.cat([context, target * 1], 1).detach()
|
||||
|
||||
out = self.tcn(x.permute(0, 2, 1)).permute(0, 2, 1)
|
||||
|
||||
# Seems to help a little, especially with extrapolating out of bounds
|
||||
steps = past_y.shape[1]
|
||||
mean = self.mean(out)[:, steps:, :]
|
||||
log_sigma = self.std(out)[:, steps:, :]
|
||||
|
||||
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
||||
return torch.distributions.Normal(mean, sigma), {}
|
||||
|
||||
Reference in New Issue
Block a user