From 336aef1774aecb3db41f6e2c1d35f28e41279e1a Mon Sep 17 00:00:00 2001 From: Hersh Godse Date: Tue, 10 Sep 2019 13:11:59 -0700 Subject: [PATCH] [tune] Save and Restore for bayesopt (#5623) --- python/ray/tune/suggest/bayesopt.py | 11 ++++ python/ray/tune/suggest/skopt.py | 11 ++++ python/ray/tune/tests/test_tune_restore.py | 58 +++++++++++++++++++++- 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/suggest/bayesopt.py b/python/ray/tune/suggest/bayesopt.py index 7b2530411..65e220354 100644 --- a/python/ray/tune/suggest/bayesopt.py +++ b/python/ray/tune/suggest/bayesopt.py @@ -4,6 +4,7 @@ from __future__ import print_function import copy import logging +import pickle try: # Python 3 only -- needed for lint test. import bayes_opt as byo except ImportError: @@ -111,3 +112,13 @@ class BayesOptSearch(SuggestionAlgorithm): def _num_live_trials(self): return len(self._live_trial_mapping) + + def save(self, checkpoint_dir): + trials_object = self.optimizer + with open(checkpoint_dir, "wb") as output: + pickle.dump(trials_object, output) + + def restore(self, checkpoint_dir): + with open(checkpoint_dir, "rb") as input: + trials_object = pickle.load(input) + self.optimizer = trials_object diff --git a/python/ray/tune/suggest/skopt.py b/python/ray/tune/suggest/skopt.py index f60a0856e..0950b31bb 100644 --- a/python/ray/tune/suggest/skopt.py +++ b/python/ray/tune/suggest/skopt.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import logging +import pickle try: import skopt as sko except ImportError: @@ -157,3 +158,13 @@ class SkOptSearch(SuggestionAlgorithm): def _num_live_trials(self): return len(self._live_trial_mapping) + + def save(self, checkpoint_dir): + trials_object = self._skopt_opt + with open(checkpoint_dir, "wb") as output: + pickle.dump(trials_object, output) + + def restore(self, checkpoint_dir): + with open(checkpoint_dir, "rb") as input: + trials_object = pickle.load(input) + self._skopt_opt = trials_object diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py index 974fab6d3..e558eb5d9 100644 --- a/python/ray/tune/tests/test_tune_restore.py +++ b/python/ray/tune/tests/test_tune_restore.py @@ -15,6 +15,7 @@ from ray.tests.utils import recursive_fnmatch from ray.tune.util import validate_save_restore from ray.rllib import _register_all from ray.tune.suggest.hyperopt import HyperOptSearch +from ray.tune.suggest.bayesopt import BayesOptSearch class TuneRestoreTest(unittest.TestCase): @@ -146,7 +147,7 @@ class HyperoptWarmStartTest(unittest.TestCase): def run_exp_1(self): search_alg, cost = self.set_basic_conf() results_exp_1 = tune.run(cost, num_samples=15, search_alg=search_alg) - self.log_dir = os.path.join(self.tmpdir, "trials_algo1.pkl") + self.log_dir = os.path.join(self.tmpdir, "trials_algo_hyo.pkl") search_alg.save(self.log_dir) return results_exp_1 @@ -169,5 +170,60 @@ class HyperoptWarmStartTest(unittest.TestCase): self.assertEqual(trials_1_config + trials_2_config, trials_3_config) +class BayesoptWarmStartTest(unittest.TestCase): + def setUp(self): + ray.init(local_mode=True) + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmpdir) + ray.shutdown() + _register_all() + + def set_basic_conf(self): + space = {"width": (0, 20), "height": (-100, 100)} + + def cost(space, reporter): + loss = space["width"]**2 + space["height"]**2 + reporter(loss=loss) + + search_alg = BayesOptSearch( + space, + max_concurrent=1, + metric="loss", + mode="min", + utility_kwargs={ + "kind": "ucb", + "kappa": 2.5, + "xi": 0.0 + }) + return search_alg, cost + + def run_exp_1(self): + search_alg, cost = self.set_basic_conf() + results_exp_1 = tune.run(cost, num_samples=15, search_alg=search_alg) + self.log_dir = os.path.join(self.tmpdir, "trials_algo_byo.pkl") + search_alg.save(self.log_dir) + return results_exp_1 + + def run_exp_2(self): + search_alg2, cost = self.set_basic_conf() + search_alg2.restore(self.log_dir) + return tune.run(cost, num_samples=15, search_alg=search_alg2) + + def run_exp_3(self): + search_alg3, cost = self.set_basic_conf() + return tune.run(cost, num_samples=30, search_alg=search_alg3) + + def testBayesoptWarmStart(self): + results_exp_1 = self.run_exp_1() + results_exp_2 = self.run_exp_2() + results_exp_3 = self.run_exp_3() + trials_1_config = [trial.config for trial in results_exp_1.trials] + trials_2_config = [trial.config for trial in results_exp_2.trials] + trials_3_config = [trial.config for trial in results_exp_3.trials] + self.assertEqual(trials_1_config + trials_2_config, trials_3_config) + + if __name__ == "__main__": unittest.main(verbosity=2)