diff --git a/doc/source/tune/user-guide.rst b/doc/source/tune/user-guide.rst index 8458fcabc..111f166e6 100644 --- a/doc/source/tune/user-guide.rst +++ b/doc/source/tune/user-guide.rst @@ -608,6 +608,8 @@ These are the environment variables Ray Tune currently considers: or a search algorithm, Tune will error if the metric was not reported in the result. Setting this environment variable to ``1`` will disable this check. +* **TUNE_FUNCTION_THREAD_TIMEOUT_S**: Time in seconds the function API waits + for threads to finish after instructing them to complete. Defaults to ``2``. * **TUNE_GLOBAL_CHECKPOINT_S**: Time in seconds that limits how often Tune's experiment state is checkpointed. If not set this will default to ``10``. * **TUNE_MAX_LEN_IDENTIFIER**: Maximum length of trial subdirectory names (those diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index b55c4b6bf..5b0c59abf 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -1,5 +1,6 @@ import logging import os +import sys import time import inspect import shutil @@ -120,12 +121,21 @@ class StatusReporter: def __init__(self, result_queue, continue_semaphore, + end_event, trial_name=None, trial_id=None, logdir=None): self._queue = result_queue self._last_report_time = None self._continue_semaphore = continue_semaphore + self._end_event = end_event + self._trial_name = trial_name + self._trial_id = trial_id + self._logdir = logdir + self._last_checkpoint = None + self._fresh_checkpoint = False + + def reset(self, trial_name=None, trial_id=None, logdir=None): self._trial_name = trial_name self._trial_id = trial_id self._logdir = logdir @@ -171,6 +181,11 @@ class StatusReporter: # resume training. self._continue_semaphore.acquire() + # If the trial should be terminated, exit gracefully. + if self._end_event.is_set(): + self._end_event.clear() + sys.exit(0) + def make_checkpoint_dir(self, step): checkpoint_dir = TrainableUtil.make_checkpoint_dir( self.logdir, index=step) @@ -264,6 +279,10 @@ class FunctionRunner(Trainable): # and to generate the next result. self._continue_semaphore = threading.Semaphore(0) + # Event for notifying the reporter to exit gracefully, terminating + # the thread. + self._end_event = threading.Event() + # Queue for passing results between threads self._results_queue = queue.Queue(1) @@ -275,6 +294,7 @@ class FunctionRunner(Trainable): self._status_reporter = StatusReporter( self._results_queue, self._continue_semaphore, + self._end_event, trial_name=self.trial_name, trial_id=self.trial_id, logdir=self.logdir) @@ -363,7 +383,7 @@ class FunctionRunner(Trainable): # This keyword appears if the train_func using the Function API # finishes without "done=True". This duplicates the last result, but # the TrialRunner will not log this result again. - if "__duplicate__" in result: + if RESULT_DUPLICATE in result: new_result = self._last_result.copy() new_result.update(result) result = new_result @@ -441,6 +461,11 @@ class FunctionRunner(Trainable): self.restore(checkpoint_path) def cleanup(self): + # Trigger thread termination + self._end_event.set() + self._continue_semaphore.release() + # Do not wait for thread termination here. + # If everything stayed in synch properly, this should never happen. if not self._results_queue.empty(): logger.warning( @@ -457,6 +482,29 @@ class FunctionRunner(Trainable): logger.debug("Clearing temporary checkpoint: %s", self.temp_checkpoint_dir) + def reset_config(self, new_config): + if self._runner and self._runner.is_alive(): + self._end_event.set() + self._continue_semaphore.release() + # Wait for thread termination so it is save to re-use the same + # actor. + thread_timeout = int( + os.environ.get("TUNE_FUNCTION_THREAD_TIMEOUT_S", 2)) + self._runner.join(timeout=thread_timeout) + if self._runner.is_alive(): + # Did not finish within timeout, reset unsuccessful. + return False + + self._runner = None + self._last_result = {} + + self._status_reporter.reset( + trial_name=self.trial_name, + trial_id=self.trial_id, + logdir=self.logdir) + + return True + def _report_thread_runner_error(self, block=False): try: err_tb_str = self._error_queue.get( diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 4aaf19e03..bbe2def83 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -14,6 +14,7 @@ from ray import ray_constants from ray.resource_spec import ResourceSpec from ray.tune.durable_trainable import DurableTrainable from ray.tune.error import AbortTrialExecution, TuneError +from ray.tune.function_runner import FunctionRunner from ray.tune.logger import NoopLogger from ray.tune.result import TRIAL_INFO, STDOUT_FILE, STDERR_FILE from ray.tune.resources import Resources @@ -276,13 +277,13 @@ class RayTrialExecutor(TrialExecutor): """ prior_status = trial.status if runner is None: - # TODO: Right now, we only support reuse if there has been - # previously instantiated state on the worker. However, - # we should consider the case where function evaluations - # can be very fast - thereby extending the need to support - # reuse to cases where there has not been previously - # instantiated state before. - reuse_allowed = checkpoint is not None or trial.has_checkpoint() + # We reuse actors when there is previously instantiated state on + # the actor. Function API calls are also supported when there is + # no checkpoint to continue from. + # TODO: Check preconditions - why is previous state needed? + reuse_allowed = checkpoint is not None or trial.has_checkpoint() \ + or issubclass(trial.get_trainable_cls(), + FunctionRunner) runner = self._setup_remote_runner(trial, reuse_allowed) trial.set_runner(runner) self.restore(trial, checkpoint) diff --git a/python/ray/tune/tests/test_actor_reuse.py b/python/ray/tune/tests/test_actor_reuse.py index 1d06628ed..8dfc2e9a2 100644 --- a/python/ray/tune/tests/test_actor_reuse.py +++ b/python/ray/tune/tests/test_actor_reuse.py @@ -1,11 +1,14 @@ import os +import pickle import unittest import sys +from collections import defaultdict import ray from ray import tune, logger from ray.tune import Trainable, run_experiments, register_trainable from ray.tune.error import TuneError +from ray.tune.function_runner import wrap_function from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler @@ -30,6 +33,7 @@ def create_resettable_class(): logger.info("LOG_STDERR: {}".format(self.msg)) return { + "id": self.config["id"], "num_resets": self.num_resets, "done": self.iter > 1, "iter": self.iter @@ -51,6 +55,35 @@ def create_resettable_class(): return MyResettableClass +def create_resettable_function(num_resets: defaultdict): + def trainable(config, checkpoint_dir=None): + if checkpoint_dir: + with open(os.path.join(checkpoint_dir, "chkpt"), "rb") as fp: + step = pickle.load(fp) + else: + step = 0 + + while step < 2: + step += 1 + with tune.checkpoint_dir(step) as checkpoint_dir: + with open(os.path.join(checkpoint_dir, "chkpt"), "wb") as fp: + pickle.dump(step, fp) + tune.report(**{ + "done": step >= 2, + "iter": step, + "id": config["id"] + }) + + trainable = wrap_function(trainable) + + class ResetCountTrainable(trainable): + def reset_config(self, new_config): + num_resets[self.trial_id] += 1 + return super().reset_config(new_config) + + return ResetCountTrainable + + class ActorReuseTest(unittest.TestCase): def setUp(self): ray.init(num_cpus=1, num_gpus=0) @@ -58,38 +91,56 @@ class ActorReuseTest(unittest.TestCase): def tearDown(self): ray.shutdown() - def testTrialReuseDisabled(self): + def _run_trials_with_frequent_pauses(self, trainable, reuse=False): trials = run_experiments( { "foo": { - "run": create_resettable_class(), - "num_samples": 4, - "config": {}, + "run": trainable, + "num_samples": 1, + "config": { + "id": tune.grid_search([0, 1, 2, 3]) + }, } }, - reuse_actors=False, + reuse_actors=reuse, scheduler=FrequentPausesScheduler(), verbose=0) + return trials + + def testTrialReuseDisabled(self): + trials = self._run_trials_with_frequent_pauses( + create_resettable_class(), reuse=False) + self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3]) self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2]) self.assertEqual([t.last_result["num_resets"] for t in trials], [0, 0, 0, 0]) + def testTrialReuseDisabledFunction(self): + num_resets = defaultdict(lambda: 0) + trials = self._run_trials_with_frequent_pauses( + create_resettable_function(num_resets), reuse=False) + self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3]) + self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2]) + self.assertEqual([num_resets[t.trial_id] for t in trials], + [0, 0, 0, 0]) + def testTrialReuseEnabled(self): - trials = run_experiments( - { - "foo": { - "run": create_resettable_class(), - "num_samples": 4, - "config": {}, - } - }, - reuse_actors=True, - scheduler=FrequentPausesScheduler(), - verbose=0) + trials = self._run_trials_with_frequent_pauses( + create_resettable_class(), reuse=True) + self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3]) self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2]) self.assertEqual([t.last_result["num_resets"] for t in trials], [1, 2, 3, 4]) + def testTrialReuseEnabledFunction(self): + num_resets = defaultdict(lambda: 0) + trials = self._run_trials_with_frequent_pauses( + create_resettable_function(num_resets), reuse=True) + self.assertEqual([t.last_result["id"] for t in trials], [0, 1, 2, 3]) + self.assertEqual([t.last_result["iter"] for t in trials], [2, 2, 2, 2]) + self.assertEqual([num_resets[t.trial_id] for t in trials], + [0, 0, 0, 0]) + def testReuseEnabledError(self): def run(): run_experiments( @@ -97,8 +148,9 @@ class ActorReuseTest(unittest.TestCase): "foo": { "run": create_resettable_class(), "max_failures": 1, - "num_samples": 4, + "num_samples": 1, "config": { + "id": tune.grid_search([0, 1, 2, 3]), "fake_reset_not_supported": True }, } @@ -115,7 +167,8 @@ class ActorReuseTest(unittest.TestCase): [trial1, trial2] = tune.run( "foo2", config={ - "message": tune.grid_search(["First", "Second"]) + "message": tune.grid_search(["First", "Second"]), + "id": -1 }, log_to_file=True, scheduler=FrequentPausesScheduler(), diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index e8aff4bfb..3b17267a1 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -552,7 +552,22 @@ class Trainable: self._close_logfiles() self._open_logfiles(stdout_file, stderr_file) - return self.reset_config(new_config) + success = self.reset_config(new_config) + if not success: + return False + + # Reset attributes. Will be overwritten by `restore` if a checkpoint + # is provided. + self._iteration = 0 + self._time_total = 0.0 + self._timesteps_total = None + self._episodes_total = None + self._time_since_restore = 0.0 + self._timesteps_since_restore = 0 + self._iterations_since_restore = 0 + self._restored = False + + return True def reset_config(self, new_config): """Resets configuration without restarting the trial. diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 2334b668f..ac3a47fe8 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -724,7 +724,6 @@ class TrialRunner: """ try: result = self.trial_executor.fetch_result(trial) - is_duplicate = RESULT_DUPLICATE in result force_checkpoint = result.get(SHOULD_CHECKPOINT, False) # TrialScheduler and SearchAlgorithm still receive a diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 30dbde888..3d189f409 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -1,5 +1,6 @@ import logging import sys +import time from ray.tune.error import TuneError from ray.tune.experiment import convert_to_experiment_list, Experiment @@ -255,6 +256,7 @@ def run( Raises: TuneError: Any trials failed and `raise_on_failed_trial` is True. """ + all_start = time.time() if global_checkpoint_period: raise ValueError("global_checkpoint_period is deprecated. Set env var " "'TUNE_GLOBAL_CHECKPOINT_S' instead.") @@ -404,10 +406,12 @@ def run( "`Trainable.default_resource_request` if using the " "Trainable API.") + tune_start = time.time() while not runner.is_finished(): runner.step() if verbose: _report_progress(runner, progress_reporter) + tune_taken = time.time() - tune_start try: runner.checkpoint(force=True) @@ -431,6 +435,10 @@ def run( else: logger.error("Trials did not complete: %s", incomplete_trials) + all_taken = time.time() - all_start + logger.info(f"Total run time: {all_taken:.2f} seconds " + f"({tune_taken:.2f} seconds for the tuning loop).") + trials = runner.get_trials() return ExperimentAnalysis( runner.checkpoint_file,