mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 09:24:28 +08:00
[serve] Use Long Polling in Backend Worker (#12093)
This commit is contained in:
@@ -1527,7 +1527,7 @@ cdef void async_set_result(shared_ptr[CRayObject] obj,
|
||||
cpython.Py_DECREF(py_future)
|
||||
return
|
||||
|
||||
if isinstance(result, RayError):
|
||||
if isinstance(result, RayTaskError):
|
||||
ray.worker.last_task_error_raise_time = time.time()
|
||||
py_future.set_exception(result.as_instanceof_cause())
|
||||
else:
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Union, List, Any, Callable, Type
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.async_compat import sync_to_async
|
||||
|
||||
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
|
||||
@@ -14,6 +15,7 @@ from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.util import metrics
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.long_poll import LongPollerAsyncClient
|
||||
from ray.serve.router import Query
|
||||
from ray.serve.constants import (DEFAULT_LATENCY_BUCKET_MS,
|
||||
BACKEND_RECONFIGURE_METHOD)
|
||||
@@ -109,15 +111,15 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
|
||||
else:
|
||||
_callable = func_or_class(*init_args)
|
||||
|
||||
assert controller_name, "Must provide a valid controller_name"
|
||||
controller_handle = ray.get_actor(controller_name)
|
||||
self.backend = RayServeReplica(backend_tag, replica_tag, _callable,
|
||||
backend_config, is_function)
|
||||
backend_config, is_function,
|
||||
controller_handle)
|
||||
|
||||
async def handle_request(self, request):
|
||||
return await self.backend.handle_request(request)
|
||||
|
||||
def update_config(self, new_config: BackendConfig):
|
||||
return self.backend.update_config(new_config)
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
|
||||
@@ -145,7 +147,8 @@ class RayServeReplica:
|
||||
"""Handles requests with the provided callable."""
|
||||
|
||||
def __init__(self, backend_tag: str, replica_tag: str, _callable: Callable,
|
||||
backend_config: BackendConfig, is_function: bool) -> None:
|
||||
backend_config: BackendConfig, is_function: bool,
|
||||
controller_handle: ActorHandle) -> None:
|
||||
self.backend_tag = backend_tag
|
||||
self.replica_tag = replica_tag
|
||||
self.callable = _callable
|
||||
@@ -165,6 +168,10 @@ class RayServeReplica:
|
||||
tag_keys=("backend", ))
|
||||
self.request_counter.set_default_tags({"backend": self.backend_tag})
|
||||
|
||||
self.long_poll_client = LongPollerAsyncClient(controller_handle, {
|
||||
"backend_configs": self._update_backend_configs,
|
||||
})
|
||||
|
||||
self.error_counter = metrics.Count(
|
||||
"backend_error_counter",
|
||||
description=("Number of exceptions that have "
|
||||
@@ -369,7 +376,13 @@ class RayServeReplica:
|
||||
BACKEND_RECONFIGURE_METHOD)
|
||||
reconfigure_method(user_config)
|
||||
|
||||
def update_config(self, new_config: BackendConfig) -> None:
|
||||
async def _update_backend_configs(self, backend_configs):
|
||||
# TODO(ilr) remove this loop when we poll per key
|
||||
for backend_tag, config in backend_configs.items():
|
||||
if backend_tag == self.backend_tag:
|
||||
self._update_config(config)
|
||||
|
||||
def _update_config(self, new_config: BackendConfig) -> None:
|
||||
self.config = new_config
|
||||
self.batch_queue.set_config(self.config.max_batch_size or 1,
|
||||
self.config.batch_wait_timeout)
|
||||
|
||||
@@ -146,10 +146,6 @@ class ActorStateReconciler:
|
||||
for replica_dict in self.backend_replicas.values()
|
||||
]))
|
||||
|
||||
def get_replica_handles_for_backend(
|
||||
self, backend_tag: BackendTag) -> List[ActorHandle]:
|
||||
return list(self.backend_replicas.get(backend_tag, {}).values())
|
||||
|
||||
async def _start_pending_backend_replicas(
|
||||
self, config_store: ConfigurationStore) -> None:
|
||||
"""Starts the pending backend replicas in self.backend_replicas_to_start.
|
||||
@@ -837,7 +833,6 @@ class ServeController:
|
||||
# Set the backend config inside the router
|
||||
# (particularly for max_concurrent_queries).
|
||||
self.notify_backend_configs_changed()
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
async def delete_backend(self, backend_tag: BackendTag) -> None:
|
||||
async with self.write_lock:
|
||||
@@ -914,18 +909,6 @@ class ServeController:
|
||||
self.notify_replica_handles_changed()
|
||||
self.notify_backend_configs_changed()
|
||||
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
async def broadcast_backend_config(self, backend_tag: BackendTag) -> None:
|
||||
backend_config = self.configuration_store.get_backend(
|
||||
backend_tag).backend_config
|
||||
broadcast_futures = [
|
||||
replica.update_config.remote(backend_config).as_future()
|
||||
for replica in
|
||||
self.actor_reconciler.get_replica_handles_for_backend(backend_tag)
|
||||
]
|
||||
await asyncio.gather(*broadcast_futures)
|
||||
|
||||
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
|
||||
"""Get the current config for the specified backend."""
|
||||
assert (self.configuration_store.get_backend(backend_tag)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.config import BackendConfig
|
||||
|
||||
if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1:
|
||||
serve.controller._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5
|
||||
@@ -32,3 +35,53 @@ def serve_instance(_shared_serve_instance):
|
||||
_shared_serve_instance.delete_endpoint(endpoint)
|
||||
for backend in ray.get(controller.get_all_backends.remote()).keys():
|
||||
_shared_serve_instance.delete_backend(backend)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_controller_with_name():
|
||||
@ray.remote(num_cpus=0)
|
||||
class MockControllerActor:
|
||||
def __init__(self):
|
||||
from ray.serve.long_poll import LongPollerHost
|
||||
self.host = LongPollerHost()
|
||||
self.backend_replicas = defaultdict(list)
|
||||
self.backend_configs = dict()
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.host.notify_changed("worker_handles", {})
|
||||
self.host.notify_changed("traffic_policies", {})
|
||||
self.host.notify_changed("backend_configs", {})
|
||||
|
||||
async def listen_for_change(self, snapshot_ids):
|
||||
return await self.host.listen_for_change(snapshot_ids)
|
||||
|
||||
def set_traffic(self, endpoint, traffic_policy):
|
||||
self.host.notify_changed("traffic_policies",
|
||||
{endpoint: traffic_policy})
|
||||
|
||||
def add_new_replica(self,
|
||||
backend_tag,
|
||||
runner_actor,
|
||||
backend_config=BackendConfig()):
|
||||
self.backend_replicas[backend_tag].append(runner_actor)
|
||||
self.backend_configs[backend_tag] = backend_config
|
||||
|
||||
self.host.notify_changed(
|
||||
"worker_handles",
|
||||
self.backend_replicas,
|
||||
)
|
||||
self.host.notify_changed("backend_configs", self.backend_configs)
|
||||
|
||||
def update_backend(self, backend_tag: str,
|
||||
backend_config: BackendConfig):
|
||||
self.backend_configs[backend_tag] = backend_config
|
||||
self.host.notify_changed("backend_configs", self.backend_configs)
|
||||
|
||||
name = f"MockController{random.randint(0,10e4)}"
|
||||
yield name, MockControllerActor.options(name=name).remote()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_controller(mock_controller_with_name):
|
||||
yield mock_controller_with_name[1]
|
||||
|
||||
@@ -45,8 +45,9 @@ def setup_worker(name,
|
||||
return worker
|
||||
|
||||
|
||||
async def add_servable_to_router(servable, router, **kwargs):
|
||||
worker = setup_worker("backend", servable, **kwargs)
|
||||
async def add_servable_to_router(servable, router, controller_name, **kwargs):
|
||||
worker = setup_worker(
|
||||
"backend", servable, controller_name=controller_name, **kwargs)
|
||||
await router._update_worker_handles.remote({"backend": [worker]})
|
||||
await router._update_traffic_policies.remote({
|
||||
"endpoint": TrafficPolicy({
|
||||
@@ -81,11 +82,12 @@ async def test_runner_wraps_error():
|
||||
assert isinstance(wrapped, ray.exceptions.RayTaskError)
|
||||
|
||||
|
||||
async def test_servable_function(serve_instance, router):
|
||||
async def test_servable_function(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
def echo(request):
|
||||
return request.args["i"]
|
||||
|
||||
_ = await add_servable_to_router(echo, router)
|
||||
await add_servable_to_router(echo, router, mock_controller_with_name[0])
|
||||
|
||||
for query in [333, 444, 555]:
|
||||
query_param = make_request_param()
|
||||
@@ -94,7 +96,8 @@ async def test_servable_function(serve_instance, router):
|
||||
assert result == query
|
||||
|
||||
|
||||
async def test_servable_class(serve_instance, router):
|
||||
async def test_servable_class(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
class MyAdder:
|
||||
def __init__(self, inc):
|
||||
self.increment = inc
|
||||
@@ -102,7 +105,8 @@ async def test_servable_class(serve_instance, router):
|
||||
def __call__(self, request):
|
||||
return request.args["i"] + self.increment
|
||||
|
||||
_ = await add_servable_to_router(MyAdder, router, init_args=(3, ))
|
||||
await add_servable_to_router(
|
||||
MyAdder, router, mock_controller_with_name[0], init_args=(3, ))
|
||||
|
||||
for query in [333, 444, 555]:
|
||||
query_param = make_request_param()
|
||||
@@ -111,7 +115,8 @@ async def test_servable_class(serve_instance, router):
|
||||
assert result == query + 3
|
||||
|
||||
|
||||
async def test_task_runner_custom_method_single(serve_instance, router):
|
||||
async def test_task_runner_custom_method_single(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
class NonBatcher:
|
||||
def a(self, _):
|
||||
return "a"
|
||||
@@ -119,7 +124,8 @@ async def test_task_runner_custom_method_single(serve_instance, router):
|
||||
def b(self, _):
|
||||
return "b"
|
||||
|
||||
_ = await add_servable_to_router(NonBatcher, router)
|
||||
await add_servable_to_router(NonBatcher, router,
|
||||
mock_controller_with_name[0])
|
||||
|
||||
query_param = make_request_param("a")
|
||||
a_result = await (await router.assign_request.remote(query_param))
|
||||
@@ -134,7 +140,8 @@ async def test_task_runner_custom_method_single(serve_instance, router):
|
||||
await (await router.assign_request.remote(query_param))
|
||||
|
||||
|
||||
async def test_task_runner_custom_method_batch(serve_instance, router):
|
||||
async def test_task_runner_custom_method_batch(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
@serve.accept_batch
|
||||
class Batcher:
|
||||
def a(self, requests):
|
||||
@@ -147,8 +154,11 @@ async def test_task_runner_custom_method_batch(serve_instance, router):
|
||||
max_batch_size=4,
|
||||
batch_wait_timeout=10,
|
||||
internal_metadata=BackendMetadata(accepts_batches=True))
|
||||
_ = await add_servable_to_router(
|
||||
Batcher, router, backend_config=backend_config)
|
||||
await add_servable_to_router(
|
||||
Batcher,
|
||||
router,
|
||||
mock_controller_with_name[0],
|
||||
backend_config=backend_config)
|
||||
|
||||
a_query_param = make_request_param("a")
|
||||
b_query_param = make_request_param("b")
|
||||
@@ -164,7 +174,8 @@ async def test_task_runner_custom_method_batch(serve_instance, router):
|
||||
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
|
||||
|
||||
|
||||
async def test_servable_batch_error(serve_instance, router):
|
||||
async def test_servable_batch_error(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
@serve.accept_batch
|
||||
class ErrorBatcher:
|
||||
def error_different_size(self, requests):
|
||||
@@ -179,8 +190,11 @@ async def test_servable_batch_error(serve_instance, router):
|
||||
backend_config = BackendConfig(
|
||||
max_batch_size=4,
|
||||
internal_metadata=BackendMetadata(accepts_batches=True))
|
||||
_ = await add_servable_to_router(
|
||||
ErrorBatcher, router, backend_config=backend_config)
|
||||
await add_servable_to_router(
|
||||
ErrorBatcher,
|
||||
router,
|
||||
mock_controller_with_name[0],
|
||||
backend_config=backend_config)
|
||||
|
||||
with pytest.raises(RayServeException, match="doesn't preserve batch size"):
|
||||
different_size = make_request_param("error_different_size")
|
||||
@@ -195,7 +209,8 @@ async def test_servable_batch_error(serve_instance, router):
|
||||
assert isinstance(result_np_value, np.int32)
|
||||
|
||||
|
||||
async def test_task_runner_perform_batch(serve_instance, router):
|
||||
async def test_task_runner_perform_batch(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
def batcher(requests):
|
||||
batch_size = len(requests)
|
||||
return [batch_size] * batch_size
|
||||
@@ -205,7 +220,8 @@ async def test_task_runner_perform_batch(serve_instance, router):
|
||||
batch_wait_timeout=10,
|
||||
internal_metadata=BackendMetadata(accepts_batches=True))
|
||||
|
||||
_ = await add_servable_to_router(batcher, router, backend_config=config)
|
||||
await add_servable_to_router(
|
||||
batcher, router, mock_controller_with_name[0], backend_config=config)
|
||||
|
||||
query_param = make_request_param()
|
||||
my_batch_sizes = await asyncio.gather(*[(
|
||||
@@ -213,7 +229,8 @@ async def test_task_runner_perform_batch(serve_instance, router):
|
||||
assert my_batch_sizes == [2, 2, 1]
|
||||
|
||||
|
||||
async def test_task_runner_perform_async(serve_instance, router):
|
||||
async def test_task_runner_perform_async(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
@ray.remote
|
||||
class Barrier:
|
||||
def __init__(self, release_on):
|
||||
@@ -238,8 +255,11 @@ async def test_task_runner_perform_async(serve_instance, router):
|
||||
max_concurrent_queries=10,
|
||||
internal_metadata=BackendMetadata(is_blocking=False))
|
||||
|
||||
_ = await add_servable_to_router(
|
||||
wait_and_go, router, backend_config=config)
|
||||
await add_servable_to_router(
|
||||
wait_and_go,
|
||||
router,
|
||||
mock_controller_with_name[0],
|
||||
backend_config=config)
|
||||
|
||||
query_param = make_request_param()
|
||||
|
||||
@@ -248,7 +268,48 @@ async def test_task_runner_perform_async(serve_instance, router):
|
||||
timeout=10)
|
||||
assert len(done) == 10
|
||||
for item in done:
|
||||
await item == "done!"
|
||||
assert await item == "done!"
|
||||
|
||||
|
||||
async def test_user_config_update(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
class Customizable:
|
||||
def __init__(self):
|
||||
self.reval = ""
|
||||
|
||||
def __call__(self, flask_request):
|
||||
return self.retval
|
||||
|
||||
def reconfigure(self, config):
|
||||
self.retval = config["return_val"]
|
||||
|
||||
config = BackendConfig(
|
||||
num_replicas=2, user_config={
|
||||
"return_val": "original",
|
||||
"b": 2
|
||||
})
|
||||
await add_servable_to_router(
|
||||
Customizable,
|
||||
router,
|
||||
mock_controller_with_name[0],
|
||||
backend_config=config)
|
||||
|
||||
query_param = make_request_param()
|
||||
|
||||
done = [(await router.assign_request.remote(query_param))
|
||||
for _ in range(10)]
|
||||
for i in done:
|
||||
assert await i == "original"
|
||||
|
||||
config = BackendConfig()
|
||||
config.user_config = {"return_val": "new_val"}
|
||||
await mock_controller_with_name[1].update_backend.remote("backend", config)
|
||||
|
||||
done = [(await router.assign_request.remote(query_param))
|
||||
for _ in range(10)]
|
||||
|
||||
for i in done:
|
||||
assert await i == "new_val"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,7 +10,6 @@ import pytest
|
||||
from ray.serve.context import TaskContext
|
||||
|
||||
import ray
|
||||
from ray.serve.config import BackendConfig
|
||||
from ray.serve.controller import TrafficPolicy
|
||||
from ray.serve.router import Query, ReplicaSet, RequestMetadata, Router
|
||||
from ray.serve.utils import get_random_letters
|
||||
@@ -61,45 +60,6 @@ def task_runner_mock_actor():
|
||||
yield mock_task_runner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_controller():
|
||||
@ray.remote(num_cpus=0)
|
||||
class MockControllerActor:
|
||||
def __init__(self):
|
||||
from ray.serve.long_poll import LongPollerHost
|
||||
self.host = LongPollerHost()
|
||||
self.backend_replicas = defaultdict(list)
|
||||
self.backend_configs = dict()
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.host.notify_changed("worker_handles", {})
|
||||
self.host.notify_changed("traffic_policies", {})
|
||||
self.host.notify_changed("backend_configs", {})
|
||||
|
||||
async def listen_for_change(self, snapshot_ids):
|
||||
return await self.host.listen_for_change(snapshot_ids)
|
||||
|
||||
def set_traffic(self, endpoint, traffic_policy):
|
||||
self.host.notify_changed("traffic_policies",
|
||||
{endpoint: traffic_policy})
|
||||
|
||||
def add_new_replica(self,
|
||||
backend_tag,
|
||||
runner_actor,
|
||||
backend_config=BackendConfig()):
|
||||
self.backend_replicas[backend_tag].append(runner_actor)
|
||||
self.backend_configs[backend_tag] = backend_config
|
||||
|
||||
self.host.notify_changed(
|
||||
"worker_handles",
|
||||
self.backend_replicas,
|
||||
)
|
||||
self.host.notify_changed("backend_configs", self.backend_configs)
|
||||
|
||||
yield MockControllerActor.remote()
|
||||
|
||||
|
||||
async def test_simple_endpoint_backend_pair(ray_instance, mock_controller,
|
||||
task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote(mock_controller)
|
||||
|
||||
Reference in New Issue
Block a user