[Serve] Add support for variable routes (#13968)

This commit is contained in:
architkulkarni
2021-02-15 09:42:42 -08:00
committed by GitHub
parent 4d727e4cdf
commit 0fb96a61fc
6 changed files with 166 additions and 84 deletions
+7 -4
View File
@@ -163,10 +163,13 @@ class ServeController:
self.endpoint_state.shadow_traffic(endpoint_name, backend_tag,
proportion)
# TODO(architkulkarni): add Optional for route after cloudpickle upgrade
async def create_endpoint(self, endpoint: str,
traffic_dict: Dict[str, float], route,
methods: List[str]) -> None:
async def create_endpoint(
self,
endpoint: str,
traffic_dict: Dict[str, float],
route: Optional[str],
methods: List[str],
) -> None:
"""Create a new endpoint with the specified route and methods.
If the route is None, this is a "headless" endpoint that will not
+1 -1
View File
@@ -20,7 +20,7 @@ class EndpointState:
long_poll_host: LongPollHost):
self._kv_store = kv_store
self._long_poll_host = long_poll_host
self._routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
self._routes: Dict[str, Tuple[EndpointTag, Any]] = dict()
self._traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
checkpoint = self._kv_store.get(CHECKPOINT_KEY)
+92 -78
View File
@@ -1,23 +1,82 @@
import asyncio
import socket
from typing import List
from typing import List, Dict, Tuple
import uvicorn
import starlette.responses
import starlette.routing
import ray
from ray.exceptions import RayTaskError
from ray.serve.common import EndpointTag
from ray.serve.constants import LongPollKey
from ray.util import metrics
from ray.serve.utils import _get_logger
from ray.serve.http_util import Response, build_starlette_request
from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.router import Router
from ray.serve.handle import DEFAULT
logger = _get_logger()
class ServeStarletteEndpoint:
"""Wraps the given Serve endpoint in a Starlette endpoint.
Implements the ASGI protocol. Constructs a Starlette endpoint for use by
a Starlette app or Starlette Router which calls the given Serve endpoint
using the given Serve client.
Usage:
route = starlette.routing.Route(
"/api",
ServeStarletteEndpoint(self.client, endpoint_tag),
methods=methods)
app = starlette.applications.Starlette(routes=[route])
"""
def __init__(self, client, endpoint_tag: EndpointTag):
self.client = client
self.endpoint_tag = endpoint_tag
self.handle = None
async def __call__(self, scope, receive, send):
http_body_bytes = await self.receive_http_body(scope, receive, send)
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
if self.handle is None:
self.handle = self.client.get_handle(self.endpoint_tag, sync=False)
self.handle = self.handle.options(
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
DEFAULT.VALUE),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
http_method=scope["method"].upper(),
http_headers=headers)
request = build_starlette_request(scope, http_body_bytes)
object_ref = await self.handle.remote(request)
result = await object_ref
if isinstance(result, RayTaskError):
error_message = "Task Error. Traceback: {}.".format(result)
await Response(
error_message, status_code=500).send(scope, receive, send)
elif isinstance(result, starlette.responses.Response):
await result(scope, receive, send)
else:
await Response(result).send(scope, receive, send)
async def receive_http_body(self, scope, receive, send):
body_buffer = []
more_body = True
while more_body:
message = await receive()
assert message["type"] == "http.request"
more_body = message["more_body"]
body_buffer.append(message["body"])
return b"".join(body_buffer)
class HTTPProxy:
"""This class is meant to be instantiated and run by an ASGI HTTP server.
@@ -33,8 +92,12 @@ class HTTPProxy:
self.client = ray.serve.connect()
controller = ray.get_actor(controller_name)
self.route_table = {} # Should be updated via long polling.
self.router = Router(controller)
self.router = starlette.routing.Router(default=self._not_found)
# route -> (endpoint_tag, methods). Updated via long polling.
self.route_table: Dict[str, Tuple[EndpointTag, List[str]]] = {}
self.long_poll_client = LongPollAsyncClient(controller, {
LongPollKey.ROUTE_TABLE: self._update_route_table,
})
@@ -44,40 +107,38 @@ class HTTPProxy:
description="The number of HTTP requests processed.",
tag_keys=("route", ))
async def setup(self):
await self.router.setup_in_async_loop()
async def _update_route_table(self, route_table):
logger.debug(f"HTTP Proxy: Get updated route table: {route_table}.")
self.route_table = route_table
async def receive_http_body(self, scope, receive, send):
body_buffer = []
more_body = True
while more_body:
message = await receive()
assert message["type"] == "http.request"
routes = [
starlette.routing.Route(
route,
ServeStarletteEndpoint(self.client, endpoint_tag),
methods=methods)
for route, (endpoint_tag, methods) in route_table.items()
if not self._is_headless(route)
]
more_body = message["more_body"]
body_buffer.append(message["body"])
routes.append(
starlette.routing.Route("/-/routes", self._display_route_table))
return b"".join(body_buffer)
self.router.routes = routes
def _make_error_sender(self, scope, receive, send):
async def sender(error_message, status_code):
response = Response(error_message, status_code=status_code)
await response.send(scope, receive, send)
return sender
async def _handle_system_request(self, scope, receive, send):
async def _not_found(self, scope, receive, send):
current_path = scope["path"]
if current_path == "/-/routes":
await Response(self.route_table).send(scope, receive, send)
else:
await Response(
"System path {} not found".format(current_path),
status_code=404).send(scope, receive, send)
error_message = ("Path {} not found. "
"Please ping http://.../-/routes for route table."
).format(current_path)
response = Response(error_message, status_code=404)
await response.send(scope, receive, send)
async def _display_route_table(self, request):
return starlette.responses.JSONResponse(self.route_table)
def _is_headless(self, route: str):
"""Returns True if `route` corresponds to a headless endpoint."""
return not route.startswith("/")
async def __call__(self, scope, receive, send):
"""Implements the ASGI protocol.
@@ -86,8 +147,6 @@ class HTTPProxy:
https://asgi.readthedocs.io/en/latest/specs/index.html.
"""
error_sender = self._make_error_sender(scope, receive, send)
assert self.route_table is not None, (
"Route table must be set via set_route_table.")
assert scope["type"] == "http"
@@ -95,51 +154,7 @@ class HTTPProxy:
self.request_counter.record(1, tags={"route": current_path})
if current_path.startswith("/-/"):
await self._handle_system_request(scope, receive, send)
return
try:
endpoint_name, methods_allowed = self.route_table[current_path]
except KeyError:
error_message = (
"Path {} not found. "
"Please ping http://.../-/routes for routing table"
).format(current_path)
await error_sender(error_message, 404)
return
if scope["method"] not in methods_allowed:
error_message = ("Methods {} not allowed. "
"Available HTTP methods are {}.").format(
scope["method"], methods_allowed)
await error_sender(error_message, 405)
return
http_body_bytes = await self.receive_http_body(scope, receive, send)
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
handle = self.client.get_handle(
endpoint_name, sync=False).options(
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
DEFAULT.VALUE),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(),
DEFAULT.VALUE),
http_method=scope["method"].upper(),
http_headers=headers)
request = build_starlette_request(scope, http_body_bytes)
object_ref = await handle.remote(request)
result = await object_ref
if isinstance(result, RayTaskError):
error_message = "Task Error. Traceback: {}.".format(result)
await error_sender(error_message, 500)
elif isinstance(result, starlette.responses.Response):
await result(scope, receive, send)
else:
await Response(result).send(scope, receive, send)
await self.router(scope, receive, send)
@ray.remote
@@ -157,7 +172,6 @@ class HTTPProxyActor:
self.setup_complete = asyncio.Event()
self.app = HTTPProxy(controller_name)
await self.app.setup()
self.wrapped_app = self.app
for middleware in http_middlewares:
+10 -1
View File
@@ -19,7 +19,16 @@ def build_starlette_request(scope, serialized_body: bytes):
"more_body": False
}
return starlette.requests.Request(scope, mock_receive)
# scope["router"] and scope["endpoint"] contain references to a router and
# endpoint object, respectively, which each in turn contain a reference to
# the Serve client, which cannot be serialized.
# The solution is to delete these from scope, as they will not be used.
# Per ASGI recommendation, copy scope before passing to child.
child_scope = scope.copy()
del child_scope["router"]
del child_scope["endpoint"]
return starlette.requests.Request(child_scope, mock_receive)
class Response:
+23
View File
@@ -989,6 +989,29 @@ def test_starlette_request(serve_instance):
assert resp == long_string
def test_variable_routes(serve_instance):
client = serve_instance
def f(starlette_request):
return starlette_request.path_params
client.create_backend("f", f)
client.create_endpoint("basic", backend="f", route="/api/{username}")
# Test multiple variables and test type conversion
client.create_endpoint(
"complex", backend="f", route="/api/{user_id:int}/{number:float}")
assert requests.get("http://127.0.0.1:8000/api/scaly").json() == {
"username": "scaly"
}
assert requests.get("http://127.0.0.1:8000/api/23/12.345").json() == {
"user_id": 23,
"number": 12.345
}
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))