mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 20:06:52 +08:00
190 lines
4.9 KiB
Python
190 lines
4.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.utils import weight_norm
|
|
from torch.nn import functional as F
|
|
|
|
class Chomp1d(nn.Module):
|
|
def __init__(self, chomp_size):
|
|
super(Chomp1d, self).__init__()
|
|
self.chomp_size = chomp_size
|
|
|
|
def forward(self, x):
|
|
return x[:, :, : -self.chomp_size].contiguous()
|
|
|
|
|
|
class Conv(nn.Module):
|
|
"""Causal convolution layer."""
|
|
|
|
def __init__(
|
|
self,
|
|
n_inputs,
|
|
n_outputs,
|
|
kernel_size,
|
|
stride,
|
|
dilation,
|
|
padding,
|
|
causal=True,
|
|
):
|
|
super().__init__()
|
|
self.conv = nn.Conv1d(
|
|
n_inputs,
|
|
n_outputs,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
)
|
|
self.chomp = Chomp1d(padding)
|
|
self.causal = causal
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
if self.causal:
|
|
out = self.chomp(out)
|
|
return out
|
|
|
|
|
|
class TemporalBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_inputs,
|
|
n_outputs,
|
|
kernel_size,
|
|
stride,
|
|
dilation,
|
|
padding,
|
|
dropout=0.2,
|
|
):
|
|
super(TemporalBlock, self).__init__()
|
|
self.conv1 = Conv(
|
|
n_inputs,
|
|
n_outputs,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
)
|
|
self.relu1 = nn.ReLU()
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
|
|
self.conv2 = Conv(
|
|
n_outputs,
|
|
n_outputs,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
)
|
|
self.relu2 = nn.ReLU()
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
self.net = nn.Sequential(
|
|
self.conv1, self.relu1, self.dropout1, self.conv2, self.relu2, self.dropout2
|
|
)
|
|
self.downsample = (
|
|
Conv(
|
|
n_inputs,
|
|
n_outputs,
|
|
1,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
causal=False,
|
|
)
|
|
if n_inputs != n_outputs
|
|
else None
|
|
)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
out = x
|
|
for i, l in enumerate(self.net):
|
|
out = l(out)
|
|
res = x if self.downsample is None else self.downsample(x)
|
|
return self.relu(out + res)
|
|
|
|
|
|
class TemporalConvNet(nn.Module):
|
|
"""
|
|
See:
|
|
- https://arxiv.org/pdf/1803.01271.pdf
|
|
- https://github.com/locuslab/TCN
|
|
"""
|
|
def __init__(
|
|
self,
|
|
num_inputs,
|
|
num_channels,
|
|
kernel_size=2,
|
|
dropout=0.2,
|
|
embedding_dim=2,
|
|
):
|
|
super(TemporalConvNet, self).__init__()
|
|
layers = []
|
|
num_levels = len(num_channels)
|
|
for i in range(num_levels):
|
|
dilation_size = 2 ** i
|
|
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
|
out_channels = num_channels[i]
|
|
layers += [
|
|
TemporalBlock(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
dilation=dilation_size,
|
|
padding=(kernel_size - 1) * dilation_size,
|
|
dropout=dropout,
|
|
)
|
|
]
|
|
self.network = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
out = x
|
|
for l in self.network:
|
|
out = l(out)
|
|
return out
|
|
|
|
|
|
class TCNSeq(nn.Module):
|
|
"""
|
|
See:
|
|
- https://arxiv.org/pdf/1803.01271.pdf
|
|
- https://github.com/locuslab/TCN
|
|
"""
|
|
def __init__(
|
|
self,
|
|
x_dim,
|
|
y_dim,
|
|
hidden_size=32,
|
|
nlayers=6,
|
|
kernel_size=2,
|
|
dropout=0.2,
|
|
):
|
|
super().__init__()
|
|
self.tcn = TemporalConvNet(
|
|
num_inputs=x_dim + y_dim,
|
|
kernel_size=kernel_size,
|
|
num_channels=[hidden_size] * nlayers,
|
|
dropout=dropout)
|
|
self._min_std = 0.01
|
|
self.mean = nn.Linear(hidden_size, y_dim)
|
|
self.std = nn.Linear(hidden_size, y_dim)
|
|
|
|
def forward(self, past_x, past_y, future_x, future_y=None):
|
|
device = next(self.parameters()).device
|
|
B, S, _ = future_x.shape
|
|
future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device)
|
|
context = torch.cat([past_x, past_y], -1)
|
|
target = torch.cat([future_x, future_y_fake], -1)
|
|
x = torch.cat([context, target * 1], 1).detach()
|
|
|
|
out = self.tcn(x.permute(0, 2, 1)).permute(0, 2, 1)
|
|
|
|
# Seems to help a little, especially with extrapolating out of bounds
|
|
steps = past_y.shape[1]
|
|
mean = self.mean(out)[:, steps:, :]
|
|
log_sigma = self.std(out)[:, steps:, :]
|
|
|
|
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
|
return torch.distributions.Normal(mean, sigma), {}
|