[tune] Async Hyperband (#1595)

This commit is contained in:
Richard Liaw
2018-03-04 14:05:56 -08:00
committed by GitHub
parent ecb811c26e
commit 78716094b5
7 changed files with 338 additions and 3 deletions
+148
View File
@@ -0,0 +1,148 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
class AsyncHyperBandScheduler(FIFOScheduler):
"""Implements the Async Successive Halving.
This should provide similar theoretical performance as HyperBand but
avoid straggler issues that HyperBand faces. One implementation detail
is when using multiple brackets, trial allocation to bracket is done
randomly with over a softmax probability.
See https://openreview.net/forum?id=S1Y7OOlRZ
Args:
time_attr (str): The TrainingResult attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
reward_attr (str): The TrainingResult objective value attribute. As
with `time_attr`, this may refer to any objective value. Stopping
procedures will use this attribute.
max_t (float): max time units per trial. Trials will be stopped after
max_t time units (determined by time_attr) have passed.
grace_period (float): Only stop trials at least this old in time.
The units are the same as the attribute named by `time_attr`.
reduction_factor (float): Used to set halving rate and amount. This
is simply a unit-less scalar.
brackets (int): Number of brackets. Each bracket has a different
halving rate, specified by the reduction factor.
"""
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean', max_t=100,
grace_period=10, reduction_factor=3, brackets=3):
assert max_t > 0, "Max (time_attr) not valid!"
assert max_t >= grace_period, "grace_period must be <= max_t!"
assert grace_period > 0, "grace_period must be positive!"
assert reduction_factor > 1, "Reduction Factor not valid!"
assert brackets > 0, "brackets must be positive!"
FIFOScheduler.__init__(self)
self._reduction_factor = reduction_factor
self._max_t = max_t
self._trial_info = {} # Stores Trial -> Bracket
# Tracks state for new trial add
self._brackets = [_Bracket(
grace_period, max_t, reduction_factor, s) for s in range(brackets)]
self._counter = 0 # for
self._num_stopped = 0
self._reward_attr = reward_attr
self._time_attr = time_attr
def on_trial_add(self, trial_runner, trial):
sizes = np.array([len(b._rungs) for b in self._brackets])
probs = np.e ** (sizes - sizes.max())
normalized = probs / probs.sum()
idx = np.random.choice(len(self._brackets), p=normalized)
self._trial_info[trial.trial_id] = self._brackets[idx]
def on_trial_result(self, trial_runner, trial, result):
if getattr(result, self._time_attr) >= self._max_t:
self._num_stopped += 1
return TrialScheduler.STOP
bracket = self._trial_info[trial.trial_id]
action = bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
return action
def on_trial_complete(self, trial_runner, trial, result):
bracket = self._trial_info[trial.trial_id]
bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
del self._trial_info[trial.trial_id]
def on_trial_remove(self, trial_runner, trial):
del self._trial_info[trial.trial_id]
def debug_string(self):
out = "Using AsyncHyperBand: num_stopped={}".format(
self._num_stopped)
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
return out
class _Bracket():
"""Bookkeeping system to track the cutoffs.
Rungs are created in reversed order so that we can more easily find
the correct rung corresponding to the current iteration of the result.
Example:
>>> b = _Bracket(1, 10, 2, 3)
>>> b.on_result(trial1, 1, 2) # CONTINUE
>>> b.on_result(trial2, 1, 4) # CONTINUE
>>> b.cutoff(b._rungs[-1][1]) == 3.0 # rungs are reversed
>>> b.on_result(trial3, 1, 1) # STOP
>>> b.cutoff(b._rungs[0][1]) == 2.0
"""
def __init__(self, min_t, max_t, reduction_factor, s):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
self._rungs = [(min_t * self.rf**(k + s), {})
for k in reversed(range(MAX_RUNGS))]
def cutoff(self, recorded):
if not recorded:
return None
return np.percentile(list(recorded.values()), (1 - 1 / self.rf) * 100)
def on_result(self, trial, cur_iter, cur_rew):
action = TrialScheduler.CONTINUE
for milestone, recorded in self._rungs:
if cur_iter < milestone or trial.trial_id in recorded:
continue
else:
cutoff = self.cutoff(recorded)
if cutoff is not None and cur_rew < cutoff:
action = TrialScheduler.STOP
recorded[trial.trial_id] = cur_rew
break
return action
def debug_str(self):
iters = " | ".join(
["Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
for milestone, recorded in self._rungs])
return "Bracket: " + iters
if __name__ == '__main__':
sched = AsyncHyperBandScheduler(
grace_period=1, max_t=10, reduction_factor=2)
print(sched.debug_string())
bracket = sched._brackets[0]
print(bracket.cutoff({str(i): i for i in range(20)}))
+1 -1
View File
@@ -80,7 +80,7 @@ def make_parser(**kwargs):
"many times. Only applies if checkpointing is enabled.")
parser.add_argument(
"--scheduler", default="FIFO", type=str,
help="FIFO (default), MedianStopping, or HyperBand.")
help="FIFO (default), MedianStopping, AsyncHyperBand, or HyperBand.")
parser.add_argument(
"--scheduler-config", default="{}", type=json.loads,
help="Config options to pass to the scheduler.")
@@ -0,0 +1,77 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import random
import numpy as np
import ray
from ray.tune import Trainable, TrainingResult, register_trainable, \
run_experiments
from ray.tune.async_hyperband import AsyncHyperBandScheduler
class MyTrainableClass(Trainable):
"""Example agent whose learning curve is a random sigmoid.
The dummy hyperparameters "width" and "height" determine the slope and
maximum reward value reached.
"""
def _setup(self):
self.timestep = 0
def _train(self):
self.timestep += 1
v = np.tanh(float(self.timestep) / self.config["width"])
v *= self.config["height"]
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy (see tune/result.py).
return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps({"timestep": self.timestep}))
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path) as f:
self.timestep = json.loads(f.read())["timestep"]
register_trainable("my_class", MyTrainableClass)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
# asynchronous hyperband early stopping, configured with
# `episode_reward_mean` as the
# objective and `timesteps_total` as the time unit.
ahb = AsyncHyperBandScheduler(
time_attr="timesteps_total", reward_attr="episode_reward_mean",
grace_period=5, max_t=100)
run_experiments({
"asynchyperband_test": {
"run": "my_class",
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
"repeat": 20,
"resources": {"cpu": 1, "gpu": 0},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random()),
},
}
}, scheduler=ahb)
@@ -8,6 +8,7 @@ import unittest
import numpy as np
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.async_hyperband import AsyncHyperBandScheduler
from ray.tune.pbt import PopulationBasedTraining, explore
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.result import TrainingResult
@@ -757,5 +758,105 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(trials[0].config["float_factor"], 43)
class AsyncHyperBandSuite(unittest.TestCase):
def basicSetup(self, scheduler):
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
scheduler.on_trial_add(None, t1)
scheduler.on_trial_add(None, t2)
for i in range(10):
self.assertEqual(
scheduler.on_trial_result(None, t1, result(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
scheduler.on_trial_result(None, t2, result(i, 450)),
TrialScheduler.CONTINUE)
return t1, t2
def testAsyncHBOnComplete(self):
scheduler = AsyncHyperBandScheduler(
max_t=10, brackets=1)
t1, t2 = self.basicSetup(scheduler)
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
scheduler.on_trial_complete(None, t3, result(10, 1000))
self.assertEqual(
scheduler.on_trial_result(None, t2, result(101, 0)),
TrialScheduler.STOP)
def testAsyncHBGracePeriod(self):
scheduler = AsyncHyperBandScheduler(
grace_period=2.5, reduction_factor=3, brackets=1)
t1, t2 = self.basicSetup(scheduler)
scheduler.on_trial_complete(None, t1, result(10, 1000))
scheduler.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(1, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(2, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
def testAsyncHBAllCompletes(self):
scheduler = AsyncHyperBandScheduler(
max_t=10, brackets=10)
trials = [Trial("PPO") for i in range(10)]
for t in trials:
scheduler.on_trial_add(None, t)
for t in trials:
self.assertEqual(
scheduler.on_trial_result(None, t, result(10, -2)),
TrialScheduler.STOP)
def testAsyncHBUsesPercentile(self):
scheduler = AsyncHyperBandScheduler(
grace_period=1, max_t=10, reduction_factor=2, brackets=1)
t1, t2 = self.basicSetup(scheduler)
scheduler.on_trial_complete(None, t1, result(10, 1000))
scheduler.on_trial_complete(None, t2, result(10, 1000))
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(1, 260)),
TrialScheduler.STOP)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(2, 260)),
TrialScheduler.STOP)
def testAlternateMetrics(self):
def result2(t, rew):
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
scheduler = AsyncHyperBandScheduler(
grace_period=1, time_attr='training_iteration',
reward_attr='neg_mean_loss', brackets=1)
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
scheduler.on_trial_add(None, t1)
scheduler.on_trial_add(None, t2)
for i in range(10):
self.assertEqual(
scheduler.on_trial_result(None, t1, result2(i, i * 100)),
TrialScheduler.CONTINUE)
for i in range(5):
self.assertEqual(
scheduler.on_trial_result(None, t2, result2(i, 450)),
TrialScheduler.CONTINUE)
scheduler.on_trial_complete(None, t1, result2(10, 1000))
self.assertEqual(
scheduler.on_trial_result(None, t2, result2(5, 450)),
TrialScheduler.CONTINUE)
self.assertEqual(
scheduler.on_trial_result(None, t2, result2(6, 0)),
TrialScheduler.CONTINUE)
if __name__ == "__main__":
unittest.main(verbosity=2)
+2
View File
@@ -6,6 +6,7 @@ import time
from ray.tune import TuneError
from ray.tune.hyperband import HyperBandScheduler
from ray.tune.async_hyperband import AsyncHyperBandScheduler
from ray.tune.median_stopping_rule import MedianStoppingRule
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
from ray.tune.log_sync import wait_for_log_sync
@@ -19,6 +20,7 @@ _SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
"HyperBand": HyperBandScheduler,
"AsyncHyperBand": AsyncHyperBandScheduler,
}