diff --git a/src/models/lightning_anp.py b/src/models/lightning_anp.py index 059f78a..2f25c67 100644 --- a/src/models/lightning_anp.py +++ b/src/models/lightning_anp.py @@ -25,7 +25,7 @@ class LatentModelPL(pl.LightningModule): def training_step(self, batch, batch_idx): assert all(torch.isfinite(d).all() for d in batch) context_x, context_y, target_x, target_y = batch - y_pred, losses, extra = = self.forward(context_x, context_y, target_x, target_y) + y_pred, losses, extra = self.forward(context_x, context_y, target_x, target_y) y_std = extra['dist'].scale tensorboard_logs = {