mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 07:53:50 +08:00
[Serve] Add async, multi methods support for serve actors (#7682)
This commit is contained in:
@@ -17,8 +17,12 @@ from ray.serve.utils import logger
|
||||
|
||||
|
||||
class Query:
|
||||
def __init__(self, request_args, request_kwargs, request_context,
|
||||
request_slo_ms):
|
||||
def __init__(self,
|
||||
request_args,
|
||||
request_kwargs,
|
||||
request_context,
|
||||
request_slo_ms,
|
||||
call_method="__call__"):
|
||||
self.request_args = request_args
|
||||
self.request_kwargs = request_kwargs
|
||||
self.request_context = request_context
|
||||
@@ -29,6 +33,8 @@ class Query:
|
||||
# absolute time since unix epoch.
|
||||
self.request_slo_ms = request_slo_ms
|
||||
|
||||
self.call_method = call_method
|
||||
|
||||
def ray_serialize(self):
|
||||
# NOTE: this method is needed because Query need to be serialized and
|
||||
# sent to the replica worker. However, after we send the query to
|
||||
@@ -175,8 +181,12 @@ class CentralizedQueues:
|
||||
else:
|
||||
request_slo_ms = request_in_object.adjust_relative_slo_ms()
|
||||
request_context = request_in_object.request_context
|
||||
query = Query(request_args, request_kwargs, request_context,
|
||||
request_slo_ms)
|
||||
query = Query(
|
||||
request_args,
|
||||
request_kwargs,
|
||||
request_context,
|
||||
request_slo_ms,
|
||||
call_method=request_in_object.call_method)
|
||||
await self.service_queues[service].put(query)
|
||||
await self.flush()
|
||||
|
||||
@@ -298,8 +308,15 @@ class CentralizedQueues:
|
||||
requests = [
|
||||
buffer_queue.pop(0) for _ in range(real_batch_size)
|
||||
]
|
||||
future = worker._ray_serve_call.remote(requests).as_future()
|
||||
future.add_done_callback(
|
||||
_make_future_unwrapper(
|
||||
client_futures=[req.async_future for req in requests],
|
||||
host_future=future))
|
||||
|
||||
# split requests by method type
|
||||
requests_group = defaultdict(list)
|
||||
for request in requests:
|
||||
requests_group[request.call_method].append(request)
|
||||
|
||||
for group in requests_group.values():
|
||||
future = worker._ray_serve_call.remote(group).as_future()
|
||||
future.add_done_callback(
|
||||
_make_future_unwrapper(
|
||||
client_futures=[req.async_future for req in group],
|
||||
host_future=future))
|
||||
|
||||
@@ -19,12 +19,14 @@ class RequestMetadata:
|
||||
service,
|
||||
request_context,
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None):
|
||||
absolute_slo_ms=None,
|
||||
call_method="__call__"):
|
||||
|
||||
self.service = service
|
||||
self.request_context = request_context
|
||||
self.relative_slo_ms = relative_slo_ms
|
||||
self.absolute_slo_ms = absolute_slo_ms
|
||||
self.call_method = call_method
|
||||
|
||||
def adjust_relative_slo_ms(self) -> float:
|
||||
"""Normalize the input latency objective to absoluate timestamp.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
import traceback
|
||||
import inspect
|
||||
|
||||
import ray
|
||||
from ray.serve import context as serve_context
|
||||
@@ -7,6 +8,7 @@ from ray.serve.context import FakeFlaskRequest
|
||||
from collections import defaultdict
|
||||
from ray.serve.utils import parse_request_item
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.async_compat import sync_to_async
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
@@ -33,6 +35,13 @@ def wrap_to_ray_error(exception):
|
||||
return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__)
|
||||
|
||||
|
||||
def ensure_async(func):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return func
|
||||
else:
|
||||
return sync_to_async(func)
|
||||
|
||||
|
||||
class RayServeMixin:
|
||||
"""This mixin class adds the functionality to fetch from router queues.
|
||||
|
||||
@@ -94,13 +103,25 @@ class RayServeMixin:
|
||||
self._ray_serve_dequeue_requester_name,
|
||||
self._ray_serve_self_handle)
|
||||
|
||||
def invoke_single(self, request_item):
|
||||
def _ray_serve_get_runner_method(self, request_item):
|
||||
method_name = request_item.call_method
|
||||
if not hasattr(self, method_name):
|
||||
raise RayServeException("Backend doesn't have method {} "
|
||||
"which is specified in the request. "
|
||||
"The avaiable methods are {}".format(
|
||||
method_name, dir(self)))
|
||||
|
||||
return getattr(self, method_name)
|
||||
|
||||
async def invoke_single(self, request_item):
|
||||
args, kwargs, is_web_context = parse_request_item(request_item)
|
||||
serve_context.web = is_web_context
|
||||
start_timestamp = time.time()
|
||||
|
||||
try:
|
||||
result = self.__call__(*args, **kwargs)
|
||||
result = await ensure_async(
|
||||
self._ray_serve_get_runner_method(request_item))(*args,
|
||||
**kwargs)
|
||||
except Exception as e:
|
||||
result = wrap_to_ray_error(e)
|
||||
self._serve_metric_error_counter += 1
|
||||
@@ -108,7 +129,7 @@ class RayServeMixin:
|
||||
self._serve_metric_latency_list.append(time.time() - start_timestamp)
|
||||
return result
|
||||
|
||||
def invoke_batch(self, request_item_list):
|
||||
async def invoke_batch(self, request_item_list):
|
||||
# TODO(alind) : create no-http services. The enqueues
|
||||
# from such services will always be TaskContext.Python.
|
||||
|
||||
@@ -127,11 +148,14 @@ class RayServeMixin:
|
||||
kwargs_list = defaultdict(list)
|
||||
context_flags = set()
|
||||
batch_size = len(request_item_list)
|
||||
call_methods = set()
|
||||
|
||||
for item in request_item_list:
|
||||
args, kwargs, is_web_context = parse_request_item(item)
|
||||
context_flags.add(is_web_context)
|
||||
|
||||
call_methods.add(self._ray_serve_get_runner_method(item))
|
||||
|
||||
if is_web_context:
|
||||
# Python context only have kwargs
|
||||
flask_request = args[0]
|
||||
@@ -153,14 +177,20 @@ class RayServeMixin:
|
||||
raise RayServeException(
|
||||
"Batched queries contain mixed context. Please only send "
|
||||
"the same type of requests in batching mode.")
|
||||
|
||||
serve_context.web = context_flags.pop()
|
||||
|
||||
if len(call_methods) != 1:
|
||||
raise RayServeException(
|
||||
"Queries contain mixed calling methods. Please only send "
|
||||
"the same type of requests in batching mode.")
|
||||
call_method = ensure_async(call_methods.pop())
|
||||
|
||||
serve_context.batch_size = batch_size
|
||||
# Flask requests are passed to __call__ as a list
|
||||
arg_list = [arg_list]
|
||||
|
||||
start_timestamp = time.time()
|
||||
result_list = self.__call__(*arg_list, **kwargs_list)
|
||||
result_list = await call_method(*arg_list, **kwargs_list)
|
||||
|
||||
self._serve_metric_latency_list.append(time.time() -
|
||||
start_timestamp)
|
||||
@@ -177,13 +207,13 @@ class RayServeMixin:
|
||||
self._serve_metric_error_counter += batch_size
|
||||
return [wrapped_exception for _ in range(batch_size)]
|
||||
|
||||
def _ray_serve_call(self, request):
|
||||
async def _ray_serve_call(self, request):
|
||||
# check if work_item is a list or not
|
||||
# if it is list: then batching supported
|
||||
if not isinstance(request, list):
|
||||
result = self.invoke_single(request)
|
||||
result = await self.invoke_single(request)
|
||||
else:
|
||||
result = self.invoke_batch(request)
|
||||
result = await self.invoke_batch(request)
|
||||
|
||||
# re-assign to default values
|
||||
serve_context.web = False
|
||||
|
||||
@@ -9,7 +9,7 @@ from ray import serve
|
||||
def test_new_driver(serve_instance):
|
||||
script = """
|
||||
import ray
|
||||
ray.init(address="auto")
|
||||
ray.init(address="{}")
|
||||
|
||||
from ray import serve
|
||||
serve.init()
|
||||
@@ -17,7 +17,7 @@ serve.init()
|
||||
@serve.route("/driver")
|
||||
def driver(flask_request):
|
||||
return "OK!"
|
||||
"""
|
||||
""".format(ray.worker._global_node._redis_address)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
|
||||
path = f.name
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray import serve
|
||||
import ray.serve.context as context
|
||||
from ray.serve.policy import RoundRobinPolicyQueueActor
|
||||
from ray.serve.task_runner import (RayServeMixin, TaskRunner, TaskRunnerActor,
|
||||
wrap_to_ray_error)
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.backend_config import BackendConfig
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
@@ -97,3 +101,83 @@ async def test_task_runner_check_context(serve_instance):
|
||||
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
await result_oid
|
||||
|
||||
|
||||
async def test_task_runner_custom_method_single(serve_instance):
|
||||
q = RoundRobinPolicyQueueActor.remote()
|
||||
|
||||
class NonBatcher:
|
||||
def a(self, _):
|
||||
return "a"
|
||||
|
||||
def b(self, _):
|
||||
return "b"
|
||||
|
||||
@ray.remote
|
||||
class CustomActor(NonBatcher, RayServeMixin):
|
||||
pass
|
||||
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
runner = CustomActor.remote()
|
||||
|
||||
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
|
||||
runner._ray_serve_fetch.remote()
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
|
||||
query_param = RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
|
||||
a_result = await q.enqueue_request.remote(query_param)
|
||||
assert a_result == "a"
|
||||
|
||||
query_param = RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method="b")
|
||||
b_result = await q.enqueue_request.remote(query_param)
|
||||
assert b_result == "b"
|
||||
|
||||
query_param = RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method="non_exist")
|
||||
with pytest.raises(ray.exceptions.RayTaskError):
|
||||
await q.enqueue_request.remote(query_param)
|
||||
|
||||
|
||||
async def test_task_runner_custom_method_batch(serve_instance):
|
||||
q = RoundRobinPolicyQueueActor.remote()
|
||||
|
||||
@serve.accept_batch
|
||||
class Batcher:
|
||||
def a(self, _):
|
||||
return ["a-{}".format(i) for i in range(serve.context.batch_size)]
|
||||
|
||||
def b(self, _):
|
||||
return ["b-{}".format(i) for i in range(serve.context.batch_size)]
|
||||
|
||||
@ray.remote
|
||||
class CustomActor(Batcher, RayServeMixin):
|
||||
pass
|
||||
|
||||
CONSUMER_NAME = "runner"
|
||||
PRODUCER_NAME = "producer"
|
||||
|
||||
runner = CustomActor.remote()
|
||||
|
||||
runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)
|
||||
|
||||
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
|
||||
q.set_backend_config.remote(
|
||||
CONSUMER_NAME, BackendConfig(max_batch_size=2).__dict__)
|
||||
|
||||
a_query_param = RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
|
||||
b_query_param = RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method="b")
|
||||
|
||||
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
|
||||
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
|
||||
|
||||
runner._ray_serve_fetch.remote()
|
||||
|
||||
gathered = await asyncio.gather(*futures)
|
||||
assert gathered == ["a-0", "a-1", "b-0", "b-1"]
|
||||
|
||||
Reference in New Issue
Block a user