mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 05:56:25 +08:00
[sgd] Tune interface for Pytorch MultiNode SGD (#5350)
This commit is contained in:
@@ -3,7 +3,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer
|
||||
from ray import tune
|
||||
from ray.experimental.sgd.pytorch.pytorch_trainer import (PyTorchTrainer,
|
||||
PyTorchTrainable)
|
||||
|
||||
from ray.experimental.sgd.tests.pytorch_utils import (
|
||||
model_creator, optimizer_creator, data_creator)
|
||||
@@ -16,12 +18,34 @@ def train_example(num_replicas=1, use_gpu=False):
|
||||
optimizer_creator,
|
||||
num_replicas=num_replicas,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=512,
|
||||
backend="gloo")
|
||||
trainer1.train()
|
||||
trainer1.shutdown()
|
||||
print("success!")
|
||||
|
||||
|
||||
def tune_example(num_replicas=1, use_gpu=False):
|
||||
config = {
|
||||
"model_creator": tune.function(model_creator),
|
||||
"data_creator": tune.function(data_creator),
|
||||
"optimizer_creator": tune.function(optimizer_creator),
|
||||
"num_replicas": num_replicas,
|
||||
"use_gpu": use_gpu,
|
||||
"batch_size": 512,
|
||||
"backend": "gloo"
|
||||
}
|
||||
|
||||
analysis = tune.run(
|
||||
PyTorchTrainable,
|
||||
num_samples=12,
|
||||
config=config,
|
||||
stop={"training_iteration": 2},
|
||||
verbose=1)
|
||||
|
||||
return analysis.get_best_config(metric="validation_loss", mode="min")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -40,9 +64,16 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
parser.add_argument(
|
||||
"--tune", action="store_true", default=False, help="Tune training")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
import ray
|
||||
|
||||
ray.init(redis_address=args.redis_address)
|
||||
train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
|
||||
|
||||
if args.tune:
|
||||
tune_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
|
||||
else:
|
||||
train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer
|
||||
from ray.experimental.sgd.pytorch.pytorch_trainer import (PyTorchTrainer,
|
||||
PyTorchTrainable)
|
||||
|
||||
__all__ = ["PyTorchTrainer"]
|
||||
__all__ = ["PyTorchTrainer", "PyTorchTrainable"]
|
||||
|
||||
@@ -33,6 +33,7 @@ class DistributedPyTorchRunner(PyTorchRunner):
|
||||
batch_size (int): batch size used by one replica for an update.
|
||||
backend (string): see pytorch_trainer.py.
|
||||
"""
|
||||
|
||||
super(DistributedPyTorchRunner, self).__init__(
|
||||
model_creator, data_creator, optimizer_creator, config, batch_size)
|
||||
self.backend = backend
|
||||
|
||||
@@ -3,12 +3,15 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import logging
|
||||
|
||||
import ray
|
||||
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.resources import Resources
|
||||
from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner
|
||||
from ray.experimental.sgd.pytorch.distributed_pytorch_runner import (
|
||||
DistributedPyTorchRunner)
|
||||
@@ -136,14 +139,25 @@ class PyTorchTrainer(object):
|
||||
model.load_state_dict(state["model"])
|
||||
return model
|
||||
|
||||
def save(self, ckpt):
|
||||
"""Saves the model at the provided checkpoint."""
|
||||
state = ray.get(self.workers[0].get_state.remote())
|
||||
torch.save(state, ckpt)
|
||||
def save(self, checkpoint):
|
||||
"""Saves the model at the provided checkpoint.
|
||||
|
||||
def restore(self, ckpt):
|
||||
"""Restores the model from the provided checkpoint."""
|
||||
state = torch.load(ckpt)
|
||||
Args:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
|
||||
"""
|
||||
state = ray.get(self.workers[0].get_state.remote())
|
||||
torch.save(state, checkpoint)
|
||||
return checkpoint
|
||||
|
||||
def restore(self, checkpoint):
|
||||
"""Restores the model from the provided checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint (str): Path to target checkpoint file.
|
||||
|
||||
"""
|
||||
state = torch.load(checkpoint)
|
||||
state_id = ray.put(state)
|
||||
ray.get([worker.set_state.remote(state_id) for worker in self.workers])
|
||||
|
||||
@@ -152,3 +166,42 @@ class PyTorchTrainer(object):
|
||||
for worker in self.workers:
|
||||
worker.shutdown.remote()
|
||||
worker.__ray_terminate__.remote()
|
||||
|
||||
|
||||
class PyTorchTrainable(Trainable):
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
return Resources(
|
||||
cpu=0,
|
||||
gpu=0,
|
||||
extra_cpu=config["num_replicas"],
|
||||
extra_gpu=int(config["use_gpu"]) * config["num_replicas"])
|
||||
|
||||
def _setup(self, config):
|
||||
self._trainer = PyTorchTrainer(
|
||||
model_creator=config["model_creator"],
|
||||
data_creator=config["data_creator"],
|
||||
optimizer_creator=config["optimizer_creator"],
|
||||
config=config,
|
||||
num_replicas=config["num_replicas"],
|
||||
use_gpu=config["use_gpu"],
|
||||
batch_size=config["batch_size"],
|
||||
backend=config["backend"])
|
||||
|
||||
def _train(self):
|
||||
|
||||
train_stats = self._trainer.train()
|
||||
validation_stats = self._trainer.validate()
|
||||
|
||||
train_stats.update(validation_stats)
|
||||
|
||||
return train_stats
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
return self._trainer.save(os.path.join(checkpoint_dir, "model.pth"))
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
return self._trainer.restore(checkpoint_path)
|
||||
|
||||
def _stop(self):
|
||||
self._trainer.shutdown()
|
||||
|
||||
@@ -8,8 +8,9 @@ import tempfile
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ray import tune
|
||||
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer, PyTorchTrainable
|
||||
|
||||
from ray.experimental.sgd.tests.pytorch_utils import (
|
||||
model_creator, optimizer_creator, data_creator)
|
||||
@@ -36,6 +37,38 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
assert validation_loss2 <= validation_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
config = {
|
||||
"model_creator": tune.function(model_creator),
|
||||
"data_creator": tune.function(data_creator),
|
||||
"optimizer_creator": tune.function(optimizer_creator),
|
||||
"num_replicas": num_replicas,
|
||||
"use_gpu": False,
|
||||
"batch_size": 512,
|
||||
"backend": "gloo"
|
||||
}
|
||||
|
||||
analysis = tune.run(
|
||||
PyTorchTrainable,
|
||||
num_samples=2,
|
||||
config=config,
|
||||
stop={"training_iteration": 2},
|
||||
verbose=1)
|
||||
|
||||
# checks loss decreasing for every trials
|
||||
for path, df in analysis.trial_dataframes.items():
|
||||
train_loss1 = df.loc[0, "train_loss"]
|
||||
train_loss2 = df.loc[1, "train_loss"]
|
||||
validation_loss1 = df.loc[0, "validation_loss"]
|
||||
validation_loss2 = df.loc[1, "validation_loss"]
|
||||
|
||||
assert train_loss2 <= train_loss1
|
||||
assert validation_loss2 <= validation_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize( # noqa: F811
|
||||
"num_replicas", [1, 2] if dist.is_available() else [1])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
|
||||
Reference in New Issue
Block a user