This commit is contained in:
wassname
2020-03-14 08:49:58 +08:00
parent 1258613694
commit a43975bbce
+1 -1
View File
@@ -25,7 +25,7 @@ class LatentModelPL(pl.LightningModule):
def training_step(self, batch, batch_idx):
assert all(torch.isfinite(d).all() for d in batch)
context_x, context_y, target_x, target_y = batch
y_pred, losses, extra = = self.forward(context_x, context_y, target_x, target_y)
y_pred, losses, extra = self.forward(context_x, context_y, target_x, target_y)
y_std = extra['dist'].scale
tensorboard_logs = {