From 25f561469127c69f1669ccad89bf25dc28b17da9 Mon Sep 17 00:00:00 2001 From: raoul-khour-ts <69156393+raoul-khour-ts@users.noreply.github.com> Date: Mon, 31 Aug 2020 16:15:12 -0400 Subject: [PATCH] [tune] Rrk/sigopt_searcher_improvements (#10446) --- python/ray/tune/suggest/sigopt.py | 39 ++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/python/ray/tune/suggest/sigopt.py b/python/ray/tune/suggest/sigopt.py index 34a6226b5..d371dd66b 100644 --- a/python/ray/tune/suggest/sigopt.py +++ b/python/ray/tune/suggest/sigopt.py @@ -35,6 +35,10 @@ class SigOptSearch(Searcher): name (str): Name of experiment. Required by SigOpt. max_concurrent (int): Number of maximum concurrent trials supported based on the user's SigOpt plan. Defaults to 1. + connection (Connection): An existing connection to SigOpt. + observation_budget (int): Optional, can improve SigOpt performance. + project (str): Optional, Project name to assign this experiment to. + SigOpt can group experiments by project metric (str): The training result objective value attribute. mode (str): One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute. @@ -71,15 +75,25 @@ class SigOptSearch(Searcher): name="Default Tune Experiment", max_concurrent=1, reward_attr=None, + connection=None, + observation_budget=None, + project=None, metric="episode_reward_mean", mode="max", **kwargs): - assert sgo is not None, "SigOpt must be installed!" assert type(max_concurrent) is int and max_concurrent > 0 - assert "SIGOPT_KEY" in os.environ, \ - "SigOpt API key must be stored as environ variable at SIGOPT_KEY" assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + if connection is not None: + self.conn = connection + else: + assert sgo is not None, "SigOpt must be installed!" + assert "SIGOPT_KEY" in os.environ, \ + "SigOpt API key must be stored as " \ + "environ variable at SIGOPT_KEY" + # Create a connection with SigOpt API, requires API key + self.conn = sgo.Connection(client_token=os.environ["SIGOPT_KEY"]) + self._max_concurrent = max_concurrent self._metric = metric if mode == "max": @@ -88,18 +102,25 @@ class SigOptSearch(Searcher): self._metric_op = -1. self._live_trial_mapping = {} - # Create a connection with SigOpt API, requires API key - self.conn = sgo.Connection(client_token=os.environ["SIGOPT_KEY"]) - - self.experiment = self.conn.experiments().create( + sigopt_params = dict( name=name, parameters=space, - parallel_bandwidth=self._max_concurrent, - ) + parallel_bandwidth=self._max_concurrent) + + if observation_budget is not None: + sigopt_params["observation_budget"] = observation_budget + + if project is not None: + sigopt_params["project"] = project + + self.experiment = self.conn.experiments().create(**sigopt_params) super(SigOptSearch, self).__init__(metric=metric, mode=mode, **kwargs) def suggest(self, trial_id): + if self._max_concurrent: + if len(self._live_trial_mapping) >= self._max_concurrent: + return None # Get new suggestion from SigOpt suggestion = self.conn.experiments( self.experiment.id).suggestions().create()