diff --git a/python/ray/tune/examples/mnist_pytorch_lightning.py b/python/ray/tune/examples/mnist_pytorch_lightning.py index 670ca1cf6..86f30ed29 100644 --- a/python/ray/tune/examples/mnist_pytorch_lightning.py +++ b/python/ray/tune/examples/mnist_pytorch_lightning.py @@ -77,27 +77,23 @@ class LightningMNISTClassifier(pl.LightningModule): loss = self.cross_entropy_loss(logits, y) accuracy = self.accuracy(logits, y) - logs = {"ptl/train_loss": loss, "ptl/train_accuracy": accuracy} - return {"loss": loss, "log": logs} + self.log("ptl/train_loss", loss) + self.log("ptl/train_accuracy", accuracy) + return loss def validation_step(self, val_batch, batch_idx): x, y = val_batch logits = self.forward(x) loss = self.cross_entropy_loss(logits, y) accuracy = self.accuracy(logits, y) - return {"val_loss": loss, "val_accuracy": accuracy} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean() - logs = {"ptl/val_loss": avg_loss, "ptl/val_accuracy": avg_acc} + self.log("ptl/val_loss", avg_loss) + self.log("ptl/val_accuracy", avg_acc) - return { - "val_loss": avg_loss, - "val_accuracy": avg_acc, - "log": logs - } @staticmethod def download_data(data_dir): @@ -144,12 +140,11 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0): callbacks=[ TuneReportCallback( { - "loss": "val_loss", - "mean_accuracy": "val_accuracy" + "loss": "ptl/val_loss", + "mean_accuracy": "ptl/val_accuracy" }, on="validation_end") ]) - trainer.fit(model) # __tune_train_end__ @@ -169,8 +164,8 @@ def train_mnist_tune_checkpoint(config, callbacks=[ TuneReportCheckpointCallback( metrics={ - "loss": "val_loss", - "mean_accuracy": "val_accuracy" + "loss": "ptl/val_loss", + "mean_accuracy": "ptl/val_accuracy" }, filename="checkpoint", on="validation_end") diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index d9d233a04..720d20a2f 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -170,6 +170,9 @@ class TuneReportCallback(TuneCallback): self._metrics = metrics def _handle(self, trainer: Trainer, pl_module: LightningModule): + # Don't report if just doing initial validation sanity checks. + if trainer.running_sanity_check: + return report_dict = {} for key in self._metrics: if isinstance(self._metrics, dict): @@ -206,6 +209,8 @@ class _TuneCheckpointCallback(TuneCallback): self._filename = filename def _handle(self, trainer: Trainer, pl_module: LightningModule): + if trainer.running_sanity_check: + return with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir: trainer.save_checkpoint( os.path.join(checkpoint_dir, self._filename)) diff --git a/python/ray/tune/tests/test_integration_pytorch_lightning.py b/python/ray/tune/tests/test_integration_pytorch_lightning.py index ba773d5de..c138ef75c 100644 --- a/python/ray/tune/tests/test_integration_pytorch_lightning.py +++ b/python/ray/tune/tests/test_integration_pytorch_lightning.py @@ -46,8 +46,8 @@ class _MockModule(pl.LightningModule): def validation_epoch_end(self, outputs): avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean() avg_val_acc = torch.stack([x["val_acc"] for x in outputs]).mean() - - return {"avg_val_loss": avg_val_loss, "avg_val_acc": avg_val_acc} + self.log("avg_val_loss", avg_val_loss) + self.log("avg_val_acc", avg_val_acc) def configure_optimizers(self): return None diff --git a/python/ray/util/sgd/torch/ptl_operator.py b/python/ray/util/sgd/torch/ptl_operator.py index 8ce3829ff..a484de7f7 100644 --- a/python/ray/util/sgd/torch/ptl_operator.py +++ b/python/ray/util/sgd/torch/ptl_operator.py @@ -5,6 +5,7 @@ import torch from pytorch_lightning.core.step_result import Result from pytorch_lightning.overrides.data_parallel import \ LightningDistributedDataParallel +from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin import pytorch_lightning as ptl @@ -39,8 +40,8 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin, assert len(models) == 1 model = models[0] assert isinstance(model, ptl.LightningModule) - # This will default to LightningDistributedDataParallel. - model = model.configure_ddp(model=model, device_ids=device_ids) + model = LightningDistributedDataParallel( + model, device_ids=device_ids, find_unused_parameters=True) return [model] @property @@ -110,7 +111,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin, # Call model.setup. ptl_module.setup("fit") - if not self.is_overridden("configure_optimizers", ptl_module): + if not is_overridden("configure_optimizers", ptl_module): raise MisconfigurationException( "No `configure_optimizers()` method defined.") @@ -232,7 +233,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin, break processed_outputs = None - if self.is_overridden("training_epoch_end", model): + if is_overridden("training_epoch_end", model): raw_outputs = [eo["raw_output"] for eo in epoch_outputs] processed_outputs = model.training_epoch_end(raw_outputs) @@ -316,7 +317,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin, # allow any mode to define training_step_end # do something will all the dp outputs (like softmax) - if self.is_overridden("training_step_end", model): + if is_overridden("training_step_end", model): output = model.training_step_end(output) # Extract loss from output if dictionary. @@ -397,7 +398,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin, val_outputs.append(batch_output) processed_outputs = None - if self.is_overridden("validation_epoch_end", model): + if is_overridden("validation_epoch_end", model): raw_outputs = [vo["raw_output"] for vo in val_outputs] processed_outputs = model.training_epoch_end(raw_outputs) @@ -440,7 +441,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin, def validate_batch(self, batch, batch_info): model = self.get_model() batch_idx = batch_info["batch_idx"] - if self.is_overridden("on_validation_batch_start", model): + if is_overridden("on_validation_batch_start", model): model.on_validation_batch_start( batch=batch, batch_idx=batch_idx, dataloader_idx=0) args = [batch, batch_idx] @@ -462,7 +463,7 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin, raise ValueError("EvalResult objects are not supported. Please " "return a dictionary instead.") - if self.is_overridden("on_validation_step_end", model): + if is_overridden("on_validation_step_end", model): output = model.validation_step_end(output) if self.is_function_implemented("on_validation_batch_end", model):