[tune] Handles nan case for AsyncHyperBand (#6916)

This commit is contained in:
hyggan
2020-01-26 02:26:30 +01:00
committed by Richard Liaw
parent ed9de8b2fa
commit 552156f22d
2 changed files with 32 additions and 1 deletions
@@ -141,7 +141,8 @@ class _Bracket():
def cutoff(self, recorded):
if not recorded:
return None
return np.percentile(list(recorded.values()), (1 - 1 / self.rf) * 100)
return np.nanpercentile(
list(recorded.values()), (1 - 1 / self.rf) * 100)
def on_result(self, trial, cur_iter, cur_rew):
action = TrialScheduler.CONTINUE
@@ -1091,6 +1091,21 @@ class AsyncHyperBandSuite(unittest.TestCase):
TrialScheduler.CONTINUE)
return t1, t2
def nanSetup(self, scheduler):
t1 = Trial("PPO") # mean is 450, max 450, t_max=10
t2 = Trial("PPO") # mean is nan, max nan, t_max=10
scheduler.on_trial_add(None, t1)
scheduler.on_trial_add(None, t2)
for i in range(10):
self.assertEqual(
scheduler.on_trial_result(None, t1, result(i, 450)),
TrialScheduler.CONTINUE)
for i in range(10):
self.assertEqual(
scheduler.on_trial_result(None, t2, result(i, np.nan)),
TrialScheduler.CONTINUE)
return t1, t2
def testAsyncHBOnComplete(self):
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
t1, t2 = self.basicSetup(scheduler)
@@ -1145,6 +1160,21 @@ class AsyncHyperBandSuite(unittest.TestCase):
scheduler.on_trial_result(None, t3, result(2, 260)),
TrialScheduler.STOP)
def testAsyncHBNanPercentile(self):
scheduler = AsyncHyperBandScheduler(
grace_period=1, max_t=10, reduction_factor=2, brackets=1)
t1, t2 = self.nanSetup(scheduler)
scheduler.on_trial_complete(None, t1, result(10, 450))
scheduler.on_trial_complete(None, t2, result(10, np.nan))
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(1, 260)),
TrialScheduler.STOP)
self.assertEqual(
scheduler.on_trial_result(None, t3, result(2, 260)),
TrialScheduler.STOP)
def _test_metrics(self, result_func, metric, mode):
scheduler = AsyncHyperBandScheduler(
grace_period=1,