import torch from torch import nn from torch.nn import functional as F from ..util import mask_upper_triangular class LatentEncoder(nn.Module): def __init__( self, input_dim, hidden_size=32, latent_dim=32, min_std=0.01, dropout=0, nhead=8, num_layers=2, ): super().__init__() self.enc_emb = nn.Linear(input_dim, hidden_size) encoder_norm = nn.LayerNorm(hidden_size) layer_enc = nn.TransformerEncoderLayer( d_model=hidden_size, dim_feedforward=hidden_size*8, dropout=dropout, nhead=nhead, # activation ) self.encoder = nn.TransformerEncoder( layer_enc, num_layers=num_layers, norm=encoder_norm ) self.mean = nn.Linear(hidden_size*3, latent_dim) self.log_var = nn.Linear(hidden_size*3, latent_dim) self._min_std = min_std def forward(self, x, y): encoder_input = torch.cat([x, y], dim=-1) # Latent Encoder x = self.enc_emb(encoder_input) # Size([B, S, X]) -> Size([B, S, hidden_size]) x = x.permute(1, 0, 2) # (B,S,hidden_size) -> (S,B,hidden_size) # autoregressive mask device = next(self.parameters()).device N = x.shape[0] mask = mask_upper_triangular(N, device) r = self.encoder(x, mask=mask) r = r.permute(1, 0, 2) # (S,B,hidden_size) -> (B,S,hidden_size) # Aggregation (max/mean/last) r_mean = r.mean(1) # (B,S,hidden_size) -> (B,hidden_size) r_last = r[:, -1, :] r_max = r.max(1)[0] r = torch.cat([r_mean, r_last, r_max], -1) mean = self.mean(r) log_sigma = self.log_var(r) sigma = self._min_std + (1 - self._min_std) * torch.sigmoid(log_sigma * 0.5) dist = torch.distributions.Normal(mean, sigma) return dist class Decoder(nn.Module): def __init__( self, x_size, y_size, hidden_size=32, latent_dim=32, num_layers=3, use_deterministic_path=True, min_std=0.01, nhead=8, dropout=0, ): super(Decoder, self).__init__() self.dec_emb = nn.Linear(x_size, hidden_size) self.z_emb = nn.Linear(latent_dim, hidden_size) layer_dec = nn.TransformerDecoderLayer( d_model=hidden_size, dim_feedforward=hidden_size*8, dropout=dropout, nhead=nhead, ) decoder_norm = nn.LayerNorm(hidden_size) self._decoder = nn.TransformerDecoder( layer_dec, num_layers=num_layers, norm=decoder_norm ) self._mean = nn.Linear(hidden_size, y_size) self._std = nn.Linear(hidden_size, y_size) self._min_std = min_std def forward(self, z, x): # (B, S, latent_size) -> (B, S, H) -> (S, B, H) x = self.dec_emb(x).permute(1, 0, 2) # (B, S, latent_size) -> (B, S, H) -> (S, B, H) z = self.z_emb(z).permute(1, 0, 2) # autoregressive mask device = next(self.parameters()).device N = x.shape[0] mask = mask_upper_triangular(N, device) r = self._decoder(x, z, tgt_mask=mask) # [S, B, H] -> [B, S, H] r = r.permute(1, 0, 2).contiguous() # Get the mean and the variance mean = self._mean(r) log_sigma = self._std(r) # Bound or clamp the variance sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma) dist = torch.distributions.Normal(mean, sigma) return dist class TransformerProcess(nn.Module): """ A attempted simplification of an attentive neural process Works on sequential data, has no deterministic encoder. Uses full transformer layer instead of custom attention. Has an autoregressive mask on the encoder and decoder. """ def __init__(self, x_size, y_size, hidden_size=64, latent_dim=32, nhead=8, nlayers=4, dropout=0, min_std=0.01): super().__init__() self._min_std = min_std self._latent_encoder = LatentEncoder( x_size + y_size, hidden_size=hidden_size, latent_dim=latent_dim, num_layers=nlayers, dropout=dropout, min_std=min_std, nhead=nhead, ) self._decoder = Decoder( x_size, y_size, hidden_size=hidden_size, latent_dim=latent_dim, dropout=dropout, min_std=min_std, num_layers=nlayers, nhead=nhead, ) def forward(self, past_x, past_y, future_x, future_y=None): device = next(self.parameters()).device dist_prior = self._latent_encoder(past_x, past_y) if (future_y is not None): x = torch.cat([past_x, future_x], 1) y = torch.cat([past_y, future_y], 1) dist_post = self._latent_encoder(x, y) if self.training and (future_y is not None): # USe posterior during training, is possible z = dist_post.rsample() else: # During eval use the prior, also take the most probable z = dist_prior.loc num_targets = future_x.size(1) z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, S_target, H] dist_out = self._decoder(z, future_x) loss = None if future_y is not None: # Make sure output dist matches label log_p = dist_out.log_prob(future_y).mean(-1) # Making sure the encoded distribition from the past is as close as possible to the future kl_loss = torch.distributions.kl_divergence(dist_post, dist_prior).mean( -1 ) # [B, R].mean(-1) kl_loss = kl_loss[:, None].expand(log_p.shape) loss = (kl_loss - log_p).mean() return dist_out, {'loss': loss}