From 2f6ce6300a404e7193165fab9887600eba1bd221 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 30 Mar 2022 19:46:19 +0200 Subject: [PATCH] fix time_feat --- transformer/module.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/transformer/module.py b/transformer/module.py index 3b20a03..da2d81d 100644 --- a/transformer/module.py +++ b/transformer/module.py @@ -164,15 +164,15 @@ class TransformerModel(nn.Module): ): # time feature time_feat = ( - past_time_feat[:, self._past_length - self.context_length :, ...] - if future_time_feat is None or future_target is None - else torch.cat( + torch.cat( ( past_time_feat[:, self._past_length - self.context_length :, ...], future_time_feat, ), dim=1, ) + if future_time_feat is not None + else past_time_feat[:, self._past_length - self.context_length :, ...] ) # target @@ -224,10 +224,7 @@ class TransformerModel(nn.Module): lags_shape[0], lags_shape[1], -1 ) - if features is None: - transformer_inputs = reshaped_lagged_sequence - else: - transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) + transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) return transformer_inputs, scale, static_feat