mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 19:32:35 +08:00
79 lines
3.0 KiB
Python
79 lines
3.0 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from ..util import mask_upper_triangular
|
|
|
|
class TransformerSeq2Seq(nn.Module):
|
|
def __init__(self, x_size, y_size, hidden_size=16, nhead=8, nlayers=2, attention_dropout=0, min_std=0.01, nan_value=0):
|
|
super().__init__()
|
|
self._min_std = min_std
|
|
self.nan_value = nan_value
|
|
|
|
self.enc_emb = nn.Linear(x_size + y_size, hidden_size)
|
|
self.dec_emb = nn.Linear(x_size, hidden_size)
|
|
|
|
encoder_norm = nn.LayerNorm(hidden_size)
|
|
layer_enc = nn.TransformerEncoderLayer(
|
|
d_model=hidden_size,
|
|
dim_feedforward=hidden_size*8,
|
|
dropout=attention_dropout,
|
|
nhead=nhead,
|
|
# activation
|
|
)
|
|
self.encoder = nn.TransformerEncoder(
|
|
layer_enc, num_layers=nlayers, norm=encoder_norm
|
|
)
|
|
|
|
layer_dec = nn.TransformerDecoderLayer(
|
|
d_model=hidden_size,
|
|
dim_feedforward=hidden_size*8,
|
|
dropout=attention_dropout,
|
|
nhead=nhead,
|
|
)
|
|
decoder_norm = nn.LayerNorm(hidden_size)
|
|
self.decoder = nn.TransformerDecoder(
|
|
layer_dec, num_layers=nlayers, norm=decoder_norm
|
|
)
|
|
self.mean = nn.Linear(hidden_size, y_size)
|
|
self.std = nn.Linear(hidden_size, y_size)
|
|
|
|
|
|
def forward(self, past_x, past_y, future_x, future_y=None):
|
|
device = next(self.parameters()).device
|
|
x = torch.cat([past_x, past_y], -1)
|
|
|
|
# Masks
|
|
future_mask = torch.isfinite(future_x) & (future_x!=self.nan_value)
|
|
tgt_key_padding_mask = ~future_mask.any(-1)
|
|
|
|
past_mask = torch.isfinite(x) & (x!=self.nan_value)
|
|
src_key_padding_mask = ~past_mask.any(-1)# * float('-inf')
|
|
|
|
# Embed
|
|
x = self.enc_emb(x)
|
|
# Size([B, C, X]) -> Size([B, C, hidden_dim])
|
|
future_x = self.dec_emb(future_x)
|
|
# Size([B, C, T]) -> Size([B, C, hidden_dim])
|
|
|
|
x = x.permute(1, 0, 2) # (B,C,hidden_dim) -> (C,B,hidden_dim)
|
|
future_x = future_x.permute(1, 0, 2)
|
|
# requires (C, B, hidden_dim)
|
|
memory = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
|
|
|
|
# In transformers the memory and future_x need to be the same length. Lets use a permutation invariant agg on the context
|
|
# Then expand it, so it's available as we decode, conditional on future_x
|
|
# (C, B, emb_dim) -> (B, emb_dim) -> (T, B, emb_dim)
|
|
S, B, H = future_x.shape
|
|
memory = memory.max(dim=0, keepdim=True)[0].repeat(1, S, 1)
|
|
outputs = self.decoder(future_x, memory, tgt_key_padding_mask=tgt_key_padding_mask)
|
|
|
|
# [T, B, emb_dim] -> [B, T, emb_dim]
|
|
outputs = outputs.permute(1, 0, 2).contiguous()
|
|
# Size([B, T, emb_dim])
|
|
mean = self.mean(outputs)
|
|
log_sigma = self.std(outputs)
|
|
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
|
return torch.distributions.Normal(mean, sigma), {}
|
|
|