diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst
index ebf4fa62f..c7595f900 100644
--- a/doc/source/serve/advanced.rst
+++ b/doc/source/serve/advanced.rst
@@ -316,3 +316,31 @@ To call a method via Python, do the following:
handle = serve.get_handle("backend_name")
handle.options(method_name="other_method").remote(5)
+
+How do I enable CORS and other HTTP features?
+---------------------------------------------
+
+Serve supports arbitrary `Starlette middlewares `_
+and custom middlewares in Starlette format. The example below shows how to enable
+`Cross-Origin Resource Sharing (CORS) `_.
+You can follow the same pattern for other Starlette middlewares.
+
+.. note::
+
+ Serve does not list ``Starlette`` as one of its dependencies. To utilize this feature,
+ you will need to:
+
+ .. code-block:: bash
+
+ pip install starlette
+
+.. code-block:: python
+
+ from starlette.middleware import Middleware
+ from starlette.middleware.cors import CORSMiddleware
+
+ serve.init(
+ http_middlewares=[
+ Middleware(
+ CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
+ ])
diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD
index 178298f66..c1f35697a 100644
--- a/python/ray/serve/BUILD
+++ b/python/ray/serve/BUILD
@@ -91,7 +91,7 @@ py_test(
py_test(
- name = "test_scaling",
+ name = "test_standalone",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py
index 0dc4d0f06..2e738d3c8 100644
--- a/python/ray/serve/api.py
+++ b/python/ray/serve/api.py
@@ -59,7 +59,7 @@ def accept_batch(f: Callable) -> Callable:
def init(name: Optional[str] = None,
http_host: str = DEFAULT_HTTP_HOST,
http_port: int = DEFAULT_HTTP_PORT,
- _http_middlewares: List[Any] = []) -> None:
+ http_middlewares: List[Any] = []) -> None:
"""Initialize or connect to a serve cluster.
If serve cluster is already initialized, this function will just return.
@@ -97,7 +97,7 @@ def init(name: Optional[str] = None,
lifetime="detached",
max_restarts=-1,
max_task_retries=-1,
- ).remote(name, http_host, http_port, _http_middlewares)
+ ).remote(name, http_host, http_port, http_middlewares)
futures = []
for node_id in ray.state.node_ids():
diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py
index d42541eed..95d6c346a 100644
--- a/python/ray/serve/controller.py
+++ b/python/ray/serve/controller.py
@@ -103,7 +103,7 @@ class ServeController:
"""
async def __init__(self, instance_name: str, http_host: str,
- http_port: str, _http_middlewares: List[Any]) -> None:
+ http_port: str, http_middlewares: List[Any]) -> None:
# Unique name of the serve instance managed by this actor. Used to
# namespace child actors and checkpoints.
self.instance_name = instance_name
@@ -145,7 +145,7 @@ class ServeController:
self.http_host = http_host
self.http_port = http_port
- self._http_middlewares = _http_middlewares
+ self.http_middlewares = http_middlewares
# If starting the actor for the first time, starts up the other system
# components. If recovering, fetches their actor handles.
@@ -202,7 +202,7 @@ class ServeController:
self.http_host,
self.http_port,
instance_name=self.instance_name,
- _http_middlewares=self._http_middlewares)
+ http_middlewares=self.http_middlewares)
self.routers[node_id] = router
diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py
index b1a7672ad..5a47061db 100644
--- a/python/ray/serve/http_proxy.py
+++ b/python/ray/serve/http_proxy.py
@@ -134,7 +134,7 @@ class HTTPProxyActor:
host,
port,
instance_name=None,
- _http_middlewares: List["starlette.middleware.Middleware"] = []):
+ http_middlewares: List["starlette.middleware.Middleware"] = []):
serve.init(name=instance_name)
self.app = HTTPProxy()
self.host = host
@@ -144,7 +144,7 @@ class HTTPProxyActor:
await self.app.fetch_config_from_controller(name, instance_name)
self.wrapped_app = self.app
- for middleware in _http_middlewares:
+ for middleware in http_middlewares:
self.wrapped_app = middleware.cls(self.wrapped_app,
**middleware.options)
diff --git a/python/ray/serve/tests/test_scaling.py b/python/ray/serve/tests/test_standalone.py
similarity index 69%
rename from python/ray/serve/tests/test_scaling.py
rename to python/ray/serve/tests/test_standalone.py
index 05b616fde..d60617e4a 100644
--- a/python/ray/serve/tests/test_scaling.py
+++ b/python/ray/serve/tests/test_standalone.py
@@ -1,7 +1,12 @@
+"""
+The test file for all standalone tests that doesn't
+requires a shared Serve instance.
+"""
import sys
import socket
import pytest
+import requests
import ray
from ray import serve
@@ -9,6 +14,7 @@ from ray.cluster_utils import Cluster
from ray.serve.constants import SERVE_PROXY_NAME
from ray.serve.utils import block_until_http_ready
from ray.test_utils import wait_for_condition
+from ray.services import new_port
@pytest.mark.skipif(
@@ -79,5 +85,33 @@ def test_multiple_routers():
cluster.shutdown()
+def test_middleware():
+ from starlette.middleware import Middleware
+ from starlette.middleware.cors import CORSMiddleware
+
+ port = new_port()
+ serve.init(
+ http_port=port,
+ http_middlewares=[
+ Middleware(
+ CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
+ ])
+ ray.get(block_until_http_ready.remote(f"http://127.0.0.1:{port}/-/routes"))
+
+ # Snatched several test cases from Starlette
+ # https://github.com/encode/starlette/blob/master/tests/
+ # middleware/test_cors.py
+ headers = {
+ "Origin": "https://example.org",
+ "Access-Control-Request-Method": "GET",
+ }
+ root = f"http://localhost:{port}"
+ resp = requests.options(root, headers=headers)
+ assert resp.headers["access-control-allow-origin"] == "*"
+
+ resp = requests.get(f"{root}/-/routes", headers=headers)
+ assert resp.headers["access-control-allow-origin"] == "*"
+
+
if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
diff --git a/python/requirements.txt b/python/requirements.txt
index b4411342f..472e925de 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -69,3 +69,4 @@ tensorflow
testfixtures
werkzeug
xlrd
+starlette