diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index cdb11dc98..e5d793fd1 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -2,15 +2,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import random -import math import copy +import itertools import logging +import json +import math +import os +import random +import shutil from ray.tune.error import TuneError -from ray.tune.trial import Trial, Checkpoint +from ray.tune.result import TRAINING_ITERATION from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest.variant_generator import format_vars +from ray.tune.trial import Trial, Checkpoint logger = logging.getLogger(__name__) @@ -137,6 +142,9 @@ class PopulationBasedTraining(FIFOScheduler): perturbations from `hyperparam_mutations` are applied, and should return `config` updated as needed. You must specify at least one of `hyperparam_mutations` or `custom_explore_fn`. + log_config (bool): Whether to log the ray config of each model to + local_dir at each exploit. Allows config schedule to be + reconstructed. Example: >>> pbt = PopulationBasedTraining( @@ -161,7 +169,8 @@ class PopulationBasedTraining(FIFOScheduler): perturbation_interval=60.0, hyperparam_mutations={}, resample_probability=0.25, - custom_explore_fn=None): + custom_explore_fn=None, + log_config=True): if not hyperparam_mutations and not custom_explore_fn: raise TuneError( "You must specify at least one of `hyperparam_mutations` or " @@ -174,6 +183,7 @@ class PopulationBasedTraining(FIFOScheduler): self._resample_probability = resample_probability self._trial_state = {} self._custom_explore_fn = custom_explore_fn + self._log_config = log_config # Metrics self._num_checkpoints = 0 @@ -212,8 +222,43 @@ class PopulationBasedTraining(FIFOScheduler): return TrialScheduler.CONTINUE + def _log_config_on_step(self, trial_state, new_state, trial, + trial_to_clone, new_config): + """Logs transition during exploit/exploit step. + + For each step, logs: [target trial tag, clone trial tag, target trial + iteration, clone trial iteration, old config, new config]. + """ + trial_name, trial_to_clone_name = (trial_state.orig_tag, + new_state.orig_tag) + trial_id = "".join(itertools.takewhile(str.isdigit, trial_name)) + trial_to_clone_id = "".join( + itertools.takewhile(str.isdigit, trial_to_clone_name)) + trial_path = os.path.join(trial.local_dir, + "pbt_policy_" + trial_id + ".txt") + trial_to_clone_path = os.path.join( + trial_to_clone.local_dir, + "pbt_policy_" + trial_to_clone_id + ".txt") + policy = [ + trial_name, trial_to_clone_name, + trial.last_result[TRAINING_ITERATION], + trial_to_clone.last_result[TRAINING_ITERATION], + trial_to_clone.config, new_config + ] + # Log to global file. + with open(os.path.join(trial.local_dir, "pbt_global.txt"), "a+") as f: + f.write(json.dumps(policy) + "\n") + # Overwrite state in target trial from trial_to_clone. + if os.path.exists(trial_to_clone_path): + shutil.copyfile(trial_to_clone_path, trial_path) + # Log new exploit in target trial log. + with open(trial_path, "a+") as f: + f.write(json.dumps(policy) + "\n") + def _exploit(self, trial_executor, trial, trial_to_clone): - """Transfers perturbed state from trial_to_clone -> trial.""" + """Transfers perturbed state from trial_to_clone -> trial. + + If specified, also logs the updated hyperparam state.""" trial_state = self._trial_state[trial] new_state = self._trial_state[trial_to_clone] @@ -228,6 +273,11 @@ class PopulationBasedTraining(FIFOScheduler): "{} (score {}) -> {} (score {})".format( trial_to_clone, new_state.last_score, trial, trial_state.last_score)) + + if self._log_config: + self._log_config_on_step(trial_state, new_state, trial, + trial_to_clone, new_config) + new_tag = make_experiment_tag(trial_state.orig_tag, new_config, self._hyperparam_mutations) reset_successful = trial_executor.reset_trial(trial, new_config, diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index aaa0dc49c..f574a4921 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -2,14 +2,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import json import random import unittest import numpy as np import sys +import tempfile +import shutil import ray + +from ray.tune.result import TRAINING_ITERATION from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, PopulationBasedTraining, MedianStoppingRule, TrialScheduler) + from ray.tune.schedulers.pbt import explore from ray.tune.trial import Trial, Resources, Checkpoint from ray.tune.trial_executor import TrialExecutor @@ -563,13 +570,13 @@ class HyperbandSuite(unittest.TestCase): def testFilterNoneBracket(self): sched, runner = self.schedulerSetup(100, 20) - # `sched' should contains None brackets + # "sched" should contains None brackets non_brackets = [ b for hyperband in sched._hyperbands for b in hyperband if b is None ] self.assertTrue(non_brackets) - # Make sure `choose_trial_to_run' still works + # Make sure "choose_trial_to_run" still works trial = sched.choose_trial_to_run(runner) self.assertIsNotNone(trial) @@ -578,7 +585,7 @@ class _MockTrial(Trial): def __init__(self, i, config): self.trainable_name = "trial_{}".format(i) self.config = config - self.experiment_tag = "tag" + self.experiment_tag = "{}tag".format(i) self.trial_name_creator = None self.logger_running = False self.restored_checkpoint = None @@ -594,7 +601,7 @@ class PopulationBasedTestingSuite(unittest.TestCase): ray.shutdown() _register_all() # re-register the evicted objects - def basicSetup(self, resample_prob=0.0, explore=None): + def basicSetup(self, resample_prob=0.0, explore=None, log_config=False): pbt = PopulationBasedTraining( time_attr="training_iteration", perturbation_interval=10, @@ -604,7 +611,8 @@ class PopulationBasedTestingSuite(unittest.TestCase): "float_factor": lambda: 100.0, "int_factor": lambda: 10, }, - custom_explore_fn=explore) + custom_explore_fn=explore, + log_config=log_config) runner = _MockTrialRunner(pbt) for i in range(5): trial = _MockTrial( @@ -738,20 +746,17 @@ class PopulationBasedTestingSuite(unittest.TestCase): # Continuous case assertProduces( - lambda: explore( - {"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0, - lambda x: x), - {80, 120}) + lambda: explore({"v": 100}, { + "v": lambda: random.choice([10, 100]) + }, 0.0, lambda x: x), {80, 120}) assertProduces( - lambda: explore( - {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0, - lambda x: x), - {80.0, 120.0}) + lambda: explore({"v": 100.0}, { + "v": lambda: random.choice([10, 100]) + }, 0.0, lambda x: x), {80.0, 120.0}) assertProduces( - lambda: explore( - {"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0, - lambda x: x), - {10.0, 100.0}) + lambda: explore({"v": 100.0}, { + "v": lambda: random.choice([10, 100]) + }, 1.0, lambda x: x), {10.0, 100.0}) def deep_add(seen, new_values): for k, new_value in new_values.items(): @@ -776,30 +781,25 @@ class PopulationBasedTestingSuite(unittest.TestCase): # Nested mutation and spec assertNestedProduces( - lambda: explore( - { - "a": { - "b": 4 - }, - "1": { - "2": { - "3": 100 - } - }, + lambda: explore({ + "a": { + "b": 4 }, - { - "a": { - "b": [3, 4, 8, 10] - }, - "1": { - "2": { - "3": lambda: random.choice([10, 100]) - } - }, + "1": { + "2": { + "3": 100 + } }, - 0.0, - lambda x: x), - { + }, { + "a": { + "b": [3, 4, 8, 10] + }, + "1": { + "2": { + "3": lambda: random.choice([10, 100]) + } + }, + }, 0.0, lambda x: x), { "a": { "b": {3, 8} }, @@ -814,30 +814,25 @@ class PopulationBasedTestingSuite(unittest.TestCase): # Nested mutation and spec assertNestedProduces( - lambda: explore( - { - "a": { - "b": 4 - }, - "1": { - "2": { - "3": 100 - } - }, + lambda: explore({ + "a": { + "b": 4 }, - { - "a": { - "b": [3, 4, 8, 10] - }, - "1": { - "2": { - "3": lambda: random.choice([10, 100]) - } - }, + "1": { + "2": { + "3": 100 + } }, - 0.0, - custom_explore_fn), - { + }, { + "a": { + "b": [3, 4, 8, 10] + }, + "1": { + "2": { + "3": lambda: random.choice([10, 100]) + } + }, + }, 0.0, custom_explore_fn), { "a": { "b": {3, 8} }, @@ -889,6 +884,47 @@ class PopulationBasedTestingSuite(unittest.TestCase): pbt.on_trial_result(runner, trials[3], result(11000, 100)) self.assertEqual(pbt._num_perturbations, 2) + def testLogConfig(self): + def check_policy(policy): + self.assertIsInstance(policy[0], str) + self.assertIsInstance(policy[1], str) + self.assertIsInstance(policy[2], int) + self.assertIsInstance(policy[3], int) + self.assertIn(policy[0], ["0tag", "2tag", "3tag", "4tag"]) + self.assertIn(policy[1], ["0tag", "2tag", "3tag", "4tag"]) + self.assertIn(policy[2], [0, 2, 3, 4]) + self.assertIn(policy[3], [0, 2, 3, 4]) + for i in [4, 5]: + self.assertIsInstance(policy[i], dict) + for key in [ + "const_factor", "int_factor", "float_factor", + "id_factor" + ]: + self.assertIn(key, policy[i]) + self.assertIsInstance(policy[i]["float_factor"], float) + self.assertIsInstance(policy[i]["int_factor"], int) + self.assertIn(policy[i]["const_factor"], [3]) + self.assertIn(policy[i]["int_factor"], [8, 10, 12]) + self.assertIn(policy[i]["float_factor"], [2.4, 2, 1.6]) + self.assertIn(policy[i]["id_factor"], [3, 4, 100]) + + pbt, runner = self.basicSetup(log_config=True) + trials = runner.get_trials() + tmpdir = tempfile.mkdtemp() + for i, trial in enumerate(trials): + trial.local_dir = tmpdir + trial.last_result = {TRAINING_ITERATION: i} + pbt.on_trial_result(runner, trials[0], result(15, -100)) + pbt.on_trial_result(runner, trials[0], result(20, -100)) + pbt.on_trial_result(runner, trials[2], result(20, 40)) + log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_2.txt"] + for log_file in log_files: + self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file))) + raw_policy = open(os.path.join(tmpdir, log_file), "r").readlines() + for line in raw_policy: + check_policy(json.loads(line)) + shutil.rmtree(tmpdir) + def testPostprocessingHook(self): def explore(new_config): new_config["id_factor"] = 42