mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:53:18 +08:00
[serve] Fix worker batch queue waiting logic (#8884)
This commit is contained in:
@@ -23,36 +23,65 @@ from ray.serve.router import Query
|
||||
logger = _get_logger()
|
||||
|
||||
|
||||
class WaitableQueue(asyncio.Queue):
|
||||
async def wait_for_batch(self, num_items: int, timeout_s: float):
|
||||
"""Wait up to num_items in the queue given timeout_s.
|
||||
class BatchQueue:
|
||||
def __init__(self, max_batch_size, timeout_s):
|
||||
self.queue = asyncio.Queue()
|
||||
self.full_batch_event = asyncio.Event()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
This method will block indefinitely for the first item. Therefore, it
|
||||
guarantees to return at least one item.
|
||||
def set_config(self, max_batch_size, timeout_s):
|
||||
self.max_batch_size = max_batch_size
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
def put(self, request):
|
||||
self.queue.put_nowait(request)
|
||||
# Signal when the full batch is ready. The event will be reset
|
||||
# in wait_for_batch.
|
||||
if self.queue.qsize() == self.max_batch_size:
|
||||
self.full_batch_event.set()
|
||||
|
||||
async def wait_for_batch(self):
|
||||
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
|
||||
|
||||
Returns a batch of up to self.max_batch_size items, waiting for up
|
||||
to self.timeout_s for a full batch. After the timeout, returns as many
|
||||
items as are ready.
|
||||
|
||||
Always returns a batch with at least one item - will block
|
||||
indefinitely until an item comes in.
|
||||
"""
|
||||
curr_timeout = self.timeout_s
|
||||
batch = []
|
||||
while len(batch) == 0:
|
||||
loop_start = time.time()
|
||||
|
||||
assert num_items >= 1
|
||||
# Wait for the first value without timeout. We will return at least
|
||||
# one item. Additionally this help the caller context switch on empty
|
||||
# queue.
|
||||
start_waiting = time.time()
|
||||
batch = [
|
||||
await self.get(),
|
||||
]
|
||||
# If the timeout is 0, wait for any item to be available on the
|
||||
# queue.
|
||||
if curr_timeout == 0:
|
||||
batch.append(await self.queue.get())
|
||||
# If the timeout is nonzero, wait for either the timeout to occur
|
||||
# or the max batch size to be ready.
|
||||
else:
|
||||
try:
|
||||
await asyncio.wait_for(self.full_batch_event.wait(),
|
||||
curr_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
# Adjust the timeout to account for the time waiting for first item.
|
||||
time_remaining = timeout_s - (time.time() - start_waiting)
|
||||
time_remaining = max(0, time_remaining)
|
||||
# Pull up to the max_batch_size requests off the queue.
|
||||
while len(batch) < self.max_batch_size and not self.queue.empty():
|
||||
batch.append(self.queue.get_nowait())
|
||||
|
||||
# Reset the event if there are fewer than max_batch_size requests
|
||||
# in the queue.
|
||||
if (self.queue.qsize() < self.max_batch_size
|
||||
and self.full_batch_event.is_set()):
|
||||
self.full_batch_event.clear()
|
||||
|
||||
# Adjust the timeout based on the time spent in this iteration.
|
||||
curr_timeout = max(0, curr_timeout - (time.time() - loop_start))
|
||||
|
||||
# Wait for the remaining batch with the timeout
|
||||
if num_items > 1:
|
||||
done_set, not_done_set = await asyncio.wait(
|
||||
[self.get() for _ in range(num_items - 1)],
|
||||
timeout=time_remaining)
|
||||
for task in done_set:
|
||||
batch.append(task.result())
|
||||
for task in not_done_set:
|
||||
task.cancel()
|
||||
return batch
|
||||
|
||||
|
||||
@@ -130,7 +159,8 @@ class RayServeWorker:
|
||||
self.is_function = is_function
|
||||
|
||||
self.config = backend_config
|
||||
self.query_queue = WaitableQueue()
|
||||
self.batch_queue = BatchQueue(self.config.max_batch_size or 1,
|
||||
self.config.batch_wait_timeout)
|
||||
|
||||
self.metric_client = metric_client
|
||||
self.request_counter = self.metric_client.new_counter(
|
||||
@@ -152,7 +182,6 @@ class RayServeWorker:
|
||||
self.restart_counter.labels(replica_tag=self.replica_tag).add()
|
||||
|
||||
self.loop_task = asyncio.get_event_loop().create_task(self.main_loop())
|
||||
self.config_updated = asyncio.Event()
|
||||
|
||||
def get_runner_method(self, request_item):
|
||||
method_name = request_item.call_method
|
||||
@@ -284,9 +313,7 @@ class RayServeWorker:
|
||||
# NOTE(simon): There's an issue when user updated batch size and
|
||||
# batch wait timeout during the execution, these values will not be
|
||||
# updated until after the current iteration.
|
||||
batch = await self.query_queue.wait_for_batch(
|
||||
num_items=self.config.max_batch_size or 1,
|
||||
timeout_s=self.config.batch_wait_timeout)
|
||||
batch = await self.batch_queue.wait_for_batch()
|
||||
|
||||
all_evaluated_futures = []
|
||||
|
||||
@@ -313,11 +340,12 @@ class RayServeWorker:
|
||||
|
||||
def update_config(self, new_config: BackendConfig):
|
||||
self.config = new_config
|
||||
self.config_updated.set()
|
||||
self.batch_queue.set_config(self.config.max_batch_size or 1,
|
||||
self.config.batch_wait_timeout)
|
||||
|
||||
async def handle_request(self, request: Query):
|
||||
assert not isinstance(request, list)
|
||||
logger.debug("Worker {} got request {}".format(self.name, request))
|
||||
request.async_future = asyncio.get_event_loop().create_future()
|
||||
self.query_queue.put_nowait(request)
|
||||
self.batch_queue.put(request)
|
||||
return await request.async_future
|
||||
|
||||
@@ -65,6 +65,9 @@ class BackendConfig:
|
||||
self.num_replicas = config_dict.pop("num_replicas")
|
||||
if "max_batch_size" in config_dict:
|
||||
self.max_batch_size = config_dict.pop("max_batch_size")
|
||||
if "max_concurrent_queries" in config_dict:
|
||||
self.max_concurrent_queries = config_dict.pop(
|
||||
"max_concurrent_queries")
|
||||
|
||||
if len(config_dict) != 0:
|
||||
raise ValueError("Unknown options in backend config: {}".format(
|
||||
|
||||
Reference in New Issue
Block a user