mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:46:49 +08:00
[tune] Add config logging functionality to PBT scheduler (#4680)
This commit is contained in:
@@ -2,15 +2,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import math
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -137,6 +142,9 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
perturbations from `hyperparam_mutations` are applied, and should
|
||||
return `config` updated as needed. You must specify at least one of
|
||||
`hyperparam_mutations` or `custom_explore_fn`.
|
||||
log_config (bool): Whether to log the ray config of each model to
|
||||
local_dir at each exploit. Allows config schedule to be
|
||||
reconstructed.
|
||||
|
||||
Example:
|
||||
>>> pbt = PopulationBasedTraining(
|
||||
@@ -161,7 +169,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
perturbation_interval=60.0,
|
||||
hyperparam_mutations={},
|
||||
resample_probability=0.25,
|
||||
custom_explore_fn=None):
|
||||
custom_explore_fn=None,
|
||||
log_config=True):
|
||||
if not hyperparam_mutations and not custom_explore_fn:
|
||||
raise TuneError(
|
||||
"You must specify at least one of `hyperparam_mutations` or "
|
||||
@@ -174,6 +183,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
self._resample_probability = resample_probability
|
||||
self._trial_state = {}
|
||||
self._custom_explore_fn = custom_explore_fn
|
||||
self._log_config = log_config
|
||||
|
||||
# Metrics
|
||||
self._num_checkpoints = 0
|
||||
@@ -212,8 +222,43 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def _log_config_on_step(self, trial_state, new_state, trial,
|
||||
trial_to_clone, new_config):
|
||||
"""Logs transition during exploit/exploit step.
|
||||
|
||||
For each step, logs: [target trial tag, clone trial tag, target trial
|
||||
iteration, clone trial iteration, old config, new config].
|
||||
"""
|
||||
trial_name, trial_to_clone_name = (trial_state.orig_tag,
|
||||
new_state.orig_tag)
|
||||
trial_id = "".join(itertools.takewhile(str.isdigit, trial_name))
|
||||
trial_to_clone_id = "".join(
|
||||
itertools.takewhile(str.isdigit, trial_to_clone_name))
|
||||
trial_path = os.path.join(trial.local_dir,
|
||||
"pbt_policy_" + trial_id + ".txt")
|
||||
trial_to_clone_path = os.path.join(
|
||||
trial_to_clone.local_dir,
|
||||
"pbt_policy_" + trial_to_clone_id + ".txt")
|
||||
policy = [
|
||||
trial_name, trial_to_clone_name,
|
||||
trial.last_result[TRAINING_ITERATION],
|
||||
trial_to_clone.last_result[TRAINING_ITERATION],
|
||||
trial_to_clone.config, new_config
|
||||
]
|
||||
# Log to global file.
|
||||
with open(os.path.join(trial.local_dir, "pbt_global.txt"), "a+") as f:
|
||||
f.write(json.dumps(policy) + "\n")
|
||||
# Overwrite state in target trial from trial_to_clone.
|
||||
if os.path.exists(trial_to_clone_path):
|
||||
shutil.copyfile(trial_to_clone_path, trial_path)
|
||||
# Log new exploit in target trial log.
|
||||
with open(trial_path, "a+") as f:
|
||||
f.write(json.dumps(policy) + "\n")
|
||||
|
||||
def _exploit(self, trial_executor, trial, trial_to_clone):
|
||||
"""Transfers perturbed state from trial_to_clone -> trial."""
|
||||
"""Transfers perturbed state from trial_to_clone -> trial.
|
||||
|
||||
If specified, also logs the updated hyperparam state."""
|
||||
|
||||
trial_state = self._trial_state[trial]
|
||||
new_state = self._trial_state[trial_to_clone]
|
||||
@@ -228,6 +273,11 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
"{} (score {}) -> {} (score {})".format(
|
||||
trial_to_clone, new_state.last_score, trial,
|
||||
trial_state.last_score))
|
||||
|
||||
if self._log_config:
|
||||
self._log_config_on_step(trial_state, new_state, trial,
|
||||
trial_to_clone, new_config)
|
||||
|
||||
new_tag = make_experiment_tag(trial_state.orig_tag, new_config,
|
||||
self._hyperparam_mutations)
|
||||
reset_successful = trial_executor.reset_trial(trial, new_config,
|
||||
|
||||
@@ -2,14 +2,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import tempfile
|
||||
import shutil
|
||||
import ray
|
||||
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
PopulationBasedTraining, MedianStoppingRule,
|
||||
TrialScheduler)
|
||||
|
||||
from ray.tune.schedulers.pbt import explore
|
||||
from ray.tune.trial import Trial, Resources, Checkpoint
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
@@ -563,13 +570,13 @@ class HyperbandSuite(unittest.TestCase):
|
||||
|
||||
def testFilterNoneBracket(self):
|
||||
sched, runner = self.schedulerSetup(100, 20)
|
||||
# `sched' should contains None brackets
|
||||
# "sched" should contains None brackets
|
||||
non_brackets = [
|
||||
b for hyperband in sched._hyperbands for b in hyperband
|
||||
if b is None
|
||||
]
|
||||
self.assertTrue(non_brackets)
|
||||
# Make sure `choose_trial_to_run' still works
|
||||
# Make sure "choose_trial_to_run" still works
|
||||
trial = sched.choose_trial_to_run(runner)
|
||||
self.assertIsNotNone(trial)
|
||||
|
||||
@@ -578,7 +585,7 @@ class _MockTrial(Trial):
|
||||
def __init__(self, i, config):
|
||||
self.trainable_name = "trial_{}".format(i)
|
||||
self.config = config
|
||||
self.experiment_tag = "tag"
|
||||
self.experiment_tag = "{}tag".format(i)
|
||||
self.trial_name_creator = None
|
||||
self.logger_running = False
|
||||
self.restored_checkpoint = None
|
||||
@@ -594,7 +601,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def basicSetup(self, resample_prob=0.0, explore=None):
|
||||
def basicSetup(self, resample_prob=0.0, explore=None, log_config=False):
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
perturbation_interval=10,
|
||||
@@ -604,7 +611,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
"float_factor": lambda: 100.0,
|
||||
"int_factor": lambda: 10,
|
||||
},
|
||||
custom_explore_fn=explore)
|
||||
custom_explore_fn=explore,
|
||||
log_config=log_config)
|
||||
runner = _MockTrialRunner(pbt)
|
||||
for i in range(5):
|
||||
trial = _MockTrial(
|
||||
@@ -738,20 +746,17 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
# Continuous case
|
||||
assertProduces(
|
||||
lambda: explore(
|
||||
{"v": 100}, {"v": lambda: random.choice([10, 100])}, 0.0,
|
||||
lambda x: x),
|
||||
{80, 120})
|
||||
lambda: explore({"v": 100}, {
|
||||
"v": lambda: random.choice([10, 100])
|
||||
}, 0.0, lambda x: x), {80, 120})
|
||||
assertProduces(
|
||||
lambda: explore(
|
||||
{"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 0.0,
|
||||
lambda x: x),
|
||||
{80.0, 120.0})
|
||||
lambda: explore({"v": 100.0}, {
|
||||
"v": lambda: random.choice([10, 100])
|
||||
}, 0.0, lambda x: x), {80.0, 120.0})
|
||||
assertProduces(
|
||||
lambda: explore(
|
||||
{"v": 100.0}, {"v": lambda: random.choice([10, 100])}, 1.0,
|
||||
lambda x: x),
|
||||
{10.0, 100.0})
|
||||
lambda: explore({"v": 100.0}, {
|
||||
"v": lambda: random.choice([10, 100])
|
||||
}, 1.0, lambda x: x), {10.0, 100.0})
|
||||
|
||||
def deep_add(seen, new_values):
|
||||
for k, new_value in new_values.items():
|
||||
@@ -776,30 +781,25 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
# Nested mutation and spec
|
||||
assertNestedProduces(
|
||||
lambda: explore(
|
||||
{
|
||||
"a": {
|
||||
"b": 4
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": 100
|
||||
}
|
||||
},
|
||||
lambda: explore({
|
||||
"a": {
|
||||
"b": 4
|
||||
},
|
||||
{
|
||||
"a": {
|
||||
"b": [3, 4, 8, 10]
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": lambda: random.choice([10, 100])
|
||||
}
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": 100
|
||||
}
|
||||
},
|
||||
0.0,
|
||||
lambda x: x),
|
||||
{
|
||||
}, {
|
||||
"a": {
|
||||
"b": [3, 4, 8, 10]
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": lambda: random.choice([10, 100])
|
||||
}
|
||||
},
|
||||
}, 0.0, lambda x: x), {
|
||||
"a": {
|
||||
"b": {3, 8}
|
||||
},
|
||||
@@ -814,30 +814,25 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
# Nested mutation and spec
|
||||
assertNestedProduces(
|
||||
lambda: explore(
|
||||
{
|
||||
"a": {
|
||||
"b": 4
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": 100
|
||||
}
|
||||
},
|
||||
lambda: explore({
|
||||
"a": {
|
||||
"b": 4
|
||||
},
|
||||
{
|
||||
"a": {
|
||||
"b": [3, 4, 8, 10]
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": lambda: random.choice([10, 100])
|
||||
}
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": 100
|
||||
}
|
||||
},
|
||||
0.0,
|
||||
custom_explore_fn),
|
||||
{
|
||||
}, {
|
||||
"a": {
|
||||
"b": [3, 4, 8, 10]
|
||||
},
|
||||
"1": {
|
||||
"2": {
|
||||
"3": lambda: random.choice([10, 100])
|
||||
}
|
||||
},
|
||||
}, 0.0, custom_explore_fn), {
|
||||
"a": {
|
||||
"b": {3, 8}
|
||||
},
|
||||
@@ -889,6 +884,47 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
pbt.on_trial_result(runner, trials[3], result(11000, 100))
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
|
||||
def testLogConfig(self):
|
||||
def check_policy(policy):
|
||||
self.assertIsInstance(policy[0], str)
|
||||
self.assertIsInstance(policy[1], str)
|
||||
self.assertIsInstance(policy[2], int)
|
||||
self.assertIsInstance(policy[3], int)
|
||||
self.assertIn(policy[0], ["0tag", "2tag", "3tag", "4tag"])
|
||||
self.assertIn(policy[1], ["0tag", "2tag", "3tag", "4tag"])
|
||||
self.assertIn(policy[2], [0, 2, 3, 4])
|
||||
self.assertIn(policy[3], [0, 2, 3, 4])
|
||||
for i in [4, 5]:
|
||||
self.assertIsInstance(policy[i], dict)
|
||||
for key in [
|
||||
"const_factor", "int_factor", "float_factor",
|
||||
"id_factor"
|
||||
]:
|
||||
self.assertIn(key, policy[i])
|
||||
self.assertIsInstance(policy[i]["float_factor"], float)
|
||||
self.assertIsInstance(policy[i]["int_factor"], int)
|
||||
self.assertIn(policy[i]["const_factor"], [3])
|
||||
self.assertIn(policy[i]["int_factor"], [8, 10, 12])
|
||||
self.assertIn(policy[i]["float_factor"], [2.4, 2, 1.6])
|
||||
self.assertIn(policy[i]["id_factor"], [3, 4, 100])
|
||||
|
||||
pbt, runner = self.basicSetup(log_config=True)
|
||||
trials = runner.get_trials()
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
for i, trial in enumerate(trials):
|
||||
trial.local_dir = tmpdir
|
||||
trial.last_result = {TRAINING_ITERATION: i}
|
||||
pbt.on_trial_result(runner, trials[0], result(15, -100))
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100))
|
||||
pbt.on_trial_result(runner, trials[2], result(20, 40))
|
||||
log_files = ["pbt_global.txt", "pbt_policy_0.txt", "pbt_policy_2.txt"]
|
||||
for log_file in log_files:
|
||||
self.assertTrue(os.path.exists(os.path.join(tmpdir, log_file)))
|
||||
raw_policy = open(os.path.join(tmpdir, log_file), "r").readlines()
|
||||
for line in raw_policy:
|
||||
check_policy(json.loads(line))
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testPostprocessingHook(self):
|
||||
def explore(new_config):
|
||||
new_config["id_factor"] = 42
|
||||
|
||||
Reference in New Issue
Block a user