diff --git a/doc/source/conf.py b/doc/source/conf.py index cdcb293f8..43bd34876 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -22,10 +22,19 @@ from custom_directives import CustomGalleryItemDirective # These lines added to enable Sphinx to work without installing Ray. import mock + + +class ChildClassMock(mock.MagicMock): + @classmethod + def __getattr__(cls, name): + return mock.Mock + + MOCK_MODULES = [ "blist", "gym", "gym.spaces", + "kubernetes", "psutil", "ray._raylet", "ray.core.generated", @@ -55,6 +64,7 @@ MOCK_MODULES = [ "torch.nn.parallel", "torch.utils.data", "torch.utils.data.distributed", + "wandb", "zoopt", ] import scipy.stats @@ -65,7 +75,7 @@ for mod_name in MOCK_MODULES: # ray.rllib.models.action_dist.py and # ray.rllib.models.lstm.py will use tf.VERSION sys.modules["tensorflow"].VERSION = "9.9.9" - +sys.modules["pytorch_lightning"] = ChildClassMock() # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. diff --git a/doc/source/tune/_tutorials/tune-pytorch-lightning.rst b/doc/source/tune/_tutorials/tune-pytorch-lightning.rst index 27984e504..c2c2b4178 100644 --- a/doc/source/tune/_tutorials/tune-pytorch-lightning.rst +++ b/doc/source/tune/_tutorials/tune-pytorch-lightning.rst @@ -95,14 +95,20 @@ that can be used to plug custom functions into the training loop. This way the o ``LightningModule`` does not have to be altered at all. Also, we could use the same callback for multiple modules. -The callback just reports some metrics back to Tune after each validation epoch: +Ray Tune comes with ready-to-use PyTorch Lightning callbacks. To report metrics +back to Tune after each validation epoch, we will use the ``TuneReportCallback``: -.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py - :language: python - :start-after: __tune_callback_begin__ - :end-before: __tune_callback_end__ +.. code-block:: python -Note that we have to explicitly convert the metrics from a tensor to a Python value. + from ray.tune.integration.pytorch_lightning import TuneReportCallback + callback = TuneReportCallback({ + "loss": "avg_val_loss", + "mean_accuracy": "avg_val_accuracy" + }, on="validation_end") + +This callback will take the ``avg_val_loss`` and ``avg_val_accuracy`` values +from the PyTorch Lightning trainer and report them to Tune as the ``loss`` +and ``mean_accuracy``, respectively. Adding the Tune training function ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -282,12 +288,20 @@ Adding checkpoints to the PyTorch Lightning module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ First, we need to introduce -another callback to save model checkpoints: +another callback to save model checkpoints. Since Tune requires a call to +``tune.report()`` after creating a new checkpoint to register it, we will use +a combined reporting and checkpointing callback: -.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py - :language: python - :start-after: __tune_checkpoint_callback_begin__ - :end-before: __tune_checkpoint_callback_end__ +.. code-block:: python + + from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback + callback = TuneReportCheckpointCallback( + metrics={"loss": "val_loss", "mean_accuracy": "val_accuracy"}, + filename="checkpoint", + on="validation_end") + +The ``checkpoint`` value is the name of the checkpoint file within the +checkpoint directory. We also include checkpoint loading in our training function: @@ -338,5 +352,5 @@ In some runs, the parameters have been perturbed. And the best configuration eve mean validation accuracy of ``0.987062``! In summary, PyTorch Lightning Modules are easy to extend to use with Tune. It just took -us writing one or two callbacks and a small wrapper function to get great performing +us importing one or two callbacks and a small wrapper function to get great performing parameter configurations. diff --git a/doc/source/tune/_tutorials/tune-wandb.rst b/doc/source/tune/_tutorials/tune-wandb.rst index 315aaaaf5..8cb851f58 100644 --- a/doc/source/tune/_tutorials/tune-wandb.rst +++ b/doc/source/tune/_tutorials/tune-wandb.rst @@ -29,7 +29,9 @@ Please :doc:`see here for a full example `. .. _tune-wandb-logger: .. autoclass:: ray.tune.integration.wandb.WandbLogger + :noindex: .. _tune-wandb-mixin: .. autofunction:: ray.tune.integration.wandb.wandb_mixin + :noindex: diff --git a/doc/source/tune/api_docs/integration.rst b/doc/source/tune/api_docs/integration.rst new file mode 100644 index 000000000..7ccb62040 --- /dev/null +++ b/doc/source/tune/api_docs/integration.rst @@ -0,0 +1,46 @@ +.. _tune-integration: + +External library integrations (tune.integration) +================================================ + +.. contents:: + :local: + :depth: 1 + +.. _tune-integration-kubernetes: + +Kubernetes (tune.integration.kubernetes) +---------------------------------------- + +.. autofunction:: ray.tune.integration.kubernetes.NamespacedKubernetesSyncer + +.. _tune-integration-pytorch-lightning: + +PyTorch Lightning (tune.integration.pytorch_lightning) +------------------------------------------------------ + +.. autoclass:: ray.tune.integration.pytorch_lightning.TuneReportCallback + +.. autoclass:: ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback + +.. _tune-integration-torch: + +Torch (tune.integration.torch) +------------------------------ + +.. autofunction:: ray.tune.integration.torch.DistributedTrainableCreator + +.. autofunction:: ray.tune.integration.torch.distributed_checkpoint_dir + +.. autofunction:: ray.tune.integration.torch.is_distributed_trainable + +.. _tune-integration-wandb: + +Weights and Biases (tune.integration.wandb) +------------------------------------------- + +:ref:`See also here `. + +.. autoclass:: ray.tune.integration.wandb.WandbLogger + +.. autofunction:: ray.tune.integration.wandb.wandb_mixin \ No newline at end of file diff --git a/doc/source/tune/api_docs/overview.rst b/doc/source/tune/api_docs/overview.rst index 2a56dbdc4..cc68b3bbb 100644 --- a/doc/source/tune/api_docs/overview.rst +++ b/doc/source/tune/api_docs/overview.rst @@ -22,6 +22,7 @@ on `Github`_. schedulers.rst sklearn.rst logging.rst + integration.rst internals.rst client.rst cli.rst diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index 6d0935657..6d6308e40 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -294,10 +294,13 @@ Ray also offers lightweight integrations to distribute your model training on Ra .. autofunction:: ray.tune.integration.torch.DistributedTrainableCreator + :noindex: .. autofunction:: ray.tune.integration.torch.distributed_checkpoint_dir + :noindex: .. autofunction:: ray.tune.integration.torch.is_distributed_trainable + :noindex: tune.DurableTrainable --------------------- diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 2c4617da8..47c483837 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -101,6 +101,14 @@ py_test( tags = ["exclusive"], ) +py_test( + name = "test_integration_pytorch_lightning", + size = "small", + srcs = ["tests/test_integration_pytorch_lightning.py"], + deps = [":tune_lib"], + tags = ["exclusive"], +) + py_test( name = "test_integration_wandb", size = "small", diff --git a/python/ray/tune/examples/mnist_pytorch_lightning.py b/python/ray/tune/examples/mnist_pytorch_lightning.py index 95ac2ee60..670ca1cf6 100644 --- a/python/ray/tune/examples/mnist_pytorch_lightning.py +++ b/python/ray/tune/examples/mnist_pytorch_lightning.py @@ -15,12 +15,13 @@ import os import shutil from functools import partial from tempfile import mkdtemp -from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load from ray import tune from ray.tune import CLIReporter from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining +from ray.tune.integration.pytorch_lightning import TuneReportCallback, \ + TuneReportCheckpointCallback # __import_tune_end__ @@ -93,8 +94,8 @@ class LightningMNISTClassifier(pl.LightningModule): logs = {"ptl/val_loss": avg_loss, "ptl/val_accuracy": avg_acc} return { - "avg_val_loss": avg_loss, - "avg_val_accuracy": avg_acc, + "val_loss": avg_loss, + "val_accuracy": avg_acc, "log": logs } @@ -131,15 +132,6 @@ def train_mnist(config): # __lightning_end__ -# __tune_callback_begin__ -class TuneReportCallback(Callback): - def on_validation_end(self, trainer, pl_module): - tune.report( - loss=trainer.callback_metrics["avg_val_loss"].item(), - mean_accuracy=trainer.callback_metrics["avg_val_accuracy"].item()) -# __tune_callback_end__ - - # __tune_train_begin__ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0): model = LightningMNISTClassifier(config, data_dir) @@ -149,35 +141,40 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0): logger=TensorBoardLogger( save_dir=tune.get_trial_dir(), name="", version="."), progress_bar_refresh_rate=0, - callbacks=[TuneReportCallback()]) + callbacks=[ + TuneReportCallback( + { + "loss": "val_loss", + "mean_accuracy": "val_accuracy" + }, + on="validation_end") + ]) trainer.fit(model) # __tune_train_end__ -# __tune_checkpoint_callback_begin__ -class CheckpointCallback(Callback): - def on_validation_end(self, trainer, pl_module): - with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir: - trainer.save_checkpoint(os.path.join(checkpoint_dir, "checkpoint")) -# __tune_checkpoint_callback_end__ - - # __tune_train_checkpoint_begin__ -def train_mnist_tune_checkpoint( - config, - checkpoint_dir=None, - data_dir=None, - num_epochs=10, - num_gpus=0): +def train_mnist_tune_checkpoint(config, + checkpoint_dir=None, + data_dir=None, + num_epochs=10, + num_gpus=0): trainer = pl.Trainer( max_epochs=num_epochs, gpus=num_gpus, logger=TensorBoardLogger( save_dir=tune.get_trial_dir(), name="", version="."), progress_bar_refresh_rate=0, - callbacks=[CheckpointCallback(), - TuneReportCallback()]) + callbacks=[ + TuneReportCheckpointCallback( + metrics={ + "loss": "val_loss", + "mean_accuracy": "val_accuracy" + }, + filename="checkpoint", + on="validation_end") + ]) if checkpoint_dir: # Currently, this leads to errors: # model = LightningMNISTClassifier.load_from_checkpoint( @@ -189,8 +186,7 @@ def train_mnist_tune_checkpoint( model = LightningMNISTClassifier._load_model_state(ckpt, config=config) trainer.current_epoch = ckpt["epoch"] else: - model = LightningMNISTClassifier( - config=config, data_dir=data_dir) + model = LightningMNISTClassifier(config=config, data_dir=data_dir) trainer.fit(model) # __tune_train_checkpoint_end__ @@ -225,7 +221,10 @@ def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0): data_dir=data_dir, num_epochs=num_epochs, num_gpus=gpus_per_trial), - resources_per_trial={"cpu": 1, "gpu": gpus_per_trial}, + resources_per_trial={ + "cpu": 1, + "gpu": gpus_per_trial + }, config=config, num_samples=num_samples, scheduler=scheduler, @@ -268,7 +267,10 @@ def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0): data_dir=data_dir, num_epochs=num_epochs, num_gpus=gpus_per_trial), - resources_per_trial={"cpu": 1, "gpu": gpus_per_trial}, + resources_per_trial={ + "cpu": 1, + "gpu": gpus_per_trial + }, config=config, num_samples=num_samples, scheduler=scheduler, diff --git a/python/ray/tune/integration/kubernetes.py b/python/ray/tune/integration/kubernetes.py index c83e20ac3..78f24ef35 100644 --- a/python/ray/tune/integration/kubernetes.py +++ b/python/ray/tune/integration/kubernetes.py @@ -17,8 +17,7 @@ def NamespacedKubernetesSyncer(namespace, use_rsync=False): installed in the Kubernetes pods for this to work. If False, ``tar`` will need to be installed instead. - Returns: A ``KubernetesSyncer`` class to be passed to - ``tune.run(sync_to_driver)``. + Returns: A ``KubernetesSyncer`` class to be passed to ``tune.run()``. Example: diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py new file mode 100644 index 000000000..230929ff1 --- /dev/null +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -0,0 +1,260 @@ +from typing import Dict, List, Optional, Union + +from pytorch_lightning import Callback, Trainer, LightningModule +from ray import tune + +import os + + +class TuneCallback(Callback): + """Base class for Tune's PyTorch Lightning callbacks.""" + _allowed = [ + "init_start", "init_end", "fit_start", "fit_end", "sanity_check_start", + "sanity_check_end", "epoch_start", "epoch_end", "batch_start", + "validation_batch_start", "validation_batch_end", "test_batch_start", + "test_batch_end", "batch_end", "train_start", "train_end", + "validation_start", "validation_end", "test_start", "test_end", + "keyboard_interrupt" + ] + + def __init__(self, on: Union[str, List[str]] = "validation_end"): + if not isinstance(on, list): + on = [on] + if any(w not in self._allowed for w in on): + raise ValueError( + "Invalid trigger time selected: {}. Must be one of {}".format( + on, self._allowed)) + self._on = on + + def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]): + raise NotImplementedError + + def on_init_start(self, trainer: Trainer): + if "init_start" in self._on: + self._handle(trainer, None) + + def on_init_end(self, trainer: Trainer): + if "init_end" in self._on: + self._handle(trainer, None) + + def on_fit_start(self, + trainer: Trainer, + pl_module: Optional[LightningModule] = None): + if "fit_start" in self._on: + self._handle(trainer, None) + + def on_fit_end(self, + trainer: Trainer, + pl_module: Optional[LightningModule] = None): + if "fit_end" in self._on: + self._handle(trainer, None) + + def on_sanity_check_start(self, trainer: Trainer, + pl_module: LightningModule): + if "sanity_check_start" in self._on: + self._handle(trainer, pl_module) + + def on_sanity_check_end(self, trainer: Trainer, + pl_module: LightningModule): + if "sanity_check_end" in self._on: + self._handle(trainer, pl_module) + + def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule): + if "epoch_start" in self._on: + self._handle(trainer, pl_module) + + def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule): + if "epoch_end" in self._on: + self._handle(trainer, pl_module) + + def on_batch_start(self, trainer: Trainer, pl_module: LightningModule): + if "batch_start" in self._on: + self._handle(trainer, pl_module) + + def on_validation_batch_start(self, trainer: Trainer, + pl_module: LightningModule, batch, batch_idx, + dataloader_idx): + if "validation_batch_start" in self._on: + self._handle(trainer, pl_module) + + def on_validation_batch_end(self, trainer: Trainer, + pl_module: LightningModule, batch, batch_idx, + dataloader_idx): + if "validation_batch_end" in self._on: + self._handle(trainer, pl_module) + + def on_test_batch_start(self, trainer: Trainer, pl_module: LightningModule, + batch, batch_idx, dataloader_idx): + if "test_batch_start" in self._on: + self._handle(trainer, pl_module) + + def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule, + batch, batch_idx, dataloader_idx): + if "test_batch_end" in self._on: + self._handle(trainer, pl_module) + + def on_batch_end(self, trainer: Trainer, pl_module: LightningModule): + if "batch_end" in self._on: + self._handle(trainer, pl_module) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule): + if "train_start" in self._on: + self._handle(trainer, pl_module) + + def on_train_end(self, trainer: Trainer, pl_module: LightningModule): + if "train_end" in self._on: + self._handle(trainer, pl_module) + + def on_validation_start(self, trainer: Trainer, + pl_module: LightningModule): + if "validation_start" in self._on: + self._handle(trainer, pl_module) + + def on_validation_end(self, trainer: Trainer, pl_module: LightningModule): + if "validation_end" in self._on: + self._handle(trainer, pl_module) + + def on_test_start(self, trainer: Trainer, pl_module: LightningModule): + if "test_start" in self._on: + self._handle(trainer, pl_module) + + def on_test_end(self, trainer: Trainer, pl_module: LightningModule): + if "test_end" in self._on: + self._handle(trainer, pl_module) + + def on_keyboard_interrupt(self, trainer: Trainer, + pl_module: LightningModule): + if "keyboard_interrupt" in self._on: + self._handle(trainer, pl_module) + + +class TuneReportCallback(TuneCallback): + """PyTorch Lightning to Ray Tune reporting callback + + Reports metrics to Ray Tune. + + Args: + metrics (str|list|dict): Metrics to report to Tune. If this is a list, + each item describes the metric key reported to PyTorch Lightning, + and it will reported under the same name to Tune. If this is a + dict, each key will be the name reported to Tune and the respective + value will be the metric key reported to PyTorch Lightning. + on (str|list): When to trigger checkpoint creations. Must be one of + the PyTorch Lightning event hooks (less the ``on_``), e.g. + "batch_start", or "train_end". Defaults to "validation_end". + + Example: + + .. code-block:: python + + import pytorch_lightning as pl + from ray.tune.integration.pytorch_lightning import TuneReportCallback + + # Report loss and accuracy to Tune after each validation epoch: + trainer = pl.Trainer(callbacks=[TuneReportCallback( + ["val_loss", "val_acc"], on="validation_end")]) + + # Same as above, but report as `loss` and `mean_accuracy`: + trainer = pl.Trainer(callbacks=[TuneReportCallback( + {"loss": "val_loss", "mean_accuracy": "val_acc"}, + on="validation_end")]) + + """ + + def __init__(self, + metrics: Union[str, List[str], Dict[str, str]], + on: Union[str, List[str]] = "validation_end"): + super(TuneReportCallback, self).__init__(on) + if isinstance(metrics, str): + metrics = [metrics] + self._metrics = metrics + + def _handle(self, trainer: Trainer, pl_module: LightningModule): + report_dict = {} + for key in self._metrics: + if isinstance(self._metrics, dict): + metric = self._metrics[key] + else: + metric = key + report_dict[key] = trainer.callback_metrics[metric].item() + tune.report(**report_dict) + + +class _TuneCheckpointCallback(TuneCallback): + """PyTorch Lightning checkpoint callback + + Saves checkpoints after each validation step. + + Checkpoint are currently not registered if no ``tune.report()`` call + is made afterwards. Consider using ``TuneReportCheckpointCallback`` + instead. + + Args: + filename (str): Filename of the checkpoint within the checkpoint + directory. Defaults to "checkpoint". + on (str|list): When to trigger checkpoint creations. Must be one of + the PyTorch Lightning event hooks (less the ``on_``), e.g. + "batch_start", or "train_end". Defaults to "validation_end". + + + """ + + def __init__(self, + filename: str = "checkpoint", + on: Union[str, List[str]] = "validation_end"): + super(_TuneCheckpointCallback, self).__init__(on) + self._filename = filename + + def _handle(self, trainer: Trainer, pl_module: LightningModule): + with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir: + trainer.save_checkpoint( + os.path.join(checkpoint_dir, self._filename)) + + +class TuneReportCheckpointCallback(TuneCallback): + """PyTorch Lightning report and checkpoint callback + + Saves checkpoints after each validation step. Also reports metrics to Tune, + which is needed for checkpoint registration. + + Args: + metrics (str|list|dict): Metrics to report to Tune. If this is a list, + each item describes the metric key reported to PyTorch Lightning, + and it will reported under the same name to Tune. If this is a + dict, each key will be the name reported to Tune and the respective + value will be the metric key reported to PyTorch Lightning. + filename (str): Filename of the checkpoint within the checkpoint + directory. Defaults to "checkpoint". + on (str|list): When to trigger checkpoint creations. Must be one of + the PyTorch Lightning event hooks (less the ``on_``), e.g. + "batch_start", or "train_end". Defaults to "validation_end". + + + Example: + + .. code-block:: python + + import pytorch_lightning as pl + from ray.tune.integration.pytorch_lightning import \ + TuneReportCheckpointCallback + + # Save checkpoint after each training batch and after each + # validation epoch. + trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback( + metrics={"loss": "val_loss", "mean_accuracy": "val_acc"}, + filename="trainer.ckpt", on="validation_end")]) + + + """ + + def __init__(self, + metrics: Union[str, List[str], Dict[str, str]], + filename: str = "checkpoint", + on: Union[str, List[str]] = "validation_end"): + super(TuneReportCheckpointCallback, self).__init__(on) + self._checkpoint = _TuneCheckpointCallback(filename, on) + self._report = TuneReportCallback(metrics, on) + + def _handle(self, trainer: Trainer, pl_module: LightningModule): + self._checkpoint._handle(trainer, pl_module) + self._report._handle(trainer, pl_module) diff --git a/python/ray/tune/tests/test_integration_pytorch_lightning.py b/python/ray/tune/tests/test_integration_pytorch_lightning.py new file mode 100644 index 000000000..ba773d5de --- /dev/null +++ b/python/ray/tune/tests/test_integration_pytorch_lightning.py @@ -0,0 +1,144 @@ +import os +import shutil +import tempfile +import unittest +import pytorch_lightning as pl +import torch +from ray.tune.result import TRAINING_ITERATION + +from torch.utils.data import DataLoader, Dataset + +from ray import tune +from ray.tune.integration.pytorch_lightning import TuneReportCallback, \ + TuneReportCheckpointCallback, _TuneCheckpointCallback + + +class _MockDataset(Dataset): + def __init__(self, values): + self.values = values + + def __getitem__(self, index): + return self.values[index] + + def __len__(self): + return len(self.values) + + +class _MockModule(pl.LightningModule): + def __init__(self, loss, acc): + super().__init__() + + self.loss = torch.tensor(loss) + self.acc = torch.tensor(acc) + + def forward(self, *args, **kwargs): + return self.loss + + def backward(self, trainer, loss, optimizer, optimizer_idx): + return None + + def training_step(self, train_batch, batch_idx): + return {"loss": self.loss, "acc": self.acc} + + def validation_step(self, val_batch, batch_idx): + return {"val_loss": self.loss * 1.1, "val_acc": self.acc * 0.9} + + def validation_epoch_end(self, outputs): + avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean() + avg_val_acc = torch.stack([x["val_acc"] for x in outputs]).mean() + + return {"avg_val_loss": avg_val_loss, "avg_val_acc": avg_val_acc} + + def configure_optimizers(self): + return None + + def train_dataloader(self): + return DataLoader(_MockDataset(list(range(10))), batch_size=1) + + def val_dataloader(self): + return DataLoader(_MockDataset(list(range(10))), batch_size=1) + + +class PyTorchLightningIntegrationTest(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def testReportCallback(self): + def train(config): + module = _MockModule(10, 20) + trainer = pl.Trainer( + max_epochs=1, + callbacks=[ + TuneReportCallback( + { + "tune_loss": "avg_val_loss" + }, on="validation_end") + ]) + trainer.fit(module) + + analysis = tune.run(train, stop={TRAINING_ITERATION: 1}) + + self.assertEqual(analysis.trials[0].last_result["tune_loss"], 10 * 1.1) + + def testCheckpointCallback(self): + tmpdir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmpdir)) + + def train(config): + module = _MockModule(10, 20) + trainer = pl.Trainer( + max_epochs=1, + callbacks=[ + _TuneCheckpointCallback( + "trainer.ckpt", on=["batch_end", "train_end"]) + ]) + trainer.fit(module) + + analysis = tune.run( + train, + stop={TRAINING_ITERATION: 10}, + keep_checkpoints_num=100, + local_dir=tmpdir) + + checkpoints = [ + dir for dir in os.listdir(analysis.trials[0].logdir) + if dir.startswith("checkpoint") + ] + # 10 checkpoints after each batch, 1 checkpoint at end + self.assertEqual(len(checkpoints), 11) + + def testReportCheckpointCallback(self): + tmpdir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(tmpdir)) + + def train(config): + module = _MockModule(10, 20) + trainer = pl.Trainer( + max_epochs=1, + callbacks=[ + TuneReportCheckpointCallback( + ["avg_val_loss"], "trainer.ckpt", on="validation_end") + ]) + trainer.fit(module) + + analysis = tune.run( + train, + stop={TRAINING_ITERATION: 10}, + keep_checkpoints_num=100, + local_dir=tmpdir) + + checkpoints = [ + dir for dir in os.listdir(analysis.trials[0].logdir) + if dir.startswith("checkpoint") + ] + # 1 checkpoint after the validation step + self.assertEqual(len(checkpoints), 1) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__]))