[Serve] Support basic Starlette response types (#12811)

This commit is contained in:
architkulkarni
2020-12-14 15:03:56 -08:00
committed by GitHub
parent d0813c1c58
commit 231518e86f
9 changed files with 83 additions and 5 deletions
+2 -1
View File
@@ -255,7 +255,8 @@ class Client:
Args:
backend_tag (str): a unique tag assign to identify this backend.
func_or_class (callable, class): a function or a class implementing
__call__.
__call__, returning a JSON-serializable object or a
Starlette Response object.
actor_init_args (optional): the arguments to pass to the class.
initialization method.
ray_actor_options (optional): options to be passed into the
+13
View File
@@ -3,6 +3,7 @@ import socket
from typing import List
import uvicorn
import starlette.responses
import ray
from ray.exceptions import RayTaskError
@@ -126,6 +127,18 @@ class HTTPProxy:
if isinstance(result, RayTaskError):
error_message = "Task Error. Traceback: {}.".format(result)
await error_sender(error_message, 500)
elif isinstance(result, starlette.responses.Response):
if isinstance(result, starlette.responses.StreamingResponse):
raise TypeError("Starlette StreamingResponse returned by "
f"backend for endpoint {endpoint_name}. "
"StreamingResponse is unserializable and not "
"supported by Ray Serve. Consider using "
"another Starlette response type such as "
"Response, HTMLResponse, PlainTextResponse, "
"or JSONResponse. If support for "
"StreamingResponse is desired, please let "
"the Ray team know by making a Github issue!")
await result(scope, receive, send)
else:
await Response(result).send(scope, receive, send)
+1 -1
View File
@@ -117,7 +117,7 @@ class Response:
elif content_type == "json":
self.raw_headers.append([b"content-type", b"application/json"])
else:
raise ValueError("Invalid content type {}".foramt(content_type))
raise ValueError("Invalid content type {}".format(content_type))
async def send(self, scope, receive, send):
await send({
+58
View File
@@ -4,6 +4,7 @@ import time
import os
import pytest
import requests
import starlette.responses
import ray
from ray import serve
@@ -32,6 +33,63 @@ def test_e2e(serve_instance):
assert resp == "POST"
def test_starlette_response(serve_instance):
client = serve_instance
def basic_response(_):
return starlette.responses.Response(
"Hello, world!", media_type="text/plain")
client.create_backend("basic_response", basic_response)
client.create_endpoint(
"basic_response", backend="basic_response", route="/basic_response")
assert requests.get(
"http://127.0.0.1:8000/basic_response").text == "Hello, world!"
def html_response(_):
return starlette.responses.HTMLResponse(
"<html><body><h1>Hello, world!</h1></body></html>")
client.create_backend("html_response", html_response)
client.create_endpoint(
"html_response", backend="html_response", route="/html_response")
assert requests.get(
"http://127.0.0.1:8000/html_response"
).text == "<html><body><h1>Hello, world!</h1></body></html>"
def plain_text_response(_):
return starlette.responses.PlainTextResponse("Hello, world!")
client.create_backend("plain_text_response", plain_text_response)
client.create_endpoint(
"plain_text_response",
backend="plain_text_response",
route="/plain_text_response")
assert requests.get(
"http://127.0.0.1:8000/plain_text_response").text == "Hello, world!"
def json_response(_):
return starlette.responses.JSONResponse({"hello": "world"})
client.create_backend("json_response", json_response)
client.create_endpoint(
"json_response", backend="json_response", route="/json_response")
assert requests.get("http://127.0.0.1:8000/json_response").json()[
"hello"] == "world"
def redirect_response(_):
return starlette.responses.RedirectResponse(
url="http://127.0.0.1:8000/basic_response")
client.create_backend("redirect_response", redirect_response)
client.create_endpoint(
"redirect_response",
backend="redirect_response",
route="/redirect_response")
assert requests.get(
"http://127.0.0.1:8000/redirect_response").text == "Hello, world!"
def test_backend_user_config(serve_instance):
client = serve_instance
+1
View File
@@ -38,6 +38,7 @@ tensorboardX
uvicorn
pydantic<1.7
dataclasses; python_version < '3.7'
starlette
# Requirements for running tests
blist; platform_system != "Windows"
+1 -1
View File
@@ -97,7 +97,7 @@ ray_files += [
extras = {
"serve": [
"uvicorn", "flask", "requests", "pydantic<1.7",
"dataclasses; python_version < '3.7'"
"dataclasses; python_version < '3.7'", "starlette"
],
"tune": [
"dataclasses; python_version < '3.7'", "pandas", "tabulate",