mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[Serve] Allow ObjectRef for Composition (#12592)
This commit is contained in:
@@ -18,7 +18,7 @@ from ray.serve.exceptions import RayServeException
|
||||
from ray.util import metrics
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.long_poll import LongPollAsyncClient
|
||||
from ray.serve.router import Query
|
||||
from ray.serve.router import Query, RequestMetadata
|
||||
from ray.serve.constants import (
|
||||
BACKEND_RECONFIGURE_METHOD,
|
||||
DEFAULT_LATENCY_BUCKET_MS,
|
||||
@@ -95,7 +95,11 @@ class BatchQueue:
|
||||
|
||||
|
||||
def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
|
||||
"""Creates a replica class wrapping the provided function or class."""
|
||||
"""Creates a replica class wrapping the provided function or class.
|
||||
|
||||
This approach is picked over inheritance to avoid conflict between user
|
||||
provided class and the RayServeReplica class.
|
||||
"""
|
||||
|
||||
if inspect.isfunction(func_or_class):
|
||||
is_function = True
|
||||
@@ -124,8 +128,15 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
|
||||
is_function, controller_handle)
|
||||
|
||||
@ray.method(num_returns=2)
|
||||
async def handle_request(self, request):
|
||||
return await self.backend.handle_request(request)
|
||||
async def handle_request(
|
||||
self,
|
||||
request_metadata: RequestMetadata,
|
||||
*request_args,
|
||||
**request_kwargs,
|
||||
):
|
||||
# Directly receive input because it might contain an ObjectRef.
|
||||
query = Query(request_args, request_kwargs, request_metadata)
|
||||
return await self.backend.handle_request(query)
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
@@ -149,10 +160,6 @@ def wrap_to_ray_error(exception: Exception) -> RayTaskError:
|
||||
return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__)
|
||||
|
||||
|
||||
def ensure_async(func: Callable) -> Callable:
|
||||
return sync_to_async(func)
|
||||
|
||||
|
||||
class RayServeReplica:
|
||||
"""Handles requests with the provided callable."""
|
||||
|
||||
@@ -286,7 +293,7 @@ class RayServeReplica:
|
||||
async def invoke_single(self, request_item: Query) -> Any:
|
||||
logger.debug("Replica {} started executing request {}".format(
|
||||
self.replica_tag, request_item.metadata.request_id))
|
||||
method_to_call = ensure_async(self.get_runner_method(request_item))
|
||||
method_to_call = sync_to_async(self.get_runner_method(request_item))
|
||||
arg = parse_request_item(request_item)
|
||||
|
||||
start = time.time()
|
||||
@@ -329,7 +336,7 @@ class RayServeReplica:
|
||||
|
||||
self.request_counter.record(batch_size)
|
||||
|
||||
call_method = ensure_async(call_methods.pop())
|
||||
call_method = sync_to_async(call_methods.pop())
|
||||
result_list = await call_method(args)
|
||||
|
||||
if not isinstance(result_list, Iterable) or isinstance(
|
||||
@@ -430,11 +437,7 @@ class RayServeReplica:
|
||||
self.config.batch_wait_timeout)
|
||||
self.reconfigure(self.config.user_config)
|
||||
|
||||
async def handle_request(self,
|
||||
request: Union[Query, bytes]) -> asyncio.Future:
|
||||
if isinstance(request, bytes):
|
||||
request = Query.ray_deserialize(request)
|
||||
|
||||
async def handle_request(self, request: Query) -> asyncio.Future:
|
||||
request.tick_enter_replica = time.time()
|
||||
logger.debug("Replica {} received request {}".format(
|
||||
self.replica_tag, request.metadata.request_id))
|
||||
|
||||
@@ -4,6 +4,8 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
from ray.serve.router import Router
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HandleOptions:
|
||||
@@ -39,7 +41,7 @@ class RayServeHandle:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
router,
|
||||
router: Router,
|
||||
endpoint_name,
|
||||
handle_options: Optional[HandleOptions] = None):
|
||||
self.router = router
|
||||
|
||||
@@ -101,9 +101,12 @@ class ReplicaSet:
|
||||
) >= self.max_concurrent_queries:
|
||||
# This replica is overloaded, try next one
|
||||
continue
|
||||
|
||||
logger.debug(f"Assigned query {query.metadata.request_id} "
|
||||
f"to replica {replica}.")
|
||||
tracker_ref, user_ref = replica.handle_request.remote(query)
|
||||
# Directly passing args because it might contain an ObjectRef.
|
||||
tracker_ref, user_ref = replica.handle_request.remote(
|
||||
query.metadata, *query.args, **query.kwargs)
|
||||
self.in_flight_queries[replica].add(tracker_ref)
|
||||
return user_ref
|
||||
return None
|
||||
@@ -144,9 +147,8 @@ class ReplicaSet:
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
if self.config_updated_event.is_set():
|
||||
self.config_updated_event.clear()
|
||||
# We are pretty sure a free replica is ready now, let's recurse and
|
||||
# assign this query a replica.
|
||||
assigned_ref = await self.assign_replica(query)
|
||||
# We are pretty sure a free replica is ready now.
|
||||
assigned_ref = self._try_assign_replica(query)
|
||||
return assigned_ref
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,12 @@ import gc
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from ray import serve
|
||||
from ray.test_utils import SignalActor
|
||||
|
||||
|
||||
def test_np_in_composed_model(serve_instance):
|
||||
@@ -63,7 +66,33 @@ def test_backend_worker_memory_growth(serve_instance):
|
||||
assert num_unreachable_objects == 0
|
||||
|
||||
|
||||
def test_ref_in_handle_input(serve_instance):
|
||||
client = serve_instance
|
||||
# https://github.com/ray-project/ray/issues/12593
|
||||
|
||||
unblock_worker_signal = SignalActor.remote()
|
||||
|
||||
async def blocked_by_ref(serve_request):
|
||||
data = await serve_request.body()
|
||||
assert not isinstance(data, ray.ObjectRef)
|
||||
|
||||
client.create_backend("ref", blocked_by_ref)
|
||||
client.create_endpoint("ref", backend="ref")
|
||||
handle = client.get_handle("ref")
|
||||
|
||||
# Pass in a ref that's not ready yet
|
||||
ref = unblock_worker_signal.wait.remote()
|
||||
worker_result = handle.remote(ref)
|
||||
|
||||
# Worker shouldn't execute the request
|
||||
with pytest.raises(GetTimeoutError):
|
||||
ray.get(worker_result, timeout=1)
|
||||
|
||||
# Now unblock the worker
|
||||
unblock_worker_signal.send.remote()
|
||||
ray.get(worker_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
||||
@@ -33,11 +33,9 @@ def mock_task_runner():
|
||||
self.queries = []
|
||||
|
||||
@ray.method(num_returns=2)
|
||||
async def handle_request(self, request):
|
||||
if isinstance(request, bytes):
|
||||
request = Query.ray_deserialize(request)
|
||||
self.query = request
|
||||
self.queries.append(request)
|
||||
async def handle_request(self, request_metadata, *args, **kwargs):
|
||||
self.query = Query(args, kwargs, request_metadata)
|
||||
self.queries.append(self.query)
|
||||
return b"", "DONE"
|
||||
|
||||
def get_recent_call(self):
|
||||
@@ -75,7 +73,7 @@ async def test_simple_endpoint_backend_pair(ray_instance, mock_controller,
|
||||
|
||||
# Make sure we get the request result back
|
||||
ref = await q.assign_request.remote(
|
||||
RequestMetadata(get_random_letters(10), "svc", None), 1)
|
||||
RequestMetadata(get_random_letters(10), "svc"), 1)
|
||||
result = await ref
|
||||
assert result == "DONE"
|
||||
|
||||
@@ -98,7 +96,7 @@ async def test_changing_backend(ray_instance, mock_controller,
|
||||
task_runner_mock_actor)
|
||||
|
||||
await (await q.assign_request.remote(
|
||||
RequestMetadata(get_random_letters(10), "svc", None), 1))
|
||||
RequestMetadata(get_random_letters(10), "svc"), 1))
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.args[0] == 1
|
||||
|
||||
@@ -109,7 +107,7 @@ async def test_changing_backend(ray_instance, mock_controller,
|
||||
await mock_controller.add_new_replica.remote("backend-alter-2",
|
||||
task_runner_mock_actor)
|
||||
await (await q.assign_request.remote(
|
||||
RequestMetadata(get_random_letters(10), "svc", None), 2))
|
||||
RequestMetadata(get_random_letters(10), "svc"), 2))
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.args[0] == 2
|
||||
|
||||
@@ -133,7 +131,7 @@ async def test_split_traffic_random(ray_instance, mock_controller,
|
||||
object_refs = []
|
||||
for _ in range(20):
|
||||
ref = await q.assign_request.remote(
|
||||
RequestMetadata(get_random_letters(10), "svc", None), 1)
|
||||
RequestMetadata(get_random_letters(10), "svc"), 1)
|
||||
object_refs.append(ref)
|
||||
ray.get(object_refs)
|
||||
|
||||
@@ -164,7 +162,7 @@ async def test_shard_key(ray_instance, mock_controller,
|
||||
for shard_key in shard_keys:
|
||||
await (await q.assign_request.remote(
|
||||
RequestMetadata(
|
||||
get_random_letters(10), "svc", None, shard_key=shard_key),
|
||||
get_random_letters(10), "svc", shard_key=shard_key),
|
||||
shard_key))
|
||||
|
||||
# Log the shard keys that were assigned to each backend.
|
||||
@@ -179,7 +177,7 @@ async def test_shard_key(ray_instance, mock_controller,
|
||||
for shard_key in shard_keys:
|
||||
await (await q.assign_request.remote(
|
||||
RequestMetadata(
|
||||
get_random_letters(10), "svc", None, shard_key=shard_key),
|
||||
get_random_letters(10), "svc", shard_key=shard_key),
|
||||
shard_key))
|
||||
|
||||
# Check that the requests were all mapped to the same backends.
|
||||
|
||||
Reference in New Issue
Block a user