mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 17:02:43 +08:00
[tune/raysgd] Tune API for TorchTrainer + Fix State Restoration (#7547)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
@@ -10,7 +9,7 @@ import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.torch import TorchTrainer, TorchTrainable
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch.training_operator import (_TestingOperator,
|
||||
_TestMetricsOperator)
|
||||
from ray.util.sgd.torch.constants import SCHEDULER_STEP
|
||||
@@ -27,6 +26,9 @@ def ray_start_2_cpus():
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
# Ensure that tests don't ALL fail
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
@@ -44,6 +46,19 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=2)
|
||||
trainer.train(num_steps=1)
|
||||
trainer.shutdown()
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer.train()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
@@ -53,12 +68,12 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=num_workers)
|
||||
for i in range(3):
|
||||
train_loss1 = trainer.train()["mean_train_loss"]
|
||||
validation_loss1 = trainer.validate()["mean_val_loss"]
|
||||
train_loss1 = trainer.train()["train_loss"]
|
||||
validation_loss1 = trainer.validate()["val_loss"]
|
||||
|
||||
for i in range(3):
|
||||
train_loss2 = trainer.train()["mean_train_loss"]
|
||||
validation_loss2 = trainer.validate()["mean_val_loss"]
|
||||
train_loss2 = trainer.train()["train_loss"]
|
||||
validation_loss2 = trainer.validate()["val_loss"]
|
||||
|
||||
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
|
||||
assert validation_loss2 <= validation_loss1, (validation_loss2,
|
||||
@@ -118,9 +133,7 @@ def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
training_operator_cls=_TestingOperator,
|
||||
num_workers=num_workers)
|
||||
trainer1.train()
|
||||
|
||||
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
|
||||
trainer1.save(filename)
|
||||
state = trainer1.state_dict()
|
||||
|
||||
models1 = trainer1.get_model()
|
||||
|
||||
@@ -134,9 +147,7 @@ def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
num_workers=num_workers)
|
||||
trainer2.restore(filename)
|
||||
|
||||
os.remove(filename)
|
||||
trainer2.load_state_dict(state)
|
||||
|
||||
models2 = trainer2.get_model()
|
||||
|
||||
@@ -336,20 +347,20 @@ def test_metrics(ray_start_2_cpus, num_workers):
|
||||
|
||||
stats = trainer.train(num_steps=num_train_steps)
|
||||
# Test that we output mean and last of custom metrics in an epoch
|
||||
assert "mean_score" in stats
|
||||
assert "score" in stats
|
||||
assert stats["last_score"] == 0
|
||||
|
||||
assert stats[NUM_SAMPLES] == num_train_steps * batch_size
|
||||
expected_score = num_workers * (sum(train_scores) /
|
||||
(num_train_steps * batch_size))
|
||||
assert np.allclose(stats["mean_score"], expected_score)
|
||||
assert np.allclose(stats["score"], expected_score)
|
||||
|
||||
val_stats = trainer.validate()
|
||||
# Test that we output mean and last of custom metrics in validation
|
||||
assert val_stats["last_score"] == 0
|
||||
expected_score = (sum(val_scores) /
|
||||
(num_val_steps * batch_size)) * num_workers
|
||||
assert np.allclose(val_stats["mean_score"], expected_score)
|
||||
assert np.allclose(val_stats["score"], expected_score)
|
||||
assert val_stats[BATCH_COUNT] == np.ceil(num_val_steps / num_workers)
|
||||
assert val_stats[NUM_SAMPLES] == num_val_steps * batch_size
|
||||
assert val_stats[NUM_SAMPLES] == val_size
|
||||
@@ -384,14 +395,14 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
|
||||
training_operator_cls=_TestMetricsOperator)
|
||||
|
||||
stats = trainer.train(num_steps=num_train_steps)
|
||||
assert "mean_score" in stats
|
||||
assert "score" in stats
|
||||
assert stats["last_score"] == 0
|
||||
assert np.isnan(stats["mean_score"])
|
||||
assert np.isnan(stats["score"])
|
||||
|
||||
stats = trainer.validate()
|
||||
assert "mean_score" in stats
|
||||
assert "score" in stats
|
||||
assert stats["last_score"] == 0
|
||||
assert np.isnan(stats["mean_score"])
|
||||
assert np.isnan(stats["score"])
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@@ -415,41 +426,41 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
|
||||
config = {
|
||||
"model_creator": model_creator,
|
||||
"data_creator": data_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": lambda config: nn.MSELoss(),
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": False,
|
||||
"backend": "gloo",
|
||||
"config": {
|
||||
"batch_size": 512,
|
||||
"lr": 0.001
|
||||
}
|
||||
}
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
**{
|
||||
"model_creator": model_creator,
|
||||
"data_creator": data_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": lambda config: nn.MSELoss(),
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": False,
|
||||
"backend": "gloo",
|
||||
"config": {
|
||||
"batch_size": 512,
|
||||
"lr": 0.001
|
||||
}
|
||||
})
|
||||
|
||||
analysis = tune.run(
|
||||
TorchTrainable,
|
||||
num_samples=2,
|
||||
config=config,
|
||||
stop={"training_iteration": 2},
|
||||
verbose=1)
|
||||
|
||||
# checks loss decreasing for every trials
|
||||
for path, df in analysis.trial_dataframes.items():
|
||||
mean_train_loss1 = df.loc[0, "mean_train_loss"]
|
||||
mean_train_loss2 = df.loc[1, "mean_train_loss"]
|
||||
mean_val_loss1 = df.loc[0, "mean_val_loss"]
|
||||
mean_val_loss2 = df.loc[1, "mean_val_loss"]
|
||||
mean_train_loss1 = df.loc[0, "train_loss"]
|
||||
mean_train_loss2 = df.loc[1, "train_loss"]
|
||||
mean_val_loss1 = df.loc[0, "val_loss"]
|
||||
mean_val_loss2 = df.loc[1, "val_loss"]
|
||||
|
||||
assert mean_train_loss2 <= mean_train_loss1
|
||||
assert mean_val_loss2 <= mean_val_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
def test_save_and_restore(ray_start_2_cpus, num_workers,
|
||||
tmp_path): # noqa: F811
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
@@ -457,9 +468,8 @@ def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=num_workers)
|
||||
trainer1.train()
|
||||
|
||||
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
|
||||
trainer1.save(filename)
|
||||
checkpoint_path = os.path.join(tmp_path, "checkpoint")
|
||||
trainer1.save(checkpoint_path)
|
||||
|
||||
model1 = trainer1.get_model()
|
||||
|
||||
@@ -471,9 +481,7 @@ def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=num_workers)
|
||||
trainer2.restore(filename)
|
||||
|
||||
os.remove(filename)
|
||||
trainer2.load(checkpoint_path)
|
||||
|
||||
model2 = trainer2.get_model()
|
||||
|
||||
@@ -619,7 +627,8 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=2)
|
||||
|
||||
trainer1.train(max_retries=2)
|
||||
# MAX RETRIES SHOULD BE ON BY DEFAULT
|
||||
trainer1.train()
|
||||
trainer1.shutdown()
|
||||
|
||||
|
||||
|
||||
@@ -2,16 +2,17 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TorchTrainer = None
|
||||
TorchTrainable = None
|
||||
TrainingOperator = None
|
||||
BaseTorchTrainable = None
|
||||
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
|
||||
from ray.util.sgd.torch.torch_trainer import (TorchTrainer, TorchTrainable)
|
||||
from ray.util.sgd.torch.torch_trainer import (TorchTrainer,
|
||||
BaseTorchTrainable)
|
||||
|
||||
from ray.util.sgd.torch.training_operator import TrainingOperator
|
||||
|
||||
__all__ = ["TorchTrainer", "TorchTrainable", "TrainingOperator"]
|
||||
__all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"]
|
||||
except ImportError:
|
||||
logger.warning("PyTorch not found. TorchTrainer will not be available")
|
||||
|
||||
@@ -142,14 +142,7 @@ class DistributedTorchRunner(TorchRunner):
|
||||
|
||||
This is needed for PyTorch DistributedDataParallel models.
|
||||
"""
|
||||
cpu_state_dicts = []
|
||||
for model in self.models:
|
||||
state_dict = model.module.state_dict()
|
||||
# This is so that we create a duplicate of weights into CPU rather
|
||||
# than move the model weights out of the GPU so that we can
|
||||
# resume training while saving intermediate checkpoints.
|
||||
cpu_state_dicts += [{k: v.cpu() for k, v in state_dict.items()}]
|
||||
return cpu_state_dicts
|
||||
return [model.module.state_dict() for model in self.models]
|
||||
|
||||
def _set_model_state_dicts(self, model_state_dicts):
|
||||
for model, model_state_dict in zip(self.models, model_state_dicts):
|
||||
@@ -212,3 +205,10 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
||||
def is_actor(self):
|
||||
actor_id = ray.worker.global_worker.actor_id
|
||||
return actor_id != actor_id.nil()
|
||||
|
||||
|
||||
class DeactivatedRunner:
|
||||
def __getattr__(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"This TorchTrainer is not active (it is likely shutdown already). "
|
||||
"Create a new TorchTrainer.")
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
@@ -10,17 +12,20 @@ import torchvision.transforms as transforms
|
||||
from tqdm import trange
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch import (TorchTrainer, TorchTrainable)
|
||||
from ray.tune import CLIReporter
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch.resnet import ResNet18
|
||||
from ray.util.sgd.utils import BATCH_SIZE
|
||||
|
||||
|
||||
def initialization_hook():
|
||||
print("NCCL DEBUG SET")
|
||||
# Need this for avoiding a connection restart issue
|
||||
# Need this for avoiding a connection restart issue on AWS.
|
||||
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
|
||||
os.environ["NCCL_LL_THRESHOLD"] = "0"
|
||||
os.environ["NCCL_DEBUG"] = "INFO"
|
||||
|
||||
# set the below if needed
|
||||
# print("NCCL DEBUG SET")
|
||||
# os.environ["NCCL_DEBUG"] = "INFO"
|
||||
|
||||
|
||||
def cifar_creator(config):
|
||||
@@ -55,7 +60,10 @@ def cifar_creator(config):
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Returns optimizer"""
|
||||
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 0.1))
|
||||
return torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.get("lr", 0.1),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
|
||||
|
||||
def scheduler_creator(optimizer, config):
|
||||
@@ -77,12 +85,11 @@ def train_example(num_workers=1,
|
||||
initialization_hook=initialization_hook,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"lr": 0.01,
|
||||
"test_mode": test_mode,
|
||||
BATCH_SIZE: 128,
|
||||
"lr": 0.1,
|
||||
"test_mode": test_mode, # user-defined param to subset the data
|
||||
BATCH_SIZE: 128 * num_workers # this will be split across workers.
|
||||
},
|
||||
use_gpu=use_gpu,
|
||||
backend="nccl" if use_gpu else "gloo",
|
||||
scheduler_step_freq="epoch",
|
||||
use_fp16=use_fp16,
|
||||
use_tqdm=True)
|
||||
@@ -92,39 +99,65 @@ def train_example(num_workers=1,
|
||||
info["epoch_idx"] = i
|
||||
info["num_epochs"] = num_epochs
|
||||
# Increase `max_retries` to turn on fault tolerance.
|
||||
stats = trainer1.train(max_retries=1, info=info)
|
||||
pbar.set_postfix(dict(loss=stats["mean_train_loss"]))
|
||||
trainer1.train(max_retries=1, info=info)
|
||||
val_stats = trainer1.validate()
|
||||
pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))
|
||||
|
||||
print(trainer1.validate())
|
||||
trainer1.shutdown()
|
||||
print("success!")
|
||||
|
||||
|
||||
def tune_example(num_workers=1, use_gpu=False, test_mode=False):
|
||||
config = {
|
||||
"model_creator": ResNet18,
|
||||
"data_creator": cifar_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": nn.CrossEntropyLoss,
|
||||
"num_workers": num_workers,
|
||||
"initialization_hook": initialization_hook,
|
||||
"use_gpu": use_gpu,
|
||||
"config": {
|
||||
"lr": tune.choice([1e-4, 1e-3]),
|
||||
BATCH_SIZE: 128,
|
||||
"test_mode": test_mode
|
||||
def tune_example(num_workers=1, use_gpu=False, use_fp16=False,
|
||||
test_mode=False):
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
initialization_hook=initialization_hook,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"test_mode": test_mode, # user-defined param to subset the data
|
||||
BATCH_SIZE: 128 * num_workers,
|
||||
},
|
||||
"backend": "nccl" if use_gpu else "gloo"
|
||||
}
|
||||
use_gpu=use_gpu,
|
||||
scheduler_step_freq="epoch",
|
||||
use_fp16=use_fp16)
|
||||
|
||||
pbt_scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="val_loss",
|
||||
mode="min",
|
||||
perturbation_interval=1,
|
||||
hyperparam_mutations={
|
||||
# distribution for resampling
|
||||
"lr": lambda: np.random.uniform(0.001, 1),
|
||||
# allow perturbations within this set of categorical values
|
||||
"momentum": [0.8, 0.9, 0.99],
|
||||
})
|
||||
|
||||
reporter = CLIReporter()
|
||||
reporter.add_metric_column("val_loss", "loss")
|
||||
reporter.add_metric_column("val_accuracy", "acc")
|
||||
|
||||
analysis = tune.run(
|
||||
TorchTrainable,
|
||||
num_samples=2,
|
||||
config=config,
|
||||
stop={"training_iteration": 2},
|
||||
verbose=2)
|
||||
num_samples=4,
|
||||
config={
|
||||
"lr": tune.choice([0.001, 0.01, 0.1]),
|
||||
"momentum": 0.8
|
||||
},
|
||||
stop={"training_iteration": 2 if test_mode else 100},
|
||||
max_failures=3, # used for fault tolerance
|
||||
checkpoint_freq=3, # used for fault tolerance
|
||||
keep_checkpoints_num=1, # used for fault tolerance
|
||||
verbose=2,
|
||||
progress_reporter=reporter,
|
||||
scheduler=pbt_scheduler)
|
||||
|
||||
return analysis.get_best_config(metric="mean_accuracy", mode="max")
|
||||
return analysis.get_best_config(metric="val_loss", mode="min")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -242,16 +242,13 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
|
||||
num_workers=num_workers,
|
||||
config=config,
|
||||
use_gpu=use_gpu,
|
||||
backend="nccl" if use_gpu else "gloo",
|
||||
use_tqdm=True)
|
||||
|
||||
from tabulate import tabulate
|
||||
pbar = trange(5, unit="epoch")
|
||||
for itr in pbar:
|
||||
stats = trainer.train(info=dict(epoch_idx=itr, num_epochs=5))
|
||||
pbar.set_postfix(
|
||||
dict(loss_g=stats["mean_loss_g"], loss_d=stats["mean_loss_d"]))
|
||||
|
||||
pbar.set_postfix(dict(loss_g=stats["loss_g"], loss_d=stats["loss_d"]))
|
||||
formatted = tabulate([stats], headers="keys")
|
||||
if itr > 0: # Get the last line of the stats.
|
||||
formatted = formatted.split("\n")[-1]
|
||||
|
||||
@@ -14,7 +14,8 @@ import torch.nn as nn
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.torch.torch_trainer import TorchTrainable
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.utils import BATCH_SIZE
|
||||
|
||||
|
||||
class LinearDataset(torch.utils.data.Dataset):
|
||||
@@ -48,30 +49,29 @@ def data_creator(config):
|
||||
val_dataset = LinearDataset(2, 5, size=400)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
batch_size=config[BATCH_SIZE],
|
||||
)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config["batch_size"])
|
||||
batch_size=config[BATCH_SIZE])
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
def tune_example(num_workers=1, use_gpu=False):
|
||||
config = {
|
||||
"model_creator": model_creator,
|
||||
"data_creator": data_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": nn.MSELoss,
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": use_gpu,
|
||||
"config": {"batch_size": 512 // num_workers},
|
||||
"backend": "gloo"
|
||||
}
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
num_workers=num_workers,
|
||||
use_gpu=use_gpu,
|
||||
config={BATCH_SIZE: 128}
|
||||
)
|
||||
|
||||
analysis = tune.run(
|
||||
TorchTrainable,
|
||||
num_samples=12,
|
||||
config=config,
|
||||
num_samples=3,
|
||||
config={"lr": tune.grid_search([1e-4, 1e-3])},
|
||||
stop={"training_iteration": 2},
|
||||
verbose=1)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import collections
|
||||
from filelock import FileLock
|
||||
import logging
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
@@ -225,22 +226,14 @@ class TorchRunner:
|
||||
self.training_operator._set_timers(self.timers)
|
||||
|
||||
def _get_model_state_dicts(self):
|
||||
# This is so that we create a duplicate of weights into CPU rather than
|
||||
# move the model weights entirely out of the GPU, so that we can
|
||||
# resume training while saving intermediate checkpoints.
|
||||
cpu_state_dicts = []
|
||||
for model in self.models:
|
||||
state_dict = model.state_dict()
|
||||
cpu_state_dicts += [{k: v.cpu() for k, v in state_dict.items()}]
|
||||
return cpu_state_dicts
|
||||
return [model.state_dict() for model in self.models]
|
||||
|
||||
def _set_model_state_dicts(self, models_state_dicts):
|
||||
for model, state_dict in zip(self.models, models_state_dicts):
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
def get_state(self):
|
||||
def state_dict(self):
|
||||
"""Returns the state of the runner."""
|
||||
|
||||
state = {
|
||||
"epoch": self.epochs,
|
||||
"operator": self.training_operator.state_dict(),
|
||||
@@ -258,9 +251,8 @@ class TorchRunner:
|
||||
state.update({"amp": amp.state_dict()})
|
||||
return state
|
||||
|
||||
def set_state(self, state):
|
||||
def load_state_dict(self, state):
|
||||
"""Sets the state of the model."""
|
||||
# TODO: restore timer stats
|
||||
self._set_model_state_dicts(state["models"])
|
||||
for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
|
||||
optimizer.load_state_dict(state_dict)
|
||||
@@ -274,6 +266,19 @@ class TorchRunner:
|
||||
self.epochs = state["epoch"]
|
||||
self.training_operator.load_state_dict(state_dict)
|
||||
|
||||
def state_stream(self):
|
||||
"""Returns a bytes object for the state dict."""
|
||||
state_dict = self.state_dict()
|
||||
_buffer = io.BytesIO()
|
||||
torch.save(state_dict, _buffer)
|
||||
return _buffer.getvalue()
|
||||
|
||||
def load_state_stream(self, byte_obj):
|
||||
"""Loads a bytes object the training state dict."""
|
||||
_buffer = io.BytesIO(byte_obj)
|
||||
state_dict = torch.load(_buffer)
|
||||
return self.load_state_dict(state_dict)
|
||||
|
||||
def apply(self, fn):
|
||||
return fn()
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import numbers
|
||||
import tempfile
|
||||
import time
|
||||
@@ -10,9 +10,10 @@ import torch.distributed as dist
|
||||
import ray
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.trial import Resources
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.utils.util import merge_dicts
|
||||
from ray.util.sgd.torch.distributed_torch_runner import (
|
||||
DistributedTorchRunner, LocalDistributedRunner)
|
||||
DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner)
|
||||
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
|
||||
@@ -130,6 +131,11 @@ class TorchTrainer:
|
||||
|
||||
"""
|
||||
|
||||
# TODO: Implement autoscaling. If num_workers=-1, the trainer will use as
|
||||
# many resources as available. Upon each train call, TorchTrainer will
|
||||
# query the Ray global state for total available resources and resize
|
||||
# its remote workers to consume all available resources.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -218,6 +224,9 @@ class TorchTrainer:
|
||||
self._num_failures = 0
|
||||
self._last_resize = float("-inf")
|
||||
|
||||
self.local_worker = DeactivatedRunner()
|
||||
self.remote_workers = []
|
||||
|
||||
_validate_scheduler_step_freq(scheduler_step_freq)
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
@@ -250,9 +259,6 @@ class TorchTrainer:
|
||||
if batch_size_per_worker:
|
||||
worker_config[BATCH_SIZE] = batch_size_per_worker
|
||||
|
||||
self.local_worker = None
|
||||
self.remote_workers = []
|
||||
|
||||
if num_workers == 1:
|
||||
# Start local worker
|
||||
self.local_worker = TorchRunner(
|
||||
@@ -319,8 +325,7 @@ class TorchTrainer:
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
reduce_results=True,
|
||||
max_retries=0,
|
||||
checkpoint="auto",
|
||||
max_retries=3,
|
||||
info=None):
|
||||
"""Runs a training epoch.
|
||||
|
||||
@@ -339,14 +344,12 @@ class TorchTrainer:
|
||||
all workers into one dict. If a metric is a non-numerical
|
||||
value (or nested dictionaries), one value will be randomly
|
||||
selected among the workers. If False, returns a list of dicts.
|
||||
max_retries (int): Must be non-negative. If set to N, will
|
||||
kill all current workers, query the Ray global state for
|
||||
total available resources, and re-launch up to the
|
||||
available resources. Behavior is not well-defined
|
||||
in case of shared cluster usage.
|
||||
checkpoint (str): Path to checkpoint to restore from if retrying.
|
||||
If max_retries is set and ``checkpoint == "auto"``,
|
||||
TorchTrainer will save a checkpoint before starting to train.
|
||||
max_retries (int): Must be non-negative. If set to N, TorchTrainer
|
||||
will detect and recover from training failure. The recovery
|
||||
process will kill all current workers, query the Ray
|
||||
global state for total available resources, and re-launch up to
|
||||
the available resources. Behavior is not well-defined
|
||||
in case of shared cluster usage. Defaults to 3.
|
||||
info (dict): Optional dictionary passed to the training
|
||||
operator for ``train_epoch`` and ``train_batch``.
|
||||
|
||||
@@ -358,18 +361,9 @@ class TorchTrainer:
|
||||
length will be equal to ``num_workers``.
|
||||
"""
|
||||
assert max_retries >= 0, "`max_retries` must be non-negative."
|
||||
if max_retries:
|
||||
if checkpoint == "auto":
|
||||
logger.debug("Retrying detected. Automatically checkpointing.")
|
||||
checkpoint = self.save(
|
||||
os.path.join(self.temp_dir, "tmp_checkpoint"))
|
||||
elif not checkpoint:
|
||||
raise ValueError("Cannot retry from empty checkpoint.")
|
||||
|
||||
if checkpoint and self._should_resize():
|
||||
if self._should_resize():
|
||||
logger.info("Resize opportunity detected. Attempting to scale up.")
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
|
||||
self._resize_workers()
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
# Fault handling
|
||||
@@ -378,7 +372,7 @@ class TorchTrainer:
|
||||
break
|
||||
else:
|
||||
self._num_failures += 1
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
self._resize_workers()
|
||||
logger.info("Retrying training step with %d workers." %
|
||||
(len(self.remote_workers) + 1))
|
||||
success, worker_stats = self._train_epoch(
|
||||
@@ -483,7 +477,6 @@ class TorchTrainer:
|
||||
w.validate.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
local_worker_stats = self.local_worker.validate(**params)
|
||||
|
||||
return self._process_stats([local_worker_stats] +
|
||||
ray.get(remote_worker_stats))
|
||||
|
||||
@@ -497,47 +490,49 @@ class TorchTrainer:
|
||||
|
||||
def get_model(self):
|
||||
"""Returns the learned model(s)."""
|
||||
models = self.model_creator(self.config)
|
||||
state = self.local_worker.get_state()
|
||||
if len(state["models"]) == 1:
|
||||
models.load_state_dict(state["models"][0])
|
||||
else:
|
||||
for model, state_dict in zip(models, state["models"]):
|
||||
model.load_state_dict(state_dict)
|
||||
return models
|
||||
unwrapped = []
|
||||
for model in self.local_worker.models:
|
||||
unwrapped += [model.module if hasattr(model, "module") else model]
|
||||
if len(unwrapped) == 1:
|
||||
return unwrapped[0]
|
||||
return unwrapped
|
||||
|
||||
def state_dict(self):
|
||||
return self.local_worker.get_state()
|
||||
return self.local_worker.state_dict()
|
||||
|
||||
def load_state_dict(self, state):
|
||||
state_id = ray.put(state)
|
||||
def load_state_dict(self, state_dict, blocking=False):
|
||||
# This is not the most efficient because you have to wait for
|
||||
# the local worker to save then dump to buffer.
|
||||
self.local_worker.load_state_dict(state_dict)
|
||||
state_id = ray.put(self.local_worker.state_stream())
|
||||
|
||||
remote_calls = [
|
||||
worker.set_state.remote(state_id) for worker in self.remote_workers
|
||||
worker.load_state_stream.remote(state_id)
|
||||
for worker in self.remote_workers
|
||||
]
|
||||
self.local_worker.set_state(state)
|
||||
ray.get(remote_calls)
|
||||
if blocking:
|
||||
ray.get(remote_calls)
|
||||
|
||||
def save(self, checkpoint):
|
||||
"""Saves the model(s) to the provided checkpoint.
|
||||
"""Saves the Trainer state to the provided checkpoint path.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
|
||||
Returns:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
"""
|
||||
torch.save(self.state_dict(), checkpoint)
|
||||
return checkpoint
|
||||
|
||||
def restore(self, checkpoint):
|
||||
"""Restores the Trainer and all workers from the provided checkpoint.
|
||||
def load(self, checkpoint):
|
||||
"""Loads the Trainer and all workers from the provided checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
"""
|
||||
state = torch.load(checkpoint)
|
||||
self.load_state_dict(state)
|
||||
state_dict = torch.load(checkpoint)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
def restore(self, *args):
|
||||
raise DeprecationWarning("Use `TorchTrainer.load()` instead.")
|
||||
|
||||
def shutdown(self, force=False):
|
||||
"""Shuts down workers and releases resources."""
|
||||
@@ -562,19 +557,19 @@ class TorchTrainer:
|
||||
else:
|
||||
self.local_worker.shutdown()
|
||||
for worker in self.remote_workers:
|
||||
logger.warning("Killing worker {}.".format(worker))
|
||||
logger.debug("Killing worker {}.".format(worker))
|
||||
ray.kill(worker)
|
||||
|
||||
self.local_worker = None
|
||||
self.local_worker = DeactivatedRunner()
|
||||
self.remote_workers = []
|
||||
|
||||
def _reset(self):
|
||||
"""Terminates models without giving up local resource reservation."""
|
||||
self.local_worker.shutdown(cleanup=False)
|
||||
for worker in self.remote_workers:
|
||||
logger.warning("Killing worker {}.".format(worker))
|
||||
logger.debug("Killing worker {}.".format(worker))
|
||||
ray.kill(worker)
|
||||
self.local_worker = None
|
||||
self.local_worker = DeactivatedRunner()
|
||||
self.remote_workers = []
|
||||
|
||||
def _check_potential_remote_workers_size(self):
|
||||
@@ -588,9 +583,8 @@ class TorchTrainer:
|
||||
remote_resources.get("GPU", 0), new_remote_workers)
|
||||
return new_remote_workers
|
||||
|
||||
def _resize_workers(self, checkpoint, max_retries=10):
|
||||
def _resize_workers(self, max_retries=10):
|
||||
self._reset()
|
||||
assert checkpoint, "Cannot restore without checkpoint."
|
||||
|
||||
time.sleep(1)
|
||||
for i in range(max_retries):
|
||||
@@ -598,7 +592,7 @@ class TorchTrainer:
|
||||
if new_remote_workers:
|
||||
self._last_resize = time.time()
|
||||
self._start_workers(int(new_remote_workers) + 1)
|
||||
self.restore(checkpoint)
|
||||
self.load_state_dict(self.state_dict())
|
||||
return
|
||||
else:
|
||||
delay = 2**i
|
||||
@@ -617,32 +611,128 @@ class TorchTrainer:
|
||||
return potential_remote_size > 0
|
||||
return False
|
||||
|
||||
|
||||
class TorchTrainable(Trainable):
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
remote_worker_count = config["num_workers"] - 1
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=int(config["use_gpu"]),
|
||||
extra_cpu=int(remote_worker_count),
|
||||
extra_gpu=int(int(config["use_gpu"]) * remote_worker_count))
|
||||
def as_trainable(cls, *args, **kwargs):
|
||||
"""Creates a BaseTorchTrainable class compatible with Tune.
|
||||
|
||||
Any configuration parameters will be overriden by the Tune
|
||||
Trial configuration. You can also subclass the provided Trainable
|
||||
to implement your own iterative optimization routine.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
num_gpus=2
|
||||
)
|
||||
analysis = tune.run(
|
||||
TorchTrainable,
|
||||
config={"lr": tune.grid_search([0.01, 0.1])}
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
class TorchTrainable(BaseTorchTrainable):
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
num_workers = config.get("num_workers",
|
||||
kwargs.get("num_workers", 1))
|
||||
use_gpu = config.get("use_gpu", kwargs.get("use_gpu"))
|
||||
|
||||
remote_worker_count = num_workers - 1
|
||||
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=int(use_gpu),
|
||||
extra_cpu=int(remote_worker_count),
|
||||
extra_gpu=int(int(use_gpu) * remote_worker_count))
|
||||
|
||||
def _create_trainer(self, tune_config):
|
||||
"""Overrides the provided config with Tune config."""
|
||||
provided_config = kwargs.get("config", {}).copy()
|
||||
provided_config.update(tune_config)
|
||||
kwargs["config"] = provided_config
|
||||
trainer = TorchTrainer(*args, **kwargs)
|
||||
return trainer
|
||||
|
||||
return TorchTrainable
|
||||
|
||||
|
||||
class BaseTorchTrainable(Trainable):
|
||||
"""Base class for converting TorchTrainer to a Trainable class.
|
||||
|
||||
This class is produced when you call ``TorchTrainer.as_trainable(...)``.
|
||||
|
||||
You can override the produced Trainable to implement custom iterative
|
||||
training procedures:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
num_gpus=2
|
||||
)
|
||||
# TorchTrainable is subclass of BaseTorchTrainable.
|
||||
|
||||
class CustomTrainable(TorchTrainable):
|
||||
def _train(self):
|
||||
for i in range(5):
|
||||
train_stats = self.trainer.train()
|
||||
validation_stats = self.trainer.validate()
|
||||
train_stats.update(validation_stats)
|
||||
return train_stats
|
||||
|
||||
analysis = tune.run(
|
||||
CustomTrainable,
|
||||
config={"lr": tune.grid_search([0.01, 0.1])}
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def _setup(self, config):
|
||||
self._trainer = TorchTrainer(**config)
|
||||
"""Constructs a TorchTrainer object as `self.trainer`."""
|
||||
self._trainer = self._create_trainer(config)
|
||||
|
||||
def _train(self):
|
||||
train_stats = self._trainer.train()
|
||||
validation_stats = self._trainer.validate()
|
||||
"""Calls `self.trainer.train()` and `self.trainer.validate()` once.
|
||||
|
||||
train_stats.update(validation_stats)
|
||||
return train_stats
|
||||
You may want to override this if using a custom LR scheduler.
|
||||
"""
|
||||
train_stats = self.trainer.train(max_retries=10, profile=True)
|
||||
validation_stats = self.trainer.validate(profile=True)
|
||||
stats = merge_dicts(train_stats, validation_stats)
|
||||
return stats
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
return self._trainer.save(os.path.join(checkpoint_dir, "model.pth"))
|
||||
"""Returns a path containing the trainer state."""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "trainer.checkpoint")
|
||||
self.trainer.save(checkpoint_path)
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
return self._trainer.restore(checkpoint_path)
|
||||
"""Restores the trainer state.
|
||||
|
||||
Override this if you have state external to the Trainer object.
|
||||
"""
|
||||
return self.trainer.load(checkpoint_path)
|
||||
|
||||
def _stop(self):
|
||||
self._trainer.shutdown()
|
||||
"""Shuts down the trainer."""
|
||||
self.trainer.shutdown()
|
||||
|
||||
def _create_trainer(self, config):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def trainer(self):
|
||||
"""An instantiated TorchTrainer object.
|
||||
|
||||
Use this when specifying custom training procedures for Tune.
|
||||
"""
|
||||
return self._trainer
|
||||
|
||||
@@ -168,8 +168,10 @@ class TrainingOperator:
|
||||
|
||||
if self.use_tqdm and self.world_rank == 0:
|
||||
_progress_bar.n = batch_idx + 1
|
||||
postfix = {}
|
||||
if "train_loss" in metrics:
|
||||
_progress_bar.set_postfix({"loss": metrics["train_loss"]})
|
||||
postfix.update(loss=metrics["train_loss"])
|
||||
_progress_bar.set_postfix(postfix)
|
||||
|
||||
if self.scheduler and batch_info.get(
|
||||
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
|
||||
@@ -259,7 +261,7 @@ class TrainingOperator:
|
||||
|
||||
Returns:
|
||||
A dict of metrics from the evaluation.
|
||||
By default, returns "mean_accuracy" and "mean_val_loss"
|
||||
By default, returns "val_accuracy" and "val_loss"
|
||||
which is computed by aggregating "loss" and "correct" values
|
||||
from ``validate_batch`` and dividing it by the sum of
|
||||
``num_samples`` from all calls to ``self.validate_batch``.
|
||||
|
||||
@@ -192,7 +192,7 @@ class AverageMeterCollection:
|
||||
"""Returns a dict of average and most recent values for each metric."""
|
||||
stats = {BATCH_COUNT: self._batch_count, NUM_SAMPLES: self.n}
|
||||
for metric, meter in self._meters.items():
|
||||
stats["mean_" + str(metric)] = meter.avg
|
||||
stats[str(metric)] = meter.avg
|
||||
stats["last_" + str(metric)] = meter.val
|
||||
return stats
|
||||
|
||||
|
||||
Reference in New Issue
Block a user