[tune] lazy trials (#10802)

* Lazily fill trial queue

* Update interface

* Update end to end reporter test

* Removed `next_trials()` method

* Lint

* Print total number of samples to be generated in progress reporter. Allow infinite samples.

* Nit check
This commit is contained in:
Kai Fricke
2020-09-17 16:51:46 +01:00
committed by GitHub
parent 89da3f9ba7
commit ee99c919e3
15 changed files with 163 additions and 91 deletions
+22 -3
View File
@@ -56,6 +56,8 @@ class AutoMLSearcher(SearchAlgorithm):
self._unfinished_count = 0
self._running_trials = {}
self._completed_trials = {}
self._next_trials = []
self._next_trial_iter = None
self._iteration = 0
self._total_trial_num = 0
@@ -68,10 +70,27 @@ class AutoMLSearcher(SearchAlgorithm):
"""Returns the Trial object with the best reward_attr"""
return self.best_trial
def next_trials(self):
def next_trial(self):
if not self._next_trial_iter:
self._generate_next_trials()
if not self._next_trials:
self.set_finished()
return None
self._next_trial_iter = iter(self._next_trials)
try:
return next(self._next_trial_iter)
except StopIteration:
self._next_trials = []
self._next_trial_iter = None
return None
def _generate_next_trials(self):
self._next_trials = []
if self._unfinished_count > 0:
# Last round not finished
return []
return
trials = []
raw_param_list, extra_arg_list = self._select()
@@ -110,7 +129,7 @@ class AutoMLSearcher(SearchAlgorithm):
"new": ntrial,
"total": self._total_trial_num
})
return trials
self._next_trials = trials
def on_trial_result(self, trial_id, result):
if not result:
+23 -6
View File
@@ -1,6 +1,8 @@
from __future__ import print_function
import collections
import sys
import numpy as np
import time
@@ -94,10 +96,12 @@ class TuneReporterBase(ProgressReporter):
def __init__(self,
metric_columns=None,
parameter_columns=None,
total_samples=None,
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5,
infer_limit=3):
self._total_samples = total_samples
self._metrics_override = metric_columns is not None
self._inferred_metrics = {}
self._metric_columns = metric_columns or self.DEFAULT_COLUMNS.copy()
@@ -109,6 +113,9 @@ class TuneReporterBase(ProgressReporter):
self._max_report_freqency = max_report_frequency
self._last_report_time = 0
def set_total_samples(self, total_samples):
self._total_samples = total_samples
def should_report(self, trials, done=False):
if time.time() - self._last_report_time > self._max_report_freqency:
self._last_report_time = time.time()
@@ -191,6 +198,7 @@ class TuneReporterBase(ProgressReporter):
trials,
metric_columns=self._metric_columns,
parameter_columns=self._parameter_columns,
total_samples=self._total_samples,
fmt=fmt,
max_rows=max_progress))
messages.append(trial_errors_str(trials, fmt=fmt, max_rows=max_error))
@@ -243,12 +251,13 @@ class JupyterNotebookReporter(TuneReporterBase):
overwrite,
metric_columns=None,
parameter_columns=None,
total_samples=None,
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5):
super(JupyterNotebookReporter, self).__init__(
metric_columns, parameter_columns, max_progress_rows,
max_error_rows, max_report_frequency)
metric_columns, parameter_columns, total_samples,
max_progress_rows, max_error_rows, max_report_frequency)
self._overwrite = overwrite
def report(self, trials, done, *sys_info):
@@ -287,13 +296,14 @@ class CLIReporter(TuneReporterBase):
def __init__(self,
metric_columns=None,
parameter_columns=None,
total_samples=None,
max_progress_rows=20,
max_error_rows=20,
max_report_frequency=5):
super(CLIReporter, self).__init__(metric_columns, parameter_columns,
max_progress_rows, max_error_rows,
max_report_frequency)
total_samples, max_progress_rows,
max_error_rows, max_report_frequency)
def report(self, trials, done, *sys_info):
print(self._progress_str(trials, done, *sys_info))
@@ -324,6 +334,7 @@ def memory_debug_str():
def trial_progress_str(trials,
metric_columns,
parameter_columns=None,
total_samples=0,
fmt="psql",
max_rows=None):
"""Returns a human readable message for printing to the console.
@@ -342,6 +353,7 @@ def trial_progress_str(trials,
values are the names to use in the message. If this is a list,
the parameter name is used in the message directly. If this is
empty, all parameters are used in the message.
total_samples (int): Total number of trials that will be generated.
fmt (str): Output format (see tablefmt in tabulate API).
max_rows (int): Maximum number of rows in the trial table. Defaults to
unlimited.
@@ -381,8 +393,13 @@ def trial_progress_str(trials,
overflow_str = ", ".join(overflow_strs)
else:
overflow = False
messages.append("Number of trials: {} ({})".format(
num_trials, ", ".join(num_trials_strs)))
if total_samples >= sys.maxsize:
total_samples = "infinite"
messages.append("Number of trials: {}{} ({})".format(
num_trials, f"/{total_samples}"
if total_samples else "", ", ".join(num_trials_strs)))
# Pre-process trials to figure out what columns to show.
if isinstance(metric_columns, Mapping):
+24 -18
View File
@@ -1,14 +1,13 @@
import itertools
import os
import random
import uuid
from typing import Dict, List, Union
from ray.tune.error import TuneError
from ray.tune.experiment import Experiment, convert_to_experiment_list
from ray.tune.config_parser import make_parser, create_trial_from_spec
from ray.tune.suggest.variant_generator import (generate_variants, format_vars,
flatten_resolved_vars)
from ray.tune.suggest.variant_generator import (
count_variants, generate_variants, format_vars, flatten_resolved_vars)
from ray.tune.suggest.search import SearchAlgorithm
@@ -17,10 +16,6 @@ class BasicVariantGenerator(SearchAlgorithm):
See also: `ray.tune.suggest.variant_generator`.
Parameters:
shuffle (bool): Shuffles the generated list of configurations.
User API:
.. code-block:: python
@@ -39,19 +34,19 @@ class BasicVariantGenerator(SearchAlgorithm):
searcher = BasicVariantGenerator()
searcher.add_configurations({"experiment": { ... }})
list_of_trials = searcher.next_trials()
trial = searcher.next_trial()
searcher.is_finished == True
"""
def __init__(self, shuffle: bool = False):
def __init__(self):
"""Initializes the Variant Generator.
"""
self._parser = make_parser()
self._trial_generator = []
self._trial_iter = None
self._counter = 0
self._finished = False
self._shuffle = shuffle
# Unique prefix for all trials generated, e.g., trial ids start as
# 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing.
@@ -61,6 +56,12 @@ class BasicVariantGenerator(SearchAlgorithm):
else:
self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_"
self._total_samples = 0
@property
def total_samples(self):
return self._total_samples
def add_configurations(
self,
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
@@ -71,23 +72,28 @@ class BasicVariantGenerator(SearchAlgorithm):
"""
experiment_list = convert_to_experiment_list(experiments)
for experiment in experiment_list:
self._total_samples += count_variants(experiment.spec)
self._trial_generator = itertools.chain(
self._trial_generator,
self._generate_trials(
experiment.spec.get("num_samples", 1), experiment.spec,
experiment.name))
def next_trials(self):
"""Provides Trial objects to be queued into the TrialRunner.
def next_trial(self):
"""Provides one Trial object to be queued into the TrialRunner.
Returns:
trials (list): Returns a list of trials.
Trial: Returns a single trial.
"""
trials = list(self._trial_generator)
if self._shuffle:
random.shuffle(trials)
self.set_finished()
return trials
if not self._trial_iter:
self._trial_iter = iter(self._trial_generator)
try:
return next(self._trial_iter)
except StopIteration:
self._trial_generator = []
self._trial_iter = None
self.set_finished()
return None
def _generate_trials(self, num_samples, unresolved_spec, output_path=""):
"""Generates Trial objects with the variant generation process.
+8 -4
View File
@@ -1,7 +1,6 @@
from typing import Dict, List, Optional, Union
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial
class SearchAlgorithm:
@@ -36,6 +35,11 @@ class SearchAlgorithm:
"""
return True
@property
def total_samples(self):
"""Get number of total trials to be generated"""
return 0
def add_configurations(
self,
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
@@ -46,11 +50,11 @@ class SearchAlgorithm:
"""
raise NotImplementedError
def next_trials(self) -> List[Trial]:
"""Provides Trial objects to be queued into the TrialRunner.
def next_trial(self):
"""Returns single Trial object to be queued into the TrialRunner.
Returns:
trials (list): Returns a list of trials.
trial (Trial): Returns a Trial object.
"""
raise NotImplementedError
+12 -12
View File
@@ -67,13 +67,17 @@ class SearchGenerator(SearchAlgorithm):
self._parser = make_parser()
self._experiment = None
self._counter = 0 # Keeps track of number of trials created.
self._total_samples = None # int: total samples to evaluate.
self._total_samples = 0 # int: total samples to evaluate.
self._finished = False
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
return self.searcher.set_search_properties(metric, mode, config)
@property
def total_samples(self):
return self._total_samples
def add_configurations(
self,
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
@@ -95,20 +99,16 @@ class SearchGenerator(SearchAlgorithm):
if "run" not in experiment_spec:
raise TuneError("Must specify `run` in {}".format(experiment_spec))
def next_trials(self) -> List[Trial]:
"""Provides a batch of Trial objects to be queued into the TrialRunner.
def next_trial(self):
"""Provides one Trial object to be queued into the TrialRunner.
Returns:
List[Trial]: A list of trials for the Runner to consume.
Trial: Returns a single trial.
"""
trials = []
while not self.is_finished():
trial = self.create_trial_if_possible(self._experiment.spec,
self._experiment.name)
if trial is None:
break
trials.append(trial)
return trials
if not self.is_finished():
return self.create_trial_if_possible(self._experiment.spec,
self._experiment.name)
return None
def create_trial_if_possible(self, experiment_spec: Dict,
output_path: str) -> Optional[Trial]:
@@ -138,6 +138,15 @@ def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[
return resolved_vars, domain_vars, grid_vars
def count_variants(spec: Dict) -> int:
spec = copy.deepcopy(spec)
_, domain_vars, grid_vars = parse_spec_vars(spec)
grid_count = 1
for path, domain in grid_vars:
grid_count *= len(domain.categories)
return spec.get("num_samples", 1) * grid_count
def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]:
spec = copy.deepcopy(spec)
_, domain_vars, grid_vars = parse_spec_vars(spec)
+16 -5
View File
@@ -5,6 +5,16 @@ from ray.tune import register_trainable
from ray.tune.automl import SearchSpace, DiscreteSpace, GridSearch
def next_trials(searcher):
trials = []
while not searcher.is_finished():
trial = searcher.next_trial()
if not trial:
break
trials.append(trial)
return trials
class AutoMLSearcherTest(unittest.TestCase):
def setUp(self):
def dummy_train(config, reporter):
@@ -20,7 +30,7 @@ class AutoMLSearcherTest(unittest.TestCase):
])
searcher = GridSearch(space, "reward")
searcher.add_configurations(exp)
trials = searcher.next_trials()
trials = next_trials(searcher)
self.assertEqual(len(trials), 4)
self.assertTrue(trials[0].config["a"]["b"]["c"] in [1, 2])
@@ -34,9 +44,9 @@ class AutoMLSearcherTest(unittest.TestCase):
])
searcher = GridSearch(space, "reward")
searcher.add_configurations(exp)
trials = searcher.next_trials()
trials = next_trials(searcher)
self.assertEqual(len(searcher.next_trials()), 0)
self.assertEqual(searcher.next_trial(), None)
for trial in trials[1:]:
searcher.on_trial_complete(trial.trial_id)
searcher.on_trial_complete(trials[0].trial_id, error=True)
@@ -51,10 +61,11 @@ class AutoMLSearcherTest(unittest.TestCase):
])
searcher = GridSearch(space, "reward")
searcher.add_configurations(exp)
trials = searcher.next_trials()
trials = next_trials(searcher)
self.assertEqual(len(searcher.next_trials()), 0)
self.assertEqual(searcher.next_trial(), None)
for i, trial in enumerate(trials):
print("TRIAL {}".format(trial))
rewards = list(range(i, i + 10))
random.shuffle(rewards)
for reward in rewards:
@@ -73,34 +73,14 @@ tune.run_experiments({
},
}, verbose=1)"""
EXPECTED_END_TO_END_START = """Number of trials: 30 (29 PENDING, 1 RUNNING)
+---------------+----------+-------+-----+-----+
| Trial name | status | loc | a | b |
|---------------+----------+-------+-----+-----|
| f_xxxxx_00001 | PENDING | | 1 | |
| f_xxxxx_00002 | PENDING | | 2 | |
| f_xxxxx_00003 | PENDING | | 3 | |
| f_xxxxx_00004 | PENDING | | 4 | |
| f_xxxxx_00005 | PENDING | | 5 | |
| f_xxxxx_00006 | PENDING | | 6 | |
| f_xxxxx_00007 | PENDING | | 7 | |
| f_xxxxx_00008 | PENDING | | 8 | |
| f_xxxxx_00009 | PENDING | | 9 | |
| f_xxxxx_00010 | PENDING | | | 0 |
| f_xxxxx_00011 | PENDING | | | 1 |
| f_xxxxx_00012 | PENDING | | | 2 |
| f_xxxxx_00013 | PENDING | | | 3 |
| f_xxxxx_00014 | PENDING | | | 4 |
| f_xxxxx_00015 | PENDING | | | 5 |
| f_xxxxx_00016 | PENDING | | | 6 |
| f_xxxxx_00017 | PENDING | | | 7 |
| f_xxxxx_00018 | PENDING | | | 8 |
| f_xxxxx_00019 | PENDING | | | 9 |
| f_xxxxx_00000 | RUNNING | | 0 | |
+---------------+----------+-------+-----+-----+
... 10 more trials not shown (10 PENDING)"""
EXPECTED_END_TO_END_START = """Number of trials: 1/30 (1 RUNNING)
+---------------+----------+-------+-----+
| Trial name | status | loc | a |
|---------------+----------+-------+-----|
| f_xxxxx_00000 | RUNNING | | 0 |
+---------------+----------+-------+-----+"""
EXPECTED_END_TO_END_END = """Number of trials: 30 (30 TERMINATED)
EXPECTED_END_TO_END_END = """Number of trials: 30/30 (30 TERMINATED)
+---------------+------------+-------+-----+-----+-----+
| Trial name | status | loc | a | b | c |
|---------------+------------+-------+-----+-----+-----|
@@ -136,7 +116,7 @@ EXPECTED_END_TO_END_END = """Number of trials: 30 (30 TERMINATED)
| f_xxxxx_00029 | TERMINATED | | | | 9 |
+---------------+------------+-------+-----+-----+-----+"""
EXPECTED_END_TO_END_AC = """Number of trials: 30 (30 TERMINATED)
EXPECTED_END_TO_END_AC = """Number of trials: 30/30 (30 TERMINATED)
+---------------+------------+-------+-----+-----+-----+
| Trial name | status | loc | a | b | c |
|---------------+------------+-------+-----+-----+-----|
@@ -170,7 +170,14 @@ class RayTrialExecutorTest(unittest.TestCase):
def generate_trials(spec, name):
suggester = BasicVariantGenerator()
suggester.add_configurations({name: spec})
return suggester.next_trials()
trials = []
while not suggester.is_finished():
trial = suggester.next_trial()
if trial:
trials.append(trial)
else:
break
return trials
def process_trial_save(self, trial):
"""Simulates trial runner save."""
+4 -1
View File
@@ -57,7 +57,10 @@ class TrialRunnerTest(unittest.TestCase):
for name, spec in experiments.items():
trial_generator = BasicVariantGenerator()
trial_generator.add_configurations({name: spec})
for trial in trial_generator.next_trials():
while not trial_generator.is_finished():
trial = trial_generator.next_trial()
if not trial:
break
trial_executor.start_trial(trial)
self.assertLessEqual(len(os.path.basename(trial.logdir)), 200)
trial_executor.stop_trial(trial)
+4 -5
View File
@@ -228,19 +228,18 @@ class TrialRunnerTest3(unittest.TestCase):
class FinishFastAlg(_MockSuggestionAlgorithm):
_index = 0
def next_trials(self):
def next_trial(self):
spec = self._experiment.spec
trials = []
trial = None
if self._index < spec["num_samples"]:
trial = Trial(
spec.get("run"), stopping_criterion=spec.get("stop"))
trials.append(trial)
self._index += 1
if self._index > 4:
self.set_finished()
return trials
return trial
def suggest(self, trial_id):
return {}
@@ -625,7 +624,7 @@ class SearchAlgorithmTest(unittest.TestCase):
searcher = TestSuggestion()
alg = SearchGenerator(searcher)
alg.add_configurations({"test": {"run": "__fake"}})
trial = alg.next_trials()[0]
trial = alg.next_trial()
self.assertTrue("e=5" in trial.experiment_tag)
self.assertTrue("d=4" in trial.experiment_tag)
+8 -1
View File
@@ -24,7 +24,14 @@ class VariantGeneratorTest(unittest.TestCase):
def generate_trials(self, spec, name):
suggester = BasicVariantGenerator()
suggester.add_configurations({name: spec})
return suggester.next_trials()
trials = []
while not suggester.is_finished():
trial = suggester.next_trial()
if trial:
trials.append(trial)
else:
break
return trials
def testParseToTrials(self):
trials = self.generate_trials({
+5 -5
View File
@@ -734,19 +734,19 @@ class TrialRunner:
or is_finished (timeout or search algorithm finishes).
timeout (int): Seconds before blocking times out.
"""
trials = self._search_alg.next_trials()
if blocking and not trials:
trial = self._search_alg.next_trial()
if blocking and not trial:
start = time.time()
# Checking `is_finished` instead of _search_alg.is_finished
# is fine because blocking only occurs if all trials are
# finished and search_algorithm is not yet finished
while (not trials and not self.is_finished()
while (not trial and not self.is_finished()
and time.time() - start < timeout):
logger.info("Blocking for next trial...")
trials = self._search_alg.next_trials()
trial = self._search_alg.next_trial()
time.sleep(1)
for trial in trials:
if trial:
self.add_trial(trial)
def request_stop_trial(self, trial):
+8 -1
View File
@@ -1,4 +1,5 @@
import logging
import sys
from ray.tune.error import TuneError
from ray.tune.experiment import convert_to_experiment_list, Experiment
@@ -177,7 +178,8 @@ def run(
num_samples (int): Number of times to sample from the
hyperparameter space. Defaults to 1. If `grid_search` is
provided as an argument, the grid will be repeated
`num_samples` of times.
`num_samples` of times. If this is -1, (virtually) infinite
samples are generated until a stopping condition is met.
local_dir (str): Local dir to save training results to.
Defaults to ``~/ray_results``.
search_alg (Searcher): Search algorithm for optimization.
@@ -292,6 +294,9 @@ def run(
sync_config = sync_config or SyncConfig()
set_sync_periods(sync_config)
if num_samples == -1:
num_samples = sys.maxsize
trial_executor = trial_executor or RayTrialExecutor(
reuse_actors=reuse_actors, queue_trials=queue_trials)
if isinstance(run_or_experiment, list):
@@ -383,6 +388,8 @@ def run(
else:
progress_reporter = CLIReporter()
progress_reporter.set_total_samples(search_alg.total_samples)
# User Warning for GPUs
if trial_executor.has_gpus():
if isinstance(resources_per_trial,
+4 -1
View File
@@ -211,7 +211,10 @@ def RunnerHandler(runner):
resource["trials"] = []
trial_generator = BasicVariantGenerator()
trial_generator.add_configurations({name: spec})
for trial in trial_generator.next_trials():
while not trial_generator.is_finished():
trial = trial_generator.next_trial()
if not trial:
break
runner.add_trial(trial)
resource["trials"].append(self._trial_info(trial))
return resource