mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 06:08:03 +08:00
[tune] Fix sort (#5111)
* fix sort * fix tune list-experiments * Update python/ray/tune/tests/test_commands.py
This commit is contained in:
@@ -163,8 +163,10 @@ def list_trials(experiment_path,
|
||||
checkpoints_df = checkpoints_df[filtered_index]
|
||||
|
||||
if sort:
|
||||
if sort not in checkpoints_df:
|
||||
raise KeyError("{} not in: {}".format(sort, list(checkpoints_df)))
|
||||
for key in sort:
|
||||
if key not in checkpoints_df:
|
||||
raise KeyError("{} not in: {}".format(key,
|
||||
list(checkpoints_df)))
|
||||
ascending = not desc
|
||||
checkpoints_df = checkpoints_df.sort_values(
|
||||
by=sort, ascending=ascending)
|
||||
@@ -274,8 +276,9 @@ def list_experiments(project_path,
|
||||
info_df = info_df[filtered_index]
|
||||
|
||||
if sort:
|
||||
if sort not in info_df:
|
||||
raise KeyError("{} not in: {}".format(sort, list(info_df)))
|
||||
for key in sort:
|
||||
if key not in info_df:
|
||||
raise KeyError("{} not in: {}".format(key, list(info_df)))
|
||||
ascending = not desc
|
||||
info_df = info_df.sort_values(by=sort, ascending=ascending)
|
||||
|
||||
|
||||
@@ -90,6 +90,7 @@ def test_ls(start_ray, tmpdir):
|
||||
with Capturing() as output:
|
||||
commands.list_trials(
|
||||
experiment_path,
|
||||
sort=["status"],
|
||||
info_keys=("status", ),
|
||||
filter_op="status == TERMINATED")
|
||||
lines = output.captured
|
||||
@@ -126,6 +127,7 @@ def test_lsx(start_ray, tmpdir):
|
||||
with Capturing() as output:
|
||||
commands.list_experiments(
|
||||
project_path,
|
||||
sort=["total_trials"],
|
||||
info_keys=("total_trials", ),
|
||||
filter_op="total_trials == 1")
|
||||
lines = output.captured
|
||||
|
||||
Reference in New Issue
Block a user