From e95455b7d7300e3b9fe13cf0827a7ec0f5ea5332 Mon Sep 17 00:00:00 2001 From: Maksim Smolin Date: Tue, 24 Mar 2020 23:43:56 -0700 Subject: [PATCH] [RaySGD] Add tqdm logging to TorchTrainer (#7588) * Update issue templates * Init fp16 * fp16 and schedulers * scheduler linking and fp16 * to fp16 * loss scaling and documentation * more documentation * add tests, refactor config * moredocs * more docs * fix logo, add test mode, add fp16 flag * fix tests * fix scheduler * fix apex * improve safety * fix tests * fix tests * remove pin memory default * rm * fix * Update doc/examples/doc_code/raysgd_torch_signatures.py * fix * migrate changes from other PR * ok thanks * pass * signatures * lint' * Update python/ray/experimental/sgd/pytorch/utils.py * Apply suggestions from code review Co-Authored-By: Edward Oakes * should address most comments * comments * fix this ci * first_pass * add overrides * override * fixing up operators * format * sgd * constants * rm * revert * Checkpoint the basics * End of day checkpoint * Checkpoint log-to-head implementation * Checkpoint * Add actor-based batch log reporting, currently segfaults * Work around progress segfault * Fix some stuff in quicktorch * Make things more customizable * Quality of life fixes * More quality of life * Move tqdm logic to training_operator * Update examples * Fix some minor bugs * Fix merge * Fix small things, add pbar to dcgan * Run format.sh * Fix missing epoch number for batch pbar * Address PR comments * Fix float is not subscriptable * Add train_loss to pbar by default * Isolate tqdm code into a handler system * Format * Remove the batch_logs_reporter from distributed runner as well * Check if the train_loss is avaialbale before using it * Enable tqdm in the dcgan example * Fix a crash in no-handler trainers * Fix * Allow not calling set_reporters for tests Co-authored-by: Philipp Moritz Co-authored-by: Richard Liaw Co-authored-by: Edward Oakes --- python/ray/util/sgd/torch/constants.py | 1 + .../sgd/torch/distributed_torch_runner.py | 3 + .../torch/examples/cifar_pytorch_example.py | 15 ++- python/ray/util/sgd/torch/examples/dcgan.py | 15 ++- python/ray/util/sgd/torch/torch_runner.py | 6 + python/ray/util/sgd/torch/torch_trainer.py | 60 +++++++-- python/ray/util/sgd/torch/tqdm_handler.py | 116 ++++++++++++++++++ .../ray/util/sgd/torch/training_operator.py | 37 ++++++ 8 files changed, 238 insertions(+), 15 deletions(-) create mode 100644 python/ray/util/sgd/torch/tqdm_handler.py diff --git a/python/ray/util/sgd/torch/constants.py b/python/ray/util/sgd/torch/constants.py index 53e27d3f2..2f8e8fcc8 100644 --- a/python/ray/util/sgd/torch/constants.py +++ b/python/ray/util/sgd/torch/constants.py @@ -2,5 +2,6 @@ USE_FP16 = "__use_fp16__" SCHEDULER_STEP = "scheduler_step" SCHEDULER_STEP_BATCH = "batch" SCHEDULER_STEP_EPOCH = "epoch" +BATCH_LOGS_RATE_LIMIT = .2 VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH} diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index 5c588ed80..6f051647c 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -80,6 +80,9 @@ class DistributedTorchRunner(TorchRunner): models=self.models, optimizers=self.optimizers, criterion=self.criterion, + train_loader=self.train_loader, + validation_loader=self.validation_loader, + world_rank=self.world_rank, schedulers=self.schedulers, use_fp16=self.use_fp16) diff --git a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py index 759b26b23..642155a9e 100644 --- a/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py +++ b/python/ray/util/sgd/torch/examples/cifar_pytorch_example.py @@ -7,6 +7,8 @@ from torch.utils.data import DataLoader, Subset import torchvision import torchvision.transforms as transforms +from tqdm import trange + import ray from ray.util.sgd.torch import (TorchTrainer, TorchTrainable) from ray.util.sgd.torch.resnet import ResNet18 @@ -82,11 +84,16 @@ def train_example(num_workers=1, use_gpu=use_gpu, backend="nccl" if use_gpu else "gloo", scheduler_step_freq="epoch", - use_fp16=use_fp16) - for i in range(num_epochs): + use_fp16=use_fp16, + tqdm=True) + pbar = trange(num_epochs, unit="epoch") + for i in pbar: + info = {"num_steps": 1} if test_mode else {} + info["epoch_idx"] = i + info["num_epochs"] = num_epochs # Increase `max_retries` to turn on fault tolerance. - stats = trainer1.train(max_retries=0) - print(stats) + stats = trainer1.train(max_retries=0, info=info) + pbar.set_postfix(dict(loss=stats["mean_train_loss"])) print(trainer1.validate()) trainer1.shutdown() diff --git a/python/ray/util/sgd/torch/examples/dcgan.py b/python/ray/util/sgd/torch/examples/dcgan.py index d25a0e6a7..6fcfda809 100644 --- a/python/ray/util/sgd/torch/examples/dcgan.py +++ b/python/ray/util/sgd/torch/examples/dcgan.py @@ -10,6 +10,8 @@ import torchvision.datasets as datasets import torchvision.transforms as transforms import numpy as np +from tqdm import trange + from torch.autograd import Variable from torch.nn import functional as F from scipy.stats import entropy @@ -240,15 +242,20 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False): num_workers=num_workers, config=config, use_gpu=use_gpu, - backend="nccl" if use_gpu else "gloo") + backend="nccl" if use_gpu else "gloo", + tqdm=True) from tabulate import tabulate - for itr in range(5): - stats = trainer.train() + pbar = trange(5, unit="epoch") + for itr in pbar: + stats = trainer.train(info=dict(epoch_idx=itr, num_epochs=5)) + pbar.set_postfix( + dict(loss_g=stats["mean_loss_g"], loss_d=stats["mean_loss_d"])) + formatted = tabulate([stats], headers="keys") if itr > 0: # Get the last line of the stats. formatted = formatted.split("\n")[-1] - print(formatted) + pbar.write(formatted) return trainer diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 022801e7d..0a784748e 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -133,6 +133,9 @@ class TorchRunner: self.models, self.optimizers = amp.initialize( self.models, self.optimizers, **self.apex_args) + def set_reporters(self, reporters): + return self.training_operator.set_reporters(reporters) + def setup(self): """Initializes the model.""" logger.debug("Creating model") @@ -156,6 +159,9 @@ class TorchRunner: models=self.models, optimizers=self.optimizers, criterion=self.criterion, + train_loader=self.train_loader, + validation_loader=self.validation_loader, + world_rank=0, schedulers=self.schedulers, use_fp16=self.use_fp16) diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 269cab493..0dc6c0c88 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -4,19 +4,22 @@ import logging import numbers import tempfile import time +import asyncio import torch import torch.distributed as dist import ray +from ray.exceptions import RayActorError from ray.tune import Trainable from ray.tune.trial import Resources from ray.util.sgd.torch.distributed_torch_runner import ( DistributedTorchRunner) -from ray.util.sgd import utils -from ray.util.sgd.utils import NUM_SAMPLES, BATCH_SIZE +from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE from ray.util.sgd.torch.torch_runner import TorchRunner -from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP +from ray.util.sgd.torch.constants import (VALID_SCHEDULER_STEP, + BATCH_LOGS_RATE_LIMIT) +from ray.util.sgd.torch.tqdm_handler import TqdmHandler logger = logging.getLogger(__name__) RESIZE_COOLDOWN_S = 10 @@ -146,6 +149,7 @@ class TorchTrainer: use_gpu=False, backend="auto", use_fp16=False, + tqdm=False, apex_args=None, scheduler_step_freq="batch", num_replicas=None, @@ -217,6 +221,10 @@ class TorchTrainer: self._num_failures = 0 self._last_resize = float("-inf") + self.handlers = [] + if tqdm: + self.handlers.append(TqdmHandler()) + _validate_scheduler_step_freq(scheduler_step_freq) self.scheduler_step_freq = scheduler_step_freq @@ -271,6 +279,8 @@ class TorchTrainer: self.apply_all_workers(self.initialization_hook) # Get setup tasks in order to throw errors on failure ray.get(self.workers[0].setup.remote()) + ray.get(self.workers[0].set_reporters.remote( + [h.create_reporter() for h in self.handlers])) else: # Generate actor class Runner = ray.remote( @@ -303,6 +313,11 @@ class TorchTrainer: worker.setup.remote(address, i, len(self.workers)) for i, worker in enumerate(self.workers) ]) + ray.get([ + w.set_reporters.remote( + [h.create_reporter() for h in self.handlers]) + for w in self.workers + ]) def train(self, num_steps=None, @@ -359,6 +374,9 @@ class TorchTrainer: logger.info("Resize opportunity detected. Attempting to scale up.") self._resize_workers(checkpoint=checkpoint) + for h in self.handlers: + h.record_train_info(info, num_steps) + success, worker_stats = self._train_epoch( num_steps=num_steps, profile=profile, info=info) # Fault handling @@ -395,14 +413,42 @@ class TorchTrainer: stats[stat_key] = worker_stats[0][stat_key] return stats - def _train_epoch(self, num_steps=None, profile=False, info=None): - worker_stats = [ + def _train_epoch(self, + num_steps=None, + profile=False, + info=None, + batch_logs_handler=None): + worker_trains = [ w.train_epoch.remote( num_steps=num_steps, profile=profile, info=info) for w in self.workers ] - success = utils.check_for_failure(worker_stats) - return success, worker_stats + + if not self.handlers: + success = check_for_failure(worker_trains) + return success, worker_trains + + unfinished = worker_trains + try: + while len(unfinished) > 0: + finished, unfinished = ray.wait( + unfinished, timeout=BATCH_LOGS_RATE_LIMIT) + + # throw errors on agent failure + finished = ray.get(finished) + + futures = [h.update() for h in self.handlers] + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(asyncio.wait(futures)) + loop.close() + + return True, worker_trains + except RayActorError as exc: + logger.exception(str(exc)) + return False, worker_trains def apply_all_workers(self, fn): """Run a function on all operators on the workers. diff --git a/python/ray/util/sgd/torch/tqdm_handler.py b/python/ray/util/sgd/torch/tqdm_handler.py new file mode 100644 index 000000000..7e25cac4f --- /dev/null +++ b/python/ray/util/sgd/torch/tqdm_handler.py @@ -0,0 +1,116 @@ +import asyncio +import time + +from tqdm import tqdm + +import ray +from ray.util.sgd.torch.constants import BATCH_LOGS_RATE_LIMIT + + +@ray.remote(num_cpus=0) +class _ReporterActor: + def __init__(self): + # we need the new_data field to allow sending back None as the legs + self._logs = {"new_data": False, "data": None} + self._setup = {"new_data": False, "data": None} + + def _send_setup(self, data): + self._setup = {"new_data": True, "data": data} + + def _send_logs(self, data): + self._logs = {"new_data": True, "data": data} + + def _read_logs(self): + res = self._logs + + self._logs = {"new_data": False, "data": None} + + return res + + def _read_setup(self): + res = self._setup + + self._setup = {"new_data": False, "data": None} + + return res + + +class TqdmReporter: + def __init__(self, actor): + self.actor = actor + + self.last_packet_time = 0 + + def _send_setup(self, packet): + ray.get(self.actor._send_setup.remote(packet)) + + def _send_logs(self, packet): + cur_time = time.monotonic() + if cur_time - self.last_packet_time < BATCH_LOGS_RATE_LIMIT: + return + + self.last_packet_time = cur_time + ray.get(self.actor._send_logs.remote(packet)) + + def on_epoch_begin(self, info, training_op): + if training_op.world_rank != 0: + return + + self.last_packet_time = 0 + + self._send_setup({"loader_len": len(training_op.train_loader)}) + + def on_batch_end(self, batch_info, metrics, training_op): + if training_op.world_rank != 0: + return + + pbar_metrics = {} + if "train_loss" in metrics: + pbar_metrics["loss"] = metrics["train_loss"] + + self._send_logs({ + "batch_idx": batch_info["batch_idx"], + "pbar_metrics": pbar_metrics + }) + + +class TqdmHandler: + def __init__(self): + self.batch_pbar = None + self.reporter_actor = _ReporterActor.remote() + + def create_reporter(self): + return TqdmReporter(self.reporter_actor) + + def handle_setup_packet(self, packet): + n = self.num_steps + if n is None: + n = packet["loader_len"] + + desc = "" + if self.train_info is not None and "epoch_idx" in self.train_info: + if "num_epochs" in self.train_info: + desc = "{}/{}e".format(self.train_info["epoch_idx"] + 1, + self.train_info["num_epochs"]) + else: + desc = "{}e".format(self.train_info["epoch_idx"] + 1) + + self.batch_pbar = tqdm(total=n, desc=desc, unit="batch", leave=False) + + def handle_logs_packet(self, packet): + self.batch_pbar.n = packet["batch_idx"] + 1 + self.batch_pbar.set_postfix(packet["pbar_metrics"]) + + def record_train_info(self, info, num_steps): + self.train_info = info + self.num_steps = num_steps + + async def update(self): + setup_read, logs_read = await asyncio.gather( + self.reporter_actor._read_setup.remote(), + self.reporter_actor._read_logs.remote()) + + if setup_read["new_data"]: + self.handle_setup_packet(setup_read["data"]) + if logs_read["new_data"]: + self.handle_logs_packet(logs_read["data"]) diff --git a/python/ray/util/sgd/torch/training_operator.py b/python/ray/util/sgd/torch/training_operator.py index 0a32f1653..c641be579 100644 --- a/python/ray/util/sgd/torch/training_operator.py +++ b/python/ray/util/sgd/torch/training_operator.py @@ -1,4 +1,5 @@ import collections + import torch from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection, @@ -49,6 +50,9 @@ class TrainingOperator: config, models, optimizers, + train_loader, + validation_loader, + world_rank, criterion=None, schedulers=None, use_fp16=False): @@ -59,6 +63,9 @@ class TrainingOperator: self._optimizers = optimizers # List of optimizers assert isinstance(optimizers, collections.Iterable), ( "Components need to be iterable. Got: {}".format(type(optimizers))) + self._train_loader = train_loader + self._validation_loader = validation_loader + self._world_rank = world_rank self._criterion = criterion self._schedulers = schedulers if schedulers: @@ -77,8 +84,12 @@ class TrainingOperator: "TrainingOperator if using multi-scheduler, " "multi-model or multi-optimizer training/validation.") self.timers = TimerCollection() + self.reporters = [] self.setup(config) + def set_reporters(self, reporters): + self.reporters = reporters + def _set_timers(self, timers): """Passes in the timers from the Runner.""" self.timers = timers @@ -131,6 +142,9 @@ class TrainingOperator: Returns: A dict of metrics from training. """ + for r in self.reporters: + r.on_epoch_begin(info, self) + metric_meters = AverageMeterCollection() self.model.train() @@ -142,6 +156,9 @@ class TrainingOperator: batch_info.update(info) metrics = self.train_batch(batch, batch_info=batch_info) + for r in self.reporters: + r.on_batch_end(batch_info, metrics, self) + if self.scheduler and batch_info.get( SCHEDULER_STEP) == SCHEDULER_STEP_BATCH: self.scheduler.step() @@ -209,6 +226,7 @@ class TrainingOperator: # Call step of optimizer to update model params. with self.timers.record("apply"): self.optimizer.step() + return {"train_loss": loss.item(), NUM_SAMPLES: features.size(0)} def validate(self, val_iterator, info): @@ -318,6 +336,25 @@ class TrainingOperator: """List of optimizers created by the ``optimizer_creator``.""" return self._optimizers + @property + def train_loader(self): + """ + Data loader for the validation dataset created by the ``data_creator``. + """ + return self._train_loader + + @property + def validation_loader(self): + """ + Data loader for the train dataset created by the ``data_creator``. + """ + return self._validation_loader + + @property + def world_rank(self): + """The rank of the parent runner. Always 0 if not distributed.""" + return self._world_rank + @property def criterion(self): """Criterion created by the provided ``loss_creator``."""