[tune] Support RESTful API for the Web Server (#4080)

Change the client/server API to RESTful design. This includes resource modeling, model URI's, and correct HTTP methods.
This commit is contained in:
Adi Zimmerman
2019-02-26 21:56:02 -08:00
committed by Richard Liaw
parent 33663bef94
commit 5cf388f29d
5 changed files with 209 additions and 112 deletions
+4 -9
View File
@@ -8,7 +8,7 @@
"source": [
"from ray.tune.web_server import TuneClient\n",
"\n",
"manager = TuneClient(tune_address=\"localhost:4321\")\n",
"manager = TuneClient(tune_address=\"localhost\", port_forward=4321)\n",
"\n",
"x = manager.get_all_trials()\n",
"\n",
@@ -19,7 +19,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": false
},
"outputs": [],
@@ -31,9 +30,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
@@ -45,9 +42,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"name, spec = [x for x in d.items()][0]"
@@ -79,7 +74,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
"version": "3.6.8"
}
},
"nbformat": 4,
-5
View File
@@ -6,8 +6,3 @@ from __future__ import print_function
class TuneError(Exception):
"""General error class raised by ray.tune."""
pass
class TuneManagerError(TuneError):
"""Error raised in operating the Tune Manager."""
pass
+22 -1
View File
@@ -4,6 +4,8 @@ from __future__ import print_function
import unittest
import socket
import subprocess
import json
import ray
from ray import tune
@@ -44,7 +46,7 @@ class TuneServerSuite(unittest.TestCase):
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
client = TuneClient("localhost:{}".format(port))
client = TuneClient("localhost", port)
return runner, client
def tearDown(self):
@@ -126,6 +128,25 @@ class TuneServerSuite(unittest.TestCase):
self.assertEqual(
len([t for t in all_trials if t["status"] == Trial.RUNNING]), 0)
def testCurlCommand(self):
"""Check if Stop Trial works."""
runner, client = self.basicSetup()
for i in range(2):
runner.step()
stdout = subprocess.check_output(
'curl "http://{}:{}/trials"'.format(client.server_address,
client.server_port),
shell=True)
self.assertNotEqual(stdout, None)
curl_trials = json.loads(stdout.decode())["trials"]
client_trials = client.get_all_trials()["trials"]
for curl_trial, client_trial in zip(curl_trials, client_trials):
self.assertEqual(curl_trial.keys(), client_trial.keys())
self.assertEqual(curl_trial["id"], client_trial["id"])
self.assertEqual(curl_trial["trainable_name"],
client_trial["trainable_name"])
self.assertEqual(curl_trial["status"], client_trial["status"])
if __name__ == "__main__":
unittest.main(verbosity=2)
+159 -87
View File
@@ -8,14 +8,16 @@ import sys
import threading
import ray.cloudpickle as cloudpickle
from ray.tune.error import TuneError, TuneManagerError
from ray.tune import TuneError
from ray.tune.suggest import BasicVariantGenerator
from ray.utils import binary_to_hex, hex_to_binary
if sys.version_info[0] == 2:
from urlparse import urljoin, urlparse
from SimpleHTTPServer import SimpleHTTPRequestHandler
from SocketServer import TCPServer as HTTPServer
elif sys.version_info[0] == 3:
from urllib.parse import urljoin, urlparse
from http.server import SimpleHTTPRequestHandler, HTTPServer
logger = logging.getLogger(__name__)
@@ -28,83 +30,159 @@ except ImportError:
"Be sure to install it on the client side.")
def load_trial_info(trial_info):
trial_info["config"] = cloudpickle.loads(
hex_to_binary(trial_info["config"]))
trial_info["result"] = cloudpickle.loads(
hex_to_binary(trial_info["result"]))
class TuneClient(object):
"""Client to interact with ongoing Tune experiment.
"""Client to interact with an ongoing Tune experiment.
Requires server to have started running."""
STOP = "STOP"
ADD = "ADD"
GET_LIST = "GET_LIST"
GET_TRIAL = "GET_TRIAL"
Requires a TuneServer to have started running.
def __init__(self, tune_address):
# TODO(rliaw): Better to specify address and port forward
Attributes:
tune_address (str): Address of running TuneServer
port_forward (int): Port number of running TuneServer
"""
def __init__(self, tune_address, port_forward):
self._tune_address = tune_address
self._path = "http://{}".format(tune_address)
self._port_forward = port_forward
self._path = "http://{}:{}".format(tune_address, port_forward)
def get_all_trials(self):
"""Returns a list of all trials (trial_id, config, status)."""
return self._get_response({"command": TuneClient.GET_LIST})
"""Returns a list of all trials' information."""
response = requests.get(urljoin(self._path, "trials"))
return self._deserialize(response)
def get_trial(self, trial_id):
"""Returns the last result for queried trial."""
return self._get_response({
"command": TuneClient.GET_TRIAL,
"trial_id": trial_id
})
"""Returns trial information by trial_id."""
response = requests.get(
urljoin(self._path, "trials/{}".format(trial_id)))
return self._deserialize(response)
def add_trial(self, name, trial_spec):
"""Adds a trial of `name` with configurations."""
# TODO(rliaw): have better way of specifying a new trial
return self._get_response({
"command": TuneClient.ADD,
"name": name,
"spec": trial_spec
})
def add_trial(self, name, specification):
"""Adds a trial by name and specification (dict)."""
payload = {"name": name, "spec": specification}
response = requests.post(urljoin(self._path, "trials"), json=payload)
return self._deserialize(response)
def stop_trial(self, trial_id):
"""Requests to stop trial."""
return self._get_response({
"command": TuneClient.STOP,
"trial_id": trial_id
})
"""Requests to stop trial by trial_id."""
response = requests.put(
urljoin(self._path, "trials/{}".format(trial_id)))
return self._deserialize(response)
def _get_response(self, data):
payload = json.dumps(data).encode()
response = requests.get(self._path, data=payload)
@property
def server_address(self):
return self._tune_address
@property
def server_port(self):
return self._port_forward
def _load_trial_info(self, trial_info):
trial_info["config"] = cloudpickle.loads(
hex_to_binary(trial_info["config"]))
trial_info["result"] = cloudpickle.loads(
hex_to_binary(trial_info["result"]))
def _deserialize(self, response):
parsed = response.json()
if "trial_info" in parsed:
load_trial_info(parsed["trial_info"])
if "trial" in parsed:
self._load_trial_info(parsed["trial"])
elif "trials" in parsed:
for trial_info in parsed["trials"]:
load_trial_info(trial_info)
self._load_trial_info(trial_info)
return parsed
def RunnerHandler(runner):
class Handler(SimpleHTTPRequestHandler):
"""A Handler is a custom handler for TuneServer.
Handles all requests and responses coming into and from
the TuneServer.
"""
def _do_header(self, response_code=200, headers=None):
"""Sends the header portion of the HTTP response.
Parameters:
response_code (int): Standard HTTP response code
headers (list[tuples]): Standard HTTP response headers
"""
if headers is None:
headers = [('Content-type', 'application/json')]
self.send_response(response_code)
for key, value in headers:
self.send_header(key, value)
self.end_headers()
def do_HEAD(self):
"""HTTP HEAD handler method."""
self._do_header()
def do_GET(self):
"""HTTP GET handler method."""
response_code = 200
message = ""
try:
result = self._get_trial_by_url(self.path)
resource = {}
if result:
if isinstance(result, list):
infos = [self._trial_info(t) for t in result]
resource["trials"] = infos
else:
resource["trial"] = self._trial_info(result)
message = json.dumps(resource)
except TuneError as e:
response_code = 404
message = str(e)
self._do_header(response_code=response_code)
self.wfile.write(message.encode())
def do_PUT(self):
"""HTTP PUT handler method."""
response_code = 200
message = ""
try:
result = self._get_trial_by_url(self.path)
resource = {}
if result:
if isinstance(result, list):
infos = [self._trial_info(t) for t in result]
resource["trials"] = infos
for t in result:
runner.request_stop_trial(t)
else:
resource["trial"] = self._trial_info(result)
runner.request_stop_trial(result)
message = json.dumps(resource)
except TuneError as e:
response_code = 404
message = str(e)
self._do_header(response_code=response_code)
self.wfile.write(message.encode())
def do_POST(self):
"""HTTP POST handler method."""
response_code = 201
content_len = int(self.headers.get('Content-Length'), 0)
raw_body = self.rfile.read(content_len)
parsed_input = json.loads(raw_body.decode())
status, response = self.execute_command(parsed_input)
if status:
self.send_response(200)
else:
self.send_response(400)
self.end_headers()
self.wfile.write(json.dumps(response).encode())
resource = self._add_trials(parsed_input["name"],
parsed_input["spec"])
def trial_info(self, trial):
headers = [('Content-type', 'application/json'), ('Location',
'/trials/')]
self._do_header(response_code=response_code, headers=headers)
self.wfile.write(json.dumps(resource).encode())
def _trial_info(self, trial):
"""Returns trial information as JSON."""
if trial.last_result:
result = trial.last_result.copy()
else:
@@ -118,62 +196,56 @@ def RunnerHandler(runner):
}
return info_dict
def execute_command(self, args):
def get_trial():
trial = runner.get_trial(args["trial_id"])
if trial is None:
error = "Trial ({}) not found.".format(args["trial_id"])
raise TuneManagerError(error)
else:
return trial
def _get_trial_by_url(self, url):
"""Parses url to get either all trials or trial by trial_id."""
parts = urlparse(url)
path = parts.path
command = args["command"]
response = {}
try:
if command == TuneClient.GET_LIST:
response["trials"] = [
self.trial_info(t) for t in runner.get_trials()
]
elif command == TuneClient.GET_TRIAL:
trial = get_trial()
response["trial_info"] = self.trial_info(trial)
elif command == TuneClient.STOP:
trial = get_trial()
runner.request_stop_trial(trial)
elif command == TuneClient.ADD:
name = args["name"]
spec = args["spec"]
trial_generator = BasicVariantGenerator()
trial_generator.add_configurations({name: spec})
for trial in trial_generator.next_trials():
runner.add_trial(trial)
else:
raise TuneManagerError("Unknown command.")
status = True
except TuneError as e:
status = False
response["message"] = str(e)
if path == "/trials":
return [t for t in runner.get_trials()]
else:
trial_id = path.split("/")[-1]
return runner.get_trial(trial_id)
return status, response
def _add_trials(self, name, spec):
"""Add trial by invoking TrialRunner."""
resource = {}
resource["trials"] = []
trial_generator = BasicVariantGenerator()
trial_generator.add_configurations({name: spec})
for trial in trial_generator.next_trials():
runner.add_trial(trial)
resource["trials"].append(self._trial_info(trial))
return resource
return Handler
class TuneServer(threading.Thread):
"""A TuneServer is a thread that initializes and runs a HTTPServer.
The server handles requests from a TuneClient.
Attributes:
runner (TrialRunner): Runner that modifies and accesses trials.
port_forward (int): Port number of TuneServer.
"""
DEFAULT_PORT = 4321
def __init__(self, runner, port=None):
"""Initialize HTTPServer and serve forever by invoking self.run()"""
threading.Thread.__init__(self)
self._port = port if port else self.DEFAULT_PORT
address = ('localhost', self._port)
logger.info("Starting Tune Server...")
self._server = HTTPServer(address, RunnerHandler(runner))
self.daemon = True
self.start()
def run(self):
self._server.serve_forever()
def shutdown(self):
"""Shutdown the underlying server."""
self._server.shutdown()