diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index 98adaffb8..532bf1612 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -18,6 +18,8 @@ For end to end examples leveraging RaySGD TorchTrainer, jump to :ref:`raysgd-tor Setting up training ------------------- +.. tip:: If you want to leverage multi-node data parallel training with PyTorch while using RayTune *without* restructuring your code, check out the :ref:`Tune PyTorch user guide ` and Tune's :ref:`distributed pytorch integrations `. + The ``TorchTrainer`` can be constructed with functions that wrap components of the training script. Specifically, it requires constructors for the Model, Data, Optimizer, Loss, and ``lr_scheduler`` to create replicated copies across different devices and machines. Under the hood, ``TorchTrainer`` will create *replicas* of your model (controlled by ``num_workers``), each of which is managed by a Ray actor. One of the replicas will be on the main process, which can simplify the debugging and logging experience. diff --git a/doc/source/tune/_tutorials/tune-distributed.rst b/doc/source/tune/_tutorials/tune-distributed.rst index e69490116..82ed206c7 100644 --- a/doc/source/tune/_tutorials/tune-distributed.rst +++ b/doc/source/tune/_tutorials/tune-distributed.rst @@ -213,16 +213,7 @@ In GCP, you can use the following configuration modification: scheduling: - preemptible: true -Spot instances may be removed suddenly while trials are still running. Often times this may be difficult to deal with when using other distributed hyperparameter optimization frameworks. Tune allows users to mitigate the effects of this by preserving the progress of your model training through checkpointing. - -The easiest way to do this is to subclass the pre-defined ``Trainable`` class and implement ``save_checkpoint``, and ``load_checkpoint`` abstract methods, as seen in the example below: - -.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_trainable.py - :language: python - :start-after: __trainable_example_begin__ - :end-before: __trainable_example_end__ - -This can then be used similarly to the Function API as before: +Spot instances may be removed suddenly while trials are still running. Often times this may be difficult to deal with when using other distributed hyperparameter optimization frameworks. Tune allows users to mitigate the effects of this by preserving the progress of your model training through :ref:`checkpointing `. .. literalinclude:: /../../python/ray/tune/tests/tutorial.py :language: python diff --git a/doc/source/tune/_tutorials/tune-pytorch-cifar.rst b/doc/source/tune/_tutorials/tune-pytorch-cifar.rst index 6d1b0363c..743f44463 100644 --- a/doc/source/tune/_tutorials/tune-pytorch-cifar.rst +++ b/doc/source/tune/_tutorials/tune-pytorch-cifar.rst @@ -25,6 +25,8 @@ need to 3. add checkpointing (optional), 4. and define the search space for the model tuning +Optionally, you can seamlessly leverage :ref:`DistributedDataParallel training ` for each individual Pytorch model within Tune. + .. note:: To run this example, you will need to install the following: @@ -74,15 +76,16 @@ The train function Now it gets interesting, because we introduce some changes to the example `from the PyTorch documentation `_. -We wrap the training script in a function ``train_cifar(config, checkpoint=None)``. As you +We wrap the training script in a function ``train_cifar(config, checkpoint_dir=None)``. As you can guess, the ``config`` parameter will receive the hyperparameters we would like to -train with. The ``checkpoint`` parameter is used to restore checkpoints. +train with. The ``checkpoint_dir`` parameter is used to restore checkpoints. .. code-block:: python net = Net(config["l1"], config["l2"]) - if checkpoint: + if checkpoint_dir: + checkpoint = os.path.join(checkpoint_dir, "checkpoint") net.load_state_dict(torch.load(checkpoint)) The learning rate of the optimizer is made configurable, too: @@ -97,6 +100,7 @@ with which we iterate through the training and test sets are configurable as wel Adding (multi) GPU support with DataParallel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Image classification benefits largely from GPUs. Luckily, we can continue to use PyTorch's abstractions in Ray Tune. Thus, we can wrap our model in ``nn.DataParallel`` to support data parallel training on multiple GPUs: @@ -132,10 +136,9 @@ The most interesting part is the communication with Tune: .. code-block:: python - checkpoint_dir = tune.make_checkpoint_dir(epoch) - path = os.path.join(checkpoint_dir, "checkpoint") - torch.save((net.state_dict(), optimizer.state_dict()), path) - tune.save_checkpoint(path) + with tune.checkpoint_dir(epoch) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + torch.save((net.state_dict(), optimizer.state_dict()), path) tune.report(loss=(val_loss / val_steps), accuracy=correct / total) @@ -273,6 +276,75 @@ be confirmed on the test set. So that's it! You can now tune the parameters of your PyTorch models. +.. _tune-torch-ddp: + +Advanced: Distributed training with DistributedDataParallel +----------------------------------------------------------- + +Some models require multiple nodes to train in a short amount of time. Ray Tune allows you to easily do distributed data parallel training in addition to distributed hyperparameter tuning. + +You can wrap your model in ``torch.nn.parallel.DistributedDataParallel`` to support distributed data parallel training: + +.. code-block:: python + + from ray.util.sgd.torch import is_distributed_trainable + from torch.nn.parallel import DistributedDataParallel + + def train_cifar(config, checkpoint_dir=None, data_dir=None): + net = Net(config["l1"], config["l2"]) + + device = "cpu" + + #### Using distributed data parallel training + if is_distributed_trainable(): + net = DistributedDataParallel(net) + + if torch.cuda.is_available(): + device = "cuda" + + net.to(device) + + +If using checkpointing, be sure to use a :ref:`special checkpoint context manager `, ``distributed_checkpoint_dir`` that avoids redundant checkpointing across multiple processes: + +.. code-block:: python + + from ray.util.sgd.torch import distributed_checkpoint_dir + + #### Using distributed data parallel training + # Inside `def train_cifar(...)`, + # replace tune.checkpoint_dir() with the following + # Avoids redundant checkpointing on different processes. + with distributed_checkpoint_dir(step=epoch) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + torch.save((net.state_dict(), optimizer.state_dict()), path) + + +Finally, we need to tell Ray Tune to start multiple distributed processes at once by using ``ray.tune.integration.torch.DistributedTrainableCreator`` (:ref:`docs `). This is essentially equivalent to running ``torch.distributed.launch`` for each hyperparameter trial: + +.. code-block:: python + + # You'll probably want to be running on a distributed Ray cluster. + # ray.init(address="auto") + + from ray.util.sgd.integration.torch import DistributedTrainableCreator + + distributed_train_cifar = DistributedTrainableCreator( + partial(train_cifar, data_dir=data_dir), + use_gpu=True, + num_workers=2, # number of parallel workers to use + num_cpus_per_worker=8 + ) + tune.run( + distributed_train_cifar, + resources_per_trial=None, + config=config, + num_samples=num_samples, + ... + ) + +See an :doc:`end-to-end example here `. + If you consider switching to PyTorch Lightning to get rid of some of your boilerplate training code, please know that we also have a walkthrough on :doc:`how to use Tune with -PyTorch Lightning models `. \ No newline at end of file +PyTorch Lightning models `. diff --git a/doc/source/tune/api_docs/trainable.rst b/doc/source/tune/api_docs/trainable.rst index e9ee28874..6d0935657 100644 --- a/doc/source/tune/api_docs/trainable.rst +++ b/doc/source/tune/api_docs/trainable.rst @@ -17,6 +17,7 @@ For the sake of example, let's maximize this objective function: Function API ------------ + Here is a simple example of using the function API. You can report intermediate metrics by simply calling ``tune.report`` within the provided function. .. code-block:: python @@ -40,38 +41,41 @@ Here is a simple example of using the function API. You can report intermediate Tune will run this function on a separate thread in a Ray actor process. +.. tip:: If you want to leverage multi-node data parallel training with PyTorch while using parallel hyperparameter tuning, check out our :ref:PyTorch user guide and Tune's :ref:distributed pytorch integrations. + +.. _tune-function-checkpointing: Function API Checkpointing ~~~~~~~~~~~~~~~~~~~~~~~~~~ -Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``: +Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. To use Tune's checkpointing features, you must expose a ``checkpoint_dir`` argument in the function signature, and call ``tune.checkpoint_dir`` : .. code-block:: python import time from ray import tune - def train_func(config, checkpoint=None): + def train_func(config, checkpoint_dir=None): start = 0 - if checkpoint: - with open(checkpoint) as f: + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "checkpoint")) as f: state = json.loads(f.read()) start = state["step"] + 1 for iter in range(start, 100): time.sleep(1) - # - checkpoint_dir = tune.make_checkpoint_dir(step=step) - path = os.path.join(checkpoint_dir, "checkpoint") - with open(path, "w") as f: - f.write(json.dumps({"step": start})) - tune.save_checkpoint(path) + with tune.checkpoint_dir(step=step): + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"step": start})) tune.report(hello="world", ray="tune") tune.run(train_func) +.. note:: ``checkpoint_freq`` and ``checkpoint_at_end`` will not work with Function API checkpointing. + In this example, checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_``. You can restore a single trial checkpoint by using ``tune.run(restore=)``: .. code-block:: python @@ -263,9 +267,7 @@ tune.report / tune.checkpoint (Function API) .. autofunction:: ray.tune.report -.. autofunction:: ray.tune.make_checkpoint_dir - -.. autofunction:: ray.tune.save_checkpoint +.. autofunction:: ray.tune.checkpoint_dir .. autofunction:: ray.tune.get_trial_dir @@ -282,6 +284,21 @@ tune.Trainable (Class API) :private-members: :members: + +.. _tune-ddp-doc: + +Distributed Torch +----------------- + +Ray also offers lightweight integrations to distribute your model training on Ray Tune. + + +.. autofunction:: ray.tune.integration.torch.DistributedTrainableCreator + +.. autofunction:: ray.tune.integration.torch.distributed_checkpoint_dir + +.. autofunction:: ray.tune.integration.torch.is_distributed_trainable + tune.DurableTrainable --------------------- diff --git a/doc/source/tune/examples/ddp_mnist_torch.rst b/doc/source/tune/examples/ddp_mnist_torch.rst new file mode 100644 index 000000000..8ac04656e --- /dev/null +++ b/doc/source/tune/examples/ddp_mnist_torch.rst @@ -0,0 +1,6 @@ +:orphan: + +ddp_mnist_torch +~~~~~~~~~~~~~~~ + +.. literalinclude:: /../../python/ray/tune/examples/ddp_mnist_torch.py diff --git a/doc/source/tune/examples/index.rst b/doc/source/tune/examples/index.rst index b83873a3d..e257f3ff7 100644 --- a/doc/source/tune/examples/index.rst +++ b/doc/source/tune/examples/index.rst @@ -41,6 +41,7 @@ PyTorch Examples - :doc:`/tune/examples/mnist_pytorch`: Converts the PyTorch MNIST example to use Tune with the function-based API. Also shows how to easily convert something relying on argparse to use Tune. - :doc:`/tune/examples/mnist_pytorch_trainable`: Converts the PyTorch MNIST example to use Tune with Trainable API. Also uses the HyperBandScheduler and checkpoints the model at the end. +- :doc:`/tune/examples/ddp_mnist_torch`: An example showing how to use DistributedDataParallel with Ray Tune. This enables both distributed training and distributed hyperparameter tuning. XGBoost Example diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst index 694f315c0..d3140642c 100644 --- a/doc/source/tune/user-guide.rst +++ b/doc/source/tune/user-guide.rst @@ -151,17 +151,18 @@ When running a hyperparameter search, Tune can automatically and periodically sa Checkpointing assumes that the model state will be saved to disk on whichever node the Trainable is running on. -To use Tune's checkpointing features, you must expose a ``checkpoint`` argument in the function signature, and call ``tune.make_checkpoint_dir`` and ``tune.save_checkpoint``: +To use Tune's checkpointing features, you must expose a ``checkpoint_dir`` argument in the function signature, and call ``tune.checkpoint_dir``: .. code-block:: python + import os import time from ray import tune - def train_func(config, checkpoint=None): + def train_func(config, checkpoint_dir=None): start = 0 - if checkpoint: - with open(checkpoint) as f: + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "checkpoint")) as f: state = json.loads(f.read()) start = state["step"] + 1 @@ -169,11 +170,10 @@ To use Tune's checkpointing features, you must expose a ``checkpoint`` argument time.sleep(1) # Obtain a checkpoint directory - checkpoint_dir = tune.make_checkpoint_dir(step=step) - path = os.path.join(checkpoint_dir, "checkpoint") - with open(path, "w") as f: - f.write(json.dumps({"step": start})) - tune.save_checkpoint(path) + with tune.checkpoint_dir(step=step) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"step": start})) tune.report(hello="world", ray="tune") diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index d5334546a..354da4aec 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -287,6 +287,14 @@ py_test( args = ["--smoke-test"] ) +py_test( + name = "test_torch_trainable", + size = "medium", + srcs = ["tests/test_torch_trainable.py"], + tags = ["exclusive", "example", "pytorch"], + deps = [":tune_lib"], +) + py_test( name = "ddp_mnist_torch", size = "small", diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 7fd38c142..43a65c430 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -9,7 +9,7 @@ from ray.tune.durable_trainable import DurableTrainable from ray.tune.suggest import grid_search from ray.tune.session import (report, get_trial_dir, get_trial_name, get_trial_id, make_checkpoint_dir, - save_checkpoint) + save_checkpoint, checkpoint_dir) from ray.tune.progress_reporter import (ProgressReporter, CLIReporter, JupyterNotebookReporter) from ray.tune.sample import (function, sample_from, uniform, choice, randint, @@ -22,5 +22,5 @@ __all__ = [ "uniform", "choice", "randint", "randn", "loguniform", "ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter", "ProgressReporter", "report", "get_trial_dir", "get_trial_name", - "get_trial_id", "make_checkpoint_dir", "save_checkpoint" + "get_trial_id", "make_checkpoint_dir", "save_checkpoint", "checkpoint_dir" ] diff --git a/python/ray/tune/examples/cifar10_pytorch.py b/python/ray/tune/examples/cifar10_pytorch.py index 90b39f57f..5089ef364 100644 --- a/python/ray/tune/examples/cifar10_pytorch.py +++ b/python/ray/tune/examples/cifar10_pytorch.py @@ -58,7 +58,7 @@ class Net(nn.Module): # __train_begin__ -def train_cifar(config, checkpoint=None, data_dir=None): +def train_cifar(config, checkpoint_dir=None, data_dir=None): net = Net(config["l1"], config["l2"]) device = "cpu" @@ -71,8 +71,8 @@ def train_cifar(config, checkpoint=None, data_dir=None): criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9) - if checkpoint: - print("loading checkpoint {}".format(checkpoint)) + if checkpoint_dir: + checkpoint = os.path.join(checkpoint_dir, "checkpoint") model_state, optimizer_state = torch.load(checkpoint) net.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) @@ -138,10 +138,10 @@ def train_cifar(config, checkpoint=None, data_dir=None): val_loss += loss.cpu().numpy() val_steps += 1 - checkpoint_dir = tune.make_checkpoint_dir(epoch) - path = os.path.join(checkpoint_dir, "checkpoint") - torch.save((net.state_dict(), optimizer.state_dict()), path) - tune.save_checkpoint(path) + with tune.checkpoint_dir(step=epoch) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + torch.save( + (net.state_dict(), optimizer.state_dict()), path) tune.report(loss=(val_loss / val_steps), accuracy=correct / total) print("Finished Training") @@ -213,7 +213,9 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2): best_trained_model = nn.DataParallel(best_trained_model) best_trained_model.to(device) - model_state, optimizer_state = torch.load(best_trial.checkpoint.value) + checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint") + + model_state, optimizer_state = torch.load(checkpoint_path) best_trained_model.load_state_dict(model_state) test_acc = test_accuracy(best_trained_model, device) diff --git a/python/ray/tune/examples/ddp_mnist_torch.py b/python/ray/tune/examples/ddp_mnist_torch.py index 2641bd662..edb0694f6 100644 --- a/python/ray/tune/examples/ddp_mnist_torch.py +++ b/python/ray/tune/examples/ddp_mnist_torch.py @@ -2,6 +2,7 @@ # https://github.com/pytorch/examples/blob/master/mnist/main.py import argparse import logging +import os import torch import torch.optim as optim from torch.nn.parallel import DistributedDataParallel @@ -10,21 +11,21 @@ import ray from ray import tune from ray.tune.examples.mnist_pytorch import (train, test, get_data_loaders, ConvNet) -from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator, - distributed_checkpoint) +from ray.tune.integration.torch import (DistributedTrainableCreator, + distributed_checkpoint_dir) logger = logging.getLogger(__name__) -def train_mnist(config, checkpoint=False): +def train_mnist(config, checkpoint_dir=False): use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") train_loader, test_loader = get_data_loaders() model = ConvNet().to(device) optimizer = optim.SGD(model.parameters(), lr=0.1) - if checkpoint: - with open(checkpoint) as f: + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "checkpoint")) as f: model_state, optimizer_state = torch.load(f) model.load_state_dict(model_state) @@ -37,7 +38,8 @@ def train_mnist(config, checkpoint=False): acc = test(model, test_loader, device) if epoch % 3 == 0: - with distributed_checkpoint(label=epoch) as path: + with distributed_checkpoint_dir(step=epoch) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") torch.save((model.state_dict(), optimizer.state_dict()), path) tune.report(mean_accuracy=acc) diff --git a/python/ray/tune/examples/hyperband_function_example.py b/python/ray/tune/examples/hyperband_function_example.py index 6d4d3dc4e..2492229fd 100644 --- a/python/ray/tune/examples/hyperband_function_example.py +++ b/python/ray/tune/examples/hyperband_function_example.py @@ -11,10 +11,10 @@ from ray import tune from ray.tune.schedulers import HyperBandScheduler -def train(config, checkpoint=None): +def train(config, checkpoint_dir=None): step = 0 - if checkpoint: - with open(checkpoint) as f: + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "checkpoint")) as f: step = json.loads(f.read())["timestep"] for timestep in range(step, 100): @@ -22,11 +22,10 @@ def train(config, checkpoint=None): v *= config.get("height", 1) if timestep % 3 == 0: - checkpoint_dir = tune.make_checkpoint_dir(step=timestep) - path = os.path.join(checkpoint_dir, "checkpoint") - with open(path, "w") as f: - f.write(json.dumps({"timestep": timestep})) - tune.save_checkpoint(path) + with tune.checkpoint_dir(step=timestep) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"timestep": timestep})) # Here we use `episode_reward_mean`, but you can also report other # objectives such as loss or accuracy. diff --git a/python/ray/tune/examples/mnist_pytorch_lightning.py b/python/ray/tune/examples/mnist_pytorch_lightning.py index 0074bf3c7..95ac2ee60 100644 --- a/python/ray/tune/examples/mnist_pytorch_lightning.py +++ b/python/ray/tune/examples/mnist_pytorch_lightning.py @@ -158,16 +158,15 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0): # __tune_checkpoint_callback_begin__ class CheckpointCallback(Callback): def on_validation_end(self, trainer, pl_module): - path = tune.make_checkpoint_dir(trainer.global_step) - trainer.save_checkpoint(os.path.join(path, "checkpoint")) - tune.save_checkpoint(path) + with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir: + trainer.save_checkpoint(os.path.join(checkpoint_dir, "checkpoint")) # __tune_checkpoint_callback_end__ # __tune_train_checkpoint_begin__ def train_mnist_tune_checkpoint( config, - checkpoint=None, + checkpoint_dir=None, data_dir=None, num_epochs=10, num_gpus=0): @@ -179,13 +178,13 @@ def train_mnist_tune_checkpoint( progress_bar_refresh_rate=0, callbacks=[CheckpointCallback(), TuneReportCallback()]) - if checkpoint: + if checkpoint_dir: # Currently, this leads to errors: # model = LightningMNISTClassifier.load_from_checkpoint( # os.path.join(checkpoint, "checkpoint")) # Workaround: ckpt = pl_load( - os.path.join(checkpoint, "checkpoint"), + os.path.join(checkpoint_dir, "checkpoint"), map_location=lambda storage, loc: storage) model = LightningMNISTClassifier._load_model_state(ckpt, config=config) trainer.current_epoch = ckpt["epoch"] diff --git a/python/ray/tune/examples/pbt_function.py b/python/ray/tune/examples/pbt_function.py index 8e75240ce..69c8d2570 100644 --- a/python/ray/tune/examples/pbt_function.py +++ b/python/ray/tune/examples/pbt_function.py @@ -11,7 +11,7 @@ from ray import tune from ray.tune.schedulers import PopulationBasedTraining -def pbt_function(config, checkpoint=None): +def pbt_function(config, checkpoint_dir=None): """Toy PBT problem for benchmarking adaptive learning rate. The goal is to optimize this trainable's accuracy. The accuracy increases @@ -35,8 +35,8 @@ def pbt_function(config, checkpoint=None): lr = config["lr"] accuracy = 0.0 # end = 1000 start = 0 - if checkpoint: - with open(checkpoint) as f: + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "checkpoint")) as f: state = json.loads(f.read()) accuracy = state["acc"] start = state["step"] @@ -65,11 +65,10 @@ def pbt_function(config, checkpoint=None): accuracy = max(0, accuracy) if step % 3 == 0: - checkpoint_dir = tune.make_checkpoint_dir(step=step) - path = os.path.join(checkpoint_dir, "checkpoint") - with open(path, "w") as f: - f.write(json.dumps({"acc": accuracy, "step": start})) - tune.save_checkpoint(path) + with tune.checkpoint_dir(step=step) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"acc": accuracy, "step": start})) tune.report( mean_accuracy=accuracy, diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index dede70789..8d81a1132 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -350,35 +350,47 @@ class FunctionRunner(Trainable): pass -def detect_checkpoint_function(train_func): - func_args = inspect.getfullargspec(train_func).args - use_checkpoint = "checkpoint" in func_args - return use_checkpoint +def detect_checkpoint_function(train_func, abort=False): + """Use checkpointing if any arg has "checkpoint_dir" and args = 2""" + argspec = inspect.getfullargspec(train_func) + func_args = argspec.args + func_kwargs = argspec.kwonlyargs + validated = len(func_args) == 2 and any("checkpoint_dir" in arg + for arg in func_args) + validated = validated or (len(func_args) == 1) and any( + "checkpoint_dir" in arg for arg in func_kwargs) + if abort and not validated: + raise ValueError( + "Provided training function must have 2 args " + "in the signature, and the latter arg must " + "contain `checkpoint_dir`. For example: " + "`func(config, checkpoint_dir=None)`. Got {}".format(func_args)) + return validated def wrap_function(train_func): class ImplicitFunc(FunctionRunner): - def _trainable_func(self, config, reporter, checkpoint): + def _trainable_func(self, config, reporter, checkpoint_dir): func_args = inspect.getfullargspec(train_func).args if len(func_args) > 1: # more arguments than just the config if "reporter" not in func_args and ( - "checkpoint" not in func_args): + not detect_checkpoint_function(train_func)): raise ValueError( "Unknown argument found in the Trainable function. " "Arguments other than the 'config' arg must be one " - "of ['reporter', 'checkpoint']. Found: {}".format( + "of ['reporter', 'checkpoint_dir']. Found: {}".format( func_args)) use_reporter = "reporter" in func_args - use_checkpoint = "checkpoint" in func_args + use_checkpoint = detect_checkpoint_function(train_func) if not use_checkpoint and not use_reporter: logger.warning( "Function checkpointing is disabled. This may result in " "unexpected behavior when using checkpointing features or " "certain schedulers. To enable, set the train function " - "arguments to be `func(config, checkpoint)`.") + "arguments to be `func(config, checkpoint_dir=None)`.") output = train_func(config) elif use_checkpoint: - output = train_func(config, checkpoint=checkpoint) + output = train_func(config, checkpoint_dir=checkpoint_dir) else: output = train_func(config, reporter) diff --git a/python/ray/tune/integration/torch.py b/python/ray/tune/integration/torch.py new file mode 100644 index 000000000..d5257931b --- /dev/null +++ b/python/ray/tune/integration/torch.py @@ -0,0 +1,290 @@ +# Original Code here: +# https://github.com/pytorch/examples/blob/master/mnist/main.py +from contextlib import contextmanager +import os +import logging +import shutil +import tempfile +import torch +from datetime import timedelta + +import ray +from ray import tune +from ray.tune.result import RESULT_DUPLICATE +from ray.tune.logger import NoopLogger +from ray.tune.function_runner import (wrap_function, + detect_checkpoint_function) +from ray.tune.resources import Resources +from ray.tune.trainable import TrainableUtil +from ray.util.sgd.torch.utils import setup_process_group, setup_address +from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S + +logger = logging.getLogger(__name__) + +_distributed_enabled = False + + +def is_distributed_trainable(): + """Returns True if executing within a DistributedTrainable.""" + return _distributed_enabled + + +def enable_distributed_trainable(): + global _distributed_enabled + _distributed_enabled = True + + +def logger_creator(log_config, logdir, rank): + worker_dir = os.path.join(logdir, "worker_{}".format(rank)) + os.makedirs(worker_dir, exist_ok=True) + return NoopLogger(log_config, worker_dir) + + +class _TorchTrainable(tune.Trainable): + """Base class for distributed training on Tune. + + A wrapper class is needed to actually create a working + version of this trainable. + """ + _function = None + _num_workers = None + _use_gpu = None + _num_cpus_per_worker = None + + __slots__ = ["workers", "_finished"] + + @classmethod + def default_process_group_parameters(self): + return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo") + + @classmethod + def get_remote_worker_options(self): + num_gpus = 1 if self._use_gpu else 0 + num_cpus = int(self._num_cpus_per_worker or 1) + return dict(num_cpus=num_cpus, num_gpus=num_gpus) + + def setup(self, config): + self._finished = False + num_workers = self._num_workers + logdir = self.logdir + assert self._function + + func_trainable = wrap_function(self.__class__._function) + + remote_trainable = ray.remote(func_trainable) + remote_trainable = remote_trainable.options( + **self.get_remote_worker_options()) + + address = setup_address() + self.workers = [ + remote_trainable.remote( + config=config, + logger_creator=lambda cfg: logger_creator(cfg, logdir, rank)) + for rank in range(num_workers) + ] + + pgroup_params = self.default_process_group_parameters() + from functools import partial + setup_on_worker = partial( + setup_process_group, + url=address, + world_size=num_workers, + **pgroup_params) + ray.get([ + w.execute.remote(lambda _: setup_on_worker(world_rank=rank)) + for rank, w in enumerate(self.workers) + ]) + + ray.get([ + w.execute.remote(lambda _: enable_distributed_trainable()) + for rank, w in enumerate(self.workers) + ]) + + def step(self): + if self._finished: + raise RuntimeError("Training has already finished.") + result = ray.get([w.step.remote() for w in self.workers])[0] + if RESULT_DUPLICATE in result: + self._finished = True + return result + + def save_checkpoint(self, checkpoint_dir): + # TODO: optimize if colocated + save_obj = ray.get(self.workers[0].save_to_object.remote()) + checkpoint_path = TrainableUtil.create_from_pickle( + save_obj, checkpoint_dir) + return checkpoint_path + + def load_checkpoint(self, checkpoint_dir): + checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir) + return ray.get( + w.restore_from_object.remote(checkpoint_obj) for w in self.workers) + + def stop(self): + ray.get([worker.stop.remote() for worker in self.workers]) + + +def DistributedTrainableCreator(func, + use_gpu=False, + num_workers=1, + num_cpus_per_worker=1, + backend="gloo", + timeout_s=NCCL_TIMEOUT_S): + """Creates a class that executes distributed training. + + Similar to running `torch.distributed.launch`. + + Note that you typically should not instantiate the object + created. + + Args: + func (callable): This function is a Tune trainable function. + This function must have 2 args in the signature, and the + latter arg must contain `checkpoint_dir`. For example: + `func(config, checkpoint_dir=None)`. + use_gpu (bool): Sets resource allocation for workers to 1 GPU + if true. Also automatically sets CUDA_VISIBLE_DEVICES + for each training worker. + num_workers (int): Number of training workers to include in + world. + num_cpus_per_worker (int): Number of CPU resources to reserve + per training worker. + backend (str): One of "gloo", "nccl". + timeout_s (float): Seconds before the torch process group + times out. Useful when machines are unreliable. Defaults + to 60 seconds. + + Returns: + A trainable class object that can be passed to Tune. Resources + are automatically set within the object, so users do + not need to set `resources_per_trainable`. + + Example: + + .. code-block:: + + trainable_cls = DistributedTrainableCreator( + train_func, num_workers=2) + analysis = tune.run(trainable_cls) + """ + detect_checkpoint_function(func, abort=True) + + class WrappedDistributedTorchTrainable(_TorchTrainable): + _function = func + _num_workers = num_workers + _use_gpu = use_gpu + _num_cpus_per_worker = num_cpus_per_worker + + @classmethod + def default_process_group_parameters(self): + return dict(timeout=timedelta(timeout_s), backend=backend) + + @classmethod + def default_resource_request(cls, config): + num_workers_ = int(config.get("num_workers", num_workers)) + num_cpus = int( + config.get("num_cpus_per_worker", num_cpus_per_worker)) + use_gpu_ = config.get("use_gpu", use_gpu) + + return Resources( + cpu=0, + gpu=0, + extra_cpu=num_cpus * num_workers_, + extra_gpu=num_workers_ if use_gpu_ else 0) + + return WrappedDistributedTorchTrainable + + +@contextmanager +def distributed_checkpoint_dir(step, disable=False): + """ContextManager for creating a distributed checkpoint. + + Only checkpoints a file on the "main" training actor, avoiding + redundant work. + + Args: + step (int): Used to label the checkpoint + disable (bool): Disable for prototyping. + + Yields: + path (str): A path to a directory. This path will be used + again when invoking the training_function. + Example: + + .. code-block:: + + def train_func(config, checkpoint_dir): + if checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + model_state_dict = torch.load(path) + + if epoch % 3 == 0: + with distributed_checkpoint_dir(step=epoch) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + torch.save(model.state_dict(), path) + """ + + if torch.distributed.get_rank() == 0 and not disable: + with tune.checkpoint_dir(step=step) as checkpoint_dir: + yield checkpoint_dir + else: + path = tempfile.mkdtemp() + yield path + shutil.rmtree(path) + + +def _train_check_global(config, checkpoint_dir=None): + """For testing only. Putting this here because Ray has problems + serializing within the test file.""" + assert is_distributed_trainable() + import time + time.sleep(0.1) + tune.report(is_distributed=True) + + +def _train_simple(config, checkpoint_dir=None): + """For testing only. Putting this here because Ray has problems + serializing within the test file.""" + import torch.nn as nn + from torch.nn.parallel import DistributedDataParallel + import torch.optim as optim + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 8, 5, 5, 5 + + # Create random Tensors to hold inputs and outputs + x = torch.randn(N, D_in) + y = torch.randn(N, D_out) + loss_fn = nn.MSELoss() + + # Use the nn package to define our model and loss function. + model = torch.nn.Sequential( + torch.nn.Linear(D_in, H), + torch.nn.ReLU(), + torch.nn.Linear(H, D_out), + ) + optimizer = optim.SGD(model.parameters(), lr=0.1) + + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "checkpoint")) as f: + model_state, optimizer_state = torch.load(f) + + model.load_state_dict(model_state) + optimizer.load_state_dict(optimizer_state) + + model = DistributedDataParallel(model) + + for epoch in range(config.get("epochs", 10)): + optimizer.zero_grad() + output = model(x) + loss = loss_fn(output, y) + loss.backward() + optimizer.step() + + if epoch % 3 == 0: + if config.get("enable_checkpoint", True): + with distributed_checkpoint_dir(step=epoch) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + torch.save((model.state_dict(), optimizer.state_dict()), + path) + tune.report(mean_loss=loss.item()) diff --git a/python/ray/tune/session.py b/python/ray/tune/session.py index 09baf7c08..db1ff426d 100644 --- a/python/ray/tune/session.py +++ b/python/ray/tune/session.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager import os import logging @@ -75,48 +76,34 @@ def report(**kwargs): def make_checkpoint_dir(step=None): """Gets the next checkpoint dir. - .. code-block:: python - - import time - from ray import tune - - def func(config, checkpoint=None): - start = 0 - if checkpoint: - with open(checkpoint) as f: - state = json.loads(f.read()) - start = state["step"] + 1 - - for iter in range(start, 100): - time.sleep(1) - - checkpoint_dir = tune.make_checkpoint_dir(step=step) - path = os.path.join(checkpoint_dir, "checkpoint") - with open(path, "w") as f: - f.write(json.dumps({"step": start})) - tune.save_checkpoint(path) - - tune.report(hello="world", ray="tune") - - .. warning:: Do not call this function within the Trainable Class API. - - Args: - step (int): Current training iteration - used for setting - an index to uniquely identify the checkpoint. - .. versionadded:: 0.8.6 + .. deprecated:: 0.8.7 + Use tune.checkpoint_dir instead. """ - _session = get_session() - if _session: - return _session.make_checkpoint_dir(step=step) - else: - return os.path.abspath("./") + raise DeprecationWarning( + "Deprecated method. Use `tune.checkpoint_dir` instead.") def save_checkpoint(checkpoint): """Register the given checkpoint. + .. versionadded:: 0.8.6 + + .. deprecated:: 0.8.7 + Use tune.checkpoint_dir instead. + """ + raise DeprecationWarning( + "Deprecated method. Use `tune.checkpoint_dir` instead.") + + +@contextmanager +def checkpoint_dir(step=None): + """Returns a checkpoint dir inside a context. + + Store any files related to restoring state within the + provided checkpoint dir. + .. code-block:: python import os @@ -124,10 +111,10 @@ def save_checkpoint(checkpoint): import time from ray import tune - def func(config, checkpoint=None): + def func(config, checkpoint_dir=None): start = 0 - if checkpoint: - with open(checkpoint) as f: + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "checkpoint")) as f: state = json.loads(f.read()) accuracy = state["acc"] start = state["step"] + 1 @@ -135,27 +122,29 @@ def save_checkpoint(checkpoint): for iter in range(start, 10): time.sleep(1) - checkpoint_dir = tune.make_checkpoint_dir(step=iter) - path = os.path.join(checkpoint_dir, "checkpoint") - with open(path, "w") as f: - f.write(json.dumps({"step": start})) - tune.save_checkpoint(path) + with tune.checkpoint_dir(step=iter) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"step": start})) tune.report(hello="world", ray="tune") - analysis = tune.run(run_me) + Yields: + checkpoint_dir (str): Directory for checkpointing. - .. warning:: Do not call this function within the Trainable Class API. - - Args: - **kwargs: Any key value pair to be logged by Tune. Any of these - metrics can be used for early stopping or optimization. - - .. versionadded:: 0.8.6 + .. versionadded:: 0.8.7 """ _session = get_session() + if _session: - return _session.save_checkpoint(checkpoint) + _checkpoint_dir = _session.make_checkpoint_dir(step=step) + else: + _checkpoint_dir = os.path.abspath("./") + + yield _checkpoint_dir + + if _session: + _session.save_checkpoint(_checkpoint_dir) def get_trial_dir(): diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py index 8c390799f..0ff5a7cf9 100644 --- a/python/ray/tune/tests/test_function_api.py +++ b/python/ray/tune/tests/test_function_api.py @@ -19,7 +19,7 @@ class FunctionApiTest(unittest.TestCase): _register_all() # re-register the evicted objects def testFunctionNoCheckpointing(self): - def train(config, checkpoint=None): + def train(config, checkpoint_dir=None): for i in range(10): tune.report(test=i) @@ -40,14 +40,13 @@ class FunctionApiTest(unittest.TestCase): def testFunctionRecurringSave(self): """This tests that save and restore are commutative.""" - def train(config, checkpoint=None): + def train(config, checkpoint_dir=None): for step in range(10): if step % 3 == 0: - checkpoint_dir = tune.make_checkpoint_dir(step=step) - path = os.path.join(checkpoint_dir, "checkpoint") - with open(path, "w") as f: - f.write(json.dumps({"step": step})) - tune.save_checkpoint(path) + with tune.checkpoint_dir(step=step) as checkpoint_dir: + path = os.path.join(checkpoint_dir, "checkpoint") + with open(path, "w") as f: + f.write(json.dumps({"step": step})) tune.report(test=step) wrapped = wrap_function(train) @@ -65,49 +64,58 @@ class FunctionApiTest(unittest.TestCase): new_trainable2.stop() def testCheckpointFunctionAtEnd(self): - def train(config, checkpoint=False): + def train(config, checkpoint_dir=False): for i in range(10): tune.report(test=i) - checkpoint_dir = tune.make_checkpoint_dir(step=10) - checkpoint_path = os.path.join(checkpoint_dir, "hello") - with open(checkpoint_path, "w") as f: - f.write("hello") - tune.save_checkpoint(checkpoint_path) - - [trial] = tune.run(train).trials - assert "hello" in trial.checkpoint.value - - def testVariousCheckpointFunctionAtEnd(self): - def train(config, checkpoint=False): - for i in range(10): - checkpoint_dir = tune.make_checkpoint_dir() - checkpoint_path = os.path.join(checkpoint_dir, "hello") + with tune.checkpoint_dir(step=10) as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log") with open(checkpoint_path, "w") as f: f.write("hello") - tune.save_checkpoint(checkpoint_path) + + [trial] = tune.run(train).trials + assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log")) + + def testCheckpointFunctionAtEndContext(self): + def train(config, checkpoint_dir=False): + for i in range(10): tune.report(test=i) - checkpoint_dir = tune.make_checkpoint_dir() - checkpoint_path = os.path.join(checkpoint_dir, "goodbye") - with open(checkpoint_path, "w") as f: - f.write("goodbye") - tune.save_checkpoint(checkpoint_path) + with tune.checkpoint_dir(step=10) as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log") + with open(checkpoint_path, "w") as f: + f.write("hello") + + [trial] = tune.run(train).trials + assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log")) + + def testVariousCheckpointFunctionAtEnd(self): + def train(config, checkpoint_dir=False): + for i in range(10): + with tune.checkpoint_dir() as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log") + with open(checkpoint_path, "w") as f: + f.write("hello") + tune.report(test=i) + with tune.checkpoint_dir() as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log2") + with open(checkpoint_path, "w") as f: + f.write("goodbye") [trial] = tune.run(train, keep_checkpoints_num=3).trials - assert "goodbye" in trial.checkpoint.value + assert os.path.exists( + os.path.join(trial.checkpoint.value, "ckpt.log2")) def testReuseCheckpoint(self): - def train(config, checkpoint=False): + def train(config, checkpoint_dir=None): itr = 0 - if checkpoint: - with open(checkpoint, "r") as f: + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f: itr = int(f.read()) + 1 for i in range(itr, config["max_iter"]): - checkpoint_dir = tune.make_checkpoint_dir(step=i) - checkpoint_path = os.path.join(checkpoint_dir, "goodbye") - with open(checkpoint_path, "w") as f: - f.write(str(i)) - tune.save_checkpoint(checkpoint_path) + with tune.checkpoint_dir(step=i) as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log") + with open(checkpoint_path, "w") as f: + f.write(str(i)) tune.report(test=i, training_iteration=i) [trial] = tune.run( @@ -117,51 +125,49 @@ class FunctionApiTest(unittest.TestCase): }, ).trials last_ckpt = trial.checkpoint.value - assert "goodbye" in last_ckpt + assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log")) analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt) trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 5 def testRetry(self): - def train(config, checkpoint=None): - restored = bool(checkpoint) + def train(config, checkpoint_dir=None): + restored = bool(checkpoint_dir) itr = 0 - if checkpoint: - with open(checkpoint, "r") as f: + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f: itr = int(f.read()) + 1 for i in range(itr, 10): if i == 5 and not restored: raise Exception("try to fail me") - checkpoint_dir = tune.make_checkpoint_dir(step=i) - checkpoint_path = os.path.join(checkpoint_dir, "goodbye") - with open(checkpoint_path, "w") as f: - f.write(str(i)) - tune.save_checkpoint(checkpoint_path) + with tune.checkpoint_dir(step=i) as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log") + with open(checkpoint_path, "w") as f: + f.write(str(i)) tune.report(test=i, training_iteration=i) analysis = tune.run(train, max_failures=3) last_ckpt = analysis.trials[0].checkpoint.value - assert "goodbye" in last_ckpt + assert os.path.exists(os.path.join(last_ckpt, "ckpt.log")) trial_dfs = list(analysis.trial_dataframes.values()) assert len(trial_dfs[0]["training_iteration"]) == 10 def testBlankCheckpoint(self): - def train(config, checkpoint=None): - restored = bool(checkpoint) + def train(config, checkpoint_dir=None): + restored = bool(checkpoint_dir) itr = 0 - if checkpoint: - with open(checkpoint, "r") as f: + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f: itr = int(f.read()) + 1 for i in range(itr, 10): if i == 5 and not restored: raise Exception("try to fail me") - checkpoint_dir = tune.make_checkpoint_dir() - checkpoint_path = os.path.join(checkpoint_dir, "goodbye") - with open(checkpoint_path, "w") as f: - f.write(str(i)) - tune.save_checkpoint(checkpoint_path) + with tune.checkpoint_dir() as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log") + with open(checkpoint_path, "w") as f: + f.write(str(i)) tune.report(test=i, training_iteration=i) analysis = tune.run(train, max_failures=3) diff --git a/python/ray/util/sgd/tests/test_torch_trainable.py b/python/ray/tune/tests/test_torch_trainable.py similarity index 71% rename from python/ray/util/sgd/tests/test_torch_trainable.py rename to python/ray/tune/tests/test_torch_trainable.py index 570827284..0ab43e54a 100644 --- a/python/ray/util/sgd/tests/test_torch_trainable.py +++ b/python/ray/tune/tests/test_torch_trainable.py @@ -6,8 +6,9 @@ import torch.distributed as dist import ray from ray import tune -from ray.util.sgd.torch.func_trainable import ( - DistributedTrainableCreator, distributed_checkpoint, _train_simple) +from ray.tune.integration.torch import (DistributedTrainableCreator, + distributed_checkpoint_dir, + _train_simple, _train_check_global) @pytest.fixture @@ -47,12 +48,29 @@ def test_step_after_completion(ray_start_2_cpus): # noqa: F811 trainer.train() +def test_validation(ray_start_2_cpus): # noqa: F811 + def bad_func(a, b, c): + return 1 + + with pytest.raises(ValueError): + DistributedTrainableCreator(bad_func, num_workers=2) + + +def test_set_global(ray_start_2_cpus): # noqa: F811 + trainable_cls = DistributedTrainableCreator( + _train_check_global, num_workers=2) + trainable = trainable_cls() + result = trainable.train() + assert result["is_distributed"] + + def test_save_checkpoint(ray_start_2_cpus): # noqa: F811 trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2) trainer = trainable_cls(config={"epochs": 1}) trainer.train() - path = trainer.save() - model_state_dict, opt_state_dict = torch.load(path) + checkpoint_dir = trainer.save() + model_state_dict, opt_state_dict = torch.load( + os.path.join(checkpoint_dir, "checkpoint")) trainer.stop() @@ -72,11 +90,11 @@ def test_simple_tune(ray_start_4_cpus, enabled_checkpoint): def test_checkpoint(ray_start_2_cpus, rank): # noqa: F811 with patch("torch.distributed.get_rank") as rank_method: rank_method.return_value = rank - with distributed_checkpoint(label="test") as path: + with distributed_checkpoint_dir(step="test") as path: if rank == 0: assert path - else: - assert path == os.devnull + if rank != 0: + assert not os.path.exists(path) if __name__ == "__main__": diff --git a/python/ray/util/sgd/BUILD b/python/ray/util/sgd/BUILD index 15ece2006..784016de9 100644 --- a/python/ray/util/sgd/BUILD +++ b/python/ray/util/sgd/BUILD @@ -26,14 +26,6 @@ py_test( deps = [":sgd_lib"], ) -py_test( - name = "test_torch_trainable", - size = "small", - srcs = ["tests/test_torch_trainable.py"], - tags = ["exclusive", "pytorch"], - deps = [":sgd_lib"], -) - # -------------------------------------------------------------------- # Tests from the python/ray/util/sgd/tf/examples directory. # Please keep these sorted alphabetically. diff --git a/python/ray/util/sgd/torch/__init__.py b/python/ray/util/sgd/torch/__init__.py index d37c27f6b..45e0ff00a 100644 --- a/python/ray/util/sgd/torch/__init__.py +++ b/python/ray/util/sgd/torch/__init__.py @@ -12,13 +12,8 @@ try: BaseTorchTrainable) from ray.util.sgd.torch.training_operator import TrainingOperator - from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator, - distributed_checkpoint) - __all__ = [ - "TorchTrainer", "BaseTorchTrainable", "TrainingOperator", - "distributed_checkpoint", "DistributedTrainableCreator" - ] + __all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"] except ImportError as e: logger.warning(e) logger.warning("PyTorch not found. TorchTrainer will not be available") diff --git a/python/ray/util/sgd/torch/func_trainable.py b/python/ray/util/sgd/torch/func_trainable.py index 3dc4ac443..e69de29bb 100644 --- a/python/ray/util/sgd/torch/func_trainable.py +++ b/python/ray/util/sgd/torch/func_trainable.py @@ -1,235 +0,0 @@ -# Original Code here: -# https://github.com/pytorch/examples/blob/master/mnist/main.py -import os -import logging -import torch -from datetime import timedelta - -import ray -from ray import tune -from ray.tune.result import RESULT_DUPLICATE -from ray.tune.logger import NoopLogger -from ray.tune.function_runner import wrap_function -from ray.tune.resources import Resources -from ray.tune.trainable import TrainableUtil -from ray.util.sgd.torch.utils import setup_process_group -from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S -from ray.util.sgd.torch.utils import setup_address - -logger = logging.getLogger(__name__) - - -def logger_creator(log_config, logdir, rank): - worker_dir = os.path.join(logdir, "worker_{}".format(rank)) - os.makedirs(worker_dir, exist_ok=True) - return NoopLogger(log_config, worker_dir) - - -class _TorchTrainable(tune.Trainable): - """Base class for distributed training on Tune. - - A wrapper class is needed to actually create a working - version of this trainable. - """ - _function = None - _num_workers = None - _use_gpu = None - _num_cpus_per_worker = None - - __slots__ = ["workers", "_finished"] - - @classmethod - def default_process_group_parameters(self): - return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo") - - @classmethod - def get_remote_worker_options(self): - num_gpus = 1 if self._use_gpu else 0 - num_cpus = int(self._num_cpus_per_worker or 1) - return dict(num_cpus=num_cpus, num_gpus=num_gpus) - - def setup(self, config): - self._finished = False - num_workers = self._num_workers - logdir = self.logdir - assert self._function - - func_trainable = wrap_function(self.__class__._function) - - remote_trainable = ray.remote(func_trainable) - remote_trainable = remote_trainable.options( - **self.get_remote_worker_options()) - - address = setup_address() - self.workers = [ - remote_trainable.remote( - config=config, - logger_creator=lambda cfg: logger_creator(cfg, logdir, rank)) - for rank in range(num_workers) - ] - - pgroup_params = self.default_process_group_parameters() - from functools import partial - setup_on_worker = partial( - setup_process_group, - url=address, - world_size=num_workers, - **pgroup_params) - ray.get([ - w.execute.remote(lambda _: setup_on_worker(world_rank=rank)) - for rank, w in enumerate(self.workers) - ]) - - def step(self): - if self._finished: - raise RuntimeError("Training has already finished.") - result = ray.get([w.step.remote() for w in self.workers])[0] - if RESULT_DUPLICATE in result: - self._finished = True - return result - - def save_checkpoint(self, checkpoint_dir): - # TODO: optimize if colocated - save_obj = ray.get(self.workers[0].save_to_object.remote()) - checkpoint_path = TrainableUtil.create_from_pickle( - save_obj, checkpoint_dir) - return checkpoint_path - - def load_checkpoint(self, checkpoint_dir): - checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir) - return ray.get( - w.restore_from_object.remote(checkpoint_obj) for w in self.workers) - - def stop(self): - ray.get([worker.stop.remote() for worker in self.workers]) - - -def DistributedTrainableCreator(func, - use_gpu=False, - num_workers=1, - num_cpus_per_worker=1, - backend="gloo", - timeout_s=NCCL_TIMEOUT_S): - """Creates a class that executes distributed training. - - Note that you typically should not instantiate the object - created. - - Example: - - .. code-block:: - - trainable_cls = DistributedTrainableCreator( - train_func, num_workers=2) - analysis = tune.run(trainable_cls) - """ - - class WrappedDistributedTorchTrainable(_TorchTrainable): - _function = func - _num_workers = num_workers - _use_gpu = use_gpu - _num_cpus_per_worker = num_cpus_per_worker - - @classmethod - def default_process_group_parameters(self): - return dict(timeout=timedelta(timeout_s), backend=backend) - - @classmethod - def default_resource_request(cls, config): - num_workers_ = int(config.get("num_workers", num_workers)) - num_cpus = int( - config.get("num_cpus_per_worker", num_cpus_per_worker)) - use_gpu_ = config.get("use_gpu", use_gpu) - - return Resources( - cpu=0, - gpu=0, - extra_cpu=num_cpus * num_workers_, - extra_gpu=num_workers_ if use_gpu_ else 0) - - return WrappedDistributedTorchTrainable - - -class distributed_checkpoint: - """ContextManager for creating a distributed checkpoint. - - Only checkpoints a file on the "main" training actor, avoiding - redundant work. - - Args: - label (int | str): Used to label the checkpoint - disable (bool): Disable for prototyping. - - Example: - - .. code-block:: - - if epoch % 3 == 0: - with distributed_checkpoint(label=epoch) as path: - torch.save(model.state_dict(), path) - """ - - def __init__(self, label, disable=False): - self.label = label - self.file = None - self.disable = disable - - def __enter__(self): - if torch.distributed.get_rank() == 0 and not self.disable: - checkpoint_dir = tune.make_checkpoint_dir(step=self.label) - path = os.path.join(checkpoint_dir, "checkpoint") - else: - path = os.devnull - self.file = path - return path - - def __exit__(self, type, value, traceback): - if torch.distributed.get_rank() == 0 and not self.disable: - tune.save_checkpoint(self.file) - - -def _train_simple(config, checkpoint=False): - """For testing only. Putting this here because Ray has problems - serializing within the test file.""" - import torch.nn as nn - from torch.nn.parallel import DistributedDataParallel - import torch.optim as optim - # N is batch size; D_in is input dimension; - # H is hidden dimension; D_out is output dimension. - N, D_in, H, D_out = 8, 5, 5, 5 - - # Create random Tensors to hold inputs and outputs - x = torch.randn(N, D_in) - y = torch.randn(N, D_out) - loss_fn = nn.MSELoss() - - # Use the nn package to define our model and loss function. - model = torch.nn.Sequential( - torch.nn.Linear(D_in, H), - torch.nn.ReLU(), - torch.nn.Linear(H, D_out), - ) - optimizer = optim.SGD(model.parameters(), lr=0.1) - - if checkpoint: - with open(checkpoint) as f: - model_state, optimizer_state = torch.load(f) - - model.load_state_dict(model_state) - optimizer.load_state_dict(optimizer_state) - - model = DistributedDataParallel(model) - - for epoch in range(config.get("epochs", 10)): - optimizer.zero_grad() - output = model(x) - loss = loss_fn(output, y) - loss.backward() - optimizer.step() - - if epoch % 3 == 0: - if config.get("enable_checkpoint", True): - with distributed_checkpoint(label=epoch) as path: - torch.save((model.state_dict(), optimizer.state_dict()), - path) - tune.report(mean_loss=loss.item()) diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 019e104cd..ac0eb2cf7 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -140,6 +140,8 @@ class TorchTrainer: Defaults to True. wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel over each model. If False, you are expected to call it yourself. + timeout_s (float): Seconds before the torch process group + times out. Useful when machines are unreliable. add_dist_sampler (bool): Whether to automatically add a DistributedSampler to all created dataloaders. Only applicable if num_workers > 1.