do not project static

This commit is contained in:
Kashif Rasul
2022-03-30 15:28:06 +02:00
parent bc45028959
commit 85cd9bd629
+1 -10
View File
@@ -312,13 +312,6 @@ class TFTModel(nn.Module):
in_features=num_feat_dynamic_real, out_features=variable_dim
)
self.static_proj = nn.Linear(
in_features=sum(self.embedding_dimension)
+ self.num_feat_static_real
+ input_size,
out_features=variable_dim,
)
# variable selection networks
self.past_selection = VariableSelectionNetwork(
d_hidden=variable_dim,
@@ -493,9 +486,7 @@ class TFTModel(nn.Module):
past_time_feat_proj = time_feat_proj[:, : self.context_length, ...]
future_time_feat_proj = time_feat_proj[:, self.context_length :, ...]
static_feat_proj = self.static_proj(static_feat)
static_var, _ = self.static_selection([static_feat_proj])
static_var, _ = self.static_selection(static_feat)
static_selection = self.selection(static_var).unsqueeze(1)
static_enrichment = self.enrichment(static_var).unsqueeze(1)