mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 15:40:00 +08:00
[tune] Added PyTorch Lightning callbacks to integrations (#10220)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -101,6 +101,14 @@ py_test(
|
||||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_integration_pytorch_lightning",
|
||||
size = "small",
|
||||
srcs = ["tests/test_integration_pytorch_lightning.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_integration_wandb",
|
||||
size = "small",
|
||||
|
||||
@@ -15,12 +15,13 @@ import os
|
||||
import shutil
|
||||
from functools import partial
|
||||
from tempfile import mkdtemp
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from ray import tune
|
||||
from ray.tune import CLIReporter
|
||||
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
|
||||
from ray.tune.integration.pytorch_lightning import TuneReportCallback, \
|
||||
TuneReportCheckpointCallback
|
||||
# __import_tune_end__
|
||||
|
||||
|
||||
@@ -93,8 +94,8 @@ class LightningMNISTClassifier(pl.LightningModule):
|
||||
logs = {"ptl/val_loss": avg_loss, "ptl/val_accuracy": avg_acc}
|
||||
|
||||
return {
|
||||
"avg_val_loss": avg_loss,
|
||||
"avg_val_accuracy": avg_acc,
|
||||
"val_loss": avg_loss,
|
||||
"val_accuracy": avg_acc,
|
||||
"log": logs
|
||||
}
|
||||
|
||||
@@ -131,15 +132,6 @@ def train_mnist(config):
|
||||
# __lightning_end__
|
||||
|
||||
|
||||
# __tune_callback_begin__
|
||||
class TuneReportCallback(Callback):
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
tune.report(
|
||||
loss=trainer.callback_metrics["avg_val_loss"].item(),
|
||||
mean_accuracy=trainer.callback_metrics["avg_val_accuracy"].item())
|
||||
# __tune_callback_end__
|
||||
|
||||
|
||||
# __tune_train_begin__
|
||||
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
|
||||
model = LightningMNISTClassifier(config, data_dir)
|
||||
@@ -149,35 +141,40 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
|
||||
logger=TensorBoardLogger(
|
||||
save_dir=tune.get_trial_dir(), name="", version="."),
|
||||
progress_bar_refresh_rate=0,
|
||||
callbacks=[TuneReportCallback()])
|
||||
callbacks=[
|
||||
TuneReportCallback(
|
||||
{
|
||||
"loss": "val_loss",
|
||||
"mean_accuracy": "val_accuracy"
|
||||
},
|
||||
on="validation_end")
|
||||
])
|
||||
|
||||
trainer.fit(model)
|
||||
# __tune_train_end__
|
||||
|
||||
|
||||
# __tune_checkpoint_callback_begin__
|
||||
class CheckpointCallback(Callback):
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
|
||||
trainer.save_checkpoint(os.path.join(checkpoint_dir, "checkpoint"))
|
||||
# __tune_checkpoint_callback_end__
|
||||
|
||||
|
||||
# __tune_train_checkpoint_begin__
|
||||
def train_mnist_tune_checkpoint(
|
||||
config,
|
||||
checkpoint_dir=None,
|
||||
data_dir=None,
|
||||
num_epochs=10,
|
||||
num_gpus=0):
|
||||
def train_mnist_tune_checkpoint(config,
|
||||
checkpoint_dir=None,
|
||||
data_dir=None,
|
||||
num_epochs=10,
|
||||
num_gpus=0):
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=num_epochs,
|
||||
gpus=num_gpus,
|
||||
logger=TensorBoardLogger(
|
||||
save_dir=tune.get_trial_dir(), name="", version="."),
|
||||
progress_bar_refresh_rate=0,
|
||||
callbacks=[CheckpointCallback(),
|
||||
TuneReportCallback()])
|
||||
callbacks=[
|
||||
TuneReportCheckpointCallback(
|
||||
metrics={
|
||||
"loss": "val_loss",
|
||||
"mean_accuracy": "val_accuracy"
|
||||
},
|
||||
filename="checkpoint",
|
||||
on="validation_end")
|
||||
])
|
||||
if checkpoint_dir:
|
||||
# Currently, this leads to errors:
|
||||
# model = LightningMNISTClassifier.load_from_checkpoint(
|
||||
@@ -189,8 +186,7 @@ def train_mnist_tune_checkpoint(
|
||||
model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
|
||||
trainer.current_epoch = ckpt["epoch"]
|
||||
else:
|
||||
model = LightningMNISTClassifier(
|
||||
config=config, data_dir=data_dir)
|
||||
model = LightningMNISTClassifier(config=config, data_dir=data_dir)
|
||||
|
||||
trainer.fit(model)
|
||||
# __tune_train_checkpoint_end__
|
||||
@@ -225,7 +221,10 @@ def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||
data_dir=data_dir,
|
||||
num_epochs=num_epochs,
|
||||
num_gpus=gpus_per_trial),
|
||||
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
"gpu": gpus_per_trial
|
||||
},
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
@@ -268,7 +267,10 @@ def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
|
||||
data_dir=data_dir,
|
||||
num_epochs=num_epochs,
|
||||
num_gpus=gpus_per_trial),
|
||||
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
|
||||
resources_per_trial={
|
||||
"cpu": 1,
|
||||
"gpu": gpus_per_trial
|
||||
},
|
||||
config=config,
|
||||
num_samples=num_samples,
|
||||
scheduler=scheduler,
|
||||
|
||||
@@ -17,8 +17,7 @@ def NamespacedKubernetesSyncer(namespace, use_rsync=False):
|
||||
installed in the Kubernetes pods for this to work.
|
||||
If False, ``tar`` will need to be installed instead.
|
||||
|
||||
Returns: A ``KubernetesSyncer`` class to be passed to
|
||||
``tune.run(sync_to_driver)``.
|
||||
Returns: A ``KubernetesSyncer`` class to be passed to ``tune.run()``.
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -0,0 +1,260 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pytorch_lightning import Callback, Trainer, LightningModule
|
||||
from ray import tune
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class TuneCallback(Callback):
|
||||
"""Base class for Tune's PyTorch Lightning callbacks."""
|
||||
_allowed = [
|
||||
"init_start", "init_end", "fit_start", "fit_end", "sanity_check_start",
|
||||
"sanity_check_end", "epoch_start", "epoch_end", "batch_start",
|
||||
"validation_batch_start", "validation_batch_end", "test_batch_start",
|
||||
"test_batch_end", "batch_end", "train_start", "train_end",
|
||||
"validation_start", "validation_end", "test_start", "test_end",
|
||||
"keyboard_interrupt"
|
||||
]
|
||||
|
||||
def __init__(self, on: Union[str, List[str]] = "validation_end"):
|
||||
if not isinstance(on, list):
|
||||
on = [on]
|
||||
if any(w not in self._allowed for w in on):
|
||||
raise ValueError(
|
||||
"Invalid trigger time selected: {}. Must be one of {}".format(
|
||||
on, self._allowed))
|
||||
self._on = on
|
||||
|
||||
def _handle(self, trainer: Trainer, pl_module: Optional[LightningModule]):
|
||||
raise NotImplementedError
|
||||
|
||||
def on_init_start(self, trainer: Trainer):
|
||||
if "init_start" in self._on:
|
||||
self._handle(trainer, None)
|
||||
|
||||
def on_init_end(self, trainer: Trainer):
|
||||
if "init_end" in self._on:
|
||||
self._handle(trainer, None)
|
||||
|
||||
def on_fit_start(self,
|
||||
trainer: Trainer,
|
||||
pl_module: Optional[LightningModule] = None):
|
||||
if "fit_start" in self._on:
|
||||
self._handle(trainer, None)
|
||||
|
||||
def on_fit_end(self,
|
||||
trainer: Trainer,
|
||||
pl_module: Optional[LightningModule] = None):
|
||||
if "fit_end" in self._on:
|
||||
self._handle(trainer, None)
|
||||
|
||||
def on_sanity_check_start(self, trainer: Trainer,
|
||||
pl_module: LightningModule):
|
||||
if "sanity_check_start" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_sanity_check_end(self, trainer: Trainer,
|
||||
pl_module: LightningModule):
|
||||
if "sanity_check_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if "epoch_start" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if "epoch_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_batch_start(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if "batch_start" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_validation_batch_start(self, trainer: Trainer,
|
||||
pl_module: LightningModule, batch, batch_idx,
|
||||
dataloader_idx):
|
||||
if "validation_batch_start" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_validation_batch_end(self, trainer: Trainer,
|
||||
pl_module: LightningModule, batch, batch_idx,
|
||||
dataloader_idx):
|
||||
if "validation_batch_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_test_batch_start(self, trainer: Trainer, pl_module: LightningModule,
|
||||
batch, batch_idx, dataloader_idx):
|
||||
if "test_batch_start" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_test_batch_end(self, trainer: Trainer, pl_module: LightningModule,
|
||||
batch, batch_idx, dataloader_idx):
|
||||
if "test_batch_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_batch_end(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if "batch_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if "train_start" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_train_end(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if "train_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_validation_start(self, trainer: Trainer,
|
||||
pl_module: LightningModule):
|
||||
if "validation_start" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if "validation_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_test_start(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if "test_start" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
|
||||
if "test_end" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
def on_keyboard_interrupt(self, trainer: Trainer,
|
||||
pl_module: LightningModule):
|
||||
if "keyboard_interrupt" in self._on:
|
||||
self._handle(trainer, pl_module)
|
||||
|
||||
|
||||
class TuneReportCallback(TuneCallback):
|
||||
"""PyTorch Lightning to Ray Tune reporting callback
|
||||
|
||||
Reports metrics to Ray Tune.
|
||||
|
||||
Args:
|
||||
metrics (str|list|dict): Metrics to report to Tune. If this is a list,
|
||||
each item describes the metric key reported to PyTorch Lightning,
|
||||
and it will reported under the same name to Tune. If this is a
|
||||
dict, each key will be the name reported to Tune and the respective
|
||||
value will be the metric key reported to PyTorch Lightning.
|
||||
on (str|list): When to trigger checkpoint creations. Must be one of
|
||||
the PyTorch Lightning event hooks (less the ``on_``), e.g.
|
||||
"batch_start", or "train_end". Defaults to "validation_end".
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from ray.tune.integration.pytorch_lightning import TuneReportCallback
|
||||
|
||||
# Report loss and accuracy to Tune after each validation epoch:
|
||||
trainer = pl.Trainer(callbacks=[TuneReportCallback(
|
||||
["val_loss", "val_acc"], on="validation_end")])
|
||||
|
||||
# Same as above, but report as `loss` and `mean_accuracy`:
|
||||
trainer = pl.Trainer(callbacks=[TuneReportCallback(
|
||||
{"loss": "val_loss", "mean_accuracy": "val_acc"},
|
||||
on="validation_end")])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
metrics: Union[str, List[str], Dict[str, str]],
|
||||
on: Union[str, List[str]] = "validation_end"):
|
||||
super(TuneReportCallback, self).__init__(on)
|
||||
if isinstance(metrics, str):
|
||||
metrics = [metrics]
|
||||
self._metrics = metrics
|
||||
|
||||
def _handle(self, trainer: Trainer, pl_module: LightningModule):
|
||||
report_dict = {}
|
||||
for key in self._metrics:
|
||||
if isinstance(self._metrics, dict):
|
||||
metric = self._metrics[key]
|
||||
else:
|
||||
metric = key
|
||||
report_dict[key] = trainer.callback_metrics[metric].item()
|
||||
tune.report(**report_dict)
|
||||
|
||||
|
||||
class _TuneCheckpointCallback(TuneCallback):
|
||||
"""PyTorch Lightning checkpoint callback
|
||||
|
||||
Saves checkpoints after each validation step.
|
||||
|
||||
Checkpoint are currently not registered if no ``tune.report()`` call
|
||||
is made afterwards. Consider using ``TuneReportCheckpointCallback``
|
||||
instead.
|
||||
|
||||
Args:
|
||||
filename (str): Filename of the checkpoint within the checkpoint
|
||||
directory. Defaults to "checkpoint".
|
||||
on (str|list): When to trigger checkpoint creations. Must be one of
|
||||
the PyTorch Lightning event hooks (less the ``on_``), e.g.
|
||||
"batch_start", or "train_end". Defaults to "validation_end".
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
filename: str = "checkpoint",
|
||||
on: Union[str, List[str]] = "validation_end"):
|
||||
super(_TuneCheckpointCallback, self).__init__(on)
|
||||
self._filename = filename
|
||||
|
||||
def _handle(self, trainer: Trainer, pl_module: LightningModule):
|
||||
with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
|
||||
trainer.save_checkpoint(
|
||||
os.path.join(checkpoint_dir, self._filename))
|
||||
|
||||
|
||||
class TuneReportCheckpointCallback(TuneCallback):
|
||||
"""PyTorch Lightning report and checkpoint callback
|
||||
|
||||
Saves checkpoints after each validation step. Also reports metrics to Tune,
|
||||
which is needed for checkpoint registration.
|
||||
|
||||
Args:
|
||||
metrics (str|list|dict): Metrics to report to Tune. If this is a list,
|
||||
each item describes the metric key reported to PyTorch Lightning,
|
||||
and it will reported under the same name to Tune. If this is a
|
||||
dict, each key will be the name reported to Tune and the respective
|
||||
value will be the metric key reported to PyTorch Lightning.
|
||||
filename (str): Filename of the checkpoint within the checkpoint
|
||||
directory. Defaults to "checkpoint".
|
||||
on (str|list): When to trigger checkpoint creations. Must be one of
|
||||
the PyTorch Lightning event hooks (less the ``on_``), e.g.
|
||||
"batch_start", or "train_end". Defaults to "validation_end".
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from ray.tune.integration.pytorch_lightning import \
|
||||
TuneReportCheckpointCallback
|
||||
|
||||
# Save checkpoint after each training batch and after each
|
||||
# validation epoch.
|
||||
trainer = pl.Trainer(callbacks=[TuneReportCheckpointCallback(
|
||||
metrics={"loss": "val_loss", "mean_accuracy": "val_acc"},
|
||||
filename="trainer.ckpt", on="validation_end")])
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
metrics: Union[str, List[str], Dict[str, str]],
|
||||
filename: str = "checkpoint",
|
||||
on: Union[str, List[str]] = "validation_end"):
|
||||
super(TuneReportCheckpointCallback, self).__init__(on)
|
||||
self._checkpoint = _TuneCheckpointCallback(filename, on)
|
||||
self._report = TuneReportCallback(metrics, on)
|
||||
|
||||
def _handle(self, trainer: Trainer, pl_module: LightningModule):
|
||||
self._checkpoint._handle(trainer, pl_module)
|
||||
self._report._handle(trainer, pl_module)
|
||||
@@ -0,0 +1,144 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.integration.pytorch_lightning import TuneReportCallback, \
|
||||
TuneReportCheckpointCallback, _TuneCheckpointCallback
|
||||
|
||||
|
||||
class _MockDataset(Dataset):
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.values[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.values)
|
||||
|
||||
|
||||
class _MockModule(pl.LightningModule):
|
||||
def __init__(self, loss, acc):
|
||||
super().__init__()
|
||||
|
||||
self.loss = torch.tensor(loss)
|
||||
self.acc = torch.tensor(acc)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.loss
|
||||
|
||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||
return None
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
return {"loss": self.loss, "acc": self.acc}
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
return {"val_loss": self.loss * 1.1, "val_acc": self.acc * 0.9}
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
||||
avg_val_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
|
||||
|
||||
return {"avg_val_loss": avg_val_loss, "avg_val_acc": avg_val_acc}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return None
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(_MockDataset(list(range(10))), batch_size=1)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(_MockDataset(list(range(10))), batch_size=1)
|
||||
|
||||
|
||||
class PyTorchLightningIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def testReportCallback(self):
|
||||
def train(config):
|
||||
module = _MockModule(10, 20)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1,
|
||||
callbacks=[
|
||||
TuneReportCallback(
|
||||
{
|
||||
"tune_loss": "avg_val_loss"
|
||||
}, on="validation_end")
|
||||
])
|
||||
trainer.fit(module)
|
||||
|
||||
analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
|
||||
|
||||
self.assertEqual(analysis.trials[0].last_result["tune_loss"], 10 * 1.1)
|
||||
|
||||
def testCheckpointCallback(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmpdir))
|
||||
|
||||
def train(config):
|
||||
module = _MockModule(10, 20)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1,
|
||||
callbacks=[
|
||||
_TuneCheckpointCallback(
|
||||
"trainer.ckpt", on=["batch_end", "train_end"])
|
||||
])
|
||||
trainer.fit(module)
|
||||
|
||||
analysis = tune.run(
|
||||
train,
|
||||
stop={TRAINING_ITERATION: 10},
|
||||
keep_checkpoints_num=100,
|
||||
local_dir=tmpdir)
|
||||
|
||||
checkpoints = [
|
||||
dir for dir in os.listdir(analysis.trials[0].logdir)
|
||||
if dir.startswith("checkpoint")
|
||||
]
|
||||
# 10 checkpoints after each batch, 1 checkpoint at end
|
||||
self.assertEqual(len(checkpoints), 11)
|
||||
|
||||
def testReportCheckpointCallback(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
self.addCleanup(lambda: shutil.rmtree(tmpdir))
|
||||
|
||||
def train(config):
|
||||
module = _MockModule(10, 20)
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=1,
|
||||
callbacks=[
|
||||
TuneReportCheckpointCallback(
|
||||
["avg_val_loss"], "trainer.ckpt", on="validation_end")
|
||||
])
|
||||
trainer.fit(module)
|
||||
|
||||
analysis = tune.run(
|
||||
train,
|
||||
stop={TRAINING_ITERATION: 10},
|
||||
keep_checkpoints_num=100,
|
||||
local_dir=tmpdir)
|
||||
|
||||
checkpoints = [
|
||||
dir for dir in os.listdir(analysis.trials[0].logdir)
|
||||
if dir.startswith("checkpoint")
|
||||
]
|
||||
# 1 checkpoint after the validation step
|
||||
self.assertEqual(len(checkpoints), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user