mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 00:35:01 +08:00
[Ray SGD] Support num_steps continue training (#11142)
This commit is contained in:
@@ -370,6 +370,82 @@ def test_dataset(ray_start_4_cpus, use_local):
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_num_steps(ray_start_2_cpus, use_local):
|
||||
"""Tests if num_steps continues training from the subsampled dataset."""
|
||||
|
||||
def data_creator(config):
|
||||
train_dataset = [0] * 5 + [1] * 5
|
||||
val_dataset = [0] * 5 + [1] * 5
|
||||
return DataLoader(train_dataset, batch_size=config["batch_size"]), \
|
||||
DataLoader(val_dataset, batch_size=config["batch_size"])
|
||||
|
||||
batch_size = 1
|
||||
Operator = TrainingOperator.from_creators(model_creator, optimizer_creator,
|
||||
data_creator)
|
||||
|
||||
def train_func(self, iterator, info=None):
|
||||
total_sum = 0
|
||||
num_items = 0
|
||||
for e in iterator:
|
||||
total_sum += e
|
||||
num_items += 1
|
||||
return {"average": total_sum.item() / num_items}
|
||||
|
||||
TestOperator = get_test_operator(Operator)
|
||||
trainer = TorchTrainer(
|
||||
training_operator_cls=TestOperator,
|
||||
num_workers=2,
|
||||
use_local=use_local,
|
||||
add_dist_sampler=False,
|
||||
config={
|
||||
"batch_size": batch_size,
|
||||
"custom_func": train_func
|
||||
})
|
||||
|
||||
# If num_steps not passed, should do one full epoch.
|
||||
result = trainer.train()
|
||||
# Average of 5 0s and 5 1s
|
||||
assert result["average"] == 0.5
|
||||
assert result["epoch"] == 1
|
||||
val_result = trainer.validate()
|
||||
assert val_result["average"] == 0.5
|
||||
|
||||
# Train again with num_steps.
|
||||
result = trainer.train(num_steps=5)
|
||||
# 5 zeros
|
||||
assert result["average"] == 0
|
||||
assert result["epoch"] == 2
|
||||
val_result = trainer.validate(num_steps=5)
|
||||
assert val_result["average"] == 0
|
||||
|
||||
# Should continue where last train run left off.
|
||||
result = trainer.train(num_steps=3)
|
||||
# 3 ones.
|
||||
assert result["average"] == 1
|
||||
assert result["epoch"] == 2
|
||||
val_result = trainer.validate(num_steps=3)
|
||||
assert val_result["average"] == 1
|
||||
|
||||
# Should continue from last train run, and cycle to beginning.
|
||||
result = trainer.train(num_steps=5)
|
||||
# 2 ones and 3 zeros.
|
||||
assert result["average"] == 0.4
|
||||
assert result["epoch"] == 3
|
||||
val_result = trainer.validate(num_steps=5)
|
||||
assert val_result["average"] == 0.4
|
||||
|
||||
# Should continue, and since num_steps not passed in, just finishes epoch.
|
||||
result = trainer.train()
|
||||
# 2 zeros and 5 ones.
|
||||
assert result["average"] == 5 / 7
|
||||
assert result["epoch"] == 3
|
||||
val_result = trainer.validate()
|
||||
assert val_result["average"] == 5 / 7
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_local", [True, False])
|
||||
def test_split_batch(ray_start_2_cpus, use_local):
|
||||
if not dist.is_available():
|
||||
@@ -467,7 +543,7 @@ def test_metrics(ray_start_2_cpus, num_workers, use_local):
|
||||
"val_size": val_size
|
||||
})
|
||||
|
||||
stats = trainer.train(num_steps=num_train_steps)
|
||||
stats = trainer.train()
|
||||
# Test that we output mean and last of custom metrics in an epoch
|
||||
assert "score" in stats
|
||||
assert stats["last_score"] == 0
|
||||
|
||||
@@ -47,6 +47,13 @@ class TorchRunner:
|
||||
"https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
self.scheduler_step_freq = scheduler_step_freq
|
||||
|
||||
# Training and Validation iterators
|
||||
self.train_iterator = None
|
||||
self._should_reset_train_loader = True
|
||||
|
||||
self.val_iterator = None
|
||||
self._should_reset_val_loader = True
|
||||
|
||||
def setup_operator(self):
|
||||
"""Create the training operator."""
|
||||
self.training_operator = self.training_operator_cls(
|
||||
@@ -60,6 +67,49 @@ class TorchRunner:
|
||||
apex_args=self.apex_args,
|
||||
scheduler_step_freq=self.scheduler_step_freq)
|
||||
|
||||
def get_iterator(self, training=True):
|
||||
if training:
|
||||
# In training.
|
||||
if self._should_reset_train_loader:
|
||||
self.epochs += 1
|
||||
self.train_iterator = iter(self.train_loader)
|
||||
self._should_reset_train_loader = False
|
||||
return self.train_iterator
|
||||
else:
|
||||
# In validation.
|
||||
if self._should_reset_val_loader:
|
||||
self.val_iterator = iter(self.validation_loader)
|
||||
self._should_reset_val_loader = False
|
||||
return self.val_iterator
|
||||
|
||||
def make_iterator(self, training=True, num_steps=None):
|
||||
steps = 0
|
||||
# Needed to make sure we don't loop forever if iterator is empty
|
||||
has_at_least_one = False
|
||||
while True:
|
||||
iterator = self.get_iterator(training=training)
|
||||
if num_steps is not None and steps >= num_steps:
|
||||
# Stop iterating after reaching num_steps.
|
||||
break
|
||||
try:
|
||||
item = next(iterator)
|
||||
steps += 1
|
||||
if not has_at_least_one:
|
||||
has_at_least_one = True
|
||||
yield item
|
||||
except StopIteration:
|
||||
# Set should reset iterator on next cycle to True.
|
||||
if training:
|
||||
self._should_reset_train_loader = True
|
||||
else:
|
||||
self._should_reset_val_loader = True
|
||||
if num_steps is None or not has_at_least_one:
|
||||
# End after current epoch or if iterator has no elements.
|
||||
break
|
||||
else:
|
||||
# Else, start cycling through the iterator again.
|
||||
pass
|
||||
|
||||
def train_epoch(self,
|
||||
num_steps=None,
|
||||
profile=False,
|
||||
@@ -76,9 +126,7 @@ class TorchRunner:
|
||||
"epoch_idx": self.epochs,
|
||||
})
|
||||
with self.timers.record("train_epoch"):
|
||||
if iterator is None:
|
||||
iterator = iter(self.train_loader)
|
||||
else:
|
||||
if iterator is not None:
|
||||
# Dataset will provide us with a list of tuples but we
|
||||
# need two lists.
|
||||
def format_batch(batch):
|
||||
@@ -86,11 +134,14 @@ class TorchRunner:
|
||||
return torch.cat(features), torch.cat(targets)
|
||||
|
||||
iterator = map(format_batch, iterator)
|
||||
if num_steps:
|
||||
iterator = itertools.islice(iterator, num_steps)
|
||||
if num_steps:
|
||||
iterator = itertools.islice(iterator, num_steps)
|
||||
self.epochs += 1
|
||||
else:
|
||||
iterator = self.make_iterator(
|
||||
training=True, num_steps=num_steps)
|
||||
train_stats = self.training_operator.train_epoch(iterator, info)
|
||||
|
||||
self.epochs += 1
|
||||
# This is so that `epochs` is first in ordering.
|
||||
stats = dict(epoch=self.epochs, **train_stats)
|
||||
if profile:
|
||||
@@ -101,12 +152,9 @@ class TorchRunner:
|
||||
"""Evaluates the model on the validation data set."""
|
||||
info = info or {}
|
||||
self._toggle_profiling(profile=profile)
|
||||
validation_loader = self.validation_loader
|
||||
|
||||
with self.timers.record("validation"):
|
||||
iterator = validation_loader
|
||||
if num_steps:
|
||||
iterator = itertools.islice(iterator, num_steps)
|
||||
iterator = self.make_iterator(training=False, num_steps=num_steps)
|
||||
validation_stats = self.training_operator.validate(
|
||||
iterator, info=info)
|
||||
if profile:
|
||||
@@ -197,6 +245,8 @@ class TorchRunner:
|
||||
|
||||
def shutdown(self):
|
||||
"""Attempts to shut down the worker."""
|
||||
del self.train_iterator
|
||||
del self.val_iterator
|
||||
del self.training_operator
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -375,9 +375,9 @@ class TorchTrainer:
|
||||
instance preemption.
|
||||
|
||||
Args:
|
||||
num_steps (int): Number of batches to compute update steps on.
|
||||
This corresponds also to the number of times
|
||||
``TrainingOperator.train_batch`` is called.
|
||||
num_steps (int): Number of batches to compute update steps on
|
||||
per worker. This corresponds also to the number of times
|
||||
``TrainingOperator.train_batch`` is called per worker.
|
||||
profile (bool): Returns time stats for the training procedure.
|
||||
reduce_results (bool): Whether to average all metrics across
|
||||
all workers into one dict. If a metric is a non-numerical
|
||||
@@ -437,7 +437,6 @@ class TorchTrainer:
|
||||
NUM_SAMPLES: sum(
|
||||
stats.pop(NUM_SAMPLES, np.nan) for stats in worker_stats)
|
||||
}
|
||||
|
||||
for stat_key in worker_stats[0]:
|
||||
if isinstance(worker_stats[0][stat_key], numbers.Number):
|
||||
stats[stat_key] = np.nanmean(
|
||||
@@ -479,9 +478,9 @@ class TorchTrainer:
|
||||
"""Evaluates the model on the validation data set.
|
||||
|
||||
Args:
|
||||
num_steps (int): Number of batches to compute update steps on.
|
||||
This corresponds also to the number of times
|
||||
``TrainingOperator.validate_batch`` is called.
|
||||
num_steps (int): Number of batches to compute update steps on
|
||||
per worker. This corresponds also to the number of times
|
||||
``TrainingOperator.validate_batch`` is called per worker.
|
||||
profile (bool): Returns time stats for the evaluation procedure.
|
||||
reduce_results (bool): Whether to average all metrics across
|
||||
all workers into one dict. If a metric is a non-numerical
|
||||
|
||||
@@ -1032,6 +1032,9 @@ def get_test_operator(operator_cls):
|
||||
return func(self, iterator, info)
|
||||
return {"done": 1}
|
||||
|
||||
def validate(self, iterator, info):
|
||||
return self.train_epoch(iterator, info)
|
||||
|
||||
return _TestingOperator
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user