From 1af9b1122d5148a4623327e12ca37f617808a68a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 17 Oct 2022 11:47:07 +0200 Subject: [PATCH] fix input_size > 1 --- hopfield/lightning_module.py | 2 +- hopfield/module.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/hopfield/lightning_module.py b/hopfield/lightning_module.py index b1f6af5..97dcec0 100644 --- a/hopfield/lightning_module.py +++ b/hopfield/lightning_module.py @@ -74,6 +74,6 @@ class HopfieldLightningModule(pl.LightningModule): if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) + loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False) return weighted_average(loss_values, weights=loss_weights) diff --git a/hopfield/module.py b/hopfield/module.py index 3a43abd..ccd80b2 100644 --- a/hopfield/module.py +++ b/hopfield/module.py @@ -115,7 +115,7 @@ class HopfieldModel(nn.Module): sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real - + 1 # the log(scale) + + self.input_size # the log(scale) ) @property @@ -224,8 +224,9 @@ class HopfieldModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (embedded_cat, feat_static_real, scale.log()), + (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand(