[raysgd] Improve raysgd examples (#7818)

* better_example

* test

* improve some usability things

* submit

* fix

* flake

* Update python/ray/util/sgd/torch/training_operator.py

* trythis

* fix

* fix

* smoke

* fail

* fix

* fix
This commit is contained in:
Richard Liaw
2020-04-01 08:58:39 -07:00
committed by GitHub
parent f4239d27fa
commit 24bf6ad607
11 changed files with 346 additions and 200 deletions
+13 -8
View File
@@ -6,6 +6,7 @@ import time
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader
import ray
from ray import tune
@@ -269,7 +270,7 @@ def test_split_batch(ray_start_2_cpus):
def data_creator(config):
"""Returns training dataloader, validation dataloader."""
train_dataset = LinearDataset(2, 5, size=config["data_size"])
return torch.utils.data.DataLoader(
return DataLoader(
train_dataset,
batch_size=config[BATCH_SIZE],
)
@@ -301,7 +302,10 @@ def test_reduce_result(ray_start_2_cpus):
def data_creator(config):
"""Returns training dataloader, validation dataloader."""
train_dataset = LinearDataset(2, 5, size=config["data_size"])
return torch.utils.data.DataLoader(train_dataset, batch_size=1)
test_dataset = LinearDataset(2, 5, size=config["data_size"])
return DataLoader(
train_dataset, batch_size=1), DataLoader(
test_dataset, batch_size=1)
data_size = 600
@@ -316,6 +320,10 @@ def test_reduce_result(ray_start_2_cpus):
assert len(list_stats) == 2
assert [stats[NUM_SAMPLES] == data_size for stats in list_stats]
assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats]
list_stats = trainer.validate(reduce_results=False, profile=True)
assert len(list_stats) == 2
assert [stats[NUM_SAMPLES] == data_size for stats in list_stats]
assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats]
trainer.shutdown()
@@ -501,8 +509,7 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return torch.utils.data.DataLoader(
dataset, batch_size=config.get("batch_size", 32))
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
def step_with_fail(self, **params):
remote_worker_stats = [
@@ -545,8 +552,7 @@ def test_resize(ray_start_2_cpus): # noqa: F811
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return torch.utils.data.DataLoader(
dataset, batch_size=config.get("batch_size", 32))
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
def step_with_fail(self, **params):
remote_worker_stats = [
@@ -595,8 +601,7 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return torch.utils.data.DataLoader(
dataset, batch_size=config.get("batch_size", 32))
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
def step_with_fail(self, **params):
remote_worker_stats = [
@@ -12,7 +12,7 @@ from torch.utils.data.distributed import DistributedSampler
from ray.util.sgd.torch.constants import NCCL_TIMEOUT_IN_SECONDS
import ray
from ray.util.sgd.torch.torch_runner import TorchRunner
from ray.util.sgd.torch.torch_runner import TorchRunner, _remind_gpu_usage
logger = logging.getLogger(__name__)
@@ -23,16 +23,19 @@ class DistributedTorchRunner(TorchRunner):
Args:
args: Arguments for TorchRunner.
backend (string): Backend used by distributed PyTorch.
backend (str): Backend used by distributed PyTorch.
add_dist_sampler (bool): Whether to automatically add a
DistributedSampler to all created dataloaders.
kwargs: Keyword arguments for TorchRunner.
"""
def __init__(self, *args, backend="gloo", **kwargs):
def __init__(self, *args, backend="gloo", add_dist_sampler=True, **kwargs):
super(DistributedTorchRunner, self).__init__(*args, **kwargs)
if backend not in ("gloo", "nccl"):
raise ValueError("Backend must be one of 'gloo' or 'nccl'.")
self.backend = backend
self.add_dist_sampler = add_dist_sampler
def setup(self, url, world_rank, world_size):
"""Connects to the distributed PyTorch backend and initializes the model.
@@ -42,6 +45,7 @@ class DistributedTorchRunner(TorchRunner):
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
"""
_remind_gpu_usage(self.use_gpu, is_chief=world_rank == 0)
self._setup_distributed_pytorch(url, world_rank, world_size)
self._setup_training()
@@ -65,14 +69,25 @@ class DistributedTorchRunner(TorchRunner):
world_size=world_size,
timeout=timeout)
self.device_ids = None
if self.use_gpu and torch.cuda.is_available():
# https://github.com/allenai/allennlp/issues/1090
self._set_cuda_device_id()
def _set_cuda_device_id(self):
self.device_ids = [0]
def _setup_training(self):
logger.debug("Loading data.")
self._initialize_dataloaders()
logger.debug("Creating model")
self.models = self.model_creator(self.config)
if not isinstance(self.models, collections.Iterable):
self.models = [self.models]
assert all(isinstance(model, nn.Module) for model in self.models), (
"All models must be PyTorch models: {}.".format(self.models))
if torch.cuda.is_available():
if self.use_gpu and torch.cuda.is_available():
self.models = [model.cuda() for model in self.models]
logger.debug("Creating optimizer.")
@@ -83,11 +98,14 @@ class DistributedTorchRunner(TorchRunner):
self._create_schedulers_if_available()
self._try_setup_apex()
# This needs to happen after apex
self.models = [DistributedDataParallel(model) for model in self.models]
self.models = [
DistributedDataParallel(model, device_ids=self.device_ids)
for model in self.models
]
self._create_loss()
self._initialize_dataloaders()
self.training_operator = self.training_operator_cls(
self.config,
@@ -98,6 +116,7 @@ class DistributedTorchRunner(TorchRunner):
validation_loader=self.validation_loader,
world_rank=self.world_rank,
schedulers=self.schedulers,
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm)
@@ -121,11 +140,13 @@ class DistributedTorchRunner(TorchRunner):
return DataLoader(**data_loader_args)
if isinstance(self.train_loader, DataLoader):
self.train_loader = with_sampler(self.train_loader)
if self.add_dist_sampler:
self.train_loader = with_sampler(self.train_loader)
if self.validation_loader and isinstance(self.validation_loader,
DataLoader):
self.validation_loader = with_sampler(self.validation_loader)
if self.add_dist_sampler:
self.validation_loader = with_sampler(self.validation_loader)
def train_epoch(self, **kwargs):
"""Runs a training epoch and updates the model parameters.
@@ -190,10 +211,22 @@ class LocalDistributedRunner(DistributedTorchRunner):
num_gpus=num_gpus,
resources={"node:" + ip: 0.1})(_DummyActor).remote()
head_cuda = ray.get(_dummy_actor.cuda_devices.remote())
os.environ["CUDA_VISIBLE_DEVICES"] = head_cuda
self.local_device = ray.get(_dummy_actor.cuda_devices.remote())
# This is a pretty annoying workaround. To enable SyncBatchNorm,
# we need to signify that we are using only 1 CUDA device (via
# the DDP constructor). However, on the local worker,
# we set the CUDA_VISIBLE_DEVICES at runtime rather at process
# start. This means that we have to make sure that DDP knows which
# specific device we are using.
os.environ["CUDA_VISIBLE_DEVICES"] = self.local_device
if self.local_device:
torch.cuda.set_device(int(self.local_device))
super(LocalDistributedRunner, self).__init__(*args, **kwargs)
def _set_cuda_device_id(self):
self.device_ids = [int(self.local_device)]
def shutdown(self, cleanup=True):
super(LocalDistributedRunner, self).shutdown()
global _dummy_actor
@@ -1,18 +1,14 @@
import numpy as np
import os
import torch
import torch.nn as nn
import argparse
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from tqdm import trange
import ray
from ray.tune import CLIReporter
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.resnet import ResNet18
from ray.util.sgd.utils import BATCH_SIZE
@@ -42,12 +38,12 @@ def cifar_creator(config):
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
train_dataset = torchvision.datasets.CIFAR10(
train_dataset = CIFAR10(
root="~/data", train=True, download=True, transform=transform_train)
validation_dataset = torchvision.datasets.CIFAR10(
validation_dataset = CIFAR10(
root="~/data", train=False, download=False, transform=transform_test)
if config.get("test_mode"):
if config["test_mode"]:
train_dataset = Subset(train_dataset, list(range(64)))
validation_dataset = Subset(validation_dataset, list(range(64)))
@@ -71,95 +67,6 @@ def scheduler_creator(optimizer, config):
optimizer, milestones=[150, 250, 350], gamma=0.1)
def train_example(num_workers=1,
num_epochs=5,
use_gpu=False,
use_fp16=False,
test_mode=False):
trainer1 = TorchTrainer(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
scheduler_creator=scheduler_creator,
initialization_hook=initialization_hook,
num_workers=num_workers,
config={
"lr": 0.1,
"test_mode": test_mode, # user-defined param to subset the data
BATCH_SIZE: 128 * num_workers # this will be split across workers.
},
use_gpu=use_gpu,
scheduler_step_freq="epoch",
use_fp16=use_fp16,
use_tqdm=True)
pbar = trange(num_epochs, unit="epoch")
for i in pbar:
info = {"num_steps": 1} if test_mode else {}
info["epoch_idx"] = i
info["num_epochs"] = num_epochs
# Increase `max_retries` to turn on fault tolerance.
trainer1.train(max_retries=1, info=info)
val_stats = trainer1.validate()
pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))
print(trainer1.validate())
trainer1.shutdown()
print("success!")
def tune_example(num_workers=1, use_gpu=False, use_fp16=False,
test_mode=False):
TorchTrainable = TorchTrainer.as_trainable(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
scheduler_creator=scheduler_creator,
initialization_hook=initialization_hook,
num_workers=num_workers,
config={
"test_mode": test_mode, # user-defined param to subset the data
BATCH_SIZE: 128 * num_workers,
},
use_gpu=use_gpu,
scheduler_step_freq="epoch",
use_fp16=use_fp16)
pbt_scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="val_loss",
mode="min",
perturbation_interval=1,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: np.random.uniform(0.001, 1),
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
reporter = CLIReporter()
reporter.add_metric_column("val_loss", "loss")
reporter.add_metric_column("val_accuracy", "acc")
analysis = tune.run(
TorchTrainable,
num_samples=4,
config={
"lr": tune.choice([0.001, 0.01, 0.1]),
"momentum": 0.8
},
stop={"training_iteration": 2 if test_mode else 100},
max_failures=3, # used for fault tolerance
checkpoint_freq=3, # used for fault tolerance
keep_checkpoints_num=1, # used for fault tolerance
verbose=2,
progress_reporter=reporter,
scheduler=pbt_scheduler)
return analysis.get_best_config(metric="val_loss", mode="min")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -194,18 +101,37 @@ if __name__ == "__main__":
"--tune", action="store_true", default=False, help="Tune training")
args, _ = parser.parse_known_args()
num_cpus = 4 if args.smoke_test else None
ray.init(address=args.address, num_cpus=num_cpus, log_to_driver=True)
ray.init(address=args.address, log_to_driver=True)
trainer1 = TorchTrainer(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
scheduler_creator=scheduler_creator,
initialization_hook=initialization_hook,
num_workers=args.num_workers,
config={
"lr": 0.1,
"test_mode": args.smoke_test, # subset the data
# this will be split across workers.
BATCH_SIZE: 128 * args.num_workers
},
use_gpu=args.use_gpu,
scheduler_step_freq="epoch",
use_fp16=args.fp16,
use_tqdm=True)
pbar = trange(args.num_epochs, unit="epoch")
for i in pbar:
info = {"num_steps": 1} if args.smoke_test else {}
info["epoch_idx"] = i
info["num_epochs"] = args.num_epochs
# Increase `max_retries` to turn on fault tolerance.
trainer1.train(max_retries=1, info=info)
val_stats = trainer1.validate()
pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))
if args.tune:
tune_example(
num_workers=args.num_workers,
use_gpu=args.use_gpu,
test_mode=args.smoke_test)
else:
train_example(
num_workers=args.num_workers,
num_epochs=args.num_epochs,
use_gpu=args.use_gpu,
use_fp16=args.fp16,
test_mode=args.smoke_test)
print(trainer1.validate())
trainer1.shutdown()
print("success!")
@@ -0,0 +1,148 @@
import numpy as np
import os
import torch
import torch.nn as nn
import argparse
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
import ray
from ray.tune import CLIReporter
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.resnet import ResNet18
from ray.util.sgd.utils import BATCH_SIZE
def initialization_hook():
# Need this for avoiding a connection restart issue on AWS.
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
os.environ["NCCL_LL_THRESHOLD"] = "0"
# set the below if needed
# print("NCCL DEBUG SET")
# os.environ["NCCL_DEBUG"] = "INFO"
def cifar_creator(config):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
]) # meanstd transformation
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(
root="~/data", train=True, download=True, transform=transform_train)
validation_dataset = CIFAR10(
root="~/data", train=False, download=False, transform=transform_test)
if config.get("test_mode"):
train_dataset = Subset(train_dataset, list(range(64)))
validation_dataset = Subset(validation_dataset, list(range(64)))
train_loader = DataLoader(
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
validation_loader = DataLoader(
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
return train_loader, validation_loader
def optimizer_creator(model, config):
"""Returns optimizer"""
return torch.optim.SGD(
model.parameters(),
lr=config.get("lr", 0.1),
momentum=config.get("momentum", 0.9))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address",
required=False,
type=str,
help="the address to use for Redis")
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=1,
help="Sets number of workers for training.")
parser.add_argument(
"--num-epochs", type=int, default=5, help="Number of epochs to train.")
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="Enables GPU training")
parser.add_argument(
"--fp16",
action="store_true",
default=False,
help="Enables FP16 training with apex. Requires `use-gpu`.")
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.")
parser.add_argument(
"--tune", action="store_true", default=False, help="Tune training")
args, _ = parser.parse_known_args()
ray.init(address=args.address, log_to_driver=True)
TorchTrainable = TorchTrainer.as_trainable(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
initialization_hook=initialization_hook,
num_workers=args.num_workers,
config={
"test_mode": args.smoke_test, # whether to to subset the data
BATCH_SIZE: 128 * args.num_workers,
},
use_gpu=args.use_gpu,
use_fp16=args.fp16)
pbt_scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="val_loss",
mode="min",
perturbation_interval=1,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: np.random.uniform(0.001, 1),
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
reporter = CLIReporter()
reporter.add_metric_column("val_loss", "loss")
reporter.add_metric_column("val_accuracy", "acc")
analysis = tune.run(
TorchTrainable,
num_samples=4,
config={
"lr": tune.choice([0.001, 0.01, 0.1]),
"momentum": 0.8
},
stop={"training_iteration": 2 if args.smoke_test else 100},
max_failures=3, # used for fault tolerance
checkpoint_freq=3, # used for fault tolerance
keep_checkpoints_num=1, # used for fault tolerance
verbose=2,
progress_reporter=reporter,
scheduler=pbt_scheduler)
print(analysis.get_best_config(metric="val_loss", mode="min"))
+1 -3
View File
@@ -128,9 +128,6 @@ def optimizer_creator(models, config):
class GANOperator(TrainingOperator):
def setup(self, config):
self.device = torch.device("cuda"
if torch.cuda.is_available() else "cpu")
self.classifier = LeNet()
self.classifier.load_state_dict(
torch.load(config["classification_model_path"]))
@@ -183,6 +180,7 @@ class GANOperator(TrainingOperator):
# Compute a discriminator update for real images
discriminator.zero_grad()
# self.device is set automatically
real_cpu = batch[0].to(self.device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size, ), real_label, device=self.device)
@@ -7,31 +7,15 @@ but we put comments right after code blocks to prevent large white spaces
in the documentation.
"""
# __torch_tune_example__
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import ray
from ray import tune
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.utils import BATCH_SIZE
class LinearDataset(torch.utils.data.Dataset):
"""y = a * x + b"""
def __init__(self, a, b, size=1000):
x = np.random.random(size).astype(np.float32) * 10
x = np.arange(0, 10, 10 / size, dtype=np.float32)
self.x = torch.from_numpy(x)
self.y = torch.from_numpy(a * x + b)
def __getitem__(self, index):
return self.x[index, None], self.y[index, None]
def __len__(self):
return len(self.x)
from ray.util.sgd.torch.examples.train_example import LinearDataset
def model_creator(config):
@@ -47,22 +31,18 @@ def data_creator(config):
"""Returns training dataloader, validation dataloader."""
train_dataset = LinearDataset(2, 5)
val_dataset = LinearDataset(2, 5, size=400)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config[BATCH_SIZE],
)
validation_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=config[BATCH_SIZE])
train_loader = DataLoader(train_dataset, batch_size=config[BATCH_SIZE])
validation_loader = DataLoader(val_dataset, batch_size=config[BATCH_SIZE])
return train_loader, validation_loader
# __torch_tune_example__
def tune_example(num_workers=1, use_gpu=False):
TorchTrainable = TorchTrainer.as_trainable(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
loss_creator=nn.MSELoss, # Note that we specify a Loss class.
num_workers=num_workers,
use_gpu=use_gpu,
config={BATCH_SIZE: 128}
@@ -76,6 +56,7 @@ def tune_example(num_workers=1, use_gpu=False):
verbose=1)
return analysis.get_best_config(metric="validation_loss", mode="min")
# __end_torch_tune_example__
if __name__ == "__main__":
+18 -4
View File
@@ -23,6 +23,14 @@ except ImportError:
pass
def _remind_gpu_usage(use_gpu, is_chief):
if not is_chief:
return
if not use_gpu and torch.cuda.is_available():
logger.info("GPUs detected but not using them. Set `use_gpu` to "
"enable GPU usage. ")
class TorchRunner:
"""Manages a PyTorch model for training.
@@ -36,6 +44,7 @@ class TorchRunner:
torch_trainer.py.
training_operator_cls: see torch_trainer.py
config (dict): see torch_trainer.py.
use_gpu (bool): see torch_trainer.py.
use_fp16 (bool): see torch_trainer.py.
apex_args (dict|None): see torch_trainer.py.
scheduler_step_freq (str): see torch_trainer.py.
@@ -49,6 +58,7 @@ class TorchRunner:
scheduler_creator=None,
training_operator_cls=None,
config=None,
use_gpu=False,
use_fp16=False,
use_tqdm=False,
apex_args=None,
@@ -69,6 +79,7 @@ class TorchRunner:
self.schedulers = None
self.train_loader = None
self.validation_loader = None
self.use_gpu = use_gpu
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
self.apex_args = apex_args or {}
@@ -117,8 +128,9 @@ class TorchRunner:
else:
self.criterion = self.loss_creator(self.config)
if torch.cuda.is_available() and hasattr(self.criterion, "cuda"):
self.criterion = self.criterion.cuda()
if self.use_gpu and torch.cuda.is_available():
if hasattr(self.criterion, "cuda"):
self.criterion = self.criterion.cuda()
def _create_schedulers_if_available(self):
# Learning rate schedules are optional.
@@ -138,11 +150,13 @@ class TorchRunner:
def setup(self):
"""Initializes the model."""
_remind_gpu_usage(self.use_gpu, is_chief=True)
self._initialize_dataloaders()
logger.debug("Creating model")
self.models = self.model_creator(self.config)
if not isinstance(self.models, collections.Iterable):
self.models = [self.models]
if torch.cuda.is_available():
if self.use_gpu and torch.cuda.is_available():
self.models = [model.cuda() for model in self.models]
logger.debug("Creating optimizer")
@@ -153,7 +167,6 @@ class TorchRunner:
self._create_schedulers_if_available()
self._try_setup_apex()
self._create_loss()
self._initialize_dataloaders()
self.training_operator = self.training_operator_cls(
self.config,
models=self.models,
@@ -163,6 +176,7 @@ class TorchRunner:
validation_loader=self.validation_loader,
world_rank=0,
schedulers=self.schedulers,
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm)
+60 -36
View File
@@ -34,12 +34,13 @@ class TorchTrainer:
"""Train a PyTorch model using distributed PyTorch.
Launches a set of actors which connect via distributed PyTorch and
coordinate gradient updates to train the provided model.
coordinate gradient updates to train the provided model. If Ray is not
initialized, TorchTrainer will automatically initialize a local Ray
cluster for you. Be sure to run `ray.init(address="auto")` to leverage
multi-node training.
.. code-block:: python
ray.init()
def model_creator(config):
return nn.Linear(1, 1)
@@ -116,6 +117,9 @@ class TorchTrainer:
support "nccl", "gloo", and "auto". If "auto", RaySGD will
automatically use "nccl" if `use_gpu` is True, and "gloo"
otherwise.
add_dist_sampler (bool): Whether to automatically add a
DistributedSampler to all created dataloaders. Only applicable
if num_workers > 1.
use_fp16 (bool): Enables mixed precision training via apex if apex
is installed. This is automatically done after the model and
optimizers are constructed and will work for multi-model training.
@@ -148,11 +152,12 @@ class TorchTrainer:
initialization_hook=None,
config=None,
num_workers=1,
use_gpu=False,
use_gpu="auto",
backend="auto",
use_fp16=False,
use_tqdm=False,
apex_args=None,
add_dist_sampler=True,
scheduler_step_freq="batch",
num_replicas=None,
batch_size=None,
@@ -202,19 +207,20 @@ class TorchTrainer:
self.initialization_hook = initialization_hook
self.config = {} if config is None else config
if use_gpu == "auto":
use_gpu = torch.cuda.is_available()
if backend == "auto":
backend = "nccl" if use_gpu else "gloo"
logger.debug("Using {} as backend.".format(backend))
self.backend = backend
# TODO: Have an auto "use_gpu" option to detect and use GPUs.
self.use_gpu = use_gpu
self.max_replicas = num_workers
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
self.add_dist_sampler = add_dist_sampler
if apex_args and not isinstance(apex_args, dict):
raise ValueError("apex_args needs to be a dict object.")
@@ -230,6 +236,11 @@ class TorchTrainer:
_validate_scheduler_step_freq(scheduler_step_freq)
self.scheduler_step_freq = scheduler_step_freq
if not ray.is_initialized() and self.max_replicas > 1:
logger.info("Automatically initializing single-node Ray. To use "
"multi-node training, be sure to run `ray.init("
"address='auto')` before instantiating the Trainer.")
ray.init()
self._start_workers(self.max_replicas)
def _configure_and_split_batch(self, num_workers):
@@ -259,39 +270,29 @@ class TorchTrainer:
if batch_size_per_worker:
worker_config[BATCH_SIZE] = batch_size_per_worker
params = dict(
model_creator=self.model_creator,
data_creator=self.data_creator,
optimizer_creator=self.optimizer_creator,
loss_creator=self.loss_creator,
scheduler_creator=self.scheduler_creator,
training_operator_cls=self.training_operator_cls,
config=worker_config,
use_fp16=self.use_fp16,
use_gpu=self.use_gpu,
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
if num_workers == 1:
# Start local worker
self.local_worker = TorchRunner(
model_creator=self.model_creator,
data_creator=self.data_creator,
optimizer_creator=self.optimizer_creator,
loss_creator=self.loss_creator,
scheduler_creator=self.scheduler_creator,
training_operator_cls=self.training_operator_cls,
config=worker_config,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
self.local_worker = TorchRunner(**params)
if self.initialization_hook:
self.apply_all_workers(self.initialization_hook)
self.local_worker.setup()
else:
params = dict(
model_creator=self.model_creator,
data_creator=self.data_creator,
optimizer_creator=self.optimizer_creator,
loss_creator=self.loss_creator,
scheduler_creator=self.scheduler_creator,
backend=self.backend,
training_operator_cls=self.training_operator_cls,
config=worker_config,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
params.update(
backend=self.backend, add_dist_sampler=self.add_dist_sampler)
# Start local worker
self.local_worker = LocalDistributedRunner(
@@ -455,7 +456,11 @@ class TorchTrainer:
local_call = self.local_worker.apply_operator(fn)
return [local_call] + ray.get(remote_calls)
def validate(self, num_steps=None, profile=False, info=None):
def validate(self,
num_steps=None,
profile=False,
reduce_results=True,
info=None):
"""Evaluates the model on the validation data set.
Args:
@@ -463,6 +468,10 @@ class TorchTrainer:
This corresponds also to the number of times
``TrainingOperator.validate_batch`` is called.
profile (bool): Returns time stats for the evaluation procedure.
reduce_results (bool): Whether to average all metrics across
all workers into one dict. If a metric is a non-numerical
value (or nested dictionaries), one value will be randomly
selected among the workers. If False, returns a list of dicts.
info (dict): Optional dictionary passed to the training
operator for `validate` and `validate_batch`.
@@ -477,8 +486,12 @@ class TorchTrainer:
w.validate.remote(**params) for w in self.remote_workers
]
local_worker_stats = self.local_worker.validate(**params)
return self._process_stats([local_worker_stats] +
ray.get(remote_worker_stats))
worker_stats = [local_worker_stats] + ray.get(remote_worker_stats)
if reduce_results:
return self._process_stats(worker_stats)
else:
return worker_stats
def update_scheduler(self, metric):
"""Calls ``scheduler.step(metric)`` on all schedulers.
@@ -497,6 +510,17 @@ class TorchTrainer:
return unwrapped[0]
return unwrapped
def get_local_operator(self):
"""Returns the local TrainingOperator object.
Be careful not to perturb its state, or else you can cause the system
to enter an inconsistent state.
Returns:
TrainingOperator: The local TrainingOperator object.
"""
return self.local_worker.training_operator
def state_dict(self):
return self.local_worker.state_dict()
+10 -2
View File
@@ -60,6 +60,7 @@ class TrainingOperator:
world_rank,
criterion=None,
schedulers=None,
use_gpu=False,
use_fp16=False,
use_tqdm=False):
# You are not expected to override this method.
@@ -80,6 +81,8 @@ class TrainingOperator:
type(schedulers)))
self._config = config
self._use_fp16 = use_fp16
self._use_gpu = use_gpu and torch.cuda.is_available()
self._device = torch.device("cuda" if self._use_gpu else "cpu")
if tqdm is None and use_tqdm:
raise ValueError("tqdm must be installed to use tqdm in training.")
self._use_tqdm = use_tqdm
@@ -324,13 +327,18 @@ class TrainingOperator:
}
def state_dict(self):
"""Returns a serializable representation of the operator state."""
"""Override this to return a representation of the operator state."""
pass
def load_state_dict(self, state_dict):
"""Loads a serializable representation of the operator state."""
"""Override this to load the representation of the operator state."""
pass
@property
def device(self):
"""The torch device, at your convenience."""
return self._device
@property
def config(self):
"""Dictionary as provided into TorchTrainer."""