mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
do not concat static features
This commit is contained in:
+7
-6
@@ -313,7 +313,9 @@ class TFTModel(nn.Module):
|
||||
)
|
||||
|
||||
self.static_proj = nn.Linear(
|
||||
in_features=sum(self.embedding_dimension) + self.num_feat_static_real + 1,
|
||||
in_features=sum(self.embedding_dimension)
|
||||
+ self.num_feat_static_real
|
||||
+ input_size,
|
||||
out_features=variable_dim,
|
||||
)
|
||||
|
||||
@@ -334,7 +336,9 @@ class TFTModel(nn.Module):
|
||||
|
||||
self.static_selection = VariableSelectionNetwork(
|
||||
d_hidden=variable_dim,
|
||||
n_vars=sum(self.embedding_dimension) + self.num_feat_static_real + 1,
|
||||
n_vars=sum(self.embedding_dimension)
|
||||
+ self.num_feat_static_real
|
||||
+ input_size,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
@@ -469,10 +473,7 @@ class TFTModel(nn.Module):
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
static_feat = torch.cat(
|
||||
(embedded_cat, feat_static_real, scale.log()),
|
||||
dim=1,
|
||||
)
|
||||
static_feat = embedded_cat + [feat_static_real, scale.log()]
|
||||
|
||||
# return the network inputs
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user