diff --git a/doc/source/raysgd/raysgd_pytorch.rst b/doc/source/raysgd/raysgd_pytorch.rst index 9c7c48864..e354b7c6b 100644 --- a/doc/source/raysgd/raysgd_pytorch.rst +++ b/doc/source/raysgd/raysgd_pytorch.rst @@ -333,6 +333,8 @@ and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load`` checkpoint_path = os.path.join(tempfile.mkdtemp(), "checkpoint") trainer_1.save(checkpoint_path) + # You can only have 1 trainer alive at a time + trainer_1.shutdown() trainer_2 = TorchTrainer( model_creator=model_creator, @@ -340,7 +342,7 @@ and ``trainer.load``, which wraps the relevant ``torch.save`` and ``torch.load`` optimizer_creator=optimizer_creator, loss_creator=nn.MSELoss, num_workers=num_workers) - trainer_2.restore(checkpoint_path) + trainer_2.load(checkpoint_path) Retrieving the model @@ -442,19 +444,12 @@ During each ``train`` method, each parallel worker iterates through the iterable 5. If there are no available resources, the Trainer will apply an exponential backoff before retrying to create workers. 6. If there are available resources and the Trainer has fewer workers than initially specified, then it will scale up its worker pool until it reaches the initially specified ``num_workers``. -Note that we assume the Trainer itself is not on a pre-emptible node. It is currently not possible to recover from a Trainer node failure. - -Users can set ``checkpoint="auto"`` to always checkpoint the current model before executing a pass over the training iterable. - -.. code-block:: python - - trainer.train(max_retries=N, checkpoint="auto") - +Note that we assume the Trainer itself is not on a pre-emptible node. To allow the entire Trainer to recover from failure, you must use Tune to execute the training. Advanced: Hyperparameter Tuning ------------------------------- -``TorchTrainer`` naturally integrates with Tune via the ``TorchTrainable`` interface. The same arguments to ``TorchTrainer`` should be passed into the ``tune.run(config=...)`` as shown below. +``TorchTrainer`` naturally integrates with Tune via the ``BaseTorchTrainable`` interface. Without changing any arguments, you can call ``TorchTrainer.as_trainable(model_creator...)`` to create a Tune-compatible class. See the documentation (:ref:`BaseTorchTrainable-doc`). .. literalinclude:: ../../../python/ray/util/sgd/torch/examples/tune_example.py :language: python diff --git a/doc/source/raysgd/raysgd_ref.rst b/doc/source/raysgd/raysgd_ref.rst index 8babd4d8e..1bc532046 100644 --- a/doc/source/raysgd/raysgd_ref.rst +++ b/doc/source/raysgd/raysgd_ref.rst @@ -1,5 +1,5 @@ -Package Reference -================= +RaySGD API Documentation +======================== .. _ref-torch-trainer: @@ -19,12 +19,14 @@ PyTorch TrainingOperator .. autoclass:: ray.util.sgd.torch.TrainingOperator :members: +.. _BaseTorchTrainable-doc: -TorchTrainable --------------- +BaseTorchTrainable +------------------ -.. autoclass:: ray.util.sgd.torch.TorchTrainable +.. autoclass:: ray.util.sgd.torch.BaseTorchTrainable :members: + :private-members: TFTrainer --------- diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index d113849a1..1f3a3e06e 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -1,7 +1,6 @@ -import os -import tempfile from unittest.mock import patch import numpy as np +import os import pytest import time import torch @@ -10,7 +9,7 @@ import torch.distributed as dist import ray from ray import tune -from ray.util.sgd.torch import TorchTrainer, TorchTrainable +from ray.util.sgd.torch import TorchTrainer from ray.util.sgd.torch.training_operator import (_TestingOperator, _TestMetricsOperator) from ray.util.sgd.torch.constants import SCHEDULER_STEP @@ -27,6 +26,9 @@ def ray_start_2_cpus(): yield address_info # The code after the yield will run as teardown code. ray.shutdown() + # Ensure that tests don't ALL fail + if dist.is_initialized(): + dist.destroy_process_group() def test_single_step(ray_start_2_cpus): # noqa: F811 @@ -44,6 +46,19 @@ def test_single_step(ray_start_2_cpus): # noqa: F811 trainer.shutdown() +def test_dead_trainer(ray_start_2_cpus): # noqa: F811 + trainer = TorchTrainer( + model_creator=model_creator, + data_creator=data_creator, + optimizer_creator=optimizer_creator, + loss_creator=lambda config: nn.MSELoss(), + num_workers=2) + trainer.train(num_steps=1) + trainer.shutdown() + with pytest.raises(RuntimeError): + trainer.train() + + @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) def test_train(ray_start_2_cpus, num_workers): # noqa: F811 trainer = TorchTrainer( @@ -53,12 +68,12 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811 loss_creator=lambda config: nn.MSELoss(), num_workers=num_workers) for i in range(3): - train_loss1 = trainer.train()["mean_train_loss"] - validation_loss1 = trainer.validate()["mean_val_loss"] + train_loss1 = trainer.train()["train_loss"] + validation_loss1 = trainer.validate()["val_loss"] for i in range(3): - train_loss2 = trainer.train()["mean_train_loss"] - validation_loss2 = trainer.validate()["mean_val_loss"] + train_loss2 = trainer.train()["train_loss"] + validation_loss2 = trainer.validate()["val_loss"] assert train_loss2 <= train_loss1, (train_loss2, train_loss1) assert validation_loss2 <= validation_loss1, (validation_loss2, @@ -118,9 +133,7 @@ def test_multi_model(ray_start_2_cpus, num_workers): training_operator_cls=_TestingOperator, num_workers=num_workers) trainer1.train() - - filename = os.path.join(tempfile.mkdtemp(), "checkpoint") - trainer1.save(filename) + state = trainer1.state_dict() models1 = trainer1.get_model() @@ -134,9 +147,7 @@ def test_multi_model(ray_start_2_cpus, num_workers): config={"custom_func": train_epoch}, training_operator_cls=_TestingOperator, num_workers=num_workers) - trainer2.restore(filename) - - os.remove(filename) + trainer2.load_state_dict(state) models2 = trainer2.get_model() @@ -336,20 +347,20 @@ def test_metrics(ray_start_2_cpus, num_workers): stats = trainer.train(num_steps=num_train_steps) # Test that we output mean and last of custom metrics in an epoch - assert "mean_score" in stats + assert "score" in stats assert stats["last_score"] == 0 assert stats[NUM_SAMPLES] == num_train_steps * batch_size expected_score = num_workers * (sum(train_scores) / (num_train_steps * batch_size)) - assert np.allclose(stats["mean_score"], expected_score) + assert np.allclose(stats["score"], expected_score) val_stats = trainer.validate() # Test that we output mean and last of custom metrics in validation assert val_stats["last_score"] == 0 expected_score = (sum(val_scores) / (num_val_steps * batch_size)) * num_workers - assert np.allclose(val_stats["mean_score"], expected_score) + assert np.allclose(val_stats["score"], expected_score) assert val_stats[BATCH_COUNT] == np.ceil(num_val_steps / num_workers) assert val_stats[NUM_SAMPLES] == num_val_steps * batch_size assert val_stats[NUM_SAMPLES] == val_size @@ -384,14 +395,14 @@ def test_metrics_nan(ray_start_2_cpus, num_workers): training_operator_cls=_TestMetricsOperator) stats = trainer.train(num_steps=num_train_steps) - assert "mean_score" in stats + assert "score" in stats assert stats["last_score"] == 0 - assert np.isnan(stats["mean_score"]) + assert np.isnan(stats["score"]) stats = trainer.validate() - assert "mean_score" in stats + assert "score" in stats assert stats["last_score"] == 0 - assert np.isnan(stats["mean_score"]) + assert np.isnan(stats["score"]) trainer.shutdown() @@ -415,41 +426,41 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811 - - config = { - "model_creator": model_creator, - "data_creator": data_creator, - "optimizer_creator": optimizer_creator, - "loss_creator": lambda config: nn.MSELoss(), - "num_workers": num_workers, - "use_gpu": False, - "backend": "gloo", - "config": { - "batch_size": 512, - "lr": 0.001 - } - } + TorchTrainable = TorchTrainer.as_trainable( + **{ + "model_creator": model_creator, + "data_creator": data_creator, + "optimizer_creator": optimizer_creator, + "loss_creator": lambda config: nn.MSELoss(), + "num_workers": num_workers, + "use_gpu": False, + "backend": "gloo", + "config": { + "batch_size": 512, + "lr": 0.001 + } + }) analysis = tune.run( TorchTrainable, num_samples=2, - config=config, stop={"training_iteration": 2}, verbose=1) # checks loss decreasing for every trials for path, df in analysis.trial_dataframes.items(): - mean_train_loss1 = df.loc[0, "mean_train_loss"] - mean_train_loss2 = df.loc[1, "mean_train_loss"] - mean_val_loss1 = df.loc[0, "mean_val_loss"] - mean_val_loss2 = df.loc[1, "mean_val_loss"] + mean_train_loss1 = df.loc[0, "train_loss"] + mean_train_loss2 = df.loc[1, "train_loss"] + mean_val_loss1 = df.loc[0, "val_loss"] + mean_val_loss2 = df.loc[1, "val_loss"] assert mean_train_loss2 <= mean_train_loss1 assert mean_val_loss2 <= mean_val_loss1 @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) -def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811 +def test_save_and_restore(ray_start_2_cpus, num_workers, + tmp_path): # noqa: F811 trainer1 = TorchTrainer( model_creator=model_creator, data_creator=data_creator, @@ -457,9 +468,8 @@ def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811 loss_creator=lambda config: nn.MSELoss(), num_workers=num_workers) trainer1.train() - - filename = os.path.join(tempfile.mkdtemp(), "checkpoint") - trainer1.save(filename) + checkpoint_path = os.path.join(tmp_path, "checkpoint") + trainer1.save(checkpoint_path) model1 = trainer1.get_model() @@ -471,9 +481,7 @@ def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811 optimizer_creator=optimizer_creator, loss_creator=lambda config: nn.MSELoss(), num_workers=num_workers) - trainer2.restore(filename) - - os.remove(filename) + trainer2.load(checkpoint_path) model2 = trainer2.get_model() @@ -619,7 +627,8 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811 loss_creator=lambda config: nn.MSELoss(), num_workers=2) - trainer1.train(max_retries=2) + # MAX RETRIES SHOULD BE ON BY DEFAULT + trainer1.train() trainer1.shutdown() diff --git a/python/ray/util/sgd/torch/__init__.py b/python/ray/util/sgd/torch/__init__.py index 64121a239..c3772de18 100644 --- a/python/ray/util/sgd/torch/__init__.py +++ b/python/ray/util/sgd/torch/__init__.py @@ -2,16 +2,17 @@ import logging logger = logging.getLogger(__name__) TorchTrainer = None -TorchTrainable = None TrainingOperator = None +BaseTorchTrainable = None try: import torch # noqa: F401 - from ray.util.sgd.torch.torch_trainer import (TorchTrainer, TorchTrainable) + from ray.util.sgd.torch.torch_trainer import (TorchTrainer, + BaseTorchTrainable) from ray.util.sgd.torch.training_operator import TrainingOperator - __all__ = ["TorchTrainer", "TorchTrainable", "TrainingOperator"] + __all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"] except ImportError: logger.warning("PyTorch not found. TorchTrainer will not be available") diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index 73ac6b412..8abc3ff47 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -142,14 +142,7 @@ class DistributedTorchRunner(TorchRunner): This is needed for PyTorch DistributedDataParallel models. """ - cpu_state_dicts = [] - for model in self.models: - state_dict = model.module.state_dict() - # This is so that we create a duplicate of weights into CPU rather - # than move the model weights out of the GPU so that we can - # resume training while saving intermediate checkpoints. - cpu_state_dicts += [{k: v.cpu() for k, v in state_dict.items()}] - return cpu_state_dicts + return [model.module.state_dict() for model in self.models] def _set_model_state_dicts(self, model_state_dicts): for model, model_state_dict in zip(self.models, model_state_dicts): @@ -212,3 +205,10 @@ class LocalDistributedRunner(DistributedTorchRunner): def is_actor(self): actor_id = ray.worker.global_worker.actor_id return actor_id != actor_id.nil() + + +class DeactivatedRunner: + def __getattr__(self, *args, **kwargs): + raise RuntimeError( + "This TorchTrainer is not active (it is likely shutdown already). " + "Create a new TorchTrainer.") diff --git a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py index f23aa361f..a3fb03690 100644 --- a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py +++ b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py @@ -1,8 +1,10 @@ +import numpy as np import os import torch import torch.nn as nn import argparse from ray import tune +from ray.tune.schedulers import PopulationBasedTraining from torch.utils.data import DataLoader, Subset import torchvision import torchvision.transforms as transforms @@ -10,17 +12,20 @@ import torchvision.transforms as transforms from tqdm import trange import ray -from ray.util.sgd.torch import (TorchTrainer, TorchTrainable) +from ray.tune import CLIReporter +from ray.util.sgd.torch import TorchTrainer from ray.util.sgd.torch.resnet import ResNet18 from ray.util.sgd.utils import BATCH_SIZE def initialization_hook(): - print("NCCL DEBUG SET") - # Need this for avoiding a connection restart issue + # Need this for avoiding a connection restart issue on AWS. os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo" os.environ["NCCL_LL_THRESHOLD"] = "0" - os.environ["NCCL_DEBUG"] = "INFO" + + # set the below if needed + # print("NCCL DEBUG SET") + # os.environ["NCCL_DEBUG"] = "INFO" def cifar_creator(config): @@ -55,7 +60,10 @@ def cifar_creator(config): def optimizer_creator(model, config): """Returns optimizer""" - return torch.optim.SGD(model.parameters(), lr=config.get("lr", 0.1)) + return torch.optim.SGD( + model.parameters(), + lr=config.get("lr", 0.1), + momentum=config.get("momentum", 0.9)) def scheduler_creator(optimizer, config): @@ -77,12 +85,11 @@ def train_example(num_workers=1, initialization_hook=initialization_hook, num_workers=num_workers, config={ - "lr": 0.01, - "test_mode": test_mode, - BATCH_SIZE: 128, + "lr": 0.1, + "test_mode": test_mode, # user-defined param to subset the data + BATCH_SIZE: 128 * num_workers # this will be split across workers. }, use_gpu=use_gpu, - backend="nccl" if use_gpu else "gloo", scheduler_step_freq="epoch", use_fp16=use_fp16, use_tqdm=True) @@ -92,39 +99,65 @@ def train_example(num_workers=1, info["epoch_idx"] = i info["num_epochs"] = num_epochs # Increase `max_retries` to turn on fault tolerance. - stats = trainer1.train(max_retries=1, info=info) - pbar.set_postfix(dict(loss=stats["mean_train_loss"])) + trainer1.train(max_retries=1, info=info) + val_stats = trainer1.validate() + pbar.set_postfix(dict(acc=val_stats["val_accuracy"])) print(trainer1.validate()) trainer1.shutdown() print("success!") -def tune_example(num_workers=1, use_gpu=False, test_mode=False): - config = { - "model_creator": ResNet18, - "data_creator": cifar_creator, - "optimizer_creator": optimizer_creator, - "loss_creator": nn.CrossEntropyLoss, - "num_workers": num_workers, - "initialization_hook": initialization_hook, - "use_gpu": use_gpu, - "config": { - "lr": tune.choice([1e-4, 1e-3]), - BATCH_SIZE: 128, - "test_mode": test_mode +def tune_example(num_workers=1, use_gpu=False, use_fp16=False, + test_mode=False): + TorchTrainable = TorchTrainer.as_trainable( + model_creator=ResNet18, + data_creator=cifar_creator, + optimizer_creator=optimizer_creator, + loss_creator=nn.CrossEntropyLoss, + scheduler_creator=scheduler_creator, + initialization_hook=initialization_hook, + num_workers=num_workers, + config={ + "test_mode": test_mode, # user-defined param to subset the data + BATCH_SIZE: 128 * num_workers, }, - "backend": "nccl" if use_gpu else "gloo" - } + use_gpu=use_gpu, + scheduler_step_freq="epoch", + use_fp16=use_fp16) + + pbt_scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="val_loss", + mode="min", + perturbation_interval=1, + hyperparam_mutations={ + # distribution for resampling + "lr": lambda: np.random.uniform(0.001, 1), + # allow perturbations within this set of categorical values + "momentum": [0.8, 0.9, 0.99], + }) + + reporter = CLIReporter() + reporter.add_metric_column("val_loss", "loss") + reporter.add_metric_column("val_accuracy", "acc") analysis = tune.run( TorchTrainable, - num_samples=2, - config=config, - stop={"training_iteration": 2}, - verbose=2) + num_samples=4, + config={ + "lr": tune.choice([0.001, 0.01, 0.1]), + "momentum": 0.8 + }, + stop={"training_iteration": 2 if test_mode else 100}, + max_failures=3, # used for fault tolerance + checkpoint_freq=3, # used for fault tolerance + keep_checkpoints_num=1, # used for fault tolerance + verbose=2, + progress_reporter=reporter, + scheduler=pbt_scheduler) - return analysis.get_best_config(metric="mean_accuracy", mode="max") + return analysis.get_best_config(metric="val_loss", mode="min") if __name__ == "__main__": diff --git a/python/ray/util/sgd/torch/examples/dcgan.py b/python/ray/util/sgd/torch/examples/dcgan.py index 6cb07c3a6..bb21ce429 100644 --- a/python/ray/util/sgd/torch/examples/dcgan.py +++ b/python/ray/util/sgd/torch/examples/dcgan.py @@ -242,16 +242,13 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False): num_workers=num_workers, config=config, use_gpu=use_gpu, - backend="nccl" if use_gpu else "gloo", use_tqdm=True) from tabulate import tabulate pbar = trange(5, unit="epoch") for itr in pbar: stats = trainer.train(info=dict(epoch_idx=itr, num_epochs=5)) - pbar.set_postfix( - dict(loss_g=stats["mean_loss_g"], loss_d=stats["mean_loss_d"])) - + pbar.set_postfix(dict(loss_g=stats["loss_g"], loss_d=stats["loss_d"])) formatted = tabulate([stats], headers="keys") if itr > 0: # Get the last line of the stats. formatted = formatted.split("\n")[-1] diff --git a/python/ray/util/sgd/torch/examples/tune_example.py b/python/ray/util/sgd/torch/examples/tune_example.py index a4bcc5362..fcf62d9ca 100644 --- a/python/ray/util/sgd/torch/examples/tune_example.py +++ b/python/ray/util/sgd/torch/examples/tune_example.py @@ -14,7 +14,8 @@ import torch.nn as nn import ray from ray import tune -from ray.util.sgd.torch.torch_trainer import TorchTrainable +from ray.util.sgd.torch import TorchTrainer +from ray.util.sgd.utils import BATCH_SIZE class LinearDataset(torch.utils.data.Dataset): @@ -48,30 +49,29 @@ def data_creator(config): val_dataset = LinearDataset(2, 5, size=400) train_loader = torch.utils.data.DataLoader( train_dataset, - batch_size=config["batch_size"], + batch_size=config[BATCH_SIZE], ) validation_loader = torch.utils.data.DataLoader( val_dataset, - batch_size=config["batch_size"]) + batch_size=config[BATCH_SIZE]) return train_loader, validation_loader def tune_example(num_workers=1, use_gpu=False): - config = { - "model_creator": model_creator, - "data_creator": data_creator, - "optimizer_creator": optimizer_creator, - "loss_creator": nn.MSELoss, - "num_workers": num_workers, - "use_gpu": use_gpu, - "config": {"batch_size": 512 // num_workers}, - "backend": "gloo" - } + TorchTrainable = TorchTrainer.as_trainable( + model_creator=model_creator, + data_creator=data_creator, + optimizer_creator=optimizer_creator, + loss_creator=nn.MSELoss, + num_workers=num_workers, + use_gpu=use_gpu, + config={BATCH_SIZE: 128} + ) analysis = tune.run( TorchTrainable, - num_samples=12, - config=config, + num_samples=3, + config={"lr": tune.grid_search([1e-4, 1e-3])}, stop={"training_iteration": 2}, verbose=1) diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 1d5b595f2..56a888ecc 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -2,6 +2,7 @@ import collections from filelock import FileLock import logging import inspect +import io import itertools import os import tempfile @@ -225,22 +226,14 @@ class TorchRunner: self.training_operator._set_timers(self.timers) def _get_model_state_dicts(self): - # This is so that we create a duplicate of weights into CPU rather than - # move the model weights entirely out of the GPU, so that we can - # resume training while saving intermediate checkpoints. - cpu_state_dicts = [] - for model in self.models: - state_dict = model.state_dict() - cpu_state_dicts += [{k: v.cpu() for k, v in state_dict.items()}] - return cpu_state_dicts + return [model.state_dict() for model in self.models] def _set_model_state_dicts(self, models_state_dicts): for model, state_dict in zip(self.models, models_state_dicts): model.load_state_dict(state_dict) - def get_state(self): + def state_dict(self): """Returns the state of the runner.""" - state = { "epoch": self.epochs, "operator": self.training_operator.state_dict(), @@ -258,9 +251,8 @@ class TorchRunner: state.update({"amp": amp.state_dict()}) return state - def set_state(self, state): + def load_state_dict(self, state): """Sets the state of the model.""" - # TODO: restore timer stats self._set_model_state_dicts(state["models"]) for optimizer, state_dict in zip(self.optimizers, state["optimizers"]): optimizer.load_state_dict(state_dict) @@ -274,6 +266,19 @@ class TorchRunner: self.epochs = state["epoch"] self.training_operator.load_state_dict(state_dict) + def state_stream(self): + """Returns a bytes object for the state dict.""" + state_dict = self.state_dict() + _buffer = io.BytesIO() + torch.save(state_dict, _buffer) + return _buffer.getvalue() + + def load_state_stream(self, byte_obj): + """Loads a bytes object the training state dict.""" + _buffer = io.BytesIO(byte_obj) + state_dict = torch.load(_buffer) + return self.load_state_dict(state_dict) + def apply(self, fn): return fn() diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 7f698994f..79c951a85 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -1,6 +1,6 @@ import numpy as np -import os import logging +import os import numbers import tempfile import time @@ -10,9 +10,10 @@ import torch.distributed as dist import ray from ray.exceptions import RayActorError from ray.tune import Trainable -from ray.tune.trial import Resources +from ray.tune.resources import Resources +from ray.tune.utils.util import merge_dicts from ray.util.sgd.torch.distributed_torch_runner import ( - DistributedTorchRunner, LocalDistributedRunner) + DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner) from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE from ray.util.sgd.torch.torch_runner import TorchRunner from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP @@ -130,6 +131,11 @@ class TorchTrainer: """ + # TODO: Implement autoscaling. If num_workers=-1, the trainer will use as + # many resources as available. Upon each train call, TorchTrainer will + # query the Ray global state for total available resources and resize + # its remote workers to consume all available resources. + def __init__( self, *, @@ -218,6 +224,9 @@ class TorchTrainer: self._num_failures = 0 self._last_resize = float("-inf") + self.local_worker = DeactivatedRunner() + self.remote_workers = [] + _validate_scheduler_step_freq(scheduler_step_freq) self.scheduler_step_freq = scheduler_step_freq @@ -250,9 +259,6 @@ class TorchTrainer: if batch_size_per_worker: worker_config[BATCH_SIZE] = batch_size_per_worker - self.local_worker = None - self.remote_workers = [] - if num_workers == 1: # Start local worker self.local_worker = TorchRunner( @@ -319,8 +325,7 @@ class TorchTrainer: num_steps=None, profile=False, reduce_results=True, - max_retries=0, - checkpoint="auto", + max_retries=3, info=None): """Runs a training epoch. @@ -339,14 +344,12 @@ class TorchTrainer: all workers into one dict. If a metric is a non-numerical value (or nested dictionaries), one value will be randomly selected among the workers. If False, returns a list of dicts. - max_retries (int): Must be non-negative. If set to N, will - kill all current workers, query the Ray global state for - total available resources, and re-launch up to the - available resources. Behavior is not well-defined - in case of shared cluster usage. - checkpoint (str): Path to checkpoint to restore from if retrying. - If max_retries is set and ``checkpoint == "auto"``, - TorchTrainer will save a checkpoint before starting to train. + max_retries (int): Must be non-negative. If set to N, TorchTrainer + will detect and recover from training failure. The recovery + process will kill all current workers, query the Ray + global state for total available resources, and re-launch up to + the available resources. Behavior is not well-defined + in case of shared cluster usage. Defaults to 3. info (dict): Optional dictionary passed to the training operator for ``train_epoch`` and ``train_batch``. @@ -358,18 +361,9 @@ class TorchTrainer: length will be equal to ``num_workers``. """ assert max_retries >= 0, "`max_retries` must be non-negative." - if max_retries: - if checkpoint == "auto": - logger.debug("Retrying detected. Automatically checkpointing.") - checkpoint = self.save( - os.path.join(self.temp_dir, "tmp_checkpoint")) - elif not checkpoint: - raise ValueError("Cannot retry from empty checkpoint.") - - if checkpoint and self._should_resize(): + if self._should_resize(): logger.info("Resize opportunity detected. Attempting to scale up.") - self._resize_workers(checkpoint=checkpoint) - + self._resize_workers() success, worker_stats = self._train_epoch( num_steps=num_steps, profile=profile, info=info) # Fault handling @@ -378,7 +372,7 @@ class TorchTrainer: break else: self._num_failures += 1 - self._resize_workers(checkpoint=checkpoint) + self._resize_workers() logger.info("Retrying training step with %d workers." % (len(self.remote_workers) + 1)) success, worker_stats = self._train_epoch( @@ -483,7 +477,6 @@ class TorchTrainer: w.validate.remote(**params) for w in self.remote_workers ] local_worker_stats = self.local_worker.validate(**params) - return self._process_stats([local_worker_stats] + ray.get(remote_worker_stats)) @@ -497,47 +490,49 @@ class TorchTrainer: def get_model(self): """Returns the learned model(s).""" - models = self.model_creator(self.config) - state = self.local_worker.get_state() - if len(state["models"]) == 1: - models.load_state_dict(state["models"][0]) - else: - for model, state_dict in zip(models, state["models"]): - model.load_state_dict(state_dict) - return models + unwrapped = [] + for model in self.local_worker.models: + unwrapped += [model.module if hasattr(model, "module") else model] + if len(unwrapped) == 1: + return unwrapped[0] + return unwrapped def state_dict(self): - return self.local_worker.get_state() + return self.local_worker.state_dict() - def load_state_dict(self, state): - state_id = ray.put(state) + def load_state_dict(self, state_dict, blocking=False): + # This is not the most efficient because you have to wait for + # the local worker to save then dump to buffer. + self.local_worker.load_state_dict(state_dict) + state_id = ray.put(self.local_worker.state_stream()) remote_calls = [ - worker.set_state.remote(state_id) for worker in self.remote_workers + worker.load_state_stream.remote(state_id) + for worker in self.remote_workers ] - self.local_worker.set_state(state) - ray.get(remote_calls) + if blocking: + ray.get(remote_calls) def save(self, checkpoint): - """Saves the model(s) to the provided checkpoint. + """Saves the Trainer state to the provided checkpoint path. Args: checkpoint (str): Path to target checkpoint file. - - Returns: - checkpoint (str): Path to target checkpoint file. """ torch.save(self.state_dict(), checkpoint) return checkpoint - def restore(self, checkpoint): - """Restores the Trainer and all workers from the provided checkpoint. + def load(self, checkpoint): + """Loads the Trainer and all workers from the provided checkpoint. Args: checkpoint (str): Path to target checkpoint file. """ - state = torch.load(checkpoint) - self.load_state_dict(state) + state_dict = torch.load(checkpoint) + self.load_state_dict(state_dict) + + def restore(self, *args): + raise DeprecationWarning("Use `TorchTrainer.load()` instead.") def shutdown(self, force=False): """Shuts down workers and releases resources.""" @@ -562,19 +557,19 @@ class TorchTrainer: else: self.local_worker.shutdown() for worker in self.remote_workers: - logger.warning("Killing worker {}.".format(worker)) + logger.debug("Killing worker {}.".format(worker)) ray.kill(worker) - self.local_worker = None + self.local_worker = DeactivatedRunner() self.remote_workers = [] def _reset(self): """Terminates models without giving up local resource reservation.""" self.local_worker.shutdown(cleanup=False) for worker in self.remote_workers: - logger.warning("Killing worker {}.".format(worker)) + logger.debug("Killing worker {}.".format(worker)) ray.kill(worker) - self.local_worker = None + self.local_worker = DeactivatedRunner() self.remote_workers = [] def _check_potential_remote_workers_size(self): @@ -588,9 +583,8 @@ class TorchTrainer: remote_resources.get("GPU", 0), new_remote_workers) return new_remote_workers - def _resize_workers(self, checkpoint, max_retries=10): + def _resize_workers(self, max_retries=10): self._reset() - assert checkpoint, "Cannot restore without checkpoint." time.sleep(1) for i in range(max_retries): @@ -598,7 +592,7 @@ class TorchTrainer: if new_remote_workers: self._last_resize = time.time() self._start_workers(int(new_remote_workers) + 1) - self.restore(checkpoint) + self.load_state_dict(self.state_dict()) return else: delay = 2**i @@ -617,32 +611,128 @@ class TorchTrainer: return potential_remote_size > 0 return False - -class TorchTrainable(Trainable): @classmethod - def default_resource_request(cls, config): - remote_worker_count = config["num_workers"] - 1 - return Resources( - cpu=1, - gpu=int(config["use_gpu"]), - extra_cpu=int(remote_worker_count), - extra_gpu=int(int(config["use_gpu"]) * remote_worker_count)) + def as_trainable(cls, *args, **kwargs): + """Creates a BaseTorchTrainable class compatible with Tune. + + Any configuration parameters will be overriden by the Tune + Trial configuration. You can also subclass the provided Trainable + to implement your own iterative optimization routine. + + .. code-block:: python + + TorchTrainable = TorchTrainer.as_trainable( + model_creator=ResNet18, + data_creator=cifar_creator, + optimizer_creator=optimizer_creator, + loss_creator=nn.CrossEntropyLoss, + num_gpus=2 + ) + analysis = tune.run( + TorchTrainable, + config={"lr": tune.grid_search([0.01, 0.1])} + ) + + """ + + class TorchTrainable(BaseTorchTrainable): + @classmethod + def default_resource_request(cls, config): + num_workers = config.get("num_workers", + kwargs.get("num_workers", 1)) + use_gpu = config.get("use_gpu", kwargs.get("use_gpu")) + + remote_worker_count = num_workers - 1 + + return Resources( + cpu=1, + gpu=int(use_gpu), + extra_cpu=int(remote_worker_count), + extra_gpu=int(int(use_gpu) * remote_worker_count)) + + def _create_trainer(self, tune_config): + """Overrides the provided config with Tune config.""" + provided_config = kwargs.get("config", {}).copy() + provided_config.update(tune_config) + kwargs["config"] = provided_config + trainer = TorchTrainer(*args, **kwargs) + return trainer + + return TorchTrainable + + +class BaseTorchTrainable(Trainable): + """Base class for converting TorchTrainer to a Trainable class. + + This class is produced when you call ``TorchTrainer.as_trainable(...)``. + + You can override the produced Trainable to implement custom iterative + training procedures: + + .. code-block:: python + + TorchTrainable = TorchTrainer.as_trainable( + model_creator=ResNet18, + data_creator=cifar_creator, + optimizer_creator=optimizer_creator, + loss_creator=nn.CrossEntropyLoss, + num_gpus=2 + ) + # TorchTrainable is subclass of BaseTorchTrainable. + + class CustomTrainable(TorchTrainable): + def _train(self): + for i in range(5): + train_stats = self.trainer.train() + validation_stats = self.trainer.validate() + train_stats.update(validation_stats) + return train_stats + + analysis = tune.run( + CustomTrainable, + config={"lr": tune.grid_search([0.01, 0.1])} + ) + + """ def _setup(self, config): - self._trainer = TorchTrainer(**config) + """Constructs a TorchTrainer object as `self.trainer`.""" + self._trainer = self._create_trainer(config) def _train(self): - train_stats = self._trainer.train() - validation_stats = self._trainer.validate() + """Calls `self.trainer.train()` and `self.trainer.validate()` once. - train_stats.update(validation_stats) - return train_stats + You may want to override this if using a custom LR scheduler. + """ + train_stats = self.trainer.train(max_retries=10, profile=True) + validation_stats = self.trainer.validate(profile=True) + stats = merge_dicts(train_stats, validation_stats) + return stats def _save(self, checkpoint_dir): - return self._trainer.save(os.path.join(checkpoint_dir, "model.pth")) + """Returns a path containing the trainer state.""" + checkpoint_path = os.path.join(checkpoint_dir, "trainer.checkpoint") + self.trainer.save(checkpoint_path) + return checkpoint_path def _restore(self, checkpoint_path): - return self._trainer.restore(checkpoint_path) + """Restores the trainer state. + + Override this if you have state external to the Trainer object. + """ + return self.trainer.load(checkpoint_path) def _stop(self): - self._trainer.shutdown() + """Shuts down the trainer.""" + self.trainer.shutdown() + + def _create_trainer(self, config): + raise NotImplementedError + + @property + def trainer(self): + """An instantiated TorchTrainer object. + + Use this when specifying custom training procedures for Tune. + """ + return self._trainer diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index ecd72e213..a35a4886a 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -168,8 +168,10 @@ class TrainingOperator: if self.use_tqdm and self.world_rank == 0: _progress_bar.n = batch_idx + 1 + postfix = {} if "train_loss" in metrics: - _progress_bar.set_postfix({"loss": metrics["train_loss"]}) + postfix.update(loss=metrics["train_loss"]) + _progress_bar.set_postfix(postfix) if self.scheduler and batch_info.get( SCHEDULER_STEP) == SCHEDULER_STEP_BATCH: @@ -259,7 +261,7 @@ class TrainingOperator: Returns: A dict of metrics from the evaluation. - By default, returns "mean_accuracy" and "mean_val_loss" + By default, returns "val_accuracy" and "val_loss" which is computed by aggregating "loss" and "correct" values from ``validate_batch`` and dividing it by the sum of ``num_samples`` from all calls to ``self.validate_batch``. diff --git a/python/ray/util/sgd/utils.py b/python/ray/util/sgd/utils.py index e99b01a52..5d2b1e44b 100644 --- a/python/ray/util/sgd/utils.py +++ b/python/ray/util/sgd/utils.py @@ -192,7 +192,7 @@ class AverageMeterCollection: """Returns a dict of average and most recent values for each metric.""" stats = {BATCH_COUNT: self._batch_count, NUM_SAMPLES: self.n} for metric, meter in self._meters.items(): - stats["mean_" + str(metric)] = meter.avg + stats[str(metric)] = meter.avg stats["last_" + str(metric)] = meter.val return stats