[Serve] Allow ObjectRef for Composition (#12592)

This commit is contained in:
Simon Mo
2021-01-18 15:26:35 -08:00
committed by GitHub
parent dc42abb2f5
commit 6341f1fa2e
5 changed files with 66 additions and 32 deletions
+18 -15
View File
@@ -18,7 +18,7 @@ from ray.serve.exceptions import RayServeException
from ray.util import metrics
from ray.serve.config import BackendConfig
from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.router import Query
from ray.serve.router import Query, RequestMetadata
from ray.serve.constants import (
BACKEND_RECONFIGURE_METHOD,
DEFAULT_LATENCY_BUCKET_MS,
@@ -95,7 +95,11 @@ class BatchQueue:
def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
"""Creates a replica class wrapping the provided function or class."""
"""Creates a replica class wrapping the provided function or class.
This approach is picked over inheritance to avoid conflict between user
provided class and the RayServeReplica class.
"""
if inspect.isfunction(func_or_class):
is_function = True
@@ -124,8 +128,15 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
is_function, controller_handle)
@ray.method(num_returns=2)
async def handle_request(self, request):
return await self.backend.handle_request(request)
async def handle_request(
self,
request_metadata: RequestMetadata,
*request_args,
**request_kwargs,
):
# Directly receive input because it might contain an ObjectRef.
query = Query(request_args, request_kwargs, request_metadata)
return await self.backend.handle_request(query)
def ready(self):
pass
@@ -149,10 +160,6 @@ def wrap_to_ray_error(exception: Exception) -> RayTaskError:
return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__)
def ensure_async(func: Callable) -> Callable:
return sync_to_async(func)
class RayServeReplica:
"""Handles requests with the provided callable."""
@@ -286,7 +293,7 @@ class RayServeReplica:
async def invoke_single(self, request_item: Query) -> Any:
logger.debug("Replica {} started executing request {}".format(
self.replica_tag, request_item.metadata.request_id))
method_to_call = ensure_async(self.get_runner_method(request_item))
method_to_call = sync_to_async(self.get_runner_method(request_item))
arg = parse_request_item(request_item)
start = time.time()
@@ -329,7 +336,7 @@ class RayServeReplica:
self.request_counter.record(batch_size)
call_method = ensure_async(call_methods.pop())
call_method = sync_to_async(call_methods.pop())
result_list = await call_method(args)
if not isinstance(result_list, Iterable) or isinstance(
@@ -430,11 +437,7 @@ class RayServeReplica:
self.config.batch_wait_timeout)
self.reconfigure(self.config.user_config)
async def handle_request(self,
request: Union[Query, bytes]) -> asyncio.Future:
if isinstance(request, bytes):
request = Query.ray_deserialize(request)
async def handle_request(self, request: Query) -> asyncio.Future:
request.tick_enter_replica = time.time()
logger.debug("Replica {} received request {}".format(
self.replica_tag, request.metadata.request_id))
+3 -1
View File
@@ -4,6 +4,8 @@ from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
from enum import Enum
from ray.serve.router import Router
@dataclass(frozen=True)
class HandleOptions:
@@ -39,7 +41,7 @@ class RayServeHandle:
"""
def __init__(self,
router,
router: Router,
endpoint_name,
handle_options: Optional[HandleOptions] = None):
self.router = router
+6 -4
View File
@@ -101,9 +101,12 @@ class ReplicaSet:
) >= self.max_concurrent_queries:
# This replica is overloaded, try next one
continue
logger.debug(f"Assigned query {query.metadata.request_id} "
f"to replica {replica}.")
tracker_ref, user_ref = replica.handle_request.remote(query)
# Directly passing args because it might contain an ObjectRef.
tracker_ref, user_ref = replica.handle_request.remote(
query.metadata, *query.args, **query.kwargs)
self.in_flight_queries[replica].add(tracker_ref)
return user_ref
return None
@@ -144,9 +147,8 @@ class ReplicaSet:
return_when=asyncio.FIRST_COMPLETED)
if self.config_updated_event.is_set():
self.config_updated_event.clear()
# We are pretty sure a free replica is ready now, let's recurse and
# assign this query a replica.
assigned_ref = await self.assign_replica(query)
# We are pretty sure a free replica is ready now.
assigned_ref = self._try_assign_replica(query)
return assigned_ref
+30 -1
View File
@@ -2,9 +2,12 @@ import gc
import numpy as np
import requests
import pytest
import ray
from ray.exceptions import GetTimeoutError
from ray import serve
from ray.test_utils import SignalActor
def test_np_in_composed_model(serve_instance):
@@ -63,7 +66,33 @@ def test_backend_worker_memory_growth(serve_instance):
assert num_unreachable_objects == 0
def test_ref_in_handle_input(serve_instance):
client = serve_instance
# https://github.com/ray-project/ray/issues/12593
unblock_worker_signal = SignalActor.remote()
async def blocked_by_ref(serve_request):
data = await serve_request.body()
assert not isinstance(data, ray.ObjectRef)
client.create_backend("ref", blocked_by_ref)
client.create_endpoint("ref", backend="ref")
handle = client.get_handle("ref")
# Pass in a ref that's not ready yet
ref = unblock_worker_signal.wait.remote()
worker_result = handle.remote(ref)
# Worker shouldn't execute the request
with pytest.raises(GetTimeoutError):
ray.get(worker_result, timeout=1)
# Now unblock the worker
unblock_worker_signal.send.remote()
ray.get(worker_result)
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", "-s", __file__]))
+9 -11
View File
@@ -33,11 +33,9 @@ def mock_task_runner():
self.queries = []
@ray.method(num_returns=2)
async def handle_request(self, request):
if isinstance(request, bytes):
request = Query.ray_deserialize(request)
self.query = request
self.queries.append(request)
async def handle_request(self, request_metadata, *args, **kwargs):
self.query = Query(args, kwargs, request_metadata)
self.queries.append(self.query)
return b"", "DONE"
def get_recent_call(self):
@@ -75,7 +73,7 @@ async def test_simple_endpoint_backend_pair(ray_instance, mock_controller,
# Make sure we get the request result back
ref = await q.assign_request.remote(
RequestMetadata(get_random_letters(10), "svc", None), 1)
RequestMetadata(get_random_letters(10), "svc"), 1)
result = await ref
assert result == "DONE"
@@ -98,7 +96,7 @@ async def test_changing_backend(ray_instance, mock_controller,
task_runner_mock_actor)
await (await q.assign_request.remote(
RequestMetadata(get_random_letters(10), "svc", None), 1))
RequestMetadata(get_random_letters(10), "svc"), 1))
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.args[0] == 1
@@ -109,7 +107,7 @@ async def test_changing_backend(ray_instance, mock_controller,
await mock_controller.add_new_replica.remote("backend-alter-2",
task_runner_mock_actor)
await (await q.assign_request.remote(
RequestMetadata(get_random_letters(10), "svc", None), 2))
RequestMetadata(get_random_letters(10), "svc"), 2))
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.args[0] == 2
@@ -133,7 +131,7 @@ async def test_split_traffic_random(ray_instance, mock_controller,
object_refs = []
for _ in range(20):
ref = await q.assign_request.remote(
RequestMetadata(get_random_letters(10), "svc", None), 1)
RequestMetadata(get_random_letters(10), "svc"), 1)
object_refs.append(ref)
ray.get(object_refs)
@@ -164,7 +162,7 @@ async def test_shard_key(ray_instance, mock_controller,
for shard_key in shard_keys:
await (await q.assign_request.remote(
RequestMetadata(
get_random_letters(10), "svc", None, shard_key=shard_key),
get_random_letters(10), "svc", shard_key=shard_key),
shard_key))
# Log the shard keys that were assigned to each backend.
@@ -179,7 +177,7 @@ async def test_shard_key(ray_instance, mock_controller,
for shard_key in shard_keys:
await (await q.assign_request.remote(
RequestMetadata(
get_random_letters(10), "svc", None, shard_key=shard_key),
get_random_letters(10), "svc", shard_key=shard_key),
shard_key))
# Check that the requests were all mapped to the same backends.