[Serve] Migrate from Flask.Request to Starlette Request (#12852)

This commit is contained in:
architkulkarni
2020-12-21 13:34:15 -08:00
committed by GitHub
parent 5b48480e29
commit 8b4b4bf0a2
26 changed files with 140 additions and 174 deletions
@@ -10,7 +10,7 @@ class Counter:
def __init__(self):
self.count = 0
def __call__(self, flask_request):
def __call__(self, starlette_request):
self.count += 1
return {"current_counter": self.count}
@@ -6,8 +6,8 @@ ray.init(num_cpus=8)
client = serve.start()
def echo(flask_request):
return "hello " + flask_request.args.get("name", "serve!")
def echo(starlette_request):
return "hello " + starlette_request.query_params.get("name", "serve!")
client.create_backend("hello", echo)
@@ -16,13 +16,13 @@ client = serve.start()
def model_one(request):
print("Model 1 called with data ", request.args.get("data"))
print("Model 1 called with data ", request.query_params.get("data"))
return random()
def model_two(request):
print("Model 2 called with data ", request.args.get("data"))
return request.args.get("data")
print("Model 2 called with data ", request.query_params.get("data"))
return request.query_params.get("data")
class ComposedModel:
@@ -32,8 +32,8 @@ class ComposedModel:
self.model_two = client.get_handle("model_two")
# This method can be called concurrently!
async def __call__(self, flask_request):
data = flask_request.data
async def __call__(self, starlette_request):
data = await starlette_request.body()
score = await self.model_one.remote(data=data)
if score > 0.5:
@@ -14,8 +14,10 @@ import requests
# __doc_define_servable_v0_begin__
@serve.accept_batch
def batch_adder_v0(flask_requests: List):
numbers = [int(request.args["number"]) for request in flask_requests]
def batch_adder_v0(starlette_requests: List):
numbers = [
int(request.query_params["number"]) for request in starlette_requests
]
input_array = np.array(numbers)
print("Our input array has shape:", input_array.shape)
@@ -58,7 +60,7 @@ print("Result returned:", results)
# __doc_define_servable_v1_begin__
@serve.accept_batch
def batch_adder_v1(requests: List):
numbers = [int(request.args["number"]) for request in requests]
numbers = [int(request.query_params["number"]) for request in requests]
input_array = np.array(numbers)
print("Our input array has shape:", input_array.shape)
# Sleep for 200ms, this could be performing CPU intensive computation
@@ -48,9 +48,9 @@ class BoostingModel:
with open("/tmp/iris_labels.json") as f:
self.label_list = json.load(f)
def __call__(self, flask_request):
payload = flask_request.json
print("Worker: received flask request with data", payload)
async def __call__(self, starlette_request):
payload = await starlette_request.json()
print("Worker: received starlette request with data", payload)
input_vector = [
payload["sepal length"],
@@ -143,9 +143,9 @@ class BoostingModelv2:
with open("/tmp/iris_labels_2.json") as f:
self.label_list = json.load(f)
def __call__(self, flask_request):
payload = flask_request.json
print("Worker: received flask request with data", payload)
async def __call__(self, starlette_request):
payload = await starlette_request.json()
print("Worker: received starlette request with data", payload)
input_vector = [
payload["sepal length"],
@@ -27,8 +27,8 @@ class ImageModel:
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __call__(self, flask_request):
image_payload_bytes = flask_request.data
async def __call__(self, starlette_request):
image_payload_bytes = await starlette_request.body()
pil_image = Image.open(BytesIO(image_payload_bytes))
print("[1/3] Parsed image data: {}".format(pil_image))
@@ -54,9 +54,9 @@ class BoostingModel:
with open(LABEL_PATH) as f:
self.label_list = json.load(f)
def __call__(self, flask_request):
payload = flask_request.json
print("Worker: received flask request with data", payload)
async def __call__(self, starlette_request):
payload = await starlette_request.json()
print("Worker: received starlette request with data", payload)
input_vector = [
payload["sepal length"],
@@ -51,10 +51,10 @@ class TFMnistModel:
self.model_path = model_path
self.model = tf.keras.models.load_model(model_path)
def __call__(self, flask_request):
async def __call__(self, starlette_request):
# Step 1: transform HTTP request -> tensorflow input
# Here we define the request schema to be a json array.
input_array = np.array(flask_request.json["array"])
input_array = np.array((await starlette_request.json())["array"])
reshaped_array = input_array.reshape((1, 28, 28))
# Step 2: tensorflow input -> tensorflow output
+2 -2
View File
@@ -9,8 +9,8 @@ import requests
from ray import serve
def echo(flask_request):
return ["hello " + flask_request.args.get("name", "serve!")]
def echo(starlette_request):
return ["hello " + starlette_request.query_params.get("name", "serve!")]
client = serve.start()
+4 -3
View File
@@ -1,6 +1,6 @@
"""
Example actor that adds an increment to a number. This number can
come from either web (parsing Flask request) or python call.
come from either web (parsing Starlette request) or python call.
This actor can be called from HTTP as well as from Python.
"""
@@ -30,9 +30,10 @@ class MagicCounter:
def __init__(self, increment):
self.increment = increment
def __call__(self, flask_request, base_number=None):
def __call__(self, starlette_request, base_number=None):
if serve.context.web:
base_number = int(flask_request.args.get("base_number", "0"))
base_number = int(
starlette_request.query_params.get("base_number", "0"))
return base_number + self.increment
@@ -1,6 +1,6 @@
"""
Example actor that adds an increment to a number. This number can
come from either web (parsing Flask request) or python call.
come from either web (parsing Starlette request) or python call.
The queries incoming to this actor are batched.
This actor can be called from HTTP as well as from Python.
"""
@@ -31,12 +31,13 @@ class MagicCounter:
self.increment = increment
@serve.accept_batch
def __call__(self, flask_request_list, base_number=None):
def __call__(self, starlette_request_list, base_number=None):
# batch_size = serve.context.batch_size
if serve.context.web:
result = []
for flask_request in flask_request_list:
base_number = int(flask_request.args.get("base_number", "0"))
for starlette_request in starlette_request_list:
base_number = int(
starlette_request.query_params.get("base_number", "0"))
result.append(base_number)
return list(map(lambda x: x + self.increment, result))
else:
+1 -1
View File
@@ -11,7 +11,7 @@ class MagicCounter:
self.increment = increment
@serve.accept_batch
def __call__(self, flask_request, base_number=None):
def __call__(self, starlette_request, base_number=None):
# __call__ fn should preserve the batch size
# base_number is a python list
+3 -3
View File
@@ -12,8 +12,8 @@ client = serve.start()
# a backend can be a function or class.
# it can be made to be invoked from web as well as python.
def echo_v1(flask_request):
response = flask_request.args.get("response", "web")
def echo_v1(starlette_request):
response = starlette_request.query_params.get("response", "web")
return response
@@ -32,7 +32,7 @@ print(ray.get(client.get_handle("my_endpoint").remote(response="hello")))
# We can also add a new backend and split the traffic.
def echo_v2(flask_request):
def echo_v2(starlette_request):
# magic, only from web.
return "something new"
+3 -3
View File
@@ -96,7 +96,7 @@ class RayServeHandle:
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
**kwargs):
"""Issue an asynchrounous request to the endpoint.
"""Issue an asynchronous request to the endpoint.
Returns a Ray ObjectRef whose results can be waited for or retrieved
using ray.wait or ray.get, respectively.
@@ -106,9 +106,9 @@ class RayServeHandle:
Args:
request_data(dict, Any): If it's a dictionary, the data will be
available in ``request.json()`` or ``request.form()``.
Otherwise, it will be available in ``request.data``.
Otherwise, it will be available in ``request.body()``.
``**kwargs``: All keyword arguments will be available in
``request.args``.
``request.query_params``.
"""
if not self.sync:
raise RayServeException(
+14 -65
View File
@@ -1,76 +1,25 @@
import io
import json
import flask
import starlette.requests
def build_flask_request(asgi_scope_dict, request_body):
"""Build and return a flask request from ASGI payload
def build_starlette_request(scope, serialized_body: bytes):
"""Build and return a Starlette Request from ASGI payload.
This function is indented to be used immediately before task invocation
happen.
This function is intended to be used immediately before task invocation
happens.
"""
wsgi_environ = build_wsgi_environ(asgi_scope_dict, request_body)
# We set populate_request=False to prevent self reference, which can lead
# to objects tracked by python garbage collector and memory growth. See
# https://github.com/ray-project/ray/issues/12395.
return flask.Request(wsgi_environ, populate_request=False)
# Simulates receiving HTTP body from TCP socket. In reality, the body has
# already been streamed in chunks and stored in serialized_body.
async def mock_receive():
return {
"body": serialized_body,
"type": "http.request",
"more_body": False
}
def build_wsgi_environ(scope, body):
"""
Builds a scope and request body into a WSGI environ object.
This code snippet is taken from https://github.com/django/asgiref/blob
/36c3e8dc70bf38fe2db87ac20b514f21aaf5ea9d/asgiref/wsgi.py#L52
WSGI specification can be found at
https://www.python.org/dev/peps/pep-0333/
This function helps translate ASGI scope and body into a flask request.
"""
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": scope.get("root_path", ""),
"PATH_INFO": scope["path"],
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": "HTTP/{}".format(scope["http_version"]),
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": body,
"wsgi.errors": io.BytesIO(),
"wsgi.multithread": True,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
# Get server name and port - required in WSGI, not in ASGI
environ["SERVER_NAME"] = scope["server"][0]
environ["SERVER_PORT"] = str(scope["server"][1])
environ["REMOTE_ADDR"] = scope["client"][0]
# Transforms headers into environ entries.
for name, value in scope.get("headers", []):
# name, values are both bytes, we need to decode them to string
name = name.decode("latin1")
value = value.decode("latin1")
# Handle name correction to conform to WSGI spec
# https://www.python.org/dev/peps/pep-0333/#environ-variables
if name == "content-length":
corrected_name = "CONTENT_LENGTH"
elif name == "content-type":
corrected_name = "CONTENT_TYPE"
else:
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
# If the header value repeated,
# we will just concatenate it to the field.
if corrected_name in environ:
value = environ[corrected_name] + "," + value
environ[corrected_name] = value
return environ
return starlette.requests.Request(scope, mock_receive)
class Response:
+25 -5
View File
@@ -19,8 +19,8 @@ from ray.serve.utils import (block_until_http_ready, format_actor_name,
def test_e2e(serve_instance):
client = serve_instance
def function(flask_request):
return {"method": flask_request.method}
def function(starlette_request):
return {"method": starlette_request.method}
client.create_backend("echo:v1", function)
client.create_endpoint(
@@ -97,7 +97,7 @@ def test_backend_user_config(serve_instance):
def __init__(self):
self.count = 10
def __call__(self, flask_request):
def __call__(self, starlette_request):
return self.count, os.getpid()
def reconfigure(self, config):
@@ -820,8 +820,8 @@ def test_serve_metrics(serve_instance):
client = serve_instance
@serve.accept_batch
def batcher(flask_requests):
return ["hello"] * len(flask_requests)
def batcher(starlette_requests):
return ["hello"] * len(starlette_requests)
client.create_backend("metrics", batcher)
client.create_endpoint("metrics", backend="metrics", route="/metrics")
@@ -871,6 +871,26 @@ def test_serve_metrics(serve_instance):
verify_metrics()
def test_starlette_request(serve_instance):
client = serve_instance
async def echo_body(starlette_request):
data = await starlette_request.body()
return data
UVICORN_HIGH_WATER_MARK = 65536 # max bytes in one message
# Long string to test serialization of multiple messages.
long_string = "x" * 10 * UVICORN_HIGH_WATER_MARK
client.create_backend("echo:v1", echo_body)
client.create_endpoint(
"endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"])
resp = requests.post("http://127.0.0.1:8000/api", data=long_string).text
assert resp == long_string
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))
@@ -85,7 +85,7 @@ async def test_runner_wraps_error():
async def test_servable_function(serve_instance, router,
mock_controller_with_name):
def echo(request):
return request.args["i"]
return request.query_params["i"]
await add_servable_to_router(echo, router, mock_controller_with_name[0])
@@ -103,7 +103,7 @@ async def test_servable_class(serve_instance, router,
self.increment = inc
def __call__(self, request):
return request.args["i"] + self.increment
return request.query_params["i"] + self.increment
await add_servable_to_router(
MyAdder, router, mock_controller_with_name[0], init_args=(3, ))
@@ -277,7 +277,7 @@ async def test_user_config_update(serve_instance, router,
def __init__(self):
self.reval = ""
def __call__(self, flask_request):
def __call__(self, starlette_request):
return self.retval
def reconfigure(self, config):
+8 -8
View File
@@ -8,7 +8,7 @@ def test_handle_in_endpoint(serve_instance):
client = serve_instance
class Endpoint1:
def __call__(self, flask_request):
def __call__(self, starlette_request):
return "hello"
class Endpoint2:
@@ -40,12 +40,12 @@ def test_handle_http_args(serve_instance):
client = serve_instance
class Endpoint:
def __call__(self, request):
async def __call__(self, request):
return {
"args": dict(request.args),
"args": dict(request.query_params),
"headers": dict(request.headers),
"method": request.method,
"json": request.json
"json": await request.json()
}
client.create_backend("backend", Endpoint)
@@ -58,7 +58,7 @@ def test_handle_http_args(serve_instance):
"arg2": "2"
},
"headers": {
"X-Custom-Header": "value"
"x-custom-header": "value"
},
"method": "POST",
"json": {
@@ -81,10 +81,10 @@ def test_handle_http_args(serve_instance):
for resp in [resp_web, resp_handle]:
for field in ["args", "method", "json"]:
assert resp[field] == ground_truth[field]
resp["headers"]["X-Custom-Header"] == "value"
resp["headers"]["x-custom-header"] == "value"
def test_handle_inject_flask_request(serve_instance):
def test_handle_inject_starlette_request(serve_instance):
client = serve_instance
def echo_request_type(request):
@@ -103,7 +103,7 @@ def test_handle_inject_flask_request(serve_instance):
for route in ["/echo", "/wrapper"]:
resp = requests.get(f"http://127.0.0.1:8000{route}")
request_type = resp.text
assert request_type == "<class 'flask.wrappers.Request'>"
assert request_type == "<class 'starlette.requests.Request'>"
if __name__ == "__main__":
+1 -1
View File
@@ -12,7 +12,7 @@ ray.init(address="{}")
from ray import serve
client = serve.connect()
def driver(flask_request):
def driver(starlette_request):
return "OK!"
client.create_backend("driver", driver)
+2 -2
View File
@@ -15,7 +15,7 @@ def test_np_in_composed_model(serve_instance):
# in cloudpickle _from_numpy_buffer
def sum_model(request):
return np.sum(request.args["data"])
return np.sum(request.query_params["data"])
class ComposedModel:
def __init__(self):
@@ -42,7 +42,7 @@ def test_backend_worker_memory_growth(serve_instance):
# https://github.com/ray-project/ray/issues/12395
client = serve_instance
def gc_unreachable_objects(flask_request):
def gc_unreachable_objects(starlette_request):
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
return len(gc.garbage)
+16 -17
View File
@@ -8,7 +8,6 @@ import random
import string
import time
from typing import List, Dict
import io
import os
from ray.serve.exceptions import RayServeException
from collections import UserDict
@@ -16,18 +15,18 @@ from collections import UserDict
import requests
import numpy as np
import pydantic
import flask
import starlette.requests
import ray
from ray.serve.constants import HTTP_PROXY_TIMEOUT
from ray.serve.context import TaskContext
from ray.serve.http_util import build_flask_request
from ray.serve.http_util import build_starlette_request
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
class ServeMultiDict(UserDict):
"""Compatible data structure to simulate Flask.Request.args API."""
"""Compatible data structure to simulate Starlette Request query_args."""
def getlist(self, key):
"""Return the list of items for a given key."""
@@ -35,11 +34,14 @@ class ServeMultiDict(UserDict):
class ServeRequest:
"""The request object used in Python context.
"""The request object used when passing arguments via ServeHandle.
ServeRequest is built to have similar API as Flask.Request. You only need
to write your model serving code once; it can be queried by both HTTP and
Python.
ServeRequest partially implements the API of Starlette Request. You only
need to write your model serving code once; it can be queried by both HTTP
and Python.
To use the full Starlette Request interface with ServeHandle, you may
instead directly pass in a Starlette Request object to the ServeHandle.
"""
def __init__(self, data, kwargs, headers, method):
@@ -59,28 +61,25 @@ class ServeRequest:
return self._method
@property
def args(self):
def query_params(self):
"""The keyword arguments from ``handle.remote(**kwargs)``."""
return self._kwargs
@property
def json(self):
async def json(self):
"""The request dictionary, from ``handle.remote(dict)``."""
if not isinstance(self._data, dict):
raise RayServeException("Request data is not a dictionary. "
f"It is {type(self._data)}.")
return self._data
@property
def form(self):
async def form(self):
"""The request dictionary, from ``handle.remote(dict)``."""
if not isinstance(self._data, dict):
raise RayServeException("Request data is not a dictionary. "
f"It is {type(self._data)}.")
return self._data
@property
def data(self):
async def body(self):
"""The request data from ``handle.remote(obj)``."""
return self._data
@@ -88,13 +87,13 @@ class ServeRequest:
def parse_request_item(request_item):
if request_item.metadata.request_context == TaskContext.Web:
asgi_scope, body_bytes = request_item.args
return build_flask_request(asgi_scope, io.BytesIO(body_bytes))
return build_starlette_request(asgi_scope, body_bytes)
else:
arg = request_item.args[0] if len(request_item.args) == 1 else None
# If the input data from handle is web request, we don't need to wrap
# it in ServeRequest.
if isinstance(arg, flask.Request):
if isinstance(arg, starlette.requests.Request):
return arg
return ServeRequest(