diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 588c247bd..fd19ff933 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -27,10 +27,12 @@ class PBTTrialState: self.last_score = None self.last_checkpoint = None self.last_perturbation_time = 0 + self.last_train_time = 0 # Used for synchronous mode. + self.last_result = None # Used for synchronous mode. def __repr__(self): return str((self.last_score, self.last_checkpoint, - self.last_perturbation_time)) + self.last_train_time, self.last_perturbation_time)) def explore(config, mutations, resample_probability, custom_explore_fn): @@ -174,6 +176,13 @@ class PopulationBasedTraining(FIFOScheduler): require_attrs (bool): Whether to require time_attr and metric to appear in result for every iteration. If True, error will be raised if these values are not present in trial result. + synch (bool): If False, will use asynchronous implementation of + PBT. Trial perturbations occur every perturbation_interval for each + trial independently. If True, will use synchronous implementation + of PBT. Perturbations will occur only after all trials are + synced at the same time_attr every perturbation_interval. + Defaults to False. See Appendix A.1 here + https://arxiv.org/pdf/1711.09846.pdf. .. code-block:: python @@ -215,7 +224,8 @@ class PopulationBasedTraining(FIFOScheduler): resample_probability=0.25, custom_explore_fn=None, log_config=True, - require_attrs=True): + require_attrs=True, + synch=False): for value in hyperparam_mutations.values(): if not (isinstance(value, (list, dict, sample_from)) or callable(value)): @@ -234,10 +244,15 @@ class PopulationBasedTraining(FIFOScheduler): "`custom_explore_fn` to use PBT.") if quantile_fraction > 0.5 or quantile_fraction < 0: - raise TuneError( + raise ValueError( "You must set `quantile_fraction` to a value between 0 and" "0.5. Current value: '{}'".format(quantile_fraction)) + if perturbation_interval <= 0: + raise ValueError( + "perturbation_interval must be a positive number greater " + "than 0. Current value: '{}'".format(perturbation_interval)) + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" if reward_attr is not None: @@ -263,6 +278,8 @@ class PopulationBasedTraining(FIFOScheduler): self._custom_explore_fn = custom_explore_fn self._log_config = log_config self._require_attrs = require_attrs + self._synch = synch + self._next_perturbation_sync = self._perturbation_interval # Metrics self._num_checkpoints = 0 @@ -320,34 +337,90 @@ class PopulationBasedTraining(FIFOScheduler): time = result[self._time_attr] state = self._trial_state[trial] + # Continue training if perturbation interval has not been reached yet. if time - state.last_perturbation_time < self._perturbation_interval: return TrialScheduler.CONTINUE # avoid checkpoint overhead + # This trial has reached its perturbation interval score = self._metric_op * result[self._metric] state.last_score = score - state.last_perturbation_time = time - lower_quantile, upper_quantile = self._quantiles() + state.last_train_time = time + state.last_result = result + if not self._synch: + state.last_perturbation_time = time + lower_quantile, upper_quantile = self._quantiles() + self._perturb_trial(trial, trial_runner, upper_quantile, + lower_quantile) + for trial in trial_runner.get_trials(): + if trial.status in [Trial.PENDING, Trial.PAUSED]: + return TrialScheduler.PAUSE # yield time to other trials + + return TrialScheduler.CONTINUE + else: + # Synchronous mode. + if any(self._trial_state[t].last_train_time < + self._next_perturbation_sync and t != trial + for t in trial_runner.get_trials()): + logger.debug("Pausing trial {}".format(trial)) + else: + # All trials are synced at the same timestep. + lower_quantile, upper_quantile = self._quantiles() + all_trials = trial_runner.get_trials() + not_in_quantile = [] + for t in all_trials: + if t not in lower_quantile and t not in upper_quantile: + not_in_quantile.append(t) + # Move upper quantile trials to beginning and lower quantile + # to end. This ensures that checkpointing of strong trials + # occurs before exploiting of weaker ones. + all_trials = upper_quantile + not_in_quantile + lower_quantile + for t in all_trials: + logger.debug("Perturbing Trial {}".format(t)) + self._trial_state[t].last_perturbation_time = time + self._perturb_trial(t, trial_runner, upper_quantile, + lower_quantile) + + all_train_times = [ + self._trial_state[trial].last_train_time + for trial in trial_runner.get_trials() + ] + max_last_train_time = max(all_train_times) + self._next_perturbation_sync = max( + self._next_perturbation_sync + self._perturbation_interval, + max_last_train_time) + # In sync mode we should pause all trials once result comes in. + # Once a perturbation step happens for all trials, they should + # still all be paused. + # choose_trial_to_run will then pick the next trial to run out of + # the paused trials. + return TrialScheduler.PAUSE + + def _perturb_trial(self, trial, trial_runner, upper_quantile, + lower_quantile): + """Checkpoint if in upper quantile, exploits if in lower.""" + state = self._trial_state[trial] if trial in upper_quantile: # The trial last result is only updated after the scheduler # callback. So, we override with the current result. - state.last_checkpoint = trial_runner.trial_executor.save( - trial, Checkpoint.MEMORY, result=result) + logger.debug("Trial {} is in upper quantile".format(trial)) + logger.debug("Checkpointing {}".format(trial)) + if trial.status == Trial.PAUSED: + # Paused trial will always have an in-memory checkpoint. + state.last_checkpoint = trial.checkpoint + else: + state.last_checkpoint = trial_runner.trial_executor.save( + trial, Checkpoint.MEMORY, result=state.last_result) self._num_checkpoints += 1 else: state.last_checkpoint = None # not a top trial if trial in lower_quantile: + logger.debug("Trial {} is in lower quantile".format(trial)) trial_to_clone = random.choice(upper_quantile) assert trial is not trial_to_clone self._exploit(trial_runner.trial_executor, trial, trial_to_clone) - for trial in trial_runner.get_trials(): - if trial.status in [Trial.PENDING, Trial.PAUSED]: - return TrialScheduler.PAUSE # yield time to other trials - - return TrialScheduler.CONTINUE - def _log_config_on_step(self, trial_state, new_state, trial, trial_to_clone, new_config): """Logs transition during exploit/exploit step. @@ -417,26 +490,40 @@ class PopulationBasedTraining(FIFOScheduler): new_tag = make_experiment_tag(trial_state.orig_tag, new_config, self._hyperparam_mutations) - reset_successful = trial_executor.reset_trial(trial, new_config, - new_tag) - - # TODO(ujvl): Refactor Scheduler abstraction to abstract - # mechanism for trial restart away. We block on restore - # and suppress train on start as a stop-gap fix to - # https://github.com/ray-project/ray/issues/7258. - if reset_successful: - trial_executor.restore( - trial, new_state.last_checkpoint, block=True) - else: - trial_executor.stop_trial(trial, stop_logger=False) + if trial.status == Trial.PAUSED: + # If trial is paused we update it with a new checkpoint. + # When the trial is started again, the new checkpoint is used. + if not self._synch: + raise TuneError("Trials should be paused here only if in " + "synchronous mode. If you encounter this error" + " please raise an issue on Ray Github.") trial.config = new_config trial.experiment_tag = new_tag - trial_executor.start_trial( - trial, new_state.last_checkpoint, train=False) + trial.on_checkpoint(new_state.last_checkpoint) + else: + # If trial is running, we first try to reset it. + # If that is unsuccessful, then we have to stop it and start it + # again with a new checkpoint. + reset_successful = trial_executor.reset_trial( + trial, new_config, new_tag) + # TODO(ujvl): Refactor Scheduler abstraction to abstract + # mechanism for trial restart away. We block on restore + # and suppress train on start as a stop-gap fix to + # https://github.com/ray-project/ray/issues/7258. + if reset_successful: + trial_executor.restore( + trial, new_state.last_checkpoint, block=True) + else: + trial_executor.stop_trial(trial, stop_logger=False) + trial.config = new_config + trial.experiment_tag = new_tag + trial_executor.start_trial( + trial, new_state.last_checkpoint, train=False) self._num_perturbations += 1 # Transfer over the last perturbation time as well trial_state.last_perturbation_time = new_state.last_perturbation_time + trial_state.last_train_time = new_state.last_train_time def _quantiles(self): """Returns trials in the lower and upper `quantile` of the population. @@ -445,6 +532,9 @@ class PopulationBasedTraining(FIFOScheduler): """ trials = [] for trial, state in self._trial_state.items(): + logger.debug("Trial {}, state {}".format(trial, state)) + if trial.is_finished(): + logger.debug("Trial {} is finished".format(trial)) if state.last_score is not None and not trial.is_finished(): trials.append(trial) trials.sort(key=lambda t: self._trial_state[t].last_score) @@ -469,9 +559,13 @@ class PopulationBasedTraining(FIFOScheduler): for trial in trial_runner.get_trials(): if trial.status in [Trial.PENDING, Trial.PAUSED] and \ trial_runner.has_resources(trial.resources): - candidates.append(trial) + if not self._synch: + candidates.append(trial) + elif self._trial_state[trial].last_train_time < \ + self._next_perturbation_sync: + candidates.append(trial) candidates.sort( - key=lambda trial: self._trial_state[trial].last_perturbation_time) + key=lambda trial: self._trial_state[trial].last_train_time) return candidates[0] if candidates else None def reset_stats(self): diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 30c3649bc..ec7e96c5b 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -241,6 +241,7 @@ class _MockTrialRunner(): return True def _pause_trial(self, trial): + self.trial_executor.save(trial, Checkpoint.MEMORY, None) trial.status = Trial.PAUSED def _launch_trial(self, trial): @@ -702,6 +703,13 @@ class _MockTrial(Trial): self.custom_trial_name = None self.custom_dirname = None + def on_checkpoint(self, checkpoint): + self.restored_checkpoint = checkpoint.value + + @property + def checkpoint(self): + return Checkpoint(Checkpoint.MEMORY, self.trainable_name, None) + class PopulationBasedTestingSuite(unittest.TestCase): def setUp(self): @@ -720,7 +728,8 @@ class PopulationBasedTestingSuite(unittest.TestCase): require_attrs=True, hyperparams=None, hyperparam_mutations=None, - step_once=True): + step_once=True, + synch=False): hyperparam_mutations = hyperparam_mutations or { "float_factor": lambda: 100.0, "int_factor": lambda: 10, @@ -734,7 +743,9 @@ class PopulationBasedTestingSuite(unittest.TestCase): hyperparam_mutations=hyperparam_mutations, custom_explore_fn=explore, log_config=log_config, - require_attrs=require_attrs) + synch=synch, + require_attrs=require_attrs, + ) runner = _MockTrialRunner(pbt) for i in range(num_trials): trial_hyperparams = hyperparams or { @@ -746,10 +757,17 @@ class PopulationBasedTestingSuite(unittest.TestCase): trial = _MockTrial(i, trial_hyperparams) runner.add_trial(trial) trial.status = Trial.RUNNING + for i in range(num_trials): + trial = runner.trials[i] if step_once: - self.assertEqual( - pbt.on_trial_result(runner, trial, result(10, 50 * i)), - TrialScheduler.CONTINUE) + if synch: + self.assertEqual( + pbt.on_trial_result(runner, trial, result(10, 50 * i)), + TrialScheduler.PAUSE) + else: + self.assertEqual( + pbt.on_trial_result(runner, trial, result(10, 50 * i)), + TrialScheduler.CONTINUE) pbt.reset_stats() return pbt, runner @@ -822,6 +840,32 @@ class PopulationBasedTestingSuite(unittest.TestCase): self.assertEqual(pbt._num_checkpoints, 2) self.assertEqual(pbt._num_perturbations, 0) + def testCheckpointMostPromisingTrialsSynch(self): + pbt, runner = self.basicSetup(synch=True) + trials = runner.get_trials() + + # no checkpoint: haven't hit next perturbation interval yet + self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200]) + self.assertEqual( + pbt.on_trial_result(runner, trials[0], result(15, 200)), + TrialScheduler.CONTINUE) + self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200]) + self.assertEqual(pbt._num_checkpoints, 0) + + # trials should be paused until all trials are synced. + for i in range(len(trials) - 1): + self.assertEqual( + pbt.on_trial_result(runner, trials[i], result(20, 200 + i)), + TrialScheduler.PAUSE) + + self.assertEqual(pbt.last_scores(trials), [200, 201, 202, 203, 200]) + self.assertEqual(pbt._num_checkpoints, 0) + + self.assertEqual( + pbt.on_trial_result(runner, trials[-1], result(20, 204)), + TrialScheduler.PAUSE) + self.assertEqual(pbt._num_checkpoints, 2) + def testPerturbsLowPerformingTrials(self): pbt, runner = self.basicSetup() trials = runner.get_trials() @@ -852,6 +896,35 @@ class PopulationBasedTestingSuite(unittest.TestCase): self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"]) self.assertTrue("@perturbed" in trials[2].experiment_tag) + def testPerturbsLowPerformingTrialsSynch(self): + pbt, runner = self.basicSetup(synch=True) + trials = runner.get_trials() + + # no perturbation: haven't hit next perturbation interval + self.assertEqual( + pbt.on_trial_result(runner, trials[-1], result(15, -100)), + TrialScheduler.CONTINUE) + self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200]) + self.assertTrue("@perturbed" not in trials[-1].experiment_tag) + self.assertEqual(pbt._num_perturbations, 0) + + # Don't perturb until all trials are synched. + self.assertEqual( + pbt.on_trial_result(runner, trials[-1], result(20, -100)), + TrialScheduler.PAUSE) + self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, -100]) + self.assertTrue("@perturbed" not in trials[-1].experiment_tag) + + # Synch all trials. + for i in range(len(trials) - 1): + self.assertEqual( + pbt.on_trial_result(runner, trials[i], result(20, -10 * i)), + TrialScheduler.PAUSE) + self.assertEqual(pbt.last_scores(trials), [0, -10, -20, -30, -100]) + self.assertIn(trials[-1].restored_checkpoint, ["trial_0", "trial_1"]) + self.assertIn(trials[-2].restored_checkpoint, ["trial_0", "trial_1"]) + self.assertEqual(pbt._num_perturbations, 2) + def testPerturbWithoutResample(self): pbt, runner = self.basicSetup(resample_prob=0.0) trials = runner.get_trials() @@ -1088,6 +1161,27 @@ class PopulationBasedTestingSuite(unittest.TestCase): trials[i].status = Trial.PENDING self.assertEqual(pbt.choose_trial_to_run(runner), trials[3]) + def testSchedulesMostBehindTrialToRunSynch(self): + pbt, runner = self.basicSetup(synch=True) + trials = runner.get_trials() + runner.process_action( + trials[0], pbt.on_trial_result(runner, trials[0], result( + 800, 1000))) + runner.process_action( + trials[1], pbt.on_trial_result(runner, trials[1], result( + 700, 1001))) + runner.process_action( + trials[2], pbt.on_trial_result(runner, trials[2], result( + 600, 1002))) + runner.process_action( + trials[3], pbt.on_trial_result(runner, trials[3], result( + 500, 1003))) + runner.process_action( + trials[4], pbt.on_trial_result(runner, trials[4], result( + 700, 1004))) + self.assertIn( + pbt.choose_trial_to_run(runner), [trials[0], trials[1], trials[3]]) + def testPerturbationResetsLastPerturbTime(self): pbt, runner = self.basicSetup() trials = runner.get_trials() @@ -1141,6 +1235,44 @@ class PopulationBasedTestingSuite(unittest.TestCase): check_policy(json.loads(line)) shutil.rmtree(tmpdir) + def testLogConfigSynch(self): + def check_policy(policy): + self.assertIsInstance(policy[2], int) + self.assertIsInstance(policy[3], int) + self.assertIn(policy[0], ["0tag", "1tag"]) + self.assertIn(policy[1], ["3tag", "4tag"]) + self.assertIn(policy[2], [0, 1]) + self.assertIn(policy[3], [3, 4]) + for i in [4, 5]: + self.assertIsInstance(policy[i], dict) + for key in [ + "const_factor", "int_factor", "float_factor", + "id_factor" + ]: + self.assertIn(key, policy[i]) + self.assertIsInstance(policy[i]["float_factor"], float) + self.assertIsInstance(policy[i]["int_factor"], int) + self.assertIn(policy[i]["const_factor"], [3]) + self.assertIn(policy[i]["int_factor"], [8, 10, 12]) + self.assertIn(policy[i]["float_factor"], [2.4, 2, 1.6]) + self.assertIn(policy[i]["id_factor"], [3, 4, 100]) + + pbt, runner = self.basicSetup( + log_config=True, synch=True, step_once=False) + trials = runner.get_trials() + tmpdir = tempfile.mkdtemp() + for i, trial in enumerate(trials): + trial.local_dir = tmpdir + trial.last_result = {TRAINING_ITERATION: i} + pbt.on_trial_result(runner, trials[i], result(10, i)) + log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_1.txt"] + for log_file in log_files: + self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file))) + raw_policy = open(os.path.join(tmpdir, log_file), "r").readlines() + for line in raw_policy: + check_policy(json.loads(line)) + shutil.rmtree(tmpdir) + def testReplay(self): # Returns unique increasing parameter mutations class _Counter: @@ -1156,6 +1288,7 @@ class PopulationBasedTestingSuite(unittest.TestCase): perturbation_interval=5, log_config=True, step_once=False, + synch=False, hyperparam_mutations={ "float_factor": lambda: 100.0, "int_factor": _Counter(1000) @@ -1292,6 +1425,176 @@ class PopulationBasedTestingSuite(unittest.TestCase): shutil.rmtree(tmpdir) + def testReplaySynch(self): + # Returns unique increasing parameter mutations + class _Counter: + def __init__(self, start=0): + self.count = start - 1 + + def __call__(self, *args, **kwargs): + self.count += 1 + return self.count + + pbt, runner = self.basicSetup( + num_trials=4, + perturbation_interval=5, + log_config=True, + step_once=False, + synch=True, + hyperparam_mutations={ + "float_factor": lambda: 100.0, + "int_factor": _Counter(1000) + }) + trials = runner.get_trials() + tmpdir = tempfile.mkdtemp() + + # Internal trial state to collect the real PBT history + class _TrialState: + def __init__(self, config): + self.step = 0 + self.config = config + self.history = [] + + def forward(self, t): + while self.step < t: + self.history.append(self.config) + self.step += 1 + + trial_state = [] + for i, trial in enumerate(trials): + trial.local_dir = tmpdir + trial.last_result = {TRAINING_ITERATION: 0} + trial_state.append(_TrialState(trial.config)) + + # Helper function to simulate stepping trial k a number of steps, + # and reporting a score at the end + def trial_step(k, steps, score, synced=False): + res = result(trial_state[k].step + steps, score) + + trials[k].last_result = res + trial_state[k].forward(res[TRAINING_ITERATION]) + + trials[k].status = Trial.RUNNING + if not synced: + action = pbt.on_trial_result(runner, trials[k], res) + runner.process_action(trials[k], action) + return + else: + # Reached synchronization point + old_configs = [trial.config for trial in trials] + action = pbt.on_trial_result(runner, trials[k], res) + runner.process_action(trials[k], action) + new_configs = [trial.config for trial in trials] + + for i in range(len(trials)): + old_config = old_configs[i] + new_config = new_configs[i] + if old_config != new_config: + # Copy history from source trial + source = -1 + for m, cand in enumerate(trials): + if cand.trainable_name == trials[ + i].restored_checkpoint: + source = m + break + assert source >= 0 + trial_state[i].history = trial_state[ + source].history.copy() + trial_state[i].step = trial_state[source].step + trial_state[i].config = new_config.copy() + + # Initial steps + trial_step(0, 10, 0) + trial_step(1, 11, 10) + trial_step(2, 12, 0) + trial_step(3, 13, -1, synced=True) + + # 3 <-- 1, new_t 11 + # next_perturb_sync = 13 + + # Next block + trial_step(0, 17, -10) # 20 + trial_step(2, 15, -20) # 20 + trial_step(3, 16, 0) # 20 + trial_step(1, 7, 1, synced=True) # 18 + + # 2 <-- 1, new_t=11+7=18 + # next_perturb_sync = 20 + + # Next block + trial_step(2, 13, 0) # 31 + trial_step(3, 14, 10) # 34 + trial_step(0, 11, -1) # 31 + trial_step(1, 12, 0, synced=True) # 30 + + # 0 <-- 3, new_t=11+9+14=34 + # next_perturb_sync = 34 + + # Next block + trial_step(0, 6, 20) # 40 + trial_step(3, 9, -40) # 43 + trial_step(2, 8, -50) # 39 + trial_step(1, 7, 30, synced=True) # 37 + + # 2 <-- 1, new_t=18+13+8=37 + # next_perturb_sync = 43 + + # Playback trainable to collect configs at each step + class Playback(Trainable): + def setup(self, config): + self.config = config + self.replayed = [] + self.iter = 0 + + def step(self): + self.iter += 1 + self.replayed.append(self.config) + return { + "reward": 0, + "done": False, + "replayed": self.replayed, + TRAINING_ITERATION: self.iter + } + + def reset_config(self, new_config): + self.config = new_config + return True + + def save_checkpoint(self, tmp_checkpoint_dir): + return tmp_checkpoint_dir + + def load_checkpoint(self, checkpoint): + pass + + # Loop through all trials and check if PBT history is the + # same as the playback history + for i, trial in enumerate(trials): + if trial.trial_id in ["1"]: # Did not exploit anything + continue + + replay = PopulationBasedTrainingReplay( + os.path.join(tmpdir, + "pbt_policy_{}.txt".format(trial.trial_id))) + analysis = tune.run( + Playback, + scheduler=replay, + stop={TRAINING_ITERATION: trial_state[i].step}) + + replayed = analysis.trials[0].last_result["replayed"] + self.assertSequenceEqual(trial_state[i].history, replayed) + + # Trial 1 did not exploit anything and should raise an error + with self.assertRaises(ValueError): + replay = PopulationBasedTrainingReplay( + os.path.join(tmpdir, + "pbt_policy_{}.txt".format(trials[1].trial_id))) + tune.run( + Playback, + scheduler=replay, + stop={TRAINING_ITERATION: trial_state[1].step}) + + shutil.rmtree(tmpdir) + def testPostprocessingHook(self): def explore(new_config): new_config["id_factor"] = 42 diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index f4cfbdaad..740616e8c 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -4,64 +4,13 @@ import pickle import random import unittest import sys +import time import ray from ray import tune from ray.tune.schedulers import PopulationBasedTraining -class MockTrainable(tune.Trainable): - def setup(self, config): - self.iter = 0 - self.a = config["a"] - self.b = config["b"] - self.c = config["c"] - - def step(self): - self.iter += 1 - return {"mean_accuracy": (self.a - self.iter) * self.b} - - def save_checkpoint(self, tmp_checkpoint_dir): - checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock") - with open(checkpoint_path, "wb") as fp: - pickle.dump((self.a, self.b, self.iter), fp) - return tmp_checkpoint_dir - - def load_checkpoint(self, tmp_checkpoint_dir): - checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.mock") - with open(checkpoint_path, "rb") as fp: - self.a, self.b, self.iter = pickle.load(fp) - - -def MockTrainingFunc(config, checkpoint_dir=None): - iter = 0 - a = config["a"] - b = config["b"] - - if checkpoint_dir: - checkpoint_path = os.path.join(checkpoint_dir, "model.mock") - with open(checkpoint_path, "rb") as fp: - a, b, iter = pickle.load(fp) - - while True: - iter += 1 - with tune.checkpoint_dir(step=iter) as checkpoint_dir: - checkpoint_path = os.path.join(checkpoint_dir, "model.mock") - with open(checkpoint_path, "wb") as fp: - pickle.dump((a, b, iter), fp) - tune.report(mean_accuracy=(a - iter) * b) - - -def MockTrainingFunc2(config): - a = config["a"] - b = config["b"] - c1 = config["c"]["c1"] - c2 = config["c"]["c2"] - - while True: - tune.report(mean_accuracy=a * b * (c1 + c2)) - - class MockParam(object): def __init__(self, params): self._params = params @@ -73,6 +22,77 @@ class MockParam(object): return val +class PopulationBasedTrainingSynchTest(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=2) + + def MockTrainingFuncSync(config, checkpoint_dir=None): + iter = 0 + + if checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + with open(checkpoint_path, "rb") as fp: + a, iter = pickle.load(fp) + + a = config["a"] # Use the new hyperparameter if perturbed. + + while True: + iter += 1 + with tune.checkpoint_dir(step=iter) as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint") + with open(checkpoint_path, "wb") as fp: + pickle.dump((a, iter), fp) + # Score gets better every iteration. + time.sleep(1) + tune.report(mean_accuracy=iter + a, a=a) + + self.MockTrainingFuncSync = MockTrainingFuncSync + + def tearDown(self): + ray.shutdown() + + def synchSetup(self, synch, param=[10, 20, 30]): + scheduler = PopulationBasedTraining( + time_attr="training_iteration", + metric="mean_accuracy", + mode="max", + perturbation_interval=1, + log_config=True, + hyperparam_mutations={"c": lambda: 1}, + synch=synch) + + param_a = MockParam(param) + + random.seed(100) + np.random.seed(100) + analysis = tune.run( + self.MockTrainingFuncSync, + config={ + "a": tune.sample_from(lambda _: param_a()), + "c": 1 + }, + fail_fast=True, + num_samples=3, + scheduler=scheduler, + name="testPBTSync", + stop={"training_iteration": 3}, + ) + return analysis + + def testAsynchFail(self): + analysis = self.synchSetup(False) + self.assertTrue(any(analysis.dataframe()["mean_accuracy"] != 33)) + + def testSynchPass(self): + analysis = self.synchSetup(True) + self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33)) + + def testSynchPassLast(self): + analysis = self.synchSetup(True, param=[30, 20, 10]) + self.assertTrue(all(analysis.dataframe()["mean_accuracy"] == 33)) + + class PopulationBasedTrainingConfigTest(unittest.TestCase): def setUp(self): ray.init() @@ -81,6 +101,15 @@ class PopulationBasedTrainingConfigTest(unittest.TestCase): ray.shutdown() def testNoConfig(self): + def MockTrainingFunc(config): + a = config["a"] + b = config["b"] + c1 = config["c"]["c1"] + c2 = config["c"]["c2"] + + while True: + tune.report(mean_accuracy=a * b * (c1 + c2)) + scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="mean_accuracy", @@ -97,7 +126,7 @@ class PopulationBasedTrainingConfigTest(unittest.TestCase): ) tune.run( - MockTrainingFunc2, + MockTrainingFunc, fail_fast=True, num_samples=4, scheduler=scheduler, @@ -120,6 +149,31 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase): fix was not applied. See issues #9036, #9036 """ + + class MockTrainable(tune.Trainable): + def setup(self, config): + self.iter = 0 + self.a = config["a"] + self.b = config["b"] + self.c = config["c"] + + def step(self): + self.iter += 1 + return {"mean_accuracy": (self.a - self.iter) * self.b} + + def save_checkpoint(self, tmp_checkpoint_dir): + checkpoint_path = os.path.join(tmp_checkpoint_dir, + "model.mock") + with open(checkpoint_path, "wb") as fp: + pickle.dump((self.a, self.b, self.iter), fp) + return tmp_checkpoint_dir + + def load_checkpoint(self, tmp_checkpoint_dir): + checkpoint_path = os.path.join(tmp_checkpoint_dir, + "model.mock") + with open(checkpoint_path, "rb") as fp: + self.a, self.b, self.iter = pickle.load(fp) + scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="mean_accuracy", @@ -151,6 +205,25 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase): stop={"training_iteration": 3}) def testPermutationContinuationFunc(self): + def MockTrainingFunc(config, checkpoint_dir=None): + iter = 0 + a = config["a"] + b = config["b"] + + if checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, "model.mock") + with open(checkpoint_path, "rb") as fp: + a, b, iter = pickle.load(fp) + + while True: + iter += 1 + with tune.checkpoint_dir(step=iter) as checkpoint_dir: + checkpoint_path = os.path.join(checkpoint_dir, + "model.mock") + with open(checkpoint_path, "wb") as fp: + pickle.dump((a, b, iter), fp) + tune.report(mean_accuracy=(a - iter) * b) + scheduler = PopulationBasedTraining( time_attr="training_iteration", metric="mean_accuracy", diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 8e644c452..57c1d8119 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -460,6 +460,7 @@ class TrialRunner: self._update_trial_queue(blocking=wait_for_trial) with warn_if_slow("choose_trial_to_run"): trial = self._scheduler_alg.choose_trial_to_run(self) + logger.debug("Running trial {}".format(trial)) return trial def _process_events(self):