[serve] Fix worker batch queue waiting logic (#8884)

This commit is contained in:
Edward Oakes
2020-06-10 21:28:16 -05:00
committed by GitHub
parent 950b389581
commit 3a9f45c4b3
2 changed files with 63 additions and 32 deletions
+60 -32
View File
@@ -23,36 +23,65 @@ from ray.serve.router import Query
logger = _get_logger()
class WaitableQueue(asyncio.Queue):
async def wait_for_batch(self, num_items: int, timeout_s: float):
"""Wait up to num_items in the queue given timeout_s.
class BatchQueue:
def __init__(self, max_batch_size, timeout_s):
self.queue = asyncio.Queue()
self.full_batch_event = asyncio.Event()
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
This method will block indefinitely for the first item. Therefore, it
guarantees to return at least one item.
def set_config(self, max_batch_size, timeout_s):
self.max_batch_size = max_batch_size
self.timeout_s = timeout_s
def put(self, request):
self.queue.put_nowait(request)
# Signal when the full batch is ready. The event will be reset
# in wait_for_batch.
if self.queue.qsize() == self.max_batch_size:
self.full_batch_event.set()
async def wait_for_batch(self):
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
Returns a batch of up to self.max_batch_size items, waiting for up
to self.timeout_s for a full batch. After the timeout, returns as many
items as are ready.
Always returns a batch with at least one item - will block
indefinitely until an item comes in.
"""
curr_timeout = self.timeout_s
batch = []
while len(batch) == 0:
loop_start = time.time()
assert num_items >= 1
# Wait for the first value without timeout. We will return at least
# one item. Additionally this help the caller context switch on empty
# queue.
start_waiting = time.time()
batch = [
await self.get(),
]
# If the timeout is 0, wait for any item to be available on the
# queue.
if curr_timeout == 0:
batch.append(await self.queue.get())
# If the timeout is nonzero, wait for either the timeout to occur
# or the max batch size to be ready.
else:
try:
await asyncio.wait_for(self.full_batch_event.wait(),
curr_timeout)
except asyncio.TimeoutError:
pass
# Adjust the timeout to account for the time waiting for first item.
time_remaining = timeout_s - (time.time() - start_waiting)
time_remaining = max(0, time_remaining)
# Pull up to the max_batch_size requests off the queue.
while len(batch) < self.max_batch_size and not self.queue.empty():
batch.append(self.queue.get_nowait())
# Reset the event if there are fewer than max_batch_size requests
# in the queue.
if (self.queue.qsize() < self.max_batch_size
and self.full_batch_event.is_set()):
self.full_batch_event.clear()
# Adjust the timeout based on the time spent in this iteration.
curr_timeout = max(0, curr_timeout - (time.time() - loop_start))
# Wait for the remaining batch with the timeout
if num_items > 1:
done_set, not_done_set = await asyncio.wait(
[self.get() for _ in range(num_items - 1)],
timeout=time_remaining)
for task in done_set:
batch.append(task.result())
for task in not_done_set:
task.cancel()
return batch
@@ -130,7 +159,8 @@ class RayServeWorker:
self.is_function = is_function
self.config = backend_config
self.query_queue = WaitableQueue()
self.batch_queue = BatchQueue(self.config.max_batch_size or 1,
self.config.batch_wait_timeout)
self.metric_client = metric_client
self.request_counter = self.metric_client.new_counter(
@@ -152,7 +182,6 @@ class RayServeWorker:
self.restart_counter.labels(replica_tag=self.replica_tag).add()
self.loop_task = asyncio.get_event_loop().create_task(self.main_loop())
self.config_updated = asyncio.Event()
def get_runner_method(self, request_item):
method_name = request_item.call_method
@@ -284,9 +313,7 @@ class RayServeWorker:
# NOTE(simon): There's an issue when user updated batch size and
# batch wait timeout during the execution, these values will not be
# updated until after the current iteration.
batch = await self.query_queue.wait_for_batch(
num_items=self.config.max_batch_size or 1,
timeout_s=self.config.batch_wait_timeout)
batch = await self.batch_queue.wait_for_batch()
all_evaluated_futures = []
@@ -313,11 +340,12 @@ class RayServeWorker:
def update_config(self, new_config: BackendConfig):
self.config = new_config
self.config_updated.set()
self.batch_queue.set_config(self.config.max_batch_size or 1,
self.config.batch_wait_timeout)
async def handle_request(self, request: Query):
assert not isinstance(request, list)
logger.debug("Worker {} got request {}".format(self.name, request))
request.async_future = asyncio.get_event_loop().create_future()
self.query_queue.put_nowait(request)
self.batch_queue.put(request)
return await request.async_future
+3
View File
@@ -65,6 +65,9 @@ class BackendConfig:
self.num_replicas = config_dict.pop("num_replicas")
if "max_batch_size" in config_dict:
self.max_batch_size = config_dict.pop("max_batch_size")
if "max_concurrent_queries" in config_dict:
self.max_concurrent_queries = config_dict.pop(
"max_concurrent_queries")
if len(config_dict) != 0:
raise ValueError("Unknown options in backend config: {}".format(