base block does not need the future target

This commit is contained in:
Dr. Kashif Rasul
2020-03-11 14:30:19 +01:00
parent 830a5df2aa
commit b6f333b5fd
+3 -3
View File
@@ -238,7 +238,7 @@ class NBEATSNetwork(nn.Module):
)
self.net_blocks.append(net_block)
def forward(self, past_target: torch.Tensor, future_target: torch.Tensor):
def forward(self, past_target: torch.Tensor):
if len(self.net_blocks) == 1:
_, forecast = self.net_blocks[0](past_target)
return forecast
@@ -315,7 +315,7 @@ class NBEATSTrainingNetwork(NBEATSNetwork):
def forward(
self, past_target: torch.Tensor, future_target: torch.Tensor
) -> torch.Tensor:
forecast = super().forward(past_target=past_target, future_target=future_target)
forecast = super().forward(past_target=past_target)
if self.loss_function == "sMAPE":
loss = self.smape_loss(forecast, future_target)
@@ -340,7 +340,7 @@ class NBEATSPredictionNetwork(NBEATSNetwork):
def forward(
self, past_target: torch.Tensor, future_target: torch.Tensor = None
) -> torch.Tensor:
forecasts = super().forward(past_target=past_target, future_target=past_target)
forecasts = super().forward(past_target=past_target)
return forecasts.unsqueeze(1)