mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 21:29:15 +08:00
[tune] Better Info String and Tweaks (#2874)
This commit is contained in:
@@ -39,6 +39,7 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
"""
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = []
|
||||
self._counter = 0
|
||||
self._finished = False
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
@@ -91,11 +92,14 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
break
|
||||
spec = copy.deepcopy(experiment_spec)
|
||||
spec["config"] = suggested_config
|
||||
self._counter += 1
|
||||
tag = "{0}_{1}".format(
|
||||
str(self._counter), format_vars(spec["config"]))
|
||||
yield create_trial_from_spec(
|
||||
spec,
|
||||
output_path,
|
||||
self._parser,
|
||||
experiment_tag=format_vars(spec["config"]),
|
||||
experiment_tag=tag,
|
||||
trial_id=trial_id)
|
||||
|
||||
def is_finished(self):
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
|
||||
@@ -17,6 +18,12 @@ from ray.tune.web_server import TuneServer
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
|
||||
|
||||
def _naturalize(string):
|
||||
"""Provides a natural representation for string for nice sorting."""
|
||||
splits = re.split("([0-9]+)", string)
|
||||
return [int(text) if text.isdigit() else text.lower() for text in splits]
|
||||
|
||||
|
||||
class TrialRunner(object):
|
||||
"""A TrialRunner implements the event loop for scheduling trials on Ray.
|
||||
|
||||
@@ -158,7 +165,6 @@ class TrialRunner(object):
|
||||
|
||||
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
messages = self._debug_messages()
|
||||
states = collections.defaultdict(set)
|
||||
limit_per_state = collections.Counter()
|
||||
@@ -181,11 +187,25 @@ class TrialRunner(object):
|
||||
for state, trials in sorted(states.items()):
|
||||
limit = limit_per_state[state]
|
||||
messages.append("{} trials:".format(state))
|
||||
for t in sorted(trials, key=lambda t: t.experiment_tag)[:limit]:
|
||||
messages.append(" - {}:\t{}".format(t, t.progress_string()))
|
||||
sorted_trials = sorted(
|
||||
trials, key=lambda t: _naturalize(t.experiment_tag))
|
||||
if len(trials) > limit:
|
||||
tail_length = limit // 2
|
||||
first = sorted_trials[:tail_length]
|
||||
for t in first:
|
||||
messages.append(" - {}:\t{}".format(
|
||||
t, t.progress_string()))
|
||||
messages.append(
|
||||
" ... {} more not shown".format(len(trials) - limit))
|
||||
" ... {} not shown".format(len(trials) - tail_length * 2))
|
||||
last = sorted_trials[-tail_length:]
|
||||
for t in last:
|
||||
messages.append(" - {}:\t{}".format(
|
||||
t, t.progress_string()))
|
||||
else:
|
||||
for t in sorted_trials:
|
||||
messages.append(" - {}:\t{}".format(
|
||||
t, t.progress_string()))
|
||||
|
||||
return "\n".join(messages) + "\n"
|
||||
|
||||
def _debug_messages(self):
|
||||
|
||||
Reference in New Issue
Block a user