mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 18:29:08 +08:00
341 lines
12 KiB
Python
341 lines
12 KiB
Python
import copy
|
|
import glob
|
|
import itertools
|
|
import os
|
|
import uuid
|
|
from typing import Dict, List, Optional, Union
|
|
import warnings
|
|
|
|
from ray.tune.error import TuneError
|
|
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
|
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
|
from ray.tune.suggest.variant_generator import (
|
|
count_variants, count_spec_samples, generate_variants, format_vars,
|
|
flatten_resolved_vars, get_preset_variants)
|
|
from ray.tune.suggest.search import SearchAlgorithm
|
|
from ray.tune.utils.util import atomic_save, load_newest_checkpoint
|
|
|
|
SERIALIZATION_THRESHOLD = 1e6
|
|
|
|
|
|
class _VariantIterator:
|
|
"""Iterates over generated variants from the search space.
|
|
|
|
This object also toggles between lazy evaluation and
|
|
eager evaluation of samples. If lazy evaluation is enabled,
|
|
this object cannot be serialized.
|
|
"""
|
|
|
|
def __init__(self, iterable, lazy_eval=False):
|
|
self.lazy_eval = lazy_eval
|
|
self.iterable = iterable
|
|
self._has_next = True
|
|
if lazy_eval:
|
|
self._load_value()
|
|
else:
|
|
self.iterable = list(iterable)
|
|
self._has_next = bool(self.iterable)
|
|
|
|
def _load_value(self):
|
|
try:
|
|
self.next_value = next(self.iterable)
|
|
except StopIteration:
|
|
self._has_next = False
|
|
|
|
def has_next(self):
|
|
return self._has_next
|
|
|
|
def __next__(self):
|
|
if self.lazy_eval:
|
|
current_value = self.next_value
|
|
self._load_value()
|
|
return current_value
|
|
current_value = self.iterable.pop(0)
|
|
self._has_next = bool(self.iterable)
|
|
return current_value
|
|
|
|
|
|
class _TrialIterator:
|
|
"""Generates trials from the spec.
|
|
|
|
Args:
|
|
uuid_prefix (str): Used in creating the trial name.
|
|
num_samples (int): Number of samples from distribution
|
|
(same as tune.run).
|
|
unresolved_spec (dict): Experiment specification
|
|
that might have unresolved distributions.
|
|
output_path (str): A specific output path within the local_dir.
|
|
points_to_evaluate (list): Same as tune.run.
|
|
lazy_eval (bool): Whether variants should be generated
|
|
lazily or eagerly. This is toggled depending
|
|
on the size of the grid search.
|
|
start (int): index at which to start counting trials.
|
|
"""
|
|
|
|
def __init__(self,
|
|
uuid_prefix: str,
|
|
num_samples: int,
|
|
unresolved_spec: dict,
|
|
output_path: str = "",
|
|
points_to_evaluate: Optional[List] = None,
|
|
lazy_eval: bool = False,
|
|
start: int = 0):
|
|
self.parser = make_parser()
|
|
self.num_samples = num_samples
|
|
self.uuid_prefix = uuid_prefix
|
|
self.num_samples_left = num_samples
|
|
self.unresolved_spec = unresolved_spec
|
|
self.output_path = output_path
|
|
self.points_to_evaluate = points_to_evaluate or []
|
|
self.num_points_to_evaluate = len(self.points_to_evaluate)
|
|
self.counter = start
|
|
self.lazy_eval = lazy_eval
|
|
self.variants = None
|
|
|
|
def create_trial(self, resolved_vars, spec):
|
|
trial_id = self.uuid_prefix + ("%05d" % self.counter)
|
|
experiment_tag = str(self.counter)
|
|
# Always append resolved vars to experiment tag?
|
|
if resolved_vars:
|
|
experiment_tag += "_{}".format(format_vars(resolved_vars))
|
|
self.counter += 1
|
|
return create_trial_from_spec(
|
|
spec,
|
|
self.output_path,
|
|
self.parser,
|
|
evaluated_params=flatten_resolved_vars(resolved_vars),
|
|
trial_id=trial_id,
|
|
experiment_tag=experiment_tag)
|
|
|
|
def __next__(self):
|
|
"""Generates Trial objects with the variant generation process.
|
|
|
|
Uses a fixed point iteration to resolve variants. All trials
|
|
should be able to be generated at once.
|
|
|
|
See also: `ray.tune.suggest.variant_generator`.
|
|
|
|
Returns:
|
|
Trial object
|
|
"""
|
|
|
|
if "run" not in self.unresolved_spec:
|
|
raise TuneError("Must specify `run` in {}".format(
|
|
self.unresolved_spec))
|
|
|
|
if self.variants and self.variants.has_next():
|
|
# This block will be skipped upon instantiation.
|
|
# `variants` will be set later after the first loop.
|
|
resolved_vars, spec = next(self.variants)
|
|
return self.create_trial(resolved_vars, spec)
|
|
|
|
if self.points_to_evaluate:
|
|
config = self.points_to_evaluate.pop(0)
|
|
self.num_samples_left -= 1
|
|
self.variants = _VariantIterator(
|
|
get_preset_variants(self.unresolved_spec, config),
|
|
lazy_eval=self.lazy_eval)
|
|
resolved_vars, spec = next(self.variants)
|
|
return self.create_trial(resolved_vars, spec)
|
|
elif self.num_samples_left > 0:
|
|
self.variants = _VariantIterator(
|
|
generate_variants(self.unresolved_spec),
|
|
lazy_eval=self.lazy_eval)
|
|
self.num_samples_left -= 1
|
|
resolved_vars, spec = next(self.variants)
|
|
return self.create_trial(resolved_vars, spec)
|
|
else:
|
|
raise StopIteration
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
|
|
class BasicVariantGenerator(SearchAlgorithm):
|
|
"""Uses Tune's variant generation for resolving variables.
|
|
|
|
This is the default search algorithm used if no other search algorithm
|
|
is specified.
|
|
|
|
|
|
Args:
|
|
points_to_evaluate (list): Initial parameter suggestions to be run
|
|
first. This is for when you already have some good parameters
|
|
you want to run first to help the algorithm make better suggestions
|
|
for future parameters. Needs to be a list of dicts containing the
|
|
configurations.
|
|
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
from ray import tune
|
|
|
|
# This will automatically use the `BasicVariantGenerator`
|
|
tune.run(
|
|
lambda config: config["a"] + config["b"],
|
|
config={
|
|
"a": tune.grid_search([1, 2]),
|
|
"b": tune.randint(0, 3)
|
|
},
|
|
num_samples=4)
|
|
|
|
In the example above, 8 trials will be generated: For each sample
|
|
(``4``), each of the grid search variants for ``a`` will be sampled
|
|
once. The ``b`` parameter will be sampled randomly.
|
|
|
|
The generator accepts a pre-set list of points that should be evaluated.
|
|
The points will replace the first samples of each experiment passed to
|
|
the ``BasicVariantGenerator``.
|
|
|
|
Each point will replace one sample of the specified ``num_samples``. If
|
|
grid search variables are overwritten with the values specified in the
|
|
presets, the number of samples will thus be reduced.
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
from ray import tune
|
|
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
|
|
|
|
|
tune.run(
|
|
lambda config: config["a"] + config["b"],
|
|
config={
|
|
"a": tune.grid_search([1, 2]),
|
|
"b": tune.randint(0, 3)
|
|
},
|
|
search_alg=BasicVariantGenerator(points_to_evaluate=[
|
|
{"a": 2, "b": 2},
|
|
{"a": 1},
|
|
{"b": 2}
|
|
]),
|
|
num_samples=4)
|
|
|
|
The example above will produce six trials via four samples:
|
|
|
|
- The first sample will produce one trial with ``a=2`` and ``b=2``.
|
|
- The second sample will produce one trial with ``a=1`` and ``b`` sampled
|
|
randomly
|
|
- The third sample will produce two trials, one for each grid search
|
|
value of ``a``. It will be ``b=2`` for both of these trials.
|
|
- The fourth sample will produce two trials, one for each grid search
|
|
value of ``a``. ``b`` will be sampled randomly and independently for
|
|
both of these trials.
|
|
|
|
"""
|
|
CKPT_FILE_TMPL = "basic-variant-state-{}.json"
|
|
|
|
def __init__(self, points_to_evaluate: Optional[List[Dict]] = None):
|
|
self._trial_generator = []
|
|
self._iterators = []
|
|
self._trial_iter = None
|
|
self._finished = False
|
|
|
|
self._points_to_evaluate = points_to_evaluate or []
|
|
|
|
# Unique prefix for all trials generated, e.g., trial ids start as
|
|
# 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing.
|
|
force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID")
|
|
if force_test_uuid:
|
|
self._uuid_prefix = force_test_uuid + "_"
|
|
else:
|
|
self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_"
|
|
|
|
self._total_samples = 0
|
|
|
|
@property
|
|
def total_samples(self):
|
|
return self._total_samples
|
|
|
|
def add_configurations(
|
|
self,
|
|
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
|
|
"""Chains generator given experiment specifications.
|
|
|
|
Arguments:
|
|
experiments (Experiment | list | dict): Experiments to run.
|
|
"""
|
|
experiment_list = convert_to_experiment_list(experiments)
|
|
for experiment in experiment_list:
|
|
grid_vals = count_spec_samples(experiment.spec, num_samples=1)
|
|
lazy_eval = grid_vals > SERIALIZATION_THRESHOLD
|
|
if lazy_eval:
|
|
warnings.warn(
|
|
f"The number of pre-generated samples ({grid_vals}) "
|
|
"exceeds the serialization threshold "
|
|
f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is "
|
|
"disabled. To fix this, reduce the number of "
|
|
"dimensions/size of the provided grid search.")
|
|
|
|
previous_samples = self._total_samples
|
|
points_to_evaluate = copy.deepcopy(self._points_to_evaluate)
|
|
self._total_samples += count_variants(experiment.spec,
|
|
points_to_evaluate)
|
|
iterator = _TrialIterator(
|
|
uuid_prefix=self._uuid_prefix,
|
|
num_samples=experiment.spec.get("num_samples", 1),
|
|
unresolved_spec=experiment.spec,
|
|
output_path=experiment.dir_name,
|
|
points_to_evaluate=points_to_evaluate,
|
|
lazy_eval=lazy_eval,
|
|
start=previous_samples)
|
|
self._iterators.append(iterator)
|
|
self._trial_generator = itertools.chain(self._trial_generator,
|
|
iterator)
|
|
|
|
def next_trial(self):
|
|
"""Provides one Trial object to be queued into the TrialRunner.
|
|
|
|
Returns:
|
|
Trial: Returns a single trial.
|
|
"""
|
|
if not self._trial_iter:
|
|
self._trial_iter = iter(self._trial_generator)
|
|
try:
|
|
return next(self._trial_iter)
|
|
except StopIteration:
|
|
self._trial_generator = []
|
|
self._trial_iter = None
|
|
self.set_finished()
|
|
return None
|
|
|
|
def get_state(self):
|
|
if any(iterator.lazy_eval for iterator in self._iterators):
|
|
return False
|
|
state = self.__dict__.copy()
|
|
del state["_trial_generator"]
|
|
return state
|
|
|
|
def set_state(self, state):
|
|
self.__dict__.update(state)
|
|
for iterator in self._iterators:
|
|
self._trial_generator = itertools.chain(self._trial_generator,
|
|
iterator)
|
|
|
|
def save_to_dir(self, dirpath, session_str):
|
|
if any(iterator.lazy_eval for iterator in self._iterators):
|
|
return False
|
|
state_dict = self.get_state()
|
|
atomic_save(
|
|
state=state_dict,
|
|
checkpoint_dir=dirpath,
|
|
file_name=self.CKPT_FILE_TMPL.format(session_str),
|
|
tmp_file_name=".tmp_generator")
|
|
|
|
def has_checkpoint(self, dirpath: str):
|
|
"""Whether a checkpoint file exists within dirpath."""
|
|
return bool(
|
|
glob.glob(os.path.join(dirpath, self.CKPT_FILE_TMPL.format("*"))))
|
|
|
|
def restore_from_dir(self, dirpath: str):
|
|
"""Restores self + searcher + search wrappers from dirpath."""
|
|
state_dict = load_newest_checkpoint(dirpath,
|
|
self.CKPT_FILE_TMPL.format("*"))
|
|
if not state_dict:
|
|
raise RuntimeError(
|
|
"Unable to find checkpoint in {}.".format(dirpath))
|
|
self.set_state(state_dict)
|