mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 20:40:09 +08:00
[Ray SGD] LightningModule integration + MNIST Example (#11042)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -2,6 +2,14 @@
|
||||
# Tests from the python/ray/util/sgd/tests directory.
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
py_test(
|
||||
name = "test_ptl",
|
||||
size = "small",
|
||||
srcs = ["tests/test_ptl.py"],
|
||||
tags = ["exclusive", "pytorch-lightning", "pytorch"],
|
||||
deps = [":sgd_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_tensorflow",
|
||||
size = "small",
|
||||
@@ -214,6 +222,15 @@ py_test(
|
||||
args = ["--no-gpu", "--mock-data", "--smoke-test", "--ray-num-workers=2", "--model=mobilenetv3_small_075", "data"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "mnist-ptl",
|
||||
size = "small",
|
||||
srcs = ["torch/examples/pytorch-lightning/mnist-ptl.py"],
|
||||
tags = ["exclusive", "pytorch", "pytorch-lightning"],
|
||||
deps = [":sgd_lib"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
# This is a dummy test dependency that causes the above tests to be
|
||||
# re-run if any of these files changes.
|
||||
py_library(
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
from ray.util.sgd.utils import BATCH_COUNT
|
||||
import torch.distributed as dist
|
||||
from pytorch_lightning import LightningModule
|
||||
from ray.util.sgd import TorchTrainer
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
from ray.util.sgd.torch.examples.train_example import \
|
||||
optimizer_creator, data_creator, scheduler_creator, model_creator
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_2_cpus():
|
||||
address_info = ray.init(num_cpus=2)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
# Ensure that tests don't ALL fail
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
class PTL_Module(LightningModule):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
if "layer" in config:
|
||||
self.layer = copy.deepcopy(config["layer"])
|
||||
else:
|
||||
self.layer = model_creator(self.config)
|
||||
|
||||
self.rand_int = np.random.randint(10)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer.forward(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optimizer_creator(self, self.config)
|
||||
scheduler = scheduler_creator(optimizer, self.config)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
output = self(x)
|
||||
loss = self.loss(output, y)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
output = self(x)
|
||||
loss = self.loss(output, y)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
num_correct = (predicted == y).sum().item()
|
||||
num_samples = y.size(0)
|
||||
return {"val_loss": loss.item(), "val_acc": num_correct / num_samples}
|
||||
|
||||
def setup(self, stage):
|
||||
self.train_loader, self.val_loader = data_creator(self.config)
|
||||
self.loss = nn.MSELoss()
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.train_loader
|
||||
|
||||
def val_dataloader(self):
|
||||
return self.val_loader
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint["int"] = self.rand_int
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
self.rand_int = checkpoint["int"]
|
||||
|
||||
|
||||
Operator = TrainingOperator.from_ptl(PTL_Module)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_single_step(ray_start_2_cpus, use_local): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=Operator,
|
||||
num_workers=1,
|
||||
use_local=use_local,
|
||||
use_gpu=False)
|
||||
metrics = trainer.train(num_steps=1)
|
||||
assert metrics[BATCH_COUNT] == 1
|
||||
|
||||
val_metrics = trainer.validate(num_steps=1)
|
||||
assert val_metrics[BATCH_COUNT] == 1
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_train(ray_start_2_cpus, num_workers, use_local): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=Operator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local,
|
||||
use_gpu=False)
|
||||
for i in range(3):
|
||||
train_loss1 = trainer.train()["train_loss"]
|
||||
validation_loss1 = trainer.validate()["val_loss"]
|
||||
|
||||
for i in range(3):
|
||||
train_loss2 = trainer.train()["train_loss"]
|
||||
validation_loss2 = trainer.validate()["val_loss"]
|
||||
|
||||
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
|
||||
assert validation_loss2 <= validation_loss1, (validation_loss2,
|
||||
validation_loss1)
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_workers, use_local,
|
||||
tmp_path): # noqa: F811
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=Operator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local)
|
||||
trainer1.train()
|
||||
checkpoint_path = os.path.join(tmp_path, "checkpoint")
|
||||
trainer1.save(checkpoint_path)
|
||||
|
||||
model1 = trainer1.get_model()
|
||||
ints1 = trainer1.apply_all_operators(lambda op: op.get_model().rand_int)[0]
|
||||
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
training_operator_cls=Operator,
|
||||
num_workers=num_workers,
|
||||
use_local=use_local)
|
||||
trainer2.load(checkpoint_path)
|
||||
|
||||
model2 = trainer2.get_model()
|
||||
ints2 = trainer2.apply_all_operators(lambda op: op.get_model().rand_int)
|
||||
|
||||
model1_state_dict = model1.state_dict()
|
||||
model2_state_dict = model2.state_dict()
|
||||
|
||||
assert set(model1_state_dict.keys()) == set(model2_state_dict.keys())
|
||||
|
||||
for k in model1_state_dict:
|
||||
assert torch.equal(model1_state_dict[k], model2_state_dict[k])
|
||||
for i in ints2:
|
||||
assert i == ints1
|
||||
trainer2.shutdown()
|
||||
|
||||
|
||||
class CorrectnessOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
model = PTL_Module(config)
|
||||
opt = optimizer_creator(model, config)
|
||||
scheduler = scheduler_creator(opt, config)
|
||||
self.model, self.optimizer, self.criterion, self.scheduler = \
|
||||
self.register(
|
||||
models=model, optimizers=opt, criterion=nn.MSELoss(),
|
||||
schedulers=scheduler)
|
||||
|
||||
train_loader, val_loader = data_creator(config)
|
||||
self.register_data(
|
||||
train_loader=train_loader, validation_loader=val_loader)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_correctness(ray_start_2_cpus, num_workers, use_local):
|
||||
layer = nn.Linear(1, 1)
|
||||
ptl_op = TrainingOperator.from_ptl(PTL_Module)
|
||||
trainer1 = TorchTrainer(
|
||||
training_operator_cls=ptl_op,
|
||||
config={
|
||||
"layer": layer,
|
||||
"data_size": 3,
|
||||
"batch_size": 1
|
||||
},
|
||||
num_workers=num_workers,
|
||||
use_local=use_local)
|
||||
train1_stats = trainer1.train()
|
||||
val1_stats = trainer1.validate()
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
training_operator_cls=CorrectnessOperator,
|
||||
scheduler_step_freq="manual",
|
||||
config={
|
||||
"layer": layer,
|
||||
"data_size": 3,
|
||||
"batch_size": 1
|
||||
},
|
||||
num_workers=num_workers,
|
||||
use_local=use_local)
|
||||
train2_stats = trainer2.train()
|
||||
val2_stats = trainer2.validate()
|
||||
trainer2.shutdown()
|
||||
|
||||
assert train1_stats["train_loss"] == train2_stats["train_loss"]
|
||||
assert val1_stats["val_loss"] == val2_stats["val_loss"]
|
||||
assert val1_stats["val_acc"] == val2_stats["val_accuracy"]
|
||||
@@ -3,8 +3,6 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from ray.util.sgd.torch.utils import setup_process_group
|
||||
|
||||
import ray
|
||||
@@ -42,6 +40,7 @@ class DistributedTorchRunner(TorchRunner):
|
||||
self.wrap_ddp = wrap_ddp
|
||||
self.add_dist_sampler = add_dist_sampler
|
||||
self.world_rank = None
|
||||
self.local_rank = None
|
||||
|
||||
def setup_address(self):
|
||||
return setup_address()
|
||||
@@ -61,6 +60,14 @@ class DistributedTorchRunner(TorchRunner):
|
||||
setup_process_group(
|
||||
url, world_rank, world_size, timeout, backend=self.backend)
|
||||
|
||||
def set_local_rank(self, local_rank):
|
||||
"""Sets the local rank of this runner.
|
||||
|
||||
Args:
|
||||
local_rank (int): the index of the runner on its node.
|
||||
"""
|
||||
self.local_rank = local_rank
|
||||
|
||||
def setup_operator(self):
|
||||
"""Runs distributed coordination components.
|
||||
|
||||
@@ -74,13 +81,14 @@ class DistributedTorchRunner(TorchRunner):
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
world_rank=self.world_rank,
|
||||
local_rank=self.local_rank,
|
||||
is_distributed=True,
|
||||
device_ids=device_ids,
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm,
|
||||
apex_args=self.apex_args,
|
||||
wrap_ddp=self.wrap_ddp,
|
||||
wrap_distributed_sampler=True,
|
||||
add_dist_sampler=self.add_dist_sampler,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
@@ -88,45 +96,21 @@ class DistributedTorchRunner(TorchRunner):
|
||||
"""Needed for SyncBatchNorm, which needs 1 GPU per process."""
|
||||
return [0]
|
||||
|
||||
def _wrap_dataloaders(self):
|
||||
def with_sampler(loader):
|
||||
# Automatically set the DistributedSampler
|
||||
data_loader_args = {
|
||||
"dataset": loader.dataset,
|
||||
"batch_size": loader.batch_size,
|
||||
"shuffle": False,
|
||||
"num_workers": loader.num_workers,
|
||||
"collate_fn": loader.collate_fn,
|
||||
"pin_memory": loader.pin_memory,
|
||||
"drop_last": loader.drop_last,
|
||||
"timeout": loader.timeout,
|
||||
"worker_init_fn": loader.worker_init_fn,
|
||||
"sampler": DistributedSampler(loader.dataset)
|
||||
}
|
||||
return DataLoader(**data_loader_args)
|
||||
|
||||
def should_wrap_dataloader(loader):
|
||||
return (isinstance(loader, DataLoader)
|
||||
and not isinstance(loader.dataset, IterableDataset))
|
||||
|
||||
if should_wrap_dataloader(self.train_loader):
|
||||
if self.add_dist_sampler:
|
||||
self.train_loader = with_sampler(self.train_loader)
|
||||
|
||||
if self.validation_loader is not None and should_wrap_dataloader(
|
||||
self.validation_loader):
|
||||
if self.add_dist_sampler:
|
||||
self.validation_loader = with_sampler(self.validation_loader)
|
||||
|
||||
def train_epoch(self, **kwargs):
|
||||
def train_epoch(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
info=None,
|
||||
iterator=None):
|
||||
"""Runs a training epoch and updates the model parameters.
|
||||
|
||||
Automatically sets epoch of sampler if possible.
|
||||
"""
|
||||
if hasattr(self.train_loader, "sampler") and hasattr(
|
||||
if iterator is None and hasattr(self.train_loader, "sampler") and \
|
||||
hasattr(
|
||||
self.train_loader.sampler, "set_epoch"):
|
||||
self.train_loader.sampler.set_epoch(self.epochs)
|
||||
return super(DistributedTorchRunner, self).train_epoch(**kwargs)
|
||||
return super(DistributedTorchRunner, self).train_epoch(
|
||||
num_steps=num_steps, profile=profile, info=info, iterator=iterator)
|
||||
|
||||
def shutdown(self):
|
||||
"""Attempts to shut down the worker."""
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from ray.util.sgd import TorchTrainer
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
from torch.nn import functional as F
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision.datasets import MNIST
|
||||
import os
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class LitMNIST(LightningModule):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
# mnist images are (1, 28, 28) (channels, width, height)
|
||||
self.layer_1 = torch.nn.Linear(28 * 28, 128)
|
||||
self.layer_2 = torch.nn.Linear(128, 256)
|
||||
self.layer_3 = torch.nn.Linear(256, 10)
|
||||
|
||||
self.config = config
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, channels, width, height = x.size()
|
||||
|
||||
# (b, 1, 28, 28) -> (b, 1*28*28)
|
||||
x = x.view(batch_size, -1)
|
||||
x = self.layer_1(x)
|
||||
x = torch.relu(x)
|
||||
x = self.layer_2(x)
|
||||
x = torch.relu(x)
|
||||
x = self.layer_3(x)
|
||||
|
||||
x = torch.log_softmax(x, dim=1)
|
||||
return x
|
||||
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr=self.config["lr"])
|
||||
|
||||
def setup(self, stage):
|
||||
# transforms for images
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
])
|
||||
|
||||
# prepare transforms standard to MNIST
|
||||
mnist_train = MNIST(
|
||||
os.getcwd(), train=True, download=True, transform=transform)
|
||||
|
||||
self.mnist_train, self.mnist_val = random_split(
|
||||
mnist_train, [55000, 5000])
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
self.mnist_train, batch_size=self.config["batch_size"])
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.mnist_val, batch_size=self.config["batch_size"])
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = F.nll_loss(logits, y)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
logits = self(x)
|
||||
loss = F.nll_loss(logits, y)
|
||||
_, predicted = torch.max(logits.data, 1)
|
||||
num_correct = (predicted == y).sum().item()
|
||||
num_samples = y.size(0)
|
||||
return {"val_loss": loss.item(), "val_acc": num_correct / num_samples}
|
||||
|
||||
|
||||
def train_mnist(num_workers=1, use_gpu=False, num_epochs=5):
|
||||
Operator = TrainingOperator.from_ptl(LitMNIST)
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=Operator,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"lr": 1e-3,
|
||||
"batch_size": 64
|
||||
},
|
||||
use_gpu=use_gpu,
|
||||
use_tqdm=True,
|
||||
)
|
||||
for i in range(num_epochs):
|
||||
stats = trainer.train()
|
||||
print(stats)
|
||||
|
||||
print(trainer.validate())
|
||||
print("Saving model checkpoint to ./model.pt")
|
||||
trainer.save("./model.pt")
|
||||
print("Model Checkpointed!")
|
||||
trainer.shutdown()
|
||||
print("success!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Ray")
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="How many epochs to train "
|
||||
"for.")
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Finish quickly for testing.")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
import ray
|
||||
if args.smoke_test:
|
||||
ray.init(num_cpus=2)
|
||||
args.num_epochs = 1
|
||||
else:
|
||||
ray.init(address=args.address)
|
||||
train_mnist(
|
||||
num_workers=args.num_workers,
|
||||
use_gpu=args.use_gpu,
|
||||
num_epochs=args.num_epochs)
|
||||
@@ -0,0 +1,502 @@
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.overrides.data_parallel import \
|
||||
LightningDistributedDataParallel
|
||||
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
import pytorch_lightning as ptl
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from ray.util.sgd.torch import TrainingOperator
|
||||
from ray.util.sgd.torch.constants import NUM_STEPS, SCHEDULER_STEP_BATCH, \
|
||||
SCHEDULER_STEP_EPOCH
|
||||
from ray.util.sgd.utils import AverageMeterCollection, NUM_SAMPLES
|
||||
|
||||
tqdm = None
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
TrainerOptimizersMixin):
|
||||
def _configure_amp(self, amp, models, optimizers):
|
||||
assert len(models) == 1
|
||||
model = models[0]
|
||||
assert isinstance(model, ptl.LightningModule)
|
||||
amp_level = self._apex_args.get("opt_level", "O2")
|
||||
model, optimizers = model.configure_apex(
|
||||
amp, model, optimizers, amp_level=amp_level)
|
||||
return [model], optimizers
|
||||
|
||||
def _configure_ddp(self, models, device_ids):
|
||||
assert len(models) == 1
|
||||
model = models[0]
|
||||
assert isinstance(model, ptl.LightningModule)
|
||||
# This will default to LightningDistributedDataParallel.
|
||||
model = model.configure_ddp(model=model, device_ids=device_ids)
|
||||
return [model]
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""The LightningModule to use for training.
|
||||
|
||||
The returned model is wrapped in DDP if using distributed training.
|
||||
"""
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def scheduler_dicts(self):
|
||||
"""Returns list of scheduler dictionaries.
|
||||
|
||||
List is empty if no schedulers are returned in the
|
||||
configure_optimizers method of your LightningModule. Default
|
||||
configuration is used if configure_optimizers returns scheduler
|
||||
objects instead of scheduler dicts. See
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#configure-optimizers
|
||||
"""
|
||||
return self._scheduler_dicts
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
"""Returns list of optimizers as returned by configure_optimizers."""
|
||||
return self._optimizers
|
||||
|
||||
@property
|
||||
def schedulers(self):
|
||||
"""Returns list of schedulers as returned by configure_optimizers.
|
||||
|
||||
List is empty if no schedulers are returned in configure_optimizers.
|
||||
"""
|
||||
return self._schedulers
|
||||
|
||||
def get_model(self):
|
||||
"""Returns original LightningModule, not wrapped in DDP."""
|
||||
if isinstance(self.model, LightningDistributedDataParallel):
|
||||
return self.model.module
|
||||
else:
|
||||
return self.model
|
||||
|
||||
def setup(self, config):
|
||||
# Pass in config if ptl_module accepts it.
|
||||
ptl_class = self.__class__._lightning_module_cls
|
||||
if not issubclass(ptl_class, ptl.LightningModule):
|
||||
raise TypeError("Argument must be subclass of "
|
||||
"pytorch_lightning.LightningModule. Got class {} "
|
||||
"instead.".format(ptl_class))
|
||||
if "config" in inspect.signature(ptl_class.__init__).parameters:
|
||||
ptl_module = ptl_class(config=config)
|
||||
else:
|
||||
ptl_module = ptl_class()
|
||||
|
||||
# This is needed for LightningDistributedDataParallel.
|
||||
ptl_module.testing = False
|
||||
|
||||
# Call on_fit_start on instantiation.
|
||||
if self.is_function_implemented("on_fit_start", ptl_module):
|
||||
ptl_module.on_fit_start()
|
||||
|
||||
# Only run data preparation once per node.
|
||||
if self.local_rank == 0 and self.is_function_implemented(
|
||||
"prepare_data", ptl_module):
|
||||
ptl_module.prepare_data()
|
||||
|
||||
# Call model.setup.
|
||||
ptl_module.setup("fit")
|
||||
|
||||
if not self.is_overridden("configure_optimizers", ptl_module):
|
||||
raise MisconfigurationException(
|
||||
"No `configure_optimizers()` method defined.")
|
||||
|
||||
optimizers, self._scheduler_dicts, optimizer_frequencies = \
|
||||
self.init_optimizers(model=ptl_module)
|
||||
|
||||
if len(optimizer_frequencies) > 0:
|
||||
logger.warning("Optimizer frequencies will be ignored. When "
|
||||
"passing in multiple optimizers, you should "
|
||||
"implement your own custom training loop.")
|
||||
|
||||
lr_schedulers = []
|
||||
for scheduler in self.scheduler_dicts:
|
||||
if isinstance(scheduler, dict):
|
||||
# A scheduler dictionary is passed in.
|
||||
if "reduce_on_plateau" in scheduler and "monitor" in \
|
||||
scheduler and scheduler["reduce_on_plateau"] is True:
|
||||
logger.info(
|
||||
"reduce_on_plateau and monitor will be "
|
||||
"ignored "
|
||||
"from the scheduler dict {}. To update a "
|
||||
"ReduceLROnPlateau scheduler, you should use "
|
||||
"TorchTrainer.update_schedulers.".format(scheduler))
|
||||
if "frequency" in scheduler and scheduler["frequency"] > 1:
|
||||
logger.info("frequency will be ignored from the "
|
||||
"scheduler dict {}.".format(scheduler))
|
||||
lr_schedulers.append(scheduler["scheduler"])
|
||||
else:
|
||||
lr_schedulers.append(scheduler)
|
||||
|
||||
# Set this so register doesn't complain.
|
||||
self._scheduler_step_freq = "ptl"
|
||||
ddp_model, self._optimizers, self._schedulers = self.register(
|
||||
models=[ptl_module],
|
||||
optimizers=optimizers,
|
||||
schedulers=lr_schedulers)
|
||||
|
||||
assert len(ddp_model) == 1
|
||||
self._model = ddp_model[0]
|
||||
|
||||
model = self.get_model()
|
||||
if self.is_function_implemented("on_pretrain_routine_start", model):
|
||||
model.on_pretrain_routine_start()
|
||||
|
||||
train_data_loader = None
|
||||
if self.__class__._train_dataloader:
|
||||
train_data_loader = self.__class__._train_dataloader
|
||||
elif self.is_function_implemented("train_dataloader", model):
|
||||
train_data_loader = model.train_dataloader()
|
||||
|
||||
val_data_loader = None
|
||||
if self.__class__._val_dataloader:
|
||||
val_data_loader = self.__class__._val_dataloader
|
||||
elif self.is_function_implemented("val_dataloader", model):
|
||||
val_data_loader = model.val_dataloader()
|
||||
|
||||
self.register_data(
|
||||
train_loader=train_data_loader, validation_loader=val_data_loader)
|
||||
|
||||
def train_epoch(self, iterator, info):
|
||||
model = self.get_model()
|
||||
|
||||
# Enable train mode.
|
||||
self.model.train()
|
||||
|
||||
# Enable gradients.
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
if self.is_function_implemented("on_train_epoch_start", model):
|
||||
model.on_train_epoch_start()
|
||||
|
||||
if self.use_tqdm and self.world_rank == 0:
|
||||
desc = ""
|
||||
if info is not None and "epoch_idx" in info:
|
||||
if "num_epochs" in info:
|
||||
desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e"
|
||||
else:
|
||||
desc = f"{info['epoch_idx'] + 1}e"
|
||||
|
||||
# TODO: Implement len for Dataset?
|
||||
total = info[NUM_STEPS]
|
||||
if total is None:
|
||||
if hasattr(iterator, "__len__"):
|
||||
total = len(iterator)
|
||||
|
||||
_progress_bar = tqdm(
|
||||
total=total, desc=desc, unit="batch", leave=False)
|
||||
|
||||
# Output for each batch.
|
||||
epoch_outputs = []
|
||||
|
||||
for batch_idx, batch in enumerate(iterator):
|
||||
batch_info = {
|
||||
"batch_idx": batch_idx,
|
||||
"global_step": self.global_step
|
||||
}
|
||||
batch_info.update(info)
|
||||
batch_output = self.train_batch(batch, batch_info=batch_info)
|
||||
# batch output for each optimizer.
|
||||
epoch_outputs.append(batch_output)
|
||||
|
||||
should_stop = batch_output["signal"] == -1
|
||||
|
||||
if self.use_tqdm and self.world_rank == 0:
|
||||
_progress_bar.n = batch_idx + 1
|
||||
postfix = {}
|
||||
if "training_loss" in batch_output:
|
||||
postfix.update(loss=batch_output["training_loss"])
|
||||
_progress_bar.set_postfix(postfix)
|
||||
|
||||
for s_dict, scheduler in zip(self.scheduler_dicts,
|
||||
self.schedulers):
|
||||
if s_dict["interval"] == SCHEDULER_STEP_BATCH:
|
||||
scheduler.step()
|
||||
|
||||
self.global_step += 1
|
||||
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
processed_outputs = None
|
||||
if self.is_overridden("training_epoch_end", model):
|
||||
raw_outputs = [eo["raw_output"] for eo in epoch_outputs]
|
||||
processed_outputs = model.training_epoch_end(raw_outputs)
|
||||
|
||||
if processed_outputs is not None:
|
||||
if isinstance(processed_outputs, torch.Tensor):
|
||||
return_output = {"train_loss": processed_outputs}
|
||||
elif isinstance(processed_outputs, Result):
|
||||
raise ValueError("Result objects are not supported. Please "
|
||||
"return a dictionary instead.")
|
||||
elif isinstance(processed_outputs, dict):
|
||||
return_output = processed_outputs
|
||||
else:
|
||||
raise TypeError("training_epoch_end returned an invalid "
|
||||
"type. It must return a Tensor, Result, "
|
||||
"or dict.")
|
||||
else:
|
||||
# User did not override training_epoch_end
|
||||
assert isinstance(epoch_outputs, list)
|
||||
# Use AverageMeterCollection util to reduce results.
|
||||
meter_collection = AverageMeterCollection()
|
||||
for o in epoch_outputs:
|
||||
num_samples = o.pop(NUM_SAMPLES, 1)
|
||||
raw_output = o["raw_output"]
|
||||
if isinstance(raw_output, dict):
|
||||
meter_collection.update(raw_output, num_samples)
|
||||
elif isinstance(raw_output, torch.Tensor):
|
||||
meter_collection.update({
|
||||
"train_loss": o["training_loss"]
|
||||
}, num_samples)
|
||||
return_output = meter_collection.summary()
|
||||
|
||||
if self.is_function_implemented("on_train_epoch_end", model):
|
||||
model.on_train_epoch_end()
|
||||
|
||||
for s_dict, scheduler in zip(self.scheduler_dicts, self.schedulers):
|
||||
if s_dict["interval"] == SCHEDULER_STEP_EPOCH:
|
||||
scheduler.step()
|
||||
|
||||
return return_output
|
||||
|
||||
def train_batch(self, batch, batch_info):
|
||||
# Get the original PTL module.
|
||||
model = self.get_model()
|
||||
optimizer = self.optimizers[0]
|
||||
batch_idx = batch_info["batch_idx"]
|
||||
epoch_idx = batch_info["epoch_idx"]
|
||||
|
||||
if self.is_function_implemented("on_train_batch_start", model):
|
||||
response = model.on_train_batch_start(
|
||||
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
|
||||
# Skip remainder of epoch if response is -1.
|
||||
if response == -1:
|
||||
return {"signal": -1}
|
||||
|
||||
args = [batch, batch_idx]
|
||||
if len(self.optimizers) > 1:
|
||||
if self.has_arg("training_step", "optimizer_idx"):
|
||||
args.append(0)
|
||||
|
||||
with self.timers.record("fwd"):
|
||||
if self._is_distributed:
|
||||
# Use the DDP wrapped model (self.model).
|
||||
output = self.model(*args)
|
||||
elif self.use_gpu:
|
||||
# Using single GPU.
|
||||
# Don't copy the batch since there is a single gpu that
|
||||
# the batch could be referenced from and if there are
|
||||
# multiple optimizers the batch will wind up copying it to
|
||||
# the same device repeatedly.
|
||||
device = self.device
|
||||
batch = model.transfer_batch_to_device(batch, device=device)
|
||||
args[0] = batch
|
||||
output = model.training_step(*args)
|
||||
else:
|
||||
# Using CPU.
|
||||
output = model.training_step(*args)
|
||||
|
||||
if isinstance(output, Result):
|
||||
raise ValueError("TrainResult objects are not supported. Please "
|
||||
"return a dictionary instead.")
|
||||
|
||||
# allow any mode to define training_step_end
|
||||
# do something will all the dp outputs (like softmax)
|
||||
if self.is_overridden("training_step_end", model):
|
||||
output = model.training_step_end(output)
|
||||
|
||||
# Extract loss from output if dictionary.
|
||||
try:
|
||||
loss = output["loss"]
|
||||
except Exception:
|
||||
if isinstance(output, torch.Tensor):
|
||||
loss = output
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No `loss` value in the dictionary returned from "
|
||||
"`model.training_step()`.")
|
||||
|
||||
# If output contains tensors, detach them all.
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.detach()
|
||||
elif isinstance(output, dict):
|
||||
output = recursive_detach(output)
|
||||
else:
|
||||
raise TypeError("training_step returned invalid type. It must "
|
||||
"return either a Tensor, Result, or dict.")
|
||||
|
||||
untouched_loss = loss.detach().clone()
|
||||
|
||||
with self.timers.record("grad"):
|
||||
if self.use_fp16:
|
||||
with self._amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
model.backward(
|
||||
self, scaled_loss, optimizer, optimizer_idx=0)
|
||||
else:
|
||||
model.backward(self, loss, optimizer, optimizer_idx=0)
|
||||
|
||||
if self.is_function_implemented("on_after_backward", model):
|
||||
model.on_after_backward()
|
||||
|
||||
with self.timers.record("apply"):
|
||||
model.optimizer_step(
|
||||
epoch=epoch_idx,
|
||||
batch_idx=batch_idx,
|
||||
optimizer=optimizer,
|
||||
optimizer_idx=0)
|
||||
|
||||
model.on_before_zero_grad(optimizer)
|
||||
|
||||
model.optimizer_zero_grad(
|
||||
epoch=epoch_idx,
|
||||
batch_idx=batch_idx,
|
||||
optimizer=optimizer,
|
||||
optimizer_idx=0)
|
||||
|
||||
if self.is_function_implemented("on_train_batch_end", model):
|
||||
model.on_train_batch_end(
|
||||
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
|
||||
|
||||
return {
|
||||
"signal": 0,
|
||||
"training_loss": untouched_loss.item(),
|
||||
"raw_output": output,
|
||||
# NUM_SAMPLES: len(batch)
|
||||
}
|
||||
|
||||
def validate(self, val_iterator, info):
|
||||
self.model.zero_grad()
|
||||
self.model.eval()
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
model = self.get_model()
|
||||
if self.is_function_implemented("on_validation_epoch_start", model):
|
||||
model.on_validation_epoch_start()
|
||||
|
||||
val_outputs = []
|
||||
for batch_idx, batch in enumerate(val_iterator):
|
||||
batch_info = {"batch_idx": batch_idx}
|
||||
batch_info.update(info)
|
||||
batch_output = self.validate_batch(batch, batch_info)
|
||||
if batch_output is not None:
|
||||
val_outputs.append(batch_output)
|
||||
|
||||
processed_outputs = None
|
||||
if self.is_overridden("validation_epoch_end", model):
|
||||
raw_outputs = [vo["raw_output"] for vo in val_outputs]
|
||||
processed_outputs = model.training_epoch_end(raw_outputs)
|
||||
|
||||
if processed_outputs is not None:
|
||||
if isinstance(processed_outputs, torch.Tensor):
|
||||
return_output = {"val_loss": processed_outputs}
|
||||
elif isinstance(processed_outputs, Result):
|
||||
raise ValueError("Result objects are not supported. Please "
|
||||
"return a dictionary instead.")
|
||||
elif isinstance(processed_outputs, dict):
|
||||
return_output = processed_outputs
|
||||
else:
|
||||
raise TypeError("validation_epoch_end returned an invalid "
|
||||
"type. It must return a Tensor, Result, "
|
||||
"or dict.")
|
||||
else:
|
||||
# User did not override training_epoch_end
|
||||
assert isinstance(val_outputs, list)
|
||||
# Use AverageMeterCollection util to reduce results.
|
||||
meter_collection = AverageMeterCollection()
|
||||
for v in val_outputs:
|
||||
num_samples = v.pop(NUM_SAMPLES, 1)
|
||||
raw_output = v["raw_output"]
|
||||
if isinstance(raw_output, dict):
|
||||
meter_collection.update(raw_output, num_samples)
|
||||
elif isinstance(raw_output, torch.Tensor):
|
||||
meter_collection.update({
|
||||
"val_loss": raw_output.item()
|
||||
}, num_samples)
|
||||
return_output = meter_collection.summary()
|
||||
|
||||
if self.is_function_implemented("on_validation_epoch_end", model):
|
||||
model.on_validation_epoch_end()
|
||||
|
||||
# Set back to True so training will work.
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
return return_output
|
||||
|
||||
def validate_batch(self, batch, batch_info):
|
||||
model = self.get_model()
|
||||
batch_idx = batch_info["batch_idx"]
|
||||
if self.is_overridden("on_validation_batch_start", model):
|
||||
model.on_validation_batch_start(
|
||||
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
|
||||
args = [batch, batch_idx]
|
||||
with self.timers.record("eval_fwd"):
|
||||
if self._is_distributed:
|
||||
# Use the DDP wrapped model (self.model).
|
||||
output = self.model(*args)
|
||||
elif self.use_gpu:
|
||||
# Using single GPU.
|
||||
device = self.device
|
||||
batch = model.transfer_batch_to_device(batch, device=device)
|
||||
args[0] = batch
|
||||
output = model.validation_step(*args)
|
||||
else:
|
||||
# Using CPU.
|
||||
output = model.validation_step(*args)
|
||||
|
||||
if isinstance(output, Result):
|
||||
raise ValueError("EvalResult objects are not supported. Please "
|
||||
"return a dictionary instead.")
|
||||
|
||||
if self.is_overridden("on_validation_step_end", model):
|
||||
output = model.validation_step_end(output)
|
||||
|
||||
if self.is_function_implemented("on_validation_batch_end", model):
|
||||
model.on_validation_batch_end(
|
||||
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
|
||||
return {
|
||||
"raw_output": output,
|
||||
# NUM_SAMPLES: len(batch)
|
||||
}
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {}
|
||||
self.get_model().on_save_checkpoint(checkpoint=state_dict)
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.get_model().on_load_checkpoint(checkpoint=state_dict)
|
||||
|
||||
def _get_train_loader(self):
|
||||
if not hasattr(self, "_train_loader") or \
|
||||
self._train_loader is None:
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered training loader. Make sure "
|
||||
"to pass in a training loader to "
|
||||
"TrainingOperator.from_ptl or implement "
|
||||
"train_dataloader in your LightningModule.")
|
||||
return self._train_loader
|
||||
|
||||
def _get_validation_loader(self):
|
||||
if not hasattr(self, "_validation_loader") or \
|
||||
self._validation_loader is None:
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered validation loader. Make sure "
|
||||
"to pass in a validation loader to "
|
||||
"TrainingOperator.from_ptl or implement "
|
||||
"val_dataloader in your LightningModule.")
|
||||
return self._validation_loader
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
import io
|
||||
import itertools
|
||||
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS
|
||||
@@ -50,6 +52,8 @@ class TorchRunner:
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
world_rank=0,
|
||||
local_rank=0,
|
||||
is_distributed=False,
|
||||
use_gpu=self.use_gpu,
|
||||
use_fp16=self.use_fp16,
|
||||
use_tqdm=self.use_tqdm,
|
||||
@@ -69,6 +73,7 @@ class TorchRunner:
|
||||
info.update({
|
||||
NUM_STEPS: num_steps,
|
||||
USE_FP16: self.use_fp16,
|
||||
"epoch_idx": self.epochs,
|
||||
})
|
||||
with self.timers.record("train_epoch"):
|
||||
if iterator is None:
|
||||
@@ -94,10 +99,6 @@ class TorchRunner:
|
||||
|
||||
def validate(self, num_steps=None, profile=False, info=None):
|
||||
"""Evaluates the model on the validation data set."""
|
||||
if self.validation_loader is None:
|
||||
raise ValueError("No validation dataloader provided. Make sure"
|
||||
"you pass in a validation_loader to "
|
||||
"TrainingOperator.register_data.")
|
||||
info = info or {}
|
||||
self._toggle_profiling(profile=profile)
|
||||
validation_loader = self.validation_loader
|
||||
@@ -204,62 +205,31 @@ class TorchRunner:
|
||||
"""Getter method. Needed for remote actor calls."""
|
||||
return self.models
|
||||
|
||||
def get_node_ip(self):
|
||||
return ray.services.get_node_ip_address()
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
if not hasattr(self.training_operator, "_original_models"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered models. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self.training_operator._original_models
|
||||
return self.training_operator._get_original_models()
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
if not hasattr(self.training_operator, "_optimizers"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered optimizers. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self.training_operator._optimizers
|
||||
return self.training_operator._get_optimizers()
|
||||
|
||||
@property
|
||||
def schedulers(self):
|
||||
if not hasattr(self.training_operator, "_schedulers"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered schedulers. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self.training_operator._schedulers
|
||||
return self.training_operator._get_schedulers()
|
||||
|
||||
@property
|
||||
def train_loader(self):
|
||||
if not hasattr(self.training_operator, "_train_loader"):
|
||||
logger.warning("Training Operator does not have any "
|
||||
"registered train loader. If this is "
|
||||
"unexepected, make sure to call "
|
||||
"self.register_data(...) inside the setup method "
|
||||
"of your Training Operator.")
|
||||
return None
|
||||
return self.training_operator._train_loader
|
||||
return self.training_operator._get_train_loader()
|
||||
|
||||
@property
|
||||
def validation_loader(self):
|
||||
if not hasattr(self.training_operator, "_validation_loader"):
|
||||
logger.warning("Training Operator does not have any "
|
||||
"registered validation loader. If this is "
|
||||
"unexepected, make sure to call "
|
||||
"self.register_data(...) inside the setup method "
|
||||
"of your Training Operator.")
|
||||
return None
|
||||
return self.training_operator._validation_loader
|
||||
return self.training_operator._get_validation_loader()
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
if not hasattr(self.training_operator, "_criterion"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered criterion. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self.training_operator._criterion
|
||||
|
||||
@property
|
||||
|
||||
@@ -439,7 +439,7 @@ class TorchTrainer:
|
||||
}
|
||||
|
||||
for stat_key in worker_stats[0]:
|
||||
if isinstance(worker_stats[0], numbers.Number):
|
||||
if isinstance(worker_stats[0][stat_key], numbers.Number):
|
||||
stats[stat_key] = np.nanmean(
|
||||
[s.get(stat_key, np.nan) for s in worker_stats])
|
||||
else:
|
||||
|
||||
@@ -14,6 +14,7 @@ from ray.util.sgd.torch.constants import (
|
||||
NUM_STEPS,
|
||||
SCHEDULER_STEP_BATCH,
|
||||
)
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DistributedSampler, DataLoader, IterableDataset
|
||||
|
||||
@@ -78,7 +79,6 @@ class TrainingOperator:
|
||||
train_loader=train_loader,
|
||||
validation_loader=val_loader)
|
||||
|
||||
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=MyTrainingOperator,
|
||||
config={"batch_size": 32},
|
||||
@@ -119,18 +119,21 @@ class TrainingOperator:
|
||||
def __init__(self,
|
||||
config,
|
||||
world_rank,
|
||||
local_rank,
|
||||
is_distributed=False,
|
||||
device_ids=None,
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
wrap_ddp=False,
|
||||
wrap_distributed_sampler=False,
|
||||
add_dist_sampler=False,
|
||||
scheduler_step_freq=None):
|
||||
# You are not expected to override this method.
|
||||
self._world_rank = world_rank
|
||||
self._local_rank = local_rank
|
||||
self._config = config
|
||||
self._is_distributed = is_distributed
|
||||
self._use_fp16 = use_fp16
|
||||
self._device_ids = device_ids
|
||||
self._use_gpu = use_gpu and torch.cuda.is_available()
|
||||
@@ -141,7 +144,6 @@ class TrainingOperator:
|
||||
self.global_step = 0
|
||||
self._apex_args = apex_args if apex_args else {}
|
||||
self._wrap_ddp = wrap_ddp
|
||||
self._wrap_distributed_sampler = wrap_distributed_sampler
|
||||
self._add_dist_sampler = add_dist_sampler
|
||||
self._scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
@@ -152,6 +154,28 @@ class TrainingOperator:
|
||||
"""Passes in the timers from the Runner."""
|
||||
self.timers = timers
|
||||
|
||||
def _configure_amp(self, amp, models, optimizers):
|
||||
models, optimizers = amp.initialize(models, optimizers,
|
||||
**self._apex_args)
|
||||
return models, optimizers
|
||||
|
||||
def _configure_ddp(self, models, device_ids):
|
||||
return [
|
||||
DistributedDataParallel(model, device_ids=device_ids)
|
||||
for model in models
|
||||
]
|
||||
|
||||
def _return_items(self, items, original_items):
|
||||
"""Helper method to return items in same format as original_items."""
|
||||
if isinstance(original_items, tuple):
|
||||
return tuple(items)
|
||||
elif isinstance(original_items, Iterable):
|
||||
# Items is already a list.
|
||||
return items
|
||||
else:
|
||||
assert len(items) == 1
|
||||
return items[0]
|
||||
|
||||
def setup(self, config):
|
||||
"""Override this method to implement operator setup.
|
||||
|
||||
@@ -218,7 +242,6 @@ class TrainingOperator:
|
||||
Returns:
|
||||
Tuple of model, optimizer, criterion if not None, and scheduler
|
||||
if not None.
|
||||
|
||||
"""
|
||||
return_vals = []
|
||||
logger.debug("Registering models.")
|
||||
@@ -244,7 +267,10 @@ class TrainingOperator:
|
||||
if not isinstance(self._schedulers, Iterable):
|
||||
self._schedulers = [self._schedulers]
|
||||
else:
|
||||
self._schedulers = None
|
||||
if isinstance(schedulers, Iterable):
|
||||
self._schedulers = []
|
||||
else:
|
||||
self._schedulers = None
|
||||
|
||||
if criterion:
|
||||
logger.debug("Registering loss.")
|
||||
@@ -257,28 +283,19 @@ class TrainingOperator:
|
||||
|
||||
if self.use_fp16 and amp:
|
||||
logger.debug("Setting up Apex.")
|
||||
self._original_models, self._optimizers = amp.initialize(
|
||||
self._original_models, self._optimizers, **self._apex_args)
|
||||
self._amp = amp
|
||||
self._original_models, self._optimizers = self._configure_amp(
|
||||
self._amp, self._original_models, self._optimizers)
|
||||
|
||||
if self._wrap_ddp:
|
||||
logging.debug("Setting up DDP for models.")
|
||||
self._models = [
|
||||
DistributedDataParallel(model, device_ids=self.device_ids)
|
||||
for model in self._original_models
|
||||
]
|
||||
self._models = self._configure_ddp(
|
||||
models=self._original_models, device_ids=self.device_ids)
|
||||
else:
|
||||
self._models = self._original_models
|
||||
|
||||
if len(self._models) == 1:
|
||||
return_vals.append(self._models[0])
|
||||
else:
|
||||
return_vals.append(self._models)
|
||||
|
||||
if len(self._optimizers) == 1:
|
||||
return_vals.append(self._optimizers[0])
|
||||
else:
|
||||
return_vals.append(self._optimizers)
|
||||
return_vals.append(self._return_items(self._models, models))
|
||||
return_vals.append(self._return_items(self._optimizers, optimizers))
|
||||
|
||||
if self._criterion is not None:
|
||||
return_vals.append(self._criterion)
|
||||
@@ -290,10 +307,8 @@ class TrainingOperator:
|
||||
"are registering schedulers. Set this to "
|
||||
"'manual' if you will be manually stepping "
|
||||
"the schedulers.")
|
||||
if len(self._schedulers) == 1:
|
||||
return_vals.append(self._schedulers[0])
|
||||
else:
|
||||
return_vals.append(self._schedulers)
|
||||
return_vals.append(
|
||||
self._return_items(self._schedulers, schedulers))
|
||||
|
||||
return tuple(return_vals)
|
||||
|
||||
@@ -348,8 +363,7 @@ class TrainingOperator:
|
||||
self._train_loader = train_loader
|
||||
self._validation_loader = validation_loader
|
||||
|
||||
if self._wrap_distributed_sampler:
|
||||
logging.debug("Wrapping data loaders with DistributedSampler.")
|
||||
if self._is_distributed:
|
||||
|
||||
def with_sampler(loader):
|
||||
# Automatically set the DistributedSampler
|
||||
@@ -373,11 +387,15 @@ class TrainingOperator:
|
||||
|
||||
if should_wrap_dataloader(self._train_loader):
|
||||
if self._add_dist_sampler:
|
||||
logging.debug("Wrapping train data loader with "
|
||||
"DistributedSampler.")
|
||||
self._train_loader = with_sampler(self._train_loader)
|
||||
|
||||
if self._validation_loader is not None and should_wrap_dataloader(
|
||||
self._validation_loader):
|
||||
if self._add_dist_sampler:
|
||||
logging.debug("Wrapping validation data loader with "
|
||||
"DistributedSampler.")
|
||||
self._validation_loader = with_sampler(
|
||||
self._validation_loader)
|
||||
|
||||
@@ -665,6 +683,90 @@ class TrainingOperator:
|
||||
state_dict (dict): State dict as returned by the operator. """
|
||||
pass
|
||||
|
||||
def _get_original_models(self):
|
||||
if not hasattr(self, "_original_models"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered models. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self._original_models
|
||||
|
||||
def _get_optimizers(self):
|
||||
if not hasattr(self, "_optimizers"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered optimizers. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self._optimizers
|
||||
|
||||
def _get_schedulers(self):
|
||||
if not hasattr(self, "_schedulers"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered schedulers. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self._schedulers
|
||||
|
||||
def _get_train_loader(self):
|
||||
if not hasattr(self, "_train_loader") or \
|
||||
self._train_loader is None:
|
||||
raise RuntimeError(
|
||||
"Training Operator does not have any "
|
||||
"registered train loader. If this is "
|
||||
"unexepected, make sure to call "
|
||||
"self.register_data(...) inside the setup method "
|
||||
"of your Training Operator.")
|
||||
return self._train_loader
|
||||
|
||||
def _get_validation_loader(self):
|
||||
if not hasattr(self, "_validation_loader") or \
|
||||
self._validation_loader is None:
|
||||
raise RuntimeError(
|
||||
"Training Operator does not have any "
|
||||
"registered validation loader. If this is "
|
||||
"unexepected, make sure to call "
|
||||
"self.register_data(...) inside the setup method "
|
||||
"of your Training Operator.")
|
||||
return self._validation_loader
|
||||
|
||||
def _get_criterion(self):
|
||||
if not hasattr(self, "_criterion"):
|
||||
raise RuntimeError("Training Operator does not have any "
|
||||
"registered criterion. Are you calling "
|
||||
"self.register(...) inside the setup method "
|
||||
"of your Training Operator?")
|
||||
return self._criterion
|
||||
|
||||
@classmethod
|
||||
def from_ptl(cls,
|
||||
lightning_module_cls,
|
||||
train_dataloader=None,
|
||||
val_dataloader=None):
|
||||
"""Creates a TrainingOperator from a Pytorch Lightning Module.
|
||||
|
||||
Args:
|
||||
lightning_module_cls: Your LightningModule class. An object of
|
||||
this class will get instantiated on each worker.
|
||||
train_dataloader: The data loader to use for training. If None
|
||||
is provided, LightningModule.train_dataloader will be used
|
||||
instead.
|
||||
val_dataloader: The data loader to use for validation. If None
|
||||
is provided, LightningModule.val_dataloader will be used
|
||||
instead.
|
||||
|
||||
Returns:
|
||||
A TrainingOperator class properly configured given the
|
||||
LightningModule.
|
||||
"""
|
||||
from ray.util.sgd.torch.ptl_operator import LightningOperator
|
||||
|
||||
class CustomLightningOperator(LightningOperator):
|
||||
_lightning_module_cls = lightning_module_cls
|
||||
_train_dataloader = train_dataloader
|
||||
_val_dataloader = val_dataloader
|
||||
|
||||
return CustomLightningOperator
|
||||
|
||||
@classmethod
|
||||
def from_creators(cls,
|
||||
model_creator,
|
||||
@@ -749,6 +851,11 @@ class TrainingOperator:
|
||||
"""int: The rank of the parent runner. Always 0 if not distributed."""
|
||||
return self._world_rank
|
||||
|
||||
@property
|
||||
def local_rank(self):
|
||||
"""int: Local rank of parent runner. Always 0 if not distributed."""
|
||||
return self._local_rank
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
"""Returns True if cuda is available and use_gpu is True."""
|
||||
@@ -849,29 +956,73 @@ class CreatorOperator(TrainingOperator):
|
||||
kwargs["criterion"] = criterion
|
||||
|
||||
state = self.register(**kwargs)
|
||||
self.models, self.optimizers = state[:2]
|
||||
if isinstance(self.models, tuple):
|
||||
self.model = self.models[0]
|
||||
self._registered_models, self._registered_optimizers = state[:2]
|
||||
if isinstance(self.models, (list, tuple)):
|
||||
logger.info("Multiple models have been registered. If custom "
|
||||
"training methods are not provided, only the first "
|
||||
"model will be used.")
|
||||
self._registered_model = self.models[0]
|
||||
else:
|
||||
self.model = self.models
|
||||
self._registered_model = self.models
|
||||
|
||||
if isinstance(self.optimizers, tuple):
|
||||
self.optimizer = self.optimizers[0]
|
||||
if isinstance(self.optimizers, (list, tuple)):
|
||||
logger.info("Multiple optimizers have been registered. If custom "
|
||||
"training methods are not provided, only the first "
|
||||
"optimizer will be used.")
|
||||
self._reigstered_optimizer = self.optimizers[0]
|
||||
else:
|
||||
self.optimizer = self.optimizers
|
||||
self._registered_optimizer = self.optimizers
|
||||
|
||||
if len(state) >= 3:
|
||||
self.criterion = state[2]
|
||||
self._registered_criterion = state[2]
|
||||
if len(state) == 4:
|
||||
self.schedulers = state[3]
|
||||
if isinstance(self.schedulers, tuple):
|
||||
self.scheduler = self.schedulers[0]
|
||||
self._registered_schedulers = state[3]
|
||||
if isinstance(self.schedulers, (list, tuple)):
|
||||
logger.info("Multiple schedulers have been registered. If "
|
||||
"custom training methods are not provided, "
|
||||
"only the first scheduler will be used.")
|
||||
self._registered_scheduler = self.schedulers[0]
|
||||
else:
|
||||
self.scheduler = self.schedulers
|
||||
self._registered_scheduler = self.schedulers
|
||||
|
||||
self.register_data(
|
||||
train_loader=train_loader, validation_loader=validation_loader)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""First or only model created by the provided ``model_creator``."""
|
||||
return self._registered_model
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
"""First or only optimizer(s) created by the ``optimizer_creator``."""
|
||||
return self._registered_optimizer
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""First or only scheduler(s) created by the ``scheduler_creator``."""
|
||||
return self._registered_scheduler
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
"""Criterion created by the provided ``loss_creator``."""
|
||||
return self._registered_criterion
|
||||
|
||||
@property
|
||||
def models(self):
|
||||
"""List of models created by the provided ``model_creator``."""
|
||||
return self._registered_models
|
||||
|
||||
@property
|
||||
def optimizers(self):
|
||||
"""List of optimizers created by the ``optimizer_creator``."""
|
||||
return self._registered_optimizers
|
||||
|
||||
@property
|
||||
def schedulers(self):
|
||||
"""List of schedulers created by the ``scheduler_creator``."""
|
||||
return self._registered_schedulers
|
||||
|
||||
|
||||
def get_test_operator(operator_cls):
|
||||
class _TestingOperator(operator_cls):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
|
||||
import ray
|
||||
@@ -191,6 +192,19 @@ class RemoteWorkerGroup(WorkerGroupInterface):
|
||||
]
|
||||
return remote_operator_setups
|
||||
|
||||
def _setup_local_rank(self, rank_counter_dict=None):
|
||||
"""Sets local rank for all workers."""
|
||||
if rank_counter_dict is None:
|
||||
rank_counter_dict = defaultdict(int)
|
||||
node_ips = ray.get(
|
||||
[w.get_node_ip.remote() for w in self.remote_workers])
|
||||
futures = []
|
||||
for ip, worker in zip(node_ips, self.remote_workers):
|
||||
rank = rank_counter_dict[ip]
|
||||
futures.append(worker.set_local_rank.remote(rank))
|
||||
rank_counter_dict[ip] += 1
|
||||
return futures
|
||||
|
||||
def start_workers(self, num_workers):
|
||||
logger.debug(f"start_workers: Setting %d workers." % num_workers)
|
||||
if num_workers == 1:
|
||||
@@ -212,6 +226,8 @@ class RemoteWorkerGroup(WorkerGroupInterface):
|
||||
self._setup_process_group(
|
||||
address=address, world_size=num_workers))
|
||||
|
||||
ray.get(self._setup_local_rank())
|
||||
|
||||
ray.get(self._setup_operator())
|
||||
|
||||
def _apply_all_operators(self, fn):
|
||||
@@ -454,6 +470,12 @@ class LocalWorkerGroup(WorkerGroupInterface):
|
||||
timeout=timedelta(self._timeout_s))
|
||||
ray.get(remote_pgs)
|
||||
|
||||
local_node_ip = ray.services.get_node_ip_address()
|
||||
rank_dict = defaultdict(int)
|
||||
self.local_worker.set_local_rank(local_rank=0)
|
||||
rank_dict[local_node_ip] += 1
|
||||
self.remote_worker_group._setup_local_rank(rank_dict)
|
||||
|
||||
remote_operators = self.remote_worker_group._setup_operator()
|
||||
self.local_worker.setup_operator()
|
||||
ray.get(remote_operators)
|
||||
|
||||
Reference in New Issue
Block a user