mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[tune] support experiment checkpointing for grid search (#13357)
This commit is contained in:
@@ -1,16 +1,154 @@
|
||||
import copy
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Union
|
||||
import warnings
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.variant_generator import (
|
||||
count_variants, generate_variants, format_vars, flatten_resolved_vars,
|
||||
get_preset_variants)
|
||||
count_variants, count_spec_samples, generate_variants, format_vars,
|
||||
flatten_resolved_vars, get_preset_variants)
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.utils.util import atomic_save, load_newest_checkpoint
|
||||
|
||||
SERIALIZATION_THRESHOLD = 1e6
|
||||
|
||||
|
||||
class _VariantIterator:
|
||||
"""Iterates over generated variants from the search space.
|
||||
|
||||
This object also toggles between lazy evaluation and
|
||||
eager evaluation of samples. If lazy evaluation is enabled,
|
||||
this object cannot be serialized.
|
||||
"""
|
||||
|
||||
def __init__(self, iterable, lazy_eval=False):
|
||||
self.lazy_eval = lazy_eval
|
||||
self.iterable = iterable
|
||||
self._has_next = True
|
||||
if lazy_eval:
|
||||
self._load_value()
|
||||
else:
|
||||
self.iterable = list(iterable)
|
||||
self._has_next = bool(self.iterable)
|
||||
|
||||
def _load_value(self):
|
||||
try:
|
||||
self.next_value = next(self.iterable)
|
||||
except StopIteration:
|
||||
self._has_next = False
|
||||
|
||||
def has_next(self):
|
||||
return self._has_next
|
||||
|
||||
def __next__(self):
|
||||
if self.lazy_eval:
|
||||
current_value = self.next_value
|
||||
self._load_value()
|
||||
return current_value
|
||||
current_value = self.iterable.pop(0)
|
||||
self._has_next = bool(self.iterable)
|
||||
return current_value
|
||||
|
||||
|
||||
class _TrialIterator:
|
||||
"""Generates trials from the spec.
|
||||
|
||||
Args:
|
||||
uuid_prefix (str): Used in creating the trial name.
|
||||
num_samples (int): Number of samples from distribution
|
||||
(same as tune.run).
|
||||
unresolved_spec (dict): Experiment specification
|
||||
that might have unresolved distributions.
|
||||
output_path (str): A specific output path within the local_dir.
|
||||
points_to_evaluate (list): Same as tune.run.
|
||||
lazy_eval (bool): Whether variants should be generated
|
||||
lazily or eagerly. This is toggled depending
|
||||
on the size of the grid search.
|
||||
start (int): index at which to start counting trials.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
uuid_prefix: str,
|
||||
num_samples: int,
|
||||
unresolved_spec: dict,
|
||||
output_path: str = "",
|
||||
points_to_evaluate: Optional[List] = None,
|
||||
lazy_eval: bool = False,
|
||||
start: int = 0):
|
||||
self.parser = make_parser()
|
||||
self.num_samples = num_samples
|
||||
self.uuid_prefix = uuid_prefix
|
||||
self.num_samples_left = num_samples
|
||||
self.unresolved_spec = unresolved_spec
|
||||
self.output_path = output_path
|
||||
self.points_to_evaluate = points_to_evaluate or []
|
||||
self.num_points_to_evaluate = len(self.points_to_evaluate)
|
||||
self.counter = start
|
||||
self.lazy_eval = lazy_eval
|
||||
self.variants = None
|
||||
|
||||
def create_trial(self, resolved_vars, spec):
|
||||
trial_id = self.uuid_prefix + ("%05d" % self.counter)
|
||||
experiment_tag = str(self.counter)
|
||||
# Always append resolved vars to experiment tag?
|
||||
if resolved_vars:
|
||||
experiment_tag += "_{}".format(format_vars(resolved_vars))
|
||||
self.counter += 1
|
||||
return create_trial_from_spec(
|
||||
spec,
|
||||
self.output_path,
|
||||
self.parser,
|
||||
evaluated_params=flatten_resolved_vars(resolved_vars),
|
||||
trial_id=trial_id,
|
||||
experiment_tag=experiment_tag)
|
||||
|
||||
def __next__(self):
|
||||
"""Generates Trial objects with the variant generation process.
|
||||
|
||||
Uses a fixed point iteration to resolve variants. All trials
|
||||
should be able to be generated at once.
|
||||
|
||||
See also: `ray.tune.suggest.variant_generator`.
|
||||
|
||||
Returns:
|
||||
Trial object
|
||||
"""
|
||||
|
||||
if "run" not in self.unresolved_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(
|
||||
self.unresolved_spec))
|
||||
|
||||
if self.variants and self.variants.has_next():
|
||||
# This block will be skipped upon instantiation.
|
||||
# `variants` will be set later after the first loop.
|
||||
resolved_vars, spec = next(self.variants)
|
||||
return self.create_trial(resolved_vars, spec)
|
||||
|
||||
if self.points_to_evaluate:
|
||||
config = self.points_to_evaluate.pop(0)
|
||||
self.num_samples_left -= 1
|
||||
self.variants = _VariantIterator(
|
||||
get_preset_variants(self.unresolved_spec, config),
|
||||
lazy_eval=self.lazy_eval)
|
||||
resolved_vars, spec = next(self.variants)
|
||||
return self.create_trial(resolved_vars, spec)
|
||||
elif self.num_samples_left > 0:
|
||||
self.variants = _VariantIterator(
|
||||
generate_variants(self.unresolved_spec),
|
||||
lazy_eval=self.lazy_eval)
|
||||
self.num_samples_left -= 1
|
||||
resolved_vars, spec = next(self.variants)
|
||||
return self.create_trial(resolved_vars, spec)
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
|
||||
class BasicVariantGenerator(SearchAlgorithm):
|
||||
@@ -88,15 +226,12 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
both of these trials.
|
||||
|
||||
"""
|
||||
CKPT_FILE_TMPL = "basic-variant-state-{}.json"
|
||||
|
||||
def __init__(self, points_to_evaluate: Optional[List[Dict]] = None):
|
||||
"""Initializes the Variant Generator.
|
||||
|
||||
"""
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = []
|
||||
self._iterators = []
|
||||
self._trial_iter = None
|
||||
self._counter = 0
|
||||
self._finished = False
|
||||
|
||||
self._points_to_evaluate = points_to_evaluate or []
|
||||
@@ -125,14 +260,31 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
"""
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
for experiment in experiment_list:
|
||||
grid_vals = count_spec_samples(experiment.spec, num_samples=1)
|
||||
lazy_eval = grid_vals > SERIALIZATION_THRESHOLD
|
||||
if lazy_eval:
|
||||
warnings.warn(
|
||||
f"The number of pre-generated samples ({grid_vals}) "
|
||||
"exceeds the serialization threshold "
|
||||
f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is "
|
||||
"disabled. To fix this, reduce the number of "
|
||||
"dimensions/size of the provided grid search.")
|
||||
|
||||
previous_samples = self._total_samples
|
||||
points_to_evaluate = copy.deepcopy(self._points_to_evaluate)
|
||||
self._total_samples += count_variants(experiment.spec,
|
||||
points_to_evaluate)
|
||||
self._trial_generator = itertools.chain(
|
||||
self._trial_generator,
|
||||
self._generate_trials(
|
||||
experiment.spec.get("num_samples", 1), experiment.spec,
|
||||
experiment.dir_name, points_to_evaluate))
|
||||
iterator = _TrialIterator(
|
||||
uuid_prefix=self._uuid_prefix,
|
||||
num_samples=experiment.spec.get("num_samples", 1),
|
||||
unresolved_spec=experiment.spec,
|
||||
output_path=experiment.dir_name,
|
||||
points_to_evaluate=points_to_evaluate,
|
||||
lazy_eval=lazy_eval,
|
||||
start=previous_samples)
|
||||
self._iterators.append(iterator)
|
||||
self._trial_generator = itertools.chain(self._trial_generator,
|
||||
iterator)
|
||||
|
||||
def next_trial(self):
|
||||
"""Provides one Trial object to be queued into the TrialRunner.
|
||||
@@ -150,57 +302,39 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
self.set_finished()
|
||||
return None
|
||||
|
||||
def _generate_trials(self,
|
||||
num_samples,
|
||||
unresolved_spec,
|
||||
output_path="",
|
||||
points_to_evaluate=None):
|
||||
"""Generates Trial objects with the variant generation process.
|
||||
def get_state(self):
|
||||
if any(iterator.lazy_eval for iterator in self._iterators):
|
||||
return False
|
||||
state = self.__dict__.copy()
|
||||
del state["_trial_generator"]
|
||||
return state
|
||||
|
||||
Uses a fixed point iteration to resolve variants. All trials
|
||||
should be able to be generated at once.
|
||||
def set_state(self, state):
|
||||
self.__dict__.update(state)
|
||||
for iterator in self._iterators:
|
||||
self._trial_generator = itertools.chain(self._trial_generator,
|
||||
iterator)
|
||||
|
||||
See also: `ray.tune.suggest.variant_generator`.
|
||||
def save_to_dir(self, dirpath, session_str):
|
||||
if any(iterator.lazy_eval for iterator in self._iterators):
|
||||
return False
|
||||
state_dict = self.get_state()
|
||||
atomic_save(
|
||||
state=state_dict,
|
||||
checkpoint_dir=dirpath,
|
||||
file_name=self.CKPT_FILE_TMPL.format(session_str),
|
||||
tmp_file_name=".tmp_generator")
|
||||
|
||||
Yields:
|
||||
Trial object
|
||||
"""
|
||||
def has_checkpoint(self, dirpath: str):
|
||||
"""Whether a checkpoint file exists within dirpath."""
|
||||
return bool(
|
||||
glob.glob(os.path.join(dirpath, self.CKPT_FILE_TMPL.format("*"))))
|
||||
|
||||
if "run" not in unresolved_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(unresolved_spec))
|
||||
|
||||
points_to_evaluate = points_to_evaluate or []
|
||||
|
||||
while points_to_evaluate:
|
||||
config = points_to_evaluate.pop(0)
|
||||
for resolved_vars, spec in get_preset_variants(
|
||||
unresolved_spec, config):
|
||||
trial_id = self._uuid_prefix + ("%05d" % self._counter)
|
||||
experiment_tag = str(self._counter)
|
||||
self._counter += 1
|
||||
yield create_trial_from_spec(
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
evaluated_params=flatten_resolved_vars(resolved_vars),
|
||||
trial_id=trial_id,
|
||||
experiment_tag=experiment_tag)
|
||||
num_samples -= 1
|
||||
|
||||
if num_samples <= 0:
|
||||
return
|
||||
|
||||
for _ in range(num_samples):
|
||||
for resolved_vars, spec in generate_variants(unresolved_spec):
|
||||
trial_id = self._uuid_prefix + ("%05d" % self._counter)
|
||||
experiment_tag = str(self._counter)
|
||||
if resolved_vars:
|
||||
experiment_tag += "_{}".format(format_vars(resolved_vars))
|
||||
self._counter += 1
|
||||
yield create_trial_from_spec(
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
evaluated_params=flatten_resolved_vars(resolved_vars),
|
||||
trial_id=trial_id,
|
||||
experiment_tag=experiment_tag)
|
||||
def restore_from_dir(self, dirpath: str):
|
||||
"""Restores self + searcher + search wrappers from dirpath."""
|
||||
state_dict = load_newest_checkpoint(dirpath,
|
||||
self.CKPT_FILE_TMPL.format("*"))
|
||||
if not state_dict:
|
||||
raise RuntimeError(
|
||||
"Unable to find checkpoint in {}.".format(dirpath))
|
||||
self.set_state(state_dict)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import glob
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
@@ -12,7 +9,8 @@ from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.suggestion import Searcher
|
||||
from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.utils import flatten_dict, merge_dicts
|
||||
from ray.tune.utils.util import (flatten_dict, merge_dicts, atomic_save,
|
||||
load_newest_checkpoint)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,30 +20,6 @@ def _warn_on_repeater(searcher, total_samples):
|
||||
_warn_num_samples(searcher, total_samples)
|
||||
|
||||
|
||||
def _atomic_save(state: Dict, checkpoint_dir: str, file_name: str):
|
||||
"""Atomically saves the object to the checkpoint directory
|
||||
|
||||
This is automatically used by tune.run during a Tune job.
|
||||
"""
|
||||
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
|
||||
".tmp_search_generator_ckpt")
|
||||
with open(tmp_search_ckpt_path, "wb") as f:
|
||||
cloudpickle.dump(state, f)
|
||||
|
||||
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
|
||||
|
||||
|
||||
def _find_newest_ckpt(dirpath: str, pattern: str):
|
||||
"""Returns path to most recently modified checkpoint."""
|
||||
full_paths = glob.glob(os.path.join(dirpath, pattern))
|
||||
if not full_paths:
|
||||
return
|
||||
most_recent_checkpoint = max(full_paths)
|
||||
with open(most_recent_checkpoint, "rb") as f:
|
||||
search_alg_state = cloudpickle.load(f)
|
||||
return search_alg_state
|
||||
|
||||
|
||||
class SearchGenerator(SearchAlgorithm):
|
||||
"""Generates trials to be passed to the TrialRunner.
|
||||
|
||||
@@ -174,7 +148,7 @@ class SearchGenerator(SearchAlgorithm):
|
||||
|
||||
def has_checkpoint(self, dirpath: str):
|
||||
return bool(
|
||||
_find_newest_ckpt(dirpath, self.CKPT_FILE_TMPL.format("*")))
|
||||
load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*")))
|
||||
|
||||
def save_to_dir(self, dirpath: str, session_str: str):
|
||||
"""Saves self + searcher to dir.
|
||||
@@ -205,15 +179,18 @@ class SearchGenerator(SearchAlgorithm):
|
||||
# We save the base searcher separately for users to easily
|
||||
# separate the searcher.
|
||||
base_searcher.save_to_dir(dirpath, session_str)
|
||||
_atomic_save(search_alg_state, dirpath,
|
||||
self.CKPT_FILE_TMPL.format(session_str))
|
||||
atomic_save(
|
||||
state=search_alg_state,
|
||||
checkpoint_dir=dirpath,
|
||||
file_name=self.CKPT_FILE_TMPL.format(session_str),
|
||||
tmp_file_name=".tmp_search_generator_ckpt")
|
||||
|
||||
def restore_from_dir(self, dirpath: str):
|
||||
"""Restores self + searcher + search wrappers from dirpath."""
|
||||
|
||||
searcher = self.searcher
|
||||
search_alg_state = _find_newest_ckpt(dirpath,
|
||||
self.CKPT_FILE_TMPL.format("*"))
|
||||
search_alg_state = load_newest_checkpoint(
|
||||
dirpath, self.CKPT_FILE_TMPL.format("*"))
|
||||
if not search_alg_state:
|
||||
raise RuntimeError(
|
||||
"Unable to find checkpoint in {}.".format(dirpath))
|
||||
|
||||
@@ -139,6 +139,15 @@ def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[
|
||||
return resolved_vars, domain_vars, grid_vars
|
||||
|
||||
|
||||
def count_spec_samples(spec: Dict, num_samples=1) -> int:
|
||||
"""Count samples for a specific spec"""
|
||||
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
grid_count = 1
|
||||
for path, domain in grid_vars:
|
||||
grid_count *= len(domain.categories)
|
||||
return num_samples * grid_count
|
||||
|
||||
|
||||
def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
|
||||
# Helper function: Deep update dictionary
|
||||
def deep_update(d, u):
|
||||
@@ -149,14 +158,6 @@ def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
|
||||
d[k] = v
|
||||
return d
|
||||
|
||||
# Count samples for a specific spec
|
||||
def spec_samples(spec, num_samples=1):
|
||||
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
grid_count = 1
|
||||
for path, domain in grid_vars:
|
||||
grid_count *= len(domain.categories)
|
||||
return num_samples * grid_count
|
||||
|
||||
total_samples = 0
|
||||
total_num_samples = spec.get("num_samples", 1)
|
||||
# For each preset, overwrite the spec and count the samples generated
|
||||
@@ -164,12 +165,12 @@ def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
|
||||
for preset in presets:
|
||||
preset_spec = copy.deepcopy(spec)
|
||||
deep_update(preset_spec["config"], preset)
|
||||
total_samples += spec_samples(preset_spec, 1)
|
||||
total_samples += count_spec_samples(preset_spec, 1)
|
||||
total_num_samples -= 1
|
||||
|
||||
# Add the remaining samples
|
||||
if total_num_samples > 0:
|
||||
total_samples += spec_samples(spec, total_num_samples)
|
||||
total_samples += count_spec_samples(spec, total_num_samples)
|
||||
return total_samples
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# coding: utf-8
|
||||
from collections import Counter
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -13,6 +14,8 @@ import ray
|
||||
from ray import tune
|
||||
from ray.test_utils import recursive_fnmatch
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.callback import Callback
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest import ConcurrencyLimiter, Searcher
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.dragonfly import DragonflySearch
|
||||
@@ -23,6 +26,7 @@ from ray.tune.suggest.optuna import OptunaSearch, param as ot_param
|
||||
from ray.tune.suggest.sigopt import SigOptSearch
|
||||
from ray.tune.suggest.zoopt import ZOOptSearch
|
||||
from ray.tune.utils import validate_save_restore
|
||||
from ray.tune.utils._mock_trainable import MyTrainableClass
|
||||
|
||||
|
||||
class TuneRestoreTest(unittest.TestCase):
|
||||
@@ -83,6 +87,211 @@ class TuneRestoreTest(unittest.TestCase):
|
||||
self.assertTrue(os.path.isfile(self.checkpoint_path))
|
||||
|
||||
|
||||
class TuneFailResumeGridTest(unittest.TestCase):
|
||||
class FailureInjectorCallback(Callback):
|
||||
"""Adds random failure injection to the TrialExecutor."""
|
||||
|
||||
def __init__(self, steps=20):
|
||||
self._step = 0
|
||||
self.steps = steps
|
||||
|
||||
def on_trial_start(self, trials, **info):
|
||||
self._step += 1
|
||||
if self._step >= self.steps:
|
||||
raise RuntimeError
|
||||
|
||||
class CheckStateCallback(Callback):
|
||||
"""Checks state for the experiment initialization."""
|
||||
|
||||
def __init__(self, expected_trials=20):
|
||||
self.expected_trials = expected_trials
|
||||
self._checked = False
|
||||
|
||||
def on_step_begin(self, iteration, trials, **kwargs):
|
||||
if not self._checked:
|
||||
assert len(trials) == self.expected_trials
|
||||
self._checked = True
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(local_mode=True, num_cpus=2)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def setUp(self):
|
||||
self.logdir = tempfile.mkdtemp()
|
||||
os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0"
|
||||
from ray.tune import register_trainable
|
||||
register_trainable("trainable", MyTrainableClass)
|
||||
|
||||
def tearDown(self):
|
||||
os.environ.pop("TUNE_GLOBAL_CHECKPOINT_S")
|
||||
shutil.rmtree(self.logdir)
|
||||
|
||||
def testFailResumeGridSearch(self):
|
||||
config = dict(
|
||||
num_samples=3,
|
||||
fail_fast=True,
|
||||
config={
|
||||
"test": tune.grid_search([1, 2, 3]),
|
||||
"test2": tune.grid_search([1, 2, 3]),
|
||||
},
|
||||
stop={"training_iteration": 2},
|
||||
local_dir=self.logdir,
|
||||
verbose=1)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
tune.run(
|
||||
"trainable",
|
||||
callbacks=[self.FailureInjectorCallback()],
|
||||
**config)
|
||||
|
||||
analysis = tune.run(
|
||||
"trainable",
|
||||
resume=True,
|
||||
callbacks=[self.CheckStateCallback()],
|
||||
**config)
|
||||
assert len(analysis.trials) == 27
|
||||
test_counter = Counter([t.config["test"] for t in analysis.trials])
|
||||
assert all(v == 9 for v in test_counter.values())
|
||||
test2_counter = Counter([t.config["test2"] for t in analysis.trials])
|
||||
assert all(v == 9 for v in test2_counter.values())
|
||||
|
||||
def testFailResumeWithPreset(self):
|
||||
search_alg = BasicVariantGenerator(points_to_evaluate=[{
|
||||
"test": -1,
|
||||
"test2": -1
|
||||
}, {
|
||||
"test": -1
|
||||
}, {
|
||||
"test2": -1
|
||||
}])
|
||||
|
||||
config = dict(
|
||||
num_samples=3 + 3, # 3 preset, 3 samples
|
||||
fail_fast=True,
|
||||
config={
|
||||
"test": tune.grid_search([1, 2, 3]),
|
||||
"test2": tune.grid_search([1, 2, 3]),
|
||||
},
|
||||
stop={"training_iteration": 2},
|
||||
local_dir=self.logdir,
|
||||
verbose=1)
|
||||
with self.assertRaises(RuntimeError):
|
||||
tune.run(
|
||||
"trainable",
|
||||
callbacks=[self.FailureInjectorCallback(5)],
|
||||
search_alg=search_alg,
|
||||
**config)
|
||||
|
||||
analysis = tune.run(
|
||||
"trainable",
|
||||
resume=True,
|
||||
callbacks=[self.CheckStateCallback(expected_trials=5)],
|
||||
search_alg=search_alg,
|
||||
**config)
|
||||
assert len(analysis.trials) == 34
|
||||
test_counter = Counter([t.config["test"] for t in analysis.trials])
|
||||
assert test_counter.pop(-1) == 4
|
||||
assert all(v == 10 for v in test_counter.values())
|
||||
test2_counter = Counter([t.config["test2"] for t in analysis.trials])
|
||||
assert test2_counter.pop(-1) == 4
|
||||
assert all(v == 10 for v in test2_counter.values())
|
||||
|
||||
def testFailResumeAfterPreset(self):
|
||||
search_alg = BasicVariantGenerator(points_to_evaluate=[{
|
||||
"test": -1,
|
||||
"test2": -1
|
||||
}, {
|
||||
"test": -1
|
||||
}, {
|
||||
"test2": -1
|
||||
}])
|
||||
|
||||
config = dict(
|
||||
num_samples=3 + 3, # 3 preset, 3 samples
|
||||
fail_fast=True,
|
||||
config={
|
||||
"test": tune.grid_search([1, 2, 3]),
|
||||
"test2": tune.grid_search([1, 2, 3]),
|
||||
},
|
||||
stop={"training_iteration": 2},
|
||||
local_dir=self.logdir,
|
||||
verbose=1)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
tune.run(
|
||||
"trainable",
|
||||
callbacks=[self.FailureInjectorCallback(15)],
|
||||
search_alg=search_alg,
|
||||
**config)
|
||||
|
||||
analysis = tune.run(
|
||||
"trainable",
|
||||
resume=True,
|
||||
callbacks=[self.CheckStateCallback(expected_trials=15)],
|
||||
search_alg=search_alg,
|
||||
**config)
|
||||
assert len(analysis.trials) == 34
|
||||
test_counter = Counter([t.config["test"] for t in analysis.trials])
|
||||
assert test_counter.pop(-1) == 4
|
||||
assert all(v == 10 for v in test_counter.values())
|
||||
test2_counter = Counter([t.config["test2"] for t in analysis.trials])
|
||||
assert test2_counter.pop(-1) == 4
|
||||
assert all(v == 10 for v in test2_counter.values())
|
||||
|
||||
def testMultiExperimentFail(self):
|
||||
experiments = []
|
||||
for i in range(3):
|
||||
experiments.append(
|
||||
tune.Experiment(
|
||||
run=MyTrainableClass,
|
||||
name="trainable",
|
||||
num_samples=2,
|
||||
config={
|
||||
"test": tune.grid_search([1, 2, 3]),
|
||||
},
|
||||
stop={"training_iteration": 1},
|
||||
local_dir=self.logdir))
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
tune.run(
|
||||
experiments,
|
||||
callbacks=[self.FailureInjectorCallback(10)],
|
||||
fail_fast=True)
|
||||
|
||||
analysis = tune.run(
|
||||
experiments,
|
||||
resume=True,
|
||||
callbacks=[self.CheckStateCallback(expected_trials=10)],
|
||||
fail_fast=True)
|
||||
assert len(analysis.trials) == 18
|
||||
|
||||
def testWarningLargeGrid(self):
|
||||
config = dict(
|
||||
num_samples=3,
|
||||
fail_fast=True,
|
||||
config={
|
||||
"test": tune.grid_search(list(range(20))),
|
||||
"test2": tune.grid_search(list(range(20))),
|
||||
"test3": tune.grid_search(list(range(20))),
|
||||
"test4": tune.grid_search(list(range(20))),
|
||||
"test5": tune.grid_search(list(range(20))),
|
||||
},
|
||||
stop={"training_iteration": 2},
|
||||
local_dir=self.logdir,
|
||||
verbose=1)
|
||||
with self.assertWarnsRegex(UserWarning,
|
||||
"exceeds the serialization threshold"):
|
||||
with self.assertRaises(RuntimeError):
|
||||
tune.run(
|
||||
"trainable",
|
||||
callbacks=[self.FailureInjectorCallback(10)],
|
||||
**config)
|
||||
|
||||
|
||||
class TuneExampleTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=2)
|
||||
@@ -484,4 +693,4 @@ class SearcherTest(unittest.TestCase):
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
sys.exit(pytest.main(["-v", __file__] + sys.argv[1:]))
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Dict
|
||||
import copy
|
||||
import json
|
||||
import glob
|
||||
import logging
|
||||
import numbers
|
||||
import os
|
||||
@@ -412,6 +414,50 @@ def diagnose_serialization(trainable):
|
||||
return failure_set
|
||||
|
||||
|
||||
def atomic_save(state: Dict, checkpoint_dir: str, file_name: str,
|
||||
tmp_file_name: str):
|
||||
"""Atomically saves the state object to the checkpoint directory.
|
||||
|
||||
This is automatically used by tune.run during a Tune job.
|
||||
|
||||
Args:
|
||||
state (dict): Object state to be serialized.
|
||||
checkpoint_dir (str): Directory location for the checkpoint.
|
||||
file_name (str): Final name of file.
|
||||
tmp_file_name (str): Temporary name of file.
|
||||
"""
|
||||
import ray.cloudpickle as cloudpickle
|
||||
tmp_search_ckpt_path = os.path.join(checkpoint_dir, tmp_file_name)
|
||||
with open(tmp_search_ckpt_path, "wb") as f:
|
||||
cloudpickle.dump(state, f)
|
||||
|
||||
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
|
||||
|
||||
|
||||
def load_newest_checkpoint(dirpath: str, ckpt_pattern: str) -> dict:
|
||||
"""Returns the most recently modified checkpoint.
|
||||
|
||||
Assumes files are saved with an ordered name, most likely by
|
||||
:obj:atomic_save.
|
||||
|
||||
Args:
|
||||
dirpath (str): Directory in which to look for the checkpoint file.
|
||||
ckpt_pattern (str): File name pattern to match to find checkpoint
|
||||
files.
|
||||
|
||||
Returns:
|
||||
(dict) Deserialized state dict.
|
||||
"""
|
||||
import ray.cloudpickle as cloudpickle
|
||||
full_paths = glob.glob(os.path.join(dirpath, ckpt_pattern))
|
||||
if not full_paths:
|
||||
return
|
||||
most_recent_checkpoint = max(full_paths)
|
||||
with open(most_recent_checkpoint, "rb") as f:
|
||||
checkpoint_state = cloudpickle.load(f)
|
||||
return checkpoint_state
|
||||
|
||||
|
||||
def wait_for_gpu(gpu_id=None, gpu_memory_limit=0.1, retry=20):
|
||||
"""Checks if a given GPU has freed memory.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user