diff --git a/seq2seq_time/models/transformer_process.py b/seq2seq_time/models/transformer_process.py index f0a53b0..0614016 100644 --- a/seq2seq_time/models/transformer_process.py +++ b/seq2seq_time/models/transformer_process.py @@ -121,8 +121,12 @@ class Decoder(nn.Module): return dist class TransformerProcess(nn.Module): - # WIP trying one that encodes a dist - # TODO autoregressive mask + """ + A attempted simplification of an attentive neural process + + Works on sequential data, has no deterministic encoder. Uses full transformer layer instead of custom attention. Has an autoregressive mask on the encoder and decoder. + + """ def __init__(self, x_size, y_size, hidden_size=64, latent_dim=32, nhead=8, nlayers=4, dropout=0, min_std=0.01): super().__init__() self._min_std = min_std diff --git a/seq2seq_time/models/xattention.py b/seq2seq_time/models/xattention.py index 01fbee0..2f6663c 100644 --- a/seq2seq_time/models/xattention.py +++ b/seq2seq_time/models/xattention.py @@ -6,7 +6,7 @@ from ..util import mask_upper_triangular class CrossAttention(nn.Module): """ - A single transformer, masking nan or 0 + A single transformer, using cross attention, like in the determistic encoder in attentive neural processes. """ def __init__(self, x_dim, y_dim, attention_dropout=0, nhead=8, nlayers=8, hidden_size=32, nan_value=0, min_std=0.01): super().__init__() @@ -14,7 +14,22 @@ class CrossAttention(nn.Module): self.nan_value = nan_value enc_x_dim = x_dim + y_dim - self.enc_emb = nn.Linear(enc_x_dim, hidden_size) + self.v_encoder = nn.Linear(enc_x_dim, hidden_size) + self.k_encoder = nn.Linear(x_dim, hidden_size) + self.q_encoder = nn.Linear(x_dim, hidden_size) + self.self_attn_k = torch.nn.MultiheadAttention( + hidden_size, nhead, bias=False, dropout=attention_dropout + ) + self.self_attn_q = torch.nn.MultiheadAttention( + hidden_size, nhead, bias=False, dropout=attention_dropout + ) + self.self_attn_v = torch.nn.MultiheadAttention( + hidden_size, nhead, bias=False, dropout=attention_dropout + ) + self.cross_attn = torch.nn.MultiheadAttention( + hidden_size, nhead, bias=False, dropout=attention_dropout + ) + encoder_norm = nn.LayerNorm(hidden_size) layer_enc = nn.TransformerEncoderLayer( d_model=hidden_size, @@ -23,7 +38,7 @@ class CrossAttention(nn.Module): nhead=nhead, # activation ) - self.encoder = nn.TransformerEncoder( + self.transformer = nn.TransformerEncoder( layer_enc, num_layers=nlayers, norm=encoder_norm ) self.mean = nn.Linear(hidden_size, y_dim) @@ -31,35 +46,34 @@ class CrossAttention(nn.Module): def forward(self, past_x, past_y, future_x, future_y=None): device = next(self.parameters()).device - B, S, _ = future_x.shape - future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device) - # future_y_fake = ( - # torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * past_y[:, -1].repeat(B, S, 1) - # ) context = torch.cat([past_x, past_y], -1).detach() - target = torch.cat([future_x, future_y_fake], -1).detach() - x = torch.cat([context, target * 1], 1).detach() # Masks - x_mask = torch.isfinite(x) & (x != self.nan_value) - x[~x_mask] = 0 - x = x.detach() - x_key_padding_mask = ~x_mask.any(-1) + B, C, _ = past_x.shape + past_causal_mask = mask_upper_triangular(C, device) + B, T, _ = future_x.shape + future_causal_mask = mask_upper_triangular(T, device) - x = self.enc_emb(x).permute(1, 0, 2) + # Change feature size + k = self.k_encoder(past_x).permute(1, 0, 2) + q = self.q_encoder(future_x).permute(1, 0, 2) + v = self.v_encoder(context).permute(1, 0, 2) - S, B, _ = x.shape - mask = mask_upper_triangular(S, device) - - outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask - ).permute( - 1, 0, 2 - ) + # Self attention with causal mask + v = self.self_attn_v(v, v, v, attn_mask=past_causal_mask)[0] + q = self.self_attn_q(q, q, q, attn_mask=future_causal_mask)[0] + k = self.self_attn_k(k, k, k, attn_mask=past_causal_mask)[0] - # Seems to help a little, especially with extrapolating out of bounds - steps = past_y.shape[1] - mean = self.mean(outputs)[:, steps:, :] - log_sigma = self.std(outputs)[:, steps:, :] + # Cross attention + h = self.cross_attn(query=q, key=k, value=v)[0] + + # Transformer + outputs = self.transformer(h, mask=future_causal_mask) + outputs = outputs.permute(1, 0, 2) + + # Head + mean = self.mean(outputs) + log_sigma = self.std(outputs) sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma) return torch.distributions.Normal(mean, sigma), {} diff --git a/seq2seq_time/silence.py b/seq2seq_time/silence.py new file mode 100644 index 0000000..9efe612 --- /dev/null +++ b/seq2seq_time/silence.py @@ -0,0 +1,3 @@ +import warnings +warnings.filterwarnings('ignore', 'Consider increasing the value of the `num_workers` argument', UserWarning) +warnings.filterwarnings('ignore', 'Your val_dataloader has `shuffle=True`', UserWarning)