diff --git a/python/ray/util/sgd/tests/test_torch.py b/python/ray/util/sgd/tests/test_torch.py index b36aeb70c..8d67892f4 100644 --- a/python/ray/util/sgd/tests/test_torch.py +++ b/python/ray/util/sgd/tests/test_torch.py @@ -60,6 +60,29 @@ def test_resize(ray_start_2_cpus): # noqa: F811 assert len(results) == 2 +def test_non_serialized_data(ray_start_2_cpus): # noqa: F811 + duration = 10 + + def slow_data(func): + def slowed_func(*args, **kwargs): + time.sleep(duration) + return func(*args, **kwargs) + + return slowed_func + + start = time.time() + trainer = TorchTrainer( + model_creator=model_creator, + data_creator=slow_data(data_creator), + optimizer_creator=optimizer_creator, + serialize_data_creation=False, + loss_creator=lambda config: nn.MSELoss(), + num_workers=2) + elapsed = time.time() - start + assert elapsed < duration * 2 + trainer.shutdown() + + def test_dead_trainer(ray_start_2_cpus): # noqa: F811 trainer = TorchTrainer( model_creator=model_creator, diff --git a/python/ray/util/sgd/torch/distributed_torch_runner.py b/python/ray/util/sgd/torch/distributed_torch_runner.py index b7d410cfb..abf2f0d0e 100644 --- a/python/ray/util/sgd/torch/distributed_torch_runner.py +++ b/python/ray/util/sgd/torch/distributed_torch_runner.py @@ -28,7 +28,6 @@ class DistributedTorchRunner(TorchRunner): wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel over each model. If False, you are expected to call it yourself. kwargs: Keyword arguments for TorchRunner. - """ def __init__(self, diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index f37ff6e93..2ce819d3e 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -29,23 +29,7 @@ except ImportError: class TorchRunner: - """Manages a PyTorch model for training. - - Args: - model_creator (dict -> Model(s)): see torch_trainer.py - data_creator (dict -> Iterable(s)): see torch_trainer.py. - optimizer_creator ((models, dict) -> optimizers): see torch_trainer.py. - loss_creator (torch.nn.*Loss class | dict -> loss): - see torch_trainer.py. - scheduler_creator ((optimizers, dict) -> scheduler): see - torch_trainer.py. - training_operator_cls: see torch_trainer.py - config (dict): see torch_trainer.py. - use_gpu (bool): see torch_trainer.py. - use_fp16 (bool): see torch_trainer.py. - apex_args (dict|None): see torch_trainer.py. - scheduler_step_freq (str): see torch_trainer.py. - """ + """Manages a PyTorch model for training.""" def __init__(self, model_creator, @@ -56,6 +40,7 @@ class TorchRunner: training_operator_cls=None, config=None, use_gpu=False, + serialize_data_creation=True, use_fp16=False, use_tqdm=False, apex_args=None, @@ -77,6 +62,7 @@ class TorchRunner: self.train_loader = None self.validation_loader = None self.training_operator = None + self.serialize_data_creation = serialize_data_creation self.use_gpu = use_gpu self.use_fp16 = use_fp16 self.use_tqdm = use_tqdm @@ -102,17 +88,15 @@ class TorchRunner: def _initialize_dataloaders(self): logger.debug("Instantiating dataloaders.") - # When creating loaders, a filelock will be used to ensure no - # race conditions in data downloading among different workers. - with FileLock(os.path.join(tempfile.gettempdir(), ".ray_data.lock")): + loaders = None + if self.serialize_data_creation: + logger.debug("Serializing the dataloading process.") + with FileLock( + os.path.join(tempfile.gettempdir(), ".raydata.lock")): + loaders = self.data_creator(self.config) + else: loaders = self.data_creator(self.config) - train_loader, val_loader = self._validate_loaders(loaders) - if not isinstance(train_loader, torch.utils.data.DataLoader): - logger.warning( - "TorchTrainer data_creator return values are no longer " - "wrapped as DataLoaders. Users must return DataLoader(s) " - "in data_creator. This warning will be removed in " - "a future version of Ray.") + train_loader, val_loader = self._validate_loaders(loaders) self.train_loader, self.validation_loader = train_loader, val_loader diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 0c988a5f3..1b4e5f738 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -131,6 +131,10 @@ class TorchTrainer: support "nccl", "gloo", and "auto". If "auto", RaySGD will automatically use "nccl" if `use_gpu` is True, and "gloo" otherwise. + serialize_data_creation (bool): A filelock will be used + to ensure no race conditions in data downloading among + different workers on the same node (using the local file system). + Defaults to True. wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel over each model. If False, you are expected to call it yourself. add_dist_sampler (bool): Whether to automatically add a @@ -171,6 +175,7 @@ class TorchTrainer: use_gpu="auto", backend="auto", wrap_ddp=True, + serialize_data_creation=True, use_fp16=False, use_tqdm=False, apex_args=None, @@ -237,6 +242,7 @@ class TorchTrainer: self.use_gpu = use_gpu self.max_replicas = num_workers + self.serialize_data_creation = serialize_data_creation self.wrap_ddp = wrap_ddp self.use_fp16 = use_fp16 self.use_tqdm = use_tqdm @@ -298,6 +304,7 @@ class TorchTrainer: scheduler_creator=self.scheduler_creator, training_operator_cls=self.training_operator_cls, config=worker_config, + serialize_data_creation=self.serialize_data_creation, use_fp16=self.use_fp16, use_gpu=self.use_gpu, use_tqdm=self.use_tqdm,