[tune] Fault Tolerance: handle lost checkpoints by restart (#3657)

Checks that node failure with lost checkpoints does not crash. Also adds test.
This commit is contained in:
Richard Liaw
2019-01-04 22:05:27 -08:00
committed by GitHub
parent 7db1f3be2a
commit 960a943503
3 changed files with 99 additions and 55 deletions
+12 -3
View File
@@ -61,11 +61,17 @@ class RayTrialExecutor(TrialExecutor):
self._running[remote] = trial
def _start_trial(self, trial, checkpoint=None):
"""Starts trial and restores last result if trial was paused.
Raises:
ValueError if restoring from checkpoint fails.
"""
prior_status = trial.status
self.set_status(trial, Trial.RUNNING)
trial.runner = self._setup_runner(trial)
if not self.restore(trial, checkpoint):
return
if trial.status == Trial.ERROR:
raise RuntimeError("Restore from checkpoint failed.")
previous_run = self._find_item(self._paused, trial)
if (prior_status == Trial.PAUSED and previous_run):
@@ -127,12 +133,15 @@ class RayTrialExecutor(TrialExecutor):
try:
self._start_trial(trial, checkpoint)
except Exception:
logger.exception("Error stopping runner - retrying...")
logger.exception("Error starting runner. "
"Trying again without checkpoint.")
error_msg = traceback.format_exc()
time.sleep(2)
self._stop_trial(trial, error=True, error_msg=error_msg)
try:
self._start_trial(trial, checkpoint)
# This forces the trial to not start from checkpoint.
trial.clear_checkpoint()
self._start_trial(trial)
except Exception:
logger.exception("Error starting runner, aborting!")
error_msg = traceback.format_exc()
+84 -52
View File
@@ -7,6 +7,7 @@ import json
import time
import os
import pytest
import shutil
try:
import pytest_timeout
except ImportError:
@@ -24,26 +25,6 @@ from ray.tune.trial_runner import TrialRunner
from ray.tune.suggest import BasicVariantGenerator
class _Fail(tune.Trainable):
"""Fails on the 4th iteration."""
def _setup(self, config):
self.state = {"hi": 0}
def _train(self):
self.state["hi"] += 1
time.sleep(0.5)
if self.state["hi"] >= 4:
assert False
return {}
def _save(self, path):
return self.state
def _restore(self, state):
self.state = state
def _start_new_cluster():
cluster = Cluster(
initialize_head=True,
@@ -121,38 +102,33 @@ def test_counting_resources(start_connected_cluster):
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2
@pytest.mark.skip("Add this test once reconstruction is fixed")
@pytest.mark.skipif(
pytest_timeout is None,
reason="Timeout package not installed; skipping test.")
@pytest.mark.timeout(10, method="thread")
def test_remove_node_before_result(start_connected_cluster):
"""Removing a node should cause a Trial to be requeued."""
cluster = start_connected_cluster
def test_remove_node_before_result(start_connected_emptyhead_cluster):
"""Tune continues when node is removed before trial returns."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(resources=dict(CPU=1))
# TODO(rliaw): Make blocking an option?
assert cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
kwargs = {"stopping_criterion": {"training_iteration": 3}}
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
kwargs = {
"stopping_criterion": {
"training_iteration": 3
},
"checkpoint_freq": 2,
"max_failures": 2
}
trial = Trial("__fake", **kwargs)
runner.add_trial(trial)
runner.step() # run 1
runner.step() # run 2
assert all(t.status == Trial.RUNNING for t in trials)
runner.step() # 1 result
assert trial.status == Trial.RUNNING
cluster.remove_node(node)
cluster.add_node(resources=dict(CPU=1))
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources["CPU"] == 1
assert ray.global_state.cluster_resources()["CPU"] == 1
runner.step() # recover
for i in range(5):
for i in range(3):
runner.step()
assert all(t.status == Trial.TERMINATED for t in trials)
assert trial.status == Trial.TERMINATED
with pytest.raises(TuneError):
runner.step()
@@ -267,6 +243,40 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
runner.step()
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
"""Test checks that trial restarts if checkpoint is lost w/ node fail."""
cluster = start_connected_emptyhead_cluster
node = cluster.add_node(resources=dict(CPU=1))
assert cluster.wait_for_nodes()
runner = TrialRunner(BasicVariantGenerator())
kwargs = {
"stopping_criterion": {
"training_iteration": 3
},
"checkpoint_freq": 2,
"max_failures": 2
}
# Test recovery of trial that has been checkpointed
t1 = Trial("__fake", **kwargs)
runner.add_trial(t1)
runner.step() # start
runner.step() # 1 result
runner.step() # 2 result and checkpoint
assert t1.has_checkpoint()
cluster.add_node(resources=dict(CPU=1))
cluster.remove_node(node)
assert cluster.wait_for_nodes()
shutil.rmtree(os.path.dirname(t1._checkpoint.value))
runner.step() # Recovery step
for i in range(3):
runner.step()
assert t1.status == Trial.TERMINATED
def test_cluster_down_simple(start_connected_cluster, tmpdir):
"""Tests that TrialRunner save/restore works on cluster shutdown."""
cluster = start_connected_cluster
@@ -379,7 +389,7 @@ tune.run_experiments(
# The trainable returns every 0.5 seconds, so this should not miss
# the checkpoint.
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
for i in range(50):
for i in range(100):
if os.path.exists(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME)):
@@ -389,7 +399,7 @@ tune.run_experiments(
last_res = trials[0].last_result
if last_res is not None and last_res["training_iteration"]:
break
time.sleep(0.2)
time.sleep(0.3)
if not os.path.exists(
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
@@ -415,13 +425,35 @@ tune.run_experiments(
def test_cluster_interrupt(start_connected_cluster, tmpdir):
"""Tests run_experiment on cluster shutdown even with atypical trial.
"""Tests run_experiment on cluster shutdown with actual interrupt.
The trial fails on the 4th step, and the checkpointing happens on
the 3rd step, so restoring should actually launch the trial again.
This is an end-to-end test.
"""
cluster = start_connected_cluster
dirpath = str(tmpdir)
# Needs to be in scope for pytest
class _Mock(tune.Trainable):
"""Finishes on the 4th iteration."""
def _setup(self, config):
self.state = {"hi": 0}
def _train(self):
self.state["hi"] += 1
time.sleep(0.5)
return {"done": self.state["hi"] >= 4}
def _save(self, path):
return self.state
def _restore(self, state):
self.state = state
# Removes indent from class.
reformatted = "\n".join(line[4:] if len(line) else line
for line in inspect.getsource(_Mock).split("\n"))
script = """
import time
import ray
@@ -444,8 +476,8 @@ tune.run_experiments(
""".format(
redis_address=cluster.redis_address,
checkpoint_dir=dirpath,
fail_class_code=inspect.getsource(_Fail),
fail_class=_Fail.__name__)
fail_class_code=reformatted,
fail_class=_Mock.__name__)
run_string_as_driver_nonblocking(script)
# Wait until the right checkpoint is saved.
@@ -471,7 +503,7 @@ tune.run_experiments(
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
Experiment._register_if_needed(_Fail)
Experiment._register_if_needed(_Mock)
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
@@ -483,13 +515,13 @@ tune.run_experiments(
trials2 = tune.run_experiments(
{
"experiment": {
"run": _Fail,
"run": _Mock,
"local_dir": dirpath,
"checkpoint_freq": 1
}
},
resume=True,
raise_on_failed_trial=False)
assert all(t.status == Trial.ERROR for t in trials2)
assert all(t.status == Trial.TERMINATED for t in trials2)
assert {t.trial_id for t in trials2} == {t.trial_id for t in trials}
cluster.shutdown()
+3
View File
@@ -337,6 +337,9 @@ class Trial(object):
def has_checkpoint(self):
return self._checkpoint.value is not None
def clear_checkpoint(self):
self._checkpoint.value = None
def should_recover(self):
"""Returns whether the trial qualifies for restoring.