scale in distrb space

This commit is contained in:
Kashif Rasul
2022-03-30 23:25:32 +02:00
parent 7fe6485cdd
commit 789cde9dc0
+6 -6
View File
@@ -269,7 +269,6 @@ class TransformerModel(nn.Module):
past_time_feat,
past_target,
past_observed_values,
future_time_feat,
)
enc_out = self.transformer.encoder(encoder_inputs)
@@ -300,7 +299,6 @@ class TransformerModel(nn.Module):
)
# self._check_shapes(repeated_past_target, next_sample, next_features)
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
lagged_sequence = self.get_lagged_subsequences(
@@ -319,13 +317,15 @@ class TransformerModel(nn.Module):
output = self.transformer.decoder(decoder_input, repeated_enc_out)
params = self.param_proj(output)
distr = self.output_distribution(params)
distr = self.output_distribution(params, scale=repeated_scale)
next_sample = distr.sample()
repeated_past_target = torch.cat((repeated_past_target, next_sample), dim=1)
repeated_past_target = torch.cat(
(repeated_past_target, next_sample / repeated_scale), dim=1
)
future_samples.append(next_sample)
unscaled_future_samples = torch.cat(future_samples, dim=1) * repeated_scale
return unscaled_future_samples.reshape(
concat_future_samples = torch.cat(future_samples, dim=1)
return concat_future_samples.reshape(
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
)