diff --git a/tft/module.py b/tft/module.py index 5fd624d..0626093 100644 --- a/tft/module.py +++ b/tft/module.py @@ -268,7 +268,6 @@ class TFTModel(nn.Module): dropout: float, # univariate input input_size: int = 1, - embedding_dimension: Optional[List[int]] = None, distr_output: DistributionOutput = StudentTOutput(), lags_seq: Optional[List[int]] = None, scaling: bool = True, @@ -282,17 +281,14 @@ class TFTModel(nn.Module): self.num_feat_dynamic_real = num_feat_dynamic_real self.num_feat_static_cat = num_feat_static_cat self.num_feat_static_real = num_feat_static_real - self.embedding_dimension = ( - embedding_dimension - if embedding_dimension is not None or cardinality is None - else [min(50, (cat + 1) // 2) for cat in cardinality] - ) + self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq) self.num_parallel_samples = num_parallel_samples self.history_length = context_length + max(self.lags_seq) + self.embedder = FeatureEmbedder( cardinalities=cardinality, - embedding_dims=self.embedding_dimension, + embedding_dims=[variable_dim] * num_feat_static_cat, ) if scaling: self.scaler = MeanScaler(dim=1, keepdim=True) @@ -312,26 +308,28 @@ class TFTModel(nn.Module): in_features=num_feat_dynamic_real, out_features=variable_dim ) + self.static_feat_proj = nn.Linear( + in_features=num_feat_static_real + input_size, out_features=variable_dim + ) + # variable selection networks self.past_selection = VariableSelectionNetwork( d_hidden=variable_dim, - n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real, + n_vars=2, # target and time features dropout=dropout, add_static=True, ) self.future_selection = VariableSelectionNetwork( d_hidden=variable_dim, - n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real, + n_vars=2, # target and time features dropout=dropout, add_static=True, ) self.static_selection = VariableSelectionNetwork( d_hidden=variable_dim, - n_vars=sum(self.embedding_dimension) - + self.num_feat_static_real - + input_size, + n_vars=2, # cat, static_feat dropout=dropout, ) @@ -466,17 +464,21 @@ class TFTModel(nn.Module): # embeddings embedded_cat = self.embedder(feat_static_cat) - static_feat = embedded_cat + [feat_static_real, scale.log()] + static_feat = torch.cat( + (feat_static_real, scale.log()), + dim=1, + ) # return the network inputs return ( reshaped_lagged_target, # target time_feat, # dynamic real covariates scale, # scale - static_feat, # static covariates + embedded_cat, # static covariates + static_feat, ) - def output_params(self, target, time_feat, static_feat): + def output_params(self, target, time_feat, embedded_cat, static_feat): target_proj = self.target_proj(target) past_target_proj = target_proj[:, : self.context_length, ...] @@ -486,7 +488,9 @@ class TFTModel(nn.Module): past_time_feat_proj = time_feat_proj[:, : self.context_length, ...] future_time_feat_proj = time_feat_proj[:, self.context_length :, ...] - static_var, _ = self.static_selection(static_feat) + static_feat_proj = self.static_feat_proj(static_feat) + + static_var, _ = self.static_selection(embedded_cat + [static_feat_proj]) static_selection = self.selection(static_var).unsqueeze(1) static_enrichment = self.enrichment(static_var).unsqueeze(1)