[Ray SGD] LightningModule integration + MNIST Example (#11042)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Amog Kamsetty
2020-10-01 20:07:32 -07:00
committed by GitHub
parent 9dc7b7b11d
commit 874da9a25f
9 changed files with 1120 additions and 118 deletions
+17
View File
@@ -2,6 +2,14 @@
# Tests from the python/ray/util/sgd/tests directory.
# Please keep these sorted alphabetically.
# --------------------------------------------------------------------
py_test(
name = "test_ptl",
size = "small",
srcs = ["tests/test_ptl.py"],
tags = ["exclusive", "pytorch-lightning", "pytorch"],
deps = [":sgd_lib"],
)
py_test(
name = "test_tensorflow",
size = "small",
@@ -214,6 +222,15 @@ py_test(
args = ["--no-gpu", "--mock-data", "--smoke-test", "--ray-num-workers=2", "--model=mobilenetv3_small_075", "data"]
)
py_test(
name = "mnist-ptl",
size = "small",
srcs = ["torch/examples/pytorch-lightning/mnist-ptl.py"],
tags = ["exclusive", "pytorch", "pytorch-lightning"],
deps = [":sgd_lib"],
args = ["--smoke-test"]
)
# This is a dummy test dependency that causes the above tests to be
# re-run if any of these files changes.
py_library(
+211
View File
@@ -0,0 +1,211 @@
import copy
import os
import pytest
import ray
import torch
from ray.util.sgd.utils import BATCH_COUNT
import torch.distributed as dist
from pytorch_lightning import LightningModule
from ray.util.sgd import TorchTrainer
from ray.util.sgd.torch import TrainingOperator
from ray.util.sgd.torch.examples.train_example import \
optimizer_creator, data_creator, scheduler_creator, model_creator
from torch import nn
import numpy as np
torch.manual_seed(0)
np.random.seed(0)
@pytest.fixture
def ray_start_2_cpus():
address_info = ray.init(num_cpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
# Ensure that tests don't ALL fail
if dist.is_initialized():
dist.destroy_process_group()
class PTL_Module(LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
if "layer" in config:
self.layer = copy.deepcopy(config["layer"])
else:
self.layer = model_creator(self.config)
self.rand_int = np.random.randint(10)
def forward(self, x):
return self.layer.forward(x)
def configure_optimizers(self):
optimizer = optimizer_creator(self, self.config)
scheduler = scheduler_creator(optimizer, self.config)
return [optimizer], [scheduler]
def training_step(self, batch, batch_idx):
x, y = batch
output = self(x)
loss = self.loss(output, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
output = self(x)
loss = self.loss(output, y)
_, predicted = torch.max(output.data, 1)
num_correct = (predicted == y).sum().item()
num_samples = y.size(0)
return {"val_loss": loss.item(), "val_acc": num_correct / num_samples}
def setup(self, stage):
self.train_loader, self.val_loader = data_creator(self.config)
self.loss = nn.MSELoss()
def train_dataloader(self):
return self.train_loader
def val_dataloader(self):
return self.val_loader
def on_save_checkpoint(self, checkpoint):
checkpoint["int"] = self.rand_int
def on_load_checkpoint(self, checkpoint):
self.rand_int = checkpoint["int"]
Operator = TrainingOperator.from_ptl(PTL_Module)
@pytest.mark.parametrize("use_local", [True, False])
def test_single_step(ray_start_2_cpus, use_local): # noqa: F811
trainer = TorchTrainer(
training_operator_cls=Operator,
num_workers=1,
use_local=use_local,
use_gpu=False)
metrics = trainer.train(num_steps=1)
assert metrics[BATCH_COUNT] == 1
val_metrics = trainer.validate(num_steps=1)
assert val_metrics[BATCH_COUNT] == 1
trainer.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("use_local", [True, False])
def test_train(ray_start_2_cpus, num_workers, use_local): # noqa: F811
trainer = TorchTrainer(
training_operator_cls=Operator,
num_workers=num_workers,
use_local=use_local,
use_gpu=False)
for i in range(3):
train_loss1 = trainer.train()["train_loss"]
validation_loss1 = trainer.validate()["val_loss"]
for i in range(3):
train_loss2 = trainer.train()["train_loss"]
validation_loss2 = trainer.validate()["val_loss"]
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
assert validation_loss2 <= validation_loss1, (validation_loss2,
validation_loss1)
trainer.shutdown()
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("use_local", [True, False])
def test_save_and_restore(ray_start_2_cpus, num_workers, use_local,
tmp_path): # noqa: F811
trainer1 = TorchTrainer(
training_operator_cls=Operator,
num_workers=num_workers,
use_local=use_local)
trainer1.train()
checkpoint_path = os.path.join(tmp_path, "checkpoint")
trainer1.save(checkpoint_path)
model1 = trainer1.get_model()
ints1 = trainer1.apply_all_operators(lambda op: op.get_model().rand_int)[0]
trainer1.shutdown()
trainer2 = TorchTrainer(
training_operator_cls=Operator,
num_workers=num_workers,
use_local=use_local)
trainer2.load(checkpoint_path)
model2 = trainer2.get_model()
ints2 = trainer2.apply_all_operators(lambda op: op.get_model().rand_int)
model1_state_dict = model1.state_dict()
model2_state_dict = model2.state_dict()
assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())
for k in model1_state_dict:
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
for i in ints2:
assert i == ints1
trainer2.shutdown()
class CorrectnessOperator(TrainingOperator):
def setup(self, config):
model = PTL_Module(config)
opt = optimizer_creator(model, config)
scheduler = scheduler_creator(opt, config)
self.model, self.optimizer, self.criterion, self.scheduler = \
self.register(
models=model, optimizers=opt, criterion=nn.MSELoss(),
schedulers=scheduler)
train_loader, val_loader = data_creator(config)
self.register_data(
train_loader=train_loader, validation_loader=val_loader)
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
@pytest.mark.parametrize("use_local", [True, False])
def test_correctness(ray_start_2_cpus, num_workers, use_local):
layer = nn.Linear(1, 1)
ptl_op = TrainingOperator.from_ptl(PTL_Module)
trainer1 = TorchTrainer(
training_operator_cls=ptl_op,
config={
"layer": layer,
"data_size": 3,
"batch_size": 1
},
num_workers=num_workers,
use_local=use_local)
train1_stats = trainer1.train()
val1_stats = trainer1.validate()
trainer1.shutdown()
trainer2 = TorchTrainer(
training_operator_cls=CorrectnessOperator,
scheduler_step_freq="manual",
config={
"layer": layer,
"data_size": 3,
"batch_size": 1
},
num_workers=num_workers,
use_local=use_local)
train2_stats = trainer2.train()
val2_stats = trainer2.validate()
trainer2.shutdown()
assert train1_stats["train_loss"] == train2_stats["train_loss"]
assert val1_stats["val_loss"] == val2_stats["val_loss"]
assert val1_stats["val_acc"] == val2_stats["val_accuracy"]
@@ -3,8 +3,6 @@ import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from ray.util.sgd.torch.utils import setup_process_group
import ray
@@ -42,6 +40,7 @@ class DistributedTorchRunner(TorchRunner):
self.wrap_ddp = wrap_ddp
self.add_dist_sampler = add_dist_sampler
self.world_rank = None
self.local_rank = None
def setup_address(self):
return setup_address()
@@ -61,6 +60,14 @@ class DistributedTorchRunner(TorchRunner):
setup_process_group(
url, world_rank, world_size, timeout, backend=self.backend)
def set_local_rank(self, local_rank):
"""Sets the local rank of this runner.
Args:
local_rank (int): the index of the runner on its node.
"""
self.local_rank = local_rank
def setup_operator(self):
"""Runs distributed coordination components.
@@ -74,13 +81,14 @@ class DistributedTorchRunner(TorchRunner):
self.training_operator = self.training_operator_cls(
self.config,
world_rank=self.world_rank,
local_rank=self.local_rank,
is_distributed=True,
device_ids=device_ids,
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm,
apex_args=self.apex_args,
wrap_ddp=self.wrap_ddp,
wrap_distributed_sampler=True,
add_dist_sampler=self.add_dist_sampler,
scheduler_step_freq=self.scheduler_step_freq)
@@ -88,45 +96,21 @@ class DistributedTorchRunner(TorchRunner):
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
return [0]
def _wrap_dataloaders(self):
def with_sampler(loader):
# Automatically set the DistributedSampler
data_loader_args = {
"dataset": loader.dataset,
"batch_size": loader.batch_size,
"shuffle": False,
"num_workers": loader.num_workers,
"collate_fn": loader.collate_fn,
"pin_memory": loader.pin_memory,
"drop_last": loader.drop_last,
"timeout": loader.timeout,
"worker_init_fn": loader.worker_init_fn,
"sampler": DistributedSampler(loader.dataset)
}
return DataLoader(**data_loader_args)
def should_wrap_dataloader(loader):
return (isinstance(loader, DataLoader)
and not isinstance(loader.dataset, IterableDataset))
if should_wrap_dataloader(self.train_loader):
if self.add_dist_sampler:
self.train_loader = with_sampler(self.train_loader)
if self.validation_loader is not None and should_wrap_dataloader(
self.validation_loader):
if self.add_dist_sampler:
self.validation_loader = with_sampler(self.validation_loader)
def train_epoch(self, **kwargs):
def train_epoch(self,
num_steps=None,
profile=False,
info=None,
iterator=None):
"""Runs a training epoch and updates the model parameters.
Automatically sets epoch of sampler if possible.
"""
if hasattr(self.train_loader, "sampler") and hasattr(
if iterator is None and hasattr(self.train_loader, "sampler") and \
hasattr(
self.train_loader.sampler, "set_epoch"):
self.train_loader.sampler.set_epoch(self.epochs)
return super(DistributedTorchRunner, self).train_epoch(**kwargs)
return super(DistributedTorchRunner, self).train_epoch(
num_steps=num_steps, profile=profile, info=info, iterator=iterator)
def shutdown(self):
"""Attempts to shut down the worker."""
@@ -0,0 +1,145 @@
import argparse
import torch
from ray.util.sgd import TorchTrainer
from ray.util.sgd.torch import TrainingOperator
from torch.nn import functional as F
from pytorch_lightning.core.lightning import LightningModule
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import transforms
class LitMNIST(LightningModule):
def __init__(self, config):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
self.config = config
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.config["lr"])
def setup(self, stage):
# transforms for images
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])
# prepare transforms standard to MNIST
mnist_train = MNIST(
os.getcwd(), train=True, download=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(
mnist_train, [55000, 5000])
def train_dataloader(self):
return DataLoader(
self.mnist_train, batch_size=self.config["batch_size"])
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.config["batch_size"])
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
_, predicted = torch.max(logits.data, 1)
num_correct = (predicted == y).sum().item()
num_samples = y.size(0)
return {"val_loss": loss.item(), "val_acc": num_correct / num_samples}
def train_mnist(num_workers=1, use_gpu=False, num_epochs=5):
Operator = TrainingOperator.from_ptl(LitMNIST)
trainer = TorchTrainer(
training_operator_cls=Operator,
num_workers=num_workers,
config={
"lr": 1e-3,
"batch_size": 64
},
use_gpu=use_gpu,
use_tqdm=True,
)
for i in range(num_epochs):
stats = trainer.train()
print(stats)
print(trainer.validate())
print("Saving model checkpoint to ./model.pt")
trainer.save("./model.pt")
print("Model Checkpointed!")
trainer.shutdown()
print("success!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address",
required=False,
type=str,
help="the address to use for Ray")
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=1,
help="Sets number of workers for training.")
parser.add_argument(
"--use-gpu",
action="store_true",
default=False,
help="Enables GPU training")
parser.add_argument(
"--num-epochs",
type=int,
default=5,
help="How many epochs to train "
"for.")
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.")
args, _ = parser.parse_known_args()
import ray
if args.smoke_test:
ray.init(num_cpus=2)
args.num_epochs = 1
else:
ray.init(address=args.address)
train_mnist(
num_workers=args.num_workers,
use_gpu=args.use_gpu,
num_epochs=args.num_epochs)
+502
View File
@@ -0,0 +1,502 @@
import inspect
import logging
import torch
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.overrides.data_parallel import \
LightningDistributedDataParallel
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
import pytorch_lightning as ptl
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from ray.util.sgd.torch import TrainingOperator
from ray.util.sgd.torch.constants import NUM_STEPS, SCHEDULER_STEP_BATCH, \
SCHEDULER_STEP_EPOCH
from ray.util.sgd.utils import AverageMeterCollection, NUM_SAMPLES
tqdm = None
try:
from tqdm import tqdm
except ImportError:
pass
logger = logging.getLogger(__name__)
class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
TrainerOptimizersMixin):
def _configure_amp(self, amp, models, optimizers):
assert len(models) == 1
model = models[0]
assert isinstance(model, ptl.LightningModule)
amp_level = self._apex_args.get("opt_level", "O2")
model, optimizers = model.configure_apex(
amp, model, optimizers, amp_level=amp_level)
return [model], optimizers
def _configure_ddp(self, models, device_ids):
assert len(models) == 1
model = models[0]
assert isinstance(model, ptl.LightningModule)
# This will default to LightningDistributedDataParallel.
model = model.configure_ddp(model=model, device_ids=device_ids)
return [model]
@property
def model(self):
"""The LightningModule to use for training.
The returned model is wrapped in DDP if using distributed training.
"""
return self._model
@property
def scheduler_dicts(self):
"""Returns list of scheduler dictionaries.
List is empty if no schedulers are returned in the
configure_optimizers method of your LightningModule. Default
configuration is used if configure_optimizers returns scheduler
objects instead of scheduler dicts. See
https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#configure-optimizers
"""
return self._scheduler_dicts
@property
def optimizers(self):
"""Returns list of optimizers as returned by configure_optimizers."""
return self._optimizers
@property
def schedulers(self):
"""Returns list of schedulers as returned by configure_optimizers.
List is empty if no schedulers are returned in configure_optimizers.
"""
return self._schedulers
def get_model(self):
"""Returns original LightningModule, not wrapped in DDP."""
if isinstance(self.model, LightningDistributedDataParallel):
return self.model.module
else:
return self.model
def setup(self, config):
# Pass in config if ptl_module accepts it.
ptl_class = self.__class__._lightning_module_cls
if not issubclass(ptl_class, ptl.LightningModule):
raise TypeError("Argument must be subclass of "
"pytorch_lightning.LightningModule. Got class {} "
"instead.".format(ptl_class))
if "config" in inspect.signature(ptl_class.__init__).parameters:
ptl_module = ptl_class(config=config)
else:
ptl_module = ptl_class()
# This is needed for LightningDistributedDataParallel.
ptl_module.testing = False
# Call on_fit_start on instantiation.
if self.is_function_implemented("on_fit_start", ptl_module):
ptl_module.on_fit_start()
# Only run data preparation once per node.
if self.local_rank == 0 and self.is_function_implemented(
"prepare_data", ptl_module):
ptl_module.prepare_data()
# Call model.setup.
ptl_module.setup("fit")
if not self.is_overridden("configure_optimizers", ptl_module):
raise MisconfigurationException(
"No `configure_optimizers()` method defined.")
optimizers, self._scheduler_dicts, optimizer_frequencies = \
self.init_optimizers(model=ptl_module)
if len(optimizer_frequencies) > 0:
logger.warning("Optimizer frequencies will be ignored. When "
"passing in multiple optimizers, you should "
"implement your own custom training loop.")
lr_schedulers = []
for scheduler in self.scheduler_dicts:
if isinstance(scheduler, dict):
# A scheduler dictionary is passed in.
if "reduce_on_plateau" in scheduler and "monitor" in \
scheduler and scheduler["reduce_on_plateau"] is True:
logger.info(
"reduce_on_plateau and monitor will be "
"ignored "
"from the scheduler dict {}. To update a "
"ReduceLROnPlateau scheduler, you should use "
"TorchTrainer.update_schedulers.".format(scheduler))
if "frequency" in scheduler and scheduler["frequency"] > 1:
logger.info("frequency will be ignored from the "
"scheduler dict {}.".format(scheduler))
lr_schedulers.append(scheduler["scheduler"])
else:
lr_schedulers.append(scheduler)
# Set this so register doesn't complain.
self._scheduler_step_freq = "ptl"
ddp_model, self._optimizers, self._schedulers = self.register(
models=[ptl_module],
optimizers=optimizers,
schedulers=lr_schedulers)
assert len(ddp_model) == 1
self._model = ddp_model[0]
model = self.get_model()
if self.is_function_implemented("on_pretrain_routine_start", model):
model.on_pretrain_routine_start()
train_data_loader = None
if self.__class__._train_dataloader:
train_data_loader = self.__class__._train_dataloader
elif self.is_function_implemented("train_dataloader", model):
train_data_loader = model.train_dataloader()
val_data_loader = None
if self.__class__._val_dataloader:
val_data_loader = self.__class__._val_dataloader
elif self.is_function_implemented("val_dataloader", model):
val_data_loader = model.val_dataloader()
self.register_data(
train_loader=train_data_loader, validation_loader=val_data_loader)
def train_epoch(self, iterator, info):
model = self.get_model()
# Enable train mode.
self.model.train()
# Enable gradients.
torch.set_grad_enabled(True)
if self.is_function_implemented("on_train_epoch_start", model):
model.on_train_epoch_start()
if self.use_tqdm and self.world_rank == 0:
desc = ""
if info is not None and "epoch_idx" in info:
if "num_epochs" in info:
desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e"
else:
desc = f"{info['epoch_idx'] + 1}e"
# TODO: Implement len for Dataset?
total = info[NUM_STEPS]
if total is None:
if hasattr(iterator, "__len__"):
total = len(iterator)
_progress_bar = tqdm(
total=total, desc=desc, unit="batch", leave=False)
# Output for each batch.
epoch_outputs = []
for batch_idx, batch in enumerate(iterator):
batch_info = {
"batch_idx": batch_idx,
"global_step": self.global_step
}
batch_info.update(info)
batch_output = self.train_batch(batch, batch_info=batch_info)
# batch output for each optimizer.
epoch_outputs.append(batch_output)
should_stop = batch_output["signal"] == -1
if self.use_tqdm and self.world_rank == 0:
_progress_bar.n = batch_idx + 1
postfix = {}
if "training_loss" in batch_output:
postfix.update(loss=batch_output["training_loss"])
_progress_bar.set_postfix(postfix)
for s_dict, scheduler in zip(self.scheduler_dicts,
self.schedulers):
if s_dict["interval"] == SCHEDULER_STEP_BATCH:
scheduler.step()
self.global_step += 1
if should_stop:
break
processed_outputs = None
if self.is_overridden("training_epoch_end", model):
raw_outputs = [eo["raw_output"] for eo in epoch_outputs]
processed_outputs = model.training_epoch_end(raw_outputs)
if processed_outputs is not None:
if isinstance(processed_outputs, torch.Tensor):
return_output = {"train_loss": processed_outputs}
elif isinstance(processed_outputs, Result):
raise ValueError("Result objects are not supported. Please "
"return a dictionary instead.")
elif isinstance(processed_outputs, dict):
return_output = processed_outputs
else:
raise TypeError("training_epoch_end returned an invalid "
"type. It must return a Tensor, Result, "
"or dict.")
else:
# User did not override training_epoch_end
assert isinstance(epoch_outputs, list)
# Use AverageMeterCollection util to reduce results.
meter_collection = AverageMeterCollection()
for o in epoch_outputs:
num_samples = o.pop(NUM_SAMPLES, 1)
raw_output = o["raw_output"]
if isinstance(raw_output, dict):
meter_collection.update(raw_output, num_samples)
elif isinstance(raw_output, torch.Tensor):
meter_collection.update({
"train_loss": o["training_loss"]
}, num_samples)
return_output = meter_collection.summary()
if self.is_function_implemented("on_train_epoch_end", model):
model.on_train_epoch_end()
for s_dict, scheduler in zip(self.scheduler_dicts, self.schedulers):
if s_dict["interval"] == SCHEDULER_STEP_EPOCH:
scheduler.step()
return return_output
def train_batch(self, batch, batch_info):
# Get the original PTL module.
model = self.get_model()
optimizer = self.optimizers[0]
batch_idx = batch_info["batch_idx"]
epoch_idx = batch_info["epoch_idx"]
if self.is_function_implemented("on_train_batch_start", model):
response = model.on_train_batch_start(
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
# Skip remainder of epoch if response is -1.
if response == -1:
return {"signal": -1}
args = [batch, batch_idx]
if len(self.optimizers) > 1:
if self.has_arg("training_step", "optimizer_idx"):
args.append(0)
with self.timers.record("fwd"):
if self._is_distributed:
# Use the DDP wrapped model (self.model).
output = self.model(*args)
elif self.use_gpu:
# Using single GPU.
# Don't copy the batch since there is a single gpu that
# the batch could be referenced from and if there are
# multiple optimizers the batch will wind up copying it to
# the same device repeatedly.
device = self.device
batch = model.transfer_batch_to_device(batch, device=device)
args[0] = batch
output = model.training_step(*args)
else:
# Using CPU.
output = model.training_step(*args)
if isinstance(output, Result):
raise ValueError("TrainResult objects are not supported. Please "
"return a dictionary instead.")
# allow any mode to define training_step_end
# do something will all the dp outputs (like softmax)
if self.is_overridden("training_step_end", model):
output = model.training_step_end(output)
# Extract loss from output if dictionary.
try:
loss = output["loss"]
except Exception:
if isinstance(output, torch.Tensor):
loss = output
else:
raise RuntimeError(
"No `loss` value in the dictionary returned from "
"`model.training_step()`.")
# If output contains tensors, detach them all.
if isinstance(output, torch.Tensor):
output = output.detach()
elif isinstance(output, dict):
output = recursive_detach(output)
else:
raise TypeError("training_step returned invalid type. It must "
"return either a Tensor, Result, or dict.")
untouched_loss = loss.detach().clone()
with self.timers.record("grad"):
if self.use_fp16:
with self._amp.scale_loss(loss, optimizer) as scaled_loss:
model.backward(
self, scaled_loss, optimizer, optimizer_idx=0)
else:
model.backward(self, loss, optimizer, optimizer_idx=0)
if self.is_function_implemented("on_after_backward", model):
model.on_after_backward()
with self.timers.record("apply"):
model.optimizer_step(
epoch=epoch_idx,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=0)
model.on_before_zero_grad(optimizer)
model.optimizer_zero_grad(
epoch=epoch_idx,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=0)
if self.is_function_implemented("on_train_batch_end", model):
model.on_train_batch_end(
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
return {
"signal": 0,
"training_loss": untouched_loss.item(),
"raw_output": output,
# NUM_SAMPLES: len(batch)
}
def validate(self, val_iterator, info):
self.model.zero_grad()
self.model.eval()
torch.set_grad_enabled(False)
model = self.get_model()
if self.is_function_implemented("on_validation_epoch_start", model):
model.on_validation_epoch_start()
val_outputs = []
for batch_idx, batch in enumerate(val_iterator):
batch_info = {"batch_idx": batch_idx}
batch_info.update(info)
batch_output = self.validate_batch(batch, batch_info)
if batch_output is not None:
val_outputs.append(batch_output)
processed_outputs = None
if self.is_overridden("validation_epoch_end", model):
raw_outputs = [vo["raw_output"] for vo in val_outputs]
processed_outputs = model.training_epoch_end(raw_outputs)
if processed_outputs is not None:
if isinstance(processed_outputs, torch.Tensor):
return_output = {"val_loss": processed_outputs}
elif isinstance(processed_outputs, Result):
raise ValueError("Result objects are not supported. Please "
"return a dictionary instead.")
elif isinstance(processed_outputs, dict):
return_output = processed_outputs
else:
raise TypeError("validation_epoch_end returned an invalid "
"type. It must return a Tensor, Result, "
"or dict.")
else:
# User did not override training_epoch_end
assert isinstance(val_outputs, list)
# Use AverageMeterCollection util to reduce results.
meter_collection = AverageMeterCollection()
for v in val_outputs:
num_samples = v.pop(NUM_SAMPLES, 1)
raw_output = v["raw_output"]
if isinstance(raw_output, dict):
meter_collection.update(raw_output, num_samples)
elif isinstance(raw_output, torch.Tensor):
meter_collection.update({
"val_loss": raw_output.item()
}, num_samples)
return_output = meter_collection.summary()
if self.is_function_implemented("on_validation_epoch_end", model):
model.on_validation_epoch_end()
# Set back to True so training will work.
torch.set_grad_enabled(True)
return return_output
def validate_batch(self, batch, batch_info):
model = self.get_model()
batch_idx = batch_info["batch_idx"]
if self.is_overridden("on_validation_batch_start", model):
model.on_validation_batch_start(
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
args = [batch, batch_idx]
with self.timers.record("eval_fwd"):
if self._is_distributed:
# Use the DDP wrapped model (self.model).
output = self.model(*args)
elif self.use_gpu:
# Using single GPU.
device = self.device
batch = model.transfer_batch_to_device(batch, device=device)
args[0] = batch
output = model.validation_step(*args)
else:
# Using CPU.
output = model.validation_step(*args)
if isinstance(output, Result):
raise ValueError("EvalResult objects are not supported. Please "
"return a dictionary instead.")
if self.is_overridden("on_validation_step_end", model):
output = model.validation_step_end(output)
if self.is_function_implemented("on_validation_batch_end", model):
model.on_validation_batch_end(
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
return {
"raw_output": output,
# NUM_SAMPLES: len(batch)
}
def state_dict(self):
state_dict = {}
self.get_model().on_save_checkpoint(checkpoint=state_dict)
return state_dict
def load_state_dict(self, state_dict):
self.get_model().on_load_checkpoint(checkpoint=state_dict)
def _get_train_loader(self):
if not hasattr(self, "_train_loader") or \
self._train_loader is None:
raise RuntimeError("Training Operator does not have any "
"registered training loader. Make sure "
"to pass in a training loader to "
"TrainingOperator.from_ptl or implement "
"train_dataloader in your LightningModule.")
return self._train_loader
def _get_validation_loader(self):
if not hasattr(self, "_validation_loader") or \
self._validation_loader is None:
raise RuntimeError("Training Operator does not have any "
"registered validation loader. Make sure "
"to pass in a validation loader to "
"TrainingOperator.from_ptl or implement "
"val_dataloader in your LightningModule.")
return self._validation_loader
+13 -43
View File
@@ -1,6 +1,8 @@
import logging
import io
import itertools
import ray
import torch
from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS
@@ -50,6 +52,8 @@ class TorchRunner:
self.training_operator = self.training_operator_cls(
self.config,
world_rank=0,
local_rank=0,
is_distributed=False,
use_gpu=self.use_gpu,
use_fp16=self.use_fp16,
use_tqdm=self.use_tqdm,
@@ -69,6 +73,7 @@ class TorchRunner:
info.update({
NUM_STEPS: num_steps,
USE_FP16: self.use_fp16,
"epoch_idx": self.epochs,
})
with self.timers.record("train_epoch"):
if iterator is None:
@@ -94,10 +99,6 @@ class TorchRunner:
def validate(self, num_steps=None, profile=False, info=None):
"""Evaluates the model on the validation data set."""
if self.validation_loader is None:
raise ValueError("No validation dataloader provided. Make sure"
"you pass in a validation_loader to "
"TrainingOperator.register_data.")
info = info or {}
self._toggle_profiling(profile=profile)
validation_loader = self.validation_loader
@@ -204,62 +205,31 @@ class TorchRunner:
"""Getter method. Needed for remote actor calls."""
return self.models
def get_node_ip(self):
return ray.services.get_node_ip_address()
@property
def models(self):
if not hasattr(self.training_operator, "_original_models"):
raise RuntimeError("Training Operator does not have any "
"registered models. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self.training_operator._original_models
return self.training_operator._get_original_models()
@property
def optimizers(self):
if not hasattr(self.training_operator, "_optimizers"):
raise RuntimeError("Training Operator does not have any "
"registered optimizers. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self.training_operator._optimizers
return self.training_operator._get_optimizers()
@property
def schedulers(self):
if not hasattr(self.training_operator, "_schedulers"):
raise RuntimeError("Training Operator does not have any "
"registered schedulers. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self.training_operator._schedulers
return self.training_operator._get_schedulers()
@property
def train_loader(self):
if not hasattr(self.training_operator, "_train_loader"):
logger.warning("Training Operator does not have any "
"registered train loader. If this is "
"unexepected, make sure to call "
"self.register_data(...) inside the setup method "
"of your Training Operator.")
return None
return self.training_operator._train_loader
return self.training_operator._get_train_loader()
@property
def validation_loader(self):
if not hasattr(self.training_operator, "_validation_loader"):
logger.warning("Training Operator does not have any "
"registered validation loader. If this is "
"unexepected, make sure to call "
"self.register_data(...) inside the setup method "
"of your Training Operator.")
return None
return self.training_operator._validation_loader
return self.training_operator._get_validation_loader()
@property
def criterion(self):
if not hasattr(self.training_operator, "_criterion"):
raise RuntimeError("Training Operator does not have any "
"registered criterion. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self.training_operator._criterion
@property
+1 -1
View File
@@ -439,7 +439,7 @@ class TorchTrainer:
}
for stat_key in worker_stats[0]:
if isinstance(worker_stats[0], numbers.Number):
if isinstance(worker_stats[0][stat_key], numbers.Number):
stats[stat_key] = np.nanmean(
[s.get(stat_key, np.nan) for s in worker_stats])
else:
+189 -38
View File
@@ -14,6 +14,7 @@ from ray.util.sgd.torch.constants import (
NUM_STEPS,
SCHEDULER_STEP_BATCH,
)
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler, DataLoader, IterableDataset
@@ -78,7 +79,6 @@ class TrainingOperator:
train_loader=train_loader,
validation_loader=val_loader)
trainer = TorchTrainer(
training_operator_cls=MyTrainingOperator,
config={"batch_size": 32},
@@ -119,18 +119,21 @@ class TrainingOperator:
def __init__(self,
config,
world_rank,
local_rank,
is_distributed=False,
device_ids=None,
use_gpu=False,
use_fp16=False,
use_tqdm=False,
apex_args=None,
wrap_ddp=False,
wrap_distributed_sampler=False,
add_dist_sampler=False,
scheduler_step_freq=None):
# You are not expected to override this method.
self._world_rank = world_rank
self._local_rank = local_rank
self._config = config
self._is_distributed = is_distributed
self._use_fp16 = use_fp16
self._device_ids = device_ids
self._use_gpu = use_gpu and torch.cuda.is_available()
@@ -141,7 +144,6 @@ class TrainingOperator:
self.global_step = 0
self._apex_args = apex_args if apex_args else {}
self._wrap_ddp = wrap_ddp
self._wrap_distributed_sampler = wrap_distributed_sampler
self._add_dist_sampler = add_dist_sampler
self._scheduler_step_freq = scheduler_step_freq
@@ -152,6 +154,28 @@ class TrainingOperator:
"""Passes in the timers from the Runner."""
self.timers = timers
def _configure_amp(self, amp, models, optimizers):
models, optimizers = amp.initialize(models, optimizers,
**self._apex_args)
return models, optimizers
def _configure_ddp(self, models, device_ids):
return [
DistributedDataParallel(model, device_ids=device_ids)
for model in models
]
def _return_items(self, items, original_items):
"""Helper method to return items in same format as original_items."""
if isinstance(original_items, tuple):
return tuple(items)
elif isinstance(original_items, Iterable):
# Items is already a list.
return items
else:
assert len(items) == 1
return items[0]
def setup(self, config):
"""Override this method to implement operator setup.
@@ -218,7 +242,6 @@ class TrainingOperator:
Returns:
Tuple of model, optimizer, criterion if not None, and scheduler
if not None.
"""
return_vals = []
logger.debug("Registering models.")
@@ -244,7 +267,10 @@ class TrainingOperator:
if not isinstance(self._schedulers, Iterable):
self._schedulers = [self._schedulers]
else:
self._schedulers = None
if isinstance(schedulers, Iterable):
self._schedulers = []
else:
self._schedulers = None
if criterion:
logger.debug("Registering loss.")
@@ -257,28 +283,19 @@ class TrainingOperator:
if self.use_fp16 and amp:
logger.debug("Setting up Apex.")
self._original_models, self._optimizers = amp.initialize(
self._original_models, self._optimizers, **self._apex_args)
self._amp = amp
self._original_models, self._optimizers = self._configure_amp(
self._amp, self._original_models, self._optimizers)
if self._wrap_ddp:
logging.debug("Setting up DDP for models.")
self._models = [
DistributedDataParallel(model, device_ids=self.device_ids)
for model in self._original_models
]
self._models = self._configure_ddp(
models=self._original_models, device_ids=self.device_ids)
else:
self._models = self._original_models
if len(self._models) == 1:
return_vals.append(self._models[0])
else:
return_vals.append(self._models)
if len(self._optimizers) == 1:
return_vals.append(self._optimizers[0])
else:
return_vals.append(self._optimizers)
return_vals.append(self._return_items(self._models, models))
return_vals.append(self._return_items(self._optimizers, optimizers))
if self._criterion is not None:
return_vals.append(self._criterion)
@@ -290,10 +307,8 @@ class TrainingOperator:
"are registering schedulers. Set this to "
"'manual' if you will be manually stepping "
"the schedulers.")
if len(self._schedulers) == 1:
return_vals.append(self._schedulers[0])
else:
return_vals.append(self._schedulers)
return_vals.append(
self._return_items(self._schedulers, schedulers))
return tuple(return_vals)
@@ -348,8 +363,7 @@ class TrainingOperator:
self._train_loader = train_loader
self._validation_loader = validation_loader
if self._wrap_distributed_sampler:
logging.debug("Wrapping data loaders with DistributedSampler.")
if self._is_distributed:
def with_sampler(loader):
# Automatically set the DistributedSampler
@@ -373,11 +387,15 @@ class TrainingOperator:
if should_wrap_dataloader(self._train_loader):
if self._add_dist_sampler:
logging.debug("Wrapping train data loader with "
"DistributedSampler.")
self._train_loader = with_sampler(self._train_loader)
if self._validation_loader is not None and should_wrap_dataloader(
self._validation_loader):
if self._add_dist_sampler:
logging.debug("Wrapping validation data loader with "
"DistributedSampler.")
self._validation_loader = with_sampler(
self._validation_loader)
@@ -665,6 +683,90 @@ class TrainingOperator:
state_dict (dict): State dict as returned by the operator. """
pass
def _get_original_models(self):
if not hasattr(self, "_original_models"):
raise RuntimeError("Training Operator does not have any "
"registered models. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self._original_models
def _get_optimizers(self):
if not hasattr(self, "_optimizers"):
raise RuntimeError("Training Operator does not have any "
"registered optimizers. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self._optimizers
def _get_schedulers(self):
if not hasattr(self, "_schedulers"):
raise RuntimeError("Training Operator does not have any "
"registered schedulers. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self._schedulers
def _get_train_loader(self):
if not hasattr(self, "_train_loader") or \
self._train_loader is None:
raise RuntimeError(
"Training Operator does not have any "
"registered train loader. If this is "
"unexepected, make sure to call "
"self.register_data(...) inside the setup method "
"of your Training Operator.")
return self._train_loader
def _get_validation_loader(self):
if not hasattr(self, "_validation_loader") or \
self._validation_loader is None:
raise RuntimeError(
"Training Operator does not have any "
"registered validation loader. If this is "
"unexepected, make sure to call "
"self.register_data(...) inside the setup method "
"of your Training Operator.")
return self._validation_loader
def _get_criterion(self):
if not hasattr(self, "_criterion"):
raise RuntimeError("Training Operator does not have any "
"registered criterion. Are you calling "
"self.register(...) inside the setup method "
"of your Training Operator?")
return self._criterion
@classmethod
def from_ptl(cls,
lightning_module_cls,
train_dataloader=None,
val_dataloader=None):
"""Creates a TrainingOperator from a Pytorch Lightning Module.
Args:
lightning_module_cls: Your LightningModule class. An object of
this class will get instantiated on each worker.
train_dataloader: The data loader to use for training. If None
is provided, LightningModule.train_dataloader will be used
instead.
val_dataloader: The data loader to use for validation. If None
is provided, LightningModule.val_dataloader will be used
instead.
Returns:
A TrainingOperator class properly configured given the
LightningModule.
"""
from ray.util.sgd.torch.ptl_operator import LightningOperator
class CustomLightningOperator(LightningOperator):
_lightning_module_cls = lightning_module_cls
_train_dataloader = train_dataloader
_val_dataloader = val_dataloader
return CustomLightningOperator
@classmethod
def from_creators(cls,
model_creator,
@@ -749,6 +851,11 @@ class TrainingOperator:
"""int: The rank of the parent runner. Always 0 if not distributed."""
return self._world_rank
@property
def local_rank(self):
"""int: Local rank of parent runner. Always 0 if not distributed."""
return self._local_rank
@property
def use_gpu(self):
"""Returns True if cuda is available and use_gpu is True."""
@@ -849,29 +956,73 @@ class CreatorOperator(TrainingOperator):
kwargs["criterion"] = criterion
state = self.register(**kwargs)
self.models, self.optimizers = state[:2]
if isinstance(self.models, tuple):
self.model = self.models[0]
self._registered_models, self._registered_optimizers = state[:2]
if isinstance(self.models, (list, tuple)):
logger.info("Multiple models have been registered. If custom "
"training methods are not provided, only the first "
"model will be used.")
self._registered_model = self.models[0]
else:
self.model = self.models
self._registered_model = self.models
if isinstance(self.optimizers, tuple):
self.optimizer = self.optimizers[0]
if isinstance(self.optimizers, (list, tuple)):
logger.info("Multiple optimizers have been registered. If custom "
"training methods are not provided, only the first "
"optimizer will be used.")
self._reigstered_optimizer = self.optimizers[0]
else:
self.optimizer = self.optimizers
self._registered_optimizer = self.optimizers
if len(state) >= 3:
self.criterion = state[2]
self._registered_criterion = state[2]
if len(state) == 4:
self.schedulers = state[3]
if isinstance(self.schedulers, tuple):
self.scheduler = self.schedulers[0]
self._registered_schedulers = state[3]
if isinstance(self.schedulers, (list, tuple)):
logger.info("Multiple schedulers have been registered. If "
"custom training methods are not provided, "
"only the first scheduler will be used.")
self._registered_scheduler = self.schedulers[0]
else:
self.scheduler = self.schedulers
self._registered_scheduler = self.schedulers
self.register_data(
train_loader=train_loader, validation_loader=validation_loader)
@property
def model(self):
"""First or only model created by the provided ``model_creator``."""
return self._registered_model
@property
def optimizer(self):
"""First or only optimizer(s) created by the ``optimizer_creator``."""
return self._registered_optimizer
@property
def scheduler(self):
"""First or only scheduler(s) created by the ``scheduler_creator``."""
return self._registered_scheduler
@property
def criterion(self):
"""Criterion created by the provided ``loss_creator``."""
return self._registered_criterion
@property
def models(self):
"""List of models created by the provided ``model_creator``."""
return self._registered_models
@property
def optimizers(self):
"""List of optimizers created by the ``optimizer_creator``."""
return self._registered_optimizers
@property
def schedulers(self):
"""List of schedulers created by the ``scheduler_creator``."""
return self._registered_schedulers
def get_test_operator(operator_cls):
class _TestingOperator(operator_cls):
+22
View File
@@ -1,6 +1,7 @@
import io
import logging
import time
from collections import defaultdict
from datetime import timedelta
import ray
@@ -191,6 +192,19 @@ class RemoteWorkerGroup(WorkerGroupInterface):
]
return remote_operator_setups
def _setup_local_rank(self, rank_counter_dict=None):
"""Sets local rank for all workers."""
if rank_counter_dict is None:
rank_counter_dict = defaultdict(int)
node_ips = ray.get(
[w.get_node_ip.remote() for w in self.remote_workers])
futures = []
for ip, worker in zip(node_ips, self.remote_workers):
rank = rank_counter_dict[ip]
futures.append(worker.set_local_rank.remote(rank))
rank_counter_dict[ip] += 1
return futures
def start_workers(self, num_workers):
logger.debug(f"start_workers: Setting %d workers." % num_workers)
if num_workers == 1:
@@ -212,6 +226,8 @@ class RemoteWorkerGroup(WorkerGroupInterface):
self._setup_process_group(
address=address, world_size=num_workers))
ray.get(self._setup_local_rank())
ray.get(self._setup_operator())
def _apply_all_operators(self, fn):
@@ -454,6 +470,12 @@ class LocalWorkerGroup(WorkerGroupInterface):
timeout=timedelta(self._timeout_s))
ray.get(remote_pgs)
local_node_ip = ray.services.get_node_ip_address()
rank_dict = defaultdict(int)
self.local_worker.set_local_rank(local_rank=0)
rank_dict[local_node_ip] += 1
self.remote_worker_group._setup_local_rank(rank_dict)
remote_operators = self.remote_worker_group._setup_operator()
self.local_worker.setup_operator()
ray.get(remote_operators)