mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:12:00 +08:00
[tune] Avoid duplication in TrialRunner execution (#6598)
* avoid_duplication * Update python/ray/tune/ray_trial_executor.py Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com> Co-authored-by: Kristian Hartikainen <kristian.hartikainen@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from ray.tune.trial import Trial, Checkpoint, Location
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
from ray.tune.util import warn_if_slow
|
||||
from ray.tune.error import TuneError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -126,6 +127,21 @@ class RayTrialExecutor(TrialExecutor):
|
||||
def _train(self, trial):
|
||||
"""Start one iteration of training and save remote id."""
|
||||
|
||||
if self._find_item(self._paused, trial):
|
||||
raise TuneError(
|
||||
"Should not call `train` on PAUSED trial {}. "
|
||||
"This is an internal error - please file an issue "
|
||||
"on https://github.com/ray-project/ray/issues/.".format(
|
||||
str(trial)))
|
||||
|
||||
if self._find_item(self._running, trial):
|
||||
logging.debug(
|
||||
"Trial {} already has a queued future. Skipping this "
|
||||
"`train` call. This may occur if a trial has "
|
||||
"been unpaused within a scheduler callback.".format(
|
||||
str(trial)))
|
||||
return
|
||||
|
||||
assert trial.status == Trial.RUNNING, trial.status
|
||||
remote = trial.runner.train.remote()
|
||||
|
||||
@@ -134,6 +150,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||
remote = _LocalWrapper(remote)
|
||||
|
||||
self._running[remote] = trial
|
||||
trial_item = self._find_item(self._running, trial)
|
||||
assert len(trial_item) < 2, trial_item
|
||||
|
||||
def _start_trial(self, trial, checkpoint=None, runner=None):
|
||||
"""Starts trial and restores last result if trial was paused.
|
||||
@@ -308,7 +326,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
trainable.reset_config.remote(new_config),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
except RayTimeoutError:
|
||||
logger.exception("Trial %s: reset_config timed out.")
|
||||
logger.exception("Trial %s: reset_config timed out.", trial)
|
||||
return False
|
||||
return reset_val
|
||||
|
||||
|
||||
Reference in New Issue
Block a user