[Serve] Add async, multi methods support for serve actors (#7682)

This commit is contained in:
Simon Mo
2020-03-23 00:45:26 -07:00
committed by GitHub
parent 039961b63a
commit afad0ed085
5 changed files with 153 additions and 20 deletions
+26 -9
View File
@@ -17,8 +17,12 @@ from ray.serve.utils import logger
class Query:
def __init__(self, request_args, request_kwargs, request_context,
request_slo_ms):
def __init__(self,
request_args,
request_kwargs,
request_context,
request_slo_ms,
call_method="__call__"):
self.request_args = request_args
self.request_kwargs = request_kwargs
self.request_context = request_context
@@ -29,6 +33,8 @@ class Query:
# absolute time since unix epoch.
self.request_slo_ms = request_slo_ms
self.call_method = call_method
def ray_serialize(self):
# NOTE: this method is needed because Query need to be serialized and
# sent to the replica worker. However, after we send the query to
@@ -175,8 +181,12 @@ class CentralizedQueues:
else:
request_slo_ms = request_in_object.adjust_relative_slo_ms()
request_context = request_in_object.request_context
query = Query(request_args, request_kwargs, request_context,
request_slo_ms)
query = Query(
request_args,
request_kwargs,
request_context,
request_slo_ms,
call_method=request_in_object.call_method)
await self.service_queues[service].put(query)
await self.flush()
@@ -298,8 +308,15 @@ class CentralizedQueues:
requests = [
buffer_queue.pop(0) for _ in range(real_batch_size)
]
future = worker._ray_serve_call.remote(requests).as_future()
future.add_done_callback(
_make_future_unwrapper(
client_futures=[req.async_future for req in requests],
host_future=future))
# split requests by method type
requests_group = defaultdict(list)
for request in requests:
requests_group[request.call_method].append(request)
for group in requests_group.values():
future = worker._ray_serve_call.remote(group).as_future()
future.add_done_callback(
_make_future_unwrapper(
client_futures=[req.async_future for req in group],
host_future=future))
+3 -1
View File
@@ -19,12 +19,14 @@ class RequestMetadata:
service,
request_context,
relative_slo_ms=None,
absolute_slo_ms=None):
absolute_slo_ms=None,
call_method="__call__"):
self.service = service
self.request_context = request_context
self.relative_slo_ms = relative_slo_ms
self.absolute_slo_ms = absolute_slo_ms
self.call_method = call_method
def adjust_relative_slo_ms(self) -> float:
"""Normalize the input latency objective to absoluate timestamp.
+38 -8
View File
@@ -1,5 +1,6 @@
import time
import traceback
import inspect
import ray
from ray.serve import context as serve_context
@@ -7,6 +8,7 @@ from ray.serve.context import FakeFlaskRequest
from collections import defaultdict
from ray.serve.utils import parse_request_item
from ray.serve.exceptions import RayServeException
from ray.async_compat import sync_to_async
class TaskRunner:
@@ -33,6 +35,13 @@ def wrap_to_ray_error(exception):
return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__)
def ensure_async(func):
if inspect.iscoroutinefunction(func):
return func
else:
return sync_to_async(func)
class RayServeMixin:
"""This mixin class adds the functionality to fetch from router queues.
@@ -94,13 +103,25 @@ class RayServeMixin:
self._ray_serve_dequeue_requester_name,
self._ray_serve_self_handle)
def invoke_single(self, request_item):
def _ray_serve_get_runner_method(self, request_item):
method_name = request_item.call_method
if not hasattr(self, method_name):
raise RayServeException("Backend doesn't have method {} "
"which is specified in the request. "
"The avaiable methods are {}".format(
method_name, dir(self)))
return getattr(self, method_name)
async def invoke_single(self, request_item):
args, kwargs, is_web_context = parse_request_item(request_item)
serve_context.web = is_web_context
start_timestamp = time.time()
try:
result = self.__call__(*args, **kwargs)
result = await ensure_async(
self._ray_serve_get_runner_method(request_item))(*args,
**kwargs)
except Exception as e:
result = wrap_to_ray_error(e)
self._serve_metric_error_counter += 1
@@ -108,7 +129,7 @@ class RayServeMixin:
self._serve_metric_latency_list.append(time.time() - start_timestamp)
return result
def invoke_batch(self, request_item_list):
async def invoke_batch(self, request_item_list):
# TODO(alind) : create no-http services. The enqueues
# from such services will always be TaskContext.Python.
@@ -127,11 +148,14 @@ class RayServeMixin:
kwargs_list = defaultdict(list)
context_flags = set()
batch_size = len(request_item_list)
call_methods = set()
for item in request_item_list:
args, kwargs, is_web_context = parse_request_item(item)
context_flags.add(is_web_context)
call_methods.add(self._ray_serve_get_runner_method(item))
if is_web_context:
# Python context only have kwargs
flask_request = args[0]
@@ -153,14 +177,20 @@ class RayServeMixin:
raise RayServeException(
"Batched queries contain mixed context. Please only send "
"the same type of requests in batching mode.")
serve_context.web = context_flags.pop()
if len(call_methods) != 1:
raise RayServeException(
"Queries contain mixed calling methods. Please only send "
"the same type of requests in batching mode.")
call_method = ensure_async(call_methods.pop())
serve_context.batch_size = batch_size
# Flask requests are passed to __call__ as a list
arg_list = [arg_list]
start_timestamp = time.time()
result_list = self.__call__(*arg_list, **kwargs_list)
result_list = await call_method(*arg_list, **kwargs_list)
self._serve_metric_latency_list.append(time.time() -
start_timestamp)
@@ -177,13 +207,13 @@ class RayServeMixin:
self._serve_metric_error_counter += batch_size
return [wrapped_exception for _ in range(batch_size)]
def _ray_serve_call(self, request):
async def _ray_serve_call(self, request):
# check if work_item is a list or not
# if it is list: then batching supported
if not isinstance(request, list):
result = self.invoke_single(request)
result = await self.invoke_single(request)
else:
result = self.invoke_batch(request)
result = await self.invoke_batch(request)
# re-assign to default values
serve_context.web = False
+2 -2
View File
@@ -9,7 +9,7 @@ from ray import serve
def test_new_driver(serve_instance):
script = """
import ray
ray.init(address="auto")
ray.init(address="{}")
from ray import serve
serve.init()
@@ -17,7 +17,7 @@ serve.init()
@serve.route("/driver")
def driver(flask_request):
return "OK!"
"""
""".format(ray.worker._global_node._redis_address)
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
path = f.name
@@ -1,11 +1,15 @@
import asyncio
import pytest
import ray
from ray import serve
import ray.serve.context as context
from ray.serve.policy import RoundRobinPolicyQueueActor
from ray.serve.task_runner import (RayServeMixin, TaskRunner, TaskRunnerActor,
wrap_to_ray_error)
from ray.serve.request_params import RequestMetadata
from ray.serve.backend_config import BackendConfig
pytestmark = pytest.mark.asyncio
@@ -97,3 +101,83 @@ async def test_task_runner_check_context(serve_instance):
with pytest.raises(ray.exceptions.RayTaskError):
await result_oid
async def test_task_runner_custom_method_single(serve_instance):
q = RoundRobinPolicyQueueActor.remote()
class NonBatcher:
def a(self, _):
return "a"
def b(self, _):
return "b"
@ray.remote
class CustomActor(NonBatcher, RayServeMixin):
pass
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
runner = CustomActor.remote()
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
runner._ray_serve_fetch.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
a_result = await q.enqueue_request.remote(query_param)
assert a_result == "a"
query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="b")
b_result = await q.enqueue_request.remote(query_param)
assert b_result == "b"
query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="non_exist")
with pytest.raises(ray.exceptions.RayTaskError):
await q.enqueue_request.remote(query_param)
async def test_task_runner_custom_method_batch(serve_instance):
q = RoundRobinPolicyQueueActor.remote()
@serve.accept_batch
class Batcher:
def a(self, _):
return ["a-{}".format(i) for i in range(serve.context.batch_size)]
def b(self, _):
return ["b-{}".format(i) for i in range(serve.context.batch_size)]
@ray.remote
class CustomActor(Batcher, RayServeMixin):
pass
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
runner = CustomActor.remote()
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
q.set_backend_config.remote(
CONSUMER_NAME, BackendConfig(max_batch_size=2).__dict__)
a_query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
b_query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="b")
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
runner._ray_serve_fetch.remote()
gathered = await asyncio.gather(*futures)
assert gathered == ["a-0", "a-1", "b-0", "b-1"]