mirror of
https://github.com/wassname/seq2seq-time.git
synced 2026-06-27 19:00:55 +08:00
docstring, xattn working
This commit is contained in:
@@ -121,8 +121,12 @@ class Decoder(nn.Module):
|
||||
return dist
|
||||
|
||||
class TransformerProcess(nn.Module):
|
||||
# WIP trying one that encodes a dist
|
||||
# TODO autoregressive mask
|
||||
"""
|
||||
A attempted simplification of an attentive neural process
|
||||
|
||||
Works on sequential data, has no deterministic encoder. Uses full transformer layer instead of custom attention. Has an autoregressive mask on the encoder and decoder.
|
||||
|
||||
"""
|
||||
def __init__(self, x_size, y_size, hidden_size=64, latent_dim=32, nhead=8, nlayers=4, dropout=0, min_std=0.01):
|
||||
super().__init__()
|
||||
self._min_std = min_std
|
||||
|
||||
@@ -6,7 +6,7 @@ from ..util import mask_upper_triangular
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
A single transformer, masking nan or 0
|
||||
A single transformer, using cross attention, like in the determistic encoder in attentive neural processes.
|
||||
"""
|
||||
def __init__(self, x_dim, y_dim, attention_dropout=0, nhead=8, nlayers=8, hidden_size=32, nan_value=0, min_std=0.01):
|
||||
super().__init__()
|
||||
@@ -14,7 +14,22 @@ class CrossAttention(nn.Module):
|
||||
self.nan_value = nan_value
|
||||
enc_x_dim = x_dim + y_dim
|
||||
|
||||
self.enc_emb = nn.Linear(enc_x_dim, hidden_size)
|
||||
self.v_encoder = nn.Linear(enc_x_dim, hidden_size)
|
||||
self.k_encoder = nn.Linear(x_dim, hidden_size)
|
||||
self.q_encoder = nn.Linear(x_dim, hidden_size)
|
||||
self.self_attn_k = torch.nn.MultiheadAttention(
|
||||
hidden_size, nhead, bias=False, dropout=attention_dropout
|
||||
)
|
||||
self.self_attn_q = torch.nn.MultiheadAttention(
|
||||
hidden_size, nhead, bias=False, dropout=attention_dropout
|
||||
)
|
||||
self.self_attn_v = torch.nn.MultiheadAttention(
|
||||
hidden_size, nhead, bias=False, dropout=attention_dropout
|
||||
)
|
||||
self.cross_attn = torch.nn.MultiheadAttention(
|
||||
hidden_size, nhead, bias=False, dropout=attention_dropout
|
||||
)
|
||||
|
||||
encoder_norm = nn.LayerNorm(hidden_size)
|
||||
layer_enc = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_size,
|
||||
@@ -23,7 +38,7 @@ class CrossAttention(nn.Module):
|
||||
nhead=nhead,
|
||||
# activation
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
layer_enc, num_layers=nlayers, norm=encoder_norm
|
||||
)
|
||||
self.mean = nn.Linear(hidden_size, y_dim)
|
||||
@@ -31,35 +46,34 @@ class CrossAttention(nn.Module):
|
||||
|
||||
def forward(self, past_x, past_y, future_x, future_y=None):
|
||||
device = next(self.parameters()).device
|
||||
B, S, _ = future_x.shape
|
||||
future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device)
|
||||
# future_y_fake = (
|
||||
# torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * past_y[:, -1].repeat(B, S, 1)
|
||||
# )
|
||||
context = torch.cat([past_x, past_y], -1).detach()
|
||||
target = torch.cat([future_x, future_y_fake], -1).detach()
|
||||
x = torch.cat([context, target * 1], 1).detach()
|
||||
|
||||
# Masks
|
||||
x_mask = torch.isfinite(x) & (x != self.nan_value)
|
||||
x[~x_mask] = 0
|
||||
x = x.detach()
|
||||
x_key_padding_mask = ~x_mask.any(-1)
|
||||
B, C, _ = past_x.shape
|
||||
past_causal_mask = mask_upper_triangular(C, device)
|
||||
B, T, _ = future_x.shape
|
||||
future_causal_mask = mask_upper_triangular(T, device)
|
||||
|
||||
x = self.enc_emb(x).permute(1, 0, 2)
|
||||
# Change feature size
|
||||
k = self.k_encoder(past_x).permute(1, 0, 2)
|
||||
q = self.q_encoder(future_x).permute(1, 0, 2)
|
||||
v = self.v_encoder(context).permute(1, 0, 2)
|
||||
|
||||
S, B, _ = x.shape
|
||||
mask = mask_upper_triangular(S, device)
|
||||
|
||||
outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
|
||||
).permute(
|
||||
1, 0, 2
|
||||
)
|
||||
# Self attention with causal mask
|
||||
v = self.self_attn_v(v, v, v, attn_mask=past_causal_mask)[0]
|
||||
q = self.self_attn_q(q, q, q, attn_mask=future_causal_mask)[0]
|
||||
k = self.self_attn_k(k, k, k, attn_mask=past_causal_mask)[0]
|
||||
|
||||
# Seems to help a little, especially with extrapolating out of bounds
|
||||
steps = past_y.shape[1]
|
||||
mean = self.mean(outputs)[:, steps:, :]
|
||||
log_sigma = self.std(outputs)[:, steps:, :]
|
||||
# Cross attention
|
||||
h = self.cross_attn(query=q, key=k, value=v)[0]
|
||||
|
||||
# Transformer
|
||||
outputs = self.transformer(h, mask=future_causal_mask)
|
||||
outputs = outputs.permute(1, 0, 2)
|
||||
|
||||
# Head
|
||||
mean = self.mean(outputs)
|
||||
log_sigma = self.std(outputs)
|
||||
|
||||
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
|
||||
return torch.distributions.Normal(mean, sigma), {}
|
||||
|
||||
Reference in New Issue
Block a user