add static real projector

This commit is contained in:
Kashif Rasul
2022-03-30 16:50:51 +02:00
parent 85cd9bd629
commit 7e25ade272
+20 -16
View File
@@ -268,7 +268,6 @@ class TFTModel(nn.Module):
dropout: float,
# univariate input
input_size: int = 1,
embedding_dimension: Optional[List[int]] = None,
distr_output: DistributionOutput = StudentTOutput(),
lags_seq: Optional[List[int]] = None,
scaling: bool = True,
@@ -282,17 +281,14 @@ class TFTModel(nn.Module):
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.embedding_dimension = (
embedding_dimension
if embedding_dimension is not None or cardinality is None
else [min(50, (cat + 1) // 2) for cat in cardinality]
)
self.lags_seq = lags_seq or get_lags_for_frequency(freq_str=freq)
self.num_parallel_samples = num_parallel_samples
self.history_length = context_length + max(self.lags_seq)
self.embedder = FeatureEmbedder(
cardinalities=cardinality,
embedding_dims=self.embedding_dimension,
embedding_dims=[variable_dim] * num_feat_static_cat,
)
if scaling:
self.scaler = MeanScaler(dim=1, keepdim=True)
@@ -312,26 +308,28 @@ class TFTModel(nn.Module):
in_features=num_feat_dynamic_real, out_features=variable_dim
)
self.static_feat_proj = nn.Linear(
in_features=num_feat_static_real + input_size, out_features=variable_dim
)
# variable selection networks
self.past_selection = VariableSelectionNetwork(
d_hidden=variable_dim,
n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real,
n_vars=2, # target and time features
dropout=dropout,
add_static=True,
)
self.future_selection = VariableSelectionNetwork(
d_hidden=variable_dim,
n_vars=input_size * len(self.lags_seq) + num_feat_dynamic_real,
n_vars=2, # target and time features
dropout=dropout,
add_static=True,
)
self.static_selection = VariableSelectionNetwork(
d_hidden=variable_dim,
n_vars=sum(self.embedding_dimension)
+ self.num_feat_static_real
+ input_size,
n_vars=2, # cat, static_feat
dropout=dropout,
)
@@ -466,17 +464,21 @@ class TFTModel(nn.Module):
# embeddings
embedded_cat = self.embedder(feat_static_cat)
static_feat = embedded_cat + [feat_static_real, scale.log()]
static_feat = torch.cat(
(feat_static_real, scale.log()),
dim=1,
)
# return the network inputs
return (
reshaped_lagged_target, # target
time_feat, # dynamic real covariates
scale, # scale
static_feat, # static covariates
embedded_cat, # static covariates
static_feat,
)
def output_params(self, target, time_feat, static_feat):
def output_params(self, target, time_feat, embedded_cat, static_feat):
target_proj = self.target_proj(target)
past_target_proj = target_proj[:, : self.context_length, ...]
@@ -486,7 +488,9 @@ class TFTModel(nn.Module):
past_time_feat_proj = time_feat_proj[:, : self.context_length, ...]
future_time_feat_proj = time_feat_proj[:, self.context_length :, ...]
static_var, _ = self.static_selection(static_feat)
static_feat_proj = self.static_feat_proj(static_feat)
static_var, _ = self.static_selection(embedded_cat + [static_feat_proj])
static_selection = self.selection(static_var).unsqueeze(1)
static_enrichment = self.enrichment(static_var).unsqueeze(1)