diff --git a/tft/module.py b/tft/module.py index 073bc82..9d050c1 100644 --- a/tft/module.py +++ b/tft/module.py @@ -426,15 +426,15 @@ class TFTModel(nn.Module): ): # time feature time_feat = ( - past_time_feat[:, self._past_length - self.context_length :, ...] - if future_time_feat is None or future_target is None - else torch.cat( + torch.cat( ( past_time_feat[:, self._past_length - self.context_length :, ...], future_time_feat, ), dim=1, ) + if future_time_feat is not None + else past_time_feat[:, self._past_length - self.context_length :, ...] ) # calculate scale @@ -535,11 +535,20 @@ class TFTModel(nn.Module): if num_parallel_samples is None: num_parallel_samples = self.num_parallel_samples - target, time_feat, scale, static_feat = self.create_network_inputs( - feat_static_cat, - feat_static_real, - past_time_feat, - past_target, - past_observed_values, - future_time_feat, + ( + target, + time_feat, + scale, + embedded_cat, + static_feat, + ) = self.create_network_inputs( + feat_static_cat=feat_static_cat, + feat_static_real=feat_static_real, + past_time_feat=past_time_feat, + past_target=past_target, + past_observed_values=past_observed_values, + future_time_feat=future_time_feat, ) + + target_proj = self.target_proj(target) + time_feat_proj = self.dynamic_proj(time_feat)