mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 19:32:05 +08:00
fix time_feat
This commit is contained in:
@@ -164,15 +164,15 @@ class TransformerModel(nn.Module):
|
||||
):
|
||||
# time feature
|
||||
time_feat = (
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
if future_time_feat is None or future_target is None
|
||||
else torch.cat(
|
||||
torch.cat(
|
||||
(
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...],
|
||||
future_time_feat,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if future_time_feat is not None
|
||||
else past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
)
|
||||
|
||||
# target
|
||||
@@ -224,10 +224,7 @@ class TransformerModel(nn.Module):
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
|
||||
if features is None:
|
||||
transformer_inputs = reshaped_lagged_sequence
|
||||
else:
|
||||
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
|
||||
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
|
||||
|
||||
return transformer_inputs, scale, static_feat
|
||||
|
||||
|
||||
Reference in New Issue
Block a user