[tune] Added PyTorch Lightning callbacks to integrations (#10220)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
krfricke
2020-08-31 23:30:48 +01:00
committed by GitHub
parent d8e7b144e4
commit f3f698816d
11 changed files with 537 additions and 48 deletions
+8
View File
@@ -101,6 +101,14 @@ py_test(
tags = ["exclusive"],
)
py_test(
name = "test_integration_pytorch_lightning",
size = "small",
srcs = ["tests/test_integration_pytorch_lightning.py"],
deps = [":tune_lib"],
tags = ["exclusive"],
)
py_test(
name = "test_integration_wandb",
size = "small",
@@ -15,12 +15,13 @@ import os
import shutil
from functools import partial
from tempfile import mkdtemp
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.integration.pytorch_lightning import TuneReportCallback, \
TuneReportCheckpointCallback
# __import_tune_end__
@@ -93,8 +94,8 @@ class LightningMNISTClassifier(pl.LightningModule):
logs = {"ptl/val_loss": avg_loss, "ptl/val_accuracy": avg_acc}
return {
"avg_val_loss": avg_loss,
"avg_val_accuracy": avg_acc,
"val_loss": avg_loss,
"val_accuracy": avg_acc,
"log": logs
}
@@ -131,15 +132,6 @@ def train_mnist(config):
# __lightning_end__
# __tune_callback_begin__
class TuneReportCallback(Callback):
def on_validation_end(self, trainer, pl_module):
tune.report(
loss=trainer.callback_metrics["avg_val_loss"].item(),
mean_accuracy=trainer.callback_metrics["avg_val_accuracy"].item())
# __tune_callback_end__
# __tune_train_begin__
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
model = LightningMNISTClassifier(config, data_dir)
@@ -149,35 +141,40 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
logger=TensorBoardLogger(
save_dir=tune.get_trial_dir(), name="", version="."),
progress_bar_refresh_rate=0,
callbacks=[TuneReportCallback()])
callbacks=[
TuneReportCallback(
{
"loss": "val_loss",
"mean_accuracy": "val_accuracy"
},
on="validation_end")
])
trainer.fit(model)
# __tune_train_end__
# __tune_checkpoint_callback_begin__
class CheckpointCallback(Callback):
def on_validation_end(self, trainer, pl_module):
with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
trainer.save_checkpoint(os.path.join(checkpoint_dir, "checkpoint"))
# __tune_checkpoint_callback_end__
# __tune_train_checkpoint_begin__
def train_mnist_tune_checkpoint(
config,
checkpoint_dir=None,
data_dir=None,
num_epochs=10,
num_gpus=0):
def train_mnist_tune_checkpoint(config,
checkpoint_dir=None,
data_dir=None,
num_epochs=10,
num_gpus=0):
trainer = pl.Trainer(
max_epochs=num_epochs,
gpus=num_gpus,
logger=TensorBoardLogger(
save_dir=tune.get_trial_dir(), name="", version="."),
progress_bar_refresh_rate=0,
callbacks=[CheckpointCallback(),
TuneReportCallback()])
callbacks=[
TuneReportCheckpointCallback(
metrics={
"loss": "val_loss",
"mean_accuracy": "val_accuracy"
},
filename="checkpoint",
on="validation_end")
])
if checkpoint_dir:
# Currently, this leads to errors:
# model = LightningMNISTClassifier.load_from_checkpoint(
@@ -189,8 +186,7 @@ def train_mnist_tune_checkpoint(
model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
trainer.current_epoch = ckpt["epoch"]
else:
model = LightningMNISTClassifier(
config=config, data_dir=data_dir)
model = LightningMNISTClassifier(config=config, data_dir=data_dir)
trainer.fit(model)
# __tune_train_checkpoint_end__
@@ -225,7 +221,10 @@ def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial),
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
resources_per_trial={
"cpu": 1,
"gpu": gpus_per_trial
},
config=config,
num_samples=num_samples,
scheduler=scheduler,
@@ -268,7 +267,10 @@ def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial),
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
resources_per_trial={
"cpu": 1,
"gpu": gpus_per_trial
},
config=config,
num_samples=num_samples,
scheduler=scheduler,
+1 -2
View File
@@ -17,8 +17,7 @@ def NamespacedKubernetesSyncer(namespace, use_rsync=False):
installed in the Kubernetes pods for this to work.
If False, ``tar`` will need to be installed instead.
Returns: A ``KubernetesSyncer`` class to be passed to
``tune.run(sync_to_driver)``.
Returns: A ``KubernetesSyncer`` class to be passed to ``tune.run()``.
Example:
@@ -0,0 +1,260 @@
from typing import Dict, List, Optional, Union
from pytorch_lightning import Callback, Trainer, LightningModule
from ray import tune
import os
class TuneCallback(Callback):
"""Base class for Tune's PyTorch Lightning callbacks."""
_allowed = [
"init_start", "init_end", "fit_start", "fit_end", "sanity_check_start",
"sanity_check_end", "epoch_start", "epoch_end", "batch_start",
"validation_batch_start", "validation_batch_end", "test_batch_start",
"test_batch_end", "batch_end", "train_start", "train_end",
"validation_start", "validation_end", "test_start", "test_end",
"keyboard_interrupt"
]
def __init__(self, on: Union[str, List[str]] = "validation_end"):
if not isinstance(on, list):
on = [on]
if any(w not in self._allowed for w in on):
raise ValueError(
"Invalid trigger time selected: {}. Must be one of {}".format(
on, self._allowed))
self._on = on
def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
raise NotImplementedError
def on_init_start(self, trainer: Trainer):
if "init_start" in self._on:
self._handle(trainer, None)
def on_init_end(self, trainer: Trainer):
if "init_end" in self._on:
self._handle(trainer, None)
def on_fit_start(self,
trainer: Trainer,
pl_module: Optional[LightningModule] = None):
if "fit_start" in self._on:
self._handle(trainer, None)
def on_fit_end(self,
trainer: Trainer,
pl_module: Optional[LightningModule] = None):
if "fit_end" in self._on:
self._handle(trainer, None)
def on_sanity_check_start(self, trainer: Trainer,
pl_module: LightningModule):
if "sanity_check_start" in self._on:
self._handle(trainer, pl_module)
def on_sanity_check_end(self, trainer: Trainer,
pl_module: LightningModule):
if "sanity_check_end" in self._on:
self._handle(trainer, pl_module)
def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
if "epoch_start" in self._on:
self._handle(trainer, pl_module)
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if "epoch_end" in self._on:
self._handle(trainer, pl_module)
def on_batch_start(self, trainer: Trainer, pl_module: LightningModule):
if "batch_start" in self._on:
self._handle(trainer, pl_module)
def on_validation_batch_start(self, trainer: Trainer,
pl_module: LightningModule, batch, batch_idx,
dataloader_idx):
if "validation_batch_start" in self._on:
self._handle(trainer, pl_module)
def on_validation_batch_end(self, trainer: Trainer,
pl_module: LightningModule, batch, batch_idx,
dataloader_idx):
if "validation_batch_end" in self._on:
self._handle(trainer, pl_module)
def on_test_batch_start(self, trainer: Trainer, pl_module: LightningModule,
batch, batch_idx, dataloader_idx):
if "test_batch_start" in self._on:
self._handle(trainer, pl_module)
def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule,
batch, batch_idx, dataloader_idx):
if "test_batch_end" in self._on:
self._handle(trainer, pl_module)
def on_batch_end(self, trainer: Trainer, pl_module: LightningModule):
if "batch_end" in self._on:
self._handle(trainer, pl_module)
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
if "train_start" in self._on:
self._handle(trainer, pl_module)
def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
if "train_end" in self._on:
self._handle(trainer, pl_module)
def on_validation_start(self, trainer: Trainer,
pl_module: LightningModule):
if "validation_start" in self._on:
self._handle(trainer, pl_module)
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
if "validation_end" in self._on:
self._handle(trainer, pl_module)
def on_test_start(self, trainer: Trainer, pl_module: LightningModule):
if "test_start" in self._on:
self._handle(trainer, pl_module)
def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
if "test_end" in self._on:
self._handle(trainer, pl_module)
def on_keyboard_interrupt(self, trainer: Trainer,
pl_module: LightningModule):
if "keyboard_interrupt" in self._on:
self._handle(trainer, pl_module)
class TuneReportCallback(TuneCallback):
"""PyTorch Lightning to Ray Tune reporting callback
Reports metrics to Ray Tune.
Args:
metrics (str|list|dict): Metrics to report to Tune. If this is a list,
each item describes the metric key reported to PyTorch Lightning,
and it will reported under the same name to Tune. If this is a
dict, each key will be the name reported to Tune and the respective
value will be the metric key reported to PyTorch Lightning.
on (str|list): When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
Example:
.. code-block:: python
import pytorch_lightning as pl
from ray.tune.integration.pytorch_lightning import TuneReportCallback
# Report loss and accuracy to Tune after each validation epoch:
trainer = pl.Trainer(callbacks=[TuneReportCallback(
["val_loss", "val_acc"], on="validation_end")])
# Same as above, but report as `loss` and `mean_accuracy`:
trainer = pl.Trainer(callbacks=[TuneReportCallback(
{"loss": "val_loss", "mean_accuracy": "val_acc"},
on="validation_end")])
"""
def __init__(self,
metrics: Union[str, List[str], Dict[str, str]],
on: Union[str, List[str]] = "validation_end"):
super(TuneReportCallback, self).__init__(on)
if isinstance(metrics, str):
metrics = [metrics]
self._metrics = metrics
def _handle(self, trainer: Trainer, pl_module: LightningModule):
report_dict = {}
for key in self._metrics:
if isinstance(self._metrics, dict):
metric = self._metrics[key]
else:
metric = key
report_dict[key] = trainer.callback_metrics[metric].item()
tune.report(**report_dict)
class _TuneCheckpointCallback(TuneCallback):
"""PyTorch Lightning checkpoint callback
Saves checkpoints after each validation step.
Checkpoint are currently not registered if no ``tune.report()`` call
is made afterwards. Consider using ``TuneReportCheckpointCallback``
instead.
Args:
filename (str): Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
on (str|list): When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"""
def __init__(self,
filename: str = "checkpoint",
on: Union[str, List[str]] = "validation_end"):
super(_TuneCheckpointCallback, self).__init__(on)
self._filename = filename
def _handle(self, trainer: Trainer, pl_module: LightningModule):
with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
trainer.save_checkpoint(
os.path.join(checkpoint_dir, self._filename))
class TuneReportCheckpointCallback(TuneCallback):
"""PyTorch Lightning report and checkpoint callback
Saves checkpoints after each validation step. Also reports metrics to Tune,
which is needed for checkpoint registration.
Args:
metrics (str|list|dict): Metrics to report to Tune. If this is a list,
each item describes the metric key reported to PyTorch Lightning,
and it will reported under the same name to Tune. If this is a
dict, each key will be the name reported to Tune and the respective
value will be the metric key reported to PyTorch Lightning.
filename (str): Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
on (str|list): When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
Example:
.. code-block:: python
import pytorch_lightning as pl
from ray.tune.integration.pytorch_lightning import \
TuneReportCheckpointCallback
# Save checkpoint after each training batch and after each
# validation epoch.
trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback(
metrics={"loss": "val_loss", "mean_accuracy": "val_acc"},
filename="trainer.ckpt", on="validation_end")])
"""
def __init__(self,
metrics: Union[str, List[str], Dict[str, str]],
filename: str = "checkpoint",
on: Union[str, List[str]] = "validation_end"):
super(TuneReportCheckpointCallback, self).__init__(on)
self._checkpoint = _TuneCheckpointCallback(filename, on)
self._report = TuneReportCallback(metrics, on)
def _handle(self, trainer: Trainer, pl_module: LightningModule):
self._checkpoint._handle(trainer, pl_module)
self._report._handle(trainer, pl_module)
@@ -0,0 +1,144 @@
import os
import shutil
import tempfile
import unittest
import pytorch_lightning as pl
import torch
from ray.tune.result import TRAINING_ITERATION
from torch.utils.data import DataLoader, Dataset
from ray import tune
from ray.tune.integration.pytorch_lightning import TuneReportCallback, \
TuneReportCheckpointCallback, _TuneCheckpointCallback
class _MockDataset(Dataset):
def __init__(self, values):
self.values = values
def __getitem__(self, index):
return self.values[index]
def __len__(self):
return len(self.values)
class _MockModule(pl.LightningModule):
def __init__(self, loss, acc):
super().__init__()
self.loss = torch.tensor(loss)
self.acc = torch.tensor(acc)
def forward(self, *args, **kwargs):
return self.loss
def backward(self, trainer, loss, optimizer, optimizer_idx):
return None
def training_step(self, train_batch, batch_idx):
return {"loss": self.loss, "acc": self.acc}
def validation_step(self, val_batch, batch_idx):
return {"val_loss": self.loss * 1.1, "val_acc": self.acc * 0.9}
def validation_epoch_end(self, outputs):
avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_val_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
return {"avg_val_loss": avg_val_loss, "avg_val_acc": avg_val_acc}
def configure_optimizers(self):
return None
def train_dataloader(self):
return DataLoader(_MockDataset(list(range(10))), batch_size=1)
def val_dataloader(self):
return DataLoader(_MockDataset(list(range(10))), batch_size=1)
class PyTorchLightningIntegrationTest(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def testReportCallback(self):
def train(config):
module = _MockModule(10, 20)
trainer = pl.Trainer(
max_epochs=1,
callbacks=[
TuneReportCallback(
{
"tune_loss": "avg_val_loss"
}, on="validation_end")
])
trainer.fit(module)
analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
self.assertEqual(analysis.trials[0].last_result["tune_loss"], 10 * 1.1)
def testCheckpointCallback(self):
tmpdir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmpdir))
def train(config):
module = _MockModule(10, 20)
trainer = pl.Trainer(
max_epochs=1,
callbacks=[
_TuneCheckpointCallback(
"trainer.ckpt", on=["batch_end", "train_end"])
])
trainer.fit(module)
analysis = tune.run(
train,
stop={TRAINING_ITERATION: 10},
keep_checkpoints_num=100,
local_dir=tmpdir)
checkpoints = [
dir for dir in os.listdir(analysis.trials[0].logdir)
if dir.startswith("checkpoint")
]
# 10 checkpoints after each batch, 1 checkpoint at end
self.assertEqual(len(checkpoints), 11)
def testReportCheckpointCallback(self):
tmpdir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmpdir))
def train(config):
module = _MockModule(10, 20)
trainer = pl.Trainer(
max_epochs=1,
callbacks=[
TuneReportCheckpointCallback(
["avg_val_loss"], "trainer.ckpt", on="validation_end")
])
trainer.fit(module)
analysis = tune.run(
train,
stop={TRAINING_ITERATION: 10},
keep_checkpoints_num=100,
local_dir=tmpdir)
checkpoints = [
dir for dir in os.listdir(analysis.trials[0].logdir)
if dir.startswith("checkpoint")
]
# 1 checkpoint after the validation step
self.assertEqual(len(checkpoints), 1)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))