mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
723 lines
24 KiB
Python
723 lines
24 KiB
Python
from math import sqrt
|
|
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from gluonts.core.component import validated
|
|
from gluonts.time_feature import get_lags_for_frequency
|
|
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
|
|
from gluonts.torch.modules.feature import FeatureEmbedder
|
|
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
|
|
|
|
|
class TriangularCausalMask:
|
|
def __init__(self, B, L, device="cpu"):
|
|
mask_shape = [B, 1, L, L]
|
|
with torch.no_grad():
|
|
self._mask = torch.triu(
|
|
torch.ones(mask_shape, dtype=torch.bool), diagonal=1
|
|
).to(device)
|
|
|
|
@property
|
|
def mask(self):
|
|
return self._mask
|
|
|
|
|
|
class ProbMask:
|
|
def __init__(self, B, H, L, index, scores, device="cpu"):
|
|
_mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
|
|
_mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
|
|
indicator = _mask_ex[
|
|
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
|
].to(device)
|
|
self._mask = indicator.view(scores.shape).to(device)
|
|
|
|
@property
|
|
def mask(self):
|
|
return self._mask
|
|
|
|
|
|
class FullAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
mask_flag=True,
|
|
factor=5,
|
|
scale=None,
|
|
attention_dropout=0.1,
|
|
output_attention=False,
|
|
):
|
|
super(FullAttention, self).__init__()
|
|
self.scale = scale
|
|
self.mask_flag = mask_flag
|
|
self.output_attention = output_attention
|
|
self.dropout = nn.Dropout(attention_dropout)
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, H, E = queries.shape
|
|
_, S, _, D = values.shape
|
|
scale = self.scale or 1.0 / sqrt(E)
|
|
|
|
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
|
if self.mask_flag:
|
|
if attn_mask is None:
|
|
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
|
|
|
scores.masked_fill_(attn_mask.mask, -np.inf)
|
|
|
|
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
|
V = torch.einsum("bhls,bshd->blhd", A, values)
|
|
|
|
if self.output_attention:
|
|
return (V.contiguous(), A)
|
|
else:
|
|
return (V.contiguous(), None)
|
|
|
|
|
|
class ProbAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
mask_flag=True,
|
|
factor=5,
|
|
scale=None,
|
|
attention_dropout=0.1,
|
|
output_attention=False,
|
|
):
|
|
super(ProbAttention, self).__init__()
|
|
self.factor = factor
|
|
self.scale = scale
|
|
self.mask_flag = mask_flag
|
|
self.output_attention = output_attention
|
|
self.dropout = nn.Dropout(attention_dropout)
|
|
|
|
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
|
|
# Q [B, H, L, D]
|
|
B, H, L_K, E = K.shape
|
|
_, _, L_Q, _ = Q.shape
|
|
|
|
# calculate the sampled Q_K
|
|
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
|
|
index_sample = torch.randint(
|
|
L_K, (L_Q, sample_k)
|
|
) # real U = U_part(factor*ln(L_k))*L_q
|
|
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
|
|
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(
|
|
-2
|
|
)
|
|
|
|
# find the Top_k query with sparisty measurement
|
|
M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
|
|
M_top = M.topk(n_top, sorted=False)[1]
|
|
|
|
# use the reduced Q to calculate Q_K
|
|
Q_reduce = Q[
|
|
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :
|
|
] # factor*ln(L_q)
|
|
Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
|
|
|
|
return Q_K, M_top
|
|
|
|
def _get_initial_context(self, V, L_Q):
|
|
B, H, L_V, D = V.shape
|
|
if not self.mask_flag:
|
|
# V_sum = V.sum(dim=-2)
|
|
V_sum = V.mean(dim=-2)
|
|
contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
|
|
else: # use mask
|
|
assert L_Q == L_V # requires that L_Q == L_V, i.e. for self-attention only
|
|
contex = V.cumsum(dim=-2)
|
|
return contex
|
|
|
|
def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
|
|
B, H, L_V, D = V.shape
|
|
|
|
if self.mask_flag:
|
|
attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
|
|
scores.masked_fill_(attn_mask.mask, -np.inf)
|
|
|
|
attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
|
|
|
|
context_in[
|
|
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
|
] = torch.matmul(attn, V).type_as(context_in)
|
|
if self.output_attention:
|
|
attns = (torch.ones([B, H, L_V, L_V]) / L_V).type_as(attn).to(attn.device)
|
|
attns[
|
|
torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :
|
|
] = attn
|
|
return (context_in, attns)
|
|
else:
|
|
return (context_in, None)
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L_Q, H, D = queries.shape
|
|
_, L_K, _, _ = keys.shape
|
|
|
|
queries = queries.transpose(2, 1)
|
|
keys = keys.transpose(2, 1)
|
|
values = values.transpose(2, 1)
|
|
|
|
U_part = self.factor * np.ceil(np.log1p(L_K)).astype("int").item() # c*ln(L_k)
|
|
u = self.factor * np.ceil(np.log1p(L_Q)).astype("int").item() # c*ln(L_q)
|
|
|
|
U_part = U_part if U_part < L_K else L_K
|
|
u = u if u < L_Q else L_Q
|
|
|
|
scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
|
|
|
|
# add scale factor
|
|
scale = self.scale or 1.0 / sqrt(D)
|
|
if scale is not None:
|
|
scores_top = scores_top * scale
|
|
# get the context
|
|
context = self._get_initial_context(values, L_Q)
|
|
# update the context with selected top_k queries
|
|
context, attn = self._update_context(
|
|
context, values, scores_top, index, L_Q, attn_mask
|
|
)
|
|
|
|
return context.transpose(2, 1).contiguous(), attn
|
|
|
|
|
|
class AttentionLayer(nn.Module):
|
|
def __init__(
|
|
self, attention, d_model, n_heads, d_keys=None, d_values=None, mix=False
|
|
):
|
|
super(AttentionLayer, self).__init__()
|
|
|
|
d_keys = d_keys or (d_model // n_heads)
|
|
d_values = d_values or (d_model // n_heads)
|
|
|
|
self.inner_attention = attention
|
|
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
self.key_projection = nn.Linear(d_model, d_keys * n_heads)
|
|
self.value_projection = nn.Linear(d_model, d_values * n_heads)
|
|
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
|
self.n_heads = n_heads
|
|
self.mix = mix
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, _ = queries.shape
|
|
_, S, _ = keys.shape
|
|
H = self.n_heads
|
|
|
|
queries = self.query_projection(queries).view(B, L, H, -1)
|
|
keys = self.key_projection(keys).view(B, S, H, -1)
|
|
values = self.value_projection(values).view(B, S, H, -1)
|
|
|
|
out, attn = self.inner_attention(queries, keys, values, attn_mask)
|
|
if self.mix:
|
|
out = out.transpose(2, 1).contiguous()
|
|
out = out.view(B, L, -1)
|
|
|
|
return self.out_projection(out), attn
|
|
|
|
|
|
class ConvLayer(nn.Module):
|
|
def __init__(self, c_in):
|
|
super(ConvLayer, self).__init__()
|
|
self.downConv = nn.Conv1d(
|
|
in_channels=c_in,
|
|
out_channels=c_in,
|
|
kernel_size=3,
|
|
padding=1,
|
|
padding_mode="circular",
|
|
)
|
|
self.norm = nn.BatchNorm1d(c_in)
|
|
self.activation = nn.ELU()
|
|
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
|
|
|
def forward(self, x):
|
|
x = self.downConv(x.permute(0, 2, 1))
|
|
x = self.norm(x)
|
|
x = self.activation(x)
|
|
x = self.maxPool(x)
|
|
x = x.transpose(1, 2)
|
|
return x
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
|
|
super(EncoderLayer, self).__init__()
|
|
d_ff = d_ff or 4 * d_model
|
|
self.attention = attention
|
|
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
|
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.activation = F.relu if activation == "relu" else F.gelu
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
# x [B, L, D]
|
|
# x = x + self.dropout(self.attention(
|
|
# x, x, x,
|
|
# attn_mask = attn_mask
|
|
# ))
|
|
new_x, attn = self.attention(x, x, x, attn_mask=attn_mask)
|
|
x = x + self.dropout(new_x)
|
|
|
|
y = x = self.norm1(x)
|
|
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
|
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
|
|
|
return self.norm2(x + y), attn
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
|
|
super(Encoder, self).__init__()
|
|
self.attn_layers = nn.ModuleList(attn_layers)
|
|
self.conv_layers = (
|
|
nn.ModuleList(conv_layers) if conv_layers is not None else None
|
|
)
|
|
self.norm = norm_layer
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
# x [B, L, D]
|
|
attns = []
|
|
if self.conv_layers is not None:
|
|
for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
|
|
x, attn = attn_layer(x, attn_mask=attn_mask)
|
|
x = conv_layer(x)
|
|
attns.append(attn)
|
|
x, attn = self.attn_layers[-1](x, attn_mask=attn_mask)
|
|
attns.append(attn)
|
|
else:
|
|
for attn_layer in self.attn_layers:
|
|
x, attn = attn_layer(x, attn_mask=attn_mask)
|
|
attns.append(attn)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
return x, attns
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
self_attention,
|
|
cross_attention,
|
|
d_model,
|
|
d_ff=None,
|
|
dropout=0.1,
|
|
activation="relu",
|
|
):
|
|
super(DecoderLayer, self).__init__()
|
|
d_ff = d_ff or 4 * d_model
|
|
self.self_attention = self_attention
|
|
self.cross_attention = cross_attention
|
|
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
|
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.norm3 = nn.LayerNorm(d_model)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.activation = F.relu if activation == "relu" else F.gelu
|
|
|
|
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
|
x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0])
|
|
x = self.norm1(x)
|
|
|
|
x = x + self.dropout(
|
|
self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0]
|
|
)
|
|
|
|
y = x = self.norm2(x)
|
|
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
|
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
|
|
|
return self.norm3(x + y)
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, layers, norm_layer=None):
|
|
super(Decoder, self).__init__()
|
|
self.layers = nn.ModuleList(layers)
|
|
self.norm = norm_layer
|
|
|
|
def forward(self, x, cross, x_mask=None, cross_mask=None):
|
|
for layer in self.layers:
|
|
x = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
return x
|
|
|
|
|
|
class InformerModel(nn.Module):
|
|
@validated()
|
|
def __init__(
|
|
self,
|
|
freq: str,
|
|
context_length: int,
|
|
prediction_length: int,
|
|
num_feat_dynamic_real: int,
|
|
num_feat_static_real: int,
|
|
num_feat_static_cat: int,
|
|
cardinality: List[int],
|
|
# Informer arguments
|
|
nhead: int,
|
|
num_encoder_layers: int,
|
|
num_decoder_layers: int,
|
|
dim_feedforward: int,
|
|
activation: str = "gelu",
|
|
dropout: float = 0.1,
|
|
attn: str = "prob",
|
|
factor: int = 5,
|
|
distil: bool = True,
|
|
# univariate input
|
|
input_size: int = 1,
|
|
embedding_dimension: Optional[List[int]] = None,
|
|
distr_output: DistributionOutput = StudentTOutput(),
|
|
lags_seq: Optional[List[int]] = None,
|
|
scaling: bool = True,
|
|
num_parallel_samples: int = 100,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.input_size = input_size
|
|
|
|
self.target_shape = distr_output.event_shape
|
|
self.num_feat_dynamic_real = num_feat_dynamic_real
|
|
self.num_feat_static_cat = num_feat_static_cat
|
|
self.num_feat_static_real = num_feat_static_real
|
|
self.embedding_dimension = (
|
|
embedding_dimension
|
|
if embedding_dimension is not None or cardinality is None
|
|
else [min(50, (cat + 1) // 2) for cat in cardinality]
|
|
)
|
|
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
|
self.num_parallel_samples = num_parallel_samples
|
|
self.history_length = context_length + max(self.lags_seq)
|
|
self.embedder = FeatureEmbedder(
|
|
cardinalities=cardinality,
|
|
embedding_dims=self.embedding_dimension,
|
|
)
|
|
if scaling:
|
|
self.scaler = MeanScaler(dim=1, keepdim=True)
|
|
else:
|
|
self.scaler = NOPScaler(dim=1, keepdim=True)
|
|
|
|
# total feature size
|
|
d_model = self.input_size * len(self.lags_seq) + self._number_of_features
|
|
|
|
self.context_length = context_length
|
|
self.prediction_length = prediction_length
|
|
self.distr_output = distr_output
|
|
self.param_proj = distr_output.get_args_proj(d_model)
|
|
|
|
# Informer enc-decoder
|
|
Attn = ProbAttention if attn == "prob" else FullAttention
|
|
# Encoder
|
|
self.encoder = Encoder(
|
|
[
|
|
EncoderLayer(
|
|
AttentionLayer(
|
|
Attn(
|
|
mask_flag=False,
|
|
factor=factor,
|
|
attention_dropout=dropout,
|
|
output_attention=False,
|
|
),
|
|
d_model,
|
|
nhead,
|
|
mix=False,
|
|
),
|
|
d_model,
|
|
d_ff=dim_feedforward,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
)
|
|
for l in range(num_encoder_layers)
|
|
],
|
|
[ConvLayer(d_model) for l in range(num_encoder_layers - 1)]
|
|
if distil
|
|
else None,
|
|
norm_layer=torch.nn.LayerNorm(d_model),
|
|
)
|
|
|
|
# Masked Decoder
|
|
self.decoder = Decoder(
|
|
[
|
|
DecoderLayer(
|
|
AttentionLayer(
|
|
Attn(
|
|
mask_flag=True,
|
|
factor=factor,
|
|
attention_dropout=dropout,
|
|
output_attention=False,
|
|
),
|
|
d_model,
|
|
nhead,
|
|
mix=True,
|
|
),
|
|
AttentionLayer(
|
|
FullAttention(
|
|
mask_flag=False,
|
|
factor=factor,
|
|
attention_dropout=dropout,
|
|
output_attention=False,
|
|
),
|
|
d_model,
|
|
nhead,
|
|
mix=False,
|
|
),
|
|
d_model,
|
|
d_ff=dim_feedforward,
|
|
dropout=dropout,
|
|
activation=activation,
|
|
)
|
|
for l in range(num_decoder_layers)
|
|
],
|
|
norm_layer=torch.nn.LayerNorm(d_model),
|
|
)
|
|
|
|
@property
|
|
def _number_of_features(self) -> int:
|
|
return (
|
|
sum(self.embedding_dimension)
|
|
+ self.num_feat_dynamic_real
|
|
+ self.num_feat_static_real
|
|
+ self.input_size # the log(scale)
|
|
)
|
|
|
|
@property
|
|
def _past_length(self) -> int:
|
|
return self.context_length + max(self.lags_seq)
|
|
|
|
def get_lagged_subsequences(
|
|
self, sequence: torch.Tensor, subsequences_length: int, shift: int = 0
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns lagged subsequences of a given sequence.
|
|
Parameters
|
|
----------
|
|
sequence : Tensor
|
|
the sequence from which lagged subsequences should be extracted.
|
|
Shape: (N, T, C).
|
|
subsequences_length : int
|
|
length of the subsequences to be extracted.
|
|
shift: int
|
|
shift the lags by this amount back.
|
|
Returns
|
|
--------
|
|
lagged : Tensor
|
|
a tensor of shape (N, S, C, I), where S = subsequences_length and
|
|
I = len(indices), containing lagged subsequences. Specifically,
|
|
lagged[i, j, :, k] = sequence[i, -indices[k]-S+j, :].
|
|
"""
|
|
sequence_length = sequence.shape[1]
|
|
indices = [lag - shift for lag in self.lags_seq]
|
|
|
|
assert max(indices) + subsequences_length <= sequence_length, (
|
|
f"lags cannot go further than history length, found lag {max(indices)} "
|
|
f"while history length is only {sequence_length}"
|
|
)
|
|
|
|
lagged_values = []
|
|
for lag_index in indices:
|
|
begin_index = -lag_index - subsequences_length
|
|
end_index = -lag_index if lag_index > 0 else None
|
|
lagged_values.append(sequence[:, begin_index:end_index, ...])
|
|
return torch.stack(lagged_values, dim=-1)
|
|
|
|
def _check_shapes(
|
|
self,
|
|
prior_input: torch.Tensor,
|
|
inputs: torch.Tensor,
|
|
features: Optional[torch.Tensor],
|
|
) -> None:
|
|
assert len(prior_input.shape) == len(inputs.shape)
|
|
assert (
|
|
len(prior_input.shape) == 2 and self.input_size == 1
|
|
) or prior_input.shape[2] == self.input_size
|
|
assert (len(inputs.shape) == 2 and self.input_size == 1) or inputs.shape[
|
|
-1
|
|
] == self.input_size
|
|
assert (
|
|
features is None or features.shape[2] == self._number_of_features
|
|
), f"{features.shape[2]}, expected {self._number_of_features}"
|
|
|
|
def create_network_inputs(
|
|
self,
|
|
feat_static_cat: torch.Tensor,
|
|
feat_static_real: torch.Tensor,
|
|
past_time_feat: torch.Tensor,
|
|
past_target: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
future_time_feat: Optional[torch.Tensor] = None,
|
|
future_target: Optional[torch.Tensor] = None,
|
|
):
|
|
# time feature
|
|
time_feat = (
|
|
torch.cat(
|
|
(
|
|
past_time_feat[:, self._past_length - self.context_length :, ...],
|
|
future_time_feat,
|
|
),
|
|
dim=1,
|
|
)
|
|
if future_target is not None
|
|
else past_time_feat[:, self._past_length - self.context_length :, ...]
|
|
)
|
|
|
|
# target
|
|
context = past_target[:, -self.context_length :]
|
|
observed_context = past_observed_values[:, -self.context_length :]
|
|
_, scale = self.scaler(context, observed_context)
|
|
|
|
inputs = (
|
|
torch.cat((past_target, future_target), dim=1) / scale
|
|
if future_target is not None
|
|
else past_target / scale
|
|
)
|
|
|
|
inputs_length = (
|
|
self._past_length + self.prediction_length
|
|
if future_target is not None
|
|
else self._past_length
|
|
)
|
|
assert inputs.shape[1] == inputs_length
|
|
|
|
subsequences_length = (
|
|
self.context_length + self.prediction_length
|
|
if future_target is not None
|
|
else self.context_length
|
|
)
|
|
|
|
# embeddings
|
|
embedded_cat = self.embedder(feat_static_cat)
|
|
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
|
|
static_feat = torch.cat(
|
|
(embedded_cat, feat_static_real, log_scale),
|
|
dim=1,
|
|
)
|
|
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
|
-1, time_feat.shape[1], -1
|
|
)
|
|
|
|
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
|
|
|
|
# self._check_shapes(prior_input, inputs, features)
|
|
|
|
# sequence = torch.cat((prior_input, inputs), dim=1)
|
|
lagged_sequence = self.get_lagged_subsequences(
|
|
sequence=inputs,
|
|
subsequences_length=subsequences_length,
|
|
)
|
|
|
|
lags_shape = lagged_sequence.shape
|
|
reshaped_lagged_sequence = lagged_sequence.reshape(
|
|
lags_shape[0], lags_shape[1], -1
|
|
)
|
|
|
|
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
|
|
|
|
return transformer_inputs, scale, static_feat
|
|
|
|
def output_params(self, transformer_inputs):
|
|
enc_input = transformer_inputs[:, : self.context_length, ...]
|
|
dec_input = transformer_inputs[:, self.context_length :, ...]
|
|
|
|
enc_out, _ = self.encoder(enc_input)
|
|
dec_output = self.decoder(dec_input, enc_out)
|
|
|
|
return self.param_proj(dec_output)
|
|
|
|
@torch.jit.ignore
|
|
def output_distribution(
|
|
self, params, scale=None, trailing_n=None
|
|
) -> torch.distributions.Distribution:
|
|
sliced_params = params
|
|
if trailing_n is not None:
|
|
sliced_params = [p[:, -trailing_n:] for p in params]
|
|
return self.distr_output.distribution(sliced_params, scale=scale)
|
|
|
|
# for prediction
|
|
def forward(
|
|
self,
|
|
feat_static_cat: torch.Tensor,
|
|
feat_static_real: torch.Tensor,
|
|
past_time_feat: torch.Tensor,
|
|
past_target: torch.Tensor,
|
|
past_observed_values: torch.Tensor,
|
|
future_time_feat: torch.Tensor,
|
|
num_parallel_samples: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
|
|
if num_parallel_samples is None:
|
|
num_parallel_samples = self.num_parallel_samples
|
|
|
|
encoder_inputs, scale, static_feat = self.create_network_inputs(
|
|
feat_static_cat,
|
|
feat_static_real,
|
|
past_time_feat,
|
|
past_target,
|
|
past_observed_values,
|
|
)
|
|
|
|
enc_out, _ = self.encoder(encoder_inputs)
|
|
|
|
repeated_scale = scale.repeat_interleave(
|
|
repeats=self.num_parallel_samples, dim=0
|
|
)
|
|
|
|
repeated_past_target = (
|
|
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
|
|
/ repeated_scale
|
|
)
|
|
|
|
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
|
-1, future_time_feat.shape[1], -1
|
|
)
|
|
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
|
|
repeated_features = features.repeat_interleave(
|
|
repeats=self.num_parallel_samples, dim=0
|
|
)
|
|
|
|
repeated_enc_out = enc_out.repeat_interleave(
|
|
repeats=self.num_parallel_samples, dim=0
|
|
)
|
|
|
|
future_samples = []
|
|
|
|
# greedy decoding
|
|
for k in range(self.prediction_length):
|
|
# self._check_shapes(repeated_past_target, next_sample, next_features)
|
|
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
|
|
|
|
lagged_sequence = self.get_lagged_subsequences(
|
|
sequence=repeated_past_target,
|
|
subsequences_length=1 + k,
|
|
shift=1,
|
|
)
|
|
|
|
lags_shape = lagged_sequence.shape
|
|
reshaped_lagged_sequence = lagged_sequence.reshape(
|
|
lags_shape[0], lags_shape[1], -1
|
|
)
|
|
|
|
decoder_input = torch.cat(
|
|
(reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1
|
|
)
|
|
|
|
output = self.decoder(decoder_input, repeated_enc_out)
|
|
|
|
params = self.param_proj(output[:, -1:])
|
|
distr = self.output_distribution(params, scale=repeated_scale)
|
|
next_sample = distr.sample()
|
|
|
|
repeated_past_target = torch.cat(
|
|
(repeated_past_target, next_sample / repeated_scale), dim=1
|
|
)
|
|
future_samples.append(next_sample)
|
|
|
|
concat_future_samples = torch.cat(future_samples, dim=1)
|
|
return concat_future_samples.reshape(
|
|
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
|
)
|