[sgd] fault tolerance for pytorch + revamp documentation (#6465)

This commit is contained in:
Richard Liaw
2020-01-16 18:38:27 -08:00
committed by GitHub
parent e5ad4e6f8d
commit 232be5a058
12 changed files with 365 additions and 60 deletions
@@ -88,11 +88,18 @@ class DistributedPyTorchRunner(PyTorchRunner):
def get_state(self):
"""Returns the state of the runner."""
# This is so that we create a duplicate of weights into CPU rather than
# move the model weights entirely out of the GPU, so that we can
# resume training while saving intermediate checkpoints.
cpu_state_dicts = []
for model in self.models:
state_dict = model.module.state_dict()
for k, v in state_dict.items():
state_dict[k] = v.cpu()
cpu_state_dicts += [state_dict]
return {
"epoch": self.epoch,
"models": [
model.module.cpu().state_dict() for model in self.models
],
"models": cpu_state_dicts,
"optimizers": [opt.state_dict() for opt in self.optimizers],
"stats": self.stats()
}
@@ -45,6 +45,7 @@ def data_creator(batch_size, config):
]))
# Create the dataloader
train_sampler = None
if distributed.is_initialized():
train_sampler = DistributedSampler(dataset)
dataloader = torch.utils.data.DataLoader(
@@ -238,8 +239,8 @@ def train_example(num_replicas=1, use_gpu=False, test_mode=False):
use_gpu=use_gpu,
batch_size=16 if test_mode else 512,
backend="nccl" if use_gpu else "gloo")
for i in range(5):
stats = trainer.train()
for i in range(10):
stats = trainer.train(max_retries=3)
print(stats)
return trainer
@@ -4,6 +4,8 @@ import torch
import torch.distributed as dist
import logging
import numbers
import tempfile
import time
import ray
@@ -15,6 +17,7 @@ from ray.experimental.sgd import utils
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
logger = logging.getLogger(__name__)
RESIZE_COOLDOWN_S = 10
class PyTorchTrainer:
@@ -74,8 +77,12 @@ class PyTorchTrainer:
"https://github.com/pytorch/examples/issues/467."))
self.model_creator = model_creator
self.data_creator = data_creator
self.train_function = train_function
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.validation_function = validation_function
self.initialization_hook = initialization_hook
self.config = {} if config is None else config
self.optimizer_timer = utils.TimerStat(window_size=1)
@@ -83,58 +90,69 @@ class PyTorchTrainer:
backend = "nccl" if use_gpu else "gloo"
logger.info("Using {} as backend.".format(backend))
self.backend = backend
self.use_gpu = use_gpu
self.batch_size = batch_size
self.max_replicas = num_replicas
self.temp_dir = tempfile.mkdtemp(prefix="raysgd")
self._num_failures = 0
self._last_resize = float("-inf")
self._start_workers(self.max_replicas)
def _start_workers(self, num_replicas):
logger.info(f"start_workers: Setting %d replicas." % num_replicas)
if num_replicas == 1:
# Generate actor class
Runner = ray.remote(
num_cpus=1, num_gpus=int(use_gpu))(PyTorchRunner)
num_cpus=1, num_gpus=int(self.use_gpu))(PyTorchRunner)
# Start workers
self.workers = [
Runner.remote(
model_creator,
data_creator,
optimizer_creator,
loss_creator,
train_function=train_function,
validation_function=validation_function,
self.model_creator,
self.data_creator,
self.optimizer_creator,
self.loss_creator,
train_function=self.train_function,
validation_function=self.validation_function,
config=self.config,
batch_size=batch_size)
batch_size=self.batch_size)
]
if initialization_hook:
self.apply_all_workers(initialization_hook)
if self.initialization_hook:
self.apply_all_workers(self.initialization_hook)
# Get setup tasks in order to throw errors on failure
ray.get(self.workers[0].setup.remote())
else:
# Generate actor class
Runner = ray.remote(
num_cpus=1, num_gpus=int(use_gpu))(DistributedPyTorchRunner)
num_cpus=1,
num_gpus=int(self.use_gpu))(DistributedPyTorchRunner)
# Compute batch size per replica
batch_size_per_replica = batch_size // num_replicas
if batch_size % num_replicas > 0:
batch_size_per_replica = self.batch_size // num_replicas
if self.batch_size % num_replicas > 0:
new_batch_size = batch_size_per_replica * num_replicas
logger.warning(
("Changing batch size from {old_batch_size} to "
"{new_batch_size} to evenly distribute batches across "
"{num_replicas} replicas.").format(
old_batch_size=batch_size,
old_batch_size=self.batch_size,
new_batch_size=new_batch_size,
num_replicas=num_replicas))
# Start workers
self.workers = [
Runner.remote(
model_creator,
data_creator,
optimizer_creator,
loss_creator,
backend=backend,
train_function=train_function,
validation_function=validation_function,
self.model_creator,
self.data_creator,
self.optimizer_creator,
self.loss_creator,
backend=self.backend,
train_function=self.train_function,
validation_function=self.validation_function,
config=self.config,
batch_size=batch_size_per_replica)
for i in range(num_replicas)
]
if initialization_hook:
self.apply_all_workers(initialization_hook)
if self.initialization_hook:
self.apply_all_workers(self.initialization_hook)
# Compute URL for initializing distributed PyTorch
ip = ray.get(self.workers[0].get_node_ip.remote())
@@ -146,13 +164,51 @@ class PyTorchTrainer:
for i, worker in enumerate(self.workers)
])
def train(self):
def train(self, max_retries=10, checkpoint="auto"):
"""Runs a training epoch.
Runs an average over all values returned from workers.
Runs an average over all values returned from workers. Set
`max_retries` to enable fault handling in case of instance preemption.
Args:
max_retries (int): Must be non-negative. If set to N, will
kill all current workers, query the Ray global state for
total available resources, and re-launch up to the
available resources. Behavior is not well-defined
in case of shared cluster usage.
checkpoint (str): Path to checkpoint to restore from if retrying.
If max_retries is set and checkpoint == "auto", PyTorchTrainer
will save a checkpoint before starting to train.
"""
assert max_retries >= 0, "`max_retries` must be non-negative."
if max_retries:
if checkpoint == "auto":
logger.debug("Retrying detected. Automatically checkpointing.")
checkpoint = self.save(
os.path.join(self.temp_dir, "tmp_checkpoint"))
elif not checkpoint:
raise ValueError("Cannot retry from empty checkpoint.")
if checkpoint and self._should_resize():
logger.info("Resize opportunity detected. Attempting to scale up.")
self._resize_workers(checkpoint=checkpoint)
with self.optimizer_timer:
worker_stats = ray.get([w.step.remote() for w in self.workers])
success, worker_stats = self._train_step()
# Fault handling
for i in range(max_retries):
if success:
break
else:
self._num_failures += 1
self._resize_workers(checkpoint=checkpoint)
logger.info("Retrying training step with %d workers." % len(
self.workers))
success, worker_stats = self._train_step()
if not success:
raise RuntimeError("Training run failed.")
worker_stats = ray.get(worker_stats)
train_stats = {}
for stat_key in worker_stats[0]:
@@ -163,6 +219,11 @@ class PyTorchTrainer:
train_stats[stat_key] = worker_stats[0][stat_key]
return train_stats
def _train_step(self):
worker_stats = [w.step.remote() for w in self.workers]
success = utils.check_for_failure(worker_stats)
return success, worker_stats
def apply_all_workers(self, fn):
return ray.get([w.apply_fn.remote(fn) for w in self.workers])
@@ -211,11 +272,54 @@ class PyTorchTrainer:
state_id = ray.put(state)
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
def shutdown(self):
def shutdown(self, force=False):
"""Shuts down workers and releases resources."""
for worker in self.workers:
worker.shutdown.remote()
worker.__ray_terminate__.remote()
if not force:
worker.shutdown.remote()
worker.__ray_terminate__.remote()
else:
logger.warning("Killing worker {}.".format(worker))
worker.__ray_kill__()
self.workers = []
def _resize_workers(self, checkpoint, max_retries=10):
# check available resources
self.shutdown(force=True)
assert checkpoint, "Cannot restore without checkpoint."
time.sleep(1)
for i in range(max_retries):
resources = ray.available_resources()
new_workers = min(resources.get("CPU", 0), self.max_replicas)
if self.use_gpu:
new_workers = min(resources.get("GPU", 0), new_workers)
if new_workers:
self._last_resize = time.time()
self._start_workers(int(new_workers))
self.restore(checkpoint)
return
else:
delay = 2**i
logger.info("Resources: {}".format(resources))
logger.warning(
"No new workers found. Retrying in %d sec." % delay)
time.sleep(delay)
raise RuntimeError("Exceeded max_retries for relaunching workers.")
def _should_resize(self):
"""Returns True if past cooldown and exists resources to scale up."""
worker_gap = self.max_replicas - len(self.workers)
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
if past_cooldown and worker_gap:
resources = ray.available_resources()
potential_workers = min(resources.get("CPU", 0), self.max_replicas)
if self.use_gpu:
potential_workers = min(
resources.get("GPU", 0), potential_workers)
return potential_workers > 0
return False
class PyTorchTrainable(Trainable):
+117 -14
View File
@@ -1,21 +1,26 @@
import os
import pytest
import tempfile
from unittest.mock import patch
import pytest
import time
import torch
import torch.nn as nn
import torch.distributed as dist
import ray
from ray import tune
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
from ray.experimental.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
from ray.experimental.sgd.pytorch.utils import train
from ray.experimental.sgd.utils import check_for_failure
from ray.experimental.sgd.examples.train_example import (
model_creator, optimizer_creator, data_creator)
model_creator, optimizer_creator, data_creator, LinearDataset)
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("num_replicas", [1, 2]
if dist.is_available() else [1])
def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
trainer = PyTorchTrainer(
model_creator,
@@ -36,8 +41,8 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
assert validation_loss2 <= validation_loss1
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("num_replicas", [1, 2]
if dist.is_available() else [1])
def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
def custom_train(models, dataloader, criterion, optimizers, config):
result = {}
@@ -94,15 +99,15 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("num_replicas", [1, 2]
if dist.is_available() else [1])
def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
config = {
"model_creator": tune.function(model_creator),
"data_creator": tune.function(data_creator),
"optimizer_creator": tune.function(optimizer_creator),
"loss_creator": tune.function(lambda config: nn.MSELoss()),
"model_creator": model_creator,
"data_creator": data_creator,
"optimizer_creator": optimizer_creator,
"loss_creator": lambda config: nn.MSELoss(),
"num_replicas": num_replicas,
"use_gpu": False,
"batch_size": 512,
@@ -127,8 +132,8 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
assert validation_loss2 <= validation_loss1
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("num_replicas", [1, 2]
if dist.is_available() else [1])
def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
trainer1 = PyTorchTrainer(
model_creator,
@@ -164,3 +169,101 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
for k in model1_state_dict:
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
return
def single_loader(batch_size, config):
train_dataset = LinearDataset(2, 5, size=1000000)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size)
return train_loader
def step_with_fail(self):
worker_stats = [w.step.remote() for w in self.workers]
if self._num_failures < 3:
time.sleep(1) # Make the batch will fail correctly.
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
trainer1 = PyTorchTrainer(
model_creator,
single_loader,
optimizer_creator,
batch_size=100000,
loss_creator=lambda config: nn.MSELoss(),
num_replicas=2)
with pytest.raises(RuntimeError):
trainer1.train(max_retries=1)
def test_resize(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
return
def single_loader(batch_size, config):
train_dataset = LinearDataset(2, 5, size=1000000)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size)
return train_loader
def step_with_fail(self):
worker_stats = [w.step.remote() for w in self.workers]
if self._num_failures < 1:
time.sleep(1) # Make the batch will fail correctly.
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
trainer1 = PyTorchTrainer(
model_creator,
single_loader,
optimizer_creator,
batch_size=100000,
loss_creator=lambda config: nn.MSELoss(),
num_replicas=2)
@ray.remote
def try_test():
import time
time.sleep(100)
try_test.remote()
trainer1.train(max_retries=1)
assert len(trainer1.workers) == 1
def test_fail_twice(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
return
def single_loader(batch_size, config):
train_dataset = LinearDataset(2, 5, size=1000000)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size)
return train_loader
def step_with_fail(self):
worker_stats = [w.step.remote() for w in self.workers]
if self._num_failures < 2:
time.sleep(1)
self.workers[0].__ray_kill__()
success = check_for_failure(worker_stats)
return success, worker_stats
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
trainer1 = PyTorchTrainer(
model_creator,
single_loader,
optimizer_creator,
batch_size=100000,
loss_creator=lambda config: nn.MSELoss(),
num_replicas=2)
trainer1.train(max_retries=2)
+28
View File
@@ -1,8 +1,14 @@
from contextlib import closing
import logging
import numpy as np
import socket
import time
import ray
from ray.exceptions import RayActorError
logger = logging.getLogger(__name__)
class TimerStat:
"""A running stat for conveniently logging the duration of a code block.
@@ -121,3 +127,25 @@ class AverageMeter:
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def check_for_failure(remote_values):
"""Checks remote values for any that returned and failed.
Args:
remote_values (list): List of object IDs representing functions
that may fail in the middle of execution. For example, running
a SGD training loop in multiple parallel actor calls.
Returns:
Bool for success in executing given remote tasks.
"""
unfinished = remote_values
try:
while len(unfinished) > 0:
finished, unfinished = ray.wait(unfinished)
finished = ray.get(finished)
return True
except RayActorError as exc:
logger.exception(str(exc))
return False
@@ -279,7 +279,7 @@ if __name__ == "__main__":
MemNNModel,
name="pbt_babi_memnn",
scheduler=pbt,
stop={"training_iteration": 20 if args.smoke_test else 100},
stop={"training_iteration": 10 if args.smoke_test else 100},
num_samples=4,
config={
"batch_size": 32,