From 38867eea4e011115515b77018c2310e7dd74ff55 Mon Sep 17 00:00:00 2001 From: joyyoj Date: Thu, 23 Aug 2018 01:55:45 +0800 Subject: [PATCH] [tune] Cross-Framework Compatibility (#2646) This commit is a first pass at restructuring the Trial execution logic to support running on multiple frameworks. --- .travis.yml | 2 + python/ray/tune/ray_trial_executor.py | 274 ++++++++++++++++++ python/ray/tune/schedulers/hyperband.py | 2 +- python/ray/tune/schedulers/pbt.py | 16 +- .../ray/tune/test/ray_trial_executor_test.py | 60 ++++ python/ray/tune/test/trial_runner_test.py | 53 +++- python/ray/tune/test/trial_scheduler_test.py | 48 +-- python/ray/tune/test/tune_server_test.py | 2 +- python/ray/tune/trial.py | 224 ++++---------- python/ray/tune/trial_executor.py | 166 +++++++++++ python/ray/tune/trial_runner.py | 164 +++-------- python/ray/tune/tune.py | 9 +- 12 files changed, 680 insertions(+), 340 deletions(-) create mode 100644 python/ray/tune/ray_trial_executor.py create mode 100644 python/ray/tune/test/ray_trial_executor_test.py create mode 100644 python/ray/tune/trial_executor.py diff --git a/.travis.yml b/.travis.yml index 682bfb558..4177e6515 100644 --- a/.travis.yml +++ b/.travis.yml @@ -158,6 +158,7 @@ matrix: - python -m pytest python/ray/tune/test/trial_scheduler_test.py - python -m pytest python/ray/tune/test/experiment_test.py - python -m pytest python/ray/tune/test/tune_server_test.py + - python -m pytest python/ray/tune/test/ray_trial_executor_test.py # ray rllib tests - python -m pytest python/ray/rllib/test/test_catalog.py @@ -228,6 +229,7 @@ script: - python -m pytest python/ray/tune/test/trial_scheduler_test.py - python -m pytest python/ray/tune/test/experiment_test.py - python -m pytest python/ray/tune/test/tune_server_test.py + - python -m pytest python/ray/tune/test/ray_trial_executor_test.py # ray rllib tests - python -m pytest python/ray/rllib/test/test_catalog.py diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py new file mode 100644 index 000000000..b364cd2f4 --- /dev/null +++ b/python/ray/tune/ray_trial_executor.py @@ -0,0 +1,274 @@ +# coding: utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import traceback +import ray +from ray.tune.logger import NoopLogger +from ray.tune.trial import Trial, Resources, Checkpoint +from ray.tune.trial_executor import TrialExecutor + + +class RayTrialExecutor(TrialExecutor): + """An implemention of TrialExecutor based on Ray.""" + + def __init__(self, queue_trials=False): + super(RayTrialExecutor, self).__init__(queue_trials) + self._running = {} # TODO + # Since trial resume after paused should not run + # trial.train.remote(), thus no more new remote object id generated. + # We use self._paused to store paused trials here. + self._paused = {} + self._avail_resources = Resources(cpu=0, gpu=0) + self._committed_resources = Resources(cpu=0, gpu=0) + self._resources_initialized = False + + def _setup_runner(self, trial): + cls = ray.remote( + num_cpus=trial.resources.cpu, + num_gpus=trial.resources.gpu)(trial._get_trainable_cls()) + + trial.init_logger() + remote_logdir = trial.logdir + + def logger_creator(config): + # Set the working dir in the remote process, for user file writes + if not os.path.exists(remote_logdir): + os.makedirs(remote_logdir) + os.chdir(remote_logdir) + return NoopLogger(config, remote_logdir) + + # Logging for trials is handled centrally by TrialRunner, so + # configure the remote runner to use a noop-logger. + return cls.remote(config=trial.config, logger_creator=logger_creator) + + def _train(self, trial): + """Start one iteration of training and save remote id.""" + + assert trial.status == Trial.RUNNING, trial.status + remote = trial.runner.train.remote() + self._running[remote] = trial + + def _start_trial(self, trial, checkpoint=None): + prior_status = trial.status + trial.status = Trial.RUNNING + trial.runner = self._setup_runner(trial) + if not self.restore(trial, checkpoint): + return + if prior_status == Trial.PAUSED: + # If prev status is PAUSED, self._paused stores its remote_id. + remote_id = self._find_item(self._paused, trial)[0] + self._paused.pop(remote_id) + self._running[remote_id] = trial + else: + self._train(trial) + + def _stop_trial(self, trial, error=False, error_msg=None, + stop_logger=True): + """Stops this trial. + + Stops this trial, releasing all allocating resources. If stopping the + trial fails, the run will be marked as terminated in error, but no + exception will be thrown. + + Args: + error (bool): Whether to mark this trial as terminated in error. + error_msg (str): Optional error message. + stop_logger (bool): Whether to shut down the trial logger. + """ + + if error: + trial.status = Trial.ERROR + else: + trial.status = Trial.TERMINATED + + try: + trial.write_error_log(error_msg) + if hasattr(trial, 'runner') and trial.runner: + stop_tasks = [] + stop_tasks.append(trial.runner.stop.remote()) + stop_tasks.append(trial.runner.__ray_terminate__.remote()) + # TODO(ekl) seems like wait hangs when killing actors + _, unfinished = ray.wait( + stop_tasks, num_returns=2, timeout=250) + except Exception: + print("Error stopping runner:", traceback.format_exc()) + trial.status = Trial.ERROR + finally: + trial.runner = None + + if stop_logger: + trial.close_logger() + + def start_trial(self, trial, checkpoint_obj=None): + """Starts the trial.""" + + self._commit_resources(trial.resources) + try: + self._start_trial(trial, checkpoint_obj) + except Exception: + error_msg = traceback.format_exc() + print("Error starting runner, retrying:", error_msg) + time.sleep(2) + self._stop_trial(trial, error=True, error_msg=error_msg) + try: + self._start_trial(trial) + except Exception: + error_msg = traceback.format_exc() + print("Error starting runner, abort:", error_msg) + self._stop_trial(trial, error=True, error_msg=error_msg) + # note that we don't return the resources, since they may + # have been lost + + def _find_item(self, dictionary, item): + out = [rid for rid, t in dictionary.items() if t is item] + return out + + def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): + """Only returns resources if resources allocated.""" + prior_status = trial.status + self._stop_trial( + trial, error=error, error_msg=error_msg, stop_logger=stop_logger) + if prior_status == Trial.RUNNING: + self._return_resources(trial.resources) + out = self._find_item(self._running, trial) + for result_id in out: + self._running.pop(result_id) + + def continue_training(self, trial): + """Continues the training of this trial.""" + + self._train(trial) + + def pause_trial(self, trial): + """Pauses the trial.""" + + remote_id = self._find_item(self._running, trial)[0] + self._paused[remote_id] = trial + super(RayTrialExecutor, self).pause_trial(trial) + + def get_running_trials(self): + """Returns the running trials.""" + + return list(self._running.values()) + + def fetch_one_result(self): + """Fetches one result of the running trials.""" + + [result_id], _ = ray.wait(list(self._running)) + trial = self._running.pop(result_id) + result = None + try: + result = ray.get(result_id) + except Exception as e: + print("fetch_one_result failed:", traceback.format_exc()) + + return trial, result + + def _commit_resources(self, resources): + self._committed_resources = Resources( + self._committed_resources.cpu + resources.cpu_total(), + self._committed_resources.gpu + resources.gpu_total()) + + def _return_resources(self, resources): + self._committed_resources = Resources( + self._committed_resources.cpu - resources.cpu_total(), + self._committed_resources.gpu - resources.gpu_total()) + assert self._committed_resources.cpu >= 0 + assert self._committed_resources.gpu >= 0 + + def _update_avail_resources(self): + clients = ray.global_state.client_table() + if ray.worker.global_worker.use_raylet: + # TODO(rliaw): Remove once raylet flag is swapped + num_cpus = sum(cl['Resources']['CPU'] for cl in clients) + num_gpus = sum(cl['Resources'].get('GPU', 0) for cl in clients) + else: + local_schedulers = [ + entry for client in clients.values() for entry in client + if (entry['ClientType'] == 'local_scheduler' + and not entry['Deleted']) + ] + num_cpus = sum(ls['CPU'] for ls in local_schedulers) + num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers) + self._avail_resources = Resources(int(num_cpus), int(num_gpus)) + self._resources_initialized = True + + def has_resources(self, resources): + """Returns whether this runner has at least the specified resources.""" + + cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu + gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu + + have_space = (resources.cpu_total() <= cpu_avail + and resources.gpu_total() <= gpu_avail) + + if have_space: + return True + + can_overcommit = self._queue_trials + + if (resources.cpu_total() > 0 and cpu_avail <= 0) or \ + (resources.gpu_total() > 0 and gpu_avail <= 0): + can_overcommit = False # requested resource is already saturated + + if can_overcommit: + print("WARNING:tune:allowing trial to start even though the " + "cluster does not have enough free resources. Trial actors " + "may appear to hang until enough resources are added to the " + "cluster (e.g., via autoscaling). You can disable this " + "behavior by specifying `queue_trials=False` in " + "ray.tune.run_experiments().") + return True + + return False + + def debug_string(self): + """Returns a human readable message for printing to the console.""" + + if self._resources_initialized: + return "Resources requested: {}/{} CPUs, {}/{} GPUs".format( + self._committed_resources.cpu, self._avail_resources.cpu, + self._committed_resources.gpu, self._avail_resources.gpu) + else: + return "" + + def on_step_begin(self): + """Before step() called, update the available resources.""" + + self._update_avail_resources() + + def save(self, trial, storage=Checkpoint.DISK): + """Saves the trial's state to a checkpoint.""" + trial._checkpoint.storage = storage + if storage == Checkpoint.MEMORY: + trial._checkpoint.value = trial.runner.save_to_object.remote() + else: + trial._checkpoint.value = ray.get(trial.runner.save.remote()) + return trial._checkpoint.value + + def restore(self, trial, checkpoint=None): + """Restores training state from a given model checkpoint.""" + if checkpoint is None or checkpoint.value is None: + checkpoint = trial._checkpoint + if checkpoint is None or checkpoint.value is None: + return True + if trial.runner is None: + print("Unable to restore - no runner") + trial.status = Trial.ERROR + return False + try: + value = checkpoint.value + if checkpoint.storage == Checkpoint.MEMORY: + assert type(value) != Checkpoint, type(value) + ray.get(trial.runner.restore_from_object.remote(value)) + else: + ray.get(trial.runner.restore.remote(value)) + return True + except Exception: + print("Error restoring runner:", traceback.format_exc()) + trial.status = Trial.ERROR + return False diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index 8b30d97e5..9d0ffd777 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -185,7 +185,7 @@ class HyperBandScheduler(FIFOScheduler): raise Exception("Trial with unexpected status encountered") if bracket.continue_trial(t): if t.status == Trial.PAUSED: - t.unpause() + trial_runner.trial_executor.unpause_trial(t) elif t.status == Trial.RUNNING: action = TrialScheduler.CONTINUE return action diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 460e0b99b..2e8d97112 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -7,8 +7,8 @@ import math import copy from ray.tune.error import TuneError -from ray.tune.trial import Trial -from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler +from ray.tune.trial import Trial, Checkpoint +from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest.variant_generator import format_vars # Parameters are transferred from the top PBT_QUANTILE fraction of trials to @@ -187,7 +187,8 @@ class PopulationBasedTraining(FIFOScheduler): lower_quantile, upper_quantile = self._quantiles() if trial in upper_quantile: - state.last_checkpoint = trial.checkpoint(to_object_store=True) + state.last_checkpoint = trial_runner.trial_executor.save( + trial, Checkpoint.MEMORY) self._num_checkpoints += 1 else: state.last_checkpoint = None # not a top trial @@ -195,7 +196,7 @@ class PopulationBasedTraining(FIFOScheduler): if trial in lower_quantile: trial_to_clone = random.choice(upper_quantile) assert trial is not trial_to_clone - self._exploit(trial, trial_to_clone) + self._exploit(trial_runner.trial_executor, trial, trial_to_clone) for trial in trial_runner.get_trials(): if trial.status in [Trial.PENDING, Trial.PAUSED]: @@ -203,7 +204,7 @@ class PopulationBasedTraining(FIFOScheduler): return TrialScheduler.CONTINUE - def _exploit(self, trial, trial_to_clone): + def _exploit(self, trial_executor, trial, trial_to_clone): """Transfers perturbed state from trial_to_clone -> trial.""" trial_state = self._trial_state[trial] @@ -220,11 +221,12 @@ class PopulationBasedTraining(FIFOScheduler): trial_state.last_score)) # TODO(ekl) restarting the trial is expensive. We should implement a # lighter way reset() method that can alter the trial config. - trial.stop(stop_logger=False) + trial_executor.stop_trial(trial, stop_logger=False) trial.config = new_config trial.experiment_tag = make_experiment_tag( trial_state.orig_tag, new_config, self._hyperparam_mutations) - trial.start(new_state.last_checkpoint) + trial_executor.start_trial( + trial, Checkpoint.from_object(new_state.last_checkpoint)) self._num_perturbations += 1 # Transfer over the last perturbation time as well trial_state.last_perturbation_time = new_state.last_perturbation_time diff --git a/python/ray/tune/test/ray_trial_executor_test.py b/python/ray/tune/test/ray_trial_executor_test.py new file mode 100644 index 000000000..b17a28739 --- /dev/null +++ b/python/ray/tune/test/ray_trial_executor_test.py @@ -0,0 +1,60 @@ +# coding: utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import ray +from ray.rllib import _register_all +from ray.tune.ray_trial_executor import RayTrialExecutor +from ray.tune.suggest import BasicVariantGenerator +from ray.tune.trial import Trial, Checkpoint + + +class RayTrialExecutorTest(unittest.TestCase): + def setUp(self): + self.trial_executor = RayTrialExecutor(queue_trials=False) + ray.init() + + def tearDown(self): + ray.shutdown() + _register_all() # re-register the evicted objects + + def _get_trials(self): + trials = self.generate_trials({ + "run": "PPO", + "config": { + "bar": { + "grid_search": [True, False] + }, + "foo": { + "grid_search": [1, 2, 3] + }, + }, + }, "grid_search") + return list(trials) + + def testStartStop(self): + trial = Trial("__fake") + self.trial_executor.start_trial(trial) + running = self.trial_executor.get_running_trials() + self.assertEqual(1, len(running)) + self.trial_executor.stop_trial(trial) + + def testSaveRestore(self): + trial = Trial("__fake") + self.trial_executor.start_trial(trial) + self.assertEqual(Trial.RUNNING, trial.status) + self.trial_executor.save(trial, Checkpoint.DISK) + self.trial_executor.restore(trial) + self.trial_executor.stop_trial(trial) + self.assertEqual(Trial.TERMINATED, trial.status) + + def generate_trials(self, spec, name): + suggester = BasicVariantGenerator({name: spec}) + return suggester.next_trials() + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 02083ba10..abfe9e97e 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -11,6 +11,7 @@ from ray.rllib import _register_all from ray.tune import Trainable, TuneError from ray.tune import register_env, register_trainable, run_experiments +from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.schedulers import TrialScheduler, FIFOScheduler from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.result import DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE @@ -667,12 +668,13 @@ class TrialRunnerTest(unittest.TestCase): def testTrialStatus(self): ray.init() trial = Trial("__fake") + trial_executor = RayTrialExecutor() self.assertEqual(trial.status, Trial.PENDING) - trial.start() + trial_executor.start_trial(trial) self.assertEqual(trial.status, Trial.RUNNING) - trial.stop() + trial_executor.stop_trial(trial) self.assertEqual(trial.status, Trial.TERMINATED) - trial.stop(error=True) + trial_executor.stop_trial(trial, error=True) self.assertEqual(trial.status, Trial.ERROR) def testExperimentTagTruncation(self): @@ -681,6 +683,7 @@ class TrialRunnerTest(unittest.TestCase): def train(config, reporter): reporter(timesteps_total=1) + trial_executor = RayTrialExecutor() register_trainable("f1", train) experiments = { @@ -697,16 +700,17 @@ class TrialRunnerTest(unittest.TestCase): trial_generator = BasicVariantGenerator() trial_generator.add_configurations({name: spec}) for trial in trial_generator.next_trials(): - trial.start() + trial_executor.start_trial(trial) self.assertLessEqual(len(trial.logdir), 200) - trial.stop() + trial_executor.stop_trial(trial) def testTrialErrorOnStart(self): ray.init() + trial_executor = RayTrialExecutor() _global_registry.register(TRAINABLE_CLASS, "asdf", None) trial = Trial("asdf", resources=Resources(1, 0)) try: - trial.start() + trial_executor.start_trial(trial) except Exception as e: self.assertIn("a class", str(e)) @@ -901,8 +905,7 @@ class TrialRunnerTest(unittest.TestCase): runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) - - path = trials[0].checkpoint() + path = runner.trial_executor.save(trials[0]) kwargs["restore_path"] = path runner.add_trial(Trial("__fake", **kwargs)) @@ -956,10 +959,10 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(ray.get(trials[0].runner.set_info.remote(1)), 1) - trials[0].pause() + runner.trial_executor.pause_trial(trials[0]) self.assertEqual(trials[0].status, Trial.PAUSED) - trials[0].resume() + runner.trial_executor.resume_trial(trials[0]) self.assertEqual(trials[0].status, Trial.RUNNING) runner.step() @@ -969,6 +972,36 @@ class TrialRunnerTest(unittest.TestCase): runner.step() self.assertEqual(trials[0].status, Trial.TERMINATED) + def testStepHook(self): + ray.init(num_cpus=4, num_gpus=2) + runner = TrialRunner(BasicVariantGenerator()) + + def on_step_begin(self): + self._update_avail_resources() + cnt = self.pre_step if hasattr(self, 'pre_step') else 0 + setattr(self, 'pre_step', cnt + 1) + + def on_step_end(self): + cnt = self.pre_step if hasattr(self, 'post_step') else 0 + setattr(self, 'post_step', 1 + cnt) + + import types + runner.trial_executor.on_step_begin = types.MethodType( + on_step_begin, runner.trial_executor) + runner.trial_executor.on_step_end = types.MethodType( + on_step_end, runner.trial_executor) + + kwargs = { + "stopping_criterion": { + "training_iteration": 5 + }, + "resources": Resources(cpu=1, gpu=1), + } + runner.add_trial(Trial("__fake", **kwargs)) + runner.step() + self.assertEqual(runner.trial_executor.pre_step, 1) + self.assertEqual(runner.trial_executor.post_step, 1) + def testStopTrial(self): ray.init(num_cpus=4, num_gpus=2) runner = TrialRunner(BasicVariantGenerator()) diff --git a/python/ray/tune/test/trial_scheduler_test.py b/python/ray/tune/test/trial_scheduler_test.py index 36416cf03..c67219992 100644 --- a/python/ray/tune/test/trial_scheduler_test.py +++ b/python/ray/tune/test/trial_scheduler_test.py @@ -11,7 +11,8 @@ from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, PopulationBasedTraining, MedianStoppingRule, TrialScheduler) from ray.tune.schedulers.pbt import explore -from ray.tune.trial import Trial, Resources +from ray.tune.trial import Trial, Resources, Checkpoint +from ray.tune.trial_executor import TrialExecutor from ray.rllib import _register_all _register_all() @@ -150,10 +151,29 @@ class EarlyStoppingSuite(unittest.TestCase): TrialScheduler.CONTINUE) +class _MockTrialExecutor(TrialExecutor): + def start_trial(self, trial, checkpoint_obj=None): + trial.logger_running = True + trial.restored_checkpoint = checkpoint_obj.value + trial.status = Trial.RUNNING + + def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): + trial.status = Trial.ERROR if error else Trial.TERMINATED + if stop_logger: + trial.logger_running = False + + def restore(self, trial, checkpoint=None): + pass + + def save(self, trial, type=Checkpoint.DISK): + return trial.trainable_name + + class _MockTrialRunner(): def __init__(self, scheduler): self._scheduler_alg = scheduler self.trials = [] + self.trial_executor = _MockTrialExecutor() def process_action(self, trial, action): if action == TrialScheduler.CONTINUE: @@ -161,7 +181,7 @@ class _MockTrialRunner(): elif action == TrialScheduler.PAUSE: self._pause_trial(trial) elif action == TrialScheduler.STOP: - trial.stop() + self.trial_executor.stop_trial(trial) def stop_trial(self, trial): if trial.status in [Trial.ERROR, Trial.TERMINATED]: @@ -169,7 +189,6 @@ class _MockTrialRunner(): elif trial.status in [Trial.PENDING, Trial.PAUSED]: self._scheduler_alg.on_trial_remove(self, trial) else: - self._scheduler_alg.on_trial_complete(self, trial, result(100, 10)) def add_trial(self, trial): @@ -198,7 +217,7 @@ class HyperbandSuite(unittest.TestCase): _register_all() # re-register the evicted objects def schedulerSetup(self, num_trials): - """Setup a scheduler and Runner with max Iter = 9 + """Setup a scheduler and Runner with max Iter = 9. Bracketing is placed as follows: (5, 81); @@ -214,7 +233,7 @@ class HyperbandSuite(unittest.TestCase): return sched, runner def default_statistics(self): - """Default statistics for HyperBand""" + """Default statistics for HyperBand.""" sched = HyperBandScheduler() res = { str(s): { @@ -232,8 +251,8 @@ class HyperbandSuite(unittest.TestCase): return int(np.ceil(n / sched._eta)) def basicSetup(self): - """Setup and verify full band. - """ + """Setup and verify full band.""" + stats = self.default_statistics() sched, _ = self.schedulerSetup(stats["max_trials"]) @@ -299,6 +318,7 @@ class HyperbandSuite(unittest.TestCase): def testSuccessiveHalving(self): """Setup full band, then iterate through last bracket (n=81) to make sure successive halving is correct.""" + stats = self.default_statistics() sched, mock_runner = self.schedulerSetup(stats["max_trials"]) big_bracket = sched._state["bracket"] @@ -360,6 +380,7 @@ class HyperbandSuite(unittest.TestCase): def testTrialErrored(self): """If a trial errored, make sure successive halving still happens""" + stats = self.default_statistics() trial_count = stats[str(0)]["n"] + 3 sched, mock_runner = self.schedulerSetup(trial_count) @@ -510,7 +531,7 @@ class HyperbandSuite(unittest.TestCase): self.assertLess(current_length, 27) def testRemove(self): - """Test with 4: start 1, remove 1 pending, add 2, remove 1 pending""" + """Test with 4: start 1, remove 1 pending, add 2, remove 1 pending.""" sched, runner = self.schedulerSetup(4) trials = sorted(list(sched._trial_info), key=lambda t: t.trial_id) runner._launch_trial(trials[0]) @@ -542,17 +563,6 @@ class _MockTrial(Trial): self.restored_checkpoint = None self.resources = Resources(1, 0) - def checkpoint(self, to_object_store=False): - return self.trainable_name - - def start(self, checkpoint=None): - self.logger_running = True - self.restored_checkpoint = checkpoint - - def stop(self, stop_logger=False): - if stop_logger: - self.logger_running = False - class PopulationBasedTestingSuite(unittest.TestCase): def setUp(self): diff --git a/python/ray/tune/test/tune_server_test.py b/python/ray/tune/test/tune_server_test.py index c40832b48..a535b421b 100644 --- a/python/ray/tune/test/tune_server_test.py +++ b/python/ray/tune/test/tune_server_test.py @@ -88,7 +88,7 @@ class TuneServerSuite(unittest.TestCase): self.assertEqual(len(all_trials), 2) def testStopTrial(self): - """Check if Stop Trial works""" + """Check if Stop Trial works.""" runner, client = self.basicSetup() for i in range(2): runner.step() diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index ddcc29637..e764369f9 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -2,16 +2,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import tempfile from collections import namedtuple from datetime import datetime -import tempfile import time -import traceback import ray import os from ray.tune import TuneError -from ray.tune.logger import NoopLogger, UnifiedLogger, pretty_print +from ray.tune.logger import pretty_print, UnifiedLogger # NOTE(rkn): We import ray.tune.registry here instead of importing the names we # need because there are cyclic imports that may cause specific names to not # have been defined yet. See https://github.com/ray-project/ray/issues/1716. @@ -39,7 +38,9 @@ class Resources( launch additional Ray actors that use CPUs. extra_gpu (int): Extra GPUs to reserve in case the trial needs to launch additional Ray actors that use GPUs. + """ + __slots__ = () def __new__(cls, cpu, gpu, extra_cpu=0, extra_gpu=0): @@ -62,6 +63,30 @@ def has_trainable(trainable_name): ray.tune.registry.TRAINABLE_CLASS, trainable_name) +class Checkpoint(object): + """Describes a checkpoint of trial state. + + Checkpoint may be saved in different storage. + + Attributes: + storage (str): Storage type. + value (str): If storage==MEMORY,value is a Python object. + If storage==DISK,value is a path points to the checkpoint in disk. + """ + + MEMORY = "memory" + DISK = "disk" + + def __init__(self, storage, value): + self.storage = storage + self.value = value + + @staticmethod + def from_object(value=None): + """Creates a checkpoint from a Python object.""" + return Checkpoint(Checkpoint.MEMORY, value) + + class Trial(object): """A trial object holds the state for one model training run. @@ -110,16 +135,15 @@ class Trial(object): resources or self._get_trainable_cls().default_resource_request(self.config)) self.stopping_criterion = stopping_criterion or {} - self.checkpoint_freq = checkpoint_freq self.upload_dir = upload_dir self.verbose = True self.max_failures = max_failures # Local trial state that is updated during the run self.last_result = None - self._checkpoint_path = restore_path - self._checkpoint_obj = None - self.runner = None + self.checkpoint_freq = checkpoint_freq + self._checkpoint = Checkpoint( + storage=Checkpoint.DISK, value=restore_path) self.status = Trial.PENDING self.location = None self.logdir = None @@ -136,96 +160,34 @@ class Trial(object): def generate_id(cls): return binary_to_hex(random_string())[:8] - def start(self, checkpoint_obj=None): - """Starts this trial. + def init_logger(self): + """Init logger.""" - If an error is encountered when starting the trial, an exception will - be thrown. + if not self.result_logger: + if not os.path.exists(self.local_dir): + os.makedirs(self.local_dir) + self.logdir = tempfile.mkdtemp( + prefix="{}_{}".format( + str(self)[:MAX_LEN_IDENTIFIER], date_str()), + dir=self.local_dir) + self.result_logger = UnifiedLogger(self.config, self.logdir, + self.upload_dir) - Args: - checkpoint_obj (obj): Optional checkpoint to resume from. - """ + def close_logger(self): + """Close logger.""" - self._setup_runner() - if checkpoint_obj: - self.restore_from_obj(checkpoint_obj) - elif self._checkpoint_path: - self.restore_from_path(self._checkpoint_path) - elif self._checkpoint_obj: - self.restore_from_obj(self._checkpoint_obj) - - def stop(self, error=False, error_msg=None, stop_logger=True): - """Stops this trial. - - Stops this trial, releasing all allocating resources. If stopping the - trial fails, the run will be marked as terminated in error, but no - exception will be thrown. - - Args: - error (bool): Whether to mark this trial as terminated in error. - error_msg (str): Optional error message. - stop_logger (bool): Whether to shut down the trial logger. - """ - - if error: - self.status = Trial.ERROR - else: - self.status = Trial.TERMINATED - - try: - if error_msg and self.logdir: - self.num_failures += 1 - error_file = os.path.join(self.logdir, - "error_{}.txt".format(date_str())) - with open(error_file, "w") as f: - f.write(error_msg) - self.error_file = error_file - if self.runner: - stop_tasks = [] - stop_tasks.append(self.runner.stop.remote()) - stop_tasks.append(self.runner.__ray_terminate__.remote()) - # TODO(ekl) seems like wait hangs when killing actors - _, unfinished = ray.wait( - stop_tasks, num_returns=2, timeout=250) - except Exception: - print("Error stopping runner:", traceback.format_exc()) - self.status = Trial.ERROR - finally: - self.runner = None - - if stop_logger and self.result_logger: + if self.result_logger: self.result_logger.close() self.result_logger = None - def pause(self): - """We want to release resources (specifically GPUs) when pausing an - experiment. This results in a state similar to TERMINATED.""" - - assert self.status == Trial.RUNNING, self.status - try: - self.checkpoint(to_object_store=True) - self.stop(stop_logger=False) - self.status = Trial.PAUSED - except Exception: - print("Error pausing runner:", traceback.format_exc()) - self.status = Trial.ERROR - - def unpause(self): - """Sets PAUSED trial to pending to allow scheduler to start.""" - assert self.status == Trial.PAUSED, self.status - self.status = Trial.PENDING - - def resume(self): - """Resume PAUSED trials. This is a blocking call.""" - - assert self.status == Trial.PAUSED, self.status - self.start() - - def train_remote(self): - """Returns Ray future for one iteration of training.""" - - assert self.status == Trial.RUNNING, self.status - return self.runner.train.remote() + def write_error_log(self, error_msg): + if error_msg and self.logdir: + self.num_failures += 1 # may be moved to outer scope? + error_file = os.path.join(self.logdir, + "error_{}.txt".format(date_str())) + with open(error_file, "w") as f: + f.write(error_msg) + self.error_file = error_file def should_stop(self, result): """Whether the given result meets this trial's stopping criteria.""" @@ -294,57 +256,7 @@ class Trial(object): if self.error_file else "") def has_checkpoint(self): - return self._checkpoint_path is not None or \ - self._checkpoint_obj is not None - - def checkpoint(self, to_object_store=False): - """Checkpoints the state of this trial. - - Args: - to_object_store (bool): Whether to save to the Ray object store - (async) vs a path on local disk (sync). - """ - - obj = None - path = None - if to_object_store: - obj = self.runner.save_to_object.remote() - else: - path = ray.get(self.runner.save.remote()) - self._checkpoint_path = path - self._checkpoint_obj = obj - - if self.verbose: - print("Saved checkpoint for {} to {}".format(self, path or obj)) - return path or obj - - def restore_from_path(self, path): - """Restores runner state from specified path. - - Args: - path (str): A path where state will be restored. - """ - - if self.runner is None: - print("Unable to restore - no runner") - else: - try: - ray.get(self.runner.restore.remote(path)) - except Exception: - print("Error restoring runner:", traceback.format_exc()) - self.status = Trial.ERROR - - def restore_from_obj(self, obj): - """Restores runner state from the specified object.""" - - if self.runner is None: - print("Unable to restore - no runner") - else: - try: - ray.get(self.runner.restore_from_object.remote(obj)) - except Exception: - print("Error restoring runner:", traceback.format_exc()) - self.status = Trial.ERROR + return self._checkpoint.value is not None def update_last_result(self, result, terminate=False): if terminate: @@ -357,34 +269,6 @@ class Trial(object): self.last_result = result self.result_logger.on_result(self.last_result) - def _setup_runner(self): - self.status = Trial.RUNNING - cls = ray.remote( - num_cpus=self.resources.cpu, - num_gpus=self.resources.gpu)(self._get_trainable_cls()) - if not self.result_logger: - if not os.path.exists(self.local_dir): - os.makedirs(self.local_dir) - self.logdir = tempfile.mkdtemp( - prefix="{}_{}".format( - str(self)[:MAX_LEN_IDENTIFIER], date_str()), - dir=self.local_dir) - self.result_logger = UnifiedLogger(self.config, self.logdir, - self.upload_dir) - remote_logdir = self.logdir - - def logger_creator(config): - # Set the working dir in the remote process, for user file writes - if not os.path.exists(remote_logdir): - os.makedirs(remote_logdir) - os.chdir(remote_logdir) - return NoopLogger(config, remote_logdir) - - # Logging for trials is handled centrally by TrialRunner, so - # configure the remote runner to use a noop-logger. - self.runner = cls.remote( - config=self.config, logger_creator=logger_creator) - def _get_trainable_cls(self): return ray.tune.registry._global_registry.get( ray.tune.registry.TRAINABLE_CLASS, self.trainable_name) diff --git a/python/ray/tune/trial_executor.py b/python/ray/tune/trial_executor.py new file mode 100644 index 000000000..270869ac9 --- /dev/null +++ b/python/ray/tune/trial_executor.py @@ -0,0 +1,166 @@ +# coding: utf-8 +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import traceback + +from ray.tune.trial import Trial, Checkpoint + + +class TrialExecutor(object): + """Manages platform-specific details such as resource handling + and starting/stopping trials. + """ + + def __init__(self, queue_trials=False): + """Initializes a new TrialExecutor. + + Args: + queue_trials (bool): Whether to queue trials when the cluster does + not currently have enough resources to launch one. This should + be set to True when running on an autoscaling cluster to enable + automatic scale-up. + """ + self._queue_trials = queue_trials + + def has_resources(self, resources): + """Returns whether this runner has at least the specified resources.""" + raise NotImplementedError("Subclasses of TrialExecutor must provide " + "has_resources() method") + + def start_trial(self, trial, checkpoint=None): + """Starts the trial restoring from checkpoint if checkpoint != None. + + If an error is encountered when starting the trial, an exception will + be thrown. + + Args: + checkpoint(Checkpoint): A Python object or path storing the state + of trial. + """ + raise NotImplementedError("Subclasses of TrialExecutor must provide " + "start_trial() method") + + def stop_trial(self, trial, error=False, error_msg=None, stop_logger=True): + """Stops the trial. + + Stops this trial, releasing all allocating resources. + If stopping the trial fails, the run will be marked as terminated + in error, but no exception will be thrown. + + Args: + error (bool): Whether to mark this trial as terminated in error. + error_msg (str): Optional error message. + stop_logger (bool): Whether to shut down the trial logger. + """ + raise NotImplementedError("Subclasses of TrialExecutor must provide " + "stop_trial() method") + + def restart_trial(self, trial, error_msg=None): + """Restarts the trial. + + The state of the trial should restore from the last checkpoint. + + Args: + error_msg (str): Optional error message. + """ + try: + print("Attempting to recover trial state from last checkpoint") + self.stop_trial( + trial, error=True, error_msg=error_msg, stop_logger=False) + trial.result_logger.flush() + self.start_trial(trial) + except Exception: + error_msg = traceback.format_exc() + print("Error recovering trial from checkpoint, abort:", error_msg) + self.stop_trial(trial, error=True, error_msg=error_msg) + + def continue_training(self, trial): + """Continues the training of this trial.""" + pass + + def pause_trial(self, trial): + """Pauses the trial. + + We want to release resources (specifically GPUs) when pausing an + experiment. This results in PAUSED state that similar to TERMINATED. + """ + assert trial.status == Trial.RUNNING, trial.status + try: + self.save(trial, Checkpoint.MEMORY) + self.stop_trial(trial, stop_logger=False) + trial.status = Trial.PAUSED + except Exception: + print("Error pausing runner:", traceback.format_exc()) + trial.status = Trial.ERROR + + def unpause_trial(self, trial): + """Sets PAUSED trial to pending to allow scheduler to start.""" + assert trial.status == Trial.PAUSED, trial.status + trial.status = Trial.PENDING + + def resume_trial(self, trial): + """Resumes PAUSED trials. This is a blocking call.""" + + assert trial.status == Trial.PAUSED, trial.status + self.start_trial(trial) + + def get_running_trials(self): + """Returns all running trials.""" + raise NotImplementedError("Subclasses of TrialExecutor must provide " + "get_running_trials() method") + + def on_step_begin(self): + """A hook called before running one step of the trial event loop.""" + pass + + def on_step_end(self): + """A hook called after running one step of the trial event loop.""" + pass + + def fetch_one_result(self): + """Fetches one result from running trials. + + It's a blocking call waits until one result is ready. + + Return: + A tuple of (trial, result). If fetch result failed, + return (trial, None) other than raise Exception. + """ + raise NotImplementedError("Subclasses of TrialExecutor must provide " + "fetch_one_result() method") + + def debug_string(self): + """Returns a human readable message for printing to the console.""" + pass + + def restore(self, trial, checkpoint=None): + """Restores training state from a checkpoint. + + If checkpoint is None, try to restore from trial._checkpoint. + If restoring fails, the trial status will be set to ERROR. + + Args: + trial (Trial): Trial to be restored. + checkpoint (Checkpoint): Checkpoint to restore from. + + Return: + False if error occurred, otherwise return True. + """ + raise NotImplementedError("Subclasses of TrialExecutor must provide " + "restore() method") + + def save(self, trial, storage=Checkpoint.DISK): + """Saves training state of this trial to a checkpoint. + + Args: + trial (Trial): The state of this trial to be saved. + storage (str): Where to store the checkpoint. Defaults to DISK. + + Return: + A Python object if storage==Checkpoint.MEMORY otherwise + a path to the checkpoint. + """ + raise NotImplementedError("Subclasses of TrialExecutor must provide " + "save() method") diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 0ac729ef5..c69b11835 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -4,15 +4,15 @@ from __future__ import print_function import collections import os -import ray import time import traceback from ray.tune import TuneError +from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.result import TIME_THIS_ITER_S -from ray.tune.web_server import TuneServer -from ray.tune.trial import Trial, Resources +from ray.tune.trial import Trial from ray.tune.schedulers import FIFOScheduler, TrialScheduler +from ray.tune.web_server import TuneServer MAX_DEBUG_TRIALS = 20 @@ -45,7 +45,8 @@ class TrialRunner(object): launch_web_server=False, server_port=TuneServer.DEFAULT_PORT, verbose=True, - queue_trials=False): + queue_trials=False, + trial_executor=None): """Initializes a new TrialRunner. Args: @@ -60,14 +61,13 @@ class TrialRunner(object): not currently have enough resources to launch one. This should be set to True when running on an autoscaling cluster to enable automatic scale-up. + trial_executor (TrialExecutor): Defaults to RayTrialExecutor. """ self._search_alg = search_alg self._scheduler_alg = scheduler or FIFOScheduler() self._trials = [] - self._running = {} - self._avail_resources = Resources(cpu=0, gpu=0) - self._committed_resources = Resources(cpu=0, gpu=0) - self._resources_initialized = False + self.trial_executor = trial_executor or \ + RayTrialExecutor(queue_trials=queue_trials) # For debugging, it may be useful to halt trials after some time has # elapsed. TODO(ekl) consider exposing this in the API. @@ -98,10 +98,11 @@ class TrialRunner(object): Callers should typically run this method repeatedly in a loop. They may inspect or modify the runner's state in between calls to step(). """ + self.trial_executor.on_step_begin() next_trial = self._get_next_trial() if next_trial is not None: - self._launch_trial(next_trial) - elif self._running: + self.trial_executor.start_trial(next_trial) + elif self.trial_executor.get_running_trials(): self._process_events() else: for trial in self._trials: @@ -109,13 +110,13 @@ class TrialRunner(object): if not self.has_resources(trial.resources): raise TuneError( ("Insufficient cluster resources to launch trial: " - "trial requested {} but the cluster only has {} " - "available. Pass `queue_trials=True` in " + "trial requested {} but the cluster summary: {} " + "Pass `queue_trials=True` in " "ray.tune.run_experiments() or on the command " "line to queue trials until the cluster scales " "up. {}").format( trial.resources.summary_string(), - self._avail_resources.summary_string(), + self.trial_executor.debug_string(), trial._get_trainable_cls().resource_help( trial.config))) elif trial.status == Trial.PAUSED: @@ -129,6 +130,7 @@ class TrialRunner(object): if self.is_finished(): self._server.shutdown() + self.trial_executor.on_step_end() def get_trial(self, tid): trial = [t for t in self._trials if t.trial_id == tid] @@ -189,41 +191,12 @@ class TrialRunner(object): def _debug_messages(self): messages = ["== Status =="] messages.append(self._scheduler_alg.debug_string()) - if self._resources_initialized: - messages.append( - "Resources requested: {}/{} CPUs, {}/{} GPUs".format( - self._committed_resources.cpu, self._avail_resources.cpu, - self._committed_resources.gpu, self._avail_resources.gpu)) + messages.append(self.trial_executor.debug_string()) return messages def has_resources(self, resources): """Returns whether this runner has at least the specified resources.""" - - cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu - gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu - - have_space = (resources.cpu_total() <= cpu_avail - and resources.gpu_total() <= gpu_avail) - - if have_space: - return True - - can_overcommit = self._queue_trials - - if ((resources.cpu_total() > 0 and cpu_avail <= 0) - or (resources.gpu_total() > 0 and gpu_avail <= 0)): - can_overcommit = False # requested resource is already saturated - - if can_overcommit: - print("WARNING:tune:allowing trial to start even though the " - "cluster does not have enough free resources. Trial actors " - "may appear to hang until enough resources are added to the " - "cluster (e.g., via autoscaling). You can disable this " - "behavior by specifying `queue_trials=False` in " - "ray.tune.run_experiments().") - return True - - return False + return self.trial_executor.has_resources(resources) def _get_next_trial(self): """Replenishes queue. @@ -231,38 +204,17 @@ class TrialRunner(object): Blocks if all trials queued have finished, but search algorithm is still not finished. """ - self._update_avail_resources() trials_done = all(trial.is_finished() for trial in self._trials) wait_for_trial = trials_done and not self._search_alg.is_finished() self._update_trial_queue(blocking=wait_for_trial) trial = self._scheduler_alg.choose_trial_to_run(self) return trial - def _launch_trial(self, trial): - self._commit_resources(trial.resources) - try: - trial.start() - self._running[trial.train_remote()] = trial - except Exception: - error_msg = traceback.format_exc() - print("Error starting runner, retrying:", error_msg) - time.sleep(2) - trial.stop(error=True, error_msg=error_msg) - try: - trial.start() - self._running[trial.train_remote()] = trial - except Exception: - error_msg = traceback.format_exc() - print("Error starting runner, abort:", error_msg) - trial.stop(error=True, error_msg=error_msg) - # note that we don't return the resources, since they may - # have been lost - def _process_events(self): - [result_id], _ = ray.wait(list(self._running)) - trial = self._running.pop(result_id) + trial, result = self.trial_executor.fetch_one_result() try: - result = ray.get(result_id) + if result is None: + raise ValueError("fetch_one_result failed") self._total_time += result[TIME_THIS_ITER_S] if trial.should_stop(result): @@ -284,12 +236,12 @@ class TrialRunner(object): if decision == TrialScheduler.CONTINUE: if trial.should_checkpoint(): # TODO(rliaw): This is a blocking call - trial.checkpoint() - self._running[trial.train_remote()] = trial + self.trial_executor.save(trial) + self.trial_executor.continue_training(trial) elif decision == TrialScheduler.PAUSE: - self._pause_trial(trial) + self.trial_executor.pause_trial(trial) elif decision == TrialScheduler.STOP: - self._stop_trial(trial) + self.trial_executor.stop_trial(trial) else: assert False, "Invalid scheduling decision: {}".format( decision) @@ -304,30 +256,28 @@ class TrialRunner(object): self._scheduler_alg.on_trial_error(self, trial) self._search_alg.on_trial_complete( trial.trial_id, error=True) - self._stop_trial(trial, error=True, error_msg=error_msg) + self.trial_executor.stop_trial(trial, True, error_msg) def _try_recover(self, trial, error_msg): try: print("Attempting to recover trial state from last checkpoint") - trial.stop(error=True, error_msg=error_msg, stop_logger=False) - trial.result_logger.flush() # make sure checkpoint is synced - trial.start() - self._running[trial.train_remote()] = trial + self.trial_executor.restart_trial(trial, error_msg) except Exception: error_msg = traceback.format_exc() print("Error recovering trial from checkpoint, abort:", error_msg) - self._stop_trial(trial, error=True, error_msg=error_msg) + self.trial_executor.stop_trial(trial, True, error_msg=error_msg) def _update_trial_queue(self, blocking=False, timeout=600): """Adds next trials to queue if possible. Note that the timeout is currently unexposed to the user. - Arguments: + Args: blocking (bool): Blocks until either a trial is available or the Runner finishes (i.e., timeout or search algorithm finishes). - timeout (int): Seconds before blocking times out.""" + timeout (int): Seconds before blocking times out. + """ trials = self._search_alg.next_trials() if blocking and not trials: start = time.time() @@ -340,18 +290,6 @@ class TrialRunner(object): for trial in trials: self.add_trial(trial) - def _commit_resources(self, resources): - self._committed_resources = Resources( - self._committed_resources.cpu + resources.cpu_total(), - self._committed_resources.gpu + resources.gpu_total()) - - def _return_resources(self, resources): - self._committed_resources = Resources( - self._committed_resources.cpu - resources.cpu_total(), - self._committed_resources.gpu - resources.gpu_total()) - assert self._committed_resources.cpu >= 0 - assert self._committed_resources.gpu >= 0 - def request_stop_trial(self, trial): self._stop_queue.append(trial) @@ -379,13 +317,10 @@ class TrialRunner(object): self._search_alg.on_trial_complete( trial.trial_id, early_terminated=True) elif trial.status is Trial.RUNNING: - # NOTE: There should only be one... - result_id = [ - rid for rid, t in self._running.items() if t is trial - ][0] - self._running.pop(result_id) try: - result = ray.get(result_id) + _, result = self.trial_executor.fetch_one_result() + if result is None: + raise ValueError("fetch_one_result failed") trial.update_last_result(result, terminate=True) self._scheduler_alg.on_trial_complete(self, trial, result) self._search_alg.on_trial_complete( @@ -397,35 +332,4 @@ class TrialRunner(object): self._search_alg.on_trial_complete(trial.trial_id, error=True) error = True - self._stop_trial(trial, error=error, error_msg=error_msg) - - def _stop_trial(self, trial, error=False, error_msg=None): - """Only returns resources if resources allocated.""" - prior_status = trial.status - trial.stop(error=error, error_msg=error_msg) - if prior_status == Trial.RUNNING: - self._return_resources(trial.resources) - - def _pause_trial(self, trial): - """Only returns resources if resources allocated.""" - prior_status = trial.status - trial.pause() - if prior_status == Trial.RUNNING: - self._return_resources(trial.resources) - - def _update_avail_resources(self): - clients = ray.global_state.client_table() - if ray.worker.global_worker.use_raylet: - # TODO(rliaw): Remove once raylet flag is swapped - num_cpus = sum(cl['Resources']['CPU'] for cl in clients) - num_gpus = sum(cl['Resources'].get('GPU', 0) for cl in clients) - else: - local_schedulers = [ - entry for client in clients.values() for entry in client - if (entry['ClientType'] == 'local_scheduler' - and not entry['Deleted']) - ] - num_cpus = sum(ls['CPU'] for ls in local_schedulers) - num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers) - self._avail_resources = Resources(int(num_cpus), int(num_gpus)) - self._resources_initialized = True + self.trial_executor.stop_trial(trial, error=error, error_msg=error_msg) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 3d122091e..f783f98a2 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -35,7 +35,8 @@ def run_experiments(experiments=None, with_server=False, server_port=TuneServer.DEFAULT_PORT, verbose=True, - queue_trials=False): + queue_trials=False, + trial_executor=None): """Runs and blocks until all trials finish. Args: @@ -54,6 +55,7 @@ def run_experiments(experiments=None, not currently have enough resources to launch one. This should be set to True when running on an autoscaling cluster to enable automatic scale-up. + trial_executor (TrialExecutor): Manage the execution of trials. Examples: >>> experiment_spec = Experiment("experiment", my_func) @@ -73,7 +75,9 @@ def run_experiments(experiments=None, Returns: List of Trial objects, holding data for each executed trial. + """ + if scheduler is None: scheduler = FIFOScheduler() @@ -88,7 +92,8 @@ def run_experiments(experiments=None, launch_web_server=with_server, server_port=server_port, verbose=verbose, - queue_trials=queue_trials) + queue_trials=queue_trials, + trial_executor=trial_executor) print(runner.debug_string(max_debug=99999))