From aa3fd62cac3a0f759f7d884eae3d202a5c72103b Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Thu, 25 Jun 2020 10:55:22 -0500 Subject: [PATCH] [serve] Add shadow traffic API (#9106) --- doc/source/serve/advanced.rst | 32 ++++ doc/source/serve/package-ref.rst | 1 + python/ray/serve/__init__.py | 6 +- python/ray/serve/api.py | 25 +++ python/ray/serve/master.py | 146 ++++++++++++------ python/ray/serve/policy.py | 38 ++++- python/ray/serve/router.py | 39 +++-- python/ray/serve/tests/test_api.py | 51 +++++- python/ray/serve/tests/test_backend_worker.py | 21 ++- python/ray/serve/tests/test_router.py | 22 +-- 10 files changed, 294 insertions(+), 87 deletions(-) diff --git a/doc/source/serve/advanced.rst b/doc/source/serve/advanced.rst index 9980ff757..2c61bb945 100644 --- a/doc/source/serve/advanced.rst +++ b/doc/source/serve/advanced.rst @@ -185,6 +185,38 @@ The shard key can either be specified via the X-SERVE-SHARD-KEY HTTP header or ` handle = serve.get_handle("api_endpoint") handler.options(shard_key=session_id).remote(args) +Shadow Testing +-------------- + +Sometimes when deploying a new backend, you may want to test it out without affecting the results seen by users. +You can do this with :mod:`shadow_traffic `, which allows you to duplicate requests to multiple backends for testing while still having them served by the set of backends specified via :mod:`set_traffic `. +Metrics about these requests are recorded as usual so you can use them to validate model performance. +This is demonstrated in the example below, where we create an endpoint serviced by a single backend but shadow traffic to two other backends for testing. + +.. code-block:: python + + serve.create_backend("existing_backend", MyClass) + + # All traffic is served by the existing backend. + serve.create_endpoint("shadowed_endpoint", backend="existing_backend", route="/shadow") + + # Create two new backends that we want to test. + serve.create_backend("new_backend_1", MyNewClass) + serve.create_backend("new_backend_2", MyNewClass) + + # Shadow traffic to the two new backends. This does not influence the result + # of requests to the endpoint, but a proportion of requests are + # *additionally* sent to these backends. + + # Send 50% of all queries to the endpoint new_backend_1. + serve.shadow_traffic("shadowed_endpoint", "new_backend_1", 0.5) + # Send 10% of all queries to the endpoint new_backend_2. + serve.shadow_traffic("shadowed_endpoint", "new_backend_2", 0.1) + + # Stop shadowing traffic to the backends. + serve.shadow_traffic("shadowed_endpoint", "new_backend_1", 0) + serve.shadow_traffic("shadowed_endpoint", "new_backend_2", 0) + Composing Multiple Models ========================= Ray Serve supports composing individually scalable models into a single model diff --git a/doc/source/serve/package-ref.rst b/doc/source/serve/package-ref.rst index 1236a1081..2d4ca81a7 100644 --- a/doc/source/serve/package-ref.rst +++ b/doc/source/serve/package-ref.rst @@ -14,6 +14,7 @@ APIs for Managing Endpoints .. autofunction:: ray.serve.list_endpoints .. autofunction:: ray.serve.delete_endpoint .. autofunction:: ray.serve.set_traffic +.. autofunction:: ray.serve.shadow_traffic APIs for Managing Backends diff --git a/python/ray/serve/__init__.py b/python/ray/serve/__init__.py index ed6d9f697..91845cca2 100644 --- a/python/ray/serve/__init__.py +++ b/python/ray/serve/__init__.py @@ -1,7 +1,8 @@ from ray.serve.api import ( init, create_backend, delete_backend, create_endpoint, delete_endpoint, - set_traffic, get_handle, stat, update_backend_config, get_backend_config, - accept_batch, list_backends, list_endpoints, shutdown) # noqa: E402 + set_traffic, shadow_traffic, get_handle, stat, update_backend_config, + get_backend_config, accept_batch, list_backends, list_endpoints, + shutdown) # noqa: E402 __all__ = [ "init", @@ -10,6 +11,7 @@ __all__ = [ "create_endpoint", "delete_endpoint", "set_traffic", + "shadow_traffic", "get_handle", "stat", "update_backend_config", diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 736bfc895..1bb5ae990 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -295,6 +295,31 @@ def set_traffic(endpoint_name, traffic_policy_dictionary): traffic_policy_dictionary)) +@_ensure_connected +def shadow_traffic(endpoint_name, backend_tag, proportion): + """Shadow traffic from an endpoint to a backend. + + The specified proportion of requests will be duplicated and sent to the + backend. Responses of the duplicated traffic will be ignored. + The backend must not already be in use. + + To stop shadowing traffic to a backend, call `shadow_traffic` with + proportion equal to 0. + + Args: + endpoint_name (str): A registered service endpoint. + backend_tag (str): A registered backend. + proportion (float): The proportion of traffic from 0 to 1. + """ + + if not isinstance(proportion, (float, int)) or not 0 <= proportion <= 1: + raise TypeError("proportion must be a float from 0 to 1.") + + ray.get( + master_actor.shadow_traffic.remote(endpoint_name, backend_tag, + proportion)) + + @_ensure_connected def get_handle(endpoint_name, relative_slo_ms=None, diff --git a/python/ray/serve/master.py b/python/ray/serve/master.py index ee39f71d6..02c7a4e54 100644 --- a/python/ray/serve/master.py +++ b/python/ray/serve/master.py @@ -1,5 +1,5 @@ import asyncio -from collections import defaultdict +from collections import defaultdict, namedtuple import os import random import time @@ -29,6 +29,39 @@ CHECKPOINT_KEY = "serve-master-checkpoint" _RESOURCE_CHECK_ENABLED = True +class TrafficPolicy: + def __init__(self, traffic_dict): + self.traffic_dict = dict() + self.shadow_dict = dict() + self.set_traffic_dict(traffic_dict) + + def set_traffic_dict(self, traffic_dict): + prob = 0 + for backend, weight in traffic_dict.items(): + if weight < 0: + raise ValueError( + "Attempted to assign a weight of {} to backend '{}'. " + "Weights cannot be negative.".format(weight, backend)) + prob += weight + + # These weights will later be plugged into np.random.choice, which + # uses a tolerance of 1e-8. + if not np.isclose(prob, 1, atol=1e-8): + raise ValueError("Traffic dictionary weights must sum to 1, " + "currently they sum to {}".format(prob)) + self.traffic_dict = traffic_dict + + def set_shadow(self, backend, proportion): + if proportion == 0 and backend in self.shadow_dict: + del self.shadow_dict[backend] + else: + self.shadow_dict[backend] = proportion + + +BackendInfo = namedtuple("BackendInfo", + ["worker_class", "backend_config", "replica_config"]) + + @ray.remote class ServeMaster: """Responsible for managing the state of the serving system. @@ -63,9 +96,9 @@ class ServeMaster: # Used to read/write checkpoints. self.kv_store = RayInternalKVStore(namespace=instance_name) # path -> (endpoint, methods). - self.routes = {} - # backend -> (backend_worker, backend_config, replica_config). - self.backends = {} + self.routes = dict() + # backend -> BackendInfo. + self.backends = dict() # backend -> replica_tags. self.replicas = defaultdict(list) # replicas that should be started if recovering from a checkpoint. @@ -78,7 +111,7 @@ class ServeMaster: # endpoints that should be removed from the router if recovering from a # checkpoint. self.endpoints_to_remove = list() - # endpoint -> traffic_dict + # endpoint -> TrafficPolicy self.traffic_policies = dict() # Dictionary of backend tag to dictionaries of replica tag to worker. # TODO(edoakes): consider removing this and just using the names. @@ -258,9 +291,9 @@ class ServeMaster: await self.router.add_new_worker.remote( backend_tag, replica_tag, worker) - for backend, (_, backend_config, _) in self.backends.items(): + for backend, info in self.backends.items(): await self.router.set_backend_config.remote( - backend, backend_config) + backend, info.backend_config) await self.broadcast_backend_config(backend) # Push configuration state to the HTTP proxy. @@ -282,8 +315,8 @@ class ServeMaster: def get_backend_configs(self): """Fetched by the router on startup.""" backend_configs = {} - for backend, (_, backend_config, _) in self.backends.items(): - backend_configs[backend] = backend_config + for backend, info in self.backends.items(): + backend_configs[backend] = info.backend_config return backend_configs def get_traffic_policies(self): @@ -306,19 +339,18 @@ class ServeMaster: """ logger.debug("Starting worker '{}' for backend '{}'.".format( replica_tag, backend_tag)) - (backend_worker, backend_config, - replica_config) = self.backends[backend_tag] + backend_info = self.backends[backend_tag] replica_name = format_actor_name(replica_tag, self.instance_name) - worker_handle = ray.remote(backend_worker).options( + worker_handle = ray.remote(backend_info.worker_class).options( name=replica_name, max_restarts=-1, max_task_retries=-1, - **replica_config.ray_actor_options).remote( + **backend_info.replica_config.ray_actor_options).remote( backend_tag, replica_tag, - replica_config.actor_init_args, - backend_config, + backend_info.replica_config.actor_init_args, + backend_info.backend_config, instance_name=self.instance_name) # TODO(edoakes): we should probably have a timeout here. await worker_handle.ready.remote() @@ -427,11 +459,11 @@ class ServeMaster: current_num_replicas = len(self.replicas[backend_tag]) delta_num_replicas = num_replicas - current_num_replicas - _, _, replica_config = self.backends[backend_tag] + backend_info = self.backends[backend_tag] if delta_num_replicas > 0: can_schedule = try_schedule_resources_on_nodes( requirements=[ - replica_config.resource_dict + backend_info.replica_config.resource_dict for _ in range(delta_num_replicas) ], ray_nodes=ray.nodes()) @@ -473,18 +505,27 @@ class ServeMaster: def get_all_backends(self): """Returns a dictionary of backend tag to backend config dict.""" backends = {} - for backend_tag, (_, config, _) in self.backends.items(): - backends[backend_tag] = config.__dict__ + for backend_tag, backend_info in self.backends.items(): + backends[backend_tag] = backend_info.backend_config.__dict__ return backends def get_all_endpoints(self): """Returns a dictionary of endpoint to endpoint config.""" endpoints = {} for route, (endpoint, methods) in self.routes.items(): + if endpoint in self.traffic_policies: + traffic_policy = self.traffic_policies[endpoint] + traffic_dict = traffic_policy.traffic_dict + shadow_dict = traffic_policy.shadow_dict + else: + traffic_dict = {} + shadow_dict = {} + endpoints[endpoint] = { "route": route if route.startswith("/") else None, "methods": methods, - "traffic": self.traffic_policies.get(endpoint, {}) + "traffic": traffic_dict, + "shadows": shadow_dict, } return endpoints @@ -494,38 +535,51 @@ class ServeMaster: " that is not registered.".format(endpoint_name)) assert isinstance(traffic_dict, - dict), "Traffic policy must be dictionary" - prob = 0 - for backend, weight in traffic_dict.items(): - if weight < 0: - raise ValueError( - "Attempted to assign a weight of {} to backend '{}'. " - "Weights cannot be negative.".format(weight, backend)) - prob += weight + dict), "Traffic policy must be a dictionary." + + for backend in traffic_dict: if backend not in self.backends: raise ValueError( "Attempted to assign traffic to a backend '{}' that " "is not registered.".format(backend)) - # These weights will later be plugged into np.random.choice, which - # uses a tolerance of 1e-8. - assert np.isclose( - prob, 1, atol=1e-8 - ), "weights must sum to 1, currently they sum to {}".format(prob) - - self.traffic_policies[endpoint_name] = traffic_dict + traffic_policy = TrafficPolicy(traffic_dict) + self.traffic_policies[endpoint_name] = traffic_policy # NOTE(edoakes): we must write a checkpoint before pushing the # update to avoid inconsistent state if we crash after pushing the # update. self._checkpoint() - await self.router.set_traffic.remote(endpoint_name, traffic_dict) + await self.router.set_traffic.remote(endpoint_name, traffic_policy) async def set_traffic(self, endpoint_name, traffic_dict): """Sets the traffic policy for the specified endpoint.""" async with self.write_lock: await self._set_traffic(endpoint_name, traffic_dict) + async def shadow_traffic(self, endpoint_name, backend_tag, proportion): + """Shadow traffic from the endpoint to the backend.""" + async with self.write_lock: + if endpoint_name not in self.get_all_endpoints(): + raise ValueError("Attempted to shadow traffic from an " + "endpoint '{}' that is not registered." + .format(endpoint_name)) + + if backend_tag not in self.backends: + raise ValueError( + "Attempted to shadow traffic to a backend '{}' that " + "is not registered.".format(backend_tag)) + + self.traffic_policies[endpoint_name].set_shadow( + backend_tag, proportion) + + # NOTE(edoakes): we must write a checkpoint before pushing the + # update to avoid inconsistent state if we crash after pushing the + # update. + self._checkpoint() + await self.router.set_traffic.remote( + endpoint_name, self.traffic_policies[endpoint_name]) + async def create_endpoint(self, endpoint, traffic_dict, route, methods): """Create a new endpoint with the specified route and methods. @@ -610,8 +664,8 @@ class ServeMaster: # Save creator that starts replicas, the arguments to be passed in, # and the configuration for the backends. - self.backends[backend_tag] = (backend_worker, backend_config, - replica_config) + self.backends[backend_tag] = BackendInfo( + backend_worker, backend_config, replica_config) self._scale_replicas(backend_tag, backend_config.num_replicas) @@ -635,8 +689,9 @@ class ServeMaster: return # Check that the specified backend isn't used by any endpoints. - for endpoint, traffic_dict in self.traffic_policies.items(): - if backend_tag in traffic_dict: + for endpoint, traffic_policy in self.traffic_policies.items(): + if (backend_tag in traffic_policy.traffic_dict + or backend_tag in traffic_policy.shadow_dict): raise ValueError("Backend '{}' is used by endpoint '{}' " "and cannot be deleted. Please remove " "the backend from all endpoints and try " @@ -665,12 +720,9 @@ class ServeMaster: assert (backend_tag in self.backends ), "Backend {} is not registered.".format(backend_tag) assert isinstance(config_options, dict) - backend_worker, backend_config, replica_config = self.backends[ - backend_tag] - backend_config.update(config_options) - self.backends[backend_tag] = (backend_worker, backend_config, - replica_config) + self.backends[backend_tag].backend_config.update(config_options) + backend_config = self.backends[backend_tag].backend_config # Scale the replicas with the new configuration. self._scale_replicas(backend_tag, backend_config.num_replicas) @@ -691,7 +743,7 @@ class ServeMaster: await self.broadcast_backend_config(backend_tag) async def broadcast_backend_config(self, backend_tag): - _, backend_config, _ = self.backends[backend_tag] + backend_config = self.backends[backend_tag].backend_config broadcast_futures = [] for replica_tag in self.replicas[backend_tag]: try: @@ -708,7 +760,7 @@ class ServeMaster: """Get the current config for the specified backend.""" assert (backend_tag in self.backends ), "Backend {} is not registered.".format(backend_tag) - return self.backends[backend_tag][2] + return self.backends[backend_tag].backend_config async def shutdown(self): """Shuts down the serve instance completely.""" diff --git a/python/ray/serve/policy.py b/python/ray/serve/policy.py index aaa511af8..9666e9fd5 100644 --- a/python/ray/serve/policy.py +++ b/python/ray/serve/policy.py @@ -1,4 +1,5 @@ from abc import ABCMeta, abstractmethod +import copy from hashlib import sha256 import numpy as np @@ -47,12 +48,29 @@ class RandomEndpointPolicy(EndpointPolicy): be made deterministically based on the hash of the shard key. """ - def __init__(self, traffic_dict): - self.backend_names, self.backend_weights = zip( - *sorted(traffic_dict.items())) + def __init__(self, traffic_policy): + self.backends = sorted(traffic_policy.traffic_dict.items()) + self.shadow_backends = list(traffic_policy.shadow_dict.items()) + + def _select_backends(self, val): + curr_sum = 0 + for name, weight in self.backends: + curr_sum += weight + if curr_sum > val: + chosen_backend = name + break + else: + assert False, "This should never be reached." + + shadow_backends = [] + for backend, backend_weight in self.shadow_backends: + if val < backend_weight: + shadow_backends.append(backend) + + return chosen_backend, shadow_backends def flush(self, endpoint_queue, backend_queues): - if len(self.backend_names) == 0: + if len(self.backends) == 0: logger.info("No backends to assign traffic to.") return set() @@ -67,11 +85,17 @@ class RandomEndpointPolicy(EndpointPolicy): # Note(simon): This constructor takes 100+us, maybe cache this? rstate = np.random.RandomState(seed) - chosen_backend = rstate.choice( - self.backend_names, replace=False, - p=self.backend_weights).squeeze() + chosen_backend, shadow_backends = self._select_backends( + rstate.random()) assigned_backends.add(chosen_backend) backend_queues[chosen_backend].add(query) + if len(shadow_backends) > 0: + shadow_query = copy.copy(query) + shadow_query.async_future = None + shadow_query.is_shadow_query = True + for shadow_backend in shadow_backends: + assigned_backends.add(shadow_backend) + backend_queues[shadow_backend].add(shadow_query) return assigned_backends diff --git a/python/ray/serve/router.py b/python/ray/serve/router.py index 5ee9e37de..ab364e3ca 100644 --- a/python/ray/serve/router.py +++ b/python/ray/serve/router.py @@ -17,14 +17,17 @@ from ray.serve.utils import logger, chain_future class Query: - def __init__(self, - request_args, - request_kwargs, - request_context, - request_slo_ms, - call_method="__call__", - shard_key=None, - async_future=None): + def __init__( + self, + request_args, + request_kwargs, + request_context, + request_slo_ms, + call_method="__call__", + shard_key=None, + async_future=None, + is_shadow_query=False, + ): self.request_args = request_args self.request_kwargs = request_kwargs self.request_context = request_context @@ -37,6 +40,7 @@ class Query: self.call_method = call_method self.shard_key = shard_key + self.is_shadow_query = is_shadow_query def ray_serialize(self): # NOTE: this method is needed because Query need to be serialized and @@ -241,11 +245,11 @@ class Router: # result. pass - async def set_traffic(self, endpoint, traffic_dict): + async def set_traffic(self, endpoint, traffic_policy): logger.debug("Setting traffic for endpoint %s to %s", endpoint, - traffic_dict) + traffic_policy) async with self.flush_lock: - self.traffic[endpoint] = RandomEndpointPolicy(traffic_dict) + self.traffic[endpoint] = RandomEndpointPolicy(traffic_policy) self.flush_endpoint_queue(endpoint) async def remove_endpoint(self, endpoint): @@ -314,7 +318,13 @@ class Router: start = time.time() worker = self.replicas[backend_replica_tag] try: - result = await worker.handle_request.remote(req) + if req.is_shadow_query: + # No need to actually get the result, but we do need to wait + # until the call completes to mark the worker idle. + asyncio.wait([worker.handle_request.remote(req)]) + result = "" + else: + result = await worker.handle_request.remote(req) except RayTaskError as error: self.num_error_backend_request.labels(backend=backend).add() result = error @@ -356,6 +366,9 @@ class Router: self.queries_counter[backend_replica_tag] += 1 future = asyncio.get_event_loop().create_task( self._do_query(backend, backend_replica_tag, request)) - chain_future(future, request.async_future) + + # For shadow queries, just ignore the result. + if not request.is_shadow_query: + chain_future(future, request.async_future) worker_queue.appendleft(backend_replica_tag) diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 527165cfc..dad1fd7da 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -444,9 +444,11 @@ def test_list_endpoints(serve_instance): serve.create_backend("backend", f) serve.create_backend("backend2", f) + serve.create_backend("backend3", f) serve.create_endpoint( "endpoint", backend="backend", route="/api", methods=["GET", "POST"]) serve.create_endpoint("endpoint2", backend="backend2", methods=["POST"]) + serve.shadow_traffic("endpoint", "backend3", 0.5) endpoints = serve.list_endpoints() assert "endpoint" in endpoints @@ -455,6 +457,9 @@ def test_list_endpoints(serve_instance): "methods": ["GET", "POST"], "traffic": { "backend": 1.0 + }, + "shadows": { + "backend3": 0.5 } } @@ -464,7 +469,8 @@ def test_list_endpoints(serve_instance): "methods": ["POST"], "traffic": { "backend2": 1.0 - } + }, + "shadows": {} } serve.delete_endpoint("endpoint") @@ -576,6 +582,49 @@ def test_shutdown(serve_instance): assert wait_for_condition(check_dead) +def test_shadow_traffic(serve_instance): + def f(): + return "hello" + + def f_shadow(): + return "oops" + + serve.create_backend("backend1", f) + serve.create_backend("backend2", f_shadow) + serve.create_backend("backend3", f_shadow) + serve.create_backend("backend4", f_shadow) + + serve.create_endpoint("endpoint", backend="backend1", route="/api") + serve.shadow_traffic("endpoint", "backend2", 1.0) + serve.shadow_traffic("endpoint", "backend3", 0.5) + serve.shadow_traffic("endpoint", "backend4", 0.1) + + start = time.time() + num_requests = 100 + for _ in range(num_requests): + assert requests.get("http://127.0.0.1:8000/api").text == "hello" + print("Finished 100 requests in {}s.".format(time.time() - start)) + + def requests_to_backend(backend): + for entry in serve.stat(): + if entry["info"]["name"] == "backend_request_counter": + if entry["info"]["backend"] == backend: + return entry["value"] + + return 0 + + def check_requests(): + return all([ + requests_to_backend("backend1") == num_requests, + requests_to_backend("backend2") == requests_to_backend("backend1"), + requests_to_backend("backend3") < requests_to_backend("backend2"), + requests_to_backend("backend4") < requests_to_backend("backend3"), + requests_to_backend("backend4") > 0, + ]) + + assert wait_for_condition(check_requests) + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_backend_worker.py b/python/ray/serve/tests/test_backend_worker.py index 4cda407a5..5beff28ec 100644 --- a/python/ray/serve/tests/test_backend_worker.py +++ b/python/ray/serve/tests/test_backend_worker.py @@ -7,6 +7,7 @@ import ray from ray import serve import ray.serve.context as context from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error +from ray.serve.master import TrafficPolicy from ray.serve.request_params import RequestMetadata from ray.serve.router import Router from ray.serve.config import BackendConfig @@ -59,7 +60,7 @@ async def test_runner_actor(serve_instance): worker = setup_worker(CONSUMER_NAME, echo) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) + q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) for query in [333, 444, 555]: query_param = RequestMetadata(PRODUCER_NAME, @@ -84,7 +85,7 @@ async def test_ray_serve_mixin(serve_instance): worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, )) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) + q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) for query in [333, 444, 555]: query_param = RequestMetadata(PRODUCER_NAME, @@ -106,7 +107,7 @@ async def test_task_runner_check_context(serve_instance): worker = setup_worker(CONSUMER_NAME, echo) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) + q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) result_oid = q.enqueue_request.remote(query_param, i=42) @@ -130,7 +131,7 @@ async def test_task_runner_custom_method_single(serve_instance): worker = setup_worker(CONSUMER_NAME, NonBatcher) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) - q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) + q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) query_param = RequestMetadata( PRODUCER_NAME, context.TaskContext.Python, call_method="a") @@ -179,7 +180,10 @@ async def test_task_runner_custom_method_batch(serve_instance): worker = setup_worker( CONSUMER_NAME, Batcher, backend_config=backend_config) - await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) + await q.set_traffic.remote(PRODUCER_NAME, + TrafficPolicy({ + CONSUMER_NAME: 1.0 + })) await q.set_backend_config.remote(CONSUMER_NAME, backend_config) def make_request_param(call_method): @@ -228,7 +232,10 @@ async def test_task_runner_perform_batch(serve_instance): worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) await q.set_backend_config.remote(CONSUMER_NAME, config) - await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) + await q.set_traffic.remote(PRODUCER_NAME, + TrafficPolicy({ + CONSUMER_NAME: 1.0 + })) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) @@ -268,7 +275,7 @@ async def test_task_runner_perform_async(serve_instance): worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config) await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker) await q.set_backend_config.remote(CONSUMER_NAME, config) - q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0}) + q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0})) query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python) diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index 642a2fb0b..54b29c60a 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -4,6 +4,7 @@ from collections import defaultdict import pytest import ray +from ray.serve.master import TrafficPolicy from ray.serve.router import Router from ray.serve.request_params import RequestMetadata from ray.serve.utils import get_random_letters @@ -47,7 +48,7 @@ def task_runner_mock_actor(): async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() - q.set_traffic.remote("svc", {"backend-single-prod": 1.0}) + q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0})) q.add_new_worker.remote("backend-single-prod", "replica-1", task_runner_mock_actor) @@ -63,7 +64,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor): async def test_slo(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() - await q.set_traffic.remote("svc", {"backend-slo": 1.0}) + await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0})) all_request_sent = [] for i in range(10): @@ -88,14 +89,14 @@ async def test_slo(serve_instance, task_runner_mock_actor): async def test_alter_backend(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() - await q.set_traffic.remote("svc", {"backend-alter": 1}) + await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1})) await q.add_new_worker.remote("backend-alter", "replica-1", task_runner_mock_actor) await q.enqueue_request.remote(RequestMetadata("svc", None), 1) got_work = await task_runner_mock_actor.get_recent_call.remote() assert got_work.request_args[0] == 1 - await q.set_traffic.remote("svc", {"backend-alter-2": 1}) + await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1})) await q.add_new_worker.remote("backend-alter-2", "replica-1", task_runner_mock_actor) await q.enqueue_request.remote(RequestMetadata("svc", None), 2) @@ -106,10 +107,11 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor): async def test_split_traffic_random(serve_instance, task_runner_mock_actor): q = ray.remote(Router).remote() - await q.set_traffic.remote("svc", { - "backend-split": 0.5, - "backend-split-2": 0.5 - }) + await q.set_traffic.remote( + "svc", TrafficPolicy({ + "backend-split": 0.5, + "backend-split-2": 0.5 + })) runner_1, runner_2 = [mock_task_runner() for _ in range(2)] await q.add_new_worker.remote("backend-split", "replica-1", runner_1) await q.add_new_worker.remote("backend-split-2", "replica-1", runner_2) @@ -148,7 +150,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor): backend_name = "backend-split-" + str(i) traffic_dict[backend_name] = 1.0 / num_backends await q.add_new_worker.remote(backend_name, "replica-1", runner) - await q.set_traffic.remote("svc", traffic_dict) + await q.set_traffic.remote("svc", TrafficPolicy(traffic_dict)) # Generate random shard keys and send one request for each. shard_keys = [get_random_letters() for _ in range(100)] @@ -196,7 +198,7 @@ async def test_router_use_max_concurrency(serve_instance): q = ray.remote(VisibleRouter).remote() BACKEND_NAME = "max-concurrent-test" config = BackendConfig({"max_concurrent_queries": 1}) - await q.set_traffic.remote("svc", {BACKEND_NAME: 1.0}) + await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0})) await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker) await q.set_backend_config.remote(BACKEND_NAME, config)