From 7c3274e65b75d4533b744639e8fb62fb05bf7f75 Mon Sep 17 00:00:00 2001 From: gehring Date: Mon, 18 Mar 2019 22:14:26 -0400 Subject: [PATCH] [tune] Make the logging of the function API consistent and predictable (#4011) ## What do these changes do? This is a re-implementation of the `FunctionRunner` which enforces some synchronicity between the thread running the training function and the thread running the Trainable which logs results. The main purpose is to make logging consistent across APIs in anticipation of a new function API which will be generator based (through `yield` statements). Without these changes, it will be impossible for the (possibly soon to be) deprecated reporter based API to behave the same as the generator based API. This new implementation provides additional guarantees to prevent results from being dropped. This makes the logging behavior more intuitive and consistent with how results are handled in custom subclasses of Trainable. New guarantees for the tune function API: - Every reported result, i.e., `reporter(**kwargs)` calls, is forwarded to the appropriate loggers instead of being dropped if not enough time has elapsed since the last results. - The wrapped function only runs if the `FunctionRunner` expects a result, i.e., when `FunctionRunner._train()` has been called. This removes the possibility that a result will be generated by the function but never logged. - The wrapped function is not called until the first `_train()` call. Currently, the wrapped function is started during the setup phase which could result in dropped results if the trial is cancelled between `_setup()` and the first `_train()` call. - Exceptions raised by the wrapped function won't be propagated until all results are logged to prevent dropped results. - The thread running the wrapped function is explicitly stopped when the `FunctionRunner` is stopped with `_stop()`. - If the wrapped function terminates without reporting `done=True`, a duplicate result with `{"done": True}`, is reported to explicitly terminate the trial, and components will be notified with a duplicate of the last reported result, but this duplicate will not be logged. ## Related issue number Closes #3956. #3949 #3834 --- python/ray/tune/function_runner.py | 253 ++++++++++++----- python/ray/tune/registry.py | 3 +- python/ray/tune/tests/test_trial_runner.py | 298 +++++++++++++-------- python/ray/tune/trainable.py | 10 - python/ray/tune/trial_runner.py | 12 +- 5 files changed, 385 insertions(+), 191 deletions(-) diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index bae885970..7dbf02ef8 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -3,15 +3,24 @@ from __future__ import division from __future__ import print_function import logging +import sys import time import threading +from six.moves import queue from ray.tune import TuneError from ray.tune.trainable import Trainable -from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION +from ray.tune.result import TIME_THIS_ITER_S logger = logging.getLogger(__name__) +# Time between FunctionRunner checks when fetching +# new results after signaling the reporter to continue +RESULT_FETCH_TIMEOUT = 0.2 + +ERROR_REPORT_TIMEOUT = 10 +ERROR_FETCH_TIMEOUT = 1 + class StatusReporter(object): """Object passed into your function that you can report status through. @@ -19,16 +28,13 @@ class StatusReporter(object): Example: >>> def trainable_function(config, reporter): >>> assert isinstance(reporter, StatusReporter) - >>> reporter(timesteps_total=1) + >>> reporter(timesteps_this_iter=1) """ - def __init__(self): - self._latest_result = None - self._last_result = None - self._lock = threading.Lock() - self._error = None - self._done = False - self._iteration = 0 + def __init__(self, result_queue, continue_semaphore): + self._queue = result_queue + self._last_report_time = None + self._continue_semaphore = continue_semaphore def __call__(self, **kwargs): """Report updated training status. @@ -41,82 +47,101 @@ class StatusReporter(object): Example: >>> reporter(mean_accuracy=1, training_iteration=4) >>> reporter(mean_accuracy=1, training_iteration=4, done=True) + + Raises: + StopIteration: A StopIteration exception is raised if the trial has + been signaled to stop. """ - with self._lock: - self._latest_result = self._last_result = kwargs.copy() - self._iteration += 1 + assert self._last_report_time is not None, ( + "StatusReporter._start() must be called before the first " + "report __call__ is made to ensure correct runtime metrics.") - def _get_and_clear_status(self): - if self._error: - raise TuneError("Error running trial: " + str(self._error)) - if self._done and not self._latest_result: - if not self._last_result: - raise TuneError("Trial finished without reporting result!") - logger.warning("Trial detected as completed; re-reporting " - "last result. To avoid this, include done=True " - "upon the last reporter call.") - self._last_result.update(done=True) - self._last_result.setdefault(TRAINING_ITERATION, self._iteration) - return self._last_result - with self._lock: - res = self._latest_result - self._latest_result = None - if res: - res.setdefault(TRAINING_ITERATION, self._iteration) - return res + # time per iteration is recorded directly in the reporter to ensure + # any delays in logging results aren't counted + report_time = time.time() + if TIME_THIS_ITER_S not in kwargs: + kwargs[TIME_THIS_ITER_S] = report_time - self._last_report_time + self._last_report_time = report_time - def _stop(self): - self._error = "Agent stopped" + # add results to a thread-safe queue + self._queue.put(kwargs.copy(), block=True) + # This blocks until notification from the FunctionRunner that the last + # result has been returned to Tune and that the function is safe to + # resume training. + self._continue_semaphore.acquire() -DEFAULT_CONFIG = { - # batch results to at least this granularity - "script_min_iter_time_s": 1, -} + def _start(self): + self._last_report_time = time.time() class _RunnerThread(threading.Thread): """Supervisor thread that runs your script.""" - def __init__(self, entrypoint, config, status_reporter): - self._entrypoint = entrypoint - self._entrypoint_args = [config, status_reporter] - self._status_reporter = status_reporter + def __init__(self, entrypoint, error_queue): threading.Thread.__init__(self) + self._entrypoint = entrypoint + self._error_queue = error_queue self.daemon = True def run(self): try: - self._entrypoint(*self._entrypoint_args) + self._entrypoint() + except StopIteration: + logger.debug( + ("Thread runner raised StopIteration. Interperting it as a " + "signal to terminate the thread without error.")) except Exception as e: - self._status_reporter._error = e logger.exception("Runner Thread raised error.") + try: + # report the error but avoid indefinite blocking which would + # prevent the exception from being propagated in the unlikely + # case that something went terribly wrong + err_type, err_value, err_tb = sys.exc_info() + err_tb = err_tb.format_exc() + self._error_queue.put( + (err_type, err_value, err_tb), + block=True, + timeout=ERROR_REPORT_TIMEOUT) + except queue.Full: + logger.critical( + ("Runner Thread was unable to report error to main " + "function runner thread. This means a previous error " + "was not processed. This should never happen.")) raise e - finally: - self._status_reporter._done = True class FunctionRunner(Trainable): - """Trainable that runs a user function returning training results. + """Trainable that runs a user function reporting results. This mode of execution does not support checkpoint/restore.""" _name = "func" - _default_config = DEFAULT_CONFIG def _setup(self, config): - entrypoint = self._trainable_func() - self._status_reporter = StatusReporter() - scrubbed_config = config.copy() - for k in self._default_config: - if k in scrubbed_config: - del scrubbed_config[k] - self._runner = _RunnerThread(entrypoint, scrubbed_config, - self._status_reporter) - self._start_time = time.time() - self._last_reported_timestep = 0 - self._runner.start() + # Semaphore for notifying the reporter to continue with the computation + # and to generate the next result. + self._continue_semaphore = threading.Semaphore(0) + + # Queue for passing results between threads + self._results_queue = queue.Queue(1) + + # Queue for passing errors back from the thread runner. The error queue + # has a max size of one to prevent stacking error and force error + # reporting to block until finished. + self._error_queue = queue.Queue(1) + + self._status_reporter = StatusReporter(self._results_queue, + self._continue_semaphore) + self._last_result = {} + config = config.copy() + + def entrypoint(): + return self._trainable_func(config, self._status_reporter) + + # the runner thread is not started until the first call to _train + self._runner = _RunnerThread(entrypoint, self._error_queue) def _trainable_func(self): """Subclasses can override this to set the trainable func.""" @@ -124,22 +149,108 @@ class FunctionRunner(Trainable): raise NotImplementedError def _train(self): - time.sleep( - self.config.get("script_min_iter_time_s", - self._default_config["script_min_iter_time_s"])) - result = self._status_reporter._get_and_clear_status() - while result is None: - time.sleep(1) - result = self._status_reporter._get_and_clear_status() + """Implements train() for a Function API. - curr_ts_total = result.get(TIMESTEPS_TOTAL) - if curr_ts_total is not None: - result.update( - timesteps_this_iter=( - curr_ts_total - self._last_reported_timestep)) - self._last_reported_timestep = curr_ts_total + If the RunnerThread finishes without reporting "done", + Tune will automatically provide a magic keyword __duplicate__ + along with a result with "done=True". The TrialRunner will handle the + result accordingly (see tune/trial_runner.py). + """ + if self._runner.is_alive(): + # if started and alive, inform the reporter to continue and + # generate the next result + self._continue_semaphore.release() + else: + # if not alive, try to start + self._status_reporter._start() + try: + self._runner.start() + except RuntimeError: + # If this is reached, it means the thread was started and is + # now done or has raised an exception. + pass + result = None + while result is None and self._runner.is_alive(): + # fetch the next produced result + try: + result = self._results_queue.get( + block=True, timeout=RESULT_FETCH_TIMEOUT) + except queue.Empty: + pass + + # if no result were found, then the runner must no longer be alive + if result is None: + # Try one last time to fetch results in case results were reported + # in between the time of the last check and the termination of the + # thread runner. + try: + result = self._results_queue.get(block=False) + except queue.Empty: + pass + + # check if error occured inside the thread runner + if result is None: + # only raise an error from the runner if all results are consumed + self._report_thread_runner_error(block=True) + + # Under normal conditions, this code should never be reached since + # this branch should only be visited if the runner thread raised + # an exception. If no exception were raised, it means that the + # runner thread never reported any results which should not be + # possible when wrapping functions with `wrap_function`. + raise TuneError( + ("Wrapped function ran until completion without reporting " + "results or raising an exception.")) + + else: + if not self._error_queue.empty(): + logger.warning( + ("Runner error waiting to be raised in main thread. " + "Logging all available results first.")) + + # This keyword appears if the train_func using the Function API + # finishes without "done=True". This duplicates the last result, but + # the TrialRunner will not log this result again. + if "__duplicate__" in result: + new_result = self._last_result.copy() + new_result.update(result) + result = new_result + + self._last_result = result return result def _stop(self): - self._status_reporter._stop() + # If everything stayed in synch properly, this should never happen. + if not self._results_queue.empty(): + logger.warning( + ("Some results were added after the trial stop condition. " + "These results won't be logged.")) + + # Check for any errors that might have been missed. + self._report_thread_runner_error() + + def _report_thread_runner_error(self, block=False): + try: + err_type, err_value, err_tb = self._error_queue.get( + block=block, timeout=ERROR_FETCH_TIMEOUT) + raise TuneError(("Trial raised a {err_type} exception with value: " + "{err_value}\nWith traceback:\n{err_tb}").format( + err_type=err_type, + err_value=err_value, + err_tb=err_tb)) + except queue.Empty: + pass + + +def wrap_function(train_func): + class WrappedFunc(FunctionRunner): + def _trainable_func(self, config, reporter): + output = train_func(config, reporter) + # If train_func returns, we need to notify the main event loop + # of the last result while avoiding double logging. This is done + # with the keyword "__duplicate__" -- see tune/trial_runner.py, + reporter(done=True, __duplicate__=True) + return output + + return WrappedFunc diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 212ac031b..6202013c7 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -31,7 +31,8 @@ def register_trainable(name, trainable): automatically converted into a class during registration. """ - from ray.tune.trainable import Trainable, wrap_function + from ray.tune.trainable import Trainable + from ray.tune.function_runner import wrap_function if isinstance(trainable, type): logger.debug("Detected class for trainable.") diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 2da8d85b2..7e3ee1071 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import os import shutil import sys @@ -19,7 +20,8 @@ from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.schedulers import TrialScheduler, FIFOScheduler from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE, - EPISODES_TOTAL, TRAINING_ITERATION) + EPISODES_TOTAL, TRAINING_ITERATION, + TIMESTEPS_THIS_ITER) from ray.tune.logger import Logger from ray.tune.util import pin_in_object_store, get_pinned_object from ray.tune.experiment import Experiment @@ -46,6 +48,93 @@ class TrainableFunctionApiTest(unittest.TestCase): ray.shutdown() _register_all() # re-register the evicted objects + def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None): + """Checks logging is the same between APIs. + + Ignore "DONE" for logging but checks that the + scheduler is notified properly with the last result. + """ + class_results = copy.deepcopy(results) + function_results = copy.deepcopy(results) + + class_output = [] + function_output = [] + scheduler_notif = [] + + class MockScheduler(FIFOScheduler): + def on_trial_complete(self, runner, trial, result): + scheduler_notif.append(result) + + class ClassAPILogger(Logger): + def on_result(self, result): + class_output.append(result) + + class FunctionAPILogger(Logger): + def on_result(self, result): + function_output.append(result) + + class _WrappedTrainable(Trainable): + def _setup(self, config): + del config + self._result_iter = copy.deepcopy(class_results) + + def _train(self): + if sleep_per_iter: + time.sleep(sleep_per_iter) + res = self._result_iter.pop(0) # This should not fail + if not self._result_iter: # Mark "Done" for last result + res[DONE] = True + return res + + def _function_trainable(config, reporter): + for result in function_results: + if sleep_per_iter: + time.sleep(sleep_per_iter) + reporter(**result) + + class_trainable_name = "class_trainable" + register_trainable(class_trainable_name, _WrappedTrainable) + + trials = run_experiments( + { + "function_api": { + "run": _function_trainable, + "loggers": [FunctionAPILogger], + }, + "class_api": { + "run": class_trainable_name, + "loggers": [ClassAPILogger], + }, + }, + raise_on_failed_trial=False, + scheduler=MockScheduler()) + + # Only compare these result fields. Metadata handling + # may be different across APIs. + COMPARE_FIELDS = {field for res in results for field in res} + + self.assertEqual(len(class_output), len(results)) + self.assertEqual(len(function_output), len(results)) + + def as_comparable_result(result): + return {k: v for k, v in result.items() if k in COMPARE_FIELDS} + + function_comparable = [ + as_comparable_result(result) for result in function_output + ] + class_comparable = [ + as_comparable_result(result) for result in class_output + ] + + self.assertEqual(function_comparable, class_comparable) + + self.assertEqual(sum(t.get(DONE) for t in scheduler_notif), 2) + self.assertEqual( + as_comparable_result(scheduler_notif[0]), + as_comparable_result(scheduler_notif[1])) + + return function_output, trials + def testPinObject(self): X = pin_in_object_store("hello") @@ -66,9 +155,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -90,9 +176,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -123,9 +206,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "test", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -337,9 +417,6 @@ class TrainableFunctionApiTest(unittest.TestCase): "stop": { "time": 10 }, - "config": { - "script_min_iter_time_s": 0, - }, } }) @@ -354,25 +431,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, - } - }) - self.assertEqual(trial.status, Trial.TERMINATED) - self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 100) - - def testAbruptReturn(self): - def train(config, reporter): - reporter(timesteps_total=100) - - register_trainable("f1", train) - [trial] = run_experiments({ - "foo": { - "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -388,9 +446,6 @@ class TrainableFunctionApiTest(unittest.TestCase): run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) @@ -405,9 +460,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -415,9 +467,7 @@ class TrainableFunctionApiTest(unittest.TestCase): def testNoRaiseFlag(self): def train(config, reporter): - # Finish this trial without any metric, - # which leads to a failed trial - return + raise Exception() register_trainable("f1", train) @@ -425,12 +475,8 @@ class TrainableFunctionApiTest(unittest.TestCase): { "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } - }, - raise_on_failed_trial=False) + }, raise_on_failed_trial=False) self.assertEqual(trial.status, Trial.ERROR) def testReportInfinity(self): @@ -442,58 +488,116 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result['mean_accuracy'], float('inf')) def testReportTimeStep(self): - def train(config, reporter): - for i in range(100): - reporter(mean_accuracy=5) + # Test that no timestep count are logged if never the Trainable never + # returns any. + results1 = [dict(mean_accuracy=5, done=i == 99) for i in range(100)] + logs1, _ = self.checkAndReturnConsistentLogs(results1) - [trial] = run_experiments({ - "foo": { - "run": train, - "config": { - "script_min_iter_time_s": 0, - }, - } - }) - self.assertIsNone(trial.last_result[TIMESTEPS_TOTAL]) + self.assertTrue(all(log[TIMESTEPS_TOTAL] is None for log in logs1)) - def train2(config, reporter): - for i in range(10): - reporter(timesteps_total=5) + # Test that no timesteps_this_iter are logged if only timesteps_total + # are returned. + results2 = [dict(timesteps_total=5, done=i == 9) for i in range(10)] + logs2, _ = self.checkAndReturnConsistentLogs(results2) - [trial2] = run_experiments({ - "foo": { - "run": train2, - "config": { - "script_min_iter_time_s": 0, - }, - } - }) - self.assertEqual(trial2.last_result[TIMESTEPS_TOTAL], 5) - self.assertEqual(trial2.last_result["timesteps_this_iter"], 0) + # Re-run the same trials but with added delay. This is to catch some + # inconsistent timestep counting that was present in the multi-threaded + # FunctionRunner. This part of the test can be removed once the + # multi-threaded FunctionRunner is removed from ray/tune. + # TODO: remove once the multi-threaded function runner is gone. + logs2, _ = self.checkAndReturnConsistentLogs(results2, 0.5) - def train3(config, reporter): - for i in range(10): - reporter(timesteps_this_iter=0, episodes_this_iter=0) + # check all timesteps_total report the same value + self.assertTrue(all(log[TIMESTEPS_TOTAL] == 5 for log in logs2)) + # check that none of the logs report timesteps_this_iter + self.assertFalse( + any(hasattr(log, TIMESTEPS_THIS_ITER) for log in logs2)) - [trial3] = run_experiments({ - "foo": { - "run": train3, - "config": { - "script_min_iter_time_s": 0, - }, - } - }) - self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 0) - self.assertEqual(trial3.last_result[EPISODES_TOTAL], 0) + # Test that timesteps_total and episodes_total are reported when + # timesteps_this_iter and episodes_this_iter despite only return zeros. + results3 = [ + dict(timesteps_this_iter=0, episodes_this_iter=0) + for i in range(10) + ] + logs3, _ = self.checkAndReturnConsistentLogs(results3) + + self.assertTrue(all(log[TIMESTEPS_TOTAL] == 0 for log in logs3)) + self.assertTrue(all(log[EPISODES_TOTAL] == 0 for log in logs3)) + + # Test that timesteps_total and episodes_total are properly counted + # when timesteps_this_iter and episodes_this_iter report non-zero + # values. + results4 = [ + dict(timesteps_this_iter=3, episodes_this_iter=i) + for i in range(10) + ] + logs4, _ = self.checkAndReturnConsistentLogs(results4) + + # The last reported result should not be double-logged. + self.assertEqual(logs4[-1][TIMESTEPS_TOTAL], 30) + self.assertNotEqual(logs4[-2][TIMESTEPS_TOTAL], + logs4[-1][TIMESTEPS_TOTAL]) + self.assertEqual(logs4[-1][EPISODES_TOTAL], 45) + self.assertNotEqual(logs4[-2][EPISODES_TOTAL], + logs4[-1][EPISODES_TOTAL]) + + def testAllValuesReceived(self): + results1 = [ + dict(timesteps_total=(i + 1), my_score=i**2, done=i == 4) + for i in range(5) + ] + + logs1, _ = self.checkAndReturnConsistentLogs(results1) + + # check if the correct number of results were reported + self.assertEqual(len(logs1), len(results1)) + + def check_no_missing(reported_result, result): + common_results = [reported_result[k] == result[k] for k in result] + return all(common_results) + + # check that no result was dropped or modified + complete_results = [ + check_no_missing(log, result) + for log, result in zip(logs1, results1) + ] + self.assertTrue(all(complete_results)) + + # check if done was logged exactly once + self.assertEqual(len([r for r in logs1 if r.get("done")]), 1) + + def testNoDoneReceived(self): + # repeat same test but without explicitly reporting done=True + results1 = [ + dict(timesteps_total=(i + 1), my_score=i**2) for i in range(5) + ] + + logs1, trials = self.checkAndReturnConsistentLogs(results1) + + # check if the correct number of results were reported. + self.assertEqual(len(logs1), len(results1)) + + # We should not double-log + trial = [t for t in trials if "function" in str(t)][0] + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertEqual(trial.last_result[DONE], False) + + def check_no_missing(reported_result, result): + common_results = [reported_result[k] == result[k] for k in result] + return all(common_results) + + # check that no result was dropped or modified + complete_results1 = [ + check_no_missing(log, result) + for log, result in zip(logs1, results1) + ] + self.assertTrue(all(complete_results1)) def testCheckpointDict(self): class TestTrain(Trainable): @@ -563,7 +667,7 @@ class TrainableFunctionApiTest(unittest.TestCase): def testIterationCounter(self): def train(config, reporter): for i in range(100): - reporter(itr=i, done=i == 99) + reporter(itr=i, timesteps_this_iter=1) register_trainable("exp", train) config = { @@ -600,15 +704,9 @@ class RunExperimentTest(unittest.TestCase): trials = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } }, "bar": { "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } } }) for trial in trials: @@ -624,9 +722,6 @@ class RunExperimentTest(unittest.TestCase): exp1 = Experiment(**{ "name": "foo", "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } }) [trial] = run_experiments(exp1) self.assertEqual(trial.status, Trial.TERMINATED) @@ -641,16 +736,10 @@ class RunExperimentTest(unittest.TestCase): exp1 = Experiment(**{ "name": "foo", "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } }) exp2 = Experiment(**{ "name": "bar", "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } }) trials = run_experiments([exp1, exp2]) for trial in trials: @@ -670,9 +759,6 @@ class RunExperimentTest(unittest.TestCase): trials = run_experiments({ "foo": { "run": train, - "config": { - "script_min_iter_time_s": 0 - } }, "bar": { "run": B diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 8b640cbb4..7661d01dd 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -445,13 +445,3 @@ class Trainable(object): A dict that maps ExportFormats to successfully exported models. """ return {} - - -def wrap_function(train_func): - from ray.tune.function_runner import FunctionRunner - - class WrappedFunc(FunctionRunner): - def _trainable_func(self): - return train_func - - return WrappedFunc diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index b3e939b2f..1dc33c03d 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -415,7 +415,6 @@ class TrialRunner(object): self._search_alg.on_trial_complete( trial.trial_id, result=result) decision = TrialScheduler.STOP - else: with warn_if_slow("scheduler.on_trial_result"): decision = self._scheduler_alg.on_trial_result( @@ -426,8 +425,15 @@ class TrialRunner(object): with warn_if_slow("search_alg.on_trial_complete"): self._search_alg.on_trial_complete( trial.trial_id, early_terminated=True) - trial.update_last_result( - result, terminate=(decision == TrialScheduler.STOP)) + + # __duplicate__ is a magic keyword used internally to + # avoid double-logging results when using the Function API. + # TrialScheduler and SearchAlgorithm still receive a + # notification because there may be special handling for + # the `on_trial_complete` hook. + if "__duplicate__" not in result: + trial.update_last_result( + result, terminate=(decision == TrialScheduler.STOP)) # Checkpoints to disk. This should be checked even if # the scheduler decision is STOP or PAUSE. Note that