diff --git a/python/ray/util/sgd/BUILD b/python/ray/util/sgd/BUILD index 081bc97d6..f614c812e 100644 --- a/python/ray/util/sgd/BUILD +++ b/python/ray/util/sgd/BUILD @@ -2,6 +2,14 @@ # Tests from the python/ray/util/sgd/tests directory. # Please keep these sorted alphabetically. # -------------------------------------------------------------------- +py_test( + name = "test_ptl", + size = "small", + srcs = ["tests/test_ptl.py"], + tags = ["exclusive", "pytorch-lightning", "pytorch"], + deps = [":sgd_lib"], +) + py_test( name = "test_tensorflow", size = "small", @@ -214,6 +222,15 @@ py_test( args = ["--no-gpu", "--mock-data", "--smoke-test", "--ray-num-workers=2", "--model=mobilenetv3_small_075", "data"] ) +py_test( + name = "mnist-ptl", + size = "small", + srcs = ["torch/examples/pytorch-lightning/mnist-ptl.py"], + tags = ["exclusive", "pytorch", "pytorch-lightning"], + deps = [":sgd_lib"], + args = ["--smoke-test"] +) + # This is a dummy test dependency that causes the above tests to be # re-run if any of these files changes. py_library( diff --git a/python/ray/util/sgd/tests/test_ptl.py b/python/ray/util/sgd/tests/test_ptl.py new file mode 100644 index 000000000..e0e4dd3db --- /dev/null +++ b/python/ray/util/sgd/tests/test_ptl.py @@ -0,0 +1,211 @@ +import copy +import os + +import pytest +import ray +import torch +from ray.util.sgd.utils import BATCH_COUNT +import torch.distributed as dist +from pytorch_lightning import LightningModule +from ray.util.sgd import TorchTrainer +from ray.util.sgd.torch import TrainingOperator +from ray.util.sgd.torch.examples.train_example import \ + optimizer_creator, data_creator, scheduler_creator, model_creator +from torch import nn +import numpy as np + +torch.manual_seed(0) +np.random.seed(0) + + +@pytest.fixture +def ray_start_2_cpus(): + address_info = ray.init(num_cpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + # Ensure that tests don't ALL fail + if dist.is_initialized(): + dist.destroy_process_group() + + +class PTL_Module(LightningModule): + def __init__(self, config): + super().__init__() + + self.config = config + if "layer" in config: + self.layer = copy.deepcopy(config["layer"]) + else: + self.layer = model_creator(self.config) + + self.rand_int = np.random.randint(10) + + def forward(self, x): + return self.layer.forward(x) + + def configure_optimizers(self): + optimizer = optimizer_creator(self, self.config) + scheduler = scheduler_creator(optimizer, self.config) + return [optimizer], [scheduler] + + def training_step(self, batch, batch_idx): + x, y = batch + output = self(x) + loss = self.loss(output, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + output = self(x) + loss = self.loss(output, y) + _, predicted = torch.max(output.data, 1) + num_correct = (predicted == y).sum().item() + num_samples = y.size(0) + return {"val_loss": loss.item(), "val_acc": num_correct / num_samples} + + def setup(self, stage): + self.train_loader, self.val_loader = data_creator(self.config) + self.loss = nn.MSELoss() + + def train_dataloader(self): + return self.train_loader + + def val_dataloader(self): + return self.val_loader + + def on_save_checkpoint(self, checkpoint): + checkpoint["int"] = self.rand_int + + def on_load_checkpoint(self, checkpoint): + self.rand_int = checkpoint["int"] + + +Operator = TrainingOperator.from_ptl(PTL_Module) + + +@pytest.mark.parametrize("use_local", [True, False]) +def test_single_step(ray_start_2_cpus, use_local): # noqa: F811 + trainer = TorchTrainer( + training_operator_cls=Operator, + num_workers=1, + use_local=use_local, + use_gpu=False) + metrics = trainer.train(num_steps=1) + assert metrics[BATCH_COUNT] == 1 + + val_metrics = trainer.validate(num_steps=1) + assert val_metrics[BATCH_COUNT] == 1 + trainer.shutdown() + + +@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) +@pytest.mark.parametrize("use_local", [True, False]) +def test_train(ray_start_2_cpus, num_workers, use_local): # noqa: F811 + trainer = TorchTrainer( + training_operator_cls=Operator, + num_workers=num_workers, + use_local=use_local, + use_gpu=False) + for i in range(3): + train_loss1 = trainer.train()["train_loss"] + validation_loss1 = trainer.validate()["val_loss"] + + for i in range(3): + train_loss2 = trainer.train()["train_loss"] + validation_loss2 = trainer.validate()["val_loss"] + + assert train_loss2 <= train_loss1, (train_loss2, train_loss1) + assert validation_loss2 <= validation_loss1, (validation_loss2, + validation_loss1) + trainer.shutdown() + + +@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) +@pytest.mark.parametrize("use_local", [True, False]) +def test_save_and_restore(ray_start_2_cpus, num_workers, use_local, + tmp_path): # noqa: F811 + trainer1 = TorchTrainer( + training_operator_cls=Operator, + num_workers=num_workers, + use_local=use_local) + trainer1.train() + checkpoint_path = os.path.join(tmp_path, "checkpoint") + trainer1.save(checkpoint_path) + + model1 = trainer1.get_model() + ints1 = trainer1.apply_all_operators(lambda op: op.get_model().rand_int)[0] + + trainer1.shutdown() + + trainer2 = TorchTrainer( + training_operator_cls=Operator, + num_workers=num_workers, + use_local=use_local) + trainer2.load(checkpoint_path) + + model2 = trainer2.get_model() + ints2 = trainer2.apply_all_operators(lambda op: op.get_model().rand_int) + + model1_state_dict = model1.state_dict() + model2_state_dict = model2.state_dict() + + assert set(model1_state_dict.keys()) == set(model2_state_dict.keys()) + + for k in model1_state_dict: + assert torch.equal(model1_state_dict[k], model2_state_dict[k]) + for i in ints2: + assert i == ints1 + trainer2.shutdown() + + +class CorrectnessOperator(TrainingOperator): + def setup(self, config): + model = PTL_Module(config) + opt = optimizer_creator(model, config) + scheduler = scheduler_creator(opt, config) + self.model, self.optimizer, self.criterion, self.scheduler = \ + self.register( + models=model, optimizers=opt, criterion=nn.MSELoss(), + schedulers=scheduler) + + train_loader, val_loader = data_creator(config) + self.register_data( + train_loader=train_loader, validation_loader=val_loader) + + +@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1]) +@pytest.mark.parametrize("use_local", [True, False]) +def test_correctness(ray_start_2_cpus, num_workers, use_local): + layer = nn.Linear(1, 1) + ptl_op = TrainingOperator.from_ptl(PTL_Module) + trainer1 = TorchTrainer( + training_operator_cls=ptl_op, + config={ + "layer": layer, + "data_size": 3, + "batch_size": 1 + }, + num_workers=num_workers, + use_local=use_local) + train1_stats = trainer1.train() + val1_stats = trainer1.validate() + trainer1.shutdown() + + trainer2 = TorchTrainer( + training_operator_cls=CorrectnessOperator, + scheduler_step_freq="manual", + config={ + "layer": layer, + "data_size": 3, + "batch_size": 1 + }, + num_workers=num_workers, + use_local=use_local) + train2_stats = trainer2.train() + val2_stats = trainer2.validate() + trainer2.shutdown() + + assert train1_stats["train_loss"] == train2_stats["train_loss"] + assert val1_stats["val_loss"] == val2_stats["val_loss"] + assert val1_stats["val_acc"] == val2_stats["val_accuracy"] diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index c21e61a8e..fc9015b68 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -3,8 +3,6 @@ import os import torch import torch.distributed as dist -from torch.utils.data import DataLoader, IterableDataset -from torch.utils.data.distributed import DistributedSampler from ray.util.sgd.torch.utils import setup_process_group import ray @@ -42,6 +40,7 @@ class DistributedTorchRunner(TorchRunner): self.wrap_ddp = wrap_ddp self.add_dist_sampler = add_dist_sampler self.world_rank = None + self.local_rank = None def setup_address(self): return setup_address() @@ -61,6 +60,14 @@ class DistributedTorchRunner(TorchRunner): setup_process_group( url, world_rank, world_size, timeout, backend=self.backend) + def set_local_rank(self, local_rank): + """Sets the local rank of this runner. + + Args: + local_rank (int): the index of the runner on its node. + """ + self.local_rank = local_rank + def setup_operator(self): """Runs distributed coordination components. @@ -74,13 +81,14 @@ class DistributedTorchRunner(TorchRunner): self.training_operator = self.training_operator_cls( self.config, world_rank=self.world_rank, + local_rank=self.local_rank, + is_distributed=True, device_ids=device_ids, use_gpu=self.use_gpu, use_fp16=self.use_fp16, use_tqdm=self.use_tqdm, apex_args=self.apex_args, wrap_ddp=self.wrap_ddp, - wrap_distributed_sampler=True, add_dist_sampler=self.add_dist_sampler, scheduler_step_freq=self.scheduler_step_freq) @@ -88,45 +96,21 @@ class DistributedTorchRunner(TorchRunner): """Needed for SyncBatchNorm, which needs 1 GPU per process.""" return [0] - def _wrap_dataloaders(self): - def with_sampler(loader): - # Automatically set the DistributedSampler - data_loader_args = { - "dataset": loader.dataset, - "batch_size": loader.batch_size, - "shuffle": False, - "num_workers": loader.num_workers, - "collate_fn": loader.collate_fn, - "pin_memory": loader.pin_memory, - "drop_last": loader.drop_last, - "timeout": loader.timeout, - "worker_init_fn": loader.worker_init_fn, - "sampler": DistributedSampler(loader.dataset) - } - return DataLoader(**data_loader_args) - - def should_wrap_dataloader(loader): - return (isinstance(loader, DataLoader) - and not isinstance(loader.dataset, IterableDataset)) - - if should_wrap_dataloader(self.train_loader): - if self.add_dist_sampler: - self.train_loader = with_sampler(self.train_loader) - - if self.validation_loader is not None and should_wrap_dataloader( - self.validation_loader): - if self.add_dist_sampler: - self.validation_loader = with_sampler(self.validation_loader) - - def train_epoch(self, **kwargs): + def train_epoch(self, + num_steps=None, + profile=False, + info=None, + iterator=None): """Runs a training epoch and updates the model parameters. Automatically sets epoch of sampler if possible. """ - if hasattr(self.train_loader, "sampler") and hasattr( + if iterator is None and hasattr(self.train_loader, "sampler") and \ + hasattr( self.train_loader.sampler, "set_epoch"): self.train_loader.sampler.set_epoch(self.epochs) - return super(DistributedTorchRunner, self).train_epoch(**kwargs) + return super(DistributedTorchRunner, self).train_epoch( + num_steps=num_steps, profile=profile, info=info, iterator=iterator) def shutdown(self): """Attempts to shut down the worker.""" diff --git a/python/ray/util/sgd/torch/examples/pytorch-lightning/mnist-ptl.py b/python/ray/util/sgd/torch/examples/pytorch-lightning/mnist-ptl.py new file mode 100644 index 000000000..da7731403 --- /dev/null +++ b/python/ray/util/sgd/torch/examples/pytorch-lightning/mnist-ptl.py @@ -0,0 +1,145 @@ +import argparse + +import torch +from ray.util.sgd import TorchTrainer +from ray.util.sgd.torch import TrainingOperator +from torch.nn import functional as F +from pytorch_lightning.core.lightning import LightningModule +from torch.optim import Adam +from torch.utils.data import DataLoader, random_split +from torchvision.datasets import MNIST +import os +from torchvision import transforms + + +class LitMNIST(LightningModule): + def __init__(self, config): + super().__init__() + + # mnist images are (1, 28, 28) (channels, width, height) + self.layer_1 = torch.nn.Linear(28 * 28, 128) + self.layer_2 = torch.nn.Linear(128, 256) + self.layer_3 = torch.nn.Linear(256, 10) + + self.config = config + + def forward(self, x): + batch_size, channels, width, height = x.size() + + # (b, 1, 28, 28) -> (b, 1*28*28) + x = x.view(batch_size, -1) + x = self.layer_1(x) + x = torch.relu(x) + x = self.layer_2(x) + x = torch.relu(x) + x = self.layer_3(x) + + x = torch.log_softmax(x, dim=1) + return x + + def configure_optimizers(self): + return Adam(self.parameters(), lr=self.config["lr"]) + + def setup(self, stage): + # transforms for images + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307, ), (0.3081, )) + ]) + + # prepare transforms standard to MNIST + mnist_train = MNIST( + os.getcwd(), train=True, download=True, transform=transform) + + self.mnist_train, self.mnist_val = random_split( + mnist_train, [55000, 5000]) + + def train_dataloader(self): + return DataLoader( + self.mnist_train, batch_size=self.config["batch_size"]) + + def val_dataloader(self): + return DataLoader(self.mnist_val, batch_size=self.config["batch_size"]) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.nll_loss(logits, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.nll_loss(logits, y) + _, predicted = torch.max(logits.data, 1) + num_correct = (predicted == y).sum().item() + num_samples = y.size(0) + return {"val_loss": loss.item(), "val_acc": num_correct / num_samples} + + +def train_mnist(num_workers=1, use_gpu=False, num_epochs=5): + Operator = TrainingOperator.from_ptl(LitMNIST) + trainer = TorchTrainer( + training_operator_cls=Operator, + num_workers=num_workers, + config={ + "lr": 1e-3, + "batch_size": 64 + }, + use_gpu=use_gpu, + use_tqdm=True, + ) + for i in range(num_epochs): + stats = trainer.train() + print(stats) + + print(trainer.validate()) + print("Saving model checkpoint to ./model.pt") + trainer.save("./model.pt") + print("Model Checkpointed!") + trainer.shutdown() + print("success!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--address", + required=False, + type=str, + help="the address to use for Ray") + parser.add_argument( + "--num-workers", + "-n", + type=int, + default=1, + help="Sets number of workers for training.") + parser.add_argument( + "--use-gpu", + action="store_true", + default=False, + help="Enables GPU training") + parser.add_argument( + "--num-epochs", + type=int, + default=5, + help="How many epochs to train " + "for.") + parser.add_argument( + "--smoke-test", + action="store_true", + default=False, + help="Finish quickly for testing.") + + args, _ = parser.parse_known_args() + + import ray + if args.smoke_test: + ray.init(num_cpus=2) + args.num_epochs = 1 + else: + ray.init(address=args.address) + train_mnist( + num_workers=args.num_workers, + use_gpu=args.use_gpu, + num_epochs=args.num_epochs) diff --git a/python/ray/util/sgd/torch/ptl_operator.py b/python/ray/util/sgd/torch/ptl_operator.py new file mode 100644 index 000000000..8ce3829ff --- /dev/null +++ b/python/ray/util/sgd/torch/ptl_operator.py @@ -0,0 +1,502 @@ +import inspect +import logging + +import torch +from pytorch_lightning.core.step_result import Result +from pytorch_lightning.overrides.data_parallel import \ + LightningDistributedDataParallel +from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin +from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin +import pytorch_lightning as ptl +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.memory import recursive_detach +from ray.util.sgd.torch import TrainingOperator +from ray.util.sgd.torch.constants import NUM_STEPS, SCHEDULER_STEP_BATCH, \ + SCHEDULER_STEP_EPOCH +from ray.util.sgd.utils import AverageMeterCollection, NUM_SAMPLES + +tqdm = None +try: + from tqdm import tqdm +except ImportError: + pass + +logger = logging.getLogger(__name__) + + +class LightningOperator(TrainingOperator, TrainerModelHooksMixin, + TrainerOptimizersMixin): + def _configure_amp(self, amp, models, optimizers): + assert len(models) == 1 + model = models[0] + assert isinstance(model, ptl.LightningModule) + amp_level = self._apex_args.get("opt_level", "O2") + model, optimizers = model.configure_apex( + amp, model, optimizers, amp_level=amp_level) + return [model], optimizers + + def _configure_ddp(self, models, device_ids): + assert len(models) == 1 + model = models[0] + assert isinstance(model, ptl.LightningModule) + # This will default to LightningDistributedDataParallel. + model = model.configure_ddp(model=model, device_ids=device_ids) + return [model] + + @property + def model(self): + """The LightningModule to use for training. + + The returned model is wrapped in DDP if using distributed training. + """ + return self._model + + @property + def scheduler_dicts(self): + """Returns list of scheduler dictionaries. + + List is empty if no schedulers are returned in the + configure_optimizers method of your LightningModule. Default + configuration is used if configure_optimizers returns scheduler + objects instead of scheduler dicts. See + https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#configure-optimizers + """ + return self._scheduler_dicts + + @property + def optimizers(self): + """Returns list of optimizers as returned by configure_optimizers.""" + return self._optimizers + + @property + def schedulers(self): + """Returns list of schedulers as returned by configure_optimizers. + + List is empty if no schedulers are returned in configure_optimizers. + """ + return self._schedulers + + def get_model(self): + """Returns original LightningModule, not wrapped in DDP.""" + if isinstance(self.model, LightningDistributedDataParallel): + return self.model.module + else: + return self.model + + def setup(self, config): + # Pass in config if ptl_module accepts it. + ptl_class = self.__class__._lightning_module_cls + if not issubclass(ptl_class, ptl.LightningModule): + raise TypeError("Argument must be subclass of " + "pytorch_lightning.LightningModule. Got class {} " + "instead.".format(ptl_class)) + if "config" in inspect.signature(ptl_class.__init__).parameters: + ptl_module = ptl_class(config=config) + else: + ptl_module = ptl_class() + + # This is needed for LightningDistributedDataParallel. + ptl_module.testing = False + + # Call on_fit_start on instantiation. + if self.is_function_implemented("on_fit_start", ptl_module): + ptl_module.on_fit_start() + + # Only run data preparation once per node. + if self.local_rank == 0 and self.is_function_implemented( + "prepare_data", ptl_module): + ptl_module.prepare_data() + + # Call model.setup. + ptl_module.setup("fit") + + if not self.is_overridden("configure_optimizers", ptl_module): + raise MisconfigurationException( + "No `configure_optimizers()` method defined.") + + optimizers, self._scheduler_dicts, optimizer_frequencies = \ + self.init_optimizers(model=ptl_module) + + if len(optimizer_frequencies) > 0: + logger.warning("Optimizer frequencies will be ignored. When " + "passing in multiple optimizers, you should " + "implement your own custom training loop.") + + lr_schedulers = [] + for scheduler in self.scheduler_dicts: + if isinstance(scheduler, dict): + # A scheduler dictionary is passed in. + if "reduce_on_plateau" in scheduler and "monitor" in \ + scheduler and scheduler["reduce_on_plateau"] is True: + logger.info( + "reduce_on_plateau and monitor will be " + "ignored " + "from the scheduler dict {}. To update a " + "ReduceLROnPlateau scheduler, you should use " + "TorchTrainer.update_schedulers.".format(scheduler)) + if "frequency" in scheduler and scheduler["frequency"] > 1: + logger.info("frequency will be ignored from the " + "scheduler dict {}.".format(scheduler)) + lr_schedulers.append(scheduler["scheduler"]) + else: + lr_schedulers.append(scheduler) + + # Set this so register doesn't complain. + self._scheduler_step_freq = "ptl" + ddp_model, self._optimizers, self._schedulers = self.register( + models=[ptl_module], + optimizers=optimizers, + schedulers=lr_schedulers) + + assert len(ddp_model) == 1 + self._model = ddp_model[0] + + model = self.get_model() + if self.is_function_implemented("on_pretrain_routine_start", model): + model.on_pretrain_routine_start() + + train_data_loader = None + if self.__class__._train_dataloader: + train_data_loader = self.__class__._train_dataloader + elif self.is_function_implemented("train_dataloader", model): + train_data_loader = model.train_dataloader() + + val_data_loader = None + if self.__class__._val_dataloader: + val_data_loader = self.__class__._val_dataloader + elif self.is_function_implemented("val_dataloader", model): + val_data_loader = model.val_dataloader() + + self.register_data( + train_loader=train_data_loader, validation_loader=val_data_loader) + + def train_epoch(self, iterator, info): + model = self.get_model() + + # Enable train mode. + self.model.train() + + # Enable gradients. + torch.set_grad_enabled(True) + + if self.is_function_implemented("on_train_epoch_start", model): + model.on_train_epoch_start() + + if self.use_tqdm and self.world_rank == 0: + desc = "" + if info is not None and "epoch_idx" in info: + if "num_epochs" in info: + desc = f"{info['epoch_idx'] + 1}/{info['num_epochs']}e" + else: + desc = f"{info['epoch_idx'] + 1}e" + + # TODO: Implement len for Dataset? + total = info[NUM_STEPS] + if total is None: + if hasattr(iterator, "__len__"): + total = len(iterator) + + _progress_bar = tqdm( + total=total, desc=desc, unit="batch", leave=False) + + # Output for each batch. + epoch_outputs = [] + + for batch_idx, batch in enumerate(iterator): + batch_info = { + "batch_idx": batch_idx, + "global_step": self.global_step + } + batch_info.update(info) + batch_output = self.train_batch(batch, batch_info=batch_info) + # batch output for each optimizer. + epoch_outputs.append(batch_output) + + should_stop = batch_output["signal"] == -1 + + if self.use_tqdm and self.world_rank == 0: + _progress_bar.n = batch_idx + 1 + postfix = {} + if "training_loss" in batch_output: + postfix.update(loss=batch_output["training_loss"]) + _progress_bar.set_postfix(postfix) + + for s_dict, scheduler in zip(self.scheduler_dicts, + self.schedulers): + if s_dict["interval"] == SCHEDULER_STEP_BATCH: + scheduler.step() + + self.global_step += 1 + + if should_stop: + break + + processed_outputs = None + if self.is_overridden("training_epoch_end", model): + raw_outputs = [eo["raw_output"] for eo in epoch_outputs] + processed_outputs = model.training_epoch_end(raw_outputs) + + if processed_outputs is not None: + if isinstance(processed_outputs, torch.Tensor): + return_output = {"train_loss": processed_outputs} + elif isinstance(processed_outputs, Result): + raise ValueError("Result objects are not supported. Please " + "return a dictionary instead.") + elif isinstance(processed_outputs, dict): + return_output = processed_outputs + else: + raise TypeError("training_epoch_end returned an invalid " + "type. It must return a Tensor, Result, " + "or dict.") + else: + # User did not override training_epoch_end + assert isinstance(epoch_outputs, list) + # Use AverageMeterCollection util to reduce results. + meter_collection = AverageMeterCollection() + for o in epoch_outputs: + num_samples = o.pop(NUM_SAMPLES, 1) + raw_output = o["raw_output"] + if isinstance(raw_output, dict): + meter_collection.update(raw_output, num_samples) + elif isinstance(raw_output, torch.Tensor): + meter_collection.update({ + "train_loss": o["training_loss"] + }, num_samples) + return_output = meter_collection.summary() + + if self.is_function_implemented("on_train_epoch_end", model): + model.on_train_epoch_end() + + for s_dict, scheduler in zip(self.scheduler_dicts, self.schedulers): + if s_dict["interval"] == SCHEDULER_STEP_EPOCH: + scheduler.step() + + return return_output + + def train_batch(self, batch, batch_info): + # Get the original PTL module. + model = self.get_model() + optimizer = self.optimizers[0] + batch_idx = batch_info["batch_idx"] + epoch_idx = batch_info["epoch_idx"] + + if self.is_function_implemented("on_train_batch_start", model): + response = model.on_train_batch_start( + batch=batch, batch_idx=batch_idx, dataloader_idx=0) + # Skip remainder of epoch if response is -1. + if response == -1: + return {"signal": -1} + + args = [batch, batch_idx] + if len(self.optimizers) > 1: + if self.has_arg("training_step", "optimizer_idx"): + args.append(0) + + with self.timers.record("fwd"): + if self._is_distributed: + # Use the DDP wrapped model (self.model). + output = self.model(*args) + elif self.use_gpu: + # Using single GPU. + # Don't copy the batch since there is a single gpu that + # the batch could be referenced from and if there are + # multiple optimizers the batch will wind up copying it to + # the same device repeatedly. + device = self.device + batch = model.transfer_batch_to_device(batch, device=device) + args[0] = batch + output = model.training_step(*args) + else: + # Using CPU. + output = model.training_step(*args) + + if isinstance(output, Result): + raise ValueError("TrainResult objects are not supported. Please " + "return a dictionary instead.") + + # allow any mode to define training_step_end + # do something will all the dp outputs (like softmax) + if self.is_overridden("training_step_end", model): + output = model.training_step_end(output) + + # Extract loss from output if dictionary. + try: + loss = output["loss"] + except Exception: + if isinstance(output, torch.Tensor): + loss = output + else: + raise RuntimeError( + "No `loss` value in the dictionary returned from " + "`model.training_step()`.") + + # If output contains tensors, detach them all. + if isinstance(output, torch.Tensor): + output = output.detach() + elif isinstance(output, dict): + output = recursive_detach(output) + else: + raise TypeError("training_step returned invalid type. It must " + "return either a Tensor, Result, or dict.") + + untouched_loss = loss.detach().clone() + + with self.timers.record("grad"): + if self.use_fp16: + with self._amp.scale_loss(loss, optimizer) as scaled_loss: + model.backward( + self, scaled_loss, optimizer, optimizer_idx=0) + else: + model.backward(self, loss, optimizer, optimizer_idx=0) + + if self.is_function_implemented("on_after_backward", model): + model.on_after_backward() + + with self.timers.record("apply"): + model.optimizer_step( + epoch=epoch_idx, + batch_idx=batch_idx, + optimizer=optimizer, + optimizer_idx=0) + + model.on_before_zero_grad(optimizer) + + model.optimizer_zero_grad( + epoch=epoch_idx, + batch_idx=batch_idx, + optimizer=optimizer, + optimizer_idx=0) + + if self.is_function_implemented("on_train_batch_end", model): + model.on_train_batch_end( + batch=batch, batch_idx=batch_idx, dataloader_idx=0) + + return { + "signal": 0, + "training_loss": untouched_loss.item(), + "raw_output": output, + # NUM_SAMPLES: len(batch) + } + + def validate(self, val_iterator, info): + self.model.zero_grad() + self.model.eval() + + torch.set_grad_enabled(False) + + model = self.get_model() + if self.is_function_implemented("on_validation_epoch_start", model): + model.on_validation_epoch_start() + + val_outputs = [] + for batch_idx, batch in enumerate(val_iterator): + batch_info = {"batch_idx": batch_idx} + batch_info.update(info) + batch_output = self.validate_batch(batch, batch_info) + if batch_output is not None: + val_outputs.append(batch_output) + + processed_outputs = None + if self.is_overridden("validation_epoch_end", model): + raw_outputs = [vo["raw_output"] for vo in val_outputs] + processed_outputs = model.training_epoch_end(raw_outputs) + + if processed_outputs is not None: + if isinstance(processed_outputs, torch.Tensor): + return_output = {"val_loss": processed_outputs} + elif isinstance(processed_outputs, Result): + raise ValueError("Result objects are not supported. Please " + "return a dictionary instead.") + elif isinstance(processed_outputs, dict): + return_output = processed_outputs + else: + raise TypeError("validation_epoch_end returned an invalid " + "type. It must return a Tensor, Result, " + "or dict.") + else: + # User did not override training_epoch_end + assert isinstance(val_outputs, list) + # Use AverageMeterCollection util to reduce results. + meter_collection = AverageMeterCollection() + for v in val_outputs: + num_samples = v.pop(NUM_SAMPLES, 1) + raw_output = v["raw_output"] + if isinstance(raw_output, dict): + meter_collection.update(raw_output, num_samples) + elif isinstance(raw_output, torch.Tensor): + meter_collection.update({ + "val_loss": raw_output.item() + }, num_samples) + return_output = meter_collection.summary() + + if self.is_function_implemented("on_validation_epoch_end", model): + model.on_validation_epoch_end() + + # Set back to True so training will work. + torch.set_grad_enabled(True) + + return return_output + + def validate_batch(self, batch, batch_info): + model = self.get_model() + batch_idx = batch_info["batch_idx"] + if self.is_overridden("on_validation_batch_start", model): + model.on_validation_batch_start( + batch=batch, batch_idx=batch_idx, dataloader_idx=0) + args = [batch, batch_idx] + with self.timers.record("eval_fwd"): + if self._is_distributed: + # Use the DDP wrapped model (self.model). + output = self.model(*args) + elif self.use_gpu: + # Using single GPU. + device = self.device + batch = model.transfer_batch_to_device(batch, device=device) + args[0] = batch + output = model.validation_step(*args) + else: + # Using CPU. + output = model.validation_step(*args) + + if isinstance(output, Result): + raise ValueError("EvalResult objects are not supported. Please " + "return a dictionary instead.") + + if self.is_overridden("on_validation_step_end", model): + output = model.validation_step_end(output) + + if self.is_function_implemented("on_validation_batch_end", model): + model.on_validation_batch_end( + batch=batch, batch_idx=batch_idx, dataloader_idx=0) + return { + "raw_output": output, + # NUM_SAMPLES: len(batch) + } + + def state_dict(self): + state_dict = {} + self.get_model().on_save_checkpoint(checkpoint=state_dict) + return state_dict + + def load_state_dict(self, state_dict): + self.get_model().on_load_checkpoint(checkpoint=state_dict) + + def _get_train_loader(self): + if not hasattr(self, "_train_loader") or \ + self._train_loader is None: + raise RuntimeError("Training Operator does not have any " + "registered training loader. Make sure " + "to pass in a training loader to " + "TrainingOperator.from_ptl or implement " + "train_dataloader in your LightningModule.") + return self._train_loader + + def _get_validation_loader(self): + if not hasattr(self, "_validation_loader") or \ + self._validation_loader is None: + raise RuntimeError("Training Operator does not have any " + "registered validation loader. Make sure " + "to pass in a validation loader to " + "TrainingOperator.from_ptl or implement " + "val_dataloader in your LightningModule.") + return self._validation_loader diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 3486458cb..34d737b1d 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -1,6 +1,8 @@ import logging import io import itertools + +import ray import torch from ray.util.sgd.torch.constants import USE_FP16, NUM_STEPS @@ -50,6 +52,8 @@ class TorchRunner: self.training_operator = self.training_operator_cls( self.config, world_rank=0, + local_rank=0, + is_distributed=False, use_gpu=self.use_gpu, use_fp16=self.use_fp16, use_tqdm=self.use_tqdm, @@ -69,6 +73,7 @@ class TorchRunner: info.update({ NUM_STEPS: num_steps, USE_FP16: self.use_fp16, + "epoch_idx": self.epochs, }) with self.timers.record("train_epoch"): if iterator is None: @@ -94,10 +99,6 @@ class TorchRunner: def validate(self, num_steps=None, profile=False, info=None): """Evaluates the model on the validation data set.""" - if self.validation_loader is None: - raise ValueError("No validation dataloader provided. Make sure" - "you pass in a validation_loader to " - "TrainingOperator.register_data.") info = info or {} self._toggle_profiling(profile=profile) validation_loader = self.validation_loader @@ -204,62 +205,31 @@ class TorchRunner: """Getter method. Needed for remote actor calls.""" return self.models + def get_node_ip(self): + return ray.services.get_node_ip_address() + @property def models(self): - if not hasattr(self.training_operator, "_original_models"): - raise RuntimeError("Training Operator does not have any " - "registered models. Are you calling " - "self.register(...) inside the setup method " - "of your Training Operator?") - return self.training_operator._original_models + return self.training_operator._get_original_models() @property def optimizers(self): - if not hasattr(self.training_operator, "_optimizers"): - raise RuntimeError("Training Operator does not have any " - "registered optimizers. Are you calling " - "self.register(...) inside the setup method " - "of your Training Operator?") - return self.training_operator._optimizers + return self.training_operator._get_optimizers() @property def schedulers(self): - if not hasattr(self.training_operator, "_schedulers"): - raise RuntimeError("Training Operator does not have any " - "registered schedulers. Are you calling " - "self.register(...) inside the setup method " - "of your Training Operator?") - return self.training_operator._schedulers + return self.training_operator._get_schedulers() @property def train_loader(self): - if not hasattr(self.training_operator, "_train_loader"): - logger.warning("Training Operator does not have any " - "registered train loader. If this is " - "unexepected, make sure to call " - "self.register_data(...) inside the setup method " - "of your Training Operator.") - return None - return self.training_operator._train_loader + return self.training_operator._get_train_loader() @property def validation_loader(self): - if not hasattr(self.training_operator, "_validation_loader"): - logger.warning("Training Operator does not have any " - "registered validation loader. If this is " - "unexepected, make sure to call " - "self.register_data(...) inside the setup method " - "of your Training Operator.") - return None - return self.training_operator._validation_loader + return self.training_operator._get_validation_loader() @property def criterion(self): - if not hasattr(self.training_operator, "_criterion"): - raise RuntimeError("Training Operator does not have any " - "registered criterion. Are you calling " - "self.register(...) inside the setup method " - "of your Training Operator?") return self.training_operator._criterion @property diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 9293eb548..6efa476a5 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -439,7 +439,7 @@ class TorchTrainer: } for stat_key in worker_stats[0]: - if isinstance(worker_stats[0], numbers.Number): + if isinstance(worker_stats[0][stat_key], numbers.Number): stats[stat_key] = np.nanmean( [s.get(stat_key, np.nan) for s in worker_stats]) else: diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index aad6db064..4b8b00cc2 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -14,6 +14,7 @@ from ray.util.sgd.torch.constants import ( NUM_STEPS, SCHEDULER_STEP_BATCH, ) + from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DistributedSampler, DataLoader, IterableDataset @@ -78,7 +79,6 @@ class TrainingOperator: train_loader=train_loader, validation_loader=val_loader) - trainer = TorchTrainer( training_operator_cls=MyTrainingOperator, config={"batch_size": 32}, @@ -119,18 +119,21 @@ class TrainingOperator: def __init__(self, config, world_rank, + local_rank, + is_distributed=False, device_ids=None, use_gpu=False, use_fp16=False, use_tqdm=False, apex_args=None, wrap_ddp=False, - wrap_distributed_sampler=False, add_dist_sampler=False, scheduler_step_freq=None): # You are not expected to override this method. self._world_rank = world_rank + self._local_rank = local_rank self._config = config + self._is_distributed = is_distributed self._use_fp16 = use_fp16 self._device_ids = device_ids self._use_gpu = use_gpu and torch.cuda.is_available() @@ -141,7 +144,6 @@ class TrainingOperator: self.global_step = 0 self._apex_args = apex_args if apex_args else {} self._wrap_ddp = wrap_ddp - self._wrap_distributed_sampler = wrap_distributed_sampler self._add_dist_sampler = add_dist_sampler self._scheduler_step_freq = scheduler_step_freq @@ -152,6 +154,28 @@ class TrainingOperator: """Passes in the timers from the Runner.""" self.timers = timers + def _configure_amp(self, amp, models, optimizers): + models, optimizers = amp.initialize(models, optimizers, + **self._apex_args) + return models, optimizers + + def _configure_ddp(self, models, device_ids): + return [ + DistributedDataParallel(model, device_ids=device_ids) + for model in models + ] + + def _return_items(self, items, original_items): + """Helper method to return items in same format as original_items.""" + if isinstance(original_items, tuple): + return tuple(items) + elif isinstance(original_items, Iterable): + # Items is already a list. + return items + else: + assert len(items) == 1 + return items[0] + def setup(self, config): """Override this method to implement operator setup. @@ -218,7 +242,6 @@ class TrainingOperator: Returns: Tuple of model, optimizer, criterion if not None, and scheduler if not None. - """ return_vals = [] logger.debug("Registering models.") @@ -244,7 +267,10 @@ class TrainingOperator: if not isinstance(self._schedulers, Iterable): self._schedulers = [self._schedulers] else: - self._schedulers = None + if isinstance(schedulers, Iterable): + self._schedulers = [] + else: + self._schedulers = None if criterion: logger.debug("Registering loss.") @@ -257,28 +283,19 @@ class TrainingOperator: if self.use_fp16 and amp: logger.debug("Setting up Apex.") - self._original_models, self._optimizers = amp.initialize( - self._original_models, self._optimizers, **self._apex_args) self._amp = amp + self._original_models, self._optimizers = self._configure_amp( + self._amp, self._original_models, self._optimizers) if self._wrap_ddp: logging.debug("Setting up DDP for models.") - self._models = [ - DistributedDataParallel(model, device_ids=self.device_ids) - for model in self._original_models - ] + self._models = self._configure_ddp( + models=self._original_models, device_ids=self.device_ids) else: self._models = self._original_models - if len(self._models) == 1: - return_vals.append(self._models[0]) - else: - return_vals.append(self._models) - - if len(self._optimizers) == 1: - return_vals.append(self._optimizers[0]) - else: - return_vals.append(self._optimizers) + return_vals.append(self._return_items(self._models, models)) + return_vals.append(self._return_items(self._optimizers, optimizers)) if self._criterion is not None: return_vals.append(self._criterion) @@ -290,10 +307,8 @@ class TrainingOperator: "are registering schedulers. Set this to " "'manual' if you will be manually stepping " "the schedulers.") - if len(self._schedulers) == 1: - return_vals.append(self._schedulers[0]) - else: - return_vals.append(self._schedulers) + return_vals.append( + self._return_items(self._schedulers, schedulers)) return tuple(return_vals) @@ -348,8 +363,7 @@ class TrainingOperator: self._train_loader = train_loader self._validation_loader = validation_loader - if self._wrap_distributed_sampler: - logging.debug("Wrapping data loaders with DistributedSampler.") + if self._is_distributed: def with_sampler(loader): # Automatically set the DistributedSampler @@ -373,11 +387,15 @@ class TrainingOperator: if should_wrap_dataloader(self._train_loader): if self._add_dist_sampler: + logging.debug("Wrapping train data loader with " + "DistributedSampler.") self._train_loader = with_sampler(self._train_loader) if self._validation_loader is not None and should_wrap_dataloader( self._validation_loader): if self._add_dist_sampler: + logging.debug("Wrapping validation data loader with " + "DistributedSampler.") self._validation_loader = with_sampler( self._validation_loader) @@ -665,6 +683,90 @@ class TrainingOperator: state_dict (dict): State dict as returned by the operator. """ pass + def _get_original_models(self): + if not hasattr(self, "_original_models"): + raise RuntimeError("Training Operator does not have any " + "registered models. Are you calling " + "self.register(...) inside the setup method " + "of your Training Operator?") + return self._original_models + + def _get_optimizers(self): + if not hasattr(self, "_optimizers"): + raise RuntimeError("Training Operator does not have any " + "registered optimizers. Are you calling " + "self.register(...) inside the setup method " + "of your Training Operator?") + return self._optimizers + + def _get_schedulers(self): + if not hasattr(self, "_schedulers"): + raise RuntimeError("Training Operator does not have any " + "registered schedulers. Are you calling " + "self.register(...) inside the setup method " + "of your Training Operator?") + return self._schedulers + + def _get_train_loader(self): + if not hasattr(self, "_train_loader") or \ + self._train_loader is None: + raise RuntimeError( + "Training Operator does not have any " + "registered train loader. If this is " + "unexepected, make sure to call " + "self.register_data(...) inside the setup method " + "of your Training Operator.") + return self._train_loader + + def _get_validation_loader(self): + if not hasattr(self, "_validation_loader") or \ + self._validation_loader is None: + raise RuntimeError( + "Training Operator does not have any " + "registered validation loader. If this is " + "unexepected, make sure to call " + "self.register_data(...) inside the setup method " + "of your Training Operator.") + return self._validation_loader + + def _get_criterion(self): + if not hasattr(self, "_criterion"): + raise RuntimeError("Training Operator does not have any " + "registered criterion. Are you calling " + "self.register(...) inside the setup method " + "of your Training Operator?") + return self._criterion + + @classmethod + def from_ptl(cls, + lightning_module_cls, + train_dataloader=None, + val_dataloader=None): + """Creates a TrainingOperator from a Pytorch Lightning Module. + + Args: + lightning_module_cls: Your LightningModule class. An object of + this class will get instantiated on each worker. + train_dataloader: The data loader to use for training. If None + is provided, LightningModule.train_dataloader will be used + instead. + val_dataloader: The data loader to use for validation. If None + is provided, LightningModule.val_dataloader will be used + instead. + + Returns: + A TrainingOperator class properly configured given the + LightningModule. + """ + from ray.util.sgd.torch.ptl_operator import LightningOperator + + class CustomLightningOperator(LightningOperator): + _lightning_module_cls = lightning_module_cls + _train_dataloader = train_dataloader + _val_dataloader = val_dataloader + + return CustomLightningOperator + @classmethod def from_creators(cls, model_creator, @@ -749,6 +851,11 @@ class TrainingOperator: """int: The rank of the parent runner. Always 0 if not distributed.""" return self._world_rank + @property + def local_rank(self): + """int: Local rank of parent runner. Always 0 if not distributed.""" + return self._local_rank + @property def use_gpu(self): """Returns True if cuda is available and use_gpu is True.""" @@ -849,29 +956,73 @@ class CreatorOperator(TrainingOperator): kwargs["criterion"] = criterion state = self.register(**kwargs) - self.models, self.optimizers = state[:2] - if isinstance(self.models, tuple): - self.model = self.models[0] + self._registered_models, self._registered_optimizers = state[:2] + if isinstance(self.models, (list, tuple)): + logger.info("Multiple models have been registered. If custom " + "training methods are not provided, only the first " + "model will be used.") + self._registered_model = self.models[0] else: - self.model = self.models + self._registered_model = self.models - if isinstance(self.optimizers, tuple): - self.optimizer = self.optimizers[0] + if isinstance(self.optimizers, (list, tuple)): + logger.info("Multiple optimizers have been registered. If custom " + "training methods are not provided, only the first " + "optimizer will be used.") + self._reigstered_optimizer = self.optimizers[0] else: - self.optimizer = self.optimizers + self._registered_optimizer = self.optimizers if len(state) >= 3: - self.criterion = state[2] + self._registered_criterion = state[2] if len(state) == 4: - self.schedulers = state[3] - if isinstance(self.schedulers, tuple): - self.scheduler = self.schedulers[0] + self._registered_schedulers = state[3] + if isinstance(self.schedulers, (list, tuple)): + logger.info("Multiple schedulers have been registered. If " + "custom training methods are not provided, " + "only the first scheduler will be used.") + self._registered_scheduler = self.schedulers[0] else: - self.scheduler = self.schedulers + self._registered_scheduler = self.schedulers self.register_data( train_loader=train_loader, validation_loader=validation_loader) + @property + def model(self): + """First or only model created by the provided ``model_creator``.""" + return self._registered_model + + @property + def optimizer(self): + """First or only optimizer(s) created by the ``optimizer_creator``.""" + return self._registered_optimizer + + @property + def scheduler(self): + """First or only scheduler(s) created by the ``scheduler_creator``.""" + return self._registered_scheduler + + @property + def criterion(self): + """Criterion created by the provided ``loss_creator``.""" + return self._registered_criterion + + @property + def models(self): + """List of models created by the provided ``model_creator``.""" + return self._registered_models + + @property + def optimizers(self): + """List of optimizers created by the ``optimizer_creator``.""" + return self._registered_optimizers + + @property + def schedulers(self): + """List of schedulers created by the ``scheduler_creator``.""" + return self._registered_schedulers + def get_test_operator(operator_cls): class _TestingOperator(operator_cls): diff --git a/python/ray/util/sgd/torch/worker_group.py b/python/ray/util/sgd/torch/worker_group.py index 626ea9f78..9a42bf46b 100644 --- a/python/ray/util/sgd/torch/worker_group.py +++ b/python/ray/util/sgd/torch/worker_group.py @@ -1,6 +1,7 @@ import io import logging import time +from collections import defaultdict from datetime import timedelta import ray @@ -191,6 +192,19 @@ class RemoteWorkerGroup(WorkerGroupInterface): ] return remote_operator_setups + def _setup_local_rank(self, rank_counter_dict=None): + """Sets local rank for all workers.""" + if rank_counter_dict is None: + rank_counter_dict = defaultdict(int) + node_ips = ray.get( + [w.get_node_ip.remote() for w in self.remote_workers]) + futures = [] + for ip, worker in zip(node_ips, self.remote_workers): + rank = rank_counter_dict[ip] + futures.append(worker.set_local_rank.remote(rank)) + rank_counter_dict[ip] += 1 + return futures + def start_workers(self, num_workers): logger.debug(f"start_workers: Setting %d workers." % num_workers) if num_workers == 1: @@ -212,6 +226,8 @@ class RemoteWorkerGroup(WorkerGroupInterface): self._setup_process_group( address=address, world_size=num_workers)) + ray.get(self._setup_local_rank()) + ray.get(self._setup_operator()) def _apply_all_operators(self, fn): @@ -454,6 +470,12 @@ class LocalWorkerGroup(WorkerGroupInterface): timeout=timedelta(self._timeout_s)) ray.get(remote_pgs) + local_node_ip = ray.services.get_node_ip_address() + rank_dict = defaultdict(int) + self.local_worker.set_local_rank(local_rank=0) + rank_dict[local_node_ip] += 1 + self.remote_worker_group._setup_local_rank(rank_dict) + remote_operators = self.remote_worker_group._setup_operator() self.local_worker.setup_operator() ray.get(remote_operators)