From 7b27ce2b236e29d82e33c5e07645f845d433973c Mon Sep 17 00:00:00 2001 From: Maksim Smolin Date: Fri, 27 Mar 2020 20:19:15 -0700 Subject: [PATCH] [RaySGD] Convert the head worker to a local model (#7746) Why are these changes needed? Running a worker on head (locally, not as a Ray actor) allows for easier handling of stateful stuff like logging and for easier debugging. --- python/ray/util/sgd/tests/test_torch.py | 81 +++-- python/ray/util/sgd/torch/constants.py | 3 +- .../sgd/torch/distributed_torch_runner.py | 75 ++++- .../torch/examples/cifar_pytorch_example.py | 4 +- python/ray/util/sgd/torch/examples/dcgan.py | 2 +- .../sgd/torch/examples/sgd-development.yaml | 16 +- python/ray/util/sgd/torch/torch_runner.py | 11 +- python/ray/util/sgd/torch/torch_trainer.py | 297 ++++++++++-------- python/ray/util/sgd/torch/tqdm_handler.py | 116 ------- .../ray/util/sgd/torch/training_operator.py | 37 ++- 10 files changed, 341 insertions(+), 301 deletions(-) delete mode 100644 python/ray/util/sgd/torch/tqdm_handler.py diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index 2f0674aff..d113849a1 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -41,6 +41,7 @@ def test_single_step(ray_start_2_cpus): # noqa: F811 val_metrics = trainer.validate(num_steps=1) assert val_metrics[BATCH_COUNT] == 1 + trainer.shutdown() @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) @@ -62,6 +63,7 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811 assert train_loss2 <= train_loss1, (train_loss2, train_loss1) assert validation_loss2 <= validation_loss1, (validation_loss2, validation_loss1) + trainer.shutdown() @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) @@ -278,6 +280,7 @@ def test_split_batch(ray_start_2_cpus): assert trainer.config[BATCH_SIZE] == (batch_size - 1) assert stats[NUM_SAMPLES] == 600 assert stats[BATCH_COUNT] == (data_size // 20) + trainer.shutdown() def test_reduce_result(ray_start_2_cpus): @@ -302,6 +305,7 @@ def test_reduce_result(ray_start_2_cpus): assert len(list_stats) == 2 assert [stats[NUM_SAMPLES] == data_size for stats in list_stats] assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats] + trainer.shutdown() @pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) @@ -388,6 +392,7 @@ def test_metrics_nan(ray_start_2_cpus, num_workers): assert "mean_score" in stats assert stats["last_score"] == 0 assert np.isnan(stats["mean_score"]) + trainer.shutdown() def test_scheduler_validate(ray_start_2_cpus): # noqa: F811 @@ -479,6 +484,7 @@ def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811 for k in model1_state_dict: assert torch.equal(model1_state_dict[k], model2_state_dict[k]) + trainer2.shutdown() def test_fail_with_recover(ray_start_2_cpus): # noqa: F811 @@ -490,15 +496,25 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811 return torch.utils.data.DataLoader( dataset, batch_size=config.get("batch_size", 32)) - def step_with_fail(self, *args, **kwargs): - worker_stats = [ - w.train_epoch.remote(*args, **kwargs) for w in self.workers + def step_with_fail(self, **params): + remote_worker_stats = [ + w.train_epoch.remote(**params) for w in self.remote_workers ] + if self._num_failures < 3: time.sleep(1) # Make the batch will fail correctly. - self.workers[0].__ray_kill__() - success = check_for_failure(worker_stats) - return success, worker_stats + ray.kill(self.remote_workers[0]) + + try: + local_worker_stats = self.local_worker.train_epoch(**params) + except RuntimeError: + return False, None + + success = check_for_failure(remote_worker_stats) + if success: + return success, [local_worker_stats] + ray.get(remote_worker_stats) + + return success, None with patch.object(TorchTrainer, "_train_epoch", step_with_fail): trainer1 = TorchTrainer( @@ -512,6 +528,8 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811 with pytest.raises(RuntimeError): trainer1.train(max_retries=1) + trainer1.shutdown(force=True) + def test_resize(ray_start_2_cpus): # noqa: F811 if not dist.is_available(): @@ -522,15 +540,25 @@ def test_resize(ray_start_2_cpus): # noqa: F811 return torch.utils.data.DataLoader( dataset, batch_size=config.get("batch_size", 32)) - def step_with_fail(self, *args, **kwargs): - worker_stats = [ - w.train_epoch.remote(*args, **kwargs) for w in self.workers + def step_with_fail(self, **params): + remote_worker_stats = [ + w.train_epoch.remote(**params) for w in self.remote_workers ] + if self._num_failures < 1: time.sleep(1) # Make the batch will fail correctly. - self.workers[0].__ray_kill__() - success = check_for_failure(worker_stats) - return success, worker_stats + self.remote_workers[0].__ray_kill__() + + try: + local_worker_stats = self.local_worker.train_epoch(**params) + except RuntimeError: + return False, None + + success = check_for_failure(remote_worker_stats) + if success: + return success, [local_worker_stats] + ray.get(remote_worker_stats) + + return success, None with patch.object(TorchTrainer, "_train_epoch", step_with_fail): trainer1 = TorchTrainer( @@ -548,7 +576,9 @@ def test_resize(ray_start_2_cpus): # noqa: F811 try_test.remote() trainer1.train(max_retries=1) - assert len(trainer1.workers) == 1 + assert len(trainer1.remote_workers) == 1 + + trainer1.shutdown() def test_fail_twice(ray_start_2_cpus): # noqa: F811 @@ -560,15 +590,25 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811 return torch.utils.data.DataLoader( dataset, batch_size=config.get("batch_size", 32)) - def step_with_fail(self, *args, **kwargs): - worker_stats = [ - w.train_epoch.remote(*args, **kwargs) for w in self.workers + def step_with_fail(self, **params): + remote_worker_stats = [ + w.train_epoch.remote(**params) for w in self.remote_workers ] + if self._num_failures < 2: - time.sleep(1) - self.workers[0].__ray_kill__() - success = check_for_failure(worker_stats) - return success, worker_stats + time.sleep(1) # Make the batch will fail correctly. + self.remote_workers[0].__ray_kill__() + + try: + local_worker_stats = self.local_worker.train_epoch(**params) + except RuntimeError: + return False, None + + success = check_for_failure(remote_worker_stats) + if success: + return success, [local_worker_stats] + ray.get(remote_worker_stats) + + return success, None with patch.object(TorchTrainer, "_train_epoch", step_with_fail): trainer1 = TorchTrainer( @@ -580,6 +620,7 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811 num_workers=2) trainer1.train(max_retries=2) + trainer1.shutdown() if __name__ == "__main__": diff --git a/python/ray/util/sgd/torch/constants.py b/python/ray/util/sgd/torch/constants.py index 2f8e8fcc8..e3673c8d8 100644 --- a/python/ray/util/sgd/torch/constants.py +++ b/python/ray/util/sgd/torch/constants.py @@ -1,7 +1,8 @@ USE_FP16 = "__use_fp16__" +NUM_STEPS = "__num_steps__" SCHEDULER_STEP = "scheduler_step" SCHEDULER_STEP_BATCH = "batch" SCHEDULER_STEP_EPOCH = "epoch" -BATCH_LOGS_RATE_LIMIT = .2 +NCCL_TIMEOUT_IN_SECONDS = 10 VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH} diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index 6f051647c..73ac6b412 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -1,12 +1,17 @@ +from datetime import timedelta import collections import logging +import os + import torch import torch.nn as nn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from ray.util.sgd.torch.constants import NCCL_TIMEOUT_IN_SECONDS +import ray from ray.util.sgd.torch.torch_runner import TorchRunner logger = logging.getLogger(__name__) @@ -45,11 +50,20 @@ class DistributedTorchRunner(TorchRunner): logger.debug("Connecting to {} world_rank: {} world_size: {}".format( url, world_rank, world_size)) logger.debug("using {}".format(self.backend)) + + if self.backend == "nccl" and "NCCL_BLOCKING_WAIT" not in os.environ: + logger.debug( + "Setting NCCL_BLOCKING_WAIT for detecting node failure. " + "To override this behavior, you can set NCCL_BLOCKING_WAIT=0.") + os.environ["NCCL_BLOCKING_WAIT"] = "1" + + timeout = timedelta(seconds=NCCL_TIMEOUT_IN_SECONDS) dist.init_process_group( backend=self.backend, init_method=url, rank=world_rank, - world_size=world_size) + world_size=world_size, + timeout=timeout) def _setup_training(self): logger.debug("Creating model") @@ -84,7 +98,8 @@ class DistributedTorchRunner(TorchRunner): validation_loader=self.validation_loader, world_rank=self.world_rank, schedulers=self.schedulers, - use_fp16=self.use_fp16) + use_fp16=self.use_fp16, + use_tqdm=self.use_tqdm) def _initialize_dataloaders(self): super(DistributedTorchRunner, self)._initialize_dataloaders() @@ -140,12 +155,60 @@ class DistributedTorchRunner(TorchRunner): for model, model_state_dict in zip(self.models, model_state_dicts): model.module.load_state_dict(model_state_dict) - # def shutdown(self): + def shutdown(self): """Attempts to shut down the worker.""" - # super(DistributedTorchRunner, self).shutdown() - # TODO: Temporarily removing since it causes hangs on MacOSX. # However, it seems to be harmless to remove permanently # since the processes are shutdown anyways. This comment can be # removed in a future release if it is still not documented # the stable Pytorch docs. - # dist.destroy_process_group() + dist.destroy_process_group() + super(DistributedTorchRunner, self).shutdown() + + +class _DummyActor: + def cuda_devices(self): + return os.environ["CUDA_VISIBLE_DEVICES"] + + +# This is a bit of a hack. It prevents the reassignment of CUDA_VISIBLE_DEVICES +# during a trainer resize. We won't need this if we don't shutdown +# all the actors. +_dummy_actor = None + + +class LocalDistributedRunner(DistributedTorchRunner): + """A wrapper for running a distributed Runner on the driver. + + A dummy actor is used to reserve resources on the driver node, + as specified by `num_cpus` and `num_gpus`. If the Trainer is already + in an actor, we will ignore this resource request. + """ + + def __init__(self, *args, num_cpus=None, num_gpus=None, **kwargs): + ip = ray.services.get_node_ip_address() + + # Reserve a local GPU or CPU for the local worker + # TODO: we should make sure this NEVER dies. + + global _dummy_actor + if not self.is_actor() and _dummy_actor is None: + _dummy_actor = ray.remote( + num_cpus=num_cpus, + num_gpus=num_gpus, + resources={"node:" + ip: 0.1})(_DummyActor).remote() + + head_cuda = ray.get(_dummy_actor.cuda_devices.remote()) + os.environ["CUDA_VISIBLE_DEVICES"] = head_cuda + super(LocalDistributedRunner, self).__init__(*args, **kwargs) + + def shutdown(self, cleanup=True): + super(LocalDistributedRunner, self).shutdown() + global _dummy_actor + if cleanup and _dummy_actor: + assert not self.is_actor(), "Actor shouldn't have a dummy actor." + ray.kill(_dummy_actor) + _dummy_actor = None + + def is_actor(self): + actor_id = ray.worker.global_worker.actor_id + return actor_id != actor_id.nil() diff --git a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py index 642155a9e..f23aa361f 100644 --- a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py +++ b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py @@ -85,14 +85,14 @@ def train_example(num_workers=1, backend="nccl" if use_gpu else "gloo", scheduler_step_freq="epoch", use_fp16=use_fp16, - tqdm=True) + use_tqdm=True) pbar = trange(num_epochs, unit="epoch") for i in pbar: info = {"num_steps": 1} if test_mode else {} info["epoch_idx"] = i info["num_epochs"] = num_epochs # Increase `max_retries` to turn on fault tolerance. - stats = trainer1.train(max_retries=0, info=info) + stats = trainer1.train(max_retries=1, info=info) pbar.set_postfix(dict(loss=stats["mean_train_loss"])) print(trainer1.validate()) diff --git a/python/ray/util/sgd/torch/examples/dcgan.py b/python/ray/util/sgd/torch/examples/dcgan.py index 6fcfda809..6cb07c3a6 100644 --- a/python/ray/util/sgd/torch/examples/dcgan.py +++ b/python/ray/util/sgd/torch/examples/dcgan.py @@ -243,7 +243,7 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False): config=config, use_gpu=use_gpu, backend="nccl" if use_gpu else "gloo", - tqdm=True) + use_tqdm=True) from tabulate import tabulate pbar = trange(5, unit="epoch") diff --git a/python/ray/util/sgd/torch/examples/sgd-development.yaml b/python/ray/util/sgd/torch/examples/sgd-development.yaml index e6697a272..bd96b6a22 100644 --- a/python/ray/util/sgd/torch/examples/sgd-development.yaml +++ b/python/ray/util/sgd/torch/examples/sgd-development.yaml @@ -3,9 +3,9 @@ cluster_name: sgd-pytorch # The maximum number of workers nodes to launch in addition to the head # node. This takes precedence over min_workers. min_workers default to 0. -min_workers: 0 -initial_workers: 0 -max_workers: 0 +min_workers: 2 +initial_workers: 2 +max_workers: 2 target_utilization_fraction: 0.9 @@ -27,11 +27,13 @@ auth: # ssh_private_key: ... head_node: - InstanceType: p3dn.24xlarge + InstanceType: p3.2xlarge ImageId: ami-0698bcaf8bd9ef56d # KeyName: ... InstanceMarketOptions: MarketType: spot + SpotOptions: + BlockDurationMinutes: 360 BlockDeviceMappings: - DeviceName: /dev/sda1 Ebs: @@ -41,11 +43,13 @@ head_node: worker_nodes: - InstanceType: p3.16xlarge + InstanceType: p3.8xlarge ImageId: ami-0698bcaf8bd9ef56d # KeyName: ... InstanceMarketOptions: MarketType: spot + SpotOptions: + BlockDurationMinutes: 360 BlockDeviceMappings: - DeviceName: /dev/sda1 Ebs: @@ -65,7 +69,7 @@ setup_commands: # Installing this without -U to make sure we don't replace the existing Ray installation - pip install ray[rllib] - - pip install -U ipdb torch torchvision + - pip install -U ipdb torch torchvision tqdm # Install Apex - rm -rf apex || true - git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir ./ || true diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 0a784748e..1d5b595f2 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -8,7 +8,7 @@ import tempfile import torch import ray -from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP +from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP, NUM_STEPS from ray.util.sgd.torch.training_operator import TrainingOperator from ray.util.sgd import utils @@ -49,6 +49,7 @@ class TorchRunner: training_operator_cls=None, config=None, use_fp16=False, + use_tqdm=False, apex_args=None, scheduler_step_freq="batch"): self.model_creator = model_creator @@ -68,6 +69,7 @@ class TorchRunner: self.train_loader = None self.validation_loader = None self.use_fp16 = use_fp16 + self.use_tqdm = use_tqdm self.apex_args = apex_args or {} if use_fp16 and not amp: raise ImportError( @@ -133,9 +135,6 @@ class TorchRunner: self.models, self.optimizers = amp.initialize( self.models, self.optimizers, **self.apex_args) - def set_reporters(self, reporters): - return self.training_operator.set_reporters(reporters) - def setup(self): """Initializes the model.""" logger.debug("Creating model") @@ -163,7 +162,8 @@ class TorchRunner: validation_loader=self.validation_loader, world_rank=0, schedulers=self.schedulers, - use_fp16=self.use_fp16) + use_fp16=self.use_fp16, + use_tqdm=self.use_tqdm) def get_node_ip(self): """Returns the IP address of the current node.""" @@ -180,6 +180,7 @@ class TorchRunner: self._toggle_profiling(profile=profile) info.update({ + NUM_STEPS: num_steps, USE_FP16: self.use_fp16, SCHEDULER_STEP: self.scheduler_step_freq }) diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 0dc6c0c88..7f698994f 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -4,22 +4,18 @@ import logging import numbers import tempfile import time -import asyncio import torch import torch.distributed as dist import ray - from ray.exceptions import RayActorError from ray.tune import Trainable from ray.tune.trial import Resources from ray.util.sgd.torch.distributed_torch_runner import ( - DistributedTorchRunner) + DistributedTorchRunner, LocalDistributedRunner) from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE from ray.util.sgd.torch.torch_runner import TorchRunner -from ray.util.sgd.torch.constants import (VALID_SCHEDULER_STEP, - BATCH_LOGS_RATE_LIMIT) -from ray.util.sgd.torch.tqdm_handler import TqdmHandler +from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP logger = logging.getLogger(__name__) RESIZE_COOLDOWN_S = 10 @@ -149,7 +145,7 @@ class TorchTrainer: use_gpu=False, backend="auto", use_fp16=False, - tqdm=False, + use_tqdm=False, apex_args=None, scheduler_step_freq="batch", num_replicas=None, @@ -212,6 +208,7 @@ class TorchTrainer: self.max_replicas = num_workers self.use_fp16 = use_fp16 + self.use_tqdm = use_tqdm if apex_args and not isinstance(apex_args, dict): raise ValueError("apex_args needs to be a dict object.") @@ -221,10 +218,6 @@ class TorchTrainer: self._num_failures = 0 self._last_resize = float("-inf") - self.handlers = [] - if tqdm: - self.handlers.append(TqdmHandler()) - _validate_scheduler_step_freq(scheduler_step_freq) self.scheduler_step_freq = scheduler_step_freq @@ -256,68 +249,71 @@ class TorchTrainer: batch_size_per_worker = self._configure_and_split_batch(num_workers) if batch_size_per_worker: worker_config[BATCH_SIZE] = batch_size_per_worker + + self.local_worker = None + self.remote_workers = [] + if num_workers == 1: - # Generate actor class - Runner = ray.remote( - num_cpus=1, num_gpus=int(self.use_gpu))(TorchRunner) - # Start workers - self.workers = [ - Runner.remote( - model_creator=self.model_creator, - data_creator=self.data_creator, - optimizer_creator=self.optimizer_creator, - loss_creator=self.loss_creator, - scheduler_creator=self.scheduler_creator, - training_operator_cls=self.training_operator_cls, - config=worker_config, - use_fp16=self.use_fp16, - apex_args=self.apex_args, - scheduler_step_freq=self.scheduler_step_freq, - ) - ] + # Start local worker + self.local_worker = TorchRunner( + model_creator=self.model_creator, + data_creator=self.data_creator, + optimizer_creator=self.optimizer_creator, + loss_creator=self.loss_creator, + scheduler_creator=self.scheduler_creator, + training_operator_cls=self.training_operator_cls, + config=worker_config, + use_fp16=self.use_fp16, + use_tqdm=self.use_tqdm, + apex_args=self.apex_args, + scheduler_step_freq=self.scheduler_step_freq) + if self.initialization_hook: self.apply_all_workers(self.initialization_hook) - # Get setup tasks in order to throw errors on failure - ray.get(self.workers[0].setup.remote()) - ray.get(self.workers[0].set_reporters.remote( - [h.create_reporter() for h in self.handlers])) + + self.local_worker.setup() else: + params = dict( + model_creator=self.model_creator, + data_creator=self.data_creator, + optimizer_creator=self.optimizer_creator, + loss_creator=self.loss_creator, + scheduler_creator=self.scheduler_creator, + backend=self.backend, + training_operator_cls=self.training_operator_cls, + config=worker_config, + use_fp16=self.use_fp16, + use_tqdm=self.use_tqdm, + apex_args=self.apex_args, + scheduler_step_freq=self.scheduler_step_freq) + + # Start local worker + self.local_worker = LocalDistributedRunner( + num_cpus=1, num_gpus=int(self.use_gpu), **params) + # Generate actor class - Runner = ray.remote( + RemoteRunner = ray.remote( num_cpus=1, num_gpus=int(self.use_gpu))(DistributedTorchRunner) # Start workers - self.workers = [ - Runner.remote( - model_creator=self.model_creator, - data_creator=self.data_creator, - optimizer_creator=self.optimizer_creator, - loss_creator=self.loss_creator, - scheduler_creator=self.scheduler_creator, - backend=self.backend, - training_operator_cls=self.training_operator_cls, - config=worker_config, - use_fp16=self.use_fp16, - apex_args=self.apex_args, - scheduler_step_freq=self.scheduler_step_freq) - for i in range(num_workers) + self.remote_workers = [ + RemoteRunner.remote(**params) for i in range(num_workers - 1) ] if self.initialization_hook: self.apply_all_workers(self.initialization_hook) # Compute URL for initializing distributed PyTorch - ip = ray.get(self.workers[0].get_node_ip.remote()) - port = ray.get(self.workers[0].find_free_port.remote()) + ip = ray.services.get_node_ip_address() + port = self.local_worker.find_free_port() + address = "tcp://{ip}:{port}".format(ip=ip, port=port) + + remote_setups = [ + worker.setup.remote(address, i + 1, num_workers) + for i, worker in enumerate(self.remote_workers) + ] + self.local_worker.setup(address, 0, num_workers) # Get setup tasks in order to throw errors on failure - ray.get([ - worker.setup.remote(address, i, len(self.workers)) - for i, worker in enumerate(self.workers) - ]) - ray.get([ - w.set_reporters.remote( - [h.create_reporter() for h in self.handlers]) - for w in self.workers - ]) + ray.get(remote_setups) def train(self, num_steps=None, @@ -374,9 +370,6 @@ class TorchTrainer: logger.info("Resize opportunity detected. Attempting to scale up.") self._resize_workers(checkpoint=checkpoint) - for h in self.handlers: - h.record_train_info(info, num_steps) - success, worker_stats = self._train_epoch( num_steps=num_steps, profile=profile, info=info) # Fault handling @@ -386,14 +379,13 @@ class TorchTrainer: else: self._num_failures += 1 self._resize_workers(checkpoint=checkpoint) - logger.info( - "Retrying training step with %d workers." % len(self.workers)) + logger.info("Retrying training step with %d workers." % + (len(self.remote_workers) + 1)) success, worker_stats = self._train_epoch( num_steps=num_steps, profile=profile, info=info) if not success: raise RuntimeError("Training run failed.") - worker_stats = ray.get(worker_stats) if reduce_results: return self._process_stats(worker_stats) else: @@ -413,42 +405,30 @@ class TorchTrainer: stats[stat_key] = worker_stats[0][stat_key] return stats - def _train_epoch(self, - num_steps=None, - profile=False, - info=None, - batch_logs_handler=None): - worker_trains = [ - w.train_epoch.remote( - num_steps=num_steps, profile=profile, info=info) - for w in self.workers + def _train_epoch(self, num_steps=None, profile=False, info=None): + params = dict(num_steps=num_steps, profile=profile, info=info) + + remote_worker_stats = [ + w.train_epoch.remote(**params) for w in self.remote_workers ] - if not self.handlers: - success = check_for_failure(worker_trains) - return success, worker_trains - - unfinished = worker_trains try: - while len(unfinished) > 0: - finished, unfinished = ray.wait( - unfinished, timeout=BATCH_LOGS_RATE_LIMIT) + local_worker_stats = self.local_worker.train_epoch(**params) + except RuntimeError as err: + if "gloo" in err.args[0] and "Timed out" in err.args[0]: + logger.warning(err) + return False, None + if "NCCL" in err.args[0]: # there is no specific error message + logger.warning(err) + return False, None - # throw errors on agent failure - finished = ray.get(finished) + raise err - futures = [h.update() for h in self.handlers] - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(asyncio.wait(futures)) - loop.close() + success = check_for_failure(remote_worker_stats) + if success: + return success, [local_worker_stats] + ray.get(remote_worker_stats) - return True, worker_trains - except RayActorError as exc: - logger.exception(str(exc)) - return False, worker_trains + return success, None def apply_all_workers(self, fn): """Run a function on all operators on the workers. @@ -460,7 +440,9 @@ class TorchTrainer: A list of objects returned by ``fn`` on each worker. """ - return ray.get([w.apply.remote(fn) for w in self.workers]) + remote_calls = [w.apply.remote(fn) for w in self.remote_workers] + local_call = self.local_worker.apply(fn) + return [local_call] + ray.get(remote_calls) def apply_all_operators(self, fn): """Run a function on all operators on the workers. @@ -473,7 +455,11 @@ class TorchTrainer: A list of objects returned by ``fn`` on each operator. """ - return ray.get([w.apply_operator.remote(fn) for w in self.workers]) + remote_calls = [ + w.apply_operator.remote(fn) for w in self.remote_workers + ] + local_call = self.local_worker.apply_operator(fn) + return [local_call] + ray.get(remote_calls) def validate(self, num_steps=None, profile=False, info=None): """Evaluates the model on the validation data set. @@ -491,12 +477,15 @@ class TorchTrainer: You can provide custom metrics by passing in a custom ``training_operator_cls``. """ - worker_stats = ray.get([ - w.validate.remote(num_steps=num_steps, profile=profile, info=info) - for w in self.workers - ]) + params = dict(num_steps=num_steps, profile=profile, info=info) - return self._process_stats(worker_stats) + remote_worker_stats = [ + w.validate.remote(**params) for w in self.remote_workers + ] + local_worker_stats = self.local_worker.validate(**params) + + return self._process_stats([local_worker_stats] + + ray.get(remote_worker_stats)) def update_scheduler(self, metric): """Calls ``scheduler.step(metric)`` on all schedulers. @@ -509,7 +498,7 @@ class TorchTrainer: def get_model(self): """Returns the learned model(s).""" models = self.model_creator(self.config) - state = ray.get(self.workers[0].get_state.remote()) + state = self.local_worker.get_state() if len(state["models"]) == 1: models.load_state_dict(state["models"][0]) else: @@ -517,6 +506,18 @@ class TorchTrainer: model.load_state_dict(state_dict) return models + def state_dict(self): + return self.local_worker.get_state() + + def load_state_dict(self, state): + state_id = ray.put(state) + + remote_calls = [ + worker.set_state.remote(state_id) for worker in self.remote_workers + ] + self.local_worker.set_state(state) + ray.get(remote_calls) + def save(self, checkpoint): """Saves the model(s) to the provided checkpoint. @@ -526,8 +527,7 @@ class TorchTrainer: Returns: checkpoint (str): Path to target checkpoint file. """ - state = ray.get(self.workers[0].get_state.remote()) - torch.save(state, checkpoint) + torch.save(self.state_dict(), checkpoint) return checkpoint def restore(self, checkpoint): @@ -537,36 +537,67 @@ class TorchTrainer: checkpoint (str): Path to target checkpoint file. """ state = torch.load(checkpoint) - state_id = ray.put(state) - ray.get([worker.set_state.remote(state_id) for worker in self.workers]) + self.load_state_dict(state) def shutdown(self, force=False): """Shuts down workers and releases resources.""" if not force: - cleanup = [worker.shutdown.remote() for worker in self.workers] - ray.get(cleanup) - [worker.__ray_terminate__.remote() for worker in self.workers] - else: - for worker in self.workers: - logger.warning("Killing worker {}.".format(worker)) - worker.__ray_kill__() + cleanup = [ + worker.shutdown.remote() for worker in self.remote_workers + ] + self.local_worker.shutdown() + try: + ray.get(cleanup) + [ + worker.__ray_terminate__.remote() + for worker in self.remote_workers + ] + except RayActorError: + logger.warning( + "Failed to shutdown gracefully, forcing a shutdown.") - self.workers = [] + for worker in self.remote_workers: + logger.warning("Killing worker {}.".format(worker)) + ray.kill(worker) + else: + self.local_worker.shutdown() + for worker in self.remote_workers: + logger.warning("Killing worker {}.".format(worker)) + ray.kill(worker) + + self.local_worker = None + self.remote_workers = [] + + def _reset(self): + """Terminates models without giving up local resource reservation.""" + self.local_worker.shutdown(cleanup=False) + for worker in self.remote_workers: + logger.warning("Killing worker {}.".format(worker)) + ray.kill(worker) + self.local_worker = None + self.remote_workers = [] + + def _check_potential_remote_workers_size(self): + # ASSUME 1 GPU + 1 CPU is already reserved for the local worker + remote_resources = ray.available_resources() + max_remote_workers = self.max_replicas - 1 + new_remote_workers = min( + remote_resources.get("CPU", 0), max_remote_workers) + if self.use_gpu: + new_remote_workers = min( + remote_resources.get("GPU", 0), new_remote_workers) + return new_remote_workers def _resize_workers(self, checkpoint, max_retries=10): - # check available resources - self.shutdown(force=True) + self._reset() assert checkpoint, "Cannot restore without checkpoint." time.sleep(1) for i in range(max_retries): - resources = ray.available_resources() - new_workers = min(resources.get("CPU", 0), self.max_replicas) - if self.use_gpu: - new_workers = min(resources.get("GPU", 0), new_workers) - if new_workers: + new_remote_workers = self._check_potential_remote_workers_size() + if new_remote_workers: self._last_resize = time.time() - self._start_workers(int(new_workers)) + self._start_workers(int(new_remote_workers) + 1) self.restore(checkpoint) return else: @@ -578,26 +609,24 @@ class TorchTrainer: def _should_resize(self): """Returns True if past cooldown and exists resources to scale up.""" - worker_gap = self.max_replicas - len(self.workers) + worker_gap = self.max_replicas - 1 - len(self.remote_workers) past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S if past_cooldown and worker_gap: - resources = ray.available_resources() - potential_workers = min(resources.get("CPU", 0), self.max_replicas) - if self.use_gpu: - potential_workers = min( - resources.get("GPU", 0), potential_workers) - return potential_workers > 0 + # Assume 1 resource is already reserved for local worker. + potential_remote_size = self._check_potential_remote_workers_size() + return potential_remote_size > 0 return False class TorchTrainable(Trainable): @classmethod def default_resource_request(cls, config): + remote_worker_count = config["num_workers"] - 1 return Resources( - cpu=0, - gpu=0, - extra_cpu=config["num_workers"], - extra_gpu=int(config["use_gpu"]) * config["num_workers"]) + cpu=1, + gpu=int(config["use_gpu"]), + extra_cpu=int(remote_worker_count), + extra_gpu=int(int(config["use_gpu"]) * remote_worker_count)) def _setup(self, config): self._trainer = TorchTrainer(**config) diff --git a/python/ray/util/sgd/torch/tqdm_handler.py b/python/ray/util/sgd/torch/tqdm_handler.py deleted file mode 100644 index 7e25cac4f..000000000 --- a/python/ray/util/sgd/torch/tqdm_handler.py +++ /dev/null @@ -1,116 +0,0 @@ -import asyncio -import time - -from tqdm import tqdm - -import ray -from ray.util.sgd.torch.constants import BATCH_LOGS_RATE_LIMIT - - -@ray.remote(num_cpus=0) -class _ReporterActor: - def __init__(self): - # we need the new_data field to allow sending back None as the legs - self._logs = {"new_data": False, "data": None} - self._setup = {"new_data": False, "data": None} - - def _send_setup(self, data): - self._setup = {"new_data": True, "data": data} - - def _send_logs(self, data): - self._logs = {"new_data": True, "data": data} - - def _read_logs(self): - res = self._logs - - self._logs = {"new_data": False, "data": None} - - return res - - def _read_setup(self): - res = self._setup - - self._setup = {"new_data": False, "data": None} - - return res - - -class TqdmReporter: - def __init__(self, actor): - self.actor = actor - - self.last_packet_time = 0 - - def _send_setup(self, packet): - ray.get(self.actor._send_setup.remote(packet)) - - def _send_logs(self, packet): - cur_time = time.monotonic() - if cur_time - self.last_packet_time < BATCH_LOGS_RATE_LIMIT: - return - - self.last_packet_time = cur_time - ray.get(self.actor._send_logs.remote(packet)) - - def on_epoch_begin(self, info, training_op): - if training_op.world_rank != 0: - return - - self.last_packet_time = 0 - - self._send_setup({"loader_len": len(training_op.train_loader)}) - - def on_batch_end(self, batch_info, metrics, training_op): - if training_op.world_rank != 0: - return - - pbar_metrics = {} - if "train_loss" in metrics: - pbar_metrics["loss"] = metrics["train_loss"] - - self._send_logs({ - "batch_idx": batch_info["batch_idx"], - "pbar_metrics": pbar_metrics - }) - - -class TqdmHandler: - def __init__(self): - self.batch_pbar = None - self.reporter_actor = _ReporterActor.remote() - - def create_reporter(self): - return TqdmReporter(self.reporter_actor) - - def handle_setup_packet(self, packet): - n = self.num_steps - if n is None: - n = packet["loader_len"] - - desc = "" - if self.train_info is not None and "epoch_idx" in self.train_info: - if "num_epochs" in self.train_info: - desc = "{}/{}e".format(self.train_info["epoch_idx"] + 1, - self.train_info["num_epochs"]) - else: - desc = "{}e".format(self.train_info["epoch_idx"] + 1) - - self.batch_pbar = tqdm(total=n, desc=desc, unit="batch", leave=False) - - def handle_logs_packet(self, packet): - self.batch_pbar.n = packet["batch_idx"] + 1 - self.batch_pbar.set_postfix(packet["pbar_metrics"]) - - def record_train_info(self, info, num_steps): - self.train_info = info - self.num_steps = num_steps - - async def update(self): - setup_read, logs_read = await asyncio.gather( - self.reporter_actor._read_setup.remote(), - self.reporter_actor._read_logs.remote()) - - if setup_read["new_data"]: - self.handle_setup_packet(setup_read["data"]) - if logs_read["new_data"]: - self.handle_logs_packet(logs_read["data"]) diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index c641be579..ecd72e213 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -1,10 +1,11 @@ import collections +from tqdm import tqdm import torch from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection, NUM_SAMPLES) -from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, +from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH, NUM_STEPS, SCHEDULER_STEP_BATCH, SCHEDULER_STEP) amp = None @@ -55,7 +56,8 @@ class TrainingOperator: world_rank, criterion=None, schedulers=None, - use_fp16=False): + use_fp16=False, + use_tqdm=False): # You are not expected to override this method. self._models = models # List of models assert isinstance(models, collections.Iterable), ( @@ -74,6 +76,7 @@ class TrainingOperator: type(schedulers))) self._config = config self._use_fp16 = use_fp16 + self._use_tqdm = use_tqdm self.global_step = 0 if type(self) is TrainingOperator: @@ -84,12 +87,8 @@ class TrainingOperator: "TrainingOperator if using multi-scheduler, " "multi-model or multi-optimizer training/validation.") self.timers = TimerCollection() - self.reporters = [] self.setup(config) - def set_reporters(self, reporters): - self.reporters = reporters - def _set_timers(self, timers): """Passes in the timers from the Runner.""" self.timers = timers @@ -142,8 +141,19 @@ class TrainingOperator: Returns: A dict of metrics from training. """ - for r in self.reporters: - r.on_epoch_begin(info, self) + if self.use_tqdm and self.world_rank == 0: + desc = "" + if info is not None and "epoch_idx" in info: + if "num_epochs" in info: + desc = "{}/{}e".format(info["epoch_idx"] + 1, + info["num_epochs"]) + else: + desc = "{}e".format(info["epoch_idx"] + 1) + _progress_bar = tqdm( + total=info[NUM_STEPS] or len(self.train_loader), + desc=desc, + unit="batch", + leave=False) metric_meters = AverageMeterCollection() @@ -156,8 +166,10 @@ class TrainingOperator: batch_info.update(info) metrics = self.train_batch(batch, batch_info=batch_info) - for r in self.reporters: - r.on_batch_end(batch_info, metrics, self) + if self.use_tqdm and self.world_rank == 0: + _progress_bar.n = batch_idx + 1 + if "train_loss" in metrics: + _progress_bar.set_postfix({"loss": metrics["train_loss"]}) if self.scheduler and batch_info.get( SCHEDULER_STEP) == SCHEDULER_STEP_BATCH: @@ -376,6 +388,11 @@ class TrainingOperator: """Whether the model and optimizer have been FP16 enabled.""" return self._use_fp16 + @property + def use_tqdm(self): + """Whether tqdm progress bars are enabled.""" + return self._use_tqdm + class _TestingOperator(TrainingOperator): def train_epoch(self, iterator, info):