diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 3706d00f0..384f18802 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -61,11 +61,17 @@ class RayTrialExecutor(TrialExecutor): self._running[remote] = trial def _start_trial(self, trial, checkpoint=None): + """Starts trial and restores last result if trial was paused. + + Raises: + ValueError if restoring from checkpoint fails. + """ prior_status = trial.status self.set_status(trial, Trial.RUNNING) trial.runner = self._setup_runner(trial) if not self.restore(trial, checkpoint): - return + if trial.status == Trial.ERROR: + raise RuntimeError("Restore from checkpoint failed.") previous_run = self._find_item(self._paused, trial) if (prior_status == Trial.PAUSED and previous_run): @@ -127,12 +133,15 @@ class RayTrialExecutor(TrialExecutor): try: self._start_trial(trial, checkpoint) except Exception: - logger.exception("Error stopping runner - retrying...") + logger.exception("Error starting runner. " + "Trying again without checkpoint.") error_msg = traceback.format_exc() time.sleep(2) self._stop_trial(trial, error=True, error_msg=error_msg) try: - self._start_trial(trial, checkpoint) + # This forces the trial to not start from checkpoint. + trial.clear_checkpoint() + self._start_trial(trial) except Exception: logger.exception("Error starting runner, aborting!") error_msg = traceback.format_exc() diff --git a/python/ray/tune/test/cluster_tests.py b/python/ray/tune/test/cluster_tests.py index abb1fe394..71b2675f2 100644 --- a/python/ray/tune/test/cluster_tests.py +++ b/python/ray/tune/test/cluster_tests.py @@ -7,6 +7,7 @@ import json import time import os import pytest +import shutil try: import pytest_timeout except ImportError: @@ -24,26 +25,6 @@ from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import BasicVariantGenerator -class _Fail(tune.Trainable): - """Fails on the 4th iteration.""" - - def _setup(self, config): - self.state = {"hi": 0} - - def _train(self): - self.state["hi"] += 1 - time.sleep(0.5) - if self.state["hi"] >= 4: - assert False - return {} - - def _save(self, path): - return self.state - - def _restore(self, state): - self.state = state - - def _start_new_cluster(): cluster = Cluster( initialize_head=True, @@ -121,38 +102,33 @@ def test_counting_resources(start_connected_cluster): assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2 -@pytest.mark.skip("Add this test once reconstruction is fixed") -@pytest.mark.skipif( - pytest_timeout is None, - reason="Timeout package not installed; skipping test.") -@pytest.mark.timeout(10, method="thread") -def test_remove_node_before_result(start_connected_cluster): - """Removing a node should cause a Trial to be requeued.""" - cluster = start_connected_cluster +def test_remove_node_before_result(start_connected_emptyhead_cluster): + """Tune continues when node is removed before trial returns.""" + cluster = start_connected_emptyhead_cluster node = cluster.add_node(resources=dict(CPU=1)) - # TODO(rliaw): Make blocking an option? assert cluster.wait_for_nodes() runner = TrialRunner(BasicVariantGenerator()) - kwargs = {"stopping_criterion": {"training_iteration": 3}} - trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)] - for t in trials: - runner.add_trial(t) + kwargs = { + "stopping_criterion": { + "training_iteration": 3 + }, + "checkpoint_freq": 2, + "max_failures": 2 + } + trial = Trial("__fake", **kwargs) + runner.add_trial(trial) runner.step() # run 1 - runner.step() # run 2 - assert all(t.status == Trial.RUNNING for t in trials) - - runner.step() # 1 result - + assert trial.status == Trial.RUNNING cluster.remove_node(node) + cluster.add_node(resources=dict(CPU=1)) cluster.wait_for_nodes() - assert ray.global_state.cluster_resources["CPU"] == 1 + assert ray.global_state.cluster_resources()["CPU"] == 1 - runner.step() # recover - for i in range(5): + for i in range(3): runner.step() - assert all(t.status == Trial.TERMINATED for t in trials) + assert trial.status == Trial.TERMINATED with pytest.raises(TuneError): runner.step() @@ -267,6 +243,40 @@ def test_trial_requeue(start_connected_emptyhead_cluster): runner.step() +def test_migration_checkpoint_removal(start_connected_emptyhead_cluster): + """Test checks that trial restarts if checkpoint is lost w/ node fail.""" + cluster = start_connected_emptyhead_cluster + node = cluster.add_node(resources=dict(CPU=1)) + assert cluster.wait_for_nodes() + + runner = TrialRunner(BasicVariantGenerator()) + kwargs = { + "stopping_criterion": { + "training_iteration": 3 + }, + "checkpoint_freq": 2, + "max_failures": 2 + } + + # Test recovery of trial that has been checkpointed + t1 = Trial("__fake", **kwargs) + runner.add_trial(t1) + runner.step() # start + runner.step() # 1 result + runner.step() # 2 result and checkpoint + assert t1.has_checkpoint() + cluster.add_node(resources=dict(CPU=1)) + cluster.remove_node(node) + assert cluster.wait_for_nodes() + shutil.rmtree(os.path.dirname(t1._checkpoint.value)) + + runner.step() # Recovery step + for i in range(3): + runner.step() + + assert t1.status == Trial.TERMINATED + + def test_cluster_down_simple(start_connected_cluster, tmpdir): """Tests that TrialRunner save/restore works on cluster shutdown.""" cluster = start_connected_cluster @@ -379,7 +389,7 @@ tune.run_experiments( # The trainable returns every 0.5 seconds, so this should not miss # the checkpoint. metadata_checkpoint_dir = os.path.join(dirpath, "experiment") - for i in range(50): + for i in range(100): if os.path.exists( os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)): @@ -389,7 +399,7 @@ tune.run_experiments( last_res = trials[0].last_result if last_res is not None and last_res["training_iteration"]: break - time.sleep(0.2) + time.sleep(0.3) if not os.path.exists( os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)): @@ -415,13 +425,35 @@ tune.run_experiments( def test_cluster_interrupt(start_connected_cluster, tmpdir): - """Tests run_experiment on cluster shutdown even with atypical trial. + """Tests run_experiment on cluster shutdown with actual interrupt. - The trial fails on the 4th step, and the checkpointing happens on - the 3rd step, so restoring should actually launch the trial again. + This is an end-to-end test. """ cluster = start_connected_cluster dirpath = str(tmpdir) + + # Needs to be in scope for pytest + class _Mock(tune.Trainable): + """Finishes on the 4th iteration.""" + + def _setup(self, config): + self.state = {"hi": 0} + + def _train(self): + self.state["hi"] += 1 + time.sleep(0.5) + return {"done": self.state["hi"] >= 4} + + def _save(self, path): + return self.state + + def _restore(self, state): + self.state = state + + # Removes indent from class. + reformatted = "\n".join(line[4:] if len(line) else line + for line in inspect.getsource(_Mock).split("\n")) + script = """ import time import ray @@ -444,8 +476,8 @@ tune.run_experiments( """.format( redis_address=cluster.redis_address, checkpoint_dir=dirpath, - fail_class_code=inspect.getsource(_Fail), - fail_class=_Fail.__name__) + fail_class_code=reformatted, + fail_class=_Mock.__name__) run_string_as_driver_nonblocking(script) # Wait until the right checkpoint is saved. @@ -471,7 +503,7 @@ tune.run_experiments( ray.shutdown() cluster.shutdown() cluster = _start_new_cluster() - Experiment._register_if_needed(_Fail) + Experiment._register_if_needed(_Mock) # Inspect the internal trialrunner runner = TrialRunner.restore(metadata_checkpoint_dir) @@ -483,13 +515,13 @@ tune.run_experiments( trials2 = tune.run_experiments( { "experiment": { - "run": _Fail, + "run": _Mock, "local_dir": dirpath, "checkpoint_freq": 1 } }, resume=True, raise_on_failed_trial=False) - assert all(t.status == Trial.ERROR for t in trials2) + assert all(t.status == Trial.TERMINATED for t in trials2) assert {t.trial_id for t in trials2} == {t.trial_id for t in trials} cluster.shutdown() diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 66406231a..4ebfee187 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -337,6 +337,9 @@ class Trial(object): def has_checkpoint(self): return self._checkpoint.value is not None + def clear_checkpoint(self): + self._checkpoint.value = None + def should_recover(self): """Returns whether the trial qualifies for restoring.