mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
[sgd] Make serialization of data creation optional (#8027)
* pytest * Update python/ray/util/sgd/torch/torch_trainer.py Co-Authored-By: Ujval Misra <misraujval@gmail.com> Co-authored-by: Ujval Misra <misraujval@gmail.com>
This commit is contained in:
@@ -60,6 +60,29 @@ def test_resize(ray_start_2_cpus): # noqa: F811
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
def test_non_serialized_data(ray_start_2_cpus): # noqa: F811
|
||||
duration = 10
|
||||
|
||||
def slow_data(func):
|
||||
def slowed_func(*args, **kwargs):
|
||||
time.sleep(duration)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return slowed_func
|
||||
|
||||
start = time.time()
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=slow_data(data_creator),
|
||||
optimizer_creator=optimizer_creator,
|
||||
serialize_data_creation=False,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
num_workers=2)
|
||||
elapsed = time.time() - start
|
||||
assert elapsed < duration * 2
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_dead_trainer(ray_start_2_cpus): # noqa: F811
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
|
||||
@@ -28,7 +28,6 @@ class DistributedTorchRunner(TorchRunner):
|
||||
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
|
||||
over each model. If False, you are expected to call it yourself.
|
||||
kwargs: Keyword arguments for TorchRunner.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -29,23 +29,7 @@ except ImportError:
|
||||
|
||||
|
||||
class TorchRunner:
|
||||
"""Manages a PyTorch model for training.
|
||||
|
||||
Args:
|
||||
model_creator (dict -> Model(s)): see torch_trainer.py
|
||||
data_creator (dict -> Iterable(s)): see torch_trainer.py.
|
||||
optimizer_creator ((models, dict) -> optimizers): see torch_trainer.py.
|
||||
loss_creator (torch.nn.*Loss class | dict -> loss):
|
||||
see torch_trainer.py.
|
||||
scheduler_creator ((optimizers, dict) -> scheduler): see
|
||||
torch_trainer.py.
|
||||
training_operator_cls: see torch_trainer.py
|
||||
config (dict): see torch_trainer.py.
|
||||
use_gpu (bool): see torch_trainer.py.
|
||||
use_fp16 (bool): see torch_trainer.py.
|
||||
apex_args (dict|None): see torch_trainer.py.
|
||||
scheduler_step_freq (str): see torch_trainer.py.
|
||||
"""
|
||||
"""Manages a PyTorch model for training."""
|
||||
|
||||
def __init__(self,
|
||||
model_creator,
|
||||
@@ -56,6 +40,7 @@ class TorchRunner:
|
||||
training_operator_cls=None,
|
||||
config=None,
|
||||
use_gpu=False,
|
||||
serialize_data_creation=True,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
@@ -77,6 +62,7 @@ class TorchRunner:
|
||||
self.train_loader = None
|
||||
self.validation_loader = None
|
||||
self.training_operator = None
|
||||
self.serialize_data_creation = serialize_data_creation
|
||||
self.use_gpu = use_gpu
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
@@ -102,17 +88,15 @@ class TorchRunner:
|
||||
|
||||
def _initialize_dataloaders(self):
|
||||
logger.debug("Instantiating dataloaders.")
|
||||
# When creating loaders, a filelock will be used to ensure no
|
||||
# race conditions in data downloading among different workers.
|
||||
with FileLock(os.path.join(tempfile.gettempdir(), ".ray_data.lock")):
|
||||
loaders = None
|
||||
if self.serialize_data_creation:
|
||||
logger.debug("Serializing the dataloading process.")
|
||||
with FileLock(
|
||||
os.path.join(tempfile.gettempdir(), ".raydata.lock")):
|
||||
loaders = self.data_creator(self.config)
|
||||
else:
|
||||
loaders = self.data_creator(self.config)
|
||||
train_loader, val_loader = self._validate_loaders(loaders)
|
||||
if not isinstance(train_loader, torch.utils.data.DataLoader):
|
||||
logger.warning(
|
||||
"TorchTrainer data_creator return values are no longer "
|
||||
"wrapped as DataLoaders. Users must return DataLoader(s) "
|
||||
"in data_creator. This warning will be removed in "
|
||||
"a future version of Ray.")
|
||||
train_loader, val_loader = self._validate_loaders(loaders)
|
||||
|
||||
self.train_loader, self.validation_loader = train_loader, val_loader
|
||||
|
||||
|
||||
@@ -131,6 +131,10 @@ class TorchTrainer:
|
||||
support "nccl", "gloo", and "auto". If "auto", RaySGD will
|
||||
automatically use "nccl" if `use_gpu` is True, and "gloo"
|
||||
otherwise.
|
||||
serialize_data_creation (bool): A filelock will be used
|
||||
to ensure no race conditions in data downloading among
|
||||
different workers on the same node (using the local file system).
|
||||
Defaults to True.
|
||||
wrap_ddp (bool): Whether to automatically wrap DistributedDataParallel
|
||||
over each model. If False, you are expected to call it yourself.
|
||||
add_dist_sampler (bool): Whether to automatically add a
|
||||
@@ -171,6 +175,7 @@ class TorchTrainer:
|
||||
use_gpu="auto",
|
||||
backend="auto",
|
||||
wrap_ddp=True,
|
||||
serialize_data_creation=True,
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
@@ -237,6 +242,7 @@ class TorchTrainer:
|
||||
self.use_gpu = use_gpu
|
||||
self.max_replicas = num_workers
|
||||
|
||||
self.serialize_data_creation = serialize_data_creation
|
||||
self.wrap_ddp = wrap_ddp
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_tqdm = use_tqdm
|
||||
@@ -298,6 +304,7 @@ class TorchTrainer:
|
||||
scheduler_creator=self.scheduler_creator,
|
||||
training_operator_cls=self.training_operator_cls,
|
||||
config=worker_config,
|
||||
serialize_data_creation=self.serialize_data_creation,
|
||||
use_fp16=self.use_fp16,
|
||||
use_gpu=self.use_gpu,
|
||||
use_tqdm=self.use_tqdm,
|
||||
|
||||
Reference in New Issue
Block a user