[tune] support experiment checkpointing for grid search (#13357)

This commit is contained in:
Richard Liaw
2021-01-18 19:24:36 -08:00
committed by GitHub
parent 1fbc3ddfac
commit 7a2997ea8c
5 changed files with 473 additions and 106 deletions
+196 -62
View File
@@ -1,16 +1,154 @@
import copy
import glob
import itertools
import os
import uuid
from typing import Dict, List, Optional, Union
import warnings
from ray.tune.error import TuneError
from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
from ray.tune.suggest.variant_generator import (
count_variants, generate_variants, format_vars, flatten_resolved_vars,
get_preset_variants)
count_variants, count_spec_samples, generate_variants, format_vars,
flatten_resolved_vars, get_preset_variants)
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.utils.util import atomic_save, load_newest_checkpoint
SERIALIZATION_THRESHOLD = 1e6
class _VariantIterator:
"""Iterates over generated variants from the search space.
This object also toggles between lazy evaluation and
eager evaluation of samples. If lazy evaluation is enabled,
this object cannot be serialized.
"""
def __init__(self, iterable, lazy_eval=False):
self.lazy_eval = lazy_eval
self.iterable = iterable
self._has_next = True
if lazy_eval:
self._load_value()
else:
self.iterable = list(iterable)
self._has_next = bool(self.iterable)
def _load_value(self):
try:
self.next_value = next(self.iterable)
except StopIteration:
self._has_next = False
def has_next(self):
return self._has_next
def __next__(self):
if self.lazy_eval:
current_value = self.next_value
self._load_value()
return current_value
current_value = self.iterable.pop(0)
self._has_next = bool(self.iterable)
return current_value
class _TrialIterator:
"""Generates trials from the spec.
Args:
uuid_prefix (str): Used in creating the trial name.
num_samples (int): Number of samples from distribution
(same as tune.run).
unresolved_spec (dict): Experiment specification
that might have unresolved distributions.
output_path (str): A specific output path within the local_dir.
points_to_evaluate (list): Same as tune.run.
lazy_eval (bool): Whether variants should be generated
lazily or eagerly. This is toggled depending
on the size of the grid search.
start (int): index at which to start counting trials.
"""
def __init__(self,
uuid_prefix: str,
num_samples: int,
unresolved_spec: dict,
output_path: str = "",
points_to_evaluate: Optional[List] = None,
lazy_eval: bool = False,
start: int = 0):
self.parser = make_parser()
self.num_samples = num_samples
self.uuid_prefix = uuid_prefix
self.num_samples_left = num_samples
self.unresolved_spec = unresolved_spec
self.output_path = output_path
self.points_to_evaluate = points_to_evaluate or []
self.num_points_to_evaluate = len(self.points_to_evaluate)
self.counter = start
self.lazy_eval = lazy_eval
self.variants = None
def create_trial(self, resolved_vars, spec):
trial_id = self.uuid_prefix + ("%05d" % self.counter)
experiment_tag = str(self.counter)
# Always append resolved vars to experiment tag?
if resolved_vars:
experiment_tag += "_{}".format(format_vars(resolved_vars))
self.counter += 1
return create_trial_from_spec(
spec,
self.output_path,
self.parser,
evaluated_params=flatten_resolved_vars(resolved_vars),
trial_id=trial_id,
experiment_tag=experiment_tag)
def __next__(self):
"""Generates Trial objects with the variant generation process.
Uses a fixed point iteration to resolve variants. All trials
should be able to be generated at once.
See also: `ray.tune.suggest.variant_generator`.
Returns:
Trial object
"""
if "run" not in self.unresolved_spec:
raise TuneError("Must specify `run` in {}".format(
self.unresolved_spec))
if self.variants and self.variants.has_next():
# This block will be skipped upon instantiation.
# `variants` will be set later after the first loop.
resolved_vars, spec = next(self.variants)
return self.create_trial(resolved_vars, spec)
if self.points_to_evaluate:
config = self.points_to_evaluate.pop(0)
self.num_samples_left -= 1
self.variants = _VariantIterator(
get_preset_variants(self.unresolved_spec, config),
lazy_eval=self.lazy_eval)
resolved_vars, spec = next(self.variants)
return self.create_trial(resolved_vars, spec)
elif self.num_samples_left > 0:
self.variants = _VariantIterator(
generate_variants(self.unresolved_spec),
lazy_eval=self.lazy_eval)
self.num_samples_left -= 1
resolved_vars, spec = next(self.variants)
return self.create_trial(resolved_vars, spec)
else:
raise StopIteration
def __iter__(self):
return self
class BasicVariantGenerator(SearchAlgorithm):
@@ -88,15 +226,12 @@ class BasicVariantGenerator(SearchAlgorithm):
both of these trials.
"""
CKPT_FILE_TMPL = "basic-variant-state-{}.json"
def __init__(self, points_to_evaluate: Optional[List[Dict]] = None):
"""Initializes the Variant Generator.
"""
self._parser = make_parser()
self._trial_generator = []
self._iterators = []
self._trial_iter = None
self._counter = 0
self._finished = False
self._points_to_evaluate = points_to_evaluate or []
@@ -125,14 +260,31 @@ class BasicVariantGenerator(SearchAlgorithm):
"""
experiment_list = convert_to_experiment_list(experiments)
for experiment in experiment_list:
grid_vals = count_spec_samples(experiment.spec, num_samples=1)
lazy_eval = grid_vals > SERIALIZATION_THRESHOLD
if lazy_eval:
warnings.warn(
f"The number of pre-generated samples ({grid_vals}) "
"exceeds the serialization threshold "
f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is "
"disabled. To fix this, reduce the number of "
"dimensions/size of the provided grid search.")
previous_samples = self._total_samples
points_to_evaluate = copy.deepcopy(self._points_to_evaluate)
self._total_samples += count_variants(experiment.spec,
points_to_evaluate)
self._trial_generator = itertools.chain(
self._trial_generator,
self._generate_trials(
experiment.spec.get("num_samples", 1), experiment.spec,
experiment.dir_name, points_to_evaluate))
iterator = _TrialIterator(
uuid_prefix=self._uuid_prefix,
num_samples=experiment.spec.get("num_samples", 1),
unresolved_spec=experiment.spec,
output_path=experiment.dir_name,
points_to_evaluate=points_to_evaluate,
lazy_eval=lazy_eval,
start=previous_samples)
self._iterators.append(iterator)
self._trial_generator = itertools.chain(self._trial_generator,
iterator)
def next_trial(self):
"""Provides one Trial object to be queued into the TrialRunner.
@@ -150,57 +302,39 @@ class BasicVariantGenerator(SearchAlgorithm):
self.set_finished()
return None
def _generate_trials(self,
num_samples,
unresolved_spec,
output_path="",
points_to_evaluate=None):
"""Generates Trial objects with the variant generation process.
def get_state(self):
if any(iterator.lazy_eval for iterator in self._iterators):
return False
state = self.__dict__.copy()
del state["_trial_generator"]
return state
Uses a fixed point iteration to resolve variants. All trials
should be able to be generated at once.
def set_state(self, state):
self.__dict__.update(state)
for iterator in self._iterators:
self._trial_generator = itertools.chain(self._trial_generator,
iterator)
See also: `ray.tune.suggest.variant_generator`.
def save_to_dir(self, dirpath, session_str):
if any(iterator.lazy_eval for iterator in self._iterators):
return False
state_dict = self.get_state()
atomic_save(
state=state_dict,
checkpoint_dir=dirpath,
file_name=self.CKPT_FILE_TMPL.format(session_str),
tmp_file_name=".tmp_generator")
Yields:
Trial object
"""
def has_checkpoint(self, dirpath: str):
"""Whether a checkpoint file exists within dirpath."""
return bool(
glob.glob(os.path.join(dirpath, self.CKPT_FILE_TMPL.format("*"))))
if "run" not in unresolved_spec:
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
points_to_evaluate = points_to_evaluate or []
while points_to_evaluate:
config = points_to_evaluate.pop(0)
for resolved_vars, spec in get_preset_variants(
unresolved_spec, config):
trial_id = self._uuid_prefix + ("%05d" % self._counter)
experiment_tag = str(self._counter)
self._counter += 1
yield create_trial_from_spec(
spec,
output_path,
self._parser,
evaluated_params=flatten_resolved_vars(resolved_vars),
trial_id=trial_id,
experiment_tag=experiment_tag)
num_samples -= 1
if num_samples <= 0:
return
for _ in range(num_samples):
for resolved_vars, spec in generate_variants(unresolved_spec):
trial_id = self._uuid_prefix + ("%05d" % self._counter)
experiment_tag = str(self._counter)
if resolved_vars:
experiment_tag += "_{}".format(format_vars(resolved_vars))
self._counter += 1
yield create_trial_from_spec(
spec,
output_path,
self._parser,
evaluated_params=flatten_resolved_vars(resolved_vars),
trial_id=trial_id,
experiment_tag=experiment_tag)
def restore_from_dir(self, dirpath: str):
"""Restores self + searcher + search wrappers from dirpath."""
state_dict = load_newest_checkpoint(dirpath,
self.CKPT_FILE_TMPL.format("*"))
if not state_dict:
raise RuntimeError(
"Unable to find checkpoint in {}.".format(dirpath))
self.set_state(state_dict)
+10 -33
View File
@@ -1,10 +1,7 @@
import os
import copy
import logging
import glob
from typing import Dict, List, Optional, Union
import ray.cloudpickle as cloudpickle
from ray.tune.error import TuneError
from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
@@ -12,7 +9,8 @@ from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.suggestion import Searcher
from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict
from ray.tune.trial import Trial
from ray.tune.utils import flatten_dict, merge_dicts
from ray.tune.utils.util import (flatten_dict, merge_dicts, atomic_save,
load_newest_checkpoint)
logger = logging.getLogger(__name__)
@@ -22,30 +20,6 @@ def _warn_on_repeater(searcher, total_samples):
_warn_num_samples(searcher, total_samples)
def _atomic_save(state: Dict, checkpoint_dir: str, file_name: str):
"""Atomically saves the object to the checkpoint directory
This is automatically used by tune.run during a Tune job.
"""
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
".tmp_search_generator_ckpt")
with open(tmp_search_ckpt_path, "wb") as f:
cloudpickle.dump(state, f)
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
def _find_newest_ckpt(dirpath: str, pattern: str):
"""Returns path to most recently modified checkpoint."""
full_paths = glob.glob(os.path.join(dirpath, pattern))
if not full_paths:
return
most_recent_checkpoint = max(full_paths)
with open(most_recent_checkpoint, "rb") as f:
search_alg_state = cloudpickle.load(f)
return search_alg_state
class SearchGenerator(SearchAlgorithm):
"""Generates trials to be passed to the TrialRunner.
@@ -174,7 +148,7 @@ class SearchGenerator(SearchAlgorithm):
def has_checkpoint(self, dirpath: str):
return bool(
_find_newest_ckpt(dirpath, self.CKPT_FILE_TMPL.format("*")))
load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*")))
def save_to_dir(self, dirpath: str, session_str: str):
"""Saves self + searcher to dir.
@@ -205,15 +179,18 @@ class SearchGenerator(SearchAlgorithm):
# We save the base searcher separately for users to easily
# separate the searcher.
base_searcher.save_to_dir(dirpath, session_str)
_atomic_save(search_alg_state, dirpath,
self.CKPT_FILE_TMPL.format(session_str))
atomic_save(
state=search_alg_state,
checkpoint_dir=dirpath,
file_name=self.CKPT_FILE_TMPL.format(session_str),
tmp_file_name=".tmp_search_generator_ckpt")
def restore_from_dir(self, dirpath: str):
"""Restores self + searcher + search wrappers from dirpath."""
searcher = self.searcher
search_alg_state = _find_newest_ckpt(dirpath,
self.CKPT_FILE_TMPL.format("*"))
search_alg_state = load_newest_checkpoint(
dirpath, self.CKPT_FILE_TMPL.format("*"))
if not search_alg_state:
raise RuntimeError(
"Unable to find checkpoint in {}.".format(dirpath))
+11 -10
View File
@@ -139,6 +139,15 @@ def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[
return resolved_vars, domain_vars, grid_vars
def count_spec_samples(spec: Dict, num_samples=1) -> int:
"""Count samples for a specific spec"""
_, domain_vars, grid_vars = parse_spec_vars(spec)
grid_count = 1
for path, domain in grid_vars:
grid_count *= len(domain.categories)
return num_samples * grid_count
def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
# Helper function: Deep update dictionary
def deep_update(d, u):
@@ -149,14 +158,6 @@ def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
d[k] = v
return d
# Count samples for a specific spec
def spec_samples(spec, num_samples=1):
_, domain_vars, grid_vars = parse_spec_vars(spec)
grid_count = 1
for path, domain in grid_vars:
grid_count *= len(domain.categories)
return num_samples * grid_count
total_samples = 0
total_num_samples = spec.get("num_samples", 1)
# For each preset, overwrite the spec and count the samples generated
@@ -164,12 +165,12 @@ def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
for preset in presets:
preset_spec = copy.deepcopy(spec)
deep_update(preset_spec["config"], preset)
total_samples += spec_samples(preset_spec, 1)
total_samples += count_spec_samples(preset_spec, 1)
total_num_samples -= 1
# Add the remaining samples
if total_num_samples > 0:
total_samples += spec_samples(spec, total_num_samples)
total_samples += count_spec_samples(spec, total_num_samples)
return total_samples
+210 -1
View File
@@ -1,4 +1,5 @@
# coding: utf-8
from collections import Counter
import os
import shutil
import tempfile
@@ -13,6 +14,8 @@ import ray
from ray import tune
from ray.test_utils import recursive_fnmatch
from ray.rllib import _register_all
from ray.tune.callback import Callback
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest import ConcurrencyLimiter, Searcher
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.dragonfly import DragonflySearch
@@ -23,6 +26,7 @@ from ray.tune.suggest.optuna import OptunaSearch, param as ot_param
from ray.tune.suggest.sigopt import SigOptSearch
from ray.tune.suggest.zoopt import ZOOptSearch
from ray.tune.utils import validate_save_restore
from ray.tune.utils._mock_trainable import MyTrainableClass
class TuneRestoreTest(unittest.TestCase):
@@ -83,6 +87,211 @@ class TuneRestoreTest(unittest.TestCase):
self.assertTrue(os.path.isfile(self.checkpoint_path))
class TuneFailResumeGridTest(unittest.TestCase):
class FailureInjectorCallback(Callback):
"""Adds random failure injection to the TrialExecutor."""
def __init__(self, steps=20):
self._step = 0
self.steps = steps
def on_trial_start(self, trials, **info):
self._step += 1
if self._step >= self.steps:
raise RuntimeError
class CheckStateCallback(Callback):
"""Checks state for the experiment initialization."""
def __init__(self, expected_trials=20):
self.expected_trials = expected_trials
self._checked = False
def on_step_begin(self, iteration, trials, **kwargs):
if not self._checked:
assert len(trials) == self.expected_trials
self._checked = True
@classmethod
def setUpClass(cls):
ray.init(local_mode=True, num_cpus=2)
@classmethod
def tearDownClass(cls):
ray.shutdown()
def setUp(self):
self.logdir = tempfile.mkdtemp()
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
from ray.tune import register_trainable
register_trainable("trainable", MyTrainableClass)
def tearDown(self):
os.environ.pop("TUNE_GLOBAL_CHECKPOINT_S")
shutil.rmtree(self.logdir)
def testFailResumeGridSearch(self):
config = dict(
num_samples=3,
fail_fast=True,
config={
"test": tune.grid_search([1, 2, 3]),
"test2": tune.grid_search([1, 2, 3]),
},
stop={"training_iteration": 2},
local_dir=self.logdir,
verbose=1)
with self.assertRaises(RuntimeError):
tune.run(
"trainable",
callbacks=[self.FailureInjectorCallback()],
**config)
analysis = tune.run(
"trainable",
resume=True,
callbacks=[self.CheckStateCallback()],
**config)
assert len(analysis.trials) == 27
test_counter = Counter([t.config["test"] for t in analysis.trials])
assert all(v == 9 for v in test_counter.values())
test2_counter = Counter([t.config["test2"] for t in analysis.trials])
assert all(v == 9 for v in test2_counter.values())
def testFailResumeWithPreset(self):
search_alg = BasicVariantGenerator(points_to_evaluate=[{
"test": -1,
"test2": -1
}, {
"test": -1
}, {
"test2": -1
}])
config = dict(
num_samples=3 + 3, # 3 preset, 3 samples
fail_fast=True,
config={
"test": tune.grid_search([1, 2, 3]),
"test2": tune.grid_search([1, 2, 3]),
},
stop={"training_iteration": 2},
local_dir=self.logdir,
verbose=1)
with self.assertRaises(RuntimeError):
tune.run(
"trainable",
callbacks=[self.FailureInjectorCallback(5)],
search_alg=search_alg,
**config)
analysis = tune.run(
"trainable",
resume=True,
callbacks=[self.CheckStateCallback(expected_trials=5)],
search_alg=search_alg,
**config)
assert len(analysis.trials) == 34
test_counter = Counter([t.config["test"] for t in analysis.trials])
assert test_counter.pop(-1) == 4
assert all(v == 10 for v in test_counter.values())
test2_counter = Counter([t.config["test2"] for t in analysis.trials])
assert test2_counter.pop(-1) == 4
assert all(v == 10 for v in test2_counter.values())
def testFailResumeAfterPreset(self):
search_alg = BasicVariantGenerator(points_to_evaluate=[{
"test": -1,
"test2": -1
}, {
"test": -1
}, {
"test2": -1
}])
config = dict(
num_samples=3 + 3, # 3 preset, 3 samples
fail_fast=True,
config={
"test": tune.grid_search([1, 2, 3]),
"test2": tune.grid_search([1, 2, 3]),
},
stop={"training_iteration": 2},
local_dir=self.logdir,
verbose=1)
with self.assertRaises(RuntimeError):
tune.run(
"trainable",
callbacks=[self.FailureInjectorCallback(15)],
search_alg=search_alg,
**config)
analysis = tune.run(
"trainable",
resume=True,
callbacks=[self.CheckStateCallback(expected_trials=15)],
search_alg=search_alg,
**config)
assert len(analysis.trials) == 34
test_counter = Counter([t.config["test"] for t in analysis.trials])
assert test_counter.pop(-1) == 4
assert all(v == 10 for v in test_counter.values())
test2_counter = Counter([t.config["test2"] for t in analysis.trials])
assert test2_counter.pop(-1) == 4
assert all(v == 10 for v in test2_counter.values())
def testMultiExperimentFail(self):
experiments = []
for i in range(3):
experiments.append(
tune.Experiment(
run=MyTrainableClass,
name="trainable",
num_samples=2,
config={
"test": tune.grid_search([1, 2, 3]),
},
stop={"training_iteration": 1},
local_dir=self.logdir))
with self.assertRaises(RuntimeError):
tune.run(
experiments,
callbacks=[self.FailureInjectorCallback(10)],
fail_fast=True)
analysis = tune.run(
experiments,
resume=True,
callbacks=[self.CheckStateCallback(expected_trials=10)],
fail_fast=True)
assert len(analysis.trials) == 18
def testWarningLargeGrid(self):
config = dict(
num_samples=3,
fail_fast=True,
config={
"test": tune.grid_search(list(range(20))),
"test2": tune.grid_search(list(range(20))),
"test3": tune.grid_search(list(range(20))),
"test4": tune.grid_search(list(range(20))),
"test5": tune.grid_search(list(range(20))),
},
stop={"training_iteration": 2},
local_dir=self.logdir,
verbose=1)
with self.assertWarnsRegex(UserWarning,
"exceeds the serialization threshold"):
with self.assertRaises(RuntimeError):
tune.run(
"trainable",
callbacks=[self.FailureInjectorCallback(10)],
**config)
class TuneExampleTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=2)
@@ -484,4 +693,4 @@ class SearcherTest(unittest.TestCase):
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))
+46
View File
@@ -1,5 +1,7 @@
from typing import Dict
import copy
import json
import glob
import logging
import numbers
import os
@@ -412,6 +414,50 @@ def diagnose_serialization(trainable):
return failure_set
def atomic_save(state: Dict, checkpoint_dir: str, file_name: str,
tmp_file_name: str):
"""Atomically saves the state object to the checkpoint directory.
This is automatically used by tune.run during a Tune job.
Args:
state (dict): Object state to be serialized.
checkpoint_dir (str): Directory location for the checkpoint.
file_name (str): Final name of file.
tmp_file_name (str): Temporary name of file.
"""
import ray.cloudpickle as cloudpickle
tmp_search_ckpt_path = os.path.join(checkpoint_dir, tmp_file_name)
with open(tmp_search_ckpt_path, "wb") as f:
cloudpickle.dump(state, f)
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
def load_newest_checkpoint(dirpath: str, ckpt_pattern: str) -> dict:
"""Returns the most recently modified checkpoint.
Assumes files are saved with an ordered name, most likely by
:obj:atomic_save.
Args:
dirpath (str): Directory in which to look for the checkpoint file.
ckpt_pattern (str): File name pattern to match to find checkpoint
files.
Returns:
(dict) Deserialized state dict.
"""
import ray.cloudpickle as cloudpickle
full_paths = glob.glob(os.path.join(dirpath, ckpt_pattern))
if not full_paths:
return
most_recent_checkpoint = max(full_paths)
with open(most_recent_checkpoint, "rb") as f:
checkpoint_state = cloudpickle.load(f)
return checkpoint_state
def wait_for_gpu(gpu_id=None, gpu_memory_limit=0.1, retry=20):
"""Checks if a given GPU has freed memory.