[sgd] Make serialization of data creation optional (#8027)

* pytest

* Update python/ray/util/sgd/torch/torch_trainer.py

Co-Authored-By: Ujval Misra <misraujval@gmail.com>

Co-authored-by: Ujval Misra <misraujval@gmail.com>
This commit is contained in:
Richard Liaw
2020-04-16 20:27:51 -07:00
committed by GitHub
parent de1787e5e5
commit a9ea139317
4 changed files with 41 additions and 28 deletions
+23
View File
@@ -60,6 +60,29 @@ def test_resize(ray_start_2_cpus): # noqa: F811
assert len(results) == 2
def test_non_serialized_data(ray_start_2_cpus): # noqa: F811
duration = 10
def slow_data(func):
def slowed_func(*args, **kwargs):
time.sleep(duration)
return func(*args, **kwargs)
return slowed_func
start = time.time()
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=slow_data(data_creator),
optimizer_creator=optimizer_creator,
serialize_data_creation=False,
loss_creator=lambda config: nn.MSELoss(),
num_workers=2)
elapsed = time.time() - start
assert elapsed < duration * 2
trainer.shutdown()
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
trainer = TorchTrainer(
model_creator=model_creator,
@@ -28,7 +28,6 @@ class DistributedTorchRunner(TorchRunner):
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
over each model. If False, you are expected to call it yourself.
kwargs: Keyword arguments for TorchRunner.
"""
def __init__(self,
+11 -27
View File
@@ -29,23 +29,7 @@ except ImportError:
class TorchRunner:
"""Manages a PyTorch model for training.
Args:
model_creator (dict -> Model(s)): see torch_trainer.py
data_creator (dict -> Iterable(s)): see torch_trainer.py.
optimizer_creator ((models, dict) -> optimizers): see torch_trainer.py.
loss_creator (torch.nn.*Loss class | dict -> loss):
see torch_trainer.py.
scheduler_creator ((optimizers, dict) -> scheduler): see
torch_trainer.py.
training_operator_cls: see torch_trainer.py
config (dict): see torch_trainer.py.
use_gpu (bool): see torch_trainer.py.
use_fp16 (bool): see torch_trainer.py.
apex_args (dict|None): see torch_trainer.py.
scheduler_step_freq (str): see torch_trainer.py.
"""
"""Manages a PyTorch model for training."""
def __init__(self,
model_creator,
@@ -56,6 +40,7 @@ class TorchRunner:
training_operator_cls=None,
config=None,
use_gpu=False,
serialize_data_creation=True,
use_fp16=False,
use_tqdm=False,
apex_args=None,
@@ -77,6 +62,7 @@ class TorchRunner:
self.train_loader = None
self.validation_loader = None
self.training_operator = None
self.serialize_data_creation = serialize_data_creation
self.use_gpu = use_gpu
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
@@ -102,17 +88,15 @@ class TorchRunner:
def _initialize_dataloaders(self):
logger.debug("Instantiating dataloaders.")
# When creating loaders, a filelock will be used to ensure no
# race conditions in data downloading among different workers.
with FileLock(os.path.join(tempfile.gettempdir(), ".ray_data.lock")):
loaders = None
if self.serialize_data_creation:
logger.debug("Serializing the dataloading process.")
with FileLock(
os.path.join(tempfile.gettempdir(), ".raydata.lock")):
loaders = self.data_creator(self.config)
else:
loaders = self.data_creator(self.config)
train_loader, val_loader = self._validate_loaders(loaders)
if not isinstance(train_loader, torch.utils.data.DataLoader):
logger.warning(
"TorchTrainer data_creator return values are no longer "
"wrapped as DataLoaders. Users must return DataLoader(s) "
"in data_creator. This warning will be removed in "
"a future version of Ray.")
train_loader, val_loader = self._validate_loaders(loaders)
self.train_loader, self.validation_loader = train_loader, val_loader
@@ -131,6 +131,10 @@ class TorchTrainer:
support "nccl", "gloo", and "auto". If "auto", RaySGD will
automatically use "nccl" if `use_gpu` is True, and "gloo"
otherwise.
serialize_data_creation (bool): A filelock will be used
to ensure no race conditions in data downloading among
different workers on the same node (using the local file system).
Defaults to True.
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
over each model. If False, you are expected to call it yourself.
add_dist_sampler (bool): Whether to automatically add a
@@ -171,6 +175,7 @@ class TorchTrainer:
use_gpu="auto",
backend="auto",
wrap_ddp=True,
serialize_data_creation=True,
use_fp16=False,
use_tqdm=False,
apex_args=None,
@@ -237,6 +242,7 @@ class TorchTrainer:
self.use_gpu = use_gpu
self.max_replicas = num_workers
self.serialize_data_creation = serialize_data_creation
self.wrap_ddp = wrap_ddp
self.use_fp16 = use_fp16
self.use_tqdm = use_tqdm
@@ -298,6 +304,7 @@ class TorchTrainer:
scheduler_creator=self.scheduler_creator,
training_operator_cls=self.training_operator_cls,
config=worker_config,
serialize_data_creation=self.serialize_data_creation,
use_fp16=self.use_fp16,
use_gpu=self.use_gpu,
use_tqdm=self.use_tqdm,