fix time feat concat

This commit is contained in:
Kashif Rasul
2022-03-30 19:41:11 +02:00
parent c653847fc5
commit d4d460100c
+19 -10
View File
@@ -426,15 +426,15 @@ class TFTModel(nn.Module):
):
# time feature
time_feat = (
past_time_feat[:, self._past_length - self.context_length :, ...]
if future_time_feat is None or future_target is None
else torch.cat(
torch.cat(
(
past_time_feat[:, self._past_length - self.context_length :, ...],
future_time_feat,
),
dim=1,
)
if future_time_feat is not None
else past_time_feat[:, self._past_length - self.context_length :, ...]
)
# calculate scale
@@ -535,11 +535,20 @@ class TFTModel(nn.Module):
if num_parallel_samples is None:
num_parallel_samples = self.num_parallel_samples
target, time_feat, scale, static_feat = self.create_network_inputs(
feat_static_cat,
feat_static_real,
past_time_feat,
past_target,
past_observed_values,
future_time_feat,
(
target,
time_feat,
scale,
embedded_cat,
static_feat,
) = self.create_network_inputs(
feat_static_cat=feat_static_cat,
feat_static_real=feat_static_real,
past_time_feat=past_time_feat,
past_target=past_target,
past_observed_values=past_observed_values,
future_time_feat=future_time_feat,
)
target_proj = self.target_proj(target)
time_feat_proj = self.dynamic_proj(time_feat)