diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index f3b4079fd..64cd380d8 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -304,6 +304,9 @@ install_dependencies() { # Additional RaySGD test dependencies. if [ "${SGD_TESTING-}" = 1 ]; then pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt + # TODO: eventually have a separate requirements file for Ray SGD. + # Fix PTL version to 0.10 for now. + pip install -U pytorch-lightning==0.10.0 fi # Additional Doc test dependencies. diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index 720d20a2f..21b9a60ee 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -78,8 +78,8 @@ class TuneCallback(Callback): self._handle(trainer, pl_module) def on_validation_batch_end(self, trainer: Trainer, - pl_module: LightningModule, batch, batch_idx, - dataloader_idx): + pl_module: LightningModule, outputs, batch, + batch_idx, dataloader_idx): if "validation_batch_end" in self._on: self._handle(trainer, pl_module) @@ -89,7 +89,7 @@ class TuneCallback(Callback): self._handle(trainer, pl_module) def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule, - batch, batch_idx, dataloader_idx): + outputs, batch, batch_idx, dataloader_idx): if "test_batch_end" in self._on: self._handle(trainer, pl_module) diff --git a/python/ray/tune/tests/test_integration_pytorch_lightning.py b/python/ray/tune/tests/test_integration_pytorch_lightning.py index c138ef75c..14d88c30a 100644 --- a/python/ray/tune/tests/test_integration_pytorch_lightning.py +++ b/python/ray/tune/tests/test_integration_pytorch_lightning.py @@ -34,7 +34,7 @@ class _MockModule(pl.LightningModule): def forward(self, *args, **kwargs): return self.loss - def backward(self, trainer, loss, optimizer, optimizer_idx): + def backward(self, loss, optimizer, optimizer_idx): return None def training_step(self, train_batch, batch_idx):