diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 52b8c63f5..86cb6d5a7 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -19,6 +19,7 @@ from ray.tune.trial import Trial, Checkpoint, Location from ray.tune.resources import Resources from ray.tune.trial_executor import TrialExecutor from ray.tune.util import warn_if_slow +from ray.tune.error import TuneError logger = logging.getLogger(__name__) @@ -126,6 +127,21 @@ class RayTrialExecutor(TrialExecutor): def _train(self, trial): """Start one iteration of training and save remote id.""" + if self._find_item(self._paused, trial): + raise TuneError( + "Should not call `train` on PAUSED trial {}. " + "This is an internal error - please file an issue " + "on https://github.com/ray-project/ray/issues/.".format( + str(trial))) + + if self._find_item(self._running, trial): + logging.debug( + "Trial {} already has a queued future. Skipping this " + "`train` call. This may occur if a trial has " + "been unpaused within a scheduler callback.".format( + str(trial))) + return + assert trial.status == Trial.RUNNING, trial.status remote = trial.runner.train.remote() @@ -134,6 +150,8 @@ class RayTrialExecutor(TrialExecutor): remote = _LocalWrapper(remote) self._running[remote] = trial + trial_item = self._find_item(self._running, trial) + assert len(trial_item) < 2, trial_item def _start_trial(self, trial, checkpoint=None, runner=None): """Starts trial and restores last result if trial was paused. @@ -308,7 +326,7 @@ class RayTrialExecutor(TrialExecutor): trainable.reset_config.remote(new_config), DEFAULT_GET_TIMEOUT) except RayTimeoutError: - logger.exception("Trial %s: reset_config timed out.") + logger.exception("Trial %s: reset_config timed out.", trial) return False return reset_val