mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:00:36 +08:00
[tune] Cancel Experiment via Client (#7719)
* init cancel * testing * Update python/ray/tune/tests/test_tune_server.py Co-Authored-By: Richard Liaw <rliaw@berkeley.edu> * Apply suggestions from code review * Apply suggestions from code review * finished * set_finished Co-authored-by: ijrsvt <ian.rodney@gmail.com>
This commit is contained in:
@@ -73,7 +73,7 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
trials = list(self._trial_generator)
|
||||
if self._shuffle:
|
||||
random.shuffle(trials)
|
||||
self._finished = True
|
||||
self.set_finished()
|
||||
return trials
|
||||
|
||||
def _generate_trials(self, num_samples, unresolved_spec, output_path=""):
|
||||
@@ -104,6 +104,3 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
evaluated_params=flatten_resolved_vars(resolved_vars),
|
||||
trial_id=trial_id,
|
||||
experiment_tag=experiment_tag)
|
||||
|
||||
def is_finished(self):
|
||||
return self._finished
|
||||
|
||||
@@ -10,6 +10,7 @@ class SearchAlgorithm:
|
||||
|
||||
See also: `ray.tune.suggest.BasicVariantGenerator`.
|
||||
"""
|
||||
_finished = False
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
"""Tracks given experiment specifications.
|
||||
@@ -62,4 +63,8 @@ class SearchAlgorithm:
|
||||
|
||||
Can return True before all trials have finished executing.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return self._finished
|
||||
|
||||
def set_finished(self):
|
||||
"""Marks the search algorithm as finished."""
|
||||
self._finished = True
|
||||
|
||||
@@ -34,11 +34,11 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
self._parser = make_parser()
|
||||
self._trial_generator = []
|
||||
self._counter = 0
|
||||
self._finished = False
|
||||
self._metric = metric
|
||||
assert mode in ["min", "max"]
|
||||
self._mode = mode
|
||||
self._use_early_stopped = use_early_stopped_trials
|
||||
self._finished = False
|
||||
|
||||
def add_configurations(self, experiments):
|
||||
"""Chains generator given experiment specifications.
|
||||
@@ -69,7 +69,7 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
return trials
|
||||
trials += [trial]
|
||||
|
||||
self._finished = True
|
||||
self.set_finished()
|
||||
return trials
|
||||
|
||||
def _generate_trials(self, num_samples, experiment_spec, output_path=""):
|
||||
@@ -105,9 +105,6 @@ class SuggestionAlgorithm(SearchAlgorithm):
|
||||
experiment_tag=tag,
|
||||
trial_id=trial_id)
|
||||
|
||||
def is_finished(self):
|
||||
return self._finished
|
||||
|
||||
def suggest(self, trial_id):
|
||||
"""Queries the algorithm to retrieve the next set of parameters.
|
||||
|
||||
|
||||
@@ -250,7 +250,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
||||
break
|
||||
|
||||
if self._index > 4:
|
||||
self._finished = True
|
||||
self.set_finished()
|
||||
return trials
|
||||
|
||||
def suggest(self, trial_id):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
import requests
|
||||
import socket
|
||||
import subprocess
|
||||
import unittest
|
||||
import json
|
||||
|
||||
import ray
|
||||
@@ -119,6 +120,22 @@ class TuneServerSuite(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0)
|
||||
|
||||
def testStopExperiment(self):
|
||||
"""Check if stop_experiment works."""
|
||||
runner, client = self.basicSetup()
|
||||
for i in range(2):
|
||||
runner.step()
|
||||
all_trials = client.get_all_trials()["trials"]
|
||||
self.assertEqual(
|
||||
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 1)
|
||||
|
||||
client.stop_experiment()
|
||||
runner.step()
|
||||
self.assertTrue(runner.is_finished())
|
||||
self.assertRaises(
|
||||
requests.exceptions.ReadTimeout,
|
||||
lambda: client.get_all_trials(timeout=1))
|
||||
|
||||
def testCurlCommand(self):
|
||||
"""Check if Stop Trial works."""
|
||||
runner, client = self.basicSetup()
|
||||
|
||||
@@ -147,6 +147,7 @@ class TrialRunner:
|
||||
self._trials = []
|
||||
self._cached_trial_decisions = {}
|
||||
self._stop_queue = []
|
||||
self._should_stop_experiment = False # used by TuneServer
|
||||
self._local_checkpoint_dir = local_checkpoint_dir
|
||||
|
||||
if self._local_checkpoint_dir:
|
||||
@@ -346,7 +347,7 @@ class TrialRunner:
|
||||
|
||||
if self._server:
|
||||
with warn_if_slow("server"):
|
||||
self._process_requests()
|
||||
self._process_stop_requests()
|
||||
|
||||
if self.is_finished():
|
||||
self._server.shutdown()
|
||||
@@ -393,7 +394,8 @@ class TrialRunner:
|
||||
def _stop_experiment_if_needed(self):
|
||||
"""Stops all trials if the user condition is satisfied."""
|
||||
|
||||
if self._stopper.stop_all():
|
||||
if self._stopper.stop_all() or self._should_stop_experiment:
|
||||
self._search_alg.set_finished()
|
||||
[self.trial_executor.stop_trial(t) for t in self._trials]
|
||||
logger.info("All trials stopped due to ``stopper.stop_all``.")
|
||||
|
||||
@@ -688,7 +690,10 @@ class TrialRunner:
|
||||
def request_stop_trial(self, trial):
|
||||
self._stop_queue.append(trial)
|
||||
|
||||
def _process_requests(self):
|
||||
def request_stop_experiment(self):
|
||||
self._should_stop_experiment = True
|
||||
|
||||
def _process_stop_requests(self):
|
||||
while self._stop_queue:
|
||||
t = self._stop_queue.pop()
|
||||
self.stop_trial(t)
|
||||
|
||||
@@ -35,15 +35,15 @@ class TuneClient:
|
||||
self._port_forward = port_forward
|
||||
self._path = "http://{}:{}".format(tune_address, port_forward)
|
||||
|
||||
def get_all_trials(self):
|
||||
def get_all_trials(self, timeout=None):
|
||||
"""Returns a list of all trials' information."""
|
||||
response = requests.get(urljoin(self._path, "trials"))
|
||||
response = requests.get(urljoin(self._path, "trials"), timeout=timeout)
|
||||
return self._deserialize(response)
|
||||
|
||||
def get_trial(self, trial_id):
|
||||
def get_trial(self, trial_id, timeout=None):
|
||||
"""Returns trial information by trial_id."""
|
||||
response = requests.get(
|
||||
urljoin(self._path, "trials/{}".format(trial_id)))
|
||||
urljoin(self._path, "trials/{}".format(trial_id)), timeout=timeout)
|
||||
return self._deserialize(response)
|
||||
|
||||
def add_trial(self, name, specification):
|
||||
@@ -58,6 +58,11 @@ class TuneClient:
|
||||
urljoin(self._path, "trials/{}".format(trial_id)))
|
||||
return self._deserialize(response)
|
||||
|
||||
def stop_experiment(self):
|
||||
"""Requests to stop the entire experiment."""
|
||||
response = requests.put(urljoin(self._path, "stop_experiment"))
|
||||
return self._deserialize(response)
|
||||
|
||||
@property
|
||||
def server_address(self):
|
||||
return self._tune_address
|
||||
@@ -137,17 +142,20 @@ def RunnerHandler(runner):
|
||||
response_code = 200
|
||||
message = ""
|
||||
try:
|
||||
result = self._get_trial_by_url(self.path)
|
||||
resource = {}
|
||||
if result:
|
||||
if isinstance(result, list):
|
||||
infos = [self._trial_info(t) for t in result]
|
||||
resource["trials"] = infos
|
||||
for t in result:
|
||||
|
||||
if self.path.endswith("stop_experiment"):
|
||||
runner.request_stop_experiment()
|
||||
trials = list(runner.get_trials())
|
||||
else:
|
||||
trials = self._get_trial_by_url(self.path)
|
||||
if trials:
|
||||
if not isinstance(trials, list):
|
||||
trials = [trials]
|
||||
for t in trials:
|
||||
runner.request_stop_trial(t)
|
||||
else:
|
||||
resource["trial"] = self._trial_info(result)
|
||||
runner.request_stop_trial(result)
|
||||
|
||||
resource["trials"] = [self._trial_info(t) for t in trials]
|
||||
message = json.dumps(resource)
|
||||
except TuneError as e:
|
||||
response_code = 404
|
||||
|
||||
+1
-1
@@ -77,7 +77,7 @@ extras = {
|
||||
"debug": [],
|
||||
"dashboard": [],
|
||||
"serve": ["uvicorn", "pygments", "werkzeug", "flask", "pandas", "blist"],
|
||||
"tune": ["tabulate", "tensorboardX"]
|
||||
"tune": ["tabulate", "tensorboardX", "pandas"]
|
||||
}
|
||||
|
||||
extras["rllib"] = extras["tune"] + [
|
||||
|
||||
Reference in New Issue
Block a user