mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 04:44:28 +08:00
[tune] Support RESTful API for the Web Server (#4080)
Change the client/server API to RESTful design. This includes resource modeling, model URI's, and correct HTTP methods.
This commit is contained in:
committed by
Richard Liaw
parent
33663bef94
commit
5cf388f29d
@@ -8,7 +8,7 @@
|
||||
"source": [
|
||||
"from ray.tune.web_server import TuneClient\n",
|
||||
"\n",
|
||||
"manager = TuneClient(tune_address=\"localhost:4321\")\n",
|
||||
"manager = TuneClient(tune_address=\"localhost\", port_forward=4321)\n",
|
||||
"\n",
|
||||
"x = manager.get_all_trials()\n",
|
||||
"\n",
|
||||
@@ -19,7 +19,6 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
@@ -31,9 +30,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import yaml\n",
|
||||
@@ -45,9 +42,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"name, spec = [x for x in d.items()][0]"
|
||||
@@ -79,7 +74,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.2"
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -6,8 +6,3 @@ from __future__ import print_function
|
||||
class TuneError(Exception):
|
||||
"""General error class raised by ray.tune."""
|
||||
pass
|
||||
|
||||
|
||||
class TuneManagerError(TuneError):
|
||||
"""Error raised in operating the Tune Manager."""
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import socket
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
@@ -44,7 +46,7 @@ class TuneServerSuite(unittest.TestCase):
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
client = TuneClient("localhost:{}".format(port))
|
||||
client = TuneClient("localhost", port)
|
||||
return runner, client
|
||||
|
||||
def tearDown(self):
|
||||
@@ -126,6 +128,25 @@ class TuneServerSuite(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0)
|
||||
|
||||
def testCurlCommand(self):
|
||||
"""Check if Stop Trial works."""
|
||||
runner, client = self.basicSetup()
|
||||
for i in range(2):
|
||||
runner.step()
|
||||
stdout = subprocess.check_output(
|
||||
'curl "http://{}:{}/trials"'.format(client.server_address,
|
||||
client.server_port),
|
||||
shell=True)
|
||||
self.assertNotEqual(stdout, None)
|
||||
curl_trials = json.loads(stdout.decode())["trials"]
|
||||
client_trials = client.get_all_trials()["trials"]
|
||||
for curl_trial, client_trial in zip(curl_trials, client_trials):
|
||||
self.assertEqual(curl_trial.keys(), client_trial.keys())
|
||||
self.assertEqual(curl_trial["id"], client_trial["id"])
|
||||
self.assertEqual(curl_trial["trainable_name"],
|
||||
client_trial["trainable_name"])
|
||||
self.assertEqual(curl_trial["status"], client_trial["status"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
+159
-87
@@ -8,14 +8,16 @@ import sys
|
||||
import threading
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.tune.error import TuneError, TuneManagerError
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
from urlparse import urljoin, urlparse
|
||||
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
||||
from SocketServer import TCPServer as HTTPServer
|
||||
elif sys.version_info[0] == 3:
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from http.server import SimpleHTTPRequestHandler, HTTPServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,83 +30,159 @@ except ImportError:
|
||||
"Be sure to install it on the client side.")
|
||||
|
||||
|
||||
def load_trial_info(trial_info):
|
||||
trial_info["config"] = cloudpickle.loads(
|
||||
hex_to_binary(trial_info["config"]))
|
||||
trial_info["result"] = cloudpickle.loads(
|
||||
hex_to_binary(trial_info["result"]))
|
||||
|
||||
|
||||
class TuneClient(object):
|
||||
"""Client to interact with ongoing Tune experiment.
|
||||
"""Client to interact with an ongoing Tune experiment.
|
||||
|
||||
Requires server to have started running."""
|
||||
STOP = "STOP"
|
||||
ADD = "ADD"
|
||||
GET_LIST = "GET_LIST"
|
||||
GET_TRIAL = "GET_TRIAL"
|
||||
Requires a TuneServer to have started running.
|
||||
|
||||
def __init__(self, tune_address):
|
||||
# TODO(rliaw): Better to specify address and port forward
|
||||
Attributes:
|
||||
tune_address (str): Address of running TuneServer
|
||||
port_forward (int): Port number of running TuneServer
|
||||
"""
|
||||
|
||||
def __init__(self, tune_address, port_forward):
|
||||
self._tune_address = tune_address
|
||||
self._path = "http://{}".format(tune_address)
|
||||
self._port_forward = port_forward
|
||||
self._path = "http://{}:{}".format(tune_address, port_forward)
|
||||
|
||||
def get_all_trials(self):
|
||||
"""Returns a list of all trials (trial_id, config, status)."""
|
||||
return self._get_response({"command": TuneClient.GET_LIST})
|
||||
"""Returns a list of all trials' information."""
|
||||
response = requests.get(urljoin(self._path, "trials"))
|
||||
return self._deserialize(response)
|
||||
|
||||
def get_trial(self, trial_id):
|
||||
"""Returns the last result for queried trial."""
|
||||
return self._get_response({
|
||||
"command": TuneClient.GET_TRIAL,
|
||||
"trial_id": trial_id
|
||||
})
|
||||
"""Returns trial information by trial_id."""
|
||||
response = requests.get(
|
||||
urljoin(self._path, "trials/{}".format(trial_id)))
|
||||
return self._deserialize(response)
|
||||
|
||||
def add_trial(self, name, trial_spec):
|
||||
"""Adds a trial of `name` with configurations."""
|
||||
# TODO(rliaw): have better way of specifying a new trial
|
||||
return self._get_response({
|
||||
"command": TuneClient.ADD,
|
||||
"name": name,
|
||||
"spec": trial_spec
|
||||
})
|
||||
def add_trial(self, name, specification):
|
||||
"""Adds a trial by name and specification (dict)."""
|
||||
payload = {"name": name, "spec": specification}
|
||||
response = requests.post(urljoin(self._path, "trials"), json=payload)
|
||||
return self._deserialize(response)
|
||||
|
||||
def stop_trial(self, trial_id):
|
||||
"""Requests to stop trial."""
|
||||
return self._get_response({
|
||||
"command": TuneClient.STOP,
|
||||
"trial_id": trial_id
|
||||
})
|
||||
"""Requests to stop trial by trial_id."""
|
||||
response = requests.put(
|
||||
urljoin(self._path, "trials/{}".format(trial_id)))
|
||||
return self._deserialize(response)
|
||||
|
||||
def _get_response(self, data):
|
||||
payload = json.dumps(data).encode()
|
||||
response = requests.get(self._path, data=payload)
|
||||
@property
|
||||
def server_address(self):
|
||||
return self._tune_address
|
||||
|
||||
@property
|
||||
def server_port(self):
|
||||
return self._port_forward
|
||||
|
||||
def _load_trial_info(self, trial_info):
|
||||
trial_info["config"] = cloudpickle.loads(
|
||||
hex_to_binary(trial_info["config"]))
|
||||
trial_info["result"] = cloudpickle.loads(
|
||||
hex_to_binary(trial_info["result"]))
|
||||
|
||||
def _deserialize(self, response):
|
||||
parsed = response.json()
|
||||
|
||||
if "trial_info" in parsed:
|
||||
load_trial_info(parsed["trial_info"])
|
||||
if "trial" in parsed:
|
||||
self._load_trial_info(parsed["trial"])
|
||||
elif "trials" in parsed:
|
||||
for trial_info in parsed["trials"]:
|
||||
load_trial_info(trial_info)
|
||||
self._load_trial_info(trial_info)
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def RunnerHandler(runner):
|
||||
class Handler(SimpleHTTPRequestHandler):
|
||||
"""A Handler is a custom handler for TuneServer.
|
||||
|
||||
Handles all requests and responses coming into and from
|
||||
the TuneServer.
|
||||
"""
|
||||
|
||||
def _do_header(self, response_code=200, headers=None):
|
||||
"""Sends the header portion of the HTTP response.
|
||||
|
||||
Parameters:
|
||||
response_code (int): Standard HTTP response code
|
||||
headers (list[tuples]): Standard HTTP response headers
|
||||
"""
|
||||
if headers is None:
|
||||
headers = [('Content-type', 'application/json')]
|
||||
|
||||
self.send_response(response_code)
|
||||
for key, value in headers:
|
||||
self.send_header(key, value)
|
||||
self.end_headers()
|
||||
|
||||
def do_HEAD(self):
|
||||
"""HTTP HEAD handler method."""
|
||||
self._do_header()
|
||||
|
||||
def do_GET(self):
|
||||
"""HTTP GET handler method."""
|
||||
response_code = 200
|
||||
message = ""
|
||||
try:
|
||||
result = self._get_trial_by_url(self.path)
|
||||
resource = {}
|
||||
if result:
|
||||
if isinstance(result, list):
|
||||
infos = [self._trial_info(t) for t in result]
|
||||
resource["trials"] = infos
|
||||
else:
|
||||
resource["trial"] = self._trial_info(result)
|
||||
message = json.dumps(resource)
|
||||
except TuneError as e:
|
||||
response_code = 404
|
||||
message = str(e)
|
||||
|
||||
self._do_header(response_code=response_code)
|
||||
self.wfile.write(message.encode())
|
||||
|
||||
def do_PUT(self):
|
||||
"""HTTP PUT handler method."""
|
||||
response_code = 200
|
||||
message = ""
|
||||
try:
|
||||
result = self._get_trial_by_url(self.path)
|
||||
resource = {}
|
||||
if result:
|
||||
if isinstance(result, list):
|
||||
infos = [self._trial_info(t) for t in result]
|
||||
resource["trials"] = infos
|
||||
for t in result:
|
||||
runner.request_stop_trial(t)
|
||||
else:
|
||||
resource["trial"] = self._trial_info(result)
|
||||
runner.request_stop_trial(result)
|
||||
message = json.dumps(resource)
|
||||
except TuneError as e:
|
||||
response_code = 404
|
||||
message = str(e)
|
||||
|
||||
self._do_header(response_code=response_code)
|
||||
self.wfile.write(message.encode())
|
||||
|
||||
def do_POST(self):
|
||||
"""HTTP POST handler method."""
|
||||
response_code = 201
|
||||
|
||||
content_len = int(self.headers.get('Content-Length'), 0)
|
||||
raw_body = self.rfile.read(content_len)
|
||||
parsed_input = json.loads(raw_body.decode())
|
||||
status, response = self.execute_command(parsed_input)
|
||||
if status:
|
||||
self.send_response(200)
|
||||
else:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode())
|
||||
resource = self._add_trials(parsed_input["name"],
|
||||
parsed_input["spec"])
|
||||
|
||||
def trial_info(self, trial):
|
||||
headers = [('Content-type', 'application/json'), ('Location',
|
||||
'/trials/')]
|
||||
self._do_header(response_code=response_code, headers=headers)
|
||||
self.wfile.write(json.dumps(resource).encode())
|
||||
|
||||
def _trial_info(self, trial):
|
||||
"""Returns trial information as JSON."""
|
||||
if trial.last_result:
|
||||
result = trial.last_result.copy()
|
||||
else:
|
||||
@@ -118,62 +196,56 @@ def RunnerHandler(runner):
|
||||
}
|
||||
return info_dict
|
||||
|
||||
def execute_command(self, args):
|
||||
def get_trial():
|
||||
trial = runner.get_trial(args["trial_id"])
|
||||
if trial is None:
|
||||
error = "Trial ({}) not found.".format(args["trial_id"])
|
||||
raise TuneManagerError(error)
|
||||
else:
|
||||
return trial
|
||||
def _get_trial_by_url(self, url):
|
||||
"""Parses url to get either all trials or trial by trial_id."""
|
||||
parts = urlparse(url)
|
||||
path = parts.path
|
||||
|
||||
command = args["command"]
|
||||
response = {}
|
||||
try:
|
||||
if command == TuneClient.GET_LIST:
|
||||
response["trials"] = [
|
||||
self.trial_info(t) for t in runner.get_trials()
|
||||
]
|
||||
elif command == TuneClient.GET_TRIAL:
|
||||
trial = get_trial()
|
||||
response["trial_info"] = self.trial_info(trial)
|
||||
elif command == TuneClient.STOP:
|
||||
trial = get_trial()
|
||||
runner.request_stop_trial(trial)
|
||||
elif command == TuneClient.ADD:
|
||||
name = args["name"]
|
||||
spec = args["spec"]
|
||||
trial_generator = BasicVariantGenerator()
|
||||
trial_generator.add_configurations({name: spec})
|
||||
for trial in trial_generator.next_trials():
|
||||
runner.add_trial(trial)
|
||||
else:
|
||||
raise TuneManagerError("Unknown command.")
|
||||
status = True
|
||||
except TuneError as e:
|
||||
status = False
|
||||
response["message"] = str(e)
|
||||
if path == "/trials":
|
||||
return [t for t in runner.get_trials()]
|
||||
else:
|
||||
trial_id = path.split("/")[-1]
|
||||
return runner.get_trial(trial_id)
|
||||
|
||||
return status, response
|
||||
def _add_trials(self, name, spec):
|
||||
"""Add trial by invoking TrialRunner."""
|
||||
resource = {}
|
||||
resource["trials"] = []
|
||||
trial_generator = BasicVariantGenerator()
|
||||
trial_generator.add_configurations({name: spec})
|
||||
for trial in trial_generator.next_trials():
|
||||
runner.add_trial(trial)
|
||||
resource["trials"].append(self._trial_info(trial))
|
||||
return resource
|
||||
|
||||
return Handler
|
||||
|
||||
|
||||
class TuneServer(threading.Thread):
|
||||
"""A TuneServer is a thread that initializes and runs a HTTPServer.
|
||||
|
||||
The server handles requests from a TuneClient.
|
||||
|
||||
Attributes:
|
||||
runner (TrialRunner): Runner that modifies and accesses trials.
|
||||
port_forward (int): Port number of TuneServer.
|
||||
"""
|
||||
|
||||
DEFAULT_PORT = 4321
|
||||
|
||||
def __init__(self, runner, port=None):
|
||||
|
||||
"""Initialize HTTPServer and serve forever by invoking self.run()"""
|
||||
threading.Thread.__init__(self)
|
||||
self._port = port if port else self.DEFAULT_PORT
|
||||
address = ('localhost', self._port)
|
||||
logger.info("Starting Tune Server...")
|
||||
self._server = HTTPServer(address, RunnerHandler(runner))
|
||||
self.daemon = True
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
self._server.serve_forever()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the underlying server."""
|
||||
self._server.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user