diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 673c4b2cf..4aad2671e 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -1,7 +1,8 @@ import asyncio -from asyncio.futures import Future from collections import defaultdict -from typing import Dict, Any, List, Optional, Set, Tuple +from enum import Enum +import time +from typing import Dict, List, Optional, Tuple import ray import ray.cloudpickle as pickle @@ -17,7 +18,6 @@ from ray.serve.common import ( ) from ray.serve.config import BackendConfig, ReplicaConfig from ray.serve.constants import LongPollKey -from ray.serve.exceptions import RayServeException from ray.serve.kv_store import RayInternalKVStore from ray.serve.long_poll import LongPollHost from ray.serve.utils import (format_actor_name, get_random_letters, logger, @@ -30,6 +30,150 @@ CHECKPOINT_KEY = "serve-backend-state-checkpoint" _RESOURCE_CHECK_ENABLED = True +class ReplicaState(Enum): + SHOULD_START = 1 + STARTING = 2 + RUNNING = 3 + SHOULD_STOP = 4 + STOPPING = 5 + STOPPED = 6 + + +class BackendReplica: + def __init__(self, controller_name: str, detached: bool, + replica_tag: ReplicaTag, backend_tag: BackendTag): + self._actor_name = format_actor_name(replica_tag, controller_name) + self._controller_name = controller_name + self._detached = detached + self._replica_tag = replica_tag + self._backend_tag = backend_tag + self._actor_handle = None + self._startup_obj_ref = None + self._drain_obj_ref = None + self._state = ReplicaState.SHOULD_START + + def __get_state__(self): + clean_dict = self.__dict__.copy() + del clean_dict["_actor_handle"] + del clean_dict["_startup_obj_ref"] + del clean_dict["_drain_obj_ref"] + return clean_dict + + def __set_state__(self, d): + self.__dict__ = d + self._actor_handle = None + self._startup_obj_ref = None + self._drain_obj_ref = None + self._recover_from_checkpoint() + + def _recover_from_checkpoint(self): + if self._state == ReplicaState.STARTING: + # We do not need to pass in the class here because the actor + # creation has already been started if this class was checkpointed + # in the STARTING state. + self.start() + elif self._state == ReplicaState.RUNNING: + # Fetch actor handles for all backend replicas in the system. + # The actors must exist if this class was checkpointed in the + # RUNNING state. + self._actor_handle = ray.get_actor(self._actor_name) + elif self._state == ReplicaState.STOPPING: + self.stop() + + def start(self, backend_info: Optional[BackendInfo]): + assert self._state in { + ReplicaState.SHOULD_START, ReplicaState.STARTING + }, (f"State must be {ReplicaState.SHOULD_START} or " + f"{ReplicaState.STARTING}, *not* {self._state}") + try: + self._actor_handle = ray.get_actor(self._actor_name) + except ValueError: + logger.debug("Starting replica '{}' for backend '{}'.".format( + self._replica_tag, self._backend_tag)) + self._actor_handle = ray.remote(backend_info.worker_class).options( + name=self._actor_name, + lifetime="detached" if self._detached else None, + max_restarts=-1, + max_task_retries=-1, + **backend_info.replica_config.ray_actor_options).remote( + self._backend_tag, self._replica_tag, + backend_info.replica_config.actor_init_args, + backend_info.backend_config, self._controller_name) + self._startup_obj_ref = self._actor_handle.ready.remote() + self._state = ReplicaState.STARTING + + def check_started(self): + if self._state == ReplicaState.RUNNING: + return True + assert self._state == ReplicaState.STARTING, ( + f"State must be {ReplicaState.STARTING}, *not* {self._state}") + ready, _ = ray.wait([self._startup_obj_ref], timeout=0) + if len(ready) == 1: + self._state = ReplicaState.RUNNING + return True + return False + + def set_should_stop(self, graceful_shutdown_timeout_s: Duration): + self._state = ReplicaState.SHOULD_STOP + self._graceful_shutdown_timeout_s = graceful_shutdown_timeout_s + + def stop(self): + # We need to handle transitions from: + # SHOULD_START -> SHOULD_STOP -> STOPPING + # This means that the replica_handle may not have been created. + + assert self._state in { + ReplicaState.SHOULD_STOP, ReplicaState.STOPPING + }, (f"State must be {ReplicaState.SHOULD_STOP} or " + f"{ReplicaState.STOPPING}, *not* {self._state}") + + def drain_actor(actor_name): + # NOTE: the replicas may already be stopped if we failed + # after stopping them but before writing a checkpoint. + try: + replica = ray.get_actor(actor_name) + except ValueError: + return None + return replica.drain_pending_queries.remote() + + self._state = ReplicaState.STOPPING + self._drain_obj_ref = drain_actor(self._actor_name) + self._shutdown_deadline = time.time( + ) + self._graceful_shutdown_timeout_s + + def check_stopped(self): + if self._state == ReplicaState.STOPPED: + return True + assert self._state == ReplicaState.STOPPING, ( + f"State must be {ReplicaState.STOPPING}, *not* {self._state}") + + try: + replica = ray.get_actor(self._actor_name) + except ValueError: + self._state = ReplicaState.STOPPED + return True + + ready, _ = ray.wait([self._drain_obj_ref], timeout=0) + timeout_passed = time.time() > self._shutdown_deadline + + if len(ready) == 1 or timeout_passed: + if timeout_passed: + # Graceful period passed, kill it forcefully. + logger.debug( + f"{self._actor_name} did not shutdown after " + f"{self._graceful_shutdown_timeout_s}s, force-killing.") + + ray.kill(replica, no_restart=True) + self._state = ReplicaState.STOPPED + return True + return False + + def get_actor_handle(self): + assert self._state == ReplicaState.RUNNING, ( + f"State must be {ReplicaState.RUNNING}, *not* {self._state}") + return self._actor_handle + + class BackendState: """Manages all state for backends in the system. @@ -46,79 +190,65 @@ class BackendState: self._long_poll_host = long_poll_host self._goal_manager = goal_manager - # Non-checkpointed state. - self.currently_starting_replicas: Dict[asyncio.Future, Tuple[ - BackendTag, ReplicaTag, ActorHandle]] = dict() - self.currently_stopping_replicas: Dict[asyncio.Future, Tuple[ - BackendTag, ReplicaTag]] = dict() - - # Checkpointed state. - self.backends: Dict[BackendTag, BackendInfo] = dict() - self.backend_replicas: Dict[BackendTag, Dict[ - ReplicaTag, ActorHandle]] = defaultdict(dict) + self._replicas: Dict[BackendTag, Dict[ReplicaState, List[ + BackendReplica]]] = defaultdict(lambda: defaultdict(list)) + self._backend_metadata: Dict[BackendTag, BackendInfo] = dict() + self._target_replicas: Dict[BackendTag, int] = defaultdict(int) self.backend_goals: Dict[BackendTag, GoalId] = dict() - self.backend_replicas_to_start: Dict[BackendTag, List[ - ReplicaTag]] = defaultdict(list) - self.backend_replicas_to_stop: Dict[BackendTag, List[Tuple[ - ReplicaTag, Duration]]] = defaultdict(list) - self.backends_to_remove: List[BackendTag] = list() + + # Un-Checkpointed state. + self.pending_goals: Dict[GoalId, asyncio.Event] = dict() checkpoint = self._kv_store.get(CHECKPOINT_KEY) if checkpoint is not None: - (self.backends, self.backend_replicas, self.backend_goals, - self.backend_replicas_to_start, self.backend_replicas_to_stop, - self.backend_to_remove, - pending_goal_ids) = pickle.loads(checkpoint) + (self._replicas, self._backend_metadata, self._target_replicas, + self.backend_goals, pending_goal_ids) = pickle.loads(checkpoint) for goal_id in pending_goal_ids: self._goal_manager.create_goal(goal_id) - # Fetch actor handles for all backend replicas in the system. - # All of these backend_replicas are guaranteed to already exist - # because they would not be written to a checkpoint in - # self.backend_replicas until they were created. - for backend_tag, replica_dict in self.backend_replicas.items(): - for replica_tag in replica_dict.keys(): - replica_name = format_actor_name(replica_tag, - self._controller_name) - self.backend_replicas[backend_tag][ - replica_tag] = ray.get_actor(replica_name) - self._notify_backend_configs_changed() self._notify_replica_handles_changed() def _checkpoint(self) -> None: self._kv_store.put( CHECKPOINT_KEY, - pickle.dumps( - (self.backends, self.backend_replicas, self.backend_goals, - self.backend_replicas_to_start, self.backend_replicas_to_stop, - self.backends_to_remove, - self._goal_manager.get_pending_goal_ids()))) + pickle.dumps((self._replicas, self._backend_metadata, + self._target_replicas, self.backend_goals, + self._goal_manager.get_pending_goal_ids()))) def _notify_backend_configs_changed(self) -> None: self._long_poll_host.notify_changed(LongPollKey.BACKEND_CONFIGS, self.get_backend_configs()) + def get_running_replica_handles( + self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: + return { + backend_tag: { + backend_replica._replica_tag: + backend_replica.get_actor_handle() + for backend_replica in state_to_replica_dict[ + ReplicaState.RUNNING] + } + for backend_tag, state_to_replica_dict in self._replicas.items() + } + def _notify_replica_handles_changed(self) -> None: self._long_poll_host.notify_changed( LongPollKey.REPLICA_HANDLES, { backend_tag: list(replica_dict.values()) - for backend_tag, replica_dict in self.backend_replicas.items() + for backend_tag, replica_dict in + self.get_running_replica_handles().items() }) def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: return { tag: info.backend_config - for tag, info in self.backends.items() + for tag, info in self._backend_metadata.items() } - def get_replica_handles( - self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: - return self.backend_replicas - def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: - return self.backends.get(backend_tag) + return self._backend_metadata.get(backend_tag) def _set_backend_goal(self, backend_tag: BackendTag, backend_info: BackendInfo) -> None: @@ -126,7 +256,11 @@ class BackendState: new_goal_id = self._goal_manager.create_goal() if backend_info is not None: - self.backends[backend_tag] = backend_info + self._backend_metadata[backend_tag] = backend_info + self._target_replicas[ + backend_tag] = backend_info.backend_config.num_replicas + else: + self._target_replicas[backend_tag] = 0 self.backend_goals[backend_tag] = new_goal_id @@ -136,31 +270,25 @@ class BackendState: backend_config: BackendConfig, replica_config: ReplicaConfig) -> Optional[GoalId]: # Ensures this method is idempotent. - backend_info = self.backends.get(backend_tag) + backend_info = self._backend_metadata.get(backend_tag) if backend_info is not None: if (backend_info.backend_config == backend_config and backend_info.replica_config == replica_config): return None - backend_replica = create_backend_replica(replica_config.func_or_class) + backend_replica_class = create_backend_replica( + replica_config.func_or_class) # Save creator that starts replicas, the arguments to be passed in, # and the configuration for the backends. backend_info = BackendInfo( - worker_class=backend_replica, + worker_class=backend_replica_class, backend_config=backend_config, replica_config=replica_config) new_goal_id, existing_goal_id = self._set_backend_goal( backend_tag, backend_info) - try: - self.scale_backend_replicas(backend_tag, - backend_config.num_replicas) - except RayServeException as e: - del self.backends[backend_tag] - raise e - # NOTE(edoakes): we must write a checkpoint before starting new # or pushing the updated config to avoid inconsistent state if we # crash while making the change. @@ -175,20 +303,15 @@ class BackendState: force_kill: bool = False) -> Optional[GoalId]: # This method must be idempotent. We should validate that the # specified backend exists on the client. - if backend_tag not in self.backends: + if backend_tag not in self._backend_metadata: return None - # Scale its replicas down to 0. - self.scale_backend_replicas(backend_tag, 0, force_kill) - - # Remove the backend's metadata. - del self.backends[backend_tag] - - # Add the intention to remove the backend from the routers. - self.backends_to_remove.append(backend_tag) - new_goal_id, existing_goal_id = self._set_backend_goal( backend_tag, None) + if force_kill: + self._backend_metadata[ + backend_tag].backend_config.\ + experimental_graceful_shutdown_timeout_s = 0 self._checkpoint() if existing_goal_id is not None: @@ -197,20 +320,18 @@ class BackendState: def update_backend_config(self, backend_tag: BackendTag, config_options: BackendConfig): - if backend_tag not in self.backends: + if backend_tag not in self._backend_metadata: raise ValueError(f"Backend {backend_tag} is not registered") - stored_backend_config = self.backends[backend_tag].backend_config + stored_backend_config = self._backend_metadata[ + backend_tag].backend_config updated_config = stored_backend_config.copy( update=config_options.dict(exclude_unset=True)) updated_config._validate_complete() - self.backends[backend_tag].backend_config = updated_config + self._backend_metadata[backend_tag].backend_config = updated_config new_goal_id, existing_goal_id = self._set_backend_goal( - backend_tag, self.backends[backend_tag]) - - # Scale the replicas with the new configuration. - self.scale_backend_replicas(backend_tag, updated_config.num_replicas) + backend_tag, self._backend_metadata[backend_tag]) # NOTE(edoakes): we must write a checkpoint before pushing the # update to avoid inconsistent state if we crash after pushing the @@ -260,31 +381,38 @@ class BackendState: def scale_backend_replicas( self, backend_tag: BackendTag, - num_replicas: int, - force_kill: bool = False, - ) -> None: + ) -> bool: """Scale the given backend to the number of replicas. NOTE: this does not actually start or stop the replicas, but instead - adds the intention to start/stop them to self.backend_replicas_to_start - and self.backend_replicas_to_stop. The caller is responsible for then - first writing a checkpoint and then actually starting/stopping the - intended replicas. This avoids inconsistencies with starting/stopping a - replica and then crashing before writing a checkpoint. + adds them to ReplicaState.SHOULD_START or ReplicaState.SHOULD_STOP. + The caller is responsible for then first writing a checkpoint and then + actually starting/stopping the intended replicas. This avoids + inconsistencies with starting/stopping a replica and then crashing + before writing a checkpoint. """ + num_replicas = self._target_replicas.get(backend_tag, 0) logger.debug("Scaling backend '{}' to {} replicas".format( backend_tag, num_replicas)) - assert (backend_tag in self.backends + assert (backend_tag in self._backend_metadata ), "Backend {} is not registered.".format(backend_tag) assert num_replicas >= 0, ("Number of replicas must be" " greater than or equal to 0.") - current_num_replicas = len(self.backend_replicas[backend_tag]) + current_num_replicas = sum([ + len(self._replicas[backend_tag][ReplicaState.SHOULD_START]), + len(self._replicas[backend_tag][ReplicaState.STARTING]), + len(self._replicas[backend_tag][ReplicaState.RUNNING]), + ]) + delta_num_replicas = num_replicas - current_num_replicas - backend_info: BackendInfo = self.backends[backend_tag] - if delta_num_replicas > 0: + backend_info: BackendInfo = self._backend_metadata[backend_tag] + if delta_num_replicas == 0: + return False + + elif delta_num_replicas > 0: can_schedule = try_schedule_resources_on_nodes(requirements=[ backend_info.replica_config.resource_dict for _ in range(delta_num_replicas) @@ -292,10 +420,11 @@ class BackendState: if _RESOURCE_CHECK_ENABLED and not all(can_schedule): num_possible = sum(can_schedule) - raise RayServeException( + logger.error( "Cannot scale backend {} to {} replicas. Ray Serve tried " "to add {} replicas but the resources only allows {} " - "to be added. To fix this, consider scaling to replica to " + "to be added. This is not a problem if the cluster is " + "autoscaling. To fix this, consider scaling to replica to " "{} or add more resources to the cluster. You can check " "avaiable resources with ray.nodes().".format( backend_tag, num_replicas, delta_num_replicas, @@ -305,154 +434,132 @@ class BackendState: delta_num_replicas, backend_tag)) for _ in range(delta_num_replicas): replica_tag = "{}#{}".format(backend_tag, get_random_letters()) - self.backend_replicas_to_start[backend_tag].append(replica_tag) + self._replicas[backend_tag][ReplicaState.SHOULD_START].append( + BackendReplica(self._controller_name, self._detached, + replica_tag, backend_tag)) elif delta_num_replicas < 0: logger.debug("Removing {} replicas from backend '{}'".format( -delta_num_replicas, backend_tag)) - assert len( - self.backend_replicas[backend_tag]) >= delta_num_replicas - replicas_copy = self.backend_replicas.copy() + assert self._target_replicas[backend_tag] >= delta_num_replicas + for _ in range(-delta_num_replicas): - replica_tag, _ = replicas_copy[backend_tag].popitem() + replica_state_dict = self._replicas[backend_tag] + list_to_use = replica_state_dict[ReplicaState.SHOULD_START] \ + or replica_state_dict[ReplicaState.STARTING] \ + or replica_state_dict[ReplicaState.RUNNING] + + assert len(list_to_use), replica_state_dict + replica_to_stop = list_to_use.pop() graceful_timeout_s = (backend_info.backend_config. experimental_graceful_shutdown_timeout_s) - if force_kill: - graceful_timeout_s = 0 - self.backend_replicas_to_stop[backend_tag].append(( - replica_tag, - graceful_timeout_s, - )) - def _start_pending_replicas(self): - for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ - items(): - for replica_tag in replicas_to_create: - replica_handle = self._start_backend_replica( - backend_tag, replica_tag) - ready_future = replica_handle.ready.remote().as_future() - self.currently_starting_replicas[ready_future] = ( - backend_tag, replica_tag, replica_handle) + replica_to_stop.set_should_stop(graceful_timeout_s) + self._replicas[backend_tag][ReplicaState.SHOULD_STOP].append( + replica_to_stop) - def _stop_pending_replicas(self): - for backend_tag, replicas_to_stop in ( - self.backend_replicas_to_stop.items()): - for replica_tag, shutdown_timeout in replicas_to_stop: - replica_name = format_actor_name(replica_tag, - self._controller_name) + return True - async def kill_actor(replica_name_to_use): - # NOTE: the replicas may already be stopped if we failed - # after stopping them but before writing a checkpoint. - try: - replica = ray.get_actor(replica_name_to_use) - except ValueError: - return + def scale_all_backends(self): + checkpoint_needed = False + for backend_tag, num_replicas in list(self._target_replicas.items()): + checkpoint_needed = (checkpoint_needed + or self.scale_backend_replicas(backend_tag)) + if num_replicas == 0: + del self._backend_metadata[backend_tag] + del self._target_replicas[backend_tag] - try: - await asyncio.wait_for( - replica.drain_pending_queries.remote(), - timeout=shutdown_timeout) - except asyncio.TimeoutError: - # Graceful period passed, kill it forcefully. - logger.debug( - f"{replica_name_to_use} did not shutdown after " - f"{shutdown_timeout}s, killing.") - finally: - ray.kill(replica, no_restart=True) + if checkpoint_needed: + self._checkpoint() - self.currently_stopping_replicas[asyncio.ensure_future( - kill_actor(replica_name))] = (backend_tag, replica_tag) + def _pop_replicas_of_state(self, state: ReplicaState + ) -> List[Tuple[ReplicaState, BackendTag]]: + replicas = [] + for backend_tag, state_to_replica_dict in self._replicas.items(): + if state in state_to_replica_dict: + replicas.extend( + (replica, backend_tag) + for replica in state_to_replica_dict.pop(state)) - async def _check_currently_starting_replicas(self) -> int: - """Returns the number of pending replicas waiting to start""" - in_flight: Set[Future[Any]] = set() - - if self.currently_starting_replicas: - done, in_flight = await asyncio.wait( - list(self.currently_starting_replicas.keys()), timeout=0) - for fut in done: - (backend_tag, replica_tag, - replica_handle) = self.currently_starting_replicas.pop(fut) - self.backend_replicas[backend_tag][ - replica_tag] = replica_handle - - backend = self.backend_replicas_to_start.get(backend_tag) - if backend: - try: - backend.remove(replica_tag) - except ValueError: - pass - if len(backend) == 0: - del self.backend_replicas_to_start[backend_tag] - - async def _check_currently_stopping_replicas(self) -> int: - """Returns the number of replicas waiting to stop""" - in_flight: Set[Future[Any]] = set() - - if self.currently_stopping_replicas: - done_stopping, in_flight = await asyncio.wait( - list(self.currently_stopping_replicas.keys()), timeout=0) - for fut in done_stopping: - (backend_tag, - replica_tag) = self.currently_stopping_replicas.pop(fut) - - backend_to_stop = self.backend_replicas_to_stop.get( - backend_tag) - - if backend_to_stop: - try: - backend_to_stop.remove(replica_tag) - except ValueError: - pass - if len(backend_to_stop) == 0: - del self.backend_replicas_to_stop[backend_tag] - - backend = self.backend_replicas.get(backend_tag) - if backend: - try: - del backend[replica_tag] - except KeyError: - pass - - if len(self.backend_replicas[backend_tag]) == 0: - del self.backend_replicas[backend_tag] + return replicas def _completed_goals(self) -> List[GoalId]: completed_goals = [] - all_tags = set(self.backend_replicas.keys()).union( - set(self.backends.keys())) + all_tags = set(self._replicas.keys()).union( + set(self._backend_metadata.keys())) for backend_tag in all_tags: - desired_info = self.backends.get(backend_tag) - existing_info = self.backend_replicas.get(backend_tag) + desired_num_replicas = self._target_replicas.get(backend_tag) + state_dict = self._replicas.get(backend_tag, {}) + existing_info = state_dict.get(ReplicaState.RUNNING, []) + + # If we have pending ops, the current goal is *not* ready + if (state_dict.get(ReplicaState.SHOULD_START) + or state_dict.get(ReplicaState.STARTING) + or state_dict.get(ReplicaState.SHOULD_STOP) + or state_dict.get(ReplicaState.STOPPING)): + continue + + # TODO(ilr): FIX # Check for deleting - if (not desired_info or - desired_info.backend_config.num_replicas == 0) and \ + if (not desired_num_replicas or + desired_num_replicas == 0) and \ (not existing_info or len(existing_info) == 0): - completed_goals.append(self.backend_goals.get(backend_tag)) + completed_goals.append( + self.backend_goals.pop(backend_tag, None)) # Check for a non-zero number of backends - if desired_info and existing_info and desired_info.backend_config.\ - num_replicas == len(existing_info): - completed_goals.append(self.backend_goals.get(backend_tag)) + if (desired_num_replicas and existing_info) \ + and desired_num_replicas == len(existing_info): + completed_goals.append( + self.backend_goals.pop(backend_tag, None)) return [goal for goal in completed_goals if goal] async def update(self) -> bool: + self.scale_all_backends() + for goal_id in self._completed_goals(): self._goal_manager.complete_goal(goal_id) - self._start_pending_replicas() - self._stop_pending_replicas() + for replica_state, backend_tag in self._pop_replicas_of_state( + ReplicaState.SHOULD_START): + replica_state.start(self._backend_metadata[backend_tag]) + self._replicas[backend_tag][ReplicaState.STARTING].append( + replica_state) - num_starting = len(self.currently_starting_replicas) - num_stopping = len(self.currently_stopping_replicas) + for replica_state, backend_tag in self._pop_replicas_of_state( + ReplicaState.SHOULD_STOP): + replica_state.stop() + self._replicas[backend_tag][ReplicaState.STOPPING].append( + replica_state) - await self._check_currently_starting_replicas() - await self._check_currently_stopping_replicas() + transition_triggered = False - if (len(self.currently_starting_replicas) != num_starting) or \ - (len(self.currently_stopping_replicas) != num_stopping): + for replica_state, backend_tag in self._pop_replicas_of_state( + ReplicaState.STARTING): + if replica_state.check_started(): + self._replicas[backend_tag][ReplicaState.RUNNING].append( + replica_state) + transition_triggered = True + else: + self._replicas[backend_tag][ReplicaState.STARTING].append( + replica_state) + + for replica_state, backend_tag in self._pop_replicas_of_state( + ReplicaState.STOPPING): + if replica_state.check_stopped(): + transition_triggered = True + else: + self._replicas[backend_tag][ReplicaState.STOPPING].append( + replica_state) + + for backend_tag in list(self._replicas.keys()): + if not any(self._replicas[backend_tag]): + del self._replicas[backend_tag] + del self._backend_metadata[backend_tag] + del self._target_replicas[backend_tag] + + if transition_triggered: self._checkpoint() self._notify_replica_handles_changed() diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 205af81b0..41a1eca08 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, Dict, List, Optional import pydantic -from pydantic import BaseModel, PositiveFloat, PositiveInt, validator +from pydantic import BaseModel, confloat, PositiveFloat, PositiveInt, validator from ray.serve.constants import (ASYNC_CONCURRENCY, DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT) @@ -64,7 +64,7 @@ class BackendConfig(BaseModel): user_config: Any = None experimental_graceful_shutdown_wait_loop_s: PositiveFloat = 2.0 - experimental_graceful_shutdown_timeout_s: PositiveFloat = 20.0 + experimental_graceful_shutdown_timeout_s: confloat(ge=0) = 20.0 class Config: validate_assignment = True diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index a3c75c711..b5c65111a 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -118,7 +118,7 @@ class ServeController: def _all_replica_handles( self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: """Used for testing.""" - return self.backend_state.get_replica_handles() + return self.backend_state.get_running_replica_handles() def get_all_backends(self) -> Dict[BackendTag, BackendConfig]: """Returns a dictionary of backend tag to backend config.""" @@ -235,7 +235,7 @@ class ServeController: async with self.write_lock: for proxy in self.http_state.get_http_proxy_handles().values(): ray.kill(proxy, no_restart=True) - for replica_dict in self.backend_state.get_replica_handles( + for replica_dict in self.backend_state.get_running_replica_handles( ).values(): for replica in replica_dict.values(): ray.kill(replica, no_restart=True) diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 202b01386..a35f7e54b 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -683,6 +683,9 @@ def test_endpoint_input_validation(serve_instance): client.create_endpoint("endpoint", backend="backend") +# This error is only printed because creation is run in the control loop, not +# in the API path. +@pytest.mark.skip() def test_create_infeasible_error(serve_instance): client = serve_instance