mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 17:04:56 +08:00
[raysgd] Better user errors! (#7546)
* format * callable * Update python/ray/util/sgd/torch/torch_trainer.py Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com> * Update python/ray/util/sgd/torch/torch_trainer.py Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com> * data * torchtrainer * num_rep Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
This commit is contained in:
@@ -95,6 +95,12 @@ class TorchRunner:
|
||||
with FileLock(os.path.join(tempfile.gettempdir(), ".ray_data.lock")):
|
||||
loaders = self.data_creator(self.config)
|
||||
train_loader, val_loader = self._validate_loaders(loaders)
|
||||
if not isinstance(train_loader, torch.utils.data.DataLoader):
|
||||
logger.warning(
|
||||
"TorchTrainer data_creator return values are no longer "
|
||||
"wrapped as DataLoaders. Users must return DataLoader(s) "
|
||||
"in data_creator. This warning will be removed in "
|
||||
"a future version of Ray.")
|
||||
|
||||
self.train_loader, self.validation_loader = train_loader, val_loader
|
||||
|
||||
|
||||
@@ -131,22 +131,27 @@ class TorchTrainer:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
model_creator=None,
|
||||
data_creator=None,
|
||||
optimizer_creator=None,
|
||||
loss_creator=None,
|
||||
scheduler_creator=None,
|
||||
training_operator_cls=None,
|
||||
initialization_hook=None,
|
||||
config=None,
|
||||
num_workers=1,
|
||||
use_gpu=False,
|
||||
backend="auto",
|
||||
use_fp16=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq="batch"):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
loss_creator=None,
|
||||
scheduler_creator=None,
|
||||
training_operator_cls=None,
|
||||
initialization_hook=None,
|
||||
config=None,
|
||||
num_workers=1,
|
||||
use_gpu=False,
|
||||
backend="auto",
|
||||
use_fp16=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq="batch",
|
||||
num_replicas=None,
|
||||
batch_size=None,
|
||||
data_loader_args=None,
|
||||
):
|
||||
if num_workers > 1 and not dist.is_available():
|
||||
raise ValueError(
|
||||
("Distributed PyTorch is not supported on macOS. "
|
||||
@@ -154,8 +159,30 @@ class TorchTrainer:
|
||||
"For more information, see "
|
||||
"https://github.com/pytorch/examples/issues/467."))
|
||||
|
||||
if not (model_creator and optimizer_creator and data_creator):
|
||||
raise ValueError("Must provide a Model, Optimizer, Data creator.")
|
||||
if not (callable(model_creator) and callable(optimizer_creator)
|
||||
and callable(data_creator)):
|
||||
raise ValueError(
|
||||
"Must provide a callable model_creator, optimizer_creator, "
|
||||
"and data_creator.")
|
||||
|
||||
if num_replicas is not None:
|
||||
raise DeprecationWarning(
|
||||
"num_replicas is deprecated. Use num_workers instead.")
|
||||
|
||||
if batch_size is not None:
|
||||
raise DeprecationWarning(
|
||||
"batch_size is deprecated. Use config={'batch_size': N} "
|
||||
"specify a batch size for each worker or "
|
||||
"config={ray.util.sgd.utils.BATCH_SIZE: N} to specify a "
|
||||
"batch size to be used across all workers.")
|
||||
|
||||
if data_loader_args:
|
||||
raise ValueError(
|
||||
"data_loader_args is deprecated. You can return a "
|
||||
"torch.utils.data.DataLoader in data_creator. Ray will "
|
||||
"automatically set a DistributedSampler if a DataLoader is "
|
||||
"returned and num_workers > 1.")
|
||||
|
||||
self.model_creator = model_creator
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
|
||||
Reference in New Issue
Block a user