From 789cde9dc0a62291f5cc9fbfb4715060e64fb730 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 30 Mar 2022 23:25:32 +0200 Subject: [PATCH] scale in distrb space --- transformer/module.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer/module.py b/transformer/module.py index 710a00f..10742bb 100644 --- a/transformer/module.py +++ b/transformer/module.py @@ -269,7 +269,6 @@ class TransformerModel(nn.Module): past_time_feat, past_target, past_observed_values, - future_time_feat, ) enc_out = self.transformer.encoder(encoder_inputs) @@ -300,7 +299,6 @@ class TransformerModel(nn.Module): ) # self._check_shapes(repeated_past_target, next_sample, next_features) - # sequence = torch.cat((repeated_past_target, next_sample), dim=1) lagged_sequence = self.get_lagged_subsequences( @@ -319,13 +317,15 @@ class TransformerModel(nn.Module): output = self.transformer.decoder(decoder_input, repeated_enc_out) params = self.param_proj(output) - distr = self.output_distribution(params) + distr = self.output_distribution(params, scale=repeated_scale) next_sample = distr.sample() - repeated_past_target = torch.cat((repeated_past_target, next_sample), dim=1) + repeated_past_target = torch.cat( + (repeated_past_target, next_sample / repeated_scale), dim=1 + ) future_samples.append(next_sample) - unscaled_future_samples = torch.cat(future_samples, dim=1) * repeated_scale - return unscaled_future_samples.reshape( + concat_future_samples = torch.cat(future_samples, dim=1) + return concat_future_samples.reshape( (-1, self.num_parallel_samples, self.prediction_length) + self.target_shape, )