mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 19:17:04 +08:00
[Serve] Start Replicas in Parallel (#8433)
This commit is contained in:
+22
-13
@@ -323,6 +323,23 @@ class ServeMaster:
|
||||
await worker_handle.ready.remote()
|
||||
return worker_handle
|
||||
|
||||
async def _start_replica(self, backend_tag, replica_tag):
|
||||
# NOTE(edoakes): the replicas may already be created if we
|
||||
# failed after creating them but before writing a
|
||||
# checkpoint.
|
||||
try:
|
||||
worker_handle = ray.util.get_actor(replica_tag)
|
||||
except ValueError:
|
||||
worker_handle = await self._start_backend_worker(
|
||||
backend_tag, replica_tag)
|
||||
|
||||
self.replicas[backend_tag].append(replica_tag)
|
||||
self.workers[backend_tag][replica_tag] = worker_handle
|
||||
|
||||
# Register the worker with the router.
|
||||
await self.router.add_new_worker.remote(backend_tag, replica_tag,
|
||||
worker_handle)
|
||||
|
||||
async def _start_pending_replicas(self):
|
||||
"""Starts the pending backend replicas in self.replicas_to_start.
|
||||
|
||||
@@ -332,22 +349,14 @@ class ServeMaster:
|
||||
|
||||
Clears self.replicas_to_start.
|
||||
"""
|
||||
replica_started_futures = []
|
||||
for backend_tag, replicas_to_create in self.replicas_to_start.items():
|
||||
for replica_tag in replicas_to_create:
|
||||
# NOTE(edoakes): the replicas may already be created if we
|
||||
# failed after creating them but before writing a checkpoint.
|
||||
try:
|
||||
worker_handle = ray.util.get_actor(replica_tag)
|
||||
except ValueError:
|
||||
worker_handle = await self._start_backend_worker(
|
||||
backend_tag, replica_tag)
|
||||
replica_started_futures.append(
|
||||
self._start_replica(backend_tag, replica_tag))
|
||||
|
||||
self.replicas[backend_tag].append(replica_tag)
|
||||
self.workers[backend_tag][replica_tag] = worker_handle
|
||||
|
||||
# Register the worker with the router.
|
||||
await self.router.add_new_worker.remote(
|
||||
backend_tag, replica_tag, worker_handle)
|
||||
# Wait on all creation task futures together.
|
||||
await asyncio.gather(*replica_started_futures)
|
||||
|
||||
self.replicas_to_start.clear()
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
@@ -386,3 +388,41 @@ def test_cluster_name():
|
||||
serve.init(cluster_name="cluster1")
|
||||
serve.delete_endpoint(endpoint)
|
||||
serve.delete_backend(backend)
|
||||
|
||||
|
||||
def test_parallel_start(serve_instance):
|
||||
# Test the ability to start multiple replicas in parallel.
|
||||
# In the past, when Serve scale up a backend, it does so one by one and
|
||||
# wait for each replica to initialize. This test avoid this by preventing
|
||||
# the first replica to finish initialization unless the second replica is
|
||||
# also started.
|
||||
@ray.remote
|
||||
class Barrier:
|
||||
def __init__(self, release_on):
|
||||
self.release_on = release_on
|
||||
self.current_waiters = 0
|
||||
self.event = asyncio.Event()
|
||||
|
||||
async def wait(self):
|
||||
self.current_waiters += 1
|
||||
if self.current_waiters == self.release_on:
|
||||
self.event.set()
|
||||
else:
|
||||
await self.event.wait()
|
||||
|
||||
barrier = Barrier.remote(release_on=2)
|
||||
|
||||
class LongStartingServable:
|
||||
def __init__(self):
|
||||
ray.get(barrier.wait.remote(), timeout=10)
|
||||
|
||||
def __call__(self, _):
|
||||
return "Ready"
|
||||
|
||||
serve.create_endpoint("test-parallel")
|
||||
serve.create_backend(
|
||||
"p:v0", LongStartingServable, config={"num_replicas": 2})
|
||||
serve.set_traffic("test-parallel", {"p:v0": 1})
|
||||
handle = serve.get_handle("test-parallel")
|
||||
|
||||
ray.get(handle.remote(), timeout=10)
|
||||
|
||||
Reference in New Issue
Block a user