[tune] Better Info String and Tweaks (#2874)

This commit is contained in:
Richard Liaw
2018-09-15 11:02:13 -07:00
committed by GitHub
parent e96817d074
commit e05baed336
2 changed files with 29 additions and 5 deletions
+5 -1
View File
@@ -39,6 +39,7 @@ class SuggestionAlgorithm(SearchAlgorithm):
"""
self._parser = make_parser()
self._trial_generator = []
self._counter = 0
self._finished = False
def add_configurations(self, experiments):
@@ -91,11 +92,14 @@ class SuggestionAlgorithm(SearchAlgorithm):
break
spec = copy.deepcopy(experiment_spec)
spec["config"] = suggested_config
self._counter += 1
tag = "{0}_{1}".format(
str(self._counter), format_vars(spec["config"]))
yield create_trial_from_spec(
spec,
output_path,
self._parser,
experiment_tag=format_vars(spec["config"]),
experiment_tag=tag,
trial_id=trial_id)
def is_finished(self):
+24 -4
View File
@@ -4,6 +4,7 @@ from __future__ import print_function
import collections
import os
import re
import time
import traceback
@@ -17,6 +18,12 @@ from ray.tune.web_server import TuneServer
MAX_DEBUG_TRIALS = 20
def _naturalize(string):
"""Provides a natural representation for string for nice sorting."""
splits = re.split("([0-9]+)", string)
return [int(text) if text.isdigit() else text.lower() for text in splits]
class TrialRunner(object):
"""A TrialRunner implements the event loop for scheduling trials on Ray.
@@ -158,7 +165,6 @@ class TrialRunner(object):
def debug_string(self, max_debug=MAX_DEBUG_TRIALS):
"""Returns a human readable message for printing to the console."""
messages = self._debug_messages()
states = collections.defaultdict(set)
limit_per_state = collections.Counter()
@@ -181,11 +187,25 @@ class TrialRunner(object):
for state, trials in sorted(states.items()):
limit = limit_per_state[state]
messages.append("{} trials:".format(state))
for t in sorted(trials, key=lambda t: t.experiment_tag)[:limit]:
messages.append(" - {}:\t{}".format(t, t.progress_string()))
sorted_trials = sorted(
trials, key=lambda t: _naturalize(t.experiment_tag))
if len(trials) > limit:
tail_length = limit // 2
first = sorted_trials[:tail_length]
for t in first:
messages.append(" - {}:\t{}".format(
t, t.progress_string()))
messages.append(
" ... {} more not shown".format(len(trials) - limit))
" ... {} not shown".format(len(trials) - tail_length * 2))
last = sorted_trials[-tail_length:]
for t in last:
messages.append(" - {}:\t{}".format(
t, t.progress_string()))
else:
for t in sorted_trials:
messages.append(" - {}:\t{}".format(
t, t.progress_string()))
return "\n".join(messages) + "\n"
def _debug_messages(self):