mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[Serve] Implement Graceful Shutdown (#13028)
This commit is contained in:
@@ -329,7 +329,7 @@ class Client:
|
||||
func_or_class (callable, class): a function or a class implementing
|
||||
__call__, returning a JSON-serializable object or a
|
||||
Starlette Response object.
|
||||
actor_init_args (optional): the arguments to pass to the class.
|
||||
*actor_init_args (optional): the arguments to pass to the class
|
||||
initialization method.
|
||||
ray_actor_options (optional): options to be passed into the
|
||||
@ray.remote decorator for the backend actor.
|
||||
@@ -409,12 +409,18 @@ class Client:
|
||||
return ray.get(self._controller.get_all_backends.remote())
|
||||
|
||||
@_ensure_connected
|
||||
def delete_backend(self, backend_tag: str) -> None:
|
||||
def delete_backend(self, backend_tag: str, force: bool = False) -> None:
|
||||
"""Delete the given backend.
|
||||
|
||||
The backend must not currently be used by any endpoints.
|
||||
|
||||
Args:
|
||||
backend_tag (str): The backend tag to be deleted.
|
||||
force (bool): Whether or not to force the deletion, without waiting
|
||||
for graceful shutdown. Default to false.
|
||||
"""
|
||||
self._get_result(self._controller.delete_backend.remote(backend_tag))
|
||||
self._get_result(
|
||||
self._controller.delete_backend.remote(backend_tag, force))
|
||||
|
||||
@_ensure_connected
|
||||
def set_traffic(self, endpoint_name: str,
|
||||
|
||||
@@ -126,6 +126,9 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
|
||||
def ready(self):
|
||||
pass
|
||||
|
||||
async def drain_pending_queries(self):
|
||||
return await self.backend.drain_pending_queries()
|
||||
|
||||
RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format(
|
||||
func_or_class.__name__)
|
||||
return RayServeWrappedReplica
|
||||
@@ -410,3 +413,25 @@ class RayServeReplica:
|
||||
|
||||
self.num_ongoing_requests -= 1
|
||||
return result
|
||||
|
||||
async def drain_pending_queries(self):
|
||||
"""Perform graceful shutdown.
|
||||
|
||||
Trigger a graceful shutdown protocol that will wait for all the queued
|
||||
tasks to be completed and return to the controller.
|
||||
"""
|
||||
sleep_time = self.config.experimental_graceful_shutdown_wait_loop_s
|
||||
while True:
|
||||
# Sleep first because we want to make sure all the routers receive
|
||||
# the notification to remove this replica first.
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
num_queries_waiting = self.batch_queue.qsize()
|
||||
if (num_queries_waiting == 0) and (self.num_ongoing_requests == 0):
|
||||
break
|
||||
else:
|
||||
logger.info(
|
||||
f"Waiting for an additional {sleep_time}s "
|
||||
f"to shutdown replica {self.replica_tag} because "
|
||||
f"num_queries_waiting {num_queries_waiting} and "
|
||||
f"num_ongoing_requests {self.num_ongoing_requests}")
|
||||
|
||||
+25
-20
@@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, PositiveInt, validator
|
||||
from pydantic import BaseModel, PositiveInt, validator, PositiveFloat
|
||||
from ray.serve.constants import ASYNC_CONCURRENCY
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
@@ -30,25 +30,27 @@ class BackendMetadata:
|
||||
class BackendConfig(BaseModel):
|
||||
"""Configuration options for a backend, to be set by the user.
|
||||
|
||||
:param num_replicas: The number of processes to start up that will
|
||||
handle requests to this backend. Defaults to 0.
|
||||
:type num_replicas: int, optional
|
||||
:param max_batch_size: The maximum number of requests that will be
|
||||
processed in one batch by this backend. Defaults to None (no
|
||||
maximium).
|
||||
:type max_batch_size: int, optional
|
||||
:param batch_wait_timeout: The time in seconds that backend replicas will
|
||||
wait for a full batch of requests before processing a partial batch.
|
||||
Defaults to 0.
|
||||
:type batch_wait_timeout: float, optional
|
||||
:param max_concurrent_queries: The maximum number of queries that will be
|
||||
sent to a replica of this backend without receiving a response.
|
||||
Defaults to None (no maximum).
|
||||
:type max_concurrent_queries: int, optional
|
||||
:param user_config: Arguments to pass to the reconfigure method of the
|
||||
backend. The reconfigure method is called if user_config is not
|
||||
None.
|
||||
:type user_config: Any, optional
|
||||
Args:
|
||||
num_replicas (Optional[int]): The number of processes to start up that
|
||||
will handle requests to this backend. Defaults to 0.
|
||||
max_batch_size (Optional[int]): The maximum number of requests that
|
||||
will be processed in one batch by this backend. Defaults to None
|
||||
(no maximium).
|
||||
batch_wait_timeout (Optional[float]): The time in seconds that backend
|
||||
replicas will wait for a full batch of requests before processing a
|
||||
partial batch. Defaults to 0.
|
||||
max_concurrent_queries (Optional[int]): The maximum number of queries
|
||||
that will be sent to a replica of this backend without receiving a
|
||||
response. Defaults to None (no maximum).
|
||||
user_config (Optional[Any]): Arguments to pass to the reconfigure
|
||||
method of the backend. The reconfigure method is called if
|
||||
user_config is not None.
|
||||
experimental_graceful_shutdown_wait_loop_s (Optional[float]): Duration
|
||||
that backend workers will wait until there is no more work to be
|
||||
done before shutting down. Defaults to 2s.
|
||||
experimental_graceful_shutdown_timeout_s (Optional[float]):
|
||||
Controller waits for this duration to forcefully kill the replica
|
||||
for shutdown. Defaults to 20s.
|
||||
"""
|
||||
|
||||
internal_metadata: BackendMetadata = BackendMetadata()
|
||||
@@ -58,6 +60,9 @@ class BackendConfig(BaseModel):
|
||||
max_concurrent_queries: Optional[int] = None
|
||||
user_config: Any = None
|
||||
|
||||
experimental_graceful_shutdown_wait_loop_s: PositiveFloat = 2.0
|
||||
experimental_graceful_shutdown_timeout_s: PositiveFloat = 20.0
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
extra = "forbid"
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
from asyncio.futures import Future
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from typing import Dict, Any, List, Optional, Set, Tuple
|
||||
from uuid import uuid4, UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -13,8 +14,11 @@ import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.serve.autoscaling_policy import BasicAutoscalingPolicy
|
||||
from ray.serve.backend_worker import create_backend_replica
|
||||
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME,
|
||||
LongPollKey)
|
||||
from ray.serve.constants import (
|
||||
ASYNC_CONCURRENCY,
|
||||
SERVE_PROXY_NAME,
|
||||
LongPollKey,
|
||||
)
|
||||
from ray.serve.http_proxy import HTTPProxyActor
|
||||
from ray.serve.kv_store import RayInternalKVStore
|
||||
from ray.serve.exceptions import RayServeException
|
||||
@@ -46,6 +50,7 @@ EndpointTag = str
|
||||
ReplicaTag = str
|
||||
NodeId = str
|
||||
GoalId = int
|
||||
Duration = float
|
||||
|
||||
|
||||
class TrafficPolicy:
|
||||
@@ -230,8 +235,9 @@ class ActorStateReconciler:
|
||||
default_factory=lambda: defaultdict(dict))
|
||||
backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field(
|
||||
default_factory=lambda: defaultdict(list))
|
||||
backend_replicas_to_stop: Dict[BackendTag, List[ReplicaTag]] = field(
|
||||
default_factory=lambda: defaultdict(list))
|
||||
backend_replicas_to_stop: Dict[BackendTag, List[Tuple[
|
||||
ReplicaTag, Duration]]] = field(
|
||||
default_factory=lambda: defaultdict(list))
|
||||
backends_to_remove: List[BackendTag] = field(default_factory=list)
|
||||
|
||||
# NOTE(ilr): These are not checkpointed, but will be recreated by
|
||||
@@ -300,9 +306,13 @@ class ActorStateReconciler:
|
||||
|
||||
return replica_handle
|
||||
|
||||
def _scale_backend_replicas(self, backends: Dict[BackendTag, BackendInfo],
|
||||
backend_tag: BackendTag,
|
||||
num_replicas: int) -> None:
|
||||
def _scale_backend_replicas(
|
||||
self,
|
||||
backends: Dict[BackendTag, BackendInfo],
|
||||
backend_tag: BackendTag,
|
||||
num_replicas: int,
|
||||
force_kill: bool = False,
|
||||
) -> None:
|
||||
"""Scale the given backend to the number of replicas.
|
||||
|
||||
NOTE: this does not actually start or stop the replicas, but instead
|
||||
@@ -323,7 +333,7 @@ class ActorStateReconciler:
|
||||
current_num_replicas = len(self.backend_replicas[backend_tag])
|
||||
delta_num_replicas = num_replicas - current_num_replicas
|
||||
|
||||
backend_info = backends[backend_tag]
|
||||
backend_info: BackendInfo = backends[backend_tag]
|
||||
if delta_num_replicas > 0:
|
||||
can_schedule = try_schedule_resources_on_nodes(requirements=[
|
||||
backend_info.replica_config.resource_dict
|
||||
@@ -357,7 +367,14 @@ class ActorStateReconciler:
|
||||
if len(self.backend_replicas[backend_tag]) == 0:
|
||||
del self.backend_replicas[backend_tag]
|
||||
|
||||
self.backend_replicas_to_stop[backend_tag].append(replica_tag)
|
||||
graceful_timeout_s = (backend_info.backend_config.
|
||||
experimental_graceful_shutdown_timeout_s)
|
||||
if force_kill:
|
||||
graceful_timeout_s = 0
|
||||
self.backend_replicas_to_stop[backend_tag].append((
|
||||
replica_tag,
|
||||
graceful_timeout_s,
|
||||
))
|
||||
|
||||
async def _enqueue_pending_scale_changes_loop(self,
|
||||
backend_state: BackendState):
|
||||
@@ -370,9 +387,9 @@ class ActorStateReconciler:
|
||||
self.currently_starting_replicas[ready_future] = (
|
||||
backend_tag, replica_tag, replica_handle)
|
||||
|
||||
for backend_tag, replicas_to_stop in self.backend_replicas_to_stop.\
|
||||
items():
|
||||
for replica_tag in replicas_to_stop:
|
||||
for backend_tag, replicas_to_stop in (
|
||||
self.backend_replicas_to_stop.items()):
|
||||
for replica_tag, shutdown_timeout in replicas_to_stop:
|
||||
replica_name = format_actor_name(replica_tag,
|
||||
self.controller_name)
|
||||
|
||||
@@ -384,19 +401,24 @@ class ActorStateReconciler:
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
# TODO(edoakes): this logic isn't ideal because there may
|
||||
# be pending tasks still executing on the replica. However,
|
||||
# if we use replica.__ray_terminate__, we may send it while
|
||||
# the replica is being restarted and there's no way to tell
|
||||
# if it successfully killed the worker or not.
|
||||
ray.kill(replica, no_restart=True)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
replica.drain_pending_queries.remote(),
|
||||
timeout=shutdown_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# Graceful period passed, kill it forcefully.
|
||||
logger.debug(
|
||||
f"{replica_name_to_use} did not shutdown after "
|
||||
f"{shutdown_timeout}s, killing.")
|
||||
finally:
|
||||
ray.kill(replica, no_restart=True)
|
||||
|
||||
self.currently_stopping_replicas[asyncio.ensure_future(
|
||||
kill_actor(replica_name))] = (backend_tag, replica_tag)
|
||||
|
||||
async def _check_currently_starting_replicas(self) -> bool:
|
||||
"""Returns a boolean specifying if there are more replicas to start"""
|
||||
in_flight = list()
|
||||
async def _check_currently_starting_replicas(self) -> int:
|
||||
"""Returns the number of pending replicas waiting to start"""
|
||||
in_flight: Set[Future[Any]] = set()
|
||||
|
||||
if self.currently_starting_replicas:
|
||||
done, in_flight = await asyncio.wait(
|
||||
@@ -415,11 +437,12 @@ class ActorStateReconciler:
|
||||
pass
|
||||
if len(backend) == 0:
|
||||
del self.backend_replicas_to_start[backend_tag]
|
||||
return len(in_flight) > 0
|
||||
return len(in_flight)
|
||||
|
||||
async def _check_currently_stopping_replicas(self) -> int:
|
||||
"""Returns the number of replicas waiting to stop"""
|
||||
in_flight: Set[Future[Any]] = set()
|
||||
|
||||
async def _check_currently_stopping_replicas(self) -> bool:
|
||||
"""Returns a boolean specifying if there are more replicas to stop"""
|
||||
in_flight = list()
|
||||
if self.currently_stopping_replicas:
|
||||
done_stoppping, in_flight = await asyncio.wait(
|
||||
list(self.currently_stopping_replicas.keys()), timeout=0)
|
||||
@@ -437,22 +460,27 @@ class ActorStateReconciler:
|
||||
if len(backend) == 0:
|
||||
del self.backend_replicas_to_stop[backend_tag]
|
||||
|
||||
return len(in_flight) > 0
|
||||
return len(in_flight)
|
||||
|
||||
async def backend_control_loop(self):
|
||||
start = time.time()
|
||||
prev_warning = start
|
||||
need_to_continue = True
|
||||
num_pending_starts, num_pending_stops = 0, 0
|
||||
while need_to_continue:
|
||||
if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
|
||||
prev_warning = time.time()
|
||||
logger.warning("Waited {:.2f}s for replicas to start up. Make "
|
||||
"sure there are enough resources to create the "
|
||||
"replicas.".format(time.time() - start))
|
||||
delta = time.time() - start
|
||||
logger.warning(
|
||||
f"Waited {delta:.2f}s for {num_pending_starts} replicas "
|
||||
f"to start up or {num_pending_stops} replicas to shutdown."
|
||||
" Make sure there are enough resources to create the "
|
||||
"replicas.")
|
||||
|
||||
need_to_continue = (
|
||||
await self._check_currently_starting_replicas()
|
||||
or await self._check_currently_stopping_replicas())
|
||||
num_pending_starts = await self._check_currently_starting_replicas(
|
||||
)
|
||||
num_pending_stops = await self._check_currently_stopping_replicas()
|
||||
need_to_continue = num_pending_starts or num_pending_stops
|
||||
|
||||
asyncio.sleep(1)
|
||||
|
||||
@@ -952,7 +980,9 @@ class ServeController:
|
||||
self.notify_backend_configs_changed()
|
||||
return return_uuid
|
||||
|
||||
async def delete_backend(self, backend_tag: BackendTag) -> UUID:
|
||||
async def delete_backend(self,
|
||||
backend_tag: BackendTag,
|
||||
force_kill: bool = False) -> UUID:
|
||||
async with self.write_lock:
|
||||
# This method must be idempotent. We should validate that the
|
||||
# specified backend exists on the client.
|
||||
@@ -975,7 +1005,7 @@ class ServeController:
|
||||
|
||||
# This should be a call to the control loop
|
||||
self.actor_reconciler._scale_backend_replicas(
|
||||
self.backend_state.backends, backend_tag, 0)
|
||||
self.backend_state.backends, backend_tag, 0, force_kill)
|
||||
|
||||
# Remove the backend's metadata.
|
||||
del self.backend_state.backends[backend_tag]
|
||||
|
||||
+50
-19
@@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, DefaultDict, Dict, Iterable, List, Optional
|
||||
from typing import Any, ChainMap, Dict, Iterable, List, Optional
|
||||
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
@@ -10,7 +13,7 @@ from ray.serve.constants import LongPollKey
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy
|
||||
from ray.serve.long_poll import LongPollAsyncClient
|
||||
from ray.serve.utils import logger
|
||||
from ray.serve.utils import logger, compute_dict_delta, compute_iterable_delta
|
||||
from ray.util import metrics
|
||||
|
||||
REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
|
||||
@@ -76,22 +79,16 @@ class ReplicaSet:
|
||||
self.config_updated_event.set()
|
||||
|
||||
def update_worker_replicas(self, worker_replicas: Iterable[ActorHandle]):
|
||||
current_replica_set = set(self.in_flight_queries.keys())
|
||||
updated_replica_set = set(worker_replicas)
|
||||
added, removed, _ = compute_iterable_delta(
|
||||
self.in_flight_queries.keys(), worker_replicas)
|
||||
|
||||
added = updated_replica_set - current_replica_set
|
||||
for new_replica_handle in added:
|
||||
self.in_flight_queries[new_replica_handle] = set()
|
||||
|
||||
removed = current_replica_set - updated_replica_set
|
||||
for removed_replica_handle in removed:
|
||||
# NOTE(simon): Do we warn if there are still inflight queries?
|
||||
# The current approach is no because the queries objectrefs are
|
||||
# just used to perform backpressure. Caller should decide what to
|
||||
# do with the object refs.
|
||||
# Delete it directly because shutdown is processed by controller.
|
||||
del self.in_flight_queries[removed_replica_handle]
|
||||
|
||||
# State changed, reset the round robin iterator
|
||||
if len(added) > 0 or len(removed) > 0:
|
||||
self.replica_iterator = itertools.cycle(
|
||||
self.in_flight_queries.keys())
|
||||
@@ -156,6 +153,12 @@ class ReplicaSet:
|
||||
return assigned_ref
|
||||
|
||||
|
||||
class _PendingEndpointFound(Enum):
|
||||
"""Enum for the status of pending endpoint registration."""
|
||||
ADDED = 1
|
||||
REMOVED = 2
|
||||
|
||||
|
||||
class Router:
|
||||
def __init__(self, controller_handle: ActorHandle):
|
||||
"""Router process incoming queries: choose backend, and assign replica.
|
||||
@@ -168,8 +171,7 @@ class Router:
|
||||
self.endpoint_policies: Dict[str, EndpointPolicy] = dict()
|
||||
self.backend_replicas: Dict[str, ReplicaSet] = defaultdict(ReplicaSet)
|
||||
|
||||
self._pending_endpoints: DefaultDict[str, asyncio.Event] = defaultdict(
|
||||
asyncio.Event)
|
||||
self._pending_endpoints: Dict[str, asyncio.Future] = dict()
|
||||
|
||||
# -- Metrics Registration -- #
|
||||
self.num_router_requests = metrics.Count(
|
||||
@@ -190,23 +192,45 @@ class Router:
|
||||
})
|
||||
|
||||
async def _update_traffic_policies(self, traffic_policies):
|
||||
for endpoint, traffic_policy in traffic_policies.items():
|
||||
added, removed, updated = compute_dict_delta(self.endpoint_policies,
|
||||
traffic_policies)
|
||||
|
||||
for endpoint, traffic_policy in ChainMap(added, updated).items():
|
||||
self.endpoint_policies[endpoint] = RandomEndpointPolicy(
|
||||
traffic_policy)
|
||||
if endpoint in self._pending_endpoints:
|
||||
event = self._pending_endpoints.pop(endpoint)
|
||||
event.set()
|
||||
future = self._pending_endpoints.pop(endpoint)
|
||||
future.set_result(_PendingEndpointFound.ADDED)
|
||||
|
||||
for endpoint, traffic_policy in removed.items():
|
||||
del self.endpoint_policies[endpoint]
|
||||
if endpoint in self._pending_endpoints:
|
||||
future = self._pending_endpoints.pop(endpoint)
|
||||
future.set_result(_PendingEndpointFound.REMOVED)
|
||||
|
||||
async def _update_replica_handles(self, replica_handles):
|
||||
for backend_tag, replica_handles in replica_handles.items():
|
||||
added, removed, updated = compute_dict_delta(self.backend_replicas,
|
||||
replica_handles)
|
||||
|
||||
for backend_tag, replica_handles in ChainMap(added, updated).items():
|
||||
self.backend_replicas[backend_tag].update_worker_replicas(
|
||||
replica_handles)
|
||||
|
||||
for backend_tag in removed.keys():
|
||||
if backend_tag in self.backend_replicas:
|
||||
del self.backend_replicas[backend_tag]
|
||||
|
||||
async def _update_backend_configs(self, backend_configs):
|
||||
for backend_tag, config in backend_configs.items():
|
||||
added, removed, updated = compute_dict_delta(self.backend_replicas,
|
||||
backend_configs)
|
||||
for backend_tag, config in ChainMap(added, updated).items():
|
||||
self.backend_replicas[backend_tag].set_max_concurrent_queries(
|
||||
config.max_concurrent_queries)
|
||||
|
||||
for backend_tag in removed.keys():
|
||||
if backend_tag in self.backend_replicas:
|
||||
del self.backend_replicas[backend_tag]
|
||||
|
||||
async def assign_request(
|
||||
self,
|
||||
request_meta: RequestMetadata,
|
||||
@@ -226,7 +250,14 @@ class Router:
|
||||
logger.info(
|
||||
f"Endpoint {endpoint} doesn't exist, waiting for registration."
|
||||
)
|
||||
await self._pending_endpoints[endpoint].wait()
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
if endpoint not in self._pending_endpoints:
|
||||
self._pending_endpoints[endpoint] = future
|
||||
endpoint_status = await self._pending_endpoints[endpoint]
|
||||
if endpoint_status == _PendingEndpointFound.REMOVED:
|
||||
raise RayServeException(
|
||||
f"Endpoint {endpoint} was removed. This request "
|
||||
"cannot be completed.")
|
||||
|
||||
endpoint_policy = self.endpoint_policies[endpoint]
|
||||
chosen_backend, *shadow_backends = endpoint_policy.assign(query)
|
||||
|
||||
@@ -35,7 +35,7 @@ def serve_instance(_shared_serve_instance):
|
||||
for endpoint in ray.get(controller.get_all_endpoints.remote()):
|
||||
_shared_serve_instance.delete_endpoint(endpoint)
|
||||
for backend in ray.get(controller.get_all_backends.remote()).keys():
|
||||
_shared_serve_instance.delete_backend(backend)
|
||||
_shared_serve_instance.delete_backend(backend, force=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -2,13 +2,14 @@ import asyncio
|
||||
from collections import defaultdict
|
||||
import time
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import requests
|
||||
import pytest
|
||||
import starlette.responses
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.test_utils import SignalActor, wait_for_condition
|
||||
from ray.serve.constants import SERVE_PROXY_NAME
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig
|
||||
@@ -871,6 +872,79 @@ def test_serve_metrics(serve_instance):
|
||||
verify_metrics()
|
||||
|
||||
|
||||
def test_serve_graceful_shutdown(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
signal = SignalActor.remote()
|
||||
|
||||
class WaitBackend:
|
||||
@serve.accept_batch
|
||||
async def __call__(self, requests):
|
||||
signal_actor = await requests[0].body()
|
||||
await signal_actor.wait.remote()
|
||||
return ["" for _ in range(len(requests))]
|
||||
|
||||
client.create_backend(
|
||||
"wait",
|
||||
WaitBackend,
|
||||
config=BackendConfig(
|
||||
# Make sure we can queue up queries in the replica side.
|
||||
max_concurrent_queries=10,
|
||||
max_batch_size=1,
|
||||
experimental_graceful_shutdown_wait_loop_s=0.5,
|
||||
experimental_graceful_shutdown_timeout_s=1000,
|
||||
))
|
||||
client.create_endpoint("wait", backend="wait")
|
||||
handle = client.get_handle("wait")
|
||||
refs = [handle.remote(signal) for _ in range(10)]
|
||||
|
||||
# Wait for all the queries to be enqueued
|
||||
with pytest.raises(ray.exceptions.GetTimeoutError):
|
||||
ray.get(refs, timeout=1)
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
def do_blocking_delete():
|
||||
client = serve.connect()
|
||||
client.delete_endpoint("wait")
|
||||
client.delete_backend("wait")
|
||||
|
||||
# Now delete the backend. This should trigger the shutdown sequence.
|
||||
delete_ref = do_blocking_delete.remote()
|
||||
|
||||
# The queries should be enqueued but not executed becuase they are blocked
|
||||
# by signal actor.
|
||||
with pytest.raises(ray.exceptions.GetTimeoutError):
|
||||
ray.get(refs, timeout=1)
|
||||
|
||||
signal.send.remote()
|
||||
|
||||
# All the queries should be drained and executed without error.
|
||||
ray.get(refs)
|
||||
# Blocking delete should complete.
|
||||
ray.get(delete_ref)
|
||||
|
||||
|
||||
def test_serve_forceful_shutdown(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def sleeper(_):
|
||||
while True:
|
||||
time.sleep(1000)
|
||||
|
||||
client.create_backend(
|
||||
"sleeper",
|
||||
sleeper,
|
||||
config=BackendConfig(experimental_graceful_shutdown_timeout_s=1))
|
||||
client.create_endpoint("sleeper", backend="sleeper")
|
||||
handle = client.get_handle("sleeper")
|
||||
ref = handle.remote()
|
||||
client.delete_endpoint("sleeper")
|
||||
client.delete_backend("sleeper")
|
||||
|
||||
with pytest.raises(ray.exceptions.RayActorError):
|
||||
ray.get(ref)
|
||||
|
||||
|
||||
def test_starlette_request(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
|
||||
@@ -40,6 +40,9 @@ def setup_worker(name,
|
||||
def update_config(self, new_config):
|
||||
return self.worker.update_config(new_config)
|
||||
|
||||
async def drain_pending_queries(self):
|
||||
return await self.worker.drain_pending_queries()
|
||||
|
||||
worker = WorkerActor.remote()
|
||||
ray.get(worker.ready.remote())
|
||||
return worker
|
||||
@@ -312,6 +315,51 @@ async def test_user_config_update(serve_instance, router,
|
||||
assert await i == "new_val"
|
||||
|
||||
|
||||
async def test_graceful_shutdown(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
class KeepInflight:
|
||||
def __init__(self):
|
||||
self.events = []
|
||||
|
||||
def reconfigure(self, config):
|
||||
if config["release"]:
|
||||
[event.set() for event in self.events]
|
||||
|
||||
async def __call__(self, _):
|
||||
e = asyncio.Event()
|
||||
self.events.append(e)
|
||||
await e.wait()
|
||||
|
||||
backend_worker = await add_servable_to_router(
|
||||
KeepInflight,
|
||||
router,
|
||||
mock_controller_with_name[0],
|
||||
backend_config=BackendConfig(
|
||||
num_replicas=1,
|
||||
internal_metadata=BackendMetadata(is_blocking=False),
|
||||
user_config={"release": False}))
|
||||
|
||||
query_param = make_request_param()
|
||||
|
||||
refs = [(await router.assign_request.remote(query_param))
|
||||
for _ in range(6)]
|
||||
|
||||
shutdown_ref = backend_worker.drain_pending_queries.remote()
|
||||
|
||||
with pytest.raises(ray.exceptions.GetTimeoutError):
|
||||
# Shutdown should block because there are still inflight queries.
|
||||
ray.get(shutdown_ref, timeout=2)
|
||||
|
||||
config = BackendConfig()
|
||||
config.user_config = {"release": True}
|
||||
await mock_controller_with_name[1].update_backend.remote("backend", config)
|
||||
|
||||
# All queries should complete successfully
|
||||
ray.get(refs)
|
||||
# The draining operation should be completed.
|
||||
ray.get(shutdown_ref)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from typing import List, Dict
|
||||
from typing import Iterable, List, Dict, Tuple
|
||||
import os
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from collections import UserDict
|
||||
@@ -382,3 +382,41 @@ class MockImportedBackend:
|
||||
|
||||
def other_method(self, request):
|
||||
return request.data
|
||||
|
||||
|
||||
def compute_iterable_delta(old: Iterable,
|
||||
new: Iterable) -> Tuple[set, set, set]:
|
||||
"""Given two iterables, return the entries that's (added, removed, updated).
|
||||
|
||||
Usage:
|
||||
>>> old = {"a", "b"}
|
||||
>>> new = {"a", "d"}
|
||||
>>> compute_dict_delta(old, new)
|
||||
({"d"}, {"b"}, {"a"})
|
||||
"""
|
||||
old_keys, new_keys = set(old), set(new)
|
||||
added_keys = new_keys - old_keys
|
||||
removed_keys = old_keys - new_keys
|
||||
updated_keys = old_keys.intersection(new_keys)
|
||||
return added_keys, removed_keys, updated_keys
|
||||
|
||||
|
||||
def compute_dict_delta(old_dict, new_dict) -> Tuple[dict, dict, dict]:
|
||||
"""Given two dicts, return the entries that's (added, removed, updated).
|
||||
|
||||
Usage:
|
||||
>>> old = {"a": 1, "b": 2}
|
||||
>>> new = {"a": 3, "d": 4}
|
||||
>>> compute_dict_delta(old, new)
|
||||
({"d": 4}, {"b": 2}, {"a": 3})
|
||||
"""
|
||||
added_keys, removed_keys, updated_keys = compute_iterable_delta(
|
||||
old_dict.keys(), new_dict.keys())
|
||||
return (
|
||||
{k: new_dict[k]
|
||||
for k in added_keys},
|
||||
{k: old_dict[k]
|
||||
for k in removed_keys},
|
||||
{k: new_dict[k]
|
||||
for k in updated_keys},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user