mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
added autoformer predict
This commit is contained in:
@@ -57,7 +57,12 @@ class AutoformerLightningModule(pl.LightningModule):
|
||||
past_observed_values = batch["past_observed_values"]
|
||||
future_observed_values = batch["future_observed_values"]
|
||||
|
||||
autoformer_inputs, scale, _ = self.model.create_network_inputs(
|
||||
(
|
||||
autoformer_inputs,
|
||||
scale,
|
||||
dynamic_features,
|
||||
_,
|
||||
) = self.model.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
@@ -66,7 +71,7 @@ class AutoformerLightningModule(pl.LightningModule):
|
||||
future_time_feat,
|
||||
future_target,
|
||||
)
|
||||
params = self.model.output_params(autoformer_inputs)
|
||||
params = self.model.output_params(autoformer_inputs, dynamic_features)
|
||||
distr = self.model.output_distribution(params, scale)
|
||||
|
||||
loss_values = self.loss(distr, future_target)
|
||||
|
||||
+103
-59
@@ -11,6 +11,42 @@ from gluonts.torch.modules.feature import FeatureEmbedder
|
||||
from gluonts.torch.modules.scaler import MeanScaler, NOPScaler
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1
|
||||
self.tokenConv = nn.Conv1d(
|
||||
in_channels=c_in,
|
||||
out_channels=d_model,
|
||||
kernel_size=3,
|
||||
padding=padding,
|
||||
padding_mode="circular",
|
||||
bias=False,
|
||||
)
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode="fan_in", nonlinearity="leaky_relu"
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class DataEmbedding_wo_pos(nn.Module):
|
||||
def __init__(self, x_in, x_mark_in, d_model, dropout=0.1):
|
||||
super(DataEmbedding_wo_pos, self).__init__()
|
||||
|
||||
self.value_embedding = TokenEmbedding(c_in=x_in, d_model=d_model)
|
||||
self.temporal_embedding = nn.Linear(x_mark_in, d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, x, x_mark):
|
||||
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class my_Layernorm(nn.Module):
|
||||
"""
|
||||
Special designed layernorm for the seasonal part
|
||||
@@ -481,6 +517,11 @@ class AutoformerModel(nn.Module):
|
||||
self.distr_output = distr_output
|
||||
self.param_proj = distr_output.get_args_proj(d_model)
|
||||
|
||||
# embeddings
|
||||
self.dec_embedding = DataEmbedding_wo_pos(
|
||||
x_in=d_model, x_mark_in=self._number_of_features, d_model=d_model
|
||||
)
|
||||
|
||||
# autoformer enc-decoder and mask initializer
|
||||
self.encoder = Encoder(
|
||||
[
|
||||
@@ -665,7 +706,7 @@ class AutoformerModel(nn.Module):
|
||||
-1, time_feat.shape[1], -1
|
||||
)
|
||||
|
||||
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
|
||||
dynamic_features = torch.cat((expanded_static_feat, time_feat), dim=-1)
|
||||
|
||||
# self._check_shapes(prior_input, inputs, features)
|
||||
|
||||
@@ -680,13 +721,18 @@ class AutoformerModel(nn.Module):
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
|
||||
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
|
||||
transformer_inputs = torch.cat(
|
||||
(reshaped_lagged_sequence, dynamic_features), dim=-1
|
||||
)
|
||||
|
||||
return transformer_inputs, scale, static_feat
|
||||
return transformer_inputs, scale, dynamic_features, static_feat
|
||||
|
||||
def output_params(self, transformer_inputs):
|
||||
def output_params(self, transformer_inputs, dynamic_features):
|
||||
enc_input = transformer_inputs[:, : self.context_length, ...]
|
||||
dec_input = transformer_inputs[:, self.context_length :, ...]
|
||||
# dec_input = transformer_inputs[:, self.context_length :, ...]
|
||||
dec_dynamic_feat = dynamic_features[
|
||||
:, self.context_length - self.label_length :, ...
|
||||
]
|
||||
|
||||
# decomp init
|
||||
mean = (
|
||||
@@ -695,7 +741,7 @@ class AutoformerModel(nn.Module):
|
||||
.repeat(1, self.prediction_length, 1)
|
||||
)
|
||||
zeros = torch.zeros(
|
||||
[dec_input.shape[0], self.prediction_length, dec_input.shape[2]],
|
||||
[enc_input.shape[0], self.prediction_length, enc_input.shape[2]],
|
||||
device=enc_input.device,
|
||||
)
|
||||
seasonal_init, trend_init = self.decomp(enc_input)
|
||||
@@ -707,9 +753,10 @@ class AutoformerModel(nn.Module):
|
||||
)
|
||||
|
||||
# enc
|
||||
enc_out, attns = self.encoder(enc_input, attn_mask=None)
|
||||
enc_out, _ = self.encoder(enc_input, attn_mask=None)
|
||||
|
||||
# dec
|
||||
dec_input = self.dec_embedding(seasonal_init, dec_dynamic_feat)
|
||||
seasonal_part, trend_part = self.decoder(
|
||||
dec_input, enc_out, x_mask=None, cross_mask=None, trend=trend_init
|
||||
)
|
||||
@@ -743,7 +790,7 @@ class AutoformerModel(nn.Module):
|
||||
if num_parallel_samples is None:
|
||||
num_parallel_samples = self.num_parallel_samples
|
||||
|
||||
encoder_inputs, scale, static_feat = self.create_network_inputs(
|
||||
enc_input, scale, dynamic_feat, static_feat = self.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
@@ -751,63 +798,60 @@ class AutoformerModel(nn.Module):
|
||||
past_observed_values,
|
||||
)
|
||||
|
||||
enc_out = self.transformer.encoder(encoder_inputs)
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
-1, future_time_feat.shape[1], -1
|
||||
)
|
||||
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
|
||||
|
||||
dec_dynamic_feat = torch.cat(
|
||||
(dynamic_feat[:, -self.label_length :, :], features), dim=1
|
||||
)
|
||||
|
||||
# decomp init
|
||||
mean = (
|
||||
torch.mean(enc_input, dim=1)
|
||||
.unsqueeze(1)
|
||||
.repeat(1, self.prediction_length, 1)
|
||||
)
|
||||
zeros = torch.zeros(
|
||||
[enc_input.shape[0], self.prediction_length, enc_input.shape[2]],
|
||||
device=enc_input.device,
|
||||
)
|
||||
seasonal_init, trend_init = self.decomp(enc_input)
|
||||
|
||||
# decoder input
|
||||
trend_init = torch.cat([trend_init[:, -self.label_length :, :], mean], dim=1)
|
||||
seasonal_init = torch.cat(
|
||||
[seasonal_init[:, -self.label_length :, :], zeros], dim=1
|
||||
)
|
||||
|
||||
# enc
|
||||
enc_out, _ = self.encoder(enc_input, attn_mask=None)
|
||||
|
||||
# dec
|
||||
dec_input = self.dec_embedding(seasonal_init, dec_dynamic_feat)
|
||||
seasonal_part, trend_part = self.decoder(
|
||||
dec_input, enc_out, x_mask=None, cross_mask=None, trend=trend_init
|
||||
)
|
||||
|
||||
# output params
|
||||
dec_out = trend_part + seasonal_part
|
||||
params = self.param_proj(dec_out[:, -self.prediction_length :, :])
|
||||
|
||||
repeated_params = [
|
||||
s.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
|
||||
for s in params
|
||||
]
|
||||
|
||||
repeated_scale = scale.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
repeated_past_target = (
|
||||
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
|
||||
/ repeated_scale
|
||||
)
|
||||
distr = self.output_distribution(repeated_params, scale=repeated_scale)
|
||||
|
||||
expanded_static_feat = static_feat.unsqueeze(1).expand(
|
||||
-1, future_time_feat.shape[1], -1
|
||||
)
|
||||
features = torch.cat((expanded_static_feat, future_time_feat), dim=-1)
|
||||
repeated_features = features.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
# Future samples
|
||||
samples = distr.sample()
|
||||
|
||||
repeated_enc_out = enc_out.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
future_samples = []
|
||||
|
||||
# greedy decoding
|
||||
for k in range(self.prediction_length):
|
||||
# self._check_shapes(repeated_past_target, next_sample, next_features)
|
||||
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
|
||||
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
sequence=repeated_past_target,
|
||||
subsequences_length=1 + k,
|
||||
shift=1,
|
||||
)
|
||||
|
||||
lags_shape = lagged_sequence.shape
|
||||
reshaped_lagged_sequence = lagged_sequence.reshape(
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
|
||||
decoder_input = torch.cat(
|
||||
(reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1
|
||||
)
|
||||
|
||||
output = self.transformer.decoder(decoder_input, repeated_enc_out)
|
||||
|
||||
params = self.param_proj(output[:, -1:])
|
||||
distr = self.output_distribution(params, scale=repeated_scale)
|
||||
next_sample = distr.sample()
|
||||
|
||||
repeated_past_target = torch.cat(
|
||||
(repeated_past_target, next_sample / repeated_scale), dim=1
|
||||
)
|
||||
future_samples.append(next_sample)
|
||||
|
||||
concat_future_samples = torch.cat(future_samples, dim=1)
|
||||
return concat_future_samples.reshape(
|
||||
return samples.reshape(
|
||||
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user