mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[sgd] Avoid parameter "gotcha" for learning rate scheduler (#8107)
* with-scheduler-creator * none * add_freq * runner * torch
This commit is contained in:
@@ -248,6 +248,7 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
optimizer_creator=multi_optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=multi_scheduler_creator,
|
||||
scheduler_step_freq="epoch",
|
||||
training_operator_cls=_TestingOperator,
|
||||
num_workers=num_workers,
|
||||
config={
|
||||
@@ -260,7 +261,7 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scheduler_freq", ["epoch", "batch"])
|
||||
@pytest.mark.parametrize("scheduler_freq", ["epoch", "batch", "manual", None])
|
||||
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
def train_epoch(self, iterator, info):
|
||||
assert info[SCHEDULER_STEP] == scheduler_freq
|
||||
@@ -270,19 +271,29 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
|
||||
return torch.optim.lr_scheduler.StepLR(
|
||||
optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
scheduler_creator=scheduler_creator,
|
||||
scheduler_step_freq=scheduler_freq)
|
||||
if scheduler_freq is None:
|
||||
with pytest.raises(ValueError):
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
scheduler_creator=scheduler_creator,
|
||||
scheduler_step_freq=scheduler_freq)
|
||||
else:
|
||||
trainer = TorchTrainer(
|
||||
model_creator=model_creator,
|
||||
data_creator=data_creator,
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
config={"custom_func": train_epoch},
|
||||
training_operator_cls=_TestingOperator,
|
||||
scheduler_creator=scheduler_creator,
|
||||
scheduler_step_freq=scheduler_freq)
|
||||
|
||||
for i in range(3):
|
||||
trainer.train()
|
||||
trainer.shutdown()
|
||||
for i in range(3):
|
||||
trainer.train()
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
def test_profiling(ray_start_2_cpus): # noqa: F811
|
||||
@@ -459,6 +470,7 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=lambda config: nn.MSELoss(),
|
||||
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer),
|
||||
scheduler_step_freq="manual",
|
||||
training_operator_cls=_TestingOperator)
|
||||
trainer.update_scheduler(0.5)
|
||||
trainer.update_scheduler(0.5)
|
||||
|
||||
@@ -5,6 +5,9 @@ NUM_STEPS = "__num_steps__"
|
||||
SCHEDULER_STEP = "scheduler_step"
|
||||
SCHEDULER_STEP_BATCH = "batch"
|
||||
SCHEDULER_STEP_EPOCH = "epoch"
|
||||
SCHEDULER_STEP_MANUAL = "manual"
|
||||
NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 10)
|
||||
|
||||
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
|
||||
VALID_SCHEDULER_STEP = {
|
||||
SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_MANUAL
|
||||
}
|
||||
|
||||
@@ -119,6 +119,7 @@ trainer = TorchTrainer(
|
||||
optimizer_creator=optimizer_creator,
|
||||
loss_creator=nn.MSELoss,
|
||||
scheduler_creator=scheduler_creator,
|
||||
scheduler_step_freq="epoch", # if scheduler_creator is set
|
||||
config={"lr": 0.001, "batch_size": 64})
|
||||
|
||||
# __torch_trainer_end__
|
||||
|
||||
@@ -44,7 +44,7 @@ class TorchRunner:
|
||||
use_fp16=False,
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
scheduler_step_freq="batch"):
|
||||
scheduler_step_freq=None):
|
||||
self.model_creator = model_creator
|
||||
self.optimizer_creator = optimizer_creator
|
||||
self.loss_creator = loss_creator
|
||||
|
||||
@@ -23,11 +23,10 @@ RESIZE_COOLDOWN_S = 10
|
||||
|
||||
|
||||
def _validate_scheduler_step_freq(scheduler_step_freq):
|
||||
if scheduler_step_freq:
|
||||
if scheduler_step_freq not in VALID_SCHEDULER_STEP:
|
||||
raise ValueError(
|
||||
"Scheduler step freq must be in {}. Got {}".format(
|
||||
VALID_SCHEDULER_STEP, scheduler_step_freq))
|
||||
"""This validation check only happens if a scheduler is passed in."""
|
||||
if scheduler_step_freq not in VALID_SCHEDULER_STEP:
|
||||
raise ValueError("Scheduler step freq must be in {}. Got {}".format(
|
||||
VALID_SCHEDULER_STEP, scheduler_step_freq))
|
||||
|
||||
|
||||
def _remind_gpu_usage(use_gpu):
|
||||
@@ -148,10 +147,13 @@ class TorchTrainer:
|
||||
See https://nvidia.github.io/apex/amp.html#module-apex.amp. By
|
||||
default, the models and optimizers are passed in. Consider using
|
||||
"num_losses" if operating over multiple models and optimizers.
|
||||
scheduler_step_freq: "batch", "epoch", or None. This will
|
||||
scheduler_step_freq: "batch", "epoch", "manual", or None. This will
|
||||
determine when ``scheduler.step`` is called. If "batch",
|
||||
``step`` will be called after every optimizer step. If "epoch",
|
||||
``step`` will be called after one pass of the DataLoader.
|
||||
``step`` will be called after one pass of the DataLoader. If
|
||||
"manual", the scheduler will not be incremented automatically -
|
||||
you are expected to call ``trainer.update_schedulers`` manually.
|
||||
If a scheduler is passed in, this value is expected to not be None.
|
||||
|
||||
"""
|
||||
|
||||
@@ -180,7 +182,7 @@ class TorchTrainer:
|
||||
use_tqdm=False,
|
||||
apex_args=None,
|
||||
add_dist_sampler=True,
|
||||
scheduler_step_freq="batch",
|
||||
scheduler_step_freq=None,
|
||||
num_replicas=None,
|
||||
batch_size=None,
|
||||
data_loader_args=None,
|
||||
@@ -259,7 +261,9 @@ class TorchTrainer:
|
||||
self.local_worker = DeactivatedRunner()
|
||||
self.remote_workers = []
|
||||
|
||||
_validate_scheduler_step_freq(scheduler_step_freq)
|
||||
if scheduler_creator:
|
||||
_validate_scheduler_step_freq(scheduler_step_freq)
|
||||
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
if not ray.is_initialized() and self.max_replicas > 1:
|
||||
|
||||
Reference in New Issue
Block a user