mirror of
https://github.com/wassname/pytorch-transformer-ts.git
synced 2026-06-27 18:06:14 +08:00
prediction
This commit is contained in:
+17
-6
@@ -158,7 +158,7 @@ class TemporalFusionEncoder(nn.Module):
|
||||
skip = self.skip_proj(skip)
|
||||
encodings = self.gate(encodings)
|
||||
encodings = self.lnorm(skip + encodings)
|
||||
return encodings, states
|
||||
return encodings
|
||||
|
||||
|
||||
class TemporalFusionDecoder(nn.Module):
|
||||
@@ -509,7 +509,7 @@ class TFTModel(nn.Module):
|
||||
c_c = self.state_c(static_var)
|
||||
states = [c_h.unsqueeze(0), c_c.unsqueeze(0)]
|
||||
|
||||
enc_out, _ = self.temporal_encoder(past_selection, future_selection, states)
|
||||
enc_out = self.temporal_encoder(past_selection, future_selection, states)
|
||||
|
||||
dec_output = self.temporal_decoder(enc_out, static_enrichment)
|
||||
|
||||
@@ -569,16 +569,27 @@ class TFTModel(nn.Module):
|
||||
c_c = self.state_c(static_var)
|
||||
states = [c_h.unsqueeze(0), c_c.unsqueeze(0)]
|
||||
|
||||
enc_out, states = self.temporal_encoder(
|
||||
enc_out = self.temporal_encoder(
|
||||
past_selection, tgt_input=None, states=states
|
||||
)
|
||||
dec_output = self.temporal_decoder(enc_out, static_enrichment)
|
||||
params = self.param_proj(dec_output)
|
||||
|
||||
for k in range(self.prediction_length):
|
||||
import pdb
|
||||
distr = self.output_distribution(params, scale=scale, trailing_n=1)
|
||||
|
||||
pdb.set_trace()
|
||||
next_sample = distr.sample()
|
||||
future_samples = [next_sample]
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
|
||||
for k in range(1, self.prediction_length):
|
||||
# TODO
|
||||
future_target_proj = pass
|
||||
future_time_feat_proj = pass
|
||||
|
||||
future_selection, _ = self.future_selection(
|
||||
[future_target_proj, future_time_feat_proj], static_selection
|
||||
)
|
||||
enc_out, states = self.temporal_encoder(
|
||||
past_selection, future_selection, states
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user