added autoformer predict

This commit is contained in:
Kashif Rasul
2022-04-06 13:15:58 +02:00
parent 68d33b7fd3
commit 6f64227d31
2 changed files with 110 additions and 61 deletions
+7 -2
View File
@@ -57,7 +57,12 @@ class AutoformerLightningModule(pl.LightningModule):
past_observed_values = batch["past_observed_values"]
future_observed_values = batch["future_observed_values"]
autoformer_inputs, scale, _ = self.model.create_network_inputs(
(
autoformer_inputs,
scale,
dynamic_features,
_,
) = self.model.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
@@ -66,7 +71,7 @@ class AutoformerLightningModule(pl.LightningModule):
future_time_feat,
future_target,
)
params = self.model.output_params(autoformer_inputs)
params = self.model.output_params(autoformer_inputs, dynamic_features)
distr = self.model.output_distribution(params, scale)
loss_values = self.loss(distr, future_target)
+103 -59
View File
@@ -11,6 +11,42 @@ from gluonts.torch.modules.feature import FeatureEmbedder
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super(TokenEmbedding, self).__init__()
padding = 1
self.tokenConv = nn.Conv1d(
in_channels=c_in,
out_channels=d_model,
kernel_size=3,
padding=padding,
padding_mode="circular",
bias=False,
)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode="fan_in", nonlinearity="leaky_relu"
)
def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x
class DataEmbedding_wo_pos(nn.Module):
def __init__(self, x_in, x_mark_in, d_model, dropout=0.1):
super(DataEmbedding_wo_pos, self).__init__()
self.value_embedding = TokenEmbedding(c_in=x_in, d_model=d_model)
self.temporal_embedding = nn.Linear(x_mark_in, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, x_mark):
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
return self.dropout(x)
class my_Layernorm(nn.Module):
"""
Special designed layernorm for the seasonal part
@@ -481,6 +517,11 @@ class AutoformerModel(nn.Module):
self.distr_output = distr_output
self.param_proj = distr_output.get_args_proj(d_model)
# embeddings
self.dec_embedding = DataEmbedding_wo_pos(
x_in=d_model, x_mark_in=self._number_of_features, d_model=d_model
)
# autoformer enc-decoder and mask initializer
self.encoder = Encoder(
[
@@ -665,7 +706,7 @@ class AutoformerModel(nn.Module):
-1, time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
dynamic_features = torch.cat((expanded_static_feat, time_feat), dim=-1)
# self._check_shapes(prior_input, inputs, features)
@@ -680,13 +721,18 @@ class AutoformerModel(nn.Module):
lags_shape[0], lags_shape[1], -1
)
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
transformer_inputs = torch.cat(
(reshaped_lagged_sequence, dynamic_features), dim=-1
)
return transformer_inputs, scale, static_feat
return transformer_inputs, scale, dynamic_features, static_feat
def output_params(self, transformer_inputs):
def output_params(self, transformer_inputs, dynamic_features):
enc_input = transformer_inputs[:, : self.context_length, ...]
dec_input = transformer_inputs[:, self.context_length :, ...]
# dec_input = transformer_inputs[:, self.context_length :, ...]
dec_dynamic_feat = dynamic_features[
:, self.context_length - self.label_length :, ...
]
# decomp init
mean = (
@@ -695,7 +741,7 @@ class AutoformerModel(nn.Module):
.repeat(1, self.prediction_length, 1)
)
zeros = torch.zeros(
[dec_input.shape[0], self.prediction_length, dec_input.shape[2]],
[enc_input.shape[0], self.prediction_length, enc_input.shape[2]],
device=enc_input.device,
)
seasonal_init, trend_init = self.decomp(enc_input)
@@ -707,9 +753,10 @@ class AutoformerModel(nn.Module):
)
# enc
enc_out, attns = self.encoder(enc_input, attn_mask=None)
enc_out, _ = self.encoder(enc_input, attn_mask=None)
# dec
dec_input = self.dec_embedding(seasonal_init, dec_dynamic_feat)
seasonal_part, trend_part = self.decoder(
dec_input, enc_out, x_mask=None, cross_mask=None, trend=trend_init
)
@@ -743,7 +790,7 @@ class AutoformerModel(nn.Module):
if num_parallel_samples is None:
num_parallel_samples = self.num_parallel_samples
encoder_inputs, scale, static_feat = self.create_network_inputs(
enc_input, scale, dynamic_feat, static_feat = self.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
@@ -751,63 +798,60 @@ class AutoformerModel(nn.Module):
past_observed_values,
)
enc_out = self.transformer.encoder(encoder_inputs)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, future_time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
dec_dynamic_feat = torch.cat(
(dynamic_feat[:, -self.label_length :, :], features), dim=1
)
# decomp init
mean = (
torch.mean(enc_input, dim=1)
.unsqueeze(1)
.repeat(1, self.prediction_length, 1)
)
zeros = torch.zeros(
[enc_input.shape[0], self.prediction_length, enc_input.shape[2]],
device=enc_input.device,
)
seasonal_init, trend_init = self.decomp(enc_input)
# decoder input
trend_init = torch.cat([trend_init[:, -self.label_length :, :], mean], dim=1)
seasonal_init = torch.cat(
[seasonal_init[:, -self.label_length :, :], zeros], dim=1
)
# enc
enc_out, _ = self.encoder(enc_input, attn_mask=None)
# dec
dec_input = self.dec_embedding(seasonal_init, dec_dynamic_feat)
seasonal_part, trend_part = self.decoder(
dec_input, enc_out, x_mask=None, cross_mask=None, trend=trend_init
)
# output params
dec_out = trend_part + seasonal_part
params = self.param_proj(dec_out[:, -self.prediction_length :, :])
repeated_params = [
s.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
for s in params
]
repeated_scale = scale.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_past_target = (
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
/ repeated_scale
)
distr = self.output_distribution(repeated_params, scale=repeated_scale)
expanded_static_feat = static_feat.unsqueeze(1).expand(
-1, future_time_feat.shape[1], -1
)
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
repeated_features = features.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
# Future samples
samples = distr.sample()
repeated_enc_out = enc_out.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
future_samples = []
# greedy decoding
for k in range(self.prediction_length):
# self._check_shapes(repeated_past_target, next_sample, next_features)
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
lagged_sequence = self.get_lagged_subsequences(
sequence=repeated_past_target,
subsequences_length=1 + k,
shift=1,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
decoder_input = torch.cat(
(reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1
)
output = self.transformer.decoder(decoder_input, repeated_enc_out)
params = self.param_proj(output[:, -1:])
distr = self.output_distribution(params, scale=repeated_scale)
next_sample = distr.sample()
repeated_past_target = torch.cat(
(repeated_past_target, next_sample / repeated_scale), dim=1
)
future_samples.append(next_sample)
concat_future_samples = torch.cat(future_samples, dim=1)
return concat_future_samples.reshape(
return samples.reshape(
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
)