From c28e6d41f56ab6b15fe1c2cfb7bf450ae63b2b36 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 16 Jan 2019 02:03:16 -0800 Subject: [PATCH] [tune] Avoid overwriting checkpoint file (#3781) --- python/ray/tune/test/cluster_tests.py | 14 +++----- python/ray/tune/test/trial_runner_test.py | 27 ++++++++++++++++ python/ray/tune/trial_runner.py | 39 +++++++++++++++++++---- python/ray/tune/tune.py | 3 +- 4 files changed, 64 insertions(+), 19 deletions(-) diff --git a/python/ray/tune/test/cluster_tests.py b/python/ray/tune/test/cluster_tests.py index 115a07d1b..05a597e06 100644 --- a/python/ray/tune/test/cluster_tests.py +++ b/python/ray/tune/test/cluster_tests.py @@ -390,9 +390,7 @@ tune.run_experiments( # the checkpoint. metadata_checkpoint_dir = os.path.join(dirpath, "experiment") for i in range(100): - if os.path.exists( - os.path.join(metadata_checkpoint_dir, - TrialRunner.CKPT_FILE_NAME)): + if TrialRunner.checkpoint_exists(metadata_checkpoint_dir): # Inspect the internal trialrunner runner = TrialRunner.restore(metadata_checkpoint_dir) trials = runner.get_trials() @@ -401,8 +399,7 @@ tune.run_experiments( break time.sleep(0.3) - if not os.path.exists( - os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)): + if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir): raise RuntimeError("Checkpoint file didn't appear.") ray.shutdown() @@ -485,9 +482,7 @@ tune.run_experiments( # the checkpoint. metadata_checkpoint_dir = os.path.join(dirpath, "experiment") for i in range(50): - if os.path.exists( - os.path.join(metadata_checkpoint_dir, - TrialRunner.CKPT_FILE_NAME)): + if TrialRunner.checkpoint_exists(metadata_checkpoint_dir): # Inspect the internal trialrunner runner = TrialRunner.restore(metadata_checkpoint_dir) trials = runner.get_trials() @@ -496,8 +491,7 @@ tune.run_experiments( break time.sleep(0.2) - if not os.path.exists( - os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)): + if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir): raise RuntimeError("Checkpoint file didn't appear.") ray.shutdown() diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 4e8357989..2be8d9d11 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -1796,6 +1796,33 @@ class TrialRunnerTest(unittest.TestCase): self.assertTrue("on_episode_start" in new_trial.config["callbacks"]) shutil.rmtree(tmpdir) + def testCheckpointOverwrite(self): + def count_checkpoints(cdir): + return sum((fname.startswith("experiment_state") + and fname.endswith(".json")) + for fname in os.listdir(cdir)) + + ray.init() + trial = Trial("__fake", checkpoint_freq=1) + tmpdir = tempfile.mkdtemp() + runner = TrialRunner( + BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir) + runner.add_trial(trial) + for i in range(5): + runner.step() + # force checkpoint + runner.checkpoint() + self.assertEquals(count_checkpoints(tmpdir), 1) + + runner2 = TrialRunner.restore(tmpdir) + for i in range(5): + runner2.step() + self.assertEquals(count_checkpoints(tmpdir), 2) + + runner2.checkpoint() + self.assertEquals(count_checkpoints(tmpdir), 2) + shutil.rmtree(tmpdir) + class SearchAlgorithmTest(unittest.TestCase): def testNestedSuggestion(self): diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index eddfbc488..ce3c648d2 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import collections +from datetime import datetime import json import logging import os @@ -28,6 +29,15 @@ def _naturalize(string): return [int(text) if text.isdigit() else text.lower() for text in splits] +def _find_newest_ckpt(ckpt_dir): + """Returns path to most recently modified checkpoint.""" + full_paths = [ + os.path.join(ckpt_dir, fname) for fname in os.listdir(ckpt_dir) + if fname.startswith("experiment_state") and fname.endswith(".json") + ] + return max(full_paths) + + class TrialRunner(object): """A TrialRunner implements the event loop for scheduling trials on Ray. @@ -50,7 +60,7 @@ class TrialRunner(object): misleading benchmark results. """ - CKPT_FILE_NAME = "experiment_state.json" + CKPT_FILE_TMPL = "experiment_state-{}.json" def __init__(self, search_alg, @@ -102,8 +112,22 @@ class TrialRunner(object): self._stop_queue = [] self._metadata_checkpoint_dir = metadata_checkpoint_dir + self._session = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + + @classmethod + def checkpoint_exists(cls, directory): + if not os.path.exists(directory): + return False + return any( + (fname.startswith("experiment_state") and fname.endswith(".json")) + for fname in os.listdir(directory)) + def checkpoint(self): - """Saves execution state to `self._metadata_checkpoint_dir`.""" + """Saves execution state to `self._metadata_checkpoint_dir`. + + Overwrites the current session checkpoint, which starts when self + is instantiated. + """ if not self._metadata_checkpoint_dir: return metadata_checkpoint_dir = self._metadata_checkpoint_dir @@ -121,7 +145,8 @@ class TrialRunner(object): os.rename( tmp_file_name, - os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)) + os.path.join(metadata_checkpoint_dir, + TrialRunner.CKPT_FILE_TMPL.format(self._session))) return metadata_checkpoint_dir @classmethod @@ -146,9 +171,9 @@ class TrialRunner(object): Returns: runner (TrialRunner): A TrialRunner to resume experiments from. """ - with open( - os.path.join(metadata_checkpoint_dir, - TrialRunner.CKPT_FILE_NAME), "r") as f: + + newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir) + with open(newest_ckpt_path, "r") as f: runner_state = json.load(f) logger.warning("".join([ @@ -520,7 +545,7 @@ class TrialRunner(object): state = self.__dict__.copy() for k in [ "_trials", "_stop_queue", "_server", "_search_alg", - "_scheduler_alg", "trial_executor" + "_scheduler_alg", "trial_executor", "_session" ]: del state[k] state["launch_web_server"] = bool(self._server) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index f93e3490b..e216dc95e 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -116,8 +116,7 @@ def run_experiments(experiments, runner = None restore = False - if os.path.exists( - os.path.join(checkpoint_dir, TrialRunner.CKPT_FILE_NAME)): + if TrialRunner.checkpoint_exists(checkpoint_dir): if resume == "prompt": msg = ("Found incomplete experiment at {}. " "Would you like to resume it?".format(checkpoint_dir))