mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[tune] buffer trainable results (#13236)
* Working prototype * Pass buffer length, fix tests * Don't buffer per default * Dispatch and process save in one go, added tests * Fix tests * Pass adaptive seconds to train_buffered, stop result processing after STOP decision * Fix tests, add release test * Update tests * Added detailed logs for slow operations * Update python/ray/tune/trial_runner.py Co-authored-by: Richard Liaw <rliaw@berkeley.edu> * Apply suggestions from code review * Revert tests and go back to old tuning loop * nit Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
@@ -30,6 +30,11 @@ BOTTLENECK_WARN_PERIOD_S = 60
|
||||
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
|
||||
DEFAULT_GET_TIMEOUT = 60.0 # seconds
|
||||
TRIAL_CLEANUP_THRESHOLD = 100
|
||||
TUNE_RESULT_BUFFER_LENGTH = int(os.getenv("TUNE_RESULT_BUFFER_LENGTH", 1000))
|
||||
TUNE_RESULT_BUFFER_MIN_TIME_S = float(
|
||||
os.getenv("TUNE_RESULT_BUFFER_MIN_TIME_S", 0.))
|
||||
TUNE_RESULT_BUFFER_MAX_TIME_S = float(
|
||||
os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.))
|
||||
|
||||
|
||||
class _ActorClassCache:
|
||||
@@ -257,8 +262,20 @@ class RayTrialExecutor(TrialExecutor):
|
||||
return
|
||||
|
||||
assert trial.status == Trial.RUNNING, trial.status
|
||||
buffer_time_s = max(
|
||||
TUNE_RESULT_BUFFER_MIN_TIME_S,
|
||||
min(TUNE_RESULT_BUFFER_MAX_TIME_S,
|
||||
len(self._running) // 10))
|
||||
with self._change_working_directory(trial):
|
||||
remote = trial.runner.train.remote()
|
||||
if TUNE_RESULT_BUFFER_LENGTH > 1:
|
||||
buffer_length = TUNE_RESULT_BUFFER_LENGTH
|
||||
if trial.checkpoint_freq > 0:
|
||||
buffer_length = min(buffer_length, trial.checkpoint_freq)
|
||||
|
||||
remote = trial.runner.train_buffered.remote(
|
||||
buffer_time_s, buffer_length)
|
||||
else:
|
||||
remote = trial.runner.train.remote()
|
||||
|
||||
# Local Mode
|
||||
if isinstance(remote, dict):
|
||||
@@ -484,7 +501,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
return self._running[result_id]
|
||||
|
||||
def fetch_result(self, trial):
|
||||
"""Fetches one result of the running trials.
|
||||
"""Fetches result list of the running trials.
|
||||
|
||||
Returns:
|
||||
Result of the most recent trial training run.
|
||||
@@ -499,6 +516,9 @@ class RayTrialExecutor(TrialExecutor):
|
||||
# For local mode
|
||||
if isinstance(result, _LocalWrapper):
|
||||
result = result.unwrap()
|
||||
|
||||
if not isinstance(result, list):
|
||||
return [result]
|
||||
return result
|
||||
|
||||
def _commit_resources(self, resources):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# coding: utf-8
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
@@ -35,7 +36,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
||||
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
||||
self.assertEqual(checkpoint, trial.saving_to)
|
||||
self.assertEqual(trial.checkpoint.value, None)
|
||||
@@ -48,7 +49,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
||||
self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
||||
self.process_trial_save(trial)
|
||||
self.trial_executor.restore(trial)
|
||||
@@ -71,7 +72,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
"""Tests that pause checkpoint does not replace restore checkpoint."""
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
||||
# Save
|
||||
checkpoint = self.trial_executor.save(trial, Checkpoint.PERSISTENT)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
@@ -80,7 +81,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
self.process_trial_save(trial)
|
||||
# Train
|
||||
self.trial_executor.continue_training(trial)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
||||
# Pause
|
||||
self.trial_executor.pause_trial(trial)
|
||||
self.assertEqual(Trial.PAUSED, trial.status)
|
||||
@@ -114,23 +115,40 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testPauseUnpause(self):
|
||||
def _testPauseUnpause(self, result_buffer_length):
|
||||
"""Tests that unpausing works for trials being processed."""
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)
|
||||
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), 1)
|
||||
self.trial_executor.pause_trial(trial)
|
||||
self.assertEqual(Trial.PAUSED, trial.status)
|
||||
self.trial_executor.unpause_trial(trial)
|
||||
self.assertEqual(Trial.PENDING, trial.status)
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)
|
||||
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), 2)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
with patch(
|
||||
"ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_LENGTH",
|
||||
result_buffer_length
|
||||
), patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_MIN_TIME_S",
|
||||
1):
|
||||
base = max(result_buffer_length, 1)
|
||||
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
||||
self.assertEqual(trial.last_result.get(TRAINING_ITERATION), base)
|
||||
self.trial_executor.pause_trial(trial)
|
||||
self.assertEqual(Trial.PAUSED, trial.status)
|
||||
self.trial_executor.unpause_trial(trial)
|
||||
self.assertEqual(Trial.PENDING, trial.status)
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
trial.last_result = self.trial_executor.fetch_result(trial)[-1]
|
||||
self.assertEqual(
|
||||
trial.last_result.get(TRAINING_ITERATION), base * 2)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testPauseUnpauseNoBuffer(self):
|
||||
self._testPauseUnpause(0)
|
||||
|
||||
def testPauseUnpauseTrivialBuffer(self):
|
||||
self._testPauseUnpause(1)
|
||||
|
||||
def testPauseUnpauseActualBuffer(self):
|
||||
self._testPauseUnpause(8)
|
||||
|
||||
def testNoResetTrial(self):
|
||||
"""Tests that reset handles NotImplemented properly."""
|
||||
@@ -182,7 +200,7 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
def process_trial_save(self, trial):
|
||||
"""Simulates trial runner save."""
|
||||
checkpoint = trial.saving_to
|
||||
checkpoint_value = self.trial_executor.fetch_result(trial)
|
||||
checkpoint_value = self.trial_executor.fetch_result(trial)[-1]
|
||||
checkpoint.value = checkpoint_value
|
||||
trial.on_checkpoint(checkpoint)
|
||||
|
||||
|
||||
@@ -5,11 +5,13 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import ray
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.schedulers import TrialScheduler, FIFOScheduler
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial
|
||||
@@ -586,10 +588,41 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
self.assertEqual(count_checkpoints(tmpdir), 2)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_MIN_TIME_S", 0.5)
|
||||
@patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_LENGTH", 7)
|
||||
def testCheckpointFreqBuffered(self):
|
||||
def num_checkpoints(trial):
|
||||
return sum(
|
||||
item.startswith("checkpoint_")
|
||||
for item in os.listdir(trial.logdir))
|
||||
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
trial = Trial("__fake", checkpoint_freq=3)
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=self.tmpdir, checkpoint_period=0)
|
||||
runner.add_trial(trial)
|
||||
|
||||
runner.step() # start trial
|
||||
runner.step() # run iteration 1-3
|
||||
runner.step() # process save
|
||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 3)
|
||||
self.assertEqual(num_checkpoints(trial), 1)
|
||||
|
||||
runner.step() # run iteration 4-6
|
||||
runner.step() # process save
|
||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 6)
|
||||
self.assertEqual(num_checkpoints(trial), 2)
|
||||
|
||||
runner.step() # run iteration 7-9
|
||||
runner.step() # process save
|
||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 9)
|
||||
self.assertEqual(num_checkpoints(trial), 3)
|
||||
|
||||
def testUserCheckpoint(self):
|
||||
ray.init(num_cpus=3)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0)
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=self.tmpdir, checkpoint_period=0)
|
||||
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 2}))
|
||||
trials = runner.get_trials()
|
||||
|
||||
@@ -604,11 +637,57 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
runner.step() # Process save
|
||||
self.assertTrue(trials[0].has_checkpoint())
|
||||
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
|
||||
runner2.step() # 5: Start trial and dispatch restore
|
||||
trials2 = runner2.get_trials()
|
||||
self.assertEqual(ray.get(trials2[0].runner.get_info.remote()), 1)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
@patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_MIN_TIME_S", 1)
|
||||
@patch("ray.tune.ray_trial_executor.TUNE_RESULT_BUFFER_LENGTH", 8)
|
||||
def testUserCheckpointBuffered(self):
|
||||
def num_checkpoints(trial):
|
||||
return sum(
|
||||
item.startswith("checkpoint_")
|
||||
for item in os.listdir(trial.logdir))
|
||||
|
||||
ray.init(num_cpus=3)
|
||||
runner = TrialRunner(
|
||||
local_checkpoint_dir=self.tmpdir, checkpoint_period=0)
|
||||
runner.add_trial(Trial("__fake", config={"user_checkpoint_freq": 10}))
|
||||
trials = runner.get_trials()
|
||||
|
||||
runner.step() # Start trial, schedule 1-8
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1)
|
||||
self.assertEqual(num_checkpoints(trials[0]), 0)
|
||||
|
||||
runner.step() # Process results 0-8, schedule 9-11 (CP)
|
||||
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 8)
|
||||
self.assertFalse(trials[0].has_checkpoint())
|
||||
self.assertEqual(num_checkpoints(trials[0]), 0)
|
||||
|
||||
runner.step() # Process results 9-11
|
||||
runner.step() # handle CP, schedule 12-19
|
||||
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 11)
|
||||
self.assertTrue(trials[0].has_checkpoint())
|
||||
self.assertEqual(num_checkpoints(trials[0]), 1)
|
||||
|
||||
runner.step() # Process results 12-19, schedule 20-21
|
||||
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 19)
|
||||
self.assertTrue(trials[0].has_checkpoint())
|
||||
self.assertEqual(num_checkpoints(trials[0]), 1)
|
||||
|
||||
runner.step() # Process results 20-21
|
||||
runner.step() # handle CP, schedule 21-29
|
||||
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 21)
|
||||
self.assertTrue(trials[0].has_checkpoint())
|
||||
self.assertEqual(num_checkpoints(trials[0]), 2)
|
||||
|
||||
runner.step() # Process results 21-29, schedule 30-31
|
||||
self.assertEqual(trials[0].last_result.get(TRAINING_ITERATION), 29)
|
||||
self.assertTrue(trials[0].has_checkpoint())
|
||||
self.assertTrue(trials[0].has_checkpoint())
|
||||
self.assertEqual(num_checkpoints(trials[0]), 2)
|
||||
|
||||
|
||||
class SearchAlgorithmTest(unittest.TestCase):
|
||||
|
||||
@@ -62,7 +62,7 @@ class _MockTrialExecutor(RayTrialExecutor):
|
||||
self.failed_trial = None
|
||||
|
||||
def fetch_result(self, trial):
|
||||
return self.results.get(trial, {})
|
||||
return [self.results.get(trial, {})]
|
||||
|
||||
def get_next_available_trial(self):
|
||||
return self.next_trial or super().get_next_available_trial()
|
||||
|
||||
@@ -18,9 +18,10 @@ import uuid
|
||||
import ray
|
||||
from ray.util.debug import log_once
|
||||
from ray.tune.result import (
|
||||
DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, TIMESTEPS_THIS_ITER, DONE,
|
||||
TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION,
|
||||
RESULT_DUPLICATE, TRIAL_INFO, STDOUT_FILE, STDERR_FILE)
|
||||
DEFAULT_RESULTS_DIR, SHOULD_CHECKPOINT, TIME_THIS_ITER_S,
|
||||
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER,
|
||||
EPISODES_TOTAL, TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_INFO,
|
||||
STDOUT_FILE, STDERR_FILE)
|
||||
from ray.tune.utils import UtilMonitor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -139,6 +140,48 @@ class Trainable:
|
||||
self._local_ip = ray.services.get_node_ip_address()
|
||||
return self._local_ip
|
||||
|
||||
def train_buffered(self,
|
||||
buffer_time_s: float,
|
||||
max_buffer_length: int = 1000):
|
||||
"""Runs multiple iterations of training.
|
||||
|
||||
Calls ``train()`` internally. Collects and combines multiple results.
|
||||
This function will run ``self.train()`` repeatedly until one of
|
||||
the following conditions is met: 1) the maximum buffer length is
|
||||
reached, 2) the maximum buffer time is reached, or 3) a checkpoint
|
||||
was created. Even if the maximum time is reached, it will always
|
||||
block until at least one result is received.
|
||||
|
||||
Args:
|
||||
buffer_time_s (float): Maximum time to buffer. The next result
|
||||
received after this amount of time has passed will return
|
||||
the whole buffer.
|
||||
max_buffer_length (int): Maximum number of results to buffer.
|
||||
|
||||
"""
|
||||
results = []
|
||||
|
||||
now = time.time()
|
||||
send_buffer_at = now + buffer_time_s
|
||||
while now < send_buffer_at or not results: # At least one result
|
||||
result = self.train()
|
||||
results.append(result)
|
||||
if result.get(DONE, False):
|
||||
# If the trial is done, return
|
||||
break
|
||||
elif result.get(SHOULD_CHECKPOINT, False):
|
||||
# If a checkpoint was created, return
|
||||
break
|
||||
elif result.get(RESULT_DUPLICATE):
|
||||
# If the function API trainable completed, return
|
||||
break
|
||||
elif len(results) >= max_buffer_length:
|
||||
# If the buffer is full, return
|
||||
break
|
||||
now = time.time()
|
||||
|
||||
return results
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
|
||||
+148
-86
@@ -141,6 +141,7 @@ class TrialRunner:
|
||||
|
||||
self._trials = []
|
||||
self._cached_trial_decisions = {}
|
||||
self._queued_trial_decisions = {}
|
||||
self._stop_queue = []
|
||||
self._should_stop_experiment = False # used by TuneServer
|
||||
self._local_checkpoint_dir = local_checkpoint_dir
|
||||
@@ -364,7 +365,15 @@ class TrialRunner:
|
||||
self._stop_experiment_if_needed()
|
||||
|
||||
try:
|
||||
with warn_if_slow("experiment_checkpoint"):
|
||||
with warn_if_slow(
|
||||
"experiment_checkpoint",
|
||||
message="Checkpointing the experiment state took "
|
||||
"{duration:.3f} s, which may be a performance "
|
||||
"bottleneck. Please ensure the "
|
||||
"`TUNE_GLOBAL_CHECKPOINT_S` environment variable is "
|
||||
"something significantly higher than this duration "
|
||||
"to ensure compute time is mostly spent on the main "
|
||||
"training loop."):
|
||||
self.checkpoint()
|
||||
except Exception as e:
|
||||
logger.warning(f"Trial Runner checkpointing failed: {str(e)}")
|
||||
@@ -467,6 +476,7 @@ class TrialRunner:
|
||||
# TODO(ujvl): Consider combining get_next_available_trial and
|
||||
# fetch_result functionality so that we don't timeout on fetch.
|
||||
trial = self.trial_executor.get_next_available_trial() # blocking
|
||||
|
||||
if trial.is_restoring:
|
||||
with warn_if_slow("process_trial_restore"):
|
||||
self._process_trial_restore(trial)
|
||||
@@ -503,6 +513,13 @@ class TrialRunner:
|
||||
with warn_if_slow("process_trial"):
|
||||
self._process_trial(trial)
|
||||
|
||||
# `self._queued_trial_decisions` now contains a final decision
|
||||
# based on all results
|
||||
final_decision = self._queued_trial_decisions.pop(
|
||||
trial.trial_id, None)
|
||||
if final_decision:
|
||||
self._execute_action(trial, final_decision)
|
||||
|
||||
def _process_trial(self, trial):
|
||||
"""Processes a trial result.
|
||||
|
||||
@@ -512,92 +529,38 @@ class TrialRunner:
|
||||
processed (see `_process_trial_save`). Otherwise the decision is
|
||||
acted on immediately.
|
||||
|
||||
If multiple results are received (e.g. because of buffering), all
|
||||
results are processed and the final action is determined. STOP
|
||||
takes precedence over PAUSE, which takes precedence over CONTINUE.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial with a result ready to be processed.
|
||||
"""
|
||||
try:
|
||||
result = self.trial_executor.fetch_result(trial)
|
||||
result.update(trial_id=trial.trial_id)
|
||||
is_duplicate = RESULT_DUPLICATE in result
|
||||
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
|
||||
# TrialScheduler and SearchAlgorithm still receive a
|
||||
# notification because there may be special handling for
|
||||
# the `on_trial_complete` hook.
|
||||
if is_duplicate:
|
||||
logger.debug("Trial finished without logging 'done'.")
|
||||
result = trial.last_result
|
||||
result.update(done=True)
|
||||
|
||||
self._validate_result_metrics(result)
|
||||
self._total_time += result.get(TIME_THIS_ITER_S, 0)
|
||||
|
||||
flat_result = flatten_dict(result)
|
||||
if self._stopper(trial.trial_id,
|
||||
result) or trial.should_stop(flat_result):
|
||||
result.update(done=True)
|
||||
|
||||
# Hook into scheduler
|
||||
self._scheduler_alg.on_trial_complete(self, trial, flat_result)
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, result=flat_result)
|
||||
|
||||
# If this is not a duplicate result, the callbacks should
|
||||
# be informed about the result.
|
||||
if not is_duplicate:
|
||||
with warn_if_slow("callbacks.on_trial_result"):
|
||||
self._callbacks.on_trial_result(
|
||||
iteration=self._iteration,
|
||||
trials=self._trials,
|
||||
trial=trial,
|
||||
result=result.copy())
|
||||
|
||||
self._callbacks.on_trial_complete(
|
||||
iteration=self._iteration,
|
||||
trials=self._trials,
|
||||
trial=trial)
|
||||
decision = TrialScheduler.STOP
|
||||
else:
|
||||
with warn_if_slow("scheduler.on_trial_result"):
|
||||
decision = self._scheduler_alg.on_trial_result(
|
||||
self, trial, flat_result)
|
||||
if decision == TrialScheduler.STOP:
|
||||
result.update(done=True)
|
||||
with warn_if_slow("search_alg.on_trial_result"):
|
||||
self._search_alg.on_trial_result(trial.trial_id,
|
||||
flat_result)
|
||||
with warn_if_slow("callbacks.on_trial_result"):
|
||||
self._callbacks.on_trial_result(
|
||||
iteration=self._iteration,
|
||||
trials=self._trials,
|
||||
trial=trial,
|
||||
result=result.copy())
|
||||
if decision == TrialScheduler.STOP:
|
||||
with warn_if_slow("search_alg.on_trial_complete"):
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, result=flat_result)
|
||||
with warn_if_slow("callbacks.on_trial_complete"):
|
||||
self._callbacks.on_trial_complete(
|
||||
iteration=self._iteration,
|
||||
trials=self._trials,
|
||||
trial=trial)
|
||||
|
||||
if not is_duplicate:
|
||||
trial.update_last_result(
|
||||
result, terminate=(decision == TrialScheduler.STOP))
|
||||
|
||||
# Checkpoints to disk. This should be checked even if
|
||||
# the scheduler decision is STOP or PAUSE. Note that
|
||||
# PAUSE only checkpoints to memory and does not update
|
||||
# the global checkpoint state.
|
||||
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
|
||||
|
||||
if trial.is_saving:
|
||||
# Cache decision to execute on after the save is processed.
|
||||
# This prevents changing the trial's state or kicking off
|
||||
# another training step prematurely.
|
||||
self._cached_trial_decisions[trial.trial_id] = decision
|
||||
else:
|
||||
self._execute_action(trial, decision)
|
||||
results = self.trial_executor.fetch_result(trial)
|
||||
with warn_if_slow(
|
||||
"process_trial_results",
|
||||
message="Processing trial results took {duration:.3f} s, "
|
||||
"which may be a performance bottleneck. Please consider "
|
||||
"reporting results less frequently to Ray Tune."):
|
||||
for i, result in enumerate(results):
|
||||
with warn_if_slow("process_trial_result"):
|
||||
decision = self._process_trial_result(trial, result)
|
||||
if decision is None:
|
||||
# If we didn't get a decision, this means a
|
||||
# non-training future (e.g. a save) was scheduled.
|
||||
# We do not allow processing more results then.
|
||||
if i < len(results) - 1:
|
||||
raise RuntimeError(
|
||||
f"Trial {trial} has a non-training future "
|
||||
f"scheduled but {len(results)-i} results "
|
||||
f"left to process. This should never "
|
||||
f"happen - please file an issue at "
|
||||
f"https://github.com/ray-project/ray/issues")
|
||||
elif decision == TrialScheduler.STOP:
|
||||
# If the decision is to stop the trial,
|
||||
# ignore all results that came after that.
|
||||
break
|
||||
except Exception:
|
||||
error_msg = "Trial %s: Error processing event." % trial
|
||||
if self._fail_fast == TrialRunner.RAISE:
|
||||
@@ -607,6 +570,88 @@ class TrialRunner:
|
||||
logger.exception(error_msg)
|
||||
self._process_trial_failure(trial, traceback.format_exc())
|
||||
|
||||
def _process_trial_result(self, trial, result):
|
||||
result.update(trial_id=trial.trial_id)
|
||||
is_duplicate = RESULT_DUPLICATE in result
|
||||
force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
|
||||
# TrialScheduler and SearchAlgorithm still receive a
|
||||
# notification because there may be special handling for
|
||||
# the `on_trial_complete` hook.
|
||||
if is_duplicate:
|
||||
logger.debug("Trial finished without logging 'done'.")
|
||||
result = trial.last_result
|
||||
result.update(done=True)
|
||||
|
||||
self._validate_result_metrics(result)
|
||||
self._total_time += result.get(TIME_THIS_ITER_S, 0)
|
||||
|
||||
flat_result = flatten_dict(result)
|
||||
if self._stopper(trial.trial_id,
|
||||
result) or trial.should_stop(flat_result):
|
||||
result.update(done=True)
|
||||
|
||||
# Hook into scheduler
|
||||
self._scheduler_alg.on_trial_complete(self, trial, flat_result)
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, result=flat_result)
|
||||
|
||||
# If this is not a duplicate result, the callbacks should
|
||||
# be informed about the result.
|
||||
if not is_duplicate:
|
||||
with warn_if_slow("callbacks.on_trial_result"):
|
||||
self._callbacks.on_trial_result(
|
||||
iteration=self._iteration,
|
||||
trials=self._trials,
|
||||
trial=trial,
|
||||
result=result.copy())
|
||||
|
||||
self._callbacks.on_trial_complete(
|
||||
iteration=self._iteration, trials=self._trials, trial=trial)
|
||||
decision = TrialScheduler.STOP
|
||||
else:
|
||||
with warn_if_slow("scheduler.on_trial_result"):
|
||||
decision = self._scheduler_alg.on_trial_result(
|
||||
self, trial, flat_result)
|
||||
if decision == TrialScheduler.STOP:
|
||||
result.update(done=True)
|
||||
with warn_if_slow("search_alg.on_trial_result"):
|
||||
self._search_alg.on_trial_result(trial.trial_id, flat_result)
|
||||
with warn_if_slow("callbacks.on_trial_result"):
|
||||
self._callbacks.on_trial_result(
|
||||
iteration=self._iteration,
|
||||
trials=self._trials,
|
||||
trial=trial,
|
||||
result=result.copy())
|
||||
if decision == TrialScheduler.STOP:
|
||||
with warn_if_slow("search_alg.on_trial_complete"):
|
||||
self._search_alg.on_trial_complete(
|
||||
trial.trial_id, result=flat_result)
|
||||
with warn_if_slow("callbacks.on_trial_complete"):
|
||||
self._callbacks.on_trial_complete(
|
||||
iteration=self._iteration,
|
||||
trials=self._trials,
|
||||
trial=trial)
|
||||
|
||||
if not is_duplicate:
|
||||
trial.update_last_result(
|
||||
result, terminate=(decision == TrialScheduler.STOP))
|
||||
|
||||
# Checkpoints to disk. This should be checked even if
|
||||
# the scheduler decision is STOP or PAUSE. Note that
|
||||
# PAUSE only checkpoints to memory and does not update
|
||||
# the global checkpoint state.
|
||||
self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
|
||||
|
||||
if trial.is_saving:
|
||||
# Cache decision to execute on after the save is processed.
|
||||
# This prevents changing the trial's state or kicking off
|
||||
# another training step prematurely.
|
||||
self._cached_trial_decisions[trial.trial_id] = decision
|
||||
return None
|
||||
else:
|
||||
self._queue_decision(trial, decision)
|
||||
return decision
|
||||
|
||||
def _validate_result_metrics(self, result):
|
||||
"""
|
||||
Check if any of the required metrics was not reported
|
||||
@@ -669,7 +714,8 @@ class TrialRunner:
|
||||
checkpoint_value = None
|
||||
|
||||
try:
|
||||
checkpoint_value = self.trial_executor.fetch_result(trial)
|
||||
results = self.trial_executor.fetch_result(trial)
|
||||
checkpoint_value = results[-1]
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error processing result.", trial)
|
||||
if self._fail_fast == TrialRunner.RAISE:
|
||||
@@ -695,7 +741,7 @@ class TrialRunner:
|
||||
trial.saving_to = None
|
||||
decision = self._cached_trial_decisions.pop(trial.trial_id, None)
|
||||
if decision and checkpoint_value:
|
||||
self._execute_action(trial, decision)
|
||||
self._queue_decision(trial, decision)
|
||||
|
||||
def _process_trial_restore(self, trial):
|
||||
"""Processes a trial restore.
|
||||
@@ -739,6 +785,21 @@ class TrialRunner:
|
||||
self.trial_executor.stop_trial(
|
||||
trial, error=True, error_msg=error_msg)
|
||||
|
||||
def _queue_decision(self, trial, decision):
|
||||
# Get old decision, setting it to the current decision if it isn't set
|
||||
old_decision = self._queued_trial_decisions.setdefault(
|
||||
trial.trial_id, decision)
|
||||
|
||||
# Stopping always takes precedence. If we decided to stop, just quit
|
||||
if old_decision is TrialScheduler.STOP:
|
||||
return
|
||||
|
||||
# The old decision wasn't STOP. We update the decision only if it is
|
||||
# STOP or PAUSE. The action will only be CONTINUE if it was set by
|
||||
# the first received result and was never updated after that.
|
||||
if decision is TrialScheduler.STOP or decision is TrialScheduler.PAUSE:
|
||||
self._queued_trial_decisions[trial.trial_id] = decision
|
||||
|
||||
def _execute_action(self, trial, decision):
|
||||
"""Executes action based on decision.
|
||||
|
||||
@@ -878,7 +939,8 @@ class TrialRunner:
|
||||
iteration=self._iteration, trials=self._trials, trial=trial)
|
||||
elif trial.status is Trial.RUNNING:
|
||||
try:
|
||||
result = self.trial_executor.fetch_result(trial)
|
||||
results = self.trial_executor.fetch_result(trial)
|
||||
result = results[-1]
|
||||
trial.update_last_result(result, terminate=True)
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
self._search_alg.on_trial_complete(
|
||||
|
||||
@@ -11,6 +11,7 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
@@ -124,10 +125,16 @@ class warn_if_slow:
|
||||
"""
|
||||
|
||||
DEFAULT_THRESHOLD = float(os.environ.get("TUNE_WARN_THRESHOLD_S", 0.5))
|
||||
DEFAULT_MESSAGE = "The `{name}` operation took {duration:.3f} s, " \
|
||||
"which may be a performance bottleneck."
|
||||
|
||||
def __init__(self, name, threshold=None):
|
||||
def __init__(self,
|
||||
name: str,
|
||||
threshold: Optional[float] = None,
|
||||
message: Optional[str] = None):
|
||||
self.name = name
|
||||
self.threshold = threshold or self.DEFAULT_THRESHOLD
|
||||
self.message = message or self.DEFAULT_MESSAGE
|
||||
self.too_slow = False
|
||||
|
||||
def __enter__(self):
|
||||
@@ -138,10 +145,9 @@ class warn_if_slow:
|
||||
now = time.time()
|
||||
if now - self.start > self.threshold and now - START_OF_TIME > 60.0:
|
||||
self.too_slow = True
|
||||
_duration = now - self.start
|
||||
duration = now - self.start
|
||||
logger.warning(
|
||||
f"The `{self.name}` operation took {_duration:.3f} s, "
|
||||
"which may be a performance bottleneck.")
|
||||
self.message.format(name=self.name, duration=duration))
|
||||
|
||||
|
||||
class Tee(object):
|
||||
|
||||
Reference in New Issue
Block a user