[tune/raysgd] Tune API for TorchTrainer + Fix State Restoration (#7547)

This commit is contained in:
Richard Liaw
2020-03-30 10:58:49 -07:00
committed by GitHub
parent 3a53ea60d9
commit 86cff17e7e
12 changed files with 347 additions and 213 deletions
+56 -47
View File
@@ -1,7 +1,6 @@
import os
import tempfile
from unittest.mock import patch
import numpy as np
import os
import pytest
import time
import torch
@@ -10,7 +9,7 @@ import torch.distributed as dist
import ray
from ray import tune
from ray.util.sgd.torch import TorchTrainer, TorchTrainable
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.training_operator import (_TestingOperator,
_TestMetricsOperator)
from ray.util.sgd.torch.constants import SCHEDULER_STEP
@@ -27,6 +26,9 @@ def ray_start_2_cpus():
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
# Ensure that tests don't ALL fail
if dist.is_initialized():
dist.destroy_process_group()
def test_single_step(ray_start_2_cpus): # noqa: F811
@@ -44,6 +46,19 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
trainer.shutdown()
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=2)
trainer.train(num_steps=1)
trainer.shutdown()
with pytest.raises(RuntimeError):
trainer.train()
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_train(ray_start_2_cpus, num_workers): # noqa: F811
trainer = TorchTrainer(
@@ -53,12 +68,12 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811
loss_creator=lambda config: nn.MSELoss(),
num_workers=num_workers)
for i in range(3):
train_loss1 = trainer.train()["mean_train_loss"]
validation_loss1 = trainer.validate()["mean_val_loss"]
train_loss1 = trainer.train()["train_loss"]
validation_loss1 = trainer.validate()["val_loss"]
for i in range(3):
train_loss2 = trainer.train()["mean_train_loss"]
validation_loss2 = trainer.validate()["mean_val_loss"]
train_loss2 = trainer.train()["train_loss"]
validation_loss2 = trainer.validate()["val_loss"]
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
assert validation_loss2 <= validation_loss1, (validation_loss2,
@@ -118,9 +133,7 @@ def test_multi_model(ray_start_2_cpus, num_workers):
training_operator_cls=_TestingOperator,
num_workers=num_workers)
trainer1.train()
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
trainer1.save(filename)
state = trainer1.state_dict()
models1 = trainer1.get_model()
@@ -134,9 +147,7 @@ def test_multi_model(ray_start_2_cpus, num_workers):
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
num_workers=num_workers)
trainer2.restore(filename)
os.remove(filename)
trainer2.load_state_dict(state)
models2 = trainer2.get_model()
@@ -336,20 +347,20 @@ def test_metrics(ray_start_2_cpus, num_workers):
stats = trainer.train(num_steps=num_train_steps)
# Test that we output mean and last of custom metrics in an epoch
assert "mean_score" in stats
assert "score" in stats
assert stats["last_score"] == 0
assert stats[NUM_SAMPLES] == num_train_steps * batch_size
expected_score = num_workers * (sum(train_scores) /
(num_train_steps * batch_size))
assert np.allclose(stats["mean_score"], expected_score)
assert np.allclose(stats["score"], expected_score)
val_stats = trainer.validate()
# Test that we output mean and last of custom metrics in validation
assert val_stats["last_score"] == 0
expected_score = (sum(val_scores) /
(num_val_steps * batch_size)) * num_workers
assert np.allclose(val_stats["mean_score"], expected_score)
assert np.allclose(val_stats["score"], expected_score)
assert val_stats[BATCH_COUNT] == np.ceil(num_val_steps / num_workers)
assert val_stats[NUM_SAMPLES] == num_val_steps * batch_size
assert val_stats[NUM_SAMPLES] == val_size
@@ -384,14 +395,14 @@ def test_metrics_nan(ray_start_2_cpus, num_workers):
training_operator_cls=_TestMetricsOperator)
stats = trainer.train(num_steps=num_train_steps)
assert "mean_score" in stats
assert "score" in stats
assert stats["last_score"] == 0
assert np.isnan(stats["mean_score"])
assert np.isnan(stats["score"])
stats = trainer.validate()
assert "mean_score" in stats
assert "score" in stats
assert stats["last_score"] == 0
assert np.isnan(stats["mean_score"])
assert np.isnan(stats["score"])
trainer.shutdown()
@@ -415,41 +426,41 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
config = {
"model_creator": model_creator,
"data_creator": data_creator,
"optimizer_creator": optimizer_creator,
"loss_creator": lambda config: nn.MSELoss(),
"num_workers": num_workers,
"use_gpu": False,
"backend": "gloo",
"config": {
"batch_size": 512,
"lr": 0.001
}
}
TorchTrainable = TorchTrainer.as_trainable(
**{
"model_creator": model_creator,
"data_creator": data_creator,
"optimizer_creator": optimizer_creator,
"loss_creator": lambda config: nn.MSELoss(),
"num_workers": num_workers,
"use_gpu": False,
"backend": "gloo",
"config": {
"batch_size": 512,
"lr": 0.001
}
})
analysis = tune.run(
TorchTrainable,
num_samples=2,
config=config,
stop={"training_iteration": 2},
verbose=1)
# checks loss decreasing for every trials
for path, df in analysis.trial_dataframes.items():
mean_train_loss1 = df.loc[0, "mean_train_loss"]
mean_train_loss2 = df.loc[1, "mean_train_loss"]
mean_val_loss1 = df.loc[0, "mean_val_loss"]
mean_val_loss2 = df.loc[1, "mean_val_loss"]
mean_train_loss1 = df.loc[0, "train_loss"]
mean_train_loss2 = df.loc[1, "train_loss"]
mean_val_loss1 = df.loc[0, "val_loss"]
mean_val_loss2 = df.loc[1, "val_loss"]
assert mean_train_loss2 <= mean_train_loss1
assert mean_val_loss2 <= mean_val_loss1
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811
def test_save_and_restore(ray_start_2_cpus, num_workers,
tmp_path): # noqa: F811
trainer1 = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
@@ -457,9 +468,8 @@ def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811
loss_creator=lambda config: nn.MSELoss(),
num_workers=num_workers)
trainer1.train()
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
trainer1.save(filename)
checkpoint_path = os.path.join(tmp_path, "checkpoint")
trainer1.save(checkpoint_path)
model1 = trainer1.get_model()
@@ -471,9 +481,7 @@ def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
num_workers=num_workers)
trainer2.restore(filename)
os.remove(filename)
trainer2.load(checkpoint_path)
model2 = trainer2.get_model()
@@ -619,7 +627,8 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
loss_creator=lambda config: nn.MSELoss(),
num_workers=2)
trainer1.train(max_retries=2)
# MAX RETRIES SHOULD BE ON BY DEFAULT
trainer1.train()
trainer1.shutdown()
+4 -3
View File
@@ -2,16 +2,17 @@ import logging
logger = logging.getLogger(__name__)
TorchTrainer = None
TorchTrainable = None
TrainingOperator = None
BaseTorchTrainable = None
try:
import torch # noqa: F401
from ray.util.sgd.torch.torch_trainer import (TorchTrainer, TorchTrainable)
from ray.util.sgd.torch.torch_trainer import (TorchTrainer,
BaseTorchTrainable)
from ray.util.sgd.torch.training_operator import TrainingOperator
__all__ = ["TorchTrainer", "TorchTrainable", "TrainingOperator"]
__all__ = ["TorchTrainer", "BaseTorchTrainable", "TrainingOperator"]
except ImportError:
logger.warning("PyTorch not found. TorchTrainer will not be available")
@@ -142,14 +142,7 @@ class DistributedTorchRunner(TorchRunner):
This is needed for PyTorch DistributedDataParallel models.
"""
cpu_state_dicts = []
for model in self.models:
state_dict = model.module.state_dict()
# This is so that we create a duplicate of weights into CPU rather
# than move the model weights out of the GPU so that we can
# resume training while saving intermediate checkpoints.
cpu_state_dicts += [{k: v.cpu() for k, v in state_dict.items()}]
return cpu_state_dicts
return [model.module.state_dict() for model in self.models]
def _set_model_state_dicts(self, model_state_dicts):
for model, model_state_dict in zip(self.models, model_state_dicts):
@@ -212,3 +205,10 @@ class LocalDistributedRunner(DistributedTorchRunner):
def is_actor(self):
actor_id = ray.worker.global_worker.actor_id
return actor_id != actor_id.nil()
class DeactivatedRunner:
def __getattr__(self, *args, **kwargs):
raise RuntimeError(
"This TorchTrainer is not active (it is likely shutdown already). "
"Create a new TorchTrainer.")
@@ -1,8 +1,10 @@
import numpy as np
import os
import torch
import torch.nn as nn
import argparse
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
@@ -10,17 +12,20 @@ import torchvision.transforms as transforms
from tqdm import trange
import ray
from ray.util.sgd.torch import (TorchTrainer, TorchTrainable)
from ray.tune import CLIReporter
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.torch.resnet import ResNet18
from ray.util.sgd.utils import BATCH_SIZE
def initialization_hook():
print("NCCL DEBUG SET")
# Need this for avoiding a connection restart issue
# Need this for avoiding a connection restart issue on AWS.
os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
os.environ["NCCL_LL_THRESHOLD"] = "0"
os.environ["NCCL_DEBUG"] = "INFO"
# set the below if needed
# print("NCCL DEBUG SET")
# os.environ["NCCL_DEBUG"] = "INFO"
def cifar_creator(config):
@@ -55,7 +60,10 @@ def cifar_creator(config):
def optimizer_creator(model, config):
"""Returns optimizer"""
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 0.1))
return torch.optim.SGD(
model.parameters(),
lr=config.get("lr", 0.1),
momentum=config.get("momentum", 0.9))
def scheduler_creator(optimizer, config):
@@ -77,12 +85,11 @@ def train_example(num_workers=1,
initialization_hook=initialization_hook,
num_workers=num_workers,
config={
"lr": 0.01,
"test_mode": test_mode,
BATCH_SIZE: 128,
"lr": 0.1,
"test_mode": test_mode, # user-defined param to subset the data
BATCH_SIZE: 128 * num_workers # this will be split across workers.
},
use_gpu=use_gpu,
backend="nccl" if use_gpu else "gloo",
scheduler_step_freq="epoch",
use_fp16=use_fp16,
use_tqdm=True)
@@ -92,39 +99,65 @@ def train_example(num_workers=1,
info["epoch_idx"] = i
info["num_epochs"] = num_epochs
# Increase `max_retries` to turn on fault tolerance.
stats = trainer1.train(max_retries=1, info=info)
pbar.set_postfix(dict(loss=stats["mean_train_loss"]))
trainer1.train(max_retries=1, info=info)
val_stats = trainer1.validate()
pbar.set_postfix(dict(acc=val_stats["val_accuracy"]))
print(trainer1.validate())
trainer1.shutdown()
print("success!")
def tune_example(num_workers=1, use_gpu=False, test_mode=False):
config = {
"model_creator": ResNet18,
"data_creator": cifar_creator,
"optimizer_creator": optimizer_creator,
"loss_creator": nn.CrossEntropyLoss,
"num_workers": num_workers,
"initialization_hook": initialization_hook,
"use_gpu": use_gpu,
"config": {
"lr": tune.choice([1e-4, 1e-3]),
BATCH_SIZE: 128,
"test_mode": test_mode
def tune_example(num_workers=1, use_gpu=False, use_fp16=False,
test_mode=False):
TorchTrainable = TorchTrainer.as_trainable(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
scheduler_creator=scheduler_creator,
initialization_hook=initialization_hook,
num_workers=num_workers,
config={
"test_mode": test_mode, # user-defined param to subset the data
BATCH_SIZE: 128 * num_workers,
},
"backend": "nccl" if use_gpu else "gloo"
}
use_gpu=use_gpu,
scheduler_step_freq="epoch",
use_fp16=use_fp16)
pbt_scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="val_loss",
mode="min",
perturbation_interval=1,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: np.random.uniform(0.001, 1),
# allow perturbations within this set of categorical values
"momentum": [0.8, 0.9, 0.99],
})
reporter = CLIReporter()
reporter.add_metric_column("val_loss", "loss")
reporter.add_metric_column("val_accuracy", "acc")
analysis = tune.run(
TorchTrainable,
num_samples=2,
config=config,
stop={"training_iteration": 2},
verbose=2)
num_samples=4,
config={
"lr": tune.choice([0.001, 0.01, 0.1]),
"momentum": 0.8
},
stop={"training_iteration": 2 if test_mode else 100},
max_failures=3, # used for fault tolerance
checkpoint_freq=3, # used for fault tolerance
keep_checkpoints_num=1, # used for fault tolerance
verbose=2,
progress_reporter=reporter,
scheduler=pbt_scheduler)
return analysis.get_best_config(metric="mean_accuracy", mode="max")
return analysis.get_best_config(metric="val_loss", mode="min")
if __name__ == "__main__":
+1 -4
View File
@@ -242,16 +242,13 @@ def train_example(num_workers=1, use_gpu=False, test_mode=False):
num_workers=num_workers,
config=config,
use_gpu=use_gpu,
backend="nccl" if use_gpu else "gloo",
use_tqdm=True)
from tabulate import tabulate
pbar = trange(5, unit="epoch")
for itr in pbar:
stats = trainer.train(info=dict(epoch_idx=itr, num_epochs=5))
pbar.set_postfix(
dict(loss_g=stats["mean_loss_g"], loss_d=stats["mean_loss_d"]))
pbar.set_postfix(dict(loss_g=stats["loss_g"], loss_d=stats["loss_d"]))
formatted = tabulate([stats], headers="keys")
if itr > 0: # Get the last line of the stats.
formatted = formatted.split("\n")[-1]
@@ -14,7 +14,8 @@ import torch.nn as nn
import ray
from ray import tune
from ray.util.sgd.torch.torch_trainer import TorchTrainable
from ray.util.sgd.torch import TorchTrainer
from ray.util.sgd.utils import BATCH_SIZE
class LinearDataset(torch.utils.data.Dataset):
@@ -48,30 +49,29 @@ def data_creator(config):
val_dataset = LinearDataset(2, 5, size=400)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config["batch_size"],
batch_size=config[BATCH_SIZE],
)
validation_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=config["batch_size"])
batch_size=config[BATCH_SIZE])
return train_loader, validation_loader
def tune_example(num_workers=1, use_gpu=False):
config = {
"model_creator": model_creator,
"data_creator": data_creator,
"optimizer_creator": optimizer_creator,
"loss_creator": nn.MSELoss,
"num_workers": num_workers,
"use_gpu": use_gpu,
"config": {"batch_size": 512 // num_workers},
"backend": "gloo"
}
TorchTrainable = TorchTrainer.as_trainable(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
num_workers=num_workers,
use_gpu=use_gpu,
config={BATCH_SIZE: 128}
)
analysis = tune.run(
TorchTrainable,
num_samples=12,
config=config,
num_samples=3,
config={"lr": tune.grid_search([1e-4, 1e-3])},
stop={"training_iteration": 2},
verbose=1)
+17 -12
View File
@@ -2,6 +2,7 @@ import collections
from filelock import FileLock
import logging
import inspect
import io
import itertools
import os
import tempfile
@@ -225,22 +226,14 @@ class TorchRunner:
self.training_operator._set_timers(self.timers)
def _get_model_state_dicts(self):
# This is so that we create a duplicate of weights into CPU rather than
# move the model weights entirely out of the GPU, so that we can
# resume training while saving intermediate checkpoints.
cpu_state_dicts = []
for model in self.models:
state_dict = model.state_dict()
cpu_state_dicts += [{k: v.cpu() for k, v in state_dict.items()}]
return cpu_state_dicts
return [model.state_dict() for model in self.models]
def _set_model_state_dicts(self, models_state_dicts):
for model, state_dict in zip(self.models, models_state_dicts):
model.load_state_dict(state_dict)
def get_state(self):
def state_dict(self):
"""Returns the state of the runner."""
state = {
"epoch": self.epochs,
"operator": self.training_operator.state_dict(),
@@ -258,9 +251,8 @@ class TorchRunner:
state.update({"amp": amp.state_dict()})
return state
def set_state(self, state):
def load_state_dict(self, state):
"""Sets the state of the model."""
# TODO: restore timer stats
self._set_model_state_dicts(state["models"])
for optimizer, state_dict in zip(self.optimizers, state["optimizers"]):
optimizer.load_state_dict(state_dict)
@@ -274,6 +266,19 @@ class TorchRunner:
self.epochs = state["epoch"]
self.training_operator.load_state_dict(state_dict)
def state_stream(self):
"""Returns a bytes object for the state dict."""
state_dict = self.state_dict()
_buffer = io.BytesIO()
torch.save(state_dict, _buffer)
return _buffer.getvalue()
def load_state_stream(self, byte_obj):
"""Loads a bytes object the training state dict."""
_buffer = io.BytesIO(byte_obj)
state_dict = torch.load(_buffer)
return self.load_state_dict(state_dict)
def apply(self, fn):
return fn()
+165 -75
View File
@@ -1,6 +1,6 @@
import numpy as np
import os
import logging
import os
import numbers
import tempfile
import time
@@ -10,9 +10,10 @@ import torch.distributed as dist
import ray
from ray.exceptions import RayActorError
from ray.tune import Trainable
from ray.tune.trial import Resources
from ray.tune.resources import Resources
from ray.tune.utils.util import merge_dicts
from ray.util.sgd.torch.distributed_torch_runner import (
DistributedTorchRunner, LocalDistributedRunner)
DistributedTorchRunner, LocalDistributedRunner, DeactivatedRunner)
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
from ray.util.sgd.torch.torch_runner import TorchRunner
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
@@ -130,6 +131,11 @@ class TorchTrainer:
"""
# TODO: Implement autoscaling. If num_workers=-1, the trainer will use as
# many resources as available. Upon each train call, TorchTrainer will
# query the Ray global state for total available resources and resize
# its remote workers to consume all available resources.
def __init__(
self,
*,
@@ -218,6 +224,9 @@ class TorchTrainer:
self._num_failures = 0
self._last_resize = float("-inf")
self.local_worker = DeactivatedRunner()
self.remote_workers = []
_validate_scheduler_step_freq(scheduler_step_freq)
self.scheduler_step_freq = scheduler_step_freq
@@ -250,9 +259,6 @@ class TorchTrainer:
if batch_size_per_worker:
worker_config[BATCH_SIZE] = batch_size_per_worker
self.local_worker = None
self.remote_workers = []
if num_workers == 1:
# Start local worker
self.local_worker = TorchRunner(
@@ -319,8 +325,7 @@ class TorchTrainer:
num_steps=None,
profile=False,
reduce_results=True,
max_retries=0,
checkpoint="auto",
max_retries=3,
info=None):
"""Runs a training epoch.
@@ -339,14 +344,12 @@ class TorchTrainer:
all workers into one dict. If a metric is a non-numerical
value (or nested dictionaries), one value will be randomly
selected among the workers. If False, returns a list of dicts.
max_retries (int): Must be non-negative. If set to N, will
kill all current workers, query the Ray global state for
total available resources, and re-launch up to the
available resources. Behavior is not well-defined
in case of shared cluster usage.
checkpoint (str): Path to checkpoint to restore from if retrying.
If max_retries is set and ``checkpoint == "auto"``,
TorchTrainer will save a checkpoint before starting to train.
max_retries (int): Must be non-negative. If set to N, TorchTrainer
will detect and recover from training failure. The recovery
process will kill all current workers, query the Ray
global state for total available resources, and re-launch up to
the available resources. Behavior is not well-defined
in case of shared cluster usage. Defaults to 3.
info (dict): Optional dictionary passed to the training
operator for ``train_epoch`` and ``train_batch``.
@@ -358,18 +361,9 @@ class TorchTrainer:
length will be equal to ``num_workers``.
"""
assert max_retries >= 0, "`max_retries` must be non-negative."
if max_retries:
if checkpoint == "auto":
logger.debug("Retrying detected. Automatically checkpointing.")
checkpoint = self.save(
os.path.join(self.temp_dir, "tmp_checkpoint"))
elif not checkpoint:
raise ValueError("Cannot retry from empty checkpoint.")
if checkpoint and self._should_resize():
if self._should_resize():
logger.info("Resize opportunity detected. Attempting to scale up.")
self._resize_workers(checkpoint=checkpoint)
self._resize_workers()
success, worker_stats = self._train_epoch(
num_steps=num_steps, profile=profile, info=info)
# Fault handling
@@ -378,7 +372,7 @@ class TorchTrainer:
break
else:
self._num_failures += 1
self._resize_workers(checkpoint=checkpoint)
self._resize_workers()
logger.info("Retrying training step with %d workers." %
(len(self.remote_workers) + 1))
success, worker_stats = self._train_epoch(
@@ -483,7 +477,6 @@ class TorchTrainer:
w.validate.remote(**params) for w in self.remote_workers
]
local_worker_stats = self.local_worker.validate(**params)
return self._process_stats([local_worker_stats] +
ray.get(remote_worker_stats))
@@ -497,47 +490,49 @@ class TorchTrainer:
def get_model(self):
"""Returns the learned model(s)."""
models = self.model_creator(self.config)
state = self.local_worker.get_state()
if len(state["models"]) == 1:
models.load_state_dict(state["models"][0])
else:
for model, state_dict in zip(models, state["models"]):
model.load_state_dict(state_dict)
return models
unwrapped = []
for model in self.local_worker.models:
unwrapped += [model.module if hasattr(model, "module") else model]
if len(unwrapped) == 1:
return unwrapped[0]
return unwrapped
def state_dict(self):
return self.local_worker.get_state()
return self.local_worker.state_dict()
def load_state_dict(self, state):
state_id = ray.put(state)
def load_state_dict(self, state_dict, blocking=False):
# This is not the most efficient because you have to wait for
# the local worker to save then dump to buffer.
self.local_worker.load_state_dict(state_dict)
state_id = ray.put(self.local_worker.state_stream())
remote_calls = [
worker.set_state.remote(state_id) for worker in self.remote_workers
worker.load_state_stream.remote(state_id)
for worker in self.remote_workers
]
self.local_worker.set_state(state)
ray.get(remote_calls)
if blocking:
ray.get(remote_calls)
def save(self, checkpoint):
"""Saves the model(s) to the provided checkpoint.
"""Saves the Trainer state to the provided checkpoint path.
Args:
checkpoint (str): Path to target checkpoint file.
Returns:
checkpoint (str): Path to target checkpoint file.
"""
torch.save(self.state_dict(), checkpoint)
return checkpoint
def restore(self, checkpoint):
"""Restores the Trainer and all workers from the provided checkpoint.
def load(self, checkpoint):
"""Loads the Trainer and all workers from the provided checkpoint.
Args:
checkpoint (str): Path to target checkpoint file.
"""
state = torch.load(checkpoint)
self.load_state_dict(state)
state_dict = torch.load(checkpoint)
self.load_state_dict(state_dict)
def restore(self, *args):
raise DeprecationWarning("Use `TorchTrainer.load()` instead.")
def shutdown(self, force=False):
"""Shuts down workers and releases resources."""
@@ -562,19 +557,19 @@ class TorchTrainer:
else:
self.local_worker.shutdown()
for worker in self.remote_workers:
logger.warning("Killing worker {}.".format(worker))
logger.debug("Killing worker {}.".format(worker))
ray.kill(worker)
self.local_worker = None
self.local_worker = DeactivatedRunner()
self.remote_workers = []
def _reset(self):
"""Terminates models without giving up local resource reservation."""
self.local_worker.shutdown(cleanup=False)
for worker in self.remote_workers:
logger.warning("Killing worker {}.".format(worker))
logger.debug("Killing worker {}.".format(worker))
ray.kill(worker)
self.local_worker = None
self.local_worker = DeactivatedRunner()
self.remote_workers = []
def _check_potential_remote_workers_size(self):
@@ -588,9 +583,8 @@ class TorchTrainer:
remote_resources.get("GPU", 0), new_remote_workers)
return new_remote_workers
def _resize_workers(self, checkpoint, max_retries=10):
def _resize_workers(self, max_retries=10):
self._reset()
assert checkpoint, "Cannot restore without checkpoint."
time.sleep(1)
for i in range(max_retries):
@@ -598,7 +592,7 @@ class TorchTrainer:
if new_remote_workers:
self._last_resize = time.time()
self._start_workers(int(new_remote_workers) + 1)
self.restore(checkpoint)
self.load_state_dict(self.state_dict())
return
else:
delay = 2**i
@@ -617,32 +611,128 @@ class TorchTrainer:
return potential_remote_size > 0
return False
class TorchTrainable(Trainable):
@classmethod
def default_resource_request(cls, config):
remote_worker_count = config["num_workers"] - 1
return Resources(
cpu=1,
gpu=int(config["use_gpu"]),
extra_cpu=int(remote_worker_count),
extra_gpu=int(int(config["use_gpu"]) * remote_worker_count))
def as_trainable(cls, *args, **kwargs):
"""Creates a BaseTorchTrainable class compatible with Tune.
Any configuration parameters will be overriden by the Tune
Trial configuration. You can also subclass the provided Trainable
to implement your own iterative optimization routine.
.. code-block:: python
TorchTrainable = TorchTrainer.as_trainable(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
num_gpus=2
)
analysis = tune.run(
TorchTrainable,
config={"lr": tune.grid_search([0.01, 0.1])}
)
"""
class TorchTrainable(BaseTorchTrainable):
@classmethod
def default_resource_request(cls, config):
num_workers = config.get("num_workers",
kwargs.get("num_workers", 1))
use_gpu = config.get("use_gpu", kwargs.get("use_gpu"))
remote_worker_count = num_workers - 1
return Resources(
cpu=1,
gpu=int(use_gpu),
extra_cpu=int(remote_worker_count),
extra_gpu=int(int(use_gpu) * remote_worker_count))
def _create_trainer(self, tune_config):
"""Overrides the provided config with Tune config."""
provided_config = kwargs.get("config", {}).copy()
provided_config.update(tune_config)
kwargs["config"] = provided_config
trainer = TorchTrainer(*args, **kwargs)
return trainer
return TorchTrainable
class BaseTorchTrainable(Trainable):
"""Base class for converting TorchTrainer to a Trainable class.
This class is produced when you call ``TorchTrainer.as_trainable(...)``.
You can override the produced Trainable to implement custom iterative
training procedures:
.. code-block:: python
TorchTrainable = TorchTrainer.as_trainable(
model_creator=ResNet18,
data_creator=cifar_creator,
optimizer_creator=optimizer_creator,
loss_creator=nn.CrossEntropyLoss,
num_gpus=2
)
# TorchTrainable is subclass of BaseTorchTrainable.
class CustomTrainable(TorchTrainable):
def _train(self):
for i in range(5):
train_stats = self.trainer.train()
validation_stats = self.trainer.validate()
train_stats.update(validation_stats)
return train_stats
analysis = tune.run(
CustomTrainable,
config={"lr": tune.grid_search([0.01, 0.1])}
)
"""
def _setup(self, config):
self._trainer = TorchTrainer(**config)
"""Constructs a TorchTrainer object as `self.trainer`."""
self._trainer = self._create_trainer(config)
def _train(self):
train_stats = self._trainer.train()
validation_stats = self._trainer.validate()
"""Calls `self.trainer.train()` and `self.trainer.validate()` once.
train_stats.update(validation_stats)
return train_stats
You may want to override this if using a custom LR scheduler.
"""
train_stats = self.trainer.train(max_retries=10, profile=True)
validation_stats = self.trainer.validate(profile=True)
stats = merge_dicts(train_stats, validation_stats)
return stats
def _save(self, checkpoint_dir):
return self._trainer.save(os.path.join(checkpoint_dir, "model.pth"))
"""Returns a path containing the trainer state."""
checkpoint_path = os.path.join(checkpoint_dir, "trainer.checkpoint")
self.trainer.save(checkpoint_path)
return checkpoint_path
def _restore(self, checkpoint_path):
return self._trainer.restore(checkpoint_path)
"""Restores the trainer state.
Override this if you have state external to the Trainer object.
"""
return self.trainer.load(checkpoint_path)
def _stop(self):
self._trainer.shutdown()
"""Shuts down the trainer."""
self.trainer.shutdown()
def _create_trainer(self, config):
raise NotImplementedError
@property
def trainer(self):
"""An instantiated TorchTrainer object.
Use this when specifying custom training procedures for Tune.
"""
return self._trainer
@@ -168,8 +168,10 @@ class TrainingOperator:
if self.use_tqdm and self.world_rank == 0:
_progress_bar.n = batch_idx + 1
postfix = {}
if "train_loss" in metrics:
_progress_bar.set_postfix({"loss": metrics["train_loss"]})
postfix.update(loss=metrics["train_loss"])
_progress_bar.set_postfix(postfix)
if self.scheduler and batch_info.get(
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
@@ -259,7 +261,7 @@ class TrainingOperator:
Returns:
A dict of metrics from the evaluation.
By default, returns "mean_accuracy" and "mean_val_loss"
By default, returns "val_accuracy" and "val_loss"
which is computed by aggregating "loss" and "correct" values
from ``validate_batch`` and dividing it by the sum of
``num_samples`` from all calls to ``self.validate_batch``.
+1 -1
View File
@@ -192,7 +192,7 @@ class AverageMeterCollection:
"""Returns a dict of average and most recent values for each metric."""
stats = {BATCH_COUNT: self._batch_count, NUM_SAMPLES: self.n}
for metric, meter in self._meters.items():
stats["mean_" + str(metric)] = meter.avg
stats[str(metric)] = meter.avg
stats["last_" + str(metric)] = meter.val
return stats