From c8da5555abe98bd4153b29549e9ee86131d8227f Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 6 Aug 2020 12:49:31 -0700 Subject: [PATCH] [Serve] Add preliminary middleware support (#9940) --- python/ray/serve/api.py | 18 ++++++------------ python/ray/serve/controller.py | 6 ++++-- python/ray/serve/http_proxy.py | 20 +++++++++++++++++--- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index db770c702..8a70d1954 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -55,12 +55,11 @@ def accept_batch(f): return f -def init( - name=None, - http_host=DEFAULT_HTTP_HOST, - http_port=DEFAULT_HTTP_PORT, - metric_exporter=InMemoryExporter, -): +def init(name=None, + http_host=DEFAULT_HTTP_HOST, + http_port=DEFAULT_HTTP_PORT, + metric_exporter=InMemoryExporter, + _http_middlewares=[]): """Initialize or connect to a serve cluster. If serve cluster is already initialized, this function will just return. @@ -101,12 +100,7 @@ def init( name=controller_name, max_restarts=-1, max_task_retries=-1, - ).remote( - name, - http_host, - http_port, - metric_exporter, - ) + ).remote(name, http_host, http_port, metric_exporter, _http_middlewares) futures = [] for node_id in ray.state.node_ids(): diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 8aad3ddb6..53d2144cd 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -92,7 +92,7 @@ class ServeController: """ async def __init__(self, instance_name, http_host, http_port, - metric_exporter_class): + metric_exporter_class, _http_middlewares): # Unique name of the serve instance managed by this actor. Used to # namespace child actors and checkpoints. self.instance_name = instance_name @@ -135,6 +135,7 @@ class ServeController: self.http_host = http_host self.http_port = http_port + self._http_middlewares = _http_middlewares # If starting the actor for the first time, starts up the other system # components. If recovering, fetches their actor handles. @@ -190,7 +191,8 @@ class ServeController: node_id, self.http_host, self.http_port, - instance_name=self.instance_name) + instance_name=self.instance_name, + _http_middlewares=self._http_middlewares) self.routers[node_id] = router diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index e6aa130ea..f594a3e67 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -1,6 +1,7 @@ import asyncio from urllib.parse import parse_qs import socket +from typing import List import uvicorn @@ -170,13 +171,26 @@ class HTTPProxy: @ray.remote class HTTPProxyActor: - async def __init__(self, name, host, port, instance_name=None): + async def __init__( + self, + name, + host, + port, + instance_name=None, + _http_middlewares: List["starlette.middleware.Middleware"] = []): serve.init(name=instance_name) self.app = HTTPProxy() - await self.app.fetch_config_from_controller(name, instance_name) self.host = host self.port = port + self.app = HTTPProxy() + await self.app.fetch_config_from_controller(name, instance_name) + + self.wrapped_app = self.app + for middleware in _http_middlewares: + self.wrapped_app = middleware.cls(self.wrapped_app, + **middleware.options) + # Start running the HTTP server on the event loop. asyncio.get_event_loop().create_task(self.run()) @@ -197,7 +211,7 @@ class HTTPProxyActor: # class because we want to run the server as a coroutine. The only # alternative is to call uvicorn.run which is blocking. config = uvicorn.Config( - self.app, + self.wrapped_app, host=self.host, port=self.port, lifespan="off",