mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 10:28:01 +08:00
[tune] Fix get_best_logdir behaviour (#5429)
* Fix get_best_logdir behaviour * addtest
This commit is contained in:
@@ -75,7 +75,7 @@ class Analysis(object):
|
||||
mode (str): One of [min, max].
|
||||
|
||||
"""
|
||||
df = self.dataframe()
|
||||
df = self.dataframe(metric=metric, mode=mode)
|
||||
if mode == "max":
|
||||
return df.iloc[df[metric].idxmax()].logdir
|
||||
elif mode == "min":
|
||||
|
||||
@@ -141,6 +141,13 @@ class AnalysisSuite(unittest.TestCase):
|
||||
self.assertTrue(logdir2.startswith(self.test_dir))
|
||||
self.assertNotEquals(logdir, logdir2)
|
||||
|
||||
def testBestConfigIsLogdir(self):
|
||||
analysis = Analysis(self.test_dir)
|
||||
for metric, mode in [(self.metric, "min"), (self.metric, "max")]:
|
||||
logdir = analysis.get_best_logdir(metric, mode=mode)
|
||||
best_config = analysis.get_best_config(metric, mode=mode)
|
||||
self.assertEquals(analysis.get_all_configs()[logdir], best_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user