Files
2020-11-01 15:49:57 +08:00

79 lines
3.0 KiB
Python

import torch
from torch import nn
from torch.nn import functional as F
from ..util import mask_upper_triangular
class TransformerSeq2Seq(nn.Module):
def __init__(self, x_size, y_size, hidden_size=16, nhead=8, nlayers=2, attention_dropout=0, min_std=0.01, nan_value=0):
super().__init__()
self._min_std = min_std
self.nan_value = nan_value
self.enc_emb = nn.Linear(x_size + y_size, hidden_size)
self.dec_emb = nn.Linear(x_size, hidden_size)
encoder_norm = nn.LayerNorm(hidden_size)
layer_enc = nn.TransformerEncoderLayer(
d_model=hidden_size,
dim_feedforward=hidden_size*8,
dropout=attention_dropout,
nhead=nhead,
# activation
)
self.encoder = nn.TransformerEncoder(
layer_enc, num_layers=nlayers, norm=encoder_norm
)
layer_dec = nn.TransformerDecoderLayer(
d_model=hidden_size,
dim_feedforward=hidden_size*8,
dropout=attention_dropout,
nhead=nhead,
)
decoder_norm = nn.LayerNorm(hidden_size)
self.decoder = nn.TransformerDecoder(
layer_dec, num_layers=nlayers, norm=decoder_norm
)
self.mean = nn.Linear(hidden_size, y_size)
self.std = nn.Linear(hidden_size, y_size)
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
x = torch.cat([past_x, past_y], -1)
# Masks
future_mask = torch.isfinite(future_x) & (future_x!=self.nan_value)
tgt_key_padding_mask = ~future_mask.any(-1)
past_mask = torch.isfinite(x) & (x!=self.nan_value)
src_key_padding_mask = ~past_mask.any(-1)# * float('-inf')
# Embed
x = self.enc_emb(x)
# Size([B, C, X]) -> Size([B, C, hidden_dim])
future_x = self.dec_emb(future_x)
# Size([B, C, T]) -> Size([B, C, hidden_dim])
x = x.permute(1, 0, 2) # (B,C,hidden_dim) -> (C,B,hidden_dim)
future_x = future_x.permute(1, 0, 2)
# requires (C, B, hidden_dim)
memory = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
# In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context
# Then expand it, so it's available as we decode, conditional on future_x
# (C, B, emb_dim) -> (B, emb_dim) -> (T, B, emb_dim)
S, B, H = future_x.shape
memory = memory.max(dim=0, keepdim=True)[0].repeat(1, S, 1)
outputs = self.decoder(future_x, memory, tgt_key_padding_mask=tgt_key_padding_mask)
# [T, B, emb_dim] -> [B, T, emb_dim]
outputs = outputs.permute(1, 0, 2).contiguous()
# Size([B, T, emb_dim])
mean = self.mean(outputs)
log_sigma = self.std(outputs)
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
return torch.distributions.Normal(mean, sigma), {}