[Ray SGD] Support num_steps continue training (#11142)

This commit is contained in:
Amog Kamsetty
2020-10-02 23:43:26 -07:00
committed by GitHub
parent 90e0054b60
commit 6325a973a2
4 changed files with 146 additions and 18 deletions
+77 -1
View File
@@ -370,6 +370,82 @@ def test_dataset(ray_start_4_cpus, use_local):
trainer.shutdown()
@pytest.mark.parametrize("use_local", [True, False])
def test_num_steps(ray_start_2_cpus, use_local):
"""Tests if num_steps continues training from the subsampled dataset."""
def data_creator(config):
train_dataset = [0] * 5 + [1] * 5
val_dataset = [0] * 5 + [1] * 5
return DataLoader(train_dataset, batch_size=config["batch_size"]), \
DataLoader(val_dataset, batch_size=config["batch_size"])
batch_size = 1
Operator = TrainingOperator.from_creators(model_creator, optimizer_creator,
data_creator)
def train_func(self, iterator, info=None):
total_sum = 0
num_items = 0
for e in iterator:
total_sum += e
num_items += 1
return {"average": total_sum.item() / num_items}
TestOperator = get_test_operator(Operator)
trainer = TorchTrainer(
training_operator_cls=TestOperator,
num_workers=2,
use_local=use_local,
add_dist_sampler=False,
config={
"batch_size": batch_size,
"custom_func": train_func
})
# If num_steps not passed, should do one full epoch.
result = trainer.train()
# Average of 5 0s and 5 1s
assert result["average"] == 0.5
assert result["epoch"] == 1
val_result = trainer.validate()
assert val_result["average"] == 0.5
# Train again with num_steps.
result = trainer.train(num_steps=5)
# 5 zeros
assert result["average"] == 0
assert result["epoch"] == 2
val_result = trainer.validate(num_steps=5)
assert val_result["average"] == 0
# Should continue where last train run left off.
result = trainer.train(num_steps=3)
# 3 ones.
assert result["average"] == 1
assert result["epoch"] == 2
val_result = trainer.validate(num_steps=3)
assert val_result["average"] == 1
# Should continue from last train run, and cycle to beginning.
result = trainer.train(num_steps=5)
# 2 ones and 3 zeros.
assert result["average"] == 0.4
assert result["epoch"] == 3
val_result = trainer.validate(num_steps=5)
assert val_result["average"] == 0.4
# Should continue, and since num_steps not passed in, just finishes epoch.
result = trainer.train()
# 2 zeros and 5 ones.
assert result["average"] == 5 / 7
assert result["epoch"] == 3
val_result = trainer.validate()
assert val_result["average"] == 5 / 7
trainer.shutdown()
@pytest.mark.parametrize("use_local", [True, False])
def test_split_batch(ray_start_2_cpus, use_local):
if not dist.is_available():
@@ -467,7 +543,7 @@ def test_metrics(ray_start_2_cpus, num_workers, use_local):
"val_size": val_size
})
stats = trainer.train(num_steps=num_train_steps)
stats = trainer.train()
# Test that we output mean and last of custom metrics in an epoch
assert "score" in stats
assert stats["last_score"] == 0
+60 -10
View File
@@ -47,6 +47,13 @@ class TorchRunner:
"https://www.github.com/nvidia/apex to use fp16 training.")
self.scheduler_step_freq = scheduler_step_freq
# Training and Validation iterators
self.train_iterator = None
self._should_reset_train_loader = True
self.val_iterator = None
self._should_reset_val_loader = True
def setup_operator(self):
"""Create the training operator."""
self.training_operator = self.training_operator_cls(
@@ -60,6 +67,49 @@ class TorchRunner:
apex_args=self.apex_args,
scheduler_step_freq=self.scheduler_step_freq)
def get_iterator(self, training=True):
if training:
# In training.
if self._should_reset_train_loader:
self.epochs += 1
self.train_iterator = iter(self.train_loader)
self._should_reset_train_loader = False
return self.train_iterator
else:
# In validation.
if self._should_reset_val_loader:
self.val_iterator = iter(self.validation_loader)
self._should_reset_val_loader = False
return self.val_iterator
def make_iterator(self, training=True, num_steps=None):
steps = 0
# Needed to make sure we don't loop forever if iterator is empty
has_at_least_one = False
while True:
iterator = self.get_iterator(training=training)
if num_steps is not None and steps >= num_steps:
# Stop iterating after reaching num_steps.
break
try:
item = next(iterator)
steps += 1
if not has_at_least_one:
has_at_least_one = True
yield item
except StopIteration:
# Set should reset iterator on next cycle to True.
if training:
self._should_reset_train_loader = True
else:
self._should_reset_val_loader = True
if num_steps is None or not has_at_least_one:
# End after current epoch or if iterator has no elements.
break
else:
# Else, start cycling through the iterator again.
pass
def train_epoch(self,
num_steps=None,
profile=False,
@@ -76,9 +126,7 @@ class TorchRunner:
"epoch_idx": self.epochs,
})
with self.timers.record("train_epoch"):
if iterator is None:
iterator = iter(self.train_loader)
else:
if iterator is not None:
# Dataset will provide us with a list of tuples but we
# need two lists.
def format_batch(batch):
@@ -86,11 +134,14 @@ class TorchRunner:
return torch.cat(features), torch.cat(targets)
iterator = map(format_batch, iterator)
if num_steps:
iterator = itertools.islice(iterator, num_steps)
if num_steps:
iterator = itertools.islice(iterator, num_steps)
self.epochs += 1
else:
iterator = self.make_iterator(
training=True, num_steps=num_steps)
train_stats = self.training_operator.train_epoch(iterator, info)
self.epochs += 1
# This is so that `epochs` is first in ordering.
stats = dict(epoch=self.epochs, **train_stats)
if profile:
@@ -101,12 +152,9 @@ class TorchRunner:
"""Evaluates the model on the validation data set."""
info = info or {}
self._toggle_profiling(profile=profile)
validation_loader = self.validation_loader
with self.timers.record("validation"):
iterator = validation_loader
if num_steps:
iterator = itertools.islice(iterator, num_steps)
iterator = self.make_iterator(training=False, num_steps=num_steps)
validation_stats = self.training_operator.validate(
iterator, info=info)
if profile:
@@ -197,6 +245,8 @@ class TorchRunner:
def shutdown(self):
"""Attempts to shut down the worker."""
del self.train_iterator
del self.val_iterator
del self.training_operator
if torch.cuda.is_available():
torch.cuda.empty_cache()
+6 -7
View File
@@ -375,9 +375,9 @@ class TorchTrainer:
instance preemption.
Args:
num_steps (int): Number of batches to compute update steps on.
This corresponds also to the number of times
``TrainingOperator.train_batch`` is called.
num_steps (int): Number of batches to compute update steps on
per worker. This corresponds also to the number of times
``TrainingOperator.train_batch`` is called per worker.
profile (bool): Returns time stats for the training procedure.
reduce_results (bool): Whether to average all metrics across
all workers into one dict. If a metric is a non-numerical
@@ -437,7 +437,6 @@ class TorchTrainer:
NUM_SAMPLES: sum(
stats.pop(NUM_SAMPLES, np.nan) for stats in worker_stats)
}
for stat_key in worker_stats[0]:
if isinstance(worker_stats[0][stat_key], numbers.Number):
stats[stat_key] = np.nanmean(
@@ -479,9 +478,9 @@ class TorchTrainer:
"""Evaluates the model on the validation data set.
Args:
num_steps (int): Number of batches to compute update steps on.
This corresponds also to the number of times
``TrainingOperator.validate_batch`` is called.
num_steps (int): Number of batches to compute update steps on
per worker. This corresponds also to the number of times
``TrainingOperator.validate_batch`` is called per worker.
profile (bool): Returns time stats for the evaluation procedure.
reduce_results (bool): Whether to average all metrics across
all workers into one dict. If a metric is a non-numerical
@@ -1032,6 +1032,9 @@ def get_test_operator(operator_cls):
return func(self, iterator, info)
return {"done": 1}
def validate(self, iterator, info):
return self.train_epoch(iterator, info)
return _TestingOperator