From 933cf6675c4ee24b7154651e93ad9d289ff1301f Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Tue, 13 Oct 2020 15:50:11 -0700 Subject: [PATCH] [Tune] Changes for Pytorch Lightning 1.0 (#11375) --- ci/travis/install-dependencies.sh | 3 +++ python/ray/tune/integration/pytorch_lightning.py | 6 +++--- python/ray/tune/tests/test_integration_pytorch_lightning.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index f3b4079fd..64cd380d8 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -304,6 +304,9 @@ install_dependencies() { # Additional RaySGD test dependencies. if [ "${SGD_TESTING-}" = 1 ]; then pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt + # TODO: eventually have a separate requirements file for Ray SGD. + # Fix PTL version to 0.10 for now. + pip install -U pytorch-lightning==0.10.0 fi # Additional Doc test dependencies. diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index 720d20a2f..21b9a60ee 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -78,8 +78,8 @@ class TuneCallback(Callback): self._handle(trainer, pl_module) def on_validation_batch_end(self, trainer: Trainer, - pl_module: LightningModule, batch, batch_idx, - dataloader_idx): + pl_module: LightningModule, outputs, batch, + batch_idx, dataloader_idx): if "validation_batch_end" in self._on: self._handle(trainer, pl_module) @@ -89,7 +89,7 @@ class TuneCallback(Callback): self._handle(trainer, pl_module) def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule, - batch, batch_idx, dataloader_idx): + outputs, batch, batch_idx, dataloader_idx): if "test_batch_end" in self._on: self._handle(trainer, pl_module) diff --git a/python/ray/tune/tests/test_integration_pytorch_lightning.py b/python/ray/tune/tests/test_integration_pytorch_lightning.py index c138ef75c..14d88c30a 100644 --- a/python/ray/tune/tests/test_integration_pytorch_lightning.py +++ b/python/ray/tune/tests/test_integration_pytorch_lightning.py @@ -34,7 +34,7 @@ class _MockModule(pl.LightningModule): def forward(self, *args, **kwargs): return self.loss - def backward(self, trainer, loss, optimizer, optimizer_idx): + def backward(self, loss, optimizer, optimizer_idx): return None def training_step(self, train_batch, batch_idx):