[raysgd] Better user errors! (#7546)

* format

* callable

* Update python/ray/util/sgd/torch/torch_trainer.py

Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com>

* Update python/ray/util/sgd/torch/torch_trainer.py

Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com>

* data

* torchtrainer

* num_rep

Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
This commit is contained in:
Richard Liaw
2020-03-10 18:58:19 -07:00
committed by GitHub
parent 7b609ca211
commit 6163b21458
2 changed files with 51 additions and 18 deletions
@@ -95,6 +95,12 @@ class TorchRunner:
with FileLock(os.path.join(tempfile.gettempdir(), ".ray_data.lock")):
loaders = self.data_creator(self.config)
train_loader, val_loader = self._validate_loaders(loaders)
if not isinstance(train_loader, torch.utils.data.DataLoader):
logger.warning(
"TorchTrainer data_creator return values are no longer "
"wrapped as DataLoaders. Users must return DataLoader(s) "
"in data_creator. This warning will be removed in "
"a future version of Ray.")
self.train_loader, self.validation_loader = train_loader, val_loader
+45 -18
View File
@@ -131,22 +131,27 @@ class TorchTrainer:
"""
def __init__(self,
*,
model_creator=None,
data_creator=None,
optimizer_creator=None,
loss_creator=None,
scheduler_creator=None,
training_operator_cls=None,
initialization_hook=None,
config=None,
num_workers=1,
use_gpu=False,
backend="auto",
use_fp16=False,
apex_args=None,
scheduler_step_freq="batch"):
def __init__(
self,
*,
model_creator,
data_creator,
optimizer_creator,
loss_creator=None,
scheduler_creator=None,
training_operator_cls=None,
initialization_hook=None,
config=None,
num_workers=1,
use_gpu=False,
backend="auto",
use_fp16=False,
apex_args=None,
scheduler_step_freq="batch",
num_replicas=None,
batch_size=None,
data_loader_args=None,
):
if num_workers > 1 and not dist.is_available():
raise ValueError(
("Distributed PyTorch is not supported on macOS. "
@@ -154,8 +159,30 @@ class TorchTrainer:
"For more information, see "
"https://github.com/pytorch/examples/issues/467."))
if not (model_creator and optimizer_creator and data_creator):
raise ValueError("Must provide a Model, Optimizer, Data creator.")
if not (callable(model_creator) and callable(optimizer_creator)
and callable(data_creator)):
raise ValueError(
"Must provide a callable model_creator, optimizer_creator, "
"and data_creator.")
if num_replicas is not None:
raise DeprecationWarning(
"num_replicas is deprecated. Use num_workers instead.")
if batch_size is not None:
raise DeprecationWarning(
"batch_size is deprecated. Use config={'batch_size': N} "
"specify a batch size for each worker or "
"config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a "
"batch size to be used across all workers.")
if data_loader_args:
raise ValueError(
"data_loader_args is deprecated. You can return a "
"torch.utils.data.DataLoader in data_creator. Ray will "
"automatically set a DistributedSampler if a DataLoader is "
"returned and num_workers > 1.")
self.model_creator = model_creator
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator