mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
fix transformer
This commit is contained in:
@@ -171,7 +171,7 @@ class TransformerModel(nn.Module):
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if future_time_feat is not None
|
||||
if future_target is not None
|
||||
else past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
)
|
||||
|
||||
@@ -194,9 +194,9 @@ class TransformerModel(nn.Module):
|
||||
assert inputs.shape[1] == inputs_length
|
||||
|
||||
subsequences_length = (
|
||||
self.context_length
|
||||
if future_time_feat is None or future_target is None
|
||||
else self.context_length + self.prediction_length
|
||||
self.context_length + self.prediction_length
|
||||
if future_target is not None
|
||||
else self.context_length
|
||||
)
|
||||
|
||||
# embeddings
|
||||
|
||||
Reference in New Issue
Block a user