[serve] Use Long Polling in Backend Worker (#12093)

This commit is contained in:
Ian Rodney
2020-11-25 12:11:38 -08:00
committed by GitHub
parent ca6c2b2442
commit 679492a235
6 changed files with 154 additions and 84 deletions
+1 -1
View File
@@ -1527,7 +1527,7 @@ cdef void async_set_result(shared_ptr[CRayObject] obj,
cpython.Py_DECREF(py_future)
return
if isinstance(result, RayError):
if isinstance(result, RayTaskError):
ray.worker.last_task_error_raise_time = time.time()
py_future.set_exception(result.as_instanceof_cause())
else:
+19 -6
View File
@@ -7,6 +7,7 @@ from typing import Union, List, Any, Callable, Type
import time
import ray
from ray.actor import ActorHandle
from ray.async_compat import sync_to_async
from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
@@ -14,6 +15,7 @@ from ray.serve.utils import (parse_request_item, _get_logger, chain_future,
from ray.serve.exceptions import RayServeException
from ray.util import metrics
from ray.serve.config import BackendConfig
from ray.serve.long_poll import LongPollerAsyncClient
from ray.serve.router import Query
from ray.serve.constants import (DEFAULT_LATENCY_BUCKET_MS,
BACKEND_RECONFIGURE_METHOD)
@@ -109,15 +111,15 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
else:
_callable = func_or_class(*init_args)
assert controller_name, "Must provide a valid controller_name"
controller_handle = ray.get_actor(controller_name)
self.backend = RayServeReplica(backend_tag, replica_tag, _callable,
backend_config, is_function)
backend_config, is_function,
controller_handle)
async def handle_request(self, request):
return await self.backend.handle_request(request)
def update_config(self, new_config: BackendConfig):
return self.backend.update_config(new_config)
def ready(self):
pass
@@ -145,7 +147,8 @@ class RayServeReplica:
"""Handles requests with the provided callable."""
def __init__(self, backend_tag: str, replica_tag: str, _callable: Callable,
backend_config: BackendConfig, is_function: bool) -> None:
backend_config: BackendConfig, is_function: bool,
controller_handle: ActorHandle) -> None:
self.backend_tag = backend_tag
self.replica_tag = replica_tag
self.callable = _callable
@@ -165,6 +168,10 @@ class RayServeReplica:
tag_keys=("backend", ))
self.request_counter.set_default_tags({"backend": self.backend_tag})
self.long_poll_client = LongPollerAsyncClient(controller_handle, {
"backend_configs": self._update_backend_configs,
})
self.error_counter = metrics.Count(
"backend_error_counter",
description=("Number of exceptions that have "
@@ -369,7 +376,13 @@ class RayServeReplica:
BACKEND_RECONFIGURE_METHOD)
reconfigure_method(user_config)
def update_config(self, new_config: BackendConfig) -> None:
async def _update_backend_configs(self, backend_configs):
# TODO(ilr) remove this loop when we poll per key
for backend_tag, config in backend_configs.items():
if backend_tag == self.backend_tag:
self._update_config(config)
def _update_config(self, new_config: BackendConfig) -> None:
self.config = new_config
self.batch_queue.set_config(self.config.max_batch_size or 1,
self.config.batch_wait_timeout)
-17
View File
@@ -146,10 +146,6 @@ class ActorStateReconciler:
for replica_dict in self.backend_replicas.values()
]))
def get_replica_handles_for_backend(
self, backend_tag: BackendTag) -> List[ActorHandle]:
return list(self.backend_replicas.get(backend_tag, {}).values())
async def _start_pending_backend_replicas(
self, config_store: ConfigurationStore) -> None:
"""Starts the pending backend replicas in self.backend_replicas_to_start.
@@ -837,7 +833,6 @@ class ServeController:
# Set the backend config inside the router
# (particularly for max_concurrent_queries).
self.notify_backend_configs_changed()
await self.broadcast_backend_config(backend_tag)
async def delete_backend(self, backend_tag: BackendTag) -> None:
async with self.write_lock:
@@ -914,18 +909,6 @@ class ServeController:
self.notify_replica_handles_changed()
self.notify_backend_configs_changed()
await self.broadcast_backend_config(backend_tag)
async def broadcast_backend_config(self, backend_tag: BackendTag) -> None:
backend_config = self.configuration_store.get_backend(
backend_tag).backend_config
broadcast_futures = [
replica.update_config.remote(backend_config).as_future()
for replica in
self.actor_reconciler.get_replica_handles_for_backend(backend_tag)
]
await asyncio.gather(*broadcast_futures)
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
"""Get the current config for the specified backend."""
assert (self.configuration_store.get_backend(backend_tag)
+53
View File
@@ -1,9 +1,12 @@
from collections import defaultdict
import random
import os
import pytest
import ray
from ray import serve
from ray.serve.config import BackendConfig
if os.environ.get("RAY_SERVE_INTENTIONALLY_CRASH", False) == 1:
serve.controller._CRASH_AFTER_CHECKPOINT_PROBABILITY = 0.5
@@ -32,3 +35,53 @@ def serve_instance(_shared_serve_instance):
_shared_serve_instance.delete_endpoint(endpoint)
for backend in ray.get(controller.get_all_backends.remote()).keys():
_shared_serve_instance.delete_backend(backend)
@pytest.fixture
def mock_controller_with_name():
@ray.remote(num_cpus=0)
class MockControllerActor:
def __init__(self):
from ray.serve.long_poll import LongPollerHost
self.host = LongPollerHost()
self.backend_replicas = defaultdict(list)
self.backend_configs = dict()
self.clear()
def clear(self):
self.host.notify_changed("worker_handles", {})
self.host.notify_changed("traffic_policies", {})
self.host.notify_changed("backend_configs", {})
async def listen_for_change(self, snapshot_ids):
return await self.host.listen_for_change(snapshot_ids)
def set_traffic(self, endpoint, traffic_policy):
self.host.notify_changed("traffic_policies",
{endpoint: traffic_policy})
def add_new_replica(self,
backend_tag,
runner_actor,
backend_config=BackendConfig()):
self.backend_replicas[backend_tag].append(runner_actor)
self.backend_configs[backend_tag] = backend_config
self.host.notify_changed(
"worker_handles",
self.backend_replicas,
)
self.host.notify_changed("backend_configs", self.backend_configs)
def update_backend(self, backend_tag: str,
backend_config: BackendConfig):
self.backend_configs[backend_tag] = backend_config
self.host.notify_changed("backend_configs", self.backend_configs)
name = f"MockController{random.randint(0,10e4)}"
yield name, MockControllerActor.options(name=name).remote()
@pytest.fixture
def mock_controller(mock_controller_with_name):
yield mock_controller_with_name[1]
+81 -20
View File
@@ -45,8 +45,9 @@ def setup_worker(name,
return worker
async def add_servable_to_router(servable, router, **kwargs):
worker = setup_worker("backend", servable, **kwargs)
async def add_servable_to_router(servable, router, controller_name, **kwargs):
worker = setup_worker(
"backend", servable, controller_name=controller_name, **kwargs)
await router._update_worker_handles.remote({"backend": [worker]})
await router._update_traffic_policies.remote({
"endpoint": TrafficPolicy({
@@ -81,11 +82,12 @@ async def test_runner_wraps_error():
assert isinstance(wrapped, ray.exceptions.RayTaskError)
async def test_servable_function(serve_instance, router):
async def test_servable_function(serve_instance, router,
mock_controller_with_name):
def echo(request):
return request.args["i"]
_ = await add_servable_to_router(echo, router)
await add_servable_to_router(echo, router, mock_controller_with_name[0])
for query in [333, 444, 555]:
query_param = make_request_param()
@@ -94,7 +96,8 @@ async def test_servable_function(serve_instance, router):
assert result == query
async def test_servable_class(serve_instance, router):
async def test_servable_class(serve_instance, router,
mock_controller_with_name):
class MyAdder:
def __init__(self, inc):
self.increment = inc
@@ -102,7 +105,8 @@ async def test_servable_class(serve_instance, router):
def __call__(self, request):
return request.args["i"] + self.increment
_ = await add_servable_to_router(MyAdder, router, init_args=(3, ))
await add_servable_to_router(
MyAdder, router, mock_controller_with_name[0], init_args=(3, ))
for query in [333, 444, 555]:
query_param = make_request_param()
@@ -111,7 +115,8 @@ async def test_servable_class(serve_instance, router):
assert result == query + 3
async def test_task_runner_custom_method_single(serve_instance, router):
async def test_task_runner_custom_method_single(serve_instance, router,
mock_controller_with_name):
class NonBatcher:
def a(self, _):
return "a"
@@ -119,7 +124,8 @@ async def test_task_runner_custom_method_single(serve_instance, router):
def b(self, _):
return "b"
_ = await add_servable_to_router(NonBatcher, router)
await add_servable_to_router(NonBatcher, router,
mock_controller_with_name[0])
query_param = make_request_param("a")
a_result = await (await router.assign_request.remote(query_param))
@@ -134,7 +140,8 @@ async def test_task_runner_custom_method_single(serve_instance, router):
await (await router.assign_request.remote(query_param))
async def test_task_runner_custom_method_batch(serve_instance, router):
async def test_task_runner_custom_method_batch(serve_instance, router,
mock_controller_with_name):
@serve.accept_batch
class Batcher:
def a(self, requests):
@@ -147,8 +154,11 @@ async def test_task_runner_custom_method_batch(serve_instance, router):
max_batch_size=4,
batch_wait_timeout=10,
internal_metadata=BackendMetadata(accepts_batches=True))
_ = await add_servable_to_router(
Batcher, router, backend_config=backend_config)
await add_servable_to_router(
Batcher,
router,
mock_controller_with_name[0],
backend_config=backend_config)
a_query_param = make_request_param("a")
b_query_param = make_request_param("b")
@@ -164,7 +174,8 @@ async def test_task_runner_custom_method_batch(serve_instance, router):
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
async def test_servable_batch_error(serve_instance, router):
async def test_servable_batch_error(serve_instance, router,
mock_controller_with_name):
@serve.accept_batch
class ErrorBatcher:
def error_different_size(self, requests):
@@ -179,8 +190,11 @@ async def test_servable_batch_error(serve_instance, router):
backend_config = BackendConfig(
max_batch_size=4,
internal_metadata=BackendMetadata(accepts_batches=True))
_ = await add_servable_to_router(
ErrorBatcher, router, backend_config=backend_config)
await add_servable_to_router(
ErrorBatcher,
router,
mock_controller_with_name[0],
backend_config=backend_config)
with pytest.raises(RayServeException, match="doesn't preserve batch size"):
different_size = make_request_param("error_different_size")
@@ -195,7 +209,8 @@ async def test_servable_batch_error(serve_instance, router):
assert isinstance(result_np_value, np.int32)
async def test_task_runner_perform_batch(serve_instance, router):
async def test_task_runner_perform_batch(serve_instance, router,
mock_controller_with_name):
def batcher(requests):
batch_size = len(requests)
return [batch_size] * batch_size
@@ -205,7 +220,8 @@ async def test_task_runner_perform_batch(serve_instance, router):
batch_wait_timeout=10,
internal_metadata=BackendMetadata(accepts_batches=True))
_ = await add_servable_to_router(batcher, router, backend_config=config)
await add_servable_to_router(
batcher, router, mock_controller_with_name[0], backend_config=config)
query_param = make_request_param()
my_batch_sizes = await asyncio.gather(*[(
@@ -213,7 +229,8 @@ async def test_task_runner_perform_batch(serve_instance, router):
assert my_batch_sizes == [2, 2, 1]
async def test_task_runner_perform_async(serve_instance, router):
async def test_task_runner_perform_async(serve_instance, router,
mock_controller_with_name):
@ray.remote
class Barrier:
def __init__(self, release_on):
@@ -238,8 +255,11 @@ async def test_task_runner_perform_async(serve_instance, router):
max_concurrent_queries=10,
internal_metadata=BackendMetadata(is_blocking=False))
_ = await add_servable_to_router(
wait_and_go, router, backend_config=config)
await add_servable_to_router(
wait_and_go,
router,
mock_controller_with_name[0],
backend_config=config)
query_param = make_request_param()
@@ -248,7 +268,48 @@ async def test_task_runner_perform_async(serve_instance, router):
timeout=10)
assert len(done) == 10
for item in done:
await item == "done!"
assert await item == "done!"
async def test_user_config_update(serve_instance, router,
mock_controller_with_name):
class Customizable:
def __init__(self):
self.reval = ""
def __call__(self, flask_request):
return self.retval
def reconfigure(self, config):
self.retval = config["return_val"]
config = BackendConfig(
num_replicas=2, user_config={
"return_val": "original",
"b": 2
})
await add_servable_to_router(
Customizable,
router,
mock_controller_with_name[0],
backend_config=config)
query_param = make_request_param()
done = [(await router.assign_request.remote(query_param))
for _ in range(10)]
for i in done:
assert await i == "original"
config = BackendConfig()
config.user_config = {"return_val": "new_val"}
await mock_controller_with_name[1].update_backend.remote("backend", config)
done = [(await router.assign_request.remote(query_param))
for _ in range(10)]
for i in done:
assert await i == "new_val"
if __name__ == "__main__":
-40
View File
@@ -10,7 +10,6 @@ import pytest
from ray.serve.context import TaskContext
import ray
from ray.serve.config import BackendConfig
from ray.serve.controller import TrafficPolicy
from ray.serve.router import Query, ReplicaSet, RequestMetadata, Router
from ray.serve.utils import get_random_letters
@@ -61,45 +60,6 @@ def task_runner_mock_actor():
yield mock_task_runner()
@pytest.fixture
def mock_controller():
@ray.remote(num_cpus=0)
class MockControllerActor:
def __init__(self):
from ray.serve.long_poll import LongPollerHost
self.host = LongPollerHost()
self.backend_replicas = defaultdict(list)
self.backend_configs = dict()
self.clear()
def clear(self):
self.host.notify_changed("worker_handles", {})
self.host.notify_changed("traffic_policies", {})
self.host.notify_changed("backend_configs", {})
async def listen_for_change(self, snapshot_ids):
return await self.host.listen_for_change(snapshot_ids)
def set_traffic(self, endpoint, traffic_policy):
self.host.notify_changed("traffic_policies",
{endpoint: traffic_policy})
def add_new_replica(self,
backend_tag,
runner_actor,
backend_config=BackendConfig()):
self.backend_replicas[backend_tag].append(runner_actor)
self.backend_configs[backend_tag] = backend_config
self.host.notify_changed(
"worker_handles",
self.backend_replicas,
)
self.host.notify_changed("backend_configs", self.backend_configs)
yield MockControllerActor.remote()
async def test_simple_endpoint_backend_pair(ray_instance, mock_controller,
task_runner_mock_actor):
q = ray.remote(Router).remote(mock_controller)