[tune] Add config logging functionality to PBT scheduler (#4680)

This commit is contained in:
Daniel Ho
2019-04-27 19:32:19 -07:00
committed by Eric Liang
parent 686d4caefe
commit d7d2694b57
2 changed files with 152 additions and 66 deletions
+55 -5
View File
@@ -2,15 +2,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import math
import copy
import itertools
import logging
import json
import math
import os
import random
import shutil
from ray.tune.error import TuneError
from ray.tune.trial import Trial, Checkpoint
from ray.tune.result import TRAINING_ITERATION
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest.variant_generator import format_vars
from ray.tune.trial import Trial, Checkpoint
logger = logging.getLogger(__name__)
@@ -137,6 +142,9 @@ class PopulationBasedTraining(FIFOScheduler):
perturbations from `hyperparam_mutations` are applied, and should
return `config` updated as needed. You must specify at least one of
`hyperparam_mutations` or `custom_explore_fn`.
log_config (bool): Whether to log the ray config of each model to
local_dir at each exploit. Allows config schedule to be
reconstructed.
Example:
>>> pbt = PopulationBasedTraining(
@@ -161,7 +169,8 @@ class PopulationBasedTraining(FIFOScheduler):
perturbation_interval=60.0,
hyperparam_mutations={},
resample_probability=0.25,
custom_explore_fn=None):
custom_explore_fn=None,
log_config=True):
if not hyperparam_mutations and not custom_explore_fn:
raise TuneError(
"You must specify at least one of `hyperparam_mutations` or "
@@ -174,6 +183,7 @@ class PopulationBasedTraining(FIFOScheduler):
self._resample_probability = resample_probability
self._trial_state = {}
self._custom_explore_fn = custom_explore_fn
self._log_config = log_config
# Metrics
self._num_checkpoints = 0
@@ -212,8 +222,43 @@ class PopulationBasedTraining(FIFOScheduler):
return TrialScheduler.CONTINUE
def _log_config_on_step(self, trial_state, new_state, trial,
trial_to_clone, new_config):
"""Logs transition during exploit/exploit step.
For each step, logs: [target trial tag, clone trial tag, target trial
iteration, clone trial iteration, old config, new config].
"""
trial_name, trial_to_clone_name = (trial_state.orig_tag,
new_state.orig_tag)
trial_id = "".join(itertools.takewhile(str.isdigit, trial_name))
trial_to_clone_id = "".join(
itertools.takewhile(str.isdigit, trial_to_clone_name))
trial_path = os.path.join(trial.local_dir,
"pbt_policy_" + trial_id + ".txt")
trial_to_clone_path = os.path.join(
trial_to_clone.local_dir,
"pbt_policy_" + trial_to_clone_id + ".txt")
policy = [
trial_name, trial_to_clone_name,
trial.last_result[TRAINING_ITERATION],
trial_to_clone.last_result[TRAINING_ITERATION],
trial_to_clone.config, new_config
]
# Log to global file.
with open(os.path.join(trial.local_dir, "pbt_global.txt"), "a+") as f:
f.write(json.dumps(policy) + "\n")
# Overwrite state in target trial from trial_to_clone.
if os.path.exists(trial_to_clone_path):
shutil.copyfile(trial_to_clone_path, trial_path)
# Log new exploit in target trial log.
with open(trial_path, "a+") as f:
f.write(json.dumps(policy) + "\n")
def _exploit(self, trial_executor, trial, trial_to_clone):
"""Transfers perturbed state from trial_to_clone -> trial."""
"""Transfers perturbed state from trial_to_clone -> trial.
If specified, also logs the updated hyperparam state."""
trial_state = self._trial_state[trial]
new_state = self._trial_state[trial_to_clone]
@@ -228,6 +273,11 @@ class PopulationBasedTraining(FIFOScheduler):
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
if self._log_config:
self._log_config_on_step(trial_state, new_state, trial,
trial_to_clone, new_config)
new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
self._hyperparam_mutations)
reset_successful = trial_executor.reset_trial(trial, new_config,
+97 -61
View File
@@ -2,14 +2,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import json
import random
import unittest
import numpy as np
import sys
import tempfile
import shutil
import ray
from ray.tune.result import TRAINING_ITERATION
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
PopulationBasedTraining, MedianStoppingRule,
TrialScheduler)
from ray.tune.schedulers.pbt import explore
from ray.tune.trial import Trial, Resources, Checkpoint
from ray.tune.trial_executor import TrialExecutor
@@ -563,13 +570,13 @@ class HyperbandSuite(unittest.TestCase):
def testFilterNoneBracket(self):
sched, runner = self.schedulerSetup(100, 20)
# `sched' should contains None brackets
# "sched" should contains None brackets
non_brackets = [
b for hyperband in sched._hyperbands for b in hyperband
if b is None
]
self.assertTrue(non_brackets)
# Make sure `choose_trial_to_run' still works
# Make sure "choose_trial_to_run" still works
trial = sched.choose_trial_to_run(runner)
self.assertIsNotNone(trial)
@@ -578,7 +585,7 @@ class _MockTrial(Trial):
def __init__(self, i, config):
self.trainable_name = "trial_{}".format(i)
self.config = config
self.experiment_tag = "tag"
self.experiment_tag = "{}tag".format(i)
self.trial_name_creator = None
self.logger_running = False
self.restored_checkpoint = None
@@ -594,7 +601,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
ray.shutdown()
_register_all() # re-register the evicted objects
def basicSetup(self, resample_prob=0.0, explore=None):
def basicSetup(self, resample_prob=0.0, explore=None, log_config=False):
pbt = PopulationBasedTraining(
time_attr="training_iteration",
perturbation_interval=10,
@@ -604,7 +611,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
"float_factor": lambda: 100.0,
"int_factor": lambda: 10,
},
custom_explore_fn=explore)
custom_explore_fn=explore,
log_config=log_config)
runner = _MockTrialRunner(pbt)
for i in range(5):
trial = _MockTrial(
@@ -738,20 +746,17 @@ class PopulationBasedTestingSuite(unittest.TestCase):
# Continuous case
assertProduces(
lambda: explore(
{"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0,
lambda x: x),
{80, 120})
lambda: explore({"v": 100}, {
"v": lambda: random.choice([10, 100])
}, 0.0, lambda x: x), {80, 120})
assertProduces(
lambda: explore(
{"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0,
lambda x: x),
{80.0, 120.0})
lambda: explore({"v": 100.0}, {
"v": lambda: random.choice([10, 100])
}, 0.0, lambda x: x), {80.0, 120.0})
assertProduces(
lambda: explore(
{"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0,
lambda x: x),
{10.0, 100.0})
lambda: explore({"v": 100.0}, {
"v": lambda: random.choice([10, 100])
}, 1.0, lambda x: x), {10.0, 100.0})
def deep_add(seen, new_values):
for k, new_value in new_values.items():
@@ -776,30 +781,25 @@ class PopulationBasedTestingSuite(unittest.TestCase):
# Nested mutation and spec
assertNestedProduces(
lambda: explore(
{
"a": {
"b": 4
},
"1": {
"2": {
"3": 100
}
},
lambda: explore({
"a": {
"b": 4
},
{
"a": {
"b": [3, 4, 8, 10]
},
"1": {
"2": {
"3": lambda: random.choice([10, 100])
}
},
"1": {
"2": {
"3": 100
}
},
0.0,
lambda x: x),
{
}, {
"a": {
"b": [3, 4, 8, 10]
},
"1": {
"2": {
"3": lambda: random.choice([10, 100])
}
},
}, 0.0, lambda x: x), {
"a": {
"b": {3, 8}
},
@@ -814,30 +814,25 @@ class PopulationBasedTestingSuite(unittest.TestCase):
# Nested mutation and spec
assertNestedProduces(
lambda: explore(
{
"a": {
"b": 4
},
"1": {
"2": {
"3": 100
}
},
lambda: explore({
"a": {
"b": 4
},
{
"a": {
"b": [3, 4, 8, 10]
},
"1": {
"2": {
"3": lambda: random.choice([10, 100])
}
},
"1": {
"2": {
"3": 100
}
},
0.0,
custom_explore_fn),
{
}, {
"a": {
"b": [3, 4, 8, 10]
},
"1": {
"2": {
"3": lambda: random.choice([10, 100])
}
},
}, 0.0, custom_explore_fn), {
"a": {
"b": {3, 8}
},
@@ -889,6 +884,47 @@ class PopulationBasedTestingSuite(unittest.TestCase):
pbt.on_trial_result(runner, trials[3], result(11000, 100))
self.assertEqual(pbt._num_perturbations, 2)
def testLogConfig(self):
def check_policy(policy):
self.assertIsInstance(policy[0], str)
self.assertIsInstance(policy[1], str)
self.assertIsInstance(policy[2], int)
self.assertIsInstance(policy[3], int)
self.assertIn(policy[0], ["0tag", "2tag", "3tag", "4tag"])
self.assertIn(policy[1], ["0tag", "2tag", "3tag", "4tag"])
self.assertIn(policy[2], [0, 2, 3, 4])
self.assertIn(policy[3], [0, 2, 3, 4])
for i in [4, 5]:
self.assertIsInstance(policy[i], dict)
for key in [
"const_factor", "int_factor", "float_factor",
"id_factor"
]:
self.assertIn(key, policy[i])
self.assertIsInstance(policy[i]["float_factor"], float)
self.assertIsInstance(policy[i]["int_factor"], int)
self.assertIn(policy[i]["const_factor"], [3])
self.assertIn(policy[i]["int_factor"], [8, 10, 12])
self.assertIn(policy[i]["float_factor"], [2.4, 2, 1.6])
self.assertIn(policy[i]["id_factor"], [3, 4, 100])
pbt, runner = self.basicSetup(log_config=True)
trials = runner.get_trials()
tmpdir = tempfile.mkdtemp()
for i, trial in enumerate(trials):
trial.local_dir = tmpdir
trial.last_result = {TRAINING_ITERATION: i}
pbt.on_trial_result(runner, trials[0], result(15, -100))
pbt.on_trial_result(runner, trials[0], result(20, -100))
pbt.on_trial_result(runner, trials[2], result(20, 40))
log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_2.txt"]
for log_file in log_files:
self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file)))
raw_policy = open(os.path.join(tmpdir, log_file), "r").readlines()
for line in raw_policy:
check_policy(json.loads(line))
shutil.rmtree(tmpdir)
def testPostprocessingHook(self):
def explore(new_config):
new_config["id_factor"] = 42