[serve] Don't use mixin class for class-based backends (#7957)

This commit is contained in:
Edward Oakes
2020-04-10 12:01:14 -05:00
committed by GitHub
parent 31b40b00f6
commit d8f5b52265
8 changed files with 190 additions and 203 deletions
+15 -31
View File
@@ -10,7 +10,6 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
from ray.serve.master import ServeMaster
from ray.serve.handle import RayServeHandle
from ray.serve.kv_store_service import SQLiteKVStore
from ray.serve.task_runner import RayServeMixin, TaskRunnerActor
from ray.serve.utils import block_until_http_ready
from ray.serve.exceptions import RayServeException, batch_annotation_not_found
from ray.serve.backend_config import BackendConfig
@@ -208,8 +207,8 @@ def create_backend(func_or_class,
"""Create a backend using func_or_class and assign backend_tag.
Args:
func_or_class (callable, class): a function or a class implements
__call__ protocol.
func_or_class (callable, class): a function or a class implementing
__call__.
backend_tag (str): a unique tag assign to this backend. It will be used
to associate services in traffic policy.
backend_config (BackendConfig): An object defining backend properties
@@ -224,41 +223,26 @@ def create_backend(func_or_class,
BackendConfig), ("backend_config must be"
" of instance BackendConfig")
# Make sure the batch size is correct
# Validate that func_or_class is a function or class.
if inspect.isfunction(func_or_class):
if len(actor_init_args) != 0:
raise ValueError(
"actor_init_args not supported for function backend.")
elif not inspect.isclass(func_or_class):
raise ValueError(
"Backend must be a function or class, it is {}.".format(
type(func_or_class)))
# Make sure the batch size is correct.
should_accept_batch = backend_config.max_batch_size is not None
if should_accept_batch and not _backend_accept_batch(func_or_class):
raise batch_annotation_not_found
if _backend_accept_batch(func_or_class):
backend_config.has_accept_batch_annotation = True
arg_list = []
if inspect.isfunction(func_or_class):
# arg list for a fn is function itself
arg_list = [func_or_class]
# ignore lint on lambda expression
creator = lambda kwrgs: TaskRunnerActor._remote(**kwrgs) # noqa: E731
elif inspect.isclass(func_or_class):
# Python inheritance order is right-to-left. We put RayServeMixin
# on the left to make sure its methods are not overriden.
@ray.remote
class CustomActor(RayServeMixin, func_or_class):
@wraps(func_or_class.__init__)
def __init__(self, *args, **kwargs):
# Initialize serve so it can be used in backends.
init()
super().__init__(*args, **kwargs)
arg_list = actor_init_args
# ignore lint on lambda expression
creator = lambda kwargs: CustomActor._remote(**kwargs) # noqa: E731
else:
raise TypeError(
"Backend must be a function or class, it is {}.".format(
type(func_or_class)))
ray.get(
master_actor.create_backend.remote(backend_tag, creator,
backend_config, arg_list))
master_actor.create_backend.remote(backend_tag, backend_config,
func_or_class, actor_init_args))
@_ensure_connected
@@ -12,23 +12,69 @@ from ray.serve.exceptions import RayServeException
from ray.async_compat import sync_to_async
class TaskRunner:
"""A simple class that runs a function.
def create_backend_worker(func_or_class):
if inspect.isfunction(func_or_class):
is_function = True
elif inspect.isclass(func_or_class):
is_function = False
else:
assert False, "func_or_class must be function or class."
The purpose of this class is to model what the most basic actor could be.
That is, a ray serve actor should implement the TaskRunner interface.
"""
class RayServeWrappedWorker(object):
def __init__(self,
backend_tag,
replica_tag,
init_args,
self_handle=None,
router_handle=None,
start_running=True):
serve.init()
if is_function:
_callable = func_or_class
else:
_callable = func_or_class(*init_args)
def __init__(self, func_to_run):
serve.init()
if self_handle is None:
assert router_handle is None
master_actor = serve.api._get_master_actor()
# TODO(edoakes): this is a hacky workaround because there is
# a race condition when the master starts up a worker: the
# master starts the worker then adds its handle to a local map
# and the worker queries the master for its own handle from
# that map. If there's a large enough delay in the master, the
# handle may not have been added to the map yet (seen in CI).
# This will be fixed soon when the router just pushes tasks to
# the workers instead of the workers indicating that they're
# available.
start = time.time()
while time.time() - start < 5:
try:
print("Calling get_backend_replica_config")
[self_handle], [router_handle] = ray.get(
master_actor.get_backend_replica_config.remote(
replica_tag))
print("Got get_backend_replica_config")
break
except ray.exceptions.RayTaskError:
pass
self.func = func_to_run
self.backend = RayServeWorker(backend_tag, _callable, self_handle,
router_handle, is_function)
# This parameter let argument inspection work with inner function.
self.__wrapped__ = func_to_run
if start_running:
self.backend.mark_idle_in_router()
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def get_metrics(self):
return self.backend.get_metrics()
async def handle_request(self, request):
return await self.backend.handle_request(request)
def ready(self):
pass
RayServeWrappedWorker.__name__ = "RayServeWorker_" + func_or_class.__name__
return RayServeWrappedWorker
def wrap_to_ray_error(exception):
@@ -48,108 +94,78 @@ def ensure_async(func):
return sync_to_async(func)
class RayServeMixin:
"""This mixin class adds the functionality to fetch from router queues.
class RayServeWorker:
"""Fetches requests and handles them with the provided callable."""
Warning:
It assumes the main execution method is `__call__` of the user defined
class. This means that serve will call `your_instance.__call__` when
each request comes in. This behavior will be fixed in the future to
allow assigning artibrary methods.
def __init__(self, name, _callable, self_handle, router_handle,
is_function):
self.name = name
self.callable = _callable
self.self_handle = self_handle
self.router_handle = router_handle
self.is_function = is_function
Example:
>>> # Use ray.remote decorator and RayServeMixin
>>> # to make MyClass servable.
>>> @ray.remote
class RayServeActor(RayServeMixin, MyClass):
pass
"""
_ray_serve_self_handle = None
_ray_serve_router_handle = None
_ray_serve_setup_completed = False
_ray_serve_dequeue_requester_name = None
self.error_counter = 0
self.latency_list = []
# Work token can be unfullfilled from last iteration.
# This cache will be used to determine whether or not we should
# work on the same task as previous iteration or we are ready to
# move on.
_ray_serve_cached_work_token = None
_serve_metric_error_counter = 0
_serve_metric_latency_list = []
def _serve_metric(self):
def get_metrics(self):
# Make a copy of the latency list and clear current list
latency_lst = self._serve_metric_latency_list[:]
self._serve_metric_latency_list = []
my_name = self._ray_serve_dequeue_requester_name
latency_list = self.latency_list[:]
self.latency_list = []
return {
"{}_error_counter".format(my_name): {
"value": self._serve_metric_error_counter,
"{}_error_counter".format(self.name): {
"value": self.error_counter,
"type": "counter",
},
"{}_latency_s".format(my_name): {
"value": latency_lst,
"{}_latency_s".format(self.name): {
"value": latency_list,
"type": "list",
},
}
def _ray_serve_setup(self, my_name, router_handle, my_handle):
self._ray_serve_dequeue_requester_name = my_name
self._ray_serve_router_handle = router_handle
self._ray_serve_self_handle = my_handle
self._ray_serve_setup_completed = True
def mark_idle_in_router(self):
# Tell the router that this worker can accept tasks.
self.router_handle.dequeue_request.remote(self.name, self.self_handle)
def _ray_serve_fetch(self):
assert self._ray_serve_setup_completed
self._ray_serve_router_handle.dequeue_request.remote(
self._ray_serve_dequeue_requester_name,
self._ray_serve_self_handle)
def _ray_serve_get_runner_method(self, request_item):
def get_runner_method(self, request_item):
method_name = request_item.call_method
if not hasattr(self, method_name):
if not hasattr(self.callable, method_name):
raise RayServeException("Backend doesn't have method {} "
"which is specified in the request. "
"The avaiable methods are {}".format(
method_name, dir(self)))
return getattr(self, method_name)
return getattr(self.callable, method_name)
def _ray_serve_count_num_positional(self, f):
def has_positional_args(self, f):
# NOTE:
# In the case of simple functions, not actors, the f will be
# a TaskRunner.__call__. What we really want here is the wrapped
# functionso inspect.signature will figure out the underlying f.
if hasattr(self, "__wrapped__"):
f = self.__wrapped__
# function.__call__, but we need to inspect the function itself.
if self.is_function:
f = self.callable
signature = inspect.signature(f)
counter = 0
for param in signature.parameters.values():
if (param.kind == param.POSITIONAL_OR_KEYWORD
and param.default is param.empty):
counter += 1
return counter
return True
return False
async def invoke_single(self, request_item):
args, kwargs, is_web_context = parse_request_item(request_item)
serve_context.web = is_web_context
start_timestamp = time.time()
method_to_call = self._ray_serve_get_runner_method(request_item)
args = args if self._ray_serve_count_num_positional(
method_to_call) else []
method_to_call = self.get_runner_method(request_item)
args = args if self.has_positional_args(method_to_call) else []
method_to_call = ensure_async(method_to_call)
try:
result = await method_to_call(*args, **kwargs)
except Exception as e:
result = wrap_to_ray_error(e)
self._serve_metric_error_counter += 1
self.error_counter += 1
self._serve_metric_latency_list.append(time.time() - start_timestamp)
self.latency_list.append(time.time() - start_timestamp)
return result
async def invoke_batch(self, request_item_list):
@@ -177,7 +193,7 @@ class RayServeMixin:
args, kwargs, is_web_context = parse_request_item(item)
context_flags.add(is_web_context)
call_method = self._ray_serve_get_runner_method(item)
call_method = self.get_runner_method(item)
call_methods.add(call_method)
if is_web_context:
@@ -191,13 +207,12 @@ class RayServeMixin:
# Set the flask request as a list to conform
# with batching semantics: when in batching
# mode, each argument it turned into list.
if self._ray_serve_count_num_positional(call_method):
# mode, each argument is turned into list.
if self.has_positional_args(call_method):
arg_list.append(FakeFlaskRequest())
try:
# check mixing of query context
# unified context needed
# Check mixing of query context (unified context needed).
if len(context_flags) != 1:
raise RayServeException(
"Batched queries contain mixed context. Please only send "
@@ -217,8 +232,7 @@ class RayServeMixin:
start_timestamp = time.time()
result_list = await call_method(*arg_list, **kwargs_list)
self._serve_metric_latency_list.append(time.time() -
start_timestamp)
self.latency_list.append(time.time() - start_timestamp)
if (not isinstance(result_list,
list)) or (len(result_list) != batch_size):
raise RayServeException("__call__ function "
@@ -229,10 +243,10 @@ class RayServeMixin:
return result_list
except Exception as e:
wrapped_exception = wrap_to_ray_error(e)
self._serve_metric_error_counter += batch_size
self.error_counter += batch_size
return [wrapped_exception for _ in range(batch_size)]
async def _ray_serve_call(self, request):
async def handle_request(self, request):
# check if work_item is a list or not
# if it is list: then batching supported
if not isinstance(request, list):
@@ -243,28 +257,6 @@ class RayServeMixin:
# re-assign to default values
serve_context.web = False
serve_context.batch_size = None
# Tell router that current actor is idle
self._ray_serve_fetch()
self.mark_idle_in_router()
return result
class TaskRunnerBackend(TaskRunner, RayServeMixin):
"""A simple function serving backend
Note that this is not yet an actor. To make it an actor:
>>> @ray.remote
class TaskRunnerActor(TaskRunnerBackend):
pass
Note:
This class is not used in the actual ray serve system. It exists
for documentation purpose.
"""
@ray.remote
class TaskRunnerActor(TaskRunnerBackend):
pass
+29 -17
View File
@@ -6,6 +6,7 @@ from ray.serve.http_proxy import HTTPProxyActor
from ray.serve.kv_store_service import (BackendTable, RoutingTable,
TrafficPolicyTable)
from ray.serve.metric import (MetricMonitor, start_metric_monitor_loop)
from ray.serve.backend_worker import create_backend_worker
from ray.serve.utils import expand, get_random_letters
import numpy as np
@@ -104,32 +105,42 @@ class ServeMaster:
for _ in range(-delta_num_replicas):
self._remove_backend_replica(backend_tag)
async def get_backend_replica_config(self, replica_tag):
return [self.tag_to_actor_handles[replica_tag]], self.get_router()
async def _start_backend_replica(self, backend_tag):
assert (backend_tag in self.backend_table.list_backends()
), "Backend {} is not registered.".format(backend_tag)
replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6))
# Register the worker in the DB.
# TODO(edoakes): we should guarantee that if calls to the master
# succeed, the cluster state has changed and if they fail, it hasn't.
# Once we have master actor fault tolerance, this breaks that guarantee
# because this method could fail after writing the replica to the DB.
self.backend_table.add_replica(backend_tag, replica_tag)
# Fetch the info to start the replica from the backend table.
creator = self.backend_table.get_backend_creator(backend_tag)
backend_actor = ray.remote(
self.backend_table.get_backend_creator(backend_tag))
backend_config_dict = self.backend_table.get_info(backend_tag)
backend_config = BackendConfig(**backend_config_dict)
init_args = self.backend_table.get_init_args(backend_tag)
init_args = [
backend_tag, replica_tag,
self.backend_table.get_init_args(backend_tag)
]
kwargs = backend_config.get_actor_creation_args(init_args)
runner_handle = creator(kwargs)
self.tag_to_actor_handles[replica_tag] = runner_handle
# Start the worker.
worker_handle = backend_actor._remote(**kwargs)
self.tag_to_actor_handles[replica_tag] = worker_handle
# Set up the worker.
# Wait for the worker to start up.
await worker_handle.ready.remote()
await runner_handle._ray_serve_setup.remote(backend_tag,
self.get_router()[0],
runner_handle)
ray.get(runner_handle._ray_serve_fetch.remote())
# Register the worker in config tables and metric monitor.
self.backend_table.add_replica(backend_tag, replica_tag)
self.get_metric_monitor()[0].add_target.remote(runner_handle)
# Register the worker with the metric monitor.
self.get_metric_monitor()[0].add_target.remote(worker_handle)
def _remove_backend_replica(self, backend_tag):
assert (backend_tag in self.backend_table.list_backends()
@@ -191,18 +202,19 @@ class ServeMaster:
self.route_table.list_service(
include_methods=True, include_headless=False)))
async def create_backend(self, backend_tag, creator, backend_config,
arg_list):
async def create_backend(self, backend_tag, backend_config, func_or_class,
actor_init_args):
backend_config_dict = dict(backend_config)
backend_worker = create_backend_worker(func_or_class)
# Save creator which starts replicas.
self.backend_table.register_backend(backend_tag, creator)
self.backend_table.register_backend(backend_tag, backend_worker)
# Save information about configurations needed to start the replicas.
self.backend_table.register_info(backend_tag, backend_config_dict)
# Save the initial arguments needed by replicas.
self.backend_table.save_init_args(backend_tag, arg_list)
self.backend_table.save_init_args(backend_tag, actor_init_args)
# Set the backend config inside the router
# (particularly for max-batch-size).
+1 -1
View File
@@ -44,7 +44,7 @@ class MetricMonitor:
curr_time = time.time()
result = [
handle._serve_metric.remote()
handle.get_metrics.remote()
for handle in self.actor_handles.values()
]
# TODO(simon): handle the possibility that an actor_handle is removed
+4 -4
View File
@@ -85,7 +85,7 @@ def _make_future_unwrapper(client_futures: List[asyncio.Future],
class CentralizedQueues:
"""A router that routes request to available workers.
Router aceepts each request from the `enqueue_request` method and enqueues
Router accepts each request from the `enqueue_request` method and enqueues
it. It also accepts worker request to work (called work_intention in code)
from workers via the `dequeue_request` method. The traffic policy is used
to match requests with their corresponding workers.
@@ -161,7 +161,7 @@ class CentralizedQueues:
def is_ready(self):
return True
def _serve_metric(self):
def get_metrics(self):
return {
"backend_{}_queue_size".format(backend_name): {
"value": len(queue),
@@ -302,7 +302,7 @@ class CentralizedQueues:
worker = await worker_queue.get()
if max_batch_size is None: # No batching
request = buffer_queue.pop(0)
future = worker._ray_serve_call.remote(request).as_future()
future = worker.handle_request.remote(request).as_future()
# chaining satisfies request.async_future with future result.
asyncio.futures._chain_future(future, request.async_future)
else:
@@ -317,7 +317,7 @@ class CentralizedQueues:
requests_group[request.call_method].append(request)
for group in requests_group.values():
future = worker._ray_serve_call.remote(group).as_future()
future = worker.handle_request.remote(group).as_future()
future.add_done_callback(
_make_future_unwrapper(
client_futures=[req.async_future for req in group],
@@ -6,20 +6,40 @@ import ray
from ray import serve
import ray.serve.context as context
from ray.serve.policy import RoundRobinPolicyQueueActor
from ray.serve.task_runner import (RayServeMixin, TaskRunner, TaskRunnerActor,
wrap_to_ray_error)
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
from ray.serve.request_params import RequestMetadata
from ray.serve.backend_config import BackendConfig
pytestmark = pytest.mark.asyncio
async def test_runner_basic():
def echo(i):
return i
def setup_worker(name, func_or_class, router_handle, init_args=None):
if init_args is None:
init_args = ()
r = TaskRunner(echo)
assert r(1) == 1
@ray.remote
class WorkerActor:
def setup(self, self_handle, router_handle):
self.worker = create_backend_worker(func_or_class)(
name,
name + ":tag",
init_args,
self_handle=self_handle[0],
router_handle=router_handle[0],
start_running=False)
def get_metrics(self):
return self.worker.get_metrics()
def run(self):
self.worker.backend.mark_idle_in_router()
async def handle_request(self, *args, **kwargs):
return await self.worker.handle_request(*args, **kwargs)
worker = WorkerActor.remote()
ray.get(worker.setup.remote([worker], [router_handle]))
return worker
async def test_runner_wraps_error():
@@ -36,9 +56,8 @@ async def test_runner_actor(serve_instance):
CONSUMER_NAME = "runner"
PRODUCER_NAME = "prod"
runner = TaskRunnerActor.remote(echo)
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
runner._ray_serve_fetch.remote()
worker = setup_worker(CONSUMER_NAME, echo, q)
await worker.run.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -62,14 +81,8 @@ async def test_ray_serve_mixin(serve_instance):
def __call__(self, flask_request, i=None):
return i + self.increment
@ray.remote
class CustomActor(MyAdder, RayServeMixin):
pass
runner = CustomActor.remote(3)
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
runner._ray_serve_fetch.remote()
worker = setup_worker(CONSUMER_NAME, MyAdder, q, init_args=(3, ))
await worker.run.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -90,10 +103,8 @@ async def test_task_runner_check_context(serve_instance):
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
runner = TaskRunnerActor.remote(echo)
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
runner._ray_serve_fetch.remote()
worker = setup_worker(CONSUMER_NAME, echo, q)
await worker.run.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
@@ -113,17 +124,11 @@ async def test_task_runner_custom_method_single(serve_instance):
def b(self, _):
return "b"
@ray.remote
class CustomActor(NonBatcher, RayServeMixin):
pass
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
runner = CustomActor.remote()
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
runner._ray_serve_fetch.remote()
worker = setup_worker(CONSUMER_NAME, NonBatcher, q)
await worker.run.remote()
q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
@@ -154,16 +159,10 @@ async def test_task_runner_custom_method_batch(serve_instance):
def b(self, _):
return ["b-{}".format(i) for i in range(serve.context.batch_size)]
@ray.remote
class CustomActor(Batcher, RayServeMixin):
pass
CONSUMER_NAME = "runner"
PRODUCER_NAME = "producer"
runner = CustomActor.remote()
ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner))
worker = setup_worker(CONSUMER_NAME, Batcher, q)
await q.link.remote(PRODUCER_NAME, CONSUMER_NAME)
await q.set_backend_config.remote(
@@ -177,7 +176,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)]
futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)]
await runner._ray_serve_fetch.remote()
await worker.run.remote()
gathered = await asyncio.gather(*futures)
assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"}
+1 -1
View File
@@ -12,7 +12,7 @@ def start_target_actor(ray_instance):
def __init__(self):
self.counter_value = 0
def _serve_metric(self):
def get_metrics(self):
self.counter_value += 1
return {
"latency_list": {
+3 -3
View File
@@ -18,9 +18,9 @@ def make_task_runner_mock():
self.query = None
self.queries = []
async def _ray_serve_call(self, request_item):
self.query = request_item
self.queries.append(request_item)
async def handle_request(self, request):
self.query = request
self.queries.append(request)
return "DONE"
def get_recent_call(self):