[serve] Add shadow traffic API (#9106)

This commit is contained in:
Edward Oakes
2020-06-25 10:55:22 -05:00
committed by GitHub
parent 536795ef79
commit aa3fd62cac
10 changed files with 294 additions and 87 deletions
+4 -2
View File
@@ -1,7 +1,8 @@
from ray.serve.api import (
init, create_backend, delete_backend, create_endpoint, delete_endpoint,
set_traffic, get_handle, stat, update_backend_config, get_backend_config,
accept_batch, list_backends, list_endpoints, shutdown) # noqa: E402
set_traffic, shadow_traffic, get_handle, stat, update_backend_config,
get_backend_config, accept_batch, list_backends, list_endpoints,
shutdown) # noqa: E402
__all__ = [
"init",
@@ -10,6 +11,7 @@ __all__ = [
"create_endpoint",
"delete_endpoint",
"set_traffic",
"shadow_traffic",
"get_handle",
"stat",
"update_backend_config",
+25
View File
@@ -295,6 +295,31 @@ def set_traffic(endpoint_name, traffic_policy_dictionary):
traffic_policy_dictionary))
@_ensure_connected
def shadow_traffic(endpoint_name, backend_tag, proportion):
"""Shadow traffic from an endpoint to a backend.
The specified proportion of requests will be duplicated and sent to the
backend. Responses of the duplicated traffic will be ignored.
The backend must not already be in use.
To stop shadowing traffic to a backend, call `shadow_traffic` with
proportion equal to 0.
Args:
endpoint_name (str): A registered service endpoint.
backend_tag (str): A registered backend.
proportion (float): The proportion of traffic from 0 to 1.
"""
if not isinstance(proportion, (float, int)) or not 0 <= proportion <= 1:
raise TypeError("proportion must be a float from 0 to 1.")
ray.get(
master_actor.shadow_traffic.remote(endpoint_name, backend_tag,
proportion))
@_ensure_connected
def get_handle(endpoint_name,
relative_slo_ms=None,
+99 -47
View File
@@ -1,5 +1,5 @@
import asyncio
from collections import defaultdict
from collections import defaultdict, namedtuple
import os
import random
import time
@@ -29,6 +29,39 @@ CHECKPOINT_KEY = "serve-master-checkpoint"
_RESOURCE_CHECK_ENABLED = True
class TrafficPolicy:
def __init__(self, traffic_dict):
self.traffic_dict = dict()
self.shadow_dict = dict()
self.set_traffic_dict(traffic_dict)
def set_traffic_dict(self, traffic_dict):
prob = 0
for backend, weight in traffic_dict.items():
if weight < 0:
raise ValueError(
"Attempted to assign a weight of {} to backend '{}'. "
"Weights cannot be negative.".format(weight, backend))
prob += weight
# These weights will later be plugged into np.random.choice, which
# uses a tolerance of 1e-8.
if not np.isclose(prob, 1, atol=1e-8):
raise ValueError("Traffic dictionary weights must sum to 1, "
"currently they sum to {}".format(prob))
self.traffic_dict = traffic_dict
def set_shadow(self, backend, proportion):
if proportion == 0 and backend in self.shadow_dict:
del self.shadow_dict[backend]
else:
self.shadow_dict[backend] = proportion
BackendInfo = namedtuple("BackendInfo",
["worker_class", "backend_config", "replica_config"])
@ray.remote
class ServeMaster:
"""Responsible for managing the state of the serving system.
@@ -63,9 +96,9 @@ class ServeMaster:
# Used to read/write checkpoints.
self.kv_store = RayInternalKVStore(namespace=instance_name)
# path -> (endpoint, methods).
self.routes = {}
# backend -> (backend_worker, backend_config, replica_config).
self.backends = {}
self.routes = dict()
# backend -> BackendInfo.
self.backends = dict()
# backend -> replica_tags.
self.replicas = defaultdict(list)
# replicas that should be started if recovering from a checkpoint.
@@ -78,7 +111,7 @@ class ServeMaster:
# endpoints that should be removed from the router if recovering from a
# checkpoint.
self.endpoints_to_remove = list()
# endpoint -> traffic_dict
# endpoint -> TrafficPolicy
self.traffic_policies = dict()
# Dictionary of backend tag to dictionaries of replica tag to worker.
# TODO(edoakes): consider removing this and just using the names.
@@ -258,9 +291,9 @@ class ServeMaster:
await self.router.add_new_worker.remote(
backend_tag, replica_tag, worker)
for backend, (_, backend_config, _) in self.backends.items():
for backend, info in self.backends.items():
await self.router.set_backend_config.remote(
backend, backend_config)
backend, info.backend_config)
await self.broadcast_backend_config(backend)
# Push configuration state to the HTTP proxy.
@@ -282,8 +315,8 @@ class ServeMaster:
def get_backend_configs(self):
"""Fetched by the router on startup."""
backend_configs = {}
for backend, (_, backend_config, _) in self.backends.items():
backend_configs[backend] = backend_config
for backend, info in self.backends.items():
backend_configs[backend] = info.backend_config
return backend_configs
def get_traffic_policies(self):
@@ -306,19 +339,18 @@ class ServeMaster:
"""
logger.debug("Starting worker '{}' for backend '{}'.".format(
replica_tag, backend_tag))
(backend_worker, backend_config,
replica_config) = self.backends[backend_tag]
backend_info = self.backends[backend_tag]
replica_name = format_actor_name(replica_tag, self.instance_name)
worker_handle = ray.remote(backend_worker).options(
worker_handle = ray.remote(backend_info.worker_class).options(
name=replica_name,
max_restarts=-1,
max_task_retries=-1,
**replica_config.ray_actor_options).remote(
**backend_info.replica_config.ray_actor_options).remote(
backend_tag,
replica_tag,
replica_config.actor_init_args,
backend_config,
backend_info.replica_config.actor_init_args,
backend_info.backend_config,
instance_name=self.instance_name)
# TODO(edoakes): we should probably have a timeout here.
await worker_handle.ready.remote()
@@ -427,11 +459,11 @@ class ServeMaster:
current_num_replicas = len(self.replicas[backend_tag])
delta_num_replicas = num_replicas - current_num_replicas
_, _, replica_config = self.backends[backend_tag]
backend_info = self.backends[backend_tag]
if delta_num_replicas > 0:
can_schedule = try_schedule_resources_on_nodes(
requirements=[
replica_config.resource_dict
backend_info.replica_config.resource_dict
for _ in range(delta_num_replicas)
],
ray_nodes=ray.nodes())
@@ -473,18 +505,27 @@ class ServeMaster:
def get_all_backends(self):
"""Returns a dictionary of backend tag to backend config dict."""
backends = {}
for backend_tag, (_, config, _) in self.backends.items():
backends[backend_tag] = config.__dict__
for backend_tag, backend_info in self.backends.items():
backends[backend_tag] = backend_info.backend_config.__dict__
return backends
def get_all_endpoints(self):
"""Returns a dictionary of endpoint to endpoint config."""
endpoints = {}
for route, (endpoint, methods) in self.routes.items():
if endpoint in self.traffic_policies:
traffic_policy = self.traffic_policies[endpoint]
traffic_dict = traffic_policy.traffic_dict
shadow_dict = traffic_policy.shadow_dict
else:
traffic_dict = {}
shadow_dict = {}
endpoints[endpoint] = {
"route": route if route.startswith("/") else None,
"methods": methods,
"traffic": self.traffic_policies.get(endpoint, {})
"traffic": traffic_dict,
"shadows": shadow_dict,
}
return endpoints
@@ -494,38 +535,51 @@ class ServeMaster:
" that is not registered.".format(endpoint_name))
assert isinstance(traffic_dict,
dict), "Traffic policy must be dictionary"
prob = 0
for backend, weight in traffic_dict.items():
if weight < 0:
raise ValueError(
"Attempted to assign a weight of {} to backend '{}'. "
"Weights cannot be negative.".format(weight, backend))
prob += weight
dict), "Traffic policy must be a dictionary."
for backend in traffic_dict:
if backend not in self.backends:
raise ValueError(
"Attempted to assign traffic to a backend '{}' that "
"is not registered.".format(backend))
# These weights will later be plugged into np.random.choice, which
# uses a tolerance of 1e-8.
assert np.isclose(
prob, 1, atol=1e-8
), "weights must sum to 1, currently they sum to {}".format(prob)
self.traffic_policies[endpoint_name] = traffic_dict
traffic_policy = TrafficPolicy(traffic_dict)
self.traffic_policies[endpoint_name] = traffic_policy
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
await self.router.set_traffic.remote(endpoint_name, traffic_dict)
await self.router.set_traffic.remote(endpoint_name, traffic_policy)
async def set_traffic(self, endpoint_name, traffic_dict):
"""Sets the traffic policy for the specified endpoint."""
async with self.write_lock:
await self._set_traffic(endpoint_name, traffic_dict)
async def shadow_traffic(self, endpoint_name, backend_tag, proportion):
"""Shadow traffic from the endpoint to the backend."""
async with self.write_lock:
if endpoint_name not in self.get_all_endpoints():
raise ValueError("Attempted to shadow traffic from an "
"endpoint '{}' that is not registered."
.format(endpoint_name))
if backend_tag not in self.backends:
raise ValueError(
"Attempted to shadow traffic to a backend '{}' that "
"is not registered.".format(backend_tag))
self.traffic_policies[endpoint_name].set_shadow(
backend_tag, proportion)
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
await self.router.set_traffic.remote(
endpoint_name, self.traffic_policies[endpoint_name])
async def create_endpoint(self, endpoint, traffic_dict, route, methods):
"""Create a new endpoint with the specified route and methods.
@@ -610,8 +664,8 @@ class ServeMaster:
# Save creator that starts replicas, the arguments to be passed in,
# and the configuration for the backends.
self.backends[backend_tag] = (backend_worker, backend_config,
replica_config)
self.backends[backend_tag] = BackendInfo(
backend_worker, backend_config, replica_config)
self._scale_replicas(backend_tag, backend_config.num_replicas)
@@ -635,8 +689,9 @@ class ServeMaster:
return
# Check that the specified backend isn't used by any endpoints.
for endpoint, traffic_dict in self.traffic_policies.items():
if backend_tag in traffic_dict:
for endpoint, traffic_policy in self.traffic_policies.items():
if (backend_tag in traffic_policy.traffic_dict
or backend_tag in traffic_policy.shadow_dict):
raise ValueError("Backend '{}' is used by endpoint '{}' "
"and cannot be deleted. Please remove "
"the backend from all endpoints and try "
@@ -665,12 +720,9 @@ class ServeMaster:
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(config_options, dict)
backend_worker, backend_config, replica_config = self.backends[
backend_tag]
backend_config.update(config_options)
self.backends[backend_tag] = (backend_worker, backend_config,
replica_config)
self.backends[backend_tag].backend_config.update(config_options)
backend_config = self.backends[backend_tag].backend_config
# Scale the replicas with the new configuration.
self._scale_replicas(backend_tag, backend_config.num_replicas)
@@ -691,7 +743,7 @@ class ServeMaster:
await self.broadcast_backend_config(backend_tag)
async def broadcast_backend_config(self, backend_tag):
_, backend_config, _ = self.backends[backend_tag]
backend_config = self.backends[backend_tag].backend_config
broadcast_futures = []
for replica_tag in self.replicas[backend_tag]:
try:
@@ -708,7 +760,7 @@ class ServeMaster:
"""Get the current config for the specified backend."""
assert (backend_tag in self.backends
), "Backend {} is not registered.".format(backend_tag)
return self.backends[backend_tag][2]
return self.backends[backend_tag].backend_config
async def shutdown(self):
"""Shuts down the serve instance completely."""
+31 -7
View File
@@ -1,4 +1,5 @@
from abc import ABCMeta, abstractmethod
import copy
from hashlib import sha256
import numpy as np
@@ -47,12 +48,29 @@ class RandomEndpointPolicy(EndpointPolicy):
be made deterministically based on the hash of the shard key.
"""
def __init__(self, traffic_dict):
self.backend_names, self.backend_weights = zip(
*sorted(traffic_dict.items()))
def __init__(self, traffic_policy):
self.backends = sorted(traffic_policy.traffic_dict.items())
self.shadow_backends = list(traffic_policy.shadow_dict.items())
def _select_backends(self, val):
curr_sum = 0
for name, weight in self.backends:
curr_sum += weight
if curr_sum > val:
chosen_backend = name
break
else:
assert False, "This should never be reached."
shadow_backends = []
for backend, backend_weight in self.shadow_backends:
if val < backend_weight:
shadow_backends.append(backend)
return chosen_backend, shadow_backends
def flush(self, endpoint_queue, backend_queues):
if len(self.backend_names) == 0:
if len(self.backends) == 0:
logger.info("No backends to assign traffic to.")
return set()
@@ -67,11 +85,17 @@ class RandomEndpointPolicy(EndpointPolicy):
# Note(simon): This constructor takes 100+us, maybe cache this?
rstate = np.random.RandomState(seed)
chosen_backend = rstate.choice(
self.backend_names, replace=False,
p=self.backend_weights).squeeze()
chosen_backend, shadow_backends = self._select_backends(
rstate.random())
assigned_backends.add(chosen_backend)
backend_queues[chosen_backend].add(query)
if len(shadow_backends) > 0:
shadow_query = copy.copy(query)
shadow_query.async_future = None
shadow_query.is_shadow_query = True
for shadow_backend in shadow_backends:
assigned_backends.add(shadow_backend)
backend_queues[shadow_backend].add(shadow_query)
return assigned_backends
+26 -13
View File
@@ -17,14 +17,17 @@ from ray.serve.utils import logger, chain_future
class Query:
def __init__(self,
request_args,
request_kwargs,
request_context,
request_slo_ms,
call_method="__call__",
shard_key=None,
async_future=None):
def __init__(
self,
request_args,
request_kwargs,
request_context,
request_slo_ms,
call_method="__call__",
shard_key=None,
async_future=None,
is_shadow_query=False,
):
self.request_args = request_args
self.request_kwargs = request_kwargs
self.request_context = request_context
@@ -37,6 +40,7 @@ class Query:
self.call_method = call_method
self.shard_key = shard_key
self.is_shadow_query = is_shadow_query
def ray_serialize(self):
# NOTE: this method is needed because Query need to be serialized and
@@ -241,11 +245,11 @@ class Router:
# result.
pass
async def set_traffic(self, endpoint, traffic_dict):
async def set_traffic(self, endpoint, traffic_policy):
logger.debug("Setting traffic for endpoint %s to %s", endpoint,
traffic_dict)
traffic_policy)
async with self.flush_lock:
self.traffic[endpoint] = RandomEndpointPolicy(traffic_dict)
self.traffic[endpoint] = RandomEndpointPolicy(traffic_policy)
self.flush_endpoint_queue(endpoint)
async def remove_endpoint(self, endpoint):
@@ -314,7 +318,13 @@ class Router:
start = time.time()
worker = self.replicas[backend_replica_tag]
try:
result = await worker.handle_request.remote(req)
if req.is_shadow_query:
# No need to actually get the result, but we do need to wait
# until the call completes to mark the worker idle.
asyncio.wait([worker.handle_request.remote(req)])
result = ""
else:
result = await worker.handle_request.remote(req)
except RayTaskError as error:
self.num_error_backend_request.labels(backend=backend).add()
result = error
@@ -356,6 +366,9 @@ class Router:
self.queries_counter[backend_replica_tag] += 1
future = asyncio.get_event_loop().create_task(
self._do_query(backend, backend_replica_tag, request))
chain_future(future, request.async_future)
# For shadow queries, just ignore the result.
if not request.is_shadow_query:
chain_future(future, request.async_future)
worker_queue.appendleft(backend_replica_tag)
+50 -1
View File
@@ -444,9 +444,11 @@ def test_list_endpoints(serve_instance):
serve.create_backend("backend", f)
serve.create_backend("backend2", f)
serve.create_backend("backend3", f)
serve.create_endpoint(
"endpoint", backend="backend", route="/api", methods=["GET", "POST"])
serve.create_endpoint("endpoint2", backend="backend2", methods=["POST"])
serve.shadow_traffic("endpoint", "backend3", 0.5)
endpoints = serve.list_endpoints()
assert "endpoint" in endpoints
@@ -455,6 +457,9 @@ def test_list_endpoints(serve_instance):
"methods": ["GET", "POST"],
"traffic": {
"backend": 1.0
},
"shadows": {
"backend3": 0.5
}
}
@@ -464,7 +469,8 @@ def test_list_endpoints(serve_instance):
"methods": ["POST"],
"traffic": {
"backend2": 1.0
}
},
"shadows": {}
}
serve.delete_endpoint("endpoint")
@@ -576,6 +582,49 @@ def test_shutdown(serve_instance):
assert wait_for_condition(check_dead)
def test_shadow_traffic(serve_instance):
def f():
return "hello"
def f_shadow():
return "oops"
serve.create_backend("backend1", f)
serve.create_backend("backend2", f_shadow)
serve.create_backend("backend3", f_shadow)
serve.create_backend("backend4", f_shadow)
serve.create_endpoint("endpoint", backend="backend1", route="/api")
serve.shadow_traffic("endpoint", "backend2", 1.0)
serve.shadow_traffic("endpoint", "backend3", 0.5)
serve.shadow_traffic("endpoint", "backend4", 0.1)
start = time.time()
num_requests = 100
for _ in range(num_requests):
assert requests.get("http://127.0.0.1:8000/api").text == "hello"
print("Finished 100 requests in {}s.".format(time.time() - start))
def requests_to_backend(backend):
for entry in serve.stat():
if entry["info"]["name"] == "backend_request_counter":
if entry["info"]["backend"] == backend:
return entry["value"]
return 0
def check_requests():
return all([
requests_to_backend("backend1") == num_requests,
requests_to_backend("backend2") == requests_to_backend("backend1"),
requests_to_backend("backend3") < requests_to_backend("backend2"),
requests_to_backend("backend4") < requests_to_backend("backend3"),
requests_to_backend("backend4") > 0,
])
assert wait_for_condition(check_requests)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))
+14 -7
View File
@@ -7,6 +7,7 @@ import ray
from ray import serve
import ray.serve.context as context
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
from ray.serve.master import TrafficPolicy
from ray.serve.request_params import RequestMetadata
from ray.serve.router import Router
from ray.serve.config import BackendConfig
@@ -59,7 +60,7 @@ async def test_runner_actor(serve_instance):
worker = setup_worker(CONSUMER_NAME, echo)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
for query in [333, 444, 555]:
query_param = RequestMetadata(PRODUCER_NAME,
@@ -84,7 +85,7 @@ async def test_ray_serve_mixin(serve_instance):
worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, ))
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
for query in [333, 444, 555]:
query_param = RequestMetadata(PRODUCER_NAME,
@@ -106,7 +107,7 @@ async def test_task_runner_check_context(serve_instance):
worker = setup_worker(CONSUMER_NAME, echo)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
result_oid = q.enqueue_request.remote(query_param, i=42)
@@ -130,7 +131,7 @@ async def test_task_runner_custom_method_single(serve_instance):
worker = setup_worker(CONSUMER_NAME, NonBatcher)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
query_param = RequestMetadata(
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
@@ -179,7 +180,10 @@ async def test_task_runner_custom_method_batch(serve_instance):
worker = setup_worker(
CONSUMER_NAME, Batcher, backend_config=backend_config)
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
await q.set_traffic.remote(PRODUCER_NAME,
TrafficPolicy({
CONSUMER_NAME: 1.0
}))
await q.set_backend_config.remote(CONSUMER_NAME, backend_config)
def make_request_param(call_method):
@@ -228,7 +232,10 @@ async def test_task_runner_perform_batch(serve_instance):
worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
await q.set_backend_config.remote(CONSUMER_NAME, config)
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
await q.set_traffic.remote(PRODUCER_NAME,
TrafficPolicy({
CONSUMER_NAME: 1.0
}))
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
@@ -268,7 +275,7 @@ async def test_task_runner_perform_async(serve_instance):
worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
await q.set_backend_config.remote(CONSUMER_NAME, config)
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
+12 -10
View File
@@ -4,6 +4,7 @@ from collections import defaultdict
import pytest
import ray
from ray.serve.master import TrafficPolicy
from ray.serve.router import Router
from ray.serve.request_params import RequestMetadata
from ray.serve.utils import get_random_letters
@@ -47,7 +48,7 @@ def task_runner_mock_actor():
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
q.set_traffic.remote("svc", {"backend-single-prod": 1.0})
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
q.add_new_worker.remote("backend-single-prod", "replica-1",
task_runner_mock_actor)
@@ -63,7 +64,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
async def test_slo(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.set_traffic.remote("svc", {"backend-slo": 1.0})
await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0}))
all_request_sent = []
for i in range(10):
@@ -88,14 +89,14 @@ async def test_slo(serve_instance, task_runner_mock_actor):
async def test_alter_backend(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.set_traffic.remote("svc", {"backend-alter": 1})
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
await q.add_new_worker.remote("backend-alter", "replica-1",
task_runner_mock_actor)
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.request_args[0] == 1
await q.set_traffic.remote("svc", {"backend-alter-2": 1})
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1}))
await q.add_new_worker.remote("backend-alter-2", "replica-1",
task_runner_mock_actor)
await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
@@ -106,10 +107,11 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.set_traffic.remote("svc", {
"backend-split": 0.5,
"backend-split-2": 0.5
})
await q.set_traffic.remote(
"svc", TrafficPolicy({
"backend-split": 0.5,
"backend-split-2": 0.5
}))
runner_1, runner_2 = [mock_task_runner() for _ in range(2)]
await q.add_new_worker.remote("backend-split", "replica-1", runner_1)
await q.add_new_worker.remote("backend-split-2", "replica-1", runner_2)
@@ -148,7 +150,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
backend_name = "backend-split-" + str(i)
traffic_dict[backend_name] = 1.0 / num_backends
await q.add_new_worker.remote(backend_name, "replica-1", runner)
await q.set_traffic.remote("svc", traffic_dict)
await q.set_traffic.remote("svc", TrafficPolicy(traffic_dict))
# Generate random shard keys and send one request for each.
shard_keys = [get_random_letters() for _ in range(100)]
@@ -196,7 +198,7 @@ async def test_router_use_max_concurrency(serve_instance):
q = ray.remote(VisibleRouter).remote()
BACKEND_NAME = "max-concurrent-test"
config = BackendConfig({"max_concurrent_queries": 1})
await q.set_traffic.remote("svc", {BACKEND_NAME: 1.0})
await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0}))
await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker)
await q.set_backend_config.remote(BACKEND_NAME, config)