mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[Serve] Add support for variable routes (#13968)
This commit is contained in:
@@ -421,3 +421,36 @@ in :mod:`serve.start <ray.serve.start>`:
|
||||
Using the "EveryNode" option, you can point a cloud load balancer to the
|
||||
instance group of Ray cluster to achieve high availability of Serve's HTTP
|
||||
proxies.
|
||||
|
||||
Variable HTTP Routes
|
||||
====================
|
||||
|
||||
Ray Serve supports capturing path parameters. For example, in a call of the form
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client.create_endpoint("my_endpoint", backend="my_backend", route="/api/{username}")
|
||||
|
||||
the ``username`` parameter will be accessible in your backend code as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def my_backend(request):
|
||||
username = request.path_params["username"]
|
||||
...
|
||||
|
||||
Ray Serve uses Starlette's Router class under the hood for routing, so type
|
||||
conversion for path parameters is also supported, as well as multiple path parameters.
|
||||
For example, suppose this route is used:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client.create_endpoint(
|
||||
"complex", backend="f", route="/api/{user_id:int}/{number:float}")
|
||||
|
||||
Then for a query to the route ``/api/123/3.14``, the ``request.path_params`` dictionary
|
||||
available in the backend will be ``{"user_id": 123, "number": 3.14}``, where ``123`` is
|
||||
a Python int and ``3.14`` is a Python float.
|
||||
|
||||
For full details on the supported path parameters, see Starlette's
|
||||
`path parameters documentation <https://www.starlette.io/routing/#path-parameters>`_.
|
||||
|
||||
@@ -163,10 +163,13 @@ class ServeController:
|
||||
self.endpoint_state.shadow_traffic(endpoint_name, backend_tag,
|
||||
proportion)
|
||||
|
||||
# TODO(architkulkarni): add Optional for route after cloudpickle upgrade
|
||||
async def create_endpoint(self, endpoint: str,
|
||||
traffic_dict: Dict[str, float], route,
|
||||
methods: List[str]) -> None:
|
||||
async def create_endpoint(
|
||||
self,
|
||||
endpoint: str,
|
||||
traffic_dict: Dict[str, float],
|
||||
route: Optional[str],
|
||||
methods: List[str],
|
||||
) -> None:
|
||||
"""Create a new endpoint with the specified route and methods.
|
||||
|
||||
If the route is None, this is a "headless" endpoint that will not
|
||||
|
||||
@@ -20,7 +20,7 @@ class EndpointState:
|
||||
long_poll_host: LongPollHost):
|
||||
self._kv_store = kv_store
|
||||
self._long_poll_host = long_poll_host
|
||||
self._routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
|
||||
self._routes: Dict[str, Tuple[EndpointTag, Any]] = dict()
|
||||
self._traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
|
||||
|
||||
checkpoint = self._kv_store.get(CHECKPOINT_KEY)
|
||||
|
||||
@@ -1,23 +1,82 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from typing import List
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
import uvicorn
|
||||
import starlette.responses
|
||||
import starlette.routing
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.serve.common import EndpointTag
|
||||
from ray.serve.constants import LongPollKey
|
||||
from ray.util import metrics
|
||||
from ray.serve.utils import _get_logger
|
||||
from ray.serve.http_util import Response, build_starlette_request
|
||||
from ray.serve.long_poll import LongPollAsyncClient
|
||||
from ray.serve.router import Router
|
||||
from ray.serve.handle import DEFAULT
|
||||
|
||||
logger = _get_logger()
|
||||
|
||||
|
||||
class ServeStarletteEndpoint:
|
||||
"""Wraps the given Serve endpoint in a Starlette endpoint.
|
||||
|
||||
Implements the ASGI protocol. Constructs a Starlette endpoint for use by
|
||||
a Starlette app or Starlette Router which calls the given Serve endpoint
|
||||
using the given Serve client.
|
||||
|
||||
Usage:
|
||||
route = starlette.routing.Route(
|
||||
"/api",
|
||||
ServeStarletteEndpoint(self.client, endpoint_tag),
|
||||
methods=methods)
|
||||
app = starlette.applications.Starlette(routes=[route])
|
||||
"""
|
||||
|
||||
def __init__(self, client, endpoint_tag: EndpointTag):
|
||||
self.client = client
|
||||
self.endpoint_tag = endpoint_tag
|
||||
self.handle = None
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
http_body_bytes = await self.receive_http_body(scope, receive, send)
|
||||
|
||||
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
||||
if self.handle is None:
|
||||
self.handle = self.client.get_handle(self.endpoint_tag, sync=False)
|
||||
self.handle = self.handle.options(
|
||||
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
|
||||
DEFAULT.VALUE),
|
||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE),
|
||||
http_method=scope["method"].upper(),
|
||||
http_headers=headers)
|
||||
request = build_starlette_request(scope, http_body_bytes)
|
||||
object_ref = await self.handle.remote(request)
|
||||
result = await object_ref
|
||||
|
||||
if isinstance(result, RayTaskError):
|
||||
error_message = "Task Error. Traceback: {}.".format(result)
|
||||
await Response(
|
||||
error_message, status_code=500).send(scope, receive, send)
|
||||
elif isinstance(result, starlette.responses.Response):
|
||||
await result(scope, receive, send)
|
||||
else:
|
||||
await Response(result).send(scope, receive, send)
|
||||
|
||||
async def receive_http_body(self, scope, receive, send):
|
||||
body_buffer = []
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await receive()
|
||||
assert message["type"] == "http.request"
|
||||
|
||||
more_body = message["more_body"]
|
||||
body_buffer.append(message["body"])
|
||||
|
||||
return b"".join(body_buffer)
|
||||
|
||||
|
||||
class HTTPProxy:
|
||||
"""This class is meant to be instantiated and run by an ASGI HTTP server.
|
||||
|
||||
@@ -33,8 +92,12 @@ class HTTPProxy:
|
||||
self.client = ray.serve.connect()
|
||||
|
||||
controller = ray.get_actor(controller_name)
|
||||
self.route_table = {} # Should be updated via long polling.
|
||||
self.router = Router(controller)
|
||||
|
||||
self.router = starlette.routing.Router(default=self._not_found)
|
||||
|
||||
# route -> (endpoint_tag, methods). Updated via long polling.
|
||||
self.route_table: Dict[str, Tuple[EndpointTag, List[str]]] = {}
|
||||
|
||||
self.long_poll_client = LongPollAsyncClient(controller, {
|
||||
LongPollKey.ROUTE_TABLE: self._update_route_table,
|
||||
})
|
||||
@@ -44,40 +107,38 @@ class HTTPProxy:
|
||||
description="The number of HTTP requests processed.",
|
||||
tag_keys=("route", ))
|
||||
|
||||
async def setup(self):
|
||||
await self.router.setup_in_async_loop()
|
||||
|
||||
async def _update_route_table(self, route_table):
|
||||
logger.debug(f"HTTP Proxy: Get updated route table: {route_table}.")
|
||||
self.route_table = route_table
|
||||
|
||||
async def receive_http_body(self, scope, receive, send):
|
||||
body_buffer = []
|
||||
more_body = True
|
||||
while more_body:
|
||||
message = await receive()
|
||||
assert message["type"] == "http.request"
|
||||
routes = [
|
||||
starlette.routing.Route(
|
||||
route,
|
||||
ServeStarletteEndpoint(self.client, endpoint_tag),
|
||||
methods=methods)
|
||||
for route, (endpoint_tag, methods) in route_table.items()
|
||||
if not self._is_headless(route)
|
||||
]
|
||||
|
||||
more_body = message["more_body"]
|
||||
body_buffer.append(message["body"])
|
||||
routes.append(
|
||||
starlette.routing.Route("/-/routes", self._display_route_table))
|
||||
|
||||
return b"".join(body_buffer)
|
||||
self.router.routes = routes
|
||||
|
||||
def _make_error_sender(self, scope, receive, send):
|
||||
async def sender(error_message, status_code):
|
||||
response = Response(error_message, status_code=status_code)
|
||||
await response.send(scope, receive, send)
|
||||
|
||||
return sender
|
||||
|
||||
async def _handle_system_request(self, scope, receive, send):
|
||||
async def _not_found(self, scope, receive, send):
|
||||
current_path = scope["path"]
|
||||
if current_path == "/-/routes":
|
||||
await Response(self.route_table).send(scope, receive, send)
|
||||
else:
|
||||
await Response(
|
||||
"System path {} not found".format(current_path),
|
||||
status_code=404).send(scope, receive, send)
|
||||
error_message = ("Path {} not found. "
|
||||
"Please ping http://.../-/routes for route table."
|
||||
).format(current_path)
|
||||
response = Response(error_message, status_code=404)
|
||||
await response.send(scope, receive, send)
|
||||
|
||||
async def _display_route_table(self, request):
|
||||
return starlette.responses.JSONResponse(self.route_table)
|
||||
|
||||
def _is_headless(self, route: str):
|
||||
"""Returns True if `route` corresponds to a headless endpoint."""
|
||||
return not route.startswith("/")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
"""Implements the ASGI protocol.
|
||||
@@ -86,8 +147,6 @@ class HTTPProxy:
|
||||
https://asgi.readthedocs.io/en/latest/specs/index.html.
|
||||
"""
|
||||
|
||||
error_sender = self._make_error_sender(scope, receive, send)
|
||||
|
||||
assert self.route_table is not None, (
|
||||
"Route table must be set via set_route_table.")
|
||||
assert scope["type"] == "http"
|
||||
@@ -95,51 +154,7 @@ class HTTPProxy:
|
||||
|
||||
self.request_counter.record(1, tags={"route": current_path})
|
||||
|
||||
if current_path.startswith("/-/"):
|
||||
await self._handle_system_request(scope, receive, send)
|
||||
return
|
||||
|
||||
try:
|
||||
endpoint_name, methods_allowed = self.route_table[current_path]
|
||||
except KeyError:
|
||||
error_message = (
|
||||
"Path {} not found. "
|
||||
"Please ping http://.../-/routes for routing table"
|
||||
).format(current_path)
|
||||
await error_sender(error_message, 404)
|
||||
return
|
||||
|
||||
if scope["method"] not in methods_allowed:
|
||||
error_message = ("Methods {} not allowed. "
|
||||
"Available HTTP methods are {}.").format(
|
||||
scope["method"], methods_allowed)
|
||||
await error_sender(error_message, 405)
|
||||
return
|
||||
|
||||
http_body_bytes = await self.receive_http_body(scope, receive, send)
|
||||
|
||||
headers = {k.decode(): v.decode() for k, v in scope["headers"]}
|
||||
|
||||
handle = self.client.get_handle(
|
||||
endpoint_name, sync=False).options(
|
||||
method_name=headers.get("X-SERVE-CALL-METHOD".lower(),
|
||||
DEFAULT.VALUE),
|
||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(),
|
||||
DEFAULT.VALUE),
|
||||
http_method=scope["method"].upper(),
|
||||
http_headers=headers)
|
||||
|
||||
request = build_starlette_request(scope, http_body_bytes)
|
||||
object_ref = await handle.remote(request)
|
||||
result = await object_ref
|
||||
|
||||
if isinstance(result, RayTaskError):
|
||||
error_message = "Task Error. Traceback: {}.".format(result)
|
||||
await error_sender(error_message, 500)
|
||||
elif isinstance(result, starlette.responses.Response):
|
||||
await result(scope, receive, send)
|
||||
else:
|
||||
await Response(result).send(scope, receive, send)
|
||||
await self.router(scope, receive, send)
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -157,7 +172,6 @@ class HTTPProxyActor:
|
||||
self.setup_complete = asyncio.Event()
|
||||
|
||||
self.app = HTTPProxy(controller_name)
|
||||
await self.app.setup()
|
||||
|
||||
self.wrapped_app = self.app
|
||||
for middleware in http_middlewares:
|
||||
|
||||
@@ -19,7 +19,16 @@ def build_starlette_request(scope, serialized_body: bytes):
|
||||
"more_body": False
|
||||
}
|
||||
|
||||
return starlette.requests.Request(scope, mock_receive)
|
||||
# scope["router"] and scope["endpoint"] contain references to a router and
|
||||
# endpoint object, respectively, which each in turn contain a reference to
|
||||
# the Serve client, which cannot be serialized.
|
||||
# The solution is to delete these from scope, as they will not be used.
|
||||
# Per ASGI recommendation, copy scope before passing to child.
|
||||
child_scope = scope.copy()
|
||||
del child_scope["router"]
|
||||
del child_scope["endpoint"]
|
||||
|
||||
return starlette.requests.Request(child_scope, mock_receive)
|
||||
|
||||
|
||||
class Response:
|
||||
|
||||
@@ -989,6 +989,29 @@ def test_starlette_request(serve_instance):
|
||||
assert resp == long_string
|
||||
|
||||
|
||||
def test_variable_routes(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def f(starlette_request):
|
||||
return starlette_request.path_params
|
||||
|
||||
client.create_backend("f", f)
|
||||
client.create_endpoint("basic", backend="f", route="/api/{username}")
|
||||
|
||||
# Test multiple variables and test type conversion
|
||||
client.create_endpoint(
|
||||
"complex", backend="f", route="/api/{user_id:int}/{number:float}")
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/api/scaly").json() == {
|
||||
"username": "scaly"
|
||||
}
|
||||
|
||||
assert requests.get("http://127.0.0.1:8000/api/23/12.345").json() == {
|
||||
"user_id": 23,
|
||||
"number": 12.345
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
||||
Reference in New Issue
Block a user