diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 74a702378..486b66b60 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -23,36 +23,65 @@ from ray.serve.router import Query logger = _get_logger() -class WaitableQueue(asyncio.Queue): - async def wait_for_batch(self, num_items: int, timeout_s: float): - """Wait up to num_items in the queue given timeout_s. +class BatchQueue: + def __init__(self, max_batch_size, timeout_s): + self.queue = asyncio.Queue() + self.full_batch_event = asyncio.Event() + self.max_batch_size = max_batch_size + self.timeout_s = timeout_s - This method will block indefinitely for the first item. Therefore, it - guarantees to return at least one item. + def set_config(self, max_batch_size, timeout_s): + self.max_batch_size = max_batch_size + self.timeout_s = timeout_s + + def put(self, request): + self.queue.put_nowait(request) + # Signal when the full batch is ready. The event will be reset + # in wait_for_batch. + if self.queue.qsize() == self.max_batch_size: + self.full_batch_event.set() + + async def wait_for_batch(self): + """Wait for batch respecting self.max_batch_size and self.timeout_s. + + Returns a batch of up to self.max_batch_size items, waiting for up + to self.timeout_s for a full batch. After the timeout, returns as many + items as are ready. + + Always returns a batch with at least one item - will block + indefinitely until an item comes in. """ + curr_timeout = self.timeout_s + batch = [] + while len(batch) == 0: + loop_start = time.time() - assert num_items >= 1 - # Wait for the first value without timeout. We will return at least - # one item. Additionally this help the caller context switch on empty - # queue. - start_waiting = time.time() - batch = [ - await self.get(), - ] + # If the timeout is 0, wait for any item to be available on the + # queue. + if curr_timeout == 0: + batch.append(await self.queue.get()) + # If the timeout is nonzero, wait for either the timeout to occur + # or the max batch size to be ready. + else: + try: + await asyncio.wait_for(self.full_batch_event.wait(), + curr_timeout) + except asyncio.TimeoutError: + pass - # Adjust the timeout to account for the time waiting for first item. - time_remaining = timeout_s - (time.time() - start_waiting) - time_remaining = max(0, time_remaining) + # Pull up to the max_batch_size requests off the queue. + while len(batch) < self.max_batch_size and not self.queue.empty(): + batch.append(self.queue.get_nowait()) + + # Reset the event if there are fewer than max_batch_size requests + # in the queue. + if (self.queue.qsize() < self.max_batch_size + and self.full_batch_event.is_set()): + self.full_batch_event.clear() + + # Adjust the timeout based on the time spent in this iteration. + curr_timeout = max(0, curr_timeout - (time.time() - loop_start)) - # Wait for the remaining batch with the timeout - if num_items > 1: - done_set, not_done_set = await asyncio.wait( - [self.get() for _ in range(num_items - 1)], - timeout=time_remaining) - for task in done_set: - batch.append(task.result()) - for task in not_done_set: - task.cancel() return batch @@ -130,7 +159,8 @@ class RayServeWorker: self.is_function = is_function self.config = backend_config - self.query_queue = WaitableQueue() + self.batch_queue = BatchQueue(self.config.max_batch_size or 1, + self.config.batch_wait_timeout) self.metric_client = metric_client self.request_counter = self.metric_client.new_counter( @@ -152,7 +182,6 @@ class RayServeWorker: self.restart_counter.labels(replica_tag=self.replica_tag).add() self.loop_task = asyncio.get_event_loop().create_task(self.main_loop()) - self.config_updated = asyncio.Event() def get_runner_method(self, request_item): method_name = request_item.call_method @@ -284,9 +313,7 @@ class RayServeWorker: # NOTE(simon): There's an issue when user updated batch size and # batch wait timeout during the execution, these values will not be # updated until after the current iteration. - batch = await self.query_queue.wait_for_batch( - num_items=self.config.max_batch_size or 1, - timeout_s=self.config.batch_wait_timeout) + batch = await self.batch_queue.wait_for_batch() all_evaluated_futures = [] @@ -313,11 +340,12 @@ class RayServeWorker: def update_config(self, new_config: BackendConfig): self.config = new_config - self.config_updated.set() + self.batch_queue.set_config(self.config.max_batch_size or 1, + self.config.batch_wait_timeout) async def handle_request(self, request: Query): assert not isinstance(request, list) logger.debug("Worker {} got request {}".format(self.name, request)) request.async_future = asyncio.get_event_loop().create_future() - self.query_queue.put_nowait(request) + self.batch_queue.put(request) return await request.async_future diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 69a506449..fbbb7fd3f 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -65,6 +65,9 @@ class BackendConfig: self.num_replicas = config_dict.pop("num_replicas") if "max_batch_size" in config_dict: self.max_batch_size = config_dict.pop("max_batch_size") + if "max_concurrent_queries" in config_dict: + self.max_concurrent_queries = config_dict.pop( + "max_concurrent_queries") if len(config_dict) != 0: raise ValueError("Unknown options in backend config: {}".format(