mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[tune/sgd] Document func_trainable and add checkpoint context (#9739)
Co-authored-by: krfricke <krfricke@users.noreply.github.com> Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
This commit is contained in:
@@ -287,6 +287,14 @@ py_test(
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_torch_trainable",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_torch_trainable.py"],
|
||||
tags = ["exclusive", "example", "pytorch"],
|
||||
deps = [":tune_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ddp_mnist_torch",
|
||||
size = "small",
|
||||
|
||||
@@ -9,7 +9,7 @@ from ray.tune.durable_trainable import DurableTrainable
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.session import (report, get_trial_dir, get_trial_name,
|
||||
get_trial_id, make_checkpoint_dir,
|
||||
save_checkpoint)
|
||||
save_checkpoint, checkpoint_dir)
|
||||
from ray.tune.progress_reporter import (ProgressReporter, CLIReporter,
|
||||
JupyterNotebookReporter)
|
||||
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
|
||||
@@ -22,5 +22,5 @@ __all__ = [
|
||||
"uniform", "choice", "randint", "randn", "loguniform",
|
||||
"ExperimentAnalysis", "Analysis", "CLIReporter", "JupyterNotebookReporter",
|
||||
"ProgressReporter", "report", "get_trial_dir", "get_trial_name",
|
||||
"get_trial_id", "make_checkpoint_dir", "save_checkpoint"
|
||||
"get_trial_id", "make_checkpoint_dir", "save_checkpoint", "checkpoint_dir"
|
||||
]
|
||||
|
||||
@@ -58,7 +58,7 @@ class Net(nn.Module):
|
||||
|
||||
|
||||
# __train_begin__
|
||||
def train_cifar(config, checkpoint=None, data_dir=None):
|
||||
def train_cifar(config, checkpoint_dir=None, data_dir=None):
|
||||
net = Net(config["l1"], config["l2"])
|
||||
|
||||
device = "cpu"
|
||||
@@ -71,8 +71,8 @@ def train_cifar(config, checkpoint=None, data_dir=None):
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
||||
|
||||
if checkpoint:
|
||||
print("loading checkpoint {}".format(checkpoint))
|
||||
if checkpoint_dir:
|
||||
checkpoint = os.path.join(checkpoint_dir, "checkpoint")
|
||||
model_state, optimizer_state = torch.load(checkpoint)
|
||||
net.load_state_dict(model_state)
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
@@ -138,10 +138,10 @@ def train_cifar(config, checkpoint=None, data_dir=None):
|
||||
val_loss += loss.cpu().numpy()
|
||||
val_steps += 1
|
||||
|
||||
checkpoint_dir = tune.make_checkpoint_dir(epoch)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save((net.state_dict(), optimizer.state_dict()), path)
|
||||
tune.save_checkpoint(path)
|
||||
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save(
|
||||
(net.state_dict(), optimizer.state_dict()), path)
|
||||
|
||||
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
|
||||
print("Finished Training")
|
||||
@@ -213,7 +213,9 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
|
||||
best_trained_model = nn.DataParallel(best_trained_model)
|
||||
best_trained_model.to(device)
|
||||
|
||||
model_state, optimizer_state = torch.load(best_trial.checkpoint.value)
|
||||
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
|
||||
|
||||
model_state, optimizer_state = torch.load(checkpoint_path)
|
||||
best_trained_model.load_state_dict(model_state)
|
||||
|
||||
test_acc = test_accuracy(best_trained_model, device)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
@@ -10,21 +11,21 @@ import ray
|
||||
from ray import tune
|
||||
from ray.tune.examples.mnist_pytorch import (train, test, get_data_loaders,
|
||||
ConvNet)
|
||||
from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator,
|
||||
distributed_checkpoint)
|
||||
from ray.tune.integration.torch import (DistributedTrainableCreator,
|
||||
distributed_checkpoint_dir)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train_mnist(config, checkpoint=False):
|
||||
def train_mnist(config, checkpoint_dir=False):
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
train_loader, test_loader = get_data_loaders()
|
||||
model = ConvNet().to(device)
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
|
||||
model_state, optimizer_state = torch.load(f)
|
||||
|
||||
model.load_state_dict(model_state)
|
||||
@@ -37,7 +38,8 @@ def train_mnist(config, checkpoint=False):
|
||||
acc = test(model, test_loader, device)
|
||||
|
||||
if epoch % 3 == 0:
|
||||
with distributed_checkpoint(label=epoch) as path:
|
||||
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save((model.state_dict(), optimizer.state_dict()), path)
|
||||
tune.report(mean_accuracy=acc)
|
||||
|
||||
|
||||
@@ -11,10 +11,10 @@ from ray import tune
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
|
||||
|
||||
def train(config, checkpoint=None):
|
||||
def train(config, checkpoint_dir=None):
|
||||
step = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
|
||||
step = json.loads(f.read())["timestep"]
|
||||
|
||||
for timestep in range(step, 100):
|
||||
@@ -22,11 +22,10 @@ def train(config, checkpoint=None):
|
||||
v *= config.get("height", 1)
|
||||
|
||||
if timestep % 3 == 0:
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=timestep)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": timestep}))
|
||||
tune.save_checkpoint(path)
|
||||
with tune.checkpoint_dir(step=timestep) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": timestep}))
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy.
|
||||
|
||||
@@ -158,16 +158,15 @@ def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
|
||||
# __tune_checkpoint_callback_begin__
|
||||
class CheckpointCallback(Callback):
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
path = tune.make_checkpoint_dir(trainer.global_step)
|
||||
trainer.save_checkpoint(os.path.join(path, "checkpoint"))
|
||||
tune.save_checkpoint(path)
|
||||
with tune.checkpoint_dir(step=trainer.global_step) as checkpoint_dir:
|
||||
trainer.save_checkpoint(os.path.join(checkpoint_dir, "checkpoint"))
|
||||
# __tune_checkpoint_callback_end__
|
||||
|
||||
|
||||
# __tune_train_checkpoint_begin__
|
||||
def train_mnist_tune_checkpoint(
|
||||
config,
|
||||
checkpoint=None,
|
||||
checkpoint_dir=None,
|
||||
data_dir=None,
|
||||
num_epochs=10,
|
||||
num_gpus=0):
|
||||
@@ -179,13 +178,13 @@ def train_mnist_tune_checkpoint(
|
||||
progress_bar_refresh_rate=0,
|
||||
callbacks=[CheckpointCallback(),
|
||||
TuneReportCallback()])
|
||||
if checkpoint:
|
||||
if checkpoint_dir:
|
||||
# Currently, this leads to errors:
|
||||
# model = LightningMNISTClassifier.load_from_checkpoint(
|
||||
# os.path.join(checkpoint, "checkpoint"))
|
||||
# Workaround:
|
||||
ckpt = pl_load(
|
||||
os.path.join(checkpoint, "checkpoint"),
|
||||
os.path.join(checkpoint_dir, "checkpoint"),
|
||||
map_location=lambda storage, loc: storage)
|
||||
model = LightningMNISTClassifier._load_model_state(ckpt, config=config)
|
||||
trainer.current_epoch = ckpt["epoch"]
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
|
||||
def pbt_function(config, checkpoint=None):
|
||||
def pbt_function(config, checkpoint_dir=None):
|
||||
"""Toy PBT problem for benchmarking adaptive learning rate.
|
||||
|
||||
The goal is to optimize this trainable's accuracy. The accuracy increases
|
||||
@@ -35,8 +35,8 @@ def pbt_function(config, checkpoint=None):
|
||||
lr = config["lr"]
|
||||
accuracy = 0.0 # end = 1000
|
||||
start = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
|
||||
state = json.loads(f.read())
|
||||
accuracy = state["acc"]
|
||||
start = state["step"]
|
||||
@@ -65,11 +65,10 @@ def pbt_function(config, checkpoint=None):
|
||||
accuracy = max(0, accuracy)
|
||||
|
||||
if step % 3 == 0:
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=step)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"acc": accuracy, "step": start}))
|
||||
tune.save_checkpoint(path)
|
||||
with tune.checkpoint_dir(step=step) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"acc": accuracy, "step": start}))
|
||||
|
||||
tune.report(
|
||||
mean_accuracy=accuracy,
|
||||
|
||||
@@ -350,35 +350,47 @@ class FunctionRunner(Trainable):
|
||||
pass
|
||||
|
||||
|
||||
def detect_checkpoint_function(train_func):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
use_checkpoint = "checkpoint" in func_args
|
||||
return use_checkpoint
|
||||
def detect_checkpoint_function(train_func, abort=False):
|
||||
"""Use checkpointing if any arg has "checkpoint_dir" and args = 2"""
|
||||
argspec = inspect.getfullargspec(train_func)
|
||||
func_args = argspec.args
|
||||
func_kwargs = argspec.kwonlyargs
|
||||
validated = len(func_args) == 2 and any("checkpoint_dir" in arg
|
||||
for arg in func_args)
|
||||
validated = validated or (len(func_args) == 1) and any(
|
||||
"checkpoint_dir" in arg for arg in func_kwargs)
|
||||
if abort and not validated:
|
||||
raise ValueError(
|
||||
"Provided training function must have 2 args "
|
||||
"in the signature, and the latter arg must "
|
||||
"contain `checkpoint_dir`. For example: "
|
||||
"`func(config, checkpoint_dir=None)`. Got {}".format(func_args))
|
||||
return validated
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
class ImplicitFunc(FunctionRunner):
|
||||
def _trainable_func(self, config, reporter, checkpoint):
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
if "reporter" not in func_args and (
|
||||
"checkpoint" not in func_args):
|
||||
not detect_checkpoint_function(train_func)):
|
||||
raise ValueError(
|
||||
"Unknown argument found in the Trainable function. "
|
||||
"Arguments other than the 'config' arg must be one "
|
||||
"of ['reporter', 'checkpoint']. Found: {}".format(
|
||||
"of ['reporter', 'checkpoint_dir']. Found: {}".format(
|
||||
func_args))
|
||||
use_reporter = "reporter" in func_args
|
||||
use_checkpoint = "checkpoint" in func_args
|
||||
use_checkpoint = detect_checkpoint_function(train_func)
|
||||
if not use_checkpoint and not use_reporter:
|
||||
logger.warning(
|
||||
"Function checkpointing is disabled. This may result in "
|
||||
"unexpected behavior when using checkpointing features or "
|
||||
"certain schedulers. To enable, set the train function "
|
||||
"arguments to be `func(config, checkpoint)`.")
|
||||
"arguments to be `func(config, checkpoint_dir=None)`.")
|
||||
output = train_func(config)
|
||||
elif use_checkpoint:
|
||||
output = train_func(config, checkpoint=checkpoint)
|
||||
output = train_func(config, checkpoint_dir=checkpoint_dir)
|
||||
else:
|
||||
output = train_func(config, reporter)
|
||||
|
||||
|
||||
@@ -0,0 +1,290 @@
|
||||
# Original Code here:
|
||||
# https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import torch
|
||||
from datetime import timedelta
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.result import RESULT_DUPLICATE
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.function_runner import (wrap_function,
|
||||
detect_checkpoint_function)
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.util.sgd.torch.utils import setup_process_group, setup_address
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_distributed_enabled = False
|
||||
|
||||
|
||||
def is_distributed_trainable():
|
||||
"""Returns True if executing within a DistributedTrainable."""
|
||||
return _distributed_enabled
|
||||
|
||||
|
||||
def enable_distributed_trainable():
|
||||
global _distributed_enabled
|
||||
_distributed_enabled = True
|
||||
|
||||
|
||||
def logger_creator(log_config, logdir, rank):
|
||||
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
|
||||
os.makedirs(worker_dir, exist_ok=True)
|
||||
return NoopLogger(log_config, worker_dir)
|
||||
|
||||
|
||||
class _TorchTrainable(tune.Trainable):
|
||||
"""Base class for distributed training on Tune.
|
||||
|
||||
A wrapper class is needed to actually create a working
|
||||
version of this trainable.
|
||||
"""
|
||||
_function = None
|
||||
_num_workers = None
|
||||
_use_gpu = None
|
||||
_num_cpus_per_worker = None
|
||||
|
||||
__slots__ = ["workers", "_finished"]
|
||||
|
||||
@classmethod
|
||||
def default_process_group_parameters(self):
|
||||
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
|
||||
|
||||
@classmethod
|
||||
def get_remote_worker_options(self):
|
||||
num_gpus = 1 if self._use_gpu else 0
|
||||
num_cpus = int(self._num_cpus_per_worker or 1)
|
||||
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
|
||||
|
||||
def setup(self, config):
|
||||
self._finished = False
|
||||
num_workers = self._num_workers
|
||||
logdir = self.logdir
|
||||
assert self._function
|
||||
|
||||
func_trainable = wrap_function(self.__class__._function)
|
||||
|
||||
remote_trainable = ray.remote(func_trainable)
|
||||
remote_trainable = remote_trainable.options(
|
||||
**self.get_remote_worker_options())
|
||||
|
||||
address = setup_address()
|
||||
self.workers = [
|
||||
remote_trainable.remote(
|
||||
config=config,
|
||||
logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
|
||||
for rank in range(num_workers)
|
||||
]
|
||||
|
||||
pgroup_params = self.default_process_group_parameters()
|
||||
from functools import partial
|
||||
setup_on_worker = partial(
|
||||
setup_process_group,
|
||||
url=address,
|
||||
world_size=num_workers,
|
||||
**pgroup_params)
|
||||
ray.get([
|
||||
w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
|
||||
for rank, w in enumerate(self.workers)
|
||||
])
|
||||
|
||||
ray.get([
|
||||
w.execute.remote(lambda _: enable_distributed_trainable())
|
||||
for rank, w in enumerate(self.workers)
|
||||
])
|
||||
|
||||
def step(self):
|
||||
if self._finished:
|
||||
raise RuntimeError("Training has already finished.")
|
||||
result = ray.get([w.step.remote() for w in self.workers])[0]
|
||||
if RESULT_DUPLICATE in result:
|
||||
self._finished = True
|
||||
return result
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
# TODO: optimize if colocated
|
||||
save_obj = ray.get(self.workers[0].save_to_object.remote())
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(
|
||||
save_obj, checkpoint_dir)
|
||||
return checkpoint_path
|
||||
|
||||
def load_checkpoint(self, checkpoint_dir):
|
||||
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
|
||||
return ray.get(
|
||||
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
|
||||
|
||||
def stop(self):
|
||||
ray.get([worker.stop.remote() for worker in self.workers])
|
||||
|
||||
|
||||
def DistributedTrainableCreator(func,
|
||||
use_gpu=False,
|
||||
num_workers=1,
|
||||
num_cpus_per_worker=1,
|
||||
backend="gloo",
|
||||
timeout_s=NCCL_TIMEOUT_S):
|
||||
"""Creates a class that executes distributed training.
|
||||
|
||||
Similar to running `torch.distributed.launch`.
|
||||
|
||||
Note that you typically should not instantiate the object
|
||||
created.
|
||||
|
||||
Args:
|
||||
func (callable): This function is a Tune trainable function.
|
||||
This function must have 2 args in the signature, and the
|
||||
latter arg must contain `checkpoint_dir`. For example:
|
||||
`func(config, checkpoint_dir=None)`.
|
||||
use_gpu (bool): Sets resource allocation for workers to 1 GPU
|
||||
if true. Also automatically sets CUDA_VISIBLE_DEVICES
|
||||
for each training worker.
|
||||
num_workers (int): Number of training workers to include in
|
||||
world.
|
||||
num_cpus_per_worker (int): Number of CPU resources to reserve
|
||||
per training worker.
|
||||
backend (str): One of "gloo", "nccl".
|
||||
timeout_s (float): Seconds before the torch process group
|
||||
times out. Useful when machines are unreliable. Defaults
|
||||
to 60 seconds.
|
||||
|
||||
Returns:
|
||||
A trainable class object that can be passed to Tune. Resources
|
||||
are automatically set within the object, so users do
|
||||
not need to set `resources_per_trainable`.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
trainable_cls = DistributedTrainableCreator(
|
||||
train_func, num_workers=2)
|
||||
analysis = tune.run(trainable_cls)
|
||||
"""
|
||||
detect_checkpoint_function(func, abort=True)
|
||||
|
||||
class WrappedDistributedTorchTrainable(_TorchTrainable):
|
||||
_function = func
|
||||
_num_workers = num_workers
|
||||
_use_gpu = use_gpu
|
||||
_num_cpus_per_worker = num_cpus_per_worker
|
||||
|
||||
@classmethod
|
||||
def default_process_group_parameters(self):
|
||||
return dict(timeout=timedelta(timeout_s), backend=backend)
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
num_workers_ = int(config.get("num_workers", num_workers))
|
||||
num_cpus = int(
|
||||
config.get("num_cpus_per_worker", num_cpus_per_worker))
|
||||
use_gpu_ = config.get("use_gpu", use_gpu)
|
||||
|
||||
return Resources(
|
||||
cpu=0,
|
||||
gpu=0,
|
||||
extra_cpu=num_cpus * num_workers_,
|
||||
extra_gpu=num_workers_ if use_gpu_ else 0)
|
||||
|
||||
return WrappedDistributedTorchTrainable
|
||||
|
||||
|
||||
@contextmanager
|
||||
def distributed_checkpoint_dir(step, disable=False):
|
||||
"""ContextManager for creating a distributed checkpoint.
|
||||
|
||||
Only checkpoints a file on the "main" training actor, avoiding
|
||||
redundant work.
|
||||
|
||||
Args:
|
||||
step (int): Used to label the checkpoint
|
||||
disable (bool): Disable for prototyping.
|
||||
|
||||
Yields:
|
||||
path (str): A path to a directory. This path will be used
|
||||
again when invoking the training_function.
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
def train_func(config, checkpoint_dir):
|
||||
if checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
model_state_dict = torch.load(path)
|
||||
|
||||
if epoch % 3 == 0:
|
||||
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save(model.state_dict(), path)
|
||||
"""
|
||||
|
||||
if torch.distributed.get_rank() == 0 and not disable:
|
||||
with tune.checkpoint_dir(step=step) as checkpoint_dir:
|
||||
yield checkpoint_dir
|
||||
else:
|
||||
path = tempfile.mkdtemp()
|
||||
yield path
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
def _train_check_global(config, checkpoint_dir=None):
|
||||
"""For testing only. Putting this here because Ray has problems
|
||||
serializing within the test file."""
|
||||
assert is_distributed_trainable()
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
tune.report(is_distributed=True)
|
||||
|
||||
|
||||
def _train_simple(config, checkpoint_dir=None):
|
||||
"""For testing only. Putting this here because Ray has problems
|
||||
serializing within the test file."""
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
import torch.optim as optim
|
||||
# N is batch size; D_in is input dimension;
|
||||
# H is hidden dimension; D_out is output dimension.
|
||||
N, D_in, H, D_out = 8, 5, 5, 5
|
||||
|
||||
# Create random Tensors to hold inputs and outputs
|
||||
x = torch.randn(N, D_in)
|
||||
y = torch.randn(N, D_out)
|
||||
loss_fn = nn.MSELoss()
|
||||
|
||||
# Use the nn package to define our model and loss function.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(D_in, H),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(H, D_out),
|
||||
)
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
|
||||
model_state, optimizer_state = torch.load(f)
|
||||
|
||||
model.load_state_dict(model_state)
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
model = DistributedDataParallel(model)
|
||||
|
||||
for epoch in range(config.get("epochs", 10)):
|
||||
optimizer.zero_grad()
|
||||
output = model(x)
|
||||
loss = loss_fn(output, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if epoch % 3 == 0:
|
||||
if config.get("enable_checkpoint", True):
|
||||
with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
torch.save((model.state_dict(), optimizer.state_dict()),
|
||||
path)
|
||||
tune.report(mean_loss=loss.item())
|
||||
+40
-51
@@ -1,3 +1,4 @@
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
import logging
|
||||
|
||||
@@ -75,48 +76,34 @@ def report(**kwargs):
|
||||
def make_checkpoint_dir(step=None):
|
||||
"""Gets the next checkpoint dir.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray import tune
|
||||
|
||||
def func(config, checkpoint=None):
|
||||
start = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
state = json.loads(f.read())
|
||||
start = state["step"] + 1
|
||||
|
||||
for iter in range(start, 100):
|
||||
time.sleep(1)
|
||||
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=step)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": start}))
|
||||
tune.save_checkpoint(path)
|
||||
|
||||
tune.report(hello="world", ray="tune")
|
||||
|
||||
.. warning:: Do not call this function within the Trainable Class API.
|
||||
|
||||
Args:
|
||||
step (int): Current training iteration - used for setting
|
||||
an index to uniquely identify the checkpoint.
|
||||
|
||||
.. versionadded:: 0.8.6
|
||||
|
||||
.. deprecated:: 0.8.7
|
||||
Use tune.checkpoint_dir instead.
|
||||
"""
|
||||
_session = get_session()
|
||||
if _session:
|
||||
return _session.make_checkpoint_dir(step=step)
|
||||
else:
|
||||
return os.path.abspath("./")
|
||||
raise DeprecationWarning(
|
||||
"Deprecated method. Use `tune.checkpoint_dir` instead.")
|
||||
|
||||
|
||||
def save_checkpoint(checkpoint):
|
||||
"""Register the given checkpoint.
|
||||
|
||||
.. versionadded:: 0.8.6
|
||||
|
||||
.. deprecated:: 0.8.7
|
||||
Use tune.checkpoint_dir instead.
|
||||
"""
|
||||
raise DeprecationWarning(
|
||||
"Deprecated method. Use `tune.checkpoint_dir` instead.")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def checkpoint_dir(step=None):
|
||||
"""Returns a checkpoint dir inside a context.
|
||||
|
||||
Store any files related to restoring state within the
|
||||
provided checkpoint dir.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import os
|
||||
@@ -124,10 +111,10 @@ def save_checkpoint(checkpoint):
|
||||
import time
|
||||
from ray import tune
|
||||
|
||||
def func(config, checkpoint=None):
|
||||
def func(config, checkpoint_dir=None):
|
||||
start = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
|
||||
state = json.loads(f.read())
|
||||
accuracy = state["acc"]
|
||||
start = state["step"] + 1
|
||||
@@ -135,27 +122,29 @@ def save_checkpoint(checkpoint):
|
||||
for iter in range(start, 10):
|
||||
time.sleep(1)
|
||||
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=iter)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": start}))
|
||||
tune.save_checkpoint(path)
|
||||
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": start}))
|
||||
|
||||
tune.report(hello="world", ray="tune")
|
||||
|
||||
analysis = tune.run(run_me)
|
||||
Yields:
|
||||
checkpoint_dir (str): Directory for checkpointing.
|
||||
|
||||
.. warning:: Do not call this function within the Trainable Class API.
|
||||
|
||||
Args:
|
||||
**kwargs: Any key value pair to be logged by Tune. Any of these
|
||||
metrics can be used for early stopping or optimization.
|
||||
|
||||
.. versionadded:: 0.8.6
|
||||
.. versionadded:: 0.8.7
|
||||
"""
|
||||
_session = get_session()
|
||||
|
||||
if _session:
|
||||
return _session.save_checkpoint(checkpoint)
|
||||
_checkpoint_dir = _session.make_checkpoint_dir(step=step)
|
||||
else:
|
||||
_checkpoint_dir = os.path.abspath("./")
|
||||
|
||||
yield _checkpoint_dir
|
||||
|
||||
if _session:
|
||||
_session.save_checkpoint(_checkpoint_dir)
|
||||
|
||||
|
||||
def get_trial_dir():
|
||||
|
||||
@@ -19,7 +19,7 @@ class FunctionApiTest(unittest.TestCase):
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testFunctionNoCheckpointing(self):
|
||||
def train(config, checkpoint=None):
|
||||
def train(config, checkpoint_dir=None):
|
||||
for i in range(10):
|
||||
tune.report(test=i)
|
||||
|
||||
@@ -40,14 +40,13 @@ class FunctionApiTest(unittest.TestCase):
|
||||
def testFunctionRecurringSave(self):
|
||||
"""This tests that save and restore are commutative."""
|
||||
|
||||
def train(config, checkpoint=None):
|
||||
def train(config, checkpoint_dir=None):
|
||||
for step in range(10):
|
||||
if step % 3 == 0:
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=step)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": step}))
|
||||
tune.save_checkpoint(path)
|
||||
with tune.checkpoint_dir(step=step) as checkpoint_dir:
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"step": step}))
|
||||
tune.report(test=step)
|
||||
|
||||
wrapped = wrap_function(train)
|
||||
@@ -65,49 +64,58 @@ class FunctionApiTest(unittest.TestCase):
|
||||
new_trainable2.stop()
|
||||
|
||||
def testCheckpointFunctionAtEnd(self):
|
||||
def train(config, checkpoint=False):
|
||||
def train(config, checkpoint_dir=False):
|
||||
for i in range(10):
|
||||
tune.report(test=i)
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=10)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "hello")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("hello")
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
|
||||
[trial] = tune.run(train).trials
|
||||
assert "hello" in trial.checkpoint.value
|
||||
|
||||
def testVariousCheckpointFunctionAtEnd(self):
|
||||
def train(config, checkpoint=False):
|
||||
for i in range(10):
|
||||
checkpoint_dir = tune.make_checkpoint_dir()
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "hello")
|
||||
with tune.checkpoint_dir(step=10) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("hello")
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
|
||||
[trial] = tune.run(train).trials
|
||||
assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log"))
|
||||
|
||||
def testCheckpointFunctionAtEndContext(self):
|
||||
def train(config, checkpoint_dir=False):
|
||||
for i in range(10):
|
||||
tune.report(test=i)
|
||||
checkpoint_dir = tune.make_checkpoint_dir()
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("goodbye")
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
with tune.checkpoint_dir(step=10) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("hello")
|
||||
|
||||
[trial] = tune.run(train).trials
|
||||
assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log"))
|
||||
|
||||
def testVariousCheckpointFunctionAtEnd(self):
|
||||
def train(config, checkpoint_dir=False):
|
||||
for i in range(10):
|
||||
with tune.checkpoint_dir() as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("hello")
|
||||
tune.report(test=i)
|
||||
with tune.checkpoint_dir() as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log2")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write("goodbye")
|
||||
|
||||
[trial] = tune.run(train, keep_checkpoints_num=3).trials
|
||||
assert "goodbye" in trial.checkpoint.value
|
||||
assert os.path.exists(
|
||||
os.path.join(trial.checkpoint.value, "ckpt.log2"))
|
||||
|
||||
def testReuseCheckpoint(self):
|
||||
def train(config, checkpoint=False):
|
||||
def train(config, checkpoint_dir=None):
|
||||
itr = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint, "r") as f:
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f:
|
||||
itr = int(f.read()) + 1
|
||||
|
||||
for i in range(itr, config["max_iter"]):
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=i)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
with tune.checkpoint_dir(step=i) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
tune.report(test=i, training_iteration=i)
|
||||
|
||||
[trial] = tune.run(
|
||||
@@ -117,51 +125,49 @@ class FunctionApiTest(unittest.TestCase):
|
||||
},
|
||||
).trials
|
||||
last_ckpt = trial.checkpoint.value
|
||||
assert "goodbye" in last_ckpt
|
||||
assert os.path.exists(os.path.join(trial.checkpoint.value, "ckpt.log"))
|
||||
analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt)
|
||||
trial_dfs = list(analysis.trial_dataframes.values())
|
||||
assert len(trial_dfs[0]["training_iteration"]) == 5
|
||||
|
||||
def testRetry(self):
|
||||
def train(config, checkpoint=None):
|
||||
restored = bool(checkpoint)
|
||||
def train(config, checkpoint_dir=None):
|
||||
restored = bool(checkpoint_dir)
|
||||
itr = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint, "r") as f:
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f:
|
||||
itr = int(f.read()) + 1
|
||||
|
||||
for i in range(itr, 10):
|
||||
if i == 5 and not restored:
|
||||
raise Exception("try to fail me")
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=i)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
with tune.checkpoint_dir(step=i) as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
tune.report(test=i, training_iteration=i)
|
||||
|
||||
analysis = tune.run(train, max_failures=3)
|
||||
last_ckpt = analysis.trials[0].checkpoint.value
|
||||
assert "goodbye" in last_ckpt
|
||||
assert os.path.exists(os.path.join(last_ckpt, "ckpt.log"))
|
||||
trial_dfs = list(analysis.trial_dataframes.values())
|
||||
assert len(trial_dfs[0]["training_iteration"]) == 10
|
||||
|
||||
def testBlankCheckpoint(self):
|
||||
def train(config, checkpoint=None):
|
||||
restored = bool(checkpoint)
|
||||
def train(config, checkpoint_dir=None):
|
||||
restored = bool(checkpoint_dir)
|
||||
itr = 0
|
||||
if checkpoint:
|
||||
with open(checkpoint, "r") as f:
|
||||
if checkpoint_dir:
|
||||
with open(os.path.join(checkpoint_dir, "ckpt.log"), "r") as f:
|
||||
itr = int(f.read()) + 1
|
||||
|
||||
for i in range(itr, 10):
|
||||
if i == 5 and not restored:
|
||||
raise Exception("try to fail me")
|
||||
checkpoint_dir = tune.make_checkpoint_dir()
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "goodbye")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
tune.save_checkpoint(checkpoint_path)
|
||||
with tune.checkpoint_dir() as checkpoint_dir:
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "ckpt.log")
|
||||
with open(checkpoint_path, "w") as f:
|
||||
f.write(str(i))
|
||||
tune.report(test=i, training_iteration=i)
|
||||
|
||||
analysis = tune.run(train, max_failures=3)
|
||||
|
||||
+25
-7
@@ -6,8 +6,9 @@ import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.torch.func_trainable import (
|
||||
DistributedTrainableCreator, distributed_checkpoint, _train_simple)
|
||||
from ray.tune.integration.torch import (DistributedTrainableCreator,
|
||||
distributed_checkpoint_dir,
|
||||
_train_simple, _train_check_global)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -47,12 +48,29 @@ def test_step_after_completion(ray_start_2_cpus): # noqa: F811
|
||||
trainer.train()
|
||||
|
||||
|
||||
def test_validation(ray_start_2_cpus): # noqa: F811
|
||||
def bad_func(a, b, c):
|
||||
return 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DistributedTrainableCreator(bad_func, num_workers=2)
|
||||
|
||||
|
||||
def test_set_global(ray_start_2_cpus): # noqa: F811
|
||||
trainable_cls = DistributedTrainableCreator(
|
||||
_train_check_global, num_workers=2)
|
||||
trainable = trainable_cls()
|
||||
result = trainable.train()
|
||||
assert result["is_distributed"]
|
||||
|
||||
|
||||
def test_save_checkpoint(ray_start_2_cpus): # noqa: F811
|
||||
trainable_cls = DistributedTrainableCreator(_train_simple, num_workers=2)
|
||||
trainer = trainable_cls(config={"epochs": 1})
|
||||
trainer.train()
|
||||
path = trainer.save()
|
||||
model_state_dict, opt_state_dict = torch.load(path)
|
||||
checkpoint_dir = trainer.save()
|
||||
model_state_dict, opt_state_dict = torch.load(
|
||||
os.path.join(checkpoint_dir, "checkpoint"))
|
||||
trainer.stop()
|
||||
|
||||
|
||||
@@ -72,11 +90,11 @@ def test_simple_tune(ray_start_4_cpus, enabled_checkpoint):
|
||||
def test_checkpoint(ray_start_2_cpus, rank): # noqa: F811
|
||||
with patch("torch.distributed.get_rank") as rank_method:
|
||||
rank_method.return_value = rank
|
||||
with distributed_checkpoint(label="test") as path:
|
||||
with distributed_checkpoint_dir(step="test") as path:
|
||||
if rank == 0:
|
||||
assert path
|
||||
else:
|
||||
assert path == os.devnull
|
||||
if rank != 0:
|
||||
assert not os.path.exists(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -26,14 +26,6 @@ py_test(
|
||||
deps = [":sgd_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_torch_trainable",
|
||||
size = "small",
|
||||
srcs = ["tests/test_torch_trainable.py"],
|
||||
tags = ["exclusive", "pytorch"],
|
||||
deps = [":sgd_lib"],
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Tests from the python/ray/util/sgd/tf/examples directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
|
||||
@@ -12,13 +12,8 @@ try:
|
||||
BaseTorchTrainable)
|
||||
|
||||
from ray.util.sgd.torch.training_operator import TrainingOperator
|
||||
from ray.util.sgd.torch.func_trainable import (DistributedTrainableCreator,
|
||||
distributed_checkpoint)
|
||||
|
||||
__all__ = [
|
||||
"TorchTrainer", "BaseTorchTrainable", "TrainingOperator",
|
||||
"distributed_checkpoint", "DistributedTrainableCreator"
|
||||
]
|
||||
__all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"]
|
||||
except ImportError as e:
|
||||
logger.warning(e)
|
||||
logger.warning("PyTorch not found. TorchTrainer will not be available")
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
# Original Code here:
|
||||
# https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
from datetime import timedelta
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.result import RESULT_DUPLICATE
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.util.sgd.torch.utils import setup_process_group
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_S
|
||||
from ray.util.sgd.torch.utils import setup_address
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def logger_creator(log_config, logdir, rank):
|
||||
worker_dir = os.path.join(logdir, "worker_{}".format(rank))
|
||||
os.makedirs(worker_dir, exist_ok=True)
|
||||
return NoopLogger(log_config, worker_dir)
|
||||
|
||||
|
||||
class _TorchTrainable(tune.Trainable):
|
||||
"""Base class for distributed training on Tune.
|
||||
|
||||
A wrapper class is needed to actually create a working
|
||||
version of this trainable.
|
||||
"""
|
||||
_function = None
|
||||
_num_workers = None
|
||||
_use_gpu = None
|
||||
_num_cpus_per_worker = None
|
||||
|
||||
__slots__ = ["workers", "_finished"]
|
||||
|
||||
@classmethod
|
||||
def default_process_group_parameters(self):
|
||||
return dict(timeout=timedelta(NCCL_TIMEOUT_S), backend="gloo")
|
||||
|
||||
@classmethod
|
||||
def get_remote_worker_options(self):
|
||||
num_gpus = 1 if self._use_gpu else 0
|
||||
num_cpus = int(self._num_cpus_per_worker or 1)
|
||||
return dict(num_cpus=num_cpus, num_gpus=num_gpus)
|
||||
|
||||
def setup(self, config):
|
||||
self._finished = False
|
||||
num_workers = self._num_workers
|
||||
logdir = self.logdir
|
||||
assert self._function
|
||||
|
||||
func_trainable = wrap_function(self.__class__._function)
|
||||
|
||||
remote_trainable = ray.remote(func_trainable)
|
||||
remote_trainable = remote_trainable.options(
|
||||
**self.get_remote_worker_options())
|
||||
|
||||
address = setup_address()
|
||||
self.workers = [
|
||||
remote_trainable.remote(
|
||||
config=config,
|
||||
logger_creator=lambda cfg: logger_creator(cfg, logdir, rank))
|
||||
for rank in range(num_workers)
|
||||
]
|
||||
|
||||
pgroup_params = self.default_process_group_parameters()
|
||||
from functools import partial
|
||||
setup_on_worker = partial(
|
||||
setup_process_group,
|
||||
url=address,
|
||||
world_size=num_workers,
|
||||
**pgroup_params)
|
||||
ray.get([
|
||||
w.execute.remote(lambda _: setup_on_worker(world_rank=rank))
|
||||
for rank, w in enumerate(self.workers)
|
||||
])
|
||||
|
||||
def step(self):
|
||||
if self._finished:
|
||||
raise RuntimeError("Training has already finished.")
|
||||
result = ray.get([w.step.remote() for w in self.workers])[0]
|
||||
if RESULT_DUPLICATE in result:
|
||||
self._finished = True
|
||||
return result
|
||||
|
||||
def save_checkpoint(self, checkpoint_dir):
|
||||
# TODO: optimize if colocated
|
||||
save_obj = ray.get(self.workers[0].save_to_object.remote())
|
||||
checkpoint_path = TrainableUtil.create_from_pickle(
|
||||
save_obj, checkpoint_dir)
|
||||
return checkpoint_path
|
||||
|
||||
def load_checkpoint(self, checkpoint_dir):
|
||||
checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir)
|
||||
return ray.get(
|
||||
w.restore_from_object.remote(checkpoint_obj) for w in self.workers)
|
||||
|
||||
def stop(self):
|
||||
ray.get([worker.stop.remote() for worker in self.workers])
|
||||
|
||||
|
||||
def DistributedTrainableCreator(func,
|
||||
use_gpu=False,
|
||||
num_workers=1,
|
||||
num_cpus_per_worker=1,
|
||||
backend="gloo",
|
||||
timeout_s=NCCL_TIMEOUT_S):
|
||||
"""Creates a class that executes distributed training.
|
||||
|
||||
Note that you typically should not instantiate the object
|
||||
created.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
trainable_cls = DistributedTrainableCreator(
|
||||
train_func, num_workers=2)
|
||||
analysis = tune.run(trainable_cls)
|
||||
"""
|
||||
|
||||
class WrappedDistributedTorchTrainable(_TorchTrainable):
|
||||
_function = func
|
||||
_num_workers = num_workers
|
||||
_use_gpu = use_gpu
|
||||
_num_cpus_per_worker = num_cpus_per_worker
|
||||
|
||||
@classmethod
|
||||
def default_process_group_parameters(self):
|
||||
return dict(timeout=timedelta(timeout_s), backend=backend)
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
num_workers_ = int(config.get("num_workers", num_workers))
|
||||
num_cpus = int(
|
||||
config.get("num_cpus_per_worker", num_cpus_per_worker))
|
||||
use_gpu_ = config.get("use_gpu", use_gpu)
|
||||
|
||||
return Resources(
|
||||
cpu=0,
|
||||
gpu=0,
|
||||
extra_cpu=num_cpus * num_workers_,
|
||||
extra_gpu=num_workers_ if use_gpu_ else 0)
|
||||
|
||||
return WrappedDistributedTorchTrainable
|
||||
|
||||
|
||||
class distributed_checkpoint:
|
||||
"""ContextManager for creating a distributed checkpoint.
|
||||
|
||||
Only checkpoints a file on the "main" training actor, avoiding
|
||||
redundant work.
|
||||
|
||||
Args:
|
||||
label (int | str): Used to label the checkpoint
|
||||
disable (bool): Disable for prototyping.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
if epoch % 3 == 0:
|
||||
with distributed_checkpoint(label=epoch) as path:
|
||||
torch.save(model.state_dict(), path)
|
||||
"""
|
||||
|
||||
def __init__(self, label, disable=False):
|
||||
self.label = label
|
||||
self.file = None
|
||||
self.disable = disable
|
||||
|
||||
def __enter__(self):
|
||||
if torch.distributed.get_rank() == 0 and not self.disable:
|
||||
checkpoint_dir = tune.make_checkpoint_dir(step=self.label)
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
else:
|
||||
path = os.devnull
|
||||
self.file = path
|
||||
return path
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if torch.distributed.get_rank() == 0 and not self.disable:
|
||||
tune.save_checkpoint(self.file)
|
||||
|
||||
|
||||
def _train_simple(config, checkpoint=False):
|
||||
"""For testing only. Putting this here because Ray has problems
|
||||
serializing within the test file."""
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
import torch.optim as optim
|
||||
# N is batch size; D_in is input dimension;
|
||||
# H is hidden dimension; D_out is output dimension.
|
||||
N, D_in, H, D_out = 8, 5, 5, 5
|
||||
|
||||
# Create random Tensors to hold inputs and outputs
|
||||
x = torch.randn(N, D_in)
|
||||
y = torch.randn(N, D_out)
|
||||
loss_fn = nn.MSELoss()
|
||||
|
||||
# Use the nn package to define our model and loss function.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(D_in, H),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(H, D_out),
|
||||
)
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
if checkpoint:
|
||||
with open(checkpoint) as f:
|
||||
model_state, optimizer_state = torch.load(f)
|
||||
|
||||
model.load_state_dict(model_state)
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
model = DistributedDataParallel(model)
|
||||
|
||||
for epoch in range(config.get("epochs", 10)):
|
||||
optimizer.zero_grad()
|
||||
output = model(x)
|
||||
loss = loss_fn(output, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if epoch % 3 == 0:
|
||||
if config.get("enable_checkpoint", True):
|
||||
with distributed_checkpoint(label=epoch) as path:
|
||||
torch.save((model.state_dict(), optimizer.state_dict()),
|
||||
path)
|
||||
tune.report(mean_loss=loss.item())
|
||||
|
||||
@@ -140,6 +140,8 @@ class TorchTrainer:
|
||||
Defaults to True.
|
||||
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
|
||||
over each model. If False, you are expected to call it yourself.
|
||||
timeout_s (float): Seconds before the torch process group
|
||||
times out. Useful when machines are unreliable.
|
||||
add_dist_sampler (bool): Whether to automatically add a
|
||||
DistributedSampler to all created dataloaders. Only applicable
|
||||
if num_workers > 1.
|
||||
|
||||
Reference in New Issue
Block a user