From b6f333b5fda2fc551156007024113d50b90f50d2 Mon Sep 17 00:00:00 2001 From: "Dr. Kashif Rasul" Date: Wed, 11 Mar 2020 14:30:19 +0100 Subject: [PATCH] base block does not need the future target --- pts/model/n_beats/n_beats_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pts/model/n_beats/n_beats_network.py b/pts/model/n_beats/n_beats_network.py index 78b7cab..42f928f 100644 --- a/pts/model/n_beats/n_beats_network.py +++ b/pts/model/n_beats/n_beats_network.py @@ -238,7 +238,7 @@ class NBEATSNetwork(nn.Module): ) self.net_blocks.append(net_block) - def forward(self, past_target: torch.Tensor, future_target: torch.Tensor): + def forward(self, past_target: torch.Tensor): if len(self.net_blocks) == 1: _, forecast = self.net_blocks[0](past_target) return forecast @@ -315,7 +315,7 @@ class NBEATSTrainingNetwork(NBEATSNetwork): def forward( self, past_target: torch.Tensor, future_target: torch.Tensor ) -> torch.Tensor: - forecast = super().forward(past_target=past_target, future_target=future_target) + forecast = super().forward(past_target=past_target) if self.loss_function == "sMAPE": loss = self.smape_loss(forecast, future_target) @@ -340,7 +340,7 @@ class NBEATSPredictionNetwork(NBEATSNetwork): def forward( self, past_target: torch.Tensor, future_target: torch.Tensor = None ) -> torch.Tensor: - forecasts = super().forward(past_target=past_target, future_target=past_target) + forecasts = super().forward(past_target=past_target) return forecasts.unsqueeze(1)