mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 06:08:03 +08:00
[tune] Search alg checkpointing during training (#9803)
Co-authored-by: krfricke <krfricke@users.noreply.github.com>
This commit is contained in:
@@ -270,16 +270,16 @@ class BayesOptSearch(Searcher):
|
||||
"""Register given tuple of params and results."""
|
||||
self.optimizer.register(params, self._metric_op * result[self.metric])
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
def save(self, checkpoint_path):
|
||||
"""Storing current optimizer state."""
|
||||
with open(checkpoint_dir, "wb") as f:
|
||||
with open(checkpoint_path, "wb") as f:
|
||||
pickle.dump(
|
||||
(self.optimizer, self._buffered_trial_results,
|
||||
self._total_random_search_trials, self._config_counter), f)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restoring current optimizer state."""
|
||||
with open(checkpoint_dir, "rb") as f:
|
||||
with open(checkpoint_path, "rb") as f:
|
||||
(self.optimizer, self._buffered_trial_results,
|
||||
self._total_random_search_trials,
|
||||
self._config_counter) = pickle.load(f)
|
||||
|
||||
@@ -212,13 +212,13 @@ class HyperOptSearch(Searcher):
|
||||
t for t in self._hpopt_trials.trials if t["tid"] == hyperopt_tid
|
||||
][0]
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
def save(self, checkpoint_path):
|
||||
trials_object = (self._hpopt_trials, self.rstate.get_state())
|
||||
with open(checkpoint_dir, "wb") as outputFile:
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(trials_object, outputFile)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
with open(checkpoint_dir, "rb") as inputFile:
|
||||
def restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
self._hpopt_trials = trials_object[0]
|
||||
self.rstate.set_state(trials_object[1])
|
||||
|
||||
@@ -137,13 +137,13 @@ class NevergradSearch(Searcher):
|
||||
self._nevergrad_opt.tell(ng_trial_info,
|
||||
self._metric_op * result[self._metric])
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
def save(self, checkpoint_path):
|
||||
trials_object = (self._nevergrad_opt, self._parameters)
|
||||
with open(checkpoint_dir, "wb") as outputFile:
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(trials_object, outputFile)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
with open(checkpoint_dir, "rb") as inputFile:
|
||||
def restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
self._nevergrad_opt = trials_object[0]
|
||||
self._parameters = trials_object[1]
|
||||
|
||||
@@ -62,3 +62,9 @@ class SearchAlgorithm:
|
||||
def set_finished(self):
|
||||
"""Marks the search algorithm as finished."""
|
||||
self._finished = True
|
||||
|
||||
def save(self, *args):
|
||||
pass
|
||||
|
||||
def restore(self, *args):
|
||||
pass
|
||||
|
||||
@@ -130,13 +130,13 @@ class SigOptSearch(Searcher):
|
||||
failed=True, suggestion=self._live_trial_mapping[trial_id].id)
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
def save(self, checkpoint_path):
|
||||
trials_object = (self.conn, self.experiment)
|
||||
with open(checkpoint_dir, "wb") as outputFile:
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(trials_object, outputFile)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
with open(checkpoint_dir, "rb") as inputFile:
|
||||
def restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
self.conn = trials_object[0]
|
||||
self.experiment = trials_object[1]
|
||||
|
||||
@@ -157,13 +157,13 @@ class SkOptSearch(Searcher):
|
||||
self._skopt_opt.tell(skopt_trial_info,
|
||||
self._metric_op * result[self._metric])
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
def save(self, checkpoint_path):
|
||||
trials_object = (self._initial_points, self._skopt_opt)
|
||||
with open(checkpoint_dir, "wb") as outputFile:
|
||||
with open(checkpoint_path, "wb") as outputFile:
|
||||
pickle.dump(trials_object, outputFile)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
with open(checkpoint_dir, "rb") as inputFile:
|
||||
def restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as inputFile:
|
||||
trials_object = pickle.load(inputFile)
|
||||
self._initial_points = trials_object[0]
|
||||
self._skopt_opt = trials_object[1]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list
|
||||
@@ -58,6 +59,7 @@ class Searcher:
|
||||
|
||||
"""
|
||||
FINISHED = "FINISHED"
|
||||
CKPT_FILE = "searcher-state.pkl"
|
||||
|
||||
def __init__(self,
|
||||
metric="episode_reward_mean",
|
||||
@@ -130,14 +132,108 @@ class Searcher:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
"""Save function for this object."""
|
||||
def save(self, checkpoint_path):
|
||||
"""Save state to path for this search algorithm.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): File where the search algorithm
|
||||
state is saved. This path should be used later when
|
||||
restoring from file.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
search_alg = Searcher(...)
|
||||
|
||||
analysis = tune.run(
|
||||
cost,
|
||||
num_samples=5,
|
||||
search_alg=search_alg,
|
||||
name=self.experiment_name,
|
||||
local_dir=self.tmpdir)
|
||||
|
||||
search_alg.save("./my_favorite_path.pkl")
|
||||
|
||||
.. versionchanged:: 0.8.7
|
||||
Save is automatically called by `tune.run`. You can use
|
||||
`restore_from_dir` to restore from an experiment directory
|
||||
such as `~/ray_results/trainable`.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
"""Restore function for this object."""
|
||||
def restore(self, checkpoint_path):
|
||||
"""Restore state for this search algorithm
|
||||
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): File where the search algorithm
|
||||
state is saved. This path should be the same
|
||||
as the one provided to "save".
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
search_alg.save("./my_favorite_path.pkl")
|
||||
|
||||
search_alg2 = Searcher(...)
|
||||
search_alg2 = ConcurrencyLimiter(search_alg2, 1)
|
||||
search_alg2.restore(checkpoint_path)
|
||||
tune.run(cost, num_samples=5, search_alg=search_alg2)
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_to_dir(self, checkpoint_dir):
|
||||
"""Automatically saves the given searcher to the checkpoint_dir.
|
||||
|
||||
This is automatically used by tune.run during a Tune job.
|
||||
"""
|
||||
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
|
||||
".tmp_searcher_ckpt")
|
||||
success = True
|
||||
try:
|
||||
self.save(tmp_search_ckpt_path)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(e)
|
||||
success = False
|
||||
|
||||
if success and os.path.exists(tmp_search_ckpt_path):
|
||||
os.rename(tmp_search_ckpt_path,
|
||||
os.path.join(checkpoint_dir, Searcher.CKPT_FILE))
|
||||
|
||||
def restore_from_dir(self, checkpoint_dir):
|
||||
"""Restores the state of a searcher from a given checkpoint_dir.
|
||||
|
||||
Typically, you should use this function to restore from an
|
||||
experiment directory such as `~/ray_results/trainable`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
experiment_1 = tune.run(
|
||||
cost,
|
||||
num_samples=5,
|
||||
search_alg=search_alg,
|
||||
verbose=0,
|
||||
name=self.experiment_name,
|
||||
local_dir="~/my_results")
|
||||
|
||||
search_alg2 = Searcher()
|
||||
search_alg2.restore_from_dir(
|
||||
os.path.join("~/my_results", self.experiment_name)
|
||||
"""
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, Searcher.CKPT_FILE)
|
||||
if os.path.exists(checkpoint_path):
|
||||
self.restore(checkpoint_path)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"{filename} not found in {directory}. Unable to restore "
|
||||
"searcher state from directory.".format(
|
||||
filename=Searcher.CKPT_FILE, directory=checkpoint_dir))
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
"""The training result objective value attribute."""
|
||||
@@ -294,6 +390,12 @@ class SearchGenerator(SearchAlgorithm):
|
||||
def is_finished(self):
|
||||
return self._counter >= self._total_samples or self._finished
|
||||
|
||||
def save(self, checkpoint_path):
|
||||
self.searcher.save(checkpoint_path)
|
||||
|
||||
def restore(self, checkpoint_path):
|
||||
self.searcher.restore(checkpoint_path)
|
||||
|
||||
|
||||
class _MockSearcher(Searcher):
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -133,12 +133,12 @@ class ZOOptSearch(Searcher):
|
||||
|
||||
del self._live_trial_mapping[trial_id]
|
||||
|
||||
def save(self, checkpoint_dir):
|
||||
def save(self, checkpoint_path):
|
||||
trials_object = self.optimizer
|
||||
with open(checkpoint_dir, "wb") as output:
|
||||
with open(checkpoint_path, "wb") as output:
|
||||
pickle.dump(trials_object, output)
|
||||
|
||||
def restore(self, checkpoint_dir):
|
||||
with open(checkpoint_dir, "rb") as input:
|
||||
def restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, "rb") as input:
|
||||
trials_object = pickle.load(input)
|
||||
self.optimizer = trials_object
|
||||
|
||||
@@ -13,8 +13,9 @@ import ray
|
||||
from ray import tune
|
||||
from ray.test_utils import recursive_fnmatch
|
||||
from ray.rllib import _register_all
|
||||
from ray.tune.suggest import ConcurrencyLimiter
|
||||
from ray.tune.suggest import ConcurrencyLimiter, Searcher
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.dragonfly import DragonflySearch
|
||||
from ray.tune.suggest.bayesopt import BayesOptSearch
|
||||
from ray.tune.suggest.skopt import SkOptSearch
|
||||
from ray.tune.suggest.nevergrad import NevergradSearch
|
||||
@@ -138,6 +139,7 @@ class AbstractWarmStartTest:
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1, local_mode=True)
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
self.experiment_name = "results"
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdir)
|
||||
@@ -147,24 +149,43 @@ class AbstractWarmStartTest:
|
||||
def set_basic_conf(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def run_exp_1(self):
|
||||
def run_part_from_scratch(self):
|
||||
np.random.seed(162)
|
||||
search_alg, cost = self.set_basic_conf()
|
||||
search_alg = ConcurrencyLimiter(search_alg, 1)
|
||||
results_exp_1 = tune.run(
|
||||
cost, num_samples=5, search_alg=search_alg, verbose=0)
|
||||
self.log_dir = os.path.join(self.tmpdir, "warmStartTest.pkl")
|
||||
search_alg.save(self.log_dir)
|
||||
return results_exp_1
|
||||
cost,
|
||||
num_samples=5,
|
||||
search_alg=search_alg,
|
||||
verbose=0,
|
||||
name=self.experiment_name,
|
||||
local_dir=self.tmpdir)
|
||||
checkpoint_path = os.path.join(self.tmpdir, "warmStartTest.pkl")
|
||||
search_alg.save(checkpoint_path)
|
||||
return results_exp_1, np.random.get_state(), checkpoint_path
|
||||
|
||||
def run_exp_2(self):
|
||||
def run_from_experiment_restore(self, random_state):
|
||||
search_alg, cost = self.set_basic_conf()
|
||||
search_alg = ConcurrencyLimiter(search_alg, 1)
|
||||
search_alg.restore_from_dir(
|
||||
os.path.join(self.tmpdir, self.experiment_name))
|
||||
results = tune.run(
|
||||
cost,
|
||||
num_samples=5,
|
||||
search_alg=search_alg,
|
||||
verbose=0,
|
||||
name=self.experiment_name,
|
||||
local_dir=self.tmpdir)
|
||||
return results
|
||||
|
||||
def run_explicit_restore(self, random_state, checkpoint_path):
|
||||
np.random.set_state(random_state)
|
||||
search_alg2, cost = self.set_basic_conf()
|
||||
search_alg2 = ConcurrencyLimiter(search_alg2, 1)
|
||||
search_alg2.restore(self.log_dir)
|
||||
search_alg2.restore(checkpoint_path)
|
||||
return tune.run(cost, num_samples=5, search_alg=search_alg2, verbose=0)
|
||||
|
||||
def run_exp_3(self):
|
||||
print("FULL RUN")
|
||||
def run_full(self):
|
||||
np.random.seed(162)
|
||||
search_alg3, cost = self.set_basic_conf()
|
||||
search_alg3 = ConcurrencyLimiter(search_alg3, 1)
|
||||
@@ -172,9 +193,19 @@ class AbstractWarmStartTest:
|
||||
cost, num_samples=10, search_alg=search_alg3, verbose=0)
|
||||
|
||||
def testWarmStart(self):
|
||||
results_exp_1 = self.run_exp_1()
|
||||
results_exp_2 = self.run_exp_2()
|
||||
results_exp_3 = self.run_exp_3()
|
||||
results_exp_1, r_state, checkpoint_path = self.run_part_from_scratch()
|
||||
results_exp_2 = self.run_explicit_restore(r_state, checkpoint_path)
|
||||
results_exp_3 = self.run_full()
|
||||
trials_1_config = [trial.config for trial in results_exp_1.trials]
|
||||
trials_2_config = [trial.config for trial in results_exp_2.trials]
|
||||
trials_3_config = [trial.config for trial in results_exp_3.trials]
|
||||
self.assertEqual(trials_1_config + trials_2_config, trials_3_config)
|
||||
|
||||
def testRestore(self):
|
||||
results_exp_1, r_state, checkpoint_path = self.run_part_from_scratch()
|
||||
results_exp_2 = self.run_from_experiment_restore(r_state)
|
||||
results_exp_3 = self.run_full()
|
||||
|
||||
trials_1_config = [trial.config for trial in results_exp_1.trials]
|
||||
trials_2_config = [trial.config for trial in results_exp_2.trials]
|
||||
trials_3_config = [trial.config for trial in results_exp_3.trials]
|
||||
@@ -216,7 +247,7 @@ class BayesoptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
return search_alg, cost
|
||||
|
||||
def testBootStrapAnalysis(self):
|
||||
analysis = self.run_exp_3()
|
||||
analysis = self.run_full()
|
||||
search_alg3, cost = self.set_basic_conf(analysis)
|
||||
search_alg3 = ConcurrencyLimiter(search_alg3, 1)
|
||||
tune.run(cost, num_samples=10, search_alg=search_alg3, verbose=0)
|
||||
@@ -261,6 +292,50 @@ class NevergradWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
return search_alg, cost
|
||||
|
||||
|
||||
class DragonflyWarmSTartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
def set_basic_conf(self):
|
||||
from dragonfly.opt.gp_bandit import EuclideanGPBandit
|
||||
from dragonfly.exd.experiment_caller import EuclideanFunctionCaller
|
||||
from dragonfly import load_config
|
||||
|
||||
def cost(space, reporter):
|
||||
height, width = space["point"]
|
||||
reporter(loss=(height - 14)**2 - abs(width - 3))
|
||||
|
||||
domain_vars = [{
|
||||
"name": "height",
|
||||
"type": "float",
|
||||
"min": -10,
|
||||
"max": 10
|
||||
}, {
|
||||
"name": "width",
|
||||
"type": "float",
|
||||
"min": 0,
|
||||
"max": 20
|
||||
}]
|
||||
|
||||
domain_config = load_config({"domain": domain_vars})
|
||||
|
||||
func_caller = EuclideanFunctionCaller(
|
||||
None, domain_config.domain.list_of_domains[0])
|
||||
optimizer = EuclideanGPBandit(func_caller, ask_tell_mode=True)
|
||||
search_alg = DragonflySearch(
|
||||
optimizer,
|
||||
metric="loss",
|
||||
mode="min",
|
||||
max_concurrent=1000, # Here to avoid breaking back-compat.
|
||||
)
|
||||
return search_alg, cost
|
||||
|
||||
@unittest.skip("Skip because this doesn't seem to work.")
|
||||
def testWarmStart(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skip because this doesn't seem to work.")
|
||||
def testRestore(self):
|
||||
pass
|
||||
|
||||
|
||||
class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
def set_basic_conf(self):
|
||||
space = [
|
||||
@@ -299,6 +374,11 @@ class SigOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
|
||||
super().testWarmStart()
|
||||
|
||||
def testRestore(self):
|
||||
if ("SIGOPT_KEY" not in os.environ):
|
||||
return
|
||||
super().testRestore()
|
||||
|
||||
|
||||
class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
def set_basic_conf(self):
|
||||
@@ -319,6 +399,33 @@ class ZOOptWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
|
||||
|
||||
return search_alg, cost
|
||||
|
||||
@unittest.skip("Skip because this seems to have leaking state.")
|
||||
def testRestore(self):
|
||||
pass
|
||||
|
||||
|
||||
class SearcherTest(unittest.TestCase):
|
||||
class MockSearcher(Searcher):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def save(self, path):
|
||||
with open(path, "w") as f:
|
||||
f.write(self.data)
|
||||
|
||||
def restore(self, path):
|
||||
with open(path, "r") as f:
|
||||
self.data = f.read()
|
||||
|
||||
def testSaveRestoreDir(self):
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
original_data = "hello-its-me"
|
||||
searcher = self.MockSearcher(original_data)
|
||||
searcher.save_to_dir(tmpdir)
|
||||
searcher_2 = self.MockSearcher("no-its-not-me")
|
||||
searcher_2.restore_from_dir(tmpdir)
|
||||
assert searcher_2.data == original_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
@@ -17,7 +17,7 @@ from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
|
||||
from ray.tune.syncer import get_cloud_syncer
|
||||
from ray.tune.trial import Checkpoint, Trial
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.suggest import BasicVariantGenerator, Searcher
|
||||
from ray.tune.utils import warn_if_slow, flatten_dict
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
@@ -252,6 +252,9 @@ class TrialRunner:
|
||||
Overwrites the current session checkpoint, which starts when self
|
||||
is instantiated. Throttle depends on self._checkpoint_period.
|
||||
|
||||
Also automatically saves the search algorithm to the local
|
||||
checkpoint dir.
|
||||
|
||||
Args:
|
||||
force (bool): Forces a checkpoint despite checkpoint_period.
|
||||
"""
|
||||
@@ -277,6 +280,9 @@ class TrialRunner:
|
||||
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
|
||||
|
||||
os.replace(tmp_file_name, self.checkpoint_file)
|
||||
|
||||
Searcher.save_to_dir(self._search_alg, self._local_checkpoint_dir)
|
||||
|
||||
if force:
|
||||
self._syncer.sync_up()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user