mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 20:18:33 +08:00
[Serve] Migrate from Flask.Request to Starlette Request (#12852)
This commit is contained in:
@@ -10,7 +10,7 @@ class Counter:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def __call__(self, flask_request):
|
||||
def __call__(self, starlette_request):
|
||||
self.count += 1
|
||||
return {"current_counter": self.count}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ ray.init(num_cpus=8)
|
||||
client = serve.start()
|
||||
|
||||
|
||||
def echo(flask_request):
|
||||
return "hello " + flask_request.args.get("name", "serve!")
|
||||
def echo(starlette_request):
|
||||
return "hello " + starlette_request.query_params.get("name", "serve!")
|
||||
|
||||
|
||||
client.create_backend("hello", echo)
|
||||
|
||||
@@ -16,13 +16,13 @@ client = serve.start()
|
||||
|
||||
|
||||
def model_one(request):
|
||||
print("Model 1 called with data ", request.args.get("data"))
|
||||
print("Model 1 called with data ", request.query_params.get("data"))
|
||||
return random()
|
||||
|
||||
|
||||
def model_two(request):
|
||||
print("Model 2 called with data ", request.args.get("data"))
|
||||
return request.args.get("data")
|
||||
print("Model 2 called with data ", request.query_params.get("data"))
|
||||
return request.query_params.get("data")
|
||||
|
||||
|
||||
class ComposedModel:
|
||||
@@ -32,8 +32,8 @@ class ComposedModel:
|
||||
self.model_two = client.get_handle("model_two")
|
||||
|
||||
# This method can be called concurrently!
|
||||
async def __call__(self, flask_request):
|
||||
data = flask_request.data
|
||||
async def __call__(self, starlette_request):
|
||||
data = await starlette_request.body()
|
||||
|
||||
score = await self.model_one.remote(data=data)
|
||||
if score > 0.5:
|
||||
|
||||
@@ -14,8 +14,10 @@ import requests
|
||||
|
||||
# __doc_define_servable_v0_begin__
|
||||
@serve.accept_batch
|
||||
def batch_adder_v0(flask_requests: List):
|
||||
numbers = [int(request.args["number"]) for request in flask_requests]
|
||||
def batch_adder_v0(starlette_requests: List):
|
||||
numbers = [
|
||||
int(request.query_params["number"]) for request in starlette_requests
|
||||
]
|
||||
|
||||
input_array = np.array(numbers)
|
||||
print("Our input array has shape:", input_array.shape)
|
||||
@@ -58,7 +60,7 @@ print("Result returned:", results)
|
||||
# __doc_define_servable_v1_begin__
|
||||
@serve.accept_batch
|
||||
def batch_adder_v1(requests: List):
|
||||
numbers = [int(request.args["number"]) for request in requests]
|
||||
numbers = [int(request.query_params["number"]) for request in requests]
|
||||
input_array = np.array(numbers)
|
||||
print("Our input array has shape:", input_array.shape)
|
||||
# Sleep for 200ms, this could be performing CPU intensive computation
|
||||
|
||||
@@ -48,9 +48,9 @@ class BoostingModel:
|
||||
with open("/tmp/iris_labels.json") as f:
|
||||
self.label_list = json.load(f)
|
||||
|
||||
def __call__(self, flask_request):
|
||||
payload = flask_request.json
|
||||
print("Worker: received flask request with data", payload)
|
||||
async def __call__(self, starlette_request):
|
||||
payload = await starlette_request.json()
|
||||
print("Worker: received starlette request with data", payload)
|
||||
|
||||
input_vector = [
|
||||
payload["sepal length"],
|
||||
@@ -143,9 +143,9 @@ class BoostingModelv2:
|
||||
with open("/tmp/iris_labels_2.json") as f:
|
||||
self.label_list = json.load(f)
|
||||
|
||||
def __call__(self, flask_request):
|
||||
payload = flask_request.json
|
||||
print("Worker: received flask request with data", payload)
|
||||
async def __call__(self, starlette_request):
|
||||
payload = await starlette_request.json()
|
||||
print("Worker: received starlette request with data", payload)
|
||||
|
||||
input_vector = [
|
||||
payload["sepal length"],
|
||||
|
||||
@@ -27,8 +27,8 @@ class ImageModel:
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def __call__(self, flask_request):
|
||||
image_payload_bytes = flask_request.data
|
||||
async def __call__(self, starlette_request):
|
||||
image_payload_bytes = await starlette_request.body()
|
||||
pil_image = Image.open(BytesIO(image_payload_bytes))
|
||||
print("[1/3] Parsed image data: {}".format(pil_image))
|
||||
|
||||
|
||||
@@ -54,9 +54,9 @@ class BoostingModel:
|
||||
with open(LABEL_PATH) as f:
|
||||
self.label_list = json.load(f)
|
||||
|
||||
def __call__(self, flask_request):
|
||||
payload = flask_request.json
|
||||
print("Worker: received flask request with data", payload)
|
||||
async def __call__(self, starlette_request):
|
||||
payload = await starlette_request.json()
|
||||
print("Worker: received starlette request with data", payload)
|
||||
|
||||
input_vector = [
|
||||
payload["sepal length"],
|
||||
|
||||
@@ -51,10 +51,10 @@ class TFMnistModel:
|
||||
self.model_path = model_path
|
||||
self.model = tf.keras.models.load_model(model_path)
|
||||
|
||||
def __call__(self, flask_request):
|
||||
async def __call__(self, starlette_request):
|
||||
# Step 1: transform HTTP request -> tensorflow input
|
||||
# Here we define the request schema to be a json array.
|
||||
input_array = np.array(flask_request.json["array"])
|
||||
input_array = np.array((await starlette_request.json())["array"])
|
||||
reshaped_array = input_array.reshape((1, 28, 28))
|
||||
|
||||
# Step 2: tensorflow input -> tensorflow output
|
||||
|
||||
@@ -9,8 +9,8 @@ import requests
|
||||
from ray import serve
|
||||
|
||||
|
||||
def echo(flask_request):
|
||||
return ["hello " + flask_request.args.get("name", "serve!")]
|
||||
def echo(starlette_request):
|
||||
return ["hello " + starlette_request.query_params.get("name", "serve!")]
|
||||
|
||||
|
||||
client = serve.start()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Example actor that adds an increment to a number. This number can
|
||||
come from either web (parsing Flask request) or python call.
|
||||
come from either web (parsing Starlette request) or python call.
|
||||
|
||||
This actor can be called from HTTP as well as from Python.
|
||||
"""
|
||||
@@ -30,9 +30,10 @@ class MagicCounter:
|
||||
def __init__(self, increment):
|
||||
self.increment = increment
|
||||
|
||||
def __call__(self, flask_request, base_number=None):
|
||||
def __call__(self, starlette_request, base_number=None):
|
||||
if serve.context.web:
|
||||
base_number = int(flask_request.args.get("base_number", "0"))
|
||||
base_number = int(
|
||||
starlette_request.query_params.get("base_number", "0"))
|
||||
return base_number + self.increment
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Example actor that adds an increment to a number. This number can
|
||||
come from either web (parsing Flask request) or python call.
|
||||
come from either web (parsing Starlette request) or python call.
|
||||
The queries incoming to this actor are batched.
|
||||
This actor can be called from HTTP as well as from Python.
|
||||
"""
|
||||
@@ -31,12 +31,13 @@ class MagicCounter:
|
||||
self.increment = increment
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request_list, base_number=None):
|
||||
def __call__(self, starlette_request_list, base_number=None):
|
||||
# batch_size = serve.context.batch_size
|
||||
if serve.context.web:
|
||||
result = []
|
||||
for flask_request in flask_request_list:
|
||||
base_number = int(flask_request.args.get("base_number", "0"))
|
||||
for starlette_request in starlette_request_list:
|
||||
base_number = int(
|
||||
starlette_request.query_params.get("base_number", "0"))
|
||||
result.append(base_number)
|
||||
return list(map(lambda x: x + self.increment, result))
|
||||
else:
|
||||
|
||||
@@ -11,7 +11,7 @@ class MagicCounter:
|
||||
self.increment = increment
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request, base_number=None):
|
||||
def __call__(self, starlette_request, base_number=None):
|
||||
# __call__ fn should preserve the batch size
|
||||
# base_number is a python list
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ client = serve.start()
|
||||
|
||||
# a backend can be a function or class.
|
||||
# it can be made to be invoked from web as well as python.
|
||||
def echo_v1(flask_request):
|
||||
response = flask_request.args.get("response", "web")
|
||||
def echo_v1(starlette_request):
|
||||
response = starlette_request.query_params.get("response", "web")
|
||||
return response
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ print(ray.get(client.get_handle("my_endpoint").remote(response="hello")))
|
||||
|
||||
|
||||
# We can also add a new backend and split the traffic.
|
||||
def echo_v2(flask_request):
|
||||
def echo_v2(starlette_request):
|
||||
# magic, only from web.
|
||||
return "something new"
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class RayServeHandle:
|
||||
|
||||
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs):
|
||||
"""Issue an asynchrounous request to the endpoint.
|
||||
"""Issue an asynchronous request to the endpoint.
|
||||
|
||||
Returns a Ray ObjectRef whose results can be waited for or retrieved
|
||||
using ray.wait or ray.get, respectively.
|
||||
@@ -106,9 +106,9 @@ class RayServeHandle:
|
||||
Args:
|
||||
request_data(dict, Any): If it's a dictionary, the data will be
|
||||
available in ``request.json()`` or ``request.form()``.
|
||||
Otherwise, it will be available in ``request.data``.
|
||||
Otherwise, it will be available in ``request.body()``.
|
||||
``**kwargs``: All keyword arguments will be available in
|
||||
``request.args``.
|
||||
``request.query_params``.
|
||||
"""
|
||||
if not self.sync:
|
||||
raise RayServeException(
|
||||
|
||||
@@ -1,76 +1,25 @@
|
||||
import io
|
||||
import json
|
||||
|
||||
import flask
|
||||
import starlette.requests
|
||||
|
||||
|
||||
def build_flask_request(asgi_scope_dict, request_body):
|
||||
"""Build and return a flask request from ASGI payload
|
||||
def build_starlette_request(scope, serialized_body: bytes):
|
||||
"""Build and return a Starlette Request from ASGI payload.
|
||||
|
||||
This function is indented to be used immediately before task invocation
|
||||
happen.
|
||||
This function is intended to be used immediately before task invocation
|
||||
happens.
|
||||
"""
|
||||
wsgi_environ = build_wsgi_environ(asgi_scope_dict, request_body)
|
||||
# We set populate_request=False to prevent self reference, which can lead
|
||||
# to objects tracked by python garbage collector and memory growth. See
|
||||
# https://github.com/ray-project/ray/issues/12395.
|
||||
return flask.Request(wsgi_environ, populate_request=False)
|
||||
|
||||
# Simulates receiving HTTP body from TCP socket. In reality, the body has
|
||||
# already been streamed in chunks and stored in serialized_body.
|
||||
async def mock_receive():
|
||||
return {
|
||||
"body": serialized_body,
|
||||
"type": "http.request",
|
||||
"more_body": False
|
||||
}
|
||||
|
||||
def build_wsgi_environ(scope, body):
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
|
||||
This code snippet is taken from https://github.com/django/asgiref/blob
|
||||
/36c3e8dc70bf38fe2db87ac20b514f21aaf5ea9d/asgiref/wsgi.py#L52
|
||||
|
||||
WSGI specification can be found at
|
||||
https://www.python.org/dev/peps/pep-0333/
|
||||
|
||||
This function helps translate ASGI scope and body into a flask request.
|
||||
"""
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": scope.get("root_path", ""),
|
||||
"PATH_INFO": scope["path"],
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": "HTTP/{}".format(scope["http_version"]),
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": scope.get("scheme", "http"),
|
||||
"wsgi.input": body,
|
||||
"wsgi.errors": io.BytesIO(),
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
}
|
||||
|
||||
# Get server name and port - required in WSGI, not in ASGI
|
||||
environ["SERVER_NAME"] = scope["server"][0]
|
||||
environ["SERVER_PORT"] = str(scope["server"][1])
|
||||
environ["REMOTE_ADDR"] = scope["client"][0]
|
||||
|
||||
# Transforms headers into environ entries.
|
||||
for name, value in scope.get("headers", []):
|
||||
# name, values are both bytes, we need to decode them to string
|
||||
name = name.decode("latin1")
|
||||
value = value.decode("latin1")
|
||||
|
||||
# Handle name correction to conform to WSGI spec
|
||||
# https://www.python.org/dev/peps/pep-0333/#environ-variables
|
||||
if name == "content-length":
|
||||
corrected_name = "CONTENT_LENGTH"
|
||||
elif name == "content-type":
|
||||
corrected_name = "CONTENT_TYPE"
|
||||
else:
|
||||
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
|
||||
|
||||
# If the header value repeated,
|
||||
# we will just concatenate it to the field.
|
||||
if corrected_name in environ:
|
||||
value = environ[corrected_name] + "," + value
|
||||
|
||||
environ[corrected_name] = value
|
||||
return environ
|
||||
return starlette.requests.Request(scope, mock_receive)
|
||||
|
||||
|
||||
class Response:
|
||||
|
||||
@@ -19,8 +19,8 @@ from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
def test_e2e(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def function(flask_request):
|
||||
return {"method": flask_request.method}
|
||||
def function(starlette_request):
|
||||
return {"method": starlette_request.method}
|
||||
|
||||
client.create_backend("echo:v1", function)
|
||||
client.create_endpoint(
|
||||
@@ -97,7 +97,7 @@ def test_backend_user_config(serve_instance):
|
||||
def __init__(self):
|
||||
self.count = 10
|
||||
|
||||
def __call__(self, flask_request):
|
||||
def __call__(self, starlette_request):
|
||||
return self.count, os.getpid()
|
||||
|
||||
def reconfigure(self, config):
|
||||
@@ -820,8 +820,8 @@ def test_serve_metrics(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
@serve.accept_batch
|
||||
def batcher(flask_requests):
|
||||
return ["hello"] * len(flask_requests)
|
||||
def batcher(starlette_requests):
|
||||
return ["hello"] * len(starlette_requests)
|
||||
|
||||
client.create_backend("metrics", batcher)
|
||||
client.create_endpoint("metrics", backend="metrics", route="/metrics")
|
||||
@@ -871,6 +871,26 @@ def test_serve_metrics(serve_instance):
|
||||
verify_metrics()
|
||||
|
||||
|
||||
def test_starlette_request(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
async def echo_body(starlette_request):
|
||||
data = await starlette_request.body()
|
||||
return data
|
||||
|
||||
UVICORN_HIGH_WATER_MARK = 65536 # max bytes in one message
|
||||
|
||||
# Long string to test serialization of multiple messages.
|
||||
long_string = "x" * 10 * UVICORN_HIGH_WATER_MARK
|
||||
|
||||
client.create_backend("echo:v1", echo_body)
|
||||
client.create_endpoint(
|
||||
"endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"])
|
||||
|
||||
resp = requests.post("http://127.0.0.1:8000/api", data=long_string).text
|
||||
assert resp == long_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
||||
@@ -85,7 +85,7 @@ async def test_runner_wraps_error():
|
||||
async def test_servable_function(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
def echo(request):
|
||||
return request.args["i"]
|
||||
return request.query_params["i"]
|
||||
|
||||
await add_servable_to_router(echo, router, mock_controller_with_name[0])
|
||||
|
||||
@@ -103,7 +103,7 @@ async def test_servable_class(serve_instance, router,
|
||||
self.increment = inc
|
||||
|
||||
def __call__(self, request):
|
||||
return request.args["i"] + self.increment
|
||||
return request.query_params["i"] + self.increment
|
||||
|
||||
await add_servable_to_router(
|
||||
MyAdder, router, mock_controller_with_name[0], init_args=(3, ))
|
||||
@@ -277,7 +277,7 @@ async def test_user_config_update(serve_instance, router,
|
||||
def __init__(self):
|
||||
self.reval = ""
|
||||
|
||||
def __call__(self, flask_request):
|
||||
def __call__(self, starlette_request):
|
||||
return self.retval
|
||||
|
||||
def reconfigure(self, config):
|
||||
|
||||
@@ -8,7 +8,7 @@ def test_handle_in_endpoint(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class Endpoint1:
|
||||
def __call__(self, flask_request):
|
||||
def __call__(self, starlette_request):
|
||||
return "hello"
|
||||
|
||||
class Endpoint2:
|
||||
@@ -40,12 +40,12 @@ def test_handle_http_args(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class Endpoint:
|
||||
def __call__(self, request):
|
||||
async def __call__(self, request):
|
||||
return {
|
||||
"args": dict(request.args),
|
||||
"args": dict(request.query_params),
|
||||
"headers": dict(request.headers),
|
||||
"method": request.method,
|
||||
"json": request.json
|
||||
"json": await request.json()
|
||||
}
|
||||
|
||||
client.create_backend("backend", Endpoint)
|
||||
@@ -58,7 +58,7 @@ def test_handle_http_args(serve_instance):
|
||||
"arg2": "2"
|
||||
},
|
||||
"headers": {
|
||||
"X-Custom-Header": "value"
|
||||
"x-custom-header": "value"
|
||||
},
|
||||
"method": "POST",
|
||||
"json": {
|
||||
@@ -81,10 +81,10 @@ def test_handle_http_args(serve_instance):
|
||||
for resp in [resp_web, resp_handle]:
|
||||
for field in ["args", "method", "json"]:
|
||||
assert resp[field] == ground_truth[field]
|
||||
resp["headers"]["X-Custom-Header"] == "value"
|
||||
resp["headers"]["x-custom-header"] == "value"
|
||||
|
||||
|
||||
def test_handle_inject_flask_request(serve_instance):
|
||||
def test_handle_inject_starlette_request(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def echo_request_type(request):
|
||||
@@ -103,7 +103,7 @@ def test_handle_inject_flask_request(serve_instance):
|
||||
for route in ["/echo", "/wrapper"]:
|
||||
resp = requests.get(f"http://127.0.0.1:8000{route}")
|
||||
request_type = resp.text
|
||||
assert request_type == "<class 'flask.wrappers.Request'>"
|
||||
assert request_type == "<class 'starlette.requests.Request'>"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -12,7 +12,7 @@ ray.init(address="{}")
|
||||
from ray import serve
|
||||
client = serve.connect()
|
||||
|
||||
def driver(flask_request):
|
||||
def driver(starlette_request):
|
||||
return "OK!"
|
||||
|
||||
client.create_backend("driver", driver)
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_np_in_composed_model(serve_instance):
|
||||
# in cloudpickle _from_numpy_buffer
|
||||
|
||||
def sum_model(request):
|
||||
return np.sum(request.args["data"])
|
||||
return np.sum(request.query_params["data"])
|
||||
|
||||
class ComposedModel:
|
||||
def __init__(self):
|
||||
@@ -42,7 +42,7 @@ def test_backend_worker_memory_growth(serve_instance):
|
||||
# https://github.com/ray-project/ray/issues/12395
|
||||
client = serve_instance
|
||||
|
||||
def gc_unreachable_objects(flask_request):
|
||||
def gc_unreachable_objects(starlette_request):
|
||||
gc.set_debug(gc.DEBUG_SAVEALL)
|
||||
gc.collect()
|
||||
return len(gc.garbage)
|
||||
|
||||
+16
-17
@@ -8,7 +8,6 @@ import random
|
||||
import string
|
||||
import time
|
||||
from typing import List, Dict
|
||||
import io
|
||||
import os
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from collections import UserDict
|
||||
@@ -16,18 +15,18 @@ from collections import UserDict
|
||||
import requests
|
||||
import numpy as np
|
||||
import pydantic
|
||||
import flask
|
||||
import starlette.requests
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import HTTP_PROXY_TIMEOUT
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.http_util import build_flask_request
|
||||
from ray.serve.http_util import build_starlette_request
|
||||
|
||||
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
|
||||
|
||||
|
||||
class ServeMultiDict(UserDict):
|
||||
"""Compatible data structure to simulate Flask.Request.args API."""
|
||||
"""Compatible data structure to simulate Starlette Request query_args."""
|
||||
|
||||
def getlist(self, key):
|
||||
"""Return the list of items for a given key."""
|
||||
@@ -35,11 +34,14 @@ class ServeMultiDict(UserDict):
|
||||
|
||||
|
||||
class ServeRequest:
|
||||
"""The request object used in Python context.
|
||||
"""The request object used when passing arguments via ServeHandle.
|
||||
|
||||
ServeRequest is built to have similar API as Flask.Request. You only need
|
||||
to write your model serving code once; it can be queried by both HTTP and
|
||||
Python.
|
||||
ServeRequest partially implements the API of Starlette Request. You only
|
||||
need to write your model serving code once; it can be queried by both HTTP
|
||||
and Python.
|
||||
|
||||
To use the full Starlette Request interface with ServeHandle, you may
|
||||
instead directly pass in a Starlette Request object to the ServeHandle.
|
||||
"""
|
||||
|
||||
def __init__(self, data, kwargs, headers, method):
|
||||
@@ -59,28 +61,25 @@ class ServeRequest:
|
||||
return self._method
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
def query_params(self):
|
||||
"""The keyword arguments from ``handle.remote(**kwargs)``."""
|
||||
return self._kwargs
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
async def json(self):
|
||||
"""The request dictionary, from ``handle.remote(dict)``."""
|
||||
if not isinstance(self._data, dict):
|
||||
raise RayServeException("Request data is not a dictionary. "
|
||||
f"It is {type(self._data)}.")
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
async def form(self):
|
||||
"""The request dictionary, from ``handle.remote(dict)``."""
|
||||
if not isinstance(self._data, dict):
|
||||
raise RayServeException("Request data is not a dictionary. "
|
||||
f"It is {type(self._data)}.")
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
async def body(self):
|
||||
"""The request data from ``handle.remote(obj)``."""
|
||||
return self._data
|
||||
|
||||
@@ -88,13 +87,13 @@ class ServeRequest:
|
||||
def parse_request_item(request_item):
|
||||
if request_item.metadata.request_context == TaskContext.Web:
|
||||
asgi_scope, body_bytes = request_item.args
|
||||
return build_flask_request(asgi_scope, io.BytesIO(body_bytes))
|
||||
return build_starlette_request(asgi_scope, body_bytes)
|
||||
else:
|
||||
arg = request_item.args[0] if len(request_item.args) == 1 else None
|
||||
|
||||
# If the input data from handle is web request, we don't need to wrap
|
||||
# it in ServeRequest.
|
||||
if isinstance(arg, flask.Request):
|
||||
if isinstance(arg, starlette.requests.Request):
|
||||
return arg
|
||||
|
||||
return ServeRequest(
|
||||
|
||||
Reference in New Issue
Block a user