mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
add static real projector
This commit is contained in:
+20
-16
@@ -268,7 +268,6 @@ class TFTModel(nn.Module):
|
||||
dropout: float,
|
||||
# univariate input
|
||||
input_size: int = 1,
|
||||
embedding_dimension: Optional[List[int]] = None,
|
||||
distr_output: DistributionOutput = StudentTOutput(),
|
||||
lags_seq: Optional[List[int]] = None,
|
||||
scaling: bool = True,
|
||||
@@ -282,17 +281,14 @@ class TFTModel(nn.Module):
|
||||
self.num_feat_dynamic_real = num_feat_dynamic_real
|
||||
self.num_feat_static_cat = num_feat_static_cat
|
||||
self.num_feat_static_real = num_feat_static_real
|
||||
self.embedding_dimension = (
|
||||
embedding_dimension
|
||||
if embedding_dimension is not None or cardinality is None
|
||||
else [min(50, (cat + 1) // 2) for cat in cardinality]
|
||||
)
|
||||
|
||||
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
|
||||
self.num_parallel_samples = num_parallel_samples
|
||||
self.history_length = context_length + max(self.lags_seq)
|
||||
|
||||
self.embedder = FeatureEmbedder(
|
||||
cardinalities=cardinality,
|
||||
embedding_dims=self.embedding_dimension,
|
||||
embedding_dims=[variable_dim] * num_feat_static_cat,
|
||||
)
|
||||
if scaling:
|
||||
self.scaler = MeanScaler(dim=1, keepdim=True)
|
||||
@@ -312,26 +308,28 @@ class TFTModel(nn.Module):
|
||||
in_features=num_feat_dynamic_real, out_features=variable_dim
|
||||
)
|
||||
|
||||
self.static_feat_proj = nn.Linear(
|
||||
in_features=num_feat_static_real + input_size, out_features=variable_dim
|
||||
)
|
||||
|
||||
# variable selection networks
|
||||
self.past_selection = VariableSelectionNetwork(
|
||||
d_hidden=variable_dim,
|
||||
n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real,
|
||||
n_vars=2, # target and time features
|
||||
dropout=dropout,
|
||||
add_static=True,
|
||||
)
|
||||
|
||||
self.future_selection = VariableSelectionNetwork(
|
||||
d_hidden=variable_dim,
|
||||
n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real,
|
||||
n_vars=2, # target and time features
|
||||
dropout=dropout,
|
||||
add_static=True,
|
||||
)
|
||||
|
||||
self.static_selection = VariableSelectionNetwork(
|
||||
d_hidden=variable_dim,
|
||||
n_vars=sum(self.embedding_dimension)
|
||||
+ self.num_feat_static_real
|
||||
+ input_size,
|
||||
n_vars=2, # cat, static_feat
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
@@ -466,17 +464,21 @@ class TFTModel(nn.Module):
|
||||
|
||||
# embeddings
|
||||
embedded_cat = self.embedder(feat_static_cat)
|
||||
static_feat = embedded_cat + [feat_static_real, scale.log()]
|
||||
static_feat = torch.cat(
|
||||
(feat_static_real, scale.log()),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# return the network inputs
|
||||
return (
|
||||
reshaped_lagged_target, # target
|
||||
time_feat, # dynamic real covariates
|
||||
scale, # scale
|
||||
static_feat, # static covariates
|
||||
embedded_cat, # static covariates
|
||||
static_feat,
|
||||
)
|
||||
|
||||
def output_params(self, target, time_feat, static_feat):
|
||||
def output_params(self, target, time_feat, embedded_cat, static_feat):
|
||||
target_proj = self.target_proj(target)
|
||||
|
||||
past_target_proj = target_proj[:, : self.context_length, ...]
|
||||
@@ -486,7 +488,9 @@ class TFTModel(nn.Module):
|
||||
past_time_feat_proj = time_feat_proj[:, : self.context_length, ...]
|
||||
future_time_feat_proj = time_feat_proj[:, self.context_length :, ...]
|
||||
|
||||
static_var, _ = self.static_selection(static_feat)
|
||||
static_feat_proj = self.static_feat_proj(static_feat)
|
||||
|
||||
static_var, _ = self.static_selection(embedded_cat + [static_feat_proj])
|
||||
static_selection = self.selection(static_var).unsqueeze(1)
|
||||
static_enrichment = self.enrichment(static_var).unsqueeze(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user