diff --git a/python/ray/tune/examples/pbt_example.py b/python/ray/tune/examples/pbt_example.py index 56557ee58..b6532144f 100755 --- a/python/ray/tune/examples/pbt_example.py +++ b/python/ray/tune/examples/pbt_example.py @@ -59,7 +59,10 @@ if __name__ == "__main__": parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") args, _ = parser.parse_known_args() - ray.init() + if args.smoke_test: + ray.init(num_cpus=4) # force pausing to happen for test + else: + ray.init() pbt = PopulationBasedTraining( time_attr="training_iteration", @@ -79,7 +82,7 @@ if __name__ == "__main__": "pbt_test": { "run": MyTrainableClass, "stop": { - "training_iteration": 2 if args.smoke_test else 99999 + "training_iteration": 20 if args.smoke_test else 99999 }, "num_samples": 10, "config": { diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index d53d928c5..1d25fc2b9 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -6,6 +6,7 @@ from __future__ import print_function import os import time import traceback + import ray from ray.tune.logger import NoopLogger from ray.tune.trial import Trial, Resources, Checkpoint @@ -17,7 +18,7 @@ class RayTrialExecutor(TrialExecutor): def __init__(self, queue_trials=False): super(RayTrialExecutor, self).__init__(queue_trials) - self._running = {} # TODO + self._running = {} # Since trial resume after paused should not run # trial.train.remote(), thus no more new remote object id generated. # We use self._paused to store paused trials here. @@ -58,11 +59,12 @@ class RayTrialExecutor(TrialExecutor): trial.runner = self._setup_runner(trial) if not self.restore(trial, checkpoint): return - if prior_status == Trial.PAUSED: - # If prev status is PAUSED, self._paused stores its remote_id. - remote_id = self._find_item(self._paused, trial)[0] - self._paused.pop(remote_id) - self._running[remote_id] = trial + + previous_run = self._find_item(self._paused, trial) + if (prior_status == Trial.PAUSED and previous_run): + # If Trial was in flight when paused, self._paused stores result. + self._paused.pop(previous_run[0]) + self._running[previous_run[0]] = trial else: self._train(trial) @@ -144,10 +146,15 @@ class RayTrialExecutor(TrialExecutor): self._train(trial) def pause_trial(self, trial): - """Pauses the trial.""" + """Pauses the trial. - remote_id = self._find_item(self._running, trial)[0] - self._paused[remote_id] = trial + If trial is in-flight, preserves return value in separate queue + before pausing, which is restored when Trial is resumed. + """ + + trial_future = self._find_item(self._running, trial) + if trial_future: + self._paused[trial_future[0]] = trial super(RayTrialExecutor, self).pause_trial(trial) def get_running_trials(self): @@ -155,18 +162,21 @@ class RayTrialExecutor(TrialExecutor): return list(self._running.values()) - def fetch_one_result(self): - """Fetches one result of the running trials.""" - + def get_next_available_trial(self): [result_id], _ = ray.wait(list(self._running)) - trial = self._running.pop(result_id) - result = None - try: - result = ray.get(result_id) - except Exception: - print("fetch_one_result failed:", traceback.format_exc()) + return self._running[result_id] - return trial, result + def fetch_result(self, trial): + """Fetches one result of the running trials. + + Returns: + Result of the most recent trial training run.""" + trial_future = self._find_item(self._running, trial) + if not trial_future: + raise ValueError("Trial was not running.") + self._running.pop(trial_future[0]) + result = ray.get(trial_future[0]) + return result def _commit_resources(self, resources): self._committed_resources = Resources( diff --git a/python/ray/tune/test/ray_trial_executor_test.py b/python/ray/tune/test/ray_trial_executor_test.py index b17a28739..ddd5995ea 100644 --- a/python/ray/tune/test/ray_trial_executor_test.py +++ b/python/ray/tune/test/ray_trial_executor_test.py @@ -51,6 +51,31 @@ class RayTrialExecutorTest(unittest.TestCase): self.trial_executor.stop_trial(trial) self.assertEqual(Trial.TERMINATED, trial.status) + def testPauseResume(self): + """Tests that pausing works for trials in flight.""" + trial = Trial("__fake") + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + self.trial_executor.pause_trial(trial) + self.assertEqual(Trial.PAUSED, trial.status) + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + self.trial_executor.stop_trial(trial) + self.assertEqual(Trial.TERMINATED, trial.status) + + def testPauseResume2(self): + """Tests that pausing works for trials being processed.""" + trial = Trial("__fake") + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + self.trial_executor.fetch_result(trial) + self.trial_executor.pause_trial(trial) + self.assertEqual(Trial.PAUSED, trial.status) + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + self.trial_executor.stop_trial(trial) + self.assertEqual(Trial.TERMINATED, trial.status) + def generate_trials(self, spec, name): suggester = BasicVariantGenerator({name: spec}) return suggester.next_trials() diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py index 270869ac9..61bdf45a8 100644 --- a/python/ray/tune/trial_executor.py +++ b/python/ray/tune/trial_executor.py @@ -119,17 +119,23 @@ class TrialExecutor(object): """A hook called after running one step of the trial event loop.""" pass - def fetch_one_result(self): - """Fetches one result from running trials. + def get_next_available_trial(self): + """Blocking call that waits until one result is ready. - It's a blocking call waits until one result is ready. + Returns: + Trial object that is ready for intermediate processing. + """ + raise NotImplementedError + + def fetch_result(self, trial): + """Fetches one result for the trial. + + Assumes the trial is running. Return: - A tuple of (trial, result). If fetch result failed, - return (trial, None) other than raise Exception. + Result object for the trial. """ - raise NotImplementedError("Subclasses of TrialExecutor must provide " - "fetch_one_result() method") + raise NotImplementedError def debug_string(self): """Returns a human readable message for printing to the console.""" diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index c4ed36afe..d8759e48d 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -211,10 +211,9 @@ class TrialRunner(object): return trial def _process_events(self): - trial, result = self.trial_executor.fetch_one_result() + trial = self.trial_executor.get_next_available_trial() try: - if result is None: - raise ValueError("fetch_one_result failed") + result = self.trial_executor.fetch_result(trial) self._total_time += result[TIME_THIS_ITER_S] if trial.should_stop(result): @@ -323,9 +322,7 @@ class TrialRunner(object): trial.trial_id, early_terminated=True) elif trial.status is Trial.RUNNING: try: - _, result = self.trial_executor.fetch_one_result() - if result is None: - raise ValueError("fetch_one_result failed") + result = self.trial_executor.fetch_result(trial) trial.update_last_result(result, terminate=True) self._scheduler_alg.on_trial_complete(self, trial, result) self._search_alg.on_trial_complete(