mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
predict loop
This commit is contained in:
+54
-21
@@ -378,11 +378,9 @@ class TFTModel(nn.Module):
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
# TODO
|
||||
# distribution output
|
||||
self.param_proj = distr_output.get_args_proj(embed_dim)
|
||||
|
||||
# TODO
|
||||
|
||||
@property
|
||||
def _past_length(self) -> int:
|
||||
return self.context_length + max(self.lags_seq)
|
||||
@@ -574,28 +572,63 @@ class TFTModel(nn.Module):
|
||||
c_c = self.state_c(static_var)
|
||||
states = [c_h.unsqueeze(0), c_c.unsqueeze(0)]
|
||||
|
||||
enc_out = self.temporal_encoder(past_selection, tgt_input=None, states=states)
|
||||
dec_output = self.temporal_decoder(enc_out, static_enrichment, causal=False)
|
||||
params = self.param_proj(dec_output)
|
||||
repeated_scale = scale.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
repeated_time_feat_proj = future_time_feat_proj.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
|
||||
distr = self.output_distribution(params, scale=scale, trailing_n=1)
|
||||
repeated_past_target = (
|
||||
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
|
||||
/ repeated_scale
|
||||
)
|
||||
repeated_past_selection = past_selection.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
repeated_static_enrichment = static_enrichment.repeat_interleave(
|
||||
repeats=self.num_parallel_samples, dim=0
|
||||
)
|
||||
repeated_states = [
|
||||
s.repeat_interleave(repeats=self.num_parallel_samples, dim=1)
|
||||
for s in states
|
||||
]
|
||||
|
||||
next_sample = distr.sample()
|
||||
future_samples = [next_sample]
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
for k in range(1, self.prediction_length):
|
||||
# TODO
|
||||
# future_target_proj = pass
|
||||
# future_time_feat_proj = pass
|
||||
future_samples = []
|
||||
for k in range(self.prediction_length):
|
||||
lagged_sequence = self.get_lagged_subsequences(
|
||||
sequence=repeated_past_target,
|
||||
subsequences_length=1,
|
||||
shift=1,
|
||||
)
|
||||
lags_shape = lagged_sequence.shape
|
||||
reshaped_lagged_sequence = lagged_sequence.reshape(
|
||||
lags_shape[0], lags_shape[1], -1
|
||||
)
|
||||
reshaped_lagged_sequence_proj = self.target_proj(reshaped_lagged_sequence)
|
||||
|
||||
next_time_feat_proj = repeated_time_feat_proj[:, k : k + 1]
|
||||
future_selection, _ = self.future_selection(
|
||||
[future_target_proj, future_time_feat_proj], static_selection
|
||||
[reshaped_lagged_sequence_proj, next_time_feat_proj], static_selection
|
||||
)
|
||||
enc_out, states = self.temporal_encoder(
|
||||
past_selection, future_selection, states
|
||||
enc_out = self.temporal_encoder(
|
||||
repeated_past_selection, future_selection, repeated_states
|
||||
)
|
||||
|
||||
dec_output = self.temporal_decoder(enc_out, static_enrichment)
|
||||
dec_output = self.temporal_decoder(
|
||||
enc_out, repeated_static_enrichment, causal=False
|
||||
)
|
||||
|
||||
params = self.param_proj(dec_output)
|
||||
distr = self.output_distribution(params, scale=repeated_scale)
|
||||
|
||||
next_sample = distr.sample()
|
||||
repeated_past_target = torch.cat(
|
||||
(repeated_past_target, next_sample / repeated_scale), dim=1
|
||||
)
|
||||
future_samples.append(next_sample)
|
||||
|
||||
concat_future_samples = torch.cat(future_samples, dim=1)
|
||||
return concat_future_samples.reshape(
|
||||
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user