diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index 7a6027ad5..ca9b8e9ce 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -421,3 +421,36 @@ in :mod:`serve.start `: Using the "EveryNode" option, you can point a cloud load balancer to the instance group of Ray cluster to achieve high availability of Serve's HTTP proxies. + +Variable HTTP Routes +==================== + +Ray Serve supports capturing path parameters. For example, in a call of the form + +.. code-block:: python + + client.create_endpoint("my_endpoint", backend="my_backend", route="/api/{username}") + +the ``username`` parameter will be accessible in your backend code as follows: + +.. code-block:: python + + def my_backend(request): + username = request.path_params["username"] + ... + +Ray Serve uses Starlette's Router class under the hood for routing, so type +conversion for path parameters is also supported, as well as multiple path parameters. +For example, suppose this route is used: + +.. code-block:: python + + client.create_endpoint( + "complex", backend="f", route="/api/{user_id:int}/{number:float}") + +Then for a query to the route ``/api/123/3.14``, the ``request.path_params`` dictionary +available in the backend will be ``{"user_id": 123, "number": 3.14}``, where ``123`` is +a Python int and ``3.14`` is a Python float. + +For full details on the supported path parameters, see Starlette's +`path parameters documentation `_. diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 0ad444a54..8996c342d 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -163,10 +163,13 @@ class ServeController: self.endpoint_state.shadow_traffic(endpoint_name, backend_tag, proportion) - # TODO(architkulkarni): add Optional for route after cloudpickle upgrade - async def create_endpoint(self, endpoint: str, - traffic_dict: Dict[str, float], route, - methods: List[str]) -> None: + async def create_endpoint( + self, + endpoint: str, + traffic_dict: Dict[str, float], + route: Optional[str], + methods: List[str], + ) -> None: """Create a new endpoint with the specified route and methods. If the route is None, this is a "headless" endpoint that will not diff --git a/python/ray/serve/endpoint_state.py b/python/ray/serve/endpoint_state.py index bdbfe2c39..39a67d090 100644 --- a/python/ray/serve/endpoint_state.py +++ b/python/ray/serve/endpoint_state.py @@ -20,7 +20,7 @@ class EndpointState: long_poll_host: LongPollHost): self._kv_store = kv_store self._long_poll_host = long_poll_host - self._routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict() + self._routes: Dict[str, Tuple[EndpointTag, Any]] = dict() self._traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict() checkpoint = self._kv_store.get(CHECKPOINT_KEY) diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 5f722276e..f6fa25bb3 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -1,23 +1,82 @@ import asyncio import socket -from typing import List +from typing import List, Dict, Tuple import uvicorn import starlette.responses +import starlette.routing import ray from ray.exceptions import RayTaskError +from ray.serve.common import EndpointTag from ray.serve.constants import LongPollKey from ray.util import metrics from ray.serve.utils import _get_logger from ray.serve.http_util import Response, build_starlette_request from ray.serve.long_poll import LongPollAsyncClient -from ray.serve.router import Router from ray.serve.handle import DEFAULT logger = _get_logger() +class ServeStarletteEndpoint: + """Wraps the given Serve endpoint in a Starlette endpoint. + + Implements the ASGI protocol. Constructs a Starlette endpoint for use by + a Starlette app or Starlette Router which calls the given Serve endpoint + using the given Serve client. + + Usage: + route = starlette.routing.Route( + "/api", + ServeStarletteEndpoint(self.client, endpoint_tag), + methods=methods) + app = starlette.applications.Starlette(routes=[route]) + """ + + def __init__(self, client, endpoint_tag: EndpointTag): + self.client = client + self.endpoint_tag = endpoint_tag + self.handle = None + + async def __call__(self, scope, receive, send): + http_body_bytes = await self.receive_http_body(scope, receive, send) + + headers = {k.decode(): v.decode() for k, v in scope["headers"]} + if self.handle is None: + self.handle = self.client.get_handle(self.endpoint_tag, sync=False) + self.handle = self.handle.options( + method_name=headers.get("X-SERVE-CALL-METHOD".lower(), + DEFAULT.VALUE), + shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), DEFAULT.VALUE), + http_method=scope["method"].upper(), + http_headers=headers) + request = build_starlette_request(scope, http_body_bytes) + object_ref = await self.handle.remote(request) + result = await object_ref + + if isinstance(result, RayTaskError): + error_message = "Task Error. Traceback: {}.".format(result) + await Response( + error_message, status_code=500).send(scope, receive, send) + elif isinstance(result, starlette.responses.Response): + await result(scope, receive, send) + else: + await Response(result).send(scope, receive, send) + + async def receive_http_body(self, scope, receive, send): + body_buffer = [] + more_body = True + while more_body: + message = await receive() + assert message["type"] == "http.request" + + more_body = message["more_body"] + body_buffer.append(message["body"]) + + return b"".join(body_buffer) + + class HTTPProxy: """This class is meant to be instantiated and run by an ASGI HTTP server. @@ -33,8 +92,12 @@ class HTTPProxy: self.client = ray.serve.connect() controller = ray.get_actor(controller_name) - self.route_table = {} # Should be updated via long polling. - self.router = Router(controller) + + self.router = starlette.routing.Router(default=self._not_found) + + # route -> (endpoint_tag, methods). Updated via long polling. + self.route_table: Dict[str, Tuple[EndpointTag, List[str]]] = {} + self.long_poll_client = LongPollAsyncClient(controller, { LongPollKey.ROUTE_TABLE: self._update_route_table, }) @@ -44,40 +107,38 @@ class HTTPProxy: description="The number of HTTP requests processed.", tag_keys=("route", )) - async def setup(self): - await self.router.setup_in_async_loop() - async def _update_route_table(self, route_table): logger.debug(f"HTTP Proxy: Get updated route table: {route_table}.") self.route_table = route_table - async def receive_http_body(self, scope, receive, send): - body_buffer = [] - more_body = True - while more_body: - message = await receive() - assert message["type"] == "http.request" + routes = [ + starlette.routing.Route( + route, + ServeStarletteEndpoint(self.client, endpoint_tag), + methods=methods) + for route, (endpoint_tag, methods) in route_table.items() + if not self._is_headless(route) + ] - more_body = message["more_body"] - body_buffer.append(message["body"]) + routes.append( + starlette.routing.Route("/-/routes", self._display_route_table)) - return b"".join(body_buffer) + self.router.routes = routes - def _make_error_sender(self, scope, receive, send): - async def sender(error_message, status_code): - response = Response(error_message, status_code=status_code) - await response.send(scope, receive, send) - - return sender - - async def _handle_system_request(self, scope, receive, send): + async def _not_found(self, scope, receive, send): current_path = scope["path"] - if current_path == "/-/routes": - await Response(self.route_table).send(scope, receive, send) - else: - await Response( - "System path {} not found".format(current_path), - status_code=404).send(scope, receive, send) + error_message = ("Path {} not found. " + "Please ping http://.../-/routes for route table." + ).format(current_path) + response = Response(error_message, status_code=404) + await response.send(scope, receive, send) + + async def _display_route_table(self, request): + return starlette.responses.JSONResponse(self.route_table) + + def _is_headless(self, route: str): + """Returns True if `route` corresponds to a headless endpoint.""" + return not route.startswith("/") async def __call__(self, scope, receive, send): """Implements the ASGI protocol. @@ -86,8 +147,6 @@ class HTTPProxy: https://asgi.readthedocs.io/en/latest/specs/index.html. """ - error_sender = self._make_error_sender(scope, receive, send) - assert self.route_table is not None, ( "Route table must be set via set_route_table.") assert scope["type"] == "http" @@ -95,51 +154,7 @@ class HTTPProxy: self.request_counter.record(1, tags={"route": current_path}) - if current_path.startswith("/-/"): - await self._handle_system_request(scope, receive, send) - return - - try: - endpoint_name, methods_allowed = self.route_table[current_path] - except KeyError: - error_message = ( - "Path {} not found. " - "Please ping http://.../-/routes for routing table" - ).format(current_path) - await error_sender(error_message, 404) - return - - if scope["method"] not in methods_allowed: - error_message = ("Methods {} not allowed. " - "Available HTTP methods are {}.").format( - scope["method"], methods_allowed) - await error_sender(error_message, 405) - return - - http_body_bytes = await self.receive_http_body(scope, receive, send) - - headers = {k.decode(): v.decode() for k, v in scope["headers"]} - - handle = self.client.get_handle( - endpoint_name, sync=False).options( - method_name=headers.get("X-SERVE-CALL-METHOD".lower(), - DEFAULT.VALUE), - shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), - DEFAULT.VALUE), - http_method=scope["method"].upper(), - http_headers=headers) - - request = build_starlette_request(scope, http_body_bytes) - object_ref = await handle.remote(request) - result = await object_ref - - if isinstance(result, RayTaskError): - error_message = "Task Error. Traceback: {}.".format(result) - await error_sender(error_message, 500) - elif isinstance(result, starlette.responses.Response): - await result(scope, receive, send) - else: - await Response(result).send(scope, receive, send) + await self.router(scope, receive, send) @ray.remote @@ -157,7 +172,6 @@ class HTTPProxyActor: self.setup_complete = asyncio.Event() self.app = HTTPProxy(controller_name) - await self.app.setup() self.wrapped_app = self.app for middleware in http_middlewares: diff --git a/python/ray/serve/http_util.py b/python/ray/serve/http_util.py index 0aa4ccf84..e8a51adf3 100644 --- a/python/ray/serve/http_util.py +++ b/python/ray/serve/http_util.py @@ -19,7 +19,16 @@ def build_starlette_request(scope, serialized_body: bytes): "more_body": False } - return starlette.requests.Request(scope, mock_receive) + # scope["router"] and scope["endpoint"] contain references to a router and + # endpoint object, respectively, which each in turn contain a reference to + # the Serve client, which cannot be serialized. + # The solution is to delete these from scope, as they will not be used. + # Per ASGI recommendation, copy scope before passing to child. + child_scope = scope.copy() + del child_scope["router"] + del child_scope["endpoint"] + + return starlette.requests.Request(child_scope, mock_receive) class Response: diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 62f239f78..abfdbf1fb 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -989,6 +989,29 @@ def test_starlette_request(serve_instance): assert resp == long_string +def test_variable_routes(serve_instance): + client = serve_instance + + def f(starlette_request): + return starlette_request.path_params + + client.create_backend("f", f) + client.create_endpoint("basic", backend="f", route="/api/{username}") + + # Test multiple variables and test type conversion + client.create_endpoint( + "complex", backend="f", route="/api/{user_id:int}/{number:float}") + + assert requests.get("http://127.0.0.1:8000/api/scaly").json() == { + "username": "scaly" + } + + assert requests.get("http://127.0.0.1:8000/api/23/12.345").json() == { + "user_id": 23, + "number": 12.345 + } + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", "-s", __file__]))