mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 23:57:45 +08:00
[tune] Handles nan case for AsyncHyperBand (#6916)
This commit is contained in:
@@ -141,7 +141,8 @@ class _Bracket():
|
||||
def cutoff(self, recorded):
|
||||
if not recorded:
|
||||
return None
|
||||
return np.percentile(list(recorded.values()), (1 - 1 / self.rf) * 100)
|
||||
return np.nanpercentile(
|
||||
list(recorded.values()), (1 - 1 / self.rf) * 100)
|
||||
|
||||
def on_result(self, trial, cur_iter, cur_rew):
|
||||
action = TrialScheduler.CONTINUE
|
||||
|
||||
@@ -1091,6 +1091,21 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
||||
TrialScheduler.CONTINUE)
|
||||
return t1, t2
|
||||
|
||||
def nanSetup(self, scheduler):
|
||||
t1 = Trial("PPO") # mean is 450, max 450, t_max=10
|
||||
t2 = Trial("PPO") # mean is nan, max nan, t_max=10
|
||||
scheduler.on_trial_add(None, t1)
|
||||
scheduler.on_trial_add(None, t2)
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t1, result(i, 450)),
|
||||
TrialScheduler.CONTINUE)
|
||||
for i in range(10):
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t2, result(i, np.nan)),
|
||||
TrialScheduler.CONTINUE)
|
||||
return t1, t2
|
||||
|
||||
def testAsyncHBOnComplete(self):
|
||||
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
|
||||
t1, t2 = self.basicSetup(scheduler)
|
||||
@@ -1145,6 +1160,21 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
||||
scheduler.on_trial_result(None, t3, result(2, 260)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testAsyncHBNanPercentile(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=1, max_t=10, reduction_factor=2, brackets=1)
|
||||
t1, t2 = self.nanSetup(scheduler)
|
||||
scheduler.on_trial_complete(None, t1, result(10, 450))
|
||||
scheduler.on_trial_complete(None, t2, result(10, np.nan))
|
||||
t3 = Trial("PPO")
|
||||
scheduler.on_trial_add(None, t3)
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t3, result(1, 260)),
|
||||
TrialScheduler.STOP)
|
||||
self.assertEqual(
|
||||
scheduler.on_trial_result(None, t3, result(2, 260)),
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def _test_metrics(self, result_func, metric, mode):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=1,
|
||||
|
||||
Reference in New Issue
Block a user