mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 11:10:02 +08:00
[Tune] PBT Error if metric not available (#9957)
This commit is contained in:
@@ -13,6 +13,8 @@ from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
|
||||
from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -148,6 +150,9 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
log_config (bool): Whether to log the ray config of each model to
|
||||
local_dir at each exploit. Allows config schedule to be
|
||||
reconstructed.
|
||||
require_attrs (bool): Whether to require time_attr and metric to appear
|
||||
in result for every iteration. If True, error will be raised
|
||||
if these values are not present in trial result.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -182,7 +187,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
quantile_fraction=0.25,
|
||||
resample_probability=0.25,
|
||||
custom_explore_fn=None,
|
||||
log_config=True):
|
||||
log_config=True,
|
||||
require_attrs=True):
|
||||
for value in hyperparam_mutations.values():
|
||||
if not (isinstance(value, (list, dict)) or callable(value)):
|
||||
raise TypeError("`hyperparam_mutation` values must be either "
|
||||
@@ -222,6 +228,7 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
self._trial_state = {}
|
||||
self._custom_explore_fn = custom_explore_fn
|
||||
self._log_config = log_config
|
||||
self._require_attrs = require_attrs
|
||||
|
||||
# Metrics
|
||||
self._num_checkpoints = 0
|
||||
@@ -231,8 +238,39 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
self._trial_state[trial] = PBTTrialState(trial)
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
if self._time_attr not in result or self._metric not in result:
|
||||
if self._time_attr not in result:
|
||||
time_missing_msg = "Cannot find time_attr {} " \
|
||||
"in trial result {}. Make sure that this " \
|
||||
"attribute is returned in the " \
|
||||
"results of your Trainable.".format(
|
||||
self._time_attr, result)
|
||||
if self._require_attrs:
|
||||
raise RuntimeError(
|
||||
time_missing_msg +
|
||||
"If this error is expected, you can change this to "
|
||||
"a warning message by "
|
||||
"setting PBT(require_attrs=False)")
|
||||
else:
|
||||
if log_once("pbt-time_attr-error"):
|
||||
logger.warning(time_missing_msg)
|
||||
if self._metric not in result:
|
||||
metric_missing_msg = "Cannot find metric {} in trial result {}. " \
|
||||
"Make sure that this attribute is returned " \
|
||||
"in the " \
|
||||
"results of your Trainable.".format(
|
||||
self._metric, result)
|
||||
if self._require_attrs:
|
||||
raise RuntimeError(
|
||||
metric_missing_msg + "If this error is expected, "
|
||||
"you can change this to a warning message by "
|
||||
"setting PBT(require_attrs=False)")
|
||||
else:
|
||||
if log_once("pbt-metric-error"):
|
||||
logger.warning(metric_missing_msg)
|
||||
|
||||
if self._metric not in result or self._time_attr not in result:
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
time = result[self._time_attr]
|
||||
state = self._trial_state[trial]
|
||||
|
||||
|
||||
@@ -716,6 +716,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
explore=None,
|
||||
perturbation_interval=10,
|
||||
log_config=False,
|
||||
require_attrs=True,
|
||||
hyperparams=None,
|
||||
hyperparam_mutations=None,
|
||||
step_once=True):
|
||||
@@ -731,7 +732,8 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
quantile_fraction=0.25,
|
||||
hyperparam_mutations=hyperparam_mutations,
|
||||
custom_explore_fn=explore,
|
||||
log_config=log_config)
|
||||
log_config=log_config,
|
||||
require_attrs=require_attrs)
|
||||
runner = _MockTrialRunner(pbt)
|
||||
for i in range(num_trials):
|
||||
trial_hyperparams = hyperparams or {
|
||||
@@ -750,6 +752,44 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
pbt.reset_stats()
|
||||
return pbt, runner
|
||||
|
||||
def testMetricError(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
|
||||
# Should error if training_iteration not in result dict.
|
||||
with self.assertRaises(RuntimeError):
|
||||
pbt.on_trial_result(
|
||||
runner, trials[0], result={"episode_reward_mean": 4})
|
||||
|
||||
# Should error if episode_reward_mean not in result dict.
|
||||
with self.assertRaises(RuntimeError):
|
||||
pbt.on_trial_result(
|
||||
runner,
|
||||
trials[0],
|
||||
result={
|
||||
"random_metric": 10,
|
||||
"training_iteration": 20
|
||||
})
|
||||
|
||||
def testMetricLog(self):
|
||||
pbt, runner = self.basicSetup(require_attrs=False)
|
||||
trials = runner.get_trials()
|
||||
|
||||
# Should not error if training_iteration not in result dict
|
||||
with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"):
|
||||
pbt.on_trial_result(
|
||||
runner, trials[0], result={"episode_reward_mean": 4})
|
||||
|
||||
# Should not error if episode_reward_mean not in result dict.
|
||||
with self.assertLogs("ray.tune.schedulers.pbt", level="WARN"):
|
||||
pbt.on_trial_result(
|
||||
runner,
|
||||
trials[0],
|
||||
result={
|
||||
"random_metric": 10,
|
||||
"training_iteration": 20
|
||||
})
|
||||
|
||||
def testCheckpointsMostPromisingTrials(self):
|
||||
pbt, runner = self.basicSetup()
|
||||
trials = runner.get_trials()
|
||||
|
||||
@@ -99,14 +99,14 @@ class PopulationBasedTrainingResumeTest(unittest.TestCase):
|
||||
"c": 1
|
||||
},
|
||||
fail_fast=True,
|
||||
num_samples=20,
|
||||
num_samples=4,
|
||||
checkpoint_freq=1,
|
||||
checkpoint_at_end=True,
|
||||
keep_checkpoints_num=1,
|
||||
checkpoint_score_attr="min-training_iteration",
|
||||
scheduler=scheduler,
|
||||
name="testPermutationContinuation",
|
||||
stop={"training_iteration": 5})
|
||||
stop={"training_iteration": 3})
|
||||
|
||||
def testPermutationContinuationFunc(self):
|
||||
scheduler = PopulationBasedTraining(
|
||||
|
||||
Reference in New Issue
Block a user