[Ray SGD] use_local flag + Worker group abstraction (#10539)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Amog Kamsetty
2020-09-15 11:58:57 -07:00
committed by GitHub
parent 0865d68466
commit d5a7c53908
15 changed files with 1025 additions and 426 deletions
+4 -1
View File
@@ -75,7 +75,6 @@ class _TorchTrainable(tune.Trainable):
remote_trainable = remote_trainable.options(
**self.get_remote_worker_options())
address = setup_address()
self.workers = [
remote_trainable.remote(
config=config,
@@ -83,6 +82,10 @@ class _TorchTrainable(tune.Trainable):
for rank in range(num_workers)
]
# Address has to be IP of rank 0 worker's node.
address = ray.get(
self.workers[0].execute.remote(lambda _: setup_address()))
pgroup_params = self.default_process_group_parameters()
from functools import partial
setup_on_worker = partial(
+8
View File
@@ -18,6 +18,14 @@ py_test(
deps = [":sgd_lib"],
)
py_test(
name = "test_torch_failure",
size = "large",
srcs = ["tests/test_torch_failure.py"],
tags = ["exclusive", "pytorch"],
deps = [":sgd_lib"],
)
py_test(
name = "test_torch_runner",
size = "small",
+2 -2
View File
@@ -84,8 +84,8 @@ class Dataset():
Returns a single, iterable shard.
"""
assert i < self.iter.num_shards(), \
"Trying to get shard {} but there are only {} shards." + \
"Are you sure you called set_num_shards already?".format(
"Trying to get shard {} but there are only {} shards. Are you " \
"sure you called set_num_shards already?".format(
i, self.iter.num_shards()
)
+86 -154
View File
@@ -1,8 +1,6 @@
from unittest.mock import patch
import numpy as np
import os
import pytest
import time
import torch
import torch.nn as nn
import torch.distributed as dist
@@ -14,8 +12,7 @@ from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.training_operator import (
get_test_operator, get_test_metrics_operator, TrainingOperator)
from ray.util.sgd.torch.constants import SCHEDULER_STEP
from ray.util.sgd.utils import (check_for_failure, NUM_SAMPLES, BATCH_COUNT,
BATCH_SIZE)
from ray.util.sgd.utils import (NUM_SAMPLES, BATCH_COUNT, BATCH_SIZE)
from ray.util.sgd.data.examples import mlp_identity
from ray.util.sgd.torch.examples.train_example import (
@@ -48,8 +45,13 @@ Operator = TrainingOperator.from_creators(
model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss)
def test_single_step(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1)
@pytest.mark.parametrize("use_local", [True, False])
def test_single_step(ray_start_2_cpus, use_local): # noqa: F811
trainer = TorchTrainer(
training_operator_cls=Operator,
num_workers=1,
use_local=use_local,
use_gpu=False)
metrics = trainer.train(num_steps=1)
assert metrics[BATCH_COUNT] == 1
@@ -58,9 +60,14 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
trainer.shutdown()
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
@pytest.mark.parametrize("use_local", [True, False])
def test_dead_trainer(ray_start_2_cpus, use_local): # noqa: F811
TestOperator = get_test_operator(Operator)
trainer = TorchTrainer(training_operator_cls=TestOperator, num_workers=2)
trainer = TorchTrainer(
training_operator_cls=TestOperator,
num_workers=2,
use_local=use_local,
use_gpu=False)
trainer.train(num_steps=1)
trainer.shutdown()
with pytest.raises(RuntimeError):
@@ -68,9 +75,13 @@ def test_dead_trainer(ray_start_2_cpus): # noqa: F811
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_train(ray_start_2_cpus, num_workers): # noqa: F811
@pytest.mark.parametrize("use_local", [True, False])
def test_train(ray_start_2_cpus, num_workers, use_local): # noqa: F811
trainer = TorchTrainer(
training_operator_cls=Operator, num_workers=num_workers)
training_operator_cls=Operator,
num_workers=num_workers,
use_local=use_local,
use_gpu=False)
for i in range(3):
train_loss1 = trainer.train()["train_loss"]
validation_loss1 = trainer.validate()["val_loss"]
@@ -86,7 +97,8 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_multi_model(ray_start_2_cpus, num_workers):
@pytest.mark.parametrize("use_local", [True, False])
def test_multi_model(ray_start_2_cpus, num_workers, use_local):
def train(*, model=None, criterion=None, optimizer=None, iterator=None):
model.train()
train_loss = 0
@@ -140,7 +152,10 @@ def test_multi_model(ray_start_2_cpus, num_workers):
trainer1 = TorchTrainer(
config={"custom_func": train_epoch},
training_operator_cls=TestOperator,
num_workers=num_workers)
num_workers=num_workers,
use_local=use_local,
use_gpu=False,
)
trainer1.train()
state = trainer1.state_dict()
@@ -151,7 +166,10 @@ def test_multi_model(ray_start_2_cpus, num_workers):
trainer2 = TorchTrainer(
config={"custom_func": train_epoch},
training_operator_cls=TestOperator,
num_workers=num_workers)
num_workers=num_workers,
use_local=use_local,
use_gpu=False,
)
trainer2.load_state_dict(state)
models2 = trainer2.get_model()
@@ -170,7 +188,9 @@ def test_multi_model(ray_start_2_cpus, num_workers):
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
@pytest.mark.parametrize("use_local", [True, False])
def test_multi_model_matrix(ray_start_2_cpus, num_workers, use_local): #
# noqa: F811
def train_epoch(self, iterator, info):
if self.config.get("models", 1) > 1:
assert len(self.models) == self.config["models"], self.config
@@ -231,6 +251,7 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
scheduler_step_freq="epoch",
training_operator_cls=TestOperator,
num_workers=num_workers,
use_local=use_local,
config={
"models": model_count,
"optimizers": optimizer_count,
@@ -242,7 +263,8 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
@pytest.mark.parametrize("scheduler_freq", ["epoch", "batch", "manual", None])
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa:
# F811
def train_epoch(self, iterator, info):
assert info[SCHEDULER_STEP] == scheduler_freq
return {"done": 1}
@@ -271,20 +293,23 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
trainer = TorchTrainer(
config={"custom_func": train_epoch},
training_operator_cls=TestTrainingOperator,
scheduler_step_freq=scheduler_freq)
scheduler_step_freq=scheduler_freq,
)
else:
trainer = TorchTrainer(
config={"custom_func": train_epoch},
training_operator_cls=TestTrainingOperator,
scheduler_step_freq=scheduler_freq)
scheduler_step_freq=scheduler_freq,
)
for i in range(3):
trainer.train()
trainer.shutdown()
def test_profiling(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(training_operator_cls=Operator)
@pytest.mark.parametrize("use_local", [True, False])
def test_profiling(ray_start_2_cpus, use_local): # noqa: F811
trainer = TorchTrainer(training_operator_cls=Operator, use_local=use_local)
stats = trainer.train(profile=True)
assert "profile" in stats
@@ -293,7 +318,8 @@ def test_profiling(ray_start_2_cpus): # noqa: F811
trainer.shutdown()
def test_dataset(ray_start_4_cpus):
@pytest.mark.parametrize("use_local", [True, False])
def test_dataset(ray_start_4_cpus, use_local):
"""
This test tries training the mlp_identity example. We check the accuracy of
the model as an all inclusive way of ensuring that we are properly sharding
@@ -312,6 +338,7 @@ def test_dataset(ray_start_4_cpus):
trainer = TorchTrainer(
training_operator_cls=DatasetOperator,
use_local=use_local,
num_workers=2,
)
@@ -319,13 +346,14 @@ def test_dataset(ray_start_4_cpus):
for i in range(5):
trainer.train(dataset=dataset, num_steps=100)
input = mlp_identity.to_mat(0.5)
prediction = float(trainer.get_model()(input)[0][0])
x = mlp_identity.to_mat(0.5)
prediction = float(trainer.get_model()(x)[0][0])
assert 0.4 <= prediction <= 0.6
trainer.shutdown()
def test_split_batch(ray_start_2_cpus):
@pytest.mark.parametrize("use_local", [True, False])
def test_split_batch(ray_start_2_cpus, use_local):
if not dist.is_available():
return
@@ -347,6 +375,7 @@ def test_split_batch(ray_start_2_cpus):
trainer = TorchTrainer(
training_operator_cls=TestOperator,
num_workers=2,
use_local=use_local,
config={
BATCH_SIZE: batch_size,
"data_size": data_size,
@@ -358,7 +387,8 @@ def test_split_batch(ray_start_2_cpus):
trainer.shutdown()
def test_reduce_result(ray_start_2_cpus):
@pytest.mark.parametrize("use_local", [True, False])
def test_reduce_result(ray_start_2_cpus, use_local):
if not dist.is_available():
return
@@ -380,6 +410,7 @@ def test_reduce_result(ray_start_2_cpus):
trainer = TorchTrainer(
training_operator_cls=TestOperator,
num_workers=2,
use_local=use_local,
config={"data_size": data_size})
list_stats = trainer.train(reduce_results=False, profile=True)
assert len(list_stats) == 2
@@ -393,7 +424,8 @@ def test_reduce_result(ray_start_2_cpus):
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_metrics(ray_start_2_cpus, num_workers):
@pytest.mark.parametrize("use_local", [True, False])
def test_metrics(ray_start_2_cpus, num_workers, use_local):
data_size, val_size = 600, 500
batch_size = 4
@@ -407,6 +439,7 @@ def test_metrics(ray_start_2_cpus, num_workers):
trainer = TorchTrainer(
training_operator_cls=TestOperator,
num_workers=num_workers,
use_local=use_local,
config={
"scores": train_scores,
"val_scores": val_scores,
@@ -440,7 +473,8 @@ def test_metrics(ray_start_2_cpus, num_workers):
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_metrics_nan(ray_start_2_cpus, num_workers):
@pytest.mark.parametrize("use_local", [True, False])
def test_metrics_nan(ray_start_2_cpus, num_workers, use_local):
data_size, val_size = 100, 100
batch_size = 10
@@ -453,6 +487,7 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
trainer = TorchTrainer(
training_operator_cls=TestOperator,
num_workers=num_workers,
use_local=use_local,
config={
"scores": train_scores,
"val_scores": val_scores,
@@ -474,7 +509,8 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
trainer.shutdown()
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
@pytest.mark.parametrize("use_local", [True, False])
def test_scheduler_validate(ray_start_2_cpus, use_local): # noqa: F811
from torch.optim.lr_scheduler import ReduceLROnPlateau
TestOperator = TrainingOperator.from_creators(
@@ -485,7 +521,9 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
loss_creator=lambda config: nn.MSELoss())
TestOperator = get_test_operator(TestOperator)
trainer = TorchTrainer(
scheduler_step_freq="manual", training_operator_cls=TestOperator)
scheduler_step_freq="manual",
training_operator_cls=TestOperator,
use_local=use_local)
trainer.update_scheduler(0.5)
trainer.update_scheduler(0.5)
assert all(
@@ -494,14 +532,16 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
trainer.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
@pytest.mark.parametrize("num_workers", [2] if dist.is_available() else [1])
@pytest.mark.parametrize("use_local", [True, False])
def test_tune_train(ray_start_4_cpus, num_workers, use_local): # noqa: F811
TorchTrainable = TorchTrainer.as_trainable(
**{
"training_operator_cls": Operator,
"num_workers": num_workers,
"use_gpu": False,
"backend": "gloo",
"use_local": use_local,
"config": {
"batch_size": 512,
"lr": 0.001
@@ -526,10 +566,13 @@ def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_save_and_restore(ray_start_2_cpus, num_workers,
@pytest.mark.parametrize("use_local", [True, False])
def test_save_and_restore(ray_start_2_cpus, num_workers, use_local,
tmp_path): # noqa: F811
trainer1 = TorchTrainer(
training_operator_cls=Operator, num_workers=num_workers)
training_operator_cls=Operator,
num_workers=num_workers,
use_local=use_local)
trainer1.train()
checkpoint_path = os.path.join(tmp_path, "checkpoint")
trainer1.save(checkpoint_path)
@@ -539,7 +582,9 @@ def test_save_and_restore(ray_start_2_cpus, num_workers,
trainer1.shutdown()
trainer2 = TorchTrainer(
training_operator_cls=Operator, num_workers=num_workers)
training_operator_cls=Operator,
num_workers=num_workers,
use_local=use_local)
trainer2.load(checkpoint_path)
model2 = trainer2.get_model()
@@ -558,14 +603,15 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
if not dist.is_available():
return
trainer1 = TorchTrainer(
training_operator_cls=Operator, wrap_ddp=False, num_workers=2)
training_operator_cls=Operator,
wrap_ddp=False,
num_workers=2,
use_local=True)
trainer1.train()
checkpoint_path = os.path.join(tmp_path, "checkpoint")
trainer1.save(checkpoint_path)
model1 = trainer1.get_model()
assert not hasattr(trainer1.local_worker.training_operator.model, "module")
assert hasattr(trainer1.local_worker.training_operator, "device_ids")
trainer1.shutdown()
trainer2 = TorchTrainer(
@@ -584,123 +630,8 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
trainer2.shutdown()
def gen_step_with_fail(num_fails):
def step_with_fail(self,
num_steps=None,
profile=False,
info=None,
dataset=None):
params = dict(num_steps=num_steps, profile=profile, info=info)
remote_worker_stats = [
w.train_epoch.remote(**params) for w in self.remote_workers
]
if self._num_failures < num_fails:
time.sleep(1) # Make the batch will fail correctly.
ray.kill(self.remote_workers[0])
try:
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError:
return False, None
success = check_for_failure(remote_worker_stats)
if success:
return success, [local_worker_stats] + ray.get(remote_worker_stats)
return success, None
return step_with_fail
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
return
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
step_with_fail = gen_step_with_fail(3)
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
num_workers=2)
with pytest.raises(RuntimeError):
trainer1.train(max_retries=1)
trainer1.shutdown(force=True)
def test_resize(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
return
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
step_with_fail = gen_step_with_fail(1)
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
num_workers=2)
@ray.remote
def try_test():
import time
time.sleep(100)
try_test.remote()
trainer1.train(max_retries=1)
assert len(trainer1.remote_workers) == 1
trainer1.shutdown()
def test_fail_twice(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
return
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
step_with_fail = gen_step_with_fail(2)
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
num_workers=2)
# MAX RETRIES SHOULD BE ON BY DEFAULT
trainer1.train()
trainer1.shutdown()
def test_multi_input_model(ray_start_2_cpus):
@pytest.mark.parametrize("use_local", [True, False])
def test_multi_input_model(ray_start_2_cpus, use_local):
def model_creator(config):
class MultiInputModel(nn.Module):
def __init__(self):
@@ -742,7 +673,8 @@ def test_multi_input_model(ray_start_2_cpus):
data_creator,
loss_creator=lambda config: nn.MSELoss())
trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1)
trainer = TorchTrainer(
training_operator_cls=Operator, num_workers=1, use_local=use_local)
metrics = trainer.train(num_steps=1)
assert metrics[BATCH_COUNT] == 1
@@ -0,0 +1,175 @@
from unittest.mock import patch
import pytest
import time
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader
import ray
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.worker_group import RemoteWorkerGroup
from ray.util.sgd.torch.training_operator import TrainingOperator
from ray.util.sgd.torch.examples.train_example import (
model_creator, optimizer_creator, data_creator, LinearDataset)
Operator = TrainingOperator.from_creators(
model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss)
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
# Ensure that tests don't ALL fail
if dist.is_initialized():
dist.destroy_process_group()
@pytest.fixture
def ray_start_4_cpus():
address_info = ray.init(num_cpus=4)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
# Ensure that tests don't ALL fail
if dist.is_initialized():
dist.destroy_process_group()
def remote_worker_train_with_fail(self, num_steps, profile, info,
dataset=None):
remote_worker_stats = []
for i, w in enumerate(self.remote_workers):
params = dict(num_steps=num_steps, profile=profile, info=info)
if dataset:
params["iterator"] = dataset.get_shard(i)
stats = w.train_epoch.remote(**params)
remote_worker_stats.append(stats)
if i == 0 and hasattr(self, "should_fail") and self.should_fail:
time.sleep(1)
ray.kill(self.remote_workers[i])
return remote_worker_stats
start_workers = TorchTrainer._start_workers
def gen_start_with_fail(num_fails):
def start_with_fail(self, *args, **kwargs):
start_workers(self, *args, **kwargs)
fail = self._num_failures < num_fails
if self.use_local:
self.worker_group.remote_worker_group.should_fail = fail
else:
self.worker_group.should_fail = fail
return start_with_fail
@pytest.mark.parametrize("use_local", [False, True])
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
def test_resize(ray_start_2_cpus, use_local): # noqa: F811
if not dist.is_available():
return
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
start_with_fail = gen_start_with_fail(1)
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
trainer1 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
use_local=use_local,
num_workers=2)
@ray.remote
def try_test():
import time
time.sleep(100)
try_test.remote()
trainer1.train(max_retries=1)
assert trainer1.worker_group.num_workers == 1
trainer1.shutdown(force=True)
@pytest.mark.parametrize("use_local", [False, True])
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
def test_fail_twice(ray_start_2_cpus, use_local): # noqa: F811
if not dist.is_available():
return
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
start_with_fail = gen_start_with_fail(2)
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
trainer1 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
use_local=use_local,
num_workers=2)
# MAX RETRIES SHOULD BE ON BY DEFAULT
trainer1.train()
trainer1.shutdown(force=True)
@pytest.mark.parametrize("use_local", [False, True])
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
def test_fail_with_recover(ray_start_2_cpus, use_local): # noqa: F811
print(locals())
if not dist.is_available():
return
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
TestOperator = TrainingOperator.from_creators(
model_creator,
optimizer_creator,
single_loader,
loss_creator=lambda config: nn.MSELoss())
start_with_fail = gen_start_with_fail(3)
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
trainer1 = TorchTrainer(
training_operator_cls=TestOperator,
config={"batch_size": 100000},
timeout_s=5,
use_local=use_local,
num_workers=2)
with pytest.raises(RuntimeError):
trainer1.train(max_retries=1)
trainer1.shutdown(force=True)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", "-x", __file__]))
@@ -1,5 +1,4 @@
import logging
import io
import os
import torch
@@ -11,6 +10,8 @@ from ray.util.sgd.torch.utils import setup_process_group
import ray
from ray.util.sgd.torch.torch_runner import TorchRunner
from ray.util.sgd.torch.utils import setup_address
logger = logging.getLogger(__name__)
@@ -42,6 +43,9 @@ class DistributedTorchRunner(TorchRunner):
self.add_dist_sampler = add_dist_sampler
self.world_rank = None
def setup_address(self):
return setup_address()
def setup_process_group(self, url, world_rank, world_size, timeout):
"""Connects the distributed PyTorch backend.
@@ -52,6 +56,7 @@ class DistributedTorchRunner(TorchRunner):
timeout (timedelta): Seconds for process group
operations to timeout.
"""
logger.info(f"Setting up process group for: {url} [rank={world_rank}]")
self.world_rank = world_rank
setup_process_group(
url, world_rank, world_size, timeout, backend=self.backend)
@@ -83,23 +88,6 @@ class DistributedTorchRunner(TorchRunner):
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
return [0]
def load_state_stream(self, byte_obj):
"""Loads a bytes object the training state dict.
This is needed because we don't want to deserialize the tensor
onto the same device (which is from the driver process). We want to
map it onto the actor's specific device.
From: github.com/pytorch/pytorch/issues/10622#issuecomment-474733769
"""
_buffer = io.BytesIO(byte_obj)
to_gpu = self.use_gpu and torch.cuda.is_available()
state_dict = torch.load(
_buffer,
map_location=("cpu" if not to_gpu else
lambda storage, loc: storage.cuda()))
return self.load_state_dict(state_dict)
def _wrap_dataloaders(self):
def with_sampler(loader):
# Automatically set the DistributedSampler
@@ -346,10 +334,3 @@ class LocalDistributedRunner(DistributedTorchRunner):
def is_actor(self):
actor_id = ray.worker.global_worker.actor_id
return actor_id != actor_id.nil()
class DeactivatedRunner:
def __getattr__(self, *args, **kwargs):
raise RuntimeError(
"This TorchTrainer is not active (it is likely shutdown already). "
"Create a new TorchTrainer.")
@@ -138,7 +138,7 @@ if __name__ == "__main__":
use_gpu=args.use_gpu,
scheduler_step_freq="epoch",
use_fp16=args.fp16,
use_tqdm=True)
use_tqdm=False)
pbar = trange(args.num_epochs, unit="epoch")
for i in pbar:
info = {"num_steps": 1} if args.smoke_test else {}
@@ -144,7 +144,14 @@ def scheduler_creator(optimizer, config):
# __torch_scheduler_end__
# __backwards_compat__start
# __torch_ray_start__
import ray
ray.init()
# or ray.init(address="auto") to connect to a running cluster.
# __torch_ray_end__
# __backwards_compat_start__
from ray.util.sgd import TorchTrainer
MyTrainingOperator = TrainingOperator.from_creators(
@@ -157,14 +164,9 @@ trainer = TorchTrainer(
scheduler_step_freq="epoch", # if scheduler_creator is passed in
config={"lr": 0.001, "batch_size": 64})
# __backwards_compat_end
# __backwards_compat_end__
# __torch_ray_start__
import ray
ray.init()
# or ray.init(address="auto") to connect to a running cluster.
# __torch_ray_end__
trainer.shutdown()
# __torch_trainer_start__
from ray.util.sgd import TorchTrainer
@@ -175,3 +177,5 @@ trainer = TorchTrainer(
config={"lr": 0.001, "batch_size": 64})
# __torch_trainer_end__
trainer.shutdown()
@@ -52,7 +52,7 @@ def tune_example(operator_cls, num_workers=1, use_gpu=False):
stop={"training_iteration": 2},
verbose=1)
return analysis.get_best_config(metric="validation_loss", mode="min")
return analysis.get_best_config(metric="val_loss", mode="min")
# __end_torch_tune_example__
+24 -13
View File
@@ -3,7 +3,6 @@ import io
import itertools
import torch
import ray
from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS
from ray.util.sgd import utils
@@ -57,14 +56,6 @@ class TorchRunner:
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
def get_node_ip(self):
"""Returns the IP address of the current node."""
return ray.services.get_node_ip_address()
def find_free_port(self):
"""Finds a free port on the current node."""
return utils.find_free_port()
def train_epoch(self,
num_steps=None,
profile=False,
@@ -132,11 +123,15 @@ class TorchRunner:
def state_dict(self):
"""Returns the state of the runner."""
model_states = [model.state_dict() for model in self.models]
optimizer_states = [
optimizer.state_dict() for optimizer in self.optimizers
]
state = {
"epoch": self.epochs,
"operator": self.training_operator.state_dict(),
"models": [model.state_dict() for model in self.models],
"optimizers": [opt.state_dict() for opt in self.optimizers]
"models": model_states,
"optimizers": optimizer_states
}
schedulers = self.schedulers
if schedulers:
@@ -148,6 +143,7 @@ class TorchRunner:
# Check if fp16 is True and if NVIDIA Apex is imported.
if self.use_fp16 and self.training_operator._amp:
state.update({"amp": self.training_operator._amp.state_dict()})
return state
def load_state_dict(self, state):
@@ -176,9 +172,20 @@ class TorchRunner:
return _buffer.getvalue()
def load_state_stream(self, byte_obj):
"""Loads a bytes object the training state dict."""
"""Loads a bytes object the training state dict.
This is needed because we don't want to deserialize the tensor
onto the same device (which is from the driver process). We want to
map it onto the actor's specific device.
From: github.com/pytorch/pytorch/issues/10622#issuecomment-474733769
"""
_buffer = io.BytesIO(byte_obj)
state_dict = torch.load(_buffer)
to_gpu = self.use_gpu and torch.cuda.is_available()
state_dict = torch.load(
_buffer,
map_location=("cpu" if not to_gpu else
lambda storage, loc: storage.cuda()))
return self.load_state_dict(state_dict)
def apply(self, fn):
@@ -193,6 +200,10 @@ class TorchRunner:
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_models(self):
"""Getter method. Needed for remote actor calls."""
return self.models
@property
def models(self):
if not hasattr(self.training_operator, "_original_models"):
+124 -215
View File
@@ -1,29 +1,25 @@
from datetime import timedelta
import time
import numpy as np
import logging
import os
import numbers
import tempfile
import time
import torch
import torch.distributed as dist
import ray
from ray.exceptions import RayActorError
from ray.tune import Trainable
from ray.tune.resources import Resources
from ray.tune.utils.util import merge_dicts
from ray.util import log_once
from ray.util.sgd.torch.distributed_torch_runner import (
DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner)
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
from ray.util.sgd.torch.torch_runner import TorchRunner
from ray.util.sgd.torch.worker_group import LocalWorkerGroup, \
RemoteWorkerGroup, DeactivatedWorkerGroup
from ray.util.sgd.utils import NUM_SAMPLES, BATCH_SIZE
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP, NCCL_TIMEOUT_S
from ray.util.sgd.torch.utils import setup_address
from ray.util.sgd.data import Dataset
logger = logging.getLogger(__name__)
RESIZE_COOLDOWN_S = 10
def _validate_scheduler_step_freq(scheduler_step_freq):
@@ -86,11 +82,17 @@ class TorchTrainer:
that subclasses the TrainingOperator class. This class
will be copied onto all remote workers and used to specify
training components and custom training and validation operations.
initialization_hook (function): A function to call on all training
workers when they are first initialized. This could be useful to
set environment variables for all the worker processes.
config (dict): Custom configuration value to be passed to
all operator constructors.
num_workers (int): the number of workers used in distributed
training. If 1, the worker will not be wrapped with
DistributedDataParallel.
DistributedDataParallel. TorchTrainer will scale down the number
of workers if enough resources are not available, and will scale
back up once they are. The total number of
workers will never exceed `num_workers` amount.
num_cpus_per_worker (int): Sets the cpu requirement for each worker.
use_gpu (bool): Sets resource allocation for workers to 1 GPU
if true, and automatically moves both the model and optimizer
@@ -119,9 +121,13 @@ class TorchTrainer:
``step`` will be called after every optimizer step. If "epoch",
``step`` will be called after one pass of the DataLoader. If
"manual", the scheduler will not be incremented automatically -
you are expected to call ``trainer.update_schedulers`` manually.
you are expected to call ``trainer.update_scheduler`` manually.
If a scheduler is passed in, this value is expected to not be None.
use_local (bool): If True, 1 worker will be a local worker running
on the driver process, and all other workers will be remote. If
False, all workers will be remote. Set this to True for easy
debugging of worker on driver process, but could also
lead to issues with Cuda devices. Defaults to False.
"""
# TODO: Implement autoscaling. If num_workers=-1, the trainer will use as
@@ -146,6 +152,7 @@ class TorchTrainer:
apex_args=None,
add_dist_sampler=True,
scheduler_step_freq=None,
use_local=False,
# Deprecated Args.
num_replicas=None,
batch_size=None,
@@ -168,6 +175,13 @@ class TorchTrainer:
"model_creator, ...) and pass in CustomOperator into "
"TorchTrainer.")
if use_local and log_once("use_local"):
logger.warning("use_local is set to True. This could lead to "
"issues with Cuda devices. If you are seeing this "
"issue, try setting use_local to False. For more "
"information, see "
"https://github.com/ray-project/ray/issues/9202.")
if num_workers > 1 and not dist.is_available():
raise ValueError(
("Distributed PyTorch is not supported on macOS. "
@@ -225,6 +239,7 @@ class TorchTrainer:
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
self.add_dist_sampler = add_dist_sampler
self.use_local = use_local
if apex_args and not isinstance(apex_args, dict):
raise ValueError("apex_args needs to be a dict object.")
@@ -234,9 +249,6 @@ class TorchTrainer:
self._num_failures = 0
self._last_resize = float("-inf")
self.local_worker = DeactivatedRunner()
self.remote_workers = []
if scheduler_step_freq:
_validate_scheduler_step_freq(scheduler_step_freq)
@@ -270,12 +282,10 @@ class TorchTrainer:
return batch_size_per_worker
def _start_workers(self, num_workers):
logger.debug(f"start_workers: Setting %d workers." % num_workers)
worker_config = self.config.copy()
batch_size_per_worker = self._configure_and_split_batch(num_workers)
if batch_size_per_worker:
worker_config[BATCH_SIZE] = batch_size_per_worker
params = dict(
training_operator_cls=self.training_operator_cls,
config=worker_config,
@@ -286,57 +296,68 @@ class TorchTrainer:
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
if num_workers == 1:
# Start local worker
self.local_worker = TorchRunner(**params)
if self.initialization_hook:
self.apply_all_workers(self.initialization_hook)
self.local_worker.setup_operator()
dist_params = dict(
backend=self.backend,
add_dist_sampler=self.add_dist_sampler,
wrap_ddp=self.wrap_ddp)
worker_args = {
"max_workers": num_workers,
"params": params,
"dist_params": dist_params,
"initialization_hook": self.initialization_hook,
"num_cpus_per_worker": self.num_cpus_per_worker,
"use_gpu": self.use_gpu,
"timeout_s": self.timeout_s
}
if self.use_local:
self.worker_group = LocalWorkerGroup(**worker_args)
else:
params.update(
backend=self.backend,
add_dist_sampler=self.add_dist_sampler,
wrap_ddp=self.wrap_ddp)
self.worker_group = RemoteWorkerGroup(**worker_args)
# Start local worker
self.local_worker = LocalDistributedRunner(
num_cpus=self.num_cpus_per_worker,
num_gpus=int(self.use_gpu),
**params)
# TODO(amogkam): If not enough resources are available to create
# num_workers workers, this command will hang. Instead,
# start_workers should take into account available resources when
# determining how many workers to create.
self.worker_group.start_workers(num_workers)
# Generate actor class
RemoteRunner = ray.remote(
num_cpus=self.num_cpus_per_worker,
num_gpus=int(self.use_gpu))(DistributedTorchRunner)
# Start workers
self.remote_workers = [
RemoteRunner.remote(**params) for i in range(num_workers - 1)
]
if self.initialization_hook:
self.apply_all_workers(self.initialization_hook)
def _resize_worker_group(self, max_retries=10):
"""Resizes the number of remote workers based on available resources.
Total number of workers will never exceed `num_workers` amount.
# Compute URL for initializing distributed PyTorch
address = setup_address()
Args:
max_retries (int): How many times to attempt to resize workers
before failing.
"""
state_dict = self.state_dict()
old_workers = self.worker_group.num_workers
self.worker_group.reset()
# Setup the process group among all workers.
remote_pgroup_setups = [
worker.setup_process_group.remote(address, i + 1, num_workers,
timedelta(self.timeout_s))
for i, worker in enumerate(self.remote_workers)
]
self.local_worker.setup_process_group(address, 0, num_workers,
timedelta(self.timeout_s))
# Get setup tasks in order to throw errors on failure
ray.get(remote_pgroup_setups)
# Runs code that requires all creator functions to have run.
remote_operator_setups = [
worker.setup_operator.remote()
for worker in self.remote_workers
]
self.local_worker.setup_operator()
# Get setup tasks in order to throw errors on failure
ray.get(remote_operator_setups)
time.sleep(1)
for i in range(max_retries):
new_workers = self.worker_group.new_workers_size()
if new_workers:
self._last_resize = time.time()
self._start_workers(int(new_workers))
self.load_state_dict(state_dict, blocking=True)
if self.use_local and new_workers == 1 and old_workers > 1:
# Major hack. If we go from LocalDistributedRunner to a
# standard TorchRunner we have to manually reset the
# dummy actor handle global vars.
# TODO(amog): Refactor LocalDistributedTorchRunner to
# not use global variables for resource reservation.
ray.util.sgd.torch.distributed_torch_runner\
._dummy_cuda_actor = None
ray.util.sgd.torch.distributed_torch_runner\
._dummy_cpu_actor = None
return
else:
delay = 2**i
logger.warning(
"No new workers found. Retrying in %d sec." % delay)
time.sleep(delay)
raise RuntimeError("Exceeded max_retries for relaunching workers.")
def train(self,
num_steps=None,
@@ -384,10 +405,10 @@ class TorchTrainer:
assert isinstance(dataset, Dataset) is not None \
or self.data_creator, \
"Must specify either a data creator or a dataset"
if self._should_resize():
if self.worker_group.should_scale_up():
logger.info("Resize opportunity detected. Attempting to scale up.")
self._resize_workers()
success, worker_stats = self._train_epoch(
self._resize_worker_group()
success, worker_stats = self.worker_group.train(
num_steps=num_steps, profile=profile, info=info, dataset=dataset)
# Fault handling
for i in range(max_retries):
@@ -395,10 +416,10 @@ class TorchTrainer:
break
else:
self._num_failures += 1
self._resize_workers()
self._resize_worker_group()
logger.info("Retrying training step with %d workers." %
(len(self.remote_workers) + 1))
success, worker_stats = self._train_epoch(
self.worker_group.num_workers)
success, worker_stats = self.worker_group.train(
num_steps=num_steps,
profile=profile,
info=info,
@@ -425,43 +446,6 @@ class TorchTrainer:
stats[stat_key] = worker_stats[0][stat_key]
return stats
def _train_epoch(self,
num_steps=None,
profile=False,
info=None,
dataset=None):
params = dict(num_steps=num_steps, profile=profile, info=info)
remote_worker_stats = []
if dataset:
dataset.set_num_shards(self.max_replicas)
for i, w in enumerate(self.remote_workers):
params = dict(num_steps=num_steps, profile=profile, info=info)
if dataset:
params["iterator"] = dataset.get_shard(i)
stats = w.train_epoch.remote(**params)
remote_worker_stats.append(stats)
try:
if dataset:
params["iterator"] = dataset.get_shard(
len(self.remote_workers))
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError as err:
if "gloo" in err.args[0] and "Timed out" in err.args[0]:
logger.warning(err)
return False, None
if "NCCL" in err.args[0]: # there is no specific error message
logger.warning(err)
return False, None
raise err
success = check_for_failure(remote_worker_stats)
if success:
return success, [local_worker_stats] + ray.get(remote_worker_stats)
return success, None
def apply_all_workers(self, fn):
"""Run a function on all operators on the workers.
@@ -472,9 +456,7 @@ class TorchTrainer:
A list of objects returned by ``fn`` on each worker.
"""
remote_calls = [w.apply.remote(fn) for w in self.remote_workers]
local_call = self.local_worker.apply(fn)
return [local_call] + ray.get(remote_calls)
return self.worker_group.apply_all_workers(fn)
def apply_all_operators(self, fn):
"""Run a function on all operators on the workers.
@@ -487,11 +469,7 @@ class TorchTrainer:
A list of objects returned by ``fn`` on each operator.
"""
remote_calls = [
w.apply_operator.remote(fn) for w in self.remote_workers
]
local_call = self.local_worker.apply_operator(fn)
return [local_call] + ray.get(remote_calls)
return self.worker_group.apply_all_operators(fn)
def validate(self,
num_steps=None,
@@ -517,13 +495,8 @@ class TorchTrainer:
You can provide custom metrics by passing in a custom
``training_operator_cls``.
"""
params = dict(num_steps=num_steps, profile=profile, info=info)
remote_worker_stats = [
w.validate.remote(**params) for w in self.remote_workers
]
local_worker_stats = self.local_worker.validate(**params)
worker_stats = [local_worker_stats] + ray.get(remote_worker_stats)
worker_stats = self.worker_group.validate(
num_steps=num_steps, profile=profile, info=info)
if reduce_results:
return self._process_stats(worker_stats)
@@ -535,13 +508,14 @@ class TorchTrainer:
This is useful for lr_schedulers such as ``ReduceLROnPlateau``.
"""
self.apply_all_operators(
self.worker_group.apply_all_operators(
lambda op: [sched.step(metric) for sched in op._schedulers])
def get_model(self):
"""Returns the learned model(s)."""
unwrapped = []
for model in self.local_worker.models:
models = self.worker_group.get_model()
for model in models:
unwrapped += [model.module if hasattr(model, "module") else model]
if len(unwrapped) == 1:
return unwrapped[0]
@@ -556,23 +530,13 @@ class TorchTrainer:
Returns:
TrainingOperator: The local TrainingOperator object.
"""
return self.local_worker.training_operator
return self.worker_group.get_local_operator()
def state_dict(self):
return self.local_worker.state_dict()
return self.worker_group.state_dict()
def load_state_dict(self, state_dict, blocking=False):
# This is not the most efficient because you have to wait for
# the local worker to save then dump to buffer.
self.local_worker.load_state_dict(state_dict)
state_id = ray.put(self.local_worker.state_stream())
remote_calls = [
worker.load_state_stream.remote(state_id)
for worker in self.remote_workers
]
if blocking:
ray.get(remote_calls)
self.worker_group.load_state_dict(state_dict, blocking=blocking)
def save(self, checkpoint):
"""Saves the Trainer state to the provided checkpoint path.
@@ -596,81 +560,16 @@ class TorchTrainer:
raise DeprecationWarning("Use `TorchTrainer.load()` instead.")
def shutdown(self, force=False):
"""Shuts down workers and releases resources."""
if not force:
cleanup = [
worker.shutdown.remote() for worker in self.remote_workers
]
self.local_worker.shutdown()
try:
ray.get(cleanup)
[
worker.__ray_terminate__.remote()
for worker in self.remote_workers
]
except RayActorError:
logger.warning(
"Failed to shutdown gracefully, forcing a shutdown.")
"""Shuts down workers and releases resources.
for worker in self.remote_workers:
logger.warning(f"Killing worker {worker}.")
ray.kill(worker)
else:
self.local_worker.shutdown()
for worker in self.remote_workers:
logger.debug(f"Killing worker {worker}.")
ray.kill(worker)
Args:
force (bool): If True, forcefully kill all workers. If False,
attempt a graceful shutdown first, and then forcefully kill if
unsuccessful.
self.local_worker = DeactivatedRunner()
self.remote_workers = []
def _reset(self):
"""Terminates models without giving up local resource reservation."""
self.local_worker.shutdown(cleanup=False)
for worker in self.remote_workers:
logger.debug(f"Killing worker {worker}.")
ray.kill(worker)
self.local_worker = DeactivatedRunner()
self.remote_workers = []
def _check_potential_remote_workers_size(self):
# ASSUME 1 GPU + 1 CPU is already reserved for the local worker
remote_resources = ray.available_resources()
max_remote_workers = self.max_replicas - 1
new_remote_workers = min(
remote_resources.get("CPU", 0), max_remote_workers)
if self.use_gpu:
new_remote_workers = min(
remote_resources.get("GPU", 0), new_remote_workers)
return new_remote_workers
def _resize_workers(self, max_retries=10):
self._reset()
time.sleep(1)
for i in range(max_retries):
new_remote_workers = self._check_potential_remote_workers_size()
if new_remote_workers:
self._last_resize = time.time()
self._start_workers(int(new_remote_workers) + 1)
self.load_state_dict(self.state_dict())
return
else:
delay = 2**i
logger.warning(
"No new workers found. Retrying in %d sec." % delay)
time.sleep(delay)
raise RuntimeError("Exceeded max_retries for relaunching workers.")
def _should_resize(self):
"""Returns True if past cooldown and exists resources to scale up."""
worker_gap = self.max_replicas - 1 - len(self.remote_workers)
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
if past_cooldown and worker_gap:
# Assume 1 resource is already reserved for local worker.
potential_remote_size = self._check_potential_remote_workers_size()
return potential_remote_size > 0
return False
"""
self.worker_group.shutdown(force=force)
self.worker_group = DeactivatedWorkerGroup()
@classmethod
def as_trainable(cls, *args, **kwargs):
@@ -698,16 +597,26 @@ class TorchTrainer:
def default_resource_request(cls, config):
num_workers = config.get("num_workers",
kwargs.get("num_workers", 1))
num_cpus = config.get("num_cpus_per_worker",
kwargs.get("num_cpus_per_worker", 1))
num_cpus_per_worker = config.get(
"num_cpus_per_worker", kwargs.get("num_cpus_per_worker",
1))
use_gpu = config.get("use_gpu", kwargs.get("use_gpu"))
use_local = config.get("use_local",
kwargs.get("use_local", False))
remote_worker_count = num_workers - 1
if use_local:
remote_worker_count = num_workers - 1
local_cpus = 1
local_gpus = int(use_gpu)
else:
remote_worker_count = num_workers
local_cpus = 0
local_gpus = 0
return Resources(
cpu=num_cpus,
gpu=int(use_gpu),
extra_cpu=int(remote_worker_count),
cpu=int(local_cpus * num_cpus_per_worker),
gpu=int(local_gpus),
extra_cpu=int(remote_worker_count * num_cpus_per_worker),
extra_gpu=int(int(use_gpu) * remote_worker_count))
def _create_trainer(self, tune_config):
@@ -255,8 +255,8 @@ class TrainingOperator:
else:
self._criterion = None
logger.debug("Setting up Apex.")
if self.use_fp16 and amp:
logger.debug("Setting up Apex.")
self._models, self._optimizers = amp.initialize(
self._models, self._optimizers, **self._apex_args)
self._amp = amp
@@ -649,7 +649,9 @@ class TrainingOperator:
"""Override this to return a representation of the operator state.
Any argument passed into self.register and self.register_data will
automatically be saved.
Use this method to save any additional state.
Use this method to save any additional state. If your TorchTrainer
is on a CPU-only machine, make sure this method converts all state
to be CPU-compatible.
Returns:
dict: The state dict of the operator."""
+1 -1
View File
@@ -1,8 +1,8 @@
import os
import logging
import torch.distributed as dist
import ray
import torch.distributed as dist
from ray.util.sgd.utils import find_free_port
logger = logging.getLogger(__name__)
+574
View File
@@ -0,0 +1,574 @@
import io
import logging
import time
from datetime import timedelta
import ray
import torch
from ray.exceptions import RayActorError
from ray.util.sgd.torch.distributed_torch_runner import \
LocalDistributedRunner, DistributedTorchRunner
from ray.util.sgd.torch.torch_runner import TorchRunner
from ray.util.sgd.torch.utils import setup_address
from ray.util.sgd.utils import check_for_failure
RESIZE_COOLDOWN_S = 10
logger = logging.getLogger(__name__)
class WorkerGroupInterface:
"""Manages a group of TorchRunner workers."""
def start_workers(self, num_workers):
"""Start workers for training.
This method has 4 steps.
1. Creates `num_workers` TorchRunner objects, either all as remote
actors or 1 locally and all but one remote.
2. If necessary, applies an initialization hook to all the
workers.
3. Sets up the process group for all workers if running in
distributed setting.
4. Instantiates the user provided TrainingOperator on all
workers to set up training state.
"""
raise NotImplementedError
def apply_all_operators(self, fn):
"""See TorchTrainer.apply_all_operators."""
raise NotImplementedError
def apply_all_workers(self, fn):
"""See TorchTrainer.apply_all_workers."""
raise NotImplementedError
def get_local_operator(self):
"""See TorchTrainer.get_local_operator."""
raise NotImplementedError
def get_model(self):
"""See TorchTrainer.get_model."""
raise NotImplementedError
def load_state_dict(self, state_dict, blocking=False):
"""See TorchTrainer.load_state_dict."""
raise NotImplementedError
def new_workers_size(self):
"""Returns number of workers to create based on available resources.
Total number of workers will never exceed `max_workers` amount.
"""
raise NotImplementedError
def reset(self):
"""Resets worker group."""
raise NotImplementedError
def should_scale_up(self):
"""Returns whether to scale up the number of remote workers.
This method returns True if current number of workers is less than
max_workers provided during startup, enough resources are
available to scale up, and a sufficient cooldown period has passed.
"""
raise NotImplementedError
def shutdown(self, force=False):
"""See TorchTrainer.shutdown."""
raise NotImplementedError
def state_dict(self):
"""See TorchTrainer.state_dict."""
raise NotImplementedError
def train(self, num_steps=None, profile=False, info=None, dataset=None):
"""Runs one epoch of training on all workers.
Args:
See TorchTrainer.train.
Returns:
Tuple of (bool, list). First value is True if training was
successful and False otherwise. Second value is list of
training results from all workers if training was successful,
and None otherwise.
"""
raise NotImplementedError
def validate(self, num_steps=None, profile=False, info=None):
"""Runs validation for all workers.
Args:
See TorchTrainer.validate.
Return:
List of validation results for each worker.
"""
raise NotImplementedError
class RemoteWorkerGroup(WorkerGroupInterface):
"""A group of TorchRunner workers that are all remote Ray actors.
Args:
max_workers (int): Maximum number of workers to use.
params (dict): Parameters to pass into a TorchRunner worker.
dist_params (dict): Additional parameters for distributed training
to pass into a DistributedTorchRunner worker.
initialization_hook (Callable): See TorchTrainer.__init__.
timeout_s (float): See TorchTrainer.__init__.
num_cpus_per_worker (int): See TorchTrainer.__init__.
use_gpu (bool): See TorchTrainer.__init__.
"""
def __init__(self, max_workers, params, dist_params, initialization_hook,
timeout_s, num_cpus_per_worker, use_gpu):
# Invariant: These variables should never change state!
self._max_workers = max_workers
self._params = params
self._dist_params = dist_params
self._initialization_hook = initialization_hook
self._timeout_s = timeout_s
self._num_cpus_per_worker = num_cpus_per_worker
self._use_gpu = use_gpu
self.remote_workers = []
# The last time when this worker group was resized.
self._last_resize = float("-inf")
def _init_dist_workers(self, num_workers):
"""Create `num_workers` remote workers."""
# Generate actor class
RemoteRunner = ray.remote(
num_cpus=self._num_cpus_per_worker,
num_gpus=int(self._use_gpu))(DistributedTorchRunner)
# Start workers
self.remote_workers = [
RemoteRunner.remote(**{
**self._params,
**self._dist_params
}) for _ in range(num_workers)
]
def _setup_process_group(self, address, world_size, starting_rank=0):
"""Sets up process group for all workers.
Args:
address (str): Address to use for TCP process group setup. The
provided address must use the IP address of the node that the
rank 0 worker is located on.
world_size (int): Total number of training workers in the
process group. This may differ from self.num_workers if
there are additional workers outside this worker group class.
starting_rank (int): The rank to use for the first worker.
Worker ranks will be in [starting_rank,
len(self.remote_workers)+starting_rank). This is useful if
you want to use worker outside of this group as the rank 0
worker.
Returns:
List of process group set up promises.
"""
# Setup the process group among all workers.
remote_pgroup_setups = [
worker.setup_process_group.remote(
url=address,
world_rank=i + starting_rank,
world_size=world_size,
timeout=timedelta(self._timeout_s))
for i, worker in enumerate(self.remote_workers)
]
return remote_pgroup_setups
def _setup_operator(self):
"""Instantiates user provided TrainingOperator on all workers.
Returns:
List of operator setup promises.
"""
# Runs code that requires all creator functions to have run.
remote_operator_setups = [
worker.setup_operator.remote() for worker in self.remote_workers
]
return remote_operator_setups
def start_workers(self, num_workers):
logger.debug(f"start_workers: Setting %d workers." % num_workers)
if num_workers == 1:
RemoteRunner = ray.remote(
num_cpus=self._num_cpus_per_worker,
num_gpus=int(self._use_gpu))(TorchRunner)
self.remote_workers = [RemoteRunner.remote(**self._params)]
ray.get(self.remote_workers[0].setup_operator.remote())
else:
self._init_dist_workers(num_workers)
if self._initialization_hook:
self.apply_all_workers(self._initialization_hook)
# Make sure to get the IP address of the rank 0 worker node.
address = ray.get(self.remote_workers[0].setup_address.remote())
ray.get(
self._setup_process_group(
address=address, world_size=num_workers))
ray.get(self._setup_operator())
def _apply_all_operators(self, fn):
remote_calls = [
w.apply_operator.remote(fn) for w in self.remote_workers
]
return remote_calls
def apply_all_operators(self, fn):
return ray.get(self._apply_all_operators(fn))
def _apply_all_workers(self, fn):
return [w.apply.remote(fn) for w in self.remote_workers]
def apply_all_workers(self, fn):
return ray.get(self._apply_all_workers(fn))
def get_local_operator(self):
raise NotImplementedError(
"Cannot return a local operators if all "
"workers are remote. Set use_local to True in"
"TorchTrainer to access a local operator.")
def get_model(self):
ready, _ = ray.wait(
[r.get_models.remote() for r in self.remote_workers])
models = ray.get(ready[0])
return models
def _load_state_id(self, state_id):
"""Loads the object with id `state_id` to all workers."""
remote_calls = [
worker.load_state_stream.remote(state_id)
for worker in self.remote_workers
]
return remote_calls
def load_state_dict(self, state_dict, blocking=False):
_buffer = io.BytesIO()
torch.save(state_dict, _buffer)
stream = _buffer.getvalue()
state_id = ray.put(stream)
remote_calls = self._load_state_id(state_id)
if blocking:
ray.get(remote_calls)
def state_dict(self):
# This is needed to handle calling ray.get on a dead actor.
buffer_object = None
futures = {r.state_stream.remote() for r in self.remote_workers}
while len(futures) > 0:
ready, _ = ray.wait(list(futures), num_returns=1)
object_ref = ready[0]
try:
buffer_object = ray.get(object_ref)
except RayActorError:
futures.remove(object_ref)
else:
break
if buffer_object is None:
raise RuntimeError("Obtaining state_dict from remote workers is "
"unsuccessful since all workers have died.")
to_gpu = self._use_gpu and torch.cuda.is_available()
_buffer = io.BytesIO(buffer_object)
state_dict = torch.load(
_buffer,
map_location=("cpu" if not to_gpu else
lambda storage, loc: storage.cuda()))
return state_dict
def _train(self, num_steps, profile, info, dataset=None):
"""Runs 1 epoch of training on all workers.
Returns training result for all workers as promises.
"""
remote_worker_stats = []
for i, w in enumerate(self.remote_workers):
params = dict(num_steps=num_steps, profile=profile, info=info)
if dataset:
params["iterator"] = dataset.get_shard(i)
stats = w.train_epoch.remote(**params)
remote_worker_stats.append(stats)
return remote_worker_stats
def train(self, num_steps=None, profile=False, info=None, dataset=None):
"""Runs 1 epoch of training on all workers.
Has additional logic to check for worker failure.
"""
if dataset:
dataset.set_num_shards(self.num_workers)
remote_worker_stats = self._train(num_steps, profile, info, dataset)
# Check if each worker has failed before calling ray.get.
success = check_for_failure(remote_worker_stats)
if success:
return success, ray.get(remote_worker_stats)
return success, None
def _validate(self, params):
"""Runs validation for each worker. Returns results as promises."""
remote_worker_stats = [
w.validate.remote(**params) for w in self.remote_workers
]
return remote_worker_stats
def validate(self, num_steps=None, profile=False, info=None):
params = dict(num_steps=num_steps, profile=profile, info=info)
remote_worker_stats = self._validate(params)
return ray.get(remote_worker_stats)
def _shutdown_remote_workers(self):
"""Shuts down each worker and returns a list of cleanup promises."""
cleanup = [worker.shutdown.remote() for worker in self.remote_workers]
return cleanup
def _terminate_remote_workers(self, cleanup):
"""Blocks on worker shutdown and then terminates each worker actor.
If graceful shutdown fails, forcefully kills all actors.
"""
try:
ray.get(cleanup)
[
worker.__ray_terminate__.remote()
for worker in self.remote_workers
]
except RayActorError:
logger.warning("Failed to shutdown gracefully, forcing a "
"shutdown.")
self.reset()
def shutdown(self, force=False):
if not force:
cleanup = [
worker.shutdown.remote() for worker in self.remote_workers
]
self._terminate_remote_workers(cleanup)
else:
self.reset()
self.remote_workers = []
def reset(self):
for worker in self.remote_workers:
logger.debug(f"Killing worker {worker}.")
ray.kill(worker)
self.remote_workers = []
def should_scale_up(self):
worker_gap = self._max_workers - self.num_workers
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
if past_cooldown and worker_gap:
# Assume 1 resource is already reserved for local worker.
potential_remote_size = self._check_potential_remote_workers_size()
return potential_remote_size > 0
return False
def new_workers_size(self):
"""Returns number of workers to create based on available resources."""
remote_resources = ray.available_resources()
max_remote_workers = self._max_workers
new_remote_workers = min(
remote_resources.get("CPU", 0), max_remote_workers)
if self._use_gpu:
new_remote_workers = min(
remote_resources.get("GPU", 0), new_remote_workers)
return new_remote_workers
@property
def num_workers(self):
"""Current number of workers being used for training.
This may differ from self.num_workers if self.resize_workers has
been called.
"""
return len(self.remote_workers)
class LocalWorkerGroup(WorkerGroupInterface):
"""A group of TorchRunner workers.
1 worker runs locally, and all the other workers are remote actors.
Args:
Same as RemoteWorkerGroup.
"""
def __init__(self, max_workers, params, dist_params, initialization_hook,
timeout_s, num_cpus_per_worker, use_gpu):
# Invariant: These variables should never change state!
self._max_workers = max_workers
self._params = params
self._dist_params = dist_params
self._initialization_hook = initialization_hook
self._timeout_s = timeout_s
self._num_cpus_per_worker = num_cpus_per_worker
self._use_gpu = use_gpu
self.local_worker = None
self.remote_worker_group = RemoteWorkerGroup(
max_workers=max_workers - 1,
params=params,
dist_params=dist_params,
initialization_hook=initialization_hook,
timeout_s=timeout_s,
num_cpus_per_worker=num_cpus_per_worker,
use_gpu=use_gpu)
def start_workers(self, num_workers):
logger.debug(f"start_workers: Setting %d workers." % num_workers)
if num_workers == 1:
self.local_worker = TorchRunner(**self._params)
if self._initialization_hook:
self.apply_all_workers(self._initialization_hook)
self.local_worker.setup_operator()
else:
# Start local worker
self.local_worker = LocalDistributedRunner(
num_cpus=self._num_cpus_per_worker,
num_gpus=int(self._use_gpu),
**{
**self._params,
**self._dist_params
})
self.remote_worker_group._init_dist_workers(num_workers - 1)
if self._initialization_hook:
self.apply_all_workers(self._initialization_hook)
# Compute URL for initializing distributed PyTorch.
address = setup_address()
remote_pgs = self.remote_worker_group._setup_process_group(
address=address, world_size=num_workers, starting_rank=1)
# Use the local worker as rank 0. This will help with debugging.
self.local_worker.setup_process_group(
url=address,
world_rank=0,
world_size=num_workers,
timeout=timedelta(self._timeout_s))
ray.get(remote_pgs)
remote_operators = self.remote_worker_group._setup_operator()
self.local_worker.setup_operator()
ray.get(remote_operators)
def apply_all_operators(self, fn):
remote_calls = self.remote_worker_group._apply_all_operators(fn)
local_call = self.local_worker.apply_operator(fn)
return [local_call] + ray.get(remote_calls)
def apply_all_workers(self, fn):
remote_calls = self.remote_worker_group.apply_all_workers(fn)
local_call = self.local_worker.apply(fn)
return [local_call] + ray.get(remote_calls)
def get_local_operator(self):
return self.local_worker.training_operator
def get_model(self):
return self.local_worker.models
def load_state_dict(self, state_dict, blocking=False):
# This is not the most efficient because you have to wait for
# the local worker to save then dump to buffer.
self.local_worker.load_state_dict(state_dict)
state_id = ray.put(self.local_worker.state_stream())
remote_calls = self.remote_worker_group._load_state_id(state_id)
if blocking:
ray.get(remote_calls)
def state_dict(self):
return self.local_worker.state_dict()
def should_scale_up(self):
return self.remote_worker_group.should_scale_up()
def reset(self):
"""Terminates models without giving up local resource reservation."""
self.local_worker.shutdown(cleanup=False)
self.remote_worker_group.reset()
self.local_worker = None
self.remote_worker_group = RemoteWorkerGroup(
max_workers=self._max_workers - 1,
params=self._params,
dist_params=self._dist_params,
initialization_hook=self._initialization_hook,
num_cpus_per_worker=self._num_cpus_per_worker,
use_gpu=self._use_gpu,
timeout_s=self._timeout_s)
def new_workers_size(self):
return self.remote_worker_group.new_workers_size() + 1
def train(self, num_steps=None, profile=False, info=None, dataset=None):
params = dict(num_steps=num_steps, profile=profile, info=info)
if dataset:
dataset.set_num_shards(self.num_workers)
remote_worker_stats = self.remote_worker_group._train(
num_steps, profile, info, dataset)
try:
if dataset:
params["iterator"] = dataset.get_shard(self.num_workers - 1)
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError as err:
if "gloo" in err.args[0] and "Timed out" in err.args[0]:
logger.warning(err)
return False, None
if "NCCL" in err.args[0]: # there is no specific error message
logger.warning(err)
return False, None
if "Connection closed by peer" in err.args[0]:
logger.warning(err)
return False, None
raise err
success = check_for_failure(remote_worker_stats)
if success:
return success, [local_worker_stats] + ray.get(remote_worker_stats)
return success, None
def validate(self, num_steps=None, profile=False, info=None):
params = dict(num_steps=num_steps, profile=profile, info=info)
remote_worker_stats = self.remote_worker_group._validate(params)
local_worker_stats = self.local_worker.validate(**params)
worker_stats = [local_worker_stats] + ray.get(remote_worker_stats)
return worker_stats
def shutdown(self, force=False):
if not force:
cleanup = self.remote_worker_group._shutdown_remote_workers()
self.local_worker.shutdown()
self.remote_worker_group._terminate_remote_workers(cleanup)
else:
self.local_worker.shutdown()
self.remote_worker_group.reset()
self.local_worker = None
self.remote_worker_group = DeactivatedWorkerGroup()
@property
def num_workers(self):
return self.remote_worker_group.num_workers + 1
@property
def remote_workers(self):
return self.remote_worker_group.remote_workers
class DeactivatedWorkerGroup:
def __getattr__(self, *args, **kwargs):
raise RuntimeError(
"This TorchTrainer is not active (it is likely shutdown already). "
"Create a new TorchTrainer.")