[Tune] Changes for Pytorch Lightning 1.0 (#11375)

This commit is contained in:
Amog Kamsetty
2020-10-13 15:50:11 -07:00
committed by GitHub
parent a6a94d3206
commit 933cf6675c
3 changed files with 7 additions and 4 deletions
@@ -78,8 +78,8 @@ class TuneCallback(Callback):
self._handle(trainer, pl_module)
def on_validation_batch_end(self, trainer: Trainer,
pl_module: LightningModule, batch, batch_idx,
dataloader_idx):
pl_module: LightningModule, outputs, batch,
batch_idx, dataloader_idx):
if "validation_batch_end" in self._on:
self._handle(trainer, pl_module)
@@ -89,7 +89,7 @@ class TuneCallback(Callback):
self._handle(trainer, pl_module)
def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule,
batch, batch_idx, dataloader_idx):
outputs, batch, batch_idx, dataloader_idx):
if "test_batch_end" in self._on:
self._handle(trainer, pl_module)
@@ -34,7 +34,7 @@ class _MockModule(pl.LightningModule):
def forward(self, *args, **kwargs):
return self.loss
def backward(self, trainer, loss, optimizer, optimizer_idx):
def backward(self, loss, optimizer, optimizer_idx):
return None
def training_step(self, train_batch, batch_idx):