fix transformer

This commit is contained in:
Kashif Rasul
2022-03-30 20:22:50 +02:00
parent 2f6ce6300a
commit f6403dd4b4
+4 -4
View File
@@ -171,7 +171,7 @@ class TransformerModel(nn.Module):
),
dim=1,
)
if future_time_feat is not None
if future_target is not None
else past_time_feat[:, self._past_length - self.context_length :, ...]
)
@@ -194,9 +194,9 @@ class TransformerModel(nn.Module):
assert inputs.shape[1] == inputs_length
subsequences_length = (
self.context_length
if future_time_feat is None or future_target is None
else self.context_length + self.prediction_length
self.context_length + self.prediction_length
if future_target is not None
else self.context_length
)
# embeddings