Files
seq2seq-time/seq2seq_time/models/transformer_process.py
T
2020-11-13 15:56:16 +08:00

188 lines
5.9 KiB
Python

import torch
from torch import nn
from torch.nn import functional as F
from ..util import mask_upper_triangular
class LatentEncoder(nn.Module):
def __init__(
self,
input_dim,
hidden_size=32,
latent_dim=32,
min_std=0.01,
dropout=0,
nhead=8,
num_layers=2,
):
super().__init__()
self.enc_emb = nn.Linear(input_dim, hidden_size)
encoder_norm = nn.LayerNorm(hidden_size)
layer_enc = nn.TransformerEncoderLayer(
d_model=hidden_size,
dim_feedforward=hidden_size*8,
dropout=dropout,
nhead=nhead,
# activation
)
self.encoder = nn.TransformerEncoder(
layer_enc, num_layers=num_layers, norm=encoder_norm
)
self.mean = nn.Linear(hidden_size*3, latent_dim)
self.log_var = nn.Linear(hidden_size*3, latent_dim)
self._min_std = min_std
def forward(self, x, y):
encoder_input = torch.cat([x, y], dim=-1)
# Latent Encoder
x = self.enc_emb(encoder_input) # Size([B, S, X]) -> Size([B, S, hidden_size])
x = x.permute(1, 0, 2) # (B,S,hidden_size) -> (S,B,hidden_size)
# autoregressive mask
device = next(self.parameters()).device
N = x.shape[0]
mask = mask_upper_triangular(N, device)
r = self.encoder(x, mask=mask)
r = r.permute(1, 0, 2) # (S,B,hidden_size) -> (B,S,hidden_size)
# Aggregation (max/mean/last)
r_mean = r.mean(1) # (B,S,hidden_size) -> (B,hidden_size)
r_last = r[:, -1, :]
r_max = r.max(1)[0]
r = torch.cat([r_mean, r_last, r_max], -1)
mean = self.mean(r)
log_sigma = self.log_var(r)
sigma = self._min_std + (1 - self._min_std) * torch.sigmoid(log_sigma * 0.5)
dist = torch.distributions.Normal(mean, sigma)
return dist
class Decoder(nn.Module):
def __init__(
self,
x_size,
y_size,
hidden_size=32,
latent_dim=32,
num_layers=3,
use_deterministic_path=True,
min_std=0.01,
nhead=8,
dropout=0,
):
super(Decoder, self).__init__()
self.dec_emb = nn.Linear(x_size, hidden_size)
self.z_emb = nn.Linear(latent_dim, hidden_size)
layer_dec = nn.TransformerDecoderLayer(
d_model=hidden_size,
dim_feedforward=hidden_size*8,
dropout=dropout,
nhead=nhead,
)
decoder_norm = nn.LayerNorm(hidden_size)
self._decoder = nn.TransformerDecoder(
layer_dec, num_layers=num_layers, norm=decoder_norm
)
self._mean = nn.Linear(hidden_size, y_size)
self._std = nn.Linear(hidden_size, y_size)
self._min_std = min_std
def forward(self, z, x):
# (B, S, latent_size) -> (B, S, H) -> (S, B, H)
x = self.dec_emb(x).permute(1, 0, 2)
# (B, S, latent_size) -> (B, S, H) -> (S, B, H)
z = self.z_emb(z).permute(1, 0, 2)
# autoregressive mask
device = next(self.parameters()).device
N = x.shape[0]
mask = mask_upper_triangular(N, device)
r = self._decoder(x, z, tgt_mask=mask)
# [S, B, H] -> [B, S, H]
r = r.permute(1, 0, 2).contiguous()
# Get the mean and the variance
mean = self._mean(r)
log_sigma = self._std(r)
# Bound or clamp the variance
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
dist = torch.distributions.Normal(mean, sigma)
return dist
class TransformerProcess(nn.Module):
"""
A attempted simplification of an attentive neural process
Works on sequential data, has no deterministic encoder. Uses full transformer layer instead of custom attention. Has an autoregressive mask on the encoder and decoder.
"""
def __init__(self, x_size, y_size, hidden_size=64, latent_dim=32, nhead=8, nlayers=4, dropout=0, min_std=0.01):
super().__init__()
self._min_std = min_std
self._latent_encoder = LatentEncoder(
x_size + y_size,
hidden_size=hidden_size,
latent_dim=latent_dim,
num_layers=nlayers,
dropout=dropout,
min_std=min_std,
nhead=nhead,
)
self._decoder = Decoder(
x_size,
y_size,
hidden_size=hidden_size,
latent_dim=latent_dim,
dropout=dropout,
min_std=min_std,
num_layers=nlayers,
nhead=nhead,
)
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
dist_prior = self._latent_encoder(past_x, past_y)
if (future_y is not None):
x = torch.cat([past_x, future_x], 1)
y = torch.cat([past_y, future_y], 1)
dist_post = self._latent_encoder(x, y)
if self.training and (future_y is not None):
# USe posterior during training, is possible
z = dist_post.rsample()
else:
# During eval use the prior, also take the most probable
z = dist_prior.loc
num_targets = future_x.size(1)
z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, S_target, H]
dist_out = self._decoder(z, future_x)
loss = None
if future_y is not None:
# Make sure output dist matches label
log_p = dist_out.log_prob(future_y).mean(-1)
# Making sure the encoded distribition from the past is as close as possible to the future
kl_loss = torch.distributions.kl_divergence(dist_post, dist_prior).mean(
-1
) # [B, R].mean(-1)
kl_loss = kl_loss[:, None].expand(log_p.shape)
loss = (kl_loss - log_p).mean()
return dist_out, {'loss': loss}