mirror of
https://github.com/wassname/attentive-neural-processes.git
synced 2026-06-27 18:03:39 +08:00
145 lines
4.8 KiB
Python
145 lines
4.8 KiB
Python
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from .modules import BatchMLP
|
|
|
|
|
|
def batch_first_attention(module: nn.MultiheadAttention, k, v, q, **kwargs):
|
|
"""
|
|
Batch first attention
|
|
|
|
[batch, seq, hidden] instead of [seq, batch, hidden]
|
|
|
|
see https://pytorch.org/docs/stable/nn.html#torch.nn.MultiheadAttention
|
|
"""
|
|
assert isinstance(
|
|
module, nn.MultiheadAttention
|
|
), f"should be nn.MultiheadAttention not {type(module)}"
|
|
q = q.permute(1, 0, 2)
|
|
k = k.permute(1, 0, 2)
|
|
v = v.permute(1, 0, 2)
|
|
attn_output, attn_output_weights = module(query=q, key=k, value=v, **kwargs)
|
|
return attn_output.permute(1, 0, 2).contiguous(), attn_output_weights
|
|
|
|
class AttnLinear(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.linear = nn.Linear(in_channels, out_channels, bias=False)
|
|
torch.nn.init.normal_(self.linear.weight, std=in_channels ** -0.5)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_dim,
|
|
attention_type,
|
|
attention_layers=2,
|
|
n_heads=8,
|
|
x_dim=1,
|
|
rep="mlp",
|
|
dropout=0,
|
|
batchnorm=False,
|
|
):
|
|
super().__init__()
|
|
self._rep = rep
|
|
|
|
if self._rep == "mlp":
|
|
self.batch_mlp_k = BatchMLP(
|
|
x_dim,
|
|
hidden_dim,
|
|
attention_layers,
|
|
dropout=dropout,
|
|
batchnorm=batchnorm,
|
|
)
|
|
self.batch_mlp_q = BatchMLP(
|
|
x_dim,
|
|
hidden_dim,
|
|
attention_layers,
|
|
dropout=dropout,
|
|
batchnorm=batchnorm,
|
|
)
|
|
elif self._rep == "lstm":
|
|
self._lstm = LSTMBlock(x_dim, hidden_dim, batchnorm=batchnorm, dropout=dropout, num_layers=attention_layers)
|
|
|
|
if attention_type == "uniform":
|
|
self._attention_func = self._uniform_attention
|
|
elif attention_type == "laplace":
|
|
self._attention_func = self._laplace_attention
|
|
elif attention_type == "dot":
|
|
self._attention_func = self._dot_attention
|
|
elif attention_type == "multihead":
|
|
self._W_k = nn.ModuleList(
|
|
[AttnLinear(hidden_dim, hidden_dim) for _ in range(n_heads)]
|
|
)
|
|
self._W_v = nn.ModuleList(
|
|
[AttnLinear(hidden_dim, hidden_dim) for _ in range(n_heads)]
|
|
)
|
|
self._W_q = nn.ModuleList(
|
|
[AttnLinear(hidden_dim, hidden_dim) for _ in range(n_heads)]
|
|
)
|
|
self._W = AttnLinear(n_heads * hidden_dim, hidden_dim)
|
|
self._attention_func = self._multihead_attention
|
|
self.n_heads = n_heads
|
|
elif attention_type == "ptmultihead":
|
|
self._W = torch.nn.MultiheadAttention(
|
|
hidden_dim, n_heads, bias=False, dropout=dropout
|
|
)
|
|
self._attention_func = self._pytorch_multihead_attention
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def forward(self, k, v, q):
|
|
if self._rep == "mlp":
|
|
k = self.batch_mlp_k(k)
|
|
q = self.batch_mlp_q(q)
|
|
elif self._rep == "lstm":
|
|
k = self.batch_lstm(k)
|
|
q = self.batch_lstm(q)
|
|
rep = self._attention_func(k, v, q)
|
|
return rep
|
|
|
|
def _uniform_attention(self, k, v, q):
|
|
total_points = q.shape[1]
|
|
rep = torch.mean(v, dim=1, keepdim=True)
|
|
rep = rep.repeat(1, total_points, 1)
|
|
return rep
|
|
|
|
def _laplace_attention(self, k, v, q, scale=0.5):
|
|
k_ = k.unsqueeze(1)
|
|
v_ = v.unsqueeze(2)
|
|
unnorm_weights = torch.abs((k_ - v_) * scale)
|
|
unnorm_weights = unnorm_weights.sum(dim=-1)
|
|
weights = torch.softmax(unnorm_weights, dim=-1)
|
|
rep = torch.einsum("bik,bkj->bij", weights, v)
|
|
return rep
|
|
|
|
def _dot_attention(self, k, v, q):
|
|
scale = q.shape[-1] ** 0.5
|
|
unnorm_weights = torch.einsum("bjk,bik->bij", k, q) / scale
|
|
weights = torch.softmax(unnorm_weights, dim=-1)
|
|
|
|
rep = torch.einsum("bik,bkj->bij", weights, v)
|
|
return rep
|
|
|
|
def _multihead_attention(self, k, v, q):
|
|
outs = []
|
|
for i in range(self.n_heads):
|
|
k_ = self._W_k[i](k)
|
|
v_ = self._W_v[i](v)
|
|
q_ = self._W_q[i](q)
|
|
out = self._dot_attention(k_, v_, q_)
|
|
outs.append(out)
|
|
outs = torch.stack(outs, dim=-1)
|
|
outs = outs.view(outs.shape[0], outs.shape[1], -1)
|
|
rep = self._W(outs)
|
|
return rep
|
|
|
|
def _pytorch_multihead_attention(self, k, v, q):
|
|
# Pytorch multiheaded attention takes inputs if diff order and permutation
|
|
return batch_first_attention(self._W, q=q, k=k, v=v)[0]
|