[sgd] Avoid parameter "gotcha" for learning rate scheduler (#8107)

* with-scheduler-creator

* none

* add_freq

* runner

* torch
This commit is contained in:
Richard Liaw
2020-04-21 01:01:04 -07:00
committed by GitHub
parent d15609ba2a
commit fa7eecf48a
5 changed files with 44 additions and 24 deletions
+25 -13
View File
@@ -248,6 +248,7 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
optimizer_creator=multi_optimizer_creator,
loss_creator=nn.MSELoss,
scheduler_creator=multi_scheduler_creator,
scheduler_step_freq="epoch",
training_operator_cls=_TestingOperator,
num_workers=num_workers,
config={
@@ -260,7 +261,7 @@ def test_multi_model_matrix(ray_start_2_cpus, num_workers): # noqa: F811
trainer.shutdown()
@pytest.mark.parametrize("scheduler_freq", ["epoch", "batch"])
@pytest.mark.parametrize("scheduler_freq", ["epoch", "batch", "manual", None])
def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
def train_epoch(self, iterator, info):
assert info[SCHEDULER_STEP] == scheduler_freq
@@ -270,19 +271,29 @@ def test_scheduler_freq(ray_start_2_cpus, scheduler_freq): # noqa: F811
return torch.optim.lr_scheduler.StepLR(
optimizer, step_size=30, gamma=0.1)
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
scheduler_creator=scheduler_creator,
scheduler_step_freq=scheduler_freq)
if scheduler_freq is None:
with pytest.raises(ValueError):
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
scheduler_creator=scheduler_creator,
scheduler_step_freq=scheduler_freq)
else:
trainer = TorchTrainer(
model_creator=model_creator,
data_creator=data_creator,
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
config={"custom_func": train_epoch},
training_operator_cls=_TestingOperator,
scheduler_creator=scheduler_creator,
scheduler_step_freq=scheduler_freq)
for i in range(3):
trainer.train()
trainer.shutdown()
for i in range(3):
trainer.train()
trainer.shutdown()
def test_profiling(ray_start_2_cpus): # noqa: F811
@@ -459,6 +470,7 @@ def test_scheduler_validate(ray_start_2_cpus): # noqa: F811
optimizer_creator=optimizer_creator,
loss_creator=lambda config: nn.MSELoss(),
scheduler_creator=lambda optimizer, cfg: ReduceLROnPlateau(optimizer),
scheduler_step_freq="manual",
training_operator_cls=_TestingOperator)
trainer.update_scheduler(0.5)
trainer.update_scheduler(0.5)
+4 -1
View File
@@ -5,6 +5,9 @@ NUM_STEPS = "__num_steps__"
SCHEDULER_STEP = "scheduler_step"
SCHEDULER_STEP_BATCH = "batch"
SCHEDULER_STEP_EPOCH = "epoch"
SCHEDULER_STEP_MANUAL = "manual"
NCCL_TIMEOUT_S = env_integer("NCCL_TIMEOUT_S", 10)
VALID_SCHEDULER_STEP = {SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH}
VALID_SCHEDULER_STEP = {
SCHEDULER_STEP_BATCH, SCHEDULER_STEP_EPOCH, SCHEDULER_STEP_MANUAL
}
@@ -119,6 +119,7 @@ trainer = TorchTrainer(
optimizer_creator=optimizer_creator,
loss_creator=nn.MSELoss,
scheduler_creator=scheduler_creator,
scheduler_step_freq="epoch", # if scheduler_creator is set
config={"lr": 0.001, "batch_size": 64})
# __torch_trainer_end__
+1 -1
View File
@@ -44,7 +44,7 @@ class TorchRunner:
use_fp16=False,
use_tqdm=False,
apex_args=None,
scheduler_step_freq="batch"):
scheduler_step_freq=None):
self.model_creator = model_creator
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
+13 -9
View File
@@ -23,11 +23,10 @@ RESIZE_COOLDOWN_S = 10
def _validate_scheduler_step_freq(scheduler_step_freq):
if scheduler_step_freq:
if scheduler_step_freq not in VALID_SCHEDULER_STEP:
raise ValueError(
"Scheduler step freq must be in {}. Got {}".format(
VALID_SCHEDULER_STEP, scheduler_step_freq))
"""This validation check only happens if a scheduler is passed in."""
if scheduler_step_freq not in VALID_SCHEDULER_STEP:
raise ValueError("Scheduler step freq must be in {}. Got {}".format(
VALID_SCHEDULER_STEP, scheduler_step_freq))
def _remind_gpu_usage(use_gpu):
@@ -148,10 +147,13 @@ class TorchTrainer:
See https://nvidia.github.io/apex/amp.html#module-apex.amp. By
default, the models and optimizers are passed in. Consider using
"num_losses" if operating over multiple models and optimizers.
scheduler_step_freq: "batch", "epoch", or None. This will
scheduler_step_freq: "batch", "epoch", "manual", or None. This will
determine when ``scheduler.step`` is called. If "batch",
``step`` will be called after every optimizer step. If "epoch",
``step`` will be called after one pass of the DataLoader.
``step`` will be called after one pass of the DataLoader. If
"manual", the scheduler will not be incremented automatically -
you are expected to call ``trainer.update_schedulers`` manually.
If a scheduler is passed in, this value is expected to not be None.
"""
@@ -180,7 +182,7 @@ class TorchTrainer:
use_tqdm=False,
apex_args=None,
add_dist_sampler=True,
scheduler_step_freq="batch",
scheduler_step_freq=None,
num_replicas=None,
batch_size=None,
data_loader_args=None,
@@ -259,7 +261,9 @@ class TorchTrainer:
self.local_worker = DeactivatedRunner()
self.remote_workers = []
_validate_scheduler_step_freq(scheduler_step_freq)
if scheduler_creator:
_validate_scheduler_step_freq(scheduler_step_freq)
self.scheduler_step_freq = scheduler_step_freq
if not ray.is_initialized() and self.max_replicas > 1: