diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 0daf7ade6..be8707d86 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -43,14 +43,13 @@ py_test( ) -# TODO(simon): Test skipped until #11683 fixed. -# py_test( -# name = "test_failure", -# size = "medium", -# srcs = serve_tests_srcs, -# tags = ["exclusive"], -# deps = [":serve_lib"], -# ) +py_test( + name = "test_failure", + size = "medium", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) py_test( diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 7330a722d..5237d453c 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -504,25 +504,11 @@ class ServeController: self.inflight_results: Dict[UUID, asyncio.Event] = dict() self._serializable_inflight_results: Dict[UUID, FutureResult] = dict() - # NOTE(edoakes): unfortunately, we can't completely recover from a - # checkpoint in the constructor because we block while waiting for - # other actors to start up, and those actors fetch soft state from - # this actor. Because no other tasks will start executing until after - # the constructor finishes, if we were to run this logic in the - # constructor it could lead to deadlock between this actor and a child. - # However we do need to guarantee that we have fully recovered from a - # checkpoint before any other state-changing calls run. We address this - # by acquiring the write_lock and then posting the task to recover from - # a checkpoint to the event loop. Other state-changing calls acquire - # this lock and will be blocked until recovering from the checkpoint - # finishes. checkpoint = self.kv_store.get(CHECKPOINT_KEY) if checkpoint is None: logger.debug("No checkpoint found") else: - await self.write_lock.acquire() - asyncio.get_event_loop().create_task( - self._recover_from_checkpoint(checkpoint)) + await self._recover_from_checkpoint(checkpoint) # NOTE(simon): Currently we do all-to-all broadcast. This means # any listeners will receive notification for all changes. This @@ -530,6 +516,10 @@ class ServeController: # will send over the entire configs. In the future, we should # optimize the logic to support subscription by key. self.long_poll_host = LongPollHost() + + # The configs pushed out here get updated by + # self._recover_from_checkpoint in the failure scenario, so that must + # be run before we notify the changes. self.notify_backend_configs_changed() self.notify_replica_handles_changed() self.notify_traffic_policies_changed() @@ -625,40 +615,51 @@ class ServeController: async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None: """Recover the instance state from the provided checkpoint. + This should be called in the constructor to ensure that the internal + state is updated before any other operations run. After running this, + internal state will be updated and long-poll clients may be notified. + Performs the following operations: 1) Deserializes the internal state from the checkpoint. - 2) Pushes the latest configuration to the HTTP proxies - in case we crashed before updating them. - 3) Starts/stops any replicas that are pending creation or + 2) Starts/stops any replicas that are pending creation or deletion. - - NOTE: this requires that self.write_lock is already acquired and will - release it before returning. """ - assert self.write_lock.locked() - start = time.time() logger.info("Recovering from checkpoint") restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes) - # Restore SystemState self.current_state = restored_checkpoint.current_state - # Restore ActorStateReconciler self.actor_reconciler = restored_checkpoint.reconciler - # Recreate self.inflight_requests! self._serializable_inflight_results = restored_checkpoint.inflight_reqs for uuid, fut_result in self._serializable_inflight_results.items(): self._create_event_with_result(fut_result.requested_goal, uuid) - self.autoscaling_policies = await self.actor_reconciler.\ - _recover_from_checkpoint(self.current_state, self) + # NOTE(edoakes): unfortunately, we can't completely recover from a + # checkpoint in the constructor because we block while waiting for + # other actors to start up, and those actors fetch soft state from + # this actor. Because no other tasks will start executing until after + # the constructor finishes, if we were to run this logic in the + # constructor it could lead to deadlock between this actor and a child. + # However, we do need to guarantee that we have fully recovered from a + # checkpoint before any other state-changing calls run. We address this + # by acquiring the write_lock and then posting the task to recover from + # a checkpoint to the event loop. Other state-changing calls acquire + # this lock and will be blocked until recovering from the checkpoint + # finishes. This can be removed once we move to the async control loop. - logger.info( - "Recovered from checkpoint in {:.3f}s".format(time.time() - start)) + async def finish_recover_from_checkpoint(): + assert self.write_lock.locked() + self.autoscaling_policies = await self.actor_reconciler.\ + _recover_from_checkpoint(self.current_state, self) + self.write_lock.release() + logger.info( + "Recovered from checkpoint in {:.3f}s".format(time.time() - + start)) - self.write_lock.release() + await self.write_lock.acquire() + asyncio.get_event_loop().create_task(finish_recover_from_checkpoint()) async def do_autoscale(self) -> None: for backend, info in self.current_state.backends.items(): diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index 00c59ceae..3215f3578 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -27,6 +27,7 @@ class HTTPProxy: def __init__(self, controller_name): controller = ray.get_actor(controller_name) + self.route_table = {} # Should be updated via long polling. self.router = Router(controller) self.long_poll_client = LongPollAsyncClient(controller, { LongPollKey.ROUTE_TABLE: self._update_route_table, @@ -41,6 +42,7 @@ class HTTPProxy: await self.router.setup_in_async_loop() async def _update_route_table(self, route_table): + logger.debug(f"HTTP Proxy: Get updated route table: {route_table}.") self.route_table = route_table async def receive_http_body(self, scope, receive, send): diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 6312e56f2..7e2fbcb65 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -4,6 +4,7 @@ import tempfile import time import ray +from ray.test_utils import wait_for_condition from ray import serve from ray.serve.config import BackendConfig, ReplicaConfig @@ -53,9 +54,11 @@ def test_controller_failure(serve_instance): client.create_backend("controller_failure:v2", function) client.set_traffic("controller_failure", {"controller_failure:v2": 1.0}) - for _ in range(10): + def check_controller_failure(): response = request_with_retries("/controller_failure", timeout=30) - assert response.text == "hello2" + return response.text == "hello2" + + wait_for_condition(check_controller_failure) def function(_): return "hello3" @@ -124,7 +127,7 @@ def test_worker_restart(serve_instance): client = serve_instance class Worker1: - def __call__(self): + def __call__(self, *args): return os.getpid() client.create_backend("worker_failure:v1", Worker1) @@ -176,7 +179,7 @@ def test_worker_replica_failure(serve_instance): while True: pass - def __call__(self): + def __call__(self, *args): pass temp_path = os.path.join(tempfile.gettempdir(),