mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 16:31:46 +08:00
working
This commit is contained in:
@@ -38,13 +38,17 @@ class LatentEncoder(nn.Module):
|
||||
def forward(self, x, y):
|
||||
encoder_input = torch.cat([x, y], dim=-1)
|
||||
# Latent Encoder
|
||||
x = self.enc_emb(encoder_input)
|
||||
# Size([B, C, X]) -> Size([B, C, hidden_size])
|
||||
x = x.permute(1, 0, 2) # (B,C,hidden_size) -> (C,B,hidden_size)
|
||||
# requires (C, B, hidden_size)
|
||||
r = self.encoder(x)
|
||||
r = r.permute(1, 0, 2) # (C,B,hidden_size) -> (B,C,hidden_size)
|
||||
r = r.mean(1)
|
||||
x = self.enc_emb(encoder_input) # Size([B, S, X]) -> Size([B, S, hidden_size])
|
||||
x = x.permute(1, 0, 2) # (B,S,hidden_size) -> (S,B,hidden_size)
|
||||
|
||||
# autoregressive mask
|
||||
device = next(self.parameters()).device
|
||||
N = x.shape[0]
|
||||
mask = mask_upper_triangular(N, device)
|
||||
|
||||
r = self.encoder(x, mask=mask)
|
||||
r = r.permute(1, 0, 2) # (S,B,hidden_size) -> (B,S,hidden_size)
|
||||
r = r.mean(1) # (B,S,hidden_size) -> (B,hidden_size)
|
||||
mean = self.mean(r)
|
||||
log_sigma = self.log_var(r)
|
||||
sigma = self._min_std + (1 - self._min_std) * torch.sigmoid(log_sigma * 0.5)
|
||||
@@ -83,20 +87,22 @@ class Decoder(nn.Module):
|
||||
self._std = nn.Linear(hidden_size, y_size)
|
||||
self._min_std = min_std
|
||||
|
||||
def forward(self, z, future_x):
|
||||
# concatenate future_x and representation
|
||||
future_x = self.dec_emb(future_x)
|
||||
future_x = future_x.permute(1, 0, 2)
|
||||
def forward(self, z, x):
|
||||
|
||||
z = self.z_emb(z)
|
||||
z = z.permute(1, 0, 2)
|
||||
# requires (C, B, hidden_size)
|
||||
# (B, S, latent_size) -> (B, S, H) -> (S, B, H)
|
||||
x = self.dec_emb(x).permute(1, 0, 2)
|
||||
|
||||
# r = torch.cat([z, future_x], dim=-1)
|
||||
# (B, S, latent_size) -> (B, S, H) -> (S, B, H)
|
||||
z = self.z_emb(z).permute(1, 0, 2)
|
||||
|
||||
r = self._decoder(future_x, z)
|
||||
# autoregressive mask
|
||||
device = next(self.parameters()).device
|
||||
N = x.shape[0]
|
||||
mask = mask_upper_triangular(N, device)
|
||||
|
||||
r = self._decoder(x, z, tgt_mask=mask)
|
||||
|
||||
# [T, B, emb_dim] -> [B, T, emb_dim]
|
||||
# [S, B, H] -> [B, S, H]
|
||||
r = r.permute(1, 0, 2).contiguous()
|
||||
|
||||
# Get the mean and the variance
|
||||
@@ -112,7 +118,7 @@ class Decoder(nn.Module):
|
||||
class TransformerProcess(nn.Module):
|
||||
# WIP trying one that encodes a dist
|
||||
# TODO autoregressive mask
|
||||
def __init__(self, x_size, y_size, hidden_size=16, latent_dim=32, nhead=8, nlayers=2, attention_dropout=0, min_std=0.01):
|
||||
def __init__(self, x_size, y_size, hidden_size=64, latent_dim=32, nhead=8, nlayers=4, dropout=0, min_std=0.01):
|
||||
super().__init__()
|
||||
self._min_std = min_std
|
||||
|
||||
@@ -121,7 +127,7 @@ class TransformerProcess(nn.Module):
|
||||
hidden_size=hidden_size,
|
||||
latent_dim=latent_dim,
|
||||
num_layers=nlayers,
|
||||
dropout=attention_dropout,
|
||||
dropout=dropout,
|
||||
min_std=min_std,
|
||||
nhead=nhead,
|
||||
)
|
||||
@@ -131,7 +137,7 @@ class TransformerProcess(nn.Module):
|
||||
y_size,
|
||||
hidden_size=hidden_size,
|
||||
latent_dim=latent_dim,
|
||||
dropout=attention_dropout,
|
||||
dropout=dropout,
|
||||
min_std=min_std,
|
||||
num_layers=nlayers,
|
||||
nhead=nhead,
|
||||
@@ -143,9 +149,6 @@ class TransformerProcess(nn.Module):
|
||||
dist_prior = self._latent_encoder(past_x, past_y)
|
||||
|
||||
if (future_y is not None):
|
||||
# If future_y is provided, we can create an auxilary loss
|
||||
# Making sure the encoded distribition from the past
|
||||
# Is as close as possible to the future
|
||||
x = torch.cat([past_x, future_x], 1)
|
||||
y = torch.cat([past_y, future_y], 1)
|
||||
dist_post = self._latent_encoder(x, y)
|
||||
@@ -156,19 +159,19 @@ class TransformerProcess(nn.Module):
|
||||
else:
|
||||
z = dist_prior.loc
|
||||
num_targets = future_x.size(1)
|
||||
z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, T_target, H]
|
||||
z = z.unsqueeze(1).repeat(1, num_targets, 1) # [B, S_target, H]
|
||||
|
||||
dist = self._decoder(z, future_x)
|
||||
dist_out = self._decoder(z, future_x)
|
||||
loss = None
|
||||
if future_y is not None:
|
||||
log_p = dist.log_prob(future_y).mean(-1)
|
||||
# Make sure output dist matches label
|
||||
log_p = dist_out.log_prob(future_y).mean(-1)
|
||||
|
||||
# Making sure the encoded distribition from the past is as close as possible to the future
|
||||
kl_loss = torch.distributions.kl_divergence(dist_post, dist_prior).mean(
|
||||
-1
|
||||
) # [B, R].mean(-1)
|
||||
kl_loss = kl_loss[:, None].expand(log_p.shape)
|
||||
mse_loss = F.mse_loss(dist.loc, future_y, reduction="none")[
|
||||
:, : past_x.size(1)
|
||||
].mean()
|
||||
loss = (kl_loss - log_p).mean()
|
||||
return dist, {'loss': loss}
|
||||
return dist_out, {'loss': loss}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user