This commit is contained in:
Kashif Rasul
2022-03-30 22:25:19 +02:00
parent 9b2c861153
commit 623b69a219
+5 -1
View File
@@ -586,6 +586,9 @@ class TFTModel(nn.Module):
repeated_past_selection = past_selection.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_static_selection = static_selection.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_static_enrichment = static_enrichment.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
@@ -609,7 +612,8 @@ class TFTModel(nn.Module):
next_time_feat_proj = repeated_time_feat_proj[:, k : k + 1]
future_selection, _ = self.future_selection(
[reshaped_lagged_sequence_proj, next_time_feat_proj], static_selection
[reshaped_lagged_sequence_proj, next_time_feat_proj],
repeated_static_selection,
)
enc_out = self.temporal_encoder(
repeated_past_selection, future_selection, repeated_states