[SGD] Dataset API (#7839)

This commit is contained in:
Alex Wu
2020-06-01 15:48:15 -07:00
committed by GitHub
parent 21d5b49c56
commit dcf58a43dc
11 changed files with 344 additions and 67 deletions
+5
View File
@@ -0,0 +1,5 @@
from ray.util.sgd.data.dataset import Dataset
import logging
logger = logging.getLogger(__name__)
__all__ = ["Dataset"]
+92
View File
@@ -0,0 +1,92 @@
from ray.util.iter import ParallelIterator, from_iterators
class Dataset():
"""A simple Dataset abstraction for RaySGD.
This dataset is designed to work with RaySGD trainers (currently just
Torch) to provide support for streaming large external datasets, and built
in sharding.
.. code-block:: python
def to_mat(x):
return torch.tensor([[x]]).float()
data = [i * 0.001 for i in range(1000)]
p_iter = iter.from_items(data, num_shards=1, repeat=True)
dataset = Dataset(
p_iter,
batch_size=32,
max_concurrency=1,
download_func=lambda x: (to_mat(x), to_mat(x)))
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=None,
optimizer_creator=optimizer_creator,
loss_creator=torch.nn.MSELoss,
num_workers=5,
)
for i in range(10):
# Train for another epoch using the dataset
trainer.train(dataset=dataset, num_steps=200)
model = trainer.get_model()
print("f(0.5)=", float(model(to_mat(0.5))[0][0]))
Args:
data (iterable[U] or ParallelIterator[U]): Any existing python
iterable (or iterator), or an existing parallel iterator
to use.
batch_size (int): The batch size for training/inference (default 32).
download_func (U -> (S, Y)): A function which returns two values, the
input and the label (default is the identity function).
max_concurrency (int): The maximum number of concurrent calls to the
download function. See ParallelIterator::for_each for details.
transform (S -> X): A final transformation to be applied to the _input
only_. This is guaranteed to run on the same worker that training
will occur on.
"""
def __init__(self,
data,
batch_size=32,
download_func=None,
max_concurrency=0,
transform=None):
par_iter = None
if isinstance(data, ParallelIterator):
par_iter = data.repartition(1)
else:
par_iter = from_iterators([data], repeat=True)
if download_func:
par_iter = par_iter.for_each(
download_func, max_concurrency=max_concurrency)
self.iter = par_iter.batch(batch_size)
self.batch_size = batch_size
self.max_concurrency = max_concurrency
self.transform = transform
def set_num_shards(self, num_shards):
"""
Reshards the iterator if necessary.
"""
if num_shards != self.iter.num_shards():
print("Setting num shards", num_shards)
self.iter = self.iter.repartition(num_shards)
def get_shard(self, i):
"""
Returns a single, iterable shard.
"""
assert i < self.iter.num_shards(), \
"Trying to get shard {} but there are only {} shards." + \
"Are you sure you called set_num_shards already?".format(
i, self.iter.num_shards()
)
return self.iter.get_shard(i)
@@ -0,0 +1 @@
ray-project/*
@@ -0,0 +1,69 @@
import ray
from ray.util.sgd.torch.torch_trainer import TorchTrainer
from ray.util.sgd.data.dataset import Dataset
import torch
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(1, 128)
self.fc2 = nn.Linear(128, 1)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x
def model_creator(config):
return Net()
def optimizer_creator(model, config):
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
def to_mat(x):
return torch.tensor([[x]]).float()
def dataset_creator():
num_points = 32 * 100 * 2
data = [i * (1 / num_points) for i in range(num_points)]
dataset = Dataset(
data,
batch_size=32,
max_concurrency=2,
download_func=lambda x: (to_mat(x), to_mat(x)))
return dataset
def main():
dataset = dataset_creator()
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=None,
optimizer_creator=optimizer_creator,
loss_creator=torch.nn.MSELoss,
num_workers=2,
)
for i in range(10):
# Train a full epoch using the data_creator
# trainer.train()
# Train for another epoch using the dataset
trainer.train(dataset=dataset, num_steps=100)
model = trainer.get_model()
print("f(0.5)=", float(model(to_mat(0.5))[0][0]))
if __name__ == "__main__":
ray.init()
main()
+68 -51
View File
@@ -17,6 +17,7 @@ from ray.util.sgd.torch.constants import SCHEDULER_STEP
from ray.util.sgd.utils import (check_for_failure, NUM_SAMPLES, BATCH_COUNT,
BATCH_SIZE)
from ray.util.sgd.data.examples import mlp_identity
from ray.util.sgd.torch.examples.train_example import (
model_creator, optimizer_creator, data_creator, LinearDataset)
@@ -32,6 +33,17 @@ def ray_start_2_cpus():
dist.destroy_process_group()
@pytest.fixture
def ray_start_4_cpus():
address_info = ray.init(num_cpus=4)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
# Ensure that tests don't ALL fail
if dist.is_initialized():
dist.destroy_process_group()
def test_single_step(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,
@@ -120,12 +132,12 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
def test_multi_model(ray_start_2_cpus, num_workers):
def train(*, model=None, criterion=None, optimizer=None, dataloader=None):
def train(*, model=None, criterion=None, optimizer=None, iterator=None):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(dataloader):
for batch_idx, (inputs, targets) in enumerate(iterator):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
@@ -143,13 +155,14 @@ def test_multi_model(ray_start_2_cpus, num_workers):
def train_epoch(self, iterator, info):
result = {}
data = list(iterator)
for i, (model, optimizer) in enumerate(
zip(self.models, self.optimizers)):
result["model_{}".format(i)] = train(
model=model,
criterion=self.criterion,
optimizer=optimizer,
dataloader=iterator)
iterator=iter(data))
return result
def multi_model_creator(config):
@@ -310,6 +323,35 @@ def test_profiling(ray_start_2_cpus): # noqa: F811
trainer.shutdown()
def test_dataset(ray_start_4_cpus):
"""
This test tries training the mlp_identity example. We check the accuracy of
the model as an all inclusive way of ensuring that we are properly sharding
and iterating over the entire dataset (instead of repeating the first set
of points for example).
"""
model_creator = mlp_identity.model_creator
optimizer_creator = mlp_identity.optimizer_creator
dataset_creator = mlp_identity.dataset_creator
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=None,
optimizer_creator=optimizer_creator,
loss_creator=torch.nn.MSELoss,
num_workers=2,
)
dataset = dataset_creator()
for i in range(5):
trainer.train(dataset=dataset, num_steps=100)
input = mlp_identity.to_mat(0.5)
prediction = float(trainer.get_model()(input)[0][0])
assert 0.4 <= prediction <= 0.6
trainer.shutdown()
def test_split_batch(ray_start_2_cpus):
if not dist.is_available():
return
@@ -591,20 +633,18 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
trainer2.shutdown()
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
return
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
def step_with_fail(self, **params):
def gen_step_with_fail(num_fails):
def step_with_fail(self,
num_steps=None,
profile=False,
info=None,
dataset=None):
params = dict(num_steps=num_steps, profile=profile, info=info)
remote_worker_stats = [
w.train_epoch.remote(**params) for w in self.remote_workers
]
if self._num_failures < 3:
if self._num_failures < num_fails:
time.sleep(1) # Make the batch will fail correctly.
ray.kill(self.remote_workers[0])
@@ -619,6 +659,19 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
return success, None
return step_with_fail
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
if not dist.is_available():
return
def single_loader(config):
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
step_with_fail = gen_step_with_fail(3)
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
model_creator=model_creator,
@@ -642,25 +695,7 @@ def test_resize(ray_start_2_cpus): # noqa: F811
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
def step_with_fail(self, **params):
remote_worker_stats = [
w.train_epoch.remote(**params) for w in self.remote_workers
]
if self._num_failures < 1:
time.sleep(1) # Make the batch will fail correctly.
ray.kill(self.remote_workers[0])
try:
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError:
return False, None
success = check_for_failure(remote_worker_stats)
if success:
return success, [local_worker_stats] + ray.get(remote_worker_stats)
return success, None
step_with_fail = gen_step_with_fail(1)
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
@@ -691,25 +726,7 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
dataset = LinearDataset(2, 5, size=1000000)
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
def step_with_fail(self, **params):
remote_worker_stats = [
w.train_epoch.remote(**params) for w in self.remote_workers
]
if self._num_failures < 2:
time.sleep(1) # Make the batch will fail correctly.
ray.kill(self.remote_workers[0])
try:
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError:
return False, None
success = check_for_failure(remote_worker_stats)
if success:
return success, [local_worker_stats] + ray.get(remote_worker_stats)
return success, None
step_with_fail = gen_step_with_fail(2)
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
trainer1 = TorchTrainer(
+19 -4
View File
@@ -138,7 +138,9 @@ class TorchRunner:
def setup_components(self):
"""Runs the creator functions without any distributed coordination."""
logger.debug("Loading data.")
self._initialize_dataloaders()
if self.data_creator and callable(self.data_creator):
self._initialize_dataloaders()
logger.debug("Creating model")
self.models = self.model_creator(self.config)
if not isinstance(self.models, Iterable):
@@ -181,7 +183,11 @@ class TorchRunner:
"""Finds a free port on the current node."""
return utils.find_free_port()
def train_epoch(self, num_steps=None, profile=False, info=None):
def train_epoch(self,
num_steps=None,
profile=False,
info=None,
iterator=None):
"""Runs a training epoch and updates the model parameters."""
logger.debug("Begin Training Step {}".format(self.epochs + 1))
info = info or {}
@@ -193,9 +199,18 @@ class TorchRunner:
SCHEDULER_STEP: self.scheduler_step_freq
})
with self.timers.record("train_epoch"):
iterator = self.train_loader
if iterator is None:
iterator = iter(self.train_loader)
else:
# Dataset will provide us with a list of tuples but we
# need two lists.
def format_batch(batch):
features, targets = zip(*batch)
return torch.cat(features), torch.cat(targets)
iterator = map(format_batch, iterator)
if num_steps:
iterator = itertools.islice(iter(self.train_loader), num_steps)
iterator = itertools.islice(iterator, num_steps)
train_stats = self.training_operator.train_epoch(iterator, info)
self.epochs += 1
+32 -12
View File
@@ -17,6 +17,7 @@ from ray.util.sgd.torch.distributed_torch_runner import (
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
from ray.util.sgd.torch.torch_runner import TorchRunner
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
from ray.util.sgd.data import Dataset
logger = logging.getLogger(__name__)
RESIZE_COOLDOWN_S = 10
@@ -194,11 +195,9 @@ class TorchTrainer:
"For more information, see "
"https://github.com/pytorch/examples/issues/467."))
if not (callable(model_creator) and callable(optimizer_creator)
and callable(data_creator)):
if not (callable(model_creator) and callable(optimizer_creator)):
raise ValueError(
"Must provide a callable model_creator, optimizer_creator, "
"and data_creator.")
"Must provide a callable model_creator and optimizer_creator.")
if num_replicas is not None:
raise DeprecationWarning(
@@ -379,7 +378,8 @@ class TorchTrainer:
profile=False,
reduce_results=True,
max_retries=3,
info=None):
info=None,
dataset=None):
"""Runs a training epoch.
Calls `operator.train_epoch()` on N parallel workers simultaneously
@@ -405,6 +405,8 @@ class TorchTrainer:
in case of shared cluster usage. Defaults to 3.
info (dict): Optional dictionary passed to the training
operator for ``train_epoch`` and ``train_batch``.
dataset (Dataset): Optional dataset to train with. If specified,
the dataloader passed in via data_creator will be ignored.
Returns:
(dict | list) A dictionary of metrics for training.
@@ -414,11 +416,14 @@ class TorchTrainer:
length will be equal to ``num_workers``.
"""
assert max_retries >= 0, "`max_retries` must be non-negative."
assert isinstance(dataset, Dataset) is not None \
or self.data_creator, \
"Must specify either a data creator or a dataset"
if self._should_resize():
logger.info("Resize opportunity detected. Attempting to scale up.")
self._resize_workers()
success, worker_stats = self._train_epoch(
num_steps=num_steps, profile=profile, info=info)
num_steps=num_steps, profile=profile, info=info, dataset=dataset)
# Fault handling
for i in range(max_retries):
if success:
@@ -429,7 +434,10 @@ class TorchTrainer:
logger.info("Retrying training step with %d workers." %
(len(self.remote_workers) + 1))
success, worker_stats = self._train_epoch(
num_steps=num_steps, profile=profile, info=info)
num_steps=num_steps,
profile=profile,
info=info,
dataset=dataset)
if not success:
raise RuntimeError("Training run failed.")
@@ -452,14 +460,26 @@ class TorchTrainer:
stats[stat_key] = worker_stats[0][stat_key]
return stats
def _train_epoch(self, num_steps=None, profile=False, info=None):
def _train_epoch(self,
num_steps=None,
profile=False,
info=None,
dataset=None):
params = dict(num_steps=num_steps, profile=profile, info=info)
remote_worker_stats = [
w.train_epoch.remote(**params) for w in self.remote_workers
]
remote_worker_stats = []
if dataset:
dataset.set_num_shards(self.max_replicas)
for i, w in enumerate(self.remote_workers):
params = dict(num_steps=num_steps, profile=profile, info=info)
if dataset:
params["iterator"] = dataset.get_shard(i)
stats = w.train_epoch.remote(**params)
remote_worker_stats.append(stats)
try:
if dataset:
params["iterator"] = dataset.get_shard(
len(self.remote_workers))
local_worker_stats = self.local_worker.train_epoch(**params)
except RuntimeError as err:
if "gloo" in err.args[0] and "Timed out" in err.args[0]: