mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:31:08 +08:00
[Serve] Add preliminary middleware support (#9940)
This commit is contained in:
+6
-12
@@ -55,12 +55,11 @@ def accept_batch(f):
|
||||
return f
|
||||
|
||||
|
||||
def init(
|
||||
name=None,
|
||||
http_host=DEFAULT_HTTP_HOST,
|
||||
http_port=DEFAULT_HTTP_PORT,
|
||||
metric_exporter=InMemoryExporter,
|
||||
):
|
||||
def init(name=None,
|
||||
http_host=DEFAULT_HTTP_HOST,
|
||||
http_port=DEFAULT_HTTP_PORT,
|
||||
metric_exporter=InMemoryExporter,
|
||||
_http_middlewares=[]):
|
||||
"""Initialize or connect to a serve cluster.
|
||||
|
||||
If serve cluster is already initialized, this function will just return.
|
||||
@@ -101,12 +100,7 @@ def init(
|
||||
name=controller_name,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
).remote(
|
||||
name,
|
||||
http_host,
|
||||
http_port,
|
||||
metric_exporter,
|
||||
)
|
||||
).remote(name, http_host, http_port, metric_exporter, _http_middlewares)
|
||||
|
||||
futures = []
|
||||
for node_id in ray.state.node_ids():
|
||||
|
||||
@@ -92,7 +92,7 @@ class ServeController:
|
||||
"""
|
||||
|
||||
async def __init__(self, instance_name, http_host, http_port,
|
||||
metric_exporter_class):
|
||||
metric_exporter_class, _http_middlewares):
|
||||
# Unique name of the serve instance managed by this actor. Used to
|
||||
# namespace child actors and checkpoints.
|
||||
self.instance_name = instance_name
|
||||
@@ -135,6 +135,7 @@ class ServeController:
|
||||
|
||||
self.http_host = http_host
|
||||
self.http_port = http_port
|
||||
self._http_middlewares = _http_middlewares
|
||||
|
||||
# If starting the actor for the first time, starts up the other system
|
||||
# components. If recovering, fetches their actor handles.
|
||||
@@ -190,7 +191,8 @@ class ServeController:
|
||||
node_id,
|
||||
self.http_host,
|
||||
self.http_port,
|
||||
instance_name=self.instance_name)
|
||||
instance_name=self.instance_name,
|
||||
_http_middlewares=self._http_middlewares)
|
||||
|
||||
self.routers[node_id] = router
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from urllib.parse import parse_qs
|
||||
import socket
|
||||
from typing import List
|
||||
|
||||
import uvicorn
|
||||
|
||||
@@ -170,13 +171,26 @@ class HTTPProxy:
|
||||
|
||||
@ray.remote
|
||||
class HTTPProxyActor:
|
||||
async def __init__(self, name, host, port, instance_name=None):
|
||||
async def __init__(
|
||||
self,
|
||||
name,
|
||||
host,
|
||||
port,
|
||||
instance_name=None,
|
||||
_http_middlewares: List["starlette.middleware.Middleware"] = []):
|
||||
serve.init(name=instance_name)
|
||||
self.app = HTTPProxy()
|
||||
await self.app.fetch_config_from_controller(name, instance_name)
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.app = HTTPProxy()
|
||||
await self.app.fetch_config_from_controller(name, instance_name)
|
||||
|
||||
self.wrapped_app = self.app
|
||||
for middleware in _http_middlewares:
|
||||
self.wrapped_app = middleware.cls(self.wrapped_app,
|
||||
**middleware.options)
|
||||
|
||||
# Start running the HTTP server on the event loop.
|
||||
asyncio.get_event_loop().create_task(self.run())
|
||||
|
||||
@@ -197,7 +211,7 @@ class HTTPProxyActor:
|
||||
# class because we want to run the server as a coroutine. The only
|
||||
# alternative is to call uvicorn.run which is blocking.
|
||||
config = uvicorn.Config(
|
||||
self.app,
|
||||
self.wrapped_app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
lifespan="off",
|
||||
|
||||
Reference in New Issue
Block a user