From ef498e8aa5cdfe09247d9c9efc7ef636378165c2 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 15 May 2020 16:18:52 -0500 Subject: [PATCH] [serve] Add basic session affinity via shard key (#8449) --- doc/source/rayserve/overview.rst | 18 +++ python/ray/serve/__init__.py | 4 +- python/ray/serve/api.py | 9 +- python/ray/serve/handle.py | 31 ++--- python/ray/serve/http_proxy.py | 4 +- python/ray/serve/master.py | 9 +- python/ray/serve/policy.py | 130 +++--------------- python/ray/serve/request_params.py | 4 +- python/ray/serve/router.py | 14 +- python/ray/serve/tests/test_api.py | 42 +++++- python/ray/serve/tests/test_backend_worker.py | 11 +- python/ray/serve/tests/test_router.py | 130 +++++++----------- 12 files changed, 157 insertions(+), 249 deletions(-) diff --git a/doc/source/rayserve/overview.rst b/doc/source/rayserve/overview.rst index b75e6e9f5..a101d6db4 100644 --- a/doc/source/rayserve/overview.rst +++ b/doc/source/rayserve/overview.rst @@ -228,6 +228,24 @@ You can also have RayServe batch requests for performance. You'll configure this serve.create_backend("counter1", BatchingExample, config=config) serve.set_traffic("counter1", {"counter1": 1.0}) +Session Affinity +++++++++++++++++ + +In some cases, you may want to ensure that requests from the same client, user, etc. get mapped to the same backend. +To do this, you can specify a "shard key" that will deterministically map requests to a backend. +The shard key can either be specified via the X-SERVE-SHARD-KEY HTTP header or ``handle.options(shard_key="key")``. + +.. note:: The mapping from shard key to backend may change when you update the traffic policy for an endpoint. + +.. code-block:: python + + # Specifying the shard key via an HTTP header. + requests.get("127.0.0.1:8000/api", headers={"X-SERVE-SHARD-KEY": session_id}) + + # Specifying the shard key in a call made via serve handle. + handle = serve.get_handle("api_endpoint") + handler.options(shard_key=session_id).remote(args) + Other Resources --------------- diff --git a/python/ray/serve/__init__.py b/python/ray/serve/__init__.py index 91327d401..a6300880f 100644 --- a/python/ray/serve/__init__.py +++ b/python/ray/serve/__init__.py @@ -1,4 +1,3 @@ -from ray.serve.policy import RoutePolicy from ray.serve.api import (init, create_backend, delete_backend, create_endpoint, delete_endpoint, set_traffic, get_handle, stat, update_backend_config, @@ -7,6 +6,5 @@ from ray.serve.api import (init, create_backend, delete_backend, __all__ = [ "init", "create_backend", "delete_backend", "create_endpoint", "delete_endpoint", "set_traffic", "get_handle", "stat", - "update_backend_config", "get_backend_config", "RoutePolicy", - "accept_batch" + "update_backend_config", "get_backend_config", "accept_batch" ] diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 61722e597..63f7ed74a 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -10,7 +10,6 @@ from ray.serve.handle import RayServeHandle from ray.serve.utils import block_until_http_ready, retry_actor_failures from ray.serve.exceptions import RayServeException from ray.serve.config import BackendConfig, ReplicaConfig -from ray.serve.policy import RoutePolicy from ray.serve.router import Query from ray.serve.request_params import RequestMetadata from ray.serve.metric import InMemoryExporter @@ -69,8 +68,6 @@ def init(blocking=False, "object_store_memory": int(1e8), "num_cpus": max(cpu_count(), 8) }, - queueing_policy=RoutePolicy.Random, - policy_kwargs={}, metric_exporter=InMemoryExporter): """Initialize a serve cluster. @@ -90,9 +87,6 @@ def init(blocking=False, ray_init_kwargs (dict): Argument passed to ray.init, if there is no ray connection. Default to {"object_store_memory": int(1e8)} for performance stability reason - queueing_policy(RoutePolicy): Define the queueing policy for selecting - the backend for a service. (Default: RoutePolicy.Random) - policy_kwargs: Arguments required to instantiate a queueing policy metric_exporter(ExporterInterface): The class aggregates metrics from all RayServe actors and optionally export them to external services. RayServe has two options built in: InMemoryExporter and @@ -132,8 +126,7 @@ def init(blocking=False, detached=True, name=SERVE_MASTER_NAME, max_restarts=-1, - ).remote(queueing_policy.value, policy_kwargs, start_server, http_node_id, - http_host, http_port, metric_exporter) + ).remote(start_server, http_node_id, http_host, http_port, metric_exporter) if start_server and blocking: block_until_http_ready("http://{}:{}/-/routes".format( diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index df9b08fe9..4b886eb56 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -2,7 +2,6 @@ import ray from ray import serve from ray.serve.context import TaskContext from ray.serve.exceptions import RayServeException -from ray.serve.constants import DEFAULT_HTTP_ADDRESS from ray.serve.request_params import RequestMetadata @@ -35,6 +34,7 @@ class RayServeHandle: relative_slo_ms=None, absolute_slo_ms=None, method_name=None, + shard_key=None, ): self.router_handle = router_handle self.endpoint_name = endpoint_name @@ -45,6 +45,7 @@ class RayServeHandle: self.relative_slo_ms = self._check_slo_ms(relative_slo_ms) self.absolute_slo_ms = self._check_slo_ms(absolute_slo_ms) self.method_name = method_name + self.shard_key = shard_key def _check_slo_ms(self, slo_value): if slo_value is not None: @@ -75,12 +76,14 @@ class RayServeHandle: self.relative_slo_ms, self.absolute_slo_ms, call_method=method_name, + shard_key=self.shard_key, ) return self.router_handle.enqueue_request.remote( request_in_object, **kwargs) def options(self, method_name=None, + shard_key=None, relative_slo_ms=None, absolute_slo_ms=None): # If both the slo's are None then then we use a high default @@ -95,48 +98,30 @@ class RayServeHandle: if method_name is None and self.method_name is not None: method_name = self.method_name + if shard_key is None and self.shard_key is not None: + shard_key = self.shard_key + return RayServeHandle( self.router_handle, self.endpoint_name, relative_slo_ms, absolute_slo_ms, method_name=method_name, + shard_key=shard_key, ) - def get_http_endpoint(self): - return DEFAULT_HTTP_ADDRESS - def get_traffic_policy(self): master_actor = serve.api._get_master_actor() return ray.get( master_actor.get_traffic_policy.remote(self.endpoint_name)) - def _ensure_backend_unique(self, backend_tag=None): - traffic_policy = self.get_traffic_policy() - if backend_tag is None: - assert len(traffic_policy) == 1, ( - "Multiple backends detected. " - "Please pass in backend_tag=... argument to specify backend.") - backends = set(traffic_policy.keys()) - return backends.pop() - else: - assert (backend_tag in traffic_policy - ), "Backend {} not found in avaiable backends: {}.".format( - backend_tag, list(traffic_policy.keys())) - return backend_tag - def __repr__(self): return """ RayServeHandle( Endpoint="{endpoint_name}", - URL="{http_endpoint}/{endpoint_name}", Traffic={traffic_policy} ) """.format( endpoint_name=self.endpoint_name, - http_endpoint=self.get_http_endpoint(), traffic_policy=self.get_traffic_policy(), ) - - # TODO(simon): a convenience function that dumps equivalent requests - # code for a given call. diff --git a/python/ray/serve/http_proxy.py b/python/ray/serve/http_proxy.py index f3ebffaac..2975ebfd8 100644 --- a/python/ray/serve/http_proxy.py +++ b/python/ray/serve/http_proxy.py @@ -168,7 +168,9 @@ class HTTPProxy: TaskContext.Web, relative_slo_ms=relative_slo_ms, absolute_slo_ms=absolute_slo_ms, - call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__")) + call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"), + shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None), + ) retries = 0 while retries <= MAX_ACTOR_DEAD_RETRIES: diff --git a/python/ray/serve/master.py b/python/ray/serve/master.py index 47eb17a8c..f680b57a6 100644 --- a/python/ray/serve/master.py +++ b/python/ray/serve/master.py @@ -49,8 +49,7 @@ class ServeMaster: requires all implementations here to be idempotent. """ - async def __init__(self, router_policy, router_policy_kwargs, - start_http_proxy, http_node_id, http_proxy_host, + async def __init__(self, start_http_proxy, http_node_id, http_proxy_host, http_proxy_port, metric_exporter_class): # Used to read/write checkpoints. # TODO(edoakes): namespace the master actor and its checkpoints. @@ -89,7 +88,7 @@ class ServeMaster: # If starting the actor for the first time, starts up the other system # components. If recovering, fetches their actor handles. self._get_or_start_metric_exporter(metric_exporter_class) - self._get_or_start_router(router_policy, router_policy_kwargs) + self._get_or_start_router() if start_http_proxy: self._get_or_start_http_proxy(http_node_id, http_proxy_host, http_proxy_port) @@ -114,7 +113,7 @@ class ServeMaster: asyncio.get_event_loop().create_task( self._recover_from_checkpoint(checkpoint)) - def _get_or_start_router(self, policy, policy_kwargs): + def _get_or_start_router(self): """Get the router belonging to this serve cluster. If the router does not already exist, it will be started. @@ -128,7 +127,7 @@ class ServeMaster: detached=True, name=SERVE_ROUTER_NAME, max_concurrency=ASYNC_CONCURRENCY, - max_restarts=-1).remote(policy, policy_kwargs) + max_restarts=-1).remote() def get_router(self): """Returns a handle to the router managed by this actor.""" diff --git a/python/ray/serve/policy.py b/python/ray/serve/policy.py index 2ddebf03e..f1df1ef83 100644 --- a/python/ray/serve/policy.py +++ b/python/ray/serve/policy.py @@ -1,13 +1,12 @@ from abc import ABCMeta, abstractmethod -from enum import Enum -import itertools +from hashlib import sha256 import numpy as np from ray.serve.utils import logger -class RoutingPolicy: +class EndpointPolicy: """Defines the interface for a routing policy for a single endpoint. To add a new routing policy, a class should be defined that provides this @@ -39,15 +38,18 @@ class RoutingPolicy: return assigned_backends -class RandomPolicy(RoutingPolicy): +class RandomEndpointPolicy(EndpointPolicy): """ A stateless policy that makes a weighted random decision to map each query to a backend using the specified weights. + + If a shard key is provided in a query, the weighted random selection will + be made deterministically based on the hash of the shard key. """ def __init__(self, traffic_dict): - self.backend_names = list(traffic_dict.keys()) - self.backend_weights = list(traffic_dict.values()) + self.backend_names, self.backend_weights = zip( + *sorted(traffic_dict.items())) async def flush(self, endpoint_queue, backend_queues): if len(self.backend_names) == 0: @@ -56,113 +58,19 @@ class RandomPolicy(RoutingPolicy): assigned_backends = set() while endpoint_queue.qsize(): - chosen_backend = np.random.choice( + query = await endpoint_queue.get() + if query.shard_key is None: + rstate = np.random + else: + sha256_seed = sha256(query.shard_key.encode("utf-8")) + seed = np.frombuffer(sha256_seed.digest(), dtype=np.uint32) + rstate = np.random.RandomState(seed) + + chosen_backend = rstate.choice( self.backend_names, replace=False, p=self.backend_weights).squeeze() + assigned_backends.add(chosen_backend) - backend_queues[chosen_backend].add(await endpoint_queue.get()) + backend_queues[chosen_backend].add(query) return assigned_backends - - -class RoundRobinPolicy(RoutingPolicy): - """A stateful policy that assigns queries in round-robin order.""" - - def __init__(self, traffic_dict): - # NOTE(edoakes): the backend weights are not used. - self.backend_names = list(traffic_dict.keys()) - # Saves the information about last assigned backend for every endpoint. - self.round_robin_iterator = itertools.cycle(self.backend_names) - - async def flush(self, endpoint_queue, backend_queues): - if len(self.backend_names) == 0: - logger.info("No backends to assign traffic to.") - return set() - - assigned_backends = set() - while endpoint_queue.qsize(): - chosen_backend = next(self.round_robin_iterator) - assigned_backends.add(chosen_backend) - backend_queues[chosen_backend].add(await endpoint_queue.get()) - - return assigned_backends - - -class PowerOfTwoPolicy(RoutingPolicy): - """A stateless policy that uses the "power of two" policy. - - For each query, two random backends are chosen. Of those two, the query is - assigned to the backend whose queue length is shorter. - """ - - def __init__(self, traffic_dict): - self.backend_names = list(traffic_dict.keys()) - self.backend_weights = list(traffic_dict.values()) - - async def flush(self, endpoint_queue, backend_queues): - if len(self.backend_names) == 0: - logger.info("No backends to assign traffic to.") - return set() - - assigned_backends = set() - while endpoint_queue.qsize(): - if len(self.backend_names) >= 2: - backend1, backend2 = np.random.choice( - self.backend_names, - 2, - replace=False, - p=self.backend_weights) - - # Choose the backend that has a shorter queue. - if (len(backend_queues[backend1]) <= len( - backend_queues[backend2])): - chosen_backend = backend1 - else: - chosen_backend = backend2 - else: - chosen_backend = np.random.choice( - self.backend_names, replace=False, - p=self.backend_weights).squeeze() - backend_queues[chosen_backend].add(await endpoint_queue.get()) - assigned_backends.add(chosen_backend) - - return assigned_backends - - -class FixedPackingPolicy(RoutingPolicy): - """A stateful policy that uses a "fixed packing" policy. - - The policy round-robins groups of packing_num queries across backends. For - example, the first packing_num queries are handled by backend-1, then the - next packing_num queries are handled by backend-2, etc. - """ - - def __init__(self, traffic_dict, packing_num=3): - # NOTE(edoakes): the backend weights are not used. - self.backend_names = list(traffic_dict.keys()) - self.fixed_packing_iterator = itertools.cycle( - itertools.chain.from_iterable( - itertools.repeat(x, self.packing_num) - for x in self.backend_names)) - self.packing_num = packing_num - - async def flush(self, endpoint_queue, backend_queues): - if len(self.backend_names) == 0: - logger.info("No backends to assign traffic to.") - return set() - - assigned_backends = set() - while endpoint_queue.qsize(): - chosen_backend = next(self.fixed_packing_iterator) - backend_queues[chosen_backend].add(await endpoint_queue.get()) - assigned_backends.add(chosen_backend) - - return assigned_backends - - -class RoutePolicy(Enum): - """All builtin routing policies.""" - Random = RandomPolicy - RoundRobin = RoundRobinPolicy - PowerOfTwo = PowerOfTwoPolicy - FixedPacking = FixedPackingPolicy diff --git a/python/ray/serve/request_params.py b/python/ray/serve/request_params.py index 156c93903..6c5a0c61d 100644 --- a/python/ray/serve/request_params.py +++ b/python/ray/serve/request_params.py @@ -21,13 +21,15 @@ class RequestMetadata: request_context, relative_slo_ms=None, absolute_slo_ms=None, - call_method="__call__"): + call_method="__call__", + shard_key=None): self.endpoint = endpoint self.request_context = request_context self.relative_slo_ms = relative_slo_ms self.absolute_slo_ms = absolute_slo_ms self.call_method = call_method + self.shard_key = shard_key def adjust_relative_slo_ms(self) -> float: """Normalize the input latency objective to absolute timestamp. diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index ae94250ec..b07ff3dbe 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -16,6 +16,7 @@ import ray import ray.cloudpickle as pickle from ray.exceptions import RayTaskError from ray.serve.metric import MetricClient +from ray.serve.policy import RandomEndpointPolicy from ray.serve.utils import logger, retry_actor_failures @@ -26,6 +27,7 @@ class Query: request_context, request_slo_ms, call_method="__call__", + shard_key=None, async_future=None): self.request_args = request_args self.request_kwargs = request_kwargs @@ -38,6 +40,7 @@ class Query: self.request_slo_ms = request_slo_ms self.call_method = call_method + self.shard_key = shard_key def ray_serialize(self): # NOTE: this method is needed because Query need to be serialized and @@ -103,7 +106,7 @@ class Router: 3. When there is only 1 backend ready, we will only use that backend. """ - async def __init__(self, policy, policy_kwargs): + async def __init__(self): # Note: Several queues are used in the router # - When a request come in, it's placed inside its corresponding # endpoint_queue. @@ -114,11 +117,6 @@ class Router: # handles are dequed during the second stage of flush operation, # which assign queries in buffer_queue to actor handle. - # policy.RoutePolicy. - self.policy = policy - # kwargs to pass into the policy when it's constructed. - self.policy_kwargs = policy_kwargs - # -- Queues -- # # endpoint_name -> request queue @@ -211,6 +209,7 @@ class Router: request_context, request_slo_ms, call_method=request_meta.call_method, + shard_key=request_meta.shard_key, async_future=asyncio.get_event_loop().create_future()) await self.endpoint_queues[endpoint].put(query) async with self.flush_lock: @@ -268,8 +267,7 @@ class Router: logger.debug("Setting traffic for endpoint %s to %s", endpoint, traffic_dict) async with self.flush_lock: - self.traffic[endpoint] = self.policy(traffic_dict, - **self.policy_kwargs) + self.traffic[endpoint] = RandomEndpointPolicy(traffic_dict) await self.flush_endpoint_queue(endpoint) async def remove_endpoint(self, endpoint): diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 913012ebc..32a9332ba 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -2,8 +2,9 @@ import time import pytest import requests -from ray import serve import ray +from ray import serve +from ray.serve.utils import get_random_letters def test_e2e(serve_instance): @@ -302,3 +303,42 @@ def test_delete_endpoint(serve_instance, route): else: handle = serve.get_handle(endpoint_name) assert ray.get(handle.remote()) == "hello" + + +@pytest.mark.parametrize("route", [None, "/shard"]) +def test_shard_key(serve_instance, route): + serve.create_endpoint("endpoint", route=route) + + # Create five backends that return different integers. + num_backends = 5 + traffic_dict = {} + for i in range(num_backends): + + def function(): + return i + + backend_name = "backend-split-" + str(i) + traffic_dict[backend_name] = 1.0 / num_backends + serve.create_backend(backend_name, function) + + serve.set_traffic("endpoint", traffic_dict) + + def do_request(shard_key): + if route is not None: + url = "http://127.0.0.1:8000" + route + headers = {"X-SERVE-SHARD-KEY": shard_key} + result = requests.get(url, headers=headers).text + else: + handle = serve.get_handle("endpoint").options(shard_key=shard_key) + result = ray.get(handle.options(shard_key=shard_key).remote()) + return result + + # Send requests with different shard keys and log the backends they go to. + shard_keys = [get_random_letters() for _ in range(20)] + results = {} + for shard_key in shard_keys: + results[shard_key] = do_request(shard_key) + + # Check that the shard keys are mapped to the same backends. + for shard_key in shard_keys: + assert do_request(shard_key) == results[shard_key] diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 97b701938..e20cbc3a2 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -6,7 +6,6 @@ import numpy as np import ray from ray import serve import ray.serve.context as context -from ray.serve.policy import RoundRobinPolicy from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error from ray.serve.request_params import RequestMetadata from ray.serve.router import Router @@ -43,7 +42,7 @@ async def test_runner_wraps_error(): async def test_runner_actor(serve_instance): - q = ray.remote(Router).remote(RoundRobinPolicy, {}) + q = ray.remote(Router).remote() def echo(flask_request, i=None): return i @@ -64,7 +63,7 @@ async def test_runner_actor(serve_instance): async def test_ray_serve_mixin(serve_instance): - q = ray.remote(Router).remote(RoundRobinPolicy, {}) + q = ray.remote(Router).remote() CONSUMER_NAME = "runner-cls" PRODUCER_NAME = "prod-cls" @@ -89,7 +88,7 @@ async def test_ray_serve_mixin(serve_instance): async def test_task_runner_check_context(serve_instance): - q = ray.remote(Router).remote(RoundRobinPolicy, {}) + q = ray.remote(Router).remote() def echo(flask_request, i=None): # Accessing the flask_request without web context should throw. @@ -110,7 +109,7 @@ async def test_task_runner_check_context(serve_instance): async def test_task_runner_custom_method_single(serve_instance): - q = ray.remote(Router).remote(RoundRobinPolicy, {}) + q = ray.remote(Router).remote() class NonBatcher: def a(self, _): @@ -144,7 +143,7 @@ async def test_task_runner_custom_method_single(serve_instance): async def test_task_runner_custom_method_batch(serve_instance): - q = ray.remote(Router).remote(RoundRobinPolicy, {}) + q = ray.remote(Router).remote() @serve.accept_batch class Batcher: diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index 444a7fa90..ef7d850ff 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -1,12 +1,12 @@ import asyncio +from collections import defaultdict import pytest import ray -from ray.serve.policy import (RandomPolicy, RoundRobinPolicy, PowerOfTwoPolicy, - FixedPackingPolicy) from ray.serve.router import Router from ray.serve.request_params import RequestMetadata +from ray.serve.utils import get_random_letters pytestmark = pytest.mark.asyncio @@ -29,6 +29,9 @@ def mock_task_runner(): def get_all_calls(self): return self.queries + def clear_calls(self): + self.queries = [] + def ready(self): pass @@ -41,7 +44,7 @@ def task_runner_mock_actor(): async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): - q = ray.remote(Router).remote(RandomPolicy, {}) + q = ray.remote(Router).remote() q.set_traffic.remote("svc", {"backend-single-prod": 1.0}) q.add_new_worker.remote("backend-single-prod", "replica-1", task_runner_mock_actor) @@ -57,7 +60,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): async def test_slo(serve_instance, task_runner_mock_actor): - q = ray.remote(Router).remote(RandomPolicy, {}) + q = ray.remote(Router).remote() await q.set_traffic.remote("svc", {"backend-slo": 1.0}) all_request_sent = [] @@ -81,7 +84,7 @@ async def test_slo(serve_instance, task_runner_mock_actor): async def test_alter_backend(serve_instance, task_runner_mock_actor): - q = ray.remote(Router).remote(RandomPolicy, {}) + q = ray.remote(Router).remote() await q.set_traffic.remote("svc", {"backend-alter": 1}) await q.add_new_worker.remote("backend-alter", "replica-1", @@ -99,7 +102,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor): async def test_split_traffic_random(serve_instance, task_runner_mock_actor): - q = ray.remote(Router).remote(RandomPolicy, {}) + q = ray.remote(Router).remote() await q.set_traffic.remote("svc", { "backend-split": 0.5, @@ -121,88 +124,51 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor): assert [g.request_args[0] for g in got_work] == [1, 1] -async def test_round_robin(serve_instance, task_runner_mock_actor): - q = ray.remote(Router).remote(RoundRobinPolicy, {}) - - await q.set_traffic.remote("svc", {"backend-rr": 0.5, "backend-rr-2": 0.5}) - runner_1, runner_2 = [mock_task_runner() for _ in range(2)] - - # NOTE: this is the only difference between the - # test_split_traffic_random and test_round_robin - await q.add_new_worker.remote("backend-rr", "replica-1", runner_1) - await q.add_new_worker.remote("backend-rr-2", "replica-1", runner_2) - - for _ in range(20): - await q.enqueue_request.remote(RequestMetadata("svc", None), 1) - - got_work = [ - await runner.get_recent_call.remote() - for runner in (runner_1, runner_2) - ] - assert [g.request_args[0] for g in got_work] == [1, 1] - - -async def test_fixed_packing(serve_instance): - packing_num = 4 - q = ray.remote(Router).remote(FixedPackingPolicy, - {"packing_num": packing_num}) - await q.set_traffic.remote("svc", { - "backend-fixed": 0.5, - "backend-fixed-2": 0.5 - }) - - runner_1, runner_2 = (mock_task_runner() for _ in range(2)) - # both the backends will get equal number of queries - # as it is packed round robin - await q.add_new_worker.remote("backend-fixed", "replica-1", runner_1) - await q.add_new_worker.remote("backend-fixed-2", "replica-1", runner_2) - - for backend, runner in zip(["1", "2"], [runner_1, runner_2]): - for _ in range(packing_num): - input_value = "should-go-to-backend-{}".format(backend) - await q.enqueue_request.remote( - RequestMetadata("svc", None), input_value) - all_calls = await runner.get_all_calls.remote() - for call in all_calls: - assert call.request_args[0] == input_value - - -async def test_power_of_two_choices(serve_instance): - q = ray.remote(Router).remote(PowerOfTwoPolicy, {}) - enqueue_futures = [] - - # First, fill the queue for backend-1 with 3 requests - await q.set_traffic.remote("svc", {"backend-pow2": 1.0}) - for _ in range(3): - future = q.enqueue_request.remote(RequestMetadata("svc", None), "1") - enqueue_futures.append(future) - - # Then, add a new backend, this backend should be filled next - await q.set_traffic.remote("svc", { - "backend-pow2": 0.5, - "backend-pow2-2": 0.5 - }) - for _ in range(2): - future = q.enqueue_request.remote(RequestMetadata("svc", None), "2") - enqueue_futures.append(future) - - runner_1, runner_2 = (mock_task_runner() for _ in range(2)) - await q.add_new_worker.remote("backend-pow2", "replica-1", runner_1) - await q.add_new_worker.remote("backend-pow2-2", "replica-1", runner_2) - - await asyncio.gather(*enqueue_futures) - - assert len(await runner_1.get_all_calls.remote()) == 3 - assert len(await runner_2.get_all_calls.remote()) == 2 - - async def test_queue_remove_replicas(serve_instance): class TestRouter(Router): def worker_queue_size(self, backend): return self.worker_queues["backend-remove"].qsize() temp_actor = mock_task_runner() - q = ray.remote(TestRouter).remote(RandomPolicy, {}) + q = ray.remote(TestRouter).remote() await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor) await q.remove_worker.remote("backend-remove", "replica-1") assert ray.get(q.worker_queue_size.remote("backend")) == 0 + + +async def test_shard_key(serve_instance, task_runner_mock_actor): + q = ray.remote(Router).remote() + + num_backends = 5 + traffic_dict = {} + runners = [mock_task_runner() for _ in range(num_backends)] + for i, runner in enumerate(runners): + backend_name = "backend-split-" + str(i) + traffic_dict[backend_name] = 1.0 / num_backends + await q.add_new_worker.remote(backend_name, "replica-1", runner) + await q.set_traffic.remote("svc", traffic_dict) + + # Generate random shard keys and send one request for each. + shard_keys = [get_random_letters() for _ in range(100)] + for shard_key in shard_keys: + await q.enqueue_request.remote( + RequestMetadata("svc", None, shard_key=shard_key), shard_key) + + # Log the shard keys that were assigned to each backend. + runner_shard_keys = defaultdict(set) + for i, runner in enumerate(runners): + calls = await runner.get_all_calls.remote() + for call in calls: + runner_shard_keys[i].add(call.request_args[0]) + await runner.clear_calls.remote() + + # Send queries with the same shard keys a second time. + for shard_key in shard_keys: + await q.enqueue_request.remote( + RequestMetadata("svc", None, shard_key=shard_key), shard_key) + + # Check that the requests were all mapped to the same backends. + for i, runner in enumerate(runners): + calls = await runner.get_all_calls.remote() + for call in calls: + assert call.request_args[0] in runner_shard_keys[i]