mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 12:02:04 +08:00
[Ray SGD] use_local flag + Worker group abstraction (#10539)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -75,7 +75,6 @@ class _TorchTrainable(tune.Trainable):
|
||||
remote_trainable = remote_trainable.options(
|
||||
**self.get_remote_worker_options())
|
||||
|
||||
address = setup_address()
|
||||
self.workers = [
|
||||
remote_trainable.remote(
|
||||
config=config,
|
||||
@@ -83,6 +82,10 @@ class _TorchTrainable(tune.Trainable):
|
||||
for rank in range(num_workers)
|
||||
]
|
||||
|
||||
# Address has to be IP of rank 0 worker's node.
|
||||
address = ray.get(
|
||||
self.workers[0].execute.remote(lambda _: setup_address()))
|
||||
|
||||
pgroup_params = self.default_process_group_parameters()
|
||||
from functools import partial
|
||||
setup_on_worker = partial(
|
||||
|
||||
@@ -18,6 +18,14 @@ py_test(
|
||||
deps = [":sgd_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_torch_failure",
|
||||
size = "large",
|
||||
srcs = ["tests/test_torch_failure.py"],
|
||||
tags = ["exclusive", "pytorch"],
|
||||
deps = [":sgd_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_torch_runner",
|
||||
size = "small",
|
||||
|
||||
@@ -84,8 +84,8 @@ class Dataset():
|
||||
Returns a single, iterable shard.
|
||||
"""
|
||||
assert i < self.iter.num_shards(), \
|
||||
"Trying to get shard {} but there are only {} shards." + \
|
||||
"Are you sure you called set_num_shards already?".format(
|
||||
"Trying to get shard {} but there are only {} shards. Are you " \
|
||||
"sure you called set_num_shards already?".format(
|
||||
i, self.iter.num_shards()
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from unittest.mock import patch
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
@@ -14,8 +12,7 @@ from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch.training_operator import (
|
||||
get_test_operator, get_test_metrics_operator, TrainingOperator)
|
||||
from ray.util.sgd.torch.constants import SCHEDULER_STEP
|
||||
from ray.util.sgd.utils import (check_for_failure, NUM_SAMPLES, BATCH_COUNT,
|
||||
BATCH_SIZE)
|
||||
from ray.util.sgd.utils import (NUM_SAMPLES, BATCH_COUNT, BATCH_SIZE)
|
||||
|
||||
from ray.util.sgd.data.examples import mlp_identity
|
||||
from ray.util.sgd.torch.examples.train_example import (
|
||||
@@ -48,8 +45,13 @@ Operator = TrainingOperator.from_creators(
|
||||
model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss)
|
||||
|
||||
|
||||
def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1)
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_single_step(ray_start_2_cpus, use_local): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=Operator,
|
||||
num_workers=1,
|
||||
use_local=use_local,
|
||||
use_gpu=False)
|
||||
metrics = trainer.train(num_steps=1)
|
||||
assert metrics[BATCH_COUNT] == 1
|
||||
|
||||
@@ -58,9 +60,14 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_dead_trainer(ray_start_2_cpus, use_local): # noqa: F811
|
||||
TestOperator = get_test_operator(Operator)
|
||||
trainer = TorchTrainer(training_operator_cls=TestOperator, num_workers=2)
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=2,
|
||||
use_local=use_local,
|
||||
use_gpu=False)
|
||||
trainer.train(num_steps=1)
|
||||
trainer.shutdown()
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -68,9 +75,13 @@ def test_dead_trainer(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_train(ray_start_2_cpus, num_workers, use_local): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=Operator, num_workers=num_workers)
|
||||
training_operator_cls=Operator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local,
|
||||
use_gpu=False)
|
||||
for i in range(3):
|
||||
train_loss1 = trainer.train()["train_loss"]
|
||||
validation_loss1 = trainer.validate()["val_loss"]
|
||||
@@ -86,7 +97,8 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_multi_model(ray_start_2_cpus, num_workers, use_local):
|
||||
def train(*, model=None, criterion=None, optimizer=None, iterator=None):
|
||||
model.train()
|
||||
train_loss = 0
|
||||
@@ -140,7 +152,10 @@ def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
trainer1 = TorchTrainer(
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers)
|
||||
num_workers=num_workers,
|
||||
use_local=use_local,
|
||||
use_gpu=False,
|
||||
)
|
||||
trainer1.train()
|
||||
state = trainer1.state_dict()
|
||||
|
||||
@@ -151,7 +166,10 @@ def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
trainer2 = TorchTrainer(
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers)
|
||||
num_workers=num_workers,
|
||||
use_local=use_local,
|
||||
use_gpu=False,
|
||||
)
|
||||
trainer2.load_state_dict(state)
|
||||
|
||||
models2 = trainer2.get_model()
|
||||
@@ -170,7 +188,9 @@ def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_multi_model_matrix(ray_start_2_cpus, num_workers, use_local): #
|
||||
# noqa: F811
|
||||
def train_epoch(self, iterator, info):
|
||||
if self.config.get("models", 1) > 1:
|
||||
assert len(self.models) == self.config["models"], self.config
|
||||
@@ -231,6 +251,7 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
scheduler_step_freq="epoch",
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local,
|
||||
config={
|
||||
"models": model_count,
|
||||
"optimizers": optimizer_count,
|
||||
@@ -242,7 +263,8 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scheduler_freq", ["epoch", "batch", "manual", None])
|
||||
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa:
|
||||
# F811
|
||||
def train_epoch(self, iterator, info):
|
||||
assert info[SCHEDULER_STEP] == scheduler_freq
|
||||
return {"done": 1}
|
||||
@@ -271,20 +293,23 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=TestTrainingOperator,
|
||||
scheduler_step_freq=scheduler_freq)
|
||||
scheduler_step_freq=scheduler_freq,
|
||||
)
|
||||
else:
|
||||
trainer = TorchTrainer(
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=TestTrainingOperator,
|
||||
scheduler_step_freq=scheduler_freq)
|
||||
scheduler_step_freq=scheduler_freq,
|
||||
)
|
||||
|
||||
for i in range(3):
|
||||
trainer.train()
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_profiling(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(training_operator_cls=Operator)
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_profiling(ray_start_2_cpus, use_local): # noqa: F811
|
||||
trainer = TorchTrainer(training_operator_cls=Operator, use_local=use_local)
|
||||
|
||||
stats = trainer.train(profile=True)
|
||||
assert "profile" in stats
|
||||
@@ -293,7 +318,8 @@ def test_profiling(ray_start_2_cpus): # noqa: F811
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_dataset(ray_start_4_cpus):
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_dataset(ray_start_4_cpus, use_local):
|
||||
"""
|
||||
This test tries training the mlp_identity example. We check the accuracy of
|
||||
the model as an all inclusive way of ensuring that we are properly sharding
|
||||
@@ -312,6 +338,7 @@ def test_dataset(ray_start_4_cpus):
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=DatasetOperator,
|
||||
use_local=use_local,
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
@@ -319,13 +346,14 @@ def test_dataset(ray_start_4_cpus):
|
||||
for i in range(5):
|
||||
trainer.train(dataset=dataset, num_steps=100)
|
||||
|
||||
input = mlp_identity.to_mat(0.5)
|
||||
prediction = float(trainer.get_model()(input)[0][0])
|
||||
x = mlp_identity.to_mat(0.5)
|
||||
prediction = float(trainer.get_model()(x)[0][0])
|
||||
assert 0.4 <= prediction <= 0.6
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_split_batch(ray_start_2_cpus):
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_split_batch(ray_start_2_cpus, use_local):
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
@@ -347,6 +375,7 @@ def test_split_batch(ray_start_2_cpus):
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=2,
|
||||
use_local=use_local,
|
||||
config={
|
||||
BATCH_SIZE: batch_size,
|
||||
"data_size": data_size,
|
||||
@@ -358,7 +387,8 @@ def test_split_batch(ray_start_2_cpus):
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_reduce_result(ray_start_2_cpus):
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_reduce_result(ray_start_2_cpus, use_local):
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
@@ -380,6 +410,7 @@ def test_reduce_result(ray_start_2_cpus):
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=2,
|
||||
use_local=use_local,
|
||||
config={"data_size": data_size})
|
||||
list_stats = trainer.train(reduce_results=False, profile=True)
|
||||
assert len(list_stats) == 2
|
||||
@@ -393,7 +424,8 @@ def test_reduce_result(ray_start_2_cpus):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_metrics(ray_start_2_cpus, num_workers):
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_metrics(ray_start_2_cpus, num_workers, use_local):
|
||||
data_size, val_size = 600, 500
|
||||
batch_size = 4
|
||||
|
||||
@@ -407,6 +439,7 @@ def test_metrics(ray_start_2_cpus, num_workers):
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local,
|
||||
config={
|
||||
"scores": train_scores,
|
||||
"val_scores": val_scores,
|
||||
@@ -440,7 +473,8 @@ def test_metrics(ray_start_2_cpus, num_workers):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_metrics_nan(ray_start_2_cpus, num_workers):
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_metrics_nan(ray_start_2_cpus, num_workers, use_local):
|
||||
data_size, val_size = 100, 100
|
||||
batch_size = 10
|
||||
|
||||
@@ -453,6 +487,7 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local,
|
||||
config={
|
||||
"scores": train_scores,
|
||||
"val_scores": val_scores,
|
||||
@@ -474,7 +509,8 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_scheduler_validate(ray_start_2_cpus, use_local): # noqa: F811
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
@@ -485,7 +521,9 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
TestOperator = get_test_operator(TestOperator)
|
||||
trainer = TorchTrainer(
|
||||
scheduler_step_freq="manual", training_operator_cls=TestOperator)
|
||||
scheduler_step_freq="manual",
|
||||
training_operator_cls=TestOperator,
|
||||
use_local=use_local)
|
||||
trainer.update_scheduler(0.5)
|
||||
trainer.update_scheduler(0.5)
|
||||
assert all(
|
||||
@@ -494,14 +532,16 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
@pytest.mark.parametrize("num_workers", [2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_tune_train(ray_start_4_cpus, num_workers, use_local): # noqa: F811
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
**{
|
||||
"training_operator_cls": Operator,
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": False,
|
||||
"backend": "gloo",
|
||||
"use_local": use_local,
|
||||
"config": {
|
||||
"batch_size": 512,
|
||||
"lr": 0.001
|
||||
@@ -526,10 +566,13 @@ def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_workers,
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_workers, use_local,
|
||||
tmp_path): # noqa: F811
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=Operator, num_workers=num_workers)
|
||||
training_operator_cls=Operator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local)
|
||||
trainer1.train()
|
||||
checkpoint_path = os.path.join(tmp_path, "checkpoint")
|
||||
trainer1.save(checkpoint_path)
|
||||
@@ -539,7 +582,9 @@ def test_save_and_restore(ray_start_2_cpus, num_workers,
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
training_operator_cls=Operator, num_workers=num_workers)
|
||||
training_operator_cls=Operator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local)
|
||||
trainer2.load(checkpoint_path)
|
||||
|
||||
model2 = trainer2.get_model()
|
||||
@@ -558,14 +603,15 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=Operator, wrap_ddp=False, num_workers=2)
|
||||
training_operator_cls=Operator,
|
||||
wrap_ddp=False,
|
||||
num_workers=2,
|
||||
use_local=True)
|
||||
trainer1.train()
|
||||
checkpoint_path = os.path.join(tmp_path, "checkpoint")
|
||||
trainer1.save(checkpoint_path)
|
||||
|
||||
model1 = trainer1.get_model()
|
||||
assert not hasattr(trainer1.local_worker.training_operator.model, "module")
|
||||
assert hasattr(trainer1.local_worker.training_operator, "device_ids")
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
@@ -584,123 +630,8 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
|
||||
trainer2.shutdown()
|
||||
|
||||
|
||||
def gen_step_with_fail(num_fails):
|
||||
def step_with_fail(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
info=None,
|
||||
dataset=None):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
remote_worker_stats = [
|
||||
w.train_epoch.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
|
||||
if self._num_failures < num_fails:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
ray.kill(self.remote_workers[0])
|
||||
|
||||
try:
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError:
|
||||
return False, None
|
||||
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
return success, None
|
||||
|
||||
return step_with_fail
|
||||
|
||||
|
||||
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
step_with_fail = gen_step_with_fail(3)
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
num_workers=2)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer1.train(max_retries=1)
|
||||
|
||||
trainer1.shutdown(force=True)
|
||||
|
||||
|
||||
def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
step_with_fail = gen_step_with_fail(1)
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
num_workers=2)
|
||||
|
||||
@ray.remote
|
||||
def try_test():
|
||||
import time
|
||||
time.sleep(100)
|
||||
|
||||
try_test.remote()
|
||||
trainer1.train(max_retries=1)
|
||||
assert len(trainer1.remote_workers) == 1
|
||||
|
||||
trainer1.shutdown()
|
||||
|
||||
|
||||
def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
step_with_fail = gen_step_with_fail(2)
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
num_workers=2)
|
||||
|
||||
# MAX RETRIES SHOULD BE ON BY DEFAULT
|
||||
trainer1.train()
|
||||
trainer1.shutdown()
|
||||
|
||||
|
||||
def test_multi_input_model(ray_start_2_cpus):
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_multi_input_model(ray_start_2_cpus, use_local):
|
||||
def model_creator(config):
|
||||
class MultiInputModel(nn.Module):
|
||||
def __init__(self):
|
||||
@@ -742,7 +673,8 @@ def test_multi_input_model(ray_start_2_cpus):
|
||||
data_creator,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
|
||||
trainer = TorchTrainer(training_operator_cls=Operator, num_workers=1)
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=Operator, num_workers=1, use_local=use_local)
|
||||
|
||||
metrics = trainer.train(num_steps=1)
|
||||
assert metrics[BATCH_COUNT] == 1
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
import time
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch.worker_group import RemoteWorkerGroup
|
||||
from ray.util.sgd.torch.training_operator import TrainingOperator
|
||||
|
||||
from ray.util.sgd.torch.examples.train_example import (
|
||||
model_creator, optimizer_creator, data_creator, LinearDataset)
|
||||
|
||||
Operator = TrainingOperator.from_creators(
|
||||
model_creator, optimizer_creator, data_creator, loss_creator=nn.MSELoss)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
# Ensure that tests don't ALL fail
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_4_cpus():
|
||||
address_info = ray.init(num_cpus=4)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
# Ensure that tests don't ALL fail
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def remote_worker_train_with_fail(self, num_steps, profile, info,
|
||||
dataset=None):
|
||||
remote_worker_stats = []
|
||||
for i, w in enumerate(self.remote_workers):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
if dataset:
|
||||
params["iterator"] = dataset.get_shard(i)
|
||||
stats = w.train_epoch.remote(**params)
|
||||
remote_worker_stats.append(stats)
|
||||
if i == 0 and hasattr(self, "should_fail") and self.should_fail:
|
||||
time.sleep(1)
|
||||
ray.kill(self.remote_workers[i])
|
||||
return remote_worker_stats
|
||||
|
||||
|
||||
start_workers = TorchTrainer._start_workers
|
||||
|
||||
|
||||
def gen_start_with_fail(num_fails):
|
||||
def start_with_fail(self, *args, **kwargs):
|
||||
start_workers(self, *args, **kwargs)
|
||||
fail = self._num_failures < num_fails
|
||||
if self.use_local:
|
||||
self.worker_group.remote_worker_group.should_fail = fail
|
||||
else:
|
||||
self.worker_group.should_fail = fail
|
||||
|
||||
return start_with_fail
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_local", [False, True])
|
||||
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
|
||||
def test_resize(ray_start_2_cpus, use_local): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
start_with_fail = gen_start_with_fail(1)
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
use_local=use_local,
|
||||
num_workers=2)
|
||||
|
||||
@ray.remote
|
||||
def try_test():
|
||||
import time
|
||||
time.sleep(100)
|
||||
|
||||
try_test.remote()
|
||||
trainer1.train(max_retries=1)
|
||||
assert trainer1.worker_group.num_workers == 1
|
||||
|
||||
trainer1.shutdown(force=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_local", [False, True])
|
||||
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
|
||||
def test_fail_twice(ray_start_2_cpus, use_local): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
|
||||
start_with_fail = gen_start_with_fail(2)
|
||||
|
||||
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
use_local=use_local,
|
||||
num_workers=2)
|
||||
|
||||
# MAX RETRIES SHOULD BE ON BY DEFAULT
|
||||
trainer1.train()
|
||||
trainer1.shutdown(force=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_local", [False, True])
|
||||
@patch.object(RemoteWorkerGroup, "_train", remote_worker_train_with_fail)
|
||||
def test_fail_with_recover(ray_start_2_cpus, use_local): # noqa: F811
|
||||
print(locals())
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
TestOperator = TrainingOperator.from_creators(
|
||||
model_creator,
|
||||
optimizer_creator,
|
||||
single_loader,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
|
||||
start_with_fail = gen_start_with_fail(3)
|
||||
|
||||
with patch.object(TorchTrainer, "_start_workers", start_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
config={"batch_size": 100000},
|
||||
timeout_s=5,
|
||||
use_local=use_local,
|
||||
num_workers=2)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer1.train(max_retries=1)
|
||||
|
||||
trainer1.shutdown(force=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import io
|
||||
import os
|
||||
|
||||
import torch
|
||||
@@ -11,6 +10,8 @@ from ray.util.sgd.torch.utils import setup_process_group
|
||||
import ray
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
|
||||
from ray.util.sgd.torch.utils import setup_address
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -42,6 +43,9 @@ class DistributedTorchRunner(TorchRunner):
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
self.world_rank = None
|
||||
|
||||
def setup_address(self):
|
||||
return setup_address()
|
||||
|
||||
def setup_process_group(self, url, world_rank, world_size, timeout):
|
||||
"""Connects the distributed PyTorch backend.
|
||||
|
||||
@@ -52,6 +56,7 @@ class DistributedTorchRunner(TorchRunner):
|
||||
timeout (timedelta): Seconds for process group
|
||||
operations to timeout.
|
||||
"""
|
||||
logger.info(f"Setting up process group for: {url} [rank={world_rank}]")
|
||||
self.world_rank = world_rank
|
||||
setup_process_group(
|
||||
url, world_rank, world_size, timeout, backend=self.backend)
|
||||
@@ -83,23 +88,6 @@ class DistributedTorchRunner(TorchRunner):
|
||||
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
|
||||
return [0]
|
||||
|
||||
def load_state_stream(self, byte_obj):
|
||||
"""Loads a bytes object the training state dict.
|
||||
|
||||
This is needed because we don't want to deserialize the tensor
|
||||
onto the same device (which is from the driver process). We want to
|
||||
map it onto the actor's specific device.
|
||||
|
||||
From: github.com/pytorch/pytorch/issues/10622#issuecomment-474733769
|
||||
"""
|
||||
_buffer = io.BytesIO(byte_obj)
|
||||
to_gpu = self.use_gpu and torch.cuda.is_available()
|
||||
state_dict = torch.load(
|
||||
_buffer,
|
||||
map_location=("cpu" if not to_gpu else
|
||||
lambda storage, loc: storage.cuda()))
|
||||
return self.load_state_dict(state_dict)
|
||||
|
||||
def _wrap_dataloaders(self):
|
||||
def with_sampler(loader):
|
||||
# Automatically set the DistributedSampler
|
||||
@@ -346,10 +334,3 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
||||
def is_actor(self):
|
||||
actor_id = ray.worker.global_worker.actor_id
|
||||
return actor_id != actor_id.nil()
|
||||
|
||||
|
||||
class DeactivatedRunner:
|
||||
def __getattr__(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"This TorchTrainer is not active (it is likely shutdown already). "
|
||||
"Create a new TorchTrainer.")
|
||||
|
||||
@@ -138,7 +138,7 @@ if __name__ == "__main__":
|
||||
use_gpu=args.use_gpu,
|
||||
scheduler_step_freq="epoch",
|
||||
use_fp16=args.fp16,
|
||||
use_tqdm=True)
|
||||
use_tqdm=False)
|
||||
pbar = trange(args.num_epochs, unit="epoch")
|
||||
for i in pbar:
|
||||
info = {"num_steps": 1} if args.smoke_test else {}
|
||||
|
||||
@@ -144,7 +144,14 @@ def scheduler_creator(optimizer, config):
|
||||
|
||||
# __torch_scheduler_end__
|
||||
|
||||
# __backwards_compat__start
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
# or ray.init(address="auto") to connect to a running cluster.
|
||||
# __torch_ray_end__
|
||||
|
||||
# __backwards_compat_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
MyTrainingOperator = TrainingOperator.from_creators(
|
||||
@@ -157,14 +164,9 @@ trainer = TorchTrainer(
|
||||
scheduler_step_freq="epoch", # if scheduler_creator is passed in
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __backwards_compat_end
|
||||
# __backwards_compat_end__
|
||||
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
# or ray.init(address="auto") to connect to a running cluster.
|
||||
# __torch_ray_end__
|
||||
trainer.shutdown()
|
||||
|
||||
# __torch_trainer_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
@@ -175,3 +177,5 @@ trainer = TorchTrainer(
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __torch_trainer_end__
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
@@ -52,7 +52,7 @@ def tune_example(operator_cls, num_workers=1, use_gpu=False):
|
||||
stop={"training_iteration": 2},
|
||||
verbose=1)
|
||||
|
||||
return analysis.get_best_config(metric="validation_loss", mode="min")
|
||||
return analysis.get_best_config(metric="val_loss", mode="min")
|
||||
# __end_torch_tune_example__
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import io
|
||||
import itertools
|
||||
import torch
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS
|
||||
from ray.util.sgd import utils
|
||||
|
||||
@@ -57,14 +56,6 @@ class TorchRunner:
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
def get_node_ip(self):
|
||||
"""Returns the IP address of the current node."""
|
||||
return ray.services.get_node_ip_address()
|
||||
|
||||
def find_free_port(self):
|
||||
"""Finds a free port on the current node."""
|
||||
return utils.find_free_port()
|
||||
|
||||
def train_epoch(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
@@ -132,11 +123,15 @@ class TorchRunner:
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the runner."""
|
||||
model_states = [model.state_dict() for model in self.models]
|
||||
optimizer_states = [
|
||||
optimizer.state_dict() for optimizer in self.optimizers
|
||||
]
|
||||
state = {
|
||||
"epoch": self.epochs,
|
||||
"operator": self.training_operator.state_dict(),
|
||||
"models": [model.state_dict() for model in self.models],
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers]
|
||||
"models": model_states,
|
||||
"optimizers": optimizer_states
|
||||
}
|
||||
schedulers = self.schedulers
|
||||
if schedulers:
|
||||
@@ -148,6 +143,7 @@ class TorchRunner:
|
||||
# Check if fp16 is True and if NVIDIA Apex is imported.
|
||||
if self.use_fp16 and self.training_operator._amp:
|
||||
state.update({"amp": self.training_operator._amp.state_dict()})
|
||||
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state):
|
||||
@@ -176,9 +172,20 @@ class TorchRunner:
|
||||
return _buffer.getvalue()
|
||||
|
||||
def load_state_stream(self, byte_obj):
|
||||
"""Loads a bytes object the training state dict."""
|
||||
"""Loads a bytes object the training state dict.
|
||||
|
||||
This is needed because we don't want to deserialize the tensor
|
||||
onto the same device (which is from the driver process). We want to
|
||||
map it onto the actor's specific device.
|
||||
|
||||
From: github.com/pytorch/pytorch/issues/10622#issuecomment-474733769
|
||||
"""
|
||||
_buffer = io.BytesIO(byte_obj)
|
||||
state_dict = torch.load(_buffer)
|
||||
to_gpu = self.use_gpu and torch.cuda.is_available()
|
||||
state_dict = torch.load(
|
||||
_buffer,
|
||||
map_location=("cpu" if not to_gpu else
|
||||
lambda storage, loc: storage.cuda()))
|
||||
return self.load_state_dict(state_dict)
|
||||
|
||||
def apply(self, fn):
|
||||
@@ -193,6 +200,10 @@ class TorchRunner:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_models(self):
|
||||
"""Getter method. Needed for remote actor calls."""
|
||||
return self.models
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
if not hasattr(self.training_operator, "_original_models"):
|
||||
|
||||
@@ -1,29 +1,25 @@
|
||||
from datetime import timedelta
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
import numbers
|
||||
import tempfile
|
||||
import time
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.utils.util import merge_dicts
|
||||
from ray.util import log_once
|
||||
from ray.util.sgd.torch.distributed_torch_runner import (
|
||||
DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner)
|
||||
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
from ray.util.sgd.torch.worker_group import LocalWorkerGroup, \
|
||||
RemoteWorkerGroup, DeactivatedWorkerGroup
|
||||
from ray.util.sgd.utils import NUM_SAMPLES, BATCH_SIZE
|
||||
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP, NCCL_TIMEOUT_S
|
||||
from ray.util.sgd.torch.utils import setup_address
|
||||
from ray.util.sgd.data import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RESIZE_COOLDOWN_S = 10
|
||||
|
||||
|
||||
def _validate_scheduler_step_freq(scheduler_step_freq):
|
||||
@@ -86,11 +82,17 @@ class TorchTrainer:
|
||||
that subclasses the TrainingOperator class. This class
|
||||
will be copied onto all remote workers and used to specify
|
||||
training components and custom training and validation operations.
|
||||
initialization_hook (function): A function to call on all training
|
||||
workers when they are first initialized. This could be useful to
|
||||
set environment variables for all the worker processes.
|
||||
config (dict): Custom configuration value to be passed to
|
||||
all operator constructors.
|
||||
num_workers (int): the number of workers used in distributed
|
||||
training. If 1, the worker will not be wrapped with
|
||||
DistributedDataParallel.
|
||||
DistributedDataParallel. TorchTrainer will scale down the number
|
||||
of workers if enough resources are not available, and will scale
|
||||
back up once they are. The total number of
|
||||
workers will never exceed `num_workers` amount.
|
||||
num_cpus_per_worker (int): Sets the cpu requirement for each worker.
|
||||
use_gpu (bool): Sets resource allocation for workers to 1 GPU
|
||||
if true, and automatically moves both the model and optimizer
|
||||
@@ -119,9 +121,13 @@ class TorchTrainer:
|
||||
``step`` will be called after every optimizer step. If "epoch",
|
||||
``step`` will be called after one pass of the DataLoader. If
|
||||
"manual", the scheduler will not be incremented automatically -
|
||||
you are expected to call ``trainer.update_schedulers`` manually.
|
||||
you are expected to call ``trainer.update_scheduler`` manually.
|
||||
If a scheduler is passed in, this value is expected to not be None.
|
||||
|
||||
use_local (bool): If True, 1 worker will be a local worker running
|
||||
on the driver process, and all other workers will be remote. If
|
||||
False, all workers will be remote. Set this to True for easy
|
||||
debugging of worker on driver process, but could also
|
||||
lead to issues with Cuda devices. Defaults to False.
|
||||
"""
|
||||
|
||||
# TODO: Implement autoscaling. If num_workers=-1, the trainer will use as
|
||||
@@ -146,6 +152,7 @@ class TorchTrainer:
|
||||
apex_args=None,
|
||||
add_dist_sampler=True,
|
||||
scheduler_step_freq=None,
|
||||
use_local=False,
|
||||
# Deprecated Args.
|
||||
num_replicas=None,
|
||||
batch_size=None,
|
||||
@@ -168,6 +175,13 @@ class TorchTrainer:
|
||||
"model_creator, ...) and pass in CustomOperator into "
|
||||
"TorchTrainer.")
|
||||
|
||||
if use_local and log_once("use_local"):
|
||||
logger.warning("use_local is set to True. This could lead to "
|
||||
"issues with Cuda devices. If you are seeing this "
|
||||
"issue, try setting use_local to False. For more "
|
||||
"information, see "
|
||||
"https://github.com/ray-project/ray/issues/9202.")
|
||||
|
||||
if num_workers > 1 and not dist.is_available():
|
||||
raise ValueError(
|
||||
("Distributed PyTorch is not supported on macOS. "
|
||||
@@ -225,6 +239,7 @@ class TorchTrainer:
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
self.use_local = use_local
|
||||
|
||||
if apex_args and not isinstance(apex_args, dict):
|
||||
raise ValueError("apex_args needs to be a dict object.")
|
||||
@@ -234,9 +249,6 @@ class TorchTrainer:
|
||||
self._num_failures = 0
|
||||
self._last_resize = float("-inf")
|
||||
|
||||
self.local_worker = DeactivatedRunner()
|
||||
self.remote_workers = []
|
||||
|
||||
if scheduler_step_freq:
|
||||
_validate_scheduler_step_freq(scheduler_step_freq)
|
||||
|
||||
@@ -270,12 +282,10 @@ class TorchTrainer:
|
||||
return batch_size_per_worker
|
||||
|
||||
def _start_workers(self, num_workers):
|
||||
logger.debug(f"start_workers: Setting %d workers." % num_workers)
|
||||
worker_config = self.config.copy()
|
||||
batch_size_per_worker = self._configure_and_split_batch(num_workers)
|
||||
if batch_size_per_worker:
|
||||
worker_config[BATCH_SIZE] = batch_size_per_worker
|
||||
|
||||
params = dict(
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
@@ -286,57 +296,68 @@ class TorchTrainer:
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
if num_workers == 1:
|
||||
# Start local worker
|
||||
self.local_worker = TorchRunner(**params)
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
self.local_worker.setup_operator()
|
||||
dist_params = dict(
|
||||
backend=self.backend,
|
||||
add_dist_sampler=self.add_dist_sampler,
|
||||
wrap_ddp=self.wrap_ddp)
|
||||
|
||||
worker_args = {
|
||||
"max_workers": num_workers,
|
||||
"params": params,
|
||||
"dist_params": dist_params,
|
||||
"initialization_hook": self.initialization_hook,
|
||||
"num_cpus_per_worker": self.num_cpus_per_worker,
|
||||
"use_gpu": self.use_gpu,
|
||||
"timeout_s": self.timeout_s
|
||||
}
|
||||
|
||||
if self.use_local:
|
||||
self.worker_group = LocalWorkerGroup(**worker_args)
|
||||
else:
|
||||
params.update(
|
||||
backend=self.backend,
|
||||
add_dist_sampler=self.add_dist_sampler,
|
||||
wrap_ddp=self.wrap_ddp)
|
||||
self.worker_group = RemoteWorkerGroup(**worker_args)
|
||||
|
||||
# Start local worker
|
||||
self.local_worker = LocalDistributedRunner(
|
||||
num_cpus=self.num_cpus_per_worker,
|
||||
num_gpus=int(self.use_gpu),
|
||||
**params)
|
||||
# TODO(amogkam): If not enough resources are available to create
|
||||
# num_workers workers, this command will hang. Instead,
|
||||
# start_workers should take into account available resources when
|
||||
# determining how many workers to create.
|
||||
self.worker_group.start_workers(num_workers)
|
||||
|
||||
# Generate actor class
|
||||
RemoteRunner = ray.remote(
|
||||
num_cpus=self.num_cpus_per_worker,
|
||||
num_gpus=int(self.use_gpu))(DistributedTorchRunner)
|
||||
# Start workers
|
||||
self.remote_workers = [
|
||||
RemoteRunner.remote(**params) for i in range(num_workers - 1)
|
||||
]
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
def _resize_worker_group(self, max_retries=10):
|
||||
"""Resizes the number of remote workers based on available resources.
|
||||
Total number of workers will never exceed `num_workers` amount.
|
||||
|
||||
# Compute URL for initializing distributed PyTorch
|
||||
address = setup_address()
|
||||
Args:
|
||||
max_retries (int): How many times to attempt to resize workers
|
||||
before failing.
|
||||
"""
|
||||
state_dict = self.state_dict()
|
||||
old_workers = self.worker_group.num_workers
|
||||
self.worker_group.reset()
|
||||
|
||||
# Setup the process group among all workers.
|
||||
remote_pgroup_setups = [
|
||||
worker.setup_process_group.remote(address, i + 1, num_workers,
|
||||
timedelta(self.timeout_s))
|
||||
for i, worker in enumerate(self.remote_workers)
|
||||
]
|
||||
self.local_worker.setup_process_group(address, 0, num_workers,
|
||||
timedelta(self.timeout_s))
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(remote_pgroup_setups)
|
||||
|
||||
# Runs code that requires all creator functions to have run.
|
||||
remote_operator_setups = [
|
||||
worker.setup_operator.remote()
|
||||
for worker in self.remote_workers
|
||||
]
|
||||
self.local_worker.setup_operator()
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(remote_operator_setups)
|
||||
time.sleep(1)
|
||||
for i in range(max_retries):
|
||||
new_workers = self.worker_group.new_workers_size()
|
||||
if new_workers:
|
||||
self._last_resize = time.time()
|
||||
self._start_workers(int(new_workers))
|
||||
self.load_state_dict(state_dict, blocking=True)
|
||||
if self.use_local and new_workers == 1 and old_workers > 1:
|
||||
# Major hack. If we go from LocalDistributedRunner to a
|
||||
# standard TorchRunner we have to manually reset the
|
||||
# dummy actor handle global vars.
|
||||
# TODO(amog): Refactor LocalDistributedTorchRunner to
|
||||
# not use global variables for resource reservation.
|
||||
ray.util.sgd.torch.distributed_torch_runner\
|
||||
._dummy_cuda_actor = None
|
||||
ray.util.sgd.torch.distributed_torch_runner\
|
||||
._dummy_cpu_actor = None
|
||||
return
|
||||
else:
|
||||
delay = 2**i
|
||||
logger.warning(
|
||||
"No new workers found. Retrying in %d sec." % delay)
|
||||
time.sleep(delay)
|
||||
raise RuntimeError("Exceeded max_retries for relaunching workers.")
|
||||
|
||||
def train(self,
|
||||
num_steps=None,
|
||||
@@ -384,10 +405,10 @@ class TorchTrainer:
|
||||
assert isinstance(dataset, Dataset) is not None \
|
||||
or self.data_creator, \
|
||||
"Must specify either a data creator or a dataset"
|
||||
if self._should_resize():
|
||||
if self.worker_group.should_scale_up():
|
||||
logger.info("Resize opportunity detected. Attempting to scale up.")
|
||||
self._resize_workers()
|
||||
success, worker_stats = self._train_epoch(
|
||||
self._resize_worker_group()
|
||||
success, worker_stats = self.worker_group.train(
|
||||
num_steps=num_steps, profile=profile, info=info, dataset=dataset)
|
||||
# Fault handling
|
||||
for i in range(max_retries):
|
||||
@@ -395,10 +416,10 @@ class TorchTrainer:
|
||||
break
|
||||
else:
|
||||
self._num_failures += 1
|
||||
self._resize_workers()
|
||||
self._resize_worker_group()
|
||||
logger.info("Retrying training step with %d workers." %
|
||||
(len(self.remote_workers) + 1))
|
||||
success, worker_stats = self._train_epoch(
|
||||
self.worker_group.num_workers)
|
||||
success, worker_stats = self.worker_group.train(
|
||||
num_steps=num_steps,
|
||||
profile=profile,
|
||||
info=info,
|
||||
@@ -425,43 +446,6 @@ class TorchTrainer:
|
||||
stats[stat_key] = worker_stats[0][stat_key]
|
||||
return stats
|
||||
|
||||
def _train_epoch(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
info=None,
|
||||
dataset=None):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
remote_worker_stats = []
|
||||
if dataset:
|
||||
dataset.set_num_shards(self.max_replicas)
|
||||
for i, w in enumerate(self.remote_workers):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
if dataset:
|
||||
params["iterator"] = dataset.get_shard(i)
|
||||
stats = w.train_epoch.remote(**params)
|
||||
remote_worker_stats.append(stats)
|
||||
|
||||
try:
|
||||
if dataset:
|
||||
params["iterator"] = dataset.get_shard(
|
||||
len(self.remote_workers))
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError as err:
|
||||
if "gloo" in err.args[0] and "Timed out" in err.args[0]:
|
||||
logger.warning(err)
|
||||
return False, None
|
||||
if "NCCL" in err.args[0]: # there is no specific error message
|
||||
logger.warning(err)
|
||||
return False, None
|
||||
|
||||
raise err
|
||||
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
return success, None
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
"""Run a function on all operators on the workers.
|
||||
|
||||
@@ -472,9 +456,7 @@ class TorchTrainer:
|
||||
A list of objects returned by ``fn`` on each worker.
|
||||
|
||||
"""
|
||||
remote_calls = [w.apply.remote(fn) for w in self.remote_workers]
|
||||
local_call = self.local_worker.apply(fn)
|
||||
return [local_call] + ray.get(remote_calls)
|
||||
return self.worker_group.apply_all_workers(fn)
|
||||
|
||||
def apply_all_operators(self, fn):
|
||||
"""Run a function on all operators on the workers.
|
||||
@@ -487,11 +469,7 @@ class TorchTrainer:
|
||||
A list of objects returned by ``fn`` on each operator.
|
||||
|
||||
"""
|
||||
remote_calls = [
|
||||
w.apply_operator.remote(fn) for w in self.remote_workers
|
||||
]
|
||||
local_call = self.local_worker.apply_operator(fn)
|
||||
return [local_call] + ray.get(remote_calls)
|
||||
return self.worker_group.apply_all_operators(fn)
|
||||
|
||||
def validate(self,
|
||||
num_steps=None,
|
||||
@@ -517,13 +495,8 @@ class TorchTrainer:
|
||||
You can provide custom metrics by passing in a custom
|
||||
``training_operator_cls``.
|
||||
"""
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
|
||||
remote_worker_stats = [
|
||||
w.validate.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
local_worker_stats = self.local_worker.validate(**params)
|
||||
worker_stats = [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
worker_stats = self.worker_group.validate(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
|
||||
if reduce_results:
|
||||
return self._process_stats(worker_stats)
|
||||
@@ -535,13 +508,14 @@ class TorchTrainer:
|
||||
|
||||
This is useful for lr_schedulers such as ``ReduceLROnPlateau``.
|
||||
"""
|
||||
self.apply_all_operators(
|
||||
self.worker_group.apply_all_operators(
|
||||
lambda op: [sched.step(metric) for sched in op._schedulers])
|
||||
|
||||
def get_model(self):
|
||||
"""Returns the learned model(s)."""
|
||||
unwrapped = []
|
||||
for model in self.local_worker.models:
|
||||
models = self.worker_group.get_model()
|
||||
for model in models:
|
||||
unwrapped += [model.module if hasattr(model, "module") else model]
|
||||
if len(unwrapped) == 1:
|
||||
return unwrapped[0]
|
||||
@@ -556,23 +530,13 @@ class TorchTrainer:
|
||||
Returns:
|
||||
TrainingOperator: The local TrainingOperator object.
|
||||
"""
|
||||
return self.local_worker.training_operator
|
||||
return self.worker_group.get_local_operator()
|
||||
|
||||
def state_dict(self):
|
||||
return self.local_worker.state_dict()
|
||||
return self.worker_group.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict, blocking=False):
|
||||
# This is not the most efficient because you have to wait for
|
||||
# the local worker to save then dump to buffer.
|
||||
self.local_worker.load_state_dict(state_dict)
|
||||
state_id = ray.put(self.local_worker.state_stream())
|
||||
|
||||
remote_calls = [
|
||||
worker.load_state_stream.remote(state_id)
|
||||
for worker in self.remote_workers
|
||||
]
|
||||
if blocking:
|
||||
ray.get(remote_calls)
|
||||
self.worker_group.load_state_dict(state_dict, blocking=blocking)
|
||||
|
||||
def save(self, checkpoint):
|
||||
"""Saves the Trainer state to the provided checkpoint path.
|
||||
@@ -596,81 +560,16 @@ class TorchTrainer:
|
||||
raise DeprecationWarning("Use `TorchTrainer.load()` instead.")
|
||||
|
||||
def shutdown(self, force=False):
|
||||
"""Shuts down workers and releases resources."""
|
||||
if not force:
|
||||
cleanup = [
|
||||
worker.shutdown.remote() for worker in self.remote_workers
|
||||
]
|
||||
self.local_worker.shutdown()
|
||||
try:
|
||||
ray.get(cleanup)
|
||||
[
|
||||
worker.__ray_terminate__.remote()
|
||||
for worker in self.remote_workers
|
||||
]
|
||||
except RayActorError:
|
||||
logger.warning(
|
||||
"Failed to shutdown gracefully, forcing a shutdown.")
|
||||
"""Shuts down workers and releases resources.
|
||||
|
||||
for worker in self.remote_workers:
|
||||
logger.warning(f"Killing worker {worker}.")
|
||||
ray.kill(worker)
|
||||
else:
|
||||
self.local_worker.shutdown()
|
||||
for worker in self.remote_workers:
|
||||
logger.debug(f"Killing worker {worker}.")
|
||||
ray.kill(worker)
|
||||
Args:
|
||||
force (bool): If True, forcefully kill all workers. If False,
|
||||
attempt a graceful shutdown first, and then forcefully kill if
|
||||
unsuccessful.
|
||||
|
||||
self.local_worker = DeactivatedRunner()
|
||||
self.remote_workers = []
|
||||
|
||||
def _reset(self):
|
||||
"""Terminates models without giving up local resource reservation."""
|
||||
self.local_worker.shutdown(cleanup=False)
|
||||
for worker in self.remote_workers:
|
||||
logger.debug(f"Killing worker {worker}.")
|
||||
ray.kill(worker)
|
||||
self.local_worker = DeactivatedRunner()
|
||||
self.remote_workers = []
|
||||
|
||||
def _check_potential_remote_workers_size(self):
|
||||
# ASSUME 1 GPU + 1 CPU is already reserved for the local worker
|
||||
remote_resources = ray.available_resources()
|
||||
max_remote_workers = self.max_replicas - 1
|
||||
new_remote_workers = min(
|
||||
remote_resources.get("CPU", 0), max_remote_workers)
|
||||
if self.use_gpu:
|
||||
new_remote_workers = min(
|
||||
remote_resources.get("GPU", 0), new_remote_workers)
|
||||
return new_remote_workers
|
||||
|
||||
def _resize_workers(self, max_retries=10):
|
||||
self._reset()
|
||||
|
||||
time.sleep(1)
|
||||
for i in range(max_retries):
|
||||
new_remote_workers = self._check_potential_remote_workers_size()
|
||||
if new_remote_workers:
|
||||
self._last_resize = time.time()
|
||||
self._start_workers(int(new_remote_workers) + 1)
|
||||
self.load_state_dict(self.state_dict())
|
||||
return
|
||||
else:
|
||||
delay = 2**i
|
||||
logger.warning(
|
||||
"No new workers found. Retrying in %d sec." % delay)
|
||||
time.sleep(delay)
|
||||
raise RuntimeError("Exceeded max_retries for relaunching workers.")
|
||||
|
||||
def _should_resize(self):
|
||||
"""Returns True if past cooldown and exists resources to scale up."""
|
||||
worker_gap = self.max_replicas - 1 - len(self.remote_workers)
|
||||
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
|
||||
if past_cooldown and worker_gap:
|
||||
# Assume 1 resource is already reserved for local worker.
|
||||
potential_remote_size = self._check_potential_remote_workers_size()
|
||||
return potential_remote_size > 0
|
||||
return False
|
||||
"""
|
||||
self.worker_group.shutdown(force=force)
|
||||
self.worker_group = DeactivatedWorkerGroup()
|
||||
|
||||
@classmethod
|
||||
def as_trainable(cls, *args, **kwargs):
|
||||
@@ -698,16 +597,26 @@ class TorchTrainer:
|
||||
def default_resource_request(cls, config):
|
||||
num_workers = config.get("num_workers",
|
||||
kwargs.get("num_workers", 1))
|
||||
num_cpus = config.get("num_cpus_per_worker",
|
||||
kwargs.get("num_cpus_per_worker", 1))
|
||||
num_cpus_per_worker = config.get(
|
||||
"num_cpus_per_worker", kwargs.get("num_cpus_per_worker",
|
||||
1))
|
||||
use_gpu = config.get("use_gpu", kwargs.get("use_gpu"))
|
||||
use_local = config.get("use_local",
|
||||
kwargs.get("use_local", False))
|
||||
|
||||
remote_worker_count = num_workers - 1
|
||||
if use_local:
|
||||
remote_worker_count = num_workers - 1
|
||||
local_cpus = 1
|
||||
local_gpus = int(use_gpu)
|
||||
else:
|
||||
remote_worker_count = num_workers
|
||||
local_cpus = 0
|
||||
local_gpus = 0
|
||||
|
||||
return Resources(
|
||||
cpu=num_cpus,
|
||||
gpu=int(use_gpu),
|
||||
extra_cpu=int(remote_worker_count),
|
||||
cpu=int(local_cpus * num_cpus_per_worker),
|
||||
gpu=int(local_gpus),
|
||||
extra_cpu=int(remote_worker_count * num_cpus_per_worker),
|
||||
extra_gpu=int(int(use_gpu) * remote_worker_count))
|
||||
|
||||
def _create_trainer(self, tune_config):
|
||||
|
||||
@@ -255,8 +255,8 @@ class TrainingOperator:
|
||||
else:
|
||||
self._criterion = None
|
||||
|
||||
logger.debug("Setting up Apex.")
|
||||
if self.use_fp16 and amp:
|
||||
logger.debug("Setting up Apex.")
|
||||
self._models, self._optimizers = amp.initialize(
|
||||
self._models, self._optimizers, **self._apex_args)
|
||||
self._amp = amp
|
||||
@@ -649,7 +649,9 @@ class TrainingOperator:
|
||||
"""Override this to return a representation of the operator state.
|
||||
Any argument passed into self.register and self.register_data will
|
||||
automatically be saved.
|
||||
Use this method to save any additional state.
|
||||
Use this method to save any additional state. If your TorchTrainer
|
||||
is on a CPU-only machine, make sure this method converts all state
|
||||
to be CPU-compatible.
|
||||
|
||||
Returns:
|
||||
dict: The state dict of the operator."""
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import logging
|
||||
import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
import torch.distributed as dist
|
||||
from ray.util.sgd.utils import find_free_port
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -0,0 +1,574 @@
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.util.sgd.torch.distributed_torch_runner import \
|
||||
LocalDistributedRunner, DistributedTorchRunner
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
from ray.util.sgd.torch.utils import setup_address
|
||||
from ray.util.sgd.utils import check_for_failure
|
||||
|
||||
RESIZE_COOLDOWN_S = 10
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerGroupInterface:
|
||||
"""Manages a group of TorchRunner workers."""
|
||||
|
||||
def start_workers(self, num_workers):
|
||||
"""Start workers for training.
|
||||
|
||||
This method has 4 steps.
|
||||
1. Creates `num_workers` TorchRunner objects, either all as remote
|
||||
actors or 1 locally and all but one remote.
|
||||
2. If necessary, applies an initialization hook to all the
|
||||
workers.
|
||||
3. Sets up the process group for all workers if running in
|
||||
distributed setting.
|
||||
4. Instantiates the user provided TrainingOperator on all
|
||||
workers to set up training state.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_all_operators(self, fn):
|
||||
"""See TorchTrainer.apply_all_operators."""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
"""See TorchTrainer.apply_all_workers."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_local_operator(self):
|
||||
"""See TorchTrainer.get_local_operator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model(self):
|
||||
"""See TorchTrainer.get_model."""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_state_dict(self, state_dict, blocking=False):
|
||||
"""See TorchTrainer.load_state_dict."""
|
||||
raise NotImplementedError
|
||||
|
||||
def new_workers_size(self):
|
||||
"""Returns number of workers to create based on available resources.
|
||||
Total number of workers will never exceed `max_workers` amount.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self):
|
||||
"""Resets worker group."""
|
||||
raise NotImplementedError
|
||||
|
||||
def should_scale_up(self):
|
||||
"""Returns whether to scale up the number of remote workers.
|
||||
|
||||
This method returns True if current number of workers is less than
|
||||
max_workers provided during startup, enough resources are
|
||||
available to scale up, and a sufficient cooldown period has passed.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def shutdown(self, force=False):
|
||||
"""See TorchTrainer.shutdown."""
|
||||
raise NotImplementedError
|
||||
|
||||
def state_dict(self):
|
||||
"""See TorchTrainer.state_dict."""
|
||||
raise NotImplementedError
|
||||
|
||||
def train(self, num_steps=None, profile=False, info=None, dataset=None):
|
||||
"""Runs one epoch of training on all workers.
|
||||
Args:
|
||||
See TorchTrainer.train.
|
||||
Returns:
|
||||
Tuple of (bool, list). First value is True if training was
|
||||
successful and False otherwise. Second value is list of
|
||||
training results from all workers if training was successful,
|
||||
and None otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def validate(self, num_steps=None, profile=False, info=None):
|
||||
"""Runs validation for all workers.
|
||||
Args:
|
||||
See TorchTrainer.validate.
|
||||
Return:
|
||||
List of validation results for each worker.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RemoteWorkerGroup(WorkerGroupInterface):
|
||||
"""A group of TorchRunner workers that are all remote Ray actors.
|
||||
|
||||
Args:
|
||||
max_workers (int): Maximum number of workers to use.
|
||||
params (dict): Parameters to pass into a TorchRunner worker.
|
||||
dist_params (dict): Additional parameters for distributed training
|
||||
to pass into a DistributedTorchRunner worker.
|
||||
initialization_hook (Callable): See TorchTrainer.__init__.
|
||||
timeout_s (float): See TorchTrainer.__init__.
|
||||
num_cpus_per_worker (int): See TorchTrainer.__init__.
|
||||
use_gpu (bool): See TorchTrainer.__init__.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers, params, dist_params, initialization_hook,
|
||||
timeout_s, num_cpus_per_worker, use_gpu):
|
||||
# Invariant: These variables should never change state!
|
||||
self._max_workers = max_workers
|
||||
self._params = params
|
||||
self._dist_params = dist_params
|
||||
self._initialization_hook = initialization_hook
|
||||
self._timeout_s = timeout_s
|
||||
self._num_cpus_per_worker = num_cpus_per_worker
|
||||
self._use_gpu = use_gpu
|
||||
|
||||
self.remote_workers = []
|
||||
|
||||
# The last time when this worker group was resized.
|
||||
self._last_resize = float("-inf")
|
||||
|
||||
def _init_dist_workers(self, num_workers):
|
||||
"""Create `num_workers` remote workers."""
|
||||
# Generate actor class
|
||||
RemoteRunner = ray.remote(
|
||||
num_cpus=self._num_cpus_per_worker,
|
||||
num_gpus=int(self._use_gpu))(DistributedTorchRunner)
|
||||
|
||||
# Start workers
|
||||
self.remote_workers = [
|
||||
RemoteRunner.remote(**{
|
||||
**self._params,
|
||||
**self._dist_params
|
||||
}) for _ in range(num_workers)
|
||||
]
|
||||
|
||||
def _setup_process_group(self, address, world_size, starting_rank=0):
|
||||
"""Sets up process group for all workers.
|
||||
|
||||
Args:
|
||||
address (str): Address to use for TCP process group setup. The
|
||||
provided address must use the IP address of the node that the
|
||||
rank 0 worker is located on.
|
||||
world_size (int): Total number of training workers in the
|
||||
process group. This may differ from self.num_workers if
|
||||
there are additional workers outside this worker group class.
|
||||
starting_rank (int): The rank to use for the first worker.
|
||||
Worker ranks will be in [starting_rank,
|
||||
len(self.remote_workers)+starting_rank). This is useful if
|
||||
you want to use worker outside of this group as the rank 0
|
||||
worker.
|
||||
|
||||
Returns:
|
||||
List of process group set up promises.
|
||||
"""
|
||||
# Setup the process group among all workers.
|
||||
remote_pgroup_setups = [
|
||||
worker.setup_process_group.remote(
|
||||
url=address,
|
||||
world_rank=i + starting_rank,
|
||||
world_size=world_size,
|
||||
timeout=timedelta(self._timeout_s))
|
||||
for i, worker in enumerate(self.remote_workers)
|
||||
]
|
||||
return remote_pgroup_setups
|
||||
|
||||
def _setup_operator(self):
|
||||
"""Instantiates user provided TrainingOperator on all workers.
|
||||
|
||||
Returns:
|
||||
List of operator setup promises.
|
||||
"""
|
||||
# Runs code that requires all creator functions to have run.
|
||||
remote_operator_setups = [
|
||||
worker.setup_operator.remote() for worker in self.remote_workers
|
||||
]
|
||||
return remote_operator_setups
|
||||
|
||||
def start_workers(self, num_workers):
|
||||
logger.debug(f"start_workers: Setting %d workers." % num_workers)
|
||||
if num_workers == 1:
|
||||
RemoteRunner = ray.remote(
|
||||
num_cpus=self._num_cpus_per_worker,
|
||||
num_gpus=int(self._use_gpu))(TorchRunner)
|
||||
self.remote_workers = [RemoteRunner.remote(**self._params)]
|
||||
ray.get(self.remote_workers[0].setup_operator.remote())
|
||||
else:
|
||||
self._init_dist_workers(num_workers)
|
||||
|
||||
if self._initialization_hook:
|
||||
self.apply_all_workers(self._initialization_hook)
|
||||
|
||||
# Make sure to get the IP address of the rank 0 worker node.
|
||||
address = ray.get(self.remote_workers[0].setup_address.remote())
|
||||
|
||||
ray.get(
|
||||
self._setup_process_group(
|
||||
address=address, world_size=num_workers))
|
||||
|
||||
ray.get(self._setup_operator())
|
||||
|
||||
def _apply_all_operators(self, fn):
|
||||
remote_calls = [
|
||||
w.apply_operator.remote(fn) for w in self.remote_workers
|
||||
]
|
||||
return remote_calls
|
||||
|
||||
def apply_all_operators(self, fn):
|
||||
return ray.get(self._apply_all_operators(fn))
|
||||
|
||||
def _apply_all_workers(self, fn):
|
||||
return [w.apply.remote(fn) for w in self.remote_workers]
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
return ray.get(self._apply_all_workers(fn))
|
||||
|
||||
def get_local_operator(self):
|
||||
raise NotImplementedError(
|
||||
"Cannot return a local operators if all "
|
||||
"workers are remote. Set use_local to True in"
|
||||
"TorchTrainer to access a local operator.")
|
||||
|
||||
def get_model(self):
|
||||
ready, _ = ray.wait(
|
||||
[r.get_models.remote() for r in self.remote_workers])
|
||||
models = ray.get(ready[0])
|
||||
return models
|
||||
|
||||
def _load_state_id(self, state_id):
|
||||
"""Loads the object with id `state_id` to all workers."""
|
||||
remote_calls = [
|
||||
worker.load_state_stream.remote(state_id)
|
||||
for worker in self.remote_workers
|
||||
]
|
||||
return remote_calls
|
||||
|
||||
def load_state_dict(self, state_dict, blocking=False):
|
||||
_buffer = io.BytesIO()
|
||||
torch.save(state_dict, _buffer)
|
||||
stream = _buffer.getvalue()
|
||||
state_id = ray.put(stream)
|
||||
|
||||
remote_calls = self._load_state_id(state_id)
|
||||
|
||||
if blocking:
|
||||
ray.get(remote_calls)
|
||||
|
||||
def state_dict(self):
|
||||
# This is needed to handle calling ray.get on a dead actor.
|
||||
buffer_object = None
|
||||
futures = {r.state_stream.remote() for r in self.remote_workers}
|
||||
while len(futures) > 0:
|
||||
ready, _ = ray.wait(list(futures), num_returns=1)
|
||||
object_ref = ready[0]
|
||||
try:
|
||||
buffer_object = ray.get(object_ref)
|
||||
except RayActorError:
|
||||
futures.remove(object_ref)
|
||||
else:
|
||||
break
|
||||
if buffer_object is None:
|
||||
raise RuntimeError("Obtaining state_dict from remote workers is "
|
||||
"unsuccessful since all workers have died.")
|
||||
to_gpu = self._use_gpu and torch.cuda.is_available()
|
||||
_buffer = io.BytesIO(buffer_object)
|
||||
state_dict = torch.load(
|
||||
_buffer,
|
||||
map_location=("cpu" if not to_gpu else
|
||||
lambda storage, loc: storage.cuda()))
|
||||
return state_dict
|
||||
|
||||
def _train(self, num_steps, profile, info, dataset=None):
|
||||
"""Runs 1 epoch of training on all workers.
|
||||
Returns training result for all workers as promises.
|
||||
"""
|
||||
remote_worker_stats = []
|
||||
for i, w in enumerate(self.remote_workers):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
if dataset:
|
||||
params["iterator"] = dataset.get_shard(i)
|
||||
stats = w.train_epoch.remote(**params)
|
||||
remote_worker_stats.append(stats)
|
||||
return remote_worker_stats
|
||||
|
||||
def train(self, num_steps=None, profile=False, info=None, dataset=None):
|
||||
"""Runs 1 epoch of training on all workers.
|
||||
|
||||
Has additional logic to check for worker failure.
|
||||
"""
|
||||
if dataset:
|
||||
dataset.set_num_shards(self.num_workers)
|
||||
remote_worker_stats = self._train(num_steps, profile, info, dataset)
|
||||
# Check if each worker has failed before calling ray.get.
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, ray.get(remote_worker_stats)
|
||||
return success, None
|
||||
|
||||
def _validate(self, params):
|
||||
"""Runs validation for each worker. Returns results as promises."""
|
||||
remote_worker_stats = [
|
||||
w.validate.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
return remote_worker_stats
|
||||
|
||||
def validate(self, num_steps=None, profile=False, info=None):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
remote_worker_stats = self._validate(params)
|
||||
return ray.get(remote_worker_stats)
|
||||
|
||||
def _shutdown_remote_workers(self):
|
||||
"""Shuts down each worker and returns a list of cleanup promises."""
|
||||
cleanup = [worker.shutdown.remote() for worker in self.remote_workers]
|
||||
return cleanup
|
||||
|
||||
def _terminate_remote_workers(self, cleanup):
|
||||
"""Blocks on worker shutdown and then terminates each worker actor.
|
||||
|
||||
If graceful shutdown fails, forcefully kills all actors.
|
||||
"""
|
||||
try:
|
||||
ray.get(cleanup)
|
||||
[
|
||||
worker.__ray_terminate__.remote()
|
||||
for worker in self.remote_workers
|
||||
]
|
||||
except RayActorError:
|
||||
logger.warning("Failed to shutdown gracefully, forcing a "
|
||||
"shutdown.")
|
||||
self.reset()
|
||||
|
||||
def shutdown(self, force=False):
|
||||
if not force:
|
||||
cleanup = [
|
||||
worker.shutdown.remote() for worker in self.remote_workers
|
||||
]
|
||||
self._terminate_remote_workers(cleanup)
|
||||
else:
|
||||
self.reset()
|
||||
self.remote_workers = []
|
||||
|
||||
def reset(self):
|
||||
for worker in self.remote_workers:
|
||||
logger.debug(f"Killing worker {worker}.")
|
||||
ray.kill(worker)
|
||||
self.remote_workers = []
|
||||
|
||||
def should_scale_up(self):
|
||||
worker_gap = self._max_workers - self.num_workers
|
||||
past_cooldown = (time.time() - self._last_resize) > RESIZE_COOLDOWN_S
|
||||
if past_cooldown and worker_gap:
|
||||
# Assume 1 resource is already reserved for local worker.
|
||||
potential_remote_size = self._check_potential_remote_workers_size()
|
||||
return potential_remote_size > 0
|
||||
return False
|
||||
|
||||
def new_workers_size(self):
|
||||
"""Returns number of workers to create based on available resources."""
|
||||
remote_resources = ray.available_resources()
|
||||
max_remote_workers = self._max_workers
|
||||
new_remote_workers = min(
|
||||
remote_resources.get("CPU", 0), max_remote_workers)
|
||||
if self._use_gpu:
|
||||
new_remote_workers = min(
|
||||
remote_resources.get("GPU", 0), new_remote_workers)
|
||||
return new_remote_workers
|
||||
|
||||
@property
|
||||
def num_workers(self):
|
||||
"""Current number of workers being used for training.
|
||||
This may differ from self.num_workers if self.resize_workers has
|
||||
been called.
|
||||
"""
|
||||
return len(self.remote_workers)
|
||||
|
||||
|
||||
class LocalWorkerGroup(WorkerGroupInterface):
|
||||
"""A group of TorchRunner workers.
|
||||
1 worker runs locally, and all the other workers are remote actors.
|
||||
|
||||
Args:
|
||||
Same as RemoteWorkerGroup.
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers, params, dist_params, initialization_hook,
|
||||
timeout_s, num_cpus_per_worker, use_gpu):
|
||||
|
||||
# Invariant: These variables should never change state!
|
||||
self._max_workers = max_workers
|
||||
self._params = params
|
||||
self._dist_params = dist_params
|
||||
self._initialization_hook = initialization_hook
|
||||
self._timeout_s = timeout_s
|
||||
self._num_cpus_per_worker = num_cpus_per_worker
|
||||
self._use_gpu = use_gpu
|
||||
|
||||
self.local_worker = None
|
||||
self.remote_worker_group = RemoteWorkerGroup(
|
||||
max_workers=max_workers - 1,
|
||||
params=params,
|
||||
dist_params=dist_params,
|
||||
initialization_hook=initialization_hook,
|
||||
timeout_s=timeout_s,
|
||||
num_cpus_per_worker=num_cpus_per_worker,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def start_workers(self, num_workers):
|
||||
logger.debug(f"start_workers: Setting %d workers." % num_workers)
|
||||
|
||||
if num_workers == 1:
|
||||
self.local_worker = TorchRunner(**self._params)
|
||||
if self._initialization_hook:
|
||||
self.apply_all_workers(self._initialization_hook)
|
||||
self.local_worker.setup_operator()
|
||||
else:
|
||||
|
||||
# Start local worker
|
||||
self.local_worker = LocalDistributedRunner(
|
||||
num_cpus=self._num_cpus_per_worker,
|
||||
num_gpus=int(self._use_gpu),
|
||||
**{
|
||||
**self._params,
|
||||
**self._dist_params
|
||||
})
|
||||
self.remote_worker_group._init_dist_workers(num_workers - 1)
|
||||
if self._initialization_hook:
|
||||
self.apply_all_workers(self._initialization_hook)
|
||||
|
||||
# Compute URL for initializing distributed PyTorch.
|
||||
address = setup_address()
|
||||
|
||||
remote_pgs = self.remote_worker_group._setup_process_group(
|
||||
address=address, world_size=num_workers, starting_rank=1)
|
||||
# Use the local worker as rank 0. This will help with debugging.
|
||||
self.local_worker.setup_process_group(
|
||||
url=address,
|
||||
world_rank=0,
|
||||
world_size=num_workers,
|
||||
timeout=timedelta(self._timeout_s))
|
||||
ray.get(remote_pgs)
|
||||
|
||||
remote_operators = self.remote_worker_group._setup_operator()
|
||||
self.local_worker.setup_operator()
|
||||
ray.get(remote_operators)
|
||||
|
||||
def apply_all_operators(self, fn):
|
||||
remote_calls = self.remote_worker_group._apply_all_operators(fn)
|
||||
local_call = self.local_worker.apply_operator(fn)
|
||||
return [local_call] + ray.get(remote_calls)
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
remote_calls = self.remote_worker_group.apply_all_workers(fn)
|
||||
local_call = self.local_worker.apply(fn)
|
||||
return [local_call] + ray.get(remote_calls)
|
||||
|
||||
def get_local_operator(self):
|
||||
return self.local_worker.training_operator
|
||||
|
||||
def get_model(self):
|
||||
return self.local_worker.models
|
||||
|
||||
def load_state_dict(self, state_dict, blocking=False):
|
||||
# This is not the most efficient because you have to wait for
|
||||
# the local worker to save then dump to buffer.
|
||||
self.local_worker.load_state_dict(state_dict)
|
||||
state_id = ray.put(self.local_worker.state_stream())
|
||||
|
||||
remote_calls = self.remote_worker_group._load_state_id(state_id)
|
||||
if blocking:
|
||||
ray.get(remote_calls)
|
||||
|
||||
def state_dict(self):
|
||||
return self.local_worker.state_dict()
|
||||
|
||||
def should_scale_up(self):
|
||||
return self.remote_worker_group.should_scale_up()
|
||||
|
||||
def reset(self):
|
||||
"""Terminates models without giving up local resource reservation."""
|
||||
self.local_worker.shutdown(cleanup=False)
|
||||
self.remote_worker_group.reset()
|
||||
|
||||
self.local_worker = None
|
||||
self.remote_worker_group = RemoteWorkerGroup(
|
||||
max_workers=self._max_workers - 1,
|
||||
params=self._params,
|
||||
dist_params=self._dist_params,
|
||||
initialization_hook=self._initialization_hook,
|
||||
num_cpus_per_worker=self._num_cpus_per_worker,
|
||||
use_gpu=self._use_gpu,
|
||||
timeout_s=self._timeout_s)
|
||||
|
||||
def new_workers_size(self):
|
||||
return self.remote_worker_group.new_workers_size() + 1
|
||||
|
||||
def train(self, num_steps=None, profile=False, info=None, dataset=None):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
if dataset:
|
||||
dataset.set_num_shards(self.num_workers)
|
||||
|
||||
remote_worker_stats = self.remote_worker_group._train(
|
||||
num_steps, profile, info, dataset)
|
||||
try:
|
||||
if dataset:
|
||||
params["iterator"] = dataset.get_shard(self.num_workers - 1)
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError as err:
|
||||
if "gloo" in err.args[0] and "Timed out" in err.args[0]:
|
||||
logger.warning(err)
|
||||
return False, None
|
||||
if "NCCL" in err.args[0]: # there is no specific error message
|
||||
logger.warning(err)
|
||||
return False, None
|
||||
if "Connection closed by peer" in err.args[0]:
|
||||
logger.warning(err)
|
||||
return False, None
|
||||
|
||||
raise err
|
||||
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
return success, None
|
||||
|
||||
def validate(self, num_steps=None, profile=False, info=None):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
|
||||
remote_worker_stats = self.remote_worker_group._validate(params)
|
||||
local_worker_stats = self.local_worker.validate(**params)
|
||||
worker_stats = [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
return worker_stats
|
||||
|
||||
def shutdown(self, force=False):
|
||||
if not force:
|
||||
cleanup = self.remote_worker_group._shutdown_remote_workers()
|
||||
self.local_worker.shutdown()
|
||||
self.remote_worker_group._terminate_remote_workers(cleanup)
|
||||
else:
|
||||
self.local_worker.shutdown()
|
||||
self.remote_worker_group.reset()
|
||||
|
||||
self.local_worker = None
|
||||
self.remote_worker_group = DeactivatedWorkerGroup()
|
||||
|
||||
@property
|
||||
def num_workers(self):
|
||||
return self.remote_worker_group.num_workers + 1
|
||||
|
||||
@property
|
||||
def remote_workers(self):
|
||||
return self.remote_worker_group.remote_workers
|
||||
|
||||
|
||||
class DeactivatedWorkerGroup:
|
||||
def __getattr__(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"This TorchTrainer is not active (it is likely shutdown already). "
|
||||
"Create a new TorchTrainer.")
|
||||
Reference in New Issue
Block a user