[Tune] PBT Error if metric not available (#9957)

This commit is contained in:
Amog Kamsetty
2020-08-17 16:12:14 -07:00
committed by GitHub
parent 4b14bf85e4
commit d3bac298d5
3 changed files with 83 additions and 5 deletions
+40 -2
View File
@@ -13,6 +13,8 @@ from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.suggest.variant_generator import format_vars
from ray.tune.trial import Trial, Checkpoint
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@@ -148,6 +150,9 @@ class PopulationBasedTraining(FIFOScheduler):
log_config (bool): Whether to log the ray config of each model to
local_dir at each exploit. Allows config schedule to be
reconstructed.
require_attrs (bool): Whether to require time_attr and metric to appear
in result for every iteration. If True, error will be raised
if these values are not present in trial result.
.. code-block:: python
@@ -182,7 +187,8 @@ class PopulationBasedTraining(FIFOScheduler):
quantile_fraction=0.25,
resample_probability=0.25,
custom_explore_fn=None,
log_config=True):
log_config=True,
require_attrs=True):
for value in hyperparam_mutations.values():
if not (isinstance(value, (list, dict)) or callable(value)):
raise TypeError("`hyperparam_mutation` values must be either "
@@ -222,6 +228,7 @@ class PopulationBasedTraining(FIFOScheduler):
self._trial_state = {}
self._custom_explore_fn = custom_explore_fn
self._log_config = log_config
self._require_attrs = require_attrs
# Metrics
self._num_checkpoints = 0
@@ -231,8 +238,39 @@ class PopulationBasedTraining(FIFOScheduler):
self._trial_state[trial] = PBTTrialState(trial)
def on_trial_result(self, trial_runner, trial, result):
if self._time_attr not in result or self._metric not in result:
if self._time_attr not in result:
time_missing_msg = "Cannot find time_attr {} " \
"in trial result {}. Make sure that this " \
"attribute is returned in the " \
"results of your Trainable.".format(
self._time_attr, result)
if self._require_attrs:
raise RuntimeError(
time_missing_msg +
"If this error is expected, you can change this to "
"a warning message by "
"setting PBT(require_attrs=False)")
else:
if log_once("pbt-time_attr-error"):
logger.warning(time_missing_msg)
if self._metric not in result:
metric_missing_msg = "Cannot find metric {} in trial result {}. " \
"Make sure that this attribute is returned " \
"in the " \
"results of your Trainable.".format(
self._metric, result)
if self._require_attrs:
raise RuntimeError(
metric_missing_msg + "If this error is expected, "
"you can change this to a warning message by "
"setting PBT(require_attrs=False)")
else:
if log_once("pbt-metric-error"):
logger.warning(metric_missing_msg)
if self._metric not in result or self._time_attr not in result:
return TrialScheduler.CONTINUE
time = result[self._time_attr]
state = self._trial_state[trial]
+41 -1
View File
@@ -716,6 +716,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
explore=None,
perturbation_interval=10,
log_config=False,
require_attrs=True,
hyperparams=None,
hyperparam_mutations=None,
step_once=True):
@@ -731,7 +732,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
quantile_fraction=0.25,
hyperparam_mutations=hyperparam_mutations,
custom_explore_fn=explore,
log_config=log_config)
log_config=log_config,
require_attrs=require_attrs)
runner = _MockTrialRunner(pbt)
for i in range(num_trials):
trial_hyperparams = hyperparams or {
@@ -750,6 +752,44 @@ class PopulationBasedTestingSuite(unittest.TestCase):
pbt.reset_stats()
return pbt, runner
def testMetricError(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
# Should error if training_iteration not in result dict.
with self.assertRaises(RuntimeError):
pbt.on_trial_result(
runner, trials[0], result={"episode_reward_mean": 4})
# Should error if episode_reward_mean not in result dict.
with self.assertRaises(RuntimeError):
pbt.on_trial_result(
runner,
trials[0],
result={
"random_metric": 10,
"training_iteration": 20
})
def testMetricLog(self):
pbt, runner = self.basicSetup(require_attrs=False)
trials = runner.get_trials()
# Should not error if training_iteration not in result dict
with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"):
pbt.on_trial_result(
runner, trials[0], result={"episode_reward_mean": 4})
# Should not error if episode_reward_mean not in result dict.
with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"):
pbt.on_trial_result(
runner,
trials[0],
result={
"random_metric": 10,
"training_iteration": 20
})
def testCheckpointsMostPromisingTrials(self):
pbt, runner = self.basicSetup()
trials = runner.get_trials()
@@ -99,14 +99,14 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
"c": 1
},
fail_fast=True,
num_samples=20,
num_samples=4,
checkpoint_freq=1,
checkpoint_at_end=True,
keep_checkpoints_num=1,
checkpoint_score_attr="min-training_iteration",
scheduler=scheduler,
name="testPermutationContinuation",
stop={"training_iteration": 5})
stop={"training_iteration": 3})
def testPermutationContinuationFunc(self):
scheduler = PopulationBasedTraining(