diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index bae885970..7dbf02ef8 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -3,15 +3,24 @@ from __future__ import division from __future__ import print_function import logging +import sys import time import threading +from six.moves import queue from ray.tune import TuneError from ray.tune.trainable import Trainable -from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION +from ray.tune.result import TIME_THIS_ITER_S logger = logging.getLogger(__name__) +# Time between FunctionRunner checks when fetching +# new results after signaling the reporter to continue +RESULT_FETCH_TIMEOUT = 0.2 + +ERROR_REPORT_TIMEOUT = 10 +ERROR_FETCH_TIMEOUT = 1 + class StatusReporter(object): """Object passed into your function that you can report status through. @@ -19,16 +28,13 @@ class StatusReporter(object): Example: >>> def trainable_function(config, reporter): >>> assert isinstance(reporter, StatusReporter) - >>> reporter(timesteps_total=1) + >>> reporter(timesteps_this_iter=1) """ - def __init__(self): - self._latest_result = None - self._last_result = None - self._lock = threading.Lock() - self._error = None - self._done = False - self._iteration = 0 + def __init__(self, result_queue, continue_semaphore): + self._queue = result_queue + self._last_report_time = None + self._continue_semaphore = continue_semaphore def __call__(self, **kwargs): """Report updated training status. @@ -41,82 +47,101 @@ class StatusReporter(object): Example: >>> reporter(mean_accuracy=1, training_iteration=4) >>> reporter(mean_accuracy=1, training_iteration=4, done=True) + + Raises: + StopIteration: A StopIteration exception is raised if the trial has + been signaled to stop. """ - with self._lock: - self._latest_result = self._last_result = kwargs.copy() - self._iteration += 1 + assert self._last_report_time is not None, ( + "StatusReporter._start() must be called before the first " + "report __call__ is made to ensure correct runtime metrics.") - def _get_and_clear_status(self): - if self._error: - raise TuneError("Error running trial: " + str(self._error)) - if self._done and not self._latest_result: - if not self._last_result: - raise TuneError("Trial finished without reporting result!") - logger.warning("Trial detected as completed; re-reporting " - "last result. To avoid this, include done=True " - "upon the last reporter call.") - self._last_result.update(done=True) - self._last_result.setdefault(TRAINING_ITERATION, self._iteration) - return self._last_result - with self._lock: - res = self._latest_result - self._latest_result = None - if res: - res.setdefault(TRAINING_ITERATION, self._iteration) - return res + # time per iteration is recorded directly in the reporter to ensure + # any delays in logging results aren't counted + report_time = time.time() + if TIME_THIS_ITER_S not in kwargs: + kwargs[TIME_THIS_ITER_S] = report_time - self._last_report_time + self._last_report_time = report_time - def _stop(self): - self._error = "Agent stopped" + # add results to a thread-safe queue + self._queue.put(kwargs.copy(), block=True) + # This blocks until notification from the FunctionRunner that the last + # result has been returned to Tune and that the function is safe to + # resume training. + self._continue_semaphore.acquire() -DEFAULT_CONFIG = { - # batch results to at least this granularity - "script_min_iter_time_s": 1, -} + def _start(self): + self._last_report_time = time.time() class _RunnerThread(threading.Thread): """Supervisor thread that runs your script.""" - def __init__(self, entrypoint, config, status_reporter): - self._entrypoint = entrypoint - self._entrypoint_args = [config, status_reporter] - self._status_reporter = status_reporter + def __init__(self, entrypoint, error_queue): threading.Thread.__init__(self) + self._entrypoint = entrypoint + self._error_queue = error_queue self.daemon = True def run(self): try: - self._entrypoint(*self._entrypoint_args) + self._entrypoint() + except StopIteration: + logger.debug( + ("Thread runner raised StopIteration. Interperting it as a " + "signal to terminate the thread without error.")) except Exception as e: - self._status_reporter._error = e logger.exception("Runner Thread raised error.") + try: + # report the error but avoid indefinite blocking which would + # prevent the exception from being propagated in the unlikely + # case that something went terribly wrong + err_type, err_value, err_tb = sys.exc_info() + err_tb = err_tb.format_exc() + self._error_queue.put( + (err_type, err_value, err_tb), + block=True, + timeout=ERROR_REPORT_TIMEOUT) + except queue.Full: + logger.critical( + ("Runner Thread was unable to report error to main " + "function runner thread. This means a previous error " + "was not processed. This should never happen.")) raise e - finally: - self._status_reporter._done = True class FunctionRunner(Trainable): - """Trainable that runs a user function returning training results. + """Trainable that runs a user function reporting results. This mode of execution does not support checkpoint/restore.""" _name = "func" - _default_config = DEFAULT_CONFIG def _setup(self, config): - entrypoint = self._trainable_func() - self._status_reporter = StatusReporter() - scrubbed_config = config.copy() - for k in self._default_config: - if k in scrubbed_config: - del scrubbed_config[k] - self._runner = _RunnerThread(entrypoint, scrubbed_config, - self._status_reporter) - self._start_time = time.time() - self._last_reported_timestep = 0 - self._runner.start() + # Semaphore for notifying the reporter to continue with the computation + # and to generate the next result. + self._continue_semaphore = threading.Semaphore(0) + + # Queue for passing results between threads + self._results_queue = queue.Queue(1) + + # Queue for passing errors back from the thread runner. The error queue + # has a max size of one to prevent stacking error and force error + # reporting to block until finished. + self._error_queue = queue.Queue(1) + + self._status_reporter = StatusReporter(self._results_queue, + self._continue_semaphore) + self._last_result = {} + config = config.copy() + + def entrypoint(): + return self._trainable_func(config, self._status_reporter) + + # the runner thread is not started until the first call to _train + self._runner = _RunnerThread(entrypoint, self._error_queue) def _trainable_func(self): """Subclasses can override this to set the trainable func.""" @@ -124,22 +149,108 @@ class FunctionRunner(Trainable): raise NotImplementedError def _train(self): - time.sleep( - self.config.get("script_min_iter_time_s", - self._default_config["script_min_iter_time_s"])) - result = self._status_reporter._get_and_clear_status() - while result is None: - time.sleep(1) - result = self._status_reporter._get_and_clear_status() + """Implements train() for a Function API. - curr_ts_total = result.get(TIMESTEPS_TOTAL) - if curr_ts_total is not None: - result.update( - timesteps_this_iter=( - curr_ts_total - self._last_reported_timestep)) - self._last_reported_timestep = curr_ts_total + If the RunnerThread finishes without reporting "done", + Tune will automatically provide a magic keyword __duplicate__ + along with a result with "done=True". The TrialRunner will handle the + result accordingly (see tune/trial_runner.py). + """ + if self._runner.is_alive(): + # if started and alive, inform the reporter to continue and + # generate the next result + self._continue_semaphore.release() + else: + # if not alive, try to start + self._status_reporter._start() + try: + self._runner.start() + except RuntimeError: + # If this is reached, it means the thread was started and is + # now done or has raised an exception. + pass + result = None + while result is None and self._runner.is_alive(): + # fetch the next produced result + try: + result = self._results_queue.get( + block=True, timeout=RESULT_FETCH_TIMEOUT) + except queue.Empty: + pass + + # if no result were found, then the runner must no longer be alive + if result is None: + # Try one last time to fetch results in case results were reported + # in between the time of the last check and the termination of the + # thread runner. + try: + result = self._results_queue.get(block=False) + except queue.Empty: + pass + + # check if error occured inside the thread runner + if result is None: + # only raise an error from the runner if all results are consumed + self._report_thread_runner_error(block=True) + + # Under normal conditions, this code should never be reached since + # this branch should only be visited if the runner thread raised + # an exception. If no exception were raised, it means that the + # runner thread never reported any results which should not be + # possible when wrapping functions with `wrap_function`. + raise TuneError( + ("Wrapped function ran until completion without reporting " + "results or raising an exception.")) + + else: + if not self._error_queue.empty(): + logger.warning( + ("Runner error waiting to be raised in main thread. " + "Logging all available results first.")) + + # This keyword appears if the train_func using the Function API + # finishes without "done=True". This duplicates the last result, but + # the TrialRunner will not log this result again. + if "__duplicate__" in result: + new_result = self._last_result.copy() + new_result.update(result) + result = new_result + + self._last_result = result return result def _stop(self): - self._status_reporter._stop() + # If everything stayed in synch properly, this should never happen. + if not self._results_queue.empty(): + logger.warning( + ("Some results were added after the trial stop condition. " + "These results won't be logged.")) + + # Check for any errors that might have been missed. + self._report_thread_runner_error() + + def _report_thread_runner_error(self, block=False): + try: + err_type, err_value, err_tb = self._error_queue.get( + block=block, timeout=ERROR_FETCH_TIMEOUT) + raise TuneError(("Trial raised a {err_type} exception with value: " + "{err_value}\nWith traceback:\n{err_tb}").format( + err_type=err_type, + err_value=err_value, + err_tb=err_tb)) + except queue.Empty: + pass + + +def wrap_function(train_func): + class WrappedFunc(FunctionRunner): + def _trainable_func(self, config, reporter): + output = train_func(config, reporter) + # If train_func returns, we need to notify the main event loop + # of the last result while avoiding double logging. This is done + # with the keyword "__duplicate__" -- see tune/trial_runner.py, + reporter(done=True, __duplicate__=True) + return output + + return WrappedFunc diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 212ac031b..6202013c7 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -31,7 +31,8 @@ def register_trainable(name, trainable): automatically converted into a class during registration. """ - from ray.tune.trainable import Trainable, wrap_function + from ray.tune.trainable import Trainable + from ray.tune.function_runner import wrap_function if isinstance(trainable, type): logger.debug("Detected class for trainable.") diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 2da8d85b2..7e3ee1071 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy import os import shutil import sys @@ -19,7 +20,8 @@ from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.schedulers import TrialScheduler, FIFOScheduler from ray.tune.registry import _global_registry, TRAINABLE_CLASS from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE, - EPISODES_TOTAL, TRAINING_ITERATION) + EPISODES_TOTAL, TRAINING_ITERATION, + TIMESTEPS_THIS_ITER) from ray.tune.logger import Logger from ray.tune.util import pin_in_object_store, get_pinned_object from ray.tune.experiment import Experiment @@ -46,6 +48,93 @@ class TrainableFunctionApiTest(unittest.TestCase): ray.shutdown() _register_all() # re-register the evicted objects + def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None): + """Checks logging is the same between APIs. + + Ignore "DONE" for logging but checks that the + scheduler is notified properly with the last result. + """ + class_results = copy.deepcopy(results) + function_results = copy.deepcopy(results) + + class_output = [] + function_output = [] + scheduler_notif = [] + + class MockScheduler(FIFOScheduler): + def on_trial_complete(self, runner, trial, result): + scheduler_notif.append(result) + + class ClassAPILogger(Logger): + def on_result(self, result): + class_output.append(result) + + class FunctionAPILogger(Logger): + def on_result(self, result): + function_output.append(result) + + class _WrappedTrainable(Trainable): + def _setup(self, config): + del config + self._result_iter = copy.deepcopy(class_results) + + def _train(self): + if sleep_per_iter: + time.sleep(sleep_per_iter) + res = self._result_iter.pop(0) # This should not fail + if not self._result_iter: # Mark "Done" for last result + res[DONE] = True + return res + + def _function_trainable(config, reporter): + for result in function_results: + if sleep_per_iter: + time.sleep(sleep_per_iter) + reporter(**result) + + class_trainable_name = "class_trainable" + register_trainable(class_trainable_name, _WrappedTrainable) + + trials = run_experiments( + { + "function_api": { + "run": _function_trainable, + "loggers": [FunctionAPILogger], + }, + "class_api": { + "run": class_trainable_name, + "loggers": [ClassAPILogger], + }, + }, + raise_on_failed_trial=False, + scheduler=MockScheduler()) + + # Only compare these result fields. Metadata handling + # may be different across APIs. + COMPARE_FIELDS = {field for res in results for field in res} + + self.assertEqual(len(class_output), len(results)) + self.assertEqual(len(function_output), len(results)) + + def as_comparable_result(result): + return {k: v for k, v in result.items() if k in COMPARE_FIELDS} + + function_comparable = [ + as_comparable_result(result) for result in function_output + ] + class_comparable = [ + as_comparable_result(result) for result in class_output + ] + + self.assertEqual(function_comparable, class_comparable) + + self.assertEqual(sum(t.get(DONE) for t in scheduler_notif), 2) + self.assertEqual( + as_comparable_result(scheduler_notif[0]), + as_comparable_result(scheduler_notif[1])) + + return function_output, trials + def testPinObject(self): X = pin_in_object_store("hello") @@ -66,9 +155,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -90,9 +176,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -123,9 +206,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "test", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -337,9 +417,6 @@ class TrainableFunctionApiTest(unittest.TestCase): "stop": { "time": 10 }, - "config": { - "script_min_iter_time_s": 0, - }, } }) @@ -354,25 +431,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, - } - }) - self.assertEqual(trial.status, Trial.TERMINATED) - self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 100) - - def testAbruptReturn(self): - def train(config, reporter): - reporter(timesteps_total=100) - - register_trainable("f1", train) - [trial] = run_experiments({ - "foo": { - "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -388,9 +446,6 @@ class TrainableFunctionApiTest(unittest.TestCase): run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) @@ -405,9 +460,6 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) @@ -415,9 +467,7 @@ class TrainableFunctionApiTest(unittest.TestCase): def testNoRaiseFlag(self): def train(config, reporter): - # Finish this trial without any metric, - # which leads to a failed trial - return + raise Exception() register_trainable("f1", train) @@ -425,12 +475,8 @@ class TrainableFunctionApiTest(unittest.TestCase): { "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } - }, - raise_on_failed_trial=False) + }, raise_on_failed_trial=False) self.assertEqual(trial.status, Trial.ERROR) def testReportInfinity(self): @@ -442,58 +488,116 @@ class TrainableFunctionApiTest(unittest.TestCase): [trial] = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0, - }, } }) self.assertEqual(trial.status, Trial.TERMINATED) self.assertEqual(trial.last_result['mean_accuracy'], float('inf')) def testReportTimeStep(self): - def train(config, reporter): - for i in range(100): - reporter(mean_accuracy=5) + # Test that no timestep count are logged if never the Trainable never + # returns any. + results1 = [dict(mean_accuracy=5, done=i == 99) for i in range(100)] + logs1, _ = self.checkAndReturnConsistentLogs(results1) - [trial] = run_experiments({ - "foo": { - "run": train, - "config": { - "script_min_iter_time_s": 0, - }, - } - }) - self.assertIsNone(trial.last_result[TIMESTEPS_TOTAL]) + self.assertTrue(all(log[TIMESTEPS_TOTAL] is None for log in logs1)) - def train2(config, reporter): - for i in range(10): - reporter(timesteps_total=5) + # Test that no timesteps_this_iter are logged if only timesteps_total + # are returned. + results2 = [dict(timesteps_total=5, done=i == 9) for i in range(10)] + logs2, _ = self.checkAndReturnConsistentLogs(results2) - [trial2] = run_experiments({ - "foo": { - "run": train2, - "config": { - "script_min_iter_time_s": 0, - }, - } - }) - self.assertEqual(trial2.last_result[TIMESTEPS_TOTAL], 5) - self.assertEqual(trial2.last_result["timesteps_this_iter"], 0) + # Re-run the same trials but with added delay. This is to catch some + # inconsistent timestep counting that was present in the multi-threaded + # FunctionRunner. This part of the test can be removed once the + # multi-threaded FunctionRunner is removed from ray/tune. + # TODO: remove once the multi-threaded function runner is gone. + logs2, _ = self.checkAndReturnConsistentLogs(results2, 0.5) - def train3(config, reporter): - for i in range(10): - reporter(timesteps_this_iter=0, episodes_this_iter=0) + # check all timesteps_total report the same value + self.assertTrue(all(log[TIMESTEPS_TOTAL] == 5 for log in logs2)) + # check that none of the logs report timesteps_this_iter + self.assertFalse( + any(hasattr(log, TIMESTEPS_THIS_ITER) for log in logs2)) - [trial3] = run_experiments({ - "foo": { - "run": train3, - "config": { - "script_min_iter_time_s": 0, - }, - } - }) - self.assertEqual(trial3.last_result[TIMESTEPS_TOTAL], 0) - self.assertEqual(trial3.last_result[EPISODES_TOTAL], 0) + # Test that timesteps_total and episodes_total are reported when + # timesteps_this_iter and episodes_this_iter despite only return zeros. + results3 = [ + dict(timesteps_this_iter=0, episodes_this_iter=0) + for i in range(10) + ] + logs3, _ = self.checkAndReturnConsistentLogs(results3) + + self.assertTrue(all(log[TIMESTEPS_TOTAL] == 0 for log in logs3)) + self.assertTrue(all(log[EPISODES_TOTAL] == 0 for log in logs3)) + + # Test that timesteps_total and episodes_total are properly counted + # when timesteps_this_iter and episodes_this_iter report non-zero + # values. + results4 = [ + dict(timesteps_this_iter=3, episodes_this_iter=i) + for i in range(10) + ] + logs4, _ = self.checkAndReturnConsistentLogs(results4) + + # The last reported result should not be double-logged. + self.assertEqual(logs4[-1][TIMESTEPS_TOTAL], 30) + self.assertNotEqual(logs4[-2][TIMESTEPS_TOTAL], + logs4[-1][TIMESTEPS_TOTAL]) + self.assertEqual(logs4[-1][EPISODES_TOTAL], 45) + self.assertNotEqual(logs4[-2][EPISODES_TOTAL], + logs4[-1][EPISODES_TOTAL]) + + def testAllValuesReceived(self): + results1 = [ + dict(timesteps_total=(i + 1), my_score=i**2, done=i == 4) + for i in range(5) + ] + + logs1, _ = self.checkAndReturnConsistentLogs(results1) + + # check if the correct number of results were reported + self.assertEqual(len(logs1), len(results1)) + + def check_no_missing(reported_result, result): + common_results = [reported_result[k] == result[k] for k in result] + return all(common_results) + + # check that no result was dropped or modified + complete_results = [ + check_no_missing(log, result) + for log, result in zip(logs1, results1) + ] + self.assertTrue(all(complete_results)) + + # check if done was logged exactly once + self.assertEqual(len([r for r in logs1 if r.get("done")]), 1) + + def testNoDoneReceived(self): + # repeat same test but without explicitly reporting done=True + results1 = [ + dict(timesteps_total=(i + 1), my_score=i**2) for i in range(5) + ] + + logs1, trials = self.checkAndReturnConsistentLogs(results1) + + # check if the correct number of results were reported. + self.assertEqual(len(logs1), len(results1)) + + # We should not double-log + trial = [t for t in trials if "function" in str(t)][0] + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertEqual(trial.last_result[DONE], False) + + def check_no_missing(reported_result, result): + common_results = [reported_result[k] == result[k] for k in result] + return all(common_results) + + # check that no result was dropped or modified + complete_results1 = [ + check_no_missing(log, result) + for log, result in zip(logs1, results1) + ] + self.assertTrue(all(complete_results1)) def testCheckpointDict(self): class TestTrain(Trainable): @@ -563,7 +667,7 @@ class TrainableFunctionApiTest(unittest.TestCase): def testIterationCounter(self): def train(config, reporter): for i in range(100): - reporter(itr=i, done=i == 99) + reporter(itr=i, timesteps_this_iter=1) register_trainable("exp", train) config = { @@ -600,15 +704,9 @@ class RunExperimentTest(unittest.TestCase): trials = run_experiments({ "foo": { "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } }, "bar": { "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } } }) for trial in trials: @@ -624,9 +722,6 @@ class RunExperimentTest(unittest.TestCase): exp1 = Experiment(**{ "name": "foo", "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } }) [trial] = run_experiments(exp1) self.assertEqual(trial.status, Trial.TERMINATED) @@ -641,16 +736,10 @@ class RunExperimentTest(unittest.TestCase): exp1 = Experiment(**{ "name": "foo", "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } }) exp2 = Experiment(**{ "name": "bar", "run": "f1", - "config": { - "script_min_iter_time_s": 0 - } }) trials = run_experiments([exp1, exp2]) for trial in trials: @@ -670,9 +759,6 @@ class RunExperimentTest(unittest.TestCase): trials = run_experiments({ "foo": { "run": train, - "config": { - "script_min_iter_time_s": 0 - } }, "bar": { "run": B diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 8b640cbb4..7661d01dd 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -445,13 +445,3 @@ class Trainable(object): A dict that maps ExportFormats to successfully exported models. """ return {} - - -def wrap_function(train_func): - from ray.tune.function_runner import FunctionRunner - - class WrappedFunc(FunctionRunner): - def _trainable_func(self): - return train_func - - return WrappedFunc diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index b3e939b2f..1dc33c03d 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -415,7 +415,6 @@ class TrialRunner(object): self._search_alg.on_trial_complete( trial.trial_id, result=result) decision = TrialScheduler.STOP - else: with warn_if_slow("scheduler.on_trial_result"): decision = self._scheduler_alg.on_trial_result( @@ -426,8 +425,15 @@ class TrialRunner(object): with warn_if_slow("search_alg.on_trial_complete"): self._search_alg.on_trial_complete( trial.trial_id, early_terminated=True) - trial.update_last_result( - result, terminate=(decision == TrialScheduler.STOP)) + + # __duplicate__ is a magic keyword used internally to + # avoid double-logging results when using the Function API. + # TrialScheduler and SearchAlgorithm still receive a + # notification because there may be special handling for + # the `on_trial_complete` hook. + if "__duplicate__" not in result: + trial.update_last_result( + result, terminate=(decision == TrialScheduler.STOP)) # Checkpoints to disk. This should be checked even if # the scheduler decision is STOP or PAUSE. Note that