[tune/sgd] Document func_trainable and add checkpoint context (#9739)

Co-authored-by: krfricke <krfricke@users.noreply.github.com>
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Richard Liaw
2020-07-30 09:46:37 -07:00
committed by GitHub
parent e540e425e4
commit 0c3b9ebeef
23 changed files with 619 additions and 452 deletions
+2
View File
@@ -18,6 +18,8 @@ For end to end examples leveraging RaySGD TorchTrainer, jump to :ref:`raysgd-tor
Setting up training
-------------------
.. tip:: If you want to leverage multi-node data parallel training with PyTorch while using RayTune *without* restructuring your code, check out the :ref:`Tune PyTorch user guide <tune-pytorch-cifar>` and Tune's :ref:`distributed pytorch integrations <tune-ddp-doc>`.
The ``TorchTrainer`` can be constructed with functions that wrap components of the training script. Specifically, it requires constructors for the Model, Data, Optimizer, Loss, and ``lr_scheduler`` to create replicated copies across different devices and machines.
Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_workers``), each of which is managed by a Ray actor. One of the replicas will be on the main process, which can simplify the debugging and logging experience.
@@ -213,16 +213,7 @@ In GCP, you can use the following configuration modification:
scheduling:
- preemptible: true
Spot instances may be removed suddenly while trials are still running. Often times this may be difficult to deal with when using other distributed hyperparameter optimization frameworks. Tune allows users to mitigate the effects of this by preserving the progress of your model training through checkpointing.
The easiest way to do this is to subclass the pre-defined ``Trainable`` class and implement ``save_checkpoint``, and ``load_checkpoint`` abstract methods, as seen in the example below:
.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_trainable.py
:language: python
:start-after: __trainable_example_begin__
:end-before: __trainable_example_end__
This can then be used similarly to the Function API as before:
Spot instances may be removed suddenly while trials are still running. Often times this may be difficult to deal with when using other distributed hyperparameter optimization frameworks. Tune allows users to mitigate the effects of this by preserving the progress of your model training through :ref:`checkpointing <tune-function-checkpointing>`.
.. literalinclude:: /../../python/ray/tune/tests/tutorial.py
:language: python
@@ -25,6 +25,8 @@ need to
3. add checkpointing (optional),
4. and define the search space for the model tuning
Optionally, you can seamlessly leverage :ref:`DistributedDataParallel training <tune-torch-ddp>` for each individual Pytorch model within Tune.
.. note::
To run this example, you will need to install the following:
@@ -74,15 +76,16 @@ The train function
Now it gets interesting, because we introduce some changes to the example `from the PyTorch
documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_.
We wrap the training script in a function ``train_cifar(config, checkpoint=None)``. As you
We wrap the training script in a function ``train_cifar(config, checkpoint_dir=None)``. As you
can guess, the ``config`` parameter will receive the hyperparameters we would like to
train with. The ``checkpoint`` parameter is used to restore checkpoints.
train with. The ``checkpoint_dir`` parameter is used to restore checkpoints.
.. code-block:: python
net = Net(config["l1"], config["l2"])
if checkpoint:
if checkpoint_dir:
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
net.load_state_dict(torch.load(checkpoint))
The learning rate of the optimizer is made configurable, too:
@@ -97,6 +100,7 @@ with which we iterate through the training and test sets are configurable as wel
Adding (multi) GPU support with DataParallel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Image classification benefits largely from GPUs. Luckily, we can continue to use
PyTorch's abstractions in Ray Tune. Thus, we can wrap our model in ``nn.DataParallel``
to support data parallel training on multiple GPUs:
@@ -132,10 +136,9 @@ The most interesting part is the communication with Tune:
.. code-block:: python
checkpoint_dir = tune.make_checkpoint_dir(epoch)
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((net.state_dict(), optimizer.state_dict()), path)
tune.save_checkpoint(path)
with tune.checkpoint_dir(epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((net.state_dict(), optimizer.state_dict()), path)
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
@@ -273,6 +276,75 @@ be confirmed on the test set.
So that's it! You can now tune the parameters of your PyTorch models.
.. _tune-torch-ddp:
Advanced: Distributed training with DistributedDataParallel
-----------------------------------------------------------
Some models require multiple nodes to train in a short amount of time. Ray Tune allows you to easily do distributed data parallel training in addition to distributed hyperparameter tuning.
You can wrap your model in ``torch.nn.parallel.DistributedDataParallel`` to support distributed data parallel training:
.. code-block:: python
from ray.util.sgd.torch import is_distributed_trainable
from torch.nn.parallel import DistributedDataParallel
def train_cifar(config, checkpoint_dir=None, data_dir=None):
net = Net(config["l1"], config["l2"])
device = "cpu"
#### Using distributed data parallel training
if is_distributed_trainable():
net = DistributedDataParallel(net)
if torch.cuda.is_available():
device = "cuda"
net.to(device)
If using checkpointing, be sure to use a :ref:`special checkpoint context manager <tune-ddp-doc>`, ``distributed_checkpoint_dir`` that avoids redundant checkpointing across multiple processes:
.. code-block:: python
from ray.util.sgd.torch import distributed_checkpoint_dir
#### Using distributed data parallel training
# Inside `def train_cifar(...)`,
# replace tune.checkpoint_dir() with the following
# Avoids redundant checkpointing on different processes.
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((net.state_dict(), optimizer.state_dict()), path)
Finally, we need to tell Ray Tune to start multiple distributed processes at once by using ``ray.tune.integration.torch.DistributedTrainableCreator`` (:ref:`docs <tune-ddp-doc>`). This is essentially equivalent to running ``torch.distributed.launch`` for each hyperparameter trial:
.. code-block:: python
# You'll probably want to be running on a distributed Ray cluster.
# ray.init(address="auto")
from ray.util.sgd.integration.torch import DistributedTrainableCreator
distributed_train_cifar = DistributedTrainableCreator(
partial(train_cifar, data_dir=data_dir),
use_gpu=True,
num_workers=2, # number of parallel workers to use
num_cpus_per_worker=8
)
tune.run(
distributed_train_cifar,
resources_per_trial=None,
config=config,
num_samples=num_samples,
...
)
See an :doc:`end-to-end example here </tune/examples/ddp_mnist_torch>`.
If you consider switching to PyTorch Lightning to get rid of some of your boilerplate
training code, please know that we also have a walkthrough on :doc:`how to use Tune with
PyTorch Lightning models <tune-pytorch-lightning>`.
PyTorch Lightning models <tune-pytorch-lightning>`.
+30 -13
View File
@@ -17,6 +17,7 @@ For the sake of example, let's maximize this objective function:
Function API
------------
Here is a simple example of using the function API. You can report intermediate metrics by simply calling ``tune.report`` within the provided function.
.. code-block:: python
@@ -40,38 +41,41 @@ Here is a simple example of using the function API. You can report intermediate
Tune will run this function on a separate thread in a Ray actor process.
.. tip:: If you want to leverage multi-node data parallel training with PyTorch while using parallel hyperparameter tuning, check out our :ref:PyTorch user guide and Tune's :ref:distributed pytorch integrations.
.. _tune-function-checkpointing:
Function API Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~
Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``:
Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. To use Tune's checkpointing features, you must expose a ``checkpoint_dir`` argument in the function signature, and call ``tune.checkpoint_dir`` :
.. code-block:: python
import time
from ray import tune
def train_func(config, checkpoint=None):
def train_func(config, checkpoint_dir=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
state = json.loads(f.read())
start = state["step"] + 1
for iter in range(start, 100):
time.sleep(1)
#
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=step):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.report(hello="world", ray="tune")
tune.run(train_func)
.. note:: ``checkpoint_freq`` and ``checkpoint_at_end`` will not work with Function API checkpointing.
In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<step>``. You can restore a single trial checkpoint by using ``tune.run(restore=<checkpoint_dir>)``:
.. code-block:: python
@@ -263,9 +267,7 @@ tune.report / tune.checkpoint (Function API)
.. autofunction:: ray.tune.report
.. autofunction:: ray.tune.make_checkpoint_dir
.. autofunction:: ray.tune.save_checkpoint
.. autofunction:: ray.tune.checkpoint_dir
.. autofunction:: ray.tune.get_trial_dir
@@ -282,6 +284,21 @@ tune.Trainable (Class API)
:private-members:
:members:
.. _tune-ddp-doc:
Distributed Torch
-----------------
Ray also offers lightweight integrations to distribute your model training on Ray Tune.
.. autofunction:: ray.tune.integration.torch.DistributedTrainableCreator
.. autofunction:: ray.tune.integration.torch.distributed_checkpoint_dir
.. autofunction:: ray.tune.integration.torch.is_distributed_trainable
tune.DurableTrainable
---------------------
@@ -0,0 +1,6 @@
:orphan:
ddp_mnist_torch
~~~~~~~~~~~~~~~
.. literalinclude:: /../../python/ray/tune/examples/ddp_mnist_torch.py
+1
View File
@@ -41,6 +41,7 @@ PyTorch Examples
- :doc:`/tune/examples/mnist_pytorch`: Converts the PyTorch MNIST example to use Tune with the function-based API. Also shows how to easily convert something relying on argparse to use Tune.
- :doc:`/tune/examples/mnist_pytorch_trainable`: Converts the PyTorch MNIST example to use Tune with Trainable API. Also uses the HyperBandScheduler and checkpoints the model at the end.
- :doc:`/tune/examples/ddp_mnist_torch`: An example showing how to use DistributedDataParallel with Ray Tune. This enables both distributed training and distributed hyperparameter tuning.
XGBoost Example
+9 -9
View File
@@ -151,17 +151,18 @@ When running a hyperparameter search, Tune can automatically and periodically sa
Checkpointing assumes that the model state will be saved to disk on whichever node the Trainable is running on.
To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``:
To use Tune's checkpointing features, you must expose a ``checkpoint_dir`` argument in the function signature, and call ``tune.checkpoint_dir``:
.. code-block:: python
import os
import time
from ray import tune
def train_func(config, checkpoint=None):
def train_func(config, checkpoint_dir=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
state = json.loads(f.read())
start = state["step"] + 1
@@ -169,11 +170,10 @@ To use Tune's checkpointing features, you must expose a ``checkpoint`` argument
time.sleep(1)
# Obtain a checkpoint directory
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.report(hello="world", ray="tune")
+8
View File
@@ -287,6 +287,14 @@ py_test(
args = ["--smoke-test"]
)
py_test(
name = "test_torch_trainable",
size = "medium",
srcs = ["tests/test_torch_trainable.py"],
tags = ["exclusive", "example", "pytorch"],
deps = [":tune_lib"],
)
py_test(
name = "ddp_mnist_torch",
size = "small",
+2 -2
View File
@@ -9,7 +9,7 @@ from ray.tune.durable_trainable import DurableTrainable
from ray.tune.suggest import grid_search
from ray.tune.session import (report, get_trial_dir, get_trial_name,
get_trial_id, make_checkpoint_dir,
save_checkpoint)
save_checkpoint, checkpoint_dir)
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
JupyterNotebookReporter)
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
@@ -22,5 +22,5 @@ __all__ = [
"uniform", "choice", "randint", "randn", "loguniform",
"ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
"ProgressReporter", "report", "get_trial_dir", "get_trial_name",
"get_trial_id", "make_checkpoint_dir", "save_checkpoint"
"get_trial_id", "make_checkpoint_dir", "save_checkpoint", "checkpoint_dir"
]
+10 -8
View File
@@ -58,7 +58,7 @@ class Net(nn.Module):
# __train_begin__
def train_cifar(config, checkpoint=None, data_dir=None):
def train_cifar(config, checkpoint_dir=None, data_dir=None):
net = Net(config["l1"], config["l2"])
device = "cpu"
@@ -71,8 +71,8 @@ def train_cifar(config, checkpoint=None, data_dir=None):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
if checkpoint:
print("loading checkpoint {}".format(checkpoint))
if checkpoint_dir:
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint)
net.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
@@ -138,10 +138,10 @@ def train_cifar(config, checkpoint=None, data_dir=None):
val_loss += loss.cpu().numpy()
val_steps += 1
checkpoint_dir = tune.make_checkpoint_dir(epoch)
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((net.state_dict(), optimizer.state_dict()), path)
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save(
(net.state_dict(), optimizer.state_dict()), path)
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
print("Finished Training")
@@ -213,7 +213,9 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device)
model_state, optimizer_state = torch.load(best_trial.checkpoint.value)
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)
test_acc = test_accuracy(best_trained_model, device)
+8 -6
View File
@@ -2,6 +2,7 @@
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import argparse
import logging
import os
import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
@@ -10,21 +11,21 @@ import ray
from ray import tune
from ray.tune.examples.mnist_pytorch import (train, test, get_data_loaders,
ConvNet)
from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator,
distributed_checkpoint)
from ray.tune.integration.torch import (DistributedTrainableCreator,
distributed_checkpoint_dir)
logger = logging.getLogger(__name__)
def train_mnist(config, checkpoint=False):
def train_mnist(config, checkpoint_dir=False):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_data_loaders()
model = ConvNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1)
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
model_state, optimizer_state = torch.load(f)
model.load_state_dict(model_state)
@@ -37,7 +38,8 @@ def train_mnist(config, checkpoint=False):
acc = test(model, test_loader, device)
if epoch % 3 == 0:
with distributed_checkpoint(label=epoch) as path:
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((model.state_dict(), optimizer.state_dict()), path)
tune.report(mean_accuracy=acc)
@@ -11,10 +11,10 @@ from ray import tune
from ray.tune.schedulers import HyperBandScheduler
def train(config, checkpoint=None):
def train(config, checkpoint_dir=None):
step = 0
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
step = json.loads(f.read())["timestep"]
for timestep in range(step, 100):
@@ -22,11 +22,10 @@ def train(config, checkpoint=None):
v *= config.get("height", 1)
if timestep % 3 == 0:
checkpoint_dir = tune.make_checkpoint_dir(step=timestep)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": timestep}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=timestep) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": timestep}))
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
@@ -158,16 +158,15 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
# __tune_checkpoint_callback_begin__
class CheckpointCallback(Callback):
def on_validation_end(self, trainer, pl_module):
path = tune.make_checkpoint_dir(trainer.global_step)
trainer.save_checkpoint(os.path.join(path, "checkpoint"))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
trainer.save_checkpoint(os.path.join(checkpoint_dir, "checkpoint"))
# __tune_checkpoint_callback_end__
# __tune_train_checkpoint_begin__
def train_mnist_tune_checkpoint(
config,
checkpoint=None,
checkpoint_dir=None,
data_dir=None,
num_epochs=10,
num_gpus=0):
@@ -179,13 +178,13 @@ def train_mnist_tune_checkpoint(
progress_bar_refresh_rate=0,
callbacks=[CheckpointCallback(),
TuneReportCallback()])
if checkpoint:
if checkpoint_dir:
# Currently, this leads to errors:
# model = LightningMNISTClassifier.load_from_checkpoint(
# os.path.join(checkpoint, "checkpoint"))
# Workaround:
ckpt = pl_load(
os.path.join(checkpoint, "checkpoint"),
os.path.join(checkpoint_dir, "checkpoint"),
map_location=lambda storage, loc: storage)
model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
trainer.current_epoch = ckpt["epoch"]
+7 -8
View File
@@ -11,7 +11,7 @@ from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
def pbt_function(config, checkpoint=None):
def pbt_function(config, checkpoint_dir=None):
"""Toy PBT problem for benchmarking adaptive learning rate.
The goal is to optimize this trainable's accuracy. The accuracy increases
@@ -35,8 +35,8 @@ def pbt_function(config, checkpoint=None):
lr = config["lr"]
accuracy = 0.0 # end = 1000
start = 0
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
state = json.loads(f.read())
accuracy = state["acc"]
start = state["step"]
@@ -65,11 +65,10 @@ def pbt_function(config, checkpoint=None):
accuracy = max(0, accuracy)
if step % 3 == 0:
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"acc": accuracy, "step": start}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"acc": accuracy, "step": start}))
tune.report(
mean_accuracy=accuracy,
+22 -10
View File
@@ -350,35 +350,47 @@ class FunctionRunner(Trainable):
pass
def detect_checkpoint_function(train_func):
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = "checkpoint" in func_args
return use_checkpoint
def detect_checkpoint_function(train_func, abort=False):
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
argspec = inspect.getfullargspec(train_func)
func_args = argspec.args
func_kwargs = argspec.kwonlyargs
validated = len(func_args) == 2 and any("checkpoint_dir" in arg
for arg in func_args)
validated = validated or (len(func_args) == 1) and any(
"checkpoint_dir" in arg for arg in func_kwargs)
if abort and not validated:
raise ValueError(
"Provided training function must have 2 args "
"in the signature, and the latter arg must "
"contain `checkpoint_dir`. For example: "
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
return validated
def wrap_function(train_func):
class ImplicitFunc(FunctionRunner):
def _trainable_func(self, config, reporter, checkpoint):
def _trainable_func(self, config, reporter, checkpoint_dir):
func_args = inspect.getfullargspec(train_func).args
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and (
"checkpoint" not in func_args):
not detect_checkpoint_function(train_func)):
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint']. Found: {}".format(
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
use_checkpoint = "checkpoint" in func_args
use_checkpoint = detect_checkpoint_function(train_func)
if not use_checkpoint and not use_reporter:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint)`.")
"arguments to be `func(config, checkpoint_dir=None)`.")
output = train_func(config)
elif use_checkpoint:
output = train_func(config, checkpoint=checkpoint)
output = train_func(config, checkpoint_dir=checkpoint_dir)
else:
output = train_func(config, reporter)
+290
View File
@@ -0,0 +1,290 @@
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
from contextlib import contextmanager
import os
import logging
import shutil
import tempfile
import torch
from datetime import timedelta
import ray
from ray import tune
from ray.tune.result import RESULT_DUPLICATE
from ray.tune.logger import NoopLogger
from ray.tune.function_runner import (wrap_function,
detect_checkpoint_function)
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.util.sgd.torch.utils import setup_process_group, setup_address
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
logger = logging.getLogger(__name__)
_distributed_enabled = False
def is_distributed_trainable():
"""Returns True if executing within a DistributedTrainable."""
return _distributed_enabled
def enable_distributed_trainable():
global _distributed_enabled
_distributed_enabled = True
def logger_creator(log_config, logdir, rank):
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
os.makedirs(worker_dir, exist_ok=True)
return NoopLogger(log_config, worker_dir)
class _TorchTrainable(tune.Trainable):
"""Base class for distributed training on Tune.
A wrapper class is needed to actually create a working
version of this trainable.
"""
_function = None
_num_workers = None
_use_gpu = None
_num_cpus_per_worker = None
__slots__ = ["workers", "_finished"]
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
@classmethod
def get_remote_worker_options(self):
num_gpus = 1 if self._use_gpu else 0
num_cpus = int(self._num_cpus_per_worker or 1)
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
def setup(self, config):
self._finished = False
num_workers = self._num_workers
logdir = self.logdir
assert self._function
func_trainable = wrap_function(self.__class__._function)
remote_trainable = ray.remote(func_trainable)
remote_trainable = remote_trainable.options(
**self.get_remote_worker_options())
address = setup_address()
self.workers = [
remote_trainable.remote(
config=config,
logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
for rank in range(num_workers)
]
pgroup_params = self.default_process_group_parameters()
from functools import partial
setup_on_worker = partial(
setup_process_group,
url=address,
world_size=num_workers,
**pgroup_params)
ray.get([
w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
for rank, w in enumerate(self.workers)
])
ray.get([
w.execute.remote(lambda _: enable_distributed_trainable())
for rank, w in enumerate(self.workers)
])
def step(self):
if self._finished:
raise RuntimeError("Training has already finished.")
result = ray.get([w.step.remote() for w in self.workers])[0]
if RESULT_DUPLICATE in result:
self._finished = True
return result
def save_checkpoint(self, checkpoint_dir):
# TODO: optimize if colocated
save_obj = ray.get(self.workers[0].save_to_object.remote())
checkpoint_path = TrainableUtil.create_from_pickle(
save_obj, checkpoint_dir)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir):
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
return ray.get(
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
def stop(self):
ray.get([worker.stop.remote() for worker in self.workers])
def DistributedTrainableCreator(func,
use_gpu=False,
num_workers=1,
num_cpus_per_worker=1,
backend="gloo",
timeout_s=NCCL_TIMEOUT_S):
"""Creates a class that executes distributed training.
Similar to running `torch.distributed.launch`.
Note that you typically should not instantiate the object
created.
Args:
func (callable): This function is a Tune trainable function.
This function must have 2 args in the signature, and the
latter arg must contain `checkpoint_dir`. For example:
`func(config, checkpoint_dir=None)`.
use_gpu (bool): Sets resource allocation for workers to 1 GPU
if true. Also automatically sets CUDA_VISIBLE_DEVICES
for each training worker.
num_workers (int): Number of training workers to include in
world.
num_cpus_per_worker (int): Number of CPU resources to reserve
per training worker.
backend (str): One of "gloo", "nccl".
timeout_s (float): Seconds before the torch process group
times out. Useful when machines are unreliable. Defaults
to 60 seconds.
Returns:
A trainable class object that can be passed to Tune. Resources
are automatically set within the object, so users do
not need to set `resources_per_trainable`.
Example:
.. code-block::
trainable_cls = DistributedTrainableCreator(
train_func, num_workers=2)
analysis = tune.run(trainable_cls)
"""
detect_checkpoint_function(func, abort=True)
class WrappedDistributedTorchTrainable(_TorchTrainable):
_function = func
_num_workers = num_workers
_use_gpu = use_gpu
_num_cpus_per_worker = num_cpus_per_worker
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(timeout_s), backend=backend)
@classmethod
def default_resource_request(cls, config):
num_workers_ = int(config.get("num_workers", num_workers))
num_cpus = int(
config.get("num_cpus_per_worker", num_cpus_per_worker))
use_gpu_ = config.get("use_gpu", use_gpu)
return Resources(
cpu=0,
gpu=0,
extra_cpu=num_cpus * num_workers_,
extra_gpu=num_workers_ if use_gpu_ else 0)
return WrappedDistributedTorchTrainable
@contextmanager
def distributed_checkpoint_dir(step, disable=False):
"""ContextManager for creating a distributed checkpoint.
Only checkpoints a file on the "main" training actor, avoiding
redundant work.
Args:
step (int): Used to label the checkpoint
disable (bool): Disable for prototyping.
Yields:
path (str): A path to a directory. This path will be used
again when invoking the training_function.
Example:
.. code-block::
def train_func(config, checkpoint_dir):
if checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
model_state_dict = torch.load(path)
if epoch % 3 == 0:
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save(model.state_dict(), path)
"""
if torch.distributed.get_rank() == 0 and not disable:
with tune.checkpoint_dir(step=step) as checkpoint_dir:
yield checkpoint_dir
else:
path = tempfile.mkdtemp()
yield path
shutil.rmtree(path)
def _train_check_global(config, checkpoint_dir=None):
"""For testing only. Putting this here because Ray has problems
serializing within the test file."""
assert is_distributed_trainable()
import time
time.sleep(0.1)
tune.report(is_distributed=True)
def _train_simple(config, checkpoint_dir=None):
"""For testing only. Putting this here because Ray has problems
serializing within the test file."""
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 8, 5, 5, 5
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
loss_fn = nn.MSELoss()
# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)
optimizer = optim.SGD(model.parameters(), lr=0.1)
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
model_state, optimizer_state = torch.load(f)
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
model = DistributedDataParallel(model)
for epoch in range(config.get("epochs", 10)):
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
if epoch % 3 == 0:
if config.get("enable_checkpoint", True):
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((model.state_dict(), optimizer.state_dict()),
path)
tune.report(mean_loss=loss.item())
+40 -51
View File
@@ -1,3 +1,4 @@
from contextlib import contextmanager
import os
import logging
@@ -75,48 +76,34 @@ def report(**kwargs):
def make_checkpoint_dir(step=None):
"""Gets the next checkpoint dir.
.. code-block:: python
import time
from ray import tune
def func(config, checkpoint=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
state = json.loads(f.read())
start = state["step"] + 1
for iter in range(start, 100):
time.sleep(1)
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
tune.report(hello="world", ray="tune")
.. warning:: Do not call this function within the Trainable Class API.
Args:
step (int): Current training iteration - used for setting
an index to uniquely identify the checkpoint.
.. versionadded:: 0.8.6
.. deprecated:: 0.8.7
Use tune.checkpoint_dir instead.
"""
_session = get_session()
if _session:
return _session.make_checkpoint_dir(step=step)
else:
return os.path.abspath("./")
raise DeprecationWarning(
"Deprecated method. Use `tune.checkpoint_dir` instead.")
def save_checkpoint(checkpoint):
"""Register the given checkpoint.
.. versionadded:: 0.8.6
.. deprecated:: 0.8.7
Use tune.checkpoint_dir instead.
"""
raise DeprecationWarning(
"Deprecated method. Use `tune.checkpoint_dir` instead.")
@contextmanager
def checkpoint_dir(step=None):
"""Returns a checkpoint dir inside a context.
Store any files related to restoring state within the
provided checkpoint dir.
.. code-block:: python
import os
@@ -124,10 +111,10 @@ def save_checkpoint(checkpoint):
import time
from ray import tune
def func(config, checkpoint=None):
def func(config, checkpoint_dir=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
state = json.loads(f.read())
accuracy = state["acc"]
start = state["step"] + 1
@@ -135,27 +122,29 @@ def save_checkpoint(checkpoint):
for iter in range(start, 10):
time.sleep(1)
checkpoint_dir = tune.make_checkpoint_dir(step=iter)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.report(hello="world", ray="tune")
analysis = tune.run(run_me)
Yields:
checkpoint_dir (str): Directory for checkpointing.
.. warning:: Do not call this function within the Trainable Class API.
Args:
**kwargs: Any key value pair to be logged by Tune. Any of these
metrics can be used for early stopping or optimization.
.. versionadded:: 0.8.6
.. versionadded:: 0.8.7
"""
_session = get_session()
if _session:
return _session.save_checkpoint(checkpoint)
_checkpoint_dir = _session.make_checkpoint_dir(step=step)
else:
_checkpoint_dir = os.path.abspath("./")
yield _checkpoint_dir
if _session:
_session.save_checkpoint(_checkpoint_dir)
def get_trial_dir():
+63 -57
View File
@@ -19,7 +19,7 @@ class FunctionApiTest(unittest.TestCase):
_register_all() # re-register the evicted objects
def testFunctionNoCheckpointing(self):
def train(config, checkpoint=None):
def train(config, checkpoint_dir=None):
for i in range(10):
tune.report(test=i)
@@ -40,14 +40,13 @@ class FunctionApiTest(unittest.TestCase):
def testFunctionRecurringSave(self):
"""This tests that save and restore are commutative."""
def train(config, checkpoint=None):
def train(config, checkpoint_dir=None):
for step in range(10):
if step % 3 == 0:
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": step}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": step}))
tune.report(test=step)
wrapped = wrap_function(train)
@@ -65,49 +64,58 @@ class FunctionApiTest(unittest.TestCase):
new_trainable2.stop()
def testCheckpointFunctionAtEnd(self):
def train(config, checkpoint=False):
def train(config, checkpoint_dir=False):
for i in range(10):
tune.report(test=i)
checkpoint_dir = tune.make_checkpoint_dir(step=10)
checkpoint_path = os.path.join(checkpoint_dir, "hello")
with open(checkpoint_path, "w") as f:
f.write("hello")
tune.save_checkpoint(checkpoint_path)
[trial] = tune.run(train).trials
assert "hello" in trial.checkpoint.value
def testVariousCheckpointFunctionAtEnd(self):
def train(config, checkpoint=False):
for i in range(10):
checkpoint_dir = tune.make_checkpoint_dir()
checkpoint_path = os.path.join(checkpoint_dir, "hello")
with tune.checkpoint_dir(step=10) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write("hello")
tune.save_checkpoint(checkpoint_path)
[trial] = tune.run(train).trials
assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log"))
def testCheckpointFunctionAtEndContext(self):
def train(config, checkpoint_dir=False):
for i in range(10):
tune.report(test=i)
checkpoint_dir = tune.make_checkpoint_dir()
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write("goodbye")
tune.save_checkpoint(checkpoint_path)
with tune.checkpoint_dir(step=10) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write("hello")
[trial] = tune.run(train).trials
assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log"))
def testVariousCheckpointFunctionAtEnd(self):
def train(config, checkpoint_dir=False):
for i in range(10):
with tune.checkpoint_dir() as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write("hello")
tune.report(test=i)
with tune.checkpoint_dir() as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log2")
with open(checkpoint_path, "w") as f:
f.write("goodbye")
[trial] = tune.run(train, keep_checkpoints_num=3).trials
assert "goodbye" in trial.checkpoint.value
assert os.path.exists(
os.path.join(trial.checkpoint.value, "ckpt.log2"))
def testReuseCheckpoint(self):
def train(config, checkpoint=False):
def train(config, checkpoint_dir=None):
itr = 0
if checkpoint:
with open(checkpoint, "r") as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f:
itr = int(f.read()) + 1
for i in range(itr, config["max_iter"]):
checkpoint_dir = tune.make_checkpoint_dir(step=i)
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.save_checkpoint(checkpoint_path)
with tune.checkpoint_dir(step=i) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.report(test=i, training_iteration=i)
[trial] = tune.run(
@@ -117,51 +125,49 @@ class FunctionApiTest(unittest.TestCase):
},
).trials
last_ckpt = trial.checkpoint.value
assert "goodbye" in last_ckpt
assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log"))
analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt)
trial_dfs = list(analysis.trial_dataframes.values())
assert len(trial_dfs[0]["training_iteration"]) == 5
def testRetry(self):
def train(config, checkpoint=None):
restored = bool(checkpoint)
def train(config, checkpoint_dir=None):
restored = bool(checkpoint_dir)
itr = 0
if checkpoint:
with open(checkpoint, "r") as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f:
itr = int(f.read()) + 1
for i in range(itr, 10):
if i == 5 and not restored:
raise Exception("try to fail me")
checkpoint_dir = tune.make_checkpoint_dir(step=i)
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.save_checkpoint(checkpoint_path)
with tune.checkpoint_dir(step=i) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.report(test=i, training_iteration=i)
analysis = tune.run(train, max_failures=3)
last_ckpt = analysis.trials[0].checkpoint.value
assert "goodbye" in last_ckpt
assert os.path.exists(os.path.join(last_ckpt, "ckpt.log"))
trial_dfs = list(analysis.trial_dataframes.values())
assert len(trial_dfs[0]["training_iteration"]) == 10
def testBlankCheckpoint(self):
def train(config, checkpoint=None):
restored = bool(checkpoint)
def train(config, checkpoint_dir=None):
restored = bool(checkpoint_dir)
itr = 0
if checkpoint:
with open(checkpoint, "r") as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f:
itr = int(f.read()) + 1
for i in range(itr, 10):
if i == 5 and not restored:
raise Exception("try to fail me")
checkpoint_dir = tune.make_checkpoint_dir()
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.save_checkpoint(checkpoint_path)
with tune.checkpoint_dir() as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.report(test=i, training_iteration=i)
analysis = tune.run(train, max_failures=3)
@@ -6,8 +6,9 @@ import torch.distributed as dist
import ray
from ray import tune
from ray.util.sgd.torch.func_trainable import (
DistributedTrainableCreator, distributed_checkpoint, _train_simple)
from ray.tune.integration.torch import (DistributedTrainableCreator,
distributed_checkpoint_dir,
_train_simple, _train_check_global)
@pytest.fixture
@@ -47,12 +48,29 @@ def test_step_after_completion(ray_start_2_cpus): # noqa: F811
trainer.train()
def test_validation(ray_start_2_cpus): # noqa: F811
def bad_func(a, b, c):
return 1
with pytest.raises(ValueError):
DistributedTrainableCreator(bad_func, num_workers=2)
def test_set_global(ray_start_2_cpus): # noqa: F811
trainable_cls = DistributedTrainableCreator(
_train_check_global, num_workers=2)
trainable = trainable_cls()
result = trainable.train()
assert result["is_distributed"]
def test_save_checkpoint(ray_start_2_cpus): # noqa: F811
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
trainer = trainable_cls(config={"epochs": 1})
trainer.train()
path = trainer.save()
model_state_dict, opt_state_dict = torch.load(path)
checkpoint_dir = trainer.save()
model_state_dict, opt_state_dict = torch.load(
os.path.join(checkpoint_dir, "checkpoint"))
trainer.stop()
@@ -72,11 +90,11 @@ def test_simple_tune(ray_start_4_cpus, enabled_checkpoint):
def test_checkpoint(ray_start_2_cpus, rank): # noqa: F811
with patch("torch.distributed.get_rank") as rank_method:
rank_method.return_value = rank
with distributed_checkpoint(label="test") as path:
with distributed_checkpoint_dir(step="test") as path:
if rank == 0:
assert path
else:
assert path == os.devnull
if rank != 0:
assert not os.path.exists(path)
if __name__ == "__main__":
-8
View File
@@ -26,14 +26,6 @@ py_test(
deps = [":sgd_lib"],
)
py_test(
name = "test_torch_trainable",
size = "small",
srcs = ["tests/test_torch_trainable.py"],
tags = ["exclusive", "pytorch"],
deps = [":sgd_lib"],
)
# --------------------------------------------------------------------
# Tests from the python/ray/util/sgd/tf/examples directory.
# Please keep these sorted alphabetically.
+1 -6
View File
@@ -12,13 +12,8 @@ try:
BaseTorchTrainable)
from ray.util.sgd.torch.training_operator import TrainingOperator
from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator,
distributed_checkpoint)
__all__ = [
"TorchTrainer", "BaseTorchTrainable", "TrainingOperator",
"distributed_checkpoint", "DistributedTrainableCreator"
]
__all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"]
except ImportError as e:
logger.warning(e)
logger.warning("PyTorch not found. TorchTrainer will not be available")
-235
View File
@@ -1,235 +0,0 @@
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import logging
import torch
from datetime import timedelta
import ray
from ray import tune
from ray.tune.result import RESULT_DUPLICATE
from ray.tune.logger import NoopLogger
from ray.tune.function_runner import wrap_function
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.util.sgd.torch.utils import setup_process_group
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
from ray.util.sgd.torch.utils import setup_address
logger = logging.getLogger(__name__)
def logger_creator(log_config, logdir, rank):
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
os.makedirs(worker_dir, exist_ok=True)
return NoopLogger(log_config, worker_dir)
class _TorchTrainable(tune.Trainable):
"""Base class for distributed training on Tune.
A wrapper class is needed to actually create a working
version of this trainable.
"""
_function = None
_num_workers = None
_use_gpu = None
_num_cpus_per_worker = None
__slots__ = ["workers", "_finished"]
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
@classmethod
def get_remote_worker_options(self):
num_gpus = 1 if self._use_gpu else 0
num_cpus = int(self._num_cpus_per_worker or 1)
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
def setup(self, config):
self._finished = False
num_workers = self._num_workers
logdir = self.logdir
assert self._function
func_trainable = wrap_function(self.__class__._function)
remote_trainable = ray.remote(func_trainable)
remote_trainable = remote_trainable.options(
**self.get_remote_worker_options())
address = setup_address()
self.workers = [
remote_trainable.remote(
config=config,
logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
for rank in range(num_workers)
]
pgroup_params = self.default_process_group_parameters()
from functools import partial
setup_on_worker = partial(
setup_process_group,
url=address,
world_size=num_workers,
**pgroup_params)
ray.get([
w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
for rank, w in enumerate(self.workers)
])
def step(self):
if self._finished:
raise RuntimeError("Training has already finished.")
result = ray.get([w.step.remote() for w in self.workers])[0]
if RESULT_DUPLICATE in result:
self._finished = True
return result
def save_checkpoint(self, checkpoint_dir):
# TODO: optimize if colocated
save_obj = ray.get(self.workers[0].save_to_object.remote())
checkpoint_path = TrainableUtil.create_from_pickle(
save_obj, checkpoint_dir)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir):
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
return ray.get(
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
def stop(self):
ray.get([worker.stop.remote() for worker in self.workers])
def DistributedTrainableCreator(func,
use_gpu=False,
num_workers=1,
num_cpus_per_worker=1,
backend="gloo",
timeout_s=NCCL_TIMEOUT_S):
"""Creates a class that executes distributed training.
Note that you typically should not instantiate the object
created.
Example:
.. code-block::
trainable_cls = DistributedTrainableCreator(
train_func, num_workers=2)
analysis = tune.run(trainable_cls)
"""
class WrappedDistributedTorchTrainable(_TorchTrainable):
_function = func
_num_workers = num_workers
_use_gpu = use_gpu
_num_cpus_per_worker = num_cpus_per_worker
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(timeout_s), backend=backend)
@classmethod
def default_resource_request(cls, config):
num_workers_ = int(config.get("num_workers", num_workers))
num_cpus = int(
config.get("num_cpus_per_worker", num_cpus_per_worker))
use_gpu_ = config.get("use_gpu", use_gpu)
return Resources(
cpu=0,
gpu=0,
extra_cpu=num_cpus * num_workers_,
extra_gpu=num_workers_ if use_gpu_ else 0)
return WrappedDistributedTorchTrainable
class distributed_checkpoint:
"""ContextManager for creating a distributed checkpoint.
Only checkpoints a file on the "main" training actor, avoiding
redundant work.
Args:
label (int | str): Used to label the checkpoint
disable (bool): Disable for prototyping.
Example:
.. code-block::
if epoch % 3 == 0:
with distributed_checkpoint(label=epoch) as path:
torch.save(model.state_dict(), path)
"""
def __init__(self, label, disable=False):
self.label = label
self.file = None
self.disable = disable
def __enter__(self):
if torch.distributed.get_rank() == 0 and not self.disable:
checkpoint_dir = tune.make_checkpoint_dir(step=self.label)
path = os.path.join(checkpoint_dir, "checkpoint")
else:
path = os.devnull
self.file = path
return path
def __exit__(self, type, value, traceback):
if torch.distributed.get_rank() == 0 and not self.disable:
tune.save_checkpoint(self.file)
def _train_simple(config, checkpoint=False):
"""For testing only. Putting this here because Ray has problems
serializing within the test file."""
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 8, 5, 5, 5
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
loss_fn = nn.MSELoss()
# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)
optimizer = optim.SGD(model.parameters(), lr=0.1)
if checkpoint:
with open(checkpoint) as f:
model_state, optimizer_state = torch.load(f)
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
model = DistributedDataParallel(model)
for epoch in range(config.get("epochs", 10)):
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
if epoch % 3 == 0:
if config.get("enable_checkpoint", True):
with distributed_checkpoint(label=epoch) as path:
torch.save((model.state_dict(), optimizer.state_dict()),
path)
tune.report(mean_loss=loss.item())
@@ -140,6 +140,8 @@ class TorchTrainer:
Defaults to True.
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
over each model. If False, you are expected to call it yourself.
timeout_s (float): Seconds before the torch process group
times out. Useful when machines are unreliable.
add_dist_sampler (bool): Whether to automatically add a
DistributedSampler to all created dataloaders. Only applicable
if num_workers > 1.