[tune] support resume for search algorithms (#9972)

This commit is contained in:
Richard Liaw
2020-08-10 13:43:14 -07:00
committed by GitHub
parent 5331c30e35
commit be8e63d477
13 changed files with 688 additions and 217 deletions
+2 -2
View File
@@ -1,7 +1,7 @@
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest.suggestion import (SearchGenerator, Searcher,
ConcurrencyLimiter)
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
from ray.tune.suggest.search_generator import SearchGenerator
from ray.tune.suggest.variant_generator import grid_search
from ray.tune.suggest.repeater import Repeater
+49
View File
@@ -0,0 +1,49 @@
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
from ray.tune.suggest.search_generator import SearchGenerator
class _MockSearcher(Searcher):
def __init__(self, **kwargs):
self.live_trials = {}
self.counter = {"result": 0, "complete": 0}
self.final_results = []
self.stall = False
self.results = []
super(_MockSearcher, self).__init__(**kwargs)
def suggest(self, trial_id):
if not self.stall:
self.live_trials[trial_id] = 1
return {"test_variable": 2}
return None
def on_trial_result(self, trial_id, result):
self.counter["result"] += 1
self.results += [result]
def on_trial_complete(self, trial_id, result=None, error=False):
self.counter["complete"] += 1
if result:
self._process_result(result)
if trial_id in self.live_trials:
del self.live_trials[trial_id]
def _process_result(self, result):
self.final_results += [result]
class _MockSuggestionAlgorithm(SearchGenerator):
def __init__(self, max_concurrent=None, **kwargs):
self.searcher = _MockSearcher(**kwargs)
if max_concurrent:
self.searcher = ConcurrencyLimiter(
self.searcher, max_concurrent=max_concurrent)
super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
@property
def live_trials(self):
return self.searcher.live_trials
@property
def results(self):
return self.searcher.results
+17 -4
View File
@@ -212,13 +212,26 @@ class HyperOptSearch(Searcher):
t for t in self._hpopt_trials.trials if t["tid"] == hyperopt_tid
][0]
def get_state(self):
return {
"hyperopt_trials": self._hpopt_trials,
"rstate": self.rstate.get_state()
}
def set_state(self, state):
self._hpopt_trials = state["hyperopt_trials"]
self.rstate.set_state(state["rstate"])
def save(self, checkpoint_path):
trials_object = (self._hpopt_trials, self.rstate.get_state())
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(trials_object, outputFile)
pickle.dump(self.get_state(), outputFile)
def restore(self, checkpoint_path):
with open(checkpoint_path, "rb") as inputFile:
trials_object = pickle.load(inputFile)
self._hpopt_trials = trials_object[0]
self.rstate.set_state(trials_object[1])
if isinstance(trials_object, tuple):
self._hpopt_trials = trials_object[0]
self.rstate.set_state(trials_object[1])
else:
self.set_state(trials_object)
+9 -6
View File
@@ -44,7 +44,7 @@ class _TrialGroup:
def add(self, trial_id):
assert len(self._trials) < self.max_trials
self._trials[trial_id] = None
self._trials.setdefault(trial_id, None)
def full(self):
return len(self._trials) == self.max_trials
@@ -56,7 +56,8 @@ class _TrialGroup:
self._trials[trial_id] = score
def finished_reporting(self):
return None not in self._trials.values()
return None not in self._trials.values() and len(
self._trials) == self.max_trials
def scores(self):
return list(self._trials.values())
@@ -159,8 +160,10 @@ class Repeater(Searcher):
result={self.searcher.metric: np.nanmean(scores)},
**kwargs)
def save(self, path):
self.searcher.save(path)
def get_state(self):
self_state = self.__dict__.copy()
del self_state["searcher"]
return self_state
def restore(self, path):
self.searcher.restore(path)
def set_state(self, state):
self.__dict__.update(state)
+8 -2
View File
@@ -63,8 +63,14 @@ class SearchAlgorithm:
"""Marks the search algorithm as finished."""
self._finished = True
def save(self, *args):
def has_checkpoint(self, dirpath):
"""Should return False if not restoring is not implemented."""
return False
def save_to_dir(self, dirpath, **kwargs):
"""Saves a search algorithm."""
pass
def restore(self, *args):
def restore_from_dir(self, dirpath):
"""Restores a search algorithm along with its wrapped state."""
pass
+222
View File
@@ -0,0 +1,222 @@
import pickle
import os
import copy
import logging
import glob
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.suggestion import Searcher
from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict
from ray.tune.trial import Trial
from ray.tune.utils import flatten_dict, merge_dicts
logger = logging.getLogger(__name__)
def _warn_on_repeater(searcher, total_samples):
from ray.tune.suggest.repeater import _warn_num_samples
_warn_num_samples(searcher, total_samples)
def _atomic_save(state, checkpoint_dir, file_name):
"""Atomically saves the object to the checkpoint directory
This is automatically used by tune.run during a Tune job.
"""
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
".tmp_search_generator_ckpt")
with open(tmp_search_ckpt_path, "wb") as f:
pickle.dump(state, f)
os.rename(tmp_search_ckpt_path, os.path.join(checkpoint_dir, file_name))
def _find_newest_ckpt(dirpath, pattern):
"""Returns path to most recently modified checkpoint."""
full_paths = glob.glob(os.path.join(dirpath, pattern))
if not full_paths:
return
most_recent_checkpoint = max(full_paths)
with open(most_recent_checkpoint, "rb") as f:
search_alg_state = pickle.load(f)
return search_alg_state
class SearchGenerator(SearchAlgorithm):
"""Generates trials to be passed to the TrialRunner.
Uses the provided ``searcher`` object to generate trials. This class
transparently handles repeating trials with score aggregation
without embedding logic into the Searcher.
Args:
searcher: Search object that subclasses the Searcher base class. This
is then used for generating new hyperparameter samples.
"""
CKPT_FILE_TMPL = "search_gen_state-{}.json"
def __init__(self, searcher):
assert issubclass(
type(searcher),
Searcher), ("Searcher should be subclassing Searcher.")
self.searcher = searcher
self._parser = make_parser()
self._experiment = None
self._counter = 0 # Keeps track of number of trials created.
self._total_samples = None # int: total samples to evaluate.
self._finished = False
def add_configurations(self, experiments):
"""Registers experiment specifications.
Arguments:
experiments (Experiment | list | dict): Experiments to run.
"""
assert not self._experiment
logger.debug("added configurations")
experiment_list = convert_to_experiment_list(experiments)
assert len(experiment_list) == 1, (
"SearchAlgorithms can only support 1 experiment at a time.")
self._experiment = experiment_list[0]
experiment_spec = self._experiment.spec
self._total_samples = self._experiment.spec.get("num_samples", 1)
_warn_on_repeater(self.searcher, self._total_samples)
if "run" not in experiment_spec:
raise TuneError("Must specify `run` in {}".format(experiment_spec))
def next_trials(self):
"""Provides a batch of Trial objects to be queued into the TrialRunner.
Returns:
List[Trial]: A list of trials for the Runner to consume.
"""
trials = []
while not self.is_finished():
trial = self.create_trial_if_possible(self._experiment.spec,
self._experiment.name)
if trial is None:
break
trials.append(trial)
return trials
def create_trial_if_possible(self, experiment_spec, output_path):
logger.debug("creating trial")
trial_id = Trial.generate_id()
suggested_config = self.searcher.suggest(trial_id)
if suggested_config == Searcher.FINISHED:
self._finished = True
logger.debug("Searcher has finished.")
return
if suggested_config is None:
return
spec = copy.deepcopy(experiment_spec)
spec["config"] = merge_dicts(spec["config"],
copy.deepcopy(suggested_config))
# Create a new trial_id if duplicate trial is created
flattened_config = resolve_nested_dict(spec["config"])
self._counter += 1
tag = "{0}_{1}".format(
str(self._counter), format_vars(flattened_config))
trial = create_trial_from_spec(
spec,
output_path,
self._parser,
evaluated_params=flatten_dict(suggested_config),
experiment_tag=tag,
trial_id=trial_id)
return trial
def on_trial_result(self, trial_id, result):
"""Notifies the underlying searcher."""
self.searcher.on_trial_result(trial_id, result)
def on_trial_complete(self, trial_id, result=None, error=False):
self.searcher.on_trial_complete(
trial_id=trial_id, result=result, error=error)
def is_finished(self):
return self._counter >= self._total_samples or self._finished
def get_state(self):
return {
"counter": self._counter,
"total_samples": self._total_samples,
"finished": self._finished,
"experiment": self._experiment
}
def set_state(self, state):
self._counter = state["counter"]
self._total_samples = state["total_samples"]
self._finished = state["finished"]
self._experiment = state["experiment"]
def has_checkpoint(self, dirpath):
return bool(
_find_newest_ckpt(dirpath, self.CKPT_FILE_TMPL.format("*")))
def save_to_dir(self, dirpath, session_str):
"""Saves self + searcher to dir.
Separates the "searcher" from its wrappers (concurrency, repeating).
This allows the user to easily restore a given searcher.
The save operation is atomic (write/swap).
Args:
dirpath (str): Filepath to experiment dir.
session_str (str): Unique identifier of the current run
session.
"""
searcher = self.searcher
search_alg_state = self.get_state()
while hasattr(searcher, "searcher"):
searcher_name = type(searcher).__name__
if searcher_name in search_alg_state:
logger.warning(
"There was a duplicate when saving {}. "
"Restore may not work properly.".format(searcher_name))
else:
search_alg_state["name:" +
searcher_name] = searcher.get_state()
searcher = searcher.searcher
base_searcher = searcher
# We save the base searcher separately for users to easily
# separate the searcher.
base_searcher.save_to_dir(dirpath, session_str)
_atomic_save(search_alg_state, dirpath,
self.CKPT_FILE_TMPL.format(session_str))
def restore_from_dir(self, dirpath):
"""Restores self + searcher + search wrappers from dirpath."""
searcher = self.searcher
search_alg_state = _find_newest_ckpt(dirpath,
self.CKPT_FILE_TMPL.format("*"))
if not search_alg_state:
raise RuntimeError(
"Unable to find checkpoint in {}.".format(dirpath))
while hasattr(searcher, "searcher"):
searcher_name = "name:" + type(searcher).__name__
if searcher_name not in search_alg_state:
names = [
key.split("name:")[1] for key in search_alg_state
if key.startswith("name:")
]
logger.warning("{} was not found in the experiment checkpoint "
"state when restoring. Found {}.".format(
searcher_name, names))
else:
searcher.set_state(search_alg_state.pop(searcher_name))
searcher = searcher.searcher
base_searcher = searcher
logger.debug(f"searching base {base_searcher}")
base_searcher.restore_from_dir(dirpath)
self.set_state(search_alg_state)
+43 -181
View File
@@ -1,23 +1,13 @@
import copy
import glob
import logging
import os
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.variant_generator import format_vars, resolve_nested_dict
from ray.tune.trial import Trial
from ray.tune.utils import merge_dicts, flatten_dict
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
def _warn_on_repeater(searcher, total_samples):
from ray.tune.suggest.repeater import _warn_num_samples
_warn_num_samples(searcher, total_samples)
class Searcher:
"""Abstract class for wrapping suggesting algorithms.
@@ -59,7 +49,7 @@ class Searcher:
"""
FINISHED = "FINISHED"
CKPT_FILE = "searcher-state.pkl"
CKPT_FILE_TMPL = "searcher-state-{}.pkl"
def __init__(self,
metric="episode_reward_mean",
@@ -186,23 +176,38 @@ class Searcher:
"""
raise NotImplementedError
def save_to_dir(self, checkpoint_dir):
def get_state(self):
raise NotImplementedError
def set_state(self, state):
raise NotImplementedError
def save_to_dir(self, checkpoint_dir, session_str="default"):
"""Automatically saves the given searcher to the checkpoint_dir.
This is automatically used by tune.run during a Tune job.
Args:
checkpoint_dir (str): Filepath to experiment dir.
session_str (str): Unique identifier of the current run
session.
"""
tmp_search_ckpt_path = os.path.join(checkpoint_dir,
".tmp_searcher_ckpt")
success = True
try:
self.save(tmp_search_ckpt_path)
except NotImplementedError as e:
logger.warning(e)
except NotImplementedError:
with log_once("suggest:save_to_dir"):
logger.warning(
"save not implemented for Searcher. Skipping save.")
success = False
if success and os.path.exists(tmp_search_ckpt_path):
os.rename(tmp_search_ckpt_path,
os.path.join(checkpoint_dir, Searcher.CKPT_FILE))
os.rename(
tmp_search_ckpt_path,
os.path.join(checkpoint_dir,
self.CKPT_FILE_TMPL.format(session_str)))
def restore_from_dir(self, checkpoint_dir):
"""Restores the state of a searcher from a given checkpoint_dir.
@@ -225,14 +230,14 @@ class Searcher:
os.path.join("~/my_results", self.experiment_name)
"""
checkpoint_path = os.path.join(checkpoint_dir, Searcher.CKPT_FILE)
if os.path.exists(checkpoint_path):
self.restore(checkpoint_path)
else:
raise FileNotFoundError(
"{filename} not found in {directory}. Unable to restore "
"searcher state from directory.".format(
filename=Searcher.CKPT_FILE, directory=checkpoint_dir))
pattern = self.CKPT_FILE_TMPL.format("*")
full_paths = glob.glob(os.path.join(checkpoint_dir, pattern))
if not full_paths:
raise RuntimeError(
"Searcher unable to find checkpoint in {}".format(
checkpoint_dir)) # TODO
most_recent_checkpoint = max(full_paths)
self.restore(most_recent_checkpoint)
@property
def metric(self):
@@ -271,7 +276,13 @@ class ConcurrencyLimiter(Searcher):
metric=self.searcher.metric, mode=self.searcher.mode)
def suggest(self, trial_id):
assert trial_id not in self.live_trials, (
f"Trial ID {trial_id} must be unique: already found in set.")
if len(self.live_trials) >= self.max_concurrent:
logger.debug(
f"Not providing a suggestion for {trial_id} due to "
"concurrency limit: %s/%s.", len(self.live_trials),
self.max_concurrent)
return
suggestion = self.searcher.suggest(trial_id)
if suggestion not in (None, Searcher.FINISHED):
@@ -286,159 +297,10 @@ class ConcurrencyLimiter(Searcher):
trial_id, result=result, error=error)
self.live_trials.remove(trial_id)
def save(self, checkpoint_dir):
self.searcher.save(checkpoint_dir)
def get_state(self):
state = self.__dict__.copy()
del state["searcher"]
return copy.deepcopy(state)
def restore(self, checkpoint_dir):
self.searcher.restore(checkpoint_dir)
class SearchGenerator(SearchAlgorithm):
"""Generates trials to be passed to the TrialRunner.
Uses the provided ``searcher`` object to generate trials. This class
transparently handles repeating trials with score aggregation
without embedding logic into the Searcher.
Args:
searcher: Search object that subclasses the Searcher base class. This
is then used for generating new hyperparameter samples.
"""
def __init__(self, searcher):
assert issubclass(
type(searcher),
Searcher), ("Searcher should be subclassing Searcher.")
self.searcher = searcher
self._parser = make_parser()
self._experiment = None
self._counter = 0 # Keeps track of number of trials created.
self._total_samples = 0 # int: total samples to evaluate.
self._finished = False
def add_configurations(self, experiments):
"""Registers experiment specifications.
Arguments:
experiments (Experiment | list | dict): Experiments to run.
"""
logger.debug("added configurations")
experiment_list = convert_to_experiment_list(experiments)
assert len(experiment_list) == 1, (
"SearchAlgorithms can only support 1 experiment at a time.")
self._experiment = experiment_list[0]
experiment_spec = self._experiment.spec
self._total_samples = experiment_spec.get("num_samples", 1)
_warn_on_repeater(self.searcher, self._total_samples)
if "run" not in experiment_spec:
raise TuneError("Must specify `run` in {}".format(experiment_spec))
def next_trials(self):
"""Provides a batch of Trial objects to be queued into the TrialRunner.
Returns:
List[Trial]: A list of trials for the Runner to consume.
"""
trials = []
while not self.is_finished():
trial = self.create_trial_if_possible(self._experiment.spec,
self._experiment.name)
if trial is None:
break
trials.append(trial)
return trials
def create_trial_if_possible(self, experiment_spec, output_path):
logger.debug("creating trial")
trial_id = Trial.generate_id()
suggested_config = self.searcher.suggest(trial_id)
if suggested_config == Searcher.FINISHED:
self._finished = True
logger.debug("Searcher has finished.")
return
if suggested_config is None:
return
spec = copy.deepcopy(experiment_spec)
spec["config"] = merge_dicts(spec["config"],
copy.deepcopy(suggested_config))
# Create a new trial_id if duplicate trial is created
flattened_config = resolve_nested_dict(spec["config"])
self._counter += 1
tag = "{0}_{1}".format(
str(self._counter), format_vars(flattened_config))
trial = create_trial_from_spec(
spec,
output_path,
self._parser,
evaluated_params=flatten_dict(suggested_config),
experiment_tag=tag,
trial_id=trial_id)
return trial
def on_trial_result(self, trial_id, result):
"""Notifies the underlying searcher."""
self.searcher.on_trial_result(trial_id, result)
def on_trial_complete(self, trial_id, result=None, error=False):
self.searcher.on_trial_complete(
trial_id=trial_id, result=result, error=error)
def is_finished(self):
return self._counter >= self._total_samples or self._finished
def save(self, checkpoint_path):
self.searcher.save(checkpoint_path)
def restore(self, checkpoint_path):
self.searcher.restore(checkpoint_path)
class _MockSearcher(Searcher):
def __init__(self, **kwargs):
self.live_trials = {}
self.counter = {"result": 0, "complete": 0}
self.final_results = []
self.stall = False
self.results = []
super(_MockSearcher, self).__init__(**kwargs)
def suggest(self, trial_id):
if not self.stall:
self.live_trials[trial_id] = 1
return {"test_variable": 2}
return None
def on_trial_result(self, trial_id, result):
self.counter["result"] += 1
self.results += [result]
def on_trial_complete(self, trial_id, result=None, error=False):
self.counter["complete"] += 1
if result:
self._process_result(result)
if trial_id in self.live_trials:
del self.live_trials[trial_id]
def _process_result(self, result):
self.final_results += [result]
class _MockSuggestionAlgorithm(SearchGenerator):
def __init__(self, max_concurrent=None, **kwargs):
self.searcher = _MockSearcher(**kwargs)
if max_concurrent:
self.searcher = ConcurrencyLimiter(
self.searcher, max_concurrent=max_concurrent)
super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
@property
def live_trials(self):
return self.searcher.live_trials
@property
def results(self):
return self.searcher.results
def set_state(self, state):
self.__dict__.update(state)
@@ -0,0 +1,56 @@
import argparse
from ray.tune import run
from ray.tune.examples.async_hyperband_example import MyTrainableClass
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.suggestion import ConcurrencyLimiter
from hyperopt import hp
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="PyTorch Example (FOR TEST ONLY)")
parser.add_argument(
"--resume", action="store_true", help="Resuming from checkpoint.")
parser.add_argument("--local-dir", help="Checkpoint path")
parser.add_argument(
"--ray-address",
help="Address of Ray cluster for seamless distributed execution.")
args = parser.parse_args()
space = {
"width": hp.uniform("width", 0, 20),
"height": hp.uniform("height", -100, 100),
"activation": hp.choice("activation", ["relu", "tanh"])
}
current_best_params = [
{
"width": 1,
"height": 2,
"activation": 0 # Activation will be relu
},
{
"width": 4,
"height": 2,
"activation": 1 # Activation will be tanh
}
]
algo = HyperOptSearch(
space,
metric="episode_reward_mean",
mode="max",
random_state_seed=5,
points_to_evaluate=current_best_params)
algo = ConcurrencyLimiter(algo, max_concurrent=1)
from ray.tune import register_trainable
register_trainable("trainable", MyTrainableClass)
run("trainable",
search_alg=algo,
global_checkpoint_period=0,
resume=args.resume,
verbose=0,
num_samples=20,
fail_fast=True,
stop={"training_iteration": 2},
local_dir=args.local_dir,
name="experiment")
+1 -1
View File
@@ -23,7 +23,7 @@ from ray.tune.logger import Logger
from ray.tune.experiment import Experiment
from ray.tune.resources import Resources
from ray.tune.suggest import grid_search
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
from ray.tune.utils import (flatten_dict, get_pinned_object,
pin_in_object_store)
from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
+75
View File
@@ -4,6 +4,7 @@ import time
import os
import pytest
import shutil
import subprocess
import sys
from unittest.mock import MagicMock, patch
@@ -713,6 +714,80 @@ tune.run(
cluster.shutdown()
def test_cluster_interrupt_searcher(start_connected_cluster, tmpdir):
"""Tests restoration of HyperOptSearch experiment on cluster shutdown
with actual interrupt.
Restoration should restore both state of trials
and previous search algorithm (HyperOptSearch) state.
This is an end-to-end test.
"""
cluster = start_connected_cluster
dirpath = str(tmpdir)
local_checkpoint_dir = os.path.join(dirpath, "experiment")
from ray.tune.examples.async_hyperband_example import MyTrainableClass
from ray.tune import register_trainable
register_trainable("trainable", MyTrainableClass)
def execute_script_with_args(*args):
current_dir = os.path.dirname(__file__)
script = os.path.join(current_dir,
"_test_cluster_interrupt_searcher.py")
subprocess.Popen([sys.executable, script] + list(args))
args = ["--ray-address", cluster.address, "--local-dir", dirpath]
execute_script_with_args(*args)
# Wait until the right checkpoint is saved.
# The trainable returns every 0.5 seconds, so this should not miss
# the checkpoint.
for i in range(50):
if TrialRunner.checkpoint_exists(local_checkpoint_dir):
# Inspect the internal trialrunner
runner = TrialRunner(
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
trials = runner.get_trials()
if trials and len(trials) >= 10:
break
time.sleep(.5)
if not TrialRunner.checkpoint_exists(local_checkpoint_dir):
raise RuntimeError(
f"Checkpoint file didn't appear in {local_checkpoint_dir}. "
f"Current list: {os.listdir(local_checkpoint_dir)}.")
ray.shutdown()
cluster.shutdown()
cluster = _start_new_cluster()
execute_script_with_args(*(args + ["--resume"]))
time.sleep(2)
register_trainable("trainable", MyTrainableClass)
reached = False
for i in range(50):
if TrialRunner.checkpoint_exists(local_checkpoint_dir):
# Inspect the internal trialrunner
runner = TrialRunner(
resume="LOCAL", local_checkpoint_dir=local_checkpoint_dir)
trials = runner.get_trials()
if len(trials) == 0:
continue # nonblocking script hasn't resumed yet, wait
reached = True
assert len(trials) >= 10
assert len(trials) <= 20
if len(trials) == 20:
break
else:
stop_fn = runner.trial_executor.stop_trial
[stop_fn(t) for t in trials if t.status is not Trial.ERROR]
time.sleep(.5)
assert reached is True
ray.shutdown()
cluster.shutdown()
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))
+177 -2
View File
@@ -1,4 +1,6 @@
from collections import Counter
import os
import pickle
import shutil
import sys
import tempfile
@@ -14,14 +16,17 @@ from ray.tune.trial import Trial
from ray.tune.trial_runner import TrialRunner
from ray.tune.resources import Resources, json_to_resources, resources_to_json
from ray.tune.suggest.repeater import Repeater
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
SearchGenerator, Searcher)
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
from ray.tune.suggest.search_generator import SearchGenerator
class TrialRunnerTest3(unittest.TestCase):
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
if "CUDA_VISIBLE_DEVICES" in os.environ:
del os.environ["CUDA_VISIBLE_DEVICES"]
def testStepHook(self):
ray.init(num_cpus=4, num_gpus=2)
@@ -264,6 +269,92 @@ class TrialRunnerTest3(unittest.TestCase):
self.assertTrue(searcher.is_finished())
self.assertRaises(TuneError, runner.step)
def testSearcherSaveRestore(self):
ray.init(num_cpus=8, local_mode=True)
tmpdir = tempfile.mkdtemp()
def create_searcher():
class TestSuggestion(Searcher):
def __init__(self, index):
self.index = index
self.returned_result = []
super().__init__(metric="result", mode="max")
def suggest(self, trial_id):
self.index += 1
return {"test_variable": self.index}
def on_trial_complete(self, trial_id, result=None, **kwargs):
self.returned_result.append(result)
def save(self, checkpoint_path):
with open(checkpoint_path, "wb") as f:
pickle.dump(self.__dict__, f)
def restore(self, checkpoint_path):
with open(checkpoint_path, "rb") as f:
self.__dict__.update(pickle.load(f))
searcher = TestSuggestion(0)
searcher = ConcurrencyLimiter(searcher, max_concurrent=2)
searcher = Repeater(searcher, repeat=3, set_index=False)
search_alg = SearchGenerator(searcher)
experiment_spec = {
"run": "__fake",
"num_samples": 20,
"stop": {
"training_iteration": 2
}
}
experiments = [Experiment.from_json("test", experiment_spec)]
search_alg.add_configurations(experiments)
return search_alg
searcher = create_searcher()
runner = TrialRunner(
search_alg=searcher,
local_checkpoint_dir=tmpdir,
checkpoint_period=-1)
for i in range(6):
runner.step()
assert len(
runner.get_trials()) == 6, [t.config for t in runner.get_trials()]
runner.checkpoint()
trials = runner.get_trials()
[
runner.trial_executor.stop_trial(t) for t in trials
if t.status is not Trial.ERROR
]
del runner
# stop_all(runner.get_trials())
searcher = create_searcher()
runner2 = TrialRunner(
search_alg=searcher, local_checkpoint_dir=tmpdir, resume="LOCAL")
assert len(runner2.get_trials()) == 6, [
t.config for t in runner2.get_trials()
]
def trial_statuses():
return [t.status for t in runner2.get_trials()]
def num_running_trials():
return sum(t.status == Trial.RUNNING for t in runner2.get_trials())
for i in range(6):
runner2.step()
assert len(set(trial_statuses())) == 1
assert Trial.RUNNING in trial_statuses()
for i in range(20):
runner2.step()
assert 1 <= num_running_trials() <= 6
evaluated = [
t.evaluated_params["test_variable"] for t in runner2.get_trials()
]
count = Counter(evaluated)
assert all(v <= 3 for v in count.values())
def testTrialSaveRestore(self):
"""Creates different trials to test runner.checkpoint/restore."""
ray.init(num_cpus=3)
@@ -512,6 +603,90 @@ class SearchAlgorithmTest(unittest.TestCase):
parameter_set = {t.evaluated_params["test_variable"] for t in trials}
self.assertEquals(len(parameter_set), 3)
def testSetGetRepeater(self):
ray.init(num_cpus=4)
class TestSuggestion(Searcher):
def __init__(self, index):
self.index = index
self.returned_result = []
super().__init__(metric="result", mode="max")
def suggest(self, trial_id):
self.index += 1
return {"score": self.index}
def on_trial_complete(self, trial_id, result=None, **kwargs):
self.returned_result.append(result)
searcher = TestSuggestion(0)
repeater1 = Repeater(searcher, repeat=3, set_index=False)
for i in range(3):
assert repeater1.suggest(f"test_{i}")["score"] == 1
for i in range(2): # An incomplete set of results
assert repeater1.suggest(f"test_{i}_2")["score"] == 2
# Restore a new one
state = repeater1.get_state()
del repeater1
new_repeater = Repeater(searcher, repeat=1, set_index=True)
new_repeater.set_state(state)
assert new_repeater.repeat == 3
assert new_repeater.suggest("test_2_2")["score"] == 2
assert new_repeater.suggest("test_x")["score"] == 3
# Report results
for i in range(3):
new_repeater.on_trial_complete(f"test_{i}", {"result": 2})
for i in range(3):
new_repeater.on_trial_complete(f"test_{i}_2", {"result": -i * 10})
assert len(new_repeater.searcher.returned_result) == 2
assert new_repeater.searcher.returned_result[-1] == {"result": -10}
# Finish the rest of the last trial group
new_repeater.on_trial_complete("test_x", {"result": 3})
assert new_repeater.suggest("test_y")["score"] == 3
new_repeater.on_trial_complete("test_y", {"result": 3})
assert len(new_repeater.searcher.returned_result) == 2
assert new_repeater.suggest("test_z")["score"] == 3
new_repeater.on_trial_complete("test_z", {"result": 3})
assert len(new_repeater.searcher.returned_result) == 3
assert new_repeater.searcher.returned_result[-1] == {"result": 3}
def testSetGetLimiter(self):
ray.init(num_cpus=4)
class TestSuggestion(Searcher):
def __init__(self, index):
self.index = index
self.returned_result = []
super().__init__(metric="result", mode="max")
def suggest(self, trial_id):
self.index += 1
return {"score": self.index}
def on_trial_complete(self, trial_id, result=None, **kwargs):
self.returned_result.append(result)
searcher = TestSuggestion(0)
limiter = ConcurrencyLimiter(searcher, max_concurrent=2)
assert limiter.suggest("test_1")["score"] == 1
assert limiter.suggest("test_2")["score"] == 2
assert limiter.suggest("test_3") is None
state = limiter.get_state()
del limiter
limiter2 = ConcurrencyLimiter(searcher, max_concurrent=3)
limiter2.set_state(state)
assert limiter2.suggest("test_4") is None
assert limiter2.suggest("test_5") is None
limiter2.on_trial_complete("test_1", {"result": 3})
limiter2.on_trial_complete("test_2", {"result": 3})
assert limiter2.suggest("test_3")["score"] == 3
class ResourcesTest(unittest.TestCase):
def testSubtraction(self):
+18 -14
View File
@@ -17,7 +17,7 @@ from ray.tune.result import (TIME_THIS_ITER_S, RESULT_DUPLICATE,
from ray.tune.syncer import get_cloud_syncer
from ray.tune.trial import Checkpoint, Trial
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest import BasicVariantGenerator, Searcher
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.utils import warn_if_slow, flatten_dict
from ray.tune.web_server import TuneServer
from ray.utils import binary_to_hex, hex_to_binary
@@ -168,9 +168,13 @@ class TrialRunner:
self.resume()
logger.info("Resuming trial.")
self._resumed = True
except Exception:
logger.exception(
"Runner restore failed. Restarting experiment.")
except Exception as e:
if self._verbose:
logger.error(str(e))
logger.exception("Runner restore failed.")
if self._fail_fast:
raise
logger.info("Restarting experiment.")
else:
logger.debug("Starting a new experiment.")
@@ -185,6 +189,10 @@ class TrialRunner:
self._local_checkpoint_dir,
TrialRunner.CKPT_FILE_TMPL.format(self._session_str))
@property
def resumed(self):
return self._resumed
@property
def scheduler_alg(self):
return self._scheduler_alg
@@ -240,12 +248,6 @@ class TrialRunner:
(fname.startswith("experiment_state") and fname.endswith(".json"))
for fname in os.listdir(directory))
def add_experiment(self, experiment):
if not self._resumed:
self._search_alg.add_configurations([experiment])
else:
logger.info("TrialRunner resumed, ignoring new add_experiment.")
def checkpoint(self, force=False):
"""Saves execution state to `self._local_checkpoint_dir`.
@@ -280,8 +282,8 @@ class TrialRunner:
json.dump(runner_state, f, indent=2, cls=_TuneFunctionEncoder)
os.replace(tmp_file_name, self.checkpoint_file)
Searcher.save_to_dir(self._search_alg, self._local_checkpoint_dir)
self._search_alg.save_to_dir(
self._local_checkpoint_dir, session_str=self._session_str)
if force:
self._syncer.sync_up()
@@ -299,14 +301,16 @@ class TrialRunner:
with open(newest_ckpt_path, "r") as f:
runner_state = json.load(f, cls=_TuneFunctionDecoder)
self.checkpoint_file = newest_ckpt_path
logger.warning("".join([
"Attempting to resume experiment from {}. ".format(
self._local_checkpoint_dir), "This feature is experimental, "
"and may not work with all search algorithms. ",
self._local_checkpoint_dir),
"This will ignore any new changes to the specification."
]))
self.__setstate__(runner_state["runner_data"])
if self._search_alg.has_checkpoint(self._local_checkpoint_dir):
self._search_alg.restore_from_dir(self._local_checkpoint_dir)
trials = []
for trial_cp in runner_state["checkpoints"]:
+11 -5
View File
@@ -3,8 +3,8 @@ import logging
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list, Experiment
from ray.tune.analysis import ExperimentAnalysis
from ray.tune.suggest import BasicVariantGenerator
from ray.tune.suggest.suggestion import Searcher, SearchGenerator
from ray.tune.suggest import BasicVariantGenerator, SearchGenerator
from ray.tune.suggest.suggestion import Searcher
from ray.tune.trial import Trial
from ray.tune.trainable import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
@@ -300,8 +300,11 @@ def run(run_or_experiment,
if issubclass(type(search_alg), Searcher):
search_alg = SearchGenerator(search_alg)
if not search_alg:
search_alg = BasicVariantGenerator()
runner = TrialRunner(
search_alg=search_alg or BasicVariantGenerator(),
search_alg=search_alg,
scheduler=scheduler or FIFOScheduler(),
local_checkpoint_dir=experiments[0].checkpoint_dir,
remote_checkpoint_dir=experiments[0].remote_checkpoint_dir,
@@ -315,8 +318,11 @@ def run(run_or_experiment,
fail_fast=fail_fast,
trial_executor=trial_executor)
for exp in experiments:
runner.add_experiment(exp)
if not runner.resumed:
for exp in experiments:
search_alg.add_configurations([exp])
else:
logger.info("TrialRunner resumed, ignoring new add_experiment.")
if progress_reporter is None:
if IS_NOTEBOOK: