From 54a892bb84bc9da9d8b6484b0128d4bcc23b1f0a Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 24 Mar 2020 20:30:12 -0700 Subject: [PATCH] [tune] Cancel Experiment via Client (#7719) * init cancel * testing * Update python/ray/tune/tests/test_tune_server.py Co-Authored-By: Richard Liaw * Apply suggestions from code review * Apply suggestions from code review * finished * set_finished Co-authored-by: ijrsvt --- python/ray/tune/suggest/basic_variant.py | 5 +-- python/ray/tune/suggest/search.py | 7 +++- python/ray/tune/suggest/suggestion.py | 7 ++-- python/ray/tune/tests/test_trial_runner_3.py | 2 +- python/ray/tune/tests/test_tune_server.py | 19 ++++++++++- python/ray/tune/trial_runner.py | 11 +++++-- python/ray/tune/web_server.py | 34 ++++++++++++-------- python/setup.py | 2 +- 8 files changed, 58 insertions(+), 29 deletions(-) diff --git a/python/ray/tune/suggest/basic_variant.py b/python/ray/tune/suggest/basic_variant.py index 86872690b..2a728697f 100644 --- a/python/ray/tune/suggest/basic_variant.py +++ b/python/ray/tune/suggest/basic_variant.py @@ -73,7 +73,7 @@ class BasicVariantGenerator(SearchAlgorithm): trials = list(self._trial_generator) if self._shuffle: random.shuffle(trials) - self._finished = True + self.set_finished() return trials def _generate_trials(self, num_samples, unresolved_spec, output_path=""): @@ -104,6 +104,3 @@ class BasicVariantGenerator(SearchAlgorithm): evaluated_params=flatten_resolved_vars(resolved_vars), trial_id=trial_id, experiment_tag=experiment_tag) - - def is_finished(self): - return self._finished diff --git a/python/ray/tune/suggest/search.py b/python/ray/tune/suggest/search.py index e80767b40..107150297 100644 --- a/python/ray/tune/suggest/search.py +++ b/python/ray/tune/suggest/search.py @@ -10,6 +10,7 @@ class SearchAlgorithm: See also: `ray.tune.suggest.BasicVariantGenerator`. """ + _finished = False def add_configurations(self, experiments): """Tracks given experiment specifications. @@ -62,4 +63,8 @@ class SearchAlgorithm: Can return True before all trials have finished executing. """ - raise NotImplementedError + return self._finished + + def set_finished(self): + """Marks the search algorithm as finished.""" + self._finished = True diff --git a/python/ray/tune/suggest/suggestion.py b/python/ray/tune/suggest/suggestion.py index df62c0740..532dd2860 100644 --- a/python/ray/tune/suggest/suggestion.py +++ b/python/ray/tune/suggest/suggestion.py @@ -34,11 +34,11 @@ class SuggestionAlgorithm(SearchAlgorithm): self._parser = make_parser() self._trial_generator = [] self._counter = 0 - self._finished = False self._metric = metric assert mode in ["min", "max"] self._mode = mode self._use_early_stopped = use_early_stopped_trials + self._finished = False def add_configurations(self, experiments): """Chains generator given experiment specifications. @@ -69,7 +69,7 @@ class SuggestionAlgorithm(SearchAlgorithm): return trials trials += [trial] - self._finished = True + self.set_finished() return trials def _generate_trials(self, num_samples, experiment_spec, output_path=""): @@ -105,9 +105,6 @@ class SuggestionAlgorithm(SearchAlgorithm): experiment_tag=tag, trial_id=trial_id) - def is_finished(self): - return self._finished - def suggest(self, trial_id): """Queries the algorithm to retrieve the next set of parameters. diff --git a/python/ray/tune/tests/test_trial_runner_3.py b/python/ray/tune/tests/test_trial_runner_3.py index b837d8845..ae9002d93 100644 --- a/python/ray/tune/tests/test_trial_runner_3.py +++ b/python/ray/tune/tests/test_trial_runner_3.py @@ -250,7 +250,7 @@ class TrialRunnerTest3(unittest.TestCase): break if self._index > 4: - self._finished = True + self.set_finished() return trials def suggest(self, trial_id): diff --git a/python/ray/tune/tests/test_tune_server.py b/python/ray/tune/tests/test_tune_server.py index af5556a87..ec2e5046a 100644 --- a/python/ray/tune/tests/test_tune_server.py +++ b/python/ray/tune/tests/test_tune_server.py @@ -1,6 +1,7 @@ -import unittest +import requests import socket import subprocess +import unittest import json import ray @@ -119,6 +120,22 @@ class TuneServerSuite(unittest.TestCase): self.assertEqual( len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0) + def testStopExperiment(self): + """Check if stop_experiment works.""" + runner, client = self.basicSetup() + for i in range(2): + runner.step() + all_trials = client.get_all_trials()["trials"] + self.assertEqual( + len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1) + + client.stop_experiment() + runner.step() + self.assertTrue(runner.is_finished()) + self.assertRaises( + requests.exceptions.ReadTimeout, + lambda: client.get_all_trials(timeout=1)) + def testCurlCommand(self): """Check if Stop Trial works.""" runner, client = self.basicSetup() diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 78403965a..1bae88d2e 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -147,6 +147,7 @@ class TrialRunner: self._trials = [] self._cached_trial_decisions = {} self._stop_queue = [] + self._should_stop_experiment = False # used by TuneServer self._local_checkpoint_dir = local_checkpoint_dir if self._local_checkpoint_dir: @@ -346,7 +347,7 @@ class TrialRunner: if self._server: with warn_if_slow("server"): - self._process_requests() + self._process_stop_requests() if self.is_finished(): self._server.shutdown() @@ -393,7 +394,8 @@ class TrialRunner: def _stop_experiment_if_needed(self): """Stops all trials if the user condition is satisfied.""" - if self._stopper.stop_all(): + if self._stopper.stop_all() or self._should_stop_experiment: + self._search_alg.set_finished() [self.trial_executor.stop_trial(t) for t in self._trials] logger.info("All trials stopped due to ``stopper.stop_all``.") @@ -688,7 +690,10 @@ class TrialRunner: def request_stop_trial(self, trial): self._stop_queue.append(trial) - def _process_requests(self): + def request_stop_experiment(self): + self._should_stop_experiment = True + + def _process_stop_requests(self): while self._stop_queue: t = self._stop_queue.pop() self.stop_trial(t) diff --git a/python/ray/tune/web_server.py b/python/ray/tune/web_server.py index 6802fe0c3..c515a6390 100644 --- a/python/ray/tune/web_server.py +++ b/python/ray/tune/web_server.py @@ -35,15 +35,15 @@ class TuneClient: self._port_forward = port_forward self._path = "http://{}:{}".format(tune_address, port_forward) - def get_all_trials(self): + def get_all_trials(self, timeout=None): """Returns a list of all trials' information.""" - response = requests.get(urljoin(self._path, "trials")) + response = requests.get(urljoin(self._path, "trials"), timeout=timeout) return self._deserialize(response) - def get_trial(self, trial_id): + def get_trial(self, trial_id, timeout=None): """Returns trial information by trial_id.""" response = requests.get( - urljoin(self._path, "trials/{}".format(trial_id))) + urljoin(self._path, "trials/{}".format(trial_id)), timeout=timeout) return self._deserialize(response) def add_trial(self, name, specification): @@ -58,6 +58,11 @@ class TuneClient: urljoin(self._path, "trials/{}".format(trial_id))) return self._deserialize(response) + def stop_experiment(self): + """Requests to stop the entire experiment.""" + response = requests.put(urljoin(self._path, "stop_experiment")) + return self._deserialize(response) + @property def server_address(self): return self._tune_address @@ -137,17 +142,20 @@ def RunnerHandler(runner): response_code = 200 message = "" try: - result = self._get_trial_by_url(self.path) resource = {} - if result: - if isinstance(result, list): - infos = [self._trial_info(t) for t in result] - resource["trials"] = infos - for t in result: + + if self.path.endswith("stop_experiment"): + runner.request_stop_experiment() + trials = list(runner.get_trials()) + else: + trials = self._get_trial_by_url(self.path) + if trials: + if not isinstance(trials, list): + trials = [trials] + for t in trials: runner.request_stop_trial(t) - else: - resource["trial"] = self._trial_info(result) - runner.request_stop_trial(result) + + resource["trials"] = [self._trial_info(t) for t in trials] message = json.dumps(resource) except TuneError as e: response_code = 404 diff --git a/python/setup.py b/python/setup.py index 2dd1638ce..20bef2184 100644 --- a/python/setup.py +++ b/python/setup.py @@ -77,7 +77,7 @@ extras = { "debug": [], "dashboard": [], "serve": ["uvicorn", "pygments", "werkzeug", "flask", "pandas", "blist"], - "tune": ["tabulate", "tensorboardX"] + "tune": ["tabulate", "tensorboardX", "pandas"] } extras["rllib"] = extras["tune"] + [