diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 4aad2671e..673c4b2cf 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -1,8 +1,7 @@ import asyncio +from asyncio.futures import Future from collections import defaultdict -from enum import Enum -import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, Any, List, Optional, Set, Tuple import ray import ray.cloudpickle as pickle @@ -18,6 +17,7 @@ from ray.serve.common import ( ) from ray.serve.config import BackendConfig, ReplicaConfig from ray.serve.constants import LongPollKey +from ray.serve.exceptions import RayServeException from ray.serve.kv_store import RayInternalKVStore from ray.serve.long_poll import LongPollHost from ray.serve.utils import (format_actor_name, get_random_letters, logger, @@ -30,150 +30,6 @@ CHECKPOINT_KEY = "serve-backend-state-checkpoint" _RESOURCE_CHECK_ENABLED = True -class ReplicaState(Enum): - SHOULD_START = 1 - STARTING = 2 - RUNNING = 3 - SHOULD_STOP = 4 - STOPPING = 5 - STOPPED = 6 - - -class BackendReplica: - def __init__(self, controller_name: str, detached: bool, - replica_tag: ReplicaTag, backend_tag: BackendTag): - self._actor_name = format_actor_name(replica_tag, controller_name) - self._controller_name = controller_name - self._detached = detached - self._replica_tag = replica_tag - self._backend_tag = backend_tag - self._actor_handle = None - self._startup_obj_ref = None - self._drain_obj_ref = None - self._state = ReplicaState.SHOULD_START - - def __get_state__(self): - clean_dict = self.__dict__.copy() - del clean_dict["_actor_handle"] - del clean_dict["_startup_obj_ref"] - del clean_dict["_drain_obj_ref"] - return clean_dict - - def __set_state__(self, d): - self.__dict__ = d - self._actor_handle = None - self._startup_obj_ref = None - self._drain_obj_ref = None - self._recover_from_checkpoint() - - def _recover_from_checkpoint(self): - if self._state == ReplicaState.STARTING: - # We do not need to pass in the class here because the actor - # creation has already been started if this class was checkpointed - # in the STARTING state. - self.start() - elif self._state == ReplicaState.RUNNING: - # Fetch actor handles for all backend replicas in the system. - # The actors must exist if this class was checkpointed in the - # RUNNING state. - self._actor_handle = ray.get_actor(self._actor_name) - elif self._state == ReplicaState.STOPPING: - self.stop() - - def start(self, backend_info: Optional[BackendInfo]): - assert self._state in { - ReplicaState.SHOULD_START, ReplicaState.STARTING - }, (f"State must be {ReplicaState.SHOULD_START} or " - f"{ReplicaState.STARTING}, *not* {self._state}") - try: - self._actor_handle = ray.get_actor(self._actor_name) - except ValueError: - logger.debug("Starting replica '{}' for backend '{}'.".format( - self._replica_tag, self._backend_tag)) - self._actor_handle = ray.remote(backend_info.worker_class).options( - name=self._actor_name, - lifetime="detached" if self._detached else None, - max_restarts=-1, - max_task_retries=-1, - **backend_info.replica_config.ray_actor_options).remote( - self._backend_tag, self._replica_tag, - backend_info.replica_config.actor_init_args, - backend_info.backend_config, self._controller_name) - self._startup_obj_ref = self._actor_handle.ready.remote() - self._state = ReplicaState.STARTING - - def check_started(self): - if self._state == ReplicaState.RUNNING: - return True - assert self._state == ReplicaState.STARTING, ( - f"State must be {ReplicaState.STARTING}, *not* {self._state}") - ready, _ = ray.wait([self._startup_obj_ref], timeout=0) - if len(ready) == 1: - self._state = ReplicaState.RUNNING - return True - return False - - def set_should_stop(self, graceful_shutdown_timeout_s: Duration): - self._state = ReplicaState.SHOULD_STOP - self._graceful_shutdown_timeout_s = graceful_shutdown_timeout_s - - def stop(self): - # We need to handle transitions from: - # SHOULD_START -> SHOULD_STOP -> STOPPING - # This means that the replica_handle may not have been created. - - assert self._state in { - ReplicaState.SHOULD_STOP, ReplicaState.STOPPING - }, (f"State must be {ReplicaState.SHOULD_STOP} or " - f"{ReplicaState.STOPPING}, *not* {self._state}") - - def drain_actor(actor_name): - # NOTE: the replicas may already be stopped if we failed - # after stopping them but before writing a checkpoint. - try: - replica = ray.get_actor(actor_name) - except ValueError: - return None - return replica.drain_pending_queries.remote() - - self._state = ReplicaState.STOPPING - self._drain_obj_ref = drain_actor(self._actor_name) - self._shutdown_deadline = time.time( - ) + self._graceful_shutdown_timeout_s - - def check_stopped(self): - if self._state == ReplicaState.STOPPED: - return True - assert self._state == ReplicaState.STOPPING, ( - f"State must be {ReplicaState.STOPPING}, *not* {self._state}") - - try: - replica = ray.get_actor(self._actor_name) - except ValueError: - self._state = ReplicaState.STOPPED - return True - - ready, _ = ray.wait([self._drain_obj_ref], timeout=0) - timeout_passed = time.time() > self._shutdown_deadline - - if len(ready) == 1 or timeout_passed: - if timeout_passed: - # Graceful period passed, kill it forcefully. - logger.debug( - f"{self._actor_name} did not shutdown after " - f"{self._graceful_shutdown_timeout_s}s, force-killing.") - - ray.kill(replica, no_restart=True) - self._state = ReplicaState.STOPPED - return True - return False - - def get_actor_handle(self): - assert self._state == ReplicaState.RUNNING, ( - f"State must be {ReplicaState.RUNNING}, *not* {self._state}") - return self._actor_handle - - class BackendState: """Manages all state for backends in the system. @@ -190,65 +46,79 @@ class BackendState: self._long_poll_host = long_poll_host self._goal_manager = goal_manager - self._replicas: Dict[BackendTag, Dict[ReplicaState, List[ - BackendReplica]]] = defaultdict(lambda: defaultdict(list)) - self._backend_metadata: Dict[BackendTag, BackendInfo] = dict() - self._target_replicas: Dict[BackendTag, int] = defaultdict(int) - self.backend_goals: Dict[BackendTag, GoalId] = dict() + # Non-checkpointed state. + self.currently_starting_replicas: Dict[asyncio.Future, Tuple[ + BackendTag, ReplicaTag, ActorHandle]] = dict() + self.currently_stopping_replicas: Dict[asyncio.Future, Tuple[ + BackendTag, ReplicaTag]] = dict() - # Un-Checkpointed state. - self.pending_goals: Dict[GoalId, asyncio.Event] = dict() + # Checkpointed state. + self.backends: Dict[BackendTag, BackendInfo] = dict() + self.backend_replicas: Dict[BackendTag, Dict[ + ReplicaTag, ActorHandle]] = defaultdict(dict) + self.backend_goals: Dict[BackendTag, GoalId] = dict() + self.backend_replicas_to_start: Dict[BackendTag, List[ + ReplicaTag]] = defaultdict(list) + self.backend_replicas_to_stop: Dict[BackendTag, List[Tuple[ + ReplicaTag, Duration]]] = defaultdict(list) + self.backends_to_remove: List[BackendTag] = list() checkpoint = self._kv_store.get(CHECKPOINT_KEY) if checkpoint is not None: - (self._replicas, self._backend_metadata, self._target_replicas, - self.backend_goals, pending_goal_ids) = pickle.loads(checkpoint) + (self.backends, self.backend_replicas, self.backend_goals, + self.backend_replicas_to_start, self.backend_replicas_to_stop, + self.backend_to_remove, + pending_goal_ids) = pickle.loads(checkpoint) for goal_id in pending_goal_ids: self._goal_manager.create_goal(goal_id) + # Fetch actor handles for all backend replicas in the system. + # All of these backend_replicas are guaranteed to already exist + # because they would not be written to a checkpoint in + # self.backend_replicas until they were created. + for backend_tag, replica_dict in self.backend_replicas.items(): + for replica_tag in replica_dict.keys(): + replica_name = format_actor_name(replica_tag, + self._controller_name) + self.backend_replicas[backend_tag][ + replica_tag] = ray.get_actor(replica_name) + self._notify_backend_configs_changed() self._notify_replica_handles_changed() def _checkpoint(self) -> None: self._kv_store.put( CHECKPOINT_KEY, - pickle.dumps((self._replicas, self._backend_metadata, - self._target_replicas, self.backend_goals, - self._goal_manager.get_pending_goal_ids()))) + pickle.dumps( + (self.backends, self.backend_replicas, self.backend_goals, + self.backend_replicas_to_start, self.backend_replicas_to_stop, + self.backends_to_remove, + self._goal_manager.get_pending_goal_ids()))) def _notify_backend_configs_changed(self) -> None: self._long_poll_host.notify_changed(LongPollKey.BACKEND_CONFIGS, self.get_backend_configs()) - def get_running_replica_handles( - self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: - return { - backend_tag: { - backend_replica._replica_tag: - backend_replica.get_actor_handle() - for backend_replica in state_to_replica_dict[ - ReplicaState.RUNNING] - } - for backend_tag, state_to_replica_dict in self._replicas.items() - } - def _notify_replica_handles_changed(self) -> None: self._long_poll_host.notify_changed( LongPollKey.REPLICA_HANDLES, { backend_tag: list(replica_dict.values()) - for backend_tag, replica_dict in - self.get_running_replica_handles().items() + for backend_tag, replica_dict in self.backend_replicas.items() }) def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]: return { tag: info.backend_config - for tag, info in self._backend_metadata.items() + for tag, info in self.backends.items() } + def get_replica_handles( + self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: + return self.backend_replicas + def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: - return self._backend_metadata.get(backend_tag) + return self.backends.get(backend_tag) def _set_backend_goal(self, backend_tag: BackendTag, backend_info: BackendInfo) -> None: @@ -256,11 +126,7 @@ class BackendState: new_goal_id = self._goal_manager.create_goal() if backend_info is not None: - self._backend_metadata[backend_tag] = backend_info - self._target_replicas[ - backend_tag] = backend_info.backend_config.num_replicas - else: - self._target_replicas[backend_tag] = 0 + self.backends[backend_tag] = backend_info self.backend_goals[backend_tag] = new_goal_id @@ -270,25 +136,31 @@ class BackendState: backend_config: BackendConfig, replica_config: ReplicaConfig) -> Optional[GoalId]: # Ensures this method is idempotent. - backend_info = self._backend_metadata.get(backend_tag) + backend_info = self.backends.get(backend_tag) if backend_info is not None: if (backend_info.backend_config == backend_config and backend_info.replica_config == replica_config): return None - backend_replica_class = create_backend_replica( - replica_config.func_or_class) + backend_replica = create_backend_replica(replica_config.func_or_class) # Save creator that starts replicas, the arguments to be passed in, # and the configuration for the backends. backend_info = BackendInfo( - worker_class=backend_replica_class, + worker_class=backend_replica, backend_config=backend_config, replica_config=replica_config) new_goal_id, existing_goal_id = self._set_backend_goal( backend_tag, backend_info) + try: + self.scale_backend_replicas(backend_tag, + backend_config.num_replicas) + except RayServeException as e: + del self.backends[backend_tag] + raise e + # NOTE(edoakes): we must write a checkpoint before starting new # or pushing the updated config to avoid inconsistent state if we # crash while making the change. @@ -303,15 +175,20 @@ class BackendState: force_kill: bool = False) -> Optional[GoalId]: # This method must be idempotent. We should validate that the # specified backend exists on the client. - if backend_tag not in self._backend_metadata: + if backend_tag not in self.backends: return None + # Scale its replicas down to 0. + self.scale_backend_replicas(backend_tag, 0, force_kill) + + # Remove the backend's metadata. + del self.backends[backend_tag] + + # Add the intention to remove the backend from the routers. + self.backends_to_remove.append(backend_tag) + new_goal_id, existing_goal_id = self._set_backend_goal( backend_tag, None) - if force_kill: - self._backend_metadata[ - backend_tag].backend_config.\ - experimental_graceful_shutdown_timeout_s = 0 self._checkpoint() if existing_goal_id is not None: @@ -320,18 +197,20 @@ class BackendState: def update_backend_config(self, backend_tag: BackendTag, config_options: BackendConfig): - if backend_tag not in self._backend_metadata: + if backend_tag not in self.backends: raise ValueError(f"Backend {backend_tag} is not registered") - stored_backend_config = self._backend_metadata[ - backend_tag].backend_config + stored_backend_config = self.backends[backend_tag].backend_config updated_config = stored_backend_config.copy( update=config_options.dict(exclude_unset=True)) updated_config._validate_complete() - self._backend_metadata[backend_tag].backend_config = updated_config + self.backends[backend_tag].backend_config = updated_config new_goal_id, existing_goal_id = self._set_backend_goal( - backend_tag, self._backend_metadata[backend_tag]) + backend_tag, self.backends[backend_tag]) + + # Scale the replicas with the new configuration. + self.scale_backend_replicas(backend_tag, updated_config.num_replicas) # NOTE(edoakes): we must write a checkpoint before pushing the # update to avoid inconsistent state if we crash after pushing the @@ -381,38 +260,31 @@ class BackendState: def scale_backend_replicas( self, backend_tag: BackendTag, - ) -> bool: + num_replicas: int, + force_kill: bool = False, + ) -> None: """Scale the given backend to the number of replicas. NOTE: this does not actually start or stop the replicas, but instead - adds them to ReplicaState.SHOULD_START or ReplicaState.SHOULD_STOP. - The caller is responsible for then first writing a checkpoint and then - actually starting/stopping the intended replicas. This avoids - inconsistencies with starting/stopping a replica and then crashing - before writing a checkpoint. + adds the intention to start/stop them to self.backend_replicas_to_start + and self.backend_replicas_to_stop. The caller is responsible for then + first writing a checkpoint and then actually starting/stopping the + intended replicas. This avoids inconsistencies with starting/stopping a + replica and then crashing before writing a checkpoint. """ - num_replicas = self._target_replicas.get(backend_tag, 0) logger.debug("Scaling backend '{}' to {} replicas".format( backend_tag, num_replicas)) - assert (backend_tag in self._backend_metadata + assert (backend_tag in self.backends ), "Backend {} is not registered.".format(backend_tag) assert num_replicas >= 0, ("Number of replicas must be" " greater than or equal to 0.") - current_num_replicas = sum([ - len(self._replicas[backend_tag][ReplicaState.SHOULD_START]), - len(self._replicas[backend_tag][ReplicaState.STARTING]), - len(self._replicas[backend_tag][ReplicaState.RUNNING]), - ]) - + current_num_replicas = len(self.backend_replicas[backend_tag]) delta_num_replicas = num_replicas - current_num_replicas - backend_info: BackendInfo = self._backend_metadata[backend_tag] - if delta_num_replicas == 0: - return False - - elif delta_num_replicas > 0: + backend_info: BackendInfo = self.backends[backend_tag] + if delta_num_replicas > 0: can_schedule = try_schedule_resources_on_nodes(requirements=[ backend_info.replica_config.resource_dict for _ in range(delta_num_replicas) @@ -420,11 +292,10 @@ class BackendState: if _RESOURCE_CHECK_ENABLED and not all(can_schedule): num_possible = sum(can_schedule) - logger.error( + raise RayServeException( "Cannot scale backend {} to {} replicas. Ray Serve tried " "to add {} replicas but the resources only allows {} " - "to be added. This is not a problem if the cluster is " - "autoscaling. To fix this, consider scaling to replica to " + "to be added. To fix this, consider scaling to replica to " "{} or add more resources to the cluster. You can check " "avaiable resources with ray.nodes().".format( backend_tag, num_replicas, delta_num_replicas, @@ -434,132 +305,154 @@ class BackendState: delta_num_replicas, backend_tag)) for _ in range(delta_num_replicas): replica_tag = "{}#{}".format(backend_tag, get_random_letters()) - self._replicas[backend_tag][ReplicaState.SHOULD_START].append( - BackendReplica(self._controller_name, self._detached, - replica_tag, backend_tag)) + self.backend_replicas_to_start[backend_tag].append(replica_tag) elif delta_num_replicas < 0: logger.debug("Removing {} replicas from backend '{}'".format( -delta_num_replicas, backend_tag)) - assert self._target_replicas[backend_tag] >= delta_num_replicas - + assert len( + self.backend_replicas[backend_tag]) >= delta_num_replicas + replicas_copy = self.backend_replicas.copy() for _ in range(-delta_num_replicas): - replica_state_dict = self._replicas[backend_tag] - list_to_use = replica_state_dict[ReplicaState.SHOULD_START] \ - or replica_state_dict[ReplicaState.STARTING] \ - or replica_state_dict[ReplicaState.RUNNING] - - assert len(list_to_use), replica_state_dict - replica_to_stop = list_to_use.pop() + replica_tag, _ = replicas_copy[backend_tag].popitem() graceful_timeout_s = (backend_info.backend_config. experimental_graceful_shutdown_timeout_s) + if force_kill: + graceful_timeout_s = 0 + self.backend_replicas_to_stop[backend_tag].append(( + replica_tag, + graceful_timeout_s, + )) - replica_to_stop.set_should_stop(graceful_timeout_s) - self._replicas[backend_tag][ReplicaState.SHOULD_STOP].append( - replica_to_stop) + def _start_pending_replicas(self): + for backend_tag, replicas_to_create in self.backend_replicas_to_start.\ + items(): + for replica_tag in replicas_to_create: + replica_handle = self._start_backend_replica( + backend_tag, replica_tag) + ready_future = replica_handle.ready.remote().as_future() + self.currently_starting_replicas[ready_future] = ( + backend_tag, replica_tag, replica_handle) - return True + def _stop_pending_replicas(self): + for backend_tag, replicas_to_stop in ( + self.backend_replicas_to_stop.items()): + for replica_tag, shutdown_timeout in replicas_to_stop: + replica_name = format_actor_name(replica_tag, + self._controller_name) - def scale_all_backends(self): - checkpoint_needed = False - for backend_tag, num_replicas in list(self._target_replicas.items()): - checkpoint_needed = (checkpoint_needed - or self.scale_backend_replicas(backend_tag)) - if num_replicas == 0: - del self._backend_metadata[backend_tag] - del self._target_replicas[backend_tag] + async def kill_actor(replica_name_to_use): + # NOTE: the replicas may already be stopped if we failed + # after stopping them but before writing a checkpoint. + try: + replica = ray.get_actor(replica_name_to_use) + except ValueError: + return - if checkpoint_needed: - self._checkpoint() + try: + await asyncio.wait_for( + replica.drain_pending_queries.remote(), + timeout=shutdown_timeout) + except asyncio.TimeoutError: + # Graceful period passed, kill it forcefully. + logger.debug( + f"{replica_name_to_use} did not shutdown after " + f"{shutdown_timeout}s, killing.") + finally: + ray.kill(replica, no_restart=True) - def _pop_replicas_of_state(self, state: ReplicaState - ) -> List[Tuple[ReplicaState, BackendTag]]: - replicas = [] - for backend_tag, state_to_replica_dict in self._replicas.items(): - if state in state_to_replica_dict: - replicas.extend( - (replica, backend_tag) - for replica in state_to_replica_dict.pop(state)) + self.currently_stopping_replicas[asyncio.ensure_future( + kill_actor(replica_name))] = (backend_tag, replica_tag) - return replicas + async def _check_currently_starting_replicas(self) -> int: + """Returns the number of pending replicas waiting to start""" + in_flight: Set[Future[Any]] = set() + + if self.currently_starting_replicas: + done, in_flight = await asyncio.wait( + list(self.currently_starting_replicas.keys()), timeout=0) + for fut in done: + (backend_tag, replica_tag, + replica_handle) = self.currently_starting_replicas.pop(fut) + self.backend_replicas[backend_tag][ + replica_tag] = replica_handle + + backend = self.backend_replicas_to_start.get(backend_tag) + if backend: + try: + backend.remove(replica_tag) + except ValueError: + pass + if len(backend) == 0: + del self.backend_replicas_to_start[backend_tag] + + async def _check_currently_stopping_replicas(self) -> int: + """Returns the number of replicas waiting to stop""" + in_flight: Set[Future[Any]] = set() + + if self.currently_stopping_replicas: + done_stopping, in_flight = await asyncio.wait( + list(self.currently_stopping_replicas.keys()), timeout=0) + for fut in done_stopping: + (backend_tag, + replica_tag) = self.currently_stopping_replicas.pop(fut) + + backend_to_stop = self.backend_replicas_to_stop.get( + backend_tag) + + if backend_to_stop: + try: + backend_to_stop.remove(replica_tag) + except ValueError: + pass + if len(backend_to_stop) == 0: + del self.backend_replicas_to_stop[backend_tag] + + backend = self.backend_replicas.get(backend_tag) + if backend: + try: + del backend[replica_tag] + except KeyError: + pass + + if len(self.backend_replicas[backend_tag]) == 0: + del self.backend_replicas[backend_tag] def _completed_goals(self) -> List[GoalId]: completed_goals = [] - all_tags = set(self._replicas.keys()).union( - set(self._backend_metadata.keys())) + all_tags = set(self.backend_replicas.keys()).union( + set(self.backends.keys())) for backend_tag in all_tags: - desired_num_replicas = self._target_replicas.get(backend_tag) - state_dict = self._replicas.get(backend_tag, {}) - existing_info = state_dict.get(ReplicaState.RUNNING, []) - - # If we have pending ops, the current goal is *not* ready - if (state_dict.get(ReplicaState.SHOULD_START) - or state_dict.get(ReplicaState.STARTING) - or state_dict.get(ReplicaState.SHOULD_STOP) - or state_dict.get(ReplicaState.STOPPING)): - continue - - # TODO(ilr): FIX + desired_info = self.backends.get(backend_tag) + existing_info = self.backend_replicas.get(backend_tag) # Check for deleting - if (not desired_num_replicas or - desired_num_replicas == 0) and \ + if (not desired_info or + desired_info.backend_config.num_replicas == 0) and \ (not existing_info or len(existing_info) == 0): - completed_goals.append( - self.backend_goals.pop(backend_tag, None)) + completed_goals.append(self.backend_goals.get(backend_tag)) # Check for a non-zero number of backends - if (desired_num_replicas and existing_info) \ - and desired_num_replicas == len(existing_info): - completed_goals.append( - self.backend_goals.pop(backend_tag, None)) + if desired_info and existing_info and desired_info.backend_config.\ + num_replicas == len(existing_info): + completed_goals.append(self.backend_goals.get(backend_tag)) return [goal for goal in completed_goals if goal] async def update(self) -> bool: - self.scale_all_backends() - for goal_id in self._completed_goals(): self._goal_manager.complete_goal(goal_id) - for replica_state, backend_tag in self._pop_replicas_of_state( - ReplicaState.SHOULD_START): - replica_state.start(self._backend_metadata[backend_tag]) - self._replicas[backend_tag][ReplicaState.STARTING].append( - replica_state) + self._start_pending_replicas() + self._stop_pending_replicas() - for replica_state, backend_tag in self._pop_replicas_of_state( - ReplicaState.SHOULD_STOP): - replica_state.stop() - self._replicas[backend_tag][ReplicaState.STOPPING].append( - replica_state) + num_starting = len(self.currently_starting_replicas) + num_stopping = len(self.currently_stopping_replicas) - transition_triggered = False + await self._check_currently_starting_replicas() + await self._check_currently_stopping_replicas() - for replica_state, backend_tag in self._pop_replicas_of_state( - ReplicaState.STARTING): - if replica_state.check_started(): - self._replicas[backend_tag][ReplicaState.RUNNING].append( - replica_state) - transition_triggered = True - else: - self._replicas[backend_tag][ReplicaState.STARTING].append( - replica_state) - - for replica_state, backend_tag in self._pop_replicas_of_state( - ReplicaState.STOPPING): - if replica_state.check_stopped(): - transition_triggered = True - else: - self._replicas[backend_tag][ReplicaState.STOPPING].append( - replica_state) - - for backend_tag in list(self._replicas.keys()): - if not any(self._replicas[backend_tag]): - del self._replicas[backend_tag] - del self._backend_metadata[backend_tag] - del self._target_replicas[backend_tag] - - if transition_triggered: + if (len(self.currently_starting_replicas) != num_starting) or \ + (len(self.currently_stopping_replicas) != num_stopping): self._checkpoint() self._notify_replica_handles_changed() diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 41a1eca08..205af81b0 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Any, Dict, List, Optional import pydantic -from pydantic import BaseModel, confloat, PositiveFloat, PositiveInt, validator +from pydantic import BaseModel, PositiveFloat, PositiveInt, validator from ray.serve.constants import (ASYNC_CONCURRENCY, DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT) @@ -64,7 +64,7 @@ class BackendConfig(BaseModel): user_config: Any = None experimental_graceful_shutdown_wait_loop_s: PositiveFloat = 2.0 - experimental_graceful_shutdown_timeout_s: confloat(ge=0) = 20.0 + experimental_graceful_shutdown_timeout_s: PositiveFloat = 20.0 class Config: validate_assignment = True diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index b5c65111a..a3c75c711 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -118,7 +118,7 @@ class ServeController: def _all_replica_handles( self) -> Dict[BackendTag, Dict[ReplicaTag, ActorHandle]]: """Used for testing.""" - return self.backend_state.get_running_replica_handles() + return self.backend_state.get_replica_handles() def get_all_backends(self) -> Dict[BackendTag, BackendConfig]: """Returns a dictionary of backend tag to backend config.""" @@ -235,7 +235,7 @@ class ServeController: async with self.write_lock: for proxy in self.http_state.get_http_proxy_handles().values(): ray.kill(proxy, no_restart=True) - for replica_dict in self.backend_state.get_running_replica_handles( + for replica_dict in self.backend_state.get_replica_handles( ).values(): for replica in replica_dict.values(): ray.kill(replica, no_restart=True) diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index a35f7e54b..202b01386 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -683,9 +683,6 @@ def test_endpoint_input_validation(serve_instance): client.create_endpoint("endpoint", backend="backend") -# This error is only printed because creation is run in the control loop, not -# in the API path. -@pytest.mark.skip() def test_create_infeasible_error(serve_instance): client = serve_instance