diff --git a/python/ray/serve/queues.py b/python/ray/serve/queues.py index 4e944e967..674180e8e 100644 --- a/python/ray/serve/queues.py +++ b/python/ray/serve/queues.py @@ -17,8 +17,12 @@ from ray.serve.utils import logger class Query: - def __init__(self, request_args, request_kwargs, request_context, - request_slo_ms): + def __init__(self, + request_args, + request_kwargs, + request_context, + request_slo_ms, + call_method="__call__"): self.request_args = request_args self.request_kwargs = request_kwargs self.request_context = request_context @@ -29,6 +33,8 @@ class Query: # absolute time since unix epoch. self.request_slo_ms = request_slo_ms + self.call_method = call_method + def ray_serialize(self): # NOTE: this method is needed because Query need to be serialized and # sent to the replica worker. However, after we send the query to @@ -175,8 +181,12 @@ class CentralizedQueues: else: request_slo_ms = request_in_object.adjust_relative_slo_ms() request_context = request_in_object.request_context - query = Query(request_args, request_kwargs, request_context, - request_slo_ms) + query = Query( + request_args, + request_kwargs, + request_context, + request_slo_ms, + call_method=request_in_object.call_method) await self.service_queues[service].put(query) await self.flush() @@ -298,8 +308,15 @@ class CentralizedQueues: requests = [ buffer_queue.pop(0) for _ in range(real_batch_size) ] - future = worker._ray_serve_call.remote(requests).as_future() - future.add_done_callback( - _make_future_unwrapper( - client_futures=[req.async_future for req in requests], - host_future=future)) + + # split requests by method type + requests_group = defaultdict(list) + for request in requests: + requests_group[request.call_method].append(request) + + for group in requests_group.values(): + future = worker._ray_serve_call.remote(group).as_future() + future.add_done_callback( + _make_future_unwrapper( + client_futures=[req.async_future for req in group], + host_future=future)) diff --git a/python/ray/serve/request_params.py b/python/ray/serve/request_params.py index 06e0e1b04..be17ef5ce 100644 --- a/python/ray/serve/request_params.py +++ b/python/ray/serve/request_params.py @@ -19,12 +19,14 @@ class RequestMetadata: service, request_context, relative_slo_ms=None, - absolute_slo_ms=None): + absolute_slo_ms=None, + call_method="__call__"): self.service = service self.request_context = request_context self.relative_slo_ms = relative_slo_ms self.absolute_slo_ms = absolute_slo_ms + self.call_method = call_method def adjust_relative_slo_ms(self) -> float: """Normalize the input latency objective to absoluate timestamp. diff --git a/python/ray/serve/task_runner.py b/python/ray/serve/task_runner.py index 7e0704656..65232e9f9 100644 --- a/python/ray/serve/task_runner.py +++ b/python/ray/serve/task_runner.py @@ -1,5 +1,6 @@ import time import traceback +import inspect import ray from ray.serve import context as serve_context @@ -7,6 +8,7 @@ from ray.serve.context import FakeFlaskRequest from collections import defaultdict from ray.serve.utils import parse_request_item from ray.serve.exceptions import RayServeException +from ray.async_compat import sync_to_async class TaskRunner: @@ -33,6 +35,13 @@ def wrap_to_ray_error(exception): return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__) +def ensure_async(func): + if inspect.iscoroutinefunction(func): + return func + else: + return sync_to_async(func) + + class RayServeMixin: """This mixin class adds the functionality to fetch from router queues. @@ -94,13 +103,25 @@ class RayServeMixin: self._ray_serve_dequeue_requester_name, self._ray_serve_self_handle) - def invoke_single(self, request_item): + def _ray_serve_get_runner_method(self, request_item): + method_name = request_item.call_method + if not hasattr(self, method_name): + raise RayServeException("Backend doesn't have method {} " + "which is specified in the request. " + "The avaiable methods are {}".format( + method_name, dir(self))) + + return getattr(self, method_name) + + async def invoke_single(self, request_item): args, kwargs, is_web_context = parse_request_item(request_item) serve_context.web = is_web_context start_timestamp = time.time() try: - result = self.__call__(*args, **kwargs) + result = await ensure_async( + self._ray_serve_get_runner_method(request_item))(*args, + **kwargs) except Exception as e: result = wrap_to_ray_error(e) self._serve_metric_error_counter += 1 @@ -108,7 +129,7 @@ class RayServeMixin: self._serve_metric_latency_list.append(time.time() - start_timestamp) return result - def invoke_batch(self, request_item_list): + async def invoke_batch(self, request_item_list): # TODO(alind) : create no-http services. The enqueues # from such services will always be TaskContext.Python. @@ -127,11 +148,14 @@ class RayServeMixin: kwargs_list = defaultdict(list) context_flags = set() batch_size = len(request_item_list) + call_methods = set() for item in request_item_list: args, kwargs, is_web_context = parse_request_item(item) context_flags.add(is_web_context) + call_methods.add(self._ray_serve_get_runner_method(item)) + if is_web_context: # Python context only have kwargs flask_request = args[0] @@ -153,14 +177,20 @@ class RayServeMixin: raise RayServeException( "Batched queries contain mixed context. Please only send " "the same type of requests in batching mode.") - serve_context.web = context_flags.pop() + + if len(call_methods) != 1: + raise RayServeException( + "Queries contain mixed calling methods. Please only send " + "the same type of requests in batching mode.") + call_method = ensure_async(call_methods.pop()) + serve_context.batch_size = batch_size # Flask requests are passed to __call__ as a list arg_list = [arg_list] start_timestamp = time.time() - result_list = self.__call__(*arg_list, **kwargs_list) + result_list = await call_method(*arg_list, **kwargs_list) self._serve_metric_latency_list.append(time.time() - start_timestamp) @@ -177,13 +207,13 @@ class RayServeMixin: self._serve_metric_error_counter += batch_size return [wrapped_exception for _ in range(batch_size)] - def _ray_serve_call(self, request): + async def _ray_serve_call(self, request): # check if work_item is a list or not # if it is list: then batching supported if not isinstance(request, list): - result = self.invoke_single(request) + result = await self.invoke_single(request) else: - result = self.invoke_batch(request) + result = await self.invoke_batch(request) # re-assign to default values serve_context.web = False diff --git a/python/ray/serve/tests/test_persistence.py b/python/ray/serve/tests/test_persistence.py index 602f3a825..b746995c5 100644 --- a/python/ray/serve/tests/test_persistence.py +++ b/python/ray/serve/tests/test_persistence.py @@ -9,7 +9,7 @@ from ray import serve def test_new_driver(serve_instance): script = """ import ray -ray.init(address="auto") +ray.init(address="{}") from ray import serve serve.init() @@ -17,7 +17,7 @@ serve.init() @serve.route("/driver") def driver(flask_request): return "OK!" -""" +""".format(ray.worker._global_node._redis_address) with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: path = f.name diff --git a/python/ray/serve/tests/test_task_runner.py b/python/ray/serve/tests/test_task_runner.py index 5a0236c6e..fb6ecd8a4 100644 --- a/python/ray/serve/tests/test_task_runner.py +++ b/python/ray/serve/tests/test_task_runner.py @@ -1,11 +1,15 @@ +import asyncio + import pytest import ray +from ray import serve import ray.serve.context as context from ray.serve.policy import RoundRobinPolicyQueueActor from ray.serve.task_runner import (RayServeMixin, TaskRunner, TaskRunnerActor, wrap_to_ray_error) from ray.serve.request_params import RequestMetadata +from ray.serve.backend_config import BackendConfig pytestmark = pytest.mark.asyncio @@ -97,3 +101,83 @@ async def test_task_runner_check_context(serve_instance): with pytest.raises(ray.exceptions.RayTaskError): await result_oid + + +async def test_task_runner_custom_method_single(serve_instance): + q = RoundRobinPolicyQueueActor.remote() + + class NonBatcher: + def a(self, _): + return "a" + + def b(self, _): + return "b" + + @ray.remote + class CustomActor(NonBatcher, RayServeMixin): + pass + + CONSUMER_NAME = "runner" + PRODUCER_NAME = "producer" + + runner = CustomActor.remote() + + runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) + runner._ray_serve_fetch.remote() + + q.link.remote(PRODUCER_NAME, CONSUMER_NAME) + + query_param = RequestMetadata( + PRODUCER_NAME, context.TaskContext.Python, call_method="a") + a_result = await q.enqueue_request.remote(query_param) + assert a_result == "a" + + query_param = RequestMetadata( + PRODUCER_NAME, context.TaskContext.Python, call_method="b") + b_result = await q.enqueue_request.remote(query_param) + assert b_result == "b" + + query_param = RequestMetadata( + PRODUCER_NAME, context.TaskContext.Python, call_method="non_exist") + with pytest.raises(ray.exceptions.RayTaskError): + await q.enqueue_request.remote(query_param) + + +async def test_task_runner_custom_method_batch(serve_instance): + q = RoundRobinPolicyQueueActor.remote() + + @serve.accept_batch + class Batcher: + def a(self, _): + return ["a-{}".format(i) for i in range(serve.context.batch_size)] + + def b(self, _): + return ["b-{}".format(i) for i in range(serve.context.batch_size)] + + @ray.remote + class CustomActor(Batcher, RayServeMixin): + pass + + CONSUMER_NAME = "runner" + PRODUCER_NAME = "producer" + + runner = CustomActor.remote() + + runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) + + q.link.remote(PRODUCER_NAME, CONSUMER_NAME) + q.set_backend_config.remote( + CONSUMER_NAME, BackendConfig(max_batch_size=2).__dict__) + + a_query_param = RequestMetadata( + PRODUCER_NAME, context.TaskContext.Python, call_method="a") + b_query_param = RequestMetadata( + PRODUCER_NAME, context.TaskContext.Python, call_method="b") + + futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] + futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] + + runner._ray_serve_fetch.remote() + + gathered = await asyncio.gather(*futures) + assert gathered == ["a-0", "a-1", "b-0", "b-1"]