[tune] Cancel Experiment via Client (#7719)

* init cancel

* testing

* Update python/ray/tune/tests/test_tune_server.py

Co-Authored-By: Richard Liaw <rliaw@berkeley.edu>

* Apply suggestions from code review

* Apply suggestions from code review

* finished

* set_finished

Co-authored-by: ijrsvt <ian.rodney@gmail.com>
This commit is contained in:
Richard Liaw
2020-03-24 20:30:12 -07:00
committed by GitHub
parent a519b4f2a9
commit 54a892bb84
8 changed files with 58 additions and 29 deletions
+1 -4
View File
@@ -73,7 +73,7 @@ class BasicVariantGenerator(SearchAlgorithm):
trials = list(self._trial_generator)
if self._shuffle:
random.shuffle(trials)
self._finished = True
self.set_finished()
return trials
def _generate_trials(self, num_samples, unresolved_spec, output_path=""):
@@ -104,6 +104,3 @@ class BasicVariantGenerator(SearchAlgorithm):
evaluated_params=flatten_resolved_vars(resolved_vars),
trial_id=trial_id,
experiment_tag=experiment_tag)
def is_finished(self):
return self._finished
+6 -1
View File
@@ -10,6 +10,7 @@ class SearchAlgorithm:
See also: `ray.tune.suggest.BasicVariantGenerator`.
"""
_finished = False
def add_configurations(self, experiments):
"""Tracks given experiment specifications.
@@ -62,4 +63,8 @@ class SearchAlgorithm:
Can return True before all trials have finished executing.
"""
raise NotImplementedError
return self._finished
def set_finished(self):
"""Marks the search algorithm as finished."""
self._finished = True
+2 -5
View File
@@ -34,11 +34,11 @@ class SuggestionAlgorithm(SearchAlgorithm):
self._parser = make_parser()
self._trial_generator = []
self._counter = 0
self._finished = False
self._metric = metric
assert mode in ["min", "max"]
self._mode = mode
self._use_early_stopped = use_early_stopped_trials
self._finished = False
def add_configurations(self, experiments):
"""Chains generator given experiment specifications.
@@ -69,7 +69,7 @@ class SuggestionAlgorithm(SearchAlgorithm):
return trials
trials += [trial]
self._finished = True
self.set_finished()
return trials
def _generate_trials(self, num_samples, experiment_spec, output_path=""):
@@ -105,9 +105,6 @@ class SuggestionAlgorithm(SearchAlgorithm):
experiment_tag=tag,
trial_id=trial_id)
def is_finished(self):
return self._finished
def suggest(self, trial_id):
"""Queries the algorithm to retrieve the next set of parameters.
+1 -1
View File
@@ -250,7 +250,7 @@ class TrialRunnerTest3(unittest.TestCase):
break
if self._index > 4:
self._finished = True
self.set_finished()
return trials
def suggest(self, trial_id):
+18 -1
View File
@@ -1,6 +1,7 @@
import unittest
import requests
import socket
import subprocess
import unittest
import json
import ray
@@ -119,6 +120,22 @@ class TuneServerSuite(unittest.TestCase):
self.assertEqual(
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0)
def testStopExperiment(self):
"""Check if stop_experiment works."""
runner, client = self.basicSetup()
for i in range(2):
runner.step()
all_trials = client.get_all_trials()["trials"]
self.assertEqual(
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1)
client.stop_experiment()
runner.step()
self.assertTrue(runner.is_finished())
self.assertRaises(
requests.exceptions.ReadTimeout,
lambda: client.get_all_trials(timeout=1))
def testCurlCommand(self):
"""Check if Stop Trial works."""
runner, client = self.basicSetup()
+8 -3
View File
@@ -147,6 +147,7 @@ class TrialRunner:
self._trials = []
self._cached_trial_decisions = {}
self._stop_queue = []
self._should_stop_experiment = False # used by TuneServer
self._local_checkpoint_dir = local_checkpoint_dir
if self._local_checkpoint_dir:
@@ -346,7 +347,7 @@ class TrialRunner:
if self._server:
with warn_if_slow("server"):
self._process_requests()
self._process_stop_requests()
if self.is_finished():
self._server.shutdown()
@@ -393,7 +394,8 @@ class TrialRunner:
def _stop_experiment_if_needed(self):
"""Stops all trials if the user condition is satisfied."""
if self._stopper.stop_all():
if self._stopper.stop_all() or self._should_stop_experiment:
self._search_alg.set_finished()
[self.trial_executor.stop_trial(t) for t in self._trials]
logger.info("All trials stopped due to ``stopper.stop_all``.")
@@ -688,7 +690,10 @@ class TrialRunner:
def request_stop_trial(self, trial):
self._stop_queue.append(trial)
def _process_requests(self):
def request_stop_experiment(self):
self._should_stop_experiment = True
def _process_stop_requests(self):
while self._stop_queue:
t = self._stop_queue.pop()
self.stop_trial(t)
+21 -13
View File
@@ -35,15 +35,15 @@ class TuneClient:
self._port_forward = port_forward
self._path = "http://{}:{}".format(tune_address, port_forward)
def get_all_trials(self):
def get_all_trials(self, timeout=None):
"""Returns a list of all trials' information."""
response = requests.get(urljoin(self._path, "trials"))
response = requests.get(urljoin(self._path, "trials"), timeout=timeout)
return self._deserialize(response)
def get_trial(self, trial_id):
def get_trial(self, trial_id, timeout=None):
"""Returns trial information by trial_id."""
response = requests.get(
urljoin(self._path, "trials/{}".format(trial_id)))
urljoin(self._path, "trials/{}".format(trial_id)), timeout=timeout)
return self._deserialize(response)
def add_trial(self, name, specification):
@@ -58,6 +58,11 @@ class TuneClient:
urljoin(self._path, "trials/{}".format(trial_id)))
return self._deserialize(response)
def stop_experiment(self):
"""Requests to stop the entire experiment."""
response = requests.put(urljoin(self._path, "stop_experiment"))
return self._deserialize(response)
@property
def server_address(self):
return self._tune_address
@@ -137,17 +142,20 @@ def RunnerHandler(runner):
response_code = 200
message = ""
try:
result = self._get_trial_by_url(self.path)
resource = {}
if result:
if isinstance(result, list):
infos = [self._trial_info(t) for t in result]
resource["trials"] = infos
for t in result:
if self.path.endswith("stop_experiment"):
runner.request_stop_experiment()
trials = list(runner.get_trials())
else:
trials = self._get_trial_by_url(self.path)
if trials:
if not isinstance(trials, list):
trials = [trials]
for t in trials:
runner.request_stop_trial(t)
else:
resource["trial"] = self._trial_info(result)
runner.request_stop_trial(result)
resource["trials"] = [self._trial_info(t) for t in trials]
message = json.dumps(resource)
except TuneError as e:
response_code = 404
+1 -1
View File
@@ -77,7 +77,7 @@ extras = {
"debug": [],
"dashboard": [],
"serve": ["uvicorn", "pygments", "werkzeug", "flask", "pandas", "blist"],
"tune": ["tabulate", "tensorboardX"]
"tune": ["tabulate", "tensorboardX", "pandas"]
}
extras["rllib"] = extras["tune"] + [