mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
typo
This commit is contained in:
+5
-1
@@ -586,6 +586,9 @@ class TFTModel(nn.Module):
|
||||
repeated_past_selection = past_selection.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
repeated_static_selection = static_selection.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
repeated_static_enrichment = static_enrichment.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
@@ -609,7 +612,8 @@ class TFTModel(nn.Module):
|
||||
|
||||
next_time_feat_proj = repeated_time_feat_proj[:, k : k + 1]
|
||||
future_selection, _ = self.future_selection(
|
||||
[reshaped_lagged_sequence_proj, next_time_feat_proj], static_selection
|
||||
[reshaped_lagged_sequence_proj, next_time_feat_proj],
|
||||
repeated_static_selection,
|
||||
)
|
||||
enc_out = self.temporal_encoder(
|
||||
repeated_past_selection, future_selection, repeated_states
|
||||
|
||||
Reference in New Issue
Block a user