do not concat static features

This commit is contained in:
Kashif Rasul
2022-03-30 15:23:39 +02:00
parent 1c9df12594
commit bc45028959
+7 -6
View File
@@ -313,7 +313,9 @@ class TFTModel(nn.Module):
)
self.static_proj = nn.Linear(
in_features=sum(self.embedding_dimension) + self.num_feat_static_real + 1,
in_features=sum(self.embedding_dimension)
+ self.num_feat_static_real
+ input_size,
out_features=variable_dim,
)
@@ -334,7 +336,9 @@ class TFTModel(nn.Module):
self.static_selection = VariableSelectionNetwork(
d_hidden=variable_dim,
n_vars=sum(self.embedding_dimension) + self.num_feat_static_real + 1,
n_vars=sum(self.embedding_dimension)
+ self.num_feat_static_real
+ input_size,
dropout=dropout,
)
@@ -469,10 +473,7 @@ class TFTModel(nn.Module):
# embeddings
embedded_cat = self.embedder(feat_static_cat)
static_feat = torch.cat(
(embedded_cat, feat_static_real, scale.log()),
dim=1,
)
static_feat = embedded_cat + [feat_static_real, scale.log()]
# return the network inputs
return (