From 53d5a8a45f3972e348a13692748b65e0f9df6858 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Fri, 5 Jul 2019 16:05:10 -0700 Subject: [PATCH] [tune] Fix sort (#5111) * fix sort * fix tune list-experiments * Update python/ray/tune/tests/test_commands.py --- python/ray/tune/commands.py | 11 +++++++---- python/ray/tune/tests/test_commands.py | 2 ++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/ray/tune/commands.py b/python/ray/tune/commands.py index 0202bd9d9..b8901566a 100644 --- a/python/ray/tune/commands.py +++ b/python/ray/tune/commands.py @@ -163,8 +163,10 @@ def list_trials(experiment_path, checkpoints_df = checkpoints_df[filtered_index] if sort: - if sort not in checkpoints_df: - raise KeyError("{} not in: {}".format(sort, list(checkpoints_df))) + for key in sort: + if key not in checkpoints_df: + raise KeyError("{} not in: {}".format(key, + list(checkpoints_df))) ascending = not desc checkpoints_df = checkpoints_df.sort_values( by=sort, ascending=ascending) @@ -274,8 +276,9 @@ def list_experiments(project_path, info_df = info_df[filtered_index] if sort: - if sort not in info_df: - raise KeyError("{} not in: {}".format(sort, list(info_df))) + for key in sort: + if key not in info_df: + raise KeyError("{} not in: {}".format(key, list(info_df))) ascending = not desc info_df = info_df.sort_values(by=sort, ascending=ascending) diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py index f55dc8336..ee6452127 100644 --- a/python/ray/tune/tests/test_commands.py +++ b/python/ray/tune/tests/test_commands.py @@ -90,6 +90,7 @@ def test_ls(start_ray, tmpdir): with Capturing() as output: commands.list_trials( experiment_path, + sort=["status"], info_keys=("status", ), filter_op="status == TERMINATED") lines = output.captured @@ -126,6 +127,7 @@ def test_lsx(start_ray, tmpdir): with Capturing() as output: commands.list_experiments( project_path, + sort=["total_trials"], info_keys=("total_trials", ), filter_op="total_trials == 1") lines = output.captured