mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 21:12:15 +08:00
[RaySGD] Add tqdm logging to TorchTrainer (#7588)
* Update issue templates * Init fp16 * fp16 and schedulers * scheduler linking and fp16 * to fp16 * loss scaling and documentation * more documentation * add tests, refactor config * moredocs * more docs * fix logo, add test mode, add fp16 flag * fix tests * fix scheduler * fix apex * improve safety * fix tests * fix tests * remove pin memory default * rm * fix * Update doc/examples/doc_code/raysgd_torch_signatures.py * fix * migrate changes from other PR * ok thanks * pass * signatures * lint' * Update python/ray/experimental/sgd/pytorch/utils.py * Apply suggestions from code review Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com> * should address most comments * comments * fix this ci * first_pass * add overrides * override * fixing up operators * format * sgd * constants * rm * revert * Checkpoint the basics * End of day checkpoint * Checkpoint log-to-head implementation * Checkpoint * Add actor-based batch log reporting, currently segfaults * Work around progress segfault * Fix some stuff in quicktorch * Make things more customizable * Quality of life fixes * More quality of life * Move tqdm logic to training_operator * Update examples * Fix some minor bugs * Fix merge * Fix small things, add pbar to dcgan * Run format.sh * Fix missing epoch number for batch pbar * Address PR comments * Fix float is not subscriptable * Add train_loss to pbar by default * Isolate tqdm code into a handler system * Format * Remove the batch_logs_reporter from distributed runner as well * Check if the train_loss is avaialbale before using it * Enable tqdm in the dcgan example * Fix a crash in no-handler trainers * Fix * Allow not calling set_reporters for tests Co-authored-by: Philipp Moritz <pcmoritz@gmail.com> Co-authored-by: Richard Liaw <rliaw@berkeley.edu> Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
This commit is contained in:
@@ -2,5 +2,6 @@ USE_FP16 = "__use_fp16__"
|
||||
SCHEDULER_STEP = "scheduler_step"
|
||||
SCHEDULER_STEP_BATCH = "batch"
|
||||
SCHEDULER_STEP_EPOCH = "epoch"
|
||||
BATCH_LOGS_RATE_LIMIT = .2
|
||||
|
||||
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
|
||||
|
||||
@@ -80,6 +80,9 @@ class DistributedTorchRunner(TorchRunner):
|
||||
models=self.models,
|
||||
optimizers=self.optimizers,
|
||||
criterion=self.criterion,
|
||||
train_loader=self.train_loader,
|
||||
validation_loader=self.validation_loader,
|
||||
world_rank=self.world_rank,
|
||||
schedulers=self.schedulers,
|
||||
use_fp16=self.use_fp16)
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from torch.utils.data import DataLoader, Subset
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch import (TorchTrainer, TorchTrainable)
|
||||
from ray.util.sgd.torch.resnet import ResNet18
|
||||
@@ -82,11 +84,16 @@ def train_example(num_workers=1,
|
||||
use_gpu=use_gpu,
|
||||
backend="nccl" if use_gpu else "gloo",
|
||||
scheduler_step_freq="epoch",
|
||||
use_fp16=use_fp16)
|
||||
for i in range(num_epochs):
|
||||
use_fp16=use_fp16,
|
||||
tqdm=True)
|
||||
pbar = trange(num_epochs, unit="epoch")
|
||||
for i in pbar:
|
||||
info = {"num_steps": 1} if test_mode else {}
|
||||
info["epoch_idx"] = i
|
||||
info["num_epochs"] = num_epochs
|
||||
# Increase `max_retries` to turn on fault tolerance.
|
||||
stats = trainer1.train(max_retries=0)
|
||||
print(stats)
|
||||
stats = trainer1.train(max_retries=0, info=info)
|
||||
pbar.set_postfix(dict(loss=stats["mean_train_loss"]))
|
||||
|
||||
print(trainer1.validate())
|
||||
trainer1.shutdown()
|
||||
|
||||
@@ -10,6 +10,8 @@ import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import functional as F
|
||||
from scipy.stats import entropy
|
||||
@@ -240,15 +242,20 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
|
||||
num_workers=num_workers,
|
||||
config=config,
|
||||
use_gpu=use_gpu,
|
||||
backend="nccl" if use_gpu else "gloo")
|
||||
backend="nccl" if use_gpu else "gloo",
|
||||
tqdm=True)
|
||||
|
||||
from tabulate import tabulate
|
||||
for itr in range(5):
|
||||
stats = trainer.train()
|
||||
pbar = trange(5, unit="epoch")
|
||||
for itr in pbar:
|
||||
stats = trainer.train(info=dict(epoch_idx=itr, num_epochs=5))
|
||||
pbar.set_postfix(
|
||||
dict(loss_g=stats["mean_loss_g"], loss_d=stats["mean_loss_d"]))
|
||||
|
||||
formatted = tabulate([stats], headers="keys")
|
||||
if itr > 0: # Get the last line of the stats.
|
||||
formatted = formatted.split("\n")[-1]
|
||||
print(formatted)
|
||||
pbar.write(formatted)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
@@ -133,6 +133,9 @@ class TorchRunner:
|
||||
self.models, self.optimizers = amp.initialize(
|
||||
self.models, self.optimizers, **self.apex_args)
|
||||
|
||||
def set_reporters(self, reporters):
|
||||
return self.training_operator.set_reporters(reporters)
|
||||
|
||||
def setup(self):
|
||||
"""Initializes the model."""
|
||||
logger.debug("Creating model")
|
||||
@@ -156,6 +159,9 @@ class TorchRunner:
|
||||
models=self.models,
|
||||
optimizers=self.optimizers,
|
||||
criterion=self.criterion,
|
||||
train_loader=self.train_loader,
|
||||
validation_loader=self.validation_loader,
|
||||
world_rank=0,
|
||||
schedulers=self.schedulers,
|
||||
use_fp16=self.use_fp16)
|
||||
|
||||
|
||||
@@ -4,19 +4,22 @@ import logging
|
||||
import numbers
|
||||
import tempfile
|
||||
import time
|
||||
import asyncio
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import ray
|
||||
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.trial import Resources
|
||||
from ray.util.sgd.torch.distributed_torch_runner import (
|
||||
DistributedTorchRunner)
|
||||
from ray.util.sgd import utils
|
||||
from ray.util.sgd.utils import NUM_SAMPLES, BATCH_SIZE
|
||||
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
|
||||
from ray.util.sgd.torch.constants import (VALID_SCHEDULER_STEP,
|
||||
BATCH_LOGS_RATE_LIMIT)
|
||||
from ray.util.sgd.torch.tqdm_handler import TqdmHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RESIZE_COOLDOWN_S = 10
|
||||
@@ -146,6 +149,7 @@ class TorchTrainer:
|
||||
use_gpu=False,
|
||||
backend="auto",
|
||||
use_fp16=False,
|
||||
tqdm=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq="batch",
|
||||
num_replicas=None,
|
||||
@@ -217,6 +221,10 @@ class TorchTrainer:
|
||||
self._num_failures = 0
|
||||
self._last_resize = float("-inf")
|
||||
|
||||
self.handlers = []
|
||||
if tqdm:
|
||||
self.handlers.append(TqdmHandler())
|
||||
|
||||
_validate_scheduler_step_freq(scheduler_step_freq)
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
@@ -271,6 +279,8 @@ class TorchTrainer:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
# Get setup tasks in order to throw errors on failure
|
||||
ray.get(self.workers[0].setup.remote())
|
||||
ray.get(self.workers[0].set_reporters.remote(
|
||||
[h.create_reporter() for h in self.handlers]))
|
||||
else:
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
@@ -303,6 +313,11 @@ class TorchTrainer:
|
||||
worker.setup.remote(address, i, len(self.workers))
|
||||
for i, worker in enumerate(self.workers)
|
||||
])
|
||||
ray.get([
|
||||
w.set_reporters.remote(
|
||||
[h.create_reporter() for h in self.handlers])
|
||||
for w in self.workers
|
||||
])
|
||||
|
||||
def train(self,
|
||||
num_steps=None,
|
||||
@@ -359,6 +374,9 @@ class TorchTrainer:
|
||||
logger.info("Resize opportunity detected. Attempting to scale up.")
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
|
||||
for h in self.handlers:
|
||||
h.record_train_info(info, num_steps)
|
||||
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
# Fault handling
|
||||
@@ -395,14 +413,42 @@ class TorchTrainer:
|
||||
stats[stat_key] = worker_stats[0][stat_key]
|
||||
return stats
|
||||
|
||||
def _train_epoch(self, num_steps=None, profile=False, info=None):
|
||||
worker_stats = [
|
||||
def _train_epoch(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
info=None,
|
||||
batch_logs_handler=None):
|
||||
worker_trains = [
|
||||
w.train_epoch.remote(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
for w in self.workers
|
||||
]
|
||||
success = utils.check_for_failure(worker_stats)
|
||||
return success, worker_stats
|
||||
|
||||
if not self.handlers:
|
||||
success = check_for_failure(worker_trains)
|
||||
return success, worker_trains
|
||||
|
||||
unfinished = worker_trains
|
||||
try:
|
||||
while len(unfinished) > 0:
|
||||
finished, unfinished = ray.wait(
|
||||
unfinished, timeout=BATCH_LOGS_RATE_LIMIT)
|
||||
|
||||
# throw errors on agent failure
|
||||
finished = ray.get(finished)
|
||||
|
||||
futures = [h.update() for h in self.handlers]
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(asyncio.wait(futures))
|
||||
loop.close()
|
||||
|
||||
return True, worker_trains
|
||||
except RayActorError as exc:
|
||||
logger.exception(str(exc))
|
||||
return False, worker_trains
|
||||
|
||||
def apply_all_workers(self, fn):
|
||||
"""Run a function on all operators on the workers.
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.constants import BATCH_LOGS_RATE_LIMIT
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _ReporterActor:
|
||||
def __init__(self):
|
||||
# we need the new_data field to allow sending back None as the legs
|
||||
self._logs = {"new_data": False, "data": None}
|
||||
self._setup = {"new_data": False, "data": None}
|
||||
|
||||
def _send_setup(self, data):
|
||||
self._setup = {"new_data": True, "data": data}
|
||||
|
||||
def _send_logs(self, data):
|
||||
self._logs = {"new_data": True, "data": data}
|
||||
|
||||
def _read_logs(self):
|
||||
res = self._logs
|
||||
|
||||
self._logs = {"new_data": False, "data": None}
|
||||
|
||||
return res
|
||||
|
||||
def _read_setup(self):
|
||||
res = self._setup
|
||||
|
||||
self._setup = {"new_data": False, "data": None}
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class TqdmReporter:
|
||||
def __init__(self, actor):
|
||||
self.actor = actor
|
||||
|
||||
self.last_packet_time = 0
|
||||
|
||||
def _send_setup(self, packet):
|
||||
ray.get(self.actor._send_setup.remote(packet))
|
||||
|
||||
def _send_logs(self, packet):
|
||||
cur_time = time.monotonic()
|
||||
if cur_time - self.last_packet_time < BATCH_LOGS_RATE_LIMIT:
|
||||
return
|
||||
|
||||
self.last_packet_time = cur_time
|
||||
ray.get(self.actor._send_logs.remote(packet))
|
||||
|
||||
def on_epoch_begin(self, info, training_op):
|
||||
if training_op.world_rank != 0:
|
||||
return
|
||||
|
||||
self.last_packet_time = 0
|
||||
|
||||
self._send_setup({"loader_len": len(training_op.train_loader)})
|
||||
|
||||
def on_batch_end(self, batch_info, metrics, training_op):
|
||||
if training_op.world_rank != 0:
|
||||
return
|
||||
|
||||
pbar_metrics = {}
|
||||
if "train_loss" in metrics:
|
||||
pbar_metrics["loss"] = metrics["train_loss"]
|
||||
|
||||
self._send_logs({
|
||||
"batch_idx": batch_info["batch_idx"],
|
||||
"pbar_metrics": pbar_metrics
|
||||
})
|
||||
|
||||
|
||||
class TqdmHandler:
|
||||
def __init__(self):
|
||||
self.batch_pbar = None
|
||||
self.reporter_actor = _ReporterActor.remote()
|
||||
|
||||
def create_reporter(self):
|
||||
return TqdmReporter(self.reporter_actor)
|
||||
|
||||
def handle_setup_packet(self, packet):
|
||||
n = self.num_steps
|
||||
if n is None:
|
||||
n = packet["loader_len"]
|
||||
|
||||
desc = ""
|
||||
if self.train_info is not None and "epoch_idx" in self.train_info:
|
||||
if "num_epochs" in self.train_info:
|
||||
desc = "{}/{}e".format(self.train_info["epoch_idx"] + 1,
|
||||
self.train_info["num_epochs"])
|
||||
else:
|
||||
desc = "{}e".format(self.train_info["epoch_idx"] + 1)
|
||||
|
||||
self.batch_pbar = tqdm(total=n, desc=desc, unit="batch", leave=False)
|
||||
|
||||
def handle_logs_packet(self, packet):
|
||||
self.batch_pbar.n = packet["batch_idx"] + 1
|
||||
self.batch_pbar.set_postfix(packet["pbar_metrics"])
|
||||
|
||||
def record_train_info(self, info, num_steps):
|
||||
self.train_info = info
|
||||
self.num_steps = num_steps
|
||||
|
||||
async def update(self):
|
||||
setup_read, logs_read = await asyncio.gather(
|
||||
self.reporter_actor._read_setup.remote(),
|
||||
self.reporter_actor._read_logs.remote())
|
||||
|
||||
if setup_read["new_data"]:
|
||||
self.handle_setup_packet(setup_read["data"])
|
||||
if logs_read["new_data"]:
|
||||
self.handle_logs_packet(logs_read["data"])
|
||||
@@ -1,4 +1,5 @@
|
||||
import collections
|
||||
|
||||
import torch
|
||||
|
||||
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
|
||||
@@ -49,6 +50,9 @@ class TrainingOperator:
|
||||
config,
|
||||
models,
|
||||
optimizers,
|
||||
train_loader,
|
||||
validation_loader,
|
||||
world_rank,
|
||||
criterion=None,
|
||||
schedulers=None,
|
||||
use_fp16=False):
|
||||
@@ -59,6 +63,9 @@ class TrainingOperator:
|
||||
self._optimizers = optimizers # List of optimizers
|
||||
assert isinstance(optimizers, collections.Iterable), (
|
||||
"Components need to be iterable. Got: {}".format(type(optimizers)))
|
||||
self._train_loader = train_loader
|
||||
self._validation_loader = validation_loader
|
||||
self._world_rank = world_rank
|
||||
self._criterion = criterion
|
||||
self._schedulers = schedulers
|
||||
if schedulers:
|
||||
@@ -77,8 +84,12 @@ class TrainingOperator:
|
||||
"TrainingOperator if using multi-scheduler, "
|
||||
"multi-model or multi-optimizer training/validation.")
|
||||
self.timers = TimerCollection()
|
||||
self.reporters = []
|
||||
self.setup(config)
|
||||
|
||||
def set_reporters(self, reporters):
|
||||
self.reporters = reporters
|
||||
|
||||
def _set_timers(self, timers):
|
||||
"""Passes in the timers from the Runner."""
|
||||
self.timers = timers
|
||||
@@ -131,6 +142,9 @@ class TrainingOperator:
|
||||
Returns:
|
||||
A dict of metrics from training.
|
||||
"""
|
||||
for r in self.reporters:
|
||||
r.on_epoch_begin(info, self)
|
||||
|
||||
metric_meters = AverageMeterCollection()
|
||||
|
||||
self.model.train()
|
||||
@@ -142,6 +156,9 @@ class TrainingOperator:
|
||||
batch_info.update(info)
|
||||
metrics = self.train_batch(batch, batch_info=batch_info)
|
||||
|
||||
for r in self.reporters:
|
||||
r.on_batch_end(batch_info, metrics, self)
|
||||
|
||||
if self.scheduler and batch_info.get(
|
||||
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
|
||||
self.scheduler.step()
|
||||
@@ -209,6 +226,7 @@ class TrainingOperator:
|
||||
# Call step of optimizer to update model params.
|
||||
with self.timers.record("apply"):
|
||||
self.optimizer.step()
|
||||
|
||||
return {"train_loss": loss.item(), NUM_SAMPLES: features.size(0)}
|
||||
|
||||
def validate(self, val_iterator, info):
|
||||
@@ -318,6 +336,25 @@ class TrainingOperator:
|
||||
"""List of optimizers created by the ``optimizer_creator``."""
|
||||
return self._optimizers
|
||||
|
||||
@property
|
||||
def train_loader(self):
|
||||
"""
|
||||
Data loader for the validation dataset created by the ``data_creator``.
|
||||
"""
|
||||
return self._train_loader
|
||||
|
||||
@property
|
||||
def validation_loader(self):
|
||||
"""
|
||||
Data loader for the train dataset created by the ``data_creator``.
|
||||
"""
|
||||
return self._validation_loader
|
||||
|
||||
@property
|
||||
def world_rank(self):
|
||||
"""The rank of the parent runner. Always 0 if not distributed."""
|
||||
return self._world_rank
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
"""Criterion created by the provided ``loss_creator``."""
|
||||
|
||||
Reference in New Issue
Block a user