diff --git a/seq2seq_time/models/transformer_process.py b/seq2seq_time/models/transformer_process.py index ca760e5..ed1f2ca 100644 --- a/seq2seq_time/models/transformer_process.py +++ b/seq2seq_time/models/transformer_process.py @@ -38,13 +38,17 @@ class LatentEncoder(nn.Module): def forward(self, x, y): encoder_input = torch.cat([x, y], dim=-1) # Latent Encoder - x = self.enc_emb(encoder_input) - # Size([B, C, X]) -> Size([B, C, hidden_size]) - x = x.permute(1, 0, 2) # (B,C,hidden_size) -> (C,B,hidden_size) - # requires (C, B, hidden_size) - r = self.encoder(x) - r = r.permute(1, 0, 2) # (C,B,hidden_size) -> (B,C,hidden_size) - r = r.mean(1) + x = self.enc_emb(encoder_input) # Size([B, S, X]) -> Size([B, S, hidden_size]) + x = x.permute(1, 0, 2) # (B,S,hidden_size) -> (S,B,hidden_size) + + # autoregressive mask + device = next(self.parameters()).device + N = x.shape[0] + mask = mask_upper_triangular(N, device) + + r = self.encoder(x, mask=mask) + r = r.permute(1, 0, 2) # (S,B,hidden_size) -> (B,S,hidden_size) + r = r.mean(1) # (B,S,hidden_size) -> (B,hidden_size) mean = self.mean(r) log_sigma = self.log_var(r) sigma = self._min_std + (1 - self._min_std) * torch.sigmoid(log_sigma * 0.5) @@ -83,20 +87,22 @@ class Decoder(nn.Module): self._std = nn.Linear(hidden_size, y_size) self._min_std = min_std - def forward(self, z, future_x): - # concatenate future_x and representation - future_x = self.dec_emb(future_x) - future_x = future_x.permute(1, 0, 2) + def forward(self, z, x): - z = self.z_emb(z) - z = z.permute(1, 0, 2) - # requires (C, B, hidden_size) + # (B, S, latent_size) -> (B, S, H) -> (S, B, H) + x = self.dec_emb(x).permute(1, 0, 2) - # r = torch.cat([z, future_x], dim=-1) + # (B, S, latent_size) -> (B, S, H) -> (S, B, H) + z = self.z_emb(z).permute(1, 0, 2) - r = self._decoder(future_x, z) + # autoregressive mask + device = next(self.parameters()).device + N = x.shape[0] + mask = mask_upper_triangular(N, device) + + r = self._decoder(x, z, tgt_mask=mask) - # [T, B, emb_dim] -> [B, T, emb_dim] + # [S, B, H] -> [B, S, H] r = r.permute(1, 0, 2).contiguous() # Get the mean and the variance @@ -112,7 +118,7 @@ class Decoder(nn.Module): class TransformerProcess(nn.Module): # WIP trying one that encodes a dist # TODO autoregressive mask - def __init__(self, x_size, y_size, hidden_size=16, latent_dim=32, nhead=8, nlayers=2, attention_dropout=0, min_std=0.01): + def __init__(self, x_size, y_size, hidden_size=64, latent_dim=32, nhead=8, nlayers=4, dropout=0, min_std=0.01): super().__init__() self._min_std = min_std @@ -121,7 +127,7 @@ class TransformerProcess(nn.Module): hidden_size=hidden_size, latent_dim=latent_dim, num_layers=nlayers, - dropout=attention_dropout, + dropout=dropout, min_std=min_std, nhead=nhead, ) @@ -131,7 +137,7 @@ class TransformerProcess(nn.Module): y_size, hidden_size=hidden_size, latent_dim=latent_dim, - dropout=attention_dropout, + dropout=dropout, min_std=min_std, num_layers=nlayers, nhead=nhead, @@ -143,9 +149,6 @@ class TransformerProcess(nn.Module): dist_prior = self._latent_encoder(past_x, past_y) if (future_y is not None): - # If future_y is provided, we can create an auxilary loss - # Making sure the encoded distribition from the past - # Is as close as possible to the future x = torch.cat([past_x, future_x], 1) y = torch.cat([past_y, future_y], 1) dist_post = self._latent_encoder(x, y) @@ -156,19 +159,19 @@ class TransformerProcess(nn.Module): else: z = dist_prior.loc num_targets = future_x.size(1) - z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H] + z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, S_target, H] - dist = self._decoder(z, future_x) + dist_out = self._decoder(z, future_x) loss = None if future_y is not None: - log_p = dist.log_prob(future_y).mean(-1) + # Make sure output dist matches label + log_p = dist_out.log_prob(future_y).mean(-1) + + # Making sure the encoded distribition from the past is as close as possible to the future kl_loss = torch.distributions.kl_divergence(dist_post, dist_prior).mean( -1 ) # [B, R].mean(-1) kl_loss = kl_loss[:, None].expand(log_p.shape) - mse_loss = F.mse_loss(dist.loc, future_y, reduction="none")[ - :, : past_x.size(1) - ].mean() loss = (kl_loss - log_p).mean() - return dist, {'loss': loss} + return dist_out, {'loss': loss}