mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:06:28 +08:00
[Serve] Added support for composing arbitrary DAGs (#7015)
This commit is contained in:
@@ -417,11 +417,15 @@ def split(endpoint_name, traffic_policy_dictionary):
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def get_handle(endpoint_name):
|
||||
def get_handle(endpoint_name, relative_slo_ms=None, absolute_slo_ms=None):
|
||||
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
relative_slo_ms(float): Specify relative deadline in milliseconds for
|
||||
queries fired using this handle. (Default: None)
|
||||
absolute_slo_ms(float): Specify absolute deadline in milliseconds for
|
||||
queries fired using this handle. (Default: None)
|
||||
|
||||
Returns:
|
||||
RayServeHandle
|
||||
@@ -431,7 +435,8 @@ def get_handle(endpoint_name):
|
||||
# Delay import due to it's dependency on global_state
|
||||
from ray.experimental.serve.handle import RayServeHandle
|
||||
|
||||
return RayServeHandle(global_state.init_or_get_router(), endpoint_name)
|
||||
return RayServeHandle(global_state.init_or_get_router(), endpoint_name,
|
||||
relative_slo_ms, absolute_slo_ms)
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Ray serve pipeline example
|
||||
"""
|
||||
import ray
|
||||
import ray.experimental.serve as serve
|
||||
import time
|
||||
|
||||
# initialize ray serve system.
|
||||
# blocking=True will wait for HTTP server to be ready to serve request.
|
||||
serve.init(blocking=True)
|
||||
|
||||
|
||||
# a backend can be a function or class.
|
||||
# it can be made to be invoked from web as well as python.
|
||||
@serve.route("/echo_v1")
|
||||
def echo_v1(_, response="hello from python!"):
|
||||
return f"echo_v1({response})"
|
||||
|
||||
|
||||
@serve.route("/echo_v2")
|
||||
def echo_v2(_, relay=""):
|
||||
return f"echo_v2({relay})"
|
||||
|
||||
|
||||
@serve.route("/echo_v3")
|
||||
def echo_v3(_, relay=""):
|
||||
return f"echo_v3({relay})"
|
||||
|
||||
|
||||
@serve.route("/echo_v4")
|
||||
def echo_v4(_, relay1="", relay2=""):
|
||||
return f"echo_v4({relay1} , {relay2})"
|
||||
|
||||
|
||||
"""
|
||||
The pipeline created is as follows -
|
||||
"my_endpoint1"
|
||||
/\
|
||||
/ \
|
||||
/ \
|
||||
/ \
|
||||
/ \
|
||||
/ \
|
||||
"my_endpoint2" "my_endpoint3"
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
\/
|
||||
"my_endpoint4"
|
||||
"""
|
||||
|
||||
# get the handle of the endpoints
|
||||
handle1 = serve.get_handle("echo_v1")
|
||||
handle2 = serve.get_handle("echo_v2")
|
||||
handle3 = serve.get_handle("echo_v3")
|
||||
handle4 = serve.get_handle("echo_v4")
|
||||
|
||||
start = time.time()
|
||||
print("Start firing to the pipeline: {} s".format(time.time()))
|
||||
handle1_oid = handle1.remote(response="hello")
|
||||
handle4_oid = handle4.remote(
|
||||
relay1=handle2.remote(relay=handle1_oid),
|
||||
relay2=handle3.remote(relay=handle1_oid))
|
||||
print("Firing ended now waiting for the result,"
|
||||
"time taken: {} s".format(time.time() - start))
|
||||
result = ray.get(handle4_oid)
|
||||
print("Result: {}, time taken: {} s".format(result, time.time() - start))
|
||||
@@ -31,11 +31,20 @@ serve.link("my_endpoint", "echo:v1")
|
||||
# wait for routing table to get populated
|
||||
time.sleep(2)
|
||||
|
||||
# slo (10 milliseconds deadline) can be specified via http
|
||||
# relative slo (10 ms deadline) can be specified via http
|
||||
slo_ms = 10.0
|
||||
print("> [HTTP] Pinging http://127.0.0.1:8000/echo?slo_ms={}".format(slo_ms))
|
||||
# absolute slo (10 ms deadline) can be specified via http
|
||||
abs_slo_ms = 11.9
|
||||
print("> [HTTP] Pinging http://127.0.0.1:8000/"
|
||||
"echo?relative_slo_ms={}".format(slo_ms))
|
||||
print(
|
||||
requests.get("http://127.0.0.1:8000/echo?slo_ms={}".format(slo_ms)).json())
|
||||
requests.get("http://127.0.0.1:8000/"
|
||||
"echo?relative_slo_ms={}".format(slo_ms)).json())
|
||||
print("> [HTTP] Pinging http://127.0.0.1:8000/"
|
||||
"echo?absolute_slo_ms={}".format(abs_slo_ms))
|
||||
print(
|
||||
requests.get("http://127.0.0.1:8000/"
|
||||
"echo?absolute_slo_ms={}".format(abs_slo_ms)).json())
|
||||
|
||||
# get the handle of the endpoint
|
||||
handle = serve.get_handle("my_endpoint")
|
||||
@@ -49,8 +58,11 @@ for r in range(10):
|
||||
response = "hello from request: {} slo: {}".format(r, slo_ms)
|
||||
print("> [REMOTE] Pinging handle.remote(response='{}',slo_ms={})".format(
|
||||
response, slo_ms))
|
||||
# slo can be specified via remote function
|
||||
f = handle.remote(response=response, slo_ms=slo_ms)
|
||||
|
||||
# overriding slo for each query.
|
||||
# Generally slo is specified for a service handle but it can
|
||||
# be overrided using options for query specific demands
|
||||
f = handle.options(relative_slo_ms=slo_ms).remote(response=response)
|
||||
future_list.append(f)
|
||||
|
||||
# get results of queries as they complete
|
||||
|
||||
@@ -2,6 +2,7 @@ from ray.experimental import serve
|
||||
from ray.experimental.serve.context import TaskContext
|
||||
from ray.experimental.serve.exceptions import RayServeException
|
||||
from ray.experimental.serve.constants import DEFAULT_HTTP_ADDRESS
|
||||
from ray.experimental.serve.request_params import RequestMetadata
|
||||
|
||||
|
||||
class RayServeHandle:
|
||||
@@ -26,33 +27,55 @@ class RayServeHandle:
|
||||
# raises RayTaskError Exception
|
||||
"""
|
||||
|
||||
def __init__(self, router_handle, endpoint_name):
|
||||
def __init__(self,
|
||||
router_handle,
|
||||
endpoint_name,
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None):
|
||||
self.router_handle = router_handle
|
||||
self.endpoint_name = endpoint_name
|
||||
assert (relative_slo_ms is None
|
||||
or absolute_slo_ms is None), ("Can't specify both "
|
||||
"relative and absolute "
|
||||
"slo's together!")
|
||||
self.relative_slo_ms = self._check_slo_ms(relative_slo_ms)
|
||||
self.absolute_slo_ms = self._check_slo_ms(absolute_slo_ms)
|
||||
|
||||
def _check_slo_ms(self, slo_value):
|
||||
if slo_value is not None:
|
||||
try:
|
||||
slo_value = float(slo_value)
|
||||
if slo_value < 0:
|
||||
raise ValueError(
|
||||
"Request SLO must be positive, it is {}".format(
|
||||
slo_value))
|
||||
return slo_value
|
||||
except ValueError as e:
|
||||
raise RayServeException(str(e))
|
||||
return None
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
if len(args) != 0:
|
||||
raise RayServeException(
|
||||
"handle.remote must be invoked with keyword arguments.")
|
||||
|
||||
# get slo_ms before enqueuing the query
|
||||
request_slo_ms = kwargs.pop("slo_ms", None)
|
||||
if request_slo_ms is not None:
|
||||
try:
|
||||
request_slo_ms = float(request_slo_ms)
|
||||
if request_slo_ms < 0:
|
||||
raise ValueError(
|
||||
"Request SLO must be positive, it is {}".format(
|
||||
request_slo_ms))
|
||||
except ValueError as e:
|
||||
raise RayServeException(str(e))
|
||||
|
||||
# create RequestMetadata instance
|
||||
request_in_object = RequestMetadata(
|
||||
self.endpoint_name, TaskContext.Python, self.relative_slo_ms,
|
||||
self.absolute_slo_ms)
|
||||
return self.router_handle.enqueue_request.remote(
|
||||
service=self.endpoint_name,
|
||||
request_args=(),
|
||||
request_kwargs=kwargs,
|
||||
request_context=TaskContext.Python,
|
||||
request_slo_ms=request_slo_ms)
|
||||
request_in_object, **kwargs)
|
||||
|
||||
def options(self, relative_slo_ms=None, absolute_slo_ms=None):
|
||||
# If both the slo's are None then then we use a high default
|
||||
# value so other queries can be prioritize and put in front of these
|
||||
# queries.
|
||||
assert (relative_slo_ms is None
|
||||
or absolute_slo_ms is None), ("Can't specify both "
|
||||
"relative and absolute "
|
||||
"slo's together!")
|
||||
return RayServeHandle(self.router_handle, self.endpoint_name,
|
||||
relative_slo_ms, absolute_slo_ms)
|
||||
|
||||
def get_traffic_policy(self):
|
||||
# TODO(simon): This method is implemented via checking global state
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
import time
|
||||
from typing import DefaultDict, Union, List
|
||||
from typing import DefaultDict, List
|
||||
import pickle
|
||||
|
||||
# Note on choosing blist instead of stdlib heapq
|
||||
@@ -15,7 +14,6 @@ import blist
|
||||
|
||||
import ray
|
||||
from ray.experimental.serve.utils import logger
|
||||
from ray.experimental.serve.constants import DEFAULT_LATENCY_SLO_MS
|
||||
|
||||
|
||||
class Query:
|
||||
@@ -56,20 +54,6 @@ class Query:
|
||||
self.request_kwargs)
|
||||
|
||||
|
||||
def _adjust_latency_slo(slo_ms: Union[float, int, None]) -> float:
|
||||
"""Normalize the input latency objective to absoluate timestamp.
|
||||
|
||||
Input:
|
||||
slo_ms(float, int, None): If value is None, then we use a high default
|
||||
value so other queries can be prioritize and put in front of these
|
||||
queries.
|
||||
"""
|
||||
if slo_ms is None:
|
||||
slo_ms = DEFAULT_LATENCY_SLO_MS
|
||||
current_time_ms = time.time() * 1000
|
||||
return current_time_ms + slo_ms
|
||||
|
||||
|
||||
def _make_future_unwrapper(client_futures: List[asyncio.Future],
|
||||
host_future: asyncio.Future):
|
||||
"""Distribute the result of host_future to each of client_future"""
|
||||
@@ -179,15 +163,18 @@ class CentralizedQueues:
|
||||
for backend_name, queue in self.buffer_queues.items()
|
||||
}
|
||||
|
||||
async def enqueue_request(self,
|
||||
service,
|
||||
request_args,
|
||||
request_kwargs,
|
||||
request_context,
|
||||
request_slo_ms=None):
|
||||
async def enqueue_request(self, request_in_object, *request_args,
|
||||
**request_kwargs):
|
||||
service = request_in_object.service
|
||||
logger.debug("Received a request for service {}".format(service))
|
||||
|
||||
request_slo_ms = _adjust_latency_slo(request_slo_ms)
|
||||
# check if the slo specified is directly the
|
||||
# wall clock time
|
||||
if request_in_object.absolute_slo_ms is not None:
|
||||
request_slo_ms = request_in_object.absolute_slo_ms
|
||||
else:
|
||||
request_slo_ms = request_in_object.adjust_relative_slo_ms()
|
||||
request_context = request_in_object.request_context
|
||||
query = Query(request_args, request_kwargs, request_context,
|
||||
request_slo_ms)
|
||||
await self.service_queues[service].put(query)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
import time
|
||||
from ray.experimental.serve.constants import DEFAULT_LATENCY_SLO_MS
|
||||
|
||||
|
||||
class RequestMetadata:
|
||||
"""
|
||||
Request Arguments required for enqueuing a request to the service
|
||||
queue.
|
||||
Args:
|
||||
service(str): A registered service endpoint.
|
||||
request_context(TaskContext): Context of a request.
|
||||
request_slo_ms(float): Expected time for the query to get
|
||||
completed.
|
||||
is_wall_clock_time(bool): if True, router won't add wall clock
|
||||
time to `request_slo_ms`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
service,
|
||||
request_context,
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None):
|
||||
|
||||
self.service = service
|
||||
self.request_context = request_context
|
||||
self.relative_slo_ms = relative_slo_ms
|
||||
self.absolute_slo_ms = absolute_slo_ms
|
||||
|
||||
def adjust_relative_slo_ms(self) -> float:
|
||||
"""Normalize the input latency objective to absoluate timestamp.
|
||||
|
||||
"""
|
||||
slo_ms = self.relative_slo_ms
|
||||
if slo_ms is None:
|
||||
slo_ms = DEFAULT_LATENCY_SLO_MS
|
||||
current_time_ms = time.time() * 1000
|
||||
return current_time_ms + slo_ms
|
||||
@@ -8,6 +8,8 @@ from ray.experimental.async_api import _async_init
|
||||
from ray.experimental.serve.constants import HTTP_ROUTER_CHECKER_INTERVAL_S
|
||||
from ray.experimental.serve.context import TaskContext
|
||||
from ray.experimental.serve.utils import BytesEncoder
|
||||
from ray.experimental.serve.request_params import RequestMetadata
|
||||
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
|
||||
@@ -103,6 +105,20 @@ class HTTPProxy:
|
||||
|
||||
return b"".join(body_buffer)
|
||||
|
||||
def _check_slo_ms(self, request_slo_ms):
|
||||
if request_slo_ms is not None:
|
||||
if len(request_slo_ms) != 1:
|
||||
raise ValueError(
|
||||
"Multiple SLO specified, please specific only one.")
|
||||
request_slo_ms = request_slo_ms[0]
|
||||
request_slo_ms = float(request_slo_ms)
|
||||
if request_slo_ms < 0:
|
||||
raise ValueError(
|
||||
"Request SLO must be positive, it is {}".format(
|
||||
request_slo_ms))
|
||||
return request_slo_ms
|
||||
return None
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
# NOTE: This implements ASGI protocol specified in
|
||||
# https://asgi.readthedocs.io/en/latest/specs/index.html
|
||||
@@ -134,29 +150,32 @@ class HTTPProxy:
|
||||
# get slo_ms before enqueuing the query
|
||||
query_string = scope["query_string"].decode("ascii")
|
||||
query_kwargs = parse_qs(query_string)
|
||||
request_slo_ms = query_kwargs.pop("slo_ms", None)
|
||||
if request_slo_ms is not None:
|
||||
try:
|
||||
if len(request_slo_ms) != 1:
|
||||
raise ValueError(
|
||||
"Multiple SLO specified, please specific only one.")
|
||||
request_slo_ms = request_slo_ms[0]
|
||||
request_slo_ms = float(request_slo_ms)
|
||||
if request_slo_ms < 0:
|
||||
raise ValueError(
|
||||
"Request SLO must be positive, it is {}".format(
|
||||
request_slo_ms))
|
||||
except ValueError as e:
|
||||
await JSONResponse({"error": str(e)})(scope, receive, send)
|
||||
return
|
||||
relative_slo_ms = query_kwargs.pop("relative_slo_ms", None)
|
||||
absolute_slo_ms = query_kwargs.pop("absolute_slo_ms", None)
|
||||
try:
|
||||
relative_slo_ms = self._check_slo_ms(relative_slo_ms)
|
||||
absolute_slo_ms = self._check_slo_ms(absolute_slo_ms)
|
||||
if relative_slo_ms is not None and absolute_slo_ms is not None:
|
||||
raise ValueError("Both relative and absolute slo's"
|
||||
"cannot be specified.")
|
||||
except ValueError as e:
|
||||
await JSONResponse({"error": str(e)})(scope, receive, send)
|
||||
return
|
||||
|
||||
# create objects necessary for enqueue
|
||||
# enclosing http_body_bytes to list due to
|
||||
# https://github.com/ray-project/ray/issues/6944
|
||||
# TODO(alind): remove list enclosing after issue is fixed
|
||||
args = (scope, [http_body_bytes])
|
||||
request_in_object = RequestMetadata(
|
||||
endpoint_name,
|
||||
TaskContext.Web,
|
||||
relative_slo_ms=relative_slo_ms,
|
||||
absolute_slo_ms=absolute_slo_ms)
|
||||
|
||||
actual_result = await (self.serve_global_state.init_or_get_router()
|
||||
.enqueue_request.remote(
|
||||
service=endpoint_name,
|
||||
request_args=(scope, http_body_bytes),
|
||||
request_kwargs=dict(),
|
||||
request_context=TaskContext.Web,
|
||||
request_slo_ms=request_slo_ms))
|
||||
.enqueue_request.remote(request_in_object,
|
||||
*args))
|
||||
result = actual_result
|
||||
|
||||
if isinstance(result, ray.exceptions.RayTaskError):
|
||||
|
||||
@@ -6,6 +6,7 @@ import ray
|
||||
from ray.experimental.serve.policy import (
|
||||
RandomPolicyQueue, RandomPolicyQueueActor, RoundRobinPolicyQueueActor,
|
||||
PowerOfTwoPolicyQueueActor, FixedPackingPolicyQueueActor)
|
||||
from ray.experimental.serve.request_params import RequestMetadata
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
@@ -42,13 +43,13 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
q.dequeue_request.remote("backend", task_runner_mock_actor)
|
||||
|
||||
# Make sure we get the request result back
|
||||
result = await q.enqueue_request.remote("svc", 1, "kwargs", None)
|
||||
result = await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
assert result == "DONE"
|
||||
|
||||
# Make sure it's the right request
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.request_args == 1
|
||||
assert got_work.request_kwargs == "kwargs"
|
||||
assert got_work.request_args[0] == 1
|
||||
assert got_work.request_kwargs == {}
|
||||
|
||||
|
||||
async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
@@ -60,7 +61,7 @@ async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
slo_ms = 1000 - 100 * i
|
||||
all_request_sent.append(
|
||||
q.enqueue_request.remote(
|
||||
"svc", i, "kwargs", None, request_slo_ms=slo_ms))
|
||||
RequestMetadata("svc", None, relative_slo_ms=slo_ms), i))
|
||||
|
||||
for i in range(10):
|
||||
await q.dequeue_request.remote("backend", task_runner_mock_actor)
|
||||
@@ -71,7 +72,7 @@ async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
all_calls = await task_runner_mock_actor.get_all_calls.remote()
|
||||
all_calls = all_calls[-10:]
|
||||
for call in all_calls:
|
||||
assert call.request_args == i_should_be
|
||||
assert call.request_args[0] == i_should_be
|
||||
i_should_be -= 1
|
||||
|
||||
|
||||
@@ -80,15 +81,15 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-1": 1})
|
||||
await q.dequeue_request.remote("backend-1", task_runner_mock_actor)
|
||||
await q.enqueue_request.remote("svc", 1, "kwargs", None)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.request_args == 1
|
||||
assert got_work.request_args[0] == 1
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-2": 1})
|
||||
await q.dequeue_request.remote("backend-2", task_runner_mock_actor)
|
||||
await q.enqueue_request.remote("svc", 2, "kwargs", None)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.request_args == 2
|
||||
assert got_work.request_args[0] == 2
|
||||
|
||||
|
||||
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
@@ -103,13 +104,13 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
# assume 50% split, the probability of all 20 requests goes to a
|
||||
# single queue is 0.5^20 ~ 1-6
|
||||
for _ in range(20):
|
||||
await q.enqueue_request.remote("svc", 1, "kwargs", None)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
|
||||
got_work = [
|
||||
await runner.get_recent_call.remote()
|
||||
for runner in (runner_1, runner_2)
|
||||
]
|
||||
assert [g.request_args for g in got_work] == [1, 1]
|
||||
assert [g.request_args[0] for g in got_work] == [1, 1]
|
||||
|
||||
|
||||
async def test_round_robin(serve_instance, task_runner_mock_actor):
|
||||
@@ -125,13 +126,13 @@ async def test_round_robin(serve_instance, task_runner_mock_actor):
|
||||
await q.dequeue_request.remote("backend-2", runner_2)
|
||||
|
||||
for _ in range(20):
|
||||
await q.enqueue_request.remote("svc", 1, "kwargs", None)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
|
||||
got_work = [
|
||||
await runner.get_recent_call.remote()
|
||||
for runner in (runner_1, runner_2)
|
||||
]
|
||||
assert [g.request_args for g in got_work] == [1, 1]
|
||||
assert [g.request_args[0] for g in got_work] == [1, 1]
|
||||
|
||||
|
||||
async def test_fixed_packing(serve_instance):
|
||||
@@ -149,10 +150,11 @@ async def test_fixed_packing(serve_instance):
|
||||
for backend, runner in zip(["1", "2"], [runner_1, runner_2]):
|
||||
for _ in range(packing_num):
|
||||
input_value = "should-go-to-backend-{}".format(backend)
|
||||
await q.enqueue_request.remote("svc", input_value, "kwargs", None)
|
||||
await q.enqueue_request.remote(
|
||||
RequestMetadata("svc", None), input_value)
|
||||
all_calls = await runner.get_all_calls.remote()
|
||||
for call in all_calls:
|
||||
assert call.request_args == input_value
|
||||
assert call.request_args[0] == input_value
|
||||
|
||||
|
||||
async def test_power_of_two_choices(serve_instance):
|
||||
@@ -162,13 +164,13 @@ async def test_power_of_two_choices(serve_instance):
|
||||
# First, fill the queue for backend-1 with 3 requests
|
||||
await q.set_traffic.remote("svc", {"backend-1": 1.0})
|
||||
for _ in range(3):
|
||||
future = q.enqueue_request.remote("svc", "1", "", None)
|
||||
future = q.enqueue_request.remote(RequestMetadata("svc", None), "1")
|
||||
enqueue_futures.append(future)
|
||||
|
||||
# Then, add a new backend, this backend should be filled next
|
||||
await q.set_traffic.remote("svc", {"backend-1": 0.5, "backend-2": 0.5})
|
||||
for _ in range(2):
|
||||
future = q.enqueue_request.remote("svc", "2", "", None)
|
||||
future = q.enqueue_request.remote(RequestMetadata("svc", None), "2")
|
||||
enqueue_futures.append(future)
|
||||
|
||||
runner_1, runner_2 = (make_task_runner_mock() for _ in range(2))
|
||||
|
||||
@@ -5,6 +5,7 @@ import ray.experimental.serve.context as context
|
||||
from ray.experimental.serve.policy import RoundRobinPolicyQueueActor
|
||||
from ray.experimental.serve.task_runner import (
|
||||
RayServeMixin, TaskRunner, TaskRunnerActor, wrap_to_ray_error)
|
||||
from ray.experimental.serve.request_params import RequestMetadata
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
@@ -38,11 +39,9 @@ async def test_runner_actor(serve_instance):
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
|
||||
for query in [333, 444, 555]:
|
||||
result = await q.enqueue_request.remote(
|
||||
PRODUCER_NAME,
|
||||
request_args=None,
|
||||
request_kwargs={"i": query},
|
||||
request_context=context.TaskContext.Python)
|
||||
query_param = RequestMetadata(PRODUCER_NAME,
|
||||
context.TaskContext.Python)
|
||||
result = await q.enqueue_request.remote(query_param, i=query)
|
||||
assert result == query
|
||||
|
||||
|
||||
@@ -71,11 +70,9 @@ async def test_ray_serve_mixin(serve_instance):
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
|
||||
for query in [333, 444, 555]:
|
||||
result = await q.enqueue_request.remote(
|
||||
PRODUCER_NAME,
|
||||
request_args=None,
|
||||
request_kwargs={"i": query},
|
||||
request_context=context.TaskContext.Python)
|
||||
query_param = RequestMetadata(PRODUCER_NAME,
|
||||
context.TaskContext.Python)
|
||||
result = await q.enqueue_request.remote(query_param, i=query)
|
||||
assert result == query + 3
|
||||
|
||||
|
||||
@@ -95,11 +92,8 @@ async def test_task_runner_check_context(serve_instance):
|
||||
runner._ray_serve_fetch.remote()
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
result_oid = q.enqueue_request.remote(
|
||||
PRODUCER_NAME,
|
||||
request_args=None,
|
||||
request_kwargs={"i": 42},
|
||||
request_context=context.TaskContext.Python)
|
||||
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
|
||||
result_oid = q.enqueue_request.remote(query_param, i=42)
|
||||
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
await result_oid
|
||||
|
||||
@@ -16,7 +16,12 @@ def parse_request_item(request_item):
|
||||
if request_item.request_context == TaskContext.Web:
|
||||
is_web_context = True
|
||||
asgi_scope, body_bytes = request_item.request_args
|
||||
flask_request = build_flask_request(asgi_scope, io.BytesIO(body_bytes))
|
||||
|
||||
# http_body_bytes enclosed in list due to
|
||||
# https://github.com/ray-project/ray/issues/6944
|
||||
# TODO(alind): remove list enclosing after issue is fixed
|
||||
flask_request = build_flask_request(asgi_scope,
|
||||
io.BytesIO(body_bytes[0]))
|
||||
args = (flask_request, )
|
||||
kwargs = {}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user