diff --git a/hopfield/lightning_module.py b/hopfield/lightning_module.py index b1f6af5..97dcec0 100644 --- a/hopfield/lightning_module.py +++ b/hopfield/lightning_module.py @@ -74,6 +74,6 @@ class HopfieldLightningModule(pl.LightningModule): if len(self.model.target_shape) == 0: loss_weights = future_observed_values else: - loss_weights = future_observed_values.min(dim=-1, keepdim=False) + loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False) return weighted_average(loss_values, weights=loss_weights) diff --git a/hopfield/module.py b/hopfield/module.py index 3a43abd..ccd80b2 100644 --- a/hopfield/module.py +++ b/hopfield/module.py @@ -115,7 +115,7 @@ class HopfieldModel(nn.Module): sum(self.embedding_dimension) + self.num_feat_dynamic_real + self.num_feat_static_real - + 1 # the log(scale) + + self.input_size # the log(scale) ) @property @@ -224,8 +224,9 @@ class HopfieldModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) + log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log() static_feat = torch.cat( - (embedded_cat, feat_static_real, scale.log()), + (embedded_cat, feat_static_real, log_scale), dim=1, ) expanded_static_feat = static_feat.unsqueeze(1).expand(