mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 04:07:01 +08:00
[Serve] Produtionize Starlette Middlewares (#10529)
This commit is contained in:
@@ -91,7 +91,7 @@ py_test(
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_scaling",
|
||||
name = "test_standalone",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
|
||||
@@ -59,7 +59,7 @@ def accept_batch(f: Callable) -> Callable:
|
||||
def init(name: Optional[str] = None,
|
||||
http_host: str = DEFAULT_HTTP_HOST,
|
||||
http_port: int = DEFAULT_HTTP_PORT,
|
||||
_http_middlewares: List[Any] = []) -> None:
|
||||
http_middlewares: List[Any] = []) -> None:
|
||||
"""Initialize or connect to a serve cluster.
|
||||
|
||||
If serve cluster is already initialized, this function will just return.
|
||||
@@ -97,7 +97,7 @@ def init(name: Optional[str] = None,
|
||||
lifetime="detached",
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
).remote(name, http_host, http_port, _http_middlewares)
|
||||
).remote(name, http_host, http_port, http_middlewares)
|
||||
|
||||
futures = []
|
||||
for node_id in ray.state.node_ids():
|
||||
|
||||
@@ -103,7 +103,7 @@ class ServeController:
|
||||
"""
|
||||
|
||||
async def __init__(self, instance_name: str, http_host: str,
|
||||
http_port: str, _http_middlewares: List[Any]) -> None:
|
||||
http_port: str, http_middlewares: List[Any]) -> None:
|
||||
# Unique name of the serve instance managed by this actor. Used to
|
||||
# namespace child actors and checkpoints.
|
||||
self.instance_name = instance_name
|
||||
@@ -145,7 +145,7 @@ class ServeController:
|
||||
|
||||
self.http_host = http_host
|
||||
self.http_port = http_port
|
||||
self._http_middlewares = _http_middlewares
|
||||
self.http_middlewares = http_middlewares
|
||||
|
||||
# If starting the actor for the first time, starts up the other system
|
||||
# components. If recovering, fetches their actor handles.
|
||||
@@ -202,7 +202,7 @@ class ServeController:
|
||||
self.http_host,
|
||||
self.http_port,
|
||||
instance_name=self.instance_name,
|
||||
_http_middlewares=self._http_middlewares)
|
||||
http_middlewares=self.http_middlewares)
|
||||
|
||||
self.routers[node_id] = router
|
||||
|
||||
|
||||
@@ -134,7 +134,7 @@ class HTTPProxyActor:
|
||||
host,
|
||||
port,
|
||||
instance_name=None,
|
||||
_http_middlewares: List["starlette.middleware.Middleware"] = []):
|
||||
http_middlewares: List["starlette.middleware.Middleware"] = []):
|
||||
serve.init(name=instance_name)
|
||||
self.app = HTTPProxy()
|
||||
self.host = host
|
||||
@@ -144,7 +144,7 @@ class HTTPProxyActor:
|
||||
await self.app.fetch_config_from_controller(name, instance_name)
|
||||
|
||||
self.wrapped_app = self.app
|
||||
for middleware in _http_middlewares:
|
||||
for middleware in http_middlewares:
|
||||
self.wrapped_app = middleware.cls(self.wrapped_app,
|
||||
**middleware.options)
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
"""
|
||||
The test file for all standalone tests that doesn't
|
||||
requires a shared Serve instance.
|
||||
"""
|
||||
import sys
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
@@ -9,6 +14,7 @@ from ray.cluster_utils import Cluster
|
||||
from ray.serve.constants import SERVE_PROXY_NAME
|
||||
from ray.serve.utils import block_until_http_ready
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.services import new_port
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -79,5 +85,33 @@ def test_multiple_routers():
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def test_middleware():
|
||||
from starlette.middleware import Middleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
port = new_port()
|
||||
serve.init(
|
||||
http_port=port,
|
||||
http_middlewares=[
|
||||
Middleware(
|
||||
CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
|
||||
])
|
||||
ray.get(block_until_http_ready.remote(f"http://127.0.0.1:{port}/-/routes"))
|
||||
|
||||
# Snatched several test cases from Starlette
|
||||
# https://github.com/encode/starlette/blob/master/tests/
|
||||
# middleware/test_cors.py
|
||||
headers = {
|
||||
"Origin": "https://example.org",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
}
|
||||
root = f"http://localhost:{port}"
|
||||
resp = requests.options(root, headers=headers)
|
||||
assert resp.headers["access-control-allow-origin"] == "*"
|
||||
|
||||
resp = requests.get(f"{root}/-/routes", headers=headers)
|
||||
assert resp.headers["access-control-allow-origin"] == "*"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
@@ -69,3 +69,4 @@ tensorflow
|
||||
testfixtures
|
||||
werkzeug
|
||||
xlrd
|
||||
starlette
|
||||
|
||||
Reference in New Issue
Block a user