import asyncio import traceback import inspect from collections.abc import Iterable from collections import defaultdict from itertools import groupby from operator import attrgetter import time import ray from ray.async_compat import sync_to_async from ray import serve from ray.serve import context as serve_context from ray.serve.context import FakeFlaskRequest from ray.serve.utils import (parse_request_item, _get_logger, chain_future, unpack_future) from ray.serve.exceptions import RayServeException from ray.serve.metric import MetricClient from ray.serve.config import BackendConfig from ray.serve.router import Query logger = _get_logger() class BatchQueue: def __init__(self, max_batch_size, timeout_s): self.queue = asyncio.Queue() self.full_batch_event = asyncio.Event() self.max_batch_size = max_batch_size self.timeout_s = timeout_s def set_config(self, max_batch_size, timeout_s): self.max_batch_size = max_batch_size self.timeout_s = timeout_s def put(self, request): self.queue.put_nowait(request) # Signal when the full batch is ready. The event will be reset # in wait_for_batch. if self.queue.qsize() == self.max_batch_size: self.full_batch_event.set() async def wait_for_batch(self): """Wait for batch respecting self.max_batch_size and self.timeout_s. Returns a batch of up to self.max_batch_size items, waiting for up to self.timeout_s for a full batch. After the timeout, returns as many items as are ready. Always returns a batch with at least one item - will block indefinitely until an item comes in. """ curr_timeout = self.timeout_s batch = [] while len(batch) == 0: loop_start = time.time() # If the timeout is 0, wait for any item to be available on the # queue. if curr_timeout == 0: batch.append(await self.queue.get()) # If the timeout is nonzero, wait for either the timeout to occur # or the max batch size to be ready. else: try: await asyncio.wait_for(self.full_batch_event.wait(), curr_timeout) except asyncio.TimeoutError: pass # Pull up to the max_batch_size requests off the queue. while len(batch) < self.max_batch_size and not self.queue.empty(): batch.append(self.queue.get_nowait()) # Reset the event if there are fewer than max_batch_size requests # in the queue. if (self.queue.qsize() < self.max_batch_size and self.full_batch_event.is_set()): self.full_batch_event.clear() # Adjust the timeout based on the time spent in this iteration. curr_timeout = max(0, curr_timeout - (time.time() - loop_start)) return batch def create_backend_worker(func_or_class): """Creates a worker class wrapping the provided function or class.""" if inspect.isfunction(func_or_class): is_function = True elif inspect.isclass(func_or_class): is_function = False else: assert False, "func_or_class must be function or class." class RayServeWrappedWorker(object): def __init__(self, backend_tag, replica_tag, init_args, backend_config: BackendConfig, instance_name=None): serve.init(name=instance_name) if is_function: _callable = func_or_class else: _callable = func_or_class(*init_args) master = serve.api._get_master_actor() [metric_exporter] = ray.get(master.get_metric_exporter.remote()) metric_client = MetricClient( metric_exporter, default_labels={"backend": backend_tag}) self.backend = RayServeWorker(backend_tag, replica_tag, _callable, backend_config, is_function, metric_client) async def handle_request(self, request): return await self.backend.handle_request(request) def update_config(self, new_config: BackendConfig): return self.backend.update_config(new_config) def ready(self): pass RayServeWrappedWorker.__name__ = "RayServeWorker_" + func_or_class.__name__ return RayServeWrappedWorker def wrap_to_ray_error(exception): """Utility method to wrap exceptions in user code.""" try: # Raise and catch so we can access traceback.format_exc() raise exception except Exception as e: traceback_str = ray.utils.format_error_message(traceback.format_exc()) return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__) def ensure_async(func): if inspect.iscoroutinefunction(func): return func else: return sync_to_async(func) class RayServeWorker: """Handles requests with the provided callable.""" def __init__(self, name, replica_tag, _callable, backend_config: BackendConfig, is_function, metric_client): self.name = name self.replica_tag = replica_tag self.callable = _callable self.is_function = is_function self.config = backend_config self.batch_queue = BatchQueue(self.config.max_batch_size or 1, self.config.batch_wait_timeout) self.metric_client = metric_client self.request_counter = self.metric_client.new_counter( "backend_request_counter", description=("Number of queries that have been " "processed in this replica"), ) self.error_counter = self.metric_client.new_counter( "backend_error_counter", description=("Number of exceptions that have " "occurred in the backend"), ) self.restart_counter = self.metric_client.new_counter( "backend_worker_starts", description=("The number of time this replica workers " "has been restarted due to failure."), label_names=("replica_tag", )) self.restart_counter.labels(replica_tag=self.replica_tag).add() self.loop_task = asyncio.get_event_loop().create_task(self.main_loop()) def get_runner_method(self, request_item): method_name = request_item.call_method if not hasattr(self.callable, method_name): raise RayServeException("Backend doesn't have method {} " "which is specified in the request. " "The available methods are {}".format( method_name, dir(self.callable))) if self.is_function: return self.callable return getattr(self.callable, method_name) def has_positional_args(self, f): # NOTE: # In the case of simple functions, not actors, the f will be # function.__call__, but we need to inspect the function itself. if self.is_function: f = self.callable signature = inspect.signature(f) for param in signature.parameters.values(): if (param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty): return True return False def _reset_context(self): # NOTE(simon): context management won't work in async mode because # many concurrent queries might be running at the same time. serve_context.web = None serve_context.batch_size = None async def invoke_single(self, request_item): args, kwargs, is_web_context = parse_request_item(request_item) serve_context.web = is_web_context method_to_call = self.get_runner_method(request_item) args = args if self.has_positional_args(method_to_call) else [] method_to_call = ensure_async(method_to_call) try: result = await method_to_call(*args, **kwargs) self.request_counter.add() except Exception as e: result = wrap_to_ray_error(e) self.error_counter.add() finally: self._reset_context() return result async def invoke_batch(self, request_item_list): arg_list = [] kwargs_list = defaultdict(list) context_flags = set() batch_size = len(request_item_list) call_methods = set() for item in request_item_list: args, kwargs, is_web_context = parse_request_item(item) context_flags.add(is_web_context) call_method = self.get_runner_method(item) call_methods.add(call_method) if is_web_context: # Python context only have kwargs flask_request = args[0] arg_list.append(flask_request) else: # Web context only have one positional argument for k, v in kwargs.items(): kwargs_list[k].append(v) # Set the flask request as a list to conform # with batching semantics: when in batching # mode, each argument is turned into list. if self.has_positional_args(call_method): arg_list.append(FakeFlaskRequest()) try: # Check mixing of query context (unified context needed). if len(context_flags) != 1: raise RayServeException( "Batched queries contain mixed context. Please only send " "the same type of requests in batching mode.") serve_context.web = context_flags.pop() if len(call_methods) != 1: raise RayServeException( "Queries contain mixed calling methods. Please only send " "the same type of requests in batching mode.") call_method = ensure_async(call_methods.pop()) serve_context.batch_size = batch_size # Flask requests are passed to __call__ as a list arg_list = [arg_list] self.request_counter.add(batch_size) result_list = await call_method(*arg_list, **kwargs_list) if not isinstance(result_list, Iterable) or isinstance( result_list, (dict, set)): error_message = ("RayServe expects an ordered iterable object " "but the worker returned a {}".format( type(result_list))) raise RayServeException(error_message) # Normalize the result into a list type. This operation is fast # in Python because it doesn't copy anything. result_list = list(result_list) if (len(result_list) != batch_size): error_message = ("Worker doesn't preserve batch size. The " "input has length {} but the returned list " "has length {}. Please return a list of " "results with length equal to the batch size" ".".format(batch_size, len(result_list))) raise RayServeException(error_message) self._reset_context() return result_list except Exception as e: wrapped_exception = wrap_to_ray_error(e) self.error_counter.add() self._reset_context() return [wrapped_exception for _ in range(batch_size)] async def main_loop(self): while True: # NOTE(simon): There's an issue when user updated batch size and # batch wait timeout during the execution, these values will not be # updated until after the current iteration. batch = await self.batch_queue.wait_for_batch() all_evaluated_futures = [] if not self.config.accepts_batches: query = batch[0] evaluated = asyncio.ensure_future(self.invoke_single(query)) all_evaluated_futures = [evaluated] chain_future(evaluated, query.async_future) else: get_call_method = attrgetter("call_method") sorted_batch = sorted(batch, key=get_call_method) for _, group in groupby(sorted_batch, key=get_call_method): group = sorted(group) evaluated = asyncio.ensure_future(self.invoke_batch(group)) all_evaluated_futures.append(evaluated) result_futures = [q.async_future for q in group] chain_future( unpack_future(evaluated, len(group)), result_futures) if self.config.is_blocking: # We use asyncio.wait here so if the result is exception, # it will not be raised. await asyncio.wait(all_evaluated_futures) def update_config(self, new_config: BackendConfig): self.config = new_config self.batch_queue.set_config(self.config.max_batch_size or 1, self.config.batch_wait_timeout) async def handle_request(self, request: Query): assert not isinstance(request, list) logger.debug("Worker {} got request {}".format(self.name, request)) request.async_future = asyncio.get_event_loop().create_future() self.batch_queue.put(request) return await request.async_future