mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 22:53:49 +08:00
[tune] Async Hyperband (#1595)
This commit is contained in:
@@ -0,0 +1,148 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
"""Implements the Async Successive Halving.
|
||||
|
||||
This should provide similar theoretical performance as HyperBand but
|
||||
avoid straggler issues that HyperBand faces. One implementation detail
|
||||
is when using multiple brackets, trial allocation to bracket is done
|
||||
randomly with over a softmax probability.
|
||||
|
||||
See https://openreview.net/forum?id=S1Y7OOlRZ
|
||||
|
||||
Args:
|
||||
time_attr (str): The TrainingResult attr to use for comparing time.
|
||||
Note that you can pass in something non-temporal such as
|
||||
`training_iteration` as a measure of progress, the only requirement
|
||||
is that the attribute should increase monotonically.
|
||||
reward_attr (str): The TrainingResult objective value attribute. As
|
||||
with `time_attr`, this may refer to any objective value. Stopping
|
||||
procedures will use this attribute.
|
||||
max_t (float): max time units per trial. Trials will be stopped after
|
||||
max_t time units (determined by time_attr) have passed.
|
||||
grace_period (float): Only stop trials at least this old in time.
|
||||
The units are the same as the attribute named by `time_attr`.
|
||||
reduction_factor (float): Used to set halving rate and amount. This
|
||||
is simply a unit-less scalar.
|
||||
brackets (int): Number of brackets. Each bracket has a different
|
||||
halving rate, specified by the reduction factor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean', max_t=100,
|
||||
grace_period=10, reduction_factor=3, brackets=3):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
assert max_t >= grace_period, "grace_period must be <= max_t!"
|
||||
assert grace_period > 0, "grace_period must be positive!"
|
||||
assert reduction_factor > 1, "Reduction Factor not valid!"
|
||||
assert brackets > 0, "brackets must be positive!"
|
||||
FIFOScheduler.__init__(self)
|
||||
self._reduction_factor = reduction_factor
|
||||
self._max_t = max_t
|
||||
|
||||
self._trial_info = {} # Stores Trial -> Bracket
|
||||
|
||||
# Tracks state for new trial add
|
||||
self._brackets = [_Bracket(
|
||||
grace_period, max_t, reduction_factor, s) for s in range(brackets)]
|
||||
self._counter = 0 # for
|
||||
self._num_stopped = 0
|
||||
self._reward_attr = reward_attr
|
||||
self._time_attr = time_attr
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
sizes = np.array([len(b._rungs) for b in self._brackets])
|
||||
probs = np.e ** (sizes - sizes.max())
|
||||
normalized = probs / probs.sum()
|
||||
idx = np.random.choice(len(self._brackets), p=normalized)
|
||||
self._trial_info[trial.trial_id] = self._brackets[idx]
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
if getattr(result, self._time_attr) >= self._max_t:
|
||||
self._num_stopped += 1
|
||||
return TrialScheduler.STOP
|
||||
|
||||
bracket = self._trial_info[trial.trial_id]
|
||||
action = bracket.on_result(
|
||||
trial,
|
||||
getattr(result, self._time_attr),
|
||||
getattr(result, self._reward_attr))
|
||||
return action
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
bracket = self._trial_info[trial.trial_id]
|
||||
bracket.on_result(
|
||||
trial,
|
||||
getattr(result, self._time_attr),
|
||||
getattr(result, self._reward_attr))
|
||||
del self._trial_info[trial.trial_id]
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
del self._trial_info[trial.trial_id]
|
||||
|
||||
def debug_string(self):
|
||||
out = "Using AsyncHyperBand: num_stopped={}".format(
|
||||
self._num_stopped)
|
||||
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
|
||||
return out
|
||||
|
||||
|
||||
class _Bracket():
|
||||
"""Bookkeeping system to track the cutoffs.
|
||||
|
||||
Rungs are created in reversed order so that we can more easily find
|
||||
the correct rung corresponding to the current iteration of the result.
|
||||
|
||||
Example:
|
||||
>>> b = _Bracket(1, 10, 2, 3)
|
||||
>>> b.on_result(trial1, 1, 2) # CONTINUE
|
||||
>>> b.on_result(trial2, 1, 4) # CONTINUE
|
||||
>>> b.cutoff(b._rungs[-1][1]) == 3.0 # rungs are reversed
|
||||
>>> b.on_result(trial3, 1, 1) # STOP
|
||||
>>> b.cutoff(b._rungs[0][1]) == 2.0
|
||||
"""
|
||||
def __init__(self, min_t, max_t, reduction_factor, s):
|
||||
self.rf = reduction_factor
|
||||
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
|
||||
self._rungs = [(min_t * self.rf**(k + s), {})
|
||||
for k in reversed(range(MAX_RUNGS))]
|
||||
|
||||
def cutoff(self, recorded):
|
||||
if not recorded:
|
||||
return None
|
||||
return np.percentile(list(recorded.values()), (1 - 1 / self.rf) * 100)
|
||||
|
||||
def on_result(self, trial, cur_iter, cur_rew):
|
||||
action = TrialScheduler.CONTINUE
|
||||
for milestone, recorded in self._rungs:
|
||||
if cur_iter < milestone or trial.trial_id in recorded:
|
||||
continue
|
||||
else:
|
||||
cutoff = self.cutoff(recorded)
|
||||
if cutoff is not None and cur_rew < cutoff:
|
||||
action = TrialScheduler.STOP
|
||||
recorded[trial.trial_id] = cur_rew
|
||||
break
|
||||
return action
|
||||
|
||||
def debug_str(self):
|
||||
iters = " | ".join(
|
||||
["Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
|
||||
for milestone, recorded in self._rungs])
|
||||
return "Bracket: " + iters
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sched = AsyncHyperBandScheduler(
|
||||
grace_period=1, max_t=10, reduction_factor=2)
|
||||
print(sched.debug_string())
|
||||
bracket = sched._brackets[0]
|
||||
print(bracket.cutoff({str(i): i for i in range(20)}))
|
||||
@@ -80,7 +80,7 @@ def make_parser(**kwargs):
|
||||
"many times. Only applies if checkpointing is enabled.")
|
||||
parser.add_argument(
|
||||
"--scheduler", default="FIFO", type=str,
|
||||
help="FIFO (default), MedianStopping, or HyperBand.")
|
||||
help="FIFO (default), MedianStopping, AsyncHyperBand, or HyperBand.")
|
||||
parser.add_argument(
|
||||
"--scheduler-config", default="{}", type=json.loads,
|
||||
help="Config options to pass to the scheduler.")
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.tune import Trainable, TrainingResult, register_trainable, \
|
||||
run_experiments
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
|
||||
|
||||
class MyTrainableClass(Trainable):
|
||||
"""Example agent whose learning curve is a random sigmoid.
|
||||
|
||||
The dummy hyperparameters "width" and "height" determine the slope and
|
||||
maximum reward value reached.
|
||||
"""
|
||||
|
||||
def _setup(self):
|
||||
self.timestep = 0
|
||||
|
||||
def _train(self):
|
||||
self.timestep += 1
|
||||
v = np.tanh(float(self.timestep) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
|
||||
# Here we use `episode_reward_mean`, but you can also report other
|
||||
# objectives such as loss or accuracy (see tune/result.py).
|
||||
return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1)
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps({"timestep": self.timestep}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path) as f:
|
||||
self.timestep = json.loads(f.read())["timestep"]
|
||||
|
||||
|
||||
register_trainable("my_class", MyTrainableClass)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
# asynchronous hyperband early stopping, configured with
|
||||
# `episode_reward_mean` as the
|
||||
# objective and `timesteps_total` as the time unit.
|
||||
ahb = AsyncHyperBandScheduler(
|
||||
time_attr="timesteps_total", reward_attr="episode_reward_mean",
|
||||
grace_period=5, max_t=100)
|
||||
|
||||
run_experiments({
|
||||
"asynchyperband_test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
|
||||
"repeat": 20,
|
||||
"resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
"height": lambda spec: int(100 * random.random()),
|
||||
},
|
||||
}
|
||||
}, scheduler=ahb)
|
||||
@@ -8,6 +8,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.pbt import PopulationBasedTraining, explore
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.result import TrainingResult
|
||||
@@ -757,5 +758,105 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertEqual(trials[0].config["float_factor"], 43)
|
||||
|
||||
|
||||
class AsyncHyperBandSuite(unittest.TestCase):
|
||||
def basicSetup(self, scheduler):
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
scheduler.on_trial_add(None, t1)
|
||||
scheduler.on_trial_add(None, t2)
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t1, result(i, i * 100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
for i in range(5):
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t2, result(i, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
return t1, t2
|
||||
|
||||
def testAsyncHBOnComplete(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
max_t=10, brackets=1)
|
||||
t1, t2 = self.basicSetup(scheduler)
|
||||
t3 = Trial("PPO")
|
||||
scheduler.on_trial_add(None, t3)
|
||||
scheduler.on_trial_complete(None, t3, result(10, 1000))
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t2, result(101, 0)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testAsyncHBGracePeriod(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=2.5, reduction_factor=3, brackets=1)
|
||||
t1, t2 = self.basicSetup(scheduler)
|
||||
scheduler.on_trial_complete(None, t1, result(10, 1000))
|
||||
scheduler.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
scheduler.on_trial_add(None, t3)
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t3, result(1, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t3, result(2, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testAsyncHBAllCompletes(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
max_t=10, brackets=10)
|
||||
trials = [Trial("PPO") for i in range(10)]
|
||||
for t in trials:
|
||||
scheduler.on_trial_add(None, t)
|
||||
|
||||
for t in trials:
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t, result(10, -2)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testAsyncHBUsesPercentile(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=1, max_t=10, reduction_factor=2, brackets=1)
|
||||
t1, t2 = self.basicSetup(scheduler)
|
||||
scheduler.on_trial_complete(None, t1, result(10, 1000))
|
||||
scheduler.on_trial_complete(None, t2, result(10, 1000))
|
||||
t3 = Trial("PPO")
|
||||
scheduler.on_trial_add(None, t3)
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t3, result(1, 260)),
|
||||
TrialScheduler.STOP)
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t3, result(2, 260)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testAlternateMetrics(self):
|
||||
def result2(t, rew):
|
||||
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
|
||||
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=1, time_attr='training_iteration',
|
||||
reward_attr='neg_mean_loss', brackets=1)
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
scheduler.on_trial_add(None, t1)
|
||||
scheduler.on_trial_add(None, t2)
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t1, result2(i, i * 100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
for i in range(5):
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t2, result2(i, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
scheduler.on_trial_complete(None, t1, result2(10, 1000))
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t2, result2(5, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t2, result2(6, 0)),
|
||||
TrialScheduler.CONTINUE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -6,6 +6,7 @@ import time
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.hyperband import HyperBandScheduler
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
from ray.tune.median_stopping_rule import MedianStoppingRule
|
||||
from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL
|
||||
from ray.tune.log_sync import wait_for_log_sync
|
||||
@@ -19,6 +20,7 @@ _SCHEDULERS = {
|
||||
"FIFO": FIFOScheduler,
|
||||
"MedianStopping": MedianStoppingRule,
|
||||
"HyperBand": HyperBandScheduler,
|
||||
"AsyncHyperBand": AsyncHyperBandScheduler,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user