[Serve] Start Replicas in Parallel (#8433)

This commit is contained in:
Simon Mo
2020-05-20 19:46:03 -07:00
committed by GitHub
parent a76434ccde
commit ed2f434593
2 changed files with 62 additions and 13 deletions
+22 -13
View File
@@ -323,6 +323,23 @@ class ServeMaster:
await worker_handle.ready.remote()
return worker_handle
async def _start_replica(self, backend_tag, replica_tag):
# NOTE(edoakes): the replicas may already be created if we
# failed after creating them but before writing a
# checkpoint.
try:
worker_handle = ray.util.get_actor(replica_tag)
except ValueError:
worker_handle = await self._start_backend_worker(
backend_tag, replica_tag)
self.replicas[backend_tag].append(replica_tag)
self.workers[backend_tag][replica_tag] = worker_handle
# Register the worker with the router.
await self.router.add_new_worker.remote(backend_tag, replica_tag,
worker_handle)
async def _start_pending_replicas(self):
"""Starts the pending backend replicas in self.replicas_to_start.
@@ -332,22 +349,14 @@ class ServeMaster:
Clears self.replicas_to_start.
"""
replica_started_futures = []
for backend_tag, replicas_to_create in self.replicas_to_start.items():
for replica_tag in replicas_to_create:
# NOTE(edoakes): the replicas may already be created if we
# failed after creating them but before writing a checkpoint.
try:
worker_handle = ray.util.get_actor(replica_tag)
except ValueError:
worker_handle = await self._start_backend_worker(
backend_tag, replica_tag)
replica_started_futures.append(
self._start_replica(backend_tag, replica_tag))
self.replicas[backend_tag].append(replica_tag)
self.workers[backend_tag][replica_tag] = worker_handle
# Register the worker with the router.
await self.router.add_new_worker.remote(
backend_tag, replica_tag, worker_handle)
# Wait on all creation task futures together.
await asyncio.gather(*replica_started_futures)
self.replicas_to_start.clear()
+40
View File
@@ -1,4 +1,6 @@
import time
import asyncio
import pytest
import requests
@@ -386,3 +388,41 @@ def test_cluster_name():
serve.init(cluster_name="cluster1")
serve.delete_endpoint(endpoint)
serve.delete_backend(backend)
def test_parallel_start(serve_instance):
# Test the ability to start multiple replicas in parallel.
# In the past, when Serve scale up a backend, it does so one by one and
# wait for each replica to initialize. This test avoid this by preventing
# the first replica to finish initialization unless the second replica is
# also started.
@ray.remote
class Barrier:
def __init__(self, release_on):
self.release_on = release_on
self.current_waiters = 0
self.event = asyncio.Event()
async def wait(self):
self.current_waiters += 1
if self.current_waiters == self.release_on:
self.event.set()
else:
await self.event.wait()
barrier = Barrier.remote(release_on=2)
class LongStartingServable:
def __init__(self):
ray.get(barrier.wait.remote(), timeout=10)
def __call__(self, _):
return "Ready"
serve.create_endpoint("test-parallel")
serve.create_backend(
"p:v0", LongStartingServable, config={"num_replicas": 2})
serve.set_traffic("test-parallel", {"p:v0": 1})
handle = serve.get_handle("test-parallel")
ray.get(handle.remote(), timeout=10)