From d3bac298d5af73df0a1fc71d5e8f0c598dd21342 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 17 Aug 2020 16:12:14 -0700 Subject: [PATCH] [Tune] PBT Error if metric not available (#9957) --- python/ray/tune/schedulers/pbt.py | 42 ++++++++++++++++++- python/ray/tune/tests/test_trial_scheduler.py | 42 ++++++++++++++++++- .../tune/tests/test_trial_scheduler_pbt.py | 4 +- 3 files changed, 83 insertions(+), 5 deletions(-) diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index a8632de3d..6711f0817 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -13,6 +13,8 @@ from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.suggest.variant_generator import format_vars from ray.tune.trial import Trial, Checkpoint +from ray.util.debug import log_once + logger = logging.getLogger(__name__) @@ -148,6 +150,9 @@ class PopulationBasedTraining(FIFOScheduler): log_config (bool): Whether to log the ray config of each model to local_dir at each exploit. Allows config schedule to be reconstructed. + require_attrs (bool): Whether to require time_attr and metric to appear + in result for every iteration. If True, error will be raised + if these values are not present in trial result. .. code-block:: python @@ -182,7 +187,8 @@ class PopulationBasedTraining(FIFOScheduler): quantile_fraction=0.25, resample_probability=0.25, custom_explore_fn=None, - log_config=True): + log_config=True, + require_attrs=True): for value in hyperparam_mutations.values(): if not (isinstance(value, (list, dict)) or callable(value)): raise TypeError("`hyperparam_mutation` values must be either " @@ -222,6 +228,7 @@ class PopulationBasedTraining(FIFOScheduler): self._trial_state = {} self._custom_explore_fn = custom_explore_fn self._log_config = log_config + self._require_attrs = require_attrs # Metrics self._num_checkpoints = 0 @@ -231,8 +238,39 @@ class PopulationBasedTraining(FIFOScheduler): self._trial_state[trial] = PBTTrialState(trial) def on_trial_result(self, trial_runner, trial, result): - if self._time_attr not in result or self._metric not in result: + if self._time_attr not in result: + time_missing_msg = "Cannot find time_attr {} " \ + "in trial result {}. Make sure that this " \ + "attribute is returned in the " \ + "results of your Trainable.".format( + self._time_attr, result) + if self._require_attrs: + raise RuntimeError( + time_missing_msg + + "If this error is expected, you can change this to " + "a warning message by " + "setting PBT(require_attrs=False)") + else: + if log_once("pbt-time_attr-error"): + logger.warning(time_missing_msg) + if self._metric not in result: + metric_missing_msg = "Cannot find metric {} in trial result {}. " \ + "Make sure that this attribute is returned " \ + "in the " \ + "results of your Trainable.".format( + self._metric, result) + if self._require_attrs: + raise RuntimeError( + metric_missing_msg + "If this error is expected, " + "you can change this to a warning message by " + "setting PBT(require_attrs=False)") + else: + if log_once("pbt-metric-error"): + logger.warning(metric_missing_msg) + + if self._metric not in result or self._time_attr not in result: return TrialScheduler.CONTINUE + time = result[self._time_attr] state = self._trial_state[trial] diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index ad3d7f298..32a5713dd 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -716,6 +716,7 @@ class PopulationBasedTestingSuite(unittest.TestCase): explore=None, perturbation_interval=10, log_config=False, + require_attrs=True, hyperparams=None, hyperparam_mutations=None, step_once=True): @@ -731,7 +732,8 @@ class PopulationBasedTestingSuite(unittest.TestCase): quantile_fraction=0.25, hyperparam_mutations=hyperparam_mutations, custom_explore_fn=explore, - log_config=log_config) + log_config=log_config, + require_attrs=require_attrs) runner = _MockTrialRunner(pbt) for i in range(num_trials): trial_hyperparams = hyperparams or { @@ -750,6 +752,44 @@ class PopulationBasedTestingSuite(unittest.TestCase): pbt.reset_stats() return pbt, runner + def testMetricError(self): + pbt, runner = self.basicSetup() + trials = runner.get_trials() + + # Should error if training_iteration not in result dict. + with self.assertRaises(RuntimeError): + pbt.on_trial_result( + runner, trials[0], result={"episode_reward_mean": 4}) + + # Should error if episode_reward_mean not in result dict. + with self.assertRaises(RuntimeError): + pbt.on_trial_result( + runner, + trials[0], + result={ + "random_metric": 10, + "training_iteration": 20 + }) + + def testMetricLog(self): + pbt, runner = self.basicSetup(require_attrs=False) + trials = runner.get_trials() + + # Should not error if training_iteration not in result dict + with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"): + pbt.on_trial_result( + runner, trials[0], result={"episode_reward_mean": 4}) + + # Should not error if episode_reward_mean not in result dict. + with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"): + pbt.on_trial_result( + runner, + trials[0], + result={ + "random_metric": 10, + "training_iteration": 20 + }) + def testCheckpointsMostPromisingTrials(self): pbt, runner = self.basicSetup() trials = runner.get_trials() diff --git a/python/ray/tune/tests/test_trial_scheduler_pbt.py b/python/ray/tune/tests/test_trial_scheduler_pbt.py index 523d6e778..0cf4f8035 100644 --- a/python/ray/tune/tests/test_trial_scheduler_pbt.py +++ b/python/ray/tune/tests/test_trial_scheduler_pbt.py @@ -99,14 +99,14 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase): "c": 1 }, fail_fast=True, - num_samples=20, + num_samples=4, checkpoint_freq=1, checkpoint_at_end=True, keep_checkpoints_num=1, checkpoint_score_attr="min-training_iteration", scheduler=scheduler, name="testPermutationContinuation", - stop={"training_iteration": 5}) + stop={"training_iteration": 3}) def testPermutationContinuationFunc(self): scheduler = PopulationBasedTraining(