[RaySGD] Add tqdm logging to TorchTrainer (#7588)

* Update issue templates

* Init fp16

* fp16 and schedulers

* scheduler linking and fp16

* to fp16

* loss scaling and documentation

* more documentation

* add tests, refactor config

* moredocs

* more docs

* fix logo, add test mode, add fp16 flag

* fix tests

* fix scheduler

* fix apex

* improve safety

* fix tests

* fix tests

* remove pin memory default

* rm

* fix

* Update doc/examples/doc_code/raysgd_torch_signatures.py

* fix

* migrate changes from other PR

* ok thanks

* pass

* signatures

* lint'

* Update python/ray/experimental/sgd/pytorch/utils.py

* Apply suggestions from code review

Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com>

* should address most comments

* comments

* fix this ci

* first_pass

* add overrides

* override

* fixing up operators

* format

* sgd

* constants

* rm

* revert

* Checkpoint the basics

* End of day checkpoint

* Checkpoint log-to-head implementation

* Checkpoint

* Add actor-based batch log reporting, currently segfaults

* Work around progress segfault

* Fix some stuff in quicktorch

* Make things more customizable

* Quality of life fixes

* More quality of life

* Move tqdm logic to training_operator

* Update examples

* Fix some minor bugs

* Fix merge

* Fix small things, add pbar to dcgan

* Run format.sh

* Fix missing epoch number for batch pbar

* Address PR comments

* Fix float is not subscriptable

* Add train_loss to pbar by default

* Isolate tqdm code into a handler system

* Format

* Remove the batch_logs_reporter from distributed runner as well

* Check if the train_loss is avaialbale before using it

* Enable tqdm in the dcgan example

* Fix a crash in no-handler trainers

* Fix

* Allow not calling set_reporters for tests

Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
This commit is contained in:
Maksim Smolin
2020-03-24 23:43:56 -07:00
committed by GitHub
parent 54a892bb84
commit e95455b7d7
8 changed files with 238 additions and 15 deletions
+1
View File
@@ -2,5 +2,6 @@ USE_FP16 = "__use_fp16__"
SCHEDULER_STEP = "scheduler_step"
SCHEDULER_STEP_BATCH = "batch"
SCHEDULER_STEP_EPOCH = "epoch"
BATCH_LOGS_RATE_LIMIT = .2
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
@@ -80,6 +80,9 @@ class DistributedTorchRunner(TorchRunner):
models=self.models,
optimizers=self.optimizers,
criterion=self.criterion,
train_loader=self.train_loader,
validation_loader=self.validation_loader,
world_rank=self.world_rank,
schedulers=self.schedulers,
use_fp16=self.use_fp16)
@@ -7,6 +7,8 @@ from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
from tqdm import trange
import ray
from ray.util.sgd.torch import (TorchTrainer, TorchTrainable)
from ray.util.sgd.torch.resnet import ResNet18
@@ -82,11 +84,16 @@ def train_example(num_workers=1,
use_gpu=use_gpu,
backend="nccl" if use_gpu else "gloo",
scheduler_step_freq="epoch",
use_fp16=use_fp16)
for i in range(num_epochs):
use_fp16=use_fp16,
tqdm=True)
pbar = trange(num_epochs, unit="epoch")
for i in pbar:
info = {"num_steps": 1} if test_mode else {}
info["epoch_idx"] = i
info["num_epochs"] = num_epochs
# Increase `max_retries` to turn on fault tolerance.
stats = trainer1.train(max_retries=0)
print(stats)
stats = trainer1.train(max_retries=0, info=info)
pbar.set_postfix(dict(loss=stats["mean_train_loss"]))
print(trainer1.validate())
trainer1.shutdown()
+11 -4
View File
@@ -10,6 +10,8 @@ import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
from tqdm import trange
from torch.autograd import Variable
from torch.nn import functional as F
from scipy.stats import entropy
@@ -240,15 +242,20 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
num_workers=num_workers,
config=config,
use_gpu=use_gpu,
backend="nccl" if use_gpu else "gloo")
backend="nccl" if use_gpu else "gloo",
tqdm=True)
from tabulate import tabulate
for itr in range(5):
stats = trainer.train()
pbar = trange(5, unit="epoch")
for itr in pbar:
stats = trainer.train(info=dict(epoch_idx=itr, num_epochs=5))
pbar.set_postfix(
dict(loss_g=stats["mean_loss_g"], loss_d=stats["mean_loss_d"]))
formatted = tabulate([stats], headers="keys")
if itr > 0: # Get the last line of the stats.
formatted = formatted.split("\n")[-1]
print(formatted)
pbar.write(formatted)
return trainer
@@ -133,6 +133,9 @@ class TorchRunner:
self.models, self.optimizers = amp.initialize(
self.models, self.optimizers, **self.apex_args)
def set_reporters(self, reporters):
return self.training_operator.set_reporters(reporters)
def setup(self):
"""Initializes the model."""
logger.debug("Creating model")
@@ -156,6 +159,9 @@ class TorchRunner:
models=self.models,
optimizers=self.optimizers,
criterion=self.criterion,
train_loader=self.train_loader,
validation_loader=self.validation_loader,
world_rank=0,
schedulers=self.schedulers,
use_fp16=self.use_fp16)
+53 -7
View File
@@ -4,19 +4,22 @@ import logging
import numbers
import tempfile
import time
import asyncio
import torch
import torch.distributed as dist
import ray
from ray.exceptions import RayActorError
from ray.tune import Trainable
from ray.tune.trial import Resources
from ray.util.sgd.torch.distributed_torch_runner import (
DistributedTorchRunner)
from ray.util.sgd import utils
from ray.util.sgd.utils import NUM_SAMPLES, BATCH_SIZE
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
from ray.util.sgd.torch.torch_runner import TorchRunner
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
from ray.util.sgd.torch.constants import (VALID_SCHEDULER_STEP,
BATCH_LOGS_RATE_LIMIT)
from ray.util.sgd.torch.tqdm_handler import TqdmHandler
logger = logging.getLogger(__name__)
RESIZE_COOLDOWN_S = 10
@@ -146,6 +149,7 @@ class TorchTrainer:
use_gpu=False,
backend="auto",
use_fp16=False,
tqdm=False,
apex_args=None,
scheduler_step_freq="batch",
num_replicas=None,
@@ -217,6 +221,10 @@ class TorchTrainer:
self._num_failures = 0
self._last_resize = float("-inf")
self.handlers = []
if tqdm:
self.handlers.append(TqdmHandler())
_validate_scheduler_step_freq(scheduler_step_freq)
self.scheduler_step_freq = scheduler_step_freq
@@ -271,6 +279,8 @@ class TorchTrainer:
self.apply_all_workers(self.initialization_hook)
# Get setup tasks in order to throw errors on failure
ray.get(self.workers[0].setup.remote())
ray.get(self.workers[0].set_reporters.remote(
[h.create_reporter() for h in self.handlers]))
else:
# Generate actor class
Runner = ray.remote(
@@ -303,6 +313,11 @@ class TorchTrainer:
worker.setup.remote(address, i, len(self.workers))
for i, worker in enumerate(self.workers)
])
ray.get([
w.set_reporters.remote(
[h.create_reporter() for h in self.handlers])
for w in self.workers
])
def train(self,
num_steps=None,
@@ -359,6 +374,9 @@ class TorchTrainer:
logger.info("Resize opportunity detected. Attempting to scale up.")
self._resize_workers(checkpoint=checkpoint)
for h in self.handlers:
h.record_train_info(info, num_steps)
success, worker_stats = self._train_epoch(
num_steps=num_steps, profile=profile, info=info)
# Fault handling
@@ -395,14 +413,42 @@ class TorchTrainer:
stats[stat_key] = worker_stats[0][stat_key]
return stats
def _train_epoch(self, num_steps=None, profile=False, info=None):
worker_stats = [
def _train_epoch(self,
num_steps=None,
profile=False,
info=None,
batch_logs_handler=None):
worker_trains = [
w.train_epoch.remote(
num_steps=num_steps, profile=profile, info=info)
for w in self.workers
]
success = utils.check_for_failure(worker_stats)
return success, worker_stats
if not self.handlers:
success = check_for_failure(worker_trains)
return success, worker_trains
unfinished = worker_trains
try:
while len(unfinished) > 0:
finished, unfinished = ray.wait(
unfinished, timeout=BATCH_LOGS_RATE_LIMIT)
# throw errors on agent failure
finished = ray.get(finished)
futures = [h.update() for h in self.handlers]
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(asyncio.wait(futures))
loop.close()
return True, worker_trains
except RayActorError as exc:
logger.exception(str(exc))
return False, worker_trains
def apply_all_workers(self, fn):
"""Run a function on all operators on the workers.
+116
View File
@@ -0,0 +1,116 @@
import asyncio
import time
from tqdm import tqdm
import ray
from ray.util.sgd.torch.constants import BATCH_LOGS_RATE_LIMIT
@ray.remote(num_cpus=0)
class _ReporterActor:
def __init__(self):
# we need the new_data field to allow sending back None as the legs
self._logs = {"new_data": False, "data": None}
self._setup = {"new_data": False, "data": None}
def _send_setup(self, data):
self._setup = {"new_data": True, "data": data}
def _send_logs(self, data):
self._logs = {"new_data": True, "data": data}
def _read_logs(self):
res = self._logs
self._logs = {"new_data": False, "data": None}
return res
def _read_setup(self):
res = self._setup
self._setup = {"new_data": False, "data": None}
return res
class TqdmReporter:
def __init__(self, actor):
self.actor = actor
self.last_packet_time = 0
def _send_setup(self, packet):
ray.get(self.actor._send_setup.remote(packet))
def _send_logs(self, packet):
cur_time = time.monotonic()
if cur_time - self.last_packet_time < BATCH_LOGS_RATE_LIMIT:
return
self.last_packet_time = cur_time
ray.get(self.actor._send_logs.remote(packet))
def on_epoch_begin(self, info, training_op):
if training_op.world_rank != 0:
return
self.last_packet_time = 0
self._send_setup({"loader_len": len(training_op.train_loader)})
def on_batch_end(self, batch_info, metrics, training_op):
if training_op.world_rank != 0:
return
pbar_metrics = {}
if "train_loss" in metrics:
pbar_metrics["loss"] = metrics["train_loss"]
self._send_logs({
"batch_idx": batch_info["batch_idx"],
"pbar_metrics": pbar_metrics
})
class TqdmHandler:
def __init__(self):
self.batch_pbar = None
self.reporter_actor = _ReporterActor.remote()
def create_reporter(self):
return TqdmReporter(self.reporter_actor)
def handle_setup_packet(self, packet):
n = self.num_steps
if n is None:
n = packet["loader_len"]
desc = ""
if self.train_info is not None and "epoch_idx" in self.train_info:
if "num_epochs" in self.train_info:
desc = "{}/{}e".format(self.train_info["epoch_idx"] + 1,
self.train_info["num_epochs"])
else:
desc = "{}e".format(self.train_info["epoch_idx"] + 1)
self.batch_pbar = tqdm(total=n, desc=desc, unit="batch", leave=False)
def handle_logs_packet(self, packet):
self.batch_pbar.n = packet["batch_idx"] + 1
self.batch_pbar.set_postfix(packet["pbar_metrics"])
def record_train_info(self, info, num_steps):
self.train_info = info
self.num_steps = num_steps
async def update(self):
setup_read, logs_read = await asyncio.gather(
self.reporter_actor._read_setup.remote(),
self.reporter_actor._read_logs.remote())
if setup_read["new_data"]:
self.handle_setup_packet(setup_read["data"])
if logs_read["new_data"]:
self.handle_logs_packet(logs_read["data"])
@@ -1,4 +1,5 @@
import collections
import torch
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
@@ -49,6 +50,9 @@ class TrainingOperator:
config,
models,
optimizers,
train_loader,
validation_loader,
world_rank,
criterion=None,
schedulers=None,
use_fp16=False):
@@ -59,6 +63,9 @@ class TrainingOperator:
self._optimizers = optimizers # List of optimizers
assert isinstance(optimizers, collections.Iterable), (
"Components need to be iterable. Got: {}".format(type(optimizers)))
self._train_loader = train_loader
self._validation_loader = validation_loader
self._world_rank = world_rank
self._criterion = criterion
self._schedulers = schedulers
if schedulers:
@@ -77,8 +84,12 @@ class TrainingOperator:
"TrainingOperator if using multi-scheduler, "
"multi-model or multi-optimizer training/validation.")
self.timers = TimerCollection()
self.reporters = []
self.setup(config)
def set_reporters(self, reporters):
self.reporters = reporters
def _set_timers(self, timers):
"""Passes in the timers from the Runner."""
self.timers = timers
@@ -131,6 +142,9 @@ class TrainingOperator:
Returns:
A dict of metrics from training.
"""
for r in self.reporters:
r.on_epoch_begin(info, self)
metric_meters = AverageMeterCollection()
self.model.train()
@@ -142,6 +156,9 @@ class TrainingOperator:
batch_info.update(info)
metrics = self.train_batch(batch, batch_info=batch_info)
for r in self.reporters:
r.on_batch_end(batch_info, metrics, self)
if self.scheduler and batch_info.get(
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
self.scheduler.step()
@@ -209,6 +226,7 @@ class TrainingOperator:
# Call step of optimizer to update model params.
with self.timers.record("apply"):
self.optimizer.step()
return {"train_loss": loss.item(), NUM_SAMPLES: features.size(0)}
def validate(self, val_iterator, info):
@@ -318,6 +336,25 @@ class TrainingOperator:
"""List of optimizers created by the ``optimizer_creator``."""
return self._optimizers
@property
def train_loader(self):
"""
Data loader for the validation dataset created by the ``data_creator``.
"""
return self._train_loader
@property
def validation_loader(self):
"""
Data loader for the train dataset created by the ``data_creator``.
"""
return self._validation_loader
@property
def world_rank(self):
"""The rank of the parent runner. Always 0 if not distributed."""
return self._world_rank
@property
def criterion(self):
"""Criterion created by the provided ``loss_creator``."""