From 498f80fab021803db96c7b5c971d23703e6600be Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 17 Oct 2022 11:41:18 +0200 Subject: [PATCH] fix for input_size>1 --- autoformer/lightning_module.py | 2 +- autoformer/module.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/autoformer/lightning_module.py b/autoformer/lightning_module.py index b29bcdc..1c6e0b2 100644 --- a/autoformer/lightning_module.py +++ b/autoformer/lightning_module.py @@ -79,6 +79,6 @@ class AutoformerLightningModule(pl.LightningModule): if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) + loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False) return weighted_average(loss_values, weights=loss_weights) diff --git a/autoformer/module.py b/autoformer/module.py index 4191920..9ec6c44 100644 --- a/autoformer/module.py +++ b/autoformer/module.py @@ -587,7 +587,7 @@ class AutoformerModel(nn.Module): sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real - + 1 # the log(scale) + + self.input_size # the log(scale) ) @property @@ -696,8 +696,9 @@ class AutoformerModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (embedded_cat, feat_static_real, scale.log()), + (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand(