diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index dd1b91359..8252c2a4b 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -11,7 +11,7 @@ serve_tests_srcs = glob(["tests/*.py"], py_test( name = "test_api", - size = "medium", + size = "large", srcs = serve_tests_srcs, tags = ["exclusive"], deps = [":serve_lib"], diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index e7870edd4..b73cb5667 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -1,18 +1,16 @@ import asyncio from asyncio.futures import Future from collections import defaultdict -from itertools import chain import os import random import time -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Dict, Any, List, Optional, Set, Tuple from uuid import uuid4, UUID from pydantic import BaseModel import ray import ray.cloudpickle as pickle -from ray.serve.autoscaling_policy import BasicAutoscalingPolicy from ray.serve.backend_worker import create_backend_replica from ray.serve.constants import ( ASYNC_CONCURRENCY, @@ -170,15 +168,51 @@ class BackendInfo(BaseModel): class BackendState: - def __init__(self, checkpoint: bytes = None): + def __init__(self, + controller_name: str, + detached: bool, + checkpoint: bytes = None): + self.controller_name = controller_name + self.detached = detached + + # Non-checkpointed state. + self.currently_starting_replicas: Dict[asyncio.Future, Tuple[ + BackendTag, ReplicaTag, ActorHandle]] = dict() + self.currently_stopping_replicas: Dict[asyncio.Future, Tuple[ + BackendTag, ReplicaTag]] = dict() + + # Checkpointed state. self.backends: Dict[BackendTag, BackendInfo] = dict() + self.backend_replicas: Dict[BackendTag, Dict[ + ReplicaTag, ActorHandle]] = defaultdict(dict) self.goals: Dict[BackendTag, GoalId] = dict() + self.backend_replicas_to_start: Dict[BackendTag, List[ + ReplicaTag]] = defaultdict(list) + self.backend_replicas_to_stop: Dict[BackendTag, List[Tuple[ + ReplicaTag, Duration]]] = defaultdict(list) + self.backends_to_remove: List[BackendTag] = list() if checkpoint is not None: - self.backends, self.goals = pickle.loads(checkpoint) + (self.backends, self.backend_replicas, self.goals, + self.backend_replicas_to_start, self.backend_replicas_to_stop, + self.backend_to_remove) = pickle.loads(checkpoint) + + # Fetch actor handles for all of the backend replicas in the system. + # All of these backend_replicas are guaranteed to already exist because + # they would not be written to a checkpoint in self.backend_replicas + # until they were created. + for backend_tag, replica_dict in self.backend_replicas.items(): + for replica_tag in replica_dict.keys(): + replica_name = format_actor_name(replica_tag, + self.controller_name) + self.backend_replicas[backend_tag][ + replica_tag] = ray.get_actor(replica_name) def checkpoint(self): - return pickle.dumps([self.backends, self.goals]) + return pickle.dumps( + (self.backends, self.backend_replicas, self.goals, + self.backend_replicas_to_start, self.backend_replicas_to_stop, + self.backends_to_remove)) def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: return { @@ -186,6 +220,10 @@ class BackendState: for tag, info in self.backends.items() } + def get_replica_handles( + self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: + return self.backend_replicas + def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: return self.backends.get(backend_tag) @@ -199,17 +237,14 @@ class BackendState: self.goals[backend_tag] = goal_id return existing_goal - def completed_goals( - self, - current_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] - ) -> List[GoalId]: + def completed_goals(self) -> List[GoalId]: completed_goals = [] - all_tags = set(current_replicas.keys()).union( + all_tags = set(self.backend_replicas.keys()).union( set(self.backends.keys())) for backend_tag in all_tags: desired_info = self.backends.get(backend_tag) - existing_info = current_replicas.get(backend_tag) + existing_info = self.backend_replicas.get(backend_tag) # Check for deleting if (not desired_info or desired_info.backend_config.num_replicas == 0) and \ @@ -222,89 +257,8 @@ class BackendState: completed_goals.append(self.goals[backend_tag]) return completed_goals - -class EndpointState: - def __init__(self, checkpoint: bytes = None): - self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict() - self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict() - - if checkpoint is not None: - self.routes, self.traffic_policies = pickle.loads(checkpoint) - - def checkpoint(self): - return pickle.dumps((self.routes, self.traffic_policies)) - - def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]: - endpoints = {} - for route, (endpoint, methods) in self.routes.items(): - if endpoint in self.traffic_policies: - traffic_policy = self.traffic_policies[endpoint] - traffic_dict = traffic_policy.traffic_dict - shadow_dict = traffic_policy.shadow_dict - else: - traffic_dict = {} - shadow_dict = {} - - endpoints[endpoint] = { - "route": route if route.startswith("/") else None, - "methods": methods, - "traffic": traffic_dict, - "shadows": shadow_dict, - } - return endpoints - - -@dataclass -class ActorStateReconciler: - controller_name: str = field(init=True) - detached: bool = field(init=True) - - backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field( - default_factory=lambda: defaultdict(dict)) - backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field( - default_factory=lambda: defaultdict(list)) - backend_replicas_to_stop: Dict[BackendTag, List[Tuple[ - ReplicaTag, Duration]]] = field( - default_factory=lambda: defaultdict(list)) - backends_to_remove: List[BackendTag] = field(default_factory=list) - - # NOTE(ilr): These are not checkpointed, but will be recreated by - # `_enqueue_pending_scale_changes_loop`. - currently_starting_replicas: Dict[asyncio.Future, Tuple[ - BackendTag, ReplicaTag, ActorHandle]] = field(default_factory=dict) - currently_stopping_replicas: Dict[asyncio.Future, Tuple[ - BackendTag, ReplicaTag]] = field(default_factory=dict) - - def __getstate__(self): - state = self.__dict__.copy() - del state["currently_stopping_replicas"] - del state["currently_starting_replicas"] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self.currently_stopping_replicas = {} - self.currently_starting_replicas = {} - - # TODO(edoakes): consider removing this and just using the names. - - def get_replica_handles(self) -> List[ActorHandle]: - return list( - chain.from_iterable([ - replica_dict.values() - for replica_dict in self.backend_replicas.values() - ])) - - def get_replica_tags(self) -> List[ReplicaTag]: - return list( - chain.from_iterable([ - replica_dict.keys() - for replica_dict in self.backend_replicas.values() - ])) - - async def _start_backend_replica(self, backend_state: BackendState, - backend_tag: BackendTag, - replica_tag: ReplicaTag) -> ActorHandle: + def _start_backend_replica(self, backend_tag: BackendTag, + replica_tag: ReplicaTag) -> ActorHandle: """Start a replica and return its actor handle. Checks if the named actor already exists before starting a new one. @@ -320,7 +274,7 @@ class ActorStateReconciler: except ValueError: logger.debug("Starting replica '{}' for backend '{}'.".format( replica_tag, backend_tag)) - backend_info = backend_state.get_backend(backend_tag) + backend_info = self.get_backend(backend_tag) replica_handle = ray.remote(backend_info.worker_class).options( name=replica_name, @@ -334,9 +288,8 @@ class ActorStateReconciler: return replica_handle - def _scale_backend_replicas( + def scale_backend_replicas( self, - backends: Dict[BackendTag, BackendInfo], backend_tag: BackendTag, num_replicas: int, force_kill: bool = False, @@ -353,7 +306,7 @@ class ActorStateReconciler: logger.debug("Scaling backend '{}' to {} replicas".format( backend_tag, num_replicas)) - assert (backend_tag in backends + assert (backend_tag in self.backends ), "Backend {} is not registered.".format(backend_tag) assert num_replicas >= 0, ("Number of replicas must be" " greater than or equal to 0.") @@ -361,7 +314,7 @@ class ActorStateReconciler: current_num_replicas = len(self.backend_replicas[backend_tag]) delta_num_replicas = num_replicas - current_num_replicas - backend_info: BackendInfo = backends[backend_tag] + backend_info: BackendInfo = self.backends[backend_tag] if delta_num_replicas > 0: can_schedule = try_schedule_resources_on_nodes(requirements=[ backend_info.replica_config.resource_dict @@ -403,17 +356,17 @@ class ActorStateReconciler: graceful_timeout_s, )) - async def _enqueue_pending_scale_changes_loop(self, - backend_state: BackendState): + def _start_pending_replicas(self): for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ items(): for replica_tag in replicas_to_create: - replica_handle = await self._start_backend_replica( - backend_state, backend_tag, replica_tag) + replica_handle = self._start_backend_replica( + backend_tag, replica_tag) ready_future = replica_handle.ready.remote().as_future() self.currently_starting_replicas[ready_future] = ( backend_tag, replica_tag, replica_handle) + def _stop_pending_replicas(self): for backend_tag, replicas_to_stop in ( self.backend_replicas_to_stop.items()): for replica_tag, shutdown_timeout in replicas_to_stop: @@ -464,7 +417,6 @@ class ActorStateReconciler: pass if len(backend) == 0: del self.backend_replicas_to_start[backend_tag] - return len(in_flight) async def _check_currently_stopping_replicas(self) -> int: """Returns the number of replicas waiting to stop""" @@ -498,56 +450,50 @@ class ActorStateReconciler: if len(self.backend_replicas[backend_tag]) == 0: del self.backend_replicas[backend_tag] - return len(in_flight) - - async def update_actor_state(self, start_time: float) -> bool: + async def update(self) -> bool: """Returns whether the number of backends has changed.""" + self._start_pending_replicas() + self._stop_pending_replicas() + num_starting = len(self.currently_starting_replicas) num_stopping = len(self.currently_stopping_replicas) - num_pending_starts = await self._check_currently_starting_replicas() - num_pending_stops = await self._check_currently_stopping_replicas() - time_running = int(time.time() - start_time) - if (time_running > 0 - and time_running % REPLICA_STARTUP_TIME_WARNING_S == 0): - delta = time.time() - start_time - logger.warning( - f"Waited {delta:.2f}s for {num_pending_starts} replicas " - f"to start up or {num_pending_stops} replicas to shutdown." - " Make sure there are enough resources to create the " - "replicas.") + await self._check_currently_starting_replicas() + await self._check_currently_stopping_replicas() return (len(self.currently_starting_replicas) != num_starting) or \ (len(self.currently_stopping_replicas) != num_stopping) - def _recover_actor_handles(self) -> None: - # Fetch actor handles for all of the backend replicas in the system. - # All of these backend_replicas are guaranteed to already exist because - # they would not be written to a checkpoint in self.backend_replicas - # until they were created. - for backend_tag, replica_dict in self.backend_replicas.items(): - for replica_tag in replica_dict.keys(): - replica_name = format_actor_name(replica_tag, - self.controller_name) - self.backend_replicas[backend_tag][ - replica_tag] = ray.get_actor(replica_name) - async def _recover_from_checkpoint( - self, backend_state: BackendState, controller: "ServeController" - ) -> Dict[BackendTag, BasicAutoscalingPolicy]: - self._recover_actor_handles() - autoscaling_policies = dict() +class EndpointState: + def __init__(self, checkpoint: bytes = None): + self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict() + self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict() - for backend, info in backend_state.backends.items(): - metadata = info.backend_config.internal_metadata - if metadata.autoscaling_config is not None: - autoscaling_policies[backend] = BasicAutoscalingPolicy( - backend, metadata.autoscaling_config) + if checkpoint is not None: + self.routes, self.traffic_policies = pickle.loads(checkpoint) - # Start/stop any pending backend replicas. - await self._enqueue_pending_scale_changes_loop(backend_state) + def checkpoint(self): + return pickle.dumps((self.routes, self.traffic_policies)) - return autoscaling_policies + def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]: + endpoints = {} + for route, (endpoint, methods) in self.routes.items(): + if endpoint in self.traffic_policies: + traffic_policy = self.traffic_policies[endpoint] + traffic_dict = traffic_policy.traffic_dict + shadow_dict = traffic_policy.shadow_dict + else: + traffic_dict = {} + shadow_dict = {} + + endpoints[endpoint] = { + "route": route if route.startswith("/") else None, + "methods": methods, + "traffic": traffic_dict, + "shadows": shadow_dict, + } + return endpoints @dataclass @@ -560,7 +506,6 @@ class FutureResult: class Checkpoint: endpoint_state_checkpoint: bytes backend_state_checkpoint: bytes - reconciler: ActorStateReconciler # TODO(ilr) Rename reconciler to PendingState inflight_reqs: Dict[uuid4, FutureResult] @@ -597,10 +542,6 @@ class ServeController: detached: bool = False): # Used to read/write checkpoints. self.kv_store = RayInternalKVStore(namespace=controller_name) - self.actor_reconciler = ActorStateReconciler(controller_name, detached) - - # backend -> AutoscalingPolicy - self.autoscaling_policies = dict() # Dictionary of backend_tag -> proxy_name -> most recent queue length. self.backend_stats = defaultdict(lambda: defaultdict(dict)) @@ -620,15 +561,21 @@ class ServeController: checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY) if checkpoint_bytes is None: logger.debug("No checkpoint found") - self.backend_state = BackendState() + self.backend_state = BackendState(controller_name, detached) self.endpoint_state = EndpointState() else: checkpoint: Checkpoint = pickle.loads(checkpoint_bytes) self.backend_state = BackendState( + controller_name, + detached, checkpoint=checkpoint.backend_state_checkpoint) self.endpoint_state = EndpointState( checkpoint=checkpoint.endpoint_state_checkpoint) - await self._recover_from_checkpoint(checkpoint) + + self._serializable_inflight_results = checkpoint.inflight_reqs + for uuid, fut_result in self._serializable_inflight_results.items( + ): + self._create_event_with_result(fut_result.requested_goal, uuid) # NOTE(simon): Currently we do all-to-all broadcast. This means # any listeners will receive notification for all changes. This @@ -637,9 +584,6 @@ class ServeController: # optimize the logic to support subscription by key. self.long_poll_host = LongPollHost() - # The configs pushed out here get updated by - # self._recover_from_checkpoint in the failure scenario, so that must - # be run before we notify the changes. self.notify_backend_configs_changed() self.notify_replica_handles_changed() self.notify_traffic_policies_changed() @@ -670,7 +614,6 @@ class ServeController: event = asyncio.Event() event.result = FutureResult(goal_state) uuid_val = recreation_uuid or uuid4() - logger.debug(f"Creating uuid {uuid_val} for result of {goal_state}") self.inflight_results[uuid_val] = event self._serializable_inflight_results[uuid_val] = event.result return uuid_val @@ -683,7 +626,7 @@ class ServeController: LongPollKey.REPLICA_HANDLES, { backend_tag: list(replica_dict.values()) for backend_tag, replica_dict in - self.actor_reconciler.backend_replicas.items() + self.backend_state.backend_replicas.items() }) def notify_traffic_policies_changed(self): @@ -724,7 +667,7 @@ class ServeController: checkpoint = pickle.dumps( Checkpoint(self.endpoint_state.checkpoint(), - self.backend_state.checkpoint(), self.actor_reconciler, + self.backend_state.checkpoint(), self._serializable_inflight_results)) self.kv_store.put(CHECKPOINT_KEY, checkpoint) @@ -735,96 +678,34 @@ class ServeController: logger.warning("Intentionally crashing after checkpoint") os._exit(0) - async def _recover_from_checkpoint(self, checkpoint: Checkpoint) -> None: - """Recover the instance state from the provided checkpoint. - - This should be called in the constructor to ensure that the internal - state is updated before any other operations run. After running this, - internal state will be updated and long-poll clients may be notified. - - Performs the following operations: - 1) Deserializes the internal state from the checkpoint. - 2) Starts/stops any replicas that are pending creation or - deletion. - """ - start = time.time() - logger.info("Recovering from checkpoint") - - self.actor_reconciler = checkpoint.reconciler - - self._serializable_inflight_results = checkpoint.inflight_reqs - for uuid, fut_result in self._serializable_inflight_results.items(): - self._create_event_with_result(fut_result.requested_goal, uuid) - - # NOTE(edoakes): unfortunately, we can't completely recover from a - # checkpoint in the constructor because we block while waiting for - # other actors to start up, and those actors fetch soft state from - # this actor. Because no other tasks will start executing until after - # the constructor finishes, if we were to run this logic in the - # constructor it could lead to deadlock between this actor and a child. - # However, we do need to guarantee that we have fully recovered from a - # checkpoint before any other state-changing calls run. We address this - # by acquiring the write_lock and then posting the task to recover from - # a checkpoint to the event loop. Other state-changing calls acquire - # this lock and will be blocked until recovering from the checkpoint - # finishes. This can be removed once we move to the async control loop. - - async def finish_recover_from_checkpoint(): - assert self.write_lock.locked() - self.autoscaling_policies = await self.actor_reconciler.\ - _recover_from_checkpoint(self.backend_state, self) - self.write_lock.release() - logger.info( - "Recovered from checkpoint in {:.3f}s".format(time.time() - - start)) - - await self.write_lock.acquire() - asyncio.get_event_loop().create_task(finish_recover_from_checkpoint()) - - async def do_autoscale(self) -> None: - for backend, info in self.backend_state.backends.items(): - if backend not in self.autoscaling_policies: - continue - - new_num_replicas = self.autoscaling_policies[backend].scale( - self.backend_stats[backend], info.backend_config.num_replicas) - if new_num_replicas > 0: - await self.update_backend_config( - backend, BackendConfig(num_replicas=new_num_replicas)) - async def reconcile_current_and_goal_backends(self): pass def set_goal_id(self, goal_id: UUID) -> None: event = self.inflight_results.get(goal_id) - logger.debug(f"Setting Goal Id: {goal_id}") + logger.debug(f"Setting goal id {goal_id}") if event: event.set() async def run_control_loop(self) -> None: - start_time = time.time() while True: - await self.do_autoscale() async with self.write_lock: self.http_state.update() - delta_workers = await self.actor_reconciler.update_actor_state( - start_time) - if delta_workers: - self.notify_replica_handles_changed() - self.notify_backend_configs_changed() - self._checkpoint() - else: - start_time = time.time() - completed_ids = self.backend_state.completed_goals( - self.actor_reconciler.backend_replicas) + + completed_ids = self.backend_state.completed_goals() for done_id in completed_ids: self.set_goal_id(done_id) + delta_workers = await self.backend_state.update() + if delta_workers: + self.notify_replica_handles_changed() + self._checkpoint() + await asyncio.sleep(CONTROL_LOOP_PERIOD_S) def _all_replica_handles( self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: """Used for testing.""" - return self.actor_reconciler.backend_replicas + return self.backend_state.get_replica_handles() def get_all_backends(self) -> Dict[BackendTag, BackendConfig]: """Returns a dictionary of backend tag to backend config.""" @@ -1014,11 +895,6 @@ class ServeController: worker_class=backend_replica, backend_config=backend_config, replica_config=replica_config) - metadata = backend_config.internal_metadata - if metadata.autoscaling_config is not None: - self.autoscaling_policies[ - backend_tag] = BasicAutoscalingPolicy( - backend_tag, metadata.autoscaling_config) return_uuid = self._create_event_with_result({ backend_tag: backend_info @@ -1028,9 +904,8 @@ class ServeController: try: # This call should be to run control loop - self.actor_reconciler._scale_backend_replicas( - self.backend_state.backends, backend_tag, - backend_config.num_replicas) + self.backend_state.scale_backend_replicas( + backend_tag, backend_config.num_replicas) except RayServeException as e: del self.backend_state.backends[backend_tag] raise e @@ -1039,9 +914,7 @@ class ServeController: # or pushing the updated config to avoid inconsistent state if we # crash while making the change. self._checkpoint() - await self.actor_reconciler._enqueue_pending_scale_changes_loop( - self.backend_state) - + self.notify_backend_configs_changed() return return_uuid async def delete_backend(self, @@ -1067,16 +940,14 @@ class ServeController: # from self.backend_state.backends and # This should be a call to the control loop - self.actor_reconciler._scale_backend_replicas( - self.backend_state.backends, backend_tag, 0, force_kill) + self.backend_state.scale_backend_replicas(backend_tag, 0, + force_kill) # Remove the backend's metadata. del self.backend_state.backends[backend_tag] - if backend_tag in self.autoscaling_policies: - del self.autoscaling_policies[backend_tag] # Add the intention to remove the backend from the routers. - self.actor_reconciler.backends_to_remove.append(backend_tag) + self.backend_state.backends_to_remove.append(backend_tag) return_uuid = self._create_event_with_result({backend_tag: None}) # Remove the backend's metadata. @@ -1085,8 +956,6 @@ class ServeController: # backend from the routers to avoid inconsistent state if we crash # after pushing the update. self._checkpoint() - await self.actor_reconciler._enqueue_pending_scale_changes_loop( - self.backend_state) return return_uuid async def update_backend_config(self, backend_tag: BackendTag, @@ -1114,20 +983,16 @@ class ServeController: # Scale the replicas with the new configuration. # This should be to run the control loop - self.actor_reconciler._scale_backend_replicas( - self.backend_state.backends, backend_tag, - backend_config.num_replicas) + self.backend_state.scale_backend_replicas( + backend_tag, backend_config.num_replicas) # NOTE(edoakes): we must write a checkpoint before pushing the # update to avoid inconsistent state if we crash after pushing the # update. self._checkpoint() - # Inform the routers about change in configuration - # (particularly for setting max_batch_size). - - await self.actor_reconciler._enqueue_pending_scale_changes_loop( - self.backend_state) + # Inform the routers and backend replicas about config changes. + self.notify_backend_configs_changed() return return_uuid @@ -1146,6 +1011,8 @@ class ServeController: async with self.write_lock: for proxy in self.http_state.get_http_proxy_handles().values(): ray.kill(proxy, no_restart=True) - for replica in self.actor_reconciler.get_replica_handles(): - ray.kill(replica, no_restart=True) + for replica_dict in self.backend_state.get_replica_handles( + ).values(): + for replica in replica_dict.values(): + ray.kill(replica, no_restart=True) self.kv_store.delete(CHECKPOINT_KEY)