[tune] Fix Median Stopping Rule Verbosity (#1833)

This commit is contained in:
Richard Liaw
2018-04-06 22:58:13 -07:00
committed by GitHub
parent bef1d872b4
commit bc8f62c947
2 changed files with 19 additions and 11 deletions
+9 -7
View File
@@ -66,15 +66,17 @@ class AsyncHyperBandScheduler(FIFOScheduler):
self._trial_info[trial.trial_id] = self._brackets[idx]
def on_trial_result(self, trial_runner, trial, result):
action = TrialScheduler.CONTINUE
if getattr(result, self._time_attr) >= self._max_t:
action = TrialScheduler.STOP
else:
bracket = self._trial_info[trial.trial_id]
action = bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
if action == TrialScheduler.STOP:
self._num_stopped += 1
return TrialScheduler.STOP
bracket = self._trial_info[trial.trial_id]
action = bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
return action
def on_trial_complete(self, trial_runner, trial, result):
+10 -4
View File
@@ -28,11 +28,14 @@ class MedianStoppingRule(FIFOScheduler):
hard_stop (bool): If False, pauses trials instead of stopping
them. When all other trials are complete, paused trials will be
resumed and allowed to run FIFO.
verbose (bool): If True, will output the median and best result each
time a trial reports. Defaults to True.
"""
def __init__(
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
grace_period=60.0, min_samples_required=3, hard_stop=True):
grace_period=60.0, min_samples_required=3,
hard_stop=True, verbose=True):
FIFOScheduler.__init__(self)
self._stopped_trials = set()
self._completed_trials = set()
@@ -42,6 +45,7 @@ class MedianStoppingRule(FIFOScheduler):
self._reward_attr = reward_attr
self._time_attr = time_attr
self._hard_stop = hard_stop
self._verbose = verbose
def on_trial_result(self, trial_runner, trial, result):
"""Callback for early stopping.
@@ -59,10 +63,12 @@ class MedianStoppingRule(FIFOScheduler):
self._results[trial].append(result)
median_result = self._get_median_result(time)
best_result = self._best_result(trial)
print("Trial {} best res={} vs median res={} at t={}".format(
trial, best_result, median_result, time))
if self._verbose:
print("Trial {} best res={} vs median res={} at t={}".format(
trial, best_result, median_result, time))
if best_result < median_result and time > self._grace_period:
print("MedianStoppingRule: early stopping {}".format(trial))
if self._verbose:
print("MedianStoppingRule: early stopping {}".format(trial))
self._stopped_trials.add(trial)
if self._hard_stop:
return TrialScheduler.STOP