diff --git a/tft/module.py b/tft/module.py index 494208f..7420b7c 100644 --- a/tft/module.py +++ b/tft/module.py @@ -378,11 +378,9 @@ class TFTModel(nn.Module): dropout=dropout, ) - # TODO + # distribution output self.param_proj = distr_output.get_args_proj(embed_dim) - # TODO - @property def _past_length(self) -> int: return self.context_length + max(self.lags_seq) @@ -574,28 +572,63 @@ class TFTModel(nn.Module): c_c = self.state_c(static_var) states = [c_h.unsqueeze(0), c_c.unsqueeze(0)] - enc_out = self.temporal_encoder(past_selection, tgt_input=None, states=states) - dec_output = self.temporal_decoder(enc_out, static_enrichment, causal=False) - params = self.param_proj(dec_output) + repeated_scale = scale.repeat_interleave( + repeats=self.num_parallel_samples, dim=0 + ) + repeated_time_feat_proj = future_time_feat_proj.repeat_interleave( + repeats=self.num_parallel_samples, dim=0 + ) - distr = self.output_distribution(params, scale=scale, trailing_n=1) + repeated_past_target = ( + past_target.repeat_interleave(repeats=self.num_parallel_samples, dim=0) + / repeated_scale + ) + repeated_past_selection = past_selection.repeat_interleave( + repeats=self.num_parallel_samples, dim=0 + ) + repeated_static_enrichment = static_enrichment.repeat_interleave( + repeats=self.num_parallel_samples, dim=0 + ) + repeated_states = [ + s.repeat_interleave(repeats=self.num_parallel_samples, dim=1) + for s in states + ] - next_sample = distr.sample() - future_samples = [next_sample] - import pdb - - pdb.set_trace() - - for k in range(1, self.prediction_length): - # TODO - # future_target_proj = pass - # future_time_feat_proj = pass + future_samples = [] + for k in range(self.prediction_length): + lagged_sequence = self.get_lagged_subsequences( + sequence=repeated_past_target, + subsequences_length=1, + shift=1, + ) + lags_shape = lagged_sequence.shape + reshaped_lagged_sequence = lagged_sequence.reshape( + lags_shape[0], lags_shape[1], -1 + ) + reshaped_lagged_sequence_proj = self.target_proj(reshaped_lagged_sequence) + next_time_feat_proj = repeated_time_feat_proj[:, k : k + 1] future_selection, _ = self.future_selection( - [future_target_proj, future_time_feat_proj], static_selection + [reshaped_lagged_sequence_proj, next_time_feat_proj], static_selection ) - enc_out, states = self.temporal_encoder( - past_selection, future_selection, states + enc_out = self.temporal_encoder( + repeated_past_selection, future_selection, repeated_states ) - dec_output = self.temporal_decoder(enc_out, static_enrichment) + dec_output = self.temporal_decoder( + enc_out, repeated_static_enrichment, causal=False + ) + + params = self.param_proj(dec_output) + distr = self.output_distribution(params, scale=repeated_scale) + + next_sample = distr.sample() + repeated_past_target = torch.cat( + (repeated_past_target, next_sample / repeated_scale), dim=1 + ) + future_samples.append(next_sample) + + concat_future_samples = torch.cat(future_samples, dim=1) + return concat_future_samples.reshape( + (-1, self.num_parallel_samples, self.prediction_length) + self.target_shape, + )