predict loop

This commit is contained in:
Kashif Rasul
2022-03-30 22:21:13 +02:00
parent aa2dcc2887
commit 9b2c861153
+54 -21
View File
@@ -378,11 +378,9 @@ class TFTModel(nn.Module):
dropout=dropout,
)
# TODO
# distribution output
self.param_proj = distr_output.get_args_proj(embed_dim)
# TODO
@property
def _past_length(self) -> int:
return self.context_length + max(self.lags_seq)
@@ -574,28 +572,63 @@ class TFTModel(nn.Module):
c_c = self.state_c(static_var)
states = [c_h.unsqueeze(0), c_c.unsqueeze(0)]
enc_out = self.temporal_encoder(past_selection, tgt_input=None, states=states)
dec_output = self.temporal_decoder(enc_out, static_enrichment, causal=False)
params = self.param_proj(dec_output)
repeated_scale = scale.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_time_feat_proj = future_time_feat_proj.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
distr = self.output_distribution(params, scale=scale, trailing_n=1)
repeated_past_target = (
past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0)
/ repeated_scale
)
repeated_past_selection = past_selection.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_static_enrichment = static_enrichment.repeat_interleave(
repeats=self.num_parallel_samples, dim=0
)
repeated_states = [
s.repeat_interleave(repeats=self.num_parallel_samples, dim=1)
for s in states
]
next_sample = distr.sample()
future_samples = [next_sample]
import pdb
pdb.set_trace()
for k in range(1, self.prediction_length):
# TODO
# future_target_proj = pass
# future_time_feat_proj = pass
future_samples = []
for k in range(self.prediction_length):
lagged_sequence = self.get_lagged_subsequences(
sequence=repeated_past_target,
subsequences_length=1,
shift=1,
)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(
lags_shape[0], lags_shape[1], -1
)
reshaped_lagged_sequence_proj = self.target_proj(reshaped_lagged_sequence)
next_time_feat_proj = repeated_time_feat_proj[:, k : k + 1]
future_selection, _ = self.future_selection(
[future_target_proj, future_time_feat_proj], static_selection
[reshaped_lagged_sequence_proj, next_time_feat_proj], static_selection
)
enc_out, states = self.temporal_encoder(
past_selection, future_selection, states
enc_out = self.temporal_encoder(
repeated_past_selection, future_selection, repeated_states
)
dec_output = self.temporal_decoder(enc_out, static_enrichment)
dec_output = self.temporal_decoder(
enc_out, repeated_static_enrichment, causal=False
)
params = self.param_proj(dec_output)
distr = self.output_distribution(params, scale=repeated_scale)
next_sample = distr.sample()
repeated_past_target = torch.cat(
(repeated_past_target, next_sample / repeated_scale), dim=1
)
future_samples.append(next_sample)
concat_future_samples = torch.cat(future_samples, dim=1)
return concat_future_samples.reshape(
(-1, self.num_parallel_samples, self.prediction_length) + self.target_shape,
)