mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 21:26:08 +08:00
[tune] lazy trials (#10802)
* Lazily fill trial queue * Update interface * Update end to end reporter test * Removed `next_trials()` method * Lint * Print total number of samples to be generated in progress reporter. Allow infinite samples. * Nit check
This commit is contained in:
@@ -56,6 +56,8 @@ class AutoMLSearcher(SearchAlgorithm):
|
||||
self._unfinished_count = 0
|
||||
self._running_trials = {}
|
||||
self._completed_trials = {}
|
||||
self._next_trials = []
|
||||
self._next_trial_iter = None
|
||||
|
||||
self._iteration = 0
|
||||
self._total_trial_num = 0
|
||||
@@ -68,10 +70,27 @@ class AutoMLSearcher(SearchAlgorithm):
|
||||
"""Returns the Trial object with the best reward_attr"""
|
||||
return self.best_trial
|
||||
|
||||
def next_trials(self):
|
||||
def next_trial(self):
|
||||
if not self._next_trial_iter:
|
||||
self._generate_next_trials()
|
||||
if not self._next_trials:
|
||||
self.set_finished()
|
||||
return None
|
||||
self._next_trial_iter = iter(self._next_trials)
|
||||
|
||||
try:
|
||||
return next(self._next_trial_iter)
|
||||
except StopIteration:
|
||||
self._next_trials = []
|
||||
self._next_trial_iter = None
|
||||
return None
|
||||
|
||||
def _generate_next_trials(self):
|
||||
self._next_trials = []
|
||||
|
||||
if self._unfinished_count > 0:
|
||||
# Last round not finished
|
||||
return []
|
||||
return
|
||||
|
||||
trials = []
|
||||
raw_param_list, extra_arg_list = self._select()
|
||||
@@ -110,7 +129,7 @@ class AutoMLSearcher(SearchAlgorithm):
|
||||
"new": ntrial,
|
||||
"total": self._total_trial_num
|
||||
})
|
||||
return trials
|
||||
self._next_trials = trials
|
||||
|
||||
def on_trial_result(self, trial_id, result):
|
||||
if not result:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
@@ -94,10 +96,12 @@ class TuneReporterBase(ProgressReporter):
|
||||
def __init__(self,
|
||||
metric_columns=None,
|
||||
parameter_columns=None,
|
||||
total_samples=None,
|
||||
max_progress_rows=20,
|
||||
max_error_rows=20,
|
||||
max_report_frequency=5,
|
||||
infer_limit=3):
|
||||
self._total_samples = total_samples
|
||||
self._metrics_override = metric_columns is not None
|
||||
self._inferred_metrics = {}
|
||||
self._metric_columns = metric_columns or self.DEFAULT_COLUMNS.copy()
|
||||
@@ -109,6 +113,9 @@ class TuneReporterBase(ProgressReporter):
|
||||
self._max_report_freqency = max_report_frequency
|
||||
self._last_report_time = 0
|
||||
|
||||
def set_total_samples(self, total_samples):
|
||||
self._total_samples = total_samples
|
||||
|
||||
def should_report(self, trials, done=False):
|
||||
if time.time() - self._last_report_time > self._max_report_freqency:
|
||||
self._last_report_time = time.time()
|
||||
@@ -191,6 +198,7 @@ class TuneReporterBase(ProgressReporter):
|
||||
trials,
|
||||
metric_columns=self._metric_columns,
|
||||
parameter_columns=self._parameter_columns,
|
||||
total_samples=self._total_samples,
|
||||
fmt=fmt,
|
||||
max_rows=max_progress))
|
||||
messages.append(trial_errors_str(trials, fmt=fmt, max_rows=max_error))
|
||||
@@ -243,12 +251,13 @@ class JupyterNotebookReporter(TuneReporterBase):
|
||||
overwrite,
|
||||
metric_columns=None,
|
||||
parameter_columns=None,
|
||||
total_samples=None,
|
||||
max_progress_rows=20,
|
||||
max_error_rows=20,
|
||||
max_report_frequency=5):
|
||||
super(JupyterNotebookReporter, self).__init__(
|
||||
metric_columns, parameter_columns, max_progress_rows,
|
||||
max_error_rows, max_report_frequency)
|
||||
metric_columns, parameter_columns, total_samples,
|
||||
max_progress_rows, max_error_rows, max_report_frequency)
|
||||
self._overwrite = overwrite
|
||||
|
||||
def report(self, trials, done, *sys_info):
|
||||
@@ -287,13 +296,14 @@ class CLIReporter(TuneReporterBase):
|
||||
def __init__(self,
|
||||
metric_columns=None,
|
||||
parameter_columns=None,
|
||||
total_samples=None,
|
||||
max_progress_rows=20,
|
||||
max_error_rows=20,
|
||||
max_report_frequency=5):
|
||||
|
||||
super(CLIReporter, self).__init__(metric_columns, parameter_columns,
|
||||
max_progress_rows, max_error_rows,
|
||||
max_report_frequency)
|
||||
total_samples, max_progress_rows,
|
||||
max_error_rows, max_report_frequency)
|
||||
|
||||
def report(self, trials, done, *sys_info):
|
||||
print(self._progress_str(trials, done, *sys_info))
|
||||
@@ -324,6 +334,7 @@ def memory_debug_str():
|
||||
def trial_progress_str(trials,
|
||||
metric_columns,
|
||||
parameter_columns=None,
|
||||
total_samples=0,
|
||||
fmt="psql",
|
||||
max_rows=None):
|
||||
"""Returns a human readable message for printing to the console.
|
||||
@@ -342,6 +353,7 @@ def trial_progress_str(trials,
|
||||
values are the names to use in the message. If this is a list,
|
||||
the parameter name is used in the message directly. If this is
|
||||
empty, all parameters are used in the message.
|
||||
total_samples (int): Total number of trials that will be generated.
|
||||
fmt (str): Output format (see tablefmt in tabulate API).
|
||||
max_rows (int): Maximum number of rows in the trial table. Defaults to
|
||||
unlimited.
|
||||
@@ -381,8 +393,13 @@ def trial_progress_str(trials,
|
||||
overflow_str = ", ".join(overflow_strs)
|
||||
else:
|
||||
overflow = False
|
||||
messages.append("Number of trials: {} ({})".format(
|
||||
num_trials, ", ".join(num_trials_strs)))
|
||||
|
||||
if total_samples >= sys.maxsize:
|
||||
total_samples = "infinite"
|
||||
|
||||
messages.append("Number of trials: {}{} ({})".format(
|
||||
num_trials, f"/{total_samples}"
|
||||
if total_samples else "", ", ".join(num_trials_strs)))
|
||||
|
||||
# Pre-process trials to figure out what columns to show.
|
||||
if isinstance(metric_columns, Mapping):
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import Experiment, convert_to_experiment_list
|
||||
from ray.tune.config_parser import make_parser, create_trial_from_spec
|
||||
from ray.tune.suggest.variant_generator import (generate_variants, format_vars,
|
||||
flatten_resolved_vars)
|
||||
from ray.tune.suggest.variant_generator import (
|
||||
count_variants, generate_variants, format_vars, flatten_resolved_vars)
|
||||
from ray.tune.suggest.search import SearchAlgorithm
|
||||
|
||||
|
||||
@@ -17,10 +16,6 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
|
||||
See also: `ray.tune.suggest.variant_generator`.
|
||||
|
||||
|
||||
Parameters:
|
||||
shuffle (bool): Shuffles the generated list of configurations.
|
||||
|
||||
User API:
|
||||
|
||||
.. code-block:: python
|
||||
@@ -39,19 +34,19 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
|
||||
searcher = BasicVariantGenerator()
|
||||
searcher.add_configurations({"experiment": { ... }})
|
||||
list_of_trials = searcher.next_trials()
|
||||
trial = searcher.next_trial()
|
||||
searcher.is_finished == True
|
||||
"""
|
||||
|
||||
def __init__(self, shuffle: bool = False):
|
||||
def __init__(self):
|
||||
"""Initializes the Variant Generator.
|
||||
|
||||
"""
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = []
|
||||
self._trial_iter = None
|
||||
self._counter = 0
|
||||
self._finished = False
|
||||
self._shuffle = shuffle
|
||||
|
||||
# Unique prefix for all trials generated, e.g., trial ids start as
|
||||
# 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing.
|
||||
@@ -61,6 +56,12 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
else:
|
||||
self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_"
|
||||
|
||||
self._total_samples = 0
|
||||
|
||||
@property
|
||||
def total_samples(self):
|
||||
return self._total_samples
|
||||
|
||||
def add_configurations(
|
||||
self,
|
||||
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
|
||||
@@ -71,23 +72,28 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
"""
|
||||
experiment_list = convert_to_experiment_list(experiments)
|
||||
for experiment in experiment_list:
|
||||
self._total_samples += count_variants(experiment.spec)
|
||||
self._trial_generator = itertools.chain(
|
||||
self._trial_generator,
|
||||
self._generate_trials(
|
||||
experiment.spec.get("num_samples", 1), experiment.spec,
|
||||
experiment.name))
|
||||
|
||||
def next_trials(self):
|
||||
"""Provides Trial objects to be queued into the TrialRunner.
|
||||
def next_trial(self):
|
||||
"""Provides one Trial object to be queued into the TrialRunner.
|
||||
|
||||
Returns:
|
||||
trials (list): Returns a list of trials.
|
||||
Trial: Returns a single trial.
|
||||
"""
|
||||
trials = list(self._trial_generator)
|
||||
if self._shuffle:
|
||||
random.shuffle(trials)
|
||||
self.set_finished()
|
||||
return trials
|
||||
if not self._trial_iter:
|
||||
self._trial_iter = iter(self._trial_generator)
|
||||
try:
|
||||
return next(self._trial_iter)
|
||||
except StopIteration:
|
||||
self._trial_generator = []
|
||||
self._trial_iter = None
|
||||
self.set_finished()
|
||||
return None
|
||||
|
||||
def _generate_trials(self, num_samples, unresolved_spec, output_path=""):
|
||||
"""Generates Trial objects with the variant generation process.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
class SearchAlgorithm:
|
||||
@@ -36,6 +35,11 @@ class SearchAlgorithm:
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def total_samples(self):
|
||||
"""Get number of total trials to be generated"""
|
||||
return 0
|
||||
|
||||
def add_configurations(
|
||||
self,
|
||||
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
|
||||
@@ -46,11 +50,11 @@ class SearchAlgorithm:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def next_trials(self) -> List[Trial]:
|
||||
"""Provides Trial objects to be queued into the TrialRunner.
|
||||
def next_trial(self):
|
||||
"""Returns single Trial object to be queued into the TrialRunner.
|
||||
|
||||
Returns:
|
||||
trials (list): Returns a list of trials.
|
||||
trial (Trial): Returns a Trial object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -67,13 +67,17 @@ class SearchGenerator(SearchAlgorithm):
|
||||
self._parser = make_parser()
|
||||
self._experiment = None
|
||||
self._counter = 0 # Keeps track of number of trials created.
|
||||
self._total_samples = None # int: total samples to evaluate.
|
||||
self._total_samples = 0 # int: total samples to evaluate.
|
||||
self._finished = False
|
||||
|
||||
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
|
||||
config: Dict) -> bool:
|
||||
return self.searcher.set_search_properties(metric, mode, config)
|
||||
|
||||
@property
|
||||
def total_samples(self):
|
||||
return self._total_samples
|
||||
|
||||
def add_configurations(
|
||||
self,
|
||||
experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]):
|
||||
@@ -95,20 +99,16 @@ class SearchGenerator(SearchAlgorithm):
|
||||
if "run" not in experiment_spec:
|
||||
raise TuneError("Must specify `run` in {}".format(experiment_spec))
|
||||
|
||||
def next_trials(self) -> List[Trial]:
|
||||
"""Provides a batch of Trial objects to be queued into the TrialRunner.
|
||||
def next_trial(self):
|
||||
"""Provides one Trial object to be queued into the TrialRunner.
|
||||
|
||||
Returns:
|
||||
List[Trial]: A list of trials for the Runner to consume.
|
||||
Trial: Returns a single trial.
|
||||
"""
|
||||
trials = []
|
||||
while not self.is_finished():
|
||||
trial = self.create_trial_if_possible(self._experiment.spec,
|
||||
self._experiment.name)
|
||||
if trial is None:
|
||||
break
|
||||
trials.append(trial)
|
||||
return trials
|
||||
if not self.is_finished():
|
||||
return self.create_trial_if_possible(self._experiment.spec,
|
||||
self._experiment.name)
|
||||
return None
|
||||
|
||||
def create_trial_if_possible(self, experiment_spec: Dict,
|
||||
output_path: str) -> Optional[Trial]:
|
||||
|
||||
@@ -138,6 +138,15 @@ def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[
|
||||
return resolved_vars, domain_vars, grid_vars
|
||||
|
||||
|
||||
def count_variants(spec: Dict) -> int:
|
||||
spec = copy.deepcopy(spec)
|
||||
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
grid_count = 1
|
||||
for path, domain in grid_vars:
|
||||
grid_count *= len(domain.categories)
|
||||
return spec.get("num_samples", 1) * grid_count
|
||||
|
||||
|
||||
def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]:
|
||||
spec = copy.deepcopy(spec)
|
||||
_, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
@@ -5,6 +5,16 @@ from ray.tune import register_trainable
|
||||
from ray.tune.automl import SearchSpace, DiscreteSpace, GridSearch
|
||||
|
||||
|
||||
def next_trials(searcher):
|
||||
trials = []
|
||||
while not searcher.is_finished():
|
||||
trial = searcher.next_trial()
|
||||
if not trial:
|
||||
break
|
||||
trials.append(trial)
|
||||
return trials
|
||||
|
||||
|
||||
class AutoMLSearcherTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def dummy_train(config, reporter):
|
||||
@@ -20,7 +30,7 @@ class AutoMLSearcherTest(unittest.TestCase):
|
||||
])
|
||||
searcher = GridSearch(space, "reward")
|
||||
searcher.add_configurations(exp)
|
||||
trials = searcher.next_trials()
|
||||
trials = next_trials(searcher)
|
||||
|
||||
self.assertEqual(len(trials), 4)
|
||||
self.assertTrue(trials[0].config["a"]["b"]["c"] in [1, 2])
|
||||
@@ -34,9 +44,9 @@ class AutoMLSearcherTest(unittest.TestCase):
|
||||
])
|
||||
searcher = GridSearch(space, "reward")
|
||||
searcher.add_configurations(exp)
|
||||
trials = searcher.next_trials()
|
||||
trials = next_trials(searcher)
|
||||
|
||||
self.assertEqual(len(searcher.next_trials()), 0)
|
||||
self.assertEqual(searcher.next_trial(), None)
|
||||
for trial in trials[1:]:
|
||||
searcher.on_trial_complete(trial.trial_id)
|
||||
searcher.on_trial_complete(trials[0].trial_id, error=True)
|
||||
@@ -51,10 +61,11 @@ class AutoMLSearcherTest(unittest.TestCase):
|
||||
])
|
||||
searcher = GridSearch(space, "reward")
|
||||
searcher.add_configurations(exp)
|
||||
trials = searcher.next_trials()
|
||||
trials = next_trials(searcher)
|
||||
|
||||
self.assertEqual(len(searcher.next_trials()), 0)
|
||||
self.assertEqual(searcher.next_trial(), None)
|
||||
for i, trial in enumerate(trials):
|
||||
print("TRIAL {}".format(trial))
|
||||
rewards = list(range(i, i + 10))
|
||||
random.shuffle(rewards)
|
||||
for reward in rewards:
|
||||
|
||||
@@ -73,34 +73,14 @@ tune.run_experiments({
|
||||
},
|
||||
}, verbose=1)"""
|
||||
|
||||
EXPECTED_END_TO_END_START = """Number of trials: 30 (29 PENDING, 1 RUNNING)
|
||||
+---------------+----------+-------+-----+-----+
|
||||
| Trial name | status | loc | a | b |
|
||||
|---------------+----------+-------+-----+-----|
|
||||
| f_xxxxx_00001 | PENDING | | 1 | |
|
||||
| f_xxxxx_00002 | PENDING | | 2 | |
|
||||
| f_xxxxx_00003 | PENDING | | 3 | |
|
||||
| f_xxxxx_00004 | PENDING | | 4 | |
|
||||
| f_xxxxx_00005 | PENDING | | 5 | |
|
||||
| f_xxxxx_00006 | PENDING | | 6 | |
|
||||
| f_xxxxx_00007 | PENDING | | 7 | |
|
||||
| f_xxxxx_00008 | PENDING | | 8 | |
|
||||
| f_xxxxx_00009 | PENDING | | 9 | |
|
||||
| f_xxxxx_00010 | PENDING | | | 0 |
|
||||
| f_xxxxx_00011 | PENDING | | | 1 |
|
||||
| f_xxxxx_00012 | PENDING | | | 2 |
|
||||
| f_xxxxx_00013 | PENDING | | | 3 |
|
||||
| f_xxxxx_00014 | PENDING | | | 4 |
|
||||
| f_xxxxx_00015 | PENDING | | | 5 |
|
||||
| f_xxxxx_00016 | PENDING | | | 6 |
|
||||
| f_xxxxx_00017 | PENDING | | | 7 |
|
||||
| f_xxxxx_00018 | PENDING | | | 8 |
|
||||
| f_xxxxx_00019 | PENDING | | | 9 |
|
||||
| f_xxxxx_00000 | RUNNING | | 0 | |
|
||||
+---------------+----------+-------+-----+-----+
|
||||
... 10 more trials not shown (10 PENDING)"""
|
||||
EXPECTED_END_TO_END_START = """Number of trials: 1/30 (1 RUNNING)
|
||||
+---------------+----------+-------+-----+
|
||||
| Trial name | status | loc | a |
|
||||
|---------------+----------+-------+-----|
|
||||
| f_xxxxx_00000 | RUNNING | | 0 |
|
||||
+---------------+----------+-------+-----+"""
|
||||
|
||||
EXPECTED_END_TO_END_END = """Number of trials: 30 (30 TERMINATED)
|
||||
EXPECTED_END_TO_END_END = """Number of trials: 30/30 (30 TERMINATED)
|
||||
+---------------+------------+-------+-----+-----+-----+
|
||||
| Trial name | status | loc | a | b | c |
|
||||
|---------------+------------+-------+-----+-----+-----|
|
||||
@@ -136,7 +116,7 @@ EXPECTED_END_TO_END_END = """Number of trials: 30 (30 TERMINATED)
|
||||
| f_xxxxx_00029 | TERMINATED | | | | 9 |
|
||||
+---------------+------------+-------+-----+-----+-----+"""
|
||||
|
||||
EXPECTED_END_TO_END_AC = """Number of trials: 30 (30 TERMINATED)
|
||||
EXPECTED_END_TO_END_AC = """Number of trials: 30/30 (30 TERMINATED)
|
||||
+---------------+------------+-------+-----+-----+-----+
|
||||
| Trial name | status | loc | a | b | c |
|
||||
|---------------+------------+-------+-----+-----+-----|
|
||||
|
||||
@@ -170,7 +170,14 @@ class RayTrialExecutorTest(unittest.TestCase):
|
||||
def generate_trials(spec, name):
|
||||
suggester = BasicVariantGenerator()
|
||||
suggester.add_configurations({name: spec})
|
||||
return suggester.next_trials()
|
||||
trials = []
|
||||
while not suggester.is_finished():
|
||||
trial = suggester.next_trial()
|
||||
if trial:
|
||||
trials.append(trial)
|
||||
else:
|
||||
break
|
||||
return trials
|
||||
|
||||
def process_trial_save(self, trial):
|
||||
"""Simulates trial runner save."""
|
||||
|
||||
@@ -57,7 +57,10 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
for name, spec in experiments.items():
|
||||
trial_generator = BasicVariantGenerator()
|
||||
trial_generator.add_configurations({name: spec})
|
||||
for trial in trial_generator.next_trials():
|
||||
while not trial_generator.is_finished():
|
||||
trial = trial_generator.next_trial()
|
||||
if not trial:
|
||||
break
|
||||
trial_executor.start_trial(trial)
|
||||
self.assertLessEqual(len(os.path.basename(trial.logdir)), 200)
|
||||
trial_executor.stop_trial(trial)
|
||||
|
||||
@@ -228,19 +228,18 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
class FinishFastAlg(_MockSuggestionAlgorithm):
|
||||
_index = 0
|
||||
|
||||
def next_trials(self):
|
||||
def next_trial(self):
|
||||
spec = self._experiment.spec
|
||||
trials = []
|
||||
trial = None
|
||||
if self._index < spec["num_samples"]:
|
||||
trial = Trial(
|
||||
spec.get("run"), stopping_criterion=spec.get("stop"))
|
||||
trials.append(trial)
|
||||
self._index += 1
|
||||
|
||||
if self._index > 4:
|
||||
self.set_finished()
|
||||
|
||||
return trials
|
||||
return trial
|
||||
|
||||
def suggest(self, trial_id):
|
||||
return {}
|
||||
@@ -625,7 +624,7 @@ class SearchAlgorithmTest(unittest.TestCase):
|
||||
searcher = TestSuggestion()
|
||||
alg = SearchGenerator(searcher)
|
||||
alg.add_configurations({"test": {"run": "__fake"}})
|
||||
trial = alg.next_trials()[0]
|
||||
trial = alg.next_trial()
|
||||
self.assertTrue("e=5" in trial.experiment_tag)
|
||||
self.assertTrue("d=4" in trial.experiment_tag)
|
||||
|
||||
|
||||
@@ -24,7 +24,14 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
def generate_trials(self, spec, name):
|
||||
suggester = BasicVariantGenerator()
|
||||
suggester.add_configurations({name: spec})
|
||||
return suggester.next_trials()
|
||||
trials = []
|
||||
while not suggester.is_finished():
|
||||
trial = suggester.next_trial()
|
||||
if trial:
|
||||
trials.append(trial)
|
||||
else:
|
||||
break
|
||||
return trials
|
||||
|
||||
def testParseToTrials(self):
|
||||
trials = self.generate_trials({
|
||||
|
||||
@@ -734,19 +734,19 @@ class TrialRunner:
|
||||
or is_finished (timeout or search algorithm finishes).
|
||||
timeout (int): Seconds before blocking times out.
|
||||
"""
|
||||
trials = self._search_alg.next_trials()
|
||||
if blocking and not trials:
|
||||
trial = self._search_alg.next_trial()
|
||||
if blocking and not trial:
|
||||
start = time.time()
|
||||
# Checking `is_finished` instead of _search_alg.is_finished
|
||||
# is fine because blocking only occurs if all trials are
|
||||
# finished and search_algorithm is not yet finished
|
||||
while (not trials and not self.is_finished()
|
||||
while (not trial and not self.is_finished()
|
||||
and time.time() - start < timeout):
|
||||
logger.info("Blocking for next trial...")
|
||||
trials = self._search_alg.next_trials()
|
||||
trial = self._search_alg.next_trial()
|
||||
time.sleep(1)
|
||||
|
||||
for trial in trials:
|
||||
if trial:
|
||||
self.add_trial(trial)
|
||||
|
||||
def request_stop_trial(self, trial):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.experiment import convert_to_experiment_list, Experiment
|
||||
@@ -177,7 +178,8 @@ def run(
|
||||
num_samples (int): Number of times to sample from the
|
||||
hyperparameter space. Defaults to 1. If `grid_search` is
|
||||
provided as an argument, the grid will be repeated
|
||||
`num_samples` of times.
|
||||
`num_samples` of times. If this is -1, (virtually) infinite
|
||||
samples are generated until a stopping condition is met.
|
||||
local_dir (str): Local dir to save training results to.
|
||||
Defaults to ``~/ray_results``.
|
||||
search_alg (Searcher): Search algorithm for optimization.
|
||||
@@ -292,6 +294,9 @@ def run(
|
||||
sync_config = sync_config or SyncConfig()
|
||||
set_sync_periods(sync_config)
|
||||
|
||||
if num_samples == -1:
|
||||
num_samples = sys.maxsize
|
||||
|
||||
trial_executor = trial_executor or RayTrialExecutor(
|
||||
reuse_actors=reuse_actors, queue_trials=queue_trials)
|
||||
if isinstance(run_or_experiment, list):
|
||||
@@ -383,6 +388,8 @@ def run(
|
||||
else:
|
||||
progress_reporter = CLIReporter()
|
||||
|
||||
progress_reporter.set_total_samples(search_alg.total_samples)
|
||||
|
||||
# User Warning for GPUs
|
||||
if trial_executor.has_gpus():
|
||||
if isinstance(resources_per_trial,
|
||||
|
||||
@@ -211,7 +211,10 @@ def RunnerHandler(runner):
|
||||
resource["trials"] = []
|
||||
trial_generator = BasicVariantGenerator()
|
||||
trial_generator.add_configurations({name: spec})
|
||||
for trial in trial_generator.next_trials():
|
||||
while not trial_generator.is_finished():
|
||||
trial = trial_generator.next_trial()
|
||||
if not trial:
|
||||
break
|
||||
runner.add_trial(trial)
|
||||
resource["trials"].append(self._trial_info(trial))
|
||||
return resource
|
||||
|
||||
Reference in New Issue
Block a user