From 8b4b4bf0a2ae0f1c08095f1d23a74698528c52e0 Mon Sep 17 00:00:00 2001 From: architkulkarni Date: Mon, 21 Dec 2020 13:34:15 -0800 Subject: [PATCH] [Serve] Migrate from Flask.Request to Starlette Request (#12852) --- doc/source/serve/faq.rst | 37 ++++----- doc/source/serve/index.rst | 3 + doc/source/serve/key-concepts.rst | 8 +- doc/source/serve/package-ref.rst | 2 +- doc/source/serve/tutorials/batch.rst | 10 +-- .../serve/examples/doc/quickstart_class.py | 2 +- .../serve/examples/doc/quickstart_function.py | 4 +- .../examples/doc/snippet_model_composition.py | 10 +-- .../ray/serve/examples/doc/tutorial_batch.py | 8 +- .../ray/serve/examples/doc/tutorial_deploy.py | 12 +-- .../serve/examples/doc/tutorial_pytorch.py | 4 +- .../serve/examples/doc/tutorial_sklearn.py | 6 +- .../serve/examples/doc/tutorial_tensorflow.py | 4 +- python/ray/serve/examples/echo.py | 4 +- python/ray/serve/examples/echo_actor.py | 7 +- python/ray/serve/examples/echo_actor_batch.py | 9 ++- python/ray/serve/examples/echo_batching.py | 2 +- python/ray/serve/examples/echo_full.py | 6 +- python/ray/serve/handle.py | 6 +- python/ray/serve/http_util.py | 79 ++++--------------- python/ray/serve/tests/test_api.py | 30 +++++-- python/ray/serve/tests/test_backend_worker.py | 6 +- python/ray/serve/tests/test_handle.py | 16 ++-- python/ray/serve/tests/test_persistence.py | 2 +- python/ray/serve/tests/test_regression.py | 4 +- python/ray/serve/utils.py | 33 ++++---- 26 files changed, 140 insertions(+), 174 deletions(-) diff --git a/doc/source/serve/faq.rst b/doc/source/serve/faq.rst index 4f6791c5b..451307ce3 100644 --- a/doc/source/serve/faq.rst +++ b/doc/source/serve/faq.rst @@ -117,7 +117,7 @@ policies `, finding the next available replica, and batching requests together. When the request arrives in the model, you can access the data similarly to how -you would with HTTP request. Here are some examples how ServeRequest mirrors Flask.Request: +you would with HTTP request. Here are some examples how ServeRequest mirrors Starlette.Request: .. list-table:: :header-rows: 1 @@ -125,25 +125,25 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla * - HTTP - ServeHandle - | Request - | (Flask.Request and ServeRequest) + | (Starlette.Request and ServeRequest) * - ``requests.get(..., headers={...})`` - ``handle.options(http_headers={...})`` - ``request.headers`` * - ``requests.post(...)`` - ``handle.options(http_method="POST")`` - - ``requests.method`` - * - ``request.get(..., json={...})`` + - ``request.method`` + * - ``requests.get(..., json={...})`` - ``handle.remote({...})`` - - ``request.json`` - * - ``request.get(..., form={...})`` + - ``await request.json()`` + * - ``requests.get(..., form={...})`` - ``handle.remote({...})`` - - ``request.form`` - * - ``request.get(..., params={"a":"b"})`` + - ``await request.form()`` + * - ``requests.get(..., params={"a":"b"})`` - ``handle.remote(a="b")`` - - ``request.args`` - * - ``request.get(..., data="long string")`` + - ``request.query_params`` + * - ``requests.get(..., data="long string")`` - ``handle.remote("long string")`` - - ``request.data`` + - ``await request.body()`` * - ``N/A`` - ``handle.remote(python_object)`` - ``request.data`` @@ -157,9 +157,9 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla .. code-block:: python - import flask + import starlette.requests - if isinstance(request, flask.Request): + if isinstance(request, starlette.requests.Request): print("Request coming from web!") elif isinstance(request, ServeRequest): print("Request coming from Python!") @@ -170,10 +170,10 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla .. code-block:: python - handle.remote(flask_request) + handle.remote(starlette_request) In this case, Serve will `not` wrap it in ServeRequest. You can directly - process the request as a ``flask.Request``. + process the request as a ``starlette.requests.Request``. How fast is Ray Serve? ---------------------- @@ -187,13 +187,6 @@ You can checkout our `microbenchmark instruction `_ as our web server, -alongside with the power of Python asyncio. -**Flask is ONLY the request object that we are using, Uvicorn (not flask) provides the webserver.** - Can I use asyncio along with Ray Serve? --------------------------------------- Yes! You can make your servable methods ``async def`` and Serve will run them diff --git a/doc/source/serve/index.rst b/doc/source/serve/index.rst index b9c0d1497..af64475a3 100644 --- a/doc/source/serve/index.rst +++ b/doc/source/serve/index.rst @@ -33,6 +33,9 @@ Since Serve is built on Ray, it also allows you to scale to many machines, in yo If you want to try out Serve, join our `community slack `_ and discuss in the #serve channel. +.. note:: + Starting with Ray version 1.3.0, Ray Serve backends must take in a Starlette Request object instead of a Flask Request object. + See the `migration guide `_ for details. Installation ============ diff --git a/doc/source/serve/key-concepts.rst b/doc/source/serve/key-concepts.rst index deada7233..f215ad4b5 100644 --- a/doc/source/serve/key-concepts.rst +++ b/doc/source/serve/key-concepts.rst @@ -19,10 +19,8 @@ Backends Backends define the implementation of your business logic or models that will handle requests when queries come in to :ref:`serve-endpoint`. In order to support seamless scalability backends can have many replicas, which are individual processes running in the Ray cluster to handle requests. To define a backend, first you must define the "handler" or the business logic you'd like to respond with. -The handler should take as input a `Flask Request object `_. -The handler should return any JSON-serializable object as output. For a more customizable response type, the handler may return a +The handler should take as input a `Starlette Request object `_ and return any JSON-serializable object as output. For a more customizable response type, the handler may return a `Starlette Response object `_. -In the future, Ray Serve will support `Starlette Request objects `_ as input as well. A backend is defined using :mod:`client.create_backend `, and the implementation can be defined as either a function or a class. Use a function when your response is stateless and a class when you might need to maintain some state (like a model). @@ -32,7 +30,7 @@ A backend consists of a number of *replicas*, which are individual copies of the .. code-block:: python - def handle_request(flask_request): + def handle_request(starlette_request): return "hello world" class RequestHandler: @@ -40,7 +38,7 @@ A backend consists of a number of *replicas*, which are individual copies of the def __init__(self, msg): self.msg = msg - def __call__(self, flask_request): + def __call__(self, starlette_request): return self.msg client.create_backend("simple_backend", handle_request) diff --git a/doc/source/serve/package-ref.rst b/doc/source/serve/package-ref.rst index 5a9e947ff..e64794500 100644 --- a/doc/source/serve/package-ref.rst +++ b/doc/source/serve/package-ref.rst @@ -23,7 +23,7 @@ Handle API :members: remote, options When calling from Python, the backend implementation will receive ``ServeRequest`` -objects instead of Flask requests. +objects instead of Starlette requests. .. autoclass:: ray.serve.utils.ServeRequest :members: diff --git a/doc/source/serve/tutorials/batch.rst b/doc/source/serve/tutorials/batch.rst index e9a74f94f..48fbc5344 100644 --- a/doc/source/serve/tutorials/batch.rst +++ b/doc/source/serve/tutorials/batch.rst @@ -30,13 +30,13 @@ You can use the ``@serve.accept_batch`` decorator to annotate a function or a cl This annotation is needed because batched backends have different APIs compared to single request backends. In a batched backend, the inputs are a list of values. -For single query backend, the input type is a single Flask request or +For single query backend, the input type is a single Starlette request or :mod:`ServeRequest `: .. code-block:: python def single_request( - request: Union[Flask.Request, ServeRequest], + request: Union[starlette.requests.Request, ServeRequest], ): pass @@ -47,7 +47,7 @@ types: @serve.accept_batch def batched_request( - request: List[Union[Flask.Request, ServeRequest]], + request: List[Union[starlette.requests.Request, ServeRequest]], ): pass @@ -84,8 +84,8 @@ Ray Serve was able to evaluate them in batches. What if you want to evaluate a whole batch in Python? Ray Serve allows you to send queries via the Python API. A batch of queries can either come from the web server -or the Python API. Requests coming from the Python API will have the similar API -as Flask.Request. See more on the API :ref:`here`. +or the Python API. Requests coming from the Python API will have a similar API +to Starlette Request. See more on the API :ref:`here`. .. literalinclude:: ../../../../python/ray/serve/examples/doc/tutorial_batch.py :start-after: __doc_define_servable_v1_begin__ diff --git a/python/ray/serve/examples/doc/quickstart_class.py b/python/ray/serve/examples/doc/quickstart_class.py index 6c6ba2808..d4238ea6a 100644 --- a/python/ray/serve/examples/doc/quickstart_class.py +++ b/python/ray/serve/examples/doc/quickstart_class.py @@ -10,7 +10,7 @@ class Counter: def __init__(self): self.count = 0 - def __call__(self, flask_request): + def __call__(self, starlette_request): self.count += 1 return {"current_counter": self.count} diff --git a/python/ray/serve/examples/doc/quickstart_function.py b/python/ray/serve/examples/doc/quickstart_function.py index 9e7c9bb1f..81ae4b7f1 100644 --- a/python/ray/serve/examples/doc/quickstart_function.py +++ b/python/ray/serve/examples/doc/quickstart_function.py @@ -6,8 +6,8 @@ ray.init(num_cpus=8) client = serve.start() -def echo(flask_request): - return "hello " + flask_request.args.get("name", "serve!") +def echo(starlette_request): + return "hello " + starlette_request.query_params.get("name", "serve!") client.create_backend("hello", echo) diff --git a/python/ray/serve/examples/doc/snippet_model_composition.py b/python/ray/serve/examples/doc/snippet_model_composition.py index 67a1890bf..6439bb9bf 100644 --- a/python/ray/serve/examples/doc/snippet_model_composition.py +++ b/python/ray/serve/examples/doc/snippet_model_composition.py @@ -16,13 +16,13 @@ client = serve.start() def model_one(request): - print("Model 1 called with data ", request.args.get("data")) + print("Model 1 called with data ", request.query_params.get("data")) return random() def model_two(request): - print("Model 2 called with data ", request.args.get("data")) - return request.args.get("data") + print("Model 2 called with data ", request.query_params.get("data")) + return request.query_params.get("data") class ComposedModel: @@ -32,8 +32,8 @@ class ComposedModel: self.model_two = client.get_handle("model_two") # This method can be called concurrently! - async def __call__(self, flask_request): - data = flask_request.data + async def __call__(self, starlette_request): + data = await starlette_request.body() score = await self.model_one.remote(data=data) if score > 0.5: diff --git a/python/ray/serve/examples/doc/tutorial_batch.py b/python/ray/serve/examples/doc/tutorial_batch.py index 3f8129a05..8aa8828b8 100644 --- a/python/ray/serve/examples/doc/tutorial_batch.py +++ b/python/ray/serve/examples/doc/tutorial_batch.py @@ -14,8 +14,10 @@ import requests # __doc_define_servable_v0_begin__ @serve.accept_batch -def batch_adder_v0(flask_requests: List): - numbers = [int(request.args["number"]) for request in flask_requests] +def batch_adder_v0(starlette_requests: List): + numbers = [ + int(request.query_params["number"]) for request in starlette_requests + ] input_array = np.array(numbers) print("Our input array has shape:", input_array.shape) @@ -58,7 +60,7 @@ print("Result returned:", results) # __doc_define_servable_v1_begin__ @serve.accept_batch def batch_adder_v1(requests: List): - numbers = [int(request.args["number"]) for request in requests] + numbers = [int(request.query_params["number"]) for request in requests] input_array = np.array(numbers) print("Our input array has shape:", input_array.shape) # Sleep for 200ms, this could be performing CPU intensive computation diff --git a/python/ray/serve/examples/doc/tutorial_deploy.py b/python/ray/serve/examples/doc/tutorial_deploy.py index 018f28dca..a7d5c75ca 100644 --- a/python/ray/serve/examples/doc/tutorial_deploy.py +++ b/python/ray/serve/examples/doc/tutorial_deploy.py @@ -48,9 +48,9 @@ class BoostingModel: with open("/tmp/iris_labels.json") as f: self.label_list = json.load(f) - def __call__(self, flask_request): - payload = flask_request.json - print("Worker: received flask request with data", payload) + async def __call__(self, starlette_request): + payload = await starlette_request.json() + print("Worker: received starlette request with data", payload) input_vector = [ payload["sepal length"], @@ -143,9 +143,9 @@ class BoostingModelv2: with open("/tmp/iris_labels_2.json") as f: self.label_list = json.load(f) - def __call__(self, flask_request): - payload = flask_request.json - print("Worker: received flask request with data", payload) + async def __call__(self, starlette_request): + payload = await starlette_request.json() + print("Worker: received starlette request with data", payload) input_vector = [ payload["sepal length"], diff --git a/python/ray/serve/examples/doc/tutorial_pytorch.py b/python/ray/serve/examples/doc/tutorial_pytorch.py index 80e3b9d32..bc43534ef 100644 --- a/python/ray/serve/examples/doc/tutorial_pytorch.py +++ b/python/ray/serve/examples/doc/tutorial_pytorch.py @@ -27,8 +27,8 @@ class ImageModel: mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) - def __call__(self, flask_request): - image_payload_bytes = flask_request.data + async def __call__(self, starlette_request): + image_payload_bytes = await starlette_request.body() pil_image = Image.open(BytesIO(image_payload_bytes)) print("[1/3] Parsed image data: {}".format(pil_image)) diff --git a/python/ray/serve/examples/doc/tutorial_sklearn.py b/python/ray/serve/examples/doc/tutorial_sklearn.py index 69f17953b..c2eeb8c8f 100644 --- a/python/ray/serve/examples/doc/tutorial_sklearn.py +++ b/python/ray/serve/examples/doc/tutorial_sklearn.py @@ -54,9 +54,9 @@ class BoostingModel: with open(LABEL_PATH) as f: self.label_list = json.load(f) - def __call__(self, flask_request): - payload = flask_request.json - print("Worker: received flask request with data", payload) + async def __call__(self, starlette_request): + payload = await starlette_request.json() + print("Worker: received starlette request with data", payload) input_vector = [ payload["sepal length"], diff --git a/python/ray/serve/examples/doc/tutorial_tensorflow.py b/python/ray/serve/examples/doc/tutorial_tensorflow.py index 07fb36381..022526296 100644 --- a/python/ray/serve/examples/doc/tutorial_tensorflow.py +++ b/python/ray/serve/examples/doc/tutorial_tensorflow.py @@ -51,10 +51,10 @@ class TFMnistModel: self.model_path = model_path self.model = tf.keras.models.load_model(model_path) - def __call__(self, flask_request): + async def __call__(self, starlette_request): # Step 1: transform HTTP request -> tensorflow input # Here we define the request schema to be a json array. - input_array = np.array(flask_request.json["array"]) + input_array = np.array((await starlette_request.json())["array"]) reshaped_array = input_array.reshape((1, 28, 28)) # Step 2: tensorflow input -> tensorflow output diff --git a/python/ray/serve/examples/echo.py b/python/ray/serve/examples/echo.py index 4f73d3da9..218beef01 100644 --- a/python/ray/serve/examples/echo.py +++ b/python/ray/serve/examples/echo.py @@ -9,8 +9,8 @@ import requests from ray import serve -def echo(flask_request): - return ["hello " + flask_request.args.get("name", "serve!")] +def echo(starlette_request): + return ["hello " + starlette_request.query_params.get("name", "serve!")] client = serve.start() diff --git a/python/ray/serve/examples/echo_actor.py b/python/ray/serve/examples/echo_actor.py index c8d94080a..e6ed0c1b8 100644 --- a/python/ray/serve/examples/echo_actor.py +++ b/python/ray/serve/examples/echo_actor.py @@ -1,6 +1,6 @@ """ Example actor that adds an increment to a number. This number can -come from either web (parsing Flask request) or python call. +come from either web (parsing Starlette request) or python call. This actor can be called from HTTP as well as from Python. """ @@ -30,9 +30,10 @@ class MagicCounter: def __init__(self, increment): self.increment = increment - def __call__(self, flask_request, base_number=None): + def __call__(self, starlette_request, base_number=None): if serve.context.web: - base_number = int(flask_request.args.get("base_number", "0")) + base_number = int( + starlette_request.query_params.get("base_number", "0")) return base_number + self.increment diff --git a/python/ray/serve/examples/echo_actor_batch.py b/python/ray/serve/examples/echo_actor_batch.py index f9ed1df4f..15c7ae67d 100644 --- a/python/ray/serve/examples/echo_actor_batch.py +++ b/python/ray/serve/examples/echo_actor_batch.py @@ -1,6 +1,6 @@ """ Example actor that adds an increment to a number. This number can -come from either web (parsing Flask request) or python call. +come from either web (parsing Starlette request) or python call. The queries incoming to this actor are batched. This actor can be called from HTTP as well as from Python. """ @@ -31,12 +31,13 @@ class MagicCounter: self.increment = increment @serve.accept_batch - def __call__(self, flask_request_list, base_number=None): + def __call__(self, starlette_request_list, base_number=None): # batch_size = serve.context.batch_size if serve.context.web: result = [] - for flask_request in flask_request_list: - base_number = int(flask_request.args.get("base_number", "0")) + for starlette_request in starlette_request_list: + base_number = int( + starlette_request.query_params.get("base_number", "0")) result.append(base_number) return list(map(lambda x: x + self.increment, result)) else: diff --git a/python/ray/serve/examples/echo_batching.py b/python/ray/serve/examples/echo_batching.py index b4c55249a..9ffb214eb 100644 --- a/python/ray/serve/examples/echo_batching.py +++ b/python/ray/serve/examples/echo_batching.py @@ -11,7 +11,7 @@ class MagicCounter: self.increment = increment @serve.accept_batch - def __call__(self, flask_request, base_number=None): + def __call__(self, starlette_request, base_number=None): # __call__ fn should preserve the batch size # base_number is a python list diff --git a/python/ray/serve/examples/echo_full.py b/python/ray/serve/examples/echo_full.py index 9639f1a25..046ec2d6a 100644 --- a/python/ray/serve/examples/echo_full.py +++ b/python/ray/serve/examples/echo_full.py @@ -12,8 +12,8 @@ client = serve.start() # a backend can be a function or class. # it can be made to be invoked from web as well as python. -def echo_v1(flask_request): - response = flask_request.args.get("response", "web") +def echo_v1(starlette_request): + response = starlette_request.query_params.get("response", "web") return response @@ -32,7 +32,7 @@ print(ray.get(client.get_handle("my_endpoint").remote(response="hello"))) # We can also add a new backend and split the traffic. -def echo_v2(flask_request): +def echo_v2(starlette_request): # magic, only from web. return "something new" diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 9e88959c8..4bfd663fd 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -96,7 +96,7 @@ class RayServeHandle: def remote(self, request_data: Optional[Union[Dict, Any]] = None, **kwargs): - """Issue an asynchrounous request to the endpoint. + """Issue an asynchronous request to the endpoint. Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get, respectively. @@ -106,9 +106,9 @@ class RayServeHandle: Args: request_data(dict, Any): If it's a dictionary, the data will be available in ``request.json()`` or ``request.form()``. - Otherwise, it will be available in ``request.data``. + Otherwise, it will be available in ``request.body()``. ``**kwargs``: All keyword arguments will be available in - ``request.args``. + ``request.query_params``. """ if not self.sync: raise RayServeException( diff --git a/python/ray/serve/http_util.py b/python/ray/serve/http_util.py index 1a057a88e..0aa4ccf84 100644 --- a/python/ray/serve/http_util.py +++ b/python/ray/serve/http_util.py @@ -1,76 +1,25 @@ -import io import json -import flask +import starlette.requests -def build_flask_request(asgi_scope_dict, request_body): - """Build and return a flask request from ASGI payload +def build_starlette_request(scope, serialized_body: bytes): + """Build and return a Starlette Request from ASGI payload. - This function is indented to be used immediately before task invocation - happen. + This function is intended to be used immediately before task invocation + happens. """ - wsgi_environ = build_wsgi_environ(asgi_scope_dict, request_body) - # We set populate_request=False to prevent self reference, which can lead - # to objects tracked by python garbage collector and memory growth. See - # https://github.com/ray-project/ray/issues/12395. - return flask.Request(wsgi_environ, populate_request=False) + # Simulates receiving HTTP body from TCP socket. In reality, the body has + # already been streamed in chunks and stored in serialized_body. + async def mock_receive(): + return { + "body": serialized_body, + "type": "http.request", + "more_body": False + } -def build_wsgi_environ(scope, body): - """ - Builds a scope and request body into a WSGI environ object. - - This code snippet is taken from https://github.com/django/asgiref/blob - /36c3e8dc70bf38fe2db87ac20b514f21aaf5ea9d/asgiref/wsgi.py#L52 - - WSGI specification can be found at - https://www.python.org/dev/peps/pep-0333/ - - This function helps translate ASGI scope and body into a flask request. - """ - environ = { - "REQUEST_METHOD": scope["method"], - "SCRIPT_NAME": scope.get("root_path", ""), - "PATH_INFO": scope["path"], - "QUERY_STRING": scope["query_string"].decode("ascii"), - "SERVER_PROTOCOL": "HTTP/{}".format(scope["http_version"]), - "wsgi.version": (1, 0), - "wsgi.url_scheme": scope.get("scheme", "http"), - "wsgi.input": body, - "wsgi.errors": io.BytesIO(), - "wsgi.multithread": True, - "wsgi.multiprocess": True, - "wsgi.run_once": False, - } - - # Get server name and port - required in WSGI, not in ASGI - environ["SERVER_NAME"] = scope["server"][0] - environ["SERVER_PORT"] = str(scope["server"][1]) - environ["REMOTE_ADDR"] = scope["client"][0] - - # Transforms headers into environ entries. - for name, value in scope.get("headers", []): - # name, values are both bytes, we need to decode them to string - name = name.decode("latin1") - value = value.decode("latin1") - - # Handle name correction to conform to WSGI spec - # https://www.python.org/dev/peps/pep-0333/#environ-variables - if name == "content-length": - corrected_name = "CONTENT_LENGTH" - elif name == "content-type": - corrected_name = "CONTENT_TYPE" - else: - corrected_name = "HTTP_%s" % name.upper().replace("-", "_") - - # If the header value repeated, - # we will just concatenate it to the field. - if corrected_name in environ: - value = environ[corrected_name] + "," + value - - environ[corrected_name] = value - return environ + return starlette.requests.Request(scope, mock_receive) class Response: diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 1f236f915..ea0f6f35d 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -19,8 +19,8 @@ from ray.serve.utils import (block_until_http_ready, format_actor_name, def test_e2e(serve_instance): client = serve_instance - def function(flask_request): - return {"method": flask_request.method} + def function(starlette_request): + return {"method": starlette_request.method} client.create_backend("echo:v1", function) client.create_endpoint( @@ -97,7 +97,7 @@ def test_backend_user_config(serve_instance): def __init__(self): self.count = 10 - def __call__(self, flask_request): + def __call__(self, starlette_request): return self.count, os.getpid() def reconfigure(self, config): @@ -820,8 +820,8 @@ def test_serve_metrics(serve_instance): client = serve_instance @serve.accept_batch - def batcher(flask_requests): - return ["hello"] * len(flask_requests) + def batcher(starlette_requests): + return ["hello"] * len(starlette_requests) client.create_backend("metrics", batcher) client.create_endpoint("metrics", backend="metrics", route="/metrics") @@ -871,6 +871,26 @@ def test_serve_metrics(serve_instance): verify_metrics() +def test_starlette_request(serve_instance): + client = serve_instance + + async def echo_body(starlette_request): + data = await starlette_request.body() + return data + + UVICORN_HIGH_WATER_MARK = 65536 # max bytes in one message + + # Long string to test serialization of multiple messages. + long_string = "x" * 10 * UVICORN_HIGH_WATER_MARK + + client.create_backend("echo:v1", echo_body) + client.create_endpoint( + "endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"]) + + resp = requests.post("http://127.0.0.1:8000/api", data=long_string).text + assert resp == long_string + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 1b03c0835..4ae745ba6 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -85,7 +85,7 @@ async def test_runner_wraps_error(): async def test_servable_function(serve_instance, router, mock_controller_with_name): def echo(request): - return request.args["i"] + return request.query_params["i"] await add_servable_to_router(echo, router, mock_controller_with_name[0]) @@ -103,7 +103,7 @@ async def test_servable_class(serve_instance, router, self.increment = inc def __call__(self, request): - return request.args["i"] + self.increment + return request.query_params["i"] + self.increment await add_servable_to_router( MyAdder, router, mock_controller_with_name[0], init_args=(3, )) @@ -277,7 +277,7 @@ async def test_user_config_update(serve_instance, router, def __init__(self): self.reval = "" - def __call__(self, flask_request): + def __call__(self, starlette_request): return self.retval def reconfigure(self, config): diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index 6e5c91d85..cc6b1e72b 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -8,7 +8,7 @@ def test_handle_in_endpoint(serve_instance): client = serve_instance class Endpoint1: - def __call__(self, flask_request): + def __call__(self, starlette_request): return "hello" class Endpoint2: @@ -40,12 +40,12 @@ def test_handle_http_args(serve_instance): client = serve_instance class Endpoint: - def __call__(self, request): + async def __call__(self, request): return { - "args": dict(request.args), + "args": dict(request.query_params), "headers": dict(request.headers), "method": request.method, - "json": request.json + "json": await request.json() } client.create_backend("backend", Endpoint) @@ -58,7 +58,7 @@ def test_handle_http_args(serve_instance): "arg2": "2" }, "headers": { - "X-Custom-Header": "value" + "x-custom-header": "value" }, "method": "POST", "json": { @@ -81,10 +81,10 @@ def test_handle_http_args(serve_instance): for resp in [resp_web, resp_handle]: for field in ["args", "method", "json"]: assert resp[field] == ground_truth[field] - resp["headers"]["X-Custom-Header"] == "value" + resp["headers"]["x-custom-header"] == "value" -def test_handle_inject_flask_request(serve_instance): +def test_handle_inject_starlette_request(serve_instance): client = serve_instance def echo_request_type(request): @@ -103,7 +103,7 @@ def test_handle_inject_flask_request(serve_instance): for route in ["/echo", "/wrapper"]: resp = requests.get(f"http://127.0.0.1:8000{route}") request_type = resp.text - assert request_type == "" + assert request_type == "" if __name__ == "__main__": diff --git a/python/ray/serve/tests/test_persistence.py b/python/ray/serve/tests/test_persistence.py index fec43f838..6124f41a0 100644 --- a/python/ray/serve/tests/test_persistence.py +++ b/python/ray/serve/tests/test_persistence.py @@ -12,7 +12,7 @@ ray.init(address="{}") from ray import serve client = serve.connect() -def driver(flask_request): +def driver(starlette_request): return "OK!" client.create_backend("driver", driver) diff --git a/python/ray/serve/tests/test_regression.py b/python/ray/serve/tests/test_regression.py index 027b53a72..e2425edf9 100644 --- a/python/ray/serve/tests/test_regression.py +++ b/python/ray/serve/tests/test_regression.py @@ -15,7 +15,7 @@ def test_np_in_composed_model(serve_instance): # in cloudpickle _from_numpy_buffer def sum_model(request): - return np.sum(request.args["data"]) + return np.sum(request.query_params["data"]) class ComposedModel: def __init__(self): @@ -42,7 +42,7 @@ def test_backend_worker_memory_growth(serve_instance): # https://github.com/ray-project/ray/issues/12395 client = serve_instance - def gc_unreachable_objects(flask_request): + def gc_unreachable_objects(starlette_request): gc.set_debug(gc.DEBUG_SAVEALL) gc.collect() return len(gc.garbage) diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index e8c5a6d13..99a65125e 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -8,7 +8,6 @@ import random import string import time from typing import List, Dict -import io import os from ray.serve.exceptions import RayServeException from collections import UserDict @@ -16,18 +15,18 @@ from collections import UserDict import requests import numpy as np import pydantic -import flask +import starlette.requests import ray from ray.serve.constants import HTTP_PROXY_TIMEOUT from ray.serve.context import TaskContext -from ray.serve.http_util import build_flask_request +from ray.serve.http_util import build_starlette_request ACTOR_FAILURE_RETRY_TIMEOUT_S = 60 class ServeMultiDict(UserDict): - """Compatible data structure to simulate Flask.Request.args API.""" + """Compatible data structure to simulate Starlette Request query_args.""" def getlist(self, key): """Return the list of items for a given key.""" @@ -35,11 +34,14 @@ class ServeMultiDict(UserDict): class ServeRequest: - """The request object used in Python context. + """The request object used when passing arguments via ServeHandle. - ServeRequest is built to have similar API as Flask.Request. You only need - to write your model serving code once; it can be queried by both HTTP and - Python. + ServeRequest partially implements the API of Starlette Request. You only + need to write your model serving code once; it can be queried by both HTTP + and Python. + + To use the full Starlette Request interface with ServeHandle, you may + instead directly pass in a Starlette Request object to the ServeHandle. """ def __init__(self, data, kwargs, headers, method): @@ -59,28 +61,25 @@ class ServeRequest: return self._method @property - def args(self): + def query_params(self): """The keyword arguments from ``handle.remote(**kwargs)``.""" return self._kwargs - @property - def json(self): + async def json(self): """The request dictionary, from ``handle.remote(dict)``.""" if not isinstance(self._data, dict): raise RayServeException("Request data is not a dictionary. " f"It is {type(self._data)}.") return self._data - @property - def form(self): + async def form(self): """The request dictionary, from ``handle.remote(dict)``.""" if not isinstance(self._data, dict): raise RayServeException("Request data is not a dictionary. " f"It is {type(self._data)}.") return self._data - @property - def data(self): + async def body(self): """The request data from ``handle.remote(obj)``.""" return self._data @@ -88,13 +87,13 @@ class ServeRequest: def parse_request_item(request_item): if request_item.metadata.request_context == TaskContext.Web: asgi_scope, body_bytes = request_item.args - return build_flask_request(asgi_scope, io.BytesIO(body_bytes)) + return build_starlette_request(asgi_scope, body_bytes) else: arg = request_item.args[0] if len(request_item.args) == 1 else None # If the input data from handle is web request, we don't need to wrap # it in ServeRequest. - if isinstance(arg, flask.Request): + if isinstance(arg, starlette.requests.Request): return arg return ServeRequest(