diff --git a/pts/model/n_beats/n_beats_network.py b/pts/model/n_beats/n_beats_network.py index 78b7cab..42f928f 100644 --- a/pts/model/n_beats/n_beats_network.py +++ b/pts/model/n_beats/n_beats_network.py @@ -238,7 +238,7 @@ class NBEATSNetwork(nn.Module): ) self.net_blocks.append(net_block) - def forward(self, past_target: torch.Tensor, future_target: torch.Tensor): + def forward(self, past_target: torch.Tensor): if len(self.net_blocks) == 1: _, forecast = self.net_blocks[0](past_target) return forecast @@ -315,7 +315,7 @@ class NBEATSTrainingNetwork(NBEATSNetwork): def forward( self, past_target: torch.Tensor, future_target: torch.Tensor ) -> torch.Tensor: - forecast = super().forward(past_target=past_target, future_target=future_target) + forecast = super().forward(past_target=past_target) if self.loss_function == "sMAPE": loss = self.smape_loss(forecast, future_target) @@ -340,7 +340,7 @@ class NBEATSPredictionNetwork(NBEATSNetwork): def forward( self, past_target: torch.Tensor, future_target: torch.Tensor = None ) -> torch.Tensor: - forecasts = super().forward(past_target=past_target, future_target=past_target) + forecasts = super().forward(past_target=past_target) return forecasts.unsqueeze(1)