mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 03:08:48 +08:00
[tune] support resume for search algorithms (#9972)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import (SearchGenerator, Searcher,
|
||||
ConcurrencyLimiter)
|
||||
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
|
||||
from ray.tune.suggest.search_generator import SearchGenerator
|
||||
from ray.tune.suggest.variant_generator import grid_search
|
||||
from ray.tune.suggest.repeater import Repeater
|
||||
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
|
||||
from ray.tune.suggest.search_generator import SearchGenerator
|
||||
|
||||
|
||||
class _MockSearcher(Searcher):
|
||||
def __init__(self, **kwargs):
|
||||
self.live_trials = {}
|
||||
self.counter = {"result": 0, "complete": 0}
|
||||
self.final_results = []
|
||||
self.stall = False
|
||||
self.results = []
|
||||
super(_MockSearcher, self).__init__(**kwargs)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if not self.stall:
|
||||
self.live_trials[trial_id] = 1
|
||||
return {"test_variable": 2}
|
||||
return None
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
self.counter["result"] += 1
|
||||
self.results += [result]
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
self.counter["complete"] += 1
|
||||
if result:
|
||||
self._process_result(result)
|
||||
if trial_id in self.live_trials:
|
||||
del self.live_trials[trial_id]
|
||||
|
||||
def _process_result(self, result):
|
||||
self.final_results += [result]
|
||||
|
||||
|
||||
class _MockSuggestionAlgorithm(SearchGenerator):
|
||||
def __init__(self, max_concurrent=None, **kwargs):
|
||||
self.searcher = _MockSearcher(**kwargs)
|
||||
if max_concurrent:
|
||||
self.searcher = ConcurrencyLimiter(
|
||||
self.searcher, max_concurrent=max_concurrent)
|
||||
super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
|
||||
|
||||
@property
|
||||
def live_trials(self):
|
||||
return self.searcher.live_trials
|
||||
|
||||
@property
|
||||
def results(self):
|
||||
return self.searcher.results
|
||||
@@ -212,13 +212,26 @@ class HyperOptSearch(Searcher):
|
||||
t for t in self._hpopt_trials.trials if t["tid"] == hyperopt_tid
|
||||
][0]
|
||||
|
||||
def get_state(self):
|
||||
return {
|
||||
"hyperopt_trials": self._hpopt_trials,
|
||||
"rstate": self.rstate.get_state()
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
self._hpopt_trials = state["hyperopt_trials"]
|
||||
self.rstate.set_state(state["rstate"])
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
trials_object = (self._hpopt_trials, self.rstate.get_state())
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(trials_object, outputFile)
|
||||
pickle.dump(self.get_state(), outputFile)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
self._hpopt_trials = trials_object[0]
|
||||
self.rstate.set_state(trials_object[1])
|
||||
|
||||
if isinstance(trials_object, tuple):
|
||||
self._hpopt_trials = trials_object[0]
|
||||
self.rstate.set_state(trials_object[1])
|
||||
else:
|
||||
self.set_state(trials_object)
|
||||
|
||||
@@ -44,7 +44,7 @@ class _TrialGroup:
|
||||
|
||||
def add(self, trial_id):
|
||||
assert len(self._trials) < self.max_trials
|
||||
self._trials[trial_id] = None
|
||||
self._trials.setdefault(trial_id, None)
|
||||
|
||||
def full(self):
|
||||
return len(self._trials) == self.max_trials
|
||||
@@ -56,7 +56,8 @@ class _TrialGroup:
|
||||
self._trials[trial_id] = score
|
||||
|
||||
def finished_reporting(self):
|
||||
return None not in self._trials.values()
|
||||
return None not in self._trials.values() and len(
|
||||
self._trials) == self.max_trials
|
||||
|
||||
def scores(self):
|
||||
return list(self._trials.values())
|
||||
@@ -159,8 +160,10 @@ class Repeater(Searcher):
|
||||
result={self.searcher.metric: np.nanmean(scores)},
|
||||
**kwargs)
|
||||
|
||||
def save(self, path):
|
||||
self.searcher.save(path)
|
||||
def get_state(self):
|
||||
self_state = self.__dict__.copy()
|
||||
del self_state["searcher"]
|
||||
return self_state
|
||||
|
||||
def restore(self, path):
|
||||
self.searcher.restore(path)
|
||||
def set_state(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
@@ -63,8 +63,14 @@ class SearchAlgorithm:
|
||||
"""Marks the search algorithm as finished."""
|
||||
self._finished = True
|
||||
|
||||
def save(self, *args):
|
||||
def has_checkpoint(self, dirpath):
|
||||
"""Should return False if not restoring is not implemented."""
|
||||
return False
|
||||
|
||||
def save_to_dir(self, dirpath, **kwargs):
|
||||
"""Saves a search algorithm."""
|
||||
pass
|
||||
|
||||
def restore(self, *args):
|
||||
def restore_from_dir(self, dirpath):
|
||||
"""Restores a search algorithm along with its wrapped state."""
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
import pickle
|
||||
import os
|
||||
import copy
|
||||
import logging
|
||||
import glob
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.suggestion import Searcher
|
||||
from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.utils import flatten_dict, merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _warn_on_repeater(searcher, total_samples):
|
||||
from ray.tune.suggest.repeater import _warn_num_samples
|
||||
_warn_num_samples(searcher, total_samples)
|
||||
|
||||
|
||||
def _atomic_save(state, checkpoint_dir, file_name):
|
||||
"""Atomically saves the object to the checkpoint directory
|
||||
|
||||
This is automatically used by tune.run during a Tune job.
|
||||
"""
|
||||
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
|
||||
".tmp_search_generator_ckpt")
|
||||
with open(tmp_search_ckpt_path, "wb") as f:
|
||||
pickle.dump(state, f)
|
||||
|
||||
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
|
||||
|
||||
|
||||
def _find_newest_ckpt(dirpath, pattern):
|
||||
"""Returns path to most recently modified checkpoint."""
|
||||
full_paths = glob.glob(os.path.join(dirpath, pattern))
|
||||
if not full_paths:
|
||||
return
|
||||
most_recent_checkpoint = max(full_paths)
|
||||
with open(most_recent_checkpoint, "rb") as f:
|
||||
search_alg_state = pickle.load(f)
|
||||
return search_alg_state
|
||||
|
||||
|
||||
class SearchGenerator(SearchAlgorithm):
|
||||
"""Generates trials to be passed to the TrialRunner.
|
||||
|
||||
Uses the provided ``searcher`` object to generate trials. This class
|
||||
transparently handles repeating trials with score aggregation
|
||||
without embedding logic into the Searcher.
|
||||
|
||||
Args:
|
||||
searcher: Search object that subclasses the Searcher base class. This
|
||||
is then used for generating new hyperparameter samples.
|
||||
"""
|
||||
CKPT_FILE_TMPL = "search_gen_state-{}.json"
|
||||
|
||||
def __init__(self, searcher):
|
||||
assert issubclass(
|
||||
type(searcher),
|
||||
Searcher), ("Searcher should be subclassing Searcher.")
|
||||
self.searcher = searcher
|
||||
self._parser = make_parser()
|
||||
self._experiment = None
|
||||
self._counter = 0 # Keeps track of number of trials created.
|
||||
self._total_samples = None # int: total samples to evaluate.
|
||||
self._finished = False
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
"""Registers experiment specifications.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
"""
|
||||
assert not self._experiment
|
||||
logger.debug("added configurations")
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
assert len(experiment_list) == 1, (
|
||||
"SearchAlgorithms can only support 1 experiment at a time.")
|
||||
self._experiment = experiment_list[0]
|
||||
experiment_spec = self._experiment.spec
|
||||
self._total_samples = self._experiment.spec.get("num_samples", 1)
|
||||
|
||||
_warn_on_repeater(self.searcher, self._total_samples)
|
||||
if "run" not in experiment_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(experiment_spec))
|
||||
|
||||
def next_trials(self):
|
||||
"""Provides a batch of Trial objects to be queued into the TrialRunner.
|
||||
|
||||
Returns:
|
||||
List[Trial]: A list of trials for the Runner to consume.
|
||||
"""
|
||||
trials = []
|
||||
while not self.is_finished():
|
||||
trial = self.create_trial_if_possible(self._experiment.spec,
|
||||
self._experiment.name)
|
||||
if trial is None:
|
||||
break
|
||||
trials.append(trial)
|
||||
return trials
|
||||
|
||||
def create_trial_if_possible(self, experiment_spec, output_path):
|
||||
logger.debug("creating trial")
|
||||
trial_id = Trial.generate_id()
|
||||
suggested_config = self.searcher.suggest(trial_id)
|
||||
if suggested_config == Searcher.FINISHED:
|
||||
self._finished = True
|
||||
logger.debug("Searcher has finished.")
|
||||
return
|
||||
|
||||
if suggested_config is None:
|
||||
return
|
||||
spec = copy.deepcopy(experiment_spec)
|
||||
spec["config"] = merge_dicts(spec["config"],
|
||||
copy.deepcopy(suggested_config))
|
||||
|
||||
# Create a new trial_id if duplicate trial is created
|
||||
flattened_config = resolve_nested_dict(spec["config"])
|
||||
self._counter += 1
|
||||
tag = "{0}_{1}".format(
|
||||
str(self._counter), format_vars(flattened_config))
|
||||
trial = create_trial_from_spec(
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
evaluated_params=flatten_dict(suggested_config),
|
||||
experiment_tag=tag,
|
||||
trial_id=trial_id)
|
||||
return trial
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
"""Notifies the underlying searcher."""
|
||||
self.searcher.on_trial_result(trial_id, result)
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
self.searcher.on_trial_complete(
|
||||
trial_id=trial_id, result=result, error=error)
|
||||
|
||||
def is_finished(self):
|
||||
return self._counter >= self._total_samples or self._finished
|
||||
|
||||
def get_state(self):
|
||||
return {
|
||||
"counter": self._counter,
|
||||
"total_samples": self._total_samples,
|
||||
"finished": self._finished,
|
||||
"experiment": self._experiment
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
self._counter = state["counter"]
|
||||
self._total_samples = state["total_samples"]
|
||||
self._finished = state["finished"]
|
||||
self._experiment = state["experiment"]
|
||||
|
||||
def has_checkpoint(self, dirpath):
|
||||
return bool(
|
||||
_find_newest_ckpt(dirpath, self.CKPT_FILE_TMPL.format("*")))
|
||||
|
||||
def save_to_dir(self, dirpath, session_str):
|
||||
"""Saves self + searcher to dir.
|
||||
|
||||
Separates the "searcher" from its wrappers (concurrency, repeating).
|
||||
This allows the user to easily restore a given searcher.
|
||||
|
||||
The save operation is atomic (write/swap).
|
||||
|
||||
Args:
|
||||
dirpath (str): Filepath to experiment dir.
|
||||
session_str (str): Unique identifier of the current run
|
||||
session.
|
||||
"""
|
||||
searcher = self.searcher
|
||||
search_alg_state = self.get_state()
|
||||
while hasattr(searcher, "searcher"):
|
||||
searcher_name = type(searcher).__name__
|
||||
if searcher_name in search_alg_state:
|
||||
logger.warning(
|
||||
"There was a duplicate when saving {}. "
|
||||
"Restore may not work properly.".format(searcher_name))
|
||||
else:
|
||||
search_alg_state["name:" +
|
||||
searcher_name] = searcher.get_state()
|
||||
searcher = searcher.searcher
|
||||
base_searcher = searcher
|
||||
# We save the base searcher separately for users to easily
|
||||
# separate the searcher.
|
||||
base_searcher.save_to_dir(dirpath, session_str)
|
||||
_atomic_save(search_alg_state, dirpath,
|
||||
self.CKPT_FILE_TMPL.format(session_str))
|
||||
|
||||
def restore_from_dir(self, dirpath):
|
||||
"""Restores self + searcher + search wrappers from dirpath."""
|
||||
|
||||
searcher = self.searcher
|
||||
search_alg_state = _find_newest_ckpt(dirpath,
|
||||
self.CKPT_FILE_TMPL.format("*"))
|
||||
if not search_alg_state:
|
||||
raise RuntimeError(
|
||||
"Unable to find checkpoint in {}.".format(dirpath))
|
||||
while hasattr(searcher, "searcher"):
|
||||
searcher_name = "name:" + type(searcher).__name__
|
||||
if searcher_name not in search_alg_state:
|
||||
names = [
|
||||
key.split("name:")[1] for key in search_alg_state
|
||||
if key.startswith("name:")
|
||||
]
|
||||
logger.warning("{} was not found in the experiment checkpoint "
|
||||
"state when restoring. Found {}.".format(
|
||||
searcher_name, names))
|
||||
else:
|
||||
searcher.set_state(search_alg_state.pop(searcher_name))
|
||||
searcher = searcher.searcher
|
||||
base_searcher = searcher
|
||||
|
||||
logger.debug(f"searching base {base_searcher}")
|
||||
base_searcher.restore_from_dir(dirpath)
|
||||
self.set_state(search_alg_state)
|
||||
@@ -1,23 +1,13 @@
|
||||
import copy
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.utils import merge_dicts, flatten_dict
|
||||
from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _warn_on_repeater(searcher, total_samples):
|
||||
from ray.tune.suggest.repeater import _warn_num_samples
|
||||
_warn_num_samples(searcher, total_samples)
|
||||
|
||||
|
||||
class Searcher:
|
||||
"""Abstract class for wrapping suggesting algorithms.
|
||||
|
||||
@@ -59,7 +49,7 @@ class Searcher:
|
||||
|
||||
"""
|
||||
FINISHED = "FINISHED"
|
||||
CKPT_FILE = "searcher-state.pkl"
|
||||
CKPT_FILE_TMPL = "searcher-state-{}.pkl"
|
||||
|
||||
def __init__(self,
|
||||
metric="episode_reward_mean",
|
||||
@@ -186,23 +176,38 @@ class Searcher:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_to_dir(self, checkpoint_dir):
|
||||
def get_state(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_state(self, state):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_to_dir(self, checkpoint_dir, session_str="default"):
|
||||
"""Automatically saves the given searcher to the checkpoint_dir.
|
||||
|
||||
This is automatically used by tune.run during a Tune job.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): Filepath to experiment dir.
|
||||
session_str (str): Unique identifier of the current run
|
||||
session.
|
||||
"""
|
||||
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
|
||||
".tmp_searcher_ckpt")
|
||||
success = True
|
||||
try:
|
||||
self.save(tmp_search_ckpt_path)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(e)
|
||||
except NotImplementedError:
|
||||
with log_once("suggest:save_to_dir"):
|
||||
logger.warning(
|
||||
"save not implemented for Searcher. Skipping save.")
|
||||
success = False
|
||||
|
||||
if success and os.path.exists(tmp_search_ckpt_path):
|
||||
os.rename(tmp_search_ckpt_path,
|
||||
os.path.join(checkpoint_dir, Searcher.CKPT_FILE))
|
||||
os.rename(
|
||||
tmp_search_ckpt_path,
|
||||
os.path.join(checkpoint_dir,
|
||||
self.CKPT_FILE_TMPL.format(session_str)))
|
||||
|
||||
def restore_from_dir(self, checkpoint_dir):
|
||||
"""Restores the state of a searcher from a given checkpoint_dir.
|
||||
@@ -225,14 +230,14 @@ class Searcher:
|
||||
os.path.join("~/my_results", self.experiment_name)
|
||||
"""
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, Searcher.CKPT_FILE)
|
||||
if os.path.exists(checkpoint_path):
|
||||
self.restore(checkpoint_path)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"{filename} not found in {directory}. Unable to restore "
|
||||
"searcher state from directory.".format(
|
||||
filename=Searcher.CKPT_FILE, directory=checkpoint_dir))
|
||||
pattern = self.CKPT_FILE_TMPL.format("*")
|
||||
full_paths = glob.glob(os.path.join(checkpoint_dir, pattern))
|
||||
if not full_paths:
|
||||
raise RuntimeError(
|
||||
"Searcher unable to find checkpoint in {}".format(
|
||||
checkpoint_dir)) # TODO
|
||||
most_recent_checkpoint = max(full_paths)
|
||||
self.restore(most_recent_checkpoint)
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
@@ -271,7 +276,13 @@ class ConcurrencyLimiter(Searcher):
|
||||
metric=self.searcher.metric, mode=self.searcher.mode)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
assert trial_id not in self.live_trials, (
|
||||
f"Trial ID {trial_id} must be unique: already found in set.")
|
||||
if len(self.live_trials) >= self.max_concurrent:
|
||||
logger.debug(
|
||||
f"Not providing a suggestion for {trial_id} due to "
|
||||
"concurrency limit: %s/%s.", len(self.live_trials),
|
||||
self.max_concurrent)
|
||||
return
|
||||
suggestion = self.searcher.suggest(trial_id)
|
||||
if suggestion not in (None, Searcher.FINISHED):
|
||||
@@ -286,159 +297,10 @@ class ConcurrencyLimiter(Searcher):
|
||||
trial_id, result=result, error=error)
|
||||
self.live_trials.remove(trial_id)
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
self.searcher.save(checkpoint_dir)
|
||||
def get_state(self):
|
||||
state = self.__dict__.copy()
|
||||
del state["searcher"]
|
||||
return copy.deepcopy(state)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
self.searcher.restore(checkpoint_dir)
|
||||
|
||||
|
||||
class SearchGenerator(SearchAlgorithm):
|
||||
"""Generates trials to be passed to the TrialRunner.
|
||||
|
||||
Uses the provided ``searcher`` object to generate trials. This class
|
||||
transparently handles repeating trials with score aggregation
|
||||
without embedding logic into the Searcher.
|
||||
|
||||
Args:
|
||||
searcher: Search object that subclasses the Searcher base class. This
|
||||
is then used for generating new hyperparameter samples.
|
||||
"""
|
||||
|
||||
def __init__(self, searcher):
|
||||
assert issubclass(
|
||||
type(searcher),
|
||||
Searcher), ("Searcher should be subclassing Searcher.")
|
||||
self.searcher = searcher
|
||||
self._parser = make_parser()
|
||||
self._experiment = None
|
||||
self._counter = 0 # Keeps track of number of trials created.
|
||||
self._total_samples = 0 # int: total samples to evaluate.
|
||||
self._finished = False
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
"""Registers experiment specifications.
|
||||
|
||||
Arguments:
|
||||
experiments (Experiment | list | dict): Experiments to run.
|
||||
"""
|
||||
logger.debug("added configurations")
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
assert len(experiment_list) == 1, (
|
||||
"SearchAlgorithms can only support 1 experiment at a time.")
|
||||
self._experiment = experiment_list[0]
|
||||
experiment_spec = self._experiment.spec
|
||||
self._total_samples = experiment_spec.get("num_samples", 1)
|
||||
|
||||
_warn_on_repeater(self.searcher, self._total_samples)
|
||||
|
||||
if "run" not in experiment_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(experiment_spec))
|
||||
|
||||
def next_trials(self):
|
||||
"""Provides a batch of Trial objects to be queued into the TrialRunner.
|
||||
|
||||
Returns:
|
||||
List[Trial]: A list of trials for the Runner to consume.
|
||||
"""
|
||||
trials = []
|
||||
while not self.is_finished():
|
||||
trial = self.create_trial_if_possible(self._experiment.spec,
|
||||
self._experiment.name)
|
||||
if trial is None:
|
||||
break
|
||||
trials.append(trial)
|
||||
return trials
|
||||
|
||||
def create_trial_if_possible(self, experiment_spec, output_path):
|
||||
logger.debug("creating trial")
|
||||
trial_id = Trial.generate_id()
|
||||
suggested_config = self.searcher.suggest(trial_id)
|
||||
if suggested_config == Searcher.FINISHED:
|
||||
self._finished = True
|
||||
logger.debug("Searcher has finished.")
|
||||
return
|
||||
|
||||
if suggested_config is None:
|
||||
return
|
||||
spec = copy.deepcopy(experiment_spec)
|
||||
spec["config"] = merge_dicts(spec["config"],
|
||||
copy.deepcopy(suggested_config))
|
||||
|
||||
# Create a new trial_id if duplicate trial is created
|
||||
flattened_config = resolve_nested_dict(spec["config"])
|
||||
self._counter += 1
|
||||
tag = "{0}_{1}".format(
|
||||
str(self._counter), format_vars(flattened_config))
|
||||
trial = create_trial_from_spec(
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
evaluated_params=flatten_dict(suggested_config),
|
||||
experiment_tag=tag,
|
||||
trial_id=trial_id)
|
||||
return trial
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
"""Notifies the underlying searcher."""
|
||||
self.searcher.on_trial_result(trial_id, result)
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
self.searcher.on_trial_complete(
|
||||
trial_id=trial_id, result=result, error=error)
|
||||
|
||||
def is_finished(self):
|
||||
return self._counter >= self._total_samples or self._finished
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
self.searcher.save(checkpoint_path)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
self.searcher.restore(checkpoint_path)
|
||||
|
||||
|
||||
class _MockSearcher(Searcher):
|
||||
def __init__(self, **kwargs):
|
||||
self.live_trials = {}
|
||||
self.counter = {"result": 0, "complete": 0}
|
||||
self.final_results = []
|
||||
self.stall = False
|
||||
self.results = []
|
||||
super(_MockSearcher, self).__init__(**kwargs)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if not self.stall:
|
||||
self.live_trials[trial_id] = 1
|
||||
return {"test_variable": 2}
|
||||
return None
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
self.counter["result"] += 1
|
||||
self.results += [result]
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, error=False):
|
||||
self.counter["complete"] += 1
|
||||
if result:
|
||||
self._process_result(result)
|
||||
if trial_id in self.live_trials:
|
||||
del self.live_trials[trial_id]
|
||||
|
||||
def _process_result(self, result):
|
||||
self.final_results += [result]
|
||||
|
||||
|
||||
class _MockSuggestionAlgorithm(SearchGenerator):
|
||||
def __init__(self, max_concurrent=None, **kwargs):
|
||||
self.searcher = _MockSearcher(**kwargs)
|
||||
if max_concurrent:
|
||||
self.searcher = ConcurrencyLimiter(
|
||||
self.searcher, max_concurrent=max_concurrent)
|
||||
super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
|
||||
|
||||
@property
|
||||
def live_trials(self):
|
||||
return self.searcher.live_trials
|
||||
|
||||
@property
|
||||
def results(self):
|
||||
return self.searcher.results
|
||||
def set_state(self, state):
|
||||
self.__dict__.update(state)
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import argparse
|
||||
|
||||
from ray.tune import run
|
||||
from ray.tune.examples.async_hyperband_example import MyTrainableClass
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.suggestion import ConcurrencyLimiter
|
||||
|
||||
from hyperopt import hp
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PyTorch Example (FOR TEST ONLY)")
|
||||
parser.add_argument(
|
||||
"--resume", action="store_true", help="Resuming from checkpoint.")
|
||||
parser.add_argument("--local-dir", help="Checkpoint path")
|
||||
parser.add_argument(
|
||||
"--ray-address",
|
||||
help="Address of Ray cluster for seamless distributed execution.")
|
||||
args = parser.parse_args()
|
||||
|
||||
space = {
|
||||
"width": hp.uniform("width", 0, 20),
|
||||
"height": hp.uniform("height", -100, 100),
|
||||
"activation": hp.choice("activation", ["relu", "tanh"])
|
||||
}
|
||||
current_best_params = [
|
||||
{
|
||||
"width": 1,
|
||||
"height": 2,
|
||||
"activation": 0 # Activation will be relu
|
||||
},
|
||||
{
|
||||
"width": 4,
|
||||
"height": 2,
|
||||
"activation": 1 # Activation will be tanh
|
||||
}
|
||||
]
|
||||
algo = HyperOptSearch(
|
||||
space,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
random_state_seed=5,
|
||||
points_to_evaluate=current_best_params)
|
||||
algo = ConcurrencyLimiter(algo, max_concurrent=1)
|
||||
from ray.tune import register_trainable
|
||||
register_trainable("trainable", MyTrainableClass)
|
||||
run("trainable",
|
||||
search_alg=algo,
|
||||
global_checkpoint_period=0,
|
||||
resume=args.resume,
|
||||
verbose=0,
|
||||
num_samples=20,
|
||||
fail_fast=True,
|
||||
stop={"training_iteration": 2},
|
||||
local_dir=args.local_dir,
|
||||
name="experiment")
|
||||
@@ -23,7 +23,7 @@ from ray.tune.logger import Logger
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
|
||||
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
||||
from ray.tune.utils import (flatten_dict, get_pinned_object,
|
||||
pin_in_object_store)
|
||||
from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
import os
|
||||
import pytest
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -713,6 +714,80 @@ tune.run(
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def test_cluster_interrupt_searcher(start_connected_cluster, tmpdir):
|
||||
"""Tests restoration of HyperOptSearch experiment on cluster shutdown
|
||||
with actual interrupt.
|
||||
|
||||
Restoration should restore both state of trials
|
||||
and previous search algorithm (HyperOptSearch) state.
|
||||
This is an end-to-end test.
|
||||
"""
|
||||
cluster = start_connected_cluster
|
||||
dirpath = str(tmpdir)
|
||||
local_checkpoint_dir = os.path.join(dirpath, "experiment")
|
||||
from ray.tune.examples.async_hyperband_example import MyTrainableClass
|
||||
from ray.tune import register_trainable
|
||||
register_trainable("trainable", MyTrainableClass)
|
||||
|
||||
def execute_script_with_args(*args):
|
||||
current_dir = os.path.dirname(__file__)
|
||||
script = os.path.join(current_dir,
|
||||
"_test_cluster_interrupt_searcher.py")
|
||||
subprocess.Popen([sys.executable, script] + list(args))
|
||||
|
||||
args = ["--ray-address", cluster.address, "--local-dir", dirpath]
|
||||
execute_script_with_args(*args)
|
||||
# Wait until the right checkpoint is saved.
|
||||
# The trainable returns every 0.5 seconds, so this should not miss
|
||||
# the checkpoint.
|
||||
for i in range(50):
|
||||
if TrialRunner.checkpoint_exists(local_checkpoint_dir):
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner(
|
||||
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
if trials and len(trials) >= 10:
|
||||
break
|
||||
time.sleep(.5)
|
||||
|
||||
if not TrialRunner.checkpoint_exists(local_checkpoint_dir):
|
||||
raise RuntimeError(
|
||||
f"Checkpoint file didn't appear in {local_checkpoint_dir}. "
|
||||
f"Current list: {os.listdir(local_checkpoint_dir)}.")
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
cluster = _start_new_cluster()
|
||||
execute_script_with_args(*(args + ["--resume"]))
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
register_trainable("trainable", MyTrainableClass)
|
||||
reached = False
|
||||
for i in range(50):
|
||||
if TrialRunner.checkpoint_exists(local_checkpoint_dir):
|
||||
# Inspect the internal trialrunner
|
||||
runner = TrialRunner(
|
||||
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
|
||||
trials = runner.get_trials()
|
||||
if len(trials) == 0:
|
||||
continue # nonblocking script hasn't resumed yet, wait
|
||||
reached = True
|
||||
assert len(trials) >= 10
|
||||
assert len(trials) <= 20
|
||||
if len(trials) == 20:
|
||||
break
|
||||
else:
|
||||
stop_fn = runner.trial_executor.stop_trial
|
||||
[stop_fn(t) for t in trials if t.status is not Trial.ERROR]
|
||||
time.sleep(.5)
|
||||
assert reached is True
|
||||
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from collections import Counter
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -14,14 +16,17 @@ from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.resources import Resources, json_to_resources, resources_to_json
|
||||
from ray.tune.suggest.repeater import Repeater
|
||||
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
|
||||
SearchGenerator, Searcher)
|
||||
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
|
||||
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
|
||||
from ray.tune.suggest.search_generator import SearchGenerator
|
||||
|
||||
|
||||
class TrialRunnerTest3(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
def testStepHook(self):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
@@ -264,6 +269,92 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
self.assertTrue(searcher.is_finished())
|
||||
self.assertRaises(TuneError, runner.step)
|
||||
|
||||
def testSearcherSaveRestore(self):
|
||||
ray.init(num_cpus=8, local_mode=True)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def create_searcher():
|
||||
class TestSuggestion(Searcher):
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
self.returned_result = []
|
||||
super().__init__(metric="result", mode="max")
|
||||
|
||||
def suggest(self, trial_id):
|
||||
self.index += 1
|
||||
return {"test_variable": self.index}
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, **kwargs):
|
||||
self.returned_result.append(result)
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump(self.__dict__, f)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
self.__dict__.update(pickle.load(f))
|
||||
|
||||
searcher = TestSuggestion(0)
|
||||
searcher = ConcurrencyLimiter(searcher, max_concurrent=2)
|
||||
searcher = Repeater(searcher, repeat=3, set_index=False)
|
||||
search_alg = SearchGenerator(searcher)
|
||||
experiment_spec = {
|
||||
"run": "__fake",
|
||||
"num_samples": 20,
|
||||
"stop": {
|
||||
"training_iteration": 2
|
||||
}
|
||||
}
|
||||
experiments = [Experiment.from_json("test", experiment_spec)]
|
||||
search_alg.add_configurations(experiments)
|
||||
return search_alg
|
||||
|
||||
searcher = create_searcher()
|
||||
runner = TrialRunner(
|
||||
search_alg=searcher,
|
||||
local_checkpoint_dir=tmpdir,
|
||||
checkpoint_period=-1)
|
||||
for i in range(6):
|
||||
runner.step()
|
||||
|
||||
assert len(
|
||||
runner.get_trials()) == 6, [t.config for t in runner.get_trials()]
|
||||
runner.checkpoint()
|
||||
trials = runner.get_trials()
|
||||
[
|
||||
runner.trial_executor.stop_trial(t) for t in trials
|
||||
if t.status is not Trial.ERROR
|
||||
]
|
||||
del runner
|
||||
# stop_all(runner.get_trials())
|
||||
|
||||
searcher = create_searcher()
|
||||
runner2 = TrialRunner(
|
||||
search_alg=searcher, local_checkpoint_dir=tmpdir, resume="LOCAL")
|
||||
assert len(runner2.get_trials()) == 6, [
|
||||
t.config for t in runner2.get_trials()
|
||||
]
|
||||
|
||||
def trial_statuses():
|
||||
return [t.status for t in runner2.get_trials()]
|
||||
|
||||
def num_running_trials():
|
||||
return sum(t.status == Trial.RUNNING for t in runner2.get_trials())
|
||||
|
||||
for i in range(6):
|
||||
runner2.step()
|
||||
assert len(set(trial_statuses())) == 1
|
||||
assert Trial.RUNNING in trial_statuses()
|
||||
for i in range(20):
|
||||
runner2.step()
|
||||
assert 1 <= num_running_trials() <= 6
|
||||
evaluated = [
|
||||
t.evaluated_params["test_variable"] for t in runner2.get_trials()
|
||||
]
|
||||
count = Counter(evaluated)
|
||||
assert all(v <= 3 for v in count.values())
|
||||
|
||||
def testTrialSaveRestore(self):
|
||||
"""Creates different trials to test runner.checkpoint/restore."""
|
||||
ray.init(num_cpus=3)
|
||||
@@ -512,6 +603,90 @@ class SearchAlgorithmTest(unittest.TestCase):
|
||||
parameter_set = {t.evaluated_params["test_variable"] for t in trials}
|
||||
self.assertEquals(len(parameter_set), 3)
|
||||
|
||||
def testSetGetRepeater(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
class TestSuggestion(Searcher):
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
self.returned_result = []
|
||||
super().__init__(metric="result", mode="max")
|
||||
|
||||
def suggest(self, trial_id):
|
||||
self.index += 1
|
||||
return {"score": self.index}
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, **kwargs):
|
||||
self.returned_result.append(result)
|
||||
|
||||
searcher = TestSuggestion(0)
|
||||
repeater1 = Repeater(searcher, repeat=3, set_index=False)
|
||||
for i in range(3):
|
||||
assert repeater1.suggest(f"test_{i}")["score"] == 1
|
||||
for i in range(2): # An incomplete set of results
|
||||
assert repeater1.suggest(f"test_{i}_2")["score"] == 2
|
||||
|
||||
# Restore a new one
|
||||
state = repeater1.get_state()
|
||||
del repeater1
|
||||
new_repeater = Repeater(searcher, repeat=1, set_index=True)
|
||||
new_repeater.set_state(state)
|
||||
assert new_repeater.repeat == 3
|
||||
assert new_repeater.suggest("test_2_2")["score"] == 2
|
||||
assert new_repeater.suggest("test_x")["score"] == 3
|
||||
|
||||
# Report results
|
||||
for i in range(3):
|
||||
new_repeater.on_trial_complete(f"test_{i}", {"result": 2})
|
||||
|
||||
for i in range(3):
|
||||
new_repeater.on_trial_complete(f"test_{i}_2", {"result": -i * 10})
|
||||
|
||||
assert len(new_repeater.searcher.returned_result) == 2
|
||||
assert new_repeater.searcher.returned_result[-1] == {"result": -10}
|
||||
|
||||
# Finish the rest of the last trial group
|
||||
new_repeater.on_trial_complete("test_x", {"result": 3})
|
||||
assert new_repeater.suggest("test_y")["score"] == 3
|
||||
new_repeater.on_trial_complete("test_y", {"result": 3})
|
||||
assert len(new_repeater.searcher.returned_result) == 2
|
||||
assert new_repeater.suggest("test_z")["score"] == 3
|
||||
new_repeater.on_trial_complete("test_z", {"result": 3})
|
||||
assert len(new_repeater.searcher.returned_result) == 3
|
||||
assert new_repeater.searcher.returned_result[-1] == {"result": 3}
|
||||
|
||||
def testSetGetLimiter(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
class TestSuggestion(Searcher):
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
self.returned_result = []
|
||||
super().__init__(metric="result", mode="max")
|
||||
|
||||
def suggest(self, trial_id):
|
||||
self.index += 1
|
||||
return {"score": self.index}
|
||||
|
||||
def on_trial_complete(self, trial_id, result=None, **kwargs):
|
||||
self.returned_result.append(result)
|
||||
|
||||
searcher = TestSuggestion(0)
|
||||
limiter = ConcurrencyLimiter(searcher, max_concurrent=2)
|
||||
assert limiter.suggest("test_1")["score"] == 1
|
||||
assert limiter.suggest("test_2")["score"] == 2
|
||||
assert limiter.suggest("test_3") is None
|
||||
|
||||
state = limiter.get_state()
|
||||
del limiter
|
||||
limiter2 = ConcurrencyLimiter(searcher, max_concurrent=3)
|
||||
limiter2.set_state(state)
|
||||
assert limiter2.suggest("test_4") is None
|
||||
assert limiter2.suggest("test_5") is None
|
||||
limiter2.on_trial_complete("test_1", {"result": 3})
|
||||
limiter2.on_trial_complete("test_2", {"result": 3})
|
||||
assert limiter2.suggest("test_3")["score"] == 3
|
||||
|
||||
|
||||
class ResourcesTest(unittest.TestCase):
|
||||
def testSubtraction(self):
|
||||
|
||||
@@ -17,7 +17,7 @@ from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
from ray.tune.syncer import get_cloud_syncer
|
||||
from ray.tune.trial import Checkpoint, Trial
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest import BasicVariantGenerator, Searcher
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.utils import warn_if_slow, flatten_dict
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
@@ -168,9 +168,13 @@ class TrialRunner:
|
||||
self.resume()
|
||||
logger.info("Resuming trial.")
|
||||
self._resumed = True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Runner restore failed. Restarting experiment.")
|
||||
except Exception as e:
|
||||
if self._verbose:
|
||||
logger.error(str(e))
|
||||
logger.exception("Runner restore failed.")
|
||||
if self._fail_fast:
|
||||
raise
|
||||
logger.info("Restarting experiment.")
|
||||
else:
|
||||
logger.debug("Starting a new experiment.")
|
||||
|
||||
@@ -185,6 +189,10 @@ class TrialRunner:
|
||||
self._local_checkpoint_dir,
|
||||
TrialRunner.CKPT_FILE_TMPL.format(self._session_str))
|
||||
|
||||
@property
|
||||
def resumed(self):
|
||||
return self._resumed
|
||||
|
||||
@property
|
||||
def scheduler_alg(self):
|
||||
return self._scheduler_alg
|
||||
@@ -240,12 +248,6 @@ class TrialRunner:
|
||||
(fname.startswith("experiment_state") and fname.endswith(".json"))
|
||||
for fname in os.listdir(directory))
|
||||
|
||||
def add_experiment(self, experiment):
|
||||
if not self._resumed:
|
||||
self._search_alg.add_configurations([experiment])
|
||||
else:
|
||||
logger.info("TrialRunner resumed, ignoring new add_experiment.")
|
||||
|
||||
def checkpoint(self, force=False):
|
||||
"""Saves execution state to `self._local_checkpoint_dir`.
|
||||
|
||||
@@ -280,8 +282,8 @@ class TrialRunner:
|
||||
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
|
||||
|
||||
os.replace(tmp_file_name, self.checkpoint_file)
|
||||
|
||||
Searcher.save_to_dir(self._search_alg, self._local_checkpoint_dir)
|
||||
self._search_alg.save_to_dir(
|
||||
self._local_checkpoint_dir, session_str=self._session_str)
|
||||
|
||||
if force:
|
||||
self._syncer.sync_up()
|
||||
@@ -299,14 +301,16 @@ class TrialRunner:
|
||||
with open(newest_ckpt_path, "r") as f:
|
||||
runner_state = json.load(f, cls=_TuneFunctionDecoder)
|
||||
self.checkpoint_file = newest_ckpt_path
|
||||
|
||||
logger.warning("".join([
|
||||
"Attempting to resume experiment from {}. ".format(
|
||||
self._local_checkpoint_dir), "This feature is experimental, "
|
||||
"and may not work with all search algorithms. ",
|
||||
self._local_checkpoint_dir),
|
||||
"This will ignore any new changes to the specification."
|
||||
]))
|
||||
|
||||
self.__setstate__(runner_state["runner_data"])
|
||||
if self._search_alg.has_checkpoint(self._local_checkpoint_dir):
|
||||
self._search_alg.restore_from_dir(self._local_checkpoint_dir)
|
||||
|
||||
trials = []
|
||||
for trial_cp in runner_state["checkpoints"]:
|
||||
|
||||
+11
-5
@@ -3,8 +3,8 @@ import logging
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list, Experiment
|
||||
from ray.tune.analysis import ExperimentAnalysis
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import Searcher, SearchGenerator
|
||||
from ray.tune.suggest import BasicVariantGenerator, SearchGenerator
|
||||
from ray.tune.suggest.suggestion import Searcher
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
@@ -300,8 +300,11 @@ def run(run_or_experiment,
|
||||
if issubclass(type(search_alg), Searcher):
|
||||
search_alg = SearchGenerator(search_alg)
|
||||
|
||||
if not search_alg:
|
||||
search_alg = BasicVariantGenerator()
|
||||
|
||||
runner = TrialRunner(
|
||||
search_alg=search_alg or BasicVariantGenerator(),
|
||||
search_alg=search_alg,
|
||||
scheduler=scheduler or FIFOScheduler(),
|
||||
local_checkpoint_dir=experiments[0].checkpoint_dir,
|
||||
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
|
||||
@@ -315,8 +318,11 @@ def run(run_or_experiment,
|
||||
fail_fast=fail_fast,
|
||||
trial_executor=trial_executor)
|
||||
|
||||
for exp in experiments:
|
||||
runner.add_experiment(exp)
|
||||
if not runner.resumed:
|
||||
for exp in experiments:
|
||||
search_alg.add_configurations([exp])
|
||||
else:
|
||||
logger.info("TrialRunner resumed, ignoring new add_experiment.")
|
||||
|
||||
if progress_reporter is None:
|
||||
if IS_NOTEBOOK:
|
||||
|
||||
Reference in New Issue
Block a user