mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 03:10:54 +08:00
[tune] Rrk/sigopt_searcher_improvements (#10446)
This commit is contained in:
@@ -35,6 +35,10 @@ class SigOptSearch(Searcher):
|
||||
name (str): Name of experiment. Required by SigOpt.
|
||||
max_concurrent (int): Number of maximum concurrent trials supported
|
||||
based on the user's SigOpt plan. Defaults to 1.
|
||||
connection (Connection): An existing connection to SigOpt.
|
||||
observation_budget (int): Optional, can improve SigOpt performance.
|
||||
project (str): Optional, Project name to assign this experiment to.
|
||||
SigOpt can group experiments by project
|
||||
metric (str): The training result objective value attribute.
|
||||
mode (str): One of {min, max}. Determines whether objective is
|
||||
minimizing or maximizing the metric attribute.
|
||||
@@ -71,15 +75,25 @@ class SigOptSearch(Searcher):
|
||||
name="Default Tune Experiment",
|
||||
max_concurrent=1,
|
||||
reward_attr=None,
|
||||
connection=None,
|
||||
observation_budget=None,
|
||||
project=None,
|
||||
metric="episode_reward_mean",
|
||||
mode="max",
|
||||
**kwargs):
|
||||
assert sgo is not None, "SigOpt must be installed!"
|
||||
assert type(max_concurrent) is int and max_concurrent > 0
|
||||
assert "SIGOPT_KEY" in os.environ, \
|
||||
"SigOpt API key must be stored as environ variable at SIGOPT_KEY"
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
|
||||
|
||||
if connection is not None:
|
||||
self.conn = connection
|
||||
else:
|
||||
assert sgo is not None, "SigOpt must be installed!"
|
||||
assert "SIGOPT_KEY" in os.environ, \
|
||||
"SigOpt API key must be stored as " \
|
||||
"environ variable at SIGOPT_KEY"
|
||||
# Create a connection with SigOpt API, requires API key
|
||||
self.conn = sgo.Connection(client_token=os.environ["SIGOPT_KEY"])
|
||||
|
||||
self._max_concurrent = max_concurrent
|
||||
self._metric = metric
|
||||
if mode == "max":
|
||||
@@ -88,18 +102,25 @@ class SigOptSearch(Searcher):
|
||||
self._metric_op = -1.
|
||||
self._live_trial_mapping = {}
|
||||
|
||||
# Create a connection with SigOpt API, requires API key
|
||||
self.conn = sgo.Connection(client_token=os.environ["SIGOPT_KEY"])
|
||||
|
||||
self.experiment = self.conn.experiments().create(
|
||||
sigopt_params = dict(
|
||||
name=name,
|
||||
parameters=space,
|
||||
parallel_bandwidth=self._max_concurrent,
|
||||
)
|
||||
parallel_bandwidth=self._max_concurrent)
|
||||
|
||||
if observation_budget is not None:
|
||||
sigopt_params["observation_budget"] = observation_budget
|
||||
|
||||
if project is not None:
|
||||
sigopt_params["project"] = project
|
||||
|
||||
self.experiment = self.conn.experiments().create(**sigopt_params)
|
||||
|
||||
super(SigOptSearch, self).__init__(metric=metric, mode=mode, **kwargs)
|
||||
|
||||
def suggest(self, trial_id):
|
||||
if self._max_concurrent:
|
||||
if len(self._live_trial_mapping) >= self._max_concurrent:
|
||||
return None
|
||||
# Get new suggestion from SigOpt
|
||||
suggestion = self.conn.experiments(
|
||||
self.experiment.id).suggestions().create()
|
||||
|
||||
Reference in New Issue
Block a user