diff --git a/python/ray/serve/backend_state.py b/python/ray/serve/backend_state.py index 9df9f8d03..40005cc16 100644 --- a/python/ray/serve/backend_state.py +++ b/python/ray/serve/backend_state.py @@ -122,7 +122,7 @@ class BackendState: def num_pending_goals(self) -> int: return len(self.pending_goals) - async def wait_for_goal(self, goal_id: GoalId) -> None: + async def wait_for_goal(self, goal_id: GoalId) -> bool: start = time.time() if goal_id not in self.pending_goals: logger.debug(f"Goal {goal_id} not found") @@ -131,6 +131,7 @@ class BackendState: await event.wait() logger.debug( f"Waiting for goal {goal_id} took {time.time() - start} seconds") + return True def _complete_goal(self, goal_id: GoalId) -> None: logger.debug(f"Completing goal {goal_id}") @@ -455,13 +456,13 @@ class BackendState: if (not desired_info or desired_info.backend_config.num_replicas == 0) and \ (not existing_info or len(existing_info) == 0): - completed_goals.append(self.goals[backend_tag]) + completed_goals.append(self.goals.get(backend_tag)) # Check for a non-zero number of backends if desired_info and existing_info and desired_info.backend_config.\ num_replicas == len(existing_info): - completed_goals.append(self.goals[backend_tag]) - return completed_goals + completed_goals.append(self.goals.get(backend_tag)) + return [goal for goal in completed_goals if goal] async def update(self) -> bool: for goal_id in self._completed_goals(): diff --git a/python/ray/serve/tests/test_backend_state.py b/python/ray/serve/tests/test_backend_state.py new file mode 100644 index 000000000..3f4114954 --- /dev/null +++ b/python/ray/serve/tests/test_backend_state.py @@ -0,0 +1,115 @@ +import asyncio +import pytest +from typing import Optional, Tuple +from unittest.mock import patch, Mock +from uuid import uuid4 + +from ray.serve.common import BackendConfig, BackendInfo, ReplicaConfig +from ray.serve.backend_state import BackendState + + +def generate_mock_backend_info( + num_replicas: Optional[int] = None) -> BackendInfo: + backend_info = BackendInfo( + worker_class=lambda x: x, + backend_config=BackendConfig(), + replica_config=ReplicaConfig(lambda x: x)) + if num_replicas: + backend_info.backend_config.num_replicas = num_replicas + + return backend_info + + +@pytest.fixture +def mock_backend_state_inputs() -> Tuple[BackendState, Mock, Mock]: + with patch( + "ray.serve.kv_store.RayInternalKVStore") as mock_kv_store, patch( + "ray.serve.long_poll.LongPollHost") as mock_long_poll: + mock_kv_store.get = Mock(return_value=None) + backend_state = BackendState("name", True, mock_kv_store, + mock_long_poll) + yield backend_state, mock_kv_store, mock_long_poll + + +@pytest.mark.asyncio +async def test_wait_for_goals(mock_backend_state_inputs): + backend_state = mock_backend_state_inputs[0] + + # Check empty goal + assert await backend_state.wait_for_goal(None) + + goal_id = backend_state._create_goal() + waiting = asyncio.create_task(backend_state.wait_for_goal(goal_id)) + + assert not waiting.done(), "Unfinished task should not be done" + backend_state._complete_goal(goal_id) + + assert await waiting + + # Test double waiting is okay + assert await backend_state.wait_for_goal(goal_id) + + +@pytest.mark.asyncio +async def test_recreate_and_wait_for_goals(): + pass + # Ensure that futures are recreated on startup + + +def test_completed_goals_deleted_backend(mock_backend_state_inputs): + backend_state = mock_backend_state_inputs[0] + b1 = "backend_one" + backend_state.backends[b1] = None + backend_state.backend_replicas[b1] = {} + result_uuid_b1 = uuid4() + backend_state.goals[b1] = result_uuid_b1 + + assert backend_state._completed_goals() == [result_uuid_b1] + + backend_state.goals = {} + + b2 = "backend_two" + backend_state.backends[b2] = None + result_uuid_b2 = uuid4() + backend_state.goals[b2] = result_uuid_b2 + + assert backend_state._completed_goals() == [result_uuid_b2] + + +def test_completed_goals_delta_backend(mock_backend_state_inputs): + backend_state = mock_backend_state_inputs[0] + b1 = "backend_one" + backend_state.backends[b1] = generate_mock_backend_info() + backend_state.backend_replicas[b1] = {i: i for i in range(1)} + # NOTE(ilr): This test made it clear that the _completed_goals function + # should (.get) from the dict. + assert len(backend_state._completed_goals()) == 0 + + backend_state.backends[b1] = generate_mock_backend_info(30) + result_uuid = uuid4() + backend_state.goals[b1] = result_uuid + assert len(backend_state._completed_goals()) == 0 + + backend_state.backend_replicas[b1] = {i: i for i in range(30)} + assert backend_state._completed_goals() == [result_uuid] + + +def test_completed_goals_created_backend(mock_backend_state_inputs): + backend_state = mock_backend_state_inputs[0] + assert len(backend_state._completed_goals()) == 0 + + b1 = "backend_one" + backend_state.backends[b1] = generate_mock_backend_info() + result_uuid = uuid4() + backend_state.goals[b1] = result_uuid + + assert len(backend_state._completed_goals()) == 0 + + backend_state.backend_replicas[b1] = {i: i for i in range(1)} + + assert backend_state._completed_goals() == [result_uuid] + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", "-s", __file__]))