[tune/sgd] Document func_trainable and add checkpoint context (#9739)

Co-authored-by: krfricke <krfricke@users.noreply.github.com>
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
Richard Liaw
2020-07-30 09:46:37 -07:00
committed by GitHub
parent e540e425e4
commit 0c3b9ebeef
23 changed files with 619 additions and 452 deletions
+8
View File
@@ -287,6 +287,14 @@ py_test(
args = ["--smoke-test"]
)
py_test(
name = "test_torch_trainable",
size = "medium",
srcs = ["tests/test_torch_trainable.py"],
tags = ["exclusive", "example", "pytorch"],
deps = [":tune_lib"],
)
py_test(
name = "ddp_mnist_torch",
size = "small",
+2 -2
View File
@@ -9,7 +9,7 @@ from ray.tune.durable_trainable import DurableTrainable
from ray.tune.suggest import grid_search
from ray.tune.session import (report, get_trial_dir, get_trial_name,
get_trial_id, make_checkpoint_dir,
save_checkpoint)
save_checkpoint, checkpoint_dir)
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
JupyterNotebookReporter)
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
@@ -22,5 +22,5 @@ __all__ = [
"uniform", "choice", "randint", "randn", "loguniform",
"ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
"ProgressReporter", "report", "get_trial_dir", "get_trial_name",
"get_trial_id", "make_checkpoint_dir", "save_checkpoint"
"get_trial_id", "make_checkpoint_dir", "save_checkpoint", "checkpoint_dir"
]
+10 -8
View File
@@ -58,7 +58,7 @@ class Net(nn.Module):
# __train_begin__
def train_cifar(config, checkpoint=None, data_dir=None):
def train_cifar(config, checkpoint_dir=None, data_dir=None):
net = Net(config["l1"], config["l2"])
device = "cpu"
@@ -71,8 +71,8 @@ def train_cifar(config, checkpoint=None, data_dir=None):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
if checkpoint:
print("loading checkpoint {}".format(checkpoint))
if checkpoint_dir:
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint)
net.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
@@ -138,10 +138,10 @@ def train_cifar(config, checkpoint=None, data_dir=None):
val_loss += loss.cpu().numpy()
val_steps += 1
checkpoint_dir = tune.make_checkpoint_dir(epoch)
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((net.state_dict(), optimizer.state_dict()), path)
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save(
(net.state_dict(), optimizer.state_dict()), path)
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
print("Finished Training")
@@ -213,7 +213,9 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device)
model_state, optimizer_state = torch.load(best_trial.checkpoint.value)
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)
test_acc = test_accuracy(best_trained_model, device)
+8 -6
View File
@@ -2,6 +2,7 @@
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import argparse
import logging
import os
import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
@@ -10,21 +11,21 @@ import ray
from ray import tune
from ray.tune.examples.mnist_pytorch import (train, test, get_data_loaders,
ConvNet)
from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator,
distributed_checkpoint)
from ray.tune.integration.torch import (DistributedTrainableCreator,
distributed_checkpoint_dir)
logger = logging.getLogger(__name__)
def train_mnist(config, checkpoint=False):
def train_mnist(config, checkpoint_dir=False):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_data_loaders()
model = ConvNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1)
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
model_state, optimizer_state = torch.load(f)
model.load_state_dict(model_state)
@@ -37,7 +38,8 @@ def train_mnist(config, checkpoint=False):
acc = test(model, test_loader, device)
if epoch % 3 == 0:
with distributed_checkpoint(label=epoch) as path:
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((model.state_dict(), optimizer.state_dict()), path)
tune.report(mean_accuracy=acc)
@@ -11,10 +11,10 @@ from ray import tune
from ray.tune.schedulers import HyperBandScheduler
def train(config, checkpoint=None):
def train(config, checkpoint_dir=None):
step = 0
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
step = json.loads(f.read())["timestep"]
for timestep in range(step, 100):
@@ -22,11 +22,10 @@ def train(config, checkpoint=None):
v *= config.get("height", 1)
if timestep % 3 == 0:
checkpoint_dir = tune.make_checkpoint_dir(step=timestep)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": timestep}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=timestep) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": timestep}))
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
@@ -158,16 +158,15 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
# __tune_checkpoint_callback_begin__
class CheckpointCallback(Callback):
def on_validation_end(self, trainer, pl_module):
path = tune.make_checkpoint_dir(trainer.global_step)
trainer.save_checkpoint(os.path.join(path, "checkpoint"))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
trainer.save_checkpoint(os.path.join(checkpoint_dir, "checkpoint"))
# __tune_checkpoint_callback_end__
# __tune_train_checkpoint_begin__
def train_mnist_tune_checkpoint(
config,
checkpoint=None,
checkpoint_dir=None,
data_dir=None,
num_epochs=10,
num_gpus=0):
@@ -179,13 +178,13 @@ def train_mnist_tune_checkpoint(
progress_bar_refresh_rate=0,
callbacks=[CheckpointCallback(),
TuneReportCallback()])
if checkpoint:
if checkpoint_dir:
# Currently, this leads to errors:
# model = LightningMNISTClassifier.load_from_checkpoint(
# os.path.join(checkpoint, "checkpoint"))
# Workaround:
ckpt = pl_load(
os.path.join(checkpoint, "checkpoint"),
os.path.join(checkpoint_dir, "checkpoint"),
map_location=lambda storage, loc: storage)
model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
trainer.current_epoch = ckpt["epoch"]
+7 -8
View File
@@ -11,7 +11,7 @@ from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
def pbt_function(config, checkpoint=None):
def pbt_function(config, checkpoint_dir=None):
"""Toy PBT problem for benchmarking adaptive learning rate.
The goal is to optimize this trainable's accuracy. The accuracy increases
@@ -35,8 +35,8 @@ def pbt_function(config, checkpoint=None):
lr = config["lr"]
accuracy = 0.0 # end = 1000
start = 0
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
state = json.loads(f.read())
accuracy = state["acc"]
start = state["step"]
@@ -65,11 +65,10 @@ def pbt_function(config, checkpoint=None):
accuracy = max(0, accuracy)
if step % 3 == 0:
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"acc": accuracy, "step": start}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"acc": accuracy, "step": start}))
tune.report(
mean_accuracy=accuracy,
+22 -10
View File
@@ -350,35 +350,47 @@ class FunctionRunner(Trainable):
pass
def detect_checkpoint_function(train_func):
func_args = inspect.getfullargspec(train_func).args
use_checkpoint = "checkpoint" in func_args
return use_checkpoint
def detect_checkpoint_function(train_func, abort=False):
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
argspec = inspect.getfullargspec(train_func)
func_args = argspec.args
func_kwargs = argspec.kwonlyargs
validated = len(func_args) == 2 and any("checkpoint_dir" in arg
for arg in func_args)
validated = validated or (len(func_args) == 1) and any(
"checkpoint_dir" in arg for arg in func_kwargs)
if abort and not validated:
raise ValueError(
"Provided training function must have 2 args "
"in the signature, and the latter arg must "
"contain `checkpoint_dir`. For example: "
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
return validated
def wrap_function(train_func):
class ImplicitFunc(FunctionRunner):
def _trainable_func(self, config, reporter, checkpoint):
def _trainable_func(self, config, reporter, checkpoint_dir):
func_args = inspect.getfullargspec(train_func).args
if len(func_args) > 1: # more arguments than just the config
if "reporter" not in func_args and (
"checkpoint" not in func_args):
not detect_checkpoint_function(train_func)):
raise ValueError(
"Unknown argument found in the Trainable function. "
"Arguments other than the 'config' arg must be one "
"of ['reporter', 'checkpoint']. Found: {}".format(
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
func_args))
use_reporter = "reporter" in func_args
use_checkpoint = "checkpoint" in func_args
use_checkpoint = detect_checkpoint_function(train_func)
if not use_checkpoint and not use_reporter:
logger.warning(
"Function checkpointing is disabled. This may result in "
"unexpected behavior when using checkpointing features or "
"certain schedulers. To enable, set the train function "
"arguments to be `func(config, checkpoint)`.")
"arguments to be `func(config, checkpoint_dir=None)`.")
output = train_func(config)
elif use_checkpoint:
output = train_func(config, checkpoint=checkpoint)
output = train_func(config, checkpoint_dir=checkpoint_dir)
else:
output = train_func(config, reporter)
+290
View File
@@ -0,0 +1,290 @@
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
from contextlib import contextmanager
import os
import logging
import shutil
import tempfile
import torch
from datetime import timedelta
import ray
from ray import tune
from ray.tune.result import RESULT_DUPLICATE
from ray.tune.logger import NoopLogger
from ray.tune.function_runner import (wrap_function,
detect_checkpoint_function)
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.util.sgd.torch.utils import setup_process_group, setup_address
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
logger = logging.getLogger(__name__)
_distributed_enabled = False
def is_distributed_trainable():
"""Returns True if executing within a DistributedTrainable."""
return _distributed_enabled
def enable_distributed_trainable():
global _distributed_enabled
_distributed_enabled = True
def logger_creator(log_config, logdir, rank):
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
os.makedirs(worker_dir, exist_ok=True)
return NoopLogger(log_config, worker_dir)
class _TorchTrainable(tune.Trainable):
"""Base class for distributed training on Tune.
A wrapper class is needed to actually create a working
version of this trainable.
"""
_function = None
_num_workers = None
_use_gpu = None
_num_cpus_per_worker = None
__slots__ = ["workers", "_finished"]
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
@classmethod
def get_remote_worker_options(self):
num_gpus = 1 if self._use_gpu else 0
num_cpus = int(self._num_cpus_per_worker or 1)
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
def setup(self, config):
self._finished = False
num_workers = self._num_workers
logdir = self.logdir
assert self._function
func_trainable = wrap_function(self.__class__._function)
remote_trainable = ray.remote(func_trainable)
remote_trainable = remote_trainable.options(
**self.get_remote_worker_options())
address = setup_address()
self.workers = [
remote_trainable.remote(
config=config,
logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
for rank in range(num_workers)
]
pgroup_params = self.default_process_group_parameters()
from functools import partial
setup_on_worker = partial(
setup_process_group,
url=address,
world_size=num_workers,
**pgroup_params)
ray.get([
w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
for rank, w in enumerate(self.workers)
])
ray.get([
w.execute.remote(lambda _: enable_distributed_trainable())
for rank, w in enumerate(self.workers)
])
def step(self):
if self._finished:
raise RuntimeError("Training has already finished.")
result = ray.get([w.step.remote() for w in self.workers])[0]
if RESULT_DUPLICATE in result:
self._finished = True
return result
def save_checkpoint(self, checkpoint_dir):
# TODO: optimize if colocated
save_obj = ray.get(self.workers[0].save_to_object.remote())
checkpoint_path = TrainableUtil.create_from_pickle(
save_obj, checkpoint_dir)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir):
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
return ray.get(
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
def stop(self):
ray.get([worker.stop.remote() for worker in self.workers])
def DistributedTrainableCreator(func,
use_gpu=False,
num_workers=1,
num_cpus_per_worker=1,
backend="gloo",
timeout_s=NCCL_TIMEOUT_S):
"""Creates a class that executes distributed training.
Similar to running `torch.distributed.launch`.
Note that you typically should not instantiate the object
created.
Args:
func (callable): This function is a Tune trainable function.
This function must have 2 args in the signature, and the
latter arg must contain `checkpoint_dir`. For example:
`func(config, checkpoint_dir=None)`.
use_gpu (bool): Sets resource allocation for workers to 1 GPU
if true. Also automatically sets CUDA_VISIBLE_DEVICES
for each training worker.
num_workers (int): Number of training workers to include in
world.
num_cpus_per_worker (int): Number of CPU resources to reserve
per training worker.
backend (str): One of "gloo", "nccl".
timeout_s (float): Seconds before the torch process group
times out. Useful when machines are unreliable. Defaults
to 60 seconds.
Returns:
A trainable class object that can be passed to Tune. Resources
are automatically set within the object, so users do
not need to set `resources_per_trainable`.
Example:
.. code-block::
trainable_cls = DistributedTrainableCreator(
train_func, num_workers=2)
analysis = tune.run(trainable_cls)
"""
detect_checkpoint_function(func, abort=True)
class WrappedDistributedTorchTrainable(_TorchTrainable):
_function = func
_num_workers = num_workers
_use_gpu = use_gpu
_num_cpus_per_worker = num_cpus_per_worker
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(timeout_s), backend=backend)
@classmethod
def default_resource_request(cls, config):
num_workers_ = int(config.get("num_workers", num_workers))
num_cpus = int(
config.get("num_cpus_per_worker", num_cpus_per_worker))
use_gpu_ = config.get("use_gpu", use_gpu)
return Resources(
cpu=0,
gpu=0,
extra_cpu=num_cpus * num_workers_,
extra_gpu=num_workers_ if use_gpu_ else 0)
return WrappedDistributedTorchTrainable
@contextmanager
def distributed_checkpoint_dir(step, disable=False):
"""ContextManager for creating a distributed checkpoint.
Only checkpoints a file on the "main" training actor, avoiding
redundant work.
Args:
step (int): Used to label the checkpoint
disable (bool): Disable for prototyping.
Yields:
path (str): A path to a directory. This path will be used
again when invoking the training_function.
Example:
.. code-block::
def train_func(config, checkpoint_dir):
if checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
model_state_dict = torch.load(path)
if epoch % 3 == 0:
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save(model.state_dict(), path)
"""
if torch.distributed.get_rank() == 0 and not disable:
with tune.checkpoint_dir(step=step) as checkpoint_dir:
yield checkpoint_dir
else:
path = tempfile.mkdtemp()
yield path
shutil.rmtree(path)
def _train_check_global(config, checkpoint_dir=None):
"""For testing only. Putting this here because Ray has problems
serializing within the test file."""
assert is_distributed_trainable()
import time
time.sleep(0.1)
tune.report(is_distributed=True)
def _train_simple(config, checkpoint_dir=None):
"""For testing only. Putting this here because Ray has problems
serializing within the test file."""
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 8, 5, 5, 5
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
loss_fn = nn.MSELoss()
# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)
optimizer = optim.SGD(model.parameters(), lr=0.1)
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
model_state, optimizer_state = torch.load(f)
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
model = DistributedDataParallel(model)
for epoch in range(config.get("epochs", 10)):
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
if epoch % 3 == 0:
if config.get("enable_checkpoint", True):
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save((model.state_dict(), optimizer.state_dict()),
path)
tune.report(mean_loss=loss.item())
+40 -51
View File
@@ -1,3 +1,4 @@
from contextlib import contextmanager
import os
import logging
@@ -75,48 +76,34 @@ def report(**kwargs):
def make_checkpoint_dir(step=None):
"""Gets the next checkpoint dir.
.. code-block:: python
import time
from ray import tune
def func(config, checkpoint=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
state = json.loads(f.read())
start = state["step"] + 1
for iter in range(start, 100):
time.sleep(1)
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
tune.report(hello="world", ray="tune")
.. warning:: Do not call this function within the Trainable Class API.
Args:
step (int): Current training iteration - used for setting
an index to uniquely identify the checkpoint.
.. versionadded:: 0.8.6
.. deprecated:: 0.8.7
Use tune.checkpoint_dir instead.
"""
_session = get_session()
if _session:
return _session.make_checkpoint_dir(step=step)
else:
return os.path.abspath("./")
raise DeprecationWarning(
"Deprecated method. Use `tune.checkpoint_dir` instead.")
def save_checkpoint(checkpoint):
"""Register the given checkpoint.
.. versionadded:: 0.8.6
.. deprecated:: 0.8.7
Use tune.checkpoint_dir instead.
"""
raise DeprecationWarning(
"Deprecated method. Use `tune.checkpoint_dir` instead.")
@contextmanager
def checkpoint_dir(step=None):
"""Returns a checkpoint dir inside a context.
Store any files related to restoring state within the
provided checkpoint dir.
.. code-block:: python
import os
@@ -124,10 +111,10 @@ def save_checkpoint(checkpoint):
import time
from ray import tune
def func(config, checkpoint=None):
def func(config, checkpoint_dir=None):
start = 0
if checkpoint:
with open(checkpoint) as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
state = json.loads(f.read())
accuracy = state["acc"]
start = state["step"] + 1
@@ -135,27 +122,29 @@ def save_checkpoint(checkpoint):
for iter in range(start, 10):
time.sleep(1)
checkpoint_dir = tune.make_checkpoint_dir(step=iter)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": start}))
tune.report(hello="world", ray="tune")
analysis = tune.run(run_me)
Yields:
checkpoint_dir (str): Directory for checkpointing.
.. warning:: Do not call this function within the Trainable Class API.
Args:
**kwargs: Any key value pair to be logged by Tune. Any of these
metrics can be used for early stopping or optimization.
.. versionadded:: 0.8.6
.. versionadded:: 0.8.7
"""
_session = get_session()
if _session:
return _session.save_checkpoint(checkpoint)
_checkpoint_dir = _session.make_checkpoint_dir(step=step)
else:
_checkpoint_dir = os.path.abspath("./")
yield _checkpoint_dir
if _session:
_session.save_checkpoint(_checkpoint_dir)
def get_trial_dir():
+63 -57
View File
@@ -19,7 +19,7 @@ class FunctionApiTest(unittest.TestCase):
_register_all() # re-register the evicted objects
def testFunctionNoCheckpointing(self):
def train(config, checkpoint=None):
def train(config, checkpoint_dir=None):
for i in range(10):
tune.report(test=i)
@@ -40,14 +40,13 @@ class FunctionApiTest(unittest.TestCase):
def testFunctionRecurringSave(self):
"""This tests that save and restore are commutative."""
def train(config, checkpoint=None):
def train(config, checkpoint_dir=None):
for step in range(10):
if step % 3 == 0:
checkpoint_dir = tune.make_checkpoint_dir(step=step)
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": step}))
tune.save_checkpoint(path)
with tune.checkpoint_dir(step=step) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"step": step}))
tune.report(test=step)
wrapped = wrap_function(train)
@@ -65,49 +64,58 @@ class FunctionApiTest(unittest.TestCase):
new_trainable2.stop()
def testCheckpointFunctionAtEnd(self):
def train(config, checkpoint=False):
def train(config, checkpoint_dir=False):
for i in range(10):
tune.report(test=i)
checkpoint_dir = tune.make_checkpoint_dir(step=10)
checkpoint_path = os.path.join(checkpoint_dir, "hello")
with open(checkpoint_path, "w") as f:
f.write("hello")
tune.save_checkpoint(checkpoint_path)
[trial] = tune.run(train).trials
assert "hello" in trial.checkpoint.value
def testVariousCheckpointFunctionAtEnd(self):
def train(config, checkpoint=False):
for i in range(10):
checkpoint_dir = tune.make_checkpoint_dir()
checkpoint_path = os.path.join(checkpoint_dir, "hello")
with tune.checkpoint_dir(step=10) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write("hello")
tune.save_checkpoint(checkpoint_path)
[trial] = tune.run(train).trials
assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log"))
def testCheckpointFunctionAtEndContext(self):
def train(config, checkpoint_dir=False):
for i in range(10):
tune.report(test=i)
checkpoint_dir = tune.make_checkpoint_dir()
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write("goodbye")
tune.save_checkpoint(checkpoint_path)
with tune.checkpoint_dir(step=10) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write("hello")
[trial] = tune.run(train).trials
assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log"))
def testVariousCheckpointFunctionAtEnd(self):
def train(config, checkpoint_dir=False):
for i in range(10):
with tune.checkpoint_dir() as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write("hello")
tune.report(test=i)
with tune.checkpoint_dir() as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log2")
with open(checkpoint_path, "w") as f:
f.write("goodbye")
[trial] = tune.run(train, keep_checkpoints_num=3).trials
assert "goodbye" in trial.checkpoint.value
assert os.path.exists(
os.path.join(trial.checkpoint.value, "ckpt.log2"))
def testReuseCheckpoint(self):
def train(config, checkpoint=False):
def train(config, checkpoint_dir=None):
itr = 0
if checkpoint:
with open(checkpoint, "r") as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f:
itr = int(f.read()) + 1
for i in range(itr, config["max_iter"]):
checkpoint_dir = tune.make_checkpoint_dir(step=i)
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.save_checkpoint(checkpoint_path)
with tune.checkpoint_dir(step=i) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.report(test=i, training_iteration=i)
[trial] = tune.run(
@@ -117,51 +125,49 @@ class FunctionApiTest(unittest.TestCase):
},
).trials
last_ckpt = trial.checkpoint.value
assert "goodbye" in last_ckpt
assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log"))
analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt)
trial_dfs = list(analysis.trial_dataframes.values())
assert len(trial_dfs[0]["training_iteration"]) == 5
def testRetry(self):
def train(config, checkpoint=None):
restored = bool(checkpoint)
def train(config, checkpoint_dir=None):
restored = bool(checkpoint_dir)
itr = 0
if checkpoint:
with open(checkpoint, "r") as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f:
itr = int(f.read()) + 1
for i in range(itr, 10):
if i == 5 and not restored:
raise Exception("try to fail me")
checkpoint_dir = tune.make_checkpoint_dir(step=i)
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.save_checkpoint(checkpoint_path)
with tune.checkpoint_dir(step=i) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.report(test=i, training_iteration=i)
analysis = tune.run(train, max_failures=3)
last_ckpt = analysis.trials[0].checkpoint.value
assert "goodbye" in last_ckpt
assert os.path.exists(os.path.join(last_ckpt, "ckpt.log"))
trial_dfs = list(analysis.trial_dataframes.values())
assert len(trial_dfs[0]["training_iteration"]) == 10
def testBlankCheckpoint(self):
def train(config, checkpoint=None):
restored = bool(checkpoint)
def train(config, checkpoint_dir=None):
restored = bool(checkpoint_dir)
itr = 0
if checkpoint:
with open(checkpoint, "r") as f:
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f:
itr = int(f.read()) + 1
for i in range(itr, 10):
if i == 5 and not restored:
raise Exception("try to fail me")
checkpoint_dir = tune.make_checkpoint_dir()
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.save_checkpoint(checkpoint_path)
with tune.checkpoint_dir() as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
with open(checkpoint_path, "w") as f:
f.write(str(i))
tune.report(test=i, training_iteration=i)
analysis = tune.run(train, max_failures=3)
@@ -6,8 +6,9 @@ import torch.distributed as dist
import ray
from ray import tune
from ray.util.sgd.torch.func_trainable import (
DistributedTrainableCreator, distributed_checkpoint, _train_simple)
from ray.tune.integration.torch import (DistributedTrainableCreator,
distributed_checkpoint_dir,
_train_simple, _train_check_global)
@pytest.fixture
@@ -47,12 +48,29 @@ def test_step_after_completion(ray_start_2_cpus): # noqa: F811
trainer.train()
def test_validation(ray_start_2_cpus): # noqa: F811
def bad_func(a, b, c):
return 1
with pytest.raises(ValueError):
DistributedTrainableCreator(bad_func, num_workers=2)
def test_set_global(ray_start_2_cpus): # noqa: F811
trainable_cls = DistributedTrainableCreator(
_train_check_global, num_workers=2)
trainable = trainable_cls()
result = trainable.train()
assert result["is_distributed"]
def test_save_checkpoint(ray_start_2_cpus): # noqa: F811
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
trainer = trainable_cls(config={"epochs": 1})
trainer.train()
path = trainer.save()
model_state_dict, opt_state_dict = torch.load(path)
checkpoint_dir = trainer.save()
model_state_dict, opt_state_dict = torch.load(
os.path.join(checkpoint_dir, "checkpoint"))
trainer.stop()
@@ -72,11 +90,11 @@ def test_simple_tune(ray_start_4_cpus, enabled_checkpoint):
def test_checkpoint(ray_start_2_cpus, rank): # noqa: F811
with patch("torch.distributed.get_rank") as rank_method:
rank_method.return_value = rank
with distributed_checkpoint(label="test") as path:
with distributed_checkpoint_dir(step="test") as path:
if rank == 0:
assert path
else:
assert path == os.devnull
if rank != 0:
assert not os.path.exists(path)
if __name__ == "__main__":
-8
View File
@@ -26,14 +26,6 @@ py_test(
deps = [":sgd_lib"],
)
py_test(
name = "test_torch_trainable",
size = "small",
srcs = ["tests/test_torch_trainable.py"],
tags = ["exclusive", "pytorch"],
deps = [":sgd_lib"],
)
# --------------------------------------------------------------------
# Tests from the python/ray/util/sgd/tf/examples directory.
# Please keep these sorted alphabetically.
+1 -6
View File
@@ -12,13 +12,8 @@ try:
BaseTorchTrainable)
from ray.util.sgd.torch.training_operator import TrainingOperator
from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator,
distributed_checkpoint)
__all__ = [
"TorchTrainer", "BaseTorchTrainable", "TrainingOperator",
"distributed_checkpoint", "DistributedTrainableCreator"
]
__all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"]
except ImportError as e:
logger.warning(e)
logger.warning("PyTorch not found. TorchTrainer will not be available")
-235
View File
@@ -1,235 +0,0 @@
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import logging
import torch
from datetime import timedelta
import ray
from ray import tune
from ray.tune.result import RESULT_DUPLICATE
from ray.tune.logger import NoopLogger
from ray.tune.function_runner import wrap_function
from ray.tune.resources import Resources
from ray.tune.trainable import TrainableUtil
from ray.util.sgd.torch.utils import setup_process_group
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
from ray.util.sgd.torch.utils import setup_address
logger = logging.getLogger(__name__)
def logger_creator(log_config, logdir, rank):
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
os.makedirs(worker_dir, exist_ok=True)
return NoopLogger(log_config, worker_dir)
class _TorchTrainable(tune.Trainable):
"""Base class for distributed training on Tune.
A wrapper class is needed to actually create a working
version of this trainable.
"""
_function = None
_num_workers = None
_use_gpu = None
_num_cpus_per_worker = None
__slots__ = ["workers", "_finished"]
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
@classmethod
def get_remote_worker_options(self):
num_gpus = 1 if self._use_gpu else 0
num_cpus = int(self._num_cpus_per_worker or 1)
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
def setup(self, config):
self._finished = False
num_workers = self._num_workers
logdir = self.logdir
assert self._function
func_trainable = wrap_function(self.__class__._function)
remote_trainable = ray.remote(func_trainable)
remote_trainable = remote_trainable.options(
**self.get_remote_worker_options())
address = setup_address()
self.workers = [
remote_trainable.remote(
config=config,
logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
for rank in range(num_workers)
]
pgroup_params = self.default_process_group_parameters()
from functools import partial
setup_on_worker = partial(
setup_process_group,
url=address,
world_size=num_workers,
**pgroup_params)
ray.get([
w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
for rank, w in enumerate(self.workers)
])
def step(self):
if self._finished:
raise RuntimeError("Training has already finished.")
result = ray.get([w.step.remote() for w in self.workers])[0]
if RESULT_DUPLICATE in result:
self._finished = True
return result
def save_checkpoint(self, checkpoint_dir):
# TODO: optimize if colocated
save_obj = ray.get(self.workers[0].save_to_object.remote())
checkpoint_path = TrainableUtil.create_from_pickle(
save_obj, checkpoint_dir)
return checkpoint_path
def load_checkpoint(self, checkpoint_dir):
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
return ray.get(
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
def stop(self):
ray.get([worker.stop.remote() for worker in self.workers])
def DistributedTrainableCreator(func,
use_gpu=False,
num_workers=1,
num_cpus_per_worker=1,
backend="gloo",
timeout_s=NCCL_TIMEOUT_S):
"""Creates a class that executes distributed training.
Note that you typically should not instantiate the object
created.
Example:
.. code-block::
trainable_cls = DistributedTrainableCreator(
train_func, num_workers=2)
analysis = tune.run(trainable_cls)
"""
class WrappedDistributedTorchTrainable(_TorchTrainable):
_function = func
_num_workers = num_workers
_use_gpu = use_gpu
_num_cpus_per_worker = num_cpus_per_worker
@classmethod
def default_process_group_parameters(self):
return dict(timeout=timedelta(timeout_s), backend=backend)
@classmethod
def default_resource_request(cls, config):
num_workers_ = int(config.get("num_workers", num_workers))
num_cpus = int(
config.get("num_cpus_per_worker", num_cpus_per_worker))
use_gpu_ = config.get("use_gpu", use_gpu)
return Resources(
cpu=0,
gpu=0,
extra_cpu=num_cpus * num_workers_,
extra_gpu=num_workers_ if use_gpu_ else 0)
return WrappedDistributedTorchTrainable
class distributed_checkpoint:
"""ContextManager for creating a distributed checkpoint.
Only checkpoints a file on the "main" training actor, avoiding
redundant work.
Args:
label (int | str): Used to label the checkpoint
disable (bool): Disable for prototyping.
Example:
.. code-block::
if epoch % 3 == 0:
with distributed_checkpoint(label=epoch) as path:
torch.save(model.state_dict(), path)
"""
def __init__(self, label, disable=False):
self.label = label
self.file = None
self.disable = disable
def __enter__(self):
if torch.distributed.get_rank() == 0 and not self.disable:
checkpoint_dir = tune.make_checkpoint_dir(step=self.label)
path = os.path.join(checkpoint_dir, "checkpoint")
else:
path = os.devnull
self.file = path
return path
def __exit__(self, type, value, traceback):
if torch.distributed.get_rank() == 0 and not self.disable:
tune.save_checkpoint(self.file)
def _train_simple(config, checkpoint=False):
"""For testing only. Putting this here because Ray has problems
serializing within the test file."""
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 8, 5, 5, 5
# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
loss_fn = nn.MSELoss()
# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
)
optimizer = optim.SGD(model.parameters(), lr=0.1)
if checkpoint:
with open(checkpoint) as f:
model_state, optimizer_state = torch.load(f)
model.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
model = DistributedDataParallel(model)
for epoch in range(config.get("epochs", 10)):
optimizer.zero_grad()
output = model(x)
loss = loss_fn(output, y)
loss.backward()
optimizer.step()
if epoch % 3 == 0:
if config.get("enable_checkpoint", True):
with distributed_checkpoint(label=epoch) as path:
torch.save((model.state_dict(), optimizer.state_dict()),
path)
tune.report(mean_loss=loss.item())
@@ -140,6 +140,8 @@ class TorchTrainer:
Defaults to True.
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
over each model. If False, you are expected to call it yourself.
timeout_s (float): Seconds before the torch process group
times out. Useful when machines are unreliable.
add_dist_sampler (bool): Whether to automatically add a
DistributedSampler to all created dataloaders. Only applicable
if num_workers > 1.