docstring, xattn working

This commit is contained in:
wassname
2020-11-02 11:47:04 +08:00
parent 632fa859a0
commit cd8810f32d
3 changed files with 49 additions and 28 deletions
+6 -2
View File
@@ -121,8 +121,12 @@ class Decoder(nn.Module):
return dist
class TransformerProcess(nn.Module):
# WIP trying one that encodes a dist
# TODO autoregressive mask
"""
A attempted simplification of an attentive neural process
Works on sequential data, has no deterministic encoder. Uses full transformer layer instead of custom attention. Has an autoregressive mask on the encoder and decoder.
"""
def __init__(self, x_size, y_size, hidden_size=64, latent_dim=32, nhead=8, nlayers=4, dropout=0, min_std=0.01):
super().__init__()
self._min_std = min_std
+40 -26
View File
@@ -6,7 +6,7 @@ from ..util import mask_upper_triangular
class CrossAttention(nn.Module):
"""
A single transformer, masking nan or 0
A single transformer, using cross attention, like in the determistic encoder in attentive neural processes.
"""
def __init__(self, x_dim, y_dim, attention_dropout=0, nhead=8, nlayers=8, hidden_size=32, nan_value=0, min_std=0.01):
super().__init__()
@@ -14,7 +14,22 @@ class CrossAttention(nn.Module):
self.nan_value = nan_value
enc_x_dim = x_dim + y_dim
self.enc_emb = nn.Linear(enc_x_dim, hidden_size)
self.v_encoder = nn.Linear(enc_x_dim, hidden_size)
self.k_encoder = nn.Linear(x_dim, hidden_size)
self.q_encoder = nn.Linear(x_dim, hidden_size)
self.self_attn_k = torch.nn.MultiheadAttention(
hidden_size, nhead, bias=False, dropout=attention_dropout
)
self.self_attn_q = torch.nn.MultiheadAttention(
hidden_size, nhead, bias=False, dropout=attention_dropout
)
self.self_attn_v = torch.nn.MultiheadAttention(
hidden_size, nhead, bias=False, dropout=attention_dropout
)
self.cross_attn = torch.nn.MultiheadAttention(
hidden_size, nhead, bias=False, dropout=attention_dropout
)
encoder_norm = nn.LayerNorm(hidden_size)
layer_enc = nn.TransformerEncoderLayer(
d_model=hidden_size,
@@ -23,7 +38,7 @@ class CrossAttention(nn.Module):
nhead=nhead,
# activation
)
self.encoder = nn.TransformerEncoder(
self.transformer = nn.TransformerEncoder(
layer_enc, num_layers=nlayers, norm=encoder_norm
)
self.mean = nn.Linear(hidden_size, y_dim)
@@ -31,35 +46,34 @@ class CrossAttention(nn.Module):
def forward(self, past_x, past_y, future_x, future_y=None):
device = next(self.parameters()).device
B, S, _ = future_x.shape
future_y_fake = past_y[:, -1:, :].repeat(1, S, 1).to(device)
# future_y_fake = (
# torch.ones(past_y.shape[0], future_x.shape[1], past_y.shape[2]).float().to(device) * past_y[:, -1].repeat(B, S, 1)
# )
context = torch.cat([past_x, past_y], -1).detach()
target = torch.cat([future_x, future_y_fake], -1).detach()
x = torch.cat([context, target * 1], 1).detach()
# Masks
x_mask = torch.isfinite(x) & (x != self.nan_value)
x[~x_mask] = 0
x = x.detach()
x_key_padding_mask = ~x_mask.any(-1)
B, C, _ = past_x.shape
past_causal_mask = mask_upper_triangular(C, device)
B, T, _ = future_x.shape
future_causal_mask = mask_upper_triangular(T, device)
x = self.enc_emb(x).permute(1, 0, 2)
# Change feature size
k = self.k_encoder(past_x).permute(1, 0, 2)
q = self.q_encoder(future_x).permute(1, 0, 2)
v = self.v_encoder(context).permute(1, 0, 2)
S, B, _ = x.shape
mask = mask_upper_triangular(S, device)
outputs = self.encoder(x, mask=mask#, src_key_padding_mask=x_key_padding_mask
).permute(
1, 0, 2
)
# Self attention with causal mask
v = self.self_attn_v(v, v, v, attn_mask=past_causal_mask)[0]
q = self.self_attn_q(q, q, q, attn_mask=future_causal_mask)[0]
k = self.self_attn_k(k, k, k, attn_mask=past_causal_mask)[0]
# Seems to help a little, especially with extrapolating out of bounds
steps = past_y.shape[1]
mean = self.mean(outputs)[:, steps:, :]
log_sigma = self.std(outputs)[:, steps:, :]
# Cross attention
h = self.cross_attn(query=q, key=k, value=v)[0]
# Transformer
outputs = self.transformer(h, mask=future_causal_mask)
outputs = outputs.permute(1, 0, 2)
# Head
mean = self.mean(outputs)
log_sigma = self.std(outputs)
sigma = self._min_std + (1 - self._min_std) * F.softplus(log_sigma)
return torch.distributions.Normal(mean, sigma), {}
+3
View File
@@ -0,0 +1,3 @@
import warnings
warnings.filterwarnings('ignore', 'Consider increasing the value of the `num_workers` argument', UserWarning)
warnings.filterwarnings('ignore', 'Your val_dataloader has `shuffle=True`', UserWarning)