mirror of
https://github.com/wassname/pytorch-ts.git
synced 2026-06-27 18:06:19 +08:00
base block does not need the future target
This commit is contained in:
@@ -238,7 +238,7 @@ class NBEATSNetwork(nn.Module):
|
||||
)
|
||||
self.net_blocks.append(net_block)
|
||||
|
||||
def forward(self, past_target: torch.Tensor, future_target: torch.Tensor):
|
||||
def forward(self, past_target: torch.Tensor):
|
||||
if len(self.net_blocks) == 1:
|
||||
_, forecast = self.net_blocks[0](past_target)
|
||||
return forecast
|
||||
@@ -315,7 +315,7 @@ class NBEATSTrainingNetwork(NBEATSNetwork):
|
||||
def forward(
|
||||
self, past_target: torch.Tensor, future_target: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
forecast = super().forward(past_target=past_target, future_target=future_target)
|
||||
forecast = super().forward(past_target=past_target)
|
||||
|
||||
if self.loss_function == "sMAPE":
|
||||
loss = self.smape_loss(forecast, future_target)
|
||||
@@ -340,7 +340,7 @@ class NBEATSPredictionNetwork(NBEATSNetwork):
|
||||
def forward(
|
||||
self, past_target: torch.Tensor, future_target: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
forecasts = super().forward(past_target=past_target, future_target=past_target)
|
||||
forecasts = super().forward(past_target=past_target)
|
||||
|
||||
return forecasts.unsqueeze(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user