mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[Serve] Support Starlette streaming response (#13328)
This commit is contained in:
@@ -6,6 +6,8 @@ from itertools import groupby
|
||||
from typing import Union, List, Any, Callable, Type
|
||||
import time
|
||||
|
||||
import starlette.responses
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.async_compat import sync_to_async
|
||||
@@ -251,6 +253,36 @@ class RayServeReplica:
|
||||
return self.callable
|
||||
return getattr(self.callable, method_name)
|
||||
|
||||
async def ensure_serializable_response(self, response: Any) -> Any:
|
||||
if isinstance(response, starlette.responses.StreamingResponse):
|
||||
# response contains a generator/iterator which is not serializable.
|
||||
# Exhaust the generator/iterator and store the results in a buffer.
|
||||
body_buffer = []
|
||||
|
||||
async def mock_send(message):
|
||||
assert message["type"] in {
|
||||
"http.response.start", "http.response.body"
|
||||
}
|
||||
if (message["type"] == "http.response.body"):
|
||||
body_buffer.append(message["body"])
|
||||
|
||||
async def mock_receive():
|
||||
# This is called in a tight loop in response() just to check
|
||||
# for an http disconnect. So rather than return immediately
|
||||
# we should suspend execution to avoid wasting CPU cycles.
|
||||
never_set_event = asyncio.Event()
|
||||
await never_set_event.wait()
|
||||
|
||||
await response(scope=None, receive=mock_receive, send=mock_send)
|
||||
content = b"".join(body_buffer)
|
||||
return starlette.responses.Response(
|
||||
content,
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
media_type=response.media_type)
|
||||
else:
|
||||
return response
|
||||
|
||||
async def invoke_single(self, request_item: Query) -> Any:
|
||||
logger.debug("Replica {} started executing request {}".format(
|
||||
self.replica_tag, request_item.metadata.request_id))
|
||||
@@ -260,6 +292,7 @@ class RayServeReplica:
|
||||
start = time.time()
|
||||
try:
|
||||
result = await method_to_call(arg)
|
||||
result = await self.ensure_serializable_response(result)
|
||||
self.request_counter.record(1)
|
||||
except Exception as e:
|
||||
import os
|
||||
@@ -317,6 +350,9 @@ class RayServeReplica:
|
||||
"results with length equal to the batch size"
|
||||
".".format(batch_size, len(result_list)))
|
||||
raise RayServeException(error_message)
|
||||
for i, result in enumerate(result_list):
|
||||
result_list[i] = (await
|
||||
self.ensure_serializable_response(result))
|
||||
except Exception as e:
|
||||
wrapped_exception = wrap_to_ray_error(e)
|
||||
self.error_counter.record(1)
|
||||
|
||||
@@ -137,16 +137,6 @@ class HTTPProxy:
|
||||
error_message = "Task Error. Traceback: {}.".format(result)
|
||||
await error_sender(error_message, 500)
|
||||
elif isinstance(result, starlette.responses.Response):
|
||||
if isinstance(result, starlette.responses.StreamingResponse):
|
||||
raise TypeError("Starlette StreamingResponse returned by "
|
||||
f"backend for endpoint {endpoint_name}. "
|
||||
"StreamingResponse is unserializable and not "
|
||||
"supported by Ray Serve. Consider using "
|
||||
"another Starlette response type such as "
|
||||
"Response, HTMLResponse, PlainTextResponse, "
|
||||
"or JSONResponse. If support for "
|
||||
"StreamingResponse is desired, please let "
|
||||
"the Ray team know by making a Github issue!")
|
||||
await result(scope, receive, send)
|
||||
else:
|
||||
await Response(result).send(scope, receive, send)
|
||||
|
||||
@@ -90,6 +90,23 @@ def test_starlette_response(serve_instance):
|
||||
assert requests.get(
|
||||
"http://127.0.0.1:8000/redirect_response").text == "Hello, world!"
|
||||
|
||||
def streaming_response(_):
|
||||
async def slow_numbers():
|
||||
for number in range(1, 4):
|
||||
yield str(number)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
return starlette.responses.StreamingResponse(
|
||||
slow_numbers(), media_type="text/plain")
|
||||
|
||||
client.create_backend("streaming_response", streaming_response)
|
||||
client.create_endpoint(
|
||||
"streaming_response",
|
||||
backend="streaming_response",
|
||||
route="/streaming_response")
|
||||
assert requests.get(
|
||||
"http://127.0.0.1:8000/streaming_response").text == "123"
|
||||
|
||||
|
||||
def test_backend_user_config(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
Reference in New Issue
Block a user