mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 04:23:03 +08:00
[tune] Avoid breakage - soft deprecation warning for search algs (#8258)
This commit is contained in:
committed by
SangBin Cho
parent
40bb225d7a
commit
aba50e1a47
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(configure_logging=False)
|
||||
ray.init()
|
||||
|
||||
space = {
|
||||
"width": hp.uniform("width", 0, 20),
|
||||
|
||||
@@ -70,15 +70,18 @@ class AxSearch(Searcher):
|
||||
logger.warning("Detected sequential enforcement. Setting max "
|
||||
"concurrency to 1.")
|
||||
max_concurrent = 1
|
||||
self.max_concurrent = max_concurrent
|
||||
self._parameters = list(exp.parameters)
|
||||
self._live_index_mapping = {}
|
||||
super(AxSearch, self).__init__(
|
||||
metric=self._objective_name,
|
||||
mode=mode,
|
||||
max_concurrent=max_concurrent,
|
||||
use_early_stopped_trials=use_early_stopped_trials)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if self.max_concurrent:
|
||||
if len(self._live_trial_mapping) >= self.max_concurrent:
|
||||
return None
|
||||
parameters, trial_index = self._ax.get_next_trial()
|
||||
self._live_index_mapping[trial_id] = trial_index
|
||||
return parameters
|
||||
|
||||
@@ -60,7 +60,7 @@ class BayesOptSearch(Searcher):
|
||||
assert utility_kwargs is not None, (
|
||||
"Must define arguments for the utility function!")
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
|
||||
self.max_concurrent = max_concurrent
|
||||
super(BayesOptSearch, self).__init__(
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
@@ -79,6 +79,9 @@ class BayesOptSearch(Searcher):
|
||||
self.utility = byo.UtilityFunction(**utility_kwargs)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if self.max_concurrent:
|
||||
if len(self._live_trial_mapping) >= self.max_concurrent:
|
||||
return None
|
||||
new_trial = self.optimizer.suggest(self.utility)
|
||||
|
||||
self._live_trial_mapping[trial_id] = new_trial
|
||||
|
||||
@@ -88,6 +88,7 @@ class HyperOptSearch(Searcher):
|
||||
mode=mode,
|
||||
max_concurrent=max_concurrent,
|
||||
use_early_stopped_trials=use_early_stopped_trials)
|
||||
self.max_concurrent = max_concurrent
|
||||
# hyperopt internally minimizes, so "max" => -1
|
||||
if mode == "max":
|
||||
self.metric_op = -1.
|
||||
@@ -118,6 +119,9 @@ class HyperOptSearch(Searcher):
|
||||
self.rstate = np.random.RandomState(random_state_seed)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if self.max_concurrent:
|
||||
if len(self._live_trial_mapping) >= self.max_concurrent:
|
||||
return None
|
||||
if self._points_to_evaluate > 0:
|
||||
new_trial = self._hpopt_trials.trials[self._points_to_evaluate - 1]
|
||||
self._points_to_evaluate -= 1
|
||||
|
||||
@@ -62,6 +62,7 @@ class NevergradSearch(Searcher):
|
||||
parameter_names,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
max_concurrent=None,
|
||||
**kwargs):
|
||||
assert ng is not None, "Nevergrad must be installed!"
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
@@ -74,8 +75,9 @@ class NevergradSearch(Searcher):
|
||||
self._metric_op = 1.
|
||||
self._nevergrad_opt = optimizer
|
||||
self._live_trial_mapping = {}
|
||||
self.max_concurrent = max_concurrent
|
||||
super(NevergradSearch, self).__init__(
|
||||
metric=metric, mode=mode, **kwargs)
|
||||
metric=metric, mode=mode, max_concurrent=max_concurrent, **kwargs)
|
||||
# validate parameters
|
||||
if hasattr(optimizer, "instrumentation"): # added in v0.2.0
|
||||
if optimizer.instrumentation.kwargs:
|
||||
@@ -98,6 +100,9 @@ class NevergradSearch(Searcher):
|
||||
"dimension for non-instrumented optimizers")
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if self.max_concurrent:
|
||||
if len(self._live_trial_mapping) >= self.max_concurrent:
|
||||
return None
|
||||
suggested_config = self._nevergrad_opt.ask()
|
||||
self._live_trial_mapping[trial_id] = suggested_config
|
||||
# in v0.2.0+, output of ask() is a Candidate,
|
||||
|
||||
@@ -93,6 +93,7 @@ class SkOptSearch(Searcher):
|
||||
_validate_warmstart(parameter_names, points_to_evaluate,
|
||||
evaluated_rewards)
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
self.max_concurrent = max_concurrent
|
||||
super(SkOptSearch, self).__init__(
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
@@ -114,6 +115,9 @@ class SkOptSearch(Searcher):
|
||||
self._live_trial_mapping = {}
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if self.max_concurrent:
|
||||
if len(self._live_trial_mapping) >= self.max_concurrent:
|
||||
return None
|
||||
if self._initial_points:
|
||||
suggested_config = self._initial_points[0]
|
||||
del self._initial_points[0]
|
||||
|
||||
@@ -68,9 +68,10 @@ class Searcher:
|
||||
"Early stopped trials are now always used. If this is a "
|
||||
"problem, file an issue: https://github.com/ray-project/ray.")
|
||||
if max_concurrent is not None:
|
||||
raise DeprecationWarning(
|
||||
"max_concurrent is now deprecated for this search algorithm. "
|
||||
"Please use tune.suggest.ConcurrencyLimiter instead.")
|
||||
logger.warning(
|
||||
"DeprecationWarning: `max_concurrent` is deprecated for this "
|
||||
"search algorithm. Use tune.suggest.ConcurrencyLimiter() "
|
||||
"instead. This will raise an error in future versions of Ray.")
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
self._metric = metric
|
||||
self._mode = mode
|
||||
|
||||
Reference in New Issue
Block a user