mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
[raysgd] Cleanup User API (#7384)
* Init fp16 * fp16 and schedulers * scheduler linking and fp16 * to fp16 * loss scaling and documentation * more documentation * add tests, refactor config * moredocs * more docs * fix logo, add test mode, add fp16 flag * fix tests * fix scheduler * fix apex * improve safety * fix tests * fix tests * remove pin memory default * rm * fix * Update doc/examples/doc_code/raysgd_torch_signatures.py * fix * migrate changes from other PR * ok thanks * pass * signatures * lint' * Update python/ray/experimental/sgd/pytorch/utils.py * Apply suggestions from code review Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com> * should address most comments * comments * fix this ci * first_pass * add overrides * override * fixing up operators * format * sgd * constants * rm * revert * save * failures * fixes * trainer * run test * operator * code * op * ok done * operator * sgd test fixes * ok * trainer * format * Apply suggestions from code review Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com> * Update doc/source/raysgd/raysgd_pytorch.rst * docstring * dcgan * doc * commits * nit * testing * revert * Start renaming pytorch to torch * Rename PyTorchTrainer to TorchTrainer * Rename PyTorch runners to Torch runners * Finish renaming API * Rename to torch in tests * Finish renaming docs + tests * Run format + fix DeprecationWarning * fix * move tests up * benchmarks * rename * remove some args * better metrics output * fix up the benchmark * benchmark-yaml * horovod-benchmark * benchmarks * Remove benchmark code for cleanups * makedatacreator * relax * metrics * autosetsampler * profile * movements * OK * smoothen * fix * nitdocs * loss * comments * fix * fix * runner_tests * codes * example * fix_test * fix * tests Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com> Co-authored-by: Maksim Smolin <maximsmol@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@ try: # py3
|
||||
from shlex import quote
|
||||
except ImportError: # py2
|
||||
from pipes import quote
|
||||
import click
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
@@ -274,7 +275,7 @@ class SSHCommandRunner:
|
||||
except subprocess.CalledProcessError:
|
||||
if exit_on_fail:
|
||||
quoted_cmd = " ".join(final_cmd[:-1] + [quote(final_cmd[-1])])
|
||||
raise Exception(
|
||||
raise click.ClickException(
|
||||
"Command failed: \n\n {}\n".format(quoted_cmd))
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import time
|
||||
import torch
|
||||
@@ -11,9 +11,11 @@ import torch.distributed as dist
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.util.sgd.torch import TorchTrainer, TorchTrainable
|
||||
from ray.util.sgd.torch.training_operator import _TestingOperator
|
||||
from ray.util.sgd.torch.constants import BATCH_COUNT, SCHEDULER_STEP
|
||||
from ray.util.sgd.utils import check_for_failure
|
||||
from ray.util.sgd.torch.training_operator import (_TestingOperator,
|
||||
_TestMetricsOperator)
|
||||
from ray.util.sgd.torch.constants import SCHEDULER_STEP
|
||||
from ray.util.sgd.utils import (check_for_failure, NUM_SAMPLES, BATCH_COUNT,
|
||||
BATCH_SIZE)
|
||||
|
||||
from ray.util.sgd.torch.examples.train_example import (
|
||||
model_creator, optimizer_creator, data_creator, LinearDataset)
|
||||
@@ -29,11 +31,11 @@ def ray_start_2_cpus():
|
||||
|
||||
def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=1)
|
||||
num_workers=1)
|
||||
metrics = trainer.train(num_steps=1)
|
||||
assert metrics[BATCH_COUNT] == 1
|
||||
|
||||
@@ -41,31 +43,29 @@ def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
assert val_metrics[BATCH_COUNT] == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas)
|
||||
num_workers=num_workers)
|
||||
for i in range(3):
|
||||
train_loss1 = trainer.train()["mean_train_loss"]
|
||||
validation_loss1 = trainer.validate()["mean_validation_loss"]
|
||||
validation_loss1 = trainer.validate()["mean_val_loss"]
|
||||
|
||||
for i in range(3):
|
||||
train_loss2 = trainer.train()["mean_train_loss"]
|
||||
validation_loss2 = trainer.validate()["mean_validation_loss"]
|
||||
validation_loss2 = trainer.validate()["mean_val_loss"]
|
||||
|
||||
assert train_loss2 <= train_loss1, (train_loss2, train_loss1)
|
||||
assert validation_loss2 <= validation_loss1, (validation_loss2,
|
||||
validation_loss1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_multi_model(ray_start_2_cpus, num_replicas):
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
def train(*, model=None, criterion=None, optimizer=None, dataloader=None):
|
||||
model.train()
|
||||
train_loss = 0
|
||||
@@ -108,13 +108,13 @@ def test_multi_model(ray_start_2_cpus, num_replicas):
|
||||
return opts[0], opts[1]
|
||||
|
||||
trainer1 = TorchTrainer(
|
||||
multi_model_creator,
|
||||
data_creator,
|
||||
multi_optimizer_creator,
|
||||
model_creator=multi_model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=multi_optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
num_replicas=num_replicas)
|
||||
num_workers=num_workers)
|
||||
trainer1.train()
|
||||
|
||||
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
|
||||
@@ -125,13 +125,13 @@ def test_multi_model(ray_start_2_cpus, num_replicas):
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
multi_model_creator,
|
||||
data_creator,
|
||||
multi_optimizer_creator,
|
||||
model_creator=multi_model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=multi_optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
num_replicas=num_replicas)
|
||||
num_workers=num_workers)
|
||||
trainer2.restore(filename)
|
||||
|
||||
os.remove(filename)
|
||||
@@ -151,9 +151,8 @@ def test_multi_model(ray_start_2_cpus, num_replicas):
|
||||
trainer2.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
def train_epoch(self, iterator, info):
|
||||
if self.config.get("models", 1) > 1:
|
||||
assert len(self.models) == self.config["models"], self.config
|
||||
@@ -194,13 +193,13 @@ def test_multi_model_matrix(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
for optimizer_count in range(1, 3):
|
||||
for scheduler_count in range(1, 3):
|
||||
trainer = TorchTrainer(
|
||||
multi_model_creator,
|
||||
data_creator,
|
||||
multi_optimizer_creator,
|
||||
model_creator=multi_model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=multi_optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=multi_scheduler_creator,
|
||||
training_operator_cls=_TestingOperator,
|
||||
num_replicas=num_replicas,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"models": model_count,
|
||||
"optimizers": optimizer_count,
|
||||
@@ -222,9 +221,9 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
@@ -236,13 +235,168 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_profiling(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss())
|
||||
|
||||
stats = trainer.train(profile=True)
|
||||
assert "profile" in stats
|
||||
stats = trainer.validate(profile=True)
|
||||
assert "profile" in stats
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_split_batch(ray_start_2_cpus):
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def data_creator(config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
train_dataset = LinearDataset(2, 5, size=config["data_size"])
|
||||
return torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config[BATCH_SIZE],
|
||||
)
|
||||
|
||||
data_size = 600
|
||||
batch_size = 21
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=2,
|
||||
config={
|
||||
BATCH_SIZE: batch_size,
|
||||
"data_size": data_size,
|
||||
})
|
||||
stats = trainer.train()
|
||||
assert trainer.config[BATCH_SIZE] == (batch_size - 1)
|
||||
assert stats[NUM_SAMPLES] == 600
|
||||
assert stats[BATCH_COUNT] == (data_size // 20)
|
||||
|
||||
|
||||
def test_reduce_result(ray_start_2_cpus):
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def data_creator(config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
train_dataset = LinearDataset(2, 5, size=config["data_size"])
|
||||
return torch.utils.data.DataLoader(train_dataset, batch_size=1)
|
||||
|
||||
data_size = 600
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=2,
|
||||
config={"data_size": data_size})
|
||||
list_stats = trainer.train(reduce_results=False, profile=True)
|
||||
assert len(list_stats) == 2
|
||||
assert [stats[NUM_SAMPLES] == data_size for stats in list_stats]
|
||||
assert [stats[BATCH_COUNT] == (data_size // 2) for stats in list_stats]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_metrics(ray_start_2_cpus, num_workers):
|
||||
data_size, val_size = 600, 500
|
||||
batch_size = 4
|
||||
|
||||
num_train_steps = int(data_size / batch_size)
|
||||
num_val_steps = int(val_size / batch_size)
|
||||
|
||||
train_scores = [1] + ([0] * num_train_steps)
|
||||
val_scores = [1] + ([0] * num_val_steps)
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"scores": train_scores,
|
||||
"val_scores": val_scores,
|
||||
"key": "score",
|
||||
"batch_size": batch_size,
|
||||
"data_size": data_size,
|
||||
"val_size": val_size
|
||||
},
|
||||
training_operator_cls=_TestMetricsOperator)
|
||||
|
||||
stats = trainer.train(num_steps=num_train_steps)
|
||||
# Test that we output mean and last of custom metrics in an epoch
|
||||
assert "mean_score" in stats
|
||||
assert stats["last_score"] == 0
|
||||
|
||||
assert stats[NUM_SAMPLES] == num_train_steps * batch_size
|
||||
expected_score = num_workers * (sum(train_scores) /
|
||||
(num_train_steps * batch_size))
|
||||
assert np.allclose(stats["mean_score"], expected_score)
|
||||
|
||||
val_stats = trainer.validate()
|
||||
# Test that we output mean and last of custom metrics in validation
|
||||
assert val_stats["last_score"] == 0
|
||||
expected_score = (sum(val_scores) /
|
||||
(num_val_steps * batch_size)) * num_workers
|
||||
assert np.allclose(val_stats["mean_score"], expected_score)
|
||||
assert val_stats[BATCH_COUNT] == np.ceil(num_val_steps / num_workers)
|
||||
assert val_stats[NUM_SAMPLES] == num_val_steps * batch_size
|
||||
assert val_stats[NUM_SAMPLES] == val_size
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_metrics_nan(ray_start_2_cpus, num_workers):
|
||||
data_size, val_size = 100, 100
|
||||
batch_size = 10
|
||||
|
||||
num_train_steps = int(data_size / batch_size)
|
||||
num_val_steps = int(val_size / batch_size)
|
||||
|
||||
train_scores = [np.nan] + ([0] * num_train_steps)
|
||||
val_scores = [np.nan] + ([0] * num_val_steps)
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"scores": train_scores,
|
||||
"val_scores": val_scores,
|
||||
"key": "score",
|
||||
"batch_size": batch_size,
|
||||
"data_size": data_size,
|
||||
"val_size": val_size
|
||||
},
|
||||
training_operator_cls=_TestMetricsOperator)
|
||||
|
||||
stats = trainer.train(num_steps=num_train_steps)
|
||||
assert "mean_score" in stats
|
||||
assert stats["last_score"] == 0
|
||||
assert np.isnan(stats["mean_score"])
|
||||
|
||||
stats = trainer.validate()
|
||||
assert "mean_score" in stats
|
||||
assert stats["last_score"] == 0
|
||||
assert np.isnan(stats["mean_score"])
|
||||
|
||||
|
||||
def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer),
|
||||
training_operator_cls=_TestingOperator)
|
||||
@@ -254,20 +408,19 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_tune_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
|
||||
config = {
|
||||
"model_creator": model_creator,
|
||||
"data_creator": data_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": lambda config: nn.MSELoss(),
|
||||
"num_replicas": num_replicas,
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": False,
|
||||
"batch_size": 512,
|
||||
"backend": "gloo",
|
||||
"config": {
|
||||
"batch_size": 512,
|
||||
"lr": 0.001
|
||||
}
|
||||
}
|
||||
@@ -283,22 +436,21 @@ def test_tune_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
for path, df in analysis.trial_dataframes.items():
|
||||
mean_train_loss1 = df.loc[0, "mean_train_loss"]
|
||||
mean_train_loss2 = df.loc[1, "mean_train_loss"]
|
||||
mean_validation_loss1 = df.loc[0, "mean_validation_loss"]
|
||||
mean_validation_loss2 = df.loc[1, "mean_validation_loss"]
|
||||
mean_val_loss1 = df.loc[0, "mean_val_loss"]
|
||||
mean_val_loss2 = df.loc[1, "mean_val_loss"]
|
||||
|
||||
assert mean_train_loss2 <= mean_train_loss1
|
||||
assert mean_validation_loss2 <= mean_validation_loss1
|
||||
assert mean_val_loss2 <= mean_val_loss1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_replicas", [1, 2]
|
||||
if dist.is_available() else [1])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_save_and_restore(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas)
|
||||
num_workers=num_workers)
|
||||
trainer1.train()
|
||||
|
||||
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
|
||||
@@ -309,11 +461,11 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
trainer1.shutdown()
|
||||
|
||||
trainer2 = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=num_replicas)
|
||||
num_workers=num_workers)
|
||||
trainer2.restore(filename)
|
||||
|
||||
os.remove(filename)
|
||||
@@ -334,7 +486,9 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5, size=1000000)
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, *args, **kwargs):
|
||||
worker_stats = [
|
||||
@@ -348,12 +502,12 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
batch_size=100000,
|
||||
model_creator=model_creator,
|
||||
data_creator=single_loader,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=2)
|
||||
config={"batch_size": 100000},
|
||||
num_workers=2)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
trainer1.train(max_retries=1)
|
||||
@@ -364,7 +518,9 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5, size=1000000)
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, *args, **kwargs):
|
||||
worker_stats = [
|
||||
@@ -378,12 +534,12 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
batch_size=100000,
|
||||
model_creator=model_creator,
|
||||
data_creator=single_loader,
|
||||
optimizer_creator=optimizer_creator,
|
||||
config={"batch_size": 100000},
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=2)
|
||||
num_workers=2)
|
||||
|
||||
@ray.remote
|
||||
def try_test():
|
||||
@@ -400,7 +556,9 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
return LinearDataset(2, 5, size=1000000)
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, *args, **kwargs):
|
||||
worker_stats = [
|
||||
@@ -414,12 +572,12 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator,
|
||||
single_loader,
|
||||
optimizer_creator,
|
||||
batch_size=100000,
|
||||
model_creator=model_creator,
|
||||
data_creator=single_loader,
|
||||
optimizer_creator=optimizer_creator,
|
||||
config={"batch_size": 100000},
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_replicas=2)
|
||||
num_workers=2)
|
||||
|
||||
trainer1.train(max_retries=2)
|
||||
|
||||
|
||||
@@ -61,11 +61,11 @@ class TestTorchRunner(unittest.TestCase):
|
||||
runner.setup()
|
||||
runner.train_epoch()
|
||||
runner.train_epoch()
|
||||
runner.train_epoch()
|
||||
result = runner.train_epoch()
|
||||
self.assertEqual(runner.training_operator.validate.call_count, 0)
|
||||
runner.validate()
|
||||
self.assertTrue(runner.training_operator.validate.called)
|
||||
self.assertEqual(runner.stats()["epoch"], 3)
|
||||
self.assertEqual(result["epoch"], 3)
|
||||
|
||||
def testtrain_epoch(self):
|
||||
class MockOperator(TrainingOperator):
|
||||
@@ -88,7 +88,7 @@ class TestTorchRunner(unittest.TestCase):
|
||||
result = runner.train_epoch()
|
||||
self.assertEqual(runner.training_operator.count, 3)
|
||||
self.assertEqual(result["count"], 3)
|
||||
self.assertEqual(runner.stats()["epoch"], 3)
|
||||
self.assertEqual(result["epoch"], 3)
|
||||
|
||||
def testGivens(self):
|
||||
class MockOperator(TrainingOperator):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
USE_FP16 = "__use_fp16__"
|
||||
BATCH_COUNT = "batch_count"
|
||||
SCHEDULER_STEP = "scheduler_step"
|
||||
SCHEDULER_STEP_BATCH = "batch"
|
||||
SCHEDULER_STEP_EPOCH = "epoch"
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import collections
|
||||
from filelock import FileLock
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
|
||||
@@ -39,17 +39,15 @@ class DistributedTorchRunner(TorchRunner):
|
||||
self._setup_training()
|
||||
|
||||
def _setup_distributed_pytorch(self, url, world_rank, world_size):
|
||||
with self._timers["setup_proc"]:
|
||||
self.world_rank = world_rank
|
||||
logger.debug(
|
||||
"Connecting to {} world_rank: {} world_size: {}".format(
|
||||
url, world_rank, world_size))
|
||||
logger.debug("using {}".format(self.backend))
|
||||
dist.init_process_group(
|
||||
backend=self.backend,
|
||||
init_method=url,
|
||||
rank=world_rank,
|
||||
world_size=world_size)
|
||||
self.world_rank = world_rank
|
||||
logger.debug("Connecting to {} world_rank: {} world_size: {}".format(
|
||||
url, world_rank, world_size))
|
||||
logger.debug("using {}".format(self.backend))
|
||||
dist.init_process_group(
|
||||
backend=self.backend,
|
||||
init_method=url,
|
||||
rank=world_rank,
|
||||
world_size=world_size)
|
||||
|
||||
def _setup_training(self):
|
||||
logger.debug("Creating model")
|
||||
@@ -68,32 +66,12 @@ class DistributedTorchRunner(TorchRunner):
|
||||
self.optimizers = [self.optimizers]
|
||||
|
||||
self._create_schedulers_if_available()
|
||||
|
||||
self._try_setup_apex()
|
||||
|
||||
# This needs to happen after apex
|
||||
self.models = [DistributedDataParallel(model) for model in self.models]
|
||||
|
||||
logger.debug("Creating loss.")
|
||||
self._create_loss()
|
||||
|
||||
logger.debug("Creating dataset.")
|
||||
with FileLock(os.path.expanduser("~/.ray_data.lock")):
|
||||
datasets = self.data_creator(self.config)
|
||||
train_set, val_set = self._validate_datasets(datasets)
|
||||
|
||||
train_loader_config = self.dataloader_config.copy()
|
||||
train_loader_config.update(
|
||||
sampler=torch.utils.data.distributed.DistributedSampler(train_set),
|
||||
shuffle=False)
|
||||
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
train_set, batch_size=self.batch_size, **train_loader_config)
|
||||
|
||||
self.validation_loader = None
|
||||
if val_set:
|
||||
self.validation_loader = torch.utils.data.DataLoader(
|
||||
val_set, batch_size=self.batch_size, **self.dataloader_config)
|
||||
self._initialize_dataloaders()
|
||||
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
@@ -103,12 +81,39 @@ class DistributedTorchRunner(TorchRunner):
|
||||
schedulers=self.schedulers,
|
||||
use_fp16=self.use_fp16)
|
||||
|
||||
def _initialize_dataloaders(self):
|
||||
super(DistributedTorchRunner, self)._initialize_dataloaders()
|
||||
|
||||
def with_sampler(loader):
|
||||
# Automatically set the DistributedSampler
|
||||
data_loader_args = {
|
||||
"dataset": loader.dataset,
|
||||
"batch_size": loader.batch_size,
|
||||
"shuffle": False,
|
||||
"num_workers": loader.num_workers,
|
||||
"collate_fn": loader.collate_fn,
|
||||
"pin_memory": loader.pin_memory,
|
||||
"drop_last": loader.drop_last,
|
||||
"timeout": loader.timeout,
|
||||
"worker_init_fn": loader.worker_init_fn,
|
||||
"sampler": DistributedSampler(loader.dataset)
|
||||
}
|
||||
return DataLoader(**data_loader_args)
|
||||
|
||||
if isinstance(self.train_loader, DataLoader):
|
||||
self.train_loader = with_sampler(self.train_loader)
|
||||
|
||||
if self.validation_loader and isinstance(self.validation_loader,
|
||||
DataLoader):
|
||||
self.validation_loader = with_sampler(self.validation_loader)
|
||||
|
||||
def train_epoch(self, **kwargs):
|
||||
"""Runs a training epoch and updates the model parameters.
|
||||
|
||||
Automatically sets epoch of sampler if possible.
|
||||
"""
|
||||
if hasattr(self.train_loader.sampler, "set_epoch"):
|
||||
if hasattr(self.train_loader, "sampler") and hasattr(
|
||||
self.train_loader.sampler, "set_epoch"):
|
||||
self.train_loader.sampler.set_epoch(self.epochs)
|
||||
return super(DistributedTorchRunner, self).train_epoch(**kwargs)
|
||||
|
||||
|
||||
@@ -3,13 +3,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
from ray import tune
|
||||
import torch.utils.data
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch import (TorchTrainer, TorchTrainable)
|
||||
from ray.util.sgd.torch.resnet import ResNet18
|
||||
from ray.util.sgd.utils import BATCH_SIZE
|
||||
|
||||
|
||||
def initialization_hook():
|
||||
@@ -40,11 +41,14 @@ def cifar_creator(config):
|
||||
root="~/data", train=False, download=False, transform=transform_test)
|
||||
|
||||
if config.get("test_mode"):
|
||||
train_dataset = torch.utils.data.Subset(train_dataset, list(range(64)))
|
||||
validation_dataset = torch.utils.data.Subset(validation_dataset,
|
||||
list(range(64)))
|
||||
train_dataset = Subset(train_dataset, list(range(64)))
|
||||
validation_dataset = Subset(validation_dataset, list(range(64)))
|
||||
|
||||
return train_dataset, validation_dataset
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
validation_loader = DataLoader(
|
||||
validation_dataset, batch_size=config[BATCH_SIZE], num_workers=2)
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
@@ -57,25 +61,25 @@ def scheduler_creator(optimizer, config):
|
||||
optimizer, milestones=[150, 250, 350], gamma=0.1)
|
||||
|
||||
|
||||
def train_example(num_replicas=1,
|
||||
def train_example(num_workers=1,
|
||||
num_epochs=5,
|
||||
use_gpu=False,
|
||||
use_fp16=False,
|
||||
test_mode=False):
|
||||
trainer1 = TorchTrainer(
|
||||
ResNet18,
|
||||
cifar_creator,
|
||||
optimizer_creator,
|
||||
nn.CrossEntropyLoss,
|
||||
model_creator=ResNet18,
|
||||
data_creator=cifar_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.CrossEntropyLoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
initialization_hook=initialization_hook,
|
||||
num_replicas=num_replicas,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
"lr": 0.01,
|
||||
"test_mode": test_mode
|
||||
"test_mode": test_mode,
|
||||
BATCH_SIZE: 128,
|
||||
},
|
||||
use_gpu=use_gpu,
|
||||
batch_size=16 if test_mode else 512,
|
||||
backend="nccl" if use_gpu else "gloo",
|
||||
scheduler_step_freq="epoch",
|
||||
use_fp16=use_fp16)
|
||||
@@ -89,18 +93,18 @@ def train_example(num_replicas=1,
|
||||
print("success!")
|
||||
|
||||
|
||||
def tune_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
def tune_example(num_workers=1, use_gpu=False, test_mode=False):
|
||||
config = {
|
||||
"model_creator": ResNet18,
|
||||
"data_creator": cifar_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": nn.CrossEntropyLoss,
|
||||
"num_replicas": num_replicas,
|
||||
"num_workers": num_workers,
|
||||
"initialization_hook": initialization_hook,
|
||||
"use_gpu": use_gpu,
|
||||
"batch_size": 16 if test_mode else 512,
|
||||
"config": {
|
||||
"lr": tune.choice([1e-4, 1e-3]),
|
||||
BATCH_SIZE: 128,
|
||||
"test_mode": test_mode
|
||||
},
|
||||
"backend": "nccl" if use_gpu else "gloo"
|
||||
@@ -124,11 +128,11 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--num-replicas",
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of replicas for training.")
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--num-epochs", type=int, default=5, help="Number of epochs to train.")
|
||||
parser.add_argument(
|
||||
@@ -155,12 +159,12 @@ if __name__ == "__main__":
|
||||
|
||||
if args.tune:
|
||||
tune_example(
|
||||
num_replicas=args.num_replicas,
|
||||
num_workers=args.num_workers,
|
||||
use_gpu=args.use_gpu,
|
||||
test_mode=args.smoke_test)
|
||||
else:
|
||||
train_example(
|
||||
num_replicas=args.num_replicas,
|
||||
num_workers=args.num_workers,
|
||||
num_epochs=args.num_epochs,
|
||||
use_gpu=args.use_gpu,
|
||||
use_fp16=args.fp16,
|
||||
|
||||
@@ -31,7 +31,9 @@ def data_creator(config):
|
||||
]))
|
||||
if config.get("test_mode"):
|
||||
dataset = torch.utils.data.Subset(dataset, list(range(64)))
|
||||
return dataset
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=config.get("batch_size", 32))
|
||||
return dataloader
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
@@ -137,12 +139,15 @@ class GANOperator(TrainingOperator):
|
||||
N = len(imgs)
|
||||
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
|
||||
up = nn.Upsample(
|
||||
size=(28, 28), mode="bilinear").type(torch.FloatTensor)
|
||||
size=(28, 28),
|
||||
mode="bilinear",
|
||||
align_corners=False # This is to reduce user warnings from torch.
|
||||
).type(torch.FloatTensor)
|
||||
|
||||
def get_pred(x):
|
||||
x = up(x)
|
||||
x = self.classifier(x)
|
||||
return F.softmax(x).data.cpu().numpy()
|
||||
return F.softmax(x, dim=1).data.cpu().numpy()
|
||||
|
||||
# Obtain predictions for the fake provided images
|
||||
preds = np.zeros((N, 10))
|
||||
@@ -218,27 +223,32 @@ class GANOperator(TrainingOperator):
|
||||
}
|
||||
|
||||
|
||||
def train_example(num_replicas=1, use_gpu=False, test_mode=False):
|
||||
def train_example(num_workers=1, use_gpu=False, test_mode=False):
|
||||
config = {
|
||||
"test_mode": test_mode,
|
||||
"batch_size": 16 if test_mode else 512 // num_workers,
|
||||
"classification_model_path": os.path.join(
|
||||
os.path.dirname(ray.__file__),
|
||||
"util/sgd/torch/examples/mnist_cnn.pt")
|
||||
}
|
||||
trainer = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
nn.BCELoss,
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.BCELoss,
|
||||
training_operator_cls=GANOperator,
|
||||
num_replicas=num_replicas,
|
||||
num_workers=num_workers,
|
||||
config=config,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=16 if test_mode else 512,
|
||||
backend="nccl" if use_gpu else "gloo")
|
||||
for i in range(5):
|
||||
|
||||
from tabulate import tabulate
|
||||
for itr in range(5):
|
||||
stats = trainer.train()
|
||||
print(stats)
|
||||
formatted = tabulate([stats], headers="keys")
|
||||
if itr > 0: # Get the last line of the stats.
|
||||
formatted = formatted.split("\n")[-1]
|
||||
print(formatted)
|
||||
|
||||
return trainer
|
||||
|
||||
@@ -253,21 +263,21 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="the address to use to connect to a cluster.")
|
||||
parser.add_argument(
|
||||
"--num-replicas",
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of replicas for training.")
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
args, _ = parser.parse_known_args()
|
||||
args = parser.parse_args()
|
||||
ray.init(address=args.address)
|
||||
|
||||
trainer = train_example(
|
||||
num_replicas=args.num_replicas,
|
||||
num_workers=args.num_workers,
|
||||
use_gpu=args.use_gpu,
|
||||
test_mode=args.smoke_test)
|
||||
models = trainer.get_model()
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
# flake8: noqa
|
||||
"""This file holds code for the TorchTrainer creator signatures.
|
||||
|
||||
It ignores yapf because yapf doesn't allow comments right after code blocks,
|
||||
but we put comments right after code blocks to prevent large white spaces
|
||||
in the documentation.
|
||||
"""
|
||||
# yapf: disable
|
||||
|
||||
# __torch_model_start__
|
||||
import torch.nn as nn
|
||||
|
||||
def model_creator(config):
|
||||
"""Constructor function for the model(s) to be optimized.
|
||||
|
||||
You will also need to provide a custom training
|
||||
function to specify the optimization procedure for multiple models.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration dictionary passed into ``TorchTrainer``.
|
||||
|
||||
Returns:
|
||||
One or more torch.nn.Module objects.
|
||||
"""
|
||||
return nn.Linear(1, 1)
|
||||
# __torch_model_end__
|
||||
|
||||
|
||||
# __torch_optimizer_start__
|
||||
import torch
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
"""Constructor of one or more Torch optimizers.
|
||||
|
||||
Args:
|
||||
models: The return values from ``model_creator``. This can be one
|
||||
or more torch nn modules.
|
||||
config (dict): Configuration dictionary passed into ``TorchTrainer``.
|
||||
|
||||
Returns:
|
||||
One or more Torch optimizer objects.
|
||||
"""
|
||||
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
|
||||
# __torch_optimizer_end__
|
||||
|
||||
|
||||
# __torch_data_start__
|
||||
from torch.utils.data import DataLoader
|
||||
from ray.util.sgd.torch.examples.train_example import LinearDataset
|
||||
|
||||
def data_creator(config):
|
||||
"""Constructs Iterables for training and validation.
|
||||
|
||||
Note that even though two Iterable objects can be returned,
|
||||
only one Iterable will be used for training.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary passed into ``TorchTrainer``
|
||||
|
||||
Returns:
|
||||
One or Two Iterable objects. If only one Iterable object is provided,
|
||||
``trainer.validate()`` will throw a ValueError.
|
||||
"""
|
||||
train_dataset, val_dataset = LinearDataset(2, 5), LinearDataset(2, 5)
|
||||
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"])
|
||||
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"])
|
||||
return train_loader, val_loader
|
||||
# __torch_data_end__
|
||||
|
||||
# __torch_loss_start__
|
||||
import torch
|
||||
|
||||
def loss_creator(config):
|
||||
"""Constructs the Torch Loss object.
|
||||
|
||||
Note that optionally, you can pass in a Torch Loss constructor directly
|
||||
into the TorchTrainer (i.e., ``TorchTrainer(loss_creator=nn.BCELoss, ...)``).
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary passed into ``TorchTrainer``
|
||||
|
||||
Returns:
|
||||
Torch Loss object.
|
||||
"""
|
||||
return torch.nn.BCELoss()
|
||||
# __torch_loss_end__
|
||||
|
||||
# __torch_scheduler_start__
|
||||
import torch
|
||||
|
||||
def scheduler_creator(optimizer, config):
|
||||
"""Constructor of one or more Torch optimizer schedulers.
|
||||
|
||||
Args:
|
||||
optimizers: The return values from ``optimizer_creator``.
|
||||
This can be one or more torch optimizer objects.
|
||||
config: Configuration dictionary passed into ``TorchTrainer``
|
||||
|
||||
Returns:
|
||||
One or more Torch scheduler objects.
|
||||
"""
|
||||
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)
|
||||
|
||||
# __torch_scheduler_end__
|
||||
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
# or ray.init(address="auto") to connect to a running cluster.
|
||||
# __torch_ray_end__
|
||||
|
||||
# __torch_trainer_start__
|
||||
from ray.util.sgd import TorchTrainer
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __torch_trainer_end__
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
This file holds code for a Training guide for PytorchSGD in the documentation.
|
||||
"""Example code for RaySGD Torch in the documentation.
|
||||
|
||||
It ignores yapf because yapf doesn't allow comments right after code blocks,
|
||||
but we put comments right after code blocks to prevent large white spaces
|
||||
@@ -55,20 +54,32 @@ def scheduler_creator(optimizer, config):
|
||||
|
||||
def data_creator(config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
train_dataset = LinearDataset(2, 5, size=config.get("data_size", 1000))
|
||||
val_dataset = LinearDataset(2, 5, size=config.get("val_size", 400))
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.get("batch_size", 32),
|
||||
)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.get("batch_size", 32))
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
def train_example(num_replicas=1, use_gpu=False):
|
||||
def train_example(num_workers=1, use_gpu=False):
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
num_replicas=num_replicas,
|
||||
num_workers=num_workers,
|
||||
use_gpu=use_gpu,
|
||||
batch_size=num_replicas * 4,
|
||||
config={"lr": 1e-2, "hidden_size": 1},
|
||||
config={
|
||||
"lr": 1e-2, # used in optimizer_creator
|
||||
"hidden_size": 1, # used in model_creator
|
||||
"batch_size": 4, # used in data_creator
|
||||
},
|
||||
backend="gloo",
|
||||
scheduler_step_freq="epoch")
|
||||
for i in range(5):
|
||||
@@ -91,11 +102,11 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="the address to use for Ray")
|
||||
parser.add_argument(
|
||||
"--num-replicas",
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of replicas for training.")
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
@@ -109,4 +120,4 @@ if __name__ == "__main__":
|
||||
import ray
|
||||
|
||||
ray.init(address=args.address)
|
||||
train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
|
||||
train_example(num_workers=args.num_workers, use_gpu=args.use_gpu)
|
||||
|
||||
@@ -44,18 +44,27 @@ def optimizer_creator(model, config):
|
||||
|
||||
def data_creator(config):
|
||||
"""Returns training dataloader, validation dataloader."""
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
train_dataset = LinearDataset(2, 5)
|
||||
val_dataset = LinearDataset(2, 5, size=400)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config["batch_size"],
|
||||
)
|
||||
validation_loader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config["batch_size"])
|
||||
return train_loader, validation_loader
|
||||
|
||||
|
||||
def tune_example(num_replicas=1, use_gpu=False):
|
||||
def tune_example(num_workers=1, use_gpu=False):
|
||||
config = {
|
||||
"model_creator": tune.function(model_creator),
|
||||
"data_creator": tune.function(data_creator),
|
||||
"optimizer_creator": tune.function(optimizer_creator),
|
||||
"loss_creator": tune.function(nn.MSELoss),
|
||||
"num_replicas": num_replicas,
|
||||
"model_creator": model_creator,
|
||||
"data_creator": data_creator,
|
||||
"optimizer_creator": optimizer_creator,
|
||||
"loss_creator": nn.MSELoss,
|
||||
"num_workers": num_workers,
|
||||
"use_gpu": use_gpu,
|
||||
"batch_size": 512,
|
||||
"config": {"batch_size": 512 // num_workers},
|
||||
"backend": "gloo"
|
||||
}
|
||||
|
||||
@@ -77,20 +86,18 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="the address to use for Ray")
|
||||
parser.add_argument(
|
||||
"--num-replicas",
|
||||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of replicas for training.")
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enables GPU training")
|
||||
parser.add_argument(
|
||||
"--tune", action="store_true", default=False, help="Tune training")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
ray.init(address=args.address)
|
||||
tune_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
|
||||
tune_example(num_workers=args.num_workers, use_gpu=args.use_gpu)
|
||||
|
||||
@@ -4,9 +4,8 @@ import logging
|
||||
import inspect
|
||||
import itertools
|
||||
import os
|
||||
import tempfile
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import ray
|
||||
from ray.util.sgd.torch.constants import USE_FP16, SCHEDULER_STEP
|
||||
@@ -27,16 +26,15 @@ class TorchRunner:
|
||||
"""Manages a PyTorch model for training.
|
||||
|
||||
Args:
|
||||
model_creator (dict -> *): see torch_trainer.py
|
||||
data_creator (dict -> Dataset, Dataset): see torch_trainer.py.
|
||||
optimizer_creator (models, dict -> optimizers): see torch_trainer.py.
|
||||
loss_creator (dict -> loss | Loss class): see torch_trainer.py.
|
||||
scheduler_creator (optimizers, dict -> schedulers): see
|
||||
model_creator (dict -> Model(s)): see torch_trainer.py
|
||||
data_creator (dict -> Iterable(s)): see torch_trainer.py.
|
||||
optimizer_creator ((models, dict) -> optimizers): see torch_trainer.py.
|
||||
loss_creator (torch.nn.*Loss class | dict -> loss):
|
||||
see torch_trainer.py.
|
||||
scheduler_creator ((optimizers, dict) -> scheduler): see
|
||||
torch_trainer.py.
|
||||
training_operator_cls: see torch_trainer.py
|
||||
config (dict): see torch_trainer.py.
|
||||
dataloader_config (dict): See torch_trainer.py.
|
||||
batch_size (int): see torch_trainer.py.
|
||||
use_fp16 (bool): see torch_trainer.py.
|
||||
apex_args (dict|None): see torch_trainer.py.
|
||||
scheduler_step_freq (str): see torch_trainer.py.
|
||||
@@ -46,36 +44,23 @@ class TorchRunner:
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
loss_creator=None,
|
||||
scheduler_creator=None,
|
||||
training_operator_cls=None,
|
||||
config=None,
|
||||
dataloader_config=None,
|
||||
batch_size=16,
|
||||
use_fp16=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq="batch"):
|
||||
self.model_creator = model_creator
|
||||
self.data_creator = data_creator
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.data_creator = data_creator
|
||||
self.scheduler_creator = scheduler_creator
|
||||
self.training_operator_cls = training_operator_cls or TrainingOperator
|
||||
self.config = {} if config is None else config
|
||||
self.dataloader_config = {
|
||||
"num_workers": 2
|
||||
} if dataloader_config is None else dataloader_config
|
||||
self.batch_size = batch_size
|
||||
self.verbose = True
|
||||
|
||||
self.timers = utils.TimerCollection()
|
||||
self.epochs = 0
|
||||
self._timers = {
|
||||
k: utils.TimerStat(window_size=1)
|
||||
for k in [
|
||||
"setup_proc", "setup_model", "get_state", "set_state",
|
||||
"validation", "training"
|
||||
]
|
||||
}
|
||||
self.models = None
|
||||
self.optimizers = None
|
||||
self.criterion = None
|
||||
@@ -90,23 +75,40 @@ class TorchRunner:
|
||||
"https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
def _validate_datasets(self, dataset):
|
||||
assert dataset, "Datasets need to be returned in data_creator."
|
||||
if issubclass(type(dataset), Dataset):
|
||||
return dataset, None
|
||||
elif len(dataset) == 2 and issubclass(type(dataset[0]), Dataset):
|
||||
return dataset
|
||||
else:
|
||||
raise ValueError("Datasets must be <= 2. Got {}".format(dataset))
|
||||
def _validate_loaders(self, loaders):
|
||||
assert loaders, "Loaders need to be returned in data_creator."
|
||||
if isinstance(loaders, (tuple, list)):
|
||||
if len(loaders) == 1:
|
||||
return loaders, None
|
||||
elif len(loaders) == 2:
|
||||
return loaders
|
||||
else:
|
||||
raise ValueError(
|
||||
"Number of loaders must be <= 2. Got {}".format(loaders))
|
||||
# No great way of checking type otherwise
|
||||
return loaders, None
|
||||
|
||||
def _initialize_dataloaders(self):
|
||||
logger.debug("Instantiating dataloaders.")
|
||||
# When creating loaders, a filelock will be used to ensure no
|
||||
# race conditions in data downloading among different workers.
|
||||
with FileLock(os.path.join(tempfile.gettempdir(), ".ray_data.lock")):
|
||||
loaders = self.data_creator(self.config)
|
||||
train_loader, val_loader = self._validate_loaders(loaders)
|
||||
|
||||
self.train_loader, self.validation_loader = train_loader, val_loader
|
||||
|
||||
def _create_loss(self):
|
||||
if not self.loss_creator:
|
||||
return
|
||||
logger.debug("Creating loss.")
|
||||
if inspect.isclass(self.loss_creator) and issubclass(
|
||||
self.loss_creator, torch.nn.modules.loss._Loss):
|
||||
self.criterion = self.loss_creator()
|
||||
else:
|
||||
self.criterion = self.loss_creator(self.config)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and hasattr("cuda", self.criterion):
|
||||
self.criterion = self.criterion.cuda()
|
||||
|
||||
def _create_schedulers_if_available(self):
|
||||
@@ -142,22 +144,7 @@ class TorchRunner:
|
||||
self._create_schedulers_if_available()
|
||||
self._try_setup_apex()
|
||||
self._create_loss()
|
||||
|
||||
logger.debug("Creating dataset")
|
||||
# When creating datasets, a filelock will be used to ensure no
|
||||
# race conditions in data downloading among different workers.
|
||||
with FileLock(os.path.expanduser("~/.ray_data.lock")):
|
||||
datasets = self.data_creator(self.config)
|
||||
train_set, val_set = self._validate_datasets(datasets)
|
||||
|
||||
self.train_loader = torch.utils.data.DataLoader(
|
||||
train_set, batch_size=self.batch_size, **self.dataloader_config)
|
||||
|
||||
self.validation_loader = None
|
||||
if val_set:
|
||||
self.validation_loader = torch.utils.data.DataLoader(
|
||||
val_set, batch_size=self.batch_size, **self.dataloader_config)
|
||||
|
||||
self._initialize_dataloaders()
|
||||
self.training_operator = self.training_operator_cls(
|
||||
self.config,
|
||||
models=self.models,
|
||||
@@ -174,47 +161,55 @@ class TorchRunner:
|
||||
"""Finds a free port on the current node."""
|
||||
return utils.find_free_port()
|
||||
|
||||
def train_epoch(self, num_steps=None, info=None):
|
||||
def train_epoch(self, num_steps=None, profile=False, info=None):
|
||||
"""Runs a training epoch and updates the model parameters."""
|
||||
logger.debug("Begin Training Step {}".format(self.epochs + 1))
|
||||
info = info or {}
|
||||
self._toggle_profiling(profile=profile)
|
||||
|
||||
info.update({
|
||||
USE_FP16: self.use_fp16,
|
||||
SCHEDULER_STEP: self.scheduler_step_freq
|
||||
})
|
||||
with self._timers["training"]:
|
||||
with self.timers.record("train_epoch"):
|
||||
iterator = self.train_loader
|
||||
if num_steps:
|
||||
iterator = itertools.islice(iter(self.train_loader), num_steps)
|
||||
train_stats = self.training_operator.train_epoch(iterator, info)
|
||||
|
||||
self.epochs += 1
|
||||
train_stats.update(self.stats())
|
||||
return train_stats
|
||||
# This is so that `epochs` is first in ordering.
|
||||
stats = dict(epoch=self.epochs, **train_stats)
|
||||
if profile:
|
||||
stats.update(profile=self.timers.stats())
|
||||
return stats
|
||||
|
||||
def validate(self, num_steps=None, info=None):
|
||||
def validate(self, num_steps=None, profile=False, info=None):
|
||||
"""Evaluates the model on the validation data set."""
|
||||
if self.validation_loader is None:
|
||||
raise ValueError("No validation dataloader provided.")
|
||||
info = info or {}
|
||||
with self._timers["validation"]:
|
||||
self._toggle_profiling(profile=profile)
|
||||
|
||||
with self.timers.record("validation"):
|
||||
iterator = self.validation_loader
|
||||
if num_steps:
|
||||
iterator = itertools.islice(
|
||||
iter(self.validation_loader), num_steps)
|
||||
validation_stats = self.training_operator.validate(iterator, info)
|
||||
|
||||
validation_stats.update(self.stats())
|
||||
validation_stats = self.training_operator.validate(
|
||||
iterator, info=info)
|
||||
if profile:
|
||||
validation_stats.update(profile=self.timers.stats())
|
||||
return validation_stats
|
||||
|
||||
def stats(self):
|
||||
"""Returns a dictionary of statistics collected."""
|
||||
stats = {"epoch": self.epochs}
|
||||
for k, t in self._timers.items():
|
||||
stats[k + "_time_mean"] = t.mean
|
||||
stats[k + "_time_total"] = t.sum
|
||||
t.reset()
|
||||
return stats
|
||||
def _toggle_profiling(self, profile=False):
|
||||
"""Enables/Disables and resets timing profiles."""
|
||||
if profile:
|
||||
self.timers.enable()
|
||||
self.timers.reset()
|
||||
else:
|
||||
self.timers.disable()
|
||||
self.training_operator._set_timers(self.timers)
|
||||
|
||||
def _get_model_state_dicts(self):
|
||||
# This is so that we create a duplicate of weights into CPU rather than
|
||||
@@ -237,8 +232,7 @@ class TorchRunner:
|
||||
"epoch": self.epochs,
|
||||
"operator": self.training_operator.state_dict(),
|
||||
"models": self._get_model_state_dicts(),
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers],
|
||||
"stats": self.stats()
|
||||
"optimizers": [opt.state_dict() for opt in self.optimizers]
|
||||
}
|
||||
if self.schedulers:
|
||||
state.update({
|
||||
@@ -264,7 +258,7 @@ class TorchRunner:
|
||||
|
||||
if self.use_fp16 and "amp" in state and amp:
|
||||
amp.load_state_dict(state["amp"])
|
||||
self.epochs = state["stats"]["epoch"]
|
||||
self.epochs = state["epoch"]
|
||||
self.training_operator.load_state_dict(state_dict)
|
||||
|
||||
def apply(self, fn):
|
||||
|
||||
@@ -14,6 +14,7 @@ from ray.tune.trial import Resources
|
||||
from ray.util.sgd.torch.distributed_torch_runner import (
|
||||
DistributedTorchRunner)
|
||||
from ray.util.sgd import utils
|
||||
from ray.util.sgd.utils import NUM_SAMPLES, BATCH_SIZE
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
|
||||
|
||||
@@ -37,6 +38,8 @@ class TorchTrainer:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ray.init()
|
||||
|
||||
def model_creator(config):
|
||||
return nn.Linear(1, 1)
|
||||
|
||||
@@ -47,13 +50,19 @@ class TorchTrainer:
|
||||
|
||||
|
||||
def data_creator(config):
|
||||
return LinearDataset(2, 5), LinearDataset(2, 5, size=400)
|
||||
batch_size = config["batch_size"]
|
||||
train_data, val_data = LinearDataset(2, 5), LinearDataset(2, 5)
|
||||
train_loader = DataLoader(train_data, batch_size=batch_size)
|
||||
val_loader = DataLoader(val_data, batch_size=batch_size)
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
config={"batch_size": 32},
|
||||
use_gpu=True
|
||||
)
|
||||
for i in range(4):
|
||||
@@ -67,12 +76,12 @@ class TorchTrainer:
|
||||
a ``training_operator_cls`` must be specified. You do not need to
|
||||
handle GPU/devices in this function; RaySGD will do that under
|
||||
the hood.
|
||||
data_creator (dict -> Dataset(s)): Constructor function
|
||||
data_creator (dict -> Iterable(s)): Constructor function
|
||||
that takes in the passed config and returns one or
|
||||
two ``torch.utils.data.Dataset`` objects.
|
||||
Note that even though two Dataset objects can be returned,
|
||||
only one dataset will be used for training. RaySGD
|
||||
will automatically wrap the objects with a ``DataLoader``.
|
||||
two Iterable objects. Note that even though two Iterable objects
|
||||
can be returned, only one will be used for training, and the
|
||||
other will be used for validation. If not provided, you must
|
||||
provide a custom TrainingOperator.
|
||||
optimizer_creator ((models, dict) -> optimizers): Constructor
|
||||
function that takes in the return values from
|
||||
``model_creator`` and the passed config and returns One or
|
||||
@@ -83,7 +92,8 @@ class TorchTrainer:
|
||||
takes in the provided config for customization or a subclass
|
||||
of ``torch.nn.modules.loss._Loss``, which is most Pytorch
|
||||
loss classes. For example, ``loss_creator=torch.nn.BCELoss``.
|
||||
scheduler_creator (optimizers, dict -> loss):
|
||||
If not provided, you must provide a custom TrainingOperator.
|
||||
scheduler_creator ((optimizers, dict) -> scheduler):
|
||||
A constructor function for the torch scheduler. This is
|
||||
a function that takes in the generated optimizers (from
|
||||
``optimizer_creator``) provided config for customization.
|
||||
@@ -96,20 +106,12 @@ class TorchTrainer:
|
||||
TrainingOperator.
|
||||
config (dict): Custom configuration value to be passed to
|
||||
all creator and operator constructors.
|
||||
dataloader_config (dict): Configuration values to be passed into
|
||||
the ``torch.utils.data.DataLoader`` object that wraps
|
||||
the dataset on each parallel worker for both training
|
||||
and validation. Note that if ``num_replicas``
|
||||
is greater than 1, ``shuffle`` and ``sampler`` will be
|
||||
automatically set. See the available arguments
|
||||
here https://pytorch.org/docs/stable/data.html.
|
||||
num_replicas (int): the number of workers used in distributed
|
||||
training.
|
||||
num_workers (int): the number of workers used in distributed
|
||||
training. If 1, the worker will not be wrapped with
|
||||
DistributedDataParallel.
|
||||
use_gpu (bool): Sets resource allocation for workers to 1 GPU
|
||||
if true, and automatically moves both the model and optimizer
|
||||
to the available CUDA device.
|
||||
batch_size (int): Total batch size for each minibatch. This
|
||||
value is divided among all workers and rounded.
|
||||
backend (string): backend used by distributed PyTorch. Currently
|
||||
support "nccl", "gloo", and "auto". If "auto", RaySGD will
|
||||
automatically use "nccl" if `use_gpu` is True, and "gloo"
|
||||
@@ -130,50 +132,53 @@ class TorchTrainer:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator,
|
||||
*,
|
||||
model_creator=None,
|
||||
data_creator=None,
|
||||
optimizer_creator=None,
|
||||
loss_creator=None,
|
||||
scheduler_creator=None,
|
||||
training_operator_cls=None,
|
||||
initialization_hook=None,
|
||||
config=None,
|
||||
dataloader_config=None,
|
||||
num_replicas=1,
|
||||
num_workers=1,
|
||||
use_gpu=False,
|
||||
batch_size=16,
|
||||
backend="auto",
|
||||
use_fp16=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq="batch"):
|
||||
if num_replicas > 1 and not dist.is_available():
|
||||
if num_workers > 1 and not dist.is_available():
|
||||
raise ValueError(
|
||||
("Distributed PyTorch is not supported on macOS. "
|
||||
"To run without distributed PyTorch, set 'num_replicas=1'. "
|
||||
"To run without distributed PyTorch, set 'num_workers=1'. "
|
||||
"For more information, see "
|
||||
"https://github.com/pytorch/examples/issues/467."))
|
||||
|
||||
if not (model_creator and optimizer_creator and data_creator):
|
||||
raise ValueError("Must provide a Model, Optimizer, Data creator.")
|
||||
self.model_creator = model_creator
|
||||
self.data_creator = data_creator
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
self.data_creator = data_creator
|
||||
self.scheduler_creator = scheduler_creator
|
||||
self.training_operator_cls = training_operator_cls
|
||||
|
||||
if not training_operator_cls and not loss_creator:
|
||||
raise ValueError("If a loss_creator is not provided, you must "
|
||||
"provide a custom training operator.")
|
||||
|
||||
self.initialization_hook = initialization_hook
|
||||
self.config = {} if config is None else config
|
||||
self.dataloader_config = dataloader_config
|
||||
self.optimizer_timer = utils.TimerStat(window_size=1)
|
||||
|
||||
if backend == "auto":
|
||||
backend = "nccl" if use_gpu else "gloo"
|
||||
|
||||
logger.info("Using {} as backend.".format(backend))
|
||||
logger.debug("Using {} as backend.".format(backend))
|
||||
self.backend = backend
|
||||
|
||||
# TODO: Have an auto "use_gpu" option to detect and use GPUs.
|
||||
self.use_gpu = use_gpu
|
||||
self.batch_size = batch_size
|
||||
self.max_replicas = num_replicas
|
||||
self.max_replicas = num_workers
|
||||
|
||||
self.use_fp16 = use_fp16
|
||||
|
||||
@@ -190,24 +195,46 @@ class TorchTrainer:
|
||||
|
||||
self._start_workers(self.max_replicas)
|
||||
|
||||
def _start_workers(self, num_replicas):
|
||||
logger.info(f"start_workers: Setting %d replicas." % num_replicas)
|
||||
if num_replicas == 1:
|
||||
def _configure_and_split_batch(self, num_workers):
|
||||
"""If sgd.utils.BATCH_SIZE is provided, split among workers."""
|
||||
if BATCH_SIZE not in self.config:
|
||||
return
|
||||
# Compute batch size per worker
|
||||
logger.debug("BATCH_SIZE parameter detected. Splitting among workers.")
|
||||
batch_size = self.config[BATCH_SIZE]
|
||||
batch_size_per_worker = batch_size // num_workers
|
||||
if batch_size % num_workers > 0:
|
||||
new_batch_size = batch_size_per_worker * num_workers
|
||||
logger.warning(
|
||||
("Changing batch size from {old_batch_size} to "
|
||||
"{new_batch_size} to evenly distribute batches across "
|
||||
"{num_workers} workers.").format(
|
||||
old_batch_size=batch_size,
|
||||
new_batch_size=new_batch_size,
|
||||
num_workers=num_workers))
|
||||
self.config[BATCH_SIZE] = new_batch_size
|
||||
return batch_size_per_worker
|
||||
|
||||
def _start_workers(self, num_workers):
|
||||
logger.debug(f"start_workers: Setting %d workers." % num_workers)
|
||||
worker_config = self.config.copy()
|
||||
batch_size_per_worker = self._configure_and_split_batch(num_workers)
|
||||
if batch_size_per_worker:
|
||||
worker_config[BATCH_SIZE] = batch_size_per_worker
|
||||
if num_workers == 1:
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=1, num_gpus=int(self.use_gpu))(TorchRunner)
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(
|
||||
self.model_creator,
|
||||
self.data_creator,
|
||||
self.optimizer_creator,
|
||||
self.loss_creator,
|
||||
self.scheduler_creator,
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=self.config,
|
||||
dataloader_config=self.dataloader_config,
|
||||
batch_size=self.batch_size,
|
||||
config=worker_config,
|
||||
use_fp16=self.use_fp16,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq,
|
||||
@@ -221,34 +248,21 @@ class TorchTrainer:
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=1, num_gpus=int(self.use_gpu))(DistributedTorchRunner)
|
||||
# Compute batch size per replica
|
||||
batch_size_per_replica = self.batch_size // num_replicas
|
||||
if self.batch_size % num_replicas > 0:
|
||||
new_batch_size = batch_size_per_replica * num_replicas
|
||||
logger.warning(
|
||||
("Changing batch size from {old_batch_size} to "
|
||||
"{new_batch_size} to evenly distribute batches across "
|
||||
"{num_replicas} replicas.").format(
|
||||
old_batch_size=self.batch_size,
|
||||
new_batch_size=new_batch_size,
|
||||
num_replicas=num_replicas))
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(
|
||||
self.model_creator,
|
||||
self.data_creator,
|
||||
self.optimizer_creator,
|
||||
self.loss_creator,
|
||||
self.scheduler_creator,
|
||||
model_creator=self.model_creator,
|
||||
data_creator=self.data_creator,
|
||||
optimizer_creator=self.optimizer_creator,
|
||||
loss_creator=self.loss_creator,
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
backend=self.backend,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=self.config,
|
||||
dataloader_config=self.dataloader_config,
|
||||
batch_size=batch_size_per_replica,
|
||||
config=worker_config,
|
||||
use_fp16=self.use_fp16,
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
for i in range(num_replicas)
|
||||
for i in range(num_workers)
|
||||
]
|
||||
if self.initialization_hook:
|
||||
self.apply_all_workers(self.initialization_hook)
|
||||
@@ -265,18 +279,28 @@ class TorchTrainer:
|
||||
|
||||
def train(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
reduce_results=True,
|
||||
max_retries=0,
|
||||
checkpoint="auto",
|
||||
info=None):
|
||||
"""Runs a training epoch.
|
||||
|
||||
Runs an average over all values returned from workers. Set
|
||||
`max_retries` to enable fault handling in case of instance preemption.
|
||||
Calls `operator.train_epoch()` on N parallel workers simultaneously
|
||||
underneath the hood.
|
||||
|
||||
Set `max_retries` to enable fault handling in case of
|
||||
instance preemption.
|
||||
|
||||
Args:
|
||||
num_steps (int): Number of batches to compute update steps on.
|
||||
This corresponds also to the number of times
|
||||
``TrainingOperator.train_batch`` is called.
|
||||
profile (bool): Returns time stats for the training procedure.
|
||||
reduce_results (bool): Whether to average all metrics across
|
||||
all workers into one dict. If a metric is a non-numerical
|
||||
value (or nested dictionaries), one value will be randomly
|
||||
selected among the workers. If False, returns a list of dicts.
|
||||
max_retries (int): Must be non-negative. If set to N, will
|
||||
kill all current workers, query the Ray global state for
|
||||
total available resources, and re-launch up to the
|
||||
@@ -289,9 +313,11 @@ class TorchTrainer:
|
||||
operator for ``train_epoch`` and ``train_batch``.
|
||||
|
||||
Returns:
|
||||
A dictionary of metrics for training.
|
||||
(dict | list) A dictionary of metrics for training.
|
||||
You can provide custom metrics by passing in a custom
|
||||
``training_operator_cls``.
|
||||
``training_operator_cls``. If ``reduce_results=False``,
|
||||
this will return a list of metric dictionaries whose
|
||||
length will be equal to ``num_workers``.
|
||||
"""
|
||||
assert max_retries >= 0, "`max_retries` must be non-negative."
|
||||
if max_retries:
|
||||
@@ -306,37 +332,46 @@ class TorchTrainer:
|
||||
logger.info("Resize opportunity detected. Attempting to scale up.")
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
|
||||
with self.optimizer_timer:
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
# Fault handling
|
||||
for i in range(max_retries):
|
||||
if success:
|
||||
break
|
||||
else:
|
||||
self._num_failures += 1
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
logger.info(
|
||||
"Retrying training step with %d workers." % len(self.workers))
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, info=info)
|
||||
# Fault handling
|
||||
for i in range(max_retries):
|
||||
if success:
|
||||
break
|
||||
else:
|
||||
self._num_failures += 1
|
||||
self._resize_workers(checkpoint=checkpoint)
|
||||
logger.info("Retrying training step with %d workers." % len(
|
||||
self.workers))
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, info=info)
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
if not success:
|
||||
raise RuntimeError("Training run failed.")
|
||||
|
||||
worker_stats = ray.get(worker_stats)
|
||||
if reduce_results:
|
||||
return self._process_stats(worker_stats)
|
||||
else:
|
||||
return worker_stats
|
||||
|
||||
def _process_stats(self, worker_stats):
|
||||
stats = {
|
||||
NUM_SAMPLES: sum(
|
||||
stats.pop(NUM_SAMPLES, np.nan) for stats in worker_stats)
|
||||
}
|
||||
|
||||
train_stats = {}
|
||||
for stat_key in worker_stats[0]:
|
||||
if isinstance(worker_stats[0], numbers.Number):
|
||||
train_stats[stat_key] = np.nanmean(
|
||||
stats[stat_key] = np.nanmean(
|
||||
[s.get(stat_key, np.nan) for s in worker_stats])
|
||||
else:
|
||||
train_stats[stat_key] = worker_stats[0][stat_key]
|
||||
return train_stats
|
||||
stats[stat_key] = worker_stats[0][stat_key]
|
||||
return stats
|
||||
|
||||
def _train_epoch(self, num_steps=None, info=None):
|
||||
def _train_epoch(self, num_steps=None, profile=False, info=None):
|
||||
worker_stats = [
|
||||
w.train_epoch.remote(num_steps=num_steps, info=info)
|
||||
w.train_epoch.remote(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
for w in self.workers
|
||||
]
|
||||
success = utils.check_for_failure(worker_stats)
|
||||
@@ -367,13 +402,14 @@ class TorchTrainer:
|
||||
"""
|
||||
return ray.get([w.apply_operator.remote(fn) for w in self.workers])
|
||||
|
||||
def validate(self, num_steps=None, info=None):
|
||||
def validate(self, num_steps=None, profile=False, info=None):
|
||||
"""Evaluates the model on the validation data set.
|
||||
|
||||
Args:
|
||||
num_steps (int): Number of batches to compute update steps on.
|
||||
This corresponds also to the number of times
|
||||
``TrainingOperator.validate_batch`` is called.
|
||||
profile (bool): Returns time stats for the evaluation procedure.
|
||||
info (dict): Optional dictionary passed to the training
|
||||
operator for `validate` and `validate_batch`.
|
||||
|
||||
@@ -383,15 +419,11 @@ class TorchTrainer:
|
||||
``training_operator_cls``.
|
||||
"""
|
||||
worker_stats = ray.get([
|
||||
w.validate.remote(num_steps=num_steps, info=info)
|
||||
w.validate.remote(num_steps=num_steps, profile=profile, info=info)
|
||||
for w in self.workers
|
||||
])
|
||||
|
||||
validation_stats = {}
|
||||
for stat_key in worker_stats[0]:
|
||||
validation_stats[stat_key] = np.nanmean(
|
||||
[s.get(stat_key, np.nan) for s in worker_stats])
|
||||
return validation_stats
|
||||
return self._process_stats(worker_stats)
|
||||
|
||||
def update_scheduler(self, metric):
|
||||
"""Calls ``scheduler.step(metric)`` on all schedulers.
|
||||
@@ -492,8 +524,8 @@ class TorchTrainable(Trainable):
|
||||
return Resources(
|
||||
cpu=0,
|
||||
gpu=0,
|
||||
extra_cpu=config["num_replicas"],
|
||||
extra_gpu=int(config["use_gpu"]) * config["num_replicas"])
|
||||
extra_cpu=config["num_workers"],
|
||||
extra_gpu=int(config["use_gpu"]) * config["num_workers"])
|
||||
|
||||
def _setup(self, config):
|
||||
self._trainer = TorchTrainer(**config)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import collections
|
||||
import torch
|
||||
|
||||
from ray.util.sgd.utils import TimerStat, AverageMeter
|
||||
from ray.util.sgd.torch.constants import (
|
||||
SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_BATCH, SCHEDULER_STEP, BATCH_COUNT)
|
||||
from ray.util.sgd.utils import (TimerCollection, AverageMeterCollection,
|
||||
NUM_SAMPLES)
|
||||
from ray.util.sgd.torch.constants import (SCHEDULER_STEP_EPOCH,
|
||||
SCHEDULER_STEP_BATCH, SCHEDULER_STEP)
|
||||
|
||||
amp = None
|
||||
|
||||
@@ -48,15 +49,10 @@ class TrainingOperator:
|
||||
config,
|
||||
models,
|
||||
optimizers,
|
||||
criterion,
|
||||
criterion=None,
|
||||
schedulers=None,
|
||||
use_fp16=False):
|
||||
# You are not expected to override this method.
|
||||
self.timers = {
|
||||
k: TimerStat()
|
||||
for k in ["fwd", "grad", "apply", "epoch_time"]
|
||||
}
|
||||
self._validated_customization = False
|
||||
self._models = models # List of models
|
||||
assert isinstance(models, collections.Iterable), (
|
||||
"Components need to be iterable. Got: {}".format(type(models)))
|
||||
@@ -80,9 +76,13 @@ class TrainingOperator:
|
||||
"Need to provide a custom operator subclassing "
|
||||
"TrainingOperator if using multi-scheduler, "
|
||||
"multi-model or multi-optimizer training/validation.")
|
||||
|
||||
self.timers = TimerCollection()
|
||||
self.setup(config)
|
||||
|
||||
def _set_timers(self, timers):
|
||||
"""Passes in the timers from the Runner."""
|
||||
self.timers = timers
|
||||
|
||||
def setup(self, config):
|
||||
"""Override this method to implement custom operator setup.
|
||||
|
||||
@@ -93,17 +93,35 @@ class TrainingOperator:
|
||||
pass
|
||||
|
||||
def train_epoch(self, iterator, info):
|
||||
"""Runs one standard training pass over the train_iterator.
|
||||
"""Runs one standard training pass over the training dataloader.
|
||||
|
||||
By default, this method will iterate over the given iterator and
|
||||
call ``self.train_batch`` over each batch.
|
||||
|
||||
If ``scheduler_step_freq`` is set, this class will also step the
|
||||
scheduler accordingly.
|
||||
call ``self.train_batch`` over each batch. If ``scheduler_step_freq``
|
||||
is set, this default method will also step the scheduler accordingly.
|
||||
|
||||
You do not need to call ``train_batch`` in this method if you plan
|
||||
to implement a custom optimization/training routine here.
|
||||
|
||||
You may find ``ray.util.sgd.utils.AverageMeterCollection`` useful
|
||||
when overriding this method. See example below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def train_epoch(self, ...):
|
||||
meter_collection = AverageMeterCollection()
|
||||
self.model.train()
|
||||
for batch in iterator:
|
||||
# do some processing
|
||||
metrics = {"metric_1": 1, "metric_2": 3} # dict of metrics
|
||||
|
||||
# This keeps track of all metrics across multiple batches
|
||||
meter_collection.update(metrics, n=len(batch))
|
||||
|
||||
# Returns stats of the meters.
|
||||
stats = meter_collection.summary()
|
||||
return stats
|
||||
|
||||
|
||||
Args:
|
||||
iterator (iter): Iterator over the training data for the entire
|
||||
epoch. This iterator is expected to be entirely consumed.
|
||||
@@ -113,41 +131,28 @@ class TrainingOperator:
|
||||
Returns:
|
||||
A dict of metrics from training.
|
||||
"""
|
||||
self._losses = AverageMeter()
|
||||
metric_meters = AverageMeterCollection()
|
||||
|
||||
self.model.train()
|
||||
with self.timers["epoch_time"]:
|
||||
for batch_idx, batch in enumerate(iterator):
|
||||
batch_info = {
|
||||
"batch_idx": batch_idx,
|
||||
"global_step": self.global_step
|
||||
}
|
||||
batch_info.update(info)
|
||||
metrics = self.train_batch(batch, batch_info=batch_info)
|
||||
for batch_idx, batch in enumerate(iterator):
|
||||
batch_info = {
|
||||
"batch_idx": batch_idx,
|
||||
"global_step": self.global_step
|
||||
}
|
||||
batch_info.update(info)
|
||||
metrics = self.train_batch(batch, batch_info=batch_info)
|
||||
|
||||
if self.scheduler and batch_info.get(
|
||||
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
|
||||
self.scheduler.step()
|
||||
if self.scheduler and batch_info.get(
|
||||
SCHEDULER_STEP) == SCHEDULER_STEP_BATCH:
|
||||
self.scheduler.step()
|
||||
|
||||
if "loss" in metrics:
|
||||
self._losses.update(
|
||||
metrics["loss"], n=metrics.get("num_samples", 1))
|
||||
self.global_step += 1
|
||||
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
|
||||
self.global_step += 1
|
||||
|
||||
if self.scheduler and info.get(SCHEDULER_STEP) == SCHEDULER_STEP_EPOCH:
|
||||
self.scheduler.step()
|
||||
|
||||
stats = {
|
||||
BATCH_COUNT: batch_idx + 1,
|
||||
"mean_train_loss": self._losses.avg,
|
||||
"last_train_loss": self._losses.val,
|
||||
"epoch_time": self.timers["epoch_time"].last
|
||||
}
|
||||
stats.update({
|
||||
timer_tag: timer.mean
|
||||
for timer_tag, timer in self.timers.items()
|
||||
})
|
||||
return stats
|
||||
return metric_meters.summary()
|
||||
|
||||
def train_batch(self, batch, batch_info):
|
||||
"""Computes loss and updates the model over one batch.
|
||||
@@ -176,6 +181,9 @@ class TrainingOperator:
|
||||
By default, this dictionary contains "loss" and "num_samples".
|
||||
"num_samples" corresponds to number of datapoints in the batch.
|
||||
However, you can provide any number of other values.
|
||||
Consider returning "num_samples" in the metrics because
|
||||
by default, ``train_epoch`` uses "num_samples" to
|
||||
calculate averages.
|
||||
|
||||
"""
|
||||
features, target = batch
|
||||
@@ -185,12 +193,12 @@ class TrainingOperator:
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# Compute output.
|
||||
with self.timers["fwd"]:
|
||||
with self.timers.record("fwd"):
|
||||
output = self.model(features)
|
||||
loss = self.criterion(output, target)
|
||||
|
||||
# Compute gradients in a backward pass.
|
||||
with self.timers["grad"]:
|
||||
with self.timers.record("grad"):
|
||||
self.optimizer.zero_grad()
|
||||
if self.use_fp16:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
@@ -199,35 +207,34 @@ class TrainingOperator:
|
||||
loss.backward()
|
||||
|
||||
# Call step of optimizer to update model params.
|
||||
with self.timers["apply"]:
|
||||
with self.timers.record("apply"):
|
||||
self.optimizer.step()
|
||||
return {"loss": loss.item(), "num_samples": features.size(0)}
|
||||
return {"train_loss": loss.item(), NUM_SAMPLES: features.size(0)}
|
||||
|
||||
def validate(self, val_iterator, info):
|
||||
"""Runs one standard validation pass over the val_iterator.
|
||||
|
||||
This will call ``model.eval()`` and ``torch.no_grad`` when iterating
|
||||
over the validation dataset.
|
||||
over the validation dataloader.
|
||||
|
||||
If overriding this method, you can access model, criterion via
|
||||
``self.model`` and ``self.criterion``. You also do not need to call
|
||||
``validate_batch`` if overriding this method.
|
||||
|
||||
Args:
|
||||
val_iterator (iter): Iterable constructed over the
|
||||
validation dataset.
|
||||
val_iterator (iter): Iterable constructed from the
|
||||
validation dataloader.
|
||||
info: (dict): Dictionary for information to be used for custom
|
||||
validation operations.
|
||||
|
||||
Returns:
|
||||
A dict of metrics from the evaluation.
|
||||
By default, returns "mean_accuracy" and "mean_validation_loss"
|
||||
By default, returns "mean_accuracy" and "mean_val_loss"
|
||||
which is computed by aggregating "loss" and "correct" values
|
||||
from ``validate_batch`` and dividing it by the sum of
|
||||
``num_samples`` from all calls to ``self.validate_batch``.
|
||||
"""
|
||||
losses = AverageMeter()
|
||||
total_correct = 0
|
||||
metric_meters = AverageMeterCollection()
|
||||
|
||||
# switch to evaluate mode
|
||||
self.model.eval()
|
||||
@@ -236,19 +243,9 @@ class TrainingOperator:
|
||||
batch_info = {"batch_idx": batch_idx}
|
||||
batch_info.update(info)
|
||||
metrics = self.validate_batch(batch, batch_info)
|
||||
if "loss" in metrics:
|
||||
losses.update(
|
||||
metrics["loss"], n=metrics.get("num_samples", 1))
|
||||
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
|
||||
|
||||
if "num_correct" in metrics:
|
||||
total_correct += metrics["num_correct"]
|
||||
|
||||
stats = {
|
||||
"batch_count": batch_idx + 1,
|
||||
"mean_validation_loss": losses.avg,
|
||||
"mean_accuracy": total_correct / losses.count
|
||||
}
|
||||
return stats
|
||||
return metric_meters.summary()
|
||||
|
||||
def validate_batch(self, batch, batch_info):
|
||||
"""Calcuates the loss and accuracy over a given batch.
|
||||
@@ -262,7 +259,11 @@ class TrainingOperator:
|
||||
|
||||
Returns:
|
||||
A dict of metrics.
|
||||
By default, returns "loss", "num_correct", and "num_samples".
|
||||
By default, returns "val_loss", "val_accuracy", and
|
||||
"num_samples". When overriding, consider returning
|
||||
"num_samples" in the metrics because
|
||||
by default, ``validate`` uses "num_samples" to
|
||||
calculate averages.
|
||||
"""
|
||||
features, target = batch
|
||||
if torch.cuda.is_available():
|
||||
@@ -270,14 +271,18 @@ class TrainingOperator:
|
||||
target = target.cuda(non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output = self.model(features)
|
||||
loss = self.criterion(output, target)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
|
||||
with self.timers.record("eval_fwd"):
|
||||
output = self.model(features)
|
||||
loss = self.criterion(output, target)
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
|
||||
num_correct = (predicted == target).sum().item()
|
||||
num_samples = target.size(0)
|
||||
return {
|
||||
"loss": loss.item(),
|
||||
"num_correct": (predicted == target).sum().item(),
|
||||
"num_samples": target.size(0)
|
||||
"val_loss": loss.item(),
|
||||
"val_accuracy": num_correct / num_samples,
|
||||
NUM_SAMPLES: num_samples
|
||||
}
|
||||
|
||||
def state_dict(self):
|
||||
@@ -341,3 +346,24 @@ class _TestingOperator(TrainingOperator):
|
||||
if callable(func):
|
||||
return func(self, iterator, info)
|
||||
return {"done": 1}
|
||||
|
||||
|
||||
class _TestMetricsOperator(TrainingOperator):
|
||||
def setup(self, config):
|
||||
self._train_scores = config["scores"].copy()
|
||||
self._val_scores = config["val_scores"].copy()
|
||||
self.key = config["key"]
|
||||
|
||||
def train_batch(self, batch, batch_info=None):
|
||||
metrics = super(_TestMetricsOperator, self).train_batch(
|
||||
batch, batch_info)
|
||||
num_samples = metrics[NUM_SAMPLES]
|
||||
metrics.update({self.key: self._train_scores.pop(0) / num_samples})
|
||||
return metrics
|
||||
|
||||
def validate_batch(self, batch, batch_info=None):
|
||||
metrics = super(_TestMetricsOperator, self).validate_batch(
|
||||
batch, batch_info)
|
||||
num_samples = metrics[NUM_SAMPLES]
|
||||
metrics.update({self.key: self._val_scores.pop(0) / num_samples})
|
||||
return metrics
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from contextlib import closing
|
||||
import collections
|
||||
from contextlib import closing, contextmanager
|
||||
import logging
|
||||
import numpy as np
|
||||
import socket
|
||||
@@ -9,6 +10,10 @@ from ray.exceptions import RayActorError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BATCH_COUNT = "batch_count"
|
||||
NUM_SAMPLES = "num_samples"
|
||||
BATCH_SIZE = "*batch_size"
|
||||
|
||||
|
||||
class TimerStat:
|
||||
"""A running stat for conveniently logging the duration of a code block.
|
||||
@@ -103,6 +108,46 @@ class TimerStat:
|
||||
self.count = 0
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _nullcontext(enter_result=None):
|
||||
"""Used for mocking timer context."""
|
||||
yield enter_result
|
||||
|
||||
|
||||
class TimerCollection:
|
||||
"""A grouping of Timers."""
|
||||
|
||||
def __init__(self):
|
||||
self._timers = collections.defaultdict(TimerStat)
|
||||
self._enabled = True
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
def enable(self):
|
||||
self._enabled = True
|
||||
|
||||
def reset(self):
|
||||
for timer in self._timers.values():
|
||||
timer.reset()
|
||||
|
||||
def record(self, key):
|
||||
if self._enabled:
|
||||
return self._timers[key]
|
||||
else:
|
||||
return _nullcontext()
|
||||
|
||||
def stats(self, mean=True, last=False):
|
||||
aggregates = {}
|
||||
for k, t in self._timers.items():
|
||||
if t.count > 0:
|
||||
if mean:
|
||||
aggregates["mean_%s_s" % k] = t.mean
|
||||
if last:
|
||||
aggregates["last_%s_s" % k] = t.last
|
||||
return aggregates
|
||||
|
||||
|
||||
def find_free_port():
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
@@ -129,6 +174,29 @@ class AverageMeter:
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
class AverageMeterCollection:
|
||||
"""A grouping of AverageMeters."""
|
||||
|
||||
def __init__(self):
|
||||
self._batch_count = 0
|
||||
self.n = 0
|
||||
self._meters = collections.defaultdict(AverageMeter)
|
||||
|
||||
def update(self, metrics, n=1):
|
||||
self._batch_count += 1
|
||||
self.n += n
|
||||
for metric, value in metrics.items():
|
||||
self._meters[metric].update(value, n=n)
|
||||
|
||||
def summary(self):
|
||||
"""Returns a dict of average and most recent values for each metric."""
|
||||
stats = {BATCH_COUNT: self._batch_count, NUM_SAMPLES: self.n}
|
||||
for metric, meter in self._meters.items():
|
||||
stats["mean_" + str(metric)] = meter.avg
|
||||
stats["last_" + str(metric)] = meter.val
|
||||
return stats
|
||||
|
||||
|
||||
def check_for_failure(remote_values):
|
||||
"""Checks remote values for any that returned and failed.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user