diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 67b658a26..69e0651f1 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -11,7 +11,7 @@ import os from collections import namedtuple from ray.tune import TuneError from ray.tune.logger import NoopLogger, UnifiedLogger -from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR +from ray.tune.result import TrainingResult, DEFAULT_RESULTS_DIR, pretty_print from ray.tune.registry import _default_registry, get_registry, TRAINABLE_CLASS @@ -285,6 +285,14 @@ class Trial(object): print("Error restoring runner:", traceback.format_exc()) self.status = Trial.ERROR + def update_last_result(self, result, terminate=False): + if terminate: + result = result._replace(done=True) + print("TrainingResult for {}:".format(self)) + print(" {}".format(pretty_print(result).replace("\n", "\n "))) + self.last_result = result + self.result_logger.on_result(self.last_result) + def _setup_runner(self): self.status = Trial.RUNNING trainable_cls = get_registry().get( diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 63aeba821..b2bae6b1a 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -8,7 +8,6 @@ import time import traceback from ray.tune import TuneError -from ray.tune.result import pretty_print from ray.tune.trial import Trial, Resources from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler @@ -157,35 +156,33 @@ class TrialRunner(object): # have been lost def _process_events(self): - [result_id], _ = ray.wait(list(self._running.keys())) - trial = self._running[result_id] - del self._running[result_id] + [result_id], _ = ray.wait(list(self._running)) + trial = self._running.pop(result_id) try: result = ray.get(result_id) - trial.result_logger.on_result(result) - print("TrainingResult for {}:".format(trial)) - print(" {}".format(pretty_print(result).replace("\n", "\n "))) - trial.last_result = result self._total_time += result.time_this_iter_s if trial.should_stop(result): self._scheduler_alg.on_trial_complete(self, trial, result) - self._stop_trial(trial) + decision = TrialScheduler.STOP else: decision = self._scheduler_alg.on_trial_result( self, trial, result) - if decision == TrialScheduler.CONTINUE: - if trial.should_checkpoint(): - # TODO(rliaw): This is a blocking call - trial.checkpoint() - self._running[trial.train_remote()] = trial - elif decision == TrialScheduler.PAUSE: - self._pause_trial(trial) - elif decision == TrialScheduler.STOP: - self._stop_trial(trial) - else: - assert False, "Invalid scheduling decision: {}".format( - decision) + trial.update_last_result( + result, terminate=(decision == TrialScheduler.STOP)) + + if decision == TrialScheduler.CONTINUE: + if trial.should_checkpoint(): + # TODO(rliaw): This is a blocking call + trial.checkpoint() + self._running[trial.train_remote()] = trial + elif decision == TrialScheduler.PAUSE: + self._pause_trial(trial) + elif decision == TrialScheduler.STOP: + self._stop_trial(trial) + else: + assert False, "Invalid scheduling decision: {}".format( + decision) except Exception: print("Error processing event:", traceback.format_exc()) if trial.status == Trial.RUNNING: diff --git a/test/trial_runner_test.py b/test/trial_runner_test.py index 26be34db8..5cd58d9c2 100644 --- a/test/trial_runner_test.py +++ b/test/trial_runner_test.py @@ -450,6 +450,24 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(ray.get(trials[1].runner.get_info.remote()), 1) self.addCleanup(os.remove, path) + def testResultDone(self): + """Tests that last_result is marked `done` after trial is complete.""" + ray.init(num_cpus=1, num_gpus=1) + runner = TrialRunner() + kwargs = { + "stopping_criterion": {"training_iteration": 2}, + "resources": Resources(cpu=1, gpu=1), + } + runner.add_trial(Trial("__fake", **kwargs)) + trials = runner.get_trials() + + runner.step() + self.assertEqual(trials[0].status, Trial.RUNNING) + runner.step() + self.assertNotEqual(trials[0].last_result.done, True) + runner.step() + self.assertEqual(trials[0].last_result.done, True) + def testPauseThenResume(self): ray.init(num_cpus=1, num_gpus=1) runner = TrialRunner()