fix input_size > 1

This commit is contained in:
Kashif Rasul
2022-10-17 11:47:07 +02:00
parent fd274b3217
commit 1af9b1122d
2 changed files with 4 additions and 3 deletions
+1 -1
View File
@@ -74,6 +74,6 @@ class HopfieldLightningModule(pl.LightningModule):
if len(self.model.target_shape) == 0:
loss_weights = future_observed_values
else:
loss_weights = future_observed_values.min(dim=-1, keepdim=False)
loss_weights, _ = future_observed_values.min(dim=-1, keepdim=False)
return weighted_average(loss_values, weights=loss_weights)
+3 -2
View File
@@ -115,7 +115,7 @@ class HopfieldModel(nn.Module):
sum(self.embedding_dimension)
+ self.num_feat_dynamic_real
+ self.num_feat_static_real
+ 1 # the log(scale)
+ self.input_size # the log(scale)
)
@property
@@ -224,8 +224,9 @@ class HopfieldModel(nn.Module):
# embeddings
embedded_cat = self.embedder(feat_static_cat)
log_scale = scale.log() if self.input_size == 1 else scale.squeeze(1).log()
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
(embedded_cat, feat_static_real, log_scale),
dim=1,
)
expanded_static_feat = static_feat.unsqueeze(1).expand(