diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index 25355a2cf..8b107b8b3 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -411,3 +411,14 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla print("Request coming from web!") elif isinstance(request, ServeRequest): print("Request coming from Python!") + +.. note:: + + Once special case is when you pass a web request to a handle. + + .. code-block:: python + + handle.remote(flask_request) + + In this case, Serve will `not` wrap it in ServeRequest. You can directly + process the request as a ``flask.Request``. diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index c0be51a06..6e5c91d85 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -84,6 +84,28 @@ def test_handle_http_args(serve_instance): resp["headers"]["X-Custom-Header"] == "value" +def test_handle_inject_flask_request(serve_instance): + client = serve_instance + + def echo_request_type(request): + return str(type(request)) + + client.create_backend("echo:v0", echo_request_type) + client.create_endpoint("echo", backend="echo:v0", route="/echo") + + def wrapper_model(web_request): + handle = serve.connect().get_handle("echo") + return ray.get(handle.remote(web_request)) + + client.create_backend("wrapper:v0", wrapper_model) + client.create_endpoint("wrapper", backend="wrapper:v0", route="/wrapper") + + for route in ["/echo", "/wrapper"]: + resp = requests.get(f"http://127.0.0.1:8000{route}") + request_type = resp.text + assert request_type == "" + + if __name__ == "__main__": import sys import pytest diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index 9c4a028fc..04351a6da 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -15,6 +15,7 @@ from collections import UserDict import requests import numpy as np import pydantic +import flask import ray from ray.serve.constants import HTTP_PROXY_TIMEOUT @@ -88,8 +89,15 @@ def parse_request_item(request_item): asgi_scope, body_bytes = request_item.args return build_flask_request(asgi_scope, io.BytesIO(body_bytes)) else: + arg = request_item.args[0] if len(request_item.args) == 1 else None + + # If the input data from handle is web request, we don't need to wrap + # it in ServeRequest. + if isinstance(arg, flask.Request): + return arg + return ServeRequest( - request_item.args[0] if len(request_item.args) == 1 else None, + arg, request_item.kwargs, headers=request_item.metadata.http_headers, method=request_item.metadata.http_method,