mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 08:10:28 +08:00
[tune] Implement median stopping rule (#1170)
* trial scheduler interface * remove * wip median stopping * remove * median stopping rule * update * docs * update * Revrt * update * comments * fix tesT
This commit is contained in:
@@ -23,6 +23,7 @@ import yaml
|
||||
|
||||
import ray
|
||||
from ray.tune.config_parser import make_parser, parse_to_trials
|
||||
from ray.tune.trial_scheduler import MedianStoppingRule
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
@@ -46,7 +47,7 @@ parser.add_argument("-f", "--config-file", default=None, type=str,
|
||||
|
||||
def main(argv):
|
||||
args = parser.parse_args(argv)
|
||||
runner = TrialRunner()
|
||||
runner = TrialRunner(MedianStoppingRule())
|
||||
|
||||
if args.config_file:
|
||||
with open(args.config_file) as f:
|
||||
|
||||
@@ -148,6 +148,7 @@ class TrialRunner(object):
|
||||
trial.last_result = result
|
||||
|
||||
if trial.should_stop(result):
|
||||
self._scheduler_alg.on_trial_complete(self, trial, result)
|
||||
self._stop_trial(trial)
|
||||
else:
|
||||
decision = self._scheduler_alg.on_trial_result(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
@@ -17,6 +20,13 @@ class TrialScheduler(object):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
"""Notification for the completion of trial.
|
||||
|
||||
This will only be called when the trial completes naturally."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, trials):
|
||||
"""Called to choose a new trial to run.
|
||||
|
||||
@@ -32,9 +42,14 @@ class TrialScheduler(object):
|
||||
|
||||
|
||||
class FIFOScheduler(TrialScheduler):
|
||||
"""Simple scheduler that just runs trials in submission order."""
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
pass
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PENDING and
|
||||
@@ -44,3 +59,85 @@ class FIFOScheduler(TrialScheduler):
|
||||
|
||||
def debug_string(self):
|
||||
return "Using FIFO scheduling algorithm."
|
||||
|
||||
|
||||
# TODO(ekl) expose this in the command line API
|
||||
class MedianStoppingRule(FIFOScheduler):
|
||||
"""Implements the median stopping rule as described in the Vizier paper:
|
||||
|
||||
https://research.google.com/pubs/pub46180.html
|
||||
|
||||
Args:
|
||||
time_attr (str): The TrainingResult attr to use for comparing time.
|
||||
Note that you can pass in something non-temporal such as
|
||||
`training_iteration` as a measure of progress, the only requirement
|
||||
is that the attribute should increase monotonically.
|
||||
reward_attr (str): The TrainingResult objective value attribute. As
|
||||
with `time_attr`, this may refer to any objective value that
|
||||
is supposed to increase with time.
|
||||
grace_period (float): Only stop trials at least this old in time.
|
||||
The units are the same as the attribute named by `time_attr`.
|
||||
min_samples_required (int): Min samples to compute median over.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='time_total_s', reward_attr='episode_reward_mean',
|
||||
grace_period=60.0, min_samples_required=3):
|
||||
FIFOScheduler.__init__(self)
|
||||
self._completed_trials = set()
|
||||
self._results = collections.defaultdict(list)
|
||||
self._grace_period = grace_period
|
||||
self._min_samples_required = min_samples_required
|
||||
self._reward_attr = reward_attr
|
||||
self._time_attr = time_attr
|
||||
self._num_stopped = 0
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
"""Callback for early stopping.
|
||||
|
||||
This stopping rule stops a running trial if the trial's best objective
|
||||
value by step `t` is strictly worse than the median of the running
|
||||
averages of all completed trials' objectives reported up to step `t`.
|
||||
"""
|
||||
|
||||
time = getattr(result, self._time_attr)
|
||||
self._results[trial].append(result)
|
||||
median_result = self._get_median_result(time)
|
||||
best_result = self._best_result(trial)
|
||||
print("Trial {} best res={} vs median res={} at t={}".format(
|
||||
trial, best_result, median_result, time))
|
||||
if best_result < median_result and time > self._grace_period:
|
||||
print("MedianStoppingRule: early stopping {}".format(trial))
|
||||
self._num_stopped += 1
|
||||
return TrialScheduler.STOP
|
||||
else:
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
self._results[trial].append(result)
|
||||
self._completed_trials.add(trial)
|
||||
|
||||
def debug_string(self):
|
||||
return "Using MedianStoppingRule: num_stopped={}.".format(
|
||||
self._num_stopped)
|
||||
|
||||
def _get_median_result(self, time):
|
||||
scores = []
|
||||
for trial in self._completed_trials:
|
||||
scores.append(self._running_result(trial, time))
|
||||
if len(scores) >= self._min_samples_required:
|
||||
return np.median(scores)
|
||||
else:
|
||||
return float('-inf')
|
||||
|
||||
def _running_result(self, trial, t_max=float('inf')):
|
||||
results = self._results[trial]
|
||||
# TODO(ekl) we could do interpolation to be more precise, but for now
|
||||
# assume len(results) is large and the time diffs are roughly equal
|
||||
return np.mean(
|
||||
[getattr(r, self._reward_attr)
|
||||
for r in results if getattr(r, self._time_attr) <= t_max])
|
||||
|
||||
def _best_result(self, trial):
|
||||
results = self._results[trial]
|
||||
return max([getattr(r, self._reward_attr) for r in results])
|
||||
|
||||
Reference in New Issue
Block a user