[serve] Add basic session affinity via shard key (#8449)

This commit is contained in:
Edward Oakes
2020-05-15 16:18:52 -05:00
committed by GitHub
parent c9435cad43
commit ef498e8aa5
12 changed files with 157 additions and 249 deletions
+1 -3
View File
@@ -1,4 +1,3 @@
from ray.serve.policy import RoutePolicy
from ray.serve.api import (init, create_backend, delete_backend,
create_endpoint, delete_endpoint, set_traffic,
get_handle, stat, update_backend_config,
@@ -7,6 +6,5 @@ from ray.serve.api import (init, create_backend, delete_backend,
__all__ = [
"init", "create_backend", "delete_backend", "create_endpoint",
"delete_endpoint", "set_traffic", "get_handle", "stat",
"update_backend_config", "get_backend_config", "RoutePolicy",
"accept_batch"
"update_backend_config", "get_backend_config", "accept_batch"
]
+1 -8
View File
@@ -10,7 +10,6 @@ from ray.serve.handle import RayServeHandle
from ray.serve.utils import block_until_http_ready, retry_actor_failures
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.serve.policy import RoutePolicy
from ray.serve.router import Query
from ray.serve.request_params import RequestMetadata
from ray.serve.metric import InMemoryExporter
@@ -69,8 +68,6 @@ def init(blocking=False,
"object_store_memory": int(1e8),
"num_cpus": max(cpu_count(), 8)
},
queueing_policy=RoutePolicy.Random,
policy_kwargs={},
metric_exporter=InMemoryExporter):
"""Initialize a serve cluster.
@@ -90,9 +87,6 @@ def init(blocking=False,
ray_init_kwargs (dict): Argument passed to ray.init, if there is no ray
connection. Default to {"object_store_memory": int(1e8)} for
performance stability reason
queueing_policy(RoutePolicy): Define the queueing policy for selecting
the backend for a service. (Default: RoutePolicy.Random)
policy_kwargs: Arguments required to instantiate a queueing policy
metric_exporter(ExporterInterface): The class aggregates metrics from
all RayServe actors and optionally export them to external
services. RayServe has two options built in: InMemoryExporter and
@@ -132,8 +126,7 @@ def init(blocking=False,
detached=True,
name=SERVE_MASTER_NAME,
max_restarts=-1,
).remote(queueing_policy.value, policy_kwargs, start_server, http_node_id,
http_host, http_port, metric_exporter)
).remote(start_server, http_node_id, http_host, http_port, metric_exporter)
if start_server and blocking:
block_until_http_ready("http://{}:{}/-/routes".format(
+8 -23
View File
@@ -2,7 +2,6 @@ import ray
from ray import serve
from ray.serve.context import TaskContext
from ray.serve.exceptions import RayServeException
from ray.serve.constants import DEFAULT_HTTP_ADDRESS
from ray.serve.request_params import RequestMetadata
@@ -35,6 +34,7 @@ class RayServeHandle:
relative_slo_ms=None,
absolute_slo_ms=None,
method_name=None,
shard_key=None,
):
self.router_handle = router_handle
self.endpoint_name = endpoint_name
@@ -45,6 +45,7 @@ class RayServeHandle:
self.relative_slo_ms = self._check_slo_ms(relative_slo_ms)
self.absolute_slo_ms = self._check_slo_ms(absolute_slo_ms)
self.method_name = method_name
self.shard_key = shard_key
def _check_slo_ms(self, slo_value):
if slo_value is not None:
@@ -75,12 +76,14 @@ class RayServeHandle:
self.relative_slo_ms,
self.absolute_slo_ms,
call_method=method_name,
shard_key=self.shard_key,
)
return self.router_handle.enqueue_request.remote(
request_in_object, **kwargs)
def options(self,
method_name=None,
shard_key=None,
relative_slo_ms=None,
absolute_slo_ms=None):
# If both the slo's are None then then we use a high default
@@ -95,48 +98,30 @@ class RayServeHandle:
if method_name is None and self.method_name is not None:
method_name = self.method_name
if shard_key is None and self.shard_key is not None:
shard_key = self.shard_key
return RayServeHandle(
self.router_handle,
self.endpoint_name,
relative_slo_ms,
absolute_slo_ms,
method_name=method_name,
shard_key=shard_key,
)
def get_http_endpoint(self):
return DEFAULT_HTTP_ADDRESS
def get_traffic_policy(self):
master_actor = serve.api._get_master_actor()
return ray.get(
master_actor.get_traffic_policy.remote(self.endpoint_name))
def _ensure_backend_unique(self, backend_tag=None):
traffic_policy = self.get_traffic_policy()
if backend_tag is None:
assert len(traffic_policy) == 1, (
"Multiple backends detected. "
"Please pass in backend_tag=... argument to specify backend.")
backends = set(traffic_policy.keys())
return backends.pop()
else:
assert (backend_tag in traffic_policy
), "Backend {} not found in avaiable backends: {}.".format(
backend_tag, list(traffic_policy.keys()))
return backend_tag
def __repr__(self):
return """
RayServeHandle(
Endpoint="{endpoint_name}",
URL="{http_endpoint}/{endpoint_name}",
Traffic={traffic_policy}
)
""".format(
endpoint_name=self.endpoint_name,
http_endpoint=self.get_http_endpoint(),
traffic_policy=self.get_traffic_policy(),
)
# TODO(simon): a convenience function that dumps equivalent requests
# code for a given call.
+3 -1
View File
@@ -168,7 +168,9 @@ class HTTPProxy:
TaskContext.Web,
relative_slo_ms=relative_slo_ms,
absolute_slo_ms=absolute_slo_ms,
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"))
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
)
retries = 0
while retries <= MAX_ACTOR_DEAD_RETRIES:
+4 -5
View File
@@ -49,8 +49,7 @@ class ServeMaster:
requires all implementations here to be idempotent.
"""
async def __init__(self, router_policy, router_policy_kwargs,
start_http_proxy, http_node_id, http_proxy_host,
async def __init__(self, start_http_proxy, http_node_id, http_proxy_host,
http_proxy_port, metric_exporter_class):
# Used to read/write checkpoints.
# TODO(edoakes): namespace the master actor and its checkpoints.
@@ -89,7 +88,7 @@ class ServeMaster:
# If starting the actor for the first time, starts up the other system
# components. If recovering, fetches their actor handles.
self._get_or_start_metric_exporter(metric_exporter_class)
self._get_or_start_router(router_policy, router_policy_kwargs)
self._get_or_start_router()
if start_http_proxy:
self._get_or_start_http_proxy(http_node_id, http_proxy_host,
http_proxy_port)
@@ -114,7 +113,7 @@ class ServeMaster:
asyncio.get_event_loop().create_task(
self._recover_from_checkpoint(checkpoint))
def _get_or_start_router(self, policy, policy_kwargs):
def _get_or_start_router(self):
"""Get the router belonging to this serve cluster.
If the router does not already exist, it will be started.
@@ -128,7 +127,7 @@ class ServeMaster:
detached=True,
name=SERVE_ROUTER_NAME,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1).remote(policy, policy_kwargs)
max_restarts=-1).remote()
def get_router(self):
"""Returns a handle to the router managed by this actor."""
+19 -111
View File
@@ -1,13 +1,12 @@
from abc import ABCMeta, abstractmethod
from enum import Enum
import itertools
from hashlib import sha256
import numpy as np
from ray.serve.utils import logger
class RoutingPolicy:
class EndpointPolicy:
"""Defines the interface for a routing policy for a single endpoint.
To add a new routing policy, a class should be defined that provides this
@@ -39,15 +38,18 @@ class RoutingPolicy:
return assigned_backends
class RandomPolicy(RoutingPolicy):
class RandomEndpointPolicy(EndpointPolicy):
"""
A stateless policy that makes a weighted random decision to map each query
to a backend using the specified weights.
If a shard key is provided in a query, the weighted random selection will
be made deterministically based on the hash of the shard key.
"""
def __init__(self, traffic_dict):
self.backend_names = list(traffic_dict.keys())
self.backend_weights = list(traffic_dict.values())
self.backend_names, self.backend_weights = zip(
*sorted(traffic_dict.items()))
async def flush(self, endpoint_queue, backend_queues):
if len(self.backend_names) == 0:
@@ -56,113 +58,19 @@ class RandomPolicy(RoutingPolicy):
assigned_backends = set()
while endpoint_queue.qsize():
chosen_backend = np.random.choice(
query = await endpoint_queue.get()
if query.shard_key is None:
rstate = np.random
else:
sha256_seed = sha256(query.shard_key.encode("utf-8"))
seed = np.frombuffer(sha256_seed.digest(), dtype=np.uint32)
rstate = np.random.RandomState(seed)
chosen_backend = rstate.choice(
self.backend_names, replace=False,
p=self.backend_weights).squeeze()
assigned_backends.add(chosen_backend)
backend_queues[chosen_backend].add(await endpoint_queue.get())
backend_queues[chosen_backend].add(query)
return assigned_backends
class RoundRobinPolicy(RoutingPolicy):
"""A stateful policy that assigns queries in round-robin order."""
def __init__(self, traffic_dict):
# NOTE(edoakes): the backend weights are not used.
self.backend_names = list(traffic_dict.keys())
# Saves the information about last assigned backend for every endpoint.
self.round_robin_iterator = itertools.cycle(self.backend_names)
async def flush(self, endpoint_queue, backend_queues):
if len(self.backend_names) == 0:
logger.info("No backends to assign traffic to.")
return set()
assigned_backends = set()
while endpoint_queue.qsize():
chosen_backend = next(self.round_robin_iterator)
assigned_backends.add(chosen_backend)
backend_queues[chosen_backend].add(await endpoint_queue.get())
return assigned_backends
class PowerOfTwoPolicy(RoutingPolicy):
"""A stateless policy that uses the "power of two" policy.
For each query, two random backends are chosen. Of those two, the query is
assigned to the backend whose queue length is shorter.
"""
def __init__(self, traffic_dict):
self.backend_names = list(traffic_dict.keys())
self.backend_weights = list(traffic_dict.values())
async def flush(self, endpoint_queue, backend_queues):
if len(self.backend_names) == 0:
logger.info("No backends to assign traffic to.")
return set()
assigned_backends = set()
while endpoint_queue.qsize():
if len(self.backend_names) >= 2:
backend1, backend2 = np.random.choice(
self.backend_names,
2,
replace=False,
p=self.backend_weights)
# Choose the backend that has a shorter queue.
if (len(backend_queues[backend1]) <= len(
backend_queues[backend2])):
chosen_backend = backend1
else:
chosen_backend = backend2
else:
chosen_backend = np.random.choice(
self.backend_names, replace=False,
p=self.backend_weights).squeeze()
backend_queues[chosen_backend].add(await endpoint_queue.get())
assigned_backends.add(chosen_backend)
return assigned_backends
class FixedPackingPolicy(RoutingPolicy):
"""A stateful policy that uses a "fixed packing" policy.
The policy round-robins groups of packing_num queries across backends. For
example, the first packing_num queries are handled by backend-1, then the
next packing_num queries are handled by backend-2, etc.
"""
def __init__(self, traffic_dict, packing_num=3):
# NOTE(edoakes): the backend weights are not used.
self.backend_names = list(traffic_dict.keys())
self.fixed_packing_iterator = itertools.cycle(
itertools.chain.from_iterable(
itertools.repeat(x, self.packing_num)
for x in self.backend_names))
self.packing_num = packing_num
async def flush(self, endpoint_queue, backend_queues):
if len(self.backend_names) == 0:
logger.info("No backends to assign traffic to.")
return set()
assigned_backends = set()
while endpoint_queue.qsize():
chosen_backend = next(self.fixed_packing_iterator)
backend_queues[chosen_backend].add(await endpoint_queue.get())
assigned_backends.add(chosen_backend)
return assigned_backends
class RoutePolicy(Enum):
"""All builtin routing policies."""
Random = RandomPolicy
RoundRobin = RoundRobinPolicy
PowerOfTwo = PowerOfTwoPolicy
FixedPacking = FixedPackingPolicy
+3 -1
View File
@@ -21,13 +21,15 @@ class RequestMetadata:
request_context,
relative_slo_ms=None,
absolute_slo_ms=None,
call_method="__call__"):
call_method="__call__",
shard_key=None):
self.endpoint = endpoint
self.request_context = request_context
self.relative_slo_ms = relative_slo_ms
self.absolute_slo_ms = absolute_slo_ms
self.call_method = call_method
self.shard_key = shard_key
def adjust_relative_slo_ms(self) -> float:
"""Normalize the input latency objective to absolute timestamp.
+6 -8
View File
@@ -16,6 +16,7 @@ import ray
import ray.cloudpickle as pickle
from ray.exceptions import RayTaskError
from ray.serve.metric import MetricClient
from ray.serve.policy import RandomEndpointPolicy
from ray.serve.utils import logger, retry_actor_failures
@@ -26,6 +27,7 @@ class Query:
request_context,
request_slo_ms,
call_method="__call__",
shard_key=None,
async_future=None):
self.request_args = request_args
self.request_kwargs = request_kwargs
@@ -38,6 +40,7 @@ class Query:
self.request_slo_ms = request_slo_ms
self.call_method = call_method
self.shard_key = shard_key
def ray_serialize(self):
# NOTE: this method is needed because Query need to be serialized and
@@ -103,7 +106,7 @@ class Router:
3. When there is only 1 backend ready, we will only use that backend.
"""
async def __init__(self, policy, policy_kwargs):
async def __init__(self):
# Note: Several queues are used in the router
# - When a request come in, it's placed inside its corresponding
# endpoint_queue.
@@ -114,11 +117,6 @@ class Router:
# handles are dequed during the second stage of flush operation,
# which assign queries in buffer_queue to actor handle.
# policy.RoutePolicy.
self.policy = policy
# kwargs to pass into the policy when it's constructed.
self.policy_kwargs = policy_kwargs
# -- Queues -- #
# endpoint_name -> request queue
@@ -211,6 +209,7 @@ class Router:
request_context,
request_slo_ms,
call_method=request_meta.call_method,
shard_key=request_meta.shard_key,
async_future=asyncio.get_event_loop().create_future())
await self.endpoint_queues[endpoint].put(query)
async with self.flush_lock:
@@ -268,8 +267,7 @@ class Router:
logger.debug("Setting traffic for endpoint %s to %s", endpoint,
traffic_dict)
async with self.flush_lock:
self.traffic[endpoint] = self.policy(traffic_dict,
**self.policy_kwargs)
self.traffic[endpoint] = RandomEndpointPolicy(traffic_dict)
await self.flush_endpoint_queue(endpoint)
async def remove_endpoint(self, endpoint):
+41 -1
View File
@@ -2,8 +2,9 @@ import time
import pytest
import requests
from ray import serve
import ray
from ray import serve
from ray.serve.utils import get_random_letters
def test_e2e(serve_instance):
@@ -302,3 +303,42 @@ def test_delete_endpoint(serve_instance, route):
else:
handle = serve.get_handle(endpoint_name)
assert ray.get(handle.remote()) == "hello"
@pytest.mark.parametrize("route", [None, "/shard"])
def test_shard_key(serve_instance, route):
serve.create_endpoint("endpoint", route=route)
# Create five backends that return different integers.
num_backends = 5
traffic_dict = {}
for i in range(num_backends):
def function():
return i
backend_name = "backend-split-" + str(i)
traffic_dict[backend_name] = 1.0 / num_backends
serve.create_backend(backend_name, function)
serve.set_traffic("endpoint", traffic_dict)
def do_request(shard_key):
if route is not None:
url = "http://127.0.0.1:8000" + route
headers = {"X-SERVE-SHARD-KEY": shard_key}
result = requests.get(url, headers=headers).text
else:
handle = serve.get_handle("endpoint").options(shard_key=shard_key)
result = ray.get(handle.options(shard_key=shard_key).remote())
return result
# Send requests with different shard keys and log the backends they go to.
shard_keys = [get_random_letters() for _ in range(20)]
results = {}
for shard_key in shard_keys:
results[shard_key] = do_request(shard_key)
# Check that the shard keys are mapped to the same backends.
for shard_key in shard_keys:
assert do_request(shard_key) == results[shard_key]
@@ -6,7 +6,6 @@ import numpy as np
import ray
from ray import serve
import ray.serve.context as context
from ray.serve.policy import RoundRobinPolicy
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
from ray.serve.request_params import RequestMetadata
from ray.serve.router import Router
@@ -43,7 +42,7 @@ async def test_runner_wraps_error():
async def test_runner_actor(serve_instance):
q = ray.remote(Router).remote(RoundRobinPolicy, {})
q = ray.remote(Router).remote()
def echo(flask_request, i=None):
return i
@@ -64,7 +63,7 @@ async def test_runner_actor(serve_instance):
async def test_ray_serve_mixin(serve_instance):
q = ray.remote(Router).remote(RoundRobinPolicy, {})
q = ray.remote(Router).remote()
CONSUMER_NAME = "runner-cls"
PRODUCER_NAME = "prod-cls"
@@ -89,7 +88,7 @@ async def test_ray_serve_mixin(serve_instance):
async def test_task_runner_check_context(serve_instance):
q = ray.remote(Router).remote(RoundRobinPolicy, {})
q = ray.remote(Router).remote()
def echo(flask_request, i=None):
# Accessing the flask_request without web context should throw.
@@ -110,7 +109,7 @@ async def test_task_runner_check_context(serve_instance):
async def test_task_runner_custom_method_single(serve_instance):
q = ray.remote(Router).remote(RoundRobinPolicy, {})
q = ray.remote(Router).remote()
class NonBatcher:
def a(self, _):
@@ -144,7 +143,7 @@ async def test_task_runner_custom_method_single(serve_instance):
async def test_task_runner_custom_method_batch(serve_instance):
q = ray.remote(Router).remote(RoundRobinPolicy, {})
q = ray.remote(Router).remote()
@serve.accept_batch
class Batcher:
+48 -82
View File
@@ -1,12 +1,12 @@
import asyncio
from collections import defaultdict
import pytest
import ray
from ray.serve.policy import (RandomPolicy, RoundRobinPolicy, PowerOfTwoPolicy,
FixedPackingPolicy)
from ray.serve.router import Router
from ray.serve.request_params import RequestMetadata
from ray.serve.utils import get_random_letters
pytestmark = pytest.mark.asyncio
@@ -29,6 +29,9 @@ def mock_task_runner():
def get_all_calls(self):
return self.queries
def clear_calls(self):
self.queries = []
def ready(self):
pass
@@ -41,7 +44,7 @@ def task_runner_mock_actor():
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote(RandomPolicy, {})
q = ray.remote(Router).remote()
q.set_traffic.remote("svc", {"backend-single-prod": 1.0})
q.add_new_worker.remote("backend-single-prod", "replica-1",
task_runner_mock_actor)
@@ -57,7 +60,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
async def test_slo(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote(RandomPolicy, {})
q = ray.remote(Router).remote()
await q.set_traffic.remote("svc", {"backend-slo": 1.0})
all_request_sent = []
@@ -81,7 +84,7 @@ async def test_slo(serve_instance, task_runner_mock_actor):
async def test_alter_backend(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote(RandomPolicy, {})
q = ray.remote(Router).remote()
await q.set_traffic.remote("svc", {"backend-alter": 1})
await q.add_new_worker.remote("backend-alter", "replica-1",
@@ -99,7 +102,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote(RandomPolicy, {})
q = ray.remote(Router).remote()
await q.set_traffic.remote("svc", {
"backend-split": 0.5,
@@ -121,88 +124,51 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
assert [g.request_args[0] for g in got_work] == [1, 1]
async def test_round_robin(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote(RoundRobinPolicy, {})
await q.set_traffic.remote("svc", {"backend-rr": 0.5, "backend-rr-2": 0.5})
runner_1, runner_2 = [mock_task_runner() for _ in range(2)]
# NOTE: this is the only difference between the
# test_split_traffic_random and test_round_robin
await q.add_new_worker.remote("backend-rr", "replica-1", runner_1)
await q.add_new_worker.remote("backend-rr-2", "replica-1", runner_2)
for _ in range(20):
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
got_work = [
await runner.get_recent_call.remote()
for runner in (runner_1, runner_2)
]
assert [g.request_args[0] for g in got_work] == [1, 1]
async def test_fixed_packing(serve_instance):
packing_num = 4
q = ray.remote(Router).remote(FixedPackingPolicy,
{"packing_num": packing_num})
await q.set_traffic.remote("svc", {
"backend-fixed": 0.5,
"backend-fixed-2": 0.5
})
runner_1, runner_2 = (mock_task_runner() for _ in range(2))
# both the backends will get equal number of queries
# as it is packed round robin
await q.add_new_worker.remote("backend-fixed", "replica-1", runner_1)
await q.add_new_worker.remote("backend-fixed-2", "replica-1", runner_2)
for backend, runner in zip(["1", "2"], [runner_1, runner_2]):
for _ in range(packing_num):
input_value = "should-go-to-backend-{}".format(backend)
await q.enqueue_request.remote(
RequestMetadata("svc", None), input_value)
all_calls = await runner.get_all_calls.remote()
for call in all_calls:
assert call.request_args[0] == input_value
async def test_power_of_two_choices(serve_instance):
q = ray.remote(Router).remote(PowerOfTwoPolicy, {})
enqueue_futures = []
# First, fill the queue for backend-1 with 3 requests
await q.set_traffic.remote("svc", {"backend-pow2": 1.0})
for _ in range(3):
future = q.enqueue_request.remote(RequestMetadata("svc", None), "1")
enqueue_futures.append(future)
# Then, add a new backend, this backend should be filled next
await q.set_traffic.remote("svc", {
"backend-pow2": 0.5,
"backend-pow2-2": 0.5
})
for _ in range(2):
future = q.enqueue_request.remote(RequestMetadata("svc", None), "2")
enqueue_futures.append(future)
runner_1, runner_2 = (mock_task_runner() for _ in range(2))
await q.add_new_worker.remote("backend-pow2", "replica-1", runner_1)
await q.add_new_worker.remote("backend-pow2-2", "replica-1", runner_2)
await asyncio.gather(*enqueue_futures)
assert len(await runner_1.get_all_calls.remote()) == 3
assert len(await runner_2.get_all_calls.remote()) == 2
async def test_queue_remove_replicas(serve_instance):
class TestRouter(Router):
def worker_queue_size(self, backend):
return self.worker_queues["backend-remove"].qsize()
temp_actor = mock_task_runner()
q = ray.remote(TestRouter).remote(RandomPolicy, {})
q = ray.remote(TestRouter).remote()
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
await q.remove_worker.remote("backend-remove", "replica-1")
assert ray.get(q.worker_queue_size.remote("backend")) == 0
async def test_shard_key(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
num_backends = 5
traffic_dict = {}
runners = [mock_task_runner() for _ in range(num_backends)]
for i, runner in enumerate(runners):
backend_name = "backend-split-" + str(i)
traffic_dict[backend_name] = 1.0 / num_backends
await q.add_new_worker.remote(backend_name, "replica-1", runner)
await q.set_traffic.remote("svc", traffic_dict)
# Generate random shard keys and send one request for each.
shard_keys = [get_random_letters() for _ in range(100)]
for shard_key in shard_keys:
await q.enqueue_request.remote(
RequestMetadata("svc", None, shard_key=shard_key), shard_key)
# Log the shard keys that were assigned to each backend.
runner_shard_keys = defaultdict(set)
for i, runner in enumerate(runners):
calls = await runner.get_all_calls.remote()
for call in calls:
runner_shard_keys[i].add(call.request_args[0])
await runner.clear_calls.remote()
# Send queries with the same shard keys a second time.
for shard_key in shard_keys:
await q.enqueue_request.remote(
RequestMetadata("svc", None, shard_key=shard_key), shard_key)
# Check that the requests were all mapped to the same backends.
for i, runner in enumerate(runners):
calls = await runner.get_all_calls.remote()
for call in calls:
assert call.request_args[0] in runner_shard_keys[i]