diff --git a/transformer/module.py b/transformer/module.py index 3b20a03..da2d81d 100644 --- a/transformer/module.py +++ b/transformer/module.py @@ -164,15 +164,15 @@ class TransformerModel(nn.Module): ): # time feature time_feat = ( - past_time_feat[:, self._past_length - self.context_length :, ...] - if future_time_feat is None or future_target is None - else torch.cat( + torch.cat( ( past_time_feat[:, self._past_length - self.context_length :, ...], future_time_feat, ), dim=1, ) + if future_time_feat is not None + else past_time_feat[:, self._past_length - self.context_length :, ...] ) # target @@ -224,10 +224,7 @@ class TransformerModel(nn.Module): lags_shape[0], lags_shape[1], -1 ) - if features is None: - transformer_inputs = reshaped_lagged_sequence - else: - transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) + transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1) return transformer_inputs, scale, static_feat