mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:37:39 +08:00
[tune] Avoid overwriting checkpoint file (#3781)
This commit is contained in:
@@ -390,9 +390,7 @@ tune.run_experiments(
|
||||
# the checkpoint.
|
||||
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
for i in range(100):
|
||||
if os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_NAME)):
|
||||
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
@@ -401,8 +399,7 @@ tune.run_experiments(
|
||||
break
|
||||
time.sleep(0.3)
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
|
||||
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
|
||||
raise RuntimeError("Checkpoint file didn't appear.")
|
||||
|
||||
ray.shutdown()
|
||||
@@ -485,9 +482,7 @@ tune.run_experiments(
|
||||
# the checkpoint.
|
||||
metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
for i in range(50):
|
||||
if os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_NAME)):
|
||||
if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner.restore(metadata_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
@@ -496,8 +491,7 @@ tune.run_experiments(
|
||||
break
|
||||
time.sleep(0.2)
|
||||
|
||||
if not os.path.exists(
|
||||
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
|
||||
if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
|
||||
raise RuntimeError("Checkpoint file didn't appear.")
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
@@ -1796,6 +1796,33 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertTrue("on_episode_start" in new_trial.config["callbacks"])
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testCheckpointOverwrite(self):
|
||||
def count_checkpoints(cdir):
|
||||
return sum((fname.startswith("experiment_state")
|
||||
and fname.endswith(".json"))
|
||||
for fname in os.listdir(cdir))
|
||||
|
||||
ray.init()
|
||||
trial = Trial("__fake", checkpoint_freq=1)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(
|
||||
BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir)
|
||||
runner.add_trial(trial)
|
||||
for i in range(5):
|
||||
runner.step()
|
||||
# force checkpoint
|
||||
runner.checkpoint()
|
||||
self.assertEquals(count_checkpoints(tmpdir), 1)
|
||||
|
||||
runner2 = TrialRunner.restore(tmpdir)
|
||||
for i in range(5):
|
||||
runner2.step()
|
||||
self.assertEquals(count_checkpoints(tmpdir), 2)
|
||||
|
||||
runner2.checkpoint()
|
||||
self.assertEquals(count_checkpoints(tmpdir), 2)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
class SearchAlgorithmTest(unittest.TestCase):
|
||||
def testNestedSuggestion(self):
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -28,6 +29,15 @@ def _naturalize(string):
|
||||
return [int(text) if text.isdigit() else text.lower() for text in splits]
|
||||
|
||||
|
||||
def _find_newest_ckpt(ckpt_dir):
|
||||
"""Returns path to most recently modified checkpoint."""
|
||||
full_paths = [
|
||||
os.path.join(ckpt_dir, fname) for fname in os.listdir(ckpt_dir)
|
||||
if fname.startswith("experiment_state") and fname.endswith(".json")
|
||||
]
|
||||
return max(full_paths)
|
||||
|
||||
|
||||
class TrialRunner(object):
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
@@ -50,7 +60,7 @@ class TrialRunner(object):
|
||||
misleading benchmark results.
|
||||
"""
|
||||
|
||||
CKPT_FILE_NAME = "experiment_state.json"
|
||||
CKPT_FILE_TMPL = "experiment_state-{}.json"
|
||||
|
||||
def __init__(self,
|
||||
search_alg,
|
||||
@@ -102,8 +112,22 @@ class TrialRunner(object):
|
||||
self._stop_queue = []
|
||||
self._metadata_checkpoint_dir = metadata_checkpoint_dir
|
||||
|
||||
self._session = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
@classmethod
|
||||
def checkpoint_exists(cls, directory):
|
||||
if not os.path.exists(directory):
|
||||
return False
|
||||
return any(
|
||||
(fname.startswith("experiment_state") and fname.endswith(".json"))
|
||||
for fname in os.listdir(directory))
|
||||
|
||||
def checkpoint(self):
|
||||
"""Saves execution state to `self._metadata_checkpoint_dir`."""
|
||||
"""Saves execution state to `self._metadata_checkpoint_dir`.
|
||||
|
||||
Overwrites the current session checkpoint, which starts when self
|
||||
is instantiated.
|
||||
"""
|
||||
if not self._metadata_checkpoint_dir:
|
||||
return
|
||||
metadata_checkpoint_dir = self._metadata_checkpoint_dir
|
||||
@@ -121,7 +145,8 @@ class TrialRunner(object):
|
||||
|
||||
os.rename(
|
||||
tmp_file_name,
|
||||
os.path.join(metadata_checkpoint_dir, TrialRunner.CKPT_FILE_NAME))
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_TMPL.format(self._session)))
|
||||
return metadata_checkpoint_dir
|
||||
|
||||
@classmethod
|
||||
@@ -146,9 +171,9 @@ class TrialRunner(object):
|
||||
Returns:
|
||||
runner (TrialRunner): A TrialRunner to resume experiments from.
|
||||
"""
|
||||
with open(
|
||||
os.path.join(metadata_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_NAME), "r") as f:
|
||||
|
||||
newest_ckpt_path = _find_newest_ckpt(metadata_checkpoint_dir)
|
||||
with open(newest_ckpt_path, "r") as f:
|
||||
runner_state = json.load(f)
|
||||
|
||||
logger.warning("".join([
|
||||
@@ -520,7 +545,7 @@ class TrialRunner(object):
|
||||
state = self.__dict__.copy()
|
||||
for k in [
|
||||
"_trials", "_stop_queue", "_server", "_search_alg",
|
||||
"_scheduler_alg", "trial_executor"
|
||||
"_scheduler_alg", "trial_executor", "_session"
|
||||
]:
|
||||
del state[k]
|
||||
state["launch_web_server"] = bool(self._server)
|
||||
|
||||
@@ -116,8 +116,7 @@ def run_experiments(experiments,
|
||||
runner = None
|
||||
restore = False
|
||||
|
||||
if os.path.exists(
|
||||
os.path.join(checkpoint_dir, TrialRunner.CKPT_FILE_NAME)):
|
||||
if TrialRunner.checkpoint_exists(checkpoint_dir):
|
||||
if resume == "prompt":
|
||||
msg = ("Found incomplete experiment at {}. "
|
||||
"Would you like to resume it?".format(checkpoint_dir))
|
||||
|
||||
Reference in New Issue
Block a user