From 6163b21458e0e3075bc1cf3b64dd14b8d6f07fdf Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 10 Mar 2020 18:58:19 -0700 Subject: [PATCH] [raysgd] Better user errors! (#7546) * format * callable * Update python/ray/util/sgd/torch/torch_trainer.py Co-Authored-By: Edward Oakes * Update python/ray/util/sgd/torch/torch_trainer.py Co-Authored-By: Edward Oakes * data * torchtrainer * num_rep Co-authored-by: Edward Oakes --- python/ray/util/sgd/torch/torch_runner.py | 6 +++ python/ray/util/sgd/torch/torch_trainer.py | 63 +++++++++++++++------- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/python/ray/util/sgd/torch/torch_runner.py b/python/ray/util/sgd/torch/torch_runner.py index 71e7bec1f..bf812dfb7 100644 --- a/python/ray/util/sgd/torch/torch_runner.py +++ b/python/ray/util/sgd/torch/torch_runner.py @@ -95,6 +95,12 @@ class TorchRunner: with FileLock(os.path.join(tempfile.gettempdir(), ".ray_data.lock")): loaders = self.data_creator(self.config) train_loader, val_loader = self._validate_loaders(loaders) + if not isinstance(train_loader, torch.utils.data.DataLoader): + logger.warning( + "TorchTrainer data_creator return values are no longer " + "wrapped as DataLoaders. Users must return DataLoader(s) " + "in data_creator. This warning will be removed in " + "a future version of Ray.") self.train_loader, self.validation_loader = train_loader, val_loader diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 6f7a8915e..4a638b1f7 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -131,22 +131,27 @@ class TorchTrainer: """ - def __init__(self, - *, - model_creator=None, - data_creator=None, - optimizer_creator=None, - loss_creator=None, - scheduler_creator=None, - training_operator_cls=None, - initialization_hook=None, - config=None, - num_workers=1, - use_gpu=False, - backend="auto", - use_fp16=False, - apex_args=None, - scheduler_step_freq="batch"): + def __init__( + self, + *, + model_creator, + data_creator, + optimizer_creator, + loss_creator=None, + scheduler_creator=None, + training_operator_cls=None, + initialization_hook=None, + config=None, + num_workers=1, + use_gpu=False, + backend="auto", + use_fp16=False, + apex_args=None, + scheduler_step_freq="batch", + num_replicas=None, + batch_size=None, + data_loader_args=None, + ): if num_workers > 1 and not dist.is_available(): raise ValueError( ("Distributed PyTorch is not supported on macOS. " @@ -154,8 +159,30 @@ class TorchTrainer: "For more information, see " "https://github.com/pytorch/examples/issues/467.")) - if not (model_creator and optimizer_creator and data_creator): - raise ValueError("Must provide a Model, Optimizer, Data creator.") + if not (callable(model_creator) and callable(optimizer_creator) + and callable(data_creator)): + raise ValueError( + "Must provide a callable model_creator, optimizer_creator, " + "and data_creator.") + + if num_replicas is not None: + raise DeprecationWarning( + "num_replicas is deprecated. Use num_workers instead.") + + if batch_size is not None: + raise DeprecationWarning( + "batch_size is deprecated. Use config={'batch_size': N} " + "specify a batch size for each worker or " + "config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a " + "batch size to be used across all workers.") + + if data_loader_args: + raise ValueError( + "data_loader_args is deprecated. You can return a " + "torch.utils.data.DataLoader in data_creator. Ray will " + "automatically set a DistributedSampler if a DataLoader is " + "returned and num_workers > 1.") + self.model_creator = model_creator self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator