diff --git a/python/ray/tune/commands.py b/python/ray/tune/commands.py index 0202bd9d9..b8901566a 100644 --- a/python/ray/tune/commands.py +++ b/python/ray/tune/commands.py @@ -163,8 +163,10 @@ def list_trials(experiment_path, checkpoints_df = checkpoints_df[filtered_index] if sort: - if sort not in checkpoints_df: - raise KeyError("{} not in: {}".format(sort, list(checkpoints_df))) + for key in sort: + if key not in checkpoints_df: + raise KeyError("{} not in: {}".format(key, + list(checkpoints_df))) ascending = not desc checkpoints_df = checkpoints_df.sort_values( by=sort, ascending=ascending) @@ -274,8 +276,9 @@ def list_experiments(project_path, info_df = info_df[filtered_index] if sort: - if sort not in info_df: - raise KeyError("{} not in: {}".format(sort, list(info_df))) + for key in sort: + if key not in info_df: + raise KeyError("{} not in: {}".format(key, list(info_df))) ascending = not desc info_df = info_df.sort_values(by=sort, ascending=ascending) diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py index f55dc8336..ee6452127 100644 --- a/python/ray/tune/tests/test_commands.py +++ b/python/ray/tune/tests/test_commands.py @@ -90,6 +90,7 @@ def test_ls(start_ray, tmpdir): with Capturing() as output: commands.list_trials( experiment_path, + sort=["status"], info_keys=("status", ), filter_op="status == TERMINATED") lines = output.captured @@ -126,6 +127,7 @@ def test_lsx(start_ray, tmpdir): with Capturing() as output: commands.list_experiments( project_path, + sort=["total_trials"], info_keys=("total_trials", ), filter_op="total_trials == 1") lines = output.captured