fix time_feat

This commit is contained in:
Kashif Rasul
2022-03-30 19:46:19 +02:00
parent d4d460100c
commit 2f6ce6300a
+4 -7
View File
@@ -164,15 +164,15 @@ class TransformerModel(nn.Module):
):
# time feature
time_feat = (
past_time_feat[:, self._past_length - self.context_length :, ...]
if future_time_feat is None or future_target is None
else torch.cat(
torch.cat(
(
past_time_feat[:, self._past_length - self.context_length :, ...],
future_time_feat,
),
dim=1,
)
if future_time_feat is not None
else past_time_feat[:, self._past_length - self.context_length :, ...]
)
# target
@@ -224,10 +224,7 @@ class TransformerModel(nn.Module):
lags_shape[0], lags_shape[1], -1
)
if features is None:
transformer_inputs = reshaped_lagged_sequence
else:
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
return transformer_inputs, scale, static_feat