From c43fa12e73c5de592ae52366f01ccd392646cf9b Mon Sep 17 00:00:00 2001 From: architkulkarni Date: Mon, 11 Jan 2021 13:27:44 -0800 Subject: [PATCH] [Serve] Support Starlette streaming response (#13328) --- python/ray/serve/backend_worker.py | 36 ++++++++++++++++++++++++++++++ python/ray/serve/http_proxy.py | 10 --------- python/ray/serve/tests/test_api.py | 17 ++++++++++++++ 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index e430dbe38..7a7198010 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -6,6 +6,8 @@ from itertools import groupby from typing import Union, List, Any, Callable, Type import time +import starlette.responses + import ray from ray.actor import ActorHandle from ray.async_compat import sync_to_async @@ -251,6 +253,36 @@ class RayServeReplica: return self.callable return getattr(self.callable, method_name) + async def ensure_serializable_response(self, response: Any) -> Any: + if isinstance(response, starlette.responses.StreamingResponse): + # response contains a generator/iterator which is not serializable. + # Exhaust the generator/iterator and store the results in a buffer. + body_buffer = [] + + async def mock_send(message): + assert message["type"] in { + "http.response.start", "http.response.body" + } + if (message["type"] == "http.response.body"): + body_buffer.append(message["body"]) + + async def mock_receive(): + # This is called in a tight loop in response() just to check + # for an http disconnect. So rather than return immediately + # we should suspend execution to avoid wasting CPU cycles. + never_set_event = asyncio.Event() + await never_set_event.wait() + + await response(scope=None, receive=mock_receive, send=mock_send) + content = b"".join(body_buffer) + return starlette.responses.Response( + content, + status_code=response.status_code, + headers=response.headers, + media_type=response.media_type) + else: + return response + async def invoke_single(self, request_item: Query) -> Any: logger.debug("Replica {} started executing request {}".format( self.replica_tag, request_item.metadata.request_id)) @@ -260,6 +292,7 @@ class RayServeReplica: start = time.time() try: result = await method_to_call(arg) + result = await self.ensure_serializable_response(result) self.request_counter.record(1) except Exception as e: import os @@ -317,6 +350,9 @@ class RayServeReplica: "results with length equal to the batch size" ".".format(batch_size, len(result_list))) raise RayServeException(error_message) + for i, result in enumerate(result_list): + result_list[i] = (await + self.ensure_serializable_response(result)) except Exception as e: wrapped_exception = wrap_to_ray_error(e) self.error_counter.record(1) diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index dad0c0034..5f722276e 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -137,16 +137,6 @@ class HTTPProxy: error_message = "Task Error. Traceback: {}.".format(result) await error_sender(error_message, 500) elif isinstance(result, starlette.responses.Response): - if isinstance(result, starlette.responses.StreamingResponse): - raise TypeError("Starlette StreamingResponse returned by " - f"backend for endpoint {endpoint_name}. " - "StreamingResponse is unserializable and not " - "supported by Ray Serve. Consider using " - "another Starlette response type such as " - "Response, HTMLResponse, PlainTextResponse, " - "or JSONResponse. If support for " - "StreamingResponse is desired, please let " - "the Ray team know by making a Github issue!") await result(scope, receive, send) else: await Response(result).send(scope, receive, send) diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index c0b443296..202b01386 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -90,6 +90,23 @@ def test_starlette_response(serve_instance): assert requests.get( "http://127.0.0.1:8000/redirect_response").text == "Hello, world!" + def streaming_response(_): + async def slow_numbers(): + for number in range(1, 4): + yield str(number) + await asyncio.sleep(0.01) + + return starlette.responses.StreamingResponse( + slow_numbers(), media_type="text/plain") + + client.create_backend("streaming_response", streaming_response) + client.create_endpoint( + "streaming_response", + backend="streaming_response", + route="/streaming_response") + assert requests.get( + "http://127.0.0.1:8000/streaming_response").text == "123" + def test_backend_user_config(serve_instance): client = serve_instance