mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[tune] Fix Pausing and Error Propogation (#2815)
* add new tests * Try-catch errors from ray get * longer pbt run * Update pbt_example.py * Split trial and result and fix tests
This commit is contained in:
@@ -59,7 +59,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
if args.smoke_test:
|
||||
ray.init(num_cpus=4) # force pausing to happen for test
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
@@ -79,7 +82,7 @@ if __name__ == "__main__":
|
||||
"pbt_test": {
|
||||
"run": MyTrainableClass,
|
||||
"stop": {
|
||||
"training_iteration": 2 if args.smoke_test else 99999
|
||||
"training_iteration": 20 if args.smoke_test else 99999
|
||||
},
|
||||
"num_samples": 10,
|
||||
"config": {
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import ray
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.trial import Trial, Resources, Checkpoint
|
||||
@@ -17,7 +18,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
def __init__(self, queue_trials=False):
|
||||
super(RayTrialExecutor, self).__init__(queue_trials)
|
||||
self._running = {} # TODO
|
||||
self._running = {}
|
||||
# Since trial resume after paused should not run
|
||||
# trial.train.remote(), thus no more new remote object id generated.
|
||||
# We use self._paused to store paused trials here.
|
||||
@@ -58,11 +59,12 @@ class RayTrialExecutor(TrialExecutor):
|
||||
trial.runner = self._setup_runner(trial)
|
||||
if not self.restore(trial, checkpoint):
|
||||
return
|
||||
if prior_status == Trial.PAUSED:
|
||||
# If prev status is PAUSED, self._paused stores its remote_id.
|
||||
remote_id = self._find_item(self._paused, trial)[0]
|
||||
self._paused.pop(remote_id)
|
||||
self._running[remote_id] = trial
|
||||
|
||||
previous_run = self._find_item(self._paused, trial)
|
||||
if (prior_status == Trial.PAUSED and previous_run):
|
||||
# If Trial was in flight when paused, self._paused stores result.
|
||||
self._paused.pop(previous_run[0])
|
||||
self._running[previous_run[0]] = trial
|
||||
else:
|
||||
self._train(trial)
|
||||
|
||||
@@ -144,10 +146,15 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._train(trial)
|
||||
|
||||
def pause_trial(self, trial):
|
||||
"""Pauses the trial."""
|
||||
"""Pauses the trial.
|
||||
|
||||
remote_id = self._find_item(self._running, trial)[0]
|
||||
self._paused[remote_id] = trial
|
||||
If trial is in-flight, preserves return value in separate queue
|
||||
before pausing, which is restored when Trial is resumed.
|
||||
"""
|
||||
|
||||
trial_future = self._find_item(self._running, trial)
|
||||
if trial_future:
|
||||
self._paused[trial_future[0]] = trial
|
||||
super(RayTrialExecutor, self).pause_trial(trial)
|
||||
|
||||
def get_running_trials(self):
|
||||
@@ -155,18 +162,21 @@ class RayTrialExecutor(TrialExecutor):
|
||||
|
||||
return list(self._running.values())
|
||||
|
||||
def fetch_one_result(self):
|
||||
"""Fetches one result of the running trials."""
|
||||
|
||||
def get_next_available_trial(self):
|
||||
[result_id], _ = ray.wait(list(self._running))
|
||||
trial = self._running.pop(result_id)
|
||||
result = None
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
except Exception:
|
||||
print("fetch_one_result failed:", traceback.format_exc())
|
||||
return self._running[result_id]
|
||||
|
||||
return trial, result
|
||||
def fetch_result(self, trial):
|
||||
"""Fetches one result of the running trials.
|
||||
|
||||
Returns:
|
||||
Result of the most recent trial training run."""
|
||||
trial_future = self._find_item(self._running, trial)
|
||||
if not trial_future:
|
||||
raise ValueError("Trial was not running.")
|
||||
self._running.pop(trial_future[0])
|
||||
result = ray.get(trial_future[0])
|
||||
return result
|
||||
|
||||
def _commit_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
|
||||
@@ -51,6 +51,31 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testPauseResume(self):
|
||||
"""Tests that pausing works for trials in flight."""
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.pause_trial(trial)
|
||||
self.assertEqual(Trial.PAUSED, trial.status)
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def testPauseResume2(self):
|
||||
"""Tests that pausing works for trials being processed."""
|
||||
trial = Trial("__fake")
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.fetch_result(trial)
|
||||
self.trial_executor.pause_trial(trial)
|
||||
self.assertEqual(Trial.PAUSED, trial.status)
|
||||
self.trial_executor.start_trial(trial)
|
||||
self.assertEqual(Trial.RUNNING, trial.status)
|
||||
self.trial_executor.stop_trial(trial)
|
||||
self.assertEqual(Trial.TERMINATED, trial.status)
|
||||
|
||||
def generate_trials(self, spec, name):
|
||||
suggester = BasicVariantGenerator({name: spec})
|
||||
return suggester.next_trials()
|
||||
|
||||
@@ -119,17 +119,23 @@ class TrialExecutor(object):
|
||||
"""A hook called after running one step of the trial event loop."""
|
||||
pass
|
||||
|
||||
def fetch_one_result(self):
|
||||
"""Fetches one result from running trials.
|
||||
def get_next_available_trial(self):
|
||||
"""Blocking call that waits until one result is ready.
|
||||
|
||||
It's a blocking call waits until one result is ready.
|
||||
Returns:
|
||||
Trial object that is ready for intermediate processing.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def fetch_result(self, trial):
|
||||
"""Fetches one result for the trial.
|
||||
|
||||
Assumes the trial is running.
|
||||
|
||||
Return:
|
||||
A tuple of (trial, result). If fetch result failed,
|
||||
return (trial, None) other than raise Exception.
|
||||
Result object for the trial.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses of TrialExecutor must provide "
|
||||
"fetch_one_result() method")
|
||||
raise NotImplementedError
|
||||
|
||||
def debug_string(self):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
@@ -211,10 +211,9 @@ class TrialRunner(object):
|
||||
return trial
|
||||
|
||||
def _process_events(self):
|
||||
trial, result = self.trial_executor.fetch_one_result()
|
||||
trial = self.trial_executor.get_next_available_trial()
|
||||
try:
|
||||
if result is None:
|
||||
raise ValueError("fetch_one_result failed")
|
||||
result = self.trial_executor.fetch_result(trial)
|
||||
self._total_time += result[TIME_THIS_ITER_S]
|
||||
|
||||
if trial.should_stop(result):
|
||||
@@ -323,9 +322,7 @@ class TrialRunner(object):
|
||||
trial.trial_id, early_terminated=True)
|
||||
elif trial.status is Trial.RUNNING:
|
||||
try:
|
||||
_, result = self.trial_executor.fetch_one_result()
|
||||
if result is None:
|
||||
raise ValueError("fetch_one_result failed")
|
||||
result = self.trial_executor.fetch_result(trial)
|
||||
trial.update_last_result(result, terminate=True)
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
self._search_alg.on_trial_complete(
|
||||
|
||||
Reference in New Issue
Block a user