mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:44:00 +08:00
do not project static
This commit is contained in:
+1
-10
@@ -312,13 +312,6 @@ class TFTModel(nn.Module):
|
||||
in_features=num_feat_dynamic_real, out_features=variable_dim
|
||||
)
|
||||
|
||||
self.static_proj = nn.Linear(
|
||||
in_features=sum(self.embedding_dimension)
|
||||
+ self.num_feat_static_real
|
||||
+ input_size,
|
||||
out_features=variable_dim,
|
||||
)
|
||||
|
||||
# variable selection networks
|
||||
self.past_selection = VariableSelectionNetwork(
|
||||
d_hidden=variable_dim,
|
||||
@@ -493,9 +486,7 @@ class TFTModel(nn.Module):
|
||||
past_time_feat_proj = time_feat_proj[:, : self.context_length, ...]
|
||||
future_time_feat_proj = time_feat_proj[:, self.context_length :, ...]
|
||||
|
||||
static_feat_proj = self.static_proj(static_feat)
|
||||
|
||||
static_var, _ = self.static_selection([static_feat_proj])
|
||||
static_var, _ = self.static_selection(static_feat)
|
||||
static_selection = self.selection(static_var).unsqueeze(1)
|
||||
static_enrichment = self.enrichment(static_var).unsqueeze(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user