[Serve] Support Starlette streaming response (#13328)

This commit is contained in:
architkulkarni
2021-01-11 13:27:44 -08:00
committed by GitHub
parent c39658f368
commit c43fa12e73
3 changed files with 53 additions and 10 deletions
+36
View File
@@ -6,6 +6,8 @@ from itertools import groupby
from typing import Union, List, Any, Callable, Type
import time
import starlette.responses
import ray
from ray.actor import ActorHandle
from ray.async_compat import sync_to_async
@@ -251,6 +253,36 @@ class RayServeReplica:
return self.callable
return getattr(self.callable, method_name)
async def ensure_serializable_response(self, response: Any) -> Any:
if isinstance(response, starlette.responses.StreamingResponse):
# response contains a generator/iterator which is not serializable.
# Exhaust the generator/iterator and store the results in a buffer.
body_buffer = []
async def mock_send(message):
assert message["type"] in {
"http.response.start", "http.response.body"
}
if (message["type"] == "http.response.body"):
body_buffer.append(message["body"])
async def mock_receive():
# This is called in a tight loop in response() just to check
# for an http disconnect. So rather than return immediately
# we should suspend execution to avoid wasting CPU cycles.
never_set_event = asyncio.Event()
await never_set_event.wait()
await response(scope=None, receive=mock_receive, send=mock_send)
content = b"".join(body_buffer)
return starlette.responses.Response(
content,
status_code=response.status_code,
headers=response.headers,
media_type=response.media_type)
else:
return response
async def invoke_single(self, request_item: Query) -> Any:
logger.debug("Replica {} started executing request {}".format(
self.replica_tag, request_item.metadata.request_id))
@@ -260,6 +292,7 @@ class RayServeReplica:
start = time.time()
try:
result = await method_to_call(arg)
result = await self.ensure_serializable_response(result)
self.request_counter.record(1)
except Exception as e:
import os
@@ -317,6 +350,9 @@ class RayServeReplica:
"results with length equal to the batch size"
".".format(batch_size, len(result_list)))
raise RayServeException(error_message)
for i, result in enumerate(result_list):
result_list[i] = (await
self.ensure_serializable_response(result))
except Exception as e:
wrapped_exception = wrap_to_ray_error(e)
self.error_counter.record(1)
-10
View File
@@ -137,16 +137,6 @@ class HTTPProxy:
error_message = "Task Error. Traceback: {}.".format(result)
await error_sender(error_message, 500)
elif isinstance(result, starlette.responses.Response):
if isinstance(result, starlette.responses.StreamingResponse):
raise TypeError("Starlette StreamingResponse returned by "
f"backend for endpoint {endpoint_name}. "
"StreamingResponse is unserializable and not "
"supported by Ray Serve. Consider using "
"another Starlette response type such as "
"Response, HTMLResponse, PlainTextResponse, "
"or JSONResponse. If support for "
"StreamingResponse is desired, please let "
"the Ray team know by making a Github issue!")
await result(scope, receive, send)
else:
await Response(result).send(scope, receive, send)
+17
View File
@@ -90,6 +90,23 @@ def test_starlette_response(serve_instance):
assert requests.get(
"http://127.0.0.1:8000/redirect_response").text == "Hello, world!"
def streaming_response(_):
async def slow_numbers():
for number in range(1, 4):
yield str(number)
await asyncio.sleep(0.01)
return starlette.responses.StreamingResponse(
slow_numbers(), media_type="text/plain")
client.create_backend("streaming_response", streaming_response)
client.create_endpoint(
"streaming_response",
backend="streaming_response",
route="/streaming_response")
assert requests.get(
"http://127.0.0.1:8000/streaming_response").text == "123"
def test_backend_user_config(serve_instance):
client = serve_instance