diff --git a/autoformer/lightning_module.py b/autoformer/lightning_module.py index b29bcdc..1c6e0b2 100644 --- a/autoformer/lightning_module.py +++ b/autoformer/lightning_module.py @@ -79,6 +79,6 @@ class AutoformerLightningModule(pl.LightningModule): if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) + loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False) return weighted_average(loss_values, weights=loss_weights) diff --git a/autoformer/module.py b/autoformer/module.py index 4191920..9ec6c44 100644 --- a/autoformer/module.py +++ b/autoformer/module.py @@ -587,7 +587,7 @@ class AutoformerModel(nn.Module): sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real - + 1 # the log(scale) + + self.input_size # the log(scale) ) @property @@ -696,8 +696,9 @@ class AutoformerModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (embedded_cat, feat_static_real, scale.log()), + (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand(