diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 7a7198010..da087efa5 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -18,7 +18,7 @@ from ray.serve.exceptions import RayServeException from ray.util import metrics from ray.serve.config import BackendConfig from ray.serve.long_poll import LongPollAsyncClient -from ray.serve.router import Query +from ray.serve.router import Query, RequestMetadata from ray.serve.constants import ( BACKEND_RECONFIGURE_METHOD, DEFAULT_LATENCY_BUCKET_MS, @@ -95,7 +95,11 @@ class BatchQueue: def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]): - """Creates a replica class wrapping the provided function or class.""" + """Creates a replica class wrapping the provided function or class. + + This approach is picked over inheritance to avoid conflict between user + provided class and the RayServeReplica class. + """ if inspect.isfunction(func_or_class): is_function = True @@ -124,8 +128,15 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]): is_function, controller_handle) @ray.method(num_returns=2) - async def handle_request(self, request): - return await self.backend.handle_request(request) + async def handle_request( + self, + request_metadata: RequestMetadata, + *request_args, + **request_kwargs, + ): + # Directly receive input because it might contain an ObjectRef. + query = Query(request_args, request_kwargs, request_metadata) + return await self.backend.handle_request(query) def ready(self): pass @@ -149,10 +160,6 @@ def wrap_to_ray_error(exception: Exception) -> RayTaskError: return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__) -def ensure_async(func: Callable) -> Callable: - return sync_to_async(func) - - class RayServeReplica: """Handles requests with the provided callable.""" @@ -286,7 +293,7 @@ class RayServeReplica: async def invoke_single(self, request_item: Query) -> Any: logger.debug("Replica {} started executing request {}".format( self.replica_tag, request_item.metadata.request_id)) - method_to_call = ensure_async(self.get_runner_method(request_item)) + method_to_call = sync_to_async(self.get_runner_method(request_item)) arg = parse_request_item(request_item) start = time.time() @@ -329,7 +336,7 @@ class RayServeReplica: self.request_counter.record(batch_size) - call_method = ensure_async(call_methods.pop()) + call_method = sync_to_async(call_methods.pop()) result_list = await call_method(args) if not isinstance(result_list, Iterable) or isinstance( @@ -430,11 +437,7 @@ class RayServeReplica: self.config.batch_wait_timeout) self.reconfigure(self.config.user_config) - async def handle_request(self, - request: Union[Query, bytes]) -> asyncio.Future: - if isinstance(request, bytes): - request = Query.ray_deserialize(request) - + async def handle_request(self, request: Query) -> asyncio.Future: request.tick_enter_replica = time.time() logger.debug("Replica {} received request {}".format( self.replica_tag, request.metadata.request_id)) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 80ea835d3..c6951c638 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -4,6 +4,8 @@ from dataclasses import dataclass, field from typing import Any, Dict, Optional, Union from enum import Enum +from ray.serve.router import Router + @dataclass(frozen=True) class HandleOptions: @@ -39,7 +41,7 @@ class RayServeHandle: """ def __init__(self, - router, + router: Router, endpoint_name, handle_options: Optional[HandleOptions] = None): self.router = router diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 1e118a604..477f037fd 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -101,9 +101,12 @@ class ReplicaSet: ) >= self.max_concurrent_queries: # This replica is overloaded, try next one continue + logger.debug(f"Assigned query {query.metadata.request_id} " f"to replica {replica}.") - tracker_ref, user_ref = replica.handle_request.remote(query) + # Directly passing args because it might contain an ObjectRef. + tracker_ref, user_ref = replica.handle_request.remote( + query.metadata, *query.args, **query.kwargs) self.in_flight_queries[replica].add(tracker_ref) return user_ref return None @@ -144,9 +147,8 @@ class ReplicaSet: return_when=asyncio.FIRST_COMPLETED) if self.config_updated_event.is_set(): self.config_updated_event.clear() - # We are pretty sure a free replica is ready now, let's recurse and - # assign this query a replica. - assigned_ref = await self.assign_replica(query) + # We are pretty sure a free replica is ready now. + assigned_ref = self._try_assign_replica(query) return assigned_ref diff --git a/python/ray/serve/tests/test_regression.py b/python/ray/serve/tests/test_regression.py index e2425edf9..401837cab 100644 --- a/python/ray/serve/tests/test_regression.py +++ b/python/ray/serve/tests/test_regression.py @@ -2,9 +2,12 @@ import gc import numpy as np import requests +import pytest import ray +from ray.exceptions import GetTimeoutError from ray import serve +from ray.test_utils import SignalActor def test_np_in_composed_model(serve_instance): @@ -63,7 +66,33 @@ def test_backend_worker_memory_growth(serve_instance): assert num_unreachable_objects == 0 +def test_ref_in_handle_input(serve_instance): + client = serve_instance + # https://github.com/ray-project/ray/issues/12593 + + unblock_worker_signal = SignalActor.remote() + + async def blocked_by_ref(serve_request): + data = await serve_request.body() + assert not isinstance(data, ray.ObjectRef) + + client.create_backend("ref", blocked_by_ref) + client.create_endpoint("ref", backend="ref") + handle = client.get_handle("ref") + + # Pass in a ref that's not ready yet + ref = unblock_worker_signal.wait.remote() + worker_result = handle.remote(ref) + + # Worker shouldn't execute the request + with pytest.raises(GetTimeoutError): + ray.get(worker_result, timeout=1) + + # Now unblock the worker + unblock_worker_signal.send.remote() + ray.get(worker_result) + + if __name__ == "__main__": import sys - import pytest sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index a3164ada2..231ac11a5 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -33,11 +33,9 @@ def mock_task_runner(): self.queries = [] @ray.method(num_returns=2) - async def handle_request(self, request): - if isinstance(request, bytes): - request = Query.ray_deserialize(request) - self.query = request - self.queries.append(request) + async def handle_request(self, request_metadata, *args, **kwargs): + self.query = Query(args, kwargs, request_metadata) + self.queries.append(self.query) return b"", "DONE" def get_recent_call(self): @@ -75,7 +73,7 @@ async def test_simple_endpoint_backend_pair(ray_instance, mock_controller, # Make sure we get the request result back ref = await q.assign_request.remote( - RequestMetadata(get_random_letters(10), "svc", None), 1) + RequestMetadata(get_random_letters(10), "svc"), 1) result = await ref assert result == "DONE" @@ -98,7 +96,7 @@ async def test_changing_backend(ray_instance, mock_controller, task_runner_mock_actor) await (await q.assign_request.remote( - RequestMetadata(get_random_letters(10), "svc", None), 1)) + RequestMetadata(get_random_letters(10), "svc"), 1)) got_work = await task_runner_mock_actor.get_recent_call.remote() assert got_work.args[0] == 1 @@ -109,7 +107,7 @@ async def test_changing_backend(ray_instance, mock_controller, await mock_controller.add_new_replica.remote("backend-alter-2", task_runner_mock_actor) await (await q.assign_request.remote( - RequestMetadata(get_random_letters(10), "svc", None), 2)) + RequestMetadata(get_random_letters(10), "svc"), 2)) got_work = await task_runner_mock_actor.get_recent_call.remote() assert got_work.args[0] == 2 @@ -133,7 +131,7 @@ async def test_split_traffic_random(ray_instance, mock_controller, object_refs = [] for _ in range(20): ref = await q.assign_request.remote( - RequestMetadata(get_random_letters(10), "svc", None), 1) + RequestMetadata(get_random_letters(10), "svc"), 1) object_refs.append(ref) ray.get(object_refs) @@ -164,7 +162,7 @@ async def test_shard_key(ray_instance, mock_controller, for shard_key in shard_keys: await (await q.assign_request.remote( RequestMetadata( - get_random_letters(10), "svc", None, shard_key=shard_key), + get_random_letters(10), "svc", shard_key=shard_key), shard_key)) # Log the shard keys that were assigned to each backend. @@ -179,7 +177,7 @@ async def test_shard_key(ray_instance, mock_controller, for shard_key in shard_keys: await (await q.assign_request.remote( RequestMetadata( - get_random_letters(10), "svc", None, shard_key=shard_key), + get_random_letters(10), "svc", shard_key=shard_key), shard_key)) # Check that the requests were all mapped to the same backends.