[tune] buffer trainable results (#13236)

* Working prototype

* Pass buffer length, fix tests

* Don't buffer per default

* Dispatch and process save in one go, added tests

* Fix tests

* Pass adaptive seconds to train_buffered, stop result processing after STOP decision

* Fix tests, add release test

* Update tests

* Added detailed logs for slow operations

* Update python/ray/tune/trial_runner.py

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

* Apply suggestions from code review

* Revert tests and go back to old tuning loop

* nit

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke
2021-01-12 18:52:47 +01:00
committed by GitHub
parent 9eebd090cf
commit 518427627b
13 changed files with 551 additions and 122 deletions
+22 -2
View File
@@ -30,6 +30,11 @@ BOTTLENECK_WARN_PERIOD_S = 60
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
DEFAULT_GET_TIMEOUT = 60.0 # seconds
TRIAL_CLEANUP_THRESHOLD = 100
TUNE_RESULT_BUFFER_LENGTH = int(os.getenv("TUNE_RESULT_BUFFER_LENGTH", 1000))
TUNE_RESULT_BUFFER_MIN_TIME_S = float(
os.getenv("TUNE_RESULT_BUFFER_MIN_TIME_S", 0.))
TUNE_RESULT_BUFFER_MAX_TIME_S = float(
os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.))
class _ActorClassCache:
@@ -257,8 +262,20 @@ class RayTrialExecutor(TrialExecutor):
return
assert trial.status == Trial.RUNNING, trial.status
buffer_time_s = max(
TUNE_RESULT_BUFFER_MIN_TIME_S,
min(TUNE_RESULT_BUFFER_MAX_TIME_S,
len(self._running) // 10))
with self._change_working_directory(trial):
remote = trial.runner.train.remote()
if TUNE_RESULT_BUFFER_LENGTH > 1:
buffer_length = TUNE_RESULT_BUFFER_LENGTH
if trial.checkpoint_freq > 0:
buffer_length = min(buffer_length, trial.checkpoint_freq)
remote = trial.runner.train_buffered.remote(
buffer_time_s, buffer_length)
else:
remote = trial.runner.train.remote()
# Local Mode
if isinstance(remote, dict):
@@ -484,7 +501,7 @@ class RayTrialExecutor(TrialExecutor):
return self._running[result_id]
def fetch_result(self, trial):
"""Fetches one result of the running trials.
"""Fetches result list of the running trials.
Returns:
Result of the most recent trial training run.
@@ -499,6 +516,9 @@ class RayTrialExecutor(TrialExecutor):
# For local mode
if isinstance(result, _LocalWrapper):
result = result.unwrap()
if not isinstance(result, list):
return [result]
return result
def _commit_resources(self, resources):
@@ -1,5 +1,6 @@
# coding: utf-8
import unittest
from unittest.mock import patch
import ray
from ray.rllib import _register_all
@@ -35,7 +36,7 @@ class RayTrialExecutorTest(unittest.TestCase):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(checkpoint, trial.saving_to)
self.assertEqual(trial.checkpoint.value, None)
@@ -48,7 +49,7 @@ class RayTrialExecutorTest(unittest.TestCase):
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.process_trial_save(trial)
self.trial_executor.restore(trial)
@@ -71,7 +72,7 @@ class RayTrialExecutorTest(unittest.TestCase):
"""Tests that pause checkpoint does not replace restore checkpoint."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
trial.last_result = self.trial_executor.fetch_result(trial)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
# Save
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
self.assertEqual(Trial.RUNNING, trial.status)
@@ -80,7 +81,7 @@ class RayTrialExecutorTest(unittest.TestCase):
self.process_trial_save(trial)
# Train
self.trial_executor.continue_training(trial)
trial.last_result = self.trial_executor.fetch_result(trial)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
# Pause
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
@@ -114,23 +115,40 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testPauseUnpause(self):
def _testPauseUnpause(self, result_buffer_length):
"""Tests that unpausing works for trials being processed."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), 1)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.unpause_trial(trial)
self.assertEqual(Trial.PENDING, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), 2)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
with patch(
"ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_LENGTH",
result_buffer_length
), patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_MIN_TIME_S",
1):
base = max(result_buffer_length, 1)
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.unpause_trial(trial)
self.assertEqual(Trial.PENDING, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
self.assertEqual(
trial.last_result.get(TRAINING_ITERATION), base * 2)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testPauseUnpauseNoBuffer(self):
self._testPauseUnpause(0)
def testPauseUnpauseTrivialBuffer(self):
self._testPauseUnpause(1)
def testPauseUnpauseActualBuffer(self):
self._testPauseUnpause(8)
def testNoResetTrial(self):
"""Tests that reset handles NotImplemented properly."""
@@ -182,7 +200,7 @@ class RayTrialExecutorTest(unittest.TestCase):
def process_trial_save(self, trial):
"""Simulates trial runner save."""
checkpoint = trial.saving_to
checkpoint_value = self.trial_executor.fetch_result(trial)
checkpoint_value = self.trial_executor.fetch_result(trial)[-1]
checkpoint.value = checkpoint_value
trial.on_checkpoint(checkpoint)
+83 -4
View File
@@ -5,11 +5,13 @@ import shutil
import sys
import tempfile
import unittest
from unittest.mock import patch
import ray
from ray.rllib import _register_all
from ray.tune import TuneError
from ray.tune.result import TRAINING_ITERATION
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial
@@ -586,10 +588,41 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertEqual(count_checkpoints(tmpdir), 2)
shutil.rmtree(tmpdir)
@patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_MIN_TIME_S", 0.5)
@patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_LENGTH", 7)
def testCheckpointFreqBuffered(self):
def num_checkpoints(trial):
return sum(
item.startswith("checkpoint_")
for item in os.listdir(trial.logdir))
ray.init(num_cpus=2)
trial = Trial("__fake", checkpoint_freq=3)
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir, checkpoint_period=0)
runner.add_trial(trial)
runner.step() # start trial
runner.step() # run iteration 1-3
runner.step() # process save
self.assertEqual(trial.last_result[TRAINING_ITERATION], 3)
self.assertEqual(num_checkpoints(trial), 1)
runner.step() # run iteration 4-6
runner.step() # process save
self.assertEqual(trial.last_result[TRAINING_ITERATION], 6)
self.assertEqual(num_checkpoints(trial), 2)
runner.step() # run iteration 7-9
runner.step() # process save
self.assertEqual(trial.last_result[TRAINING_ITERATION], 9)
self.assertEqual(num_checkpoints(trial), 3)
def testUserCheckpoint(self):
ray.init(num_cpus=3)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir, checkpoint_period=0)
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2}))
trials = runner.get_trials()
@@ -604,11 +637,57 @@ class TrialRunnerTest3(unittest.TestCase):
runner.step() # Process save
self.assertTrue(trials[0].has_checkpoint())
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
runner2.step() # 5: Start trial and dispatch restore
trials2 = runner2.get_trials()
self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1)
shutil.rmtree(tmpdir)
@patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_MIN_TIME_S", 1)
@patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_LENGTH", 8)
def testUserCheckpointBuffered(self):
def num_checkpoints(trial):
return sum(
item.startswith("checkpoint_")
for item in os.listdir(trial.logdir))
ray.init(num_cpus=3)
runner = TrialRunner(
local_checkpoint_dir=self.tmpdir, checkpoint_period=0)
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 10}))
trials = runner.get_trials()
runner.step() # Start trial, schedule 1-8
self.assertEqual(trials[0].status, Trial.RUNNING)
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
self.assertEqual(num_checkpoints(trials[0]), 0)
runner.step() # Process results 0-8, schedule 9-11 (CP)
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 8)
self.assertFalse(trials[0].has_checkpoint())
self.assertEqual(num_checkpoints(trials[0]), 0)
runner.step() # Process results 9-11
runner.step() # handle CP, schedule 12-19
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 11)
self.assertTrue(trials[0].has_checkpoint())
self.assertEqual(num_checkpoints(trials[0]), 1)
runner.step() # Process results 12-19, schedule 20-21
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 19)
self.assertTrue(trials[0].has_checkpoint())
self.assertEqual(num_checkpoints(trials[0]), 1)
runner.step() # Process results 20-21
runner.step() # handle CP, schedule 21-29
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 21)
self.assertTrue(trials[0].has_checkpoint())
self.assertEqual(num_checkpoints(trials[0]), 2)
runner.step() # Process results 21-29, schedule 30-31
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 29)
self.assertTrue(trials[0].has_checkpoint())
self.assertTrue(trials[0].has_checkpoint())
self.assertEqual(num_checkpoints(trials[0]), 2)
class SearchAlgorithmTest(unittest.TestCase):
@@ -62,7 +62,7 @@ class _MockTrialExecutor(RayTrialExecutor):
self.failed_trial = None
def fetch_result(self, trial):
return self.results.get(trial, {})
return [self.results.get(trial, {})]
def get_next_available_trial(self):
return self.next_trial or super().get_next_available_trial()
+46 -3
View File
@@ -18,9 +18,10 @@ import uuid
import ray
from ray.util.debug import log_once
from ray.tune.result import (
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION,
RESULT_DUPLICATE, TRIAL_INFO, STDOUT_FILE, STDERR_FILE)
DEFAULT_RESULTS_DIR, SHOULD_CHECKPOINT, TIME_THIS_ITER_S,
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER,
EPISODES_TOTAL, TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_INFO,
STDOUT_FILE, STDERR_FILE)
from ray.tune.utils import UtilMonitor
logger = logging.getLogger(__name__)
@@ -139,6 +140,48 @@ class Trainable:
self._local_ip = ray.services.get_node_ip_address()
return self._local_ip
def train_buffered(self,
buffer_time_s: float,
max_buffer_length: int = 1000):
"""Runs multiple iterations of training.
Calls ``train()`` internally. Collects and combines multiple results.
This function will run ``self.train()`` repeatedly until one of
the following conditions is met: 1) the maximum buffer length is
reached, 2) the maximum buffer time is reached, or 3) a checkpoint
was created. Even if the maximum time is reached, it will always
block until at least one result is received.
Args:
buffer_time_s (float): Maximum time to buffer. The next result
received after this amount of time has passed will return
the whole buffer.
max_buffer_length (int): Maximum number of results to buffer.
"""
results = []
now = time.time()
send_buffer_at = now + buffer_time_s
while now < send_buffer_at or not results: # At least one result
result = self.train()
results.append(result)
if result.get(DONE, False):
# If the trial is done, return
break
elif result.get(SHOULD_CHECKPOINT, False):
# If a checkpoint was created, return
break
elif result.get(RESULT_DUPLICATE):
# If the function API trainable completed, return
break
elif len(results) >= max_buffer_length:
# If the buffer is full, return
break
now = time.time()
return results
def train(self):
"""Runs one logical iteration of training.
+148 -86
View File
@@ -141,6 +141,7 @@ class TrialRunner:
self._trials = []
self._cached_trial_decisions = {}
self._queued_trial_decisions = {}
self._stop_queue = []
self._should_stop_experiment = False # used by TuneServer
self._local_checkpoint_dir = local_checkpoint_dir
@@ -364,7 +365,15 @@ class TrialRunner:
self._stop_experiment_if_needed()
try:
with warn_if_slow("experiment_checkpoint"):
with warn_if_slow(
"experiment_checkpoint",
message="Checkpointing the experiment state took "
"{duration:.3f} s, which may be a performance "
"bottleneck. Please ensure the "
"`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
"something significantly higher than this duration "
"to ensure compute time is mostly spent on the main "
"training loop."):
self.checkpoint()
except Exception as e:
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
@@ -467,6 +476,7 @@ class TrialRunner:
# TODO(ujvl): Consider combining get_next_available_trial and
# fetch_result functionality so that we don't timeout on fetch.
trial = self.trial_executor.get_next_available_trial() # blocking
if trial.is_restoring:
with warn_if_slow("process_trial_restore"):
self._process_trial_restore(trial)
@@ -503,6 +513,13 @@ class TrialRunner:
with warn_if_slow("process_trial"):
self._process_trial(trial)
# `self._queued_trial_decisions` now contains a final decision
# based on all results
final_decision = self._queued_trial_decisions.pop(
trial.trial_id, None)
if final_decision:
self._execute_action(trial, final_decision)
def _process_trial(self, trial):
"""Processes a trial result.
@@ -512,92 +529,38 @@ class TrialRunner:
processed (see `_process_trial_save`). Otherwise the decision is
acted on immediately.
If multiple results are received (e.g. because of buffering), all
results are processed and the final action is determined. STOP
takes precedence over PAUSE, which takes precedence over CONTINUE.
Args:
trial (Trial): Trial with a result ready to be processed.
"""
try:
result = self.trial_executor.fetch_result(trial)
result.update(trial_id=trial.trial_id)
is_duplicate = RESULT_DUPLICATE in result
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
# TrialScheduler and SearchAlgorithm still receive a
# notification because there may be special handling for
# the `on_trial_complete` hook.
if is_duplicate:
logger.debug("Trial finished without logging 'done'.")
result = trial.last_result
result.update(done=True)
self._validate_result_metrics(result)
self._total_time += result.get(TIME_THIS_ITER_S, 0)
flat_result = flatten_dict(result)
if self._stopper(trial.trial_id,
result) or trial.should_stop(flat_result):
result.update(done=True)
# Hook into scheduler
self._scheduler_alg.on_trial_complete(self, trial, flat_result)
self._search_alg.on_trial_complete(
trial.trial_id, result=flat_result)
# If this is not a duplicate result, the callbacks should
# be informed about the result.
if not is_duplicate:
with warn_if_slow("callbacks.on_trial_result"):
self._callbacks.on_trial_result(
iteration=self._iteration,
trials=self._trials,
trial=trial,
result=result.copy())
self._callbacks.on_trial_complete(
iteration=self._iteration,
trials=self._trials,
trial=trial)
decision = TrialScheduler.STOP
else:
with warn_if_slow("scheduler.on_trial_result"):
decision = self._scheduler_alg.on_trial_result(
self, trial, flat_result)
if decision == TrialScheduler.STOP:
result.update(done=True)
with warn_if_slow("search_alg.on_trial_result"):
self._search_alg.on_trial_result(trial.trial_id,
flat_result)
with warn_if_slow("callbacks.on_trial_result"):
self._callbacks.on_trial_result(
iteration=self._iteration,
trials=self._trials,
trial=trial,
result=result.copy())
if decision == TrialScheduler.STOP:
with warn_if_slow("search_alg.on_trial_complete"):
self._search_alg.on_trial_complete(
trial.trial_id, result=flat_result)
with warn_if_slow("callbacks.on_trial_complete"):
self._callbacks.on_trial_complete(
iteration=self._iteration,
trials=self._trials,
trial=trial)
if not is_duplicate:
trial.update_last_result(
result, terminate=(decision == TrialScheduler.STOP))
# Checkpoints to disk. This should be checked even if
# the scheduler decision is STOP or PAUSE. Note that
# PAUSE only checkpoints to memory and does not update
# the global checkpoint state.
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
if trial.is_saving:
# Cache decision to execute on after the save is processed.
# This prevents changing the trial's state or kicking off
# another training step prematurely.
self._cached_trial_decisions[trial.trial_id] = decision
else:
self._execute_action(trial, decision)
results = self.trial_executor.fetch_result(trial)
with warn_if_slow(
"process_trial_results",
message="Processing trial results took {duration:.3f} s, "
"which may be a performance bottleneck. Please consider "
"reporting results less frequently to Ray Tune."):
for i, result in enumerate(results):
with warn_if_slow("process_trial_result"):
decision = self._process_trial_result(trial, result)
if decision is None:
# If we didn't get a decision, this means a
# non-training future (e.g. a save) was scheduled.
# We do not allow processing more results then.
if i < len(results) - 1:
raise RuntimeError(
f"Trial {trial} has a non-training future "
f"scheduled but {len(results)-i} results "
f"left to process. This should never "
f"happen - please file an issue at "
f"https://github.com/ray-project/ray/issues")
elif decision == TrialScheduler.STOP:
# If the decision is to stop the trial,
# ignore all results that came after that.
break
except Exception:
error_msg = "Trial %s: Error processing event." % trial
if self._fail_fast == TrialRunner.RAISE:
@@ -607,6 +570,88 @@ class TrialRunner:
logger.exception(error_msg)
self._process_trial_failure(trial, traceback.format_exc())
def _process_trial_result(self, trial, result):
result.update(trial_id=trial.trial_id)
is_duplicate = RESULT_DUPLICATE in result
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
# TrialScheduler and SearchAlgorithm still receive a
# notification because there may be special handling for
# the `on_trial_complete` hook.
if is_duplicate:
logger.debug("Trial finished without logging 'done'.")
result = trial.last_result
result.update(done=True)
self._validate_result_metrics(result)
self._total_time += result.get(TIME_THIS_ITER_S, 0)
flat_result = flatten_dict(result)
if self._stopper(trial.trial_id,
result) or trial.should_stop(flat_result):
result.update(done=True)
# Hook into scheduler
self._scheduler_alg.on_trial_complete(self, trial, flat_result)
self._search_alg.on_trial_complete(
trial.trial_id, result=flat_result)
# If this is not a duplicate result, the callbacks should
# be informed about the result.
if not is_duplicate:
with warn_if_slow("callbacks.on_trial_result"):
self._callbacks.on_trial_result(
iteration=self._iteration,
trials=self._trials,
trial=trial,
result=result.copy())
self._callbacks.on_trial_complete(
iteration=self._iteration, trials=self._trials, trial=trial)
decision = TrialScheduler.STOP
else:
with warn_if_slow("scheduler.on_trial_result"):
decision = self._scheduler_alg.on_trial_result(
self, trial, flat_result)
if decision == TrialScheduler.STOP:
result.update(done=True)
with warn_if_slow("search_alg.on_trial_result"):
self._search_alg.on_trial_result(trial.trial_id, flat_result)
with warn_if_slow("callbacks.on_trial_result"):
self._callbacks.on_trial_result(
iteration=self._iteration,
trials=self._trials,
trial=trial,
result=result.copy())
if decision == TrialScheduler.STOP:
with warn_if_slow("search_alg.on_trial_complete"):
self._search_alg.on_trial_complete(
trial.trial_id, result=flat_result)
with warn_if_slow("callbacks.on_trial_complete"):
self._callbacks.on_trial_complete(
iteration=self._iteration,
trials=self._trials,
trial=trial)
if not is_duplicate:
trial.update_last_result(
result, terminate=(decision == TrialScheduler.STOP))
# Checkpoints to disk. This should be checked even if
# the scheduler decision is STOP or PAUSE. Note that
# PAUSE only checkpoints to memory and does not update
# the global checkpoint state.
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
if trial.is_saving:
# Cache decision to execute on after the save is processed.
# This prevents changing the trial's state or kicking off
# another training step prematurely.
self._cached_trial_decisions[trial.trial_id] = decision
return None
else:
self._queue_decision(trial, decision)
return decision
def _validate_result_metrics(self, result):
"""
Check if any of the required metrics was not reported
@@ -669,7 +714,8 @@ class TrialRunner:
checkpoint_value = None
try:
checkpoint_value = self.trial_executor.fetch_result(trial)
results = self.trial_executor.fetch_result(trial)
checkpoint_value = results[-1]
except Exception:
logger.exception("Trial %s: Error processing result.", trial)
if self._fail_fast == TrialRunner.RAISE:
@@ -695,7 +741,7 @@ class TrialRunner:
trial.saving_to = None
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
if decision and checkpoint_value:
self._execute_action(trial, decision)
self._queue_decision(trial, decision)
def _process_trial_restore(self, trial):
"""Processes a trial restore.
@@ -739,6 +785,21 @@ class TrialRunner:
self.trial_executor.stop_trial(
trial, error=True, error_msg=error_msg)
def _queue_decision(self, trial, decision):
# Get old decision, setting it to the current decision if it isn't set
old_decision = self._queued_trial_decisions.setdefault(
trial.trial_id, decision)
# Stopping always takes precedence. If we decided to stop, just quit
if old_decision is TrialScheduler.STOP:
return
# The old decision wasn't STOP. We update the decision only if it is
# STOP or PAUSE. The action will only be CONTINUE if it was set by
# the first received result and was never updated after that.
if decision is TrialScheduler.STOP or decision is TrialScheduler.PAUSE:
self._queued_trial_decisions[trial.trial_id] = decision
def _execute_action(self, trial, decision):
"""Executes action based on decision.
@@ -878,7 +939,8 @@ class TrialRunner:
iteration=self._iteration, trials=self._trials, trial=trial)
elif trial.status is Trial.RUNNING:
try:
result = self.trial_executor.fetch_result(trial)
results = self.trial_executor.fetch_result(trial)
result = results[-1]
trial.update_last_result(result, terminate=True)
self._scheduler_alg.on_trial_complete(self, trial, result)
self._search_alg.on_trial_complete(
+10 -4
View File
@@ -11,6 +11,7 @@ from collections import defaultdict, deque
from collections.abc import Mapping, Sequence
from datetime import datetime
from threading import Thread
from typing import Optional
import numpy as np
import ray
@@ -124,10 +125,16 @@ class warn_if_slow:
"""
DEFAULT_THRESHOLD = float(os.environ.get("TUNE_WARN_THRESHOLD_S", 0.5))
DEFAULT_MESSAGE = "The `{name}` operation took {duration:.3f} s, " \
"which may be a performance bottleneck."
def __init__(self, name, threshold=None):
def __init__(self,
name: str,
threshold: Optional[float] = None,
message: Optional[str] = None):
self.name = name
self.threshold = threshold or self.DEFAULT_THRESHOLD
self.message = message or self.DEFAULT_MESSAGE
self.too_slow = False
def __enter__(self):
@@ -138,10 +145,9 @@ class warn_if_slow:
now = time.time()
if now - self.start > self.threshold and now - START_OF_TIME > 60.0:
self.too_slow = True
_duration = now - self.start
duration = now - self.start
logger.warning(
f"The `{self.name}` operation took {_duration:.3f} s, "
"which may be a performance bottleneck.")
self.message.format(name=self.name, duration=duration))
class Tee(object):