mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 05:17:38 +08:00
[raysgd] Improve raysgd examples (#7818)
* better_example * test * improve some usability things * submit * fix * flake * Update python/ray/util/sgd/torch/training_operator.py * trythis * fix * fix * smoke * fail * fix * fix
This commit is contained in:
@@ -6,6 +6,7 @@ import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
@@ -269,7 +270,7 @@ def test_split_batch(ray_start_2_cpus):
|
||||
def data_creator(config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
train_dataset = LinearDataset(2, 5, size=config["data_size"])
|
||||
return torch.utils.data.DataLoader(
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config[BATCH_SIZE],
|
||||
)
|
||||
@@ -301,7 +302,10 @@ def test_reduce_result(ray_start_2_cpus):
|
||||
def data_creator(config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
train_dataset = LinearDataset(2, 5, size=config["data_size"])
|
||||
return torch.utils.data.DataLoader(train_dataset, batch_size=1)
|
||||
test_dataset = LinearDataset(2, 5, size=config["data_size"])
|
||||
return DataLoader(
|
||||
train_dataset, batch_size=1), DataLoader(
|
||||
test_dataset, batch_size=1)
|
||||
|
||||
data_size = 600
|
||||
|
||||
@@ -316,6 +320,10 @@ def test_reduce_result(ray_start_2_cpus):
|
||||
assert len(list_stats) == 2
|
||||
assert [stats[NUM_SAMPLES] == data_size for stats in list_stats]
|
||||
assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats]
|
||||
list_stats = trainer.validate(reduce_results=False, profile=True)
|
||||
assert len(list_stats) == 2
|
||||
assert [stats[NUM_SAMPLES] == data_size for stats in list_stats]
|
||||
assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats]
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@@ -501,8 +509,7 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, **params):
|
||||
remote_worker_stats = [
|
||||
@@ -545,8 +552,7 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, **params):
|
||||
remote_worker_stats = [
|
||||
@@ -595,8 +601,7 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, **params):
|
||||
remote_worker_stats = [
|
||||
|
||||
@@ -12,7 +12,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_IN_SECONDS
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner, _remind_gpu_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,16 +23,19 @@ class DistributedTorchRunner(TorchRunner):
|
||||
|
||||
Args:
|
||||
args: Arguments for TorchRunner.
|
||||
backend (string): Backend used by distributed PyTorch.
|
||||
backend (str): Backend used by distributed PyTorch.
|
||||
add_dist_sampler (bool): Whether to automatically add a
|
||||
DistributedSampler to all created dataloaders.
|
||||
kwargs: Keyword arguments for TorchRunner.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, backend="gloo", **kwargs):
|
||||
def __init__(self, *args, backend="gloo", add_dist_sampler=True, **kwargs):
|
||||
super(DistributedTorchRunner, self).__init__(*args, **kwargs)
|
||||
if backend not in ("gloo", "nccl"):
|
||||
raise ValueError("Backend must be one of 'gloo' or 'nccl'.")
|
||||
self.backend = backend
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
|
||||
def setup(self, url, world_rank, world_size):
|
||||
"""Connects to the distributed PyTorch backend and initializes the model.
|
||||
@@ -42,6 +45,7 @@ class DistributedTorchRunner(TorchRunner):
|
||||
world_rank (int): the index of the runner.
|
||||
world_size (int): the total number of runners.
|
||||
"""
|
||||
_remind_gpu_usage(self.use_gpu, is_chief=world_rank == 0)
|
||||
self._setup_distributed_pytorch(url, world_rank, world_size)
|
||||
self._setup_training()
|
||||
|
||||
@@ -65,14 +69,25 @@ class DistributedTorchRunner(TorchRunner):
|
||||
world_size=world_size,
|
||||
timeout=timeout)
|
||||
|
||||
self.device_ids = None
|
||||
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
# https://github.com/allenai/allennlp/issues/1090
|
||||
self._set_cuda_device_id()
|
||||
|
||||
def _set_cuda_device_id(self):
|
||||
self.device_ids = [0]
|
||||
|
||||
def _setup_training(self):
|
||||
logger.debug("Loading data.")
|
||||
self._initialize_dataloaders()
|
||||
logger.debug("Creating model")
|
||||
self.models = self.model_creator(self.config)
|
||||
if not isinstance(self.models, collections.Iterable):
|
||||
self.models = [self.models]
|
||||
assert all(isinstance(model, nn.Module) for model in self.models), (
|
||||
"All models must be PyTorch models: {}.".format(self.models))
|
||||
if torch.cuda.is_available():
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
self.models = [model.cuda() for model in self.models]
|
||||
|
||||
logger.debug("Creating optimizer.")
|
||||
@@ -83,11 +98,14 @@ class DistributedTorchRunner(TorchRunner):
|
||||
|
||||
self._create_schedulers_if_available()
|
||||
self._try_setup_apex()
|
||||
|
||||
# This needs to happen after apex
|
||||
self.models = [DistributedDataParallel(model) for model in self.models]
|
||||
self.models = [
|
||||
DistributedDataParallel(model, device_ids=self.device_ids)
|
||||
for model in self.models
|
||||
]
|
||||
|
||||
self._create_loss()
|
||||
self._initialize_dataloaders()
|
||||
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
@@ -98,6 +116,7 @@ class DistributedTorchRunner(TorchRunner):
|
||||
validation_loader=self.validation_loader,
|
||||
world_rank=self.world_rank,
|
||||
schedulers=self.schedulers,
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm)
|
||||
|
||||
@@ -121,11 +140,13 @@ class DistributedTorchRunner(TorchRunner):
|
||||
return DataLoader(**data_loader_args)
|
||||
|
||||
if isinstance(self.train_loader, DataLoader):
|
||||
self.train_loader = with_sampler(self.train_loader)
|
||||
if self.add_dist_sampler:
|
||||
self.train_loader = with_sampler(self.train_loader)
|
||||
|
||||
if self.validation_loader and isinstance(self.validation_loader,
|
||||
DataLoader):
|
||||
self.validation_loader = with_sampler(self.validation_loader)
|
||||
if self.add_dist_sampler:
|
||||
self.validation_loader = with_sampler(self.validation_loader)
|
||||
|
||||
def train_epoch(self, **kwargs):
|
||||
"""Runs a training epoch and updates the model parameters.
|
||||
@@ -190,10 +211,22 @@ class LocalDistributedRunner(DistributedTorchRunner):
|
||||
num_gpus=num_gpus,
|
||||
resources={"node:" + ip: 0.1})(_DummyActor).remote()
|
||||
|
||||
head_cuda = ray.get(_dummy_actor.cuda_devices.remote())
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = head_cuda
|
||||
self.local_device = ray.get(_dummy_actor.cuda_devices.remote())
|
||||
|
||||
# This is a pretty annoying workaround. To enable SyncBatchNorm,
|
||||
# we need to signify that we are using only 1 CUDA device (via
|
||||
# the DDP constructor). However, on the local worker,
|
||||
# we set the CUDA_VISIBLE_DEVICES at runtime rather at process
|
||||
# start. This means that we have to make sure that DDP knows which
|
||||
# specific device we are using.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.local_device
|
||||
if self.local_device:
|
||||
torch.cuda.set_device(int(self.local_device))
|
||||
super(LocalDistributedRunner, self).__init__(*args, **kwargs)
|
||||
|
||||
def _set_cuda_device_id(self):
|
||||
self.device_ids = [int(self.local_device)]
|
||||
|
||||
def shutdown(self, cleanup=True):
|
||||
super(LocalDistributedRunner, self).shutdown()
|
||||
global _dummy_actor
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
import torchvision
|
||||
from torchvision.datasets import CIFAR10
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
import ray
|
||||
from ray.tune import CLIReporter
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch.resnet import ResNet18
|
||||
from ray.util.sgd.utils import BATCH_SIZE
|
||||
@@ -42,12 +38,12 @@ def cifar_creator(config):
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
train_dataset = torchvision.datasets.CIFAR10(
|
||||
train_dataset = CIFAR10(
|
||||
root="~/data", train=True, download=True, transform=transform_train)
|
||||
validation_dataset = torchvision.datasets.CIFAR10(
|
||||
validation_dataset = CIFAR10(
|
||||
root="~/data", train=False, download=False, transform=transform_test)
|
||||
|
||||
if config.get("test_mode"):
|
||||
if config["test_mode"]:
|
||||
train_dataset = Subset(train_dataset, list(range(64)))
|
||||
validation_dataset = Subset(validation_dataset, list(range(64)))
|
||||
|
||||
@@ -71,95 +67,6 @@ def scheduler_creator(optimizer, config):
|
||||
optimizer, milestones=[150, 250, 350], gamma=0.1)
|
||||
|
||||
|
||||
def train_example(num_workers=1,
|
||||
num_epochs=5,
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
test_mode=False):
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
initialization_hook=initialization_hook,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"lr": 0.1,
|
||||
"test_mode": test_mode, # user-defined param to subset the data
|
||||
BATCH_SIZE: 128 * num_workers # this will be split across workers.
|
||||
},
|
||||
use_gpu=use_gpu,
|
||||
scheduler_step_freq="epoch",
|
||||
use_fp16=use_fp16,
|
||||
use_tqdm=True)
|
||||
pbar = trange(num_epochs, unit="epoch")
|
||||
for i in pbar:
|
||||
info = {"num_steps": 1} if test_mode else {}
|
||||
info["epoch_idx"] = i
|
||||
info["num_epochs"] = num_epochs
|
||||
# Increase `max_retries` to turn on fault tolerance.
|
||||
trainer1.train(max_retries=1, info=info)
|
||||
val_stats = trainer1.validate()
|
||||
pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))
|
||||
|
||||
print(trainer1.validate())
|
||||
trainer1.shutdown()
|
||||
print("success!")
|
||||
|
||||
|
||||
def tune_example(num_workers=1, use_gpu=False, use_fp16=False,
|
||||
test_mode=False):
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
initialization_hook=initialization_hook,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"test_mode": test_mode, # user-defined param to subset the data
|
||||
BATCH_SIZE: 128 * num_workers,
|
||||
},
|
||||
use_gpu=use_gpu,
|
||||
scheduler_step_freq="epoch",
|
||||
use_fp16=use_fp16)
|
||||
|
||||
pbt_scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="val_loss",
|
||||
mode="min",
|
||||
perturbation_interval=1,
|
||||
hyperparam_mutations={
|
||||
# distribution for resampling
|
||||
"lr": lambda: np.random.uniform(0.001, 1),
|
||||
# allow perturbations within this set of categorical values
|
||||
"momentum": [0.8, 0.9, 0.99],
|
||||
})
|
||||
|
||||
reporter = CLIReporter()
|
||||
reporter.add_metric_column("val_loss", "loss")
|
||||
reporter.add_metric_column("val_accuracy", "acc")
|
||||
|
||||
analysis = tune.run(
|
||||
TorchTrainable,
|
||||
num_samples=4,
|
||||
config={
|
||||
"lr": tune.choice([0.001, 0.01, 0.1]),
|
||||
"momentum": 0.8
|
||||
},
|
||||
stop={"training_iteration": 2 if test_mode else 100},
|
||||
max_failures=3, # used for fault tolerance
|
||||
checkpoint_freq=3, # used for fault tolerance
|
||||
keep_checkpoints_num=1, # used for fault tolerance
|
||||
verbose=2,
|
||||
progress_reporter=reporter,
|
||||
scheduler=pbt_scheduler)
|
||||
|
||||
return analysis.get_best_config(metric="val_loss", mode="min")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -194,18 +101,37 @@ if __name__ == "__main__":
|
||||
"--tune", action="store_true", default=False, help="Tune training")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
num_cpus = 4 if args.smoke_test else None
|
||||
ray.init(address=args.address, num_cpus=num_cpus, log_to_driver=True)
|
||||
|
||||
ray.init(address=args.address, log_to_driver=True)
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
initialization_hook=initialization_hook,
|
||||
num_workers=args.num_workers,
|
||||
config={
|
||||
"lr": 0.1,
|
||||
"test_mode": args.smoke_test, # subset the data
|
||||
# this will be split across workers.
|
||||
BATCH_SIZE: 128 * args.num_workers
|
||||
},
|
||||
use_gpu=args.use_gpu,
|
||||
scheduler_step_freq="epoch",
|
||||
use_fp16=args.fp16,
|
||||
use_tqdm=True)
|
||||
pbar = trange(args.num_epochs, unit="epoch")
|
||||
for i in pbar:
|
||||
info = {"num_steps": 1} if args.smoke_test else {}
|
||||
info["epoch_idx"] = i
|
||||
info["num_epochs"] = args.num_epochs
|
||||
# Increase `max_retries` to turn on fault tolerance.
|
||||
trainer1.train(max_retries=1, info=info)
|
||||
val_stats = trainer1.validate()
|
||||
pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))
|
||||
|
||||
if args.tune:
|
||||
tune_example(
|
||||
num_workers=args.num_workers,
|
||||
use_gpu=args.use_gpu,
|
||||
test_mode=args.smoke_test)
|
||||
else:
|
||||
train_example(
|
||||
num_workers=args.num_workers,
|
||||
num_epochs=args.num_epochs,
|
||||
use_gpu=args.use_gpu,
|
||||
use_fp16=args.fp16,
|
||||
test_mode=args.smoke_test)
|
||||
print(trainer1.validate())
|
||||
trainer1.shutdown()
|
||||
print("success!")
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from torchvision.datasets import CIFAR10
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import ray
|
||||
from ray.tune import CLIReporter
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.torch.resnet import ResNet18
|
||||
from ray.util.sgd.utils import BATCH_SIZE
|
||||
|
||||
|
||||
def initialization_hook():
|
||||
# Need this for avoiding a connection restart issue on AWS.
|
||||
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
|
||||
os.environ["NCCL_LL_THRESHOLD"] = "0"
|
||||
|
||||
# set the below if needed
|
||||
# print("NCCL DEBUG SET")
|
||||
# os.environ["NCCL_DEBUG"] = "INFO"
|
||||
|
||||
|
||||
def cifar_creator(config):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
]) # meanstd transformation
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
train_dataset = CIFAR10(
|
||||
root="~/data", train=True, download=True, transform=transform_train)
|
||||
validation_dataset = CIFAR10(
|
||||
root="~/data", train=False, download=False, transform=transform_test)
|
||||
|
||||
if config.get("test_mode"):
|
||||
train_dataset = Subset(train_dataset, list(range(64)))
|
||||
validation_dataset = Subset(validation_dataset, list(range(64)))
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
validation_loader = DataLoader(
|
||||
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Returns optimizer"""
|
||||
return torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=config.get("lr", 0.1),
|
||||
momentum=config.get("momentum", 0.9))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--num-epochs", type=int, default=5, help="Number of epochs to train.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables FP16 training with apex. Requires `use-gpu`.")
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.")
|
||||
parser.add_argument(
|
||||
"--tune", action="store_true", default=False, help="Tune training")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(address=args.address, log_to_driver=True)
|
||||
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
initialization_hook=initialization_hook,
|
||||
num_workers=args.num_workers,
|
||||
config={
|
||||
"test_mode": args.smoke_test, # whether to to subset the data
|
||||
BATCH_SIZE: 128 * args.num_workers,
|
||||
},
|
||||
use_gpu=args.use_gpu,
|
||||
use_fp16=args.fp16)
|
||||
|
||||
pbt_scheduler = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="val_loss",
|
||||
mode="min",
|
||||
perturbation_interval=1,
|
||||
hyperparam_mutations={
|
||||
# distribution for resampling
|
||||
"lr": lambda: np.random.uniform(0.001, 1),
|
||||
# allow perturbations within this set of categorical values
|
||||
"momentum": [0.8, 0.9, 0.99],
|
||||
})
|
||||
|
||||
reporter = CLIReporter()
|
||||
reporter.add_metric_column("val_loss", "loss")
|
||||
reporter.add_metric_column("val_accuracy", "acc")
|
||||
|
||||
analysis = tune.run(
|
||||
TorchTrainable,
|
||||
num_samples=4,
|
||||
config={
|
||||
"lr": tune.choice([0.001, 0.01, 0.1]),
|
||||
"momentum": 0.8
|
||||
},
|
||||
stop={"training_iteration": 2 if args.smoke_test else 100},
|
||||
max_failures=3, # used for fault tolerance
|
||||
checkpoint_freq=3, # used for fault tolerance
|
||||
keep_checkpoints_num=1, # used for fault tolerance
|
||||
verbose=2,
|
||||
progress_reporter=reporter,
|
||||
scheduler=pbt_scheduler)
|
||||
|
||||
print(analysis.get_best_config(metric="val_loss", mode="min"))
|
||||
@@ -128,9 +128,6 @@ def optimizer_creator(models, config):
|
||||
|
||||
class GANOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
self.device = torch.device("cuda"
|
||||
if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.classifier = LeNet()
|
||||
self.classifier.load_state_dict(
|
||||
torch.load(config["classification_model_path"]))
|
||||
@@ -183,6 +180,7 @@ class GANOperator(TrainingOperator):
|
||||
|
||||
# Compute a discriminator update for real images
|
||||
discriminator.zero_grad()
|
||||
# self.device is set automatically
|
||||
real_cpu = batch[0].to(self.device)
|
||||
batch_size = real_cpu.size(0)
|
||||
label = torch.full((batch_size, ), real_label, device=self.device)
|
||||
|
||||
@@ -7,31 +7,15 @@ but we put comments right after code blocks to prevent large white spaces
|
||||
in the documentation.
|
||||
"""
|
||||
|
||||
# __torch_tune_example__
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.torch import TorchTrainer
|
||||
from ray.util.sgd.utils import BATCH_SIZE
|
||||
|
||||
|
||||
class LinearDataset(torch.utils.data.Dataset):
|
||||
"""y = a * x + b"""
|
||||
|
||||
def __init__(self, a, b, size=1000):
|
||||
x = np.random.random(size).astype(np.float32) * 10
|
||||
x = np.arange(0, 10, 10 / size, dtype=np.float32)
|
||||
self.x = torch.from_numpy(x)
|
||||
self.y = torch.from_numpy(a * x + b)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.x[index, None], self.y[index, None]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
from ray.util.sgd.torch.examples.train_example import LinearDataset
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
@@ -47,22 +31,18 @@ def data_creator(config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
val_dataset = LinearDataset(2, 5, size=400)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config[BATCH_SIZE],
|
||||
)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config[BATCH_SIZE])
|
||||
train_loader = DataLoader(train_dataset, batch_size=config[BATCH_SIZE])
|
||||
validation_loader = DataLoader(val_dataset, batch_size=config[BATCH_SIZE])
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
# __torch_tune_example__
|
||||
def tune_example(num_workers=1, use_gpu=False):
|
||||
TorchTrainable = TorchTrainer.as_trainable(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
loss_creator=nn.MSELoss, # Note that we specify a Loss class.
|
||||
num_workers=num_workers,
|
||||
use_gpu=use_gpu,
|
||||
config={BATCH_SIZE: 128}
|
||||
@@ -76,6 +56,7 @@ def tune_example(num_workers=1, use_gpu=False):
|
||||
verbose=1)
|
||||
|
||||
return analysis.get_best_config(metric="validation_loss", mode="min")
|
||||
# __end_torch_tune_example__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -23,6 +23,14 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _remind_gpu_usage(use_gpu, is_chief):
|
||||
if not is_chief:
|
||||
return
|
||||
if not use_gpu and torch.cuda.is_available():
|
||||
logger.info("GPUs detected but not using them. Set `use_gpu` to "
|
||||
"enable GPU usage. ")
|
||||
|
||||
|
||||
class TorchRunner:
|
||||
"""Manages a PyTorch model for training.
|
||||
|
||||
@@ -36,6 +44,7 @@ class TorchRunner:
|
||||
torch_trainer.py.
|
||||
training_operator_cls: see torch_trainer.py
|
||||
config (dict): see torch_trainer.py.
|
||||
use_gpu (bool): see torch_trainer.py.
|
||||
use_fp16 (bool): see torch_trainer.py.
|
||||
apex_args (dict|None): see torch_trainer.py.
|
||||
scheduler_step_freq (str): see torch_trainer.py.
|
||||
@@ -49,6 +58,7 @@ class TorchRunner:
|
||||
scheduler_creator=None,
|
||||
training_operator_cls=None,
|
||||
config=None,
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
@@ -69,6 +79,7 @@ class TorchRunner:
|
||||
self.schedulers = None
|
||||
self.train_loader = None
|
||||
self.validation_loader = None
|
||||
self.use_gpu = use_gpu
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
self.apex_args = apex_args or {}
|
||||
@@ -117,8 +128,9 @@ class TorchRunner:
|
||||
else:
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
|
||||
if torch.cuda.is_available() and hasattr(self.criterion, "cuda"):
|
||||
self.criterion = self.criterion.cuda()
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
if hasattr(self.criterion, "cuda"):
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
def _create_schedulers_if_available(self):
|
||||
# Learning rate schedules are optional.
|
||||
@@ -138,11 +150,13 @@ class TorchRunner:
|
||||
|
||||
def setup(self):
|
||||
"""Initializes the model."""
|
||||
_remind_gpu_usage(self.use_gpu, is_chief=True)
|
||||
self._initialize_dataloaders()
|
||||
logger.debug("Creating model")
|
||||
self.models = self.model_creator(self.config)
|
||||
if not isinstance(self.models, collections.Iterable):
|
||||
self.models = [self.models]
|
||||
if torch.cuda.is_available():
|
||||
if self.use_gpu and torch.cuda.is_available():
|
||||
self.models = [model.cuda() for model in self.models]
|
||||
|
||||
logger.debug("Creating optimizer")
|
||||
@@ -153,7 +167,6 @@ class TorchRunner:
|
||||
self._create_schedulers_if_available()
|
||||
self._try_setup_apex()
|
||||
self._create_loss()
|
||||
self._initialize_dataloaders()
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
models=self.models,
|
||||
@@ -163,6 +176,7 @@ class TorchRunner:
|
||||
validation_loader=self.validation_loader,
|
||||
world_rank=0,
|
||||
schedulers=self.schedulers,
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm)
|
||||
|
||||
|
||||
@@ -34,12 +34,13 @@ class TorchTrainer:
|
||||
"""Train a PyTorch model using distributed PyTorch.
|
||||
|
||||
Launches a set of actors which connect via distributed PyTorch and
|
||||
coordinate gradient updates to train the provided model.
|
||||
coordinate gradient updates to train the provided model. If Ray is not
|
||||
initialized, TorchTrainer will automatically initialize a local Ray
|
||||
cluster for you. Be sure to run `ray.init(address="auto")` to leverage
|
||||
multi-node training.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ray.init()
|
||||
|
||||
def model_creator(config):
|
||||
return nn.Linear(1, 1)
|
||||
|
||||
@@ -116,6 +117,9 @@ class TorchTrainer:
|
||||
support "nccl", "gloo", and "auto". If "auto", RaySGD will
|
||||
automatically use "nccl" if `use_gpu` is True, and "gloo"
|
||||
otherwise.
|
||||
add_dist_sampler (bool): Whether to automatically add a
|
||||
DistributedSampler to all created dataloaders. Only applicable
|
||||
if num_workers > 1.
|
||||
use_fp16 (bool): Enables mixed precision training via apex if apex
|
||||
is installed. This is automatically done after the model and
|
||||
optimizers are constructed and will work for multi-model training.
|
||||
@@ -148,11 +152,12 @@ class TorchTrainer:
|
||||
initialization_hook=None,
|
||||
config=None,
|
||||
num_workers=1,
|
||||
use_gpu=False,
|
||||
use_gpu="auto",
|
||||
backend="auto",
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
add_dist_sampler=True,
|
||||
scheduler_step_freq="batch",
|
||||
num_replicas=None,
|
||||
batch_size=None,
|
||||
@@ -202,19 +207,20 @@ class TorchTrainer:
|
||||
|
||||
self.initialization_hook = initialization_hook
|
||||
self.config = {} if config is None else config
|
||||
if use_gpu == "auto":
|
||||
use_gpu = torch.cuda.is_available()
|
||||
|
||||
if backend == "auto":
|
||||
backend = "nccl" if use_gpu else "gloo"
|
||||
|
||||
logger.debug("Using {} as backend.".format(backend))
|
||||
self.backend = backend
|
||||
|
||||
# TODO: Have an auto "use_gpu" option to detect and use GPUs.
|
||||
self.use_gpu = use_gpu
|
||||
self.max_replicas = num_workers
|
||||
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
|
||||
if apex_args and not isinstance(apex_args, dict):
|
||||
raise ValueError("apex_args needs to be a dict object.")
|
||||
@@ -230,6 +236,11 @@ class TorchTrainer:
|
||||
_validate_scheduler_step_freq(scheduler_step_freq)
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
if not ray.is_initialized() and self.max_replicas > 1:
|
||||
logger.info("Automatically initializing single-node Ray. To use "
|
||||
"multi-node training, be sure to run `ray.init("
|
||||
"address='auto')` before instantiating the Trainer.")
|
||||
ray.init()
|
||||
self._start_workers(self.max_replicas)
|
||||
|
||||
def _configure_and_split_batch(self, num_workers):
|
||||
@@ -259,39 +270,29 @@ class TorchTrainer:
|
||||
if batch_size_per_worker:
|
||||
worker_config[BATCH_SIZE] = batch_size_per_worker
|
||||
|
||||
params = dict(
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
use_fp16=self.use_fp16,
|
||||
use_gpu=self.use_gpu,
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
if num_workers == 1:
|
||||
# Start local worker
|
||||
self.local_worker = TorchRunner(
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
self.local_worker = TorchRunner(**params)
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
|
||||
self.local_worker.setup()
|
||||
else:
|
||||
params = dict(
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
backend=self.backend,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
params.update(
|
||||
backend=self.backend, add_dist_sampler=self.add_dist_sampler)
|
||||
|
||||
# Start local worker
|
||||
self.local_worker = LocalDistributedRunner(
|
||||
@@ -455,7 +456,11 @@ class TorchTrainer:
|
||||
local_call = self.local_worker.apply_operator(fn)
|
||||
return [local_call] + ray.get(remote_calls)
|
||||
|
||||
def validate(self, num_steps=None, profile=False, info=None):
|
||||
def validate(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
reduce_results=True,
|
||||
info=None):
|
||||
"""Evaluates the model on the validation data set.
|
||||
|
||||
Args:
|
||||
@@ -463,6 +468,10 @@ class TorchTrainer:
|
||||
This corresponds also to the number of times
|
||||
``TrainingOperator.validate_batch`` is called.
|
||||
profile (bool): Returns time stats for the evaluation procedure.
|
||||
reduce_results (bool): Whether to average all metrics across
|
||||
all workers into one dict. If a metric is a non-numerical
|
||||
value (or nested dictionaries), one value will be randomly
|
||||
selected among the workers. If False, returns a list of dicts.
|
||||
info (dict): Optional dictionary passed to the training
|
||||
operator for `validate` and `validate_batch`.
|
||||
|
||||
@@ -477,8 +486,12 @@ class TorchTrainer:
|
||||
w.validate.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
local_worker_stats = self.local_worker.validate(**params)
|
||||
return self._process_stats([local_worker_stats] +
|
||||
ray.get(remote_worker_stats))
|
||||
worker_stats = [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
if reduce_results:
|
||||
return self._process_stats(worker_stats)
|
||||
else:
|
||||
return worker_stats
|
||||
|
||||
def update_scheduler(self, metric):
|
||||
"""Calls ``scheduler.step(metric)`` on all schedulers.
|
||||
@@ -497,6 +510,17 @@ class TorchTrainer:
|
||||
return unwrapped[0]
|
||||
return unwrapped
|
||||
|
||||
def get_local_operator(self):
|
||||
"""Returns the local TrainingOperator object.
|
||||
|
||||
Be careful not to perturb its state, or else you can cause the system
|
||||
to enter an inconsistent state.
|
||||
|
||||
Returns:
|
||||
TrainingOperator: The local TrainingOperator object.
|
||||
"""
|
||||
return self.local_worker.training_operator
|
||||
|
||||
def state_dict(self):
|
||||
return self.local_worker.state_dict()
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ class TrainingOperator:
|
||||
world_rank,
|
||||
criterion=None,
|
||||
schedulers=None,
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
use_tqdm=False):
|
||||
# You are not expected to override this method.
|
||||
@@ -80,6 +81,8 @@ class TrainingOperator:
|
||||
type(schedulers)))
|
||||
self._config = config
|
||||
self._use_fp16 = use_fp16
|
||||
self._use_gpu = use_gpu and torch.cuda.is_available()
|
||||
self._device = torch.device("cuda" if self._use_gpu else "cpu")
|
||||
if tqdm is None and use_tqdm:
|
||||
raise ValueError("tqdm must be installed to use tqdm in training.")
|
||||
self._use_tqdm = use_tqdm
|
||||
@@ -324,13 +327,18 @@ class TrainingOperator:
|
||||
}
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns a serializable representation of the operator state."""
|
||||
"""Override this to return a representation of the operator state."""
|
||||
pass
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads a serializable representation of the operator state."""
|
||||
"""Override this to load the representation of the operator state."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""The torch device, at your convenience."""
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
"""Dictionary as provided into TorchTrainer."""
|
||||
|
||||
Reference in New Issue
Block a user