[tune] Fix Pausing and Error Propogation (#2815)

* add new tests

* Try-catch errors from ray get

* longer pbt run

* Update pbt_example.py

* Split trial and result and fix tests
This commit is contained in:
Richard Liaw
2018-09-04 15:22:11 -07:00
committed by Eric Liang
parent dfb7c2be1e
commit 72542c9016
5 changed files with 75 additions and 34 deletions
+5 -2
View File
@@ -59,7 +59,10 @@ if __name__ == "__main__":
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
if args.smoke_test:
ray.init(num_cpus=4) # force pausing to happen for test
else:
ray.init()
pbt = PopulationBasedTraining(
time_attr="training_iteration",
@@ -79,7 +82,7 @@ if __name__ == "__main__":
"pbt_test": {
"run": MyTrainableClass,
"stop": {
"training_iteration": 2 if args.smoke_test else 99999
"training_iteration": 20 if args.smoke_test else 99999
},
"num_samples": 10,
"config": {
+29 -19
View File
@@ -6,6 +6,7 @@ from __future__ import print_function
import os
import time
import traceback
import ray
from ray.tune.logger import NoopLogger
from ray.tune.trial import Trial, Resources, Checkpoint
@@ -17,7 +18,7 @@ class RayTrialExecutor(TrialExecutor):
def __init__(self, queue_trials=False):
super(RayTrialExecutor, self).__init__(queue_trials)
self._running = {} # TODO
self._running = {}
# Since trial resume after paused should not run
# trial.train.remote(), thus no more new remote object id generated.
# We use self._paused to store paused trials here.
@@ -58,11 +59,12 @@ class RayTrialExecutor(TrialExecutor):
trial.runner = self._setup_runner(trial)
if not self.restore(trial, checkpoint):
return
if prior_status == Trial.PAUSED:
# If prev status is PAUSED, self._paused stores its remote_id.
remote_id = self._find_item(self._paused, trial)[0]
self._paused.pop(remote_id)
self._running[remote_id] = trial
previous_run = self._find_item(self._paused, trial)
if (prior_status == Trial.PAUSED and previous_run):
# If Trial was in flight when paused, self._paused stores result.
self._paused.pop(previous_run[0])
self._running[previous_run[0]] = trial
else:
self._train(trial)
@@ -144,10 +146,15 @@ class RayTrialExecutor(TrialExecutor):
self._train(trial)
def pause_trial(self, trial):
"""Pauses the trial."""
"""Pauses the trial.
remote_id = self._find_item(self._running, trial)[0]
self._paused[remote_id] = trial
If trial is in-flight, preserves return value in separate queue
before pausing, which is restored when Trial is resumed.
"""
trial_future = self._find_item(self._running, trial)
if trial_future:
self._paused[trial_future[0]] = trial
super(RayTrialExecutor, self).pause_trial(trial)
def get_running_trials(self):
@@ -155,18 +162,21 @@ class RayTrialExecutor(TrialExecutor):
return list(self._running.values())
def fetch_one_result(self):
"""Fetches one result of the running trials."""
def get_next_available_trial(self):
[result_id], _ = ray.wait(list(self._running))
trial = self._running.pop(result_id)
result = None
try:
result = ray.get(result_id)
except Exception:
print("fetch_one_result failed:", traceback.format_exc())
return self._running[result_id]
return trial, result
def fetch_result(self, trial):
"""Fetches one result of the running trials.
Returns:
Result of the most recent trial training run."""
trial_future = self._find_item(self._running, trial)
if not trial_future:
raise ValueError("Trial was not running.")
self._running.pop(trial_future[0])
result = ray.get(trial_future[0])
return result
def _commit_resources(self, resources):
self._committed_resources = Resources(
@@ -51,6 +51,31 @@ class RayTrialExecutorTest(unittest.TestCase):
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testPauseResume(self):
"""Tests that pausing works for trials in flight."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def testPauseResume2(self):
"""Tests that pausing works for trials being processed."""
trial = Trial("__fake")
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.fetch_result(trial)
self.trial_executor.pause_trial(trial)
self.assertEqual(Trial.PAUSED, trial.status)
self.trial_executor.start_trial(trial)
self.assertEqual(Trial.RUNNING, trial.status)
self.trial_executor.stop_trial(trial)
self.assertEqual(Trial.TERMINATED, trial.status)
def generate_trials(self, spec, name):
suggester = BasicVariantGenerator({name: spec})
return suggester.next_trials()
+13 -7
View File
@@ -119,17 +119,23 @@ class TrialExecutor(object):
"""A hook called after running one step of the trial event loop."""
pass
def fetch_one_result(self):
"""Fetches one result from running trials.
def get_next_available_trial(self):
"""Blocking call that waits until one result is ready.
It's a blocking call waits until one result is ready.
Returns:
Trial object that is ready for intermediate processing.
"""
raise NotImplementedError
def fetch_result(self, trial):
"""Fetches one result for the trial.
Assumes the trial is running.
Return:
A tuple of (trial, result). If fetch result failed,
return (trial, None) other than raise Exception.
Result object for the trial.
"""
raise NotImplementedError("Subclasses of TrialExecutor must provide "
"fetch_one_result() method")
raise NotImplementedError
def debug_string(self):
"""Returns a human readable message for printing to the console."""
+3 -6
View File
@@ -211,10 +211,9 @@ class TrialRunner(object):
return trial
def _process_events(self):
trial, result = self.trial_executor.fetch_one_result()
trial = self.trial_executor.get_next_available_trial()
try:
if result is None:
raise ValueError("fetch_one_result failed")
result = self.trial_executor.fetch_result(trial)
self._total_time += result[TIME_THIS_ITER_S]
if trial.should_stop(result):
@@ -323,9 +322,7 @@ class TrialRunner(object):
trial.trial_id, early_terminated=True)
elif trial.status is Trial.RUNNING:
try:
_, result = self.trial_executor.fetch_one_result()
if result is None:
raise ValueError("fetch_one_result failed")
result = self.trial_executor.fetch_result(trial)
trial.update_last_result(result, terminate=True)
self._scheduler_alg.on_trial_complete(self, trial, result)
self._search_alg.on_trial_complete(