diff --git a/python/ray/tune/suggest/basic_variant.py b/python/ray/tune/suggest/basic_variant.py index 46f54888b..818e25222 100644 --- a/python/ray/tune/suggest/basic_variant.py +++ b/python/ray/tune/suggest/basic_variant.py @@ -1,16 +1,154 @@ import copy +import glob import itertools import os import uuid from typing import Dict, List, Optional, Union +import warnings from ray.tune.error import TuneError from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.config_parser import make_parser, create_trial_from_spec from ray.tune.suggest.variant_generator import ( - count_variants, generate_variants, format_vars, flatten_resolved_vars, - get_preset_variants) + count_variants, count_spec_samples, generate_variants, format_vars, + flatten_resolved_vars, get_preset_variants) from ray.tune.suggest.search import SearchAlgorithm +from ray.tune.utils.util import atomic_save, load_newest_checkpoint + +SERIALIZATION_THRESHOLD = 1e6 + + +class _VariantIterator: + """Iterates over generated variants from the search space. + + This object also toggles between lazy evaluation and + eager evaluation of samples. If lazy evaluation is enabled, + this object cannot be serialized. + """ + + def __init__(self, iterable, lazy_eval=False): + self.lazy_eval = lazy_eval + self.iterable = iterable + self._has_next = True + if lazy_eval: + self._load_value() + else: + self.iterable = list(iterable) + self._has_next = bool(self.iterable) + + def _load_value(self): + try: + self.next_value = next(self.iterable) + except StopIteration: + self._has_next = False + + def has_next(self): + return self._has_next + + def __next__(self): + if self.lazy_eval: + current_value = self.next_value + self._load_value() + return current_value + current_value = self.iterable.pop(0) + self._has_next = bool(self.iterable) + return current_value + + +class _TrialIterator: + """Generates trials from the spec. + + Args: + uuid_prefix (str): Used in creating the trial name. + num_samples (int): Number of samples from distribution + (same as tune.run). + unresolved_spec (dict): Experiment specification + that might have unresolved distributions. + output_path (str): A specific output path within the local_dir. + points_to_evaluate (list): Same as tune.run. + lazy_eval (bool): Whether variants should be generated + lazily or eagerly. This is toggled depending + on the size of the grid search. + start (int): index at which to start counting trials. + """ + + def __init__(self, + uuid_prefix: str, + num_samples: int, + unresolved_spec: dict, + output_path: str = "", + points_to_evaluate: Optional[List] = None, + lazy_eval: bool = False, + start: int = 0): + self.parser = make_parser() + self.num_samples = num_samples + self.uuid_prefix = uuid_prefix + self.num_samples_left = num_samples + self.unresolved_spec = unresolved_spec + self.output_path = output_path + self.points_to_evaluate = points_to_evaluate or [] + self.num_points_to_evaluate = len(self.points_to_evaluate) + self.counter = start + self.lazy_eval = lazy_eval + self.variants = None + + def create_trial(self, resolved_vars, spec): + trial_id = self.uuid_prefix + ("%05d" % self.counter) + experiment_tag = str(self.counter) + # Always append resolved vars to experiment tag? + if resolved_vars: + experiment_tag += "_{}".format(format_vars(resolved_vars)) + self.counter += 1 + return create_trial_from_spec( + spec, + self.output_path, + self.parser, + evaluated_params=flatten_resolved_vars(resolved_vars), + trial_id=trial_id, + experiment_tag=experiment_tag) + + def __next__(self): + """Generates Trial objects with the variant generation process. + + Uses a fixed point iteration to resolve variants. All trials + should be able to be generated at once. + + See also: `ray.tune.suggest.variant_generator`. + + Returns: + Trial object + """ + + if "run" not in self.unresolved_spec: + raise TuneError("Must specify `run` in {}".format( + self.unresolved_spec)) + + if self.variants and self.variants.has_next(): + # This block will be skipped upon instantiation. + # `variants` will be set later after the first loop. + resolved_vars, spec = next(self.variants) + return self.create_trial(resolved_vars, spec) + + if self.points_to_evaluate: + config = self.points_to_evaluate.pop(0) + self.num_samples_left -= 1 + self.variants = _VariantIterator( + get_preset_variants(self.unresolved_spec, config), + lazy_eval=self.lazy_eval) + resolved_vars, spec = next(self.variants) + return self.create_trial(resolved_vars, spec) + elif self.num_samples_left > 0: + self.variants = _VariantIterator( + generate_variants(self.unresolved_spec), + lazy_eval=self.lazy_eval) + self.num_samples_left -= 1 + resolved_vars, spec = next(self.variants) + return self.create_trial(resolved_vars, spec) + else: + raise StopIteration + + def __iter__(self): + return self class BasicVariantGenerator(SearchAlgorithm): @@ -88,15 +226,12 @@ class BasicVariantGenerator(SearchAlgorithm): both of these trials. """ + CKPT_FILE_TMPL = "basic-variant-state-{}.json" def __init__(self, points_to_evaluate: Optional[List[Dict]] = None): - """Initializes the Variant Generator. - - """ - self._parser = make_parser() self._trial_generator = [] + self._iterators = [] self._trial_iter = None - self._counter = 0 self._finished = False self._points_to_evaluate = points_to_evaluate or [] @@ -125,14 +260,31 @@ class BasicVariantGenerator(SearchAlgorithm): """ experiment_list = convert_to_experiment_list(experiments) for experiment in experiment_list: + grid_vals = count_spec_samples(experiment.spec, num_samples=1) + lazy_eval = grid_vals > SERIALIZATION_THRESHOLD + if lazy_eval: + warnings.warn( + f"The number of pre-generated samples ({grid_vals}) " + "exceeds the serialization threshold " + f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is " + "disabled. To fix this, reduce the number of " + "dimensions/size of the provided grid search.") + + previous_samples = self._total_samples points_to_evaluate = copy.deepcopy(self._points_to_evaluate) self._total_samples += count_variants(experiment.spec, points_to_evaluate) - self._trial_generator = itertools.chain( - self._trial_generator, - self._generate_trials( - experiment.spec.get("num_samples", 1), experiment.spec, - experiment.dir_name, points_to_evaluate)) + iterator = _TrialIterator( + uuid_prefix=self._uuid_prefix, + num_samples=experiment.spec.get("num_samples", 1), + unresolved_spec=experiment.spec, + output_path=experiment.dir_name, + points_to_evaluate=points_to_evaluate, + lazy_eval=lazy_eval, + start=previous_samples) + self._iterators.append(iterator) + self._trial_generator = itertools.chain(self._trial_generator, + iterator) def next_trial(self): """Provides one Trial object to be queued into the TrialRunner. @@ -150,57 +302,39 @@ class BasicVariantGenerator(SearchAlgorithm): self.set_finished() return None - def _generate_trials(self, - num_samples, - unresolved_spec, - output_path="", - points_to_evaluate=None): - """Generates Trial objects with the variant generation process. + def get_state(self): + if any(iterator.lazy_eval for iterator in self._iterators): + return False + state = self.__dict__.copy() + del state["_trial_generator"] + return state - Uses a fixed point iteration to resolve variants. All trials - should be able to be generated at once. + def set_state(self, state): + self.__dict__.update(state) + for iterator in self._iterators: + self._trial_generator = itertools.chain(self._trial_generator, + iterator) - See also: `ray.tune.suggest.variant_generator`. + def save_to_dir(self, dirpath, session_str): + if any(iterator.lazy_eval for iterator in self._iterators): + return False + state_dict = self.get_state() + atomic_save( + state=state_dict, + checkpoint_dir=dirpath, + file_name=self.CKPT_FILE_TMPL.format(session_str), + tmp_file_name=".tmp_generator") - Yields: - Trial object - """ + def has_checkpoint(self, dirpath: str): + """Whether a checkpoint file exists within dirpath.""" + return bool( + glob.glob(os.path.join(dirpath, self.CKPT_FILE_TMPL.format("*")))) - if "run" not in unresolved_spec: - raise TuneError("Must specify `run` in {}".format(unresolved_spec)) - - points_to_evaluate = points_to_evaluate or [] - - while points_to_evaluate: - config = points_to_evaluate.pop(0) - for resolved_vars, spec in get_preset_variants( - unresolved_spec, config): - trial_id = self._uuid_prefix + ("%05d" % self._counter) - experiment_tag = str(self._counter) - self._counter += 1 - yield create_trial_from_spec( - spec, - output_path, - self._parser, - evaluated_params=flatten_resolved_vars(resolved_vars), - trial_id=trial_id, - experiment_tag=experiment_tag) - num_samples -= 1 - - if num_samples <= 0: - return - - for _ in range(num_samples): - for resolved_vars, spec in generate_variants(unresolved_spec): - trial_id = self._uuid_prefix + ("%05d" % self._counter) - experiment_tag = str(self._counter) - if resolved_vars: - experiment_tag += "_{}".format(format_vars(resolved_vars)) - self._counter += 1 - yield create_trial_from_spec( - spec, - output_path, - self._parser, - evaluated_params=flatten_resolved_vars(resolved_vars), - trial_id=trial_id, - experiment_tag=experiment_tag) + def restore_from_dir(self, dirpath: str): + """Restores self + searcher + search wrappers from dirpath.""" + state_dict = load_newest_checkpoint(dirpath, + self.CKPT_FILE_TMPL.format("*")) + if not state_dict: + raise RuntimeError( + "Unable to find checkpoint in {}.".format(dirpath)) + self.set_state(state_dict) diff --git a/python/ray/tune/suggest/search_generator.py b/python/ray/tune/suggest/search_generator.py index b7a066269..396da90c6 100644 --- a/python/ray/tune/suggest/search_generator.py +++ b/python/ray/tune/suggest/search_generator.py @@ -1,10 +1,7 @@ -import os import copy import logging -import glob from typing import Dict, List, Optional, Union -import ray.cloudpickle as cloudpickle from ray.tune.error import TuneError from ray.tune.experiment import Experiment, convert_to_experiment_list from ray.tune.config_parser import make_parser, create_trial_from_spec @@ -12,7 +9,8 @@ from ray.tune.suggest.search import SearchAlgorithm from ray.tune.suggest.suggestion import Searcher from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict from ray.tune.trial import Trial -from ray.tune.utils import flatten_dict, merge_dicts +from ray.tune.utils.util import (flatten_dict, merge_dicts, atomic_save, + load_newest_checkpoint) logger = logging.getLogger(__name__) @@ -22,30 +20,6 @@ def _warn_on_repeater(searcher, total_samples): _warn_num_samples(searcher, total_samples) -def _atomic_save(state: Dict, checkpoint_dir: str, file_name: str): - """Atomically saves the object to the checkpoint directory - - This is automatically used by tune.run during a Tune job. - """ - tmp_search_ckpt_path = os.path.join(checkpoint_dir, - ".tmp_search_generator_ckpt") - with open(tmp_search_ckpt_path, "wb") as f: - cloudpickle.dump(state, f) - - os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name)) - - -def _find_newest_ckpt(dirpath: str, pattern: str): - """Returns path to most recently modified checkpoint.""" - full_paths = glob.glob(os.path.join(dirpath, pattern)) - if not full_paths: - return - most_recent_checkpoint = max(full_paths) - with open(most_recent_checkpoint, "rb") as f: - search_alg_state = cloudpickle.load(f) - return search_alg_state - - class SearchGenerator(SearchAlgorithm): """Generates trials to be passed to the TrialRunner. @@ -174,7 +148,7 @@ class SearchGenerator(SearchAlgorithm): def has_checkpoint(self, dirpath: str): return bool( - _find_newest_ckpt(dirpath, self.CKPT_FILE_TMPL.format("*"))) + load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*"))) def save_to_dir(self, dirpath: str, session_str: str): """Saves self + searcher to dir. @@ -205,15 +179,18 @@ class SearchGenerator(SearchAlgorithm): # We save the base searcher separately for users to easily # separate the searcher. base_searcher.save_to_dir(dirpath, session_str) - _atomic_save(search_alg_state, dirpath, - self.CKPT_FILE_TMPL.format(session_str)) + atomic_save( + state=search_alg_state, + checkpoint_dir=dirpath, + file_name=self.CKPT_FILE_TMPL.format(session_str), + tmp_file_name=".tmp_search_generator_ckpt") def restore_from_dir(self, dirpath: str): """Restores self + searcher + search wrappers from dirpath.""" searcher = self.searcher - search_alg_state = _find_newest_ckpt(dirpath, - self.CKPT_FILE_TMPL.format("*")) + search_alg_state = load_newest_checkpoint( + dirpath, self.CKPT_FILE_TMPL.format("*")) if not search_alg_state: raise RuntimeError( "Unable to find checkpoint in {}.".format(dirpath)) diff --git a/python/ray/tune/suggest/variant_generator.py b/python/ray/tune/suggest/variant_generator.py index 849b3b012..2761f283b 100644 --- a/python/ray/tune/suggest/variant_generator.py +++ b/python/ray/tune/suggest/variant_generator.py @@ -139,6 +139,15 @@ def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[ return resolved_vars, domain_vars, grid_vars +def count_spec_samples(spec: Dict, num_samples=1) -> int: + """Count samples for a specific spec""" + _, domain_vars, grid_vars = parse_spec_vars(spec) + grid_count = 1 + for path, domain in grid_vars: + grid_count *= len(domain.categories) + return num_samples * grid_count + + def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int: # Helper function: Deep update dictionary def deep_update(d, u): @@ -149,14 +158,6 @@ def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int: d[k] = v return d - # Count samples for a specific spec - def spec_samples(spec, num_samples=1): - _, domain_vars, grid_vars = parse_spec_vars(spec) - grid_count = 1 - for path, domain in grid_vars: - grid_count *= len(domain.categories) - return num_samples * grid_count - total_samples = 0 total_num_samples = spec.get("num_samples", 1) # For each preset, overwrite the spec and count the samples generated @@ -164,12 +165,12 @@ def count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int: for preset in presets: preset_spec = copy.deepcopy(spec) deep_update(preset_spec["config"], preset) - total_samples += spec_samples(preset_spec, 1) + total_samples += count_spec_samples(preset_spec, 1) total_num_samples -= 1 # Add the remaining samples if total_num_samples > 0: - total_samples += spec_samples(spec, total_num_samples) + total_samples += count_spec_samples(spec, total_num_samples) return total_samples diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py index fc61c81ff..baabd2b03 100644 --- a/python/ray/tune/tests/test_tune_restore.py +++ b/python/ray/tune/tests/test_tune_restore.py @@ -1,4 +1,5 @@ # coding: utf-8 +from collections import Counter import os import shutil import tempfile @@ -13,6 +14,8 @@ import ray from ray import tune from ray.test_utils import recursive_fnmatch from ray.rllib import _register_all +from ray.tune.callback import Callback +from ray.tune.suggest.basic_variant import BasicVariantGenerator from ray.tune.suggest import ConcurrencyLimiter, Searcher from ray.tune.suggest.hyperopt import HyperOptSearch from ray.tune.suggest.dragonfly import DragonflySearch @@ -23,6 +26,7 @@ from ray.tune.suggest.optuna import OptunaSearch, param as ot_param from ray.tune.suggest.sigopt import SigOptSearch from ray.tune.suggest.zoopt import ZOOptSearch from ray.tune.utils import validate_save_restore +from ray.tune.utils._mock_trainable import MyTrainableClass class TuneRestoreTest(unittest.TestCase): @@ -83,6 +87,211 @@ class TuneRestoreTest(unittest.TestCase): self.assertTrue(os.path.isfile(self.checkpoint_path)) +class TuneFailResumeGridTest(unittest.TestCase): + class FailureInjectorCallback(Callback): + """Adds random failure injection to the TrialExecutor.""" + + def __init__(self, steps=20): + self._step = 0 + self.steps = steps + + def on_trial_start(self, trials, **info): + self._step += 1 + if self._step >= self.steps: + raise RuntimeError + + class CheckStateCallback(Callback): + """Checks state for the experiment initialization.""" + + def __init__(self, expected_trials=20): + self.expected_trials = expected_trials + self._checked = False + + def on_step_begin(self, iteration, trials, **kwargs): + if not self._checked: + assert len(trials) == self.expected_trials + self._checked = True + + @classmethod + def setUpClass(cls): + ray.init(local_mode=True, num_cpus=2) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def setUp(self): + self.logdir = tempfile.mkdtemp() + os.environ["TUNE_GLOBAL_CHECKPOINT_S"] = "0" + from ray.tune import register_trainable + register_trainable("trainable", MyTrainableClass) + + def tearDown(self): + os.environ.pop("TUNE_GLOBAL_CHECKPOINT_S") + shutil.rmtree(self.logdir) + + def testFailResumeGridSearch(self): + config = dict( + num_samples=3, + fail_fast=True, + config={ + "test": tune.grid_search([1, 2, 3]), + "test2": tune.grid_search([1, 2, 3]), + }, + stop={"training_iteration": 2}, + local_dir=self.logdir, + verbose=1) + + with self.assertRaises(RuntimeError): + tune.run( + "trainable", + callbacks=[self.FailureInjectorCallback()], + **config) + + analysis = tune.run( + "trainable", + resume=True, + callbacks=[self.CheckStateCallback()], + **config) + assert len(analysis.trials) == 27 + test_counter = Counter([t.config["test"] for t in analysis.trials]) + assert all(v == 9 for v in test_counter.values()) + test2_counter = Counter([t.config["test2"] for t in analysis.trials]) + assert all(v == 9 for v in test2_counter.values()) + + def testFailResumeWithPreset(self): + search_alg = BasicVariantGenerator(points_to_evaluate=[{ + "test": -1, + "test2": -1 + }, { + "test": -1 + }, { + "test2": -1 + }]) + + config = dict( + num_samples=3 + 3, # 3 preset, 3 samples + fail_fast=True, + config={ + "test": tune.grid_search([1, 2, 3]), + "test2": tune.grid_search([1, 2, 3]), + }, + stop={"training_iteration": 2}, + local_dir=self.logdir, + verbose=1) + with self.assertRaises(RuntimeError): + tune.run( + "trainable", + callbacks=[self.FailureInjectorCallback(5)], + search_alg=search_alg, + **config) + + analysis = tune.run( + "trainable", + resume=True, + callbacks=[self.CheckStateCallback(expected_trials=5)], + search_alg=search_alg, + **config) + assert len(analysis.trials) == 34 + test_counter = Counter([t.config["test"] for t in analysis.trials]) + assert test_counter.pop(-1) == 4 + assert all(v == 10 for v in test_counter.values()) + test2_counter = Counter([t.config["test2"] for t in analysis.trials]) + assert test2_counter.pop(-1) == 4 + assert all(v == 10 for v in test2_counter.values()) + + def testFailResumeAfterPreset(self): + search_alg = BasicVariantGenerator(points_to_evaluate=[{ + "test": -1, + "test2": -1 + }, { + "test": -1 + }, { + "test2": -1 + }]) + + config = dict( + num_samples=3 + 3, # 3 preset, 3 samples + fail_fast=True, + config={ + "test": tune.grid_search([1, 2, 3]), + "test2": tune.grid_search([1, 2, 3]), + }, + stop={"training_iteration": 2}, + local_dir=self.logdir, + verbose=1) + + with self.assertRaises(RuntimeError): + tune.run( + "trainable", + callbacks=[self.FailureInjectorCallback(15)], + search_alg=search_alg, + **config) + + analysis = tune.run( + "trainable", + resume=True, + callbacks=[self.CheckStateCallback(expected_trials=15)], + search_alg=search_alg, + **config) + assert len(analysis.trials) == 34 + test_counter = Counter([t.config["test"] for t in analysis.trials]) + assert test_counter.pop(-1) == 4 + assert all(v == 10 for v in test_counter.values()) + test2_counter = Counter([t.config["test2"] for t in analysis.trials]) + assert test2_counter.pop(-1) == 4 + assert all(v == 10 for v in test2_counter.values()) + + def testMultiExperimentFail(self): + experiments = [] + for i in range(3): + experiments.append( + tune.Experiment( + run=MyTrainableClass, + name="trainable", + num_samples=2, + config={ + "test": tune.grid_search([1, 2, 3]), + }, + stop={"training_iteration": 1}, + local_dir=self.logdir)) + + with self.assertRaises(RuntimeError): + tune.run( + experiments, + callbacks=[self.FailureInjectorCallback(10)], + fail_fast=True) + + analysis = tune.run( + experiments, + resume=True, + callbacks=[self.CheckStateCallback(expected_trials=10)], + fail_fast=True) + assert len(analysis.trials) == 18 + + def testWarningLargeGrid(self): + config = dict( + num_samples=3, + fail_fast=True, + config={ + "test": tune.grid_search(list(range(20))), + "test2": tune.grid_search(list(range(20))), + "test3": tune.grid_search(list(range(20))), + "test4": tune.grid_search(list(range(20))), + "test5": tune.grid_search(list(range(20))), + }, + stop={"training_iteration": 2}, + local_dir=self.logdir, + verbose=1) + with self.assertWarnsRegex(UserWarning, + "exceeds the serialization threshold"): + with self.assertRaises(RuntimeError): + tune.run( + "trainable", + callbacks=[self.FailureInjectorCallback(10)], + **config) + + class TuneExampleTest(unittest.TestCase): def setUp(self): ray.init(num_cpus=2) @@ -484,4 +693,4 @@ class SearcherTest(unittest.TestCase): if __name__ == "__main__": import pytest import sys - sys.exit(pytest.main(["-v", __file__])) + sys.exit(pytest.main(["-v", __file__] + sys.argv[1:])) diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index af0d1f0e0..47a6b648e 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -1,5 +1,7 @@ +from typing import Dict import copy import json +import glob import logging import numbers import os @@ -412,6 +414,50 @@ def diagnose_serialization(trainable): return failure_set +def atomic_save(state: Dict, checkpoint_dir: str, file_name: str, + tmp_file_name: str): + """Atomically saves the state object to the checkpoint directory. + + This is automatically used by tune.run during a Tune job. + + Args: + state (dict): Object state to be serialized. + checkpoint_dir (str): Directory location for the checkpoint. + file_name (str): Final name of file. + tmp_file_name (str): Temporary name of file. + """ + import ray.cloudpickle as cloudpickle + tmp_search_ckpt_path = os.path.join(checkpoint_dir, tmp_file_name) + with open(tmp_search_ckpt_path, "wb") as f: + cloudpickle.dump(state, f) + + os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name)) + + +def load_newest_checkpoint(dirpath: str, ckpt_pattern: str) -> dict: + """Returns the most recently modified checkpoint. + + Assumes files are saved with an ordered name, most likely by + :obj:atomic_save. + + Args: + dirpath (str): Directory in which to look for the checkpoint file. + ckpt_pattern (str): File name pattern to match to find checkpoint + files. + + Returns: + (dict) Deserialized state dict. + """ + import ray.cloudpickle as cloudpickle + full_paths = glob.glob(os.path.join(dirpath, ckpt_pattern)) + if not full_paths: + return + most_recent_checkpoint = max(full_paths) + with open(most_recent_checkpoint, "rb") as f: + checkpoint_state = cloudpickle.load(f) + return checkpoint_state + + def wait_for_gpu(gpu_id=None, gpu_memory_limit=0.1, retry=20): """Checks if a given GPU has freed memory.