[tune] Avoid duplication in TrialRunner execution (#6598)

* avoid_duplication

* Update python/ray/tune/ray_trial_executor.py

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

Co-authored-by: Kristian Hartikainen <kristian.hartikainen@gmail.com>
This commit is contained in:
Richard Liaw
2019-12-26 02:13:55 +01:00
committed by GitHub
parent 8707a721d9
commit 93e8c85e72
+19 -1
View File
@@ -19,6 +19,7 @@ from ray.tune.trial import Trial, Checkpoint, Location
from ray.tune.resources import Resources
from ray.tune.trial_executor import TrialExecutor
from ray.tune.util import warn_if_slow
from ray.tune.error import TuneError
logger = logging.getLogger(__name__)
@@ -126,6 +127,21 @@ class RayTrialExecutor(TrialExecutor):
def _train(self, trial):
"""Start one iteration of training and save remote id."""
if self._find_item(self._paused, trial):
raise TuneError(
"Should not call `train` on PAUSED trial {}. "
"This is an internal error - please file an issue "
"on https://github.com/ray-project/ray/issues/.".format(
str(trial)))
if self._find_item(self._running, trial):
logging.debug(
"Trial {} already has a queued future. Skipping this "
"`train` call. This may occur if a trial has "
"been unpaused within a scheduler callback.".format(
str(trial)))
return
assert trial.status == Trial.RUNNING, trial.status
remote = trial.runner.train.remote()
@@ -134,6 +150,8 @@ class RayTrialExecutor(TrialExecutor):
remote = _LocalWrapper(remote)
self._running[remote] = trial
trial_item = self._find_item(self._running, trial)
assert len(trial_item) < 2, trial_item
def _start_trial(self, trial, checkpoint=None, runner=None):
"""Starts trial and restores last result if trial was paused.
@@ -308,7 +326,7 @@ class RayTrialExecutor(TrialExecutor):
trainable.reset_config.remote(new_config),
DEFAULT_GET_TIMEOUT)
except RayTimeoutError:
logger.exception("Trial %s: reset_config timed out.")
logger.exception("Trial %s: reset_config timed out.", trial)
return False
return reset_val