mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:53:14 +08:00
[Tune] Changes for Pytorch Lightning 1.0 (#11375)
This commit is contained in:
@@ -78,8 +78,8 @@ class TuneCallback(Callback):
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_validation_batch_end(self, trainer: Trainer,
|
||||
pl_module: LightningModule, batch, batch_idx,
|
||||
dataloader_idx):
|
||||
pl_module: LightningModule, outputs, batch,
|
||||
batch_idx, dataloader_idx):
|
||||
if "validation_batch_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
@@ -89,7 +89,7 @@ class TuneCallback(Callback):
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule,
|
||||
batch, batch_idx, dataloader_idx):
|
||||
outputs, batch, batch_idx, dataloader_idx):
|
||||
if "test_batch_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ class _MockModule(pl.LightningModule):
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.loss
|
||||
|
||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||
def backward(self, loss, optimizer, optimizer_idx):
|
||||
return None
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
|
||||
Reference in New Issue
Block a user