Files
wassname 27d4cde5bd tidy
2020-11-01 15:36:32 +08:00

61 lines
2.1 KiB
Python

import torch
from torch import nn
from torch.nn import functional as F
from ..util import mask_upper_triangular
class Transformer(nn.Module):
"""
A single transformer, masking nan or 0
"""
def __init__(self, x_dim, y_dim, attention_dropout=0, nhead=8, nlayers=8, hidden_size=32, nan_value=0, min_std=0.01):
super().__init__()
self._min_std = min_std
self.nan_value = nan_value
enc_x_dim = x_dim + y_dim
self.enc_emb = nn.Linear(enc_x_dim, hidden_size)
encoder_norm = nn.LayerNorm(hidden_size)
layer_enc = nn.TransformerEncoderLayer(
d_model=hidden_size,
dim_feedforward=hidden_size*8,
dropout=attention_dropout,
nhead=nhead,
# activation
)
self.encoder = nn.TransformerEncoder(
layer_enc, num_layers=nlayers, norm=encoder_norm
)
self.mean = nn.Linear(hidden_size, y_dim)
self.std = nn.Linear(hidden_size, y_dim)
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
B, S, _ = future_x.shape
future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device)
# future_y_fake = (
# torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * past_y[:, -1].repeat(B, S, 1)
# )
context = torch.cat([past_x, past_y], -1).detach()
target = torch.cat([future_x, future_y_fake], -1).detach()
x = torch.cat([context, target * 1], 1).detach()
x = self.enc_emb(x).permute(1, 0, 2)
S, B, _ = x.shape
mask = mask_upper_triangular(S, device)
outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
).permute(
1, 0, 2
)
# Seems to help a little, especially with extrapolating out of bounds
steps = past_y.shape[1]
mean = self.mean(outputs)[:, steps:, :]
log_sigma = self.std(outputs)[:, steps:, :]
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
return torch.distributions.Normal(mean, sigma), {}