diff --git a/python/ray/tune/schedulers/async_hyperband.py b/python/ray/tune/schedulers/async_hyperband.py index 80c52c8c0..e30cfb0a9 100644 --- a/python/ray/tune/schedulers/async_hyperband.py +++ b/python/ray/tune/schedulers/async_hyperband.py @@ -141,7 +141,8 @@ class _Bracket(): def cutoff(self, recorded): if not recorded: return None - return np.percentile(list(recorded.values()), (1 - 1 / self.rf) * 100) + return np.nanpercentile( + list(recorded.values()), (1 - 1 / self.rf) * 100) def on_result(self, trial, cur_iter, cur_rew): action = TrialScheduler.CONTINUE diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index e1e0d0277..ffdebc93d 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -1091,6 +1091,21 @@ class AsyncHyperBandSuite(unittest.TestCase): TrialScheduler.CONTINUE) return t1, t2 + def nanSetup(self, scheduler): + t1 = Trial("PPO") # mean is 450, max 450, t_max=10 + t2 = Trial("PPO") # mean is nan, max nan, t_max=10 + scheduler.on_trial_add(None, t1) + scheduler.on_trial_add(None, t2) + for i in range(10): + self.assertEqual( + scheduler.on_trial_result(None, t1, result(i, 450)), + TrialScheduler.CONTINUE) + for i in range(10): + self.assertEqual( + scheduler.on_trial_result(None, t2, result(i, np.nan)), + TrialScheduler.CONTINUE) + return t1, t2 + def testAsyncHBOnComplete(self): scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1) t1, t2 = self.basicSetup(scheduler) @@ -1145,6 +1160,21 @@ class AsyncHyperBandSuite(unittest.TestCase): scheduler.on_trial_result(None, t3, result(2, 260)), TrialScheduler.STOP) + def testAsyncHBNanPercentile(self): + scheduler = AsyncHyperBandScheduler( + grace_period=1, max_t=10, reduction_factor=2, brackets=1) + t1, t2 = self.nanSetup(scheduler) + scheduler.on_trial_complete(None, t1, result(10, 450)) + scheduler.on_trial_complete(None, t2, result(10, np.nan)) + t3 = Trial("PPO") + scheduler.on_trial_add(None, t3) + self.assertEqual( + scheduler.on_trial_result(None, t3, result(1, 260)), + TrialScheduler.STOP) + self.assertEqual( + scheduler.on_trial_result(None, t3, result(2, 260)), + TrialScheduler.STOP) + def _test_metrics(self, result_func, metric, mode): scheduler = AsyncHyperBandScheduler( grace_period=1,