mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 19:32:05 +08:00
fix time feat concat
This commit is contained in:
+19
-10
@@ -426,15 +426,15 @@ class TFTModel(nn.Module):
|
||||
):
|
||||
# time feature
|
||||
time_feat = (
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
if future_time_feat is None or future_target is None
|
||||
else torch.cat(
|
||||
torch.cat(
|
||||
(
|
||||
past_time_feat[:, self._past_length - self.context_length :, ...],
|
||||
future_time_feat,
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if future_time_feat is not None
|
||||
else past_time_feat[:, self._past_length - self.context_length :, ...]
|
||||
)
|
||||
|
||||
# calculate scale
|
||||
@@ -535,11 +535,20 @@ class TFTModel(nn.Module):
|
||||
if num_parallel_samples is None:
|
||||
num_parallel_samples = self.num_parallel_samples
|
||||
|
||||
target, time_feat, scale, static_feat = self.create_network_inputs(
|
||||
feat_static_cat,
|
||||
feat_static_real,
|
||||
past_time_feat,
|
||||
past_target,
|
||||
past_observed_values,
|
||||
future_time_feat,
|
||||
(
|
||||
target,
|
||||
time_feat,
|
||||
scale,
|
||||
embedded_cat,
|
||||
static_feat,
|
||||
) = self.create_network_inputs(
|
||||
feat_static_cat=feat_static_cat,
|
||||
feat_static_real=feat_static_real,
|
||||
past_time_feat=past_time_feat,
|
||||
past_target=past_target,
|
||||
past_observed_values=past_observed_values,
|
||||
future_time_feat=future_time_feat,
|
||||
)
|
||||
|
||||
target_proj = self.target_proj(target)
|
||||
time_feat_proj = self.dynamic_proj(time_feat)
|
||||
|
||||
Reference in New Issue
Block a user