diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 197ee302c..e430dbe38 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -121,6 +121,7 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]): self.backend = RayServeReplica(_callable, backend_config, is_function, controller_handle) + @ray.method(num_returns=2) async def handle_request(self, request): return await self.backend.handle_request(request) @@ -411,7 +412,8 @@ class RayServeReplica: self.replica_tag, request.metadata.request_id, request_time_ms)) self.num_ongoing_requests -= 1 - return result + # Returns a small object for router to track request status. + return b"", result async def drain_pending_queries(self): """Perform graceful shutdown. diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 1ee2e3c59..1e118a604 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -103,9 +103,9 @@ class ReplicaSet: continue logger.debug(f"Assigned query {query.metadata.request_id} " f"to replica {replica}.") - ref = replica.handle_request.remote(query) - self.in_flight_queries[replica].add(ref) - return ref + tracker_ref, user_ref = replica.handle_request.remote(query) + self.in_flight_queries[replica].add(tracker_ref) + return user_ref return None @property diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index ee175a4d1..74c5418df 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -33,6 +33,7 @@ def setup_worker(name, def ready(self): pass + @ray.method(num_returns=2) async def handle_request(self, *args, **kwargs): return await self.worker.handle_request(*args, **kwargs) diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index 69ffeb2e0..a3164ada2 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -32,12 +32,13 @@ def mock_task_runner(): self.query = None self.queries = [] + @ray.method(num_returns=2) async def handle_request(self, request): if isinstance(request, bytes): request = Query.ray_deserialize(request) self.query = request self.queries.append(request) - return "DONE" + return b"", "DONE" def get_recent_call(self): return self.query @@ -195,10 +196,11 @@ async def test_replica_set(ray_instance): class MockWorker: _num_queries = 0 + @ray.method(num_returns=2) async def handle_request(self, request): self._num_queries += 1 await signal.wait.remote() - return "DONE" + return b"", "DONE" async def num_queries(self): return self._num_queries