diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index f52fda893..3093b73ef 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -75,7 +75,7 @@ class Analysis(object): mode (str): One of [min, max]. """ - df = self.dataframe() + df = self.dataframe(metric=metric, mode=mode) if mode == "max": return df.iloc[df[metric].idxmax()].logdir elif mode == "min": diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index 6a12b39f2..99d9c8e1c 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -141,6 +141,13 @@ class AnalysisSuite(unittest.TestCase): self.assertTrue(logdir2.startswith(self.test_dir)) self.assertNotEquals(logdir, logdir2) + def testBestConfigIsLogdir(self): + analysis = Analysis(self.test_dir) + for metric, mode in [(self.metric, "min"), (self.metric, "max")]: + logdir = analysis.get_best_logdir(metric, mode=mode) + best_config = analysis.get_best_config(metric, mode=mode) + self.assertEquals(analysis.get_all_configs()[logdir], best_config) + if __name__ == "__main__": unittest.main(verbosity=2)