Files
pytorch-transformer-ts/autoformer/module.py
T
2022-05-28 18:49:50 +02:00

856 lines
28 KiB
Python

import math
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from gluonts.core.component import validated
from gluonts.time_feature import get_lags_for_frequency
from gluonts.torch.modules.distribution_output import DistributionOutput, StudentTOutput
from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1
self.tokenConv = nn.Conv1d(
in_channels=c_in,
out_channels=d_model,
kernel_size=3,
padding=padding,
padding_mode="circular",
bias=False,
)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode="fan_in", nonlinearity="leaky_relu"
)
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class DataEmbedding_wo_pos(nn.Module):
def __init__(self, x_in, x_mark_in, d_model, dropout=0.1):
super(DataEmbedding_wo_pos, self).__init__()
self.value_embedding = TokenEmbedding(c_in=x_in, d_model=d_model)
self.temporal_embedding = nn.Linear(x_mark_in, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
return self.dropout(x)
class my_Layernorm(nn.Module):
"""
Special designed layernorm for the seasonal part
"""
@validated()
def __init__(self, channels):
super(my_Layernorm, self).__init__()
self.layernorm = nn.LayerNorm(channels)
def forward(self, x):
x_hat = self.layernorm(x)
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
return x_hat - bias
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
@validated()
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
@validated()
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = moving_avg(kernel_size, stride=1)
def forward(self, x):
moving_mean = self.moving_avg(x)
res = x - moving_mean
return res, moving_mean
class EncoderLayer(nn.Module):
"""
Autoformer encoder layer with the progressive decomposition architecture
"""
@validated()
def __init__(
self,
attention,
d_model,
d_ff=None,
moving_avg=25,
dropout=0.1,
activation="relu",
):
super(EncoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.attention = attention
self.conv1 = nn.Conv1d(
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
)
self.conv2 = nn.Conv1d(
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
)
self.decomp1 = series_decomp(moving_avg)
self.decomp2 = series_decomp(moving_avg)
self.dropout = nn.Dropout(dropout)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, attn_mask=None):
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask)
x = x + self.dropout(new_x)
x, _ = self.decomp1(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
res, _ = self.decomp2(x + y)
return res, attn
class DecoderLayer(nn.Module):
"""
Autoformer decoder layer with the progressive decomposition architecture
"""
@validated()
def __init__(
self,
self_attention,
cross_attention,
d_model,
c_out,
d_ff=None,
moving_avg=25,
dropout=0.1,
activation="relu",
):
super(DecoderLayer, self).__init__()
d_ff = d_ff or 4 * d_model
self.self_attention = self_attention
self.cross_attention = cross_attention
self.conv1 = nn.Conv1d(
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
)
self.conv2 = nn.Conv1d(
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
)
self.decomp1 = series_decomp(moving_avg)
self.decomp2 = series_decomp(moving_avg)
self.decomp3 = series_decomp(moving_avg)
self.dropout = nn.Dropout(dropout)
self.projection = nn.Conv1d(
in_channels=d_model,
out_channels=c_out,
kernel_size=3,
stride=1,
padding=1,
padding_mode="circular",
bias=False,
)
self.activation = F.relu if activation == "relu" else F.gelu
def forward(self, x, cross, x_mask=None, cross_mask=None):
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
x, trend1 = self.decomp1(x)
x = x + self.dropout(
self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
)
x, trend2 = self.decomp2(x)
y = x
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
y = self.dropout(self.conv2(y).transpose(-1, 1))
x, trend3 = self.decomp3(x + y)
residual_trend = trend1 + trend2 + trend3
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(
1, 2
)
return x, residual_trend
class Encoder(nn.Module):
"""
Autoformer encoder
"""
@validated()
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
super(Encoder, self).__init__()
self.attn_layers = nn.ModuleList(attn_layers)
self.conv_layers = (
nn.ModuleList(conv_layers) if conv_layers is not None else None
)
self.norm = norm_layer
def forward(self, x, attn_mask=None):
attns = []
if self.conv_layers is not None:
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
x, attn = attn_layer(x, attn_mask=attn_mask)
x = conv_layer(x)
attns.append(attn)
x, attn = self.attn_layers[-1](x)
attns.append(attn)
else:
for attn_layer in self.attn_layers:
x, attn = attn_layer(x, attn_mask=attn_mask)
attns.append(attn)
if self.norm is not None:
x = self.norm(x)
return x, attns
class Decoder(nn.Module):
"""
Autoformer encoder
"""
@validated()
def __init__(self, layers, norm_layer=None, projection=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
self.projection = projection
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
for layer in self.layers:
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
trend = trend + residual_trend
if self.norm is not None:
x = self.norm(x)
if self.projection is not None:
x = self.projection(x)
return x, trend
class AutoCorrelation(nn.Module):
"""
AutoCorrelation Mechanism with the following two phases:
(1) period-based dependencies discovery
(2) time delay aggregation
This block can replace the self-attention family mechanism seamlessly.
"""
@validated()
def __init__(
self,
mask_flag=True,
factor=1,
scale=None,
attention_dropout=0.1,
output_attention=False,
):
super(AutoCorrelation, self).__init__()
self.factor = factor
self.scale = scale
self.mask_flag = mask_flag
self.output_attention = output_attention
self.dropout = nn.Dropout(attention_dropout)
def time_delay_agg_training(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the training phase.
"""
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
_, index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
pattern = torch.roll(tmp_values, -int(index[i]), -1)
delays_agg = delays_agg + pattern * (
tmp_corr[:, i]
.unsqueeze(1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, head, channel, length)
)
return delays_agg
def time_delay_agg_inference(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the inference phase.
"""
batch = values.shape[0]
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# index init
init_index = (
torch.arange(length)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0)
.repeat(batch, head, channel, 1)
.to(values.device)
)
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
weights, delay = torch.topk(mean_value, top_k, dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values.repeat(1, 1, 1, 2)
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(
1
).repeat(1, head, channel, length)
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
delays_agg = delays_agg + pattern * (
tmp_corr[:, i]
.unsqueeze(1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, head, channel, length)
)
return delays_agg
def time_delay_agg_full(self, values, corr):
"""
Standard version of Autocorrelation
"""
batch = values.shape[0]
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# index init
init_index = (
torch.arange(length)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0)
.repeat(batch, head, channel, 1)
.to(values.device)
)
# find top k
top_k = int(self.factor * math.log(length))
weights, delay = torch.topk(corr, top_k, dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values.repeat(1, 1, 1, 2)
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
return delays_agg
def forward(self, queries, keys, values, attn_mask):
B, L, H, E = queries.shape
_, S, _, D = values.shape
if L > S:
zeros = torch.zeros_like(queries[:, : (L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]
# period-based dependencies
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
res = q_fft * torch.conj(k_fft)
corr = torch.fft.irfft(res, dim=-1)
# time delay agg
if self.training:
V = self.time_delay_agg_training(
values.permute(0, 2, 3, 1).contiguous(), corr
).permute(0, 3, 1, 2)
else:
V = self.time_delay_agg_inference(
values.permute(0, 2, 3, 1).contiguous(), corr
).permute(0, 3, 1, 2)
if self.output_attention:
return (V.contiguous(), corr.permute(0, 3, 1, 2))
else:
return (V.contiguous(), None)
class AutoCorrelationLayer(nn.Module):
@validated()
def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None):
super(AutoCorrelationLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
d_values = d_values or (d_model // n_heads)
self.inner_correlation = correlation
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
self.value_projection = nn.Linear(d_model, d_values * n_heads)
self.out_projection = nn.Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
def forward(self, queries, keys, values, attn_mask):
B, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
queries = self.query_projection(queries).view(B, L, H, -1)
keys = self.key_projection(keys).view(B, S, H, -1)
values = self.value_projection(values).view(B, S, H, -1)
out, attn = self.inner_correlation(queries, keys, values, attn_mask)
out = out.view(B, L, -1)
return self.out_projection(out), attn
class AutoformerModel(nn.Module):
@validated()
def __init__(
self,
freq: str,
context_length: int,
prediction_length: int,
num_feat_dynamic_real: int,
num_feat_static_real: int,
num_feat_static_cat: int,
cardinality: List[int],
# autoformer arguments
n_heads: int,
num_encoder_layers: int,
num_decoder_layers: int,
dim_feedforward: int,
activation: str = "gelu",
dropout: float = 0.1,
factor: int = 1,
moving_avg: int = 25,
# univariate input
input_size: int = 1,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
lags_seq: Optional[List[int]] = None,
scaling: bool = True,
num_parallel_samples: int = 100,
) -> None:
super().__init__()
self.input_size = input_size
self.target_shape = distr_output.event_shape
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.embedding_dimension = (
embedding_dimension
if embedding_dimension is not None or cardinality is None
else [min(50, (cat + 1) // 2) for cat in cardinality]
)
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
self.num_parallel_samples = num_parallel_samples
self.history_length = context_length + max(self.lags_seq)
self.embedder = FeatureEmbedder(
cardinalities=cardinality,
embedding_dims=self.embedding_dimension,
)
if scaling:
self.scaler = MeanScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
# total feature size
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
self.context_length = context_length
self.prediction_length = prediction_length
self.label_length = context_length // 2
# Input decomposition
self.decomp = series_decomp(kernel_size=moving_avg)
# output projection
self.distr_output = distr_output
self.param_proj = distr_output.get_args_proj(d_model)
# embeddings
self.dec_embedding = DataEmbedding_wo_pos(
x_in=d_model, x_mark_in=self._number_of_features, d_model=d_model
)
# autoformer enc-decoder and mask initializer
self.encoder = Encoder(
[
EncoderLayer(
AutoCorrelationLayer(
AutoCorrelation(
False,
factor,
attention_dropout=dropout,
output_attention=False,
),
d_model,
n_heads,
),
d_model,
dim_feedforward,
moving_avg=moving_avg,
dropout=dropout,
activation=activation,
)
for l in range(num_encoder_layers)
],
norm_layer=my_Layernorm(d_model),
)
self.decoder = Decoder(
[
DecoderLayer(
AutoCorrelationLayer(
AutoCorrelation(
True,
factor,
attention_dropout=dropout,
output_attention=False,
),
d_model,
n_heads,
),
AutoCorrelationLayer(
AutoCorrelation(
False,
factor,
attention_dropout=dropout,
output_attention=False,
),
d_model,
n_heads,
),
d_model,
d_model,
dim_feedforward,
moving_avg=moving_avg,
dropout=dropout,
activation=activation,
)
for l in range(num_decoder_layers)
],
norm_layer=my_Layernorm(d_model),
projection=None,
)
@property
def _number_of_features(self) -> int:
return (
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
)
@property
def _past_length(self) -> int:
return self.context_length + max(self.lags_seq)
def get_lagged_subsequences(
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
) -> torch.Tensor:
"""
Returns lagged subsequences of a given sequence.
Parameters
----------
sequence : Tensor
the sequence from which lagged subsequences should be extracted.
Shape: (N, T, C).
subsequences_length : int
length of the subsequences to be extracted.
shift: int
shift the lags by this amount back.
Returns
--------
lagged : Tensor
a tensor of shape (N, S, C, I), where S = subsequences_length and
I = len(indices), containing lagged subsequences. Specifically,
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
"""
sequence_length = sequence.shape[1]
indices = [lag - shift for lag in self.lags_seq]
assert max(indices) + subsequences_length <= sequence_length, (
f"lags cannot go further than history length, found lag {max(indices)} "
f"while history length is only {sequence_length}"
)
lagged_values = []
for lag_index in indices:
begin_index = -lag_index - subsequences_length
end_index = -lag_index if lag_index > 0 else None
lagged_values.append(sequence[:, begin_index:end_index, ...])
return torch.stack(lagged_values, dim=-1)
def _check_shapes(
self,
prior_input: torch.Tensor,
inputs: torch.Tensor,
features: Optional[torch.Tensor],
) -> None:
assert len(prior_input.shape) == len(inputs.shape)
assert (
len(prior_input.shape) == 2 and self.input_size == 1
) or prior_input.shape[2] == self.input_size
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
-1
] == self.input_size
assert (
features is None or features.shape[2] == self._number_of_features
), f"{features.shape[2]}, expected {self._number_of_features}"
def create_network_inputs(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: Optional[torch.Tensor] = None,
future_target: Optional[torch.Tensor] = None,
):
# time feature
time_feat = (
torch.cat(
(
past_time_feat[:, self._past_length - self.context_length :, ...],
future_time_feat,
),
dim=1,
)
if future_target is not None
else past_time_feat[:, self._past_length - self.context_length :, ...]
)
# target
context = past_target[:, -self.context_length :]
observed_context = past_observed_values[:, -self.context_length :]
_, scale = self.scaler(context, observed_context)
inputs = (
torch.cat((past_target, future_target), dim=1) / scale
if future_target is not None
else past_target / scale
)
inputs_length = (
self._past_length + self.prediction_length
if future_target is not None
else self._past_length
)
assert inputs.shape[1] == inputs_length
subsequences_length = (
self.context_length + self.prediction_length
if future_target is not None
else self.context_length
)
# embeddings
embedded_cat = self.embedder(feat_static_cat)
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, time_feat.shape[1], -1
)
dynamic_features = torch.cat((expanded_static_feat, time_feat), dim=-1)
# self._check_shapes(prior_input, inputs, features)
# sequence = torch.cat((prior_input, inputs), dim=1)
lagged_sequence = self.get_lagged_subsequences(
sequence=inputs,
subsequences_length=subsequences_length,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
transformer_inputs = torch.cat(
(reshaped_lagged_sequence, dynamic_features), dim=-1
)
return transformer_inputs, scale, dynamic_features, static_feat
def output_params(self, transformer_inputs, dynamic_features):
enc_input = transformer_inputs[:, : self.context_length, ...]
# dec_input = transformer_inputs[:, self.context_length :, ...]
dec_dynamic_feat = dynamic_features[
:, self.context_length - self.label_length :, ...
]
# decomp init
mean = (
torch.mean(enc_input, dim=1)
.unsqueeze(1)
.repeat(1, self.prediction_length, 1)
)
zeros = torch.zeros(
[enc_input.shape[0], self.prediction_length, enc_input.shape[2]],
device=enc_input.device,
)
seasonal_init, trend_init = self.decomp(enc_input)
# decoder input
trend_init = torch.cat([trend_init[:, -self.label_length :, :], mean], dim=1)
seasonal_init = torch.cat(
[seasonal_init[:, -self.label_length :, :], zeros], dim=1
)
# enc
enc_out, _ = self.encoder(enc_input, attn_mask=None)
# dec
dec_input = self.dec_embedding(seasonal_init, dec_dynamic_feat)
seasonal_part, trend_part = self.decoder(
dec_input, enc_out, x_mask=None, cross_mask=None, trend=trend_init
)
# final
dec_out = trend_part + seasonal_part
return self.param_proj(dec_out[:, -self.prediction_length :, :])
@torch.jit.ignore
def output_distribution(
self, params, scale=None, trailing_n=None
) -> torch.distributions.Distribution:
sliced_params = params
if trailing_n is not None:
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distr_output.distribution(sliced_params, scale=scale)
# for prediction
def forward(
self,
feat_static_cat: torch.Tensor,
feat_static_real: torch.Tensor,
past_time_feat: torch.Tensor,
past_target: torch.Tensor,
past_observed_values: torch.Tensor,
future_time_feat: torch.Tensor,
num_parallel_samples: Optional[int] = None,
) -> torch.Tensor:
if num_parallel_samples is None:
num_parallel_samples = self.num_parallel_samples
enc_input, scale, dynamic_feat, static_feat = self.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, future_time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
dec_dynamic_feat = torch.cat(
(dynamic_feat[:, -self.label_length :, :], features), dim=1
)
# decomp init
mean = (
torch.mean(enc_input, dim=1)
.unsqueeze(1)
.repeat(1, self.prediction_length, 1)
)
zeros = torch.zeros(
[enc_input.shape[0], self.prediction_length, enc_input.shape[2]],
device=enc_input.device,
)
seasonal_init, trend_init = self.decomp(enc_input)
# decoder input
trend_init = torch.cat([trend_init[:, -self.label_length :, :], mean], dim=1)
seasonal_init = torch.cat(
[seasonal_init[:, -self.label_length :, :], zeros], dim=1
)
# enc
enc_out, _ = self.encoder(enc_input, attn_mask=None)
# dec
dec_input = self.dec_embedding(seasonal_init, dec_dynamic_feat)
seasonal_part, trend_part = self.decoder(
dec_input, enc_out, x_mask=None, cross_mask=None, trend=trend_init
)
# output params
dec_out = trend_part + seasonal_part
params = self.param_proj(dec_out[:, -self.prediction_length :, :])
repeated_params = [
s.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
for s in params
]
repeated_scale = scale.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
distr = self.output_distribution(repeated_params, scale=repeated_scale)
# Future samples
samples = distr.sample()
return samples.reshape(
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
)