[tune] Avoid overwriting checkpoint file (#3781)

This commit is contained in:
Richard Liaw
2019-01-16 02:03:16 -08:00
committed by GitHub
parent a237b4a6a1
commit c28e6d41f5
4 changed files with 64 additions and 19 deletions
+4 -10
View File
@@ -390,9 +390,7 @@ tune.run_experiments(
# the checkpoint.
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
for i in range(100):
if os.path.exists(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME)):
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
trials = runner.get_trials()
@@ -401,8 +399,7 @@ tune.run_experiments(
break
time.sleep(0.3)
if not os.path.exists(
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
raise RuntimeError("Checkpoint file didn't appear.")
ray.shutdown()
@@ -485,9 +482,7 @@ tune.run_experiments(
# the checkpoint.
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
for i in range(50):
if os.path.exists(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME)):
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
# Inspect the internal trialrunner
runner = TrialRunner.restore(metadata_checkpoint_dir)
trials = runner.get_trials()
@@ -496,8 +491,7 @@ tune.run_experiments(
break
time.sleep(0.2)
if not os.path.exists(
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
raise RuntimeError("Checkpoint file didn't appear.")
ray.shutdown()
+27
View File
@@ -1796,6 +1796,33 @@ class TrialRunnerTest(unittest.TestCase):
self.assertTrue("on_episode_start" in new_trial.config["callbacks"])
shutil.rmtree(tmpdir)
def testCheckpointOverwrite(self):
def count_checkpoints(cdir):
return sum((fname.startswith("experiment_state")
and fname.endswith(".json"))
for fname in os.listdir(cdir))
ray.init()
trial = Trial("__fake", checkpoint_freq=1)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(
BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir)
runner.add_trial(trial)
for i in range(5):
runner.step()
# force checkpoint
runner.checkpoint()
self.assertEquals(count_checkpoints(tmpdir), 1)
runner2 = TrialRunner.restore(tmpdir)
for i in range(5):
runner2.step()
self.assertEquals(count_checkpoints(tmpdir), 2)
runner2.checkpoint()
self.assertEquals(count_checkpoints(tmpdir), 2)
shutil.rmtree(tmpdir)
class SearchAlgorithmTest(unittest.TestCase):
def testNestedSuggestion(self):
+32 -7
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import collections
from datetime import datetime
import json
import logging
import os
@@ -28,6 +29,15 @@ def _naturalize(string):
return [int(text) if text.isdigit() else text.lower() for text in splits]
def _find_newest_ckpt(ckpt_dir):
"""Returns path to most recently modified checkpoint."""
full_paths = [
os.path.join(ckpt_dir, fname) for fname in os.listdir(ckpt_dir)
if fname.startswith("experiment_state") and fname.endswith(".json")
]
return max(full_paths)
class TrialRunner(object):
"""A TrialRunner implements the event loop for scheduling trials on Ray.
@@ -50,7 +60,7 @@ class TrialRunner(object):
misleading benchmark results.
"""
CKPT_FILE_NAME = "experiment_state.json"
CKPT_FILE_TMPL = "experiment_state-{}.json"
def __init__(self,
search_alg,
@@ -102,8 +112,22 @@ class TrialRunner(object):
self._stop_queue = []
self._metadata_checkpoint_dir = metadata_checkpoint_dir
self._session = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
@classmethod
def checkpoint_exists(cls, directory):
if not os.path.exists(directory):
return False
return any(
(fname.startswith("experiment_state") and fname.endswith(".json"))
for fname in os.listdir(directory))
def checkpoint(self):
"""Saves execution state to `self._metadata_checkpoint_dir`."""
"""Saves execution state to `self._metadata_checkpoint_dir`.
Overwrites the current session checkpoint, which starts when self
is instantiated.
"""
if not self._metadata_checkpoint_dir:
return
metadata_checkpoint_dir = self._metadata_checkpoint_dir
@@ -121,7 +145,8 @@ class TrialRunner(object):
os.rename(
tmp_file_name,
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME))
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_TMPL.format(self._session)))
return metadata_checkpoint_dir
@classmethod
@@ -146,9 +171,9 @@ class TrialRunner(object):
Returns:
runner (TrialRunner): A TrialRunner to resume experiments from.
"""
with open(
os.path.join(metadata_checkpoint_dir,
TrialRunner.CKPT_FILE_NAME), "r") as f:
newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir)
with open(newest_ckpt_path, "r") as f:
runner_state = json.load(f)
logger.warning("".join([
@@ -520,7 +545,7 @@ class TrialRunner(object):
state = self.__dict__.copy()
for k in [
"_trials", "_stop_queue", "_server", "_search_alg",
"_scheduler_alg", "trial_executor"
"_scheduler_alg", "trial_executor", "_session"
]:
del state[k]
state["launch_web_server"] = bool(self._server)
+1 -2
View File
@@ -116,8 +116,7 @@ def run_experiments(experiments,
runner = None
restore = False
if os.path.exists(
os.path.join(checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
if TrialRunner.checkpoint_exists(checkpoint_dir):
if resume == "prompt":
msg = ("Found incomplete experiment at {}. "
"Would you like to resume it?".format(checkpoint_dir))