From d4d460100c7727a62768a5e29efd855f149fd7bf Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 30 Mar 2022 19:41:11 +0200 Subject: [PATCH] fix time feat concat --- tft/module.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/tft/module.py b/tft/module.py index 073bc82..9d050c1 100644 --- a/tft/module.py +++ b/tft/module.py @@ -426,15 +426,15 @@ class TFTModel(nn.Module): ): # time feature time_feat = ( - past_time_feat[:, self._past_length - self.context_length :, ...] - if future_time_feat is None or future_target is None - else torch.cat( + torch.cat( ( past_time_feat[:, self._past_length - self.context_length :, ...], future_time_feat, ), dim=1, ) + if future_time_feat is not None + else past_time_feat[:, self._past_length - self.context_length :, ...] ) # calculate scale @@ -535,11 +535,20 @@ class TFTModel(nn.Module): if num_parallel_samples is None: num_parallel_samples = self.num_parallel_samples - target, time_feat, scale, static_feat = self.create_network_inputs( - feat_static_cat, - feat_static_real, - past_time_feat, - past_target, - past_observed_values, - future_time_feat, + ( + target, + time_feat, + scale, + embedded_cat, + static_feat, + ) = self.create_network_inputs( + feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, + past_time_feat=past_time_feat, + past_target=past_target, + past_observed_values=past_observed_values, + future_time_feat=future_time_feat, ) + + target_proj = self.target_proj(target) + time_feat_proj = self.dynamic_proj(time_feat)