From eff4375c3da095cb0f1a7ebc441f4acfbf3abff1 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 3 Sep 2020 17:31:38 -0700 Subject: [PATCH] [Serve] Produtionize Starlette Middlewares (#10529) --- doc/source/serve/advanced.rst | 28 +++++++++++++++ python/ray/serve/BUILD | 2 +- python/ray/serve/api.py | 4 +-- python/ray/serve/controller.py | 6 ++-- python/ray/serve/http_proxy.py | 4 +-- .../{test_scaling.py => test_standalone.py} | 34 +++++++++++++++++++ python/requirements.txt | 1 + 7 files changed, 71 insertions(+), 8 deletions(-) rename python/ray/serve/tests/{test_scaling.py => test_standalone.py} (69%) diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index ebf4fa62f..c7595f900 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -316,3 +316,31 @@ To call a method via Python, do the following: handle = serve.get_handle("backend_name") handle.options(method_name="other_method").remote(5) + +How do I enable CORS and other HTTP features? +--------------------------------------------- + +Serve supports arbitrary `Starlette middlewares `_ +and custom middlewares in Starlette format. The example below shows how to enable +`Cross-Origin Resource Sharing (CORS) `_. +You can follow the same pattern for other Starlette middlewares. + +.. note:: + + Serve does not list ``Starlette`` as one of its dependencies. To utilize this feature, + you will need to: + + .. code-block:: bash + + pip install starlette + +.. code-block:: python + + from starlette.middleware import Middleware + from starlette.middleware.cors import CORSMiddleware + + serve.init( + http_middlewares=[ + Middleware( + CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) + ]) diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 178298f66..c1f35697a 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -91,7 +91,7 @@ py_test( py_test( - name = "test_scaling", + name = "test_standalone", size = "small", srcs = serve_tests_srcs, tags = ["exclusive"], diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 0dc4d0f06..2e738d3c8 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -59,7 +59,7 @@ def accept_batch(f: Callable) -> Callable: def init(name: Optional[str] = None, http_host: str = DEFAULT_HTTP_HOST, http_port: int = DEFAULT_HTTP_PORT, - _http_middlewares: List[Any] = []) -> None: + http_middlewares: List[Any] = []) -> None: """Initialize or connect to a serve cluster. If serve cluster is already initialized, this function will just return. @@ -97,7 +97,7 @@ def init(name: Optional[str] = None, lifetime="detached", max_restarts=-1, max_task_retries=-1, - ).remote(name, http_host, http_port, _http_middlewares) + ).remote(name, http_host, http_port, http_middlewares) futures = [] for node_id in ray.state.node_ids(): diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index d42541eed..95d6c346a 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -103,7 +103,7 @@ class ServeController: """ async def __init__(self, instance_name: str, http_host: str, - http_port: str, _http_middlewares: List[Any]) -> None: + http_port: str, http_middlewares: List[Any]) -> None: # Unique name of the serve instance managed by this actor. Used to # namespace child actors and checkpoints. self.instance_name = instance_name @@ -145,7 +145,7 @@ class ServeController: self.http_host = http_host self.http_port = http_port - self._http_middlewares = _http_middlewares + self.http_middlewares = http_middlewares # If starting the actor for the first time, starts up the other system # components. If recovering, fetches their actor handles. @@ -202,7 +202,7 @@ class ServeController: self.http_host, self.http_port, instance_name=self.instance_name, - _http_middlewares=self._http_middlewares) + http_middlewares=self.http_middlewares) self.routers[node_id] = router diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index b1a7672ad..5a47061db 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -134,7 +134,7 @@ class HTTPProxyActor: host, port, instance_name=None, - _http_middlewares: List["starlette.middleware.Middleware"] = []): + http_middlewares: List["starlette.middleware.Middleware"] = []): serve.init(name=instance_name) self.app = HTTPProxy() self.host = host @@ -144,7 +144,7 @@ class HTTPProxyActor: await self.app.fetch_config_from_controller(name, instance_name) self.wrapped_app = self.app - for middleware in _http_middlewares: + for middleware in http_middlewares: self.wrapped_app = middleware.cls(self.wrapped_app, **middleware.options) diff --git a/python/ray/serve/tests/test_scaling.py b/python/ray/serve/tests/test_standalone.py similarity index 69% rename from python/ray/serve/tests/test_scaling.py rename to python/ray/serve/tests/test_standalone.py index 05b616fde..d60617e4a 100644 --- a/python/ray/serve/tests/test_scaling.py +++ b/python/ray/serve/tests/test_standalone.py @@ -1,7 +1,12 @@ +""" +The test file for all standalone tests that doesn't +requires a shared Serve instance. +""" import sys import socket import pytest +import requests import ray from ray import serve @@ -9,6 +14,7 @@ from ray.cluster_utils import Cluster from ray.serve.constants import SERVE_PROXY_NAME from ray.serve.utils import block_until_http_ready from ray.test_utils import wait_for_condition +from ray.services import new_port @pytest.mark.skipif( @@ -79,5 +85,33 @@ def test_multiple_routers(): cluster.shutdown() +def test_middleware(): + from starlette.middleware import Middleware + from starlette.middleware.cors import CORSMiddleware + + port = new_port() + serve.init( + http_port=port, + http_middlewares=[ + Middleware( + CORSMiddleware, allow_origins=["*"], allow_methods=["*"]) + ]) + ray.get(block_until_http_ready.remote(f"http://127.0.0.1:{port}/-/routes")) + + # Snatched several test cases from Starlette + # https://github.com/encode/starlette/blob/master/tests/ + # middleware/test_cors.py + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "GET", + } + root = f"http://localhost:{port}" + resp = requests.options(root, headers=headers) + assert resp.headers["access-control-allow-origin"] == "*" + + resp = requests.get(f"{root}/-/routes", headers=headers) + assert resp.headers["access-control-allow-origin"] == "*" + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/requirements.txt b/python/requirements.txt index b4411342f..472e925de 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -69,3 +69,4 @@ tensorflow testfixtures werkzeug xlrd +starlette