prediction

This commit is contained in:
Kashif Rasul
2022-03-30 21:38:08 +02:00
parent 7655f63bd7
commit 8efc4a7b63
+17 -6
View File
@@ -158,7 +158,7 @@ class TemporalFusionEncoder(nn.Module):
skip = self.skip_proj(skip)
encodings = self.gate(encodings)
encodings = self.lnorm(skip + encodings)
return encodings, states
return encodings
class TemporalFusionDecoder(nn.Module):
@@ -509,7 +509,7 @@ class TFTModel(nn.Module):
c_c = self.state_c(static_var)
states = [c_h.unsqueeze(0), c_c.unsqueeze(0)]
enc_out, _ = self.temporal_encoder(past_selection, future_selection, states)
enc_out = self.temporal_encoder(past_selection, future_selection, states)
dec_output = self.temporal_decoder(enc_out, static_enrichment)
@@ -569,16 +569,27 @@ class TFTModel(nn.Module):
c_c = self.state_c(static_var)
states = [c_h.unsqueeze(0), c_c.unsqueeze(0)]
enc_out, states = self.temporal_encoder(
enc_out = self.temporal_encoder(
past_selection, tgt_input=None, states=states
)
dec_output = self.temporal_decoder(enc_out, static_enrichment)
params = self.param_proj(dec_output)
for k in range(self.prediction_length):
import pdb
distr = self.output_distribution(params, scale=scale, trailing_n=1)
pdb.set_trace()
next_sample = distr.sample()
future_samples = [next_sample]
import pdb
pdb.set_trace()
for k in range(1, self.prediction_length):
# TODO
future_target_proj = pass
future_time_feat_proj = pass
future_selection, _ = self.future_selection(
[future_target_proj, future_time_feat_proj], static_selection
)
enc_out, states = self.temporal_encoder(
past_selection, future_selection, states
)