Files
attentive-neural-processes/neural_processes/modules/attention.py
T
wassname a1c26dfbb7 logger
2020-04-26 12:36:19 +08:00

145 lines
4.8 KiB
Python

import torch
from torch import nn
import torch.nn.functional as F
from .modules import BatchMLP
def batch_first_attention(module: nn.MultiheadAttention, k, v, q, **kwargs):
"""
Batch first attention
[batch, seq, hidden] instead of [seq, batch, hidden]
see https://pytorch.org/docs/stable/nn.html#torch.nn.MultiheadAttention
"""
assert isinstance(
module, nn.MultiheadAttention
), f"should be nn.MultiheadAttention not {type(module)}"
q = q.permute(1, 0, 2)
k = k.permute(1, 0, 2)
v = v.permute(1, 0, 2)
attn_output, attn_output_weights = module(query=q, key=k, value=v, **kwargs)
return attn_output.permute(1, 0, 2).contiguous(), attn_output_weights
class AttnLinear(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.linear = nn.Linear(in_channels, out_channels, bias=False)
torch.nn.init.normal_(self.linear.weight, std=in_channels ** -0.5)
def forward(self, x):
x = self.linear(x)
return x
class Attention(nn.Module):
def __init__(
self,
hidden_dim,
attention_type,
attention_layers=2,
n_heads=8,
x_dim=1,
rep="mlp",
dropout=0,
batchnorm=False,
):
super().__init__()
self._rep = rep
if self._rep == "mlp":
self.batch_mlp_k = BatchMLP(
x_dim,
hidden_dim,
attention_layers,
dropout=dropout,
batchnorm=batchnorm,
)
self.batch_mlp_q = BatchMLP(
x_dim,
hidden_dim,
attention_layers,
dropout=dropout,
batchnorm=batchnorm,
)
elif self._rep == "lstm":
self._lstm = LSTMBlock(x_dim, hidden_dim, batchnorm=batchnorm, dropout=dropout, num_layers=attention_layers)
if attention_type == "uniform":
self._attention_func = self._uniform_attention
elif attention_type == "laplace":
self._attention_func = self._laplace_attention
elif attention_type == "dot":
self._attention_func = self._dot_attention
elif attention_type == "multihead":
self._W_k = nn.ModuleList(
[AttnLinear(hidden_dim, hidden_dim) for _ in range(n_heads)]
)
self._W_v = nn.ModuleList(
[AttnLinear(hidden_dim, hidden_dim) for _ in range(n_heads)]
)
self._W_q = nn.ModuleList(
[AttnLinear(hidden_dim, hidden_dim) for _ in range(n_heads)]
)
self._W = AttnLinear(n_heads * hidden_dim, hidden_dim)
self._attention_func = self._multihead_attention
self.n_heads = n_heads
elif attention_type == "ptmultihead":
self._W = torch.nn.MultiheadAttention(
hidden_dim, n_heads, bias=False, dropout=dropout
)
self._attention_func = self._pytorch_multihead_attention
else:
raise NotImplementedError
def forward(self, k, v, q):
if self._rep == "mlp":
k = self.batch_mlp_k(k)
q = self.batch_mlp_q(q)
elif self._rep == "lstm":
k = self.batch_lstm(k)
q = self.batch_lstm(q)
rep = self._attention_func(k, v, q)
return rep
def _uniform_attention(self, k, v, q):
total_points = q.shape[1]
rep = torch.mean(v, dim=1, keepdim=True)
rep = rep.repeat(1, total_points, 1)
return rep
def _laplace_attention(self, k, v, q, scale=0.5):
k_ = k.unsqueeze(1)
v_ = v.unsqueeze(2)
unnorm_weights = torch.abs((k_ - v_) * scale)
unnorm_weights = unnorm_weights.sum(dim=-1)
weights = torch.softmax(unnorm_weights, dim=-1)
rep = torch.einsum("bik,bkj->bij", weights, v)
return rep
def _dot_attention(self, k, v, q):
scale = q.shape[-1] ** 0.5
unnorm_weights = torch.einsum("bjk,bik->bij", k, q) / scale
weights = torch.softmax(unnorm_weights, dim=-1)
rep = torch.einsum("bik,bkj->bij", weights, v)
return rep
def _multihead_attention(self, k, v, q):
outs = []
for i in range(self.n_heads):
k_ = self._W_k[i](k)
v_ = self._W_v[i](v)
q_ = self._W_q[i](q)
out = self._dot_attention(k_, v_, q_)
outs.append(out)
outs = torch.stack(outs, dim=-1)
outs = outs.view(outs.shape[0], outs.shape[1], -1)
rep = self._W(outs)
return rep
def _pytorch_multihead_attention(self, k, v, q):
# Pytorch multiheaded attention takes inputs if diff order and permutation
return batch_first_attention(self._W, q=q, k=k, v=v)[0]