Files
pytorch-transformer-ts/switch/module.py
T
2022-06-06 14:11:51 +02:00

587 lines
21 KiB
Python

from typing import List, Optional, Union, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.transformer import _get_activation_fn, _get_clones
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
class SwitchFeedForward(nn.Module):
"""
## Routing among multiple FFNs
"""
def __init__(
self,
*,
capacity_factor: float,
drop_tokens: bool,
is_scale_prob: bool,
n_experts: int,
expert: nn.Module,
d_model: int,
dim_feedforward: int,
):
"""
* `capacity_factor` is the capacity of each expert as a factor relative to ideally balanced load
* `drop_tokens` specifies whether to drop tokens if more tokens are routed to an expert than the capacity
* `is_scale_prob` specifies whether to multiply the input to the FFN by the routing probability
* `n_experts` is the number of experts
* `expert` is the expert layer, a [FFN module](../feed_forward.html)
* `d_model` is the number of features in a token embedding
"""
super().__init__()
self.capacity_factor = capacity_factor
self.is_scale_prob = is_scale_prob
self.n_experts = n_experts
self.drop_tokens = drop_tokens
self.dim_feedforward = dim_feedforward
# make copies of the FFNs
self.experts = _get_clones(expert, n_experts)
# Routing layer and softmax
self.switch = nn.Linear(d_model, n_experts)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: torch.Tensor):
"""
* `x` is the input to the switching module with shape `[batch_size, seq_len, d_model]`
"""
# Capture the shape to change shapes later
batch_size, seq_len, d_model = x.shape
# Flatten the sequence and batch dimensions
x = x.view(-1, d_model)
# Get routing probabilities for each of the tokens.
# $$p_i(x) = \frac{e^{h(x)_i}}{\sum^N_j e^{h(x)_j}}$$
# where $N$ is the number of experts `n_experts` and
# $h(\cdot)$ is the linear transformation of token embeddings.
route_prob = self.softmax(self.switch(x))
# Get the maximum routing probabilities and the routes.
# We route to the expert with highest probability
route_prob_max, routes = torch.max(route_prob, dim=-1)
# Get indexes of tokens going to each expert
indexes_list = [
torch.eq(routes, i).nonzero(as_tuple=True)[0] for i in range(self.n_experts)
]
# Initialize an empty tensor to store outputs
final_output = x.new_zeros((batch_size * seq_len, self.dim_feedforward))
# Capacity of each expert.
# $$\mathrm{expert\;capacity} =
# \frac{\mathrm{tokens\;per\;batch}}{\mathrm{number\;of\;experts}}
# \times \mathrm{capacity\;factor}$$
capacity = int(self.capacity_factor * len(x) / self.n_experts)
# Number of tokens routed to each expert.
counts = x.new_tensor([len(indexes_list[i]) for i in range(self.n_experts)])
# Initialize an empty list of dropped tokens
dropped = []
# Only drop tokens if `drop_tokens` is `True`.
if self.drop_tokens:
# Drop tokens in each of the experts
for i in range(self.n_experts):
# Ignore if the expert is not over capacity
if len(indexes_list[i]) <= capacity:
continue
# Shuffle indexes before dropping
indexes_list[i] = indexes_list[i][torch.randperm(len(indexes_list[i]))]
# Collect the tokens over capacity as dropped tokens
dropped.append(indexes_list[i][capacity:])
# Keep only the tokens upto the capacity of the expert
indexes_list[i] = indexes_list[i][:capacity]
# Get outputs of the expert FFNs
expert_output = [
self.experts[i](x[indexes_list[i], :]) for i in range(self.n_experts)
]
# Assign to final output
for i in range(self.n_experts):
final_output[indexes_list[i], :] = expert_output[i]
# Pass through the dropped tokens
if dropped:
dropped = torch.cat(dropped)
final_output[dropped, :] = x[dropped, :]
if self.is_scale_prob:
# Multiply by the expert outputs by the probabilities $y = p_i(x) E_i(x)$
final_output = final_output * route_prob_max.view(-1, 1)
else:
# Don't scale the values but multiply by $\frac{p}{\hat{p}} = 1$ so that the gradients flow
# (this is something we experimented with).
final_output = final_output * (
route_prob_max / route_prob_max.detach()
).view(-1, 1)
# Change the shape of the final output back to `[batch_size, seq_len, d_ff]`
final_output = final_output.view(batch_size, seq_len, -1)
# Return
#
# * the final output
# * counts: number of tokens routed to each expert
# * sum of probabilities for each expert
# * number of tokens dropped.
# * routing probabilities of the selected experts
#
# These are used for the load balancing loss and logging
return final_output, counts, route_prob.sum(0), len(dropped), route_prob_max
class SwitchTransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
nhead: int,
capacity_factor: float,
drop_tokens: bool,
is_scale_prob: bool,
n_experts: int = 1,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[torch.Tensor], torch.Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = True,
norm_first: bool = False,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super(SwitchTransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
)
# Implementation of Feedforward model
linear = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.linear1 = SwitchFeedForward(
capacity_factor=capacity_factor,
drop_tokens=drop_tokens,
is_scale_prob=is_scale_prob,
n_experts=n_experts,
expert=linear,
d_model=d_model,
dim_feedforward=dim_feedforward,
)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = F.relu
super(SwitchTransformerEncoderLayer, self).__setstate__(state)
def forward(
self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = src
if self.norm_first:
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor],
key_padding_mask: Optional[torch.Tensor],
) -> torch.Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
x, _, _, _, _ = self.linear1(x)
x = self.linear2(self.dropout(self.activation(x)))
return self.dropout2(x)
class SwitchTransformerModel(nn.Module):
@validated()
def __init__(
self,
freq: str,
context_length: int,
prediction_length: int,
num_feat_dynamic_real: int,
num_feat_static_real: int,
num_feat_static_cat: int,
cardinality: List[int],
# switch transformer arguments
nhead: int,
num_encoder_layers: int,
num_decoder_layers: int,
dim_feedforward: int,
capacity_factor: float,
activation: str = "gelu",
dropout: float = 0.1,
layer_norm_eps: float = 1e-5,
drop_tokens: bool = False,
is_scale_prob: bool = True,
n_experts: int = 1,
# univariate input
input_size: int = 1,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
lags_seq: Optional[List[int]] = None,
scaling: bool = True,
num_parallel_samples: int = 100,
) -> None:
super().__init__()
self.input_size = input_size
self.target_shape = distr_output.event_shape
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.embedding_dimension = (
embedding_dimension
if embedding_dimension is not None or cardinality is None
else [min(50, (cat + 1) // 2) for cat in cardinality]
)
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
self.num_parallel_samples = num_parallel_samples
self.history_length = context_length + max(self.lags_seq)
self.embedder = FeatureEmbedder(
cardinalities=cardinality,
embedding_dims=self.embedding_dimension,
)
if scaling:
self.scaler = MeanScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
# total feature size
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
self.context_length = context_length
self.prediction_length = prediction_length
self.distr_output = distr_output
self.param_proj = distr_output.get_args_proj(d_model)
# switch-transformer enc
switch_encoder_layer = SwitchTransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
capacity_factor=capacity_factor,
drop_tokens=drop_tokens,
is_scale_prob=is_scale_prob,
n_experts=n_experts,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
)
switch_encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
switch_encoder = nn.TransformerEncoder(
switch_encoder_layer, num_encoder_layers, switch_encoder_norm
)
# vanilla decoder and mask initializer
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
custom_encoder=switch_encoder,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
batch_first=True,
)
# causal decoder tgt mask
self.register_buffer(
"tgt_mask",
nn.Transformer.generate_square_subsequent_mask(prediction_length),
)
@property
def _number_of_features(self) -> int:
return (
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
)
@property
def _past_length(self) -> int:
return self.context_length + max(self.lags_seq)
def get_lagged_subsequences(
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
) -> torch.Tensor:
"""
Returns lagged subsequences of a given sequence.
Parameters
----------
sequence : Tensor
the sequence from which lagged subsequences should be extracted.
Shape: (N, T, C).
subsequences_length : int
length of the subsequences to be extracted.
shift: int
shift the lags by this amount back.
Returns
--------
lagged : Tensor
a tensor of shape (N, S, C, I), where S = subsequences_length and
I = len(indices), containing lagged subsequences. Specifically,
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
"""
sequence_length = sequence.shape[1]
indices = [lag - shift for lag in self.lags_seq]
assert max(indices) + subsequences_length <= sequence_length, (
f"lags cannot go further than history length, found lag {max(indices)} "
f"while history length is only {sequence_length}"
)
lagged_values = []
for lag_index in indices:
begin_index = -lag_index - subsequences_length
end_index = -lag_index if lag_index > 0 else None
lagged_values.append(sequence[:, begin_index:end_index, ...])
return torch.stack(lagged_values, dim=-1)
def _check_shapes(
self,
prior_input: torch.Tensor,
inputs: torch.Tensor,
features: Optional[torch.Tensor],
) -> None:
assert len(prior_input.shape) == len(inputs.shape)
assert (
len(prior_input.shape) == 2 and self.input_size == 1
) or prior_input.shape[2] == self.input_size
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
-1
] == self.input_size
assert (
features is None or features.shape[2] == self._number_of_features
), f"{features.shape[2]}, expected {self._number_of_features}"
def create_network_inputs(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: Optional[torch.Tensor] = None,
future_target: Optional[torch.Tensor] = None,
):
# time feature
time_feat = (
torch.cat(
(
past_time_feat[:, self._past_length - self.context_length :, ...],
future_time_feat,
),
dim=1,
)
if future_target is not None
else past_time_feat[:, self._past_length - self.context_length :, ...]
)
# target
context = past_target[:, -self.context_length :]
observed_context = past_observed_values[:, -self.context_length :]
_, scale = self.scaler(context, observed_context)
inputs = (
torch.cat((past_target, future_target), dim=1) / scale
if future_target is not None
else past_target / scale
)
inputs_length = (
self._past_length + self.prediction_length
if future_target is not None
else self._past_length
)
assert inputs.shape[1] == inputs_length
subsequences_length = (
self.context_length + self.prediction_length
if future_target is not None
else self.context_length
)
# embeddings
embedded_cat = self.embedder(feat_static_cat)
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
# self._check_shapes(prior_input, inputs, features)
# sequence = torch.cat((prior_input, inputs), dim=1)
lagged_sequence = self.get_lagged_subsequences(
sequence=inputs,
subsequences_length=subsequences_length,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
return transformer_inputs, scale, static_feat
def output_params(self, transformer_inputs):
enc_input = transformer_inputs[:, : self.context_length, ...]
dec_input = transformer_inputs[:, self.context_length :, ...]
enc_out = self.transformer.encoder(enc_input)
dec_output = self.transformer.decoder(
dec_input, enc_out, tgt_mask=self.tgt_mask
)
return self.param_proj(dec_output)
@torch.jit.ignore
def output_distribution(
self, params, scale=None, trailing_n=None
) -> torch.distributions.Distribution:
sliced_params = params
if trailing_n is not None:
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distr_output.distribution(sliced_params, scale=scale)
# for prediction
def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: torch.Tensor,
num_parallel_samples: Optional[int] = None,
) -> torch.Tensor:
if num_parallel_samples is None:
num_parallel_samples = self.num_parallel_samples
encoder_inputs, scale, static_feat = self.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
)
enc_out = self.transformer.encoder(encoder_inputs)
repeated_scale = scale.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_past_target = (
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
/ repeated_scale
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, future_time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
repeated_features = features.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_enc_out = enc_out.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
future_samples = []
# greedy decoding
for k in range(self.prediction_length):
# self._check_shapes(repeated_past_target, next_sample, next_features)
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
lagged_sequence = self.get_lagged_subsequences(
sequence=repeated_past_target,
subsequences_length=1 + k,
shift=1,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
decoder_input = torch.cat(
(reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1
)
output = self.transformer.decoder(decoder_input, repeated_enc_out)
params = self.param_proj(output[:, -1:])
distr = self.output_distribution(params, scale=repeated_scale)
next_sample = distr.sample()
repeated_past_target = torch.cat(
(repeated_past_target, next_sample / repeated_scale), dim=1
)
future_samples.append(next_sample)
concat_future_samples = torch.cat(future_samples, dim=1)
return concat_future_samples.reshape(
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
)