diff --git a/python/ray/serve/async_goal_manager.py b/python/ray/serve/async_goal_manager.py new file mode 100644 index 000000000..606106fe5 --- /dev/null +++ b/python/ray/serve/async_goal_manager.py @@ -0,0 +1,41 @@ +import asyncio +import time +from typing import Dict, List, Optional +from uuid import uuid4 + +from ray.serve.common import GoalId +from ray.serve.utils import logger + + +class AsyncGoalManager: + def __init__(self): + self._pending_goals: Dict[GoalId, asyncio.Event] = dict() + + def get_pending_goal_ids(self) -> List[GoalId]: + return list(self._pending_goals.keys()) + + def num_pending_goals(self) -> int: + return len(self._pending_goals) + + def create_goal(self, goal_id: Optional[GoalId] = None) -> GoalId: + if goal_id is None: + goal_id = uuid4() + self._pending_goals[goal_id] = asyncio.Event() + return goal_id + + def complete_goal(self, goal_id: GoalId) -> None: + logger.debug(f"Completing goal {goal_id}") + event = self._pending_goals.pop(goal_id, None) + if event: + event.set() + + async def wait_for_goal(self, goal_id: GoalId) -> None: + start = time.time() + if goal_id not in self._pending_goals: + logger.debug(f"Goal {goal_id} not found") + return + + event = self._pending_goals[goal_id] + await event.wait() + logger.debug( + f"Waiting for goal {goal_id} took {time.time() - start} seconds") diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 40005cc16..673c4b2cf 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -1,13 +1,12 @@ import asyncio from asyncio.futures import Future from collections import defaultdict -import time from typing import Dict, Any, List, Optional, Set, Tuple -from uuid import uuid4 import ray import ray.cloudpickle as pickle from ray.actor import ActorHandle +from ray.serve.async_goal_manager import AsyncGoalManager from ray.serve.backend_worker import create_backend_replica from ray.serve.common import ( BackendInfo, @@ -39,11 +38,13 @@ class BackendState: """ def __init__(self, controller_name: str, detached: bool, - kv_store: RayInternalKVStore, long_poll_host: LongPollHost): + kv_store: RayInternalKVStore, long_poll_host: LongPollHost, + goal_manager: AsyncGoalManager): self._controller_name = controller_name self._detached = detached self._kv_store = kv_store self._long_poll_host = long_poll_host + self._goal_manager = goal_manager # Non-checkpointed state. self.currently_starting_replicas: Dict[asyncio.Future, Tuple[ @@ -55,23 +56,22 @@ class BackendState: self.backends: Dict[BackendTag, BackendInfo] = dict() self.backend_replicas: Dict[BackendTag, Dict[ ReplicaTag, ActorHandle]] = defaultdict(dict) - self.goals: Dict[BackendTag, GoalId] = dict() + self.backend_goals: Dict[BackendTag, GoalId] = dict() self.backend_replicas_to_start: Dict[BackendTag, List[ ReplicaTag]] = defaultdict(list) self.backend_replicas_to_stop: Dict[BackendTag, List[Tuple[ ReplicaTag, Duration]]] = defaultdict(list) self.backends_to_remove: List[BackendTag] = list() - self.pending_goals: Dict[GoalId, asyncio.Event] = dict() checkpoint = self._kv_store.get(CHECKPOINT_KEY) if checkpoint is not None: - (self.backends, self.backend_replicas, self.goals, + (self.backends, self.backend_replicas, self.backend_goals, self.backend_replicas_to_start, self.backend_replicas_to_stop, self.backend_to_remove, pending_goal_ids) = pickle.loads(checkpoint) for goal_id in pending_goal_ids: - self._create_goal(goal_id) + self._goal_manager.create_goal(goal_id) # Fetch actor handles for all backend replicas in the system. # All of these backend_replicas are guaranteed to already exist @@ -91,9 +91,10 @@ class BackendState: self._kv_store.put( CHECKPOINT_KEY, pickle.dumps( - (self.backends, self.backend_replicas, self.goals, + (self.backends, self.backend_replicas, self.backend_goals, self.backend_replicas_to_start, self.backend_replicas_to_stop, - self.backends_to_remove, list(self.pending_goals.keys())))) + self.backends_to_remove, + self._goal_manager.get_pending_goal_ids()))) def _notify_backend_configs_changed(self) -> None: self._long_poll_host.notify_changed(LongPollKey.BACKEND_CONFIGS, @@ -119,44 +120,17 @@ class BackendState: def get_backend(self, backend_tag: BackendTag) -> Optional[BackendInfo]: return self.backends.get(backend_tag) - def num_pending_goals(self) -> int: - return len(self.pending_goals) - - async def wait_for_goal(self, goal_id: GoalId) -> bool: - start = time.time() - if goal_id not in self.pending_goals: - logger.debug(f"Goal {goal_id} not found") - return True - event = self.pending_goals[goal_id] - await event.wait() - logger.debug( - f"Waiting for goal {goal_id} took {time.time() - start} seconds") - return True - - def _complete_goal(self, goal_id: GoalId) -> None: - logger.debug(f"Completing goal {goal_id}") - event = self.pending_goals.pop(goal_id, None) - if event: - event.set() - - def _create_goal(self, goal_id: Optional[GoalId] = None) -> GoalId: - if goal_id is None: - goal_id = uuid4() - event = asyncio.Event() - self.pending_goals[goal_id] = event - return goal_id - def _set_backend_goal(self, backend_tag: BackendTag, backend_info: BackendInfo) -> None: - existing_goal = self.goals.get(backend_tag) - new_goal = self._create_goal() + existing_goal_id = self.backend_goals.get(backend_tag) + new_goal_id = self._goal_manager.create_goal() if backend_info is not None: self.backends[backend_tag] = backend_info - self.goals[backend_tag] = new_goal + self.backend_goals[backend_tag] = new_goal_id - return new_goal, existing_goal + return new_goal_id, existing_goal_id def create_backend(self, backend_tag: BackendTag, backend_config: BackendConfig, @@ -177,7 +151,7 @@ class BackendState: backend_config=backend_config, replica_config=replica_config) - new_goal, existing_goal = self._set_backend_goal( + new_goal_id, existing_goal_id = self._set_backend_goal( backend_tag, backend_info) try: @@ -193,9 +167,9 @@ class BackendState: self._checkpoint() self._notify_backend_configs_changed() - if existing_goal is not None: - self._complete_goal(existing_goal) - return new_goal + if existing_goal_id is not None: + self._goal_manager.complete_goal(existing_goal_id) + return new_goal_id def delete_backend(self, backend_tag: BackendTag, force_kill: bool = False) -> Optional[GoalId]: @@ -213,12 +187,13 @@ class BackendState: # Add the intention to remove the backend from the routers. self.backends_to_remove.append(backend_tag) - new_goal, existing_goal = self._set_backend_goal(backend_tag, None) + new_goal_id, existing_goal_id = self._set_backend_goal( + backend_tag, None) self._checkpoint() - if existing_goal is not None: - self._complete_goal(existing_goal) - return new_goal + if existing_goal_id is not None: + self._goal_manager.complete_goal(existing_goal_id) + return new_goal_id def update_backend_config(self, backend_tag: BackendTag, config_options: BackendConfig): @@ -231,7 +206,7 @@ class BackendState: updated_config._validate_complete() self.backends[backend_tag].backend_config = updated_config - new_goal, existing_goal = self._set_backend_goal( + new_goal_id, existing_goal_id = self._set_backend_goal( backend_tag, self.backends[backend_tag]) # Scale the replicas with the new configuration. @@ -241,15 +216,15 @@ class BackendState: # update to avoid inconsistent state if we crash after pushing the # update. self._checkpoint() - if existing_goal is not None: - self._complete_goal(existing_goal) + if existing_goal_id is not None: + self._goal_manager.complete_goal(existing_goal_id) # Inform the routers and backend replicas about config changes. # TODO(edoakes): this should only happen if we change something other # than num_replicas. self._notify_backend_configs_changed() - return new_goal + return new_goal_id def _start_backend_replica(self, backend_tag: BackendTag, replica_tag: ReplicaTag) -> ActorHandle: @@ -456,17 +431,17 @@ class BackendState: if (not desired_info or desired_info.backend_config.num_replicas == 0) and \ (not existing_info or len(existing_info) == 0): - completed_goals.append(self.goals.get(backend_tag)) + completed_goals.append(self.backend_goals.get(backend_tag)) # Check for a non-zero number of backends if desired_info and existing_info and desired_info.backend_config.\ num_replicas == len(existing_info): - completed_goals.append(self.goals.get(backend_tag)) + completed_goals.append(self.backend_goals.get(backend_tag)) return [goal for goal in completed_goals if goal] async def update(self) -> bool: for goal_id in self._completed_goals(): - self._complete_goal(goal_id) + self._goal_manager.complete_goal(goal_id) self._start_pending_replicas() self._stop_pending_replicas() diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index e6d0aa82a..a3c75c711 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -4,6 +4,7 @@ from typing import Dict, Any, List, Optional import ray from ray.actor import ActorHandle +from ray.serve.async_goal_manager import AsyncGoalManager from ray.serve.backend_state import BackendState from ray.serve.common import ( BackendTag, @@ -76,18 +77,20 @@ class ServeController: # optimize the logic to support subscription by key. self.long_poll_host = LongPollHost() + self.goal_manager = AsyncGoalManager() self.http_state = HTTPState(controller_name, detached, http_config) self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host) self.backend_state = BackendState(controller_name, detached, - self.kv_store, self.long_poll_host) + self.kv_store, self.long_poll_host, + self.goal_manager) asyncio.get_event_loop().create_task(self.run_control_loop()) async def wait_for_goal(self, goal_id: GoalId) -> None: - await self.backend_state.wait_for_goal(goal_id) + await self.goal_manager.wait_for_goal(goal_id) async def _num_pending_goals(self) -> int: - return self.backend_state.num_pending_goals() + return self.goal_manager.num_pending_goals() async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]): """Proxy long pull client's listen request. diff --git a/python/ray/serve/tests/test_async_goal_manager.py b/python/ray/serve/tests/test_async_goal_manager.py new file mode 100644 index 000000000..770bab888 --- /dev/null +++ b/python/ray/serve/tests/test_async_goal_manager.py @@ -0,0 +1,28 @@ +import asyncio +import pytest + +from ray.serve.async_goal_manager import AsyncGoalManager + + +@pytest.mark.asyncio +async def test_wait_for_goals(): + manager = AsyncGoalManager() + + # Check empty goal + await manager.wait_for_goal(None) + + goal_id = manager.create_goal() + waiting = asyncio.create_task(manager.wait_for_goal(goal_id)) + + assert not waiting.done(), "Unfinished task should not be done" + manager.complete_goal(goal_id) + + await waiting + + # Test double waiting is okay + await manager.wait_for_goal(goal_id) + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_backend_state.py b/python/ray/serve/tests/test_backend_state.py index 3f4114954..3030b63fd 100644 --- a/python/ray/serve/tests/test_backend_state.py +++ b/python/ray/serve/tests/test_backend_state.py @@ -1,4 +1,3 @@ -import asyncio import pytest from typing import Optional, Tuple from unittest.mock import patch, Mock @@ -24,36 +23,13 @@ def generate_mock_backend_info( def mock_backend_state_inputs() -> Tuple[BackendState, Mock, Mock]: with patch( "ray.serve.kv_store.RayInternalKVStore") as mock_kv_store, patch( - "ray.serve.long_poll.LongPollHost") as mock_long_poll: + "ray.serve.long_poll.LongPollHost") as mock_long_poll, patch( + "ray.serve.async_goal_manager.AsyncGoalManager" + ) as mock_goal_manager: mock_kv_store.get = Mock(return_value=None) backend_state = BackendState("name", True, mock_kv_store, - mock_long_poll) - yield backend_state, mock_kv_store, mock_long_poll - - -@pytest.mark.asyncio -async def test_wait_for_goals(mock_backend_state_inputs): - backend_state = mock_backend_state_inputs[0] - - # Check empty goal - assert await backend_state.wait_for_goal(None) - - goal_id = backend_state._create_goal() - waiting = asyncio.create_task(backend_state.wait_for_goal(goal_id)) - - assert not waiting.done(), "Unfinished task should not be done" - backend_state._complete_goal(goal_id) - - assert await waiting - - # Test double waiting is okay - assert await backend_state.wait_for_goal(goal_id) - - -@pytest.mark.asyncio -async def test_recreate_and_wait_for_goals(): - pass - # Ensure that futures are recreated on startup + mock_long_poll, mock_goal_manager) + yield backend_state, mock_kv_store, mock_long_poll, mock_goal_manager def test_completed_goals_deleted_backend(mock_backend_state_inputs): @@ -62,16 +38,16 @@ def test_completed_goals_deleted_backend(mock_backend_state_inputs): backend_state.backends[b1] = None backend_state.backend_replicas[b1] = {} result_uuid_b1 = uuid4() - backend_state.goals[b1] = result_uuid_b1 + backend_state.backend_goals[b1] = result_uuid_b1 assert backend_state._completed_goals() == [result_uuid_b1] - backend_state.goals = {} + backend_state.backend_goals = {} b2 = "backend_two" backend_state.backends[b2] = None result_uuid_b2 = uuid4() - backend_state.goals[b2] = result_uuid_b2 + backend_state.backend_goals[b2] = result_uuid_b2 assert backend_state._completed_goals() == [result_uuid_b2] @@ -87,7 +63,7 @@ def test_completed_goals_delta_backend(mock_backend_state_inputs): backend_state.backends[b1] = generate_mock_backend_info(30) result_uuid = uuid4() - backend_state.goals[b1] = result_uuid + backend_state.backend_goals[b1] = result_uuid assert len(backend_state._completed_goals()) == 0 backend_state.backend_replicas[b1] = {i: i for i in range(30)} @@ -101,7 +77,7 @@ def test_completed_goals_created_backend(mock_backend_state_inputs): b1 = "backend_one" backend_state.backends[b1] = generate_mock_backend_info() result_uuid = uuid4() - backend_state.goals[b1] = result_uuid + backend_state.backend_goals[b1] = result_uuid assert len(backend_state._completed_goals()) == 0