mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
scale in distrb space
This commit is contained in:
@@ -269,7 +269,6 @@ class TransformerModel(nn.Module):
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
future_time_feat,
|
||||
)
|
||||
|
||||
enc_out = self.transformer.encoder(encoder_inputs)
|
||||
@@ -300,7 +299,6 @@ class TransformerModel(nn.Module):
|
||||
)
|
||||
|
||||
# self._check_shapes(repeated_past_target, next_sample, next_features)
|
||||
|
||||
# sequence = torch.cat((repeated_past_target, next_sample), dim=1)
|
||||
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
@@ -319,13 +317,15 @@ class TransformerModel(nn.Module):
|
||||
output = self.transformer.decoder(decoder_input, repeated_enc_out)
|
||||
|
||||
params = self.param_proj(output)
|
||||
distr = self.output_distribution(params)
|
||||
distr = self.output_distribution(params, scale=repeated_scale)
|
||||
next_sample = distr.sample()
|
||||
|
||||
repeated_past_target = torch.cat((repeated_past_target, next_sample), dim=1)
|
||||
repeated_past_target = torch.cat(
|
||||
(repeated_past_target, next_sample / repeated_scale), dim=1
|
||||
)
|
||||
future_samples.append(next_sample)
|
||||
|
||||
unscaled_future_samples = torch.cat(future_samples, dim=1) * repeated_scale
|
||||
return unscaled_future_samples.reshape(
|
||||
concat_future_samples = torch.cat(future_samples, dim=1)
|
||||
return concat_future_samples.reshape(
|
||||
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user