From d8f5b522655f2b848ba20817dd18d6e4aefb8a42 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 10 Apr 2020 12:01:14 -0500 Subject: [PATCH] [serve] Don't use mixin class for class-based backends (#7957) --- python/ray/serve/api.py | 46 ++-- .../{task_runner.py => backend_worker.py} | 208 +++++++++--------- python/ray/serve/master.py | 46 ++-- python/ray/serve/metric.py | 2 +- python/ray/serve/queues.py | 8 +- ..._task_runner.py => test_backend_worker.py} | 75 ++++--- python/ray/serve/tests/test_metric.py | 2 +- python/ray/serve/tests/test_queue.py | 6 +- 8 files changed, 190 insertions(+), 203 deletions(-) rename python/ray/serve/{task_runner.py => backend_worker.py} (54%) rename python/ray/serve/tests/{test_task_runner.py => test_backend_worker.py} (75%) diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 2f8652dab..53e01457e 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -10,7 +10,6 @@ from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, from ray.serve.master import ServeMaster from ray.serve.handle import RayServeHandle from ray.serve.kv_store_service import SQLiteKVStore -from ray.serve.task_runner import RayServeMixin, TaskRunnerActor from ray.serve.utils import block_until_http_ready from ray.serve.exceptions import RayServeException, batch_annotation_not_found from ray.serve.backend_config import BackendConfig @@ -208,8 +207,8 @@ def create_backend(func_or_class, """Create a backend using func_or_class and assign backend_tag. Args: - func_or_class (callable, class): a function or a class implements - __call__ protocol. + func_or_class (callable, class): a function or a class implementing + __call__. backend_tag (str): a unique tag assign to this backend. It will be used to associate services in traffic policy. backend_config (BackendConfig): An object defining backend properties @@ -224,41 +223,26 @@ def create_backend(func_or_class, BackendConfig), ("backend_config must be" " of instance BackendConfig") - # Make sure the batch size is correct + # Validate that func_or_class is a function or class. + if inspect.isfunction(func_or_class): + if len(actor_init_args) != 0: + raise ValueError( + "actor_init_args not supported for function backend.") + elif not inspect.isclass(func_or_class): + raise ValueError( + "Backend must be a function or class, it is {}.".format( + type(func_or_class))) + + # Make sure the batch size is correct. should_accept_batch = backend_config.max_batch_size is not None if should_accept_batch and not _backend_accept_batch(func_or_class): raise batch_annotation_not_found if _backend_accept_batch(func_or_class): backend_config.has_accept_batch_annotation = True - arg_list = [] - if inspect.isfunction(func_or_class): - # arg list for a fn is function itself - arg_list = [func_or_class] - # ignore lint on lambda expression - creator = lambda kwrgs: TaskRunnerActor._remote(**kwrgs) # noqa: E731 - elif inspect.isclass(func_or_class): - # Python inheritance order is right-to-left. We put RayServeMixin - # on the left to make sure its methods are not overriden. - @ray.remote - class CustomActor(RayServeMixin, func_or_class): - @wraps(func_or_class.__init__) - def __init__(self, *args, **kwargs): - # Initialize serve so it can be used in backends. - init() - super().__init__(*args, **kwargs) - - arg_list = actor_init_args - # ignore lint on lambda expression - creator = lambda kwargs: CustomActor._remote(**kwargs) # noqa: E731 - else: - raise TypeError( - "Backend must be a function or class, it is {}.".format( - type(func_or_class))) - ray.get( - master_actor.create_backend.remote(backend_tag, creator, - backend_config, arg_list)) + master_actor.create_backend.remote(backend_tag, backend_config, + func_or_class, actor_init_args)) @_ensure_connected diff --git a/python/ray/serve/task_runner.py b/python/ray/serve/backend_worker.py similarity index 54% rename from python/ray/serve/task_runner.py rename to python/ray/serve/backend_worker.py index 8da8374b9..6d500e446 100644 --- a/python/ray/serve/task_runner.py +++ b/python/ray/serve/backend_worker.py @@ -12,23 +12,69 @@ from ray.serve.exceptions import RayServeException from ray.async_compat import sync_to_async -class TaskRunner: - """A simple class that runs a function. +def create_backend_worker(func_or_class): + if inspect.isfunction(func_or_class): + is_function = True + elif inspect.isclass(func_or_class): + is_function = False + else: + assert False, "func_or_class must be function or class." - The purpose of this class is to model what the most basic actor could be. - That is, a ray serve actor should implement the TaskRunner interface. - """ + class RayServeWrappedWorker(object): + def __init__(self, + backend_tag, + replica_tag, + init_args, + self_handle=None, + router_handle=None, + start_running=True): + serve.init() + if is_function: + _callable = func_or_class + else: + _callable = func_or_class(*init_args) - def __init__(self, func_to_run): - serve.init() + if self_handle is None: + assert router_handle is None + master_actor = serve.api._get_master_actor() + # TODO(edoakes): this is a hacky workaround because there is + # a race condition when the master starts up a worker: the + # master starts the worker then adds its handle to a local map + # and the worker queries the master for its own handle from + # that map. If there's a large enough delay in the master, the + # handle may not have been added to the map yet (seen in CI). + # This will be fixed soon when the router just pushes tasks to + # the workers instead of the workers indicating that they're + # available. + start = time.time() + while time.time() - start < 5: + try: + print("Calling get_backend_replica_config") + [self_handle], [router_handle] = ray.get( + master_actor.get_backend_replica_config.remote( + replica_tag)) + print("Got get_backend_replica_config") + break + except ray.exceptions.RayTaskError: + pass - self.func = func_to_run + self.backend = RayServeWorker(backend_tag, _callable, self_handle, + router_handle, is_function) - # This parameter let argument inspection work with inner function. - self.__wrapped__ = func_to_run + if start_running: + self.backend.mark_idle_in_router() - def __call__(self, *args, **kwargs): - return self.func(*args, **kwargs) + def get_metrics(self): + return self.backend.get_metrics() + + async def handle_request(self, request): + return await self.backend.handle_request(request) + + def ready(self): + pass + + RayServeWrappedWorker.__name__ = "RayServeWorker_" + func_or_class.__name__ + return RayServeWrappedWorker def wrap_to_ray_error(exception): @@ -48,108 +94,78 @@ def ensure_async(func): return sync_to_async(func) -class RayServeMixin: - """This mixin class adds the functionality to fetch from router queues. +class RayServeWorker: + """Fetches requests and handles them with the provided callable.""" - Warning: - It assumes the main execution method is `__call__` of the user defined - class. This means that serve will call `your_instance.__call__` when - each request comes in. This behavior will be fixed in the future to - allow assigning artibrary methods. + def __init__(self, name, _callable, self_handle, router_handle, + is_function): + self.name = name + self.callable = _callable + self.self_handle = self_handle + self.router_handle = router_handle + self.is_function = is_function - Example: - >>> # Use ray.remote decorator and RayServeMixin - >>> # to make MyClass servable. - >>> @ray.remote - class RayServeActor(RayServeMixin, MyClass): - pass - """ - _ray_serve_self_handle = None - _ray_serve_router_handle = None - _ray_serve_setup_completed = False - _ray_serve_dequeue_requester_name = None + self.error_counter = 0 + self.latency_list = [] - # Work token can be unfullfilled from last iteration. - # This cache will be used to determine whether or not we should - # work on the same task as previous iteration or we are ready to - # move on. - _ray_serve_cached_work_token = None - - _serve_metric_error_counter = 0 - _serve_metric_latency_list = [] - - def _serve_metric(self): + def get_metrics(self): # Make a copy of the latency list and clear current list - latency_lst = self._serve_metric_latency_list[:] - self._serve_metric_latency_list = [] - - my_name = self._ray_serve_dequeue_requester_name + latency_list = self.latency_list[:] + self.latency_list = [] return { - "{}_error_counter".format(my_name): { - "value": self._serve_metric_error_counter, + "{}_error_counter".format(self.name): { + "value": self.error_counter, "type": "counter", }, - "{}_latency_s".format(my_name): { - "value": latency_lst, + "{}_latency_s".format(self.name): { + "value": latency_list, "type": "list", }, } - def _ray_serve_setup(self, my_name, router_handle, my_handle): - self._ray_serve_dequeue_requester_name = my_name - self._ray_serve_router_handle = router_handle - self._ray_serve_self_handle = my_handle - self._ray_serve_setup_completed = True + def mark_idle_in_router(self): + # Tell the router that this worker can accept tasks. + self.router_handle.dequeue_request.remote(self.name, self.self_handle) - def _ray_serve_fetch(self): - assert self._ray_serve_setup_completed - - self._ray_serve_router_handle.dequeue_request.remote( - self._ray_serve_dequeue_requester_name, - self._ray_serve_self_handle) - - def _ray_serve_get_runner_method(self, request_item): + def get_runner_method(self, request_item): method_name = request_item.call_method - if not hasattr(self, method_name): + if not hasattr(self.callable, method_name): raise RayServeException("Backend doesn't have method {} " "which is specified in the request. " "The avaiable methods are {}".format( method_name, dir(self))) - return getattr(self, method_name) + return getattr(self.callable, method_name) - def _ray_serve_count_num_positional(self, f): + def has_positional_args(self, f): # NOTE: # In the case of simple functions, not actors, the f will be - # a TaskRunner.__call__. What we really want here is the wrapped - # functionso inspect.signature will figure out the underlying f. - if hasattr(self, "__wrapped__"): - f = self.__wrapped__ + # function.__call__, but we need to inspect the function itself. + if self.is_function: + f = self.callable signature = inspect.signature(f) - counter = 0 for param in signature.parameters.values(): if (param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty): - counter += 1 - return counter + return True + return False async def invoke_single(self, request_item): args, kwargs, is_web_context = parse_request_item(request_item) serve_context.web = is_web_context start_timestamp = time.time() - method_to_call = self._ray_serve_get_runner_method(request_item) - args = args if self._ray_serve_count_num_positional( - method_to_call) else [] + method_to_call = self.get_runner_method(request_item) + args = args if self.has_positional_args(method_to_call) else [] method_to_call = ensure_async(method_to_call) try: result = await method_to_call(*args, **kwargs) except Exception as e: result = wrap_to_ray_error(e) - self._serve_metric_error_counter += 1 + self.error_counter += 1 - self._serve_metric_latency_list.append(time.time() - start_timestamp) + self.latency_list.append(time.time() - start_timestamp) return result async def invoke_batch(self, request_item_list): @@ -177,7 +193,7 @@ class RayServeMixin: args, kwargs, is_web_context = parse_request_item(item) context_flags.add(is_web_context) - call_method = self._ray_serve_get_runner_method(item) + call_method = self.get_runner_method(item) call_methods.add(call_method) if is_web_context: @@ -191,13 +207,12 @@ class RayServeMixin: # Set the flask request as a list to conform # with batching semantics: when in batching - # mode, each argument it turned into list. - if self._ray_serve_count_num_positional(call_method): + # mode, each argument is turned into list. + if self.has_positional_args(call_method): arg_list.append(FakeFlaskRequest()) try: - # check mixing of query context - # unified context needed + # Check mixing of query context (unified context needed). if len(context_flags) != 1: raise RayServeException( "Batched queries contain mixed context. Please only send " @@ -217,8 +232,7 @@ class RayServeMixin: start_timestamp = time.time() result_list = await call_method(*arg_list, **kwargs_list) - self._serve_metric_latency_list.append(time.time() - - start_timestamp) + self.latency_list.append(time.time() - start_timestamp) if (not isinstance(result_list, list)) or (len(result_list) != batch_size): raise RayServeException("__call__ function " @@ -229,10 +243,10 @@ class RayServeMixin: return result_list except Exception as e: wrapped_exception = wrap_to_ray_error(e) - self._serve_metric_error_counter += batch_size + self.error_counter += batch_size return [wrapped_exception for _ in range(batch_size)] - async def _ray_serve_call(self, request): + async def handle_request(self, request): # check if work_item is a list or not # if it is list: then batching supported if not isinstance(request, list): @@ -243,28 +257,6 @@ class RayServeMixin: # re-assign to default values serve_context.web = False serve_context.batch_size = None - - # Tell router that current actor is idle - self._ray_serve_fetch() + self.mark_idle_in_router() return result - - -class TaskRunnerBackend(TaskRunner, RayServeMixin): - """A simple function serving backend - - Note that this is not yet an actor. To make it an actor: - - >>> @ray.remote - class TaskRunnerActor(TaskRunnerBackend): - pass - - Note: - This class is not used in the actual ray serve system. It exists - for documentation purpose. - """ - - -@ray.remote -class TaskRunnerActor(TaskRunnerBackend): - pass diff --git a/python/ray/serve/master.py b/python/ray/serve/master.py index 8f9af5a1c..8794161fd 100644 --- a/python/ray/serve/master.py +++ b/python/ray/serve/master.py @@ -6,6 +6,7 @@ from ray.serve.http_proxy import HTTPProxyActor from ray.serve.kv_store_service import (BackendTable, RoutingTable, TrafficPolicyTable) from ray.serve.metric import (MetricMonitor, start_metric_monitor_loop) +from ray.serve.backend_worker import create_backend_worker from ray.serve.utils import expand, get_random_letters import numpy as np @@ -104,32 +105,42 @@ class ServeMaster: for _ in range(-delta_num_replicas): self._remove_backend_replica(backend_tag) + async def get_backend_replica_config(self, replica_tag): + return [self.tag_to_actor_handles[replica_tag]], self.get_router() + async def _start_backend_replica(self, backend_tag): assert (backend_tag in self.backend_table.list_backends() ), "Backend {} is not registered.".format(backend_tag) replica_tag = "{}#{}".format(backend_tag, get_random_letters(length=6)) + # Register the worker in the DB. + # TODO(edoakes): we should guarantee that if calls to the master + # succeed, the cluster state has changed and if they fail, it hasn't. + # Once we have master actor fault tolerance, this breaks that guarantee + # because this method could fail after writing the replica to the DB. + self.backend_table.add_replica(backend_tag, replica_tag) + # Fetch the info to start the replica from the backend table. - creator = self.backend_table.get_backend_creator(backend_tag) + backend_actor = ray.remote( + self.backend_table.get_backend_creator(backend_tag)) backend_config_dict = self.backend_table.get_info(backend_tag) backend_config = BackendConfig(**backend_config_dict) - init_args = self.backend_table.get_init_args(backend_tag) + init_args = [ + backend_tag, replica_tag, + self.backend_table.get_init_args(backend_tag) + ] kwargs = backend_config.get_actor_creation_args(init_args) - runner_handle = creator(kwargs) - self.tag_to_actor_handles[replica_tag] = runner_handle + # Start the worker. + worker_handle = backend_actor._remote(**kwargs) + self.tag_to_actor_handles[replica_tag] = worker_handle - # Set up the worker. + # Wait for the worker to start up. + await worker_handle.ready.remote() - await runner_handle._ray_serve_setup.remote(backend_tag, - self.get_router()[0], - runner_handle) - ray.get(runner_handle._ray_serve_fetch.remote()) - - # Register the worker in config tables and metric monitor. - self.backend_table.add_replica(backend_tag, replica_tag) - self.get_metric_monitor()[0].add_target.remote(runner_handle) + # Register the worker with the metric monitor. + self.get_metric_monitor()[0].add_target.remote(worker_handle) def _remove_backend_replica(self, backend_tag): assert (backend_tag in self.backend_table.list_backends() @@ -191,18 +202,19 @@ class ServeMaster: self.route_table.list_service( include_methods=True, include_headless=False))) - async def create_backend(self, backend_tag, creator, backend_config, - arg_list): + async def create_backend(self, backend_tag, backend_config, func_or_class, + actor_init_args): backend_config_dict = dict(backend_config) + backend_worker = create_backend_worker(func_or_class) # Save creator which starts replicas. - self.backend_table.register_backend(backend_tag, creator) + self.backend_table.register_backend(backend_tag, backend_worker) # Save information about configurations needed to start the replicas. self.backend_table.register_info(backend_tag, backend_config_dict) # Save the initial arguments needed by replicas. - self.backend_table.save_init_args(backend_tag, arg_list) + self.backend_table.save_init_args(backend_tag, actor_init_args) # Set the backend config inside the router # (particularly for max-batch-size). diff --git a/python/ray/serve/metric.py b/python/ray/serve/metric.py index 901901489..9ce1ce7b2 100644 --- a/python/ray/serve/metric.py +++ b/python/ray/serve/metric.py @@ -44,7 +44,7 @@ class MetricMonitor: curr_time = time.time() result = [ - handle._serve_metric.remote() + handle.get_metrics.remote() for handle in self.actor_handles.values() ] # TODO(simon): handle the possibility that an actor_handle is removed diff --git a/python/ray/serve/queues.py b/python/ray/serve/queues.py index f37498096..e872b03a0 100644 --- a/python/ray/serve/queues.py +++ b/python/ray/serve/queues.py @@ -85,7 +85,7 @@ def _make_future_unwrapper(client_futures: List[asyncio.Future], class CentralizedQueues: """A router that routes request to available workers. - Router aceepts each request from the `enqueue_request` method and enqueues + Router accepts each request from the `enqueue_request` method and enqueues it. It also accepts worker request to work (called work_intention in code) from workers via the `dequeue_request` method. The traffic policy is used to match requests with their corresponding workers. @@ -161,7 +161,7 @@ class CentralizedQueues: def is_ready(self): return True - def _serve_metric(self): + def get_metrics(self): return { "backend_{}_queue_size".format(backend_name): { "value": len(queue), @@ -302,7 +302,7 @@ class CentralizedQueues: worker = await worker_queue.get() if max_batch_size is None: # No batching request = buffer_queue.pop(0) - future = worker._ray_serve_call.remote(request).as_future() + future = worker.handle_request.remote(request).as_future() # chaining satisfies request.async_future with future result. asyncio.futures._chain_future(future, request.async_future) else: @@ -317,7 +317,7 @@ class CentralizedQueues: requests_group[request.call_method].append(request) for group in requests_group.values(): - future = worker._ray_serve_call.remote(group).as_future() + future = worker.handle_request.remote(group).as_future() future.add_done_callback( _make_future_unwrapper( client_futures=[req.async_future for req in group], diff --git a/python/ray/serve/tests/test_task_runner.py b/python/ray/serve/tests/test_backend_worker.py similarity index 75% rename from python/ray/serve/tests/test_task_runner.py rename to python/ray/serve/tests/test_backend_worker.py index d52ca7c3a..a126084fa 100644 --- a/python/ray/serve/tests/test_task_runner.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -6,20 +6,40 @@ import ray from ray import serve import ray.serve.context as context from ray.serve.policy import RoundRobinPolicyQueueActor -from ray.serve.task_runner import (RayServeMixin, TaskRunner, TaskRunnerActor, - wrap_to_ray_error) +from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error from ray.serve.request_params import RequestMetadata from ray.serve.backend_config import BackendConfig pytestmark = pytest.mark.asyncio -async def test_runner_basic(): - def echo(i): - return i +def setup_worker(name, func_or_class, router_handle, init_args=None): + if init_args is None: + init_args = () - r = TaskRunner(echo) - assert r(1) == 1 + @ray.remote + class WorkerActor: + def setup(self, self_handle, router_handle): + self.worker = create_backend_worker(func_or_class)( + name, + name + ":tag", + init_args, + self_handle=self_handle[0], + router_handle=router_handle[0], + start_running=False) + + def get_metrics(self): + return self.worker.get_metrics() + + def run(self): + self.worker.backend.mark_idle_in_router() + + async def handle_request(self, *args, **kwargs): + return await self.worker.handle_request(*args, **kwargs) + + worker = WorkerActor.remote() + ray.get(worker.setup.remote([worker], [router_handle])) + return worker async def test_runner_wraps_error(): @@ -36,9 +56,8 @@ async def test_runner_actor(serve_instance): CONSUMER_NAME = "runner" PRODUCER_NAME = "prod" - runner = TaskRunnerActor.remote(echo) - ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)) - runner._ray_serve_fetch.remote() + worker = setup_worker(CONSUMER_NAME, echo, q) + await worker.run.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) @@ -62,14 +81,8 @@ async def test_ray_serve_mixin(serve_instance): def __call__(self, flask_request, i=None): return i + self.increment - @ray.remote - class CustomActor(MyAdder, RayServeMixin): - pass - - runner = CustomActor.remote(3) - - ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)) - runner._ray_serve_fetch.remote() + worker = setup_worker(CONSUMER_NAME, MyAdder, q, init_args=(3, )) + await worker.run.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) @@ -90,10 +103,8 @@ async def test_task_runner_check_context(serve_instance): CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" - runner = TaskRunnerActor.remote(echo) - - ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)) - runner._ray_serve_fetch.remote() + worker = setup_worker(CONSUMER_NAME, echo, q) + await worker.run.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) @@ -113,17 +124,11 @@ async def test_task_runner_custom_method_single(serve_instance): def b(self, _): return "b" - @ray.remote - class CustomActor(NonBatcher, RayServeMixin): - pass - CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" - runner = CustomActor.remote() - - ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)) - runner._ray_serve_fetch.remote() + worker = setup_worker(CONSUMER_NAME, NonBatcher, q) + await worker.run.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) @@ -154,16 +159,10 @@ async def test_task_runner_custom_method_batch(serve_instance): def b(self, _): return ["b-{}".format(i) for i in range(serve.context.batch_size)] - @ray.remote - class CustomActor(Batcher, RayServeMixin): - pass - CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" - runner = CustomActor.remote() - - ray.get(runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner)) + worker = setup_worker(CONSUMER_NAME, Batcher, q) await q.link.remote(PRODUCER_NAME, CONSUMER_NAME) await q.set_backend_config.remote( @@ -177,7 +176,7 @@ async def test_task_runner_custom_method_batch(serve_instance): futures = [q.enqueue_request.remote(a_query_param) for _ in range(2)] futures += [q.enqueue_request.remote(b_query_param) for _ in range(2)] - await runner._ray_serve_fetch.remote() + await worker.run.remote() gathered = await asyncio.gather(*futures) assert set(gathered) == {"a-0", "a-1", "b-0", "b-1"} diff --git a/python/ray/serve/tests/test_metric.py b/python/ray/serve/tests/test_metric.py index ed8d69c81..f3efb385f 100644 --- a/python/ray/serve/tests/test_metric.py +++ b/python/ray/serve/tests/test_metric.py @@ -12,7 +12,7 @@ def start_target_actor(ray_instance): def __init__(self): self.counter_value = 0 - def _serve_metric(self): + def get_metrics(self): self.counter_value += 1 return { "latency_list": { diff --git a/python/ray/serve/tests/test_queue.py b/python/ray/serve/tests/test_queue.py index 6ee5f264b..c368dc718 100644 --- a/python/ray/serve/tests/test_queue.py +++ b/python/ray/serve/tests/test_queue.py @@ -18,9 +18,9 @@ def make_task_runner_mock(): self.query = None self.queries = [] - async def _ray_serve_call(self, request_item): - self.query = request_item - self.queries.append(request_item) + async def handle_request(self, request): + self.query = request + self.queries.append(request) return "DONE" def get_recent_call(self):