mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
[SGD] Dataset API (#7839)
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from ray.util.sgd.data.dataset import Dataset
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["Dataset"]
|
||||
@@ -0,0 +1,92 @@
|
||||
from ray.util.iter import ParallelIterator, from_iterators
|
||||
|
||||
|
||||
class Dataset():
|
||||
"""A simple Dataset abstraction for RaySGD.
|
||||
|
||||
This dataset is designed to work with RaySGD trainers (currently just
|
||||
Torch) to provide support for streaming large external datasets, and built
|
||||
in sharding.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def to_mat(x):
|
||||
return torch.tensor([[x]]).float()
|
||||
|
||||
|
||||
data = [i * 0.001 for i in range(1000)]
|
||||
p_iter = iter.from_items(data, num_shards=1, repeat=True)
|
||||
dataset = Dataset(
|
||||
p_iter,
|
||||
batch_size=32,
|
||||
max_concurrency=1,
|
||||
download_func=lambda x: (to_mat(x), to_mat(x)))
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=None,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=torch.nn.MSELoss,
|
||||
num_workers=5,
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
# Train for another epoch using the dataset
|
||||
trainer.train(dataset=dataset, num_steps=200)
|
||||
|
||||
model = trainer.get_model()
|
||||
print("f(0.5)=", float(model(to_mat(0.5))[0][0]))
|
||||
|
||||
Args:
|
||||
data (iterable[U] or ParallelIterator[U]): Any existing python
|
||||
iterable (or iterator), or an existing parallel iterator
|
||||
to use.
|
||||
batch_size (int): The batch size for training/inference (default 32).
|
||||
download_func (U -> (S, Y)): A function which returns two values, the
|
||||
input and the label (default is the identity function).
|
||||
max_concurrency (int): The maximum number of concurrent calls to the
|
||||
download function. See ParallelIterator::for_each for details.
|
||||
transform (S -> X): A final transformation to be applied to the _input
|
||||
only_. This is guaranteed to run on the same worker that training
|
||||
will occur on.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data,
|
||||
batch_size=32,
|
||||
download_func=None,
|
||||
max_concurrency=0,
|
||||
transform=None):
|
||||
par_iter = None
|
||||
if isinstance(data, ParallelIterator):
|
||||
par_iter = data.repartition(1)
|
||||
else:
|
||||
par_iter = from_iterators([data], repeat=True)
|
||||
if download_func:
|
||||
par_iter = par_iter.for_each(
|
||||
download_func, max_concurrency=max_concurrency)
|
||||
self.iter = par_iter.batch(batch_size)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.max_concurrency = max_concurrency
|
||||
self.transform = transform
|
||||
|
||||
def set_num_shards(self, num_shards):
|
||||
"""
|
||||
Reshards the iterator if necessary.
|
||||
"""
|
||||
if num_shards != self.iter.num_shards():
|
||||
print("Setting num shards", num_shards)
|
||||
self.iter = self.iter.repartition(num_shards)
|
||||
|
||||
def get_shard(self, i):
|
||||
"""
|
||||
Returns a single, iterable shard.
|
||||
"""
|
||||
assert i < self.iter.num_shards(), \
|
||||
"Trying to get shard {} but there are only {} shards." + \
|
||||
"Are you sure you called set_num_shards already?".format(
|
||||
i, self.iter.num_shards()
|
||||
)
|
||||
|
||||
return self.iter.get_shard(i)
|
||||
@@ -0,0 +1 @@
|
||||
ray-project/*
|
||||
@@ -0,0 +1,69 @@
|
||||
import ray
|
||||
from ray.util.sgd.torch.torch_trainer import TorchTrainer
|
||||
from ray.util.sgd.data.dataset import Dataset
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Linear(1, 128)
|
||||
self.fc2 = nn.Linear(128, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = F.relu(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def model_creator(config):
|
||||
return Net()
|
||||
|
||||
|
||||
def optimizer_creator(model, config):
|
||||
return torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))
|
||||
|
||||
|
||||
def to_mat(x):
|
||||
return torch.tensor([[x]]).float()
|
||||
|
||||
|
||||
def dataset_creator():
|
||||
num_points = 32 * 100 * 2
|
||||
data = [i * (1 / num_points) for i in range(num_points)]
|
||||
dataset = Dataset(
|
||||
data,
|
||||
batch_size=32,
|
||||
max_concurrency=2,
|
||||
download_func=lambda x: (to_mat(x), to_mat(x)))
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
dataset = dataset_creator()
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=None,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=torch.nn.MSELoss,
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
# Train a full epoch using the data_creator
|
||||
# trainer.train()
|
||||
|
||||
# Train for another epoch using the dataset
|
||||
trainer.train(dataset=dataset, num_steps=100)
|
||||
|
||||
model = trainer.get_model()
|
||||
print("f(0.5)=", float(model(to_mat(0.5))[0][0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
main()
|
||||
@@ -17,6 +17,7 @@ from ray.util.sgd.torch.constants import SCHEDULER_STEP
|
||||
from ray.util.sgd.utils import (check_for_failure, NUM_SAMPLES, BATCH_COUNT,
|
||||
BATCH_SIZE)
|
||||
|
||||
from ray.util.sgd.data.examples import mlp_identity
|
||||
from ray.util.sgd.torch.examples.train_example import (
|
||||
model_creator, optimizer_creator, data_creator, LinearDataset)
|
||||
|
||||
@@ -32,6 +33,17 @@ def ray_start_2_cpus():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_4_cpus():
|
||||
address_info = ray.init(num_cpus=4)
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
# Ensure that tests don't ALL fail
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_single_step(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
@@ -120,12 +132,12 @@ def test_train(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2] if dist.is_available() else [1])
|
||||
def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
def train(*, model=None, criterion=None, optimizer=None, dataloader=None):
|
||||
def train(*, model=None, criterion=None, optimizer=None, iterator=None):
|
||||
model.train()
|
||||
train_loss = 0
|
||||
correct = 0
|
||||
total = 0
|
||||
for batch_idx, (inputs, targets) in enumerate(dataloader):
|
||||
for batch_idx, (inputs, targets) in enumerate(iterator):
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
@@ -143,13 +155,14 @@ def test_multi_model(ray_start_2_cpus, num_workers):
|
||||
|
||||
def train_epoch(self, iterator, info):
|
||||
result = {}
|
||||
data = list(iterator)
|
||||
for i, (model, optimizer) in enumerate(
|
||||
zip(self.models, self.optimizers)):
|
||||
result["model_{}".format(i)] = train(
|
||||
model=model,
|
||||
criterion=self.criterion,
|
||||
optimizer=optimizer,
|
||||
dataloader=iterator)
|
||||
iterator=iter(data))
|
||||
return result
|
||||
|
||||
def multi_model_creator(config):
|
||||
@@ -310,6 +323,35 @@ def test_profiling(ray_start_2_cpus): # noqa: F811
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_dataset(ray_start_4_cpus):
|
||||
"""
|
||||
This test tries training the mlp_identity example. We check the accuracy of
|
||||
the model as an all inclusive way of ensuring that we are properly sharding
|
||||
and iterating over the entire dataset (instead of repeating the first set
|
||||
of points for example).
|
||||
"""
|
||||
|
||||
model_creator = mlp_identity.model_creator
|
||||
optimizer_creator = mlp_identity.optimizer_creator
|
||||
dataset_creator = mlp_identity.dataset_creator
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=None,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=torch.nn.MSELoss,
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
dataset = dataset_creator()
|
||||
for i in range(5):
|
||||
trainer.train(dataset=dataset, num_steps=100)
|
||||
|
||||
input = mlp_identity.to_mat(0.5)
|
||||
prediction = float(trainer.get_model()(input)[0][0])
|
||||
assert 0.4 <= prediction <= 0.6
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_split_batch(ray_start_2_cpus):
|
||||
if not dist.is_available():
|
||||
return
|
||||
@@ -591,20 +633,18 @@ def test_wrap_ddp(ray_start_2_cpus, tmp_path): # noqa: F811
|
||||
trainer2.shutdown()
|
||||
|
||||
|
||||
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, **params):
|
||||
def gen_step_with_fail(num_fails):
|
||||
def step_with_fail(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
info=None,
|
||||
dataset=None):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
remote_worker_stats = [
|
||||
w.train_epoch.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
|
||||
if self._num_failures < 3:
|
||||
if self._num_failures < num_fails:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
ray.kill(self.remote_workers[0])
|
||||
|
||||
@@ -619,6 +659,19 @@ def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
|
||||
return success, None
|
||||
|
||||
return step_with_fail
|
||||
|
||||
|
||||
def test_fail_with_recover(ray_start_2_cpus): # noqa: F811
|
||||
if not dist.is_available():
|
||||
return
|
||||
|
||||
def single_loader(config):
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
step_with_fail = gen_step_with_fail(3)
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
@@ -642,25 +695,7 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, **params):
|
||||
remote_worker_stats = [
|
||||
w.train_epoch.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
|
||||
if self._num_failures < 1:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
ray.kill(self.remote_workers[0])
|
||||
|
||||
try:
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError:
|
||||
return False, None
|
||||
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
return success, None
|
||||
step_with_fail = gen_step_with_fail(1)
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
@@ -691,25 +726,7 @@ def test_fail_twice(ray_start_2_cpus): # noqa: F811
|
||||
dataset = LinearDataset(2, 5, size=1000000)
|
||||
return DataLoader(dataset, batch_size=config.get("batch_size", 32))
|
||||
|
||||
def step_with_fail(self, **params):
|
||||
remote_worker_stats = [
|
||||
w.train_epoch.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
|
||||
if self._num_failures < 2:
|
||||
time.sleep(1) # Make the batch will fail correctly.
|
||||
ray.kill(self.remote_workers[0])
|
||||
|
||||
try:
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError:
|
||||
return False, None
|
||||
|
||||
success = check_for_failure(remote_worker_stats)
|
||||
if success:
|
||||
return success, [local_worker_stats] + ray.get(remote_worker_stats)
|
||||
|
||||
return success, None
|
||||
step_with_fail = gen_step_with_fail(2)
|
||||
|
||||
with patch.object(TorchTrainer, "_train_epoch", step_with_fail):
|
||||
trainer1 = TorchTrainer(
|
||||
|
||||
@@ -138,7 +138,9 @@ class TorchRunner:
|
||||
def setup_components(self):
|
||||
"""Runs the creator functions without any distributed coordination."""
|
||||
logger.debug("Loading data.")
|
||||
self._initialize_dataloaders()
|
||||
if self.data_creator and callable(self.data_creator):
|
||||
self._initialize_dataloaders()
|
||||
|
||||
logger.debug("Creating model")
|
||||
self.models = self.model_creator(self.config)
|
||||
if not isinstance(self.models, Iterable):
|
||||
@@ -181,7 +183,11 @@ class TorchRunner:
|
||||
"""Finds a free port on the current node."""
|
||||
return utils.find_free_port()
|
||||
|
||||
def train_epoch(self, num_steps=None, profile=False, info=None):
|
||||
def train_epoch(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
info=None,
|
||||
iterator=None):
|
||||
"""Runs a training epoch and updates the model parameters."""
|
||||
logger.debug("Begin Training Step {}".format(self.epochs + 1))
|
||||
info = info or {}
|
||||
@@ -193,9 +199,18 @@ class TorchRunner:
|
||||
SCHEDULER_STEP: self.scheduler_step_freq
|
||||
})
|
||||
with self.timers.record("train_epoch"):
|
||||
iterator = self.train_loader
|
||||
if iterator is None:
|
||||
iterator = iter(self.train_loader)
|
||||
else:
|
||||
# Dataset will provide us with a list of tuples but we
|
||||
# need two lists.
|
||||
def format_batch(batch):
|
||||
features, targets = zip(*batch)
|
||||
return torch.cat(features), torch.cat(targets)
|
||||
|
||||
iterator = map(format_batch, iterator)
|
||||
if num_steps:
|
||||
iterator = itertools.islice(iter(self.train_loader), num_steps)
|
||||
iterator = itertools.islice(iterator, num_steps)
|
||||
train_stats = self.training_operator.train_epoch(iterator, info)
|
||||
|
||||
self.epochs += 1
|
||||
|
||||
@@ -17,6 +17,7 @@ from ray.util.sgd.torch.distributed_torch_runner import (
|
||||
from ray.util.sgd.utils import check_for_failure, NUM_SAMPLES, BATCH_SIZE
|
||||
from ray.util.sgd.torch.torch_runner import TorchRunner
|
||||
from ray.util.sgd.torch.constants import VALID_SCHEDULER_STEP
|
||||
from ray.util.sgd.data import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
RESIZE_COOLDOWN_S = 10
|
||||
@@ -194,11 +195,9 @@ class TorchTrainer:
|
||||
"For more information, see "
|
||||
"https://github.com/pytorch/examples/issues/467."))
|
||||
|
||||
if not (callable(model_creator) and callable(optimizer_creator)
|
||||
and callable(data_creator)):
|
||||
if not (callable(model_creator) and callable(optimizer_creator)):
|
||||
raise ValueError(
|
||||
"Must provide a callable model_creator, optimizer_creator, "
|
||||
"and data_creator.")
|
||||
"Must provide a callable model_creator and optimizer_creator.")
|
||||
|
||||
if num_replicas is not None:
|
||||
raise DeprecationWarning(
|
||||
@@ -379,7 +378,8 @@ class TorchTrainer:
|
||||
profile=False,
|
||||
reduce_results=True,
|
||||
max_retries=3,
|
||||
info=None):
|
||||
info=None,
|
||||
dataset=None):
|
||||
"""Runs a training epoch.
|
||||
|
||||
Calls `operator.train_epoch()` on N parallel workers simultaneously
|
||||
@@ -405,6 +405,8 @@ class TorchTrainer:
|
||||
in case of shared cluster usage. Defaults to 3.
|
||||
info (dict): Optional dictionary passed to the training
|
||||
operator for ``train_epoch`` and ``train_batch``.
|
||||
dataset (Dataset): Optional dataset to train with. If specified,
|
||||
the dataloader passed in via data_creator will be ignored.
|
||||
|
||||
Returns:
|
||||
(dict | list) A dictionary of metrics for training.
|
||||
@@ -414,11 +416,14 @@ class TorchTrainer:
|
||||
length will be equal to ``num_workers``.
|
||||
"""
|
||||
assert max_retries >= 0, "`max_retries` must be non-negative."
|
||||
assert isinstance(dataset, Dataset) is not None \
|
||||
or self.data_creator, \
|
||||
"Must specify either a data creator or a dataset"
|
||||
if self._should_resize():
|
||||
logger.info("Resize opportunity detected. Attempting to scale up.")
|
||||
self._resize_workers()
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
num_steps=num_steps, profile=profile, info=info, dataset=dataset)
|
||||
# Fault handling
|
||||
for i in range(max_retries):
|
||||
if success:
|
||||
@@ -429,7 +434,10 @@ class TorchTrainer:
|
||||
logger.info("Retrying training step with %d workers." %
|
||||
(len(self.remote_workers) + 1))
|
||||
success, worker_stats = self._train_epoch(
|
||||
num_steps=num_steps, profile=profile, info=info)
|
||||
num_steps=num_steps,
|
||||
profile=profile,
|
||||
info=info,
|
||||
dataset=dataset)
|
||||
if not success:
|
||||
raise RuntimeError("Training run failed.")
|
||||
|
||||
@@ -452,14 +460,26 @@ class TorchTrainer:
|
||||
stats[stat_key] = worker_stats[0][stat_key]
|
||||
return stats
|
||||
|
||||
def _train_epoch(self, num_steps=None, profile=False, info=None):
|
||||
def _train_epoch(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
info=None,
|
||||
dataset=None):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
|
||||
remote_worker_stats = [
|
||||
w.train_epoch.remote(**params) for w in self.remote_workers
|
||||
]
|
||||
remote_worker_stats = []
|
||||
if dataset:
|
||||
dataset.set_num_shards(self.max_replicas)
|
||||
for i, w in enumerate(self.remote_workers):
|
||||
params = dict(num_steps=num_steps, profile=profile, info=info)
|
||||
if dataset:
|
||||
params["iterator"] = dataset.get_shard(i)
|
||||
stats = w.train_epoch.remote(**params)
|
||||
remote_worker_stats.append(stats)
|
||||
|
||||
try:
|
||||
if dataset:
|
||||
params["iterator"] = dataset.get_shard(
|
||||
len(self.remote_workers))
|
||||
local_worker_stats = self.local_worker.train_epoch(**params)
|
||||
except RuntimeError as err:
|
||||
if "gloo" in err.args[0] and "Timed out" in err.args[0]:
|
||||
|
||||
Reference in New Issue
Block a user