mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 01:23:10 +08:00
[sgd] fault tolerance for pytorch + revamp documentation (#6465)
This commit is contained in:
@@ -88,11 +88,18 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
||||
|
||||
def get_state(self):
|
||||
"""Returns the state of the runner."""
|
||||
# This is so that we create a duplicate of weights into CPU rather than
|
||||
# move the model weights entirely out of the GPU, so that we can
|
||||
# resume training while saving intermediate checkpoints.
|
||||
cpu_state_dicts = []
|
||||
for model in self.models:
|
||||
state_dict = model.module.state_dict()
|
||||
for k, v in state_dict.items():
|
||||
state_dict[k] = v.cpu()
|
||||
cpu_state_dicts += [state_dict]
|
||||
return {
|
||||
"epoch": self.epoch,
|
||||
"models": [
|
||||
model.module.cpu().state_dict() for model in self.models
|
||||
],
|
||||
"models": cpu_state_dicts,
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers],
|
||||
"stats": self.stats()
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ def data_creator(batch_size, config):
|
||||
]))
|
||||
|
||||
# Create the dataloader
|
||||
train_sampler = None
|
||||
if distributed.is_initialized():
|
||||
train_sampler = DistributedSampler(dataset)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
@@ -238,8 +239,8 @@ def train_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
use_gpu=use_gpu,
|
||||
batch_size=16 if test_mode else 512,
|
||||
backend="nccl" if use_gpu else "gloo")
|
||||
for i in range(5):
|
||||
stats = trainer.train()
|
||||
for i in range(10):
|
||||
stats = trainer.train(max_retries=3)
|
||||
print(stats)
|
||||
|
||||
return trainer
|
||||
|
||||
@@ -4,6 +4,8 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import logging
|
||||
import numbers
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import ray
|
||||
|
||||
@@ -15,6 +17,7 @@ from ray.experimental.sgd import utils
|
||||
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RESIZE_COOLDOWN_S = 10
|
||||
|
||||
|
||||
class PyTorchTrainer:
|
||||
@@ -74,8 +77,12 @@ class PyTorchTrainer:
|
||||
"https://github.com/pytorch/examples/issues/467."))
|
||||
|
||||
self.model_creator = model_creator
|
||||
self.data_creator = data_creator
|
||||
self.train_function = train_function
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.validation_function = validation_function
|
||||
self.initialization_hook = initialization_hook
|
||||
self.config = {} if config is None else config
|
||||
self.optimizer_timer = utils.TimerStat(window_size=1)
|
||||
|
||||
@@ -83,58 +90,69 @@ class PyTorchTrainer:
|
||||
backend = "nccl" if use_gpu else "gloo"
|
||||
|
||||
logger.info("Using {} as backend.".format(backend))
|
||||
self.backend = backend
|
||||
self.use_gpu = use_gpu
|
||||
self.batch_size = batch_size
|
||||
self.max_replicas = num_replicas
|
||||
self.temp_dir = tempfile.mkdtemp(prefix="raysgd")
|
||||
self._num_failures = 0
|
||||
self._last_resize = float("-inf")
|
||||
self._start_workers(self.max_replicas)
|
||||
|
||||
def _start_workers(self, num_replicas):
|
||||
logger.info(f"start_workers: Setting %d replicas." % num_replicas)
|
||||
if num_replicas == 1:
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=1, num_gpus=int(use_gpu))(PyTorchRunner)
|
||||
num_cpus=1, num_gpus=int(self.use_gpu))(PyTorchRunner)
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
train_function=train_function,
|
||||
validation_function=validation_function,
|
||||
self.model_creator,
|
||||
self.data_creator,
|
||||
self.optimizer_creator,
|
||||
self.loss_creator,
|
||||
train_function=self.train_function,
|
||||
validation_function=self.validation_function,
|
||||
config=self.config,
|
||||
batch_size=batch_size)
|
||||
batch_size=self.batch_size)
|
||||
]
|
||||
if initialization_hook:
|
||||
self.apply_all_workers(initialization_hook)
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(self.workers[0].setup.remote())
|
||||
else:
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=1, num_gpus=int(use_gpu))(DistributedPyTorchRunner)
|
||||
num_cpus=1,
|
||||
num_gpus=int(self.use_gpu))(DistributedPyTorchRunner)
|
||||
# Compute batch size per replica
|
||||
batch_size_per_replica = batch_size // num_replicas
|
||||
if batch_size % num_replicas > 0:
|
||||
batch_size_per_replica = self.batch_size // num_replicas
|
||||
if self.batch_size % num_replicas > 0:
|
||||
new_batch_size = batch_size_per_replica * num_replicas
|
||||
logger.warning(
|
||||
("Changing batch size from {old_batch_size} to "
|
||||
"{new_batch_size} to evenly distribute batches across "
|
||||
"{num_replicas} replicas.").format(
|
||||
old_batch_size=batch_size,
|
||||
old_batch_size=self.batch_size,
|
||||
new_batch_size=new_batch_size,
|
||||
num_replicas=num_replicas))
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
backend=backend,
|
||||
train_function=train_function,
|
||||
validation_function=validation_function,
|
||||
self.model_creator,
|
||||
self.data_creator,
|
||||
self.optimizer_creator,
|
||||
self.loss_creator,
|
||||
backend=self.backend,
|
||||
train_function=self.train_function,
|
||||
validation_function=self.validation_function,
|
||||
config=self.config,
|
||||
batch_size=batch_size_per_replica)
|
||||
for i in range(num_replicas)
|
||||
]
|
||||
if initialization_hook:
|
||||
self.apply_all_workers(initialization_hook)
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
|
||||
# Compute URL for initializing distributed PyTorch
|
||||
ip = ray.get(self.workers[0].get_node_ip.remote())
|
||||
@@ -146,13 +164,51 @@ class PyTorchTrainer:
|
||||
for i, worker in enumerate(self.workers)
|
||||
])
|
||||
|
||||
def train(self):
|
||||
def train(self, max_retries=10, checkpoint="auto"):
|
||||
"""Runs a training epoch.
|
||||
|
||||
Runs an average over all values returned from workers.
|
||||
Runs an average over all values returned from workers. Set
|
||||
`max_retries` to enable fault handling in case of instance preemption.
|
||||
|
||||
Args:
|
||||
max_retries (int): Must be non-negative. If set to N, will
|
||||
kill all current workers, query the Ray global state for
|
||||
total available resources, and re-launch up to the
|
||||
available resources. Behavior is not well-defined
|
||||
in case of shared cluster usage.
|
||||
checkpoint (str): Path to checkpoint to restore from if retrying.
|
||||
If max_retries is set and checkpoint == "auto", PyTorchTrainer
|
||||
will save a checkpoint before starting to train.
|
||||
"""
|
||||
assert max_retries >= 0, "`max_retries` must be non-negative."
|
||||
if max_retries:
|
||||
if checkpoint == "auto":
|
||||
logger.debug("Retrying detected. Automatically checkpointing.")
|
||||
checkpoint = self.save(
|
||||
os.path.join(self.temp_dir, "tmp_checkpoint"))
|
||||
elif not checkpoint:
|
||||
raise ValueError("Cannot retry from empty checkpoint.")
|
||||
|
||||
if checkpoint and self._should_resize():
|
||||
logger.info("Resize opportunity detected. Attempting to scale up.")
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
|
||||
with self.optimizer_timer:
|
||||
worker_stats = ray.get([w.step.remote() for w in self.workers])
|
||||
success, worker_stats = self._train_step()
|
||||
# Fault handling
|
||||
for i in range(max_retries):
|
||||
if success:
|
||||
break
|
||||
else:
|
||||
self._num_failures += 1
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
logger.info("Retrying training step with %d workers." % len(
|
||||
self.workers))
|
||||
success, worker_stats = self._train_step()
|
||||
if not success:
|
||||
raise RuntimeError("Training run failed.")
|
||||
|
||||
worker_stats = ray.get(worker_stats)
|
||||
|
||||
train_stats = {}
|
||||
for stat_key in worker_stats[0]:
|
||||
@@ -163,6 +219,11 @@ class PyTorchTrainer:
|
||||
train_stats[stat_key] = worker_stats[0][stat_key]
|
||||
return train_stats
|
||||
|
||||
def _train_step(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
success = utils.check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
return ray.get([w.apply_fn.remote(fn) for w in self.workers])
|
||||
|
||||
@@ -211,11 +272,54 @@ class PyTorchTrainer:
|
||||
state_id = ray.put(state)
|
||||
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
|
||||
|
||||
def shutdown(self):
|
||||
def shutdown(self, force=False):
|
||||
"""Shuts down workers and releases resources."""
|
||||
for worker in self.workers:
|
||||
worker.shutdown.remote()
|
||||
worker.__ray_terminate__.remote()
|
||||
if not force:
|
||||
worker.shutdown.remote()
|
||||
worker.__ray_terminate__.remote()
|
||||
else:
|
||||
logger.warning("Killing worker {}.".format(worker))
|
||||
worker.__ray_kill__()
|
||||
|
||||
self.workers = []
|
||||
|
||||
def _resize_workers(self, checkpoint, max_retries=10):
|
||||
# check available resources
|
||||
self.shutdown(force=True)
|
||||
assert checkpoint, "Cannot restore without checkpoint."
|
||||
|
||||
time.sleep(1)
|
||||
for i in range(max_retries):
|
||||
resources = ray.available_resources()
|
||||
new_workers = min(resources.get("CPU", 0), self.max_replicas)
|
||||
if self.use_gpu:
|
||||
new_workers = min(resources.get("GPU", 0), new_workers)
|
||||
if new_workers:
|
||||
self._last_resize = time.time()
|
||||
self._start_workers(int(new_workers))
|
||||
self.restore(checkpoint)
|
||||
return
|
||||
else:
|
||||
delay = 2**i
|
||||
logger.info("Resources: {}".format(resources))
|
||||
logger.warning(
|
||||
"No new workers found. Retrying in %d sec." % delay)
|
||||
time.sleep(delay)
|
||||
raise RuntimeError("Exceeded max_retries for relaunching workers.")
|
||||
|
||||
def _should_resize(self):
|
||||
"""Returns True if past cooldown and exists resources to scale up."""
|
||||
worker_gap = self.max_replicas - len(self.workers)
|
||||
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
|
||||
if past_cooldown and worker_gap:
|
||||
resources = ray.available_resources()
|
||||
potential_workers = min(resources.get("CPU", 0), self.max_replicas)
|
||||
if self.use_gpu:
|
||||
potential_workers = min(
|
||||
resources.get("GPU", 0), potential_workers)
|
||||
return potential_workers > 0
|
||||
return False
|
||||
|
||||
|
||||
class PyTorchTrainable(Trainable):
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
import os
|
||||
import pytest
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
|
||||
from ray.experimental.sgd.pytorch.utils import train
|
||||
from ray.experimental.sgd.utils import check_for_failure
|
||||
|
||||
from ray.experimental.sgd.examples.train_example import (
|
||||
model_creator, optimizer_creator, data_creator)
|
||||
model_creator, optimizer_creator, data_creator, LinearDataset)
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
trainer = PyTorchTrainer(
|
||||
model_creator,
|
||||
@@ -36,8 +41,8 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
assert validation_loss2 <= validation_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
def custom_train(models, dataloader, criterion, optimizers, config):
|
||||
result = {}
|
||||
@@ -94,15 +99,15 @@ def test_multi_model(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
config = {
|
||||
"model_creator": tune.function(model_creator),
|
||||
"data_creator": tune.function(data_creator),
|
||||
"optimizer_creator": tune.function(optimizer_creator),
|
||||
"loss_creator": tune.function(lambda config: nn.MSELoss()),
|
||||
"model_creator": model_creator,
|
||||
"data_creator": data_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": lambda config: nn.MSELoss(),
|
||||
"num_replicas": num_replicas,
|
||||
"use_gpu": False,
|
||||
"batch_size": 512,
|
||||
@@ -127,8 +132,8 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
assert validation_loss2 <= validation_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
@@ -164,3 +169,101 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
for k in model1_state_dict:
|
||||
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
|
||||
|
||||
|
||||
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5, size=1000000)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
if self._num_failures < 3:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
batch_size=100000,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=2)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer1.train(max_retries=1)
|
||||
|
||||
|
||||
def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5, size=1000000)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
if self._num_failures < 1:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
batch_size=100000,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=2)
|
||||
|
||||
@ray.remote
|
||||
def try_test():
|
||||
import time
|
||||
time.sleep(100)
|
||||
|
||||
try_test.remote()
|
||||
trainer1.train(max_retries=1)
|
||||
assert len(trainer1.workers) == 1
|
||||
|
||||
|
||||
def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(batch_size, config):
|
||||
train_dataset = LinearDataset(2, 5, size=1000000)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size)
|
||||
return train_loader
|
||||
|
||||
def step_with_fail(self):
|
||||
worker_stats = [w.step.remote() for w in self.workers]
|
||||
if self._num_failures < 2:
|
||||
time.sleep(1)
|
||||
self.workers[0].__ray_kill__()
|
||||
success = check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
with patch.object(PyTorchTrainer, "_train_step", step_with_fail):
|
||||
trainer1 = PyTorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
batch_size=100000,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=2)
|
||||
|
||||
trainer1.train(max_retries=2)
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
from contextlib import closing
|
||||
import logging
|
||||
import numpy as np
|
||||
import socket
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayActorError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimerStat:
|
||||
"""A running stat for conveniently logging the duration of a code block.
|
||||
@@ -121,3 +127,25 @@ class AverageMeter:
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def check_for_failure(remote_values):
|
||||
"""Checks remote values for any that returned and failed.
|
||||
|
||||
Args:
|
||||
remote_values (list): List of object IDs representing functions
|
||||
that may fail in the middle of execution. For example, running
|
||||
a SGD training loop in multiple parallel actor calls.
|
||||
|
||||
Returns:
|
||||
Bool for success in executing given remote tasks.
|
||||
"""
|
||||
unfinished = remote_values
|
||||
try:
|
||||
while len(unfinished) > 0:
|
||||
finished, unfinished = ray.wait(unfinished)
|
||||
finished = ray.get(finished)
|
||||
return True
|
||||
except RayActorError as exc:
|
||||
logger.exception(str(exc))
|
||||
return False
|
||||
|
||||
@@ -279,7 +279,7 @@ if __name__ == "__main__":
|
||||
MemNNModel,
|
||||
name="pbt_babi_memnn",
|
||||
scheduler=pbt,
|
||||
stop={"training_iteration": 20 if args.smoke_test else 100},
|
||||
stop={"training_iteration": 10 if args.smoke_test else 100},
|
||||
num_samples=4,
|
||||
config={
|
||||
"batch_size": 32,
|
||||
|
||||
Reference in New Issue
Block a user