diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index f60ef2aaf..ac9d75106 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -5,21 +5,101 @@ py_library( srcs = glob(["**/*.py"], exclude=["tests/*.py"]), ) -# This test aggregates all serve tests and run them in a single session -# similar to `pytest .` -# Serve tests need to run in a single session because starting and stopping -# serve cluster take a large chunk of time. All serve tests use a shared -# cluster. +serve_tests_srcs = glob(["tests/*.py"], + exclude=["tests/test_nonblocking.py", + "tests/test_master_crashes.py", + "tests/test_serve.py", + ]) + py_test( - name = "test_serve", + name = "test_api", size = "medium", - srcs = glob(["tests/*.py"], - exclude=["tests/test_nonblocking.py", - "tests/test_master_crashes.py"]), + srcs = serve_tests_srcs, tags = ["exclusive"], deps = [":serve_lib"], ) +py_test( + name = "test_backend_worker", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + + +py_test( + name = "test_config", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + + +py_test( + name = "test_failure", + size = "medium", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + + +py_test( + name = "test_handle", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + + +py_test( + name = "test_kv_store", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + + +py_test( + name = "test_metric", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + + +py_test( + name = "test_persistence", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + + +py_test( + name = "test_router", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + + +py_test( + name = "test_util", + size = "small", + srcs = serve_tests_srcs, + tags = ["exclusive"], + deps = [":serve_lib"], +) + + # Runs test_api and test_failure with injected failures in the master actor. # TODO(edoakes): reenable this once we're using GCS actor fault tolerance. # py_test( @@ -97,6 +177,14 @@ py_test( deps = [":serve_lib"] ) +py_test( + name = "snippet_model_composition", + size = "small", + srcs = glob(["examples/doc/*.py"]), + tags = ["exclusive"], + deps = [":serve_lib"] +) + # Disable the deployment tutorial test because it requires # ray start --head in the background. # py_test( diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index b59388363..7387bcfdd 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -236,7 +236,8 @@ def create_backend(backend_tag, replica_config = ReplicaConfig( func_or_class, *actor_init_args, ray_actor_options=ray_actor_options) - backend_config = BackendConfig(config, replica_config.accepts_batches) + backend_config = BackendConfig(config, replica_config.accepts_batches, + replica_config.is_blocking) ray.get( master_actor.create_backend.remote(backend_tag, backend_config, diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 78782fcfb..74a702378 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -1,20 +1,61 @@ +import asyncio import traceback import inspect from collections.abc import Iterable +from collections import defaultdict +from itertools import groupby +from operator import attrgetter +import time import ray +from ray.async_compat import sync_to_async + from ray import serve from ray.serve import context as serve_context from ray.serve.context import FakeFlaskRequest -from collections import defaultdict -from ray.serve.utils import (parse_request_item, _get_logger) +from ray.serve.utils import (parse_request_item, _get_logger, chain_future, + unpack_future) from ray.serve.exceptions import RayServeException from ray.serve.metric import MetricClient -from ray.async_compat import sync_to_async +from ray.serve.config import BackendConfig +from ray.serve.router import Query logger = _get_logger() +class WaitableQueue(asyncio.Queue): + async def wait_for_batch(self, num_items: int, timeout_s: float): + """Wait up to num_items in the queue given timeout_s. + + This method will block indefinitely for the first item. Therefore, it + guarantees to return at least one item. + """ + + assert num_items >= 1 + # Wait for the first value without timeout. We will return at least + # one item. Additionally this help the caller context switch on empty + # queue. + start_waiting = time.time() + batch = [ + await self.get(), + ] + + # Adjust the timeout to account for the time waiting for first item. + time_remaining = timeout_s - (time.time() - start_waiting) + time_remaining = max(0, time_remaining) + + # Wait for the remaining batch with the timeout + if num_items > 1: + done_set, not_done_set = await asyncio.wait( + [self.get() for _ in range(num_items - 1)], + timeout=time_remaining) + for task in done_set: + batch.append(task.result()) + for task in not_done_set: + task.cancel() + return batch + + def create_backend_worker(func_or_class): """Creates a worker class wrapping the provided function or class.""" @@ -30,8 +71,10 @@ def create_backend_worker(func_or_class): backend_tag, replica_tag, init_args, + backend_config: BackendConfig, instance_name=None): serve.init(name=instance_name) + if is_function: _callable = func_or_class else: @@ -42,11 +85,15 @@ def create_backend_worker(func_or_class): metric_client = MetricClient( metric_exporter, default_labels={"backend": backend_tag}) self.backend = RayServeWorker(backend_tag, replica_tag, _callable, - is_function, metric_client) + backend_config, is_function, + metric_client) async def handle_request(self, request): return await self.backend.handle_request(request) + def update_config(self, new_config: BackendConfig): + return self.backend.update_config(new_config) + def ready(self): pass @@ -75,13 +122,16 @@ def ensure_async(func): class RayServeWorker: """Handles requests with the provided callable.""" - def __init__(self, name, replica_tag, _callable, is_function, - metric_client): + def __init__(self, name, replica_tag, _callable, + backend_config: BackendConfig, is_function, metric_client): self.name = name self.replica_tag = replica_tag self.callable = _callable self.is_function = is_function + self.config = backend_config + self.query_queue = WaitableQueue() + self.metric_client = metric_client self.request_counter = self.metric_client.new_counter( "backend_request_counter", @@ -101,6 +151,9 @@ class RayServeWorker: self.restart_counter.labels(replica_tag=self.replica_tag).add() + self.loop_task = asyncio.get_event_loop().create_task(self.main_loop()) + self.config_updated = asyncio.Event() + def get_runner_method(self, request_item): method_name = request_item.call_method if not hasattr(self.callable, method_name): @@ -108,6 +161,8 @@ class RayServeWorker: "which is specified in the request. " "The available methods are {}".format( method_name, dir(self.callable))) + if self.is_function: + return self.callable return getattr(self.callable, method_name) def has_positional_args(self, f): @@ -124,6 +179,12 @@ class RayServeWorker: return True return False + def _reset_context(self): + # NOTE(simon): context management won't work in async mode because + # many concurrent queries might be running at the same time. + serve_context.web = None + serve_context.batch_size = None + async def invoke_single(self, request_item): args, kwargs, is_web_context = parse_request_item(request_item) serve_context.web = is_web_context @@ -137,24 +198,12 @@ class RayServeWorker: except Exception as e: result = wrap_to_ray_error(e) self.error_counter.add() + finally: + self._reset_context() return result async def invoke_batch(self, request_item_list): - # TODO(alind) : create no-http services. The enqueues - # from such services will always be TaskContext.Python. - - # Assumption : all the requests in a bacth - # have same serve context. - - # For batching kwargs are modified as follows - - # kwargs [Python Context] : key,val - # kwargs_list : key, [val1,val2, ... , valn] - # or - # args[Web Context] : val - # args_list : [val1,val2, ...... , valn] - # where n (current batch size) <= max_batch_size of a backend - arg_list = [] kwargs_list = defaultdict(list) context_flags = set() @@ -222,22 +271,53 @@ class RayServeWorker: "results with length equal to the batch size" ".".format(batch_size, len(result_list))) raise RayServeException(error_message) + self._reset_context() return result_list except Exception as e: wrapped_exception = wrap_to_ray_error(e) self.error_counter.add() + self._reset_context() return [wrapped_exception for _ in range(batch_size)] - async def handle_request(self, request): - # check if work_item is a list or not - # if it is list: then batching supported - if not isinstance(request, list): - result = await self.invoke_single(request) - else: - result = await self.invoke_batch(request) + async def main_loop(self): + while True: + # NOTE(simon): There's an issue when user updated batch size and + # batch wait timeout during the execution, these values will not be + # updated until after the current iteration. + batch = await self.query_queue.wait_for_batch( + num_items=self.config.max_batch_size or 1, + timeout_s=self.config.batch_wait_timeout) - # re-assign to default values - serve_context.web = False - serve_context.batch_size = None + all_evaluated_futures = [] - return result + if not self.config.accepts_batches: + query = batch[0] + evaluated = asyncio.ensure_future(self.invoke_single(query)) + all_evaluated_futures = [evaluated] + chain_future(evaluated, query.async_future) + else: + get_call_method = attrgetter("call_method") + sorted_batch = sorted(batch, key=get_call_method) + for _, group in groupby(sorted_batch, key=get_call_method): + group = sorted(group) + evaluated = asyncio.ensure_future(self.invoke_batch(group)) + all_evaluated_futures.append(evaluated) + result_futures = [q.async_future for q in group] + chain_future( + unpack_future(evaluated, len(group)), result_futures) + + if self.config.is_blocking: + # We use asyncio.wait here so if the result is exception, + # it will not be raised. + await asyncio.wait(all_evaluated_futures) + + def update_config(self, new_config: BackendConfig): + self.config = new_config + self.config_updated.set() + + async def handle_request(self, request: Query): + assert not isinstance(request, list) + logger.debug("Worker {} got request {}".format(self.name, request)) + request.async_future = asyncio.get_event_loop().create_future() + self.query_queue.put_nowait(request) + return await request.async_future diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index 61fab8a1c..69a506449 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -1,5 +1,7 @@ import inspect +from ray.serve.constants import ASYNC_CONCURRENCY + def _callable_accepts_batch(func_or_class): if inspect.isfunction(func_or_class): @@ -8,15 +10,46 @@ def _callable_accepts_batch(func_or_class): return hasattr(func_or_class.__call__, "_serve_accept_batch") +def _callable_is_blocking(func_or_class): + if inspect.isfunction(func_or_class): + return not inspect.iscoroutinefunction(func_or_class) + elif inspect.isclass(func_or_class): + return not inspect.iscoroutinefunction(func_or_class.__call__) + + class BackendConfig: - def __init__(self, config_dict, accepts_batches=False): + def __init__(self, config_dict, accepts_batches=False, is_blocking=True): assert isinstance(config_dict, dict) # Make a copy so that we don't modify the input dict. config_dict = config_dict.copy() self.accepts_batches = accepts_batches + self.is_blocking = is_blocking self.num_replicas = config_dict.pop("num_replicas", 1) self.max_batch_size = config_dict.pop("max_batch_size", None) + self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0) + self.max_concurrent_queries = config_dict.pop("max_concurrent_queries", + None) + + if self.max_concurrent_queries is None: + # Model serving mode: if the servable is blocking and the wait + # timeout is default zero seconds, then we keep the existing + # behavior to allow at most max batch size queries. + if self.is_blocking and self.batch_wait_timeout == 0: + self.max_concurrent_queries = self.max_batch_size or 1 + + # Pipeline/async mode: if the servable is not blocking, + # router should just keep pushing queries to the worker + # replicas until a high limit. + if not self.is_blocking: + self.max_concurrent_queries = ASYNC_CONCURRENCY + + # Batch inference mode: user specifies non zero timeout to wait for + # full batch. We will use 2*max_batch_size to perform double + # buffering to keep the replica busy. + if self.max_batch_size is not None and self.batch_wait_timeout > 0: + self.max_concurrent_queries = 2 * self.max_batch_size + if len(config_dict) != 0: raise ValueError("Unknown options in backend config: {}".format( list(config_dict.keys()))) @@ -64,6 +97,7 @@ class ReplicaConfig: ray_actor_options=None): self.func_or_class = func_or_class self.accepts_batches = _callable_accepts_batch(func_or_class) + self.is_blocking = _callable_is_blocking(func_or_class) self.actor_init_args = list(actor_init_args) if ray_actor_options is None: self.ray_actor_options = {} diff --git a/python/ray/serve/examples/doc/snippet_model_composition.py b/python/ray/serve/examples/doc/snippet_model_composition.py new file mode 100644 index 000000000..29bef42c7 --- /dev/null +++ b/python/ray/serve/examples/doc/snippet_model_composition.py @@ -0,0 +1,57 @@ +from random import random + +import requests + +from ray import serve + +serve.init() + + +def model_one(_unused_flask_request, data=None): + print("Model 1 called with data ", data) + return random() + + +def model_two(_unused_flask_request, data=None): + print("Model 2 called with data ", data) + return data + + +class ComposedModel: + def __init__(self): + self.model_one = serve.get_handle("model_one") + self.model_two = serve.get_handle("model_two") + + async def __call__(self, flask_request): + data = flask_request.data + + score = await self.model_one.remote(data=data) + if score > 0.5: + result = await self.model_two.remote(data=data) + result = {"model_used": 2, "score": score} + else: + result = {"model_used": 1, "score": score} + + return result + + +serve.create_backend("model_one", model_one) +serve.create_endpoint("model_one", backend="model_one") + +serve.create_backend("model_two", model_two) +serve.create_endpoint("model_two", backend="model_two") + +serve.create_backend( + "composed_backend", ComposedModel, config={"max_concurrent_queries": 10}) +serve.create_endpoint( + "composed", backend="composed_backend", route="/composed") + +for _ in range(5): + resp = requests.get("http://127.0.0.1:8000/composed", data="hey!") + print(resp.json()) +# Output +# {'model_used': 2, 'score': 0.6250189863595503} +# {'model_used': 1, 'score': 0.03146855349621436} +# {'model_used': 2, 'score': 0.6916977560006987} +# {'model_used': 2, 'score': 0.8169693450866928} +# {'model_used': 2, 'score': 0.9540681979573862} diff --git a/python/ray/serve/master.py b/python/ray/serve/master.py index 1ded64026..fa5de7a6a 100644 --- a/python/ray/serve/master.py +++ b/python/ray/serve/master.py @@ -258,6 +258,7 @@ class ServeMaster: for backend, (_, backend_config, _) in self.backends.items(): await self.router.set_backend_config.remote( backend, backend_config) + await self.broadcast_backend_config(backend) # Push configuration state to the HTTP proxy. await self.http_proxy.set_route_table.remote(self.routes) @@ -314,6 +315,7 @@ class ServeMaster: backend_tag, replica_tag, replica_config.actor_init_args, + backend_config, instance_name=self.instance_name) # TODO(edoakes): we should probably have a timeout here. await worker_handle.ready.remote() @@ -602,6 +604,7 @@ class ServeMaster: # (particularly for max-batch-size). await self.router.set_backend_config.remote( backend_tag, backend_config) + await self.broadcast_backend_config(backend_tag) async def delete_backend(self, backend_tag): async with self.write_lock: @@ -664,6 +667,22 @@ class ServeMaster: await self._start_pending_replicas() await self._stop_pending_replicas() + await self.broadcast_backend_config(backend_tag) + + async def broadcast_backend_config(self, backend_tag): + _, backend_config, _ = self.backends[backend_tag] + broadcast_futures = [] + for replica_tag in self.replicas[backend_tag]: + try: + replica = ray.get_actor(replica_tag) + except ValueError: + continue + + future = replica.update_config.remote(backend_config).as_future() + broadcast_futures.append(future) + if len(broadcast_futures) > 0: + await asyncio.gather(*broadcast_futures) + def get_backend_config(self, backend_tag): """Get the current config for the specified backend.""" assert (backend_tag in self.backends diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 348543c0e..5ee9e37de 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -13,7 +13,7 @@ import ray from ray import serve from ray.serve.metric import MetricClient from ray.serve.policy import RandomEndpointPolicy -from ray.serve.utils import logger +from ray.serve.utils import logger, chain_future class Query: @@ -58,10 +58,6 @@ class Query: def __lt__(self, other): return self.request_slo_ms < other.request_slo_ms - def __repr__(self): - return "".format(self.request_args, - self.request_kwargs) - def _make_future_unwrapper(client_futures: List[asyncio.Future], host_future: asyncio.Future): @@ -117,6 +113,8 @@ class Router: self.backend_info = dict() # replica tag -> worker_handle self.replicas = dict() + # replica_tag -> concurrent queries counter + self.queries_counter = defaultdict(lambda: 0) # -- Synchronization -- # @@ -126,7 +124,7 @@ class Router: # an operation holding the only query and the other flush operation # holding the only idle replica. Additionally, allowing only one flush # operation at a time simplifies design overhead for custom queuing and - # batching polcies. + # batching policies. self.flush_lock = asyncio.Lock() # -- State Restoration -- # @@ -215,11 +213,15 @@ class Router: await self.mark_worker_idle(backend_tag, backend_replica_tag) async def mark_worker_idle(self, backend_tag, backend_replica_tag): + logger.debug( + "Marking backend with tag {} as idle.".format(backend_replica_tag)) if backend_replica_tag not in self.replicas: return async with self.flush_lock: - self.worker_queues[backend_tag].appendleft(backend_replica_tag) + # NOTE(simon): This is a O(n) operation where n=len(worker_queue) + if backend_replica_tag not in self.worker_queues[backend_tag]: + self.worker_queues[backend_tag].appendleft(backend_replica_tag) self.flush_backend_queues([backend_tag]) async def remove_worker(self, backend_tag, replica_tag): @@ -299,12 +301,11 @@ class Router: "queue size {} and worker queue size {}".format( backend, len(buffer_queue), len(worker_queue))) - max_batch_size = None - if backend in self.backend_info: - max_batch_size = self.backend_info[backend].max_batch_size - - self._assign_query_to_worker(backend, buffer_queue, worker_queue, - max_batch_size) + self._assign_query_to_worker( + backend, + buffer_queue, + worker_queue, + ) async def _do_query(self, backend, backend_replica_tag, req): # If the worker died, this will be a RayActorError. Just return it and @@ -317,16 +318,13 @@ class Router: except RayTaskError as error: self.num_error_backend_request.labels(backend=backend).add() result = error + self.queries_counter[backend_replica_tag] -= 1 await self.mark_worker_idle(backend, backend_replica_tag) logger.debug("Got result in {:.2f}s".format(time.time() - start)) return result - def _assign_query_to_worker(self, - backend, - buffer_queue, - worker_queue, - max_batch_size=None): - + def _assign_query_to_worker(self, backend, buffer_queue, worker_queue): + overloaded_replicas = set() while len(buffer_queue) and len(worker_queue): backend_replica_tag = worker_queue.pop() @@ -334,27 +332,30 @@ class Router: if backend_replica_tag not in self.replicas: continue - if max_batch_size is None: # No batching - request = buffer_queue.pop(0) - future = asyncio.get_event_loop().create_task( - self._do_query(backend, backend_replica_tag, request)) - # chaining satisfies request.async_future with future result. - asyncio.futures._chain_future(future, request.async_future) - else: - real_batch_size = min(len(buffer_queue), max_batch_size) - requests = [ - buffer_queue.pop(0) for _ in range(real_batch_size) - ] + # We have reached the end of the worker queue where all replicas + # are overloaded. + if backend_replica_tag in overloaded_replicas: + break - # split requests by method type - requests_group = defaultdict(list) - for request in requests: - requests_group[request.call_method].append(request) + # This replica has too many in flight and processing queries. + max_queries = 1 + if backend in self.backend_info: + max_queries = self.backend_info[backend].max_concurrent_queries + curr_queries = self.queries_counter[backend_replica_tag] + if curr_queries >= max_queries: + # Put the worker back to the queue. + worker_queue.appendleft(backend_replica_tag) + overloaded_replicas.add(backend_replica_tag) + logger.debug( + "Skipping backend {} because it has {} in flight " + "requests which exceeded the concurrency limit.".format( + backend, curr_queries)) + continue - for group in requests_group.values(): - future = asyncio.get_event_loop().create_task( - self._do_query(backend, backend_replica_tag, group)) - future.add_done_callback( - _make_future_unwrapper( - client_futures=[req.async_future for req in group], - host_future=future)) + request = buffer_queue.pop(0) + self.queries_counter[backend_replica_tag] += 1 + future = asyncio.get_event_loop().create_task( + self._do_query(backend, backend_replica_tag, request)) + chain_future(future, request.async_future) + + worker_queue.appendleft(backend_replica_tag) diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 5f524022b..84a0a3bc0 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -507,3 +507,8 @@ def test_endpoint_input_validation(serve_instance): with pytest.raises(TypeError): serve.create_endpoint("endpoint", backend=2) serve.create_endpoint("endpoint", backend="backend") + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index e20cbc3a2..4cda407a5 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -15,7 +15,10 @@ from ray.serve.exceptions import RayServeException pytestmark = pytest.mark.asyncio -def setup_worker(name, func_or_class, init_args=None): +def setup_worker(name, + func_or_class, + init_args=None, + backend_config=BackendConfig({})): if init_args is None: init_args = () @@ -23,7 +26,7 @@ def setup_worker(name, func_or_class, init_args=None): class WorkerActor: def __init__(self): self.worker = create_backend_worker(func_or_class)( - name, name + ":tag", init_args) + name, name + ":tag", init_args, backend_config) def ready(self): pass @@ -31,6 +34,9 @@ def setup_worker(name, func_or_class, init_args=None): async def handle_request(self, *args, **kwargs): return await self.worker.handle_request(*args, **kwargs) + def update_config(self, new_config): + return self.worker.update_config(new_config) + worker = WorkerActor.remote() ray.get(worker.ready.remote()) return worker @@ -165,14 +171,16 @@ async def test_task_runner_custom_method_batch(serve_instance): CONSUMER_NAME = "runner" PRODUCER_NAME = "producer" - worker = setup_worker(CONSUMER_NAME, Batcher) + backend_config = BackendConfig( + { + "max_batch_size": 4, + "batch_wait_timeout": 2 + }, accepts_batches=True) + worker = setup_worker( + CONSUMER_NAME, Batcher, backend_config=backend_config) await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) - await q.set_backend_config.remote( - CONSUMER_NAME, - BackendConfig({ - "max_batch_size": 10 - }, accepts_batches=True)) + await q.set_backend_config.remote(CONSUMER_NAME, backend_config) def make_request_param(call_method): return RequestMetadata( @@ -200,3 +208,77 @@ async def test_task_runner_custom_method_batch(serve_instance): np_array = make_request_param("return_np_array") result_np_value = await q.enqueue_request.remote(np_array) assert isinstance(result_np_value, np.int32) + + +async def test_task_runner_perform_batch(serve_instance): + q = ray.remote(Router).remote() + + def batcher(*args, **kwargs): + return [serve.context.batch_size] * serve.context.batch_size + + CONSUMER_NAME = "runner" + PRODUCER_NAME = "producer" + + config = BackendConfig( + { + "max_batch_size": 2, + "batch_wait_timeout": 10 + }, accepts_batches=True) + + worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config) + await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) + await q.set_backend_config.remote(CONSUMER_NAME, config) + await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) + + query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) + + my_batch_sizes = await asyncio.gather( + *[q.enqueue_request.remote(query_param) for _ in range(3)]) + assert my_batch_sizes == [2, 2, 1] + + +async def test_task_runner_perform_async(serve_instance): + q = ray.remote(Router).remote() + + @ray.remote + class Barrier: + def __init__(self, release_on): + self.release_on = release_on + self.current_waiters = 0 + self.event = asyncio.Event() + + async def wait(self): + self.current_waiters += 1 + if self.current_waiters == self.release_on: + self.event.set() + else: + await self.event.wait() + + barrier = Barrier.remote(release_on=10) + + async def wait_and_go(*args, **kwargs): + await barrier.wait.remote() + return "done!" + + CONSUMER_NAME = "runner" + PRODUCER_NAME = "producer" + + config = BackendConfig({"max_concurrent_queries": 10}, is_blocking=False) + + worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config) + await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) + await q.set_backend_config.remote(CONSUMER_NAME, config) + q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) + + query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) + + done, not_done = await asyncio.wait( + [q.enqueue_request.remote(query_param) for _ in range(10)], timeout=10) + assert len(done) == 10 + for item in done: + await item == "done!" + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_config.py b/python/ray/serve/tests/test_config.py index 07f4088eb..af6958d13 100644 --- a/python/ray/serve/tests/test_config.py +++ b/python/ray/serve/tests/test_config.py @@ -132,3 +132,8 @@ def test_replica_config_validation(): ReplicaConfig(Class, ray_actor_options={"detached": None}) with pytest.raises(ValueError): ReplicaConfig(Class, ray_actor_options={"max_restarts": None}) + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_failure.py b/python/ray/serve/tests/test_failure.py index 21c42fb60..4e7eb7f80 100644 --- a/python/ray/serve/tests/test_failure.py +++ b/python/ray/serve/tests/test_failure.py @@ -241,3 +241,9 @@ def test_worker_replica_failure(serve_instance): break except TimeoutError: time.sleep(0.1) + + +if __name__ == "__main__": + import sys + import pytest + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_handle.py b/python/ray/serve/tests/test_handle.py index 1ad0d65e4..8dc6856ff 100644 --- a/python/ray/serve/tests/test_handle.py +++ b/python/ray/serve/tests/test_handle.py @@ -33,3 +33,9 @@ def test_handle_in_endpoint(serve_instance): methods=["GET", "POST"]) assert requests.get("http://127.0.0.1:8000/endpoint2").text == "hello" + + +if __name__ == "__main__": + import sys + import pytest + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_kv_store.py b/python/ray/serve/tests/test_kv_store.py index 964881f98..3c13dec0d 100644 --- a/python/ray/serve/tests/test_kv_store.py +++ b/python/ray/serve/tests/test_kv_store.py @@ -38,3 +38,8 @@ def test_ray_internal_kv_collisions(serve_instance): kv2.put("1", b"-1") assert kv2.get("1") == b"-1" assert kv1.get("1") == b"1" + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_metric.py b/python/ray/serve/tests/test_metric.py index e5c89f2b7..11773cc6e 100644 --- a/python/ray/serve/tests/test_metric.py +++ b/python/ray/serve/tests/test_metric.py @@ -195,3 +195,8 @@ async def test_system_metric_endpoints(serve_instance): print("Metric not correct, retrying...") if not success: test_metric_endpoint() + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_nonblocking.py b/python/ray/serve/tests/test_nonblocking.py index 747667af2..9d656e9b0 100644 --- a/python/ray/serve/tests/test_nonblocking.py +++ b/python/ray/serve/tests/test_nonblocking.py @@ -20,4 +20,4 @@ def test_nonblocking(): if __name__ == "__main__": import pytest - sys.exit(pytest.main(["-v", __file__])) + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_persistence.py b/python/ray/serve/tests/test_persistence.py index 3a873c9a3..4ed75fd92 100644 --- a/python/ray/serve/tests/test_persistence.py +++ b/python/ray/serve/tests/test_persistence.py @@ -33,3 +33,9 @@ serve.create_endpoint("driver", backend="driver", route="/driver") assert ray.get(handle.remote()) == "OK!" os.remove(path) + + +if __name__ == "__main__": + import sys + import pytest + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index ced8f737b..642a2fb0b 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -7,6 +7,8 @@ import ray from ray.serve.router import Router from ray.serve.request_params import RequestMetadata from ray.serve.utils import get_random_letters +from ray.test_utils import SignalActor +from ray.serve.config import BackendConfig pytestmark = pytest.mark.asyncio @@ -172,3 +174,68 @@ async def test_shard_key(serve_instance, task_runner_mock_actor): calls = await runner.get_all_calls.remote() for call in calls: assert call.request_args[0] in runner_shard_keys[i] + + +async def test_router_use_max_concurrency(serve_instance): + signal = SignalActor.remote() + + @ray.remote + class MockWorker: + async def handle_request(self, request): + await signal.wait.remote() + return "DONE" + + def ready(self): + pass + + class VisibleRouter(Router): + def get_queues(self): + return self.queries_counter, self.backend_queues + + worker = MockWorker.remote() + q = ray.remote(VisibleRouter).remote() + BACKEND_NAME = "max-concurrent-test" + config = BackendConfig({"max_concurrent_queries": 1}) + await q.set_traffic.remote("svc", {BACKEND_NAME: 1.0}) + await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker) + await q.set_backend_config.remote(BACKEND_NAME, config) + + # We send over two queries + first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1) + second_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1) + + # Neither queries should be available + with pytest.raises(ray.exceptions.RayTimeoutError): + ray.get([first_query, second_query], timeout=0.2) + + # Let's retrieve the router internal state + queries_counter, backend_queues = await q.get_queues.remote() + # There should be just one inflight request + assert queries_counter["max-concurrent-test:replica-tag"] == 1 + # The second query is buffered + assert len(backend_queues["max-concurrent-test"]) == 1 + + # Let's unblock the first query + await signal.send.remote(clear=True) + assert await first_query == "DONE" + + # The internal state of router should have changed. + queries_counter, backend_queues = await q.get_queues.remote() + # There should still be one inflight request + assert queries_counter["max-concurrent-test:replica-tag"] == 1 + # But there shouldn't be any queries in the queue + assert len(backend_queues["max-concurrent-test"]) == 0 + + # Unblocking the second query + await signal.send.remote(clear=True) + assert await second_query == "DONE" + + # Checking the internal state of the router one more time + queries_counter, backend_queues = await q.get_queues.remote() + assert queries_counter["max-concurrent-test:replica-tag"] == 0 + assert len(backend_queues["max-concurrent-test"]) == 0 + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_util.py b/python/ray/serve/tests/test_util.py index b1c2254c6..01cdf3fed 100644 --- a/python/ray/serve/tests/test_util.py +++ b/python/ray/serve/tests/test_util.py @@ -1,8 +1,10 @@ +import asyncio import json import numpy as np +import pytest -from ray.serve.utils import ServeEncoder +from ray.serve.utils import ServeEncoder, chain_future, unpack_future def test_bytes_encoder(): @@ -20,3 +22,50 @@ def test_numpy_encoding(): assert json.loads(json.dumps(floats, cls=ServeEncoder)) == data assert json.loads(json.dumps(ints, cls=ServeEncoder)) == data assert json.loads(json.dumps(uints, cls=ServeEncoder)) == data + + +@pytest.mark.asyncio +async def test_future_chaining(): + def make(): + return asyncio.get_event_loop().create_future() + + # Test 1 -> 1 chaining + fut1, fut2 = make(), make() + chain_future(fut1, fut2) + fut1.set_result(1) + assert await fut2 == 1 + + # Test 1 -> 1 chaining with exception + fut1, fut2 = make(), make() + chain_future(fut1, fut2) + fut1.set_exception(ValueError("")) + with pytest.raises(ValueError): + await fut2 + + # Test many -> many chaining + src_futs = [make() for _ in range(4)] + dst_futs = [make() for _ in range(4)] + chain_future(src_futs, dst_futs) + [fut.set_result(i) for i, fut in enumerate(src_futs)] + for i, fut in enumerate(dst_futs): + assert await fut == i + + # Test 1 -> many unwrapping + batched_future = make() + single_futures = unpack_future(batched_future, 4) + batched_future.set_result(list(range(4))) + for i, fut in enumerate(single_futures): + assert await fut == i + + # Test 1 -> many unwrapping with exception + batched_future = make() + single_futures = unpack_future(batched_future, 4) + batched_future.set_exception(ValueError("")) + for future in single_futures: + with pytest.raises(ValueError): + await future + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/utils.py b/python/ray/serve/utils.py index a2492542f..4563009a5 100644 --- a/python/ray/serve/utils.py +++ b/python/ray/serve/utils.py @@ -1,8 +1,11 @@ +import asyncio +from functools import singledispatch import json import logging import random import string import time +from typing import List import io import os @@ -114,3 +117,59 @@ def format_actor_name(actor_name, instance_name=None): return actor_name else: return "{}:{}".format(instance_name, actor_name) + + +@singledispatch +def chain_future(src, dst): + """Base method for chaining futures together. + + Chaining futures means the output from source future(s) are written as the + results of the destination future(s). This method can work with the + following inputs: + - src: Future, dst: Future + - src: List[Future], dst: List[Future] + """ + raise NotImplementedError() + + +@chain_future.register(asyncio.Future) +def _chain_future_single(src: asyncio.Future, dst: asyncio.Future): + asyncio.futures._chain_future(src, dst) + + +@chain_future.register(list) +def _chain_future_list(src: List[asyncio.Future], dst: List[asyncio.Future]): + if len(src) != len(dst): + raise ValueError( + "Source and destination list doesn't have the same length. " + "Source: {}. Destination: {}.".foramt(len(src), len(dst))) + + for s, d in zip(src, dst): + chain_future(s, d) + + +def unpack_future(src: asyncio.Future, num_items: int) -> List[asyncio.Future]: + """Unpack the result of source future to num_items futures. + + This function takes in a Future and splits its result into many futures. If + the result of the source future is an exception, then all destination + futures will have the same exception. + """ + dest_futures = [ + asyncio.get_event_loop().create_future() for _ in range(num_items) + ] + + def unwrap_callback(fut: asyncio.Future): + exception = fut.exception() + if exception is not None: + [f.set_exception(exception) for f in dest_futures] + return + + result = fut.result() + assert len(result) == num_items + for item, future in zip(result, dest_futures): + future.set_result(item) + + src.add_done_callback(unwrap_callback) + + return dest_futures diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index 575277bdf..7bc37ebcc 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -232,8 +232,10 @@ class SignalActor: def __init__(self): self.ready_event = asyncio.Event() - def send(self): + def send(self, clear=False): self.ready_event.set() + if clear: + self.ready_event.clear() async def wait(self, should_wait=True): if should_wait: