[Serve] Implement Graceful Shutdown (#13028)

This commit is contained in:
Simon Mo
2020-12-28 17:53:53 -08:00
committed by GitHub
parent 350917958c
commit 30c22921d9
9 changed files with 338 additions and 81 deletions
+9 -3
View File
@@ -329,7 +329,7 @@ class Client:
func_or_class (callable, class): a function or a class implementing
__call__, returning a JSON-serializable object or a
Starlette Response object.
actor_init_args (optional): the arguments to pass to the class.
*actor_init_args (optional): the arguments to pass to the class
initialization method.
ray_actor_options (optional): options to be passed into the
@ray.remote decorator for the backend actor.
@@ -409,12 +409,18 @@ class Client:
return ray.get(self._controller.get_all_backends.remote())
@_ensure_connected
def delete_backend(self, backend_tag: str) -> None:
def delete_backend(self, backend_tag: str, force: bool = False) -> None:
"""Delete the given backend.
The backend must not currently be used by any endpoints.
Args:
backend_tag (str): The backend tag to be deleted.
force (bool): Whether or not to force the deletion, without waiting
for graceful shutdown. Default to false.
"""
self._get_result(self._controller.delete_backend.remote(backend_tag))
self._get_result(
self._controller.delete_backend.remote(backend_tag, force))
@_ensure_connected
def set_traffic(self, endpoint_name: str,
+25
View File
@@ -126,6 +126,9 @@ def create_backend_replica(func_or_class: Union[Callable, Type[Callable]]):
def ready(self):
pass
async def drain_pending_queries(self):
return await self.backend.drain_pending_queries()
RayServeWrappedReplica.__name__ = "RayServeReplica_{}".format(
func_or_class.__name__)
return RayServeWrappedReplica
@@ -410,3 +413,25 @@ class RayServeReplica:
self.num_ongoing_requests -= 1
return result
async def drain_pending_queries(self):
"""Perform graceful shutdown.
Trigger a graceful shutdown protocol that will wait for all the queued
tasks to be completed and return to the controller.
"""
sleep_time = self.config.experimental_graceful_shutdown_wait_loop_s
while True:
# Sleep first because we want to make sure all the routers receive
# the notification to remove this replica first.
await asyncio.sleep(sleep_time)
num_queries_waiting = self.batch_queue.qsize()
if (num_queries_waiting == 0) and (self.num_ongoing_requests == 0):
break
else:
logger.info(
f"Waiting for an additional {sleep_time}s "
f"to shutdown replica {self.replica_tag} because "
f"num_queries_waiting {num_queries_waiting} and "
f"num_ongoing_requests {self.num_ongoing_requests}")
+25 -20
View File
@@ -1,6 +1,6 @@
import inspect
from pydantic import BaseModel, PositiveInt, validator
from pydantic import BaseModel, PositiveInt, validator, PositiveFloat
from ray.serve.constants import ASYNC_CONCURRENCY
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
@@ -30,25 +30,27 @@ class BackendMetadata:
class BackendConfig(BaseModel):
"""Configuration options for a backend, to be set by the user.
:param num_replicas: The number of processes to start up that will
handle requests to this backend. Defaults to 0.
:type num_replicas: int, optional
:param max_batch_size: The maximum number of requests that will be
processed in one batch by this backend. Defaults to None (no
maximium).
:type max_batch_size: int, optional
:param batch_wait_timeout: The time in seconds that backend replicas will
wait for a full batch of requests before processing a partial batch.
Defaults to 0.
:type batch_wait_timeout: float, optional
:param max_concurrent_queries: The maximum number of queries that will be
sent to a replica of this backend without receiving a response.
Defaults to None (no maximum).
:type max_concurrent_queries: int, optional
:param user_config: Arguments to pass to the reconfigure method of the
backend. The reconfigure method is called if user_config is not
None.
:type user_config: Any, optional
Args:
num_replicas (Optional[int]): The number of processes to start up that
will handle requests to this backend. Defaults to 0.
max_batch_size (Optional[int]): The maximum number of requests that
will be processed in one batch by this backend. Defaults to None
(no maximium).
batch_wait_timeout (Optional[float]): The time in seconds that backend
replicas will wait for a full batch of requests before processing a
partial batch. Defaults to 0.
max_concurrent_queries (Optional[int]): The maximum number of queries
that will be sent to a replica of this backend without receiving a
response. Defaults to None (no maximum).
user_config (Optional[Any]): Arguments to pass to the reconfigure
method of the backend. The reconfigure method is called if
user_config is not None.
experimental_graceful_shutdown_wait_loop_s (Optional[float]): Duration
that backend workers will wait until there is no more work to be
done before shutting down. Defaults to 2s.
experimental_graceful_shutdown_timeout_s (Optional[float]):
Controller waits for this duration to forcefully kill the replica
for shutdown. Defaults to 20s.
"""
internal_metadata: BackendMetadata = BackendMetadata()
@@ -58,6 +60,9 @@ class BackendConfig(BaseModel):
max_concurrent_queries: Optional[int] = None
user_config: Any = None
experimental_graceful_shutdown_wait_loop_s: PositiveFloat = 2.0
experimental_graceful_shutdown_timeout_s: PositiveFloat = 20.0
class Config:
validate_assignment = True
extra = "forbid"
+65 -35
View File
@@ -1,11 +1,12 @@
import asyncio
from asyncio.futures import Future
from collections import defaultdict
from itertools import chain
import os
import random
import time
from dataclasses import dataclass, field
from typing import Dict, Any, List, Optional, Tuple
from typing import Dict, Any, List, Optional, Set, Tuple
from uuid import uuid4, UUID
from pydantic import BaseModel
@@ -13,8 +14,11 @@ import ray
import ray.cloudpickle as pickle
from ray.serve.autoscaling_policy import BasicAutoscalingPolicy
from ray.serve.backend_worker import create_backend_replica
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME,
LongPollKey)
from ray.serve.constants import (
ASYNC_CONCURRENCY,
SERVE_PROXY_NAME,
LongPollKey,
)
from ray.serve.http_proxy import HTTPProxyActor
from ray.serve.kv_store import RayInternalKVStore
from ray.serve.exceptions import RayServeException
@@ -46,6 +50,7 @@ EndpointTag = str
ReplicaTag = str
NodeId = str
GoalId = int
Duration = float
class TrafficPolicy:
@@ -230,8 +235,9 @@ class ActorStateReconciler:
default_factory=lambda: defaultdict(dict))
backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field(
default_factory=lambda: defaultdict(list))
backend_replicas_to_stop: Dict[BackendTag, List[ReplicaTag]] = field(
default_factory=lambda: defaultdict(list))
backend_replicas_to_stop: Dict[BackendTag, List[Tuple[
ReplicaTag, Duration]]] = field(
default_factory=lambda: defaultdict(list))
backends_to_remove: List[BackendTag] = field(default_factory=list)
# NOTE(ilr): These are not checkpointed, but will be recreated by
@@ -300,9 +306,13 @@ class ActorStateReconciler:
return replica_handle
def _scale_backend_replicas(self, backends: Dict[BackendTag, BackendInfo],
backend_tag: BackendTag,
num_replicas: int) -> None:
def _scale_backend_replicas(
self,
backends: Dict[BackendTag, BackendInfo],
backend_tag: BackendTag,
num_replicas: int,
force_kill: bool = False,
) -> None:
"""Scale the given backend to the number of replicas.
NOTE: this does not actually start or stop the replicas, but instead
@@ -323,7 +333,7 @@ class ActorStateReconciler:
current_num_replicas = len(self.backend_replicas[backend_tag])
delta_num_replicas = num_replicas - current_num_replicas
backend_info = backends[backend_tag]
backend_info: BackendInfo = backends[backend_tag]
if delta_num_replicas > 0:
can_schedule = try_schedule_resources_on_nodes(requirements=[
backend_info.replica_config.resource_dict
@@ -357,7 +367,14 @@ class ActorStateReconciler:
if len(self.backend_replicas[backend_tag]) == 0:
del self.backend_replicas[backend_tag]
self.backend_replicas_to_stop[backend_tag].append(replica_tag)
graceful_timeout_s = (backend_info.backend_config.
experimental_graceful_shutdown_timeout_s)
if force_kill:
graceful_timeout_s = 0
self.backend_replicas_to_stop[backend_tag].append((
replica_tag,
graceful_timeout_s,
))
async def _enqueue_pending_scale_changes_loop(self,
backend_state: BackendState):
@@ -370,9 +387,9 @@ class ActorStateReconciler:
self.currently_starting_replicas[ready_future] = (
backend_tag, replica_tag, replica_handle)
for backend_tag, replicas_to_stop in self.backend_replicas_to_stop.\
items():
for replica_tag in replicas_to_stop:
for backend_tag, replicas_to_stop in (
self.backend_replicas_to_stop.items()):
for replica_tag, shutdown_timeout in replicas_to_stop:
replica_name = format_actor_name(replica_tag,
self.controller_name)
@@ -384,19 +401,24 @@ class ActorStateReconciler:
except ValueError:
return
# TODO(edoakes): this logic isn't ideal because there may
# be pending tasks still executing on the replica. However,
# if we use replica.__ray_terminate__, we may send it while
# the replica is being restarted and there's no way to tell
# if it successfully killed the worker or not.
ray.kill(replica, no_restart=True)
try:
await asyncio.wait_for(
replica.drain_pending_queries.remote(),
timeout=shutdown_timeout)
except asyncio.TimeoutError:
# Graceful period passed, kill it forcefully.
logger.debug(
f"{replica_name_to_use} did not shutdown after "
f"{shutdown_timeout}s, killing.")
finally:
ray.kill(replica, no_restart=True)
self.currently_stopping_replicas[asyncio.ensure_future(
kill_actor(replica_name))] = (backend_tag, replica_tag)
async def _check_currently_starting_replicas(self) -> bool:
"""Returns a boolean specifying if there are more replicas to start"""
in_flight = list()
async def _check_currently_starting_replicas(self) -> int:
"""Returns the number of pending replicas waiting to start"""
in_flight: Set[Future[Any]] = set()
if self.currently_starting_replicas:
done, in_flight = await asyncio.wait(
@@ -415,11 +437,12 @@ class ActorStateReconciler:
pass
if len(backend) == 0:
del self.backend_replicas_to_start[backend_tag]
return len(in_flight) > 0
return len(in_flight)
async def _check_currently_stopping_replicas(self) -> int:
"""Returns the number of replicas waiting to stop"""
in_flight: Set[Future[Any]] = set()
async def _check_currently_stopping_replicas(self) -> bool:
"""Returns a boolean specifying if there are more replicas to stop"""
in_flight = list()
if self.currently_stopping_replicas:
done_stoppping, in_flight = await asyncio.wait(
list(self.currently_stopping_replicas.keys()), timeout=0)
@@ -437,22 +460,27 @@ class ActorStateReconciler:
if len(backend) == 0:
del self.backend_replicas_to_stop[backend_tag]
return len(in_flight) > 0
return len(in_flight)
async def backend_control_loop(self):
start = time.time()
prev_warning = start
need_to_continue = True
num_pending_starts, num_pending_stops = 0, 0
while need_to_continue:
if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
prev_warning = time.time()
logger.warning("Waited {:.2f}s for replicas to start up. Make "
"sure there are enough resources to create the "
"replicas.".format(time.time() - start))
delta = time.time() - start
logger.warning(
f"Waited {delta:.2f}s for {num_pending_starts} replicas "
f"to start up or {num_pending_stops} replicas to shutdown."
" Make sure there are enough resources to create the "
"replicas.")
need_to_continue = (
await self._check_currently_starting_replicas()
or await self._check_currently_stopping_replicas())
num_pending_starts = await self._check_currently_starting_replicas(
)
num_pending_stops = await self._check_currently_stopping_replicas()
need_to_continue = num_pending_starts or num_pending_stops
asyncio.sleep(1)
@@ -952,7 +980,9 @@ class ServeController:
self.notify_backend_configs_changed()
return return_uuid
async def delete_backend(self, backend_tag: BackendTag) -> UUID:
async def delete_backend(self,
backend_tag: BackendTag,
force_kill: bool = False) -> UUID:
async with self.write_lock:
# This method must be idempotent. We should validate that the
# specified backend exists on the client.
@@ -975,7 +1005,7 @@ class ServeController:
# This should be a call to the control loop
self.actor_reconciler._scale_backend_replicas(
self.backend_state.backends, backend_tag, 0)
self.backend_state.backends, backend_tag, 0, force_kill)
# Remove the backend's metadata.
del self.backend_state.backends[backend_tag]
+50 -19
View File
@@ -1,8 +1,11 @@
import asyncio
from enum import Enum
import itertools
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, DefaultDict, Dict, Iterable, List, Optional
from typing import Any, ChainMap, Dict, Iterable, List, Optional
from ray.serve.exceptions import RayServeException
import ray
from ray.actor import ActorHandle
@@ -10,7 +13,7 @@ from ray.serve.constants import LongPollKey
from ray.serve.context import TaskContext
from ray.serve.endpoint_policy import EndpointPolicy, RandomEndpointPolicy
from ray.serve.long_poll import LongPollAsyncClient
from ray.serve.utils import logger
from ray.serve.utils import logger, compute_dict_delta, compute_iterable_delta
from ray.util import metrics
REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
@@ -76,22 +79,16 @@ class ReplicaSet:
self.config_updated_event.set()
def update_worker_replicas(self, worker_replicas: Iterable[ActorHandle]):
current_replica_set = set(self.in_flight_queries.keys())
updated_replica_set = set(worker_replicas)
added, removed, _ = compute_iterable_delta(
self.in_flight_queries.keys(), worker_replicas)
added = updated_replica_set - current_replica_set
for new_replica_handle in added:
self.in_flight_queries[new_replica_handle] = set()
removed = current_replica_set - updated_replica_set
for removed_replica_handle in removed:
# NOTE(simon): Do we warn if there are still inflight queries?
# The current approach is no because the queries objectrefs are
# just used to perform backpressure. Caller should decide what to
# do with the object refs.
# Delete it directly because shutdown is processed by controller.
del self.in_flight_queries[removed_replica_handle]
# State changed, reset the round robin iterator
if len(added) > 0 or len(removed) > 0:
self.replica_iterator = itertools.cycle(
self.in_flight_queries.keys())
@@ -156,6 +153,12 @@ class ReplicaSet:
return assigned_ref
class _PendingEndpointFound(Enum):
"""Enum for the status of pending endpoint registration."""
ADDED = 1
REMOVED = 2
class Router:
def __init__(self, controller_handle: ActorHandle):
"""Router process incoming queries: choose backend, and assign replica.
@@ -168,8 +171,7 @@ class Router:
self.endpoint_policies: Dict[str, EndpointPolicy] = dict()
self.backend_replicas: Dict[str, ReplicaSet] = defaultdict(ReplicaSet)
self._pending_endpoints: DefaultDict[str, asyncio.Event] = defaultdict(
asyncio.Event)
self._pending_endpoints: Dict[str, asyncio.Future] = dict()
# -- Metrics Registration -- #
self.num_router_requests = metrics.Count(
@@ -190,23 +192,45 @@ class Router:
})
async def _update_traffic_policies(self, traffic_policies):
for endpoint, traffic_policy in traffic_policies.items():
added, removed, updated = compute_dict_delta(self.endpoint_policies,
traffic_policies)
for endpoint, traffic_policy in ChainMap(added, updated).items():
self.endpoint_policies[endpoint] = RandomEndpointPolicy(
traffic_policy)
if endpoint in self._pending_endpoints:
event = self._pending_endpoints.pop(endpoint)
event.set()
future = self._pending_endpoints.pop(endpoint)
future.set_result(_PendingEndpointFound.ADDED)
for endpoint, traffic_policy in removed.items():
del self.endpoint_policies[endpoint]
if endpoint in self._pending_endpoints:
future = self._pending_endpoints.pop(endpoint)
future.set_result(_PendingEndpointFound.REMOVED)
async def _update_replica_handles(self, replica_handles):
for backend_tag, replica_handles in replica_handles.items():
added, removed, updated = compute_dict_delta(self.backend_replicas,
replica_handles)
for backend_tag, replica_handles in ChainMap(added, updated).items():
self.backend_replicas[backend_tag].update_worker_replicas(
replica_handles)
for backend_tag in removed.keys():
if backend_tag in self.backend_replicas:
del self.backend_replicas[backend_tag]
async def _update_backend_configs(self, backend_configs):
for backend_tag, config in backend_configs.items():
added, removed, updated = compute_dict_delta(self.backend_replicas,
backend_configs)
for backend_tag, config in ChainMap(added, updated).items():
self.backend_replicas[backend_tag].set_max_concurrent_queries(
config.max_concurrent_queries)
for backend_tag in removed.keys():
if backend_tag in self.backend_replicas:
del self.backend_replicas[backend_tag]
async def assign_request(
self,
request_meta: RequestMetadata,
@@ -226,7 +250,14 @@ class Router:
logger.info(
f"Endpoint {endpoint} doesn't exist, waiting for registration."
)
await self._pending_endpoints[endpoint].wait()
future = asyncio.get_event_loop().create_future()
if endpoint not in self._pending_endpoints:
self._pending_endpoints[endpoint] = future
endpoint_status = await self._pending_endpoints[endpoint]
if endpoint_status == _PendingEndpointFound.REMOVED:
raise RayServeException(
f"Endpoint {endpoint} was removed. This request "
"cannot be completed.")
endpoint_policy = self.endpoint_policies[endpoint]
chosen_backend, *shadow_backends = endpoint_policy.assign(query)
+1 -1
View File
@@ -35,7 +35,7 @@ def serve_instance(_shared_serve_instance):
for endpoint in ray.get(controller.get_all_endpoints.remote()):
_shared_serve_instance.delete_endpoint(endpoint)
for backend in ray.get(controller.get_all_backends.remote()).keys():
_shared_serve_instance.delete_backend(backend)
_shared_serve_instance.delete_backend(backend, force=True)
@pytest.fixture
+76 -2
View File
@@ -2,13 +2,14 @@ import asyncio
from collections import defaultdict
import time
import os
import pytest
import requests
import pytest
import starlette.responses
import ray
from ray import serve
from ray.test_utils import wait_for_condition
from ray.test_utils import SignalActor, wait_for_condition
from ray.serve.constants import SERVE_PROXY_NAME
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig
@@ -871,6 +872,79 @@ def test_serve_metrics(serve_instance):
verify_metrics()
def test_serve_graceful_shutdown(serve_instance):
client = serve_instance
signal = SignalActor.remote()
class WaitBackend:
@serve.accept_batch
async def __call__(self, requests):
signal_actor = await requests[0].body()
await signal_actor.wait.remote()
return ["" for _ in range(len(requests))]
client.create_backend(
"wait",
WaitBackend,
config=BackendConfig(
# Make sure we can queue up queries in the replica side.
max_concurrent_queries=10,
max_batch_size=1,
experimental_graceful_shutdown_wait_loop_s=0.5,
experimental_graceful_shutdown_timeout_s=1000,
))
client.create_endpoint("wait", backend="wait")
handle = client.get_handle("wait")
refs = [handle.remote(signal) for _ in range(10)]
# Wait for all the queries to be enqueued
with pytest.raises(ray.exceptions.GetTimeoutError):
ray.get(refs, timeout=1)
@ray.remote(num_cpus=0)
def do_blocking_delete():
client = serve.connect()
client.delete_endpoint("wait")
client.delete_backend("wait")
# Now delete the backend. This should trigger the shutdown sequence.
delete_ref = do_blocking_delete.remote()
# The queries should be enqueued but not executed becuase they are blocked
# by signal actor.
with pytest.raises(ray.exceptions.GetTimeoutError):
ray.get(refs, timeout=1)
signal.send.remote()
# All the queries should be drained and executed without error.
ray.get(refs)
# Blocking delete should complete.
ray.get(delete_ref)
def test_serve_forceful_shutdown(serve_instance):
client = serve_instance
def sleeper(_):
while True:
time.sleep(1000)
client.create_backend(
"sleeper",
sleeper,
config=BackendConfig(experimental_graceful_shutdown_timeout_s=1))
client.create_endpoint("sleeper", backend="sleeper")
handle = client.get_handle("sleeper")
ref = handle.remote()
client.delete_endpoint("sleeper")
client.delete_backend("sleeper")
with pytest.raises(ray.exceptions.RayActorError):
ray.get(ref)
def test_starlette_request(serve_instance):
client = serve_instance
@@ -40,6 +40,9 @@ def setup_worker(name,
def update_config(self, new_config):
return self.worker.update_config(new_config)
async def drain_pending_queries(self):
return await self.worker.drain_pending_queries()
worker = WorkerActor.remote()
ray.get(worker.ready.remote())
return worker
@@ -312,6 +315,51 @@ async def test_user_config_update(serve_instance, router,
assert await i == "new_val"
async def test_graceful_shutdown(serve_instance, router,
mock_controller_with_name):
class KeepInflight:
def __init__(self):
self.events = []
def reconfigure(self, config):
if config["release"]:
[event.set() for event in self.events]
async def __call__(self, _):
e = asyncio.Event()
self.events.append(e)
await e.wait()
backend_worker = await add_servable_to_router(
KeepInflight,
router,
mock_controller_with_name[0],
backend_config=BackendConfig(
num_replicas=1,
internal_metadata=BackendMetadata(is_blocking=False),
user_config={"release": False}))
query_param = make_request_param()
refs = [(await router.assign_request.remote(query_param))
for _ in range(6)]
shutdown_ref = backend_worker.drain_pending_queries.remote()
with pytest.raises(ray.exceptions.GetTimeoutError):
# Shutdown should block because there are still inflight queries.
ray.get(shutdown_ref, timeout=2)
config = BackendConfig()
config.user_config = {"release": True}
await mock_controller_with_name[1].update_backend.remote("backend", config)
# All queries should complete successfully
ray.get(refs)
# The draining operation should be completed.
ray.get(shutdown_ref)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))
+39 -1
View File
@@ -7,7 +7,7 @@ import logging
import random
import string
import time
from typing import List, Dict
from typing import Iterable, List, Dict, Tuple
import os
from ray.serve.exceptions import RayServeException
from collections import UserDict
@@ -382,3 +382,41 @@ class MockImportedBackend:
def other_method(self, request):
return request.data
def compute_iterable_delta(old: Iterable,
new: Iterable) -> Tuple[set, set, set]:
"""Given two iterables, return the entries that's (added, removed, updated).
Usage:
>>> old = {"a", "b"}
>>> new = {"a", "d"}
>>> compute_dict_delta(old, new)
({"d"}, {"b"}, {"a"})
"""
old_keys, new_keys = set(old), set(new)
added_keys = new_keys - old_keys
removed_keys = old_keys - new_keys
updated_keys = old_keys.intersection(new_keys)
return added_keys, removed_keys, updated_keys
def compute_dict_delta(old_dict, new_dict) -> Tuple[dict, dict, dict]:
"""Given two dicts, return the entries that's (added, removed, updated).
Usage:
>>> old = {"a": 1, "b": 2}
>>> new = {"a": 3, "d": 4}
>>> compute_dict_delta(old, new)
({"d": 4}, {"b": 2}, {"a": 3})
"""
added_keys, removed_keys, updated_keys = compute_iterable_delta(
old_dict.keys(), new_dict.keys())
return (
{k: new_dict[k]
for k in added_keys},
{k: old_dict[k]
for k in removed_keys},
{k: new_dict[k]
for k in updated_keys},
)