diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index d5868560e..5e8c1f42a 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -301,18 +301,22 @@ class RayTrialExecutor(TrialExecutor): def get_current_trial_ips(self): return {t.node_ip for t in self.get_running_trials()} - def get_next_available_trial(self): + def get_next_failed_trial(self): + """Gets the first trial found to be running on a node presumed dead. + + Returns: + A Trial object that is ready for failure processing. None if + no failure detected. + """ if ray.worker._mode() != ray.worker.LOCAL_MODE: live_cluster_ips = self.get_alive_node_ips() if live_cluster_ips - self.get_current_trial_ips(): for trial in self.get_running_trials(): if trial.node_ip and trial.node_ip not in live_cluster_ips: - logger.warning( - "{} (ip: {}) detected as stale. This is likely " - "because the node was lost. Processing this " - "trial first.".format(trial, trial.node_ip)) return trial + return None + def get_next_available_trial(self): shuffled_results = list(self._running.keys()) random.shuffle(shuffled_results) # Note: We shuffle the results because `ray.wait` by default returns diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index ad18079f8..ae287cfda 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -8,6 +8,7 @@ import time import os import pytest import shutil +import sys import ray from ray import tune @@ -20,6 +21,11 @@ from ray.tune.trial import Trial from ray.tune.trial_runner import TrialRunner from ray.tune.suggest import BasicVariantGenerator +if sys.version_info >= (3, 3): + from unittest.mock import MagicMock +else: + from mock import MagicMock + def _start_new_cluster(): cluster = Cluster( @@ -98,6 +104,26 @@ def test_counting_resources(start_connected_cluster): assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2 +def test_trial_processed_after_node_failure(start_connected_emptyhead_cluster): + """Tests that Tune processes a trial as failed if its node died.""" + cluster = start_connected_emptyhead_cluster + node = cluster.add_node(num_cpus=1) + cluster.wait_for_nodes() + + runner = TrialRunner(BasicVariantGenerator()) + mock_process_failure = MagicMock(side_effect=runner._process_trial_failure) + runner._process_trial_failure = mock_process_failure + + runner.add_trial(Trial("__fake")) + runner.step() + runner.step() + assert not mock_process_failure.called + + cluster.remove_node(node) + runner.step() + assert mock_process_failure.called + + def test_remove_node_before_result(start_connected_emptyhead_cluster): """Tune continues when node is removed before trial returns.""" cluster = start_connected_emptyhead_cluster diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 8bb944bec..1364ddebe 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -158,6 +158,15 @@ class TrialExecutor(object): """ raise NotImplementedError + def get_next_failed_trial(self): + """Non-blocking call that detects and returns one failed trial. + + Returns: + A Trial object that is ready for failure processing. None if + no failure detected. + """ + raise NotImplementedError + def fetch_result(self, trial): """Fetches one result for the trial. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 21bd11191..38dd07c56 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -497,9 +497,18 @@ class TrialRunner(object): return trial def _process_events(self): - trial = self.trial_executor.get_next_available_trial() # blocking - with warn_if_slow("process_trial"): - self._process_trial(trial) + failed_trial = self.trial_executor.get_next_failed_trial() + if failed_trial: + with warn_if_slow("process_failed_trial"): + self._process_trial_failure( + failed_trial, + error_msg="{} (ip: {}) detected as stale. This is likely" + "because the node was lost".format(failed_trial, + failed_trial.node_ip)) + else: + trial = self.trial_executor.get_next_available_trial() # blocking + with warn_if_slow("process_trial"): + self._process_trial(trial) def _process_trial(self, trial): try: @@ -558,16 +567,25 @@ class TrialRunner(object): decision) except Exception: logger.exception("Error processing event.") - error_msg = traceback.format_exc() - if trial.status == Trial.RUNNING: - if trial.should_recover(): - self._try_recover(trial, error_msg) - else: - self._scheduler_alg.on_trial_error(self, trial) - self._search_alg.on_trial_complete( - trial.trial_id, error=True) - self.trial_executor.stop_trial( - trial, error=True, error_msg=error_msg) + self._process_trial_failure(trial, traceback.format_exc()) + + def _process_trial_failure(self, trial, error_msg): + """Handle trial failure. + + Attempt trial recovery if possible, clean up state otherwise. + + Args: + trial (Trial): Failed trial. + error_msg (str): Error message prior to invoking this method. + """ + if trial.status == Trial.RUNNING: + if trial.should_recover(): + self._try_recover(trial, error_msg) + else: + self._scheduler_alg.on_trial_error(self, trial) + self._search_alg.on_trial_complete(trial.trial_id, error=True) + self.trial_executor.stop_trial( + trial, error=True, error_msg=error_msg) def _checkpoint_trial_if_needed(self, trial, force=False): """Checkpoints trial based off trial.last_result."""