[Tune] Synchronous Mode for PBT (#10283)

This commit is contained in:
Amog Kamsetty
2020-08-31 00:00:47 -07:00
committed by GitHub
parent 05fe6dc278
commit afde3db4f0
4 changed files with 558 additions and 87 deletions
+123 -29
View File
@@ -27,10 +27,12 @@ class PBTTrialState:
self.last_score = None
self.last_checkpoint = None
self.last_perturbation_time = 0
self.last_train_time = 0 # Used for synchronous mode.
self.last_result = None # Used for synchronous mode.
def __repr__(self):
return str((self.last_score, self.last_checkpoint,
self.last_perturbation_time))
self.last_train_time, self.last_perturbation_time))
def explore(config, mutations, resample_probability, custom_explore_fn):
@@ -174,6 +176,13 @@ class PopulationBasedTraining(FIFOScheduler):
require_attrs (bool): Whether to require time_attr and metric to appear
in result for every iteration. If True, error will be raised
if these values are not present in trial result.
synch (bool): If False, will use asynchronous implementation of
PBT. Trial perturbations occur every perturbation_interval for each
trial independently. If True, will use synchronous implementation
of PBT. Perturbations will occur only after all trials are
synced at the same time_attr every perturbation_interval.
Defaults to False. See Appendix A.1 here
https://arxiv.org/pdf/1711.09846.pdf.
.. code-block:: python
@@ -215,7 +224,8 @@ class PopulationBasedTraining(FIFOScheduler):
resample_probability=0.25,
custom_explore_fn=None,
log_config=True,
require_attrs=True):
require_attrs=True,
synch=False):
for value in hyperparam_mutations.values():
if not (isinstance(value,
(list, dict, sample_from)) or callable(value)):
@@ -234,10 +244,15 @@ class PopulationBasedTraining(FIFOScheduler):
"`custom_explore_fn` to use PBT.")
if quantile_fraction > 0.5 or quantile_fraction < 0:
raise TuneError(
raise ValueError(
"You must set `quantile_fraction` to a value between 0 and"
"0.5. Current value: '{}'".format(quantile_fraction))
if perturbation_interval <= 0:
raise ValueError(
"perturbation_interval must be a positive number greater "
"than 0. Current value: '{}'".format(perturbation_interval))
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
if reward_attr is not None:
@@ -263,6 +278,8 @@ class PopulationBasedTraining(FIFOScheduler):
self._custom_explore_fn = custom_explore_fn
self._log_config = log_config
self._require_attrs = require_attrs
self._synch = synch
self._next_perturbation_sync = self._perturbation_interval
# Metrics
self._num_checkpoints = 0
@@ -320,34 +337,90 @@ class PopulationBasedTraining(FIFOScheduler):
time = result[self._time_attr]
state = self._trial_state[trial]
# Continue training if perturbation interval has not been reached yet.
if time - state.last_perturbation_time < self._perturbation_interval:
return TrialScheduler.CONTINUE # avoid checkpoint overhead
# This trial has reached its perturbation interval
score = self._metric_op * result[self._metric]
state.last_score = score
state.last_perturbation_time = time
lower_quantile, upper_quantile = self._quantiles()
state.last_train_time = time
state.last_result = result
if not self._synch:
state.last_perturbation_time = time
lower_quantile, upper_quantile = self._quantiles()
self._perturb_trial(trial, trial_runner, upper_quantile,
lower_quantile)
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED]:
return TrialScheduler.PAUSE # yield time to other trials
return TrialScheduler.CONTINUE
else:
# Synchronous mode.
if any(self._trial_state[t].last_train_time <
self._next_perturbation_sync and t != trial
for t in trial_runner.get_trials()):
logger.debug("Pausing trial {}".format(trial))
else:
# All trials are synced at the same timestep.
lower_quantile, upper_quantile = self._quantiles()
all_trials = trial_runner.get_trials()
not_in_quantile = []
for t in all_trials:
if t not in lower_quantile and t not in upper_quantile:
not_in_quantile.append(t)
# Move upper quantile trials to beginning and lower quantile
# to end. This ensures that checkpointing of strong trials
# occurs before exploiting of weaker ones.
all_trials = upper_quantile + not_in_quantile + lower_quantile
for t in all_trials:
logger.debug("Perturbing Trial {}".format(t))
self._trial_state[t].last_perturbation_time = time
self._perturb_trial(t, trial_runner, upper_quantile,
lower_quantile)
all_train_times = [
self._trial_state[trial].last_train_time
for trial in trial_runner.get_trials()
]
max_last_train_time = max(all_train_times)
self._next_perturbation_sync = max(
self._next_perturbation_sync + self._perturbation_interval,
max_last_train_time)
# In sync mode we should pause all trials once result comes in.
# Once a perturbation step happens for all trials, they should
# still all be paused.
# choose_trial_to_run will then pick the next trial to run out of
# the paused trials.
return TrialScheduler.PAUSE
def _perturb_trial(self, trial, trial_runner, upper_quantile,
lower_quantile):
"""Checkpoint if in upper quantile, exploits if in lower."""
state = self._trial_state[trial]
if trial in upper_quantile:
# The trial last result is only updated after the scheduler
# callback. So, we override with the current result.
state.last_checkpoint = trial_runner.trial_executor.save(
trial, Checkpoint.MEMORY, result=result)
logger.debug("Trial {} is in upper quantile".format(trial))
logger.debug("Checkpointing {}".format(trial))
if trial.status == Trial.PAUSED:
# Paused trial will always have an in-memory checkpoint.
state.last_checkpoint = trial.checkpoint
else:
state.last_checkpoint = trial_runner.trial_executor.save(
trial, Checkpoint.MEMORY, result=state.last_result)
self._num_checkpoints += 1
else:
state.last_checkpoint = None # not a top trial
if trial in lower_quantile:
logger.debug("Trial {} is in lower quantile".format(trial))
trial_to_clone = random.choice(upper_quantile)
assert trial is not trial_to_clone
self._exploit(trial_runner.trial_executor, trial, trial_to_clone)
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED]:
return TrialScheduler.PAUSE # yield time to other trials
return TrialScheduler.CONTINUE
def _log_config_on_step(self, trial_state, new_state, trial,
trial_to_clone, new_config):
"""Logs transition during exploit/exploit step.
@@ -417,26 +490,40 @@ class PopulationBasedTraining(FIFOScheduler):
new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
self._hyperparam_mutations)
reset_successful = trial_executor.reset_trial(trial, new_config,
new_tag)
# TODO(ujvl): Refactor Scheduler abstraction to abstract
# mechanism for trial restart away. We block on restore
# and suppress train on start as a stop-gap fix to
# https://github.com/ray-project/ray/issues/7258.
if reset_successful:
trial_executor.restore(
trial, new_state.last_checkpoint, block=True)
else:
trial_executor.stop_trial(trial, stop_logger=False)
if trial.status == Trial.PAUSED:
# If trial is paused we update it with a new checkpoint.
# When the trial is started again, the new checkpoint is used.
if not self._synch:
raise TuneError("Trials should be paused here only if in "
"synchronous mode. If you encounter this error"
" please raise an issue on Ray Github.")
trial.config = new_config
trial.experiment_tag = new_tag
trial_executor.start_trial(
trial, new_state.last_checkpoint, train=False)
trial.on_checkpoint(new_state.last_checkpoint)
else:
# If trial is running, we first try to reset it.
# If that is unsuccessful, then we have to stop it and start it
# again with a new checkpoint.
reset_successful = trial_executor.reset_trial(
trial, new_config, new_tag)
# TODO(ujvl): Refactor Scheduler abstraction to abstract
# mechanism for trial restart away. We block on restore
# and suppress train on start as a stop-gap fix to
# https://github.com/ray-project/ray/issues/7258.
if reset_successful:
trial_executor.restore(
trial, new_state.last_checkpoint, block=True)
else:
trial_executor.stop_trial(trial, stop_logger=False)
trial.config = new_config
trial.experiment_tag = new_tag
trial_executor.start_trial(
trial, new_state.last_checkpoint, train=False)
self._num_perturbations += 1
# Transfer over the last perturbation time as well
trial_state.last_perturbation_time = new_state.last_perturbation_time
trial_state.last_train_time = new_state.last_train_time
def _quantiles(self):
"""Returns trials in the lower and upper `quantile` of the population.
@@ -445,6 +532,9 @@ class PopulationBasedTraining(FIFOScheduler):
"""
trials = []
for trial, state in self._trial_state.items():
logger.debug("Trial {}, state {}".format(trial, state))
if trial.is_finished():
logger.debug("Trial {} is finished".format(trial))
if state.last_score is not None and not trial.is_finished():
trials.append(trial)
trials.sort(key=lambda t: self._trial_state[t].last_score)
@@ -469,9 +559,13 @@ class PopulationBasedTraining(FIFOScheduler):
for trial in trial_runner.get_trials():
if trial.status in [Trial.PENDING, Trial.PAUSED] and \
trial_runner.has_resources(trial.resources):
candidates.append(trial)
if not self._synch:
candidates.append(trial)
elif self._trial_state[trial].last_train_time < \
self._next_perturbation_sync:
candidates.append(trial)
candidates.sort(
key=lambda trial: self._trial_state[trial].last_perturbation_time)
key=lambda trial: self._trial_state[trial].last_train_time)
return candidates[0] if candidates else None
def reset_stats(self):
+308 -5
View File
@@ -241,6 +241,7 @@ class _MockTrialRunner():
return True
def _pause_trial(self, trial):
self.trial_executor.save(trial, Checkpoint.MEMORY, None)
trial.status = Trial.PAUSED
def _launch_trial(self, trial):
@@ -702,6 +703,13 @@ class _MockTrial(Trial):
self.custom_trial_name = None
self.custom_dirname = None
def on_checkpoint(self, checkpoint):
self.restored_checkpoint = checkpoint.value
@property
def checkpoint(self):
return Checkpoint(Checkpoint.MEMORY, self.trainable_name, None)
class PopulationBasedTestingSuite(unittest.TestCase):
def setUp(self):
@@ -720,7 +728,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
require_attrs=True,
hyperparams=None,
hyperparam_mutations=None,
step_once=True):
step_once=True,
synch=False):
hyperparam_mutations = hyperparam_mutations or {
"float_factor": lambda: 100.0,
"int_factor": lambda: 10,
@@ -734,7 +743,9 @@ class PopulationBasedTestingSuite(unittest.TestCase):
hyperparam_mutations=hyperparam_mutations,
custom_explore_fn=explore,
log_config=log_config,
require_attrs=require_attrs)
synch=synch,
require_attrs=require_attrs,
)
runner = _MockTrialRunner(pbt)
for i in range(num_trials):
trial_hyperparams = hyperparams or {
@@ -746,10 +757,17 @@ class PopulationBasedTestingSuite(unittest.TestCase):
trial = _MockTrial(i, trial_hyperparams)
runner.add_trial(trial)
trial.status = Trial.RUNNING
for i in range(num_trials):
trial = runner.trials[i]
if step_once:
self.assertEqual(
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
TrialScheduler.CONTINUE)
if synch:
self.assertEqual(
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
TrialScheduler.PAUSE)
else:
self.assertEqual(
pbt.on_trial_result(runner, trial, result(10, 50 * i)),
TrialScheduler.CONTINUE)
pbt.reset_stats()
return pbt, runner
@@ -822,6 +840,32 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(pbt._num_checkpoints, 2)
self.assertEqual(pbt._num_perturbations, 0)
def testCheckpointMostPromisingTrialsSynch(self):
pbt, runner = self.basicSetup(synch=True)
trials = runner.get_trials()
# no checkpoint: haven't hit next perturbation interval yet
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 0)
# trials should be paused until all trials are synced.
for i in range(len(trials) - 1):
self.assertEqual(
pbt.on_trial_result(runner, trials[i], result(20, 200 + i)),
TrialScheduler.PAUSE)
self.assertEqual(pbt.last_scores(trials), [200, 201, 202, 203, 200])
self.assertEqual(pbt._num_checkpoints, 0)
self.assertEqual(
pbt.on_trial_result(runner, trials[-1], result(20, 204)),
TrialScheduler.PAUSE)
self.assertEqual(pbt._num_checkpoints, 2)
def testPerturbsLowPerformingTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
@@ -852,6 +896,35 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertTrue("@perturbed" in trials[2].experiment_tag)
def testPerturbsLowPerformingTrialsSynch(self):
pbt, runner = self.basicSetup(synch=True)
trials = runner.get_trials()
# no perturbation: haven't hit next perturbation interval
self.assertEqual(
pbt.on_trial_result(runner, trials[-1], result(15, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
self.assertEqual(pbt._num_perturbations, 0)
# Don't perturb until all trials are synched.
self.assertEqual(
pbt.on_trial_result(runner, trials[-1], result(20, -100)),
TrialScheduler.PAUSE)
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, -100])
self.assertTrue("@perturbed" not in trials[-1].experiment_tag)
# Synch all trials.
for i in range(len(trials) - 1):
self.assertEqual(
pbt.on_trial_result(runner, trials[i], result(20, -10 * i)),
TrialScheduler.PAUSE)
self.assertEqual(pbt.last_scores(trials), [0, -10, -20, -30, -100])
self.assertIn(trials[-1].restored_checkpoint, ["trial_0", "trial_1"])
self.assertIn(trials[-2].restored_checkpoint, ["trial_0", "trial_1"])
self.assertEqual(pbt._num_perturbations, 2)
def testPerturbWithoutResample(self):
pbt, runner = self.basicSetup(resample_prob=0.0)
trials = runner.get_trials()
@@ -1088,6 +1161,27 @@ class PopulationBasedTestingSuite(unittest.TestCase):
trials[i].status = Trial.PENDING
self.assertEqual(pbt.choose_trial_to_run(runner), trials[3])
def testSchedulesMostBehindTrialToRunSynch(self):
pbt, runner = self.basicSetup(synch=True)
trials = runner.get_trials()
runner.process_action(
trials[0], pbt.on_trial_result(runner, trials[0], result(
800, 1000)))
runner.process_action(
trials[1], pbt.on_trial_result(runner, trials[1], result(
700, 1001)))
runner.process_action(
trials[2], pbt.on_trial_result(runner, trials[2], result(
600, 1002)))
runner.process_action(
trials[3], pbt.on_trial_result(runner, trials[3], result(
500, 1003)))
runner.process_action(
trials[4], pbt.on_trial_result(runner, trials[4], result(
700, 1004)))
self.assertIn(
pbt.choose_trial_to_run(runner), [trials[0], trials[1], trials[3]])
def testPerturbationResetsLastPerturbTime(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
@@ -1141,6 +1235,44 @@ class PopulationBasedTestingSuite(unittest.TestCase):
check_policy(json.loads(line))
shutil.rmtree(tmpdir)
def testLogConfigSynch(self):
def check_policy(policy):
self.assertIsInstance(policy[2], int)
self.assertIsInstance(policy[3], int)
self.assertIn(policy[0], ["0tag", "1tag"])
self.assertIn(policy[1], ["3tag", "4tag"])
self.assertIn(policy[2], [0, 1])
self.assertIn(policy[3], [3, 4])
for i in [4, 5]:
self.assertIsInstance(policy[i], dict)
for key in [
"const_factor", "int_factor", "float_factor",
"id_factor"
]:
self.assertIn(key, policy[i])
self.assertIsInstance(policy[i]["float_factor"], float)
self.assertIsInstance(policy[i]["int_factor"], int)
self.assertIn(policy[i]["const_factor"], [3])
self.assertIn(policy[i]["int_factor"], [8, 10, 12])
self.assertIn(policy[i]["float_factor"], [2.4, 2, 1.6])
self.assertIn(policy[i]["id_factor"], [3, 4, 100])
pbt, runner = self.basicSetup(
log_config=True, synch=True, step_once=False)
trials = runner.get_trials()
tmpdir = tempfile.mkdtemp()
for i, trial in enumerate(trials):
trial.local_dir = tmpdir
trial.last_result = {TRAINING_ITERATION: i}
pbt.on_trial_result(runner, trials[i], result(10, i))
log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_1.txt"]
for log_file in log_files:
self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file)))
raw_policy = open(os.path.join(tmpdir, log_file), "r").readlines()
for line in raw_policy:
check_policy(json.loads(line))
shutil.rmtree(tmpdir)
def testReplay(self):
# Returns unique increasing parameter mutations
class _Counter:
@@ -1156,6 +1288,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
perturbation_interval=5,
log_config=True,
step_once=False,
synch=False,
hyperparam_mutations={
"float_factor": lambda: 100.0,
"int_factor": _Counter(1000)
@@ -1292,6 +1425,176 @@ class PopulationBasedTestingSuite(unittest.TestCase):
shutil.rmtree(tmpdir)
def testReplaySynch(self):
# Returns unique increasing parameter mutations
class _Counter:
def __init__(self, start=0):
self.count = start - 1
def __call__(self, *args, **kwargs):
self.count += 1
return self.count
pbt, runner = self.basicSetup(
num_trials=4,
perturbation_interval=5,
log_config=True,
step_once=False,
synch=True,
hyperparam_mutations={
"float_factor": lambda: 100.0,
"int_factor": _Counter(1000)
})
trials = runner.get_trials()
tmpdir = tempfile.mkdtemp()
# Internal trial state to collect the real PBT history
class _TrialState:
def __init__(self, config):
self.step = 0
self.config = config
self.history = []
def forward(self, t):
while self.step < t:
self.history.append(self.config)
self.step += 1
trial_state = []
for i, trial in enumerate(trials):
trial.local_dir = tmpdir
trial.last_result = {TRAINING_ITERATION: 0}
trial_state.append(_TrialState(trial.config))
# Helper function to simulate stepping trial k a number of steps,
# and reporting a score at the end
def trial_step(k, steps, score, synced=False):
res = result(trial_state[k].step + steps, score)
trials[k].last_result = res
trial_state[k].forward(res[TRAINING_ITERATION])
trials[k].status = Trial.RUNNING
if not synced:
action = pbt.on_trial_result(runner, trials[k], res)
runner.process_action(trials[k], action)
return
else:
# Reached synchronization point
old_configs = [trial.config for trial in trials]
action = pbt.on_trial_result(runner, trials[k], res)
runner.process_action(trials[k], action)
new_configs = [trial.config for trial in trials]
for i in range(len(trials)):
old_config = old_configs[i]
new_config = new_configs[i]
if old_config != new_config:
# Copy history from source trial
source = -1
for m, cand in enumerate(trials):
if cand.trainable_name == trials[
i].restored_checkpoint:
source = m
break
assert source >= 0
trial_state[i].history = trial_state[
source].history.copy()
trial_state[i].step = trial_state[source].step
trial_state[i].config = new_config.copy()
# Initial steps
trial_step(0, 10, 0)
trial_step(1, 11, 10)
trial_step(2, 12, 0)
trial_step(3, 13, -1, synced=True)
# 3 <-- 1, new_t 11
# next_perturb_sync = 13
# Next block
trial_step(0, 17, -10) # 20
trial_step(2, 15, -20) # 20
trial_step(3, 16, 0) # 20
trial_step(1, 7, 1, synced=True) # 18
# 2 <-- 1, new_t=11+7=18
# next_perturb_sync = 20
# Next block
trial_step(2, 13, 0) # 31
trial_step(3, 14, 10) # 34
trial_step(0, 11, -1) # 31
trial_step(1, 12, 0, synced=True) # 30
# 0 <-- 3, new_t=11+9+14=34
# next_perturb_sync = 34
# Next block
trial_step(0, 6, 20) # 40
trial_step(3, 9, -40) # 43
trial_step(2, 8, -50) # 39
trial_step(1, 7, 30, synced=True) # 37
# 2 <-- 1, new_t=18+13+8=37
# next_perturb_sync = 43
# Playback trainable to collect configs at each step
class Playback(Trainable):
def setup(self, config):
self.config = config
self.replayed = []
self.iter = 0
def step(self):
self.iter += 1
self.replayed.append(self.config)
return {
"reward": 0,
"done": False,
"replayed": self.replayed,
TRAINING_ITERATION: self.iter
}
def reset_config(self, new_config):
self.config = new_config
return True
def save_checkpoint(self, tmp_checkpoint_dir):
return tmp_checkpoint_dir
def load_checkpoint(self, checkpoint):
pass
# Loop through all trials and check if PBT history is the
# same as the playback history
for i, trial in enumerate(trials):
if trial.trial_id in ["1"]: # Did not exploit anything
continue
replay = PopulationBasedTrainingReplay(
os.path.join(tmpdir,
"pbt_policy_{}.txt".format(trial.trial_id)))
analysis = tune.run(
Playback,
scheduler=replay,
stop={TRAINING_ITERATION: trial_state[i].step})
replayed = analysis.trials[0].last_result["replayed"]
self.assertSequenceEqual(trial_state[i].history, replayed)
# Trial 1 did not exploit anything and should raise an error
with self.assertRaises(ValueError):
replay = PopulationBasedTrainingReplay(
os.path.join(tmpdir,
"pbt_policy_{}.txt".format(trials[1].trial_id)))
tune.run(
Playback,
scheduler=replay,
stop={TRAINING_ITERATION: trial_state[1].step})
shutil.rmtree(tmpdir)
def testPostprocessingHook(self):
def explore(new_config):
new_config["id_factor"] = 42
+126 -53
View File
@@ -4,64 +4,13 @@ import pickle
import random
import unittest
import sys
import time
import ray
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining
class MockTrainable(tune.Trainable):
def setup(self, config):
self.iter = 0
self.a = config["a"]
self.b = config["b"]
self.c = config["c"]
def step(self):
self.iter += 1
return {"mean_accuracy": (self.a - self.iter) * self.b}
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
with open(checkpoint_path, "wb") as fp:
pickle.dump((self.a, self.b, self.iter), fp)
return tmp_checkpoint_dir
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock")
with open(checkpoint_path, "rb") as fp:
self.a, self.b, self.iter = pickle.load(fp)
def MockTrainingFunc(config, checkpoint_dir=None):
iter = 0
a = config["a"]
b = config["b"]
if checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
with open(checkpoint_path, "rb") as fp:
a, b, iter = pickle.load(fp)
while True:
iter += 1
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
with open(checkpoint_path, "wb") as fp:
pickle.dump((a, b, iter), fp)
tune.report(mean_accuracy=(a - iter) * b)
def MockTrainingFunc2(config):
a = config["a"]
b = config["b"]
c1 = config["c"]["c1"]
c2 = config["c"]["c2"]
while True:
tune.report(mean_accuracy=a * b * (c1 + c2))
class MockParam(object):
def __init__(self, params):
self._params = params
@@ -73,6 +22,77 @@ class MockParam(object):
return val
class PopulationBasedTrainingSynchTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=2)
def MockTrainingFuncSync(config, checkpoint_dir=None):
iter = 0
if checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
with open(checkpoint_path, "rb") as fp:
a, iter = pickle.load(fp)
a = config["a"] # Use the new hyperparameter if perturbed.
while True:
iter += 1
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint")
with open(checkpoint_path, "wb") as fp:
pickle.dump((a, iter), fp)
# Score gets better every iteration.
time.sleep(1)
tune.report(mean_accuracy=iter + a, a=a)
self.MockTrainingFuncSync = MockTrainingFuncSync
def tearDown(self):
ray.shutdown()
def synchSetup(self, synch, param=[10, 20, 30]):
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=1,
log_config=True,
hyperparam_mutations={"c": lambda: 1},
synch=synch)
param_a = MockParam(param)
random.seed(100)
np.random.seed(100)
analysis = tune.run(
self.MockTrainingFuncSync,
config={
"a": tune.sample_from(lambda _: param_a()),
"c": 1
},
fail_fast=True,
num_samples=3,
scheduler=scheduler,
name="testPBTSync",
stop={"training_iteration": 3},
)
return analysis
def testAsynchFail(self):
analysis = self.synchSetup(False)
self.assertTrue(any(analysis.dataframe()["mean_accuracy"] != 33))
def testSynchPass(self):
analysis = self.synchSetup(True)
self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33))
def testSynchPassLast(self):
analysis = self.synchSetup(True, param=[30, 20, 10])
self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33))
class PopulationBasedTrainingConfigTest(unittest.TestCase):
def setUp(self):
ray.init()
@@ -81,6 +101,15 @@ class PopulationBasedTrainingConfigTest(unittest.TestCase):
ray.shutdown()
def testNoConfig(self):
def MockTrainingFunc(config):
a = config["a"]
b = config["b"]
c1 = config["c"]["c1"]
c2 = config["c"]["c2"]
while True:
tune.report(mean_accuracy=a * b * (c1 + c2))
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
@@ -97,7 +126,7 @@ class PopulationBasedTrainingConfigTest(unittest.TestCase):
)
tune.run(
MockTrainingFunc2,
MockTrainingFunc,
fail_fast=True,
num_samples=4,
scheduler=scheduler,
@@ -120,6 +149,31 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
fix was not applied.
See issues #9036, #9036
"""
class MockTrainable(tune.Trainable):
def setup(self, config):
self.iter = 0
self.a = config["a"]
self.b = config["b"]
self.c = config["c"]
def step(self):
self.iter += 1
return {"mean_accuracy": (self.a - self.iter) * self.b}
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir,
"model.mock")
with open(checkpoint_path, "wb") as fp:
pickle.dump((self.a, self.b, self.iter), fp)
return tmp_checkpoint_dir
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir,
"model.mock")
with open(checkpoint_path, "rb") as fp:
self.a, self.b, self.iter = pickle.load(fp)
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
@@ -151,6 +205,25 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
stop={"training_iteration": 3})
def testPermutationContinuationFunc(self):
def MockTrainingFunc(config, checkpoint_dir=None):
iter = 0
a = config["a"]
b = config["b"]
if checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir, "model.mock")
with open(checkpoint_path, "rb") as fp:
a, b, iter = pickle.load(fp)
while True:
iter += 1
with tune.checkpoint_dir(step=iter) as checkpoint_dir:
checkpoint_path = os.path.join(checkpoint_dir,
"model.mock")
with open(checkpoint_path, "wb") as fp:
pickle.dump((a, b, iter), fp)
tune.report(mean_accuracy=(a - iter) * b)
scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
+1
View File
@@ -460,6 +460,7 @@ class TrialRunner:
self._update_trial_queue(blocking=wait_for_trial)
with warn_if_slow("choose_trial_to_run"):
trial = self._scheduler_alg.choose_trial_to_run(self)
logger.debug("Running trial {}".format(trial))
return trial
def _process_events(self):