diff --git a/python/ray/serve/master.py b/python/ray/serve/master.py index cc3472949..2507e48d8 100644 --- a/python/ray/serve/master.py +++ b/python/ray/serve/master.py @@ -323,6 +323,23 @@ class ServeMaster: await worker_handle.ready.remote() return worker_handle + async def _start_replica(self, backend_tag, replica_tag): + # NOTE(edoakes): the replicas may already be created if we + # failed after creating them but before writing a + # checkpoint. + try: + worker_handle = ray.util.get_actor(replica_tag) + except ValueError: + worker_handle = await self._start_backend_worker( + backend_tag, replica_tag) + + self.replicas[backend_tag].append(replica_tag) + self.workers[backend_tag][replica_tag] = worker_handle + + # Register the worker with the router. + await self.router.add_new_worker.remote(backend_tag, replica_tag, + worker_handle) + async def _start_pending_replicas(self): """Starts the pending backend replicas in self.replicas_to_start. @@ -332,22 +349,14 @@ class ServeMaster: Clears self.replicas_to_start. """ + replica_started_futures = [] for backend_tag, replicas_to_create in self.replicas_to_start.items(): for replica_tag in replicas_to_create: - # NOTE(edoakes): the replicas may already be created if we - # failed after creating them but before writing a checkpoint. - try: - worker_handle = ray.util.get_actor(replica_tag) - except ValueError: - worker_handle = await self._start_backend_worker( - backend_tag, replica_tag) + replica_started_futures.append( + self._start_replica(backend_tag, replica_tag)) - self.replicas[backend_tag].append(replica_tag) - self.workers[backend_tag][replica_tag] = worker_handle - - # Register the worker with the router. - await self.router.add_new_worker.remote( - backend_tag, replica_tag, worker_handle) + # Wait on all creation task futures together. + await asyncio.gather(*replica_started_futures) self.replicas_to_start.clear() diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 5c9d5397f..e53170591 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -1,4 +1,6 @@ import time +import asyncio + import pytest import requests @@ -386,3 +388,41 @@ def test_cluster_name(): serve.init(cluster_name="cluster1") serve.delete_endpoint(endpoint) serve.delete_backend(backend) + + +def test_parallel_start(serve_instance): + # Test the ability to start multiple replicas in parallel. + # In the past, when Serve scale up a backend, it does so one by one and + # wait for each replica to initialize. This test avoid this by preventing + # the first replica to finish initialization unless the second replica is + # also started. + @ray.remote + class Barrier: + def __init__(self, release_on): + self.release_on = release_on + self.current_waiters = 0 + self.event = asyncio.Event() + + async def wait(self): + self.current_waiters += 1 + if self.current_waiters == self.release_on: + self.event.set() + else: + await self.event.wait() + + barrier = Barrier.remote(release_on=2) + + class LongStartingServable: + def __init__(self): + ray.get(barrier.wait.remote(), timeout=10) + + def __call__(self, _): + return "Ready" + + serve.create_endpoint("test-parallel") + serve.create_backend( + "p:v0", LongStartingServable, config={"num_replicas": 2}) + serve.set_traffic("test-parallel", {"p:v0": 1}) + handle = serve.get_handle("test-parallel") + + ray.get(handle.remote(), timeout=10)