[tune] Search alg checkpointing during training (#9803)

Co-authored-by: krfricke <krfricke@users.noreply.github.com>
This commit is contained in:
Richard Liaw
2020-08-03 15:07:31 -07:00
committed by GitHub
parent db09f70315
commit c6404e8cf6
11 changed files with 320 additions and 45 deletions
+4 -4
View File
@@ -270,16 +270,16 @@ class BayesOptSearch(Searcher):
"""Register given tuple of params and results."""
self.optimizer.register(params, self._metric_op * result[self.metric])
def save(self, checkpoint_dir):
def save(self, checkpoint_path):
"""Storing current optimizer state."""
with open(checkpoint_dir, "wb") as f:
with open(checkpoint_path, "wb") as f:
pickle.dump(
(self.optimizer, self._buffered_trial_results,
self._total_random_search_trials, self._config_counter), f)
def restore(self, checkpoint_dir):
def restore(self, checkpoint_path):
"""Restoring current optimizer state."""
with open(checkpoint_dir, "rb") as f:
with open(checkpoint_path, "rb") as f:
(self.optimizer, self._buffered_trial_results,
self._total_random_search_trials,
self._config_counter) = pickle.load(f)
+4 -4
View File
@@ -212,13 +212,13 @@ class HyperOptSearch(Searcher):
t for t in self._hpopt_trials.trials if t["tid"] == hyperopt_tid
][0]
def save(self, checkpoint_dir):
def save(self, checkpoint_path):
trials_object = (self._hpopt_trials, self.rstate.get_state())
with open(checkpoint_dir, "wb") as outputFile:
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
def restore(self, checkpoint_dir):
with open(checkpoint_dir, "rb") as inputFile:
def restore(self, checkpoint_path):
with open(checkpoint_path, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self._hpopt_trials = trials_object[0]
self.rstate.set_state(trials_object[1])
+4 -4
View File
@@ -137,13 +137,13 @@ class NevergradSearch(Searcher):
self._nevergrad_opt.tell(ng_trial_info,
self._metric_op * result[self._metric])
def save(self, checkpoint_dir):
def save(self, checkpoint_path):
trials_object = (self._nevergrad_opt, self._parameters)
with open(checkpoint_dir, "wb") as outputFile:
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
def restore(self, checkpoint_dir):
with open(checkpoint_dir, "rb") as inputFile:
def restore(self, checkpoint_path):
with open(checkpoint_path, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self._nevergrad_opt = trials_object[0]
self._parameters = trials_object[1]
+6
View File
@@ -62,3 +62,9 @@ class SearchAlgorithm:
def set_finished(self):
"""Marks the search algorithm as finished."""
self._finished = True
def save(self, *args):
pass
def restore(self, *args):
pass
+4 -4
View File
@@ -130,13 +130,13 @@ class SigOptSearch(Searcher):
failed=True, suggestion=self._live_trial_mapping[trial_id].id)
del self._live_trial_mapping[trial_id]
def save(self, checkpoint_dir):
def save(self, checkpoint_path):
trials_object = (self.conn, self.experiment)
with open(checkpoint_dir, "wb") as outputFile:
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
def restore(self, checkpoint_dir):
with open(checkpoint_dir, "rb") as inputFile:
def restore(self, checkpoint_path):
with open(checkpoint_path, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self.conn = trials_object[0]
self.experiment = trials_object[1]
+4 -4
View File
@@ -157,13 +157,13 @@ class SkOptSearch(Searcher):
self._skopt_opt.tell(skopt_trial_info,
self._metric_op * result[self._metric])
def save(self, checkpoint_dir):
def save(self, checkpoint_path):
trials_object = (self._initial_points, self._skopt_opt)
with open(checkpoint_dir, "wb") as outputFile:
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
def restore(self, checkpoint_dir):
with open(checkpoint_dir, "rb") as inputFile:
def restore(self, checkpoint_path):
with open(checkpoint_path, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self._initial_points = trials_object[0]
self._skopt_opt = trials_object[1]
+106 -4
View File
@@ -1,5 +1,6 @@
import copy
import logging
import os
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list
@@ -58,6 +59,7 @@ class Searcher:
"""
FINISHED = "FINISHED"
CKPT_FILE = "searcher-state.pkl"
def __init__(self,
metric="episode_reward_mean",
@@ -130,14 +132,108 @@ class Searcher:
"""
raise NotImplementedError
def save(self, checkpoint_dir):
"""Save function for this object."""
def save(self, checkpoint_path):
"""Save state to path for this search algorithm.
Args:
checkpoint_path (str): File where the search algorithm
state is saved. This path should be used later when
restoring from file.
Example:
.. code-block:: python
search_alg = Searcher(...)
analysis = tune.run(
cost,
num_samples=5,
search_alg=search_alg,
name=self.experiment_name,
local_dir=self.tmpdir)
search_alg.save("./my_favorite_path.pkl")
.. versionchanged:: 0.8.7
Save is automatically called by `tune.run`. You can use
`restore_from_dir` to restore from an experiment directory
such as `~/ray_results/trainable`.
"""
raise NotImplementedError
def restore(self, checkpoint_dir):
"""Restore function for this object."""
def restore(self, checkpoint_path):
"""Restore state for this search algorithm
Args:
checkpoint_path (str): File where the search algorithm
state is saved. This path should be the same
as the one provided to "save".
Example:
.. code-block:: python
search_alg.save("./my_favorite_path.pkl")
search_alg2 = Searcher(...)
search_alg2 = ConcurrencyLimiter(search_alg2, 1)
search_alg2.restore(checkpoint_path)
tune.run(cost, num_samples=5, search_alg=search_alg2)
"""
raise NotImplementedError
def save_to_dir(self, checkpoint_dir):
"""Automatically saves the given searcher to the checkpoint_dir.
This is automatically used by tune.run during a Tune job.
"""
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
".tmp_searcher_ckpt")
success = True
try:
self.save(tmp_search_ckpt_path)
except NotImplementedError as e:
logger.warning(e)
success = False
if success and os.path.exists(tmp_search_ckpt_path):
os.rename(tmp_search_ckpt_path,
os.path.join(checkpoint_dir, Searcher.CKPT_FILE))
def restore_from_dir(self, checkpoint_dir):
"""Restores the state of a searcher from a given checkpoint_dir.
Typically, you should use this function to restore from an
experiment directory such as `~/ray_results/trainable`.
.. code-block:: python
experiment_1 = tune.run(
cost,
num_samples=5,
search_alg=search_alg,
verbose=0,
name=self.experiment_name,
local_dir="~/my_results")
search_alg2 = Searcher()
search_alg2.restore_from_dir(
os.path.join("~/my_results", self.experiment_name)
"""
checkpoint_path = os.path.join(checkpoint_dir, Searcher.CKPT_FILE)
if os.path.exists(checkpoint_path):
self.restore(checkpoint_path)
else:
raise FileNotFoundError(
"{filename} not found in {directory}. Unable to restore "
"searcher state from directory.".format(
filename=Searcher.CKPT_FILE, directory=checkpoint_dir))
@property
def metric(self):
"""The training result objective value attribute."""
@@ -294,6 +390,12 @@ class SearchGenerator(SearchAlgorithm):
def is_finished(self):
return self._counter >= self._total_samples or self._finished
def save(self, checkpoint_path):
self.searcher.save(checkpoint_path)
def restore(self, checkpoint_path):
self.searcher.restore(checkpoint_path)
class _MockSearcher(Searcher):
def __init__(self, **kwargs):
+4 -4
View File
@@ -133,12 +133,12 @@ class ZOOptSearch(Searcher):
del self._live_trial_mapping[trial_id]
def save(self, checkpoint_dir):
def save(self, checkpoint_path):
trials_object = self.optimizer
with open(checkpoint_dir, "wb") as output:
with open(checkpoint_path, "wb") as output:
pickle.dump(trials_object, output)
def restore(self, checkpoint_dir):
with open(checkpoint_dir, "rb") as input:
def restore(self, checkpoint_path):
with open(checkpoint_path, "rb") as input:
trials_object = pickle.load(input)
self.optimizer = trials_object
+121 -14
View File
@@ -13,8 +13,9 @@ import ray
from ray import tune
from ray.test_utils import recursive_fnmatch
from ray.rllib import _register_all
from ray.tune.suggest import ConcurrencyLimiter
from ray.tune.suggest import ConcurrencyLimiter, Searcher
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.dragonfly import DragonflySearch
from ray.tune.suggest.bayesopt import BayesOptSearch
from ray.tune.suggest.skopt import SkOptSearch
from ray.tune.suggest.nevergrad import NevergradSearch
@@ -138,6 +139,7 @@ class AbstractWarmStartTest:
def setUp(self):
ray.init(num_cpus=1, local_mode=True)
self.tmpdir = tempfile.mkdtemp()
self.experiment_name = "results"
def tearDown(self):
shutil.rmtree(self.tmpdir)
@@ -147,24 +149,43 @@ class AbstractWarmStartTest:
def set_basic_conf(self):
raise NotImplementedError()
def run_exp_1(self):
def run_part_from_scratch(self):
np.random.seed(162)
search_alg, cost = self.set_basic_conf()
search_alg = ConcurrencyLimiter(search_alg, 1)
results_exp_1 = tune.run(
cost, num_samples=5, search_alg=search_alg, verbose=0)
self.log_dir = os.path.join(self.tmpdir, "warmStartTest.pkl")
search_alg.save(self.log_dir)
return results_exp_1
cost,
num_samples=5,
search_alg=search_alg,
verbose=0,
name=self.experiment_name,
local_dir=self.tmpdir)
checkpoint_path = os.path.join(self.tmpdir, "warmStartTest.pkl")
search_alg.save(checkpoint_path)
return results_exp_1, np.random.get_state(), checkpoint_path
def run_exp_2(self):
def run_from_experiment_restore(self, random_state):
search_alg, cost = self.set_basic_conf()
search_alg = ConcurrencyLimiter(search_alg, 1)
search_alg.restore_from_dir(
os.path.join(self.tmpdir, self.experiment_name))
results = tune.run(
cost,
num_samples=5,
search_alg=search_alg,
verbose=0,
name=self.experiment_name,
local_dir=self.tmpdir)
return results
def run_explicit_restore(self, random_state, checkpoint_path):
np.random.set_state(random_state)
search_alg2, cost = self.set_basic_conf()
search_alg2 = ConcurrencyLimiter(search_alg2, 1)
search_alg2.restore(self.log_dir)
search_alg2.restore(checkpoint_path)
return tune.run(cost, num_samples=5, search_alg=search_alg2, verbose=0)
def run_exp_3(self):
print("FULL RUN")
def run_full(self):
np.random.seed(162)
search_alg3, cost = self.set_basic_conf()
search_alg3 = ConcurrencyLimiter(search_alg3, 1)
@@ -172,9 +193,19 @@ class AbstractWarmStartTest:
cost, num_samples=10, search_alg=search_alg3, verbose=0)
def testWarmStart(self):
results_exp_1 = self.run_exp_1()
results_exp_2 = self.run_exp_2()
results_exp_3 = self.run_exp_3()
results_exp_1, r_state, checkpoint_path = self.run_part_from_scratch()
results_exp_2 = self.run_explicit_restore(r_state, checkpoint_path)
results_exp_3 = self.run_full()
trials_1_config = [trial.config for trial in results_exp_1.trials]
trials_2_config = [trial.config for trial in results_exp_2.trials]
trials_3_config = [trial.config for trial in results_exp_3.trials]
self.assertEqual(trials_1_config + trials_2_config, trials_3_config)
def testRestore(self):
results_exp_1, r_state, checkpoint_path = self.run_part_from_scratch()
results_exp_2 = self.run_from_experiment_restore(r_state)
results_exp_3 = self.run_full()
trials_1_config = [trial.config for trial in results_exp_1.trials]
trials_2_config = [trial.config for trial in results_exp_2.trials]
trials_3_config = [trial.config for trial in results_exp_3.trials]
@@ -216,7 +247,7 @@ class BayesoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
return search_alg, cost
def testBootStrapAnalysis(self):
analysis = self.run_exp_3()
analysis = self.run_full()
search_alg3, cost = self.set_basic_conf(analysis)
search_alg3 = ConcurrencyLimiter(search_alg3, 1)
tune.run(cost, num_samples=10, search_alg=search_alg3, verbose=0)
@@ -261,6 +292,50 @@ class NevergradWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
return search_alg, cost
class DragonflyWarmSTartTest(AbstractWarmStartTest, unittest.TestCase):
def set_basic_conf(self):
from dragonfly.opt.gp_bandit import EuclideanGPBandit
from dragonfly.exd.experiment_caller import EuclideanFunctionCaller
from dragonfly import load_config
def cost(space, reporter):
height, width = space["point"]
reporter(loss=(height - 14)**2 - abs(width - 3))
domain_vars = [{
"name": "height",
"type": "float",
"min": -10,
"max": 10
}, {
"name": "width",
"type": "float",
"min": 0,
"max": 20
}]
domain_config = load_config({"domain": domain_vars})
func_caller = EuclideanFunctionCaller(
None, domain_config.domain.list_of_domains[0])
optimizer = EuclideanGPBandit(func_caller, ask_tell_mode=True)
search_alg = DragonflySearch(
optimizer,
metric="loss",
mode="min",
max_concurrent=1000, # Here to avoid breaking back-compat.
)
return search_alg, cost
@unittest.skip("Skip because this doesn't seem to work.")
def testWarmStart(self):
pass
@unittest.skip("Skip because this doesn't seem to work.")
def testRestore(self):
pass
class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
def set_basic_conf(self):
space = [
@@ -299,6 +374,11 @@ class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
super().testWarmStart()
def testRestore(self):
if ("SIGOPT_KEY" not in os.environ):
return
super().testRestore()
class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
def set_basic_conf(self):
@@ -319,6 +399,33 @@ class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
return search_alg, cost
@unittest.skip("Skip because this seems to have leaking state.")
def testRestore(self):
pass
class SearcherTest(unittest.TestCase):
class MockSearcher(Searcher):
def __init__(self, data):
self.data = data
def save(self, path):
with open(path, "w") as f:
f.write(self.data)
def restore(self, path):
with open(path, "r") as f:
self.data = f.read()
def testSaveRestoreDir(self):
tmpdir = tempfile.mkdtemp()
original_data = "hello-its-me"
searcher = self.MockSearcher(original_data)
searcher.save_to_dir(tmpdir)
searcher_2 = self.MockSearcher("no-its-not-me")
searcher_2.restore_from_dir(tmpdir)
assert searcher_2.data == original_data
if __name__ == "__main__":
import pytest
+7 -1
View File
@@ -17,7 +17,7 @@ from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
from ray.tune.syncer import get_cloud_syncer
from ray.tune.trial import Checkpoint, Trial
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.suggest import BasicVariantGenerator, Searcher
from ray.tune.utils import warn_if_slow, flatten_dict
from ray.tune.web_server import TuneServer
from ray.utils import binary_to_hex, hex_to_binary
@@ -252,6 +252,9 @@ class TrialRunner:
Overwrites the current session checkpoint, which starts when self
is instantiated. Throttle depends on self._checkpoint_period.
Also automatically saves the search algorithm to the local
checkpoint dir.
Args:
force (bool): Forces a checkpoint despite checkpoint_period.
"""
@@ -277,6 +280,9 @@ class TrialRunner:
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
os.replace(tmp_file_name, self.checkpoint_file)
Searcher.save_to_dir(self._search_alg, self._local_checkpoint_dir)
if force:
self._syncer.sync_up()
else: