[tune] Rrk/sigopt_searcher_improvements (#10446)

This commit is contained in:
raoul-khour-ts
2020-08-31 16:15:12 -04:00
committed by GitHub
parent 6917efabc4
commit 25f5614691
+30 -9
View File
@@ -35,6 +35,10 @@ class SigOptSearch(Searcher):
name (str): Name of experiment. Required by SigOpt.
max_concurrent (int): Number of maximum concurrent trials supported
based on the user's SigOpt plan. Defaults to 1.
connection (Connection): An existing connection to SigOpt.
observation_budget (int): Optional, can improve SigOpt performance.
project (str): Optional, Project name to assign this experiment to.
SigOpt can group experiments by project
metric (str): The training result objective value attribute.
mode (str): One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
@@ -71,15 +75,25 @@ class SigOptSearch(Searcher):
name="Default Tune Experiment",
max_concurrent=1,
reward_attr=None,
connection=None,
observation_budget=None,
project=None,
metric="episode_reward_mean",
mode="max",
**kwargs):
assert sgo is not None, "SigOpt must be installed!"
assert type(max_concurrent) is int and max_concurrent > 0
assert "SIGOPT_KEY" in os.environ, \
"SigOpt API key must be stored as environ variable at SIGOPT_KEY"
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if connection is not None:
self.conn = connection
else:
assert sgo is not None, "SigOpt must be installed!"
assert "SIGOPT_KEY" in os.environ, \
"SigOpt API key must be stored as " \
"environ variable at SIGOPT_KEY"
# Create a connection with SigOpt API, requires API key
self.conn = sgo.Connection(client_token=os.environ["SIGOPT_KEY"])
self._max_concurrent = max_concurrent
self._metric = metric
if mode == "max":
@@ -88,18 +102,25 @@ class SigOptSearch(Searcher):
self._metric_op = -1.
self._live_trial_mapping = {}
# Create a connection with SigOpt API, requires API key
self.conn = sgo.Connection(client_token=os.environ["SIGOPT_KEY"])
self.experiment = self.conn.experiments().create(
sigopt_params = dict(
name=name,
parameters=space,
parallel_bandwidth=self._max_concurrent,
)
parallel_bandwidth=self._max_concurrent)
if observation_budget is not None:
sigopt_params["observation_budget"] = observation_budget
if project is not None:
sigopt_params["project"] = project
self.experiment = self.conn.experiments().create(**sigopt_params)
super(SigOptSearch, self).__init__(metric=metric, mode=mode, **kwargs)
def suggest(self, trial_id):
if self._max_concurrent:
if len(self._live_trial_mapping) >= self._max_concurrent:
return None
# Get new suggestion from SigOpt
suggestion = self.conn.experiments(
self.experiment.id).suggestions().create()