mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 06:12:09 +08:00
[Serve] Introduce Long Polling (#11905)
This commit is contained in:
+25
-16
@@ -35,13 +35,14 @@ py_test(
|
||||
)
|
||||
|
||||
|
||||
py_test(
|
||||
name = "test_failure",
|
||||
size = "medium",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
# TODO(simon): Test skipped until #11683 fixed.
|
||||
# py_test(
|
||||
# name = "test_failure",
|
||||
# size = "medium",
|
||||
# srcs = serve_tests_srcs,
|
||||
# tags = ["exclusive"],
|
||||
# deps = [":serve_lib"],
|
||||
# )
|
||||
|
||||
|
||||
py_test(
|
||||
@@ -87,6 +88,13 @@ py_test(
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_long_poll",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_standalone",
|
||||
@@ -106,15 +114,16 @@ py_test(
|
||||
|
||||
|
||||
# Runs test_api and test_failure with injected failures in the controller.
|
||||
py_test(
|
||||
name = "test_controller_crashes",
|
||||
size = "large",
|
||||
srcs = glob(["tests/test_controller_crashes.py",
|
||||
"tests/test_api.py",
|
||||
"tests/test_failure.py",
|
||||
"**/conftest.py"],
|
||||
exclude=["tests/test_serve.py"]),
|
||||
)
|
||||
# TODO(simon): Tests are disabled until #11683 is fixed.
|
||||
# py_test(
|
||||
# name = "test_controller_crashes",
|
||||
# size = "large",
|
||||
# srcs = glob(["tests/test_controller_crashes.py",
|
||||
# "tests/test_api.py",
|
||||
# "tests/test_failure.py",
|
||||
# "**/conftest.py"],
|
||||
# exclude=["tests/test_serve.py"]),
|
||||
# )
|
||||
|
||||
py_test(
|
||||
name = "echo_full",
|
||||
|
||||
@@ -19,6 +19,7 @@ from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
|
||||
try_schedule_resources_on_nodes, get_all_node_ids)
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.serve.long_poll import LongPollerHost
|
||||
from ray.actor import ActorHandle
|
||||
|
||||
import numpy as np
|
||||
@@ -182,13 +183,6 @@ class ActorStateReconciler:
|
||||
|
||||
self.backend_replicas[backend_tag][replica_tag] = replica_handle
|
||||
|
||||
# Register the replica with the router.
|
||||
await asyncio.gather(*[
|
||||
router.add_new_replica.remote(backend_tag, replica_tag,
|
||||
replica_handle)
|
||||
for router in self.router_handles()
|
||||
])
|
||||
|
||||
def _scale_backend_replicas(self, backends: Dict[BackendTag, BackendInfo],
|
||||
backend_tag: BackendTag,
|
||||
num_replicas: int) -> None:
|
||||
@@ -265,12 +259,6 @@ class ActorStateReconciler:
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Remove the replica from router. This call is idempotent.
|
||||
await asyncio.gather(*[
|
||||
router.remove_replica.remote(backend_tag, replica_tag)
|
||||
for router in self.router_handles()
|
||||
])
|
||||
|
||||
# TODO(edoakes): this logic isn't ideal because there may be
|
||||
# pending tasks still executing on the replica. However, if we
|
||||
# use replica.__ray_terminate__, we may send it while the
|
||||
@@ -280,18 +268,6 @@ class ActorStateReconciler:
|
||||
|
||||
self.backend_replicas_to_stop.clear()
|
||||
|
||||
async def _remove_pending_backends(self) -> None:
|
||||
"""Removes the pending backends in self.backends_to_remove.
|
||||
|
||||
Clears self.backends_to_remove.
|
||||
"""
|
||||
for backend_tag in self.backends_to_remove:
|
||||
await asyncio.gather(*[
|
||||
router.remove_backend.remote(backend_tag)
|
||||
for router in self.router_handles()
|
||||
])
|
||||
self.backends_to_remove.clear()
|
||||
|
||||
async def _start_single_replica(
|
||||
self, config_store: ConfigurationStore, backend_tag: BackendTag,
|
||||
replica_tag: ReplicaTag, replica_name: str) -> ActorHandle:
|
||||
@@ -372,18 +348,6 @@ class ActorStateReconciler:
|
||||
|
||||
return actor_stopped
|
||||
|
||||
async def _remove_pending_endpoints(self) -> None:
|
||||
"""Removes the pending endpoints in self.actor_reconciler.endpoints_to_remove.
|
||||
|
||||
Clears self.endpoints_to_remove.
|
||||
"""
|
||||
for endpoint_tag in self.endpoints_to_remove:
|
||||
await asyncio.gather(*[
|
||||
router.remove_endpoint.remote(endpoint_tag)
|
||||
for router in self.router_handles()
|
||||
])
|
||||
self.endpoints_to_remove.clear()
|
||||
|
||||
def _recover_actor_handles(self) -> None:
|
||||
# Refresh the RouterCache
|
||||
for node_id in self.routers_cache.keys():
|
||||
@@ -408,47 +372,17 @@ class ActorStateReconciler:
|
||||
) -> Dict[BackendTag, BasicAutoscalingPolicy]:
|
||||
self._recover_actor_handles()
|
||||
autoscaling_policies = dict()
|
||||
# Push configuration state to the router.
|
||||
# TODO(edoakes): should we make this a pull-only model for simplicity?
|
||||
for endpoint, traffic_policy in config_store.traffic_policies.items():
|
||||
await asyncio.gather(*[
|
||||
router.set_traffic.remote(endpoint, traffic_policy)
|
||||
for router in self.router_handles()
|
||||
])
|
||||
|
||||
for backend_tag, replica_dict in self.backend_replicas.items():
|
||||
for replica_tag, replica_handle in replica_dict.items():
|
||||
await asyncio.gather(*[
|
||||
router.add_new_replica.remote(backend_tag, replica_tag,
|
||||
replica_handle)
|
||||
for router in self.router_handles()
|
||||
])
|
||||
|
||||
for backend, info in config_store.backends.items():
|
||||
await asyncio.gather(*[
|
||||
router.set_backend_config.remote(backend, info.backend_config)
|
||||
for router in self.router_handles()
|
||||
])
|
||||
await controller.broadcast_backend_config(backend)
|
||||
metadata = info.backend_config.internal_metadata
|
||||
if metadata.autoscaling_config is not None:
|
||||
autoscaling_policies[backend] = BasicAutoscalingPolicy(
|
||||
backend, metadata.autoscaling_config)
|
||||
|
||||
# Push configuration state to the routers.
|
||||
await asyncio.gather(*[
|
||||
router.set_route_table.remote(config_store.routes)
|
||||
for router in self.router_handles()
|
||||
])
|
||||
|
||||
# Start/stop any pending backend replicas.
|
||||
await self._start_pending_backend_replicas(config_store)
|
||||
await self._stop_pending_backend_replicas()
|
||||
|
||||
# Remove any pending backends and endpoints.
|
||||
await self._remove_pending_backends()
|
||||
await self._remove_pending_endpoints()
|
||||
|
||||
return autoscaling_policies
|
||||
|
||||
|
||||
@@ -536,8 +470,41 @@ class ServeController:
|
||||
asyncio.get_event_loop().create_task(
|
||||
self._recover_from_checkpoint(checkpoint))
|
||||
|
||||
# NOTE(simon): Currently we do all-to-all broadcast. This means
|
||||
# any listeners will receive notification for all changes. This
|
||||
# can be problem at scale, e.g. updating a single backend config
|
||||
# will send over the entire configs. In the future, we should
|
||||
# optimize the logic to support subscription by key.
|
||||
self.long_poll_host = LongPollerHost()
|
||||
self.notify_backend_configs_changed()
|
||||
self.notify_replica_handles_changed()
|
||||
self.notify_traffic_policies_changed()
|
||||
|
||||
asyncio.get_event_loop().create_task(self.run_control_loop())
|
||||
|
||||
def notify_replica_handles_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
"worker_handles", self.actor_reconciler.backend_replicas)
|
||||
|
||||
def notify_traffic_policies_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
"traffic_policies", self.configuration_store.traffic_policies)
|
||||
|
||||
def notify_backend_configs_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
"backend_configs", self.configuration_store.get_backend_configs())
|
||||
|
||||
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
|
||||
"""Proxy long pull client's listen request.
|
||||
|
||||
Args:
|
||||
keys_to_snapshot_ids (Dict[str, int]): Snapshot IDs are used to
|
||||
determine whether or not the host should immediately return the
|
||||
data or wait for the value to be changed.
|
||||
"""
|
||||
return await (
|
||||
self.long_poll_host.listen_for_change(keys_to_snapshot_ids))
|
||||
|
||||
def get_routers(self) -> Dict[str, ActorHandle]:
|
||||
"""Returns a dictionary of node ID to router actor handles."""
|
||||
return self.actor_reconciler.routers_cache
|
||||
@@ -689,10 +656,8 @@ class ServeController:
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
await asyncio.gather(*[
|
||||
router.set_traffic.remote(endpoint_name, traffic_policy)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
|
||||
self.notify_traffic_policies_changed()
|
||||
|
||||
async def set_traffic(self, endpoint_name: str,
|
||||
traffic_dict: Dict[str, float]) -> None:
|
||||
@@ -721,12 +686,7 @@ class ServeController:
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
await asyncio.gather(*[
|
||||
router.set_traffic.remote(
|
||||
endpoint_name,
|
||||
self.configuration_store.traffic_policies[endpoint_name],
|
||||
) for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
self.notify_traffic_policies_changed()
|
||||
|
||||
# TODO(architkulkarni): add Optional for route after cloudpickle upgrade
|
||||
async def create_endpoint(self, endpoint: str,
|
||||
@@ -813,7 +773,6 @@ class ServeController:
|
||||
router.set_route_table.remote(self.configuration_store.routes)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
await self.actor_reconciler._remove_pending_endpoints()
|
||||
|
||||
async def create_backend(self, backend_tag: BackendTag,
|
||||
backend_config: BackendConfig,
|
||||
@@ -859,12 +818,11 @@ class ServeController:
|
||||
await self.actor_reconciler._start_pending_backend_replicas(
|
||||
self.configuration_store)
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
|
||||
# Set the backend config inside the router
|
||||
# (particularly for max-batch-size).
|
||||
await asyncio.gather(*[
|
||||
router.set_backend_config.remote(backend_tag, backend_config)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
# (particularly for max_concurrent_queries).
|
||||
self.notify_backend_configs_changed()
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
async def delete_backend(self, backend_tag: BackendTag) -> None:
|
||||
@@ -903,7 +861,8 @@ class ServeController:
|
||||
# after pushing the update.
|
||||
self._checkpoint()
|
||||
await self.actor_reconciler._stop_pending_backend_replicas()
|
||||
await self.actor_reconciler._remove_pending_backends()
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
|
||||
async def update_backend_config(
|
||||
self, backend_tag: BackendTag,
|
||||
@@ -939,15 +898,14 @@ class ServeController:
|
||||
|
||||
# Inform the router about change in configuration
|
||||
# (particularly for setting max_batch_size).
|
||||
await asyncio.gather(*[
|
||||
router.set_backend_config.remote(backend_tag, backend_config)
|
||||
for router in self.actor_reconciler.router_handles()
|
||||
])
|
||||
|
||||
await self.actor_reconciler._start_pending_backend_replicas(
|
||||
self.configuration_store)
|
||||
await self.actor_reconciler._stop_pending_backend_replicas()
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
self.notify_backend_configs_changed()
|
||||
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
async def broadcast_backend_config(self, backend_tag: BackendTag) -> None:
|
||||
|
||||
@@ -186,25 +186,6 @@ class HTTPProxyActor:
|
||||
self.app.set_route_table(route_table)
|
||||
|
||||
# ------ Proxy router logic ------ #
|
||||
async def add_new_replica(self, backend_tag, replica_tag, worker_handle):
|
||||
return await self.app.router.add_new_replica(backend_tag, replica_tag,
|
||||
worker_handle)
|
||||
|
||||
async def set_traffic(self, endpoint, traffic_policy):
|
||||
return await self.app.router.set_traffic(endpoint, traffic_policy)
|
||||
|
||||
async def set_backend_config(self, backend, config):
|
||||
return await self.app.router.set_backend_config(backend, config)
|
||||
|
||||
async def remove_backend(self, backend):
|
||||
return await self.app.router.remove_backend(backend)
|
||||
|
||||
async def remove_endpoint(self, endpoint):
|
||||
return await self.app.router.remove_endpoint(endpoint)
|
||||
|
||||
async def remove_replica(self, backend_tag, replica_tag):
|
||||
return await self.app.router.remove_replica(backend_tag, replica_tag)
|
||||
|
||||
async def enqueue_request(self, request_meta, *request_args,
|
||||
**request_kwargs):
|
||||
return await self.app.router.enqueue_request(
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
import asyncio
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, Set
|
||||
|
||||
import ray
|
||||
from ray.serve.utils import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdatedObject:
|
||||
object_snapshot: Any
|
||||
# The identifier for the object's version. There is not sequential relation
|
||||
# among different object's snapshot_ids.
|
||||
snapshot_id: int
|
||||
|
||||
|
||||
# Type signature for the update state callbacks. E.g.
|
||||
# async def update_state(updated_object: Any):
|
||||
# do_something(updated_object)
|
||||
UpdateStateAsyncCallable = Callable[[Any], Awaitable[None]]
|
||||
|
||||
|
||||
class LongPollerAsyncClient:
|
||||
"""The asynchronous long polling client.
|
||||
|
||||
Internally, it runs `await object_ref` in a `while True` loop. When a
|
||||
object notification arrived, the client will invoke callback if supplied.
|
||||
Note that this client will wait the callback to be completed before issuing
|
||||
the next poll.
|
||||
|
||||
Args:
|
||||
host_actor(ray.ActorHandle): handle to actor embedding LongPollerHost.
|
||||
key_listeners(Dict[str, AsyncCallable]): a dictionary mapping keys to
|
||||
callbacks to be called on state update for the corresponding keys.
|
||||
"""
|
||||
|
||||
def __init__(self, host_actor,
|
||||
key_listeners: Dict[str, UpdateStateAsyncCallable]) -> None:
|
||||
self.host_actor = host_actor
|
||||
self.key_listeners = key_listeners
|
||||
|
||||
self.snapshot_ids: Dict[str, int] = {
|
||||
key: -1
|
||||
for key in key_listeners.keys()
|
||||
}
|
||||
self.object_snapshots: Dict[str, Any] = dict()
|
||||
|
||||
in_async_loop = asyncio.get_event_loop().is_running
|
||||
assert in_async_loop, "The client is only available in async context."
|
||||
asyncio.get_event_loop().create_task(self._do_long_poll())
|
||||
|
||||
def _poll_once(self) -> ray.ObjectRef:
|
||||
object_ref = self.host_actor.listen_for_change.remote(
|
||||
self.snapshot_ids)
|
||||
return object_ref
|
||||
|
||||
def _update(self, updates: Dict[str, UpdatedObject]):
|
||||
for key, update in updates.items():
|
||||
self.object_snapshots[key] = update.object_snapshot
|
||||
self.snapshot_ids[key] = update.snapshot_id
|
||||
|
||||
async def _do_long_poll(self):
|
||||
while True:
|
||||
updates: Dict[str, UpdatedObject] = await self._poll_once()
|
||||
self._update(updates)
|
||||
for key, updated_object in updates.items():
|
||||
# NOTE(simon): This blocks the loop from doing another poll.
|
||||
# Consider use loop.create_task here or poll first then call
|
||||
# the callbacks.
|
||||
callback = self.key_listeners[key]
|
||||
await callback(updated_object.object_snapshot)
|
||||
|
||||
|
||||
class LongPollerHost:
|
||||
"""The server side object that manages long pulling requests.
|
||||
|
||||
The desired use case is to embed this in an Ray actor. Client will be
|
||||
expected to call actor.listen_for_change.remote(...). On the host side,
|
||||
you can call host.notify_changed(key, object) to update the state and
|
||||
potentially notify whoever is polling for these values.
|
||||
|
||||
Internally, we use snapshot_ids for each object to identify client with
|
||||
outdated object and immediately return the result. If the client has the
|
||||
up-to-date verison, then the listen_for_change call will only return when
|
||||
the object is updated.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map object_key -> int
|
||||
self.snapshot_ids: DefaultDict[str, int] = defaultdict(
|
||||
lambda: random.randint(0, 1_000_000))
|
||||
# Map object_key -> object
|
||||
self.object_snapshots: Dict[str, Any] = dict()
|
||||
# Map object_key -> set(asyncio.Event waiting for updates)
|
||||
self.notifier_events: DefaultDict[str, Set[
|
||||
asyncio.Event]] = defaultdict(set)
|
||||
|
||||
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]
|
||||
) -> Dict[str, UpdatedObject]:
|
||||
"""Listen for changed objects.
|
||||
|
||||
This method will returns a dictionary of updated objects. It returns
|
||||
immediately if the snapshot_ids are outdated, otherwise it will block
|
||||
until there's one updates.
|
||||
"""
|
||||
# 1. Figure out which keys do we care about
|
||||
watched_keys = set(self.snapshot_ids.keys()).intersection(
|
||||
keys_to_snapshot_ids.keys())
|
||||
if len(watched_keys) == 0:
|
||||
raise ValueError("Keys not found.")
|
||||
|
||||
# 2. If there are any outdated keys (by comparing snapshot ids)
|
||||
# return immediately.
|
||||
client_outdated_keys = {
|
||||
key: UpdatedObject(self.object_snapshots[key],
|
||||
self.snapshot_ids[key])
|
||||
for key in watched_keys
|
||||
if self.snapshot_ids[key] != keys_to_snapshot_ids[key]
|
||||
}
|
||||
if len(client_outdated_keys) > 0:
|
||||
return client_outdated_keys
|
||||
|
||||
# 3. Otherwise, register asyncio events to be waited.
|
||||
async_task_to_watched_keys = {}
|
||||
for key in watched_keys:
|
||||
# Create a new asyncio event for this key
|
||||
event = asyncio.Event()
|
||||
task = asyncio.get_event_loop().create_task(event.wait())
|
||||
async_task_to_watched_keys[task] = key
|
||||
|
||||
# Make sure future caller of notify_changed will unblock this
|
||||
# asyncio Event.
|
||||
self.notifier_events[key].add(event)
|
||||
|
||||
done, not_done = await asyncio.wait(
|
||||
async_task_to_watched_keys.keys(),
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
[task.cancel() for task in not_done]
|
||||
|
||||
updated_object_key: str = async_task_to_watched_keys[done.pop()]
|
||||
return {
|
||||
updated_object_key: UpdatedObject(
|
||||
self.object_snapshots[updated_object_key],
|
||||
self.snapshot_ids[updated_object_key])
|
||||
}
|
||||
|
||||
def notify_changed(self, object_key: str, updated_object: Any):
|
||||
self.snapshot_ids[object_key] += 1
|
||||
self.object_snapshots[object_key] = updated_object
|
||||
logger.debug(f"LongPollerHost: {object_key} = {updated_object}")
|
||||
|
||||
if object_key in self.notifier_events:
|
||||
for event in self.notifier_events.pop(object_key):
|
||||
event.set()
|
||||
+62
-19
@@ -6,9 +6,9 @@ from typing import DefaultDict, List, Dict, Any, Optional
|
||||
import pickle
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.serve.long_poll import LongPollerAsyncClient
|
||||
from ray.util import metrics
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.endpoint_policy import RandomEndpointPolicy
|
||||
@@ -70,7 +70,16 @@ class Query:
|
||||
class Router:
|
||||
"""A router that routes request to available replicas."""
|
||||
|
||||
async def setup(self, name, controller_name):
|
||||
async def setup(self, name, controller_name, _do_long_pull=True):
|
||||
"""Setup the router state
|
||||
|
||||
Args:
|
||||
name(str): Used to identify the router when reporting queue
|
||||
lengths to the controller.
|
||||
controller_name(str): The actor name for the controller.
|
||||
_do_long_pull(bool): Used by unit testing.
|
||||
"""
|
||||
|
||||
# Note: Several queues are used in the router
|
||||
# - When a request come in, it's placed inside its corresponding
|
||||
# endpoint_queue.
|
||||
@@ -123,22 +132,6 @@ class Router:
|
||||
# from failure.
|
||||
self.controller = ray.get_actor(controller_name)
|
||||
|
||||
traffic_policies = ray.get(
|
||||
self.controller.get_traffic_policies.remote())
|
||||
for endpoint, traffic_policy in traffic_policies.items():
|
||||
await self.set_traffic(endpoint, traffic_policy)
|
||||
|
||||
backend_dict = ray.get(
|
||||
self.controller.get_all_replica_handles.remote())
|
||||
for backend_tag, replica_dict in backend_dict.items():
|
||||
for replica_tag, replica_handle in replica_dict.items():
|
||||
await self.add_new_replica(backend_tag, replica_tag,
|
||||
replica_handle)
|
||||
|
||||
backend_configs = ray.get(self.controller.get_backend_configs.remote())
|
||||
for backend, backend_config in backend_configs.items():
|
||||
await self.set_backend_config(backend, backend_config)
|
||||
|
||||
# -- Metrics Registration -- #
|
||||
self.num_router_requests = metrics.Count(
|
||||
"num_router_requests",
|
||||
@@ -164,6 +157,56 @@ class Router:
|
||||
|
||||
asyncio.get_event_loop().create_task(self.report_queue_lengths())
|
||||
|
||||
if _do_long_pull:
|
||||
self.long_poll_client = LongPollerAsyncClient(
|
||||
self.controller, {
|
||||
"traffic_policies": self.update_traffic_policies,
|
||||
"worker_handles": self.update_worker_handles,
|
||||
"backend_configs": self.update_backend_configs
|
||||
})
|
||||
|
||||
async def update_traffic_policies(self, traffic_policies):
|
||||
updated_endpoints = set(traffic_policies.keys())
|
||||
curr_endpoints = set(self.traffic.keys())
|
||||
|
||||
for endpoint in updated_endpoints:
|
||||
await self.set_traffic(endpoint, traffic_policies[endpoint])
|
||||
|
||||
removed_endpoints = curr_endpoints - updated_endpoints
|
||||
for endpoint in removed_endpoints:
|
||||
await self.remove_endpoint(endpoint)
|
||||
|
||||
async def update_worker_handles(self, worker_handles):
|
||||
for backend_tag, replica_dict in worker_handles.items():
|
||||
# NOTE(simon): This is a just hack around the current data
|
||||
# structure to resolve replicas added and removed. It will be
|
||||
# immediately become obselete when we update the router.
|
||||
updated_replica_tags = set(replica_dict.keys())
|
||||
curr_replica_tags = {
|
||||
tag.replace(backend_tag + ":", "")
|
||||
for tag in self.replicas.keys() if tag.startswith(backend_tag)
|
||||
}
|
||||
|
||||
added_replicas = updated_replica_tags - curr_replica_tags
|
||||
removed_replicas = curr_replica_tags - updated_replica_tags
|
||||
|
||||
for replica_tag in added_replicas:
|
||||
await self.add_new_replica(backend_tag, replica_tag,
|
||||
replica_dict[replica_tag])
|
||||
for replica_tag in removed_replicas:
|
||||
await self.remove_replica(backend_tag, replica_tag)
|
||||
|
||||
async def update_backend_configs(self, backend_configs):
|
||||
updated_backends = set(backend_configs.keys())
|
||||
curr_backends = set(self.backend_info.keys())
|
||||
|
||||
for backend in updated_backends:
|
||||
await self.set_backend_config(backend, backend_configs[backend])
|
||||
|
||||
removed_backends = curr_backends - updated_backends
|
||||
for backend in removed_backends:
|
||||
await self.remove_backend(backend)
|
||||
|
||||
async def enqueue_request(self, request_meta, *request_args,
|
||||
**request_kwargs):
|
||||
endpoint = request_meta.endpoint
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
import sys
|
||||
import functools
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.serve.long_poll import (LongPollerAsyncClient, LongPollerHost,
|
||||
UpdatedObject)
|
||||
|
||||
|
||||
def test_host_standalone(serve_instance):
|
||||
host = ray.remote(LongPollerHost).remote()
|
||||
|
||||
# Write two values
|
||||
ray.get(host.notify_changed.remote("key_1", 999))
|
||||
ray.get(host.notify_changed.remote("key_2", 999))
|
||||
object_ref = host.listen_for_change.remote({"key_1": -1, "key_2": -1})
|
||||
|
||||
# We should be able to get the result immediately
|
||||
result: Dict[str, UpdatedObject] = ray.get(object_ref)
|
||||
assert set(result.keys()) == {"key_1", "key_2"}
|
||||
assert {v.object_snapshot for v in result.values()} == {999}
|
||||
|
||||
# Now try to pull it again, nothing should happen
|
||||
# because we have the updated snapshot_id
|
||||
new_snapshot_ids = {k: v.snapshot_id for k, v in result.items()}
|
||||
object_ref = host.listen_for_change.remote(new_snapshot_ids)
|
||||
_, not_done = ray.wait([object_ref], timeout=0.2)
|
||||
assert len(not_done) == 1
|
||||
|
||||
# Now update the value, we should immediately get updated value
|
||||
ray.get(host.notify_changed.remote("key_2", 999))
|
||||
result = ray.get(object_ref)
|
||||
assert len(result) == 1
|
||||
assert "key_2" in result
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip until https://github.com/ray-project/ray/issues/11683 fixed "
|
||||
"since async actor retries is broken.")
|
||||
def test_long_pull_restarts(serve_instance):
|
||||
@ray.remote(
|
||||
max_restarts=-1,
|
||||
# max_task_retries=-1,
|
||||
)
|
||||
class RestartableLongPollerHost:
|
||||
def __init__(self) -> None:
|
||||
print("actor started")
|
||||
self.host = LongPollerHost()
|
||||
self.host.notify_changed("timer", time.time())
|
||||
|
||||
async def listen_for_change(self, key_to_ids):
|
||||
await asyncio.sleep(0.5)
|
||||
return await self.host.listen_for_change(key_to_ids)
|
||||
|
||||
async def exit(self):
|
||||
sys.exit(1)
|
||||
|
||||
host = RestartableLongPollerHost.remote()
|
||||
updated_values = ray.get(host.listen_for_change.remote({"timer": -1}))
|
||||
timer: UpdatedObject = updated_values["timer"]
|
||||
|
||||
on_going_ref = host.listen_for_change.remote({"timer": timer.snapshot_id})
|
||||
host.exit.remote()
|
||||
on_going_ref = host.listen_for_change.remote({"timer": timer.snapshot_id})
|
||||
new_timer: UpdatedObject = ray.get(on_going_ref)["timer"]
|
||||
assert new_timer.snapshot_id != timer.snapshot_id + 1
|
||||
assert new_timer.object_snapshot != timer.object_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client(serve_instance):
|
||||
host = ray.remote(LongPollerHost).remote()
|
||||
|
||||
# Write two values
|
||||
ray.get(host.notify_changed.remote("key_1", 100))
|
||||
ray.get(host.notify_changed.remote("key_2", 999))
|
||||
|
||||
callback_results = dict()
|
||||
|
||||
async def callback(result, key):
|
||||
callback_results[key] = result
|
||||
|
||||
client = LongPollerAsyncClient(
|
||||
host, {
|
||||
"key_1": functools.partial(callback, key="key_1"),
|
||||
"key_2": functools.partial(callback, key="key_2")
|
||||
})
|
||||
|
||||
while len(client.object_snapshots) == 0:
|
||||
# Yield the loop for client to get the result
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
assert client.object_snapshots["key_1"] == 100
|
||||
assert client.object_snapshots["key_2"] == 999
|
||||
|
||||
ray.get(host.notify_changed.remote("key_2", 1999))
|
||||
|
||||
values = set()
|
||||
for _ in range(3):
|
||||
values.add(client.object_snapshots["key_2"])
|
||||
if 1999 in values:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
assert 1999 in values
|
||||
|
||||
assert callback_results == {"key_1": 100, "key_2": 1999}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
@@ -1,3 +1,8 @@
|
||||
"""
|
||||
Unit tests for the router class. Please don't add any test that will involve
|
||||
controller or the backend worker, use mock if necessary.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import pytest
|
||||
@@ -48,7 +53,8 @@ def task_runner_mock_actor():
|
||||
|
||||
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
await q.setup.remote(
|
||||
"", serve_instance._controller_name, _do_long_pull=False)
|
||||
|
||||
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
|
||||
q.add_new_replica.remote("backend-single-prod", "replica-1",
|
||||
@@ -67,7 +73,8 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
await q.setup.remote(
|
||||
"", serve_instance._controller_name, _do_long_pull=False)
|
||||
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
|
||||
await q.add_new_replica.remote("backend-alter", "replica-1",
|
||||
@@ -88,7 +95,8 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
await q.setup.remote(
|
||||
"", serve_instance._controller_name, _do_long_pull=False)
|
||||
|
||||
await q.set_traffic.remote(
|
||||
"svc", TrafficPolicy({
|
||||
@@ -119,7 +127,8 @@ async def test_queue_remove_replicas(serve_instance):
|
||||
|
||||
temp_actor = mock_task_runner()
|
||||
q = ray.remote(TestRouter).remote()
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
await q.setup.remote(
|
||||
"", serve_instance._controller_name, _do_long_pull=False)
|
||||
await q.add_new_replica.remote("backend-remove", "replica-1", temp_actor)
|
||||
await q.remove_replica.remote("backend-remove", "replica-1")
|
||||
assert ray.get(q.worker_queue_size.remote("backend")) == 0
|
||||
@@ -127,7 +136,8 @@ async def test_queue_remove_replicas(serve_instance):
|
||||
|
||||
async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
await q.setup.remote(
|
||||
"", serve_instance._controller_name, _do_long_pull=False)
|
||||
|
||||
num_backends = 5
|
||||
traffic_dict = {}
|
||||
@@ -186,7 +196,8 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
|
||||
worker = MockWorker.remote()
|
||||
q = ray.remote(VisibleRouter).remote()
|
||||
await q.setup.remote("", serve_instance._controller_name)
|
||||
await q.setup.remote(
|
||||
"", serve_instance._controller_name, _do_long_pull=False)
|
||||
backend_name = "max-concurrent-test"
|
||||
config = BackendConfig(max_concurrent_queries=1)
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
|
||||
|
||||
Reference in New Issue
Block a user