mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 19:16:11 +08:00
857 lines
28 KiB
Python
857 lines
28 KiB
Python
import math
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from gluonts.core.component import validated
|
|
from gluonts.time_feature import get_lags_for_frequency
|
|
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
|
from gluonts.torch.modules.feature import FeatureEmbedder
|
|
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
|
|
|
|
|
class TokenEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model):
|
|
super(TokenEmbedding, self).__init__()
|
|
padding = 1
|
|
self.tokenConv = nn.Conv1d(
|
|
in_channels=c_in,
|
|
out_channels=d_model,
|
|
kernel_size=3,
|
|
padding=padding,
|
|
padding_mode="circular",
|
|
bias=False,
|
|
)
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.kaiming_normal_(
|
|
m.weight, mode="fan_in", nonlinearity="leaky_relu"
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
|
return x
|
|
|
|
|
|
class DataEmbedding_wo_pos(nn.Module):
|
|
def __init__(self, x_in, x_mark_in, d_model, dropout=0.1):
|
|
super(DataEmbedding_wo_pos, self).__init__()
|
|
|
|
self.value_embedding = TokenEmbedding(c_in=x_in, d_model=d_model)
|
|
self.temporal_embedding = nn.Linear(x_mark_in, d_model)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
def forward(self, x, x_mark):
|
|
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
|
|
return self.dropout(x)
|
|
|
|
|
|
class my_Layernorm(nn.Module):
|
|
"""
|
|
Special designed layernorm for the seasonal part
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(self, channels):
|
|
super(my_Layernorm, self).__init__()
|
|
self.layernorm = nn.LayerNorm(channels)
|
|
|
|
def forward(self, x):
|
|
x_hat = self.layernorm(x)
|
|
bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1)
|
|
return x_hat - bias
|
|
|
|
|
|
class moving_avg(nn.Module):
|
|
"""
|
|
Moving average block to highlight the trend of time series
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(self, kernel_size, stride):
|
|
super(moving_avg, self).__init__()
|
|
self.kernel_size = kernel_size
|
|
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
|
|
|
|
def forward(self, x):
|
|
# padding on the both ends of time series
|
|
front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
|
|
x = torch.cat([front, x, end], dim=1)
|
|
x = self.avg(x.permute(0, 2, 1))
|
|
x = x.permute(0, 2, 1)
|
|
return x
|
|
|
|
|
|
class series_decomp(nn.Module):
|
|
"""
|
|
Series decomposition block
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(self, kernel_size):
|
|
super(series_decomp, self).__init__()
|
|
self.moving_avg = moving_avg(kernel_size, stride=1)
|
|
|
|
def forward(self, x):
|
|
moving_mean = self.moving_avg(x)
|
|
res = x - moving_mean
|
|
return res, moving_mean
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
"""
|
|
Autoformer encoder layer with the progressive decomposition architecture
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
attention,
|
|
d_model,
|
|
d_ff=None,
|
|
moving_avg=25,
|
|
dropout=0.1,
|
|
activation="relu",
|
|
):
|
|
super(EncoderLayer, self).__init__()
|
|
d_ff = d_ff or 4 * d_model
|
|
self.attention = attention
|
|
self.conv1 = nn.Conv1d(
|
|
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
|
|
)
|
|
self.conv2 = nn.Conv1d(
|
|
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
|
|
)
|
|
self.decomp1 = series_decomp(moving_avg)
|
|
self.decomp2 = series_decomp(moving_avg)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.activation = F.relu if activation == "relu" else F.gelu
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask)
|
|
x = x + self.dropout(new_x)
|
|
x, _ = self.decomp1(x)
|
|
y = x
|
|
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
|
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
|
res, _ = self.decomp2(x + y)
|
|
return res, attn
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
"""
|
|
Autoformer decoder layer with the progressive decomposition architecture
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
self_attention,
|
|
cross_attention,
|
|
d_model,
|
|
c_out,
|
|
d_ff=None,
|
|
moving_avg=25,
|
|
dropout=0.1,
|
|
activation="relu",
|
|
):
|
|
super(DecoderLayer, self).__init__()
|
|
d_ff = d_ff or 4 * d_model
|
|
self.self_attention = self_attention
|
|
self.cross_attention = cross_attention
|
|
self.conv1 = nn.Conv1d(
|
|
in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False
|
|
)
|
|
self.conv2 = nn.Conv1d(
|
|
in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False
|
|
)
|
|
self.decomp1 = series_decomp(moving_avg)
|
|
self.decomp2 = series_decomp(moving_avg)
|
|
self.decomp3 = series_decomp(moving_avg)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.projection = nn.Conv1d(
|
|
in_channels=d_model,
|
|
out_channels=c_out,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
padding_mode="circular",
|
|
bias=False,
|
|
)
|
|
self.activation = F.relu if activation == "relu" else F.gelu
|
|
|
|
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
|
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
|
|
x, trend1 = self.decomp1(x)
|
|
x = x + self.dropout(
|
|
self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
|
|
)
|
|
x, trend2 = self.decomp2(x)
|
|
y = x
|
|
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
|
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
|
x, trend3 = self.decomp3(x + y)
|
|
|
|
residual_trend = trend1 + trend2 + trend3
|
|
residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose(
|
|
1, 2
|
|
)
|
|
return x, residual_trend
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
"""
|
|
Autoformer encoder
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
|
super(Encoder, self).__init__()
|
|
self.attn_layers = nn.ModuleList(attn_layers)
|
|
self.conv_layers = (
|
|
nn.ModuleList(conv_layers) if conv_layers is not None else None
|
|
)
|
|
self.norm = norm_layer
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
attns = []
|
|
if self.conv_layers is not None:
|
|
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
|
|
x, attn = attn_layer(x, attn_mask=attn_mask)
|
|
x = conv_layer(x)
|
|
attns.append(attn)
|
|
x, attn = self.attn_layers[-1](x)
|
|
attns.append(attn)
|
|
else:
|
|
for attn_layer in self.attn_layers:
|
|
x, attn = attn_layer(x, attn_mask=attn_mask)
|
|
attns.append(attn)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
return x, attns
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
"""
|
|
Autoformer encoder
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(self, layers, norm_layer=None, projection=None):
|
|
super(Decoder, self).__init__()
|
|
self.layers = nn.ModuleList(layers)
|
|
self.norm = norm_layer
|
|
self.projection = projection
|
|
|
|
def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None):
|
|
for layer in self.layers:
|
|
x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
|
|
trend = trend + residual_trend
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
if self.projection is not None:
|
|
x = self.projection(x)
|
|
return x, trend
|
|
|
|
|
|
class AutoCorrelation(nn.Module):
|
|
"""
|
|
AutoCorrelation Mechanism with the following two phases:
|
|
(1) period-based dependencies discovery
|
|
(2) time delay aggregation
|
|
This block can replace the self-attention family mechanism seamlessly.
|
|
"""
|
|
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
mask_flag=True,
|
|
factor=1,
|
|
scale=None,
|
|
attention_dropout=0.1,
|
|
output_attention=False,
|
|
):
|
|
super(AutoCorrelation, self).__init__()
|
|
self.factor = factor
|
|
self.scale = scale
|
|
self.mask_flag = mask_flag
|
|
self.output_attention = output_attention
|
|
self.dropout = nn.Dropout(attention_dropout)
|
|
|
|
def time_delay_agg_training(self, values, corr):
|
|
"""
|
|
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
|
This is for the training phase.
|
|
"""
|
|
head = values.shape[1]
|
|
channel = values.shape[2]
|
|
length = values.shape[3]
|
|
# find top k
|
|
top_k = int(self.factor * math.log(length))
|
|
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
|
_, index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)
|
|
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
|
|
# update corr
|
|
tmp_corr = torch.softmax(weights, dim=-1)
|
|
# aggregation
|
|
tmp_values = values
|
|
delays_agg = torch.zeros_like(values).float()
|
|
for i in range(top_k):
|
|
pattern = torch.roll(tmp_values, -int(index[i]), -1)
|
|
delays_agg = delays_agg + pattern * (
|
|
tmp_corr[:, i]
|
|
.unsqueeze(1)
|
|
.unsqueeze(1)
|
|
.unsqueeze(1)
|
|
.repeat(1, head, channel, length)
|
|
)
|
|
return delays_agg
|
|
|
|
def time_delay_agg_inference(self, values, corr):
|
|
"""
|
|
SpeedUp version of Autocorrelation (a batch-normalization style design)
|
|
This is for the inference phase.
|
|
"""
|
|
batch = values.shape[0]
|
|
head = values.shape[1]
|
|
channel = values.shape[2]
|
|
length = values.shape[3]
|
|
# index init
|
|
init_index = (
|
|
torch.arange(length)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.repeat(batch, head, channel, 1)
|
|
.to(values.device)
|
|
)
|
|
# find top k
|
|
top_k = int(self.factor * math.log(length))
|
|
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
|
|
weights, delay = torch.topk(mean_value, top_k, dim=-1)
|
|
# update corr
|
|
tmp_corr = torch.softmax(weights, dim=-1)
|
|
# aggregation
|
|
tmp_values = values.repeat(1, 1, 1, 2)
|
|
delays_agg = torch.zeros_like(values).float()
|
|
for i in range(top_k):
|
|
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(
|
|
1
|
|
).repeat(1, head, channel, length)
|
|
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
|
delays_agg = delays_agg + pattern * (
|
|
tmp_corr[:, i]
|
|
.unsqueeze(1)
|
|
.unsqueeze(1)
|
|
.unsqueeze(1)
|
|
.repeat(1, head, channel, length)
|
|
)
|
|
return delays_agg
|
|
|
|
def time_delay_agg_full(self, values, corr):
|
|
"""
|
|
Standard version of Autocorrelation
|
|
"""
|
|
batch = values.shape[0]
|
|
head = values.shape[1]
|
|
channel = values.shape[2]
|
|
length = values.shape[3]
|
|
# index init
|
|
init_index = (
|
|
torch.arange(length)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.repeat(batch, head, channel, 1)
|
|
.to(values.device)
|
|
)
|
|
# find top k
|
|
top_k = int(self.factor * math.log(length))
|
|
weights, delay = torch.topk(corr, top_k, dim=-1)
|
|
# update corr
|
|
tmp_corr = torch.softmax(weights, dim=-1)
|
|
# aggregation
|
|
tmp_values = values.repeat(1, 1, 1, 2)
|
|
delays_agg = torch.zeros_like(values).float()
|
|
for i in range(top_k):
|
|
tmp_delay = init_index + delay[..., i].unsqueeze(-1)
|
|
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
|
|
delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
|
|
return delays_agg
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, H, E = queries.shape
|
|
_, S, _, D = values.shape
|
|
if L > S:
|
|
zeros = torch.zeros_like(queries[:, : (L - S), :]).float()
|
|
values = torch.cat([values, zeros], dim=1)
|
|
keys = torch.cat([keys, zeros], dim=1)
|
|
else:
|
|
values = values[:, :L, :, :]
|
|
keys = keys[:, :L, :, :]
|
|
|
|
# period-based dependencies
|
|
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
|
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
|
|
res = q_fft * torch.conj(k_fft)
|
|
corr = torch.fft.irfft(res, dim=-1)
|
|
|
|
# time delay agg
|
|
if self.training:
|
|
V = self.time_delay_agg_training(
|
|
values.permute(0, 2, 3, 1).contiguous(), corr
|
|
).permute(0, 3, 1, 2)
|
|
else:
|
|
V = self.time_delay_agg_inference(
|
|
values.permute(0, 2, 3, 1).contiguous(), corr
|
|
).permute(0, 3, 1, 2)
|
|
|
|
if self.output_attention:
|
|
return (V.contiguous(), corr.permute(0, 3, 1, 2))
|
|
else:
|
|
return (V.contiguous(), None)
|
|
|
|
|
|
class AutoCorrelationLayer(nn.Module):
|
|
@validated()
|
|
def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None):
|
|
super(AutoCorrelationLayer, self).__init__()
|
|
|
|
d_keys = d_keys or (d_model // n_heads)
|
|
d_values = d_values or (d_model // n_heads)
|
|
|
|
self.inner_correlation = correlation
|
|
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
|
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, _ = queries.shape
|
|
_, S, _ = keys.shape
|
|
H = self.n_heads
|
|
|
|
queries = self.query_projection(queries).view(B, L, H, -1)
|
|
keys = self.key_projection(keys).view(B, S, H, -1)
|
|
values = self.value_projection(values).view(B, S, H, -1)
|
|
|
|
out, attn = self.inner_correlation(queries, keys, values, attn_mask)
|
|
out = out.view(B, L, -1)
|
|
|
|
return self.out_projection(out), attn
|
|
|
|
|
|
class AutoformerModel(nn.Module):
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
freq: str,
|
|
context_length: int,
|
|
prediction_length: int,
|
|
num_feat_dynamic_real: int,
|
|
num_feat_static_real: int,
|
|
num_feat_static_cat: int,
|
|
cardinality: List[int],
|
|
# autoformer arguments
|
|
n_heads: int,
|
|
num_encoder_layers: int,
|
|
num_decoder_layers: int,
|
|
dim_feedforward: int,
|
|
activation: str = "gelu",
|
|
dropout: float = 0.1,
|
|
factor: int = 1,
|
|
moving_avg: int = 25,
|
|
# univariate input
|
|
input_size: int = 1,
|
|
embedding_dimension: Optional[List[int]] = None,
|
|
distr_output: DistributionOutput = StudentTOutput(),
|
|
lags_seq: Optional[List[int]] = None,
|
|
scaling: bool = True,
|
|
num_parallel_samples: int = 100,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.input_size = input_size
|
|
|
|
self.target_shape = distr_output.event_shape
|
|
self.num_feat_dynamic_real = num_feat_dynamic_real
|
|
self.num_feat_static_cat = num_feat_static_cat
|
|
self.num_feat_static_real = num_feat_static_real
|
|
self.embedding_dimension = (
|
|
embedding_dimension
|
|
if embedding_dimension is not None or cardinality is None
|
|
else [min(50, (cat + 1) // 2) for cat in cardinality]
|
|
)
|
|
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
|
self.num_parallel_samples = num_parallel_samples
|
|
self.history_length = context_length + max(self.lags_seq)
|
|
self.embedder = FeatureEmbedder(
|
|
cardinalities=cardinality,
|
|
embedding_dims=self.embedding_dimension,
|
|
)
|
|
if scaling:
|
|
self.scaler = MeanScaler(dim=1, keepdim=True)
|
|
else:
|
|
self.scaler = NOPScaler(dim=1, keepdim=True)
|
|
|
|
# total feature size
|
|
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
|
|
|
|
self.context_length = context_length
|
|
self.prediction_length = prediction_length
|
|
self.label_length = context_length // 2
|
|
|
|
# Input decomposition
|
|
self.decomp = series_decomp(kernel_size=moving_avg)
|
|
|
|
# output projection
|
|
self.distr_output = distr_output
|
|
self.param_proj = distr_output.get_args_proj(d_model)
|
|
|
|
# embeddings
|
|
self.dec_embedding = DataEmbedding_wo_pos(
|
|
x_in=d_model, x_mark_in=self._number_of_features, d_model=d_model
|
|
)
|
|
|
|
# autoformer enc-decoder and mask initializer
|
|
self.encoder = Encoder(
|
|
[
|
|
EncoderLayer(
|
|
AutoCorrelationLayer(
|
|
AutoCorrelation(
|
|
False,
|
|
factor,
|
|
attention_dropout=dropout,
|
|
output_attention=False,
|
|
),
|
|
d_model,
|
|
n_heads,
|
|
),
|
|
d_model,
|
|
dim_feedforward,
|
|
moving_avg=moving_avg,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
)
|
|
for l in range(num_encoder_layers)
|
|
],
|
|
norm_layer=my_Layernorm(d_model),
|
|
)
|
|
|
|
self.decoder = Decoder(
|
|
[
|
|
DecoderLayer(
|
|
AutoCorrelationLayer(
|
|
AutoCorrelation(
|
|
True,
|
|
factor,
|
|
attention_dropout=dropout,
|
|
output_attention=False,
|
|
),
|
|
d_model,
|
|
n_heads,
|
|
),
|
|
AutoCorrelationLayer(
|
|
AutoCorrelation(
|
|
False,
|
|
factor,
|
|
attention_dropout=dropout,
|
|
output_attention=False,
|
|
),
|
|
d_model,
|
|
n_heads,
|
|
),
|
|
d_model,
|
|
d_model,
|
|
dim_feedforward,
|
|
moving_avg=moving_avg,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
)
|
|
for l in range(num_decoder_layers)
|
|
],
|
|
norm_layer=my_Layernorm(d_model),
|
|
projection=None,
|
|
)
|
|
|
|
@property
|
|
def _number_of_features(self) -> int:
|
|
return (
|
|
sum(self.embedding_dimension)
|
|
+ self.num_feat_dynamic_real
|
|
+ self.num_feat_static_real
|
|
+ self.input_size # the log(scale)
|
|
)
|
|
|
|
@property
|
|
def _past_length(self) -> int:
|
|
return self.context_length + max(self.lags_seq)
|
|
|
|
def get_lagged_subsequences(
|
|
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns lagged subsequences of a given sequence.
|
|
Parameters
|
|
----------
|
|
sequence : Tensor
|
|
the sequence from which lagged subsequences should be extracted.
|
|
Shape: (N, T, C).
|
|
subsequences_length : int
|
|
length of the subsequences to be extracted.
|
|
shift: int
|
|
shift the lags by this amount back.
|
|
Returns
|
|
--------
|
|
lagged : Tensor
|
|
a tensor of shape (N, S, C, I), where S = subsequences_length and
|
|
I = len(indices), containing lagged subsequences. Specifically,
|
|
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
|
"""
|
|
sequence_length = sequence.shape[1]
|
|
indices = [lag - shift for lag in self.lags_seq]
|
|
|
|
assert max(indices) + subsequences_length <= sequence_length, (
|
|
f"lags cannot go further than history length, found lag {max(indices)} "
|
|
f"while history length is only {sequence_length}"
|
|
)
|
|
|
|
lagged_values = []
|
|
for lag_index in indices:
|
|
begin_index = -lag_index - subsequences_length
|
|
end_index = -lag_index if lag_index > 0 else None
|
|
lagged_values.append(sequence[:, begin_index:end_index, ...])
|
|
return torch.stack(lagged_values, dim=-1)
|
|
|
|
def _check_shapes(
|
|
self,
|
|
prior_input: torch.Tensor,
|
|
inputs: torch.Tensor,
|
|
features: Optional[torch.Tensor],
|
|
) -> None:
|
|
assert len(prior_input.shape) == len(inputs.shape)
|
|
assert (
|
|
len(prior_input.shape) == 2 and self.input_size == 1
|
|
) or prior_input.shape[2] == self.input_size
|
|
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
|
|
-1
|
|
] == self.input_size
|
|
assert (
|
|
features is None or features.shape[2] == self._number_of_features
|
|
), f"{features.shape[2]}, expected {self._number_of_features}"
|
|
|
|
def create_network_inputs(
|
|
self,
|
|
feat_static_cat: torch.Tensor,
|
|
feat_static_real: torch.Tensor,
|
|
past_time_feat: torch.Tensor,
|
|
past_target: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
future_time_feat: Optional[torch.Tensor] = None,
|
|
future_target: Optional[torch.Tensor] = None,
|
|
):
|
|
# time feature
|
|
time_feat = (
|
|
torch.cat(
|
|
(
|
|
past_time_feat[:, self._past_length - self.context_length :, ...],
|
|
future_time_feat,
|
|
),
|
|
dim=1,
|
|
)
|
|
if future_target is not None
|
|
else past_time_feat[:, self._past_length - self.context_length :, ...]
|
|
)
|
|
|
|
# target
|
|
context = past_target[:, -self.context_length :]
|
|
observed_context = past_observed_values[:, -self.context_length :]
|
|
_, scale = self.scaler(context, observed_context)
|
|
|
|
inputs = (
|
|
torch.cat((past_target, future_target), dim=1) / scale
|
|
if future_target is not None
|
|
else past_target / scale
|
|
)
|
|
|
|
inputs_length = (
|
|
self._past_length + self.prediction_length
|
|
if future_target is not None
|
|
else self._past_length
|
|
)
|
|
assert inputs.shape[1] == inputs_length
|
|
|
|
subsequences_length = (
|
|
self.context_length + self.prediction_length
|
|
if future_target is not None
|
|
else self.context_length
|
|
)
|
|
|
|
# embeddings
|
|
embedded_cat = self.embedder(feat_static_cat)
|
|
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
|
static_feat = torch.cat(
|
|
(embedded_cat, feat_static_real, log_scale),
|
|
dim=1,
|
|
)
|
|
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
|
-1, time_feat.shape[1], -1
|
|
)
|
|
|
|
dynamic_features = torch.cat((expanded_static_feat, time_feat), dim=-1)
|
|
|
|
# self._check_shapes(prior_input, inputs, features)
|
|
|
|
# sequence = torch.cat((prior_input, inputs), dim=1)
|
|
lagged_sequence = self.get_lagged_subsequences(
|
|
sequence=inputs,
|
|
subsequences_length=subsequences_length,
|
|
)
|
|
|
|
lags_shape = lagged_sequence.shape
|
|
reshaped_lagged_sequence = lagged_sequence.reshape(
|
|
lags_shape[0], lags_shape[1], -1
|
|
)
|
|
|
|
transformer_inputs = torch.cat(
|
|
(reshaped_lagged_sequence, dynamic_features), dim=-1
|
|
)
|
|
|
|
return transformer_inputs, scale, dynamic_features, static_feat
|
|
|
|
def output_params(self, transformer_inputs, dynamic_features):
|
|
enc_input = transformer_inputs[:, : self.context_length, ...]
|
|
# dec_input = transformer_inputs[:, self.context_length :, ...]
|
|
dec_dynamic_feat = dynamic_features[
|
|
:, self.context_length - self.label_length :, ...
|
|
]
|
|
|
|
# decomp init
|
|
mean = (
|
|
torch.mean(enc_input, dim=1)
|
|
.unsqueeze(1)
|
|
.repeat(1, self.prediction_length, 1)
|
|
)
|
|
zeros = torch.zeros(
|
|
[enc_input.shape[0], self.prediction_length, enc_input.shape[2]],
|
|
device=enc_input.device,
|
|
)
|
|
seasonal_init, trend_init = self.decomp(enc_input)
|
|
|
|
# decoder input
|
|
trend_init = torch.cat([trend_init[:, -self.label_length :, :], mean], dim=1)
|
|
seasonal_init = torch.cat(
|
|
[seasonal_init[:, -self.label_length :, :], zeros], dim=1
|
|
)
|
|
|
|
# enc
|
|
enc_out, _ = self.encoder(enc_input, attn_mask=None)
|
|
|
|
# dec
|
|
dec_input = self.dec_embedding(seasonal_init, dec_dynamic_feat)
|
|
seasonal_part, trend_part = self.decoder(
|
|
dec_input, enc_out, x_mask=None, cross_mask=None, trend=trend_init
|
|
)
|
|
|
|
# final
|
|
dec_out = trend_part + seasonal_part
|
|
|
|
return self.param_proj(dec_out[:, -self.prediction_length :, :])
|
|
|
|
@torch.jit.ignore
|
|
def output_distribution(
|
|
self, params, scale=None, trailing_n=None
|
|
) -> torch.distributions.Distribution:
|
|
sliced_params = params
|
|
if trailing_n is not None:
|
|
sliced_params = [p[:, -trailing_n:] for p in params]
|
|
return self.distr_output.distribution(sliced_params, scale=scale)
|
|
|
|
# for prediction
|
|
def forward(
|
|
self,
|
|
feat_static_cat: torch.Tensor,
|
|
feat_static_real: torch.Tensor,
|
|
past_time_feat: torch.Tensor,
|
|
past_target: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
future_time_feat: torch.Tensor,
|
|
num_parallel_samples: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
|
|
if num_parallel_samples is None:
|
|
num_parallel_samples = self.num_parallel_samples
|
|
|
|
enc_input, scale, dynamic_feat, static_feat = self.create_network_inputs(
|
|
feat_static_cat,
|
|
feat_static_real,
|
|
past_time_feat,
|
|
past_target,
|
|
past_observed_values,
|
|
)
|
|
|
|
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
|
-1, future_time_feat.shape[1], -1
|
|
)
|
|
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
|
|
|
|
dec_dynamic_feat = torch.cat(
|
|
(dynamic_feat[:, -self.label_length :, :], features), dim=1
|
|
)
|
|
|
|
# decomp init
|
|
mean = (
|
|
torch.mean(enc_input, dim=1)
|
|
.unsqueeze(1)
|
|
.repeat(1, self.prediction_length, 1)
|
|
)
|
|
zeros = torch.zeros(
|
|
[enc_input.shape[0], self.prediction_length, enc_input.shape[2]],
|
|
device=enc_input.device,
|
|
)
|
|
seasonal_init, trend_init = self.decomp(enc_input)
|
|
|
|
# decoder input
|
|
trend_init = torch.cat([trend_init[:, -self.label_length :, :], mean], dim=1)
|
|
seasonal_init = torch.cat(
|
|
[seasonal_init[:, -self.label_length :, :], zeros], dim=1
|
|
)
|
|
|
|
# enc
|
|
enc_out, _ = self.encoder(enc_input, attn_mask=None)
|
|
|
|
# dec
|
|
dec_input = self.dec_embedding(seasonal_init, dec_dynamic_feat)
|
|
seasonal_part, trend_part = self.decoder(
|
|
dec_input, enc_out, x_mask=None, cross_mask=None, trend=trend_init
|
|
)
|
|
|
|
# output params
|
|
dec_out = trend_part + seasonal_part
|
|
params = self.param_proj(dec_out[:, -self.prediction_length :, :])
|
|
|
|
repeated_params = [
|
|
s.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
|
|
for s in params
|
|
]
|
|
|
|
repeated_scale = scale.repeat_interleave(
|
|
repeats=self.num_parallel_samples, dim=0
|
|
)
|
|
|
|
distr = self.output_distribution(repeated_params, scale=repeated_scale)
|
|
|
|
# Future samples
|
|
samples = distr.sample()
|
|
|
|
return samples.reshape(
|
|
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
|
)
|