[sgd] Tune interface for Pytorch MultiNode SGD (#5350)

This commit is contained in:
jichan3751
2019-08-10 13:51:44 -07:00
committed by Richard Liaw
parent df47bdf6c9
commit de95117e96
6 changed files with 137 additions and 12 deletions
@@ -3,7 +3,9 @@ from __future__ import division
from __future__ import print_function
import argparse
from ray.experimental.sgd.pytorch import PyTorchTrainer
from ray import tune
from ray.experimental.sgd.pytorch.pytorch_trainer import (PyTorchTrainer,
PyTorchTrainable)
from ray.experimental.sgd.tests.pytorch_utils import (
model_creator, optimizer_creator, data_creator)
@@ -16,12 +18,34 @@ def train_example(num_replicas=1, use_gpu=False):
optimizer_creator,
num_replicas=num_replicas,
use_gpu=use_gpu,
batch_size=512,
backend="gloo")
trainer1.train()
trainer1.shutdown()
print("success!")
def tune_example(num_replicas=1, use_gpu=False):
config = {
"model_creator": tune.function(model_creator),
"data_creator": tune.function(data_creator),
"optimizer_creator": tune.function(optimizer_creator),
"num_replicas": num_replicas,
"use_gpu": use_gpu,
"batch_size": 512,
"backend": "gloo"
}
analysis = tune.run(
PyTorchTrainable,
num_samples=12,
config=config,
stop={"training_iteration": 2},
verbose=1)
return analysis.get_best_config(metric="validation_loss", mode="min")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -40,9 +64,16 @@ if __name__ == "__main__":
action="store_true",
default=False,
help="Enables GPU training")
parser.add_argument(
"--tune", action="store_true", default=False, help="Tune training")
args, _ = parser.parse_known_args()
import ray
ray.init(redis_address=args.redis_address)
train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
if args.tune:
tune_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
else:
train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer
from ray.experimental.sgd.pytorch.pytorch_trainer import (PyTorchTrainer,
PyTorchTrainable)
__all__ = ["PyTorchTrainer"]
__all__ = ["PyTorchTrainer", "PyTorchTrainable"]
@@ -33,6 +33,7 @@ class DistributedPyTorchRunner(PyTorchRunner):
batch_size (int): batch size used by one replica for an update.
backend (string): see pytorch_trainer.py.
"""
super(DistributedPyTorchRunner, self).__init__(
model_creator, data_creator, optimizer_creator, config, batch_size)
self.backend = backend
@@ -3,12 +3,15 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import os
import torch
import torch.distributed as dist
import logging
import ray
from ray.tune import Trainable
from ray.tune.resources import Resources
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
from ray.experimental.sgd.pytorch.distributed_pytorch_runner import (
DistributedPyTorchRunner)
@@ -136,14 +139,25 @@ class PyTorchTrainer(object):
model.load_state_dict(state["model"])
return model
def save(self, ckpt):
"""Saves the model at the provided checkpoint."""
state = ray.get(self.workers[0].get_state.remote())
torch.save(state, ckpt)
def save(self, checkpoint):
"""Saves the model at the provided checkpoint.
def restore(self, ckpt):
"""Restores the model from the provided checkpoint."""
state = torch.load(ckpt)
Args:
checkpoint (str): Path to target checkpoint file.
"""
state = ray.get(self.workers[0].get_state.remote())
torch.save(state, checkpoint)
return checkpoint
def restore(self, checkpoint):
"""Restores the model from the provided checkpoint.
Args:
checkpoint (str): Path to target checkpoint file.
"""
state = torch.load(checkpoint)
state_id = ray.put(state)
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
@@ -152,3 +166,42 @@ class PyTorchTrainer(object):
for worker in self.workers:
worker.shutdown.remote()
worker.__ray_terminate__.remote()
class PyTorchTrainable(Trainable):
@classmethod
def default_resource_request(cls, config):
return Resources(
cpu=0,
gpu=0,
extra_cpu=config["num_replicas"],
extra_gpu=int(config["use_gpu"]) * config["num_replicas"])
def _setup(self, config):
self._trainer = PyTorchTrainer(
model_creator=config["model_creator"],
data_creator=config["data_creator"],
optimizer_creator=config["optimizer_creator"],
config=config,
num_replicas=config["num_replicas"],
use_gpu=config["use_gpu"],
batch_size=config["batch_size"],
backend=config["backend"])
def _train(self):
train_stats = self._trainer.train()
validation_stats = self._trainer.validate()
train_stats.update(validation_stats)
return train_stats
def _save(self, checkpoint_dir):
return self._trainer.save(os.path.join(checkpoint_dir, "model.pth"))
def _restore(self, checkpoint_path):
return self._trainer.restore(checkpoint_path)
def _stop(self):
self._trainer.shutdown()
@@ -8,8 +8,9 @@ import tempfile
import torch
import torch.distributed as dist
from ray import tune
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
from ray.experimental.sgd.pytorch import PyTorchTrainer
from ray.experimental.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
from ray.experimental.sgd.tests.pytorch_utils import (
model_creator, optimizer_creator, data_creator)
@@ -36,6 +37,38 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
assert validation_loss2 <= validation_loss1
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2] if dist.is_available() else [1])
def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
config = {
"model_creator": tune.function(model_creator),
"data_creator": tune.function(data_creator),
"optimizer_creator": tune.function(optimizer_creator),
"num_replicas": num_replicas,
"use_gpu": False,
"batch_size": 512,
"backend": "gloo"
}
analysis = tune.run(
PyTorchTrainable,
num_samples=2,
config=config,
stop={"training_iteration": 2},
verbose=1)
# checks loss decreasing for every trials
for path, df in analysis.trial_dataframes.items():
train_loss1 = df.loc[0, "train_loss"]
train_loss2 = df.loc[1, "train_loss"]
validation_loss1 = df.loc[0, "validation_loss"]
validation_loss2 = df.loc[1, "validation_loss"]
assert train_loss2 <= train_loss1
assert validation_loss2 <= validation_loss1
@pytest.mark.parametrize( # noqa: F811
"num_replicas", [1, 2] if dist.is_available() else [1])
def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811