mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 14:19:24 +08:00
[tune] Fault Tolerance: handle lost checkpoints by restart (#3657)
Checks that node failure with lost checkpoints does not crash. Also adds test.
This commit is contained in:
@@ -61,11 +61,17 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._running[remote] = trial
|
||||
|
||||
def _start_trial(self, trial, checkpoint=None):
|
||||
"""Starts trial and restores last result if trial was paused.
|
||||
|
||||
Raises:
|
||||
ValueError if restoring from checkpoint fails.
|
||||
"""
|
||||
prior_status = trial.status
|
||||
self.set_status(trial, Trial.RUNNING)
|
||||
trial.runner = self._setup_runner(trial)
|
||||
if not self.restore(trial, checkpoint):
|
||||
return
|
||||
if trial.status == Trial.ERROR:
|
||||
raise RuntimeError("Restore from checkpoint failed.")
|
||||
|
||||
previous_run = self._find_item(self._paused, trial)
|
||||
if (prior_status == Trial.PAUSED and previous_run):
|
||||
@@ -127,12 +133,15 @@ class RayTrialExecutor(TrialExecutor):
|
||||
try:
|
||||
self._start_trial(trial, checkpoint)
|
||||
except Exception:
|
||||
logger.exception("Error stopping runner - retrying...")
|
||||
logger.exception("Error starting runner. "
|
||||
"Trying again without checkpoint.")
|
||||
error_msg = traceback.format_exc()
|
||||
time.sleep(2)
|
||||
self._stop_trial(trial, error=True, error_msg=error_msg)
|
||||
try:
|
||||
self._start_trial(trial, checkpoint)
|
||||
# This forces the trial to not start from checkpoint.
|
||||
trial.clear_checkpoint()
|
||||
self._start_trial(trial)
|
||||
except Exception:
|
||||
logger.exception("Error starting runner, aborting!")
|
||||
error_msg = traceback.format_exc()
|
||||
|
||||
@@ -7,6 +7,7 @@ import json
|
||||
import time
|
||||
import os
|
||||
import pytest
|
||||
import shutil
|
||||
try:
|
||||
import pytest_timeout
|
||||
except ImportError:
|
||||
@@ -24,26 +25,6 @@ from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
|
||||
|
||||
class _Fail(tune.Trainable):
|
||||
"""Fails on the 4th iteration."""
|
||||
|
||||
def _setup(self, config):
|
||||
self.state = {"hi": 0}
|
||||
|
||||
def _train(self):
|
||||
self.state["hi"] += 1
|
||||
time.sleep(0.5)
|
||||
if self.state["hi"] >= 4:
|
||||
assert False
|
||||
return {}
|
||||
|
||||
def _save(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
self.state = state
|
||||
|
||||
|
||||
def _start_new_cluster():
|
||||
cluster = Cluster(
|
||||
initialize_head=True,
|
||||
@@ -121,38 +102,33 @@ def test_counting_resources(start_connected_cluster):
|
||||
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2
|
||||
|
||||
|
||||
@pytest.mark.skip("Add this test once reconstruction is fixed")
|
||||
@pytest.mark.skipif(
|
||||
pytest_timeout is None,
|
||||
reason="Timeout package not installed; skipping test.")
|
||||
@pytest.mark.timeout(10, method="thread")
|
||||
def test_remove_node_before_result(start_connected_cluster):
|
||||
"""Removing a node should cause a Trial to be requeued."""
|
||||
cluster = start_connected_cluster
|
||||
def test_remove_node_before_result(start_connected_emptyhead_cluster):
|
||||
"""Tune continues when node is removed before trial returns."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(resources=dict(CPU=1))
|
||||
# TODO(rliaw): Make blocking an option?
|
||||
assert cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {"stopping_criterion": {"training_iteration": 3}}
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 3
|
||||
},
|
||||
"checkpoint_freq": 2,
|
||||
"max_failures": 2
|
||||
}
|
||||
trial = Trial("__fake", **kwargs)
|
||||
runner.add_trial(trial)
|
||||
|
||||
runner.step() # run 1
|
||||
runner.step() # run 2
|
||||
assert all(t.status == Trial.RUNNING for t in trials)
|
||||
|
||||
runner.step() # 1 result
|
||||
|
||||
assert trial.status == Trial.RUNNING
|
||||
cluster.remove_node(node)
|
||||
cluster.add_node(resources=dict(CPU=1))
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.global_state.cluster_resources["CPU"] == 1
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
|
||||
runner.step() # recover
|
||||
for i in range(5):
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
assert all(t.status == Trial.TERMINATED for t in trials)
|
||||
assert trial.status == Trial.TERMINATED
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
runner.step()
|
||||
@@ -267,6 +243,40 @@ def test_trial_requeue(start_connected_emptyhead_cluster):
|
||||
runner.step()
|
||||
|
||||
|
||||
def test_migration_checkpoint_removal(start_connected_emptyhead_cluster):
|
||||
"""Test checks that trial restarts if checkpoint is lost w/ node fail."""
|
||||
cluster = start_connected_emptyhead_cluster
|
||||
node = cluster.add_node(resources=dict(CPU=1))
|
||||
assert cluster.wait_for_nodes()
|
||||
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 3
|
||||
},
|
||||
"checkpoint_freq": 2,
|
||||
"max_failures": 2
|
||||
}
|
||||
|
||||
# Test recovery of trial that has been checkpointed
|
||||
t1 = Trial("__fake", **kwargs)
|
||||
runner.add_trial(t1)
|
||||
runner.step() # start
|
||||
runner.step() # 1 result
|
||||
runner.step() # 2 result and checkpoint
|
||||
assert t1.has_checkpoint()
|
||||
cluster.add_node(resources=dict(CPU=1))
|
||||
cluster.remove_node(node)
|
||||
assert cluster.wait_for_nodes()
|
||||
shutil.rmtree(os.path.dirname(t1._checkpoint.value))
|
||||
|
||||
runner.step() # Recovery step
|
||||
for i in range(3):
|
||||
runner.step()
|
||||
|
||||
assert t1.status == Trial.TERMINATED
|
||||
|
||||
|
||||
def test_cluster_down_simple(start_connected_cluster, tmpdir):
|
||||
"""Tests that TrialRunner save/restore works on cluster shutdown."""
|
||||
cluster = start_connected_cluster
|
||||
@@ -379,7 +389,7 @@ tune.run_experiments(
|
||||
# The trainable returns every 0.5 seconds, so this should not miss
|
||||
# the checkpoint.
|
||||
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
for i in range(50):
|
||||
for i in range(100):
|
||||
if os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_NAME)):
|
||||
@@ -389,7 +399,7 @@ tune.run_experiments(
|
||||
last_res = trials[0].last_result
|
||||
if last_res is not None and last_res["training_iteration"]:
|
||||
break
|
||||
time.sleep(0.2)
|
||||
time.sleep(0.3)
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
|
||||
@@ -415,13 +425,35 @@ tune.run_experiments(
|
||||
|
||||
|
||||
def test_cluster_interrupt(start_connected_cluster, tmpdir):
|
||||
"""Tests run_experiment on cluster shutdown even with atypical trial.
|
||||
"""Tests run_experiment on cluster shutdown with actual interrupt.
|
||||
|
||||
The trial fails on the 4th step, and the checkpointing happens on
|
||||
the 3rd step, so restoring should actually launch the trial again.
|
||||
This is an end-to-end test.
|
||||
"""
|
||||
cluster = start_connected_cluster
|
||||
dirpath = str(tmpdir)
|
||||
|
||||
# Needs to be in scope for pytest
|
||||
class _Mock(tune.Trainable):
|
||||
"""Finishes on the 4th iteration."""
|
||||
|
||||
def _setup(self, config):
|
||||
self.state = {"hi": 0}
|
||||
|
||||
def _train(self):
|
||||
self.state["hi"] += 1
|
||||
time.sleep(0.5)
|
||||
return {"done": self.state["hi"] >= 4}
|
||||
|
||||
def _save(self, path):
|
||||
return self.state
|
||||
|
||||
def _restore(self, state):
|
||||
self.state = state
|
||||
|
||||
# Removes indent from class.
|
||||
reformatted = "\n".join(line[4:] if len(line) else line
|
||||
for line in inspect.getsource(_Mock).split("\n"))
|
||||
|
||||
script = """
|
||||
import time
|
||||
import ray
|
||||
@@ -444,8 +476,8 @@ tune.run_experiments(
|
||||
""".format(
|
||||
redis_address=cluster.redis_address,
|
||||
checkpoint_dir=dirpath,
|
||||
fail_class_code=inspect.getsource(_Fail),
|
||||
fail_class=_Fail.__name__)
|
||||
fail_class_code=reformatted,
|
||||
fail_class=_Mock.__name__)
|
||||
run_string_as_driver_nonblocking(script)
|
||||
|
||||
# Wait until the right checkpoint is saved.
|
||||
@@ -471,7 +503,7 @@ tune.run_experiments(
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
cluster = _start_new_cluster()
|
||||
Experiment._register_if_needed(_Fail)
|
||||
Experiment._register_if_needed(_Mock)
|
||||
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
@@ -483,13 +515,13 @@ tune.run_experiments(
|
||||
trials2 = tune.run_experiments(
|
||||
{
|
||||
"experiment": {
|
||||
"run": _Fail,
|
||||
"run": _Mock,
|
||||
"local_dir": dirpath,
|
||||
"checkpoint_freq": 1
|
||||
}
|
||||
},
|
||||
resume=True,
|
||||
raise_on_failed_trial=False)
|
||||
assert all(t.status == Trial.ERROR for t in trials2)
|
||||
assert all(t.status == Trial.TERMINATED for t in trials2)
|
||||
assert {t.trial_id for t in trials2} == {t.trial_id for t in trials}
|
||||
cluster.shutdown()
|
||||
|
||||
@@ -337,6 +337,9 @@ class Trial(object):
|
||||
def has_checkpoint(self):
|
||||
return self._checkpoint.value is not None
|
||||
|
||||
def clear_checkpoint(self):
|
||||
self._checkpoint.value = None
|
||||
|
||||
def should_recover(self):
|
||||
"""Returns whether the trial qualifies for restoring.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user