[serve] Merge ActorReconciler and BackendState (#13139)

This commit is contained in:
Edward Oakes
2021-01-05 09:56:22 -06:00
committed by GitHub
parent 4150970226
commit e8162f1b1f
2 changed files with 123 additions and 256 deletions
+1 -1
View File
@@ -11,7 +11,7 @@ serve_tests_srcs = glob(["tests/*.py"],
py_test(
name = "test_api",
size = "medium",
size = "large",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
+122 -255
View File
@@ -1,18 +1,16 @@
import asyncio
from asyncio.futures import Future
from collections import defaultdict
from itertools import chain
import os
import random
import time
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Dict, Any, List, Optional, Set, Tuple
from uuid import uuid4, UUID
from pydantic import BaseModel
import ray
import ray.cloudpickle as pickle
from ray.serve.autoscaling_policy import BasicAutoscalingPolicy
from ray.serve.backend_worker import create_backend_replica
from ray.serve.constants import (
ASYNC_CONCURRENCY,
@@ -170,15 +168,51 @@ class BackendInfo(BaseModel):
class BackendState:
def __init__(self, checkpoint: bytes = None):
def __init__(self,
controller_name: str,
detached: bool,
checkpoint: bytes = None):
self.controller_name = controller_name
self.detached = detached
# Non-checkpointed state.
self.currently_starting_replicas: Dict[asyncio.Future, Tuple[
BackendTag, ReplicaTag, ActorHandle]] = dict()
self.currently_stopping_replicas: Dict[asyncio.Future, Tuple[
BackendTag, ReplicaTag]] = dict()
# Checkpointed state.
self.backends: Dict[BackendTag, BackendInfo] = dict()
self.backend_replicas: Dict[BackendTag, Dict[
ReplicaTag, ActorHandle]] = defaultdict(dict)
self.goals: Dict[BackendTag, GoalId] = dict()
self.backend_replicas_to_start: Dict[BackendTag, List[
ReplicaTag]] = defaultdict(list)
self.backend_replicas_to_stop: Dict[BackendTag, List[Tuple[
ReplicaTag, Duration]]] = defaultdict(list)
self.backends_to_remove: List[BackendTag] = list()
if checkpoint is not None:
self.backends, self.goals = pickle.loads(checkpoint)
(self.backends, self.backend_replicas, self.goals,
self.backend_replicas_to_start, self.backend_replicas_to_stop,
self.backend_to_remove) = pickle.loads(checkpoint)
# Fetch actor handles for all of the backend replicas in the system.
# All of these backend_replicas are guaranteed to already exist because
# they would not be written to a checkpoint in self.backend_replicas
# until they were created.
for backend_tag, replica_dict in self.backend_replicas.items():
for replica_tag in replica_dict.keys():
replica_name = format_actor_name(replica_tag,
self.controller_name)
self.backend_replicas[backend_tag][
replica_tag] = ray.get_actor(replica_name)
def checkpoint(self):
return pickle.dumps([self.backends, self.goals])
return pickle.dumps(
(self.backends, self.backend_replicas, self.goals,
self.backend_replicas_to_start, self.backend_replicas_to_stop,
self.backends_to_remove))
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
return {
@@ -186,6 +220,10 @@ class BackendState:
for tag, info in self.backends.items()
}
def get_replica_handles(
self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]:
return self.backend_replicas
def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]:
return self.backends.get(backend_tag)
@@ -199,17 +237,14 @@ class BackendState:
self.goals[backend_tag] = goal_id
return existing_goal
def completed_goals(
self,
current_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]
) -> List[GoalId]:
def completed_goals(self) -> List[GoalId]:
completed_goals = []
all_tags = set(current_replicas.keys()).union(
all_tags = set(self.backend_replicas.keys()).union(
set(self.backends.keys()))
for backend_tag in all_tags:
desired_info = self.backends.get(backend_tag)
existing_info = current_replicas.get(backend_tag)
existing_info = self.backend_replicas.get(backend_tag)
# Check for deleting
if (not desired_info or
desired_info.backend_config.num_replicas == 0) and \
@@ -222,89 +257,8 @@ class BackendState:
completed_goals.append(self.goals[backend_tag])
return completed_goals
class EndpointState:
def __init__(self, checkpoint: bytes = None):
self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
if checkpoint is not None:
self.routes, self.traffic_policies = pickle.loads(checkpoint)
def checkpoint(self):
return pickle.dumps((self.routes, self.traffic_policies))
def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]:
endpoints = {}
for route, (endpoint, methods) in self.routes.items():
if endpoint in self.traffic_policies:
traffic_policy = self.traffic_policies[endpoint]
traffic_dict = traffic_policy.traffic_dict
shadow_dict = traffic_policy.shadow_dict
else:
traffic_dict = {}
shadow_dict = {}
endpoints[endpoint] = {
"route": route if route.startswith("/") else None,
"methods": methods,
"traffic": traffic_dict,
"shadows": shadow_dict,
}
return endpoints
@dataclass
class ActorStateReconciler:
controller_name: str = field(init=True)
detached: bool = field(init=True)
backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field(
default_factory=lambda: defaultdict(dict))
backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field(
default_factory=lambda: defaultdict(list))
backend_replicas_to_stop: Dict[BackendTag, List[Tuple[
ReplicaTag, Duration]]] = field(
default_factory=lambda: defaultdict(list))
backends_to_remove: List[BackendTag] = field(default_factory=list)
# NOTE(ilr): These are not checkpointed, but will be recreated by
# `_enqueue_pending_scale_changes_loop`.
currently_starting_replicas: Dict[asyncio.Future, Tuple[
BackendTag, ReplicaTag, ActorHandle]] = field(default_factory=dict)
currently_stopping_replicas: Dict[asyncio.Future, Tuple[
BackendTag, ReplicaTag]] = field(default_factory=dict)
def __getstate__(self):
state = self.__dict__.copy()
del state["currently_stopping_replicas"]
del state["currently_starting_replicas"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self.currently_stopping_replicas = {}
self.currently_starting_replicas = {}
# TODO(edoakes): consider removing this and just using the names.
def get_replica_handles(self) -> List[ActorHandle]:
return list(
chain.from_iterable([
replica_dict.values()
for replica_dict in self.backend_replicas.values()
]))
def get_replica_tags(self) -> List[ReplicaTag]:
return list(
chain.from_iterable([
replica_dict.keys()
for replica_dict in self.backend_replicas.values()
]))
async def _start_backend_replica(self, backend_state: BackendState,
backend_tag: BackendTag,
replica_tag: ReplicaTag) -> ActorHandle:
def _start_backend_replica(self, backend_tag: BackendTag,
replica_tag: ReplicaTag) -> ActorHandle:
"""Start a replica and return its actor handle.
Checks if the named actor already exists before starting a new one.
@@ -320,7 +274,7 @@ class ActorStateReconciler:
except ValueError:
logger.debug("Starting replica '{}' for backend '{}'.".format(
replica_tag, backend_tag))
backend_info = backend_state.get_backend(backend_tag)
backend_info = self.get_backend(backend_tag)
replica_handle = ray.remote(backend_info.worker_class).options(
name=replica_name,
@@ -334,9 +288,8 @@ class ActorStateReconciler:
return replica_handle
def _scale_backend_replicas(
def scale_backend_replicas(
self,
backends: Dict[BackendTag, BackendInfo],
backend_tag: BackendTag,
num_replicas: int,
force_kill: bool = False,
@@ -353,7 +306,7 @@ class ActorStateReconciler:
logger.debug("Scaling backend '{}' to {} replicas".format(
backend_tag, num_replicas))
assert (backend_tag in backends
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
assert num_replicas >= 0, ("Number of replicas must be"
" greater than or equal to 0.")
@@ -361,7 +314,7 @@ class ActorStateReconciler:
current_num_replicas = len(self.backend_replicas[backend_tag])
delta_num_replicas = num_replicas - current_num_replicas
backend_info: BackendInfo = backends[backend_tag]
backend_info: BackendInfo = self.backends[backend_tag]
if delta_num_replicas > 0:
can_schedule = try_schedule_resources_on_nodes(requirements=[
backend_info.replica_config.resource_dict
@@ -403,17 +356,17 @@ class ActorStateReconciler:
graceful_timeout_s,
))
async def _enqueue_pending_scale_changes_loop(self,
backend_state: BackendState):
def _start_pending_replicas(self):
for backend_tag, replicas_to_create in self.backend_replicas_to_start.\
items():
for replica_tag in replicas_to_create:
replica_handle = await self._start_backend_replica(
backend_state, backend_tag, replica_tag)
replica_handle = self._start_backend_replica(
backend_tag, replica_tag)
ready_future = replica_handle.ready.remote().as_future()
self.currently_starting_replicas[ready_future] = (
backend_tag, replica_tag, replica_handle)
def _stop_pending_replicas(self):
for backend_tag, replicas_to_stop in (
self.backend_replicas_to_stop.items()):
for replica_tag, shutdown_timeout in replicas_to_stop:
@@ -464,7 +417,6 @@ class ActorStateReconciler:
pass
if len(backend) == 0:
del self.backend_replicas_to_start[backend_tag]
return len(in_flight)
async def _check_currently_stopping_replicas(self) -> int:
"""Returns the number of replicas waiting to stop"""
@@ -498,56 +450,50 @@ class ActorStateReconciler:
if len(self.backend_replicas[backend_tag]) == 0:
del self.backend_replicas[backend_tag]
return len(in_flight)
async def update_actor_state(self, start_time: float) -> bool:
async def update(self) -> bool:
"""Returns whether the number of backends has changed."""
self._start_pending_replicas()
self._stop_pending_replicas()
num_starting = len(self.currently_starting_replicas)
num_stopping = len(self.currently_stopping_replicas)
num_pending_starts = await self._check_currently_starting_replicas()
num_pending_stops = await self._check_currently_stopping_replicas()
time_running = int(time.time() - start_time)
if (time_running > 0
and time_running % REPLICA_STARTUP_TIME_WARNING_S == 0):
delta = time.time() - start_time
logger.warning(
f"Waited {delta:.2f}s for {num_pending_starts} replicas "
f"to start up or {num_pending_stops} replicas to shutdown."
" Make sure there are enough resources to create the "
"replicas.")
await self._check_currently_starting_replicas()
await self._check_currently_stopping_replicas()
return (len(self.currently_starting_replicas) != num_starting) or \
(len(self.currently_stopping_replicas) != num_stopping)
def _recover_actor_handles(self) -> None:
# Fetch actor handles for all of the backend replicas in the system.
# All of these backend_replicas are guaranteed to already exist because
# they would not be written to a checkpoint in self.backend_replicas
# until they were created.
for backend_tag, replica_dict in self.backend_replicas.items():
for replica_tag in replica_dict.keys():
replica_name = format_actor_name(replica_tag,
self.controller_name)
self.backend_replicas[backend_tag][
replica_tag] = ray.get_actor(replica_name)
async def _recover_from_checkpoint(
self, backend_state: BackendState, controller: "ServeController"
) -> Dict[BackendTag, BasicAutoscalingPolicy]:
self._recover_actor_handles()
autoscaling_policies = dict()
class EndpointState:
def __init__(self, checkpoint: bytes = None):
self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
for backend, info in backend_state.backends.items():
metadata = info.backend_config.internal_metadata
if metadata.autoscaling_config is not None:
autoscaling_policies[backend] = BasicAutoscalingPolicy(
backend, metadata.autoscaling_config)
if checkpoint is not None:
self.routes, self.traffic_policies = pickle.loads(checkpoint)
# Start/stop any pending backend replicas.
await self._enqueue_pending_scale_changes_loop(backend_state)
def checkpoint(self):
return pickle.dumps((self.routes, self.traffic_policies))
return autoscaling_policies
def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]:
endpoints = {}
for route, (endpoint, methods) in self.routes.items():
if endpoint in self.traffic_policies:
traffic_policy = self.traffic_policies[endpoint]
traffic_dict = traffic_policy.traffic_dict
shadow_dict = traffic_policy.shadow_dict
else:
traffic_dict = {}
shadow_dict = {}
endpoints[endpoint] = {
"route": route if route.startswith("/") else None,
"methods": methods,
"traffic": traffic_dict,
"shadows": shadow_dict,
}
return endpoints
@dataclass
@@ -560,7 +506,6 @@ class FutureResult:
class Checkpoint:
endpoint_state_checkpoint: bytes
backend_state_checkpoint: bytes
reconciler: ActorStateReconciler
# TODO(ilr) Rename reconciler to PendingState
inflight_reqs: Dict[uuid4, FutureResult]
@@ -597,10 +542,6 @@ class ServeController:
detached: bool = False):
# Used to read/write checkpoints.
self.kv_store = RayInternalKVStore(namespace=controller_name)
self.actor_reconciler = ActorStateReconciler(controller_name, detached)
# backend -> AutoscalingPolicy
self.autoscaling_policies = dict()
# Dictionary of backend_tag -> proxy_name -> most recent queue length.
self.backend_stats = defaultdict(lambda: defaultdict(dict))
@@ -620,15 +561,21 @@ class ServeController:
checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint_bytes is None:
logger.debug("No checkpoint found")
self.backend_state = BackendState()
self.backend_state = BackendState(controller_name, detached)
self.endpoint_state = EndpointState()
else:
checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
self.backend_state = BackendState(
controller_name,
detached,
checkpoint=checkpoint.backend_state_checkpoint)
self.endpoint_state = EndpointState(
checkpoint=checkpoint.endpoint_state_checkpoint)
await self._recover_from_checkpoint(checkpoint)
self._serializable_inflight_results = checkpoint.inflight_reqs
for uuid, fut_result in self._serializable_inflight_results.items(
):
self._create_event_with_result(fut_result.requested_goal, uuid)
# NOTE(simon): Currently we do all-to-all broadcast. This means
# any listeners will receive notification for all changes. This
@@ -637,9 +584,6 @@ class ServeController:
# optimize the logic to support subscription by key.
self.long_poll_host = LongPollHost()
# The configs pushed out here get updated by
# self._recover_from_checkpoint in the failure scenario, so that must
# be run before we notify the changes.
self.notify_backend_configs_changed()
self.notify_replica_handles_changed()
self.notify_traffic_policies_changed()
@@ -670,7 +614,6 @@ class ServeController:
event = asyncio.Event()
event.result = FutureResult(goal_state)
uuid_val = recreation_uuid or uuid4()
logger.debug(f"Creating uuid {uuid_val} for result of {goal_state}")
self.inflight_results[uuid_val] = event
self._serializable_inflight_results[uuid_val] = event.result
return uuid_val
@@ -683,7 +626,7 @@ class ServeController:
LongPollKey.REPLICA_HANDLES, {
backend_tag: list(replica_dict.values())
for backend_tag, replica_dict in
self.actor_reconciler.backend_replicas.items()
self.backend_state.backend_replicas.items()
})
def notify_traffic_policies_changed(self):
@@ -724,7 +667,7 @@ class ServeController:
checkpoint = pickle.dumps(
Checkpoint(self.endpoint_state.checkpoint(),
self.backend_state.checkpoint(), self.actor_reconciler,
self.backend_state.checkpoint(),
self._serializable_inflight_results))
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
@@ -735,96 +678,34 @@ class ServeController:
logger.warning("Intentionally crashing after checkpoint")
os._exit(0)
async def _recover_from_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Recover the instance state from the provided checkpoint.
This should be called in the constructor to ensure that the internal
state is updated before any other operations run. After running this,
internal state will be updated and long-poll clients may be notified.
Performs the following operations:
1) Deserializes the internal state from the checkpoint.
2) Starts/stops any replicas that are pending creation or
deletion.
"""
start = time.time()
logger.info("Recovering from checkpoint")
self.actor_reconciler = checkpoint.reconciler
self._serializable_inflight_results = checkpoint.inflight_reqs
for uuid, fut_result in self._serializable_inflight_results.items():
self._create_event_with_result(fut_result.requested_goal, uuid)
# NOTE(edoakes): unfortunately, we can't completely recover from a
# checkpoint in the constructor because we block while waiting for
# other actors to start up, and those actors fetch soft state from
# this actor. Because no other tasks will start executing until after
# the constructor finishes, if we were to run this logic in the
# constructor it could lead to deadlock between this actor and a child.
# However, we do need to guarantee that we have fully recovered from a
# checkpoint before any other state-changing calls run. We address this
# by acquiring the write_lock and then posting the task to recover from
# a checkpoint to the event loop. Other state-changing calls acquire
# this lock and will be blocked until recovering from the checkpoint
# finishes. This can be removed once we move to the async control loop.
async def finish_recover_from_checkpoint():
assert self.write_lock.locked()
self.autoscaling_policies = await self.actor_reconciler.\
_recover_from_checkpoint(self.backend_state, self)
self.write_lock.release()
logger.info(
"Recovered from checkpoint in {:.3f}s".format(time.time() -
start))
await self.write_lock.acquire()
asyncio.get_event_loop().create_task(finish_recover_from_checkpoint())
async def do_autoscale(self) -> None:
for backend, info in self.backend_state.backends.items():
if backend not in self.autoscaling_policies:
continue
new_num_replicas = self.autoscaling_policies[backend].scale(
self.backend_stats[backend], info.backend_config.num_replicas)
if new_num_replicas > 0:
await self.update_backend_config(
backend, BackendConfig(num_replicas=new_num_replicas))
async def reconcile_current_and_goal_backends(self):
pass
def set_goal_id(self, goal_id: UUID) -> None:
event = self.inflight_results.get(goal_id)
logger.debug(f"Setting Goal Id: {goal_id}")
logger.debug(f"Setting goal id {goal_id}")
if event:
event.set()
async def run_control_loop(self) -> None:
start_time = time.time()
while True:
await self.do_autoscale()
async with self.write_lock:
self.http_state.update()
delta_workers = await self.actor_reconciler.update_actor_state(
start_time)
if delta_workers:
self.notify_replica_handles_changed()
self.notify_backend_configs_changed()
self._checkpoint()
else:
start_time = time.time()
completed_ids = self.backend_state.completed_goals(
self.actor_reconciler.backend_replicas)
completed_ids = self.backend_state.completed_goals()
for done_id in completed_ids:
self.set_goal_id(done_id)
delta_workers = await self.backend_state.update()
if delta_workers:
self.notify_replica_handles_changed()
self._checkpoint()
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
def _all_replica_handles(
self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]:
"""Used for testing."""
return self.actor_reconciler.backend_replicas
return self.backend_state.get_replica_handles()
def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
"""Returns a dictionary of backend tag to backend config."""
@@ -1014,11 +895,6 @@ class ServeController:
worker_class=backend_replica,
backend_config=backend_config,
replica_config=replica_config)
metadata = backend_config.internal_metadata
if metadata.autoscaling_config is not None:
self.autoscaling_policies[
backend_tag] = BasicAutoscalingPolicy(
backend_tag, metadata.autoscaling_config)
return_uuid = self._create_event_with_result({
backend_tag: backend_info
@@ -1028,9 +904,8 @@ class ServeController:
try:
# This call should be to run control loop
self.actor_reconciler._scale_backend_replicas(
self.backend_state.backends, backend_tag,
backend_config.num_replicas)
self.backend_state.scale_backend_replicas(
backend_tag, backend_config.num_replicas)
except RayServeException as e:
del self.backend_state.backends[backend_tag]
raise e
@@ -1039,9 +914,7 @@ class ServeController:
# or pushing the updated config to avoid inconsistent state if we
# crash while making the change.
self._checkpoint()
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.backend_state)
self.notify_backend_configs_changed()
return return_uuid
async def delete_backend(self,
@@ -1067,16 +940,14 @@ class ServeController:
# from self.backend_state.backends and
# This should be a call to the control loop
self.actor_reconciler._scale_backend_replicas(
self.backend_state.backends, backend_tag, 0, force_kill)
self.backend_state.scale_backend_replicas(backend_tag, 0,
force_kill)
# Remove the backend's metadata.
del self.backend_state.backends[backend_tag]
if backend_tag in self.autoscaling_policies:
del self.autoscaling_policies[backend_tag]
# Add the intention to remove the backend from the routers.
self.actor_reconciler.backends_to_remove.append(backend_tag)
self.backend_state.backends_to_remove.append(backend_tag)
return_uuid = self._create_event_with_result({backend_tag: None})
# Remove the backend's metadata.
@@ -1085,8 +956,6 @@ class ServeController:
# backend from the routers to avoid inconsistent state if we crash
# after pushing the update.
self._checkpoint()
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.backend_state)
return return_uuid
async def update_backend_config(self, backend_tag: BackendTag,
@@ -1114,20 +983,16 @@ class ServeController:
# Scale the replicas with the new configuration.
# This should be to run the control loop
self.actor_reconciler._scale_backend_replicas(
self.backend_state.backends, backend_tag,
backend_config.num_replicas)
self.backend_state.scale_backend_replicas(
backend_tag, backend_config.num_replicas)
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
# Inform the routers about change in configuration
# (particularly for setting max_batch_size).
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.backend_state)
# Inform the routers and backend replicas about config changes.
self.notify_backend_configs_changed()
return return_uuid
@@ -1146,6 +1011,8 @@ class ServeController:
async with self.write_lock:
for proxy in self.http_state.get_http_proxy_handles().values():
ray.kill(proxy, no_restart=True)
for replica in self.actor_reconciler.get_replica_handles():
ray.kill(replica, no_restart=True)
for replica_dict in self.backend_state.get_replica_handles(
).values():
for replica in replica_dict.values():
ray.kill(replica, no_restart=True)
self.kv_store.delete(CHECKPOINT_KEY)