mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:23:17 +08:00
[serve] Add shadow traffic API (#9106)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from ray.serve.api import (
|
||||
init, create_backend, delete_backend, create_endpoint, delete_endpoint,
|
||||
set_traffic, get_handle, stat, update_backend_config, get_backend_config,
|
||||
accept_batch, list_backends, list_endpoints, shutdown) # noqa: E402
|
||||
set_traffic, shadow_traffic, get_handle, stat, update_backend_config,
|
||||
get_backend_config, accept_batch, list_backends, list_endpoints,
|
||||
shutdown) # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"init",
|
||||
@@ -10,6 +11,7 @@ __all__ = [
|
||||
"create_endpoint",
|
||||
"delete_endpoint",
|
||||
"set_traffic",
|
||||
"shadow_traffic",
|
||||
"get_handle",
|
||||
"stat",
|
||||
"update_backend_config",
|
||||
|
||||
@@ -295,6 +295,31 @@ def set_traffic(endpoint_name, traffic_policy_dictionary):
|
||||
traffic_policy_dictionary))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def shadow_traffic(endpoint_name, backend_tag, proportion):
|
||||
"""Shadow traffic from an endpoint to a backend.
|
||||
|
||||
The specified proportion of requests will be duplicated and sent to the
|
||||
backend. Responses of the duplicated traffic will be ignored.
|
||||
The backend must not already be in use.
|
||||
|
||||
To stop shadowing traffic to a backend, call `shadow_traffic` with
|
||||
proportion equal to 0.
|
||||
|
||||
Args:
|
||||
endpoint_name (str): A registered service endpoint.
|
||||
backend_tag (str): A registered backend.
|
||||
proportion (float): The proportion of traffic from 0 to 1.
|
||||
"""
|
||||
|
||||
if not isinstance(proportion, (float, int)) or not 0 <= proportion <= 1:
|
||||
raise TypeError("proportion must be a float from 0 to 1.")
|
||||
|
||||
ray.get(
|
||||
master_actor.shadow_traffic.remote(endpoint_name, backend_tag,
|
||||
proportion))
|
||||
|
||||
|
||||
@_ensure_connected
|
||||
def get_handle(endpoint_name,
|
||||
relative_slo_ms=None,
|
||||
|
||||
+99
-47
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, namedtuple
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
@@ -29,6 +29,39 @@ CHECKPOINT_KEY = "serve-master-checkpoint"
|
||||
_RESOURCE_CHECK_ENABLED = True
|
||||
|
||||
|
||||
class TrafficPolicy:
|
||||
def __init__(self, traffic_dict):
|
||||
self.traffic_dict = dict()
|
||||
self.shadow_dict = dict()
|
||||
self.set_traffic_dict(traffic_dict)
|
||||
|
||||
def set_traffic_dict(self, traffic_dict):
|
||||
prob = 0
|
||||
for backend, weight in traffic_dict.items():
|
||||
if weight < 0:
|
||||
raise ValueError(
|
||||
"Attempted to assign a weight of {} to backend '{}'. "
|
||||
"Weights cannot be negative.".format(weight, backend))
|
||||
prob += weight
|
||||
|
||||
# These weights will later be plugged into np.random.choice, which
|
||||
# uses a tolerance of 1e-8.
|
||||
if not np.isclose(prob, 1, atol=1e-8):
|
||||
raise ValueError("Traffic dictionary weights must sum to 1, "
|
||||
"currently they sum to {}".format(prob))
|
||||
self.traffic_dict = traffic_dict
|
||||
|
||||
def set_shadow(self, backend, proportion):
|
||||
if proportion == 0 and backend in self.shadow_dict:
|
||||
del self.shadow_dict[backend]
|
||||
else:
|
||||
self.shadow_dict[backend] = proportion
|
||||
|
||||
|
||||
BackendInfo = namedtuple("BackendInfo",
|
||||
["worker_class", "backend_config", "replica_config"])
|
||||
|
||||
|
||||
@ray.remote
|
||||
class ServeMaster:
|
||||
"""Responsible for managing the state of the serving system.
|
||||
@@ -63,9 +96,9 @@ class ServeMaster:
|
||||
# Used to read/write checkpoints.
|
||||
self.kv_store = RayInternalKVStore(namespace=instance_name)
|
||||
# path -> (endpoint, methods).
|
||||
self.routes = {}
|
||||
# backend -> (backend_worker, backend_config, replica_config).
|
||||
self.backends = {}
|
||||
self.routes = dict()
|
||||
# backend -> BackendInfo.
|
||||
self.backends = dict()
|
||||
# backend -> replica_tags.
|
||||
self.replicas = defaultdict(list)
|
||||
# replicas that should be started if recovering from a checkpoint.
|
||||
@@ -78,7 +111,7 @@ class ServeMaster:
|
||||
# endpoints that should be removed from the router if recovering from a
|
||||
# checkpoint.
|
||||
self.endpoints_to_remove = list()
|
||||
# endpoint -> traffic_dict
|
||||
# endpoint -> TrafficPolicy
|
||||
self.traffic_policies = dict()
|
||||
# Dictionary of backend tag to dictionaries of replica tag to worker.
|
||||
# TODO(edoakes): consider removing this and just using the names.
|
||||
@@ -258,9 +291,9 @@ class ServeMaster:
|
||||
await self.router.add_new_worker.remote(
|
||||
backend_tag, replica_tag, worker)
|
||||
|
||||
for backend, (_, backend_config, _) in self.backends.items():
|
||||
for backend, info in self.backends.items():
|
||||
await self.router.set_backend_config.remote(
|
||||
backend, backend_config)
|
||||
backend, info.backend_config)
|
||||
await self.broadcast_backend_config(backend)
|
||||
|
||||
# Push configuration state to the HTTP proxy.
|
||||
@@ -282,8 +315,8 @@ class ServeMaster:
|
||||
def get_backend_configs(self):
|
||||
"""Fetched by the router on startup."""
|
||||
backend_configs = {}
|
||||
for backend, (_, backend_config, _) in self.backends.items():
|
||||
backend_configs[backend] = backend_config
|
||||
for backend, info in self.backends.items():
|
||||
backend_configs[backend] = info.backend_config
|
||||
return backend_configs
|
||||
|
||||
def get_traffic_policies(self):
|
||||
@@ -306,19 +339,18 @@ class ServeMaster:
|
||||
"""
|
||||
logger.debug("Starting worker '{}' for backend '{}'.".format(
|
||||
replica_tag, backend_tag))
|
||||
(backend_worker, backend_config,
|
||||
replica_config) = self.backends[backend_tag]
|
||||
backend_info = self.backends[backend_tag]
|
||||
|
||||
replica_name = format_actor_name(replica_tag, self.instance_name)
|
||||
worker_handle = ray.remote(backend_worker).options(
|
||||
worker_handle = ray.remote(backend_info.worker_class).options(
|
||||
name=replica_name,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
**replica_config.ray_actor_options).remote(
|
||||
**backend_info.replica_config.ray_actor_options).remote(
|
||||
backend_tag,
|
||||
replica_tag,
|
||||
replica_config.actor_init_args,
|
||||
backend_config,
|
||||
backend_info.replica_config.actor_init_args,
|
||||
backend_info.backend_config,
|
||||
instance_name=self.instance_name)
|
||||
# TODO(edoakes): we should probably have a timeout here.
|
||||
await worker_handle.ready.remote()
|
||||
@@ -427,11 +459,11 @@ class ServeMaster:
|
||||
current_num_replicas = len(self.replicas[backend_tag])
|
||||
delta_num_replicas = num_replicas - current_num_replicas
|
||||
|
||||
_, _, replica_config = self.backends[backend_tag]
|
||||
backend_info = self.backends[backend_tag]
|
||||
if delta_num_replicas > 0:
|
||||
can_schedule = try_schedule_resources_on_nodes(
|
||||
requirements=[
|
||||
replica_config.resource_dict
|
||||
backend_info.replica_config.resource_dict
|
||||
for _ in range(delta_num_replicas)
|
||||
],
|
||||
ray_nodes=ray.nodes())
|
||||
@@ -473,18 +505,27 @@ class ServeMaster:
|
||||
def get_all_backends(self):
|
||||
"""Returns a dictionary of backend tag to backend config dict."""
|
||||
backends = {}
|
||||
for backend_tag, (_, config, _) in self.backends.items():
|
||||
backends[backend_tag] = config.__dict__
|
||||
for backend_tag, backend_info in self.backends.items():
|
||||
backends[backend_tag] = backend_info.backend_config.__dict__
|
||||
return backends
|
||||
|
||||
def get_all_endpoints(self):
|
||||
"""Returns a dictionary of endpoint to endpoint config."""
|
||||
endpoints = {}
|
||||
for route, (endpoint, methods) in self.routes.items():
|
||||
if endpoint in self.traffic_policies:
|
||||
traffic_policy = self.traffic_policies[endpoint]
|
||||
traffic_dict = traffic_policy.traffic_dict
|
||||
shadow_dict = traffic_policy.shadow_dict
|
||||
else:
|
||||
traffic_dict = {}
|
||||
shadow_dict = {}
|
||||
|
||||
endpoints[endpoint] = {
|
||||
"route": route if route.startswith("/") else None,
|
||||
"methods": methods,
|
||||
"traffic": self.traffic_policies.get(endpoint, {})
|
||||
"traffic": traffic_dict,
|
||||
"shadows": shadow_dict,
|
||||
}
|
||||
return endpoints
|
||||
|
||||
@@ -494,38 +535,51 @@ class ServeMaster:
|
||||
" that is not registered.".format(endpoint_name))
|
||||
|
||||
assert isinstance(traffic_dict,
|
||||
dict), "Traffic policy must be dictionary"
|
||||
prob = 0
|
||||
for backend, weight in traffic_dict.items():
|
||||
if weight < 0:
|
||||
raise ValueError(
|
||||
"Attempted to assign a weight of {} to backend '{}'. "
|
||||
"Weights cannot be negative.".format(weight, backend))
|
||||
prob += weight
|
||||
dict), "Traffic policy must be a dictionary."
|
||||
|
||||
for backend in traffic_dict:
|
||||
if backend not in self.backends:
|
||||
raise ValueError(
|
||||
"Attempted to assign traffic to a backend '{}' that "
|
||||
"is not registered.".format(backend))
|
||||
|
||||
# These weights will later be plugged into np.random.choice, which
|
||||
# uses a tolerance of 1e-8.
|
||||
assert np.isclose(
|
||||
prob, 1, atol=1e-8
|
||||
), "weights must sum to 1, currently they sum to {}".format(prob)
|
||||
|
||||
self.traffic_policies[endpoint_name] = traffic_dict
|
||||
traffic_policy = TrafficPolicy(traffic_dict)
|
||||
self.traffic_policies[endpoint_name] = traffic_policy
|
||||
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
await self.router.set_traffic.remote(endpoint_name, traffic_dict)
|
||||
await self.router.set_traffic.remote(endpoint_name, traffic_policy)
|
||||
|
||||
async def set_traffic(self, endpoint_name, traffic_dict):
|
||||
"""Sets the traffic policy for the specified endpoint."""
|
||||
async with self.write_lock:
|
||||
await self._set_traffic(endpoint_name, traffic_dict)
|
||||
|
||||
async def shadow_traffic(self, endpoint_name, backend_tag, proportion):
|
||||
"""Shadow traffic from the endpoint to the backend."""
|
||||
async with self.write_lock:
|
||||
if endpoint_name not in self.get_all_endpoints():
|
||||
raise ValueError("Attempted to shadow traffic from an "
|
||||
"endpoint '{}' that is not registered."
|
||||
.format(endpoint_name))
|
||||
|
||||
if backend_tag not in self.backends:
|
||||
raise ValueError(
|
||||
"Attempted to shadow traffic to a backend '{}' that "
|
||||
"is not registered.".format(backend_tag))
|
||||
|
||||
self.traffic_policies[endpoint_name].set_shadow(
|
||||
backend_tag, proportion)
|
||||
|
||||
# NOTE(edoakes): we must write a checkpoint before pushing the
|
||||
# update to avoid inconsistent state if we crash after pushing the
|
||||
# update.
|
||||
self._checkpoint()
|
||||
await self.router.set_traffic.remote(
|
||||
endpoint_name, self.traffic_policies[endpoint_name])
|
||||
|
||||
async def create_endpoint(self, endpoint, traffic_dict, route, methods):
|
||||
"""Create a new endpoint with the specified route and methods.
|
||||
|
||||
@@ -610,8 +664,8 @@ class ServeMaster:
|
||||
|
||||
# Save creator that starts replicas, the arguments to be passed in,
|
||||
# and the configuration for the backends.
|
||||
self.backends[backend_tag] = (backend_worker, backend_config,
|
||||
replica_config)
|
||||
self.backends[backend_tag] = BackendInfo(
|
||||
backend_worker, backend_config, replica_config)
|
||||
|
||||
self._scale_replicas(backend_tag, backend_config.num_replicas)
|
||||
|
||||
@@ -635,8 +689,9 @@ class ServeMaster:
|
||||
return
|
||||
|
||||
# Check that the specified backend isn't used by any endpoints.
|
||||
for endpoint, traffic_dict in self.traffic_policies.items():
|
||||
if backend_tag in traffic_dict:
|
||||
for endpoint, traffic_policy in self.traffic_policies.items():
|
||||
if (backend_tag in traffic_policy.traffic_dict
|
||||
or backend_tag in traffic_policy.shadow_dict):
|
||||
raise ValueError("Backend '{}' is used by endpoint '{}' "
|
||||
"and cannot be deleted. Please remove "
|
||||
"the backend from all endpoints and try "
|
||||
@@ -665,12 +720,9 @@ class ServeMaster:
|
||||
assert (backend_tag in self.backends
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert isinstance(config_options, dict)
|
||||
backend_worker, backend_config, replica_config = self.backends[
|
||||
backend_tag]
|
||||
|
||||
backend_config.update(config_options)
|
||||
self.backends[backend_tag] = (backend_worker, backend_config,
|
||||
replica_config)
|
||||
self.backends[backend_tag].backend_config.update(config_options)
|
||||
backend_config = self.backends[backend_tag].backend_config
|
||||
|
||||
# Scale the replicas with the new configuration.
|
||||
self._scale_replicas(backend_tag, backend_config.num_replicas)
|
||||
@@ -691,7 +743,7 @@ class ServeMaster:
|
||||
await self.broadcast_backend_config(backend_tag)
|
||||
|
||||
async def broadcast_backend_config(self, backend_tag):
|
||||
_, backend_config, _ = self.backends[backend_tag]
|
||||
backend_config = self.backends[backend_tag].backend_config
|
||||
broadcast_futures = []
|
||||
for replica_tag in self.replicas[backend_tag]:
|
||||
try:
|
||||
@@ -708,7 +760,7 @@ class ServeMaster:
|
||||
"""Get the current config for the specified backend."""
|
||||
assert (backend_tag in self.backends
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
return self.backends[backend_tag][2]
|
||||
return self.backends[backend_tag].backend_config
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shuts down the serve instance completely."""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import copy
|
||||
from hashlib import sha256
|
||||
|
||||
import numpy as np
|
||||
@@ -47,12 +48,29 @@ class RandomEndpointPolicy(EndpointPolicy):
|
||||
be made deterministically based on the hash of the shard key.
|
||||
"""
|
||||
|
||||
def __init__(self, traffic_dict):
|
||||
self.backend_names, self.backend_weights = zip(
|
||||
*sorted(traffic_dict.items()))
|
||||
def __init__(self, traffic_policy):
|
||||
self.backends = sorted(traffic_policy.traffic_dict.items())
|
||||
self.shadow_backends = list(traffic_policy.shadow_dict.items())
|
||||
|
||||
def _select_backends(self, val):
|
||||
curr_sum = 0
|
||||
for name, weight in self.backends:
|
||||
curr_sum += weight
|
||||
if curr_sum > val:
|
||||
chosen_backend = name
|
||||
break
|
||||
else:
|
||||
assert False, "This should never be reached."
|
||||
|
||||
shadow_backends = []
|
||||
for backend, backend_weight in self.shadow_backends:
|
||||
if val < backend_weight:
|
||||
shadow_backends.append(backend)
|
||||
|
||||
return chosen_backend, shadow_backends
|
||||
|
||||
def flush(self, endpoint_queue, backend_queues):
|
||||
if len(self.backend_names) == 0:
|
||||
if len(self.backends) == 0:
|
||||
logger.info("No backends to assign traffic to.")
|
||||
return set()
|
||||
|
||||
@@ -67,11 +85,17 @@ class RandomEndpointPolicy(EndpointPolicy):
|
||||
# Note(simon): This constructor takes 100+us, maybe cache this?
|
||||
rstate = np.random.RandomState(seed)
|
||||
|
||||
chosen_backend = rstate.choice(
|
||||
self.backend_names, replace=False,
|
||||
p=self.backend_weights).squeeze()
|
||||
chosen_backend, shadow_backends = self._select_backends(
|
||||
rstate.random())
|
||||
|
||||
assigned_backends.add(chosen_backend)
|
||||
backend_queues[chosen_backend].add(query)
|
||||
if len(shadow_backends) > 0:
|
||||
shadow_query = copy.copy(query)
|
||||
shadow_query.async_future = None
|
||||
shadow_query.is_shadow_query = True
|
||||
for shadow_backend in shadow_backends:
|
||||
assigned_backends.add(shadow_backend)
|
||||
backend_queues[shadow_backend].add(shadow_query)
|
||||
|
||||
return assigned_backends
|
||||
|
||||
+26
-13
@@ -17,14 +17,17 @@ from ray.serve.utils import logger, chain_future
|
||||
|
||||
|
||||
class Query:
|
||||
def __init__(self,
|
||||
request_args,
|
||||
request_kwargs,
|
||||
request_context,
|
||||
request_slo_ms,
|
||||
call_method="__call__",
|
||||
shard_key=None,
|
||||
async_future=None):
|
||||
def __init__(
|
||||
self,
|
||||
request_args,
|
||||
request_kwargs,
|
||||
request_context,
|
||||
request_slo_ms,
|
||||
call_method="__call__",
|
||||
shard_key=None,
|
||||
async_future=None,
|
||||
is_shadow_query=False,
|
||||
):
|
||||
self.request_args = request_args
|
||||
self.request_kwargs = request_kwargs
|
||||
self.request_context = request_context
|
||||
@@ -37,6 +40,7 @@ class Query:
|
||||
|
||||
self.call_method = call_method
|
||||
self.shard_key = shard_key
|
||||
self.is_shadow_query = is_shadow_query
|
||||
|
||||
def ray_serialize(self):
|
||||
# NOTE: this method is needed because Query need to be serialized and
|
||||
@@ -241,11 +245,11 @@ class Router:
|
||||
# result.
|
||||
pass
|
||||
|
||||
async def set_traffic(self, endpoint, traffic_dict):
|
||||
async def set_traffic(self, endpoint, traffic_policy):
|
||||
logger.debug("Setting traffic for endpoint %s to %s", endpoint,
|
||||
traffic_dict)
|
||||
traffic_policy)
|
||||
async with self.flush_lock:
|
||||
self.traffic[endpoint] = RandomEndpointPolicy(traffic_dict)
|
||||
self.traffic[endpoint] = RandomEndpointPolicy(traffic_policy)
|
||||
self.flush_endpoint_queue(endpoint)
|
||||
|
||||
async def remove_endpoint(self, endpoint):
|
||||
@@ -314,7 +318,13 @@ class Router:
|
||||
start = time.time()
|
||||
worker = self.replicas[backend_replica_tag]
|
||||
try:
|
||||
result = await worker.handle_request.remote(req)
|
||||
if req.is_shadow_query:
|
||||
# No need to actually get the result, but we do need to wait
|
||||
# until the call completes to mark the worker idle.
|
||||
asyncio.wait([worker.handle_request.remote(req)])
|
||||
result = ""
|
||||
else:
|
||||
result = await worker.handle_request.remote(req)
|
||||
except RayTaskError as error:
|
||||
self.num_error_backend_request.labels(backend=backend).add()
|
||||
result = error
|
||||
@@ -356,6 +366,9 @@ class Router:
|
||||
self.queries_counter[backend_replica_tag] += 1
|
||||
future = asyncio.get_event_loop().create_task(
|
||||
self._do_query(backend, backend_replica_tag, request))
|
||||
chain_future(future, request.async_future)
|
||||
|
||||
# For shadow queries, just ignore the result.
|
||||
if not request.is_shadow_query:
|
||||
chain_future(future, request.async_future)
|
||||
|
||||
worker_queue.appendleft(backend_replica_tag)
|
||||
|
||||
@@ -444,9 +444,11 @@ def test_list_endpoints(serve_instance):
|
||||
|
||||
serve.create_backend("backend", f)
|
||||
serve.create_backend("backend2", f)
|
||||
serve.create_backend("backend3", f)
|
||||
serve.create_endpoint(
|
||||
"endpoint", backend="backend", route="/api", methods=["GET", "POST"])
|
||||
serve.create_endpoint("endpoint2", backend="backend2", methods=["POST"])
|
||||
serve.shadow_traffic("endpoint", "backend3", 0.5)
|
||||
|
||||
endpoints = serve.list_endpoints()
|
||||
assert "endpoint" in endpoints
|
||||
@@ -455,6 +457,9 @@ def test_list_endpoints(serve_instance):
|
||||
"methods": ["GET", "POST"],
|
||||
"traffic": {
|
||||
"backend": 1.0
|
||||
},
|
||||
"shadows": {
|
||||
"backend3": 0.5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -464,7 +469,8 @@ def test_list_endpoints(serve_instance):
|
||||
"methods": ["POST"],
|
||||
"traffic": {
|
||||
"backend2": 1.0
|
||||
}
|
||||
},
|
||||
"shadows": {}
|
||||
}
|
||||
|
||||
serve.delete_endpoint("endpoint")
|
||||
@@ -576,6 +582,49 @@ def test_shutdown(serve_instance):
|
||||
assert wait_for_condition(check_dead)
|
||||
|
||||
|
||||
def test_shadow_traffic(serve_instance):
|
||||
def f():
|
||||
return "hello"
|
||||
|
||||
def f_shadow():
|
||||
return "oops"
|
||||
|
||||
serve.create_backend("backend1", f)
|
||||
serve.create_backend("backend2", f_shadow)
|
||||
serve.create_backend("backend3", f_shadow)
|
||||
serve.create_backend("backend4", f_shadow)
|
||||
|
||||
serve.create_endpoint("endpoint", backend="backend1", route="/api")
|
||||
serve.shadow_traffic("endpoint", "backend2", 1.0)
|
||||
serve.shadow_traffic("endpoint", "backend3", 0.5)
|
||||
serve.shadow_traffic("endpoint", "backend4", 0.1)
|
||||
|
||||
start = time.time()
|
||||
num_requests = 100
|
||||
for _ in range(num_requests):
|
||||
assert requests.get("http://127.0.0.1:8000/api").text == "hello"
|
||||
print("Finished 100 requests in {}s.".format(time.time() - start))
|
||||
|
||||
def requests_to_backend(backend):
|
||||
for entry in serve.stat():
|
||||
if entry["info"]["name"] == "backend_request_counter":
|
||||
if entry["info"]["backend"] == backend:
|
||||
return entry["value"]
|
||||
|
||||
return 0
|
||||
|
||||
def check_requests():
|
||||
return all([
|
||||
requests_to_backend("backend1") == num_requests,
|
||||
requests_to_backend("backend2") == requests_to_backend("backend1"),
|
||||
requests_to_backend("backend3") < requests_to_backend("backend2"),
|
||||
requests_to_backend("backend4") < requests_to_backend("backend3"),
|
||||
requests_to_backend("backend4") > 0,
|
||||
])
|
||||
|
||||
assert wait_for_condition(check_requests)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
||||
@@ -7,6 +7,7 @@ import ray
|
||||
from ray import serve
|
||||
import ray.serve.context as context
|
||||
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
|
||||
from ray.serve.master import TrafficPolicy
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.router import Router
|
||||
from ray.serve.config import BackendConfig
|
||||
@@ -59,7 +60,7 @@ async def test_runner_actor(serve_instance):
|
||||
worker = setup_worker(CONSUMER_NAME, echo)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
|
||||
|
||||
for query in [333, 444, 555]:
|
||||
query_param = RequestMetadata(PRODUCER_NAME,
|
||||
@@ -84,7 +85,7 @@ async def test_ray_serve_mixin(serve_instance):
|
||||
worker = setup_worker(CONSUMER_NAME, MyAdder, init_args=(3, ))
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
|
||||
|
||||
for query in [333, 444, 555]:
|
||||
query_param = RequestMetadata(PRODUCER_NAME,
|
||||
@@ -106,7 +107,7 @@ async def test_task_runner_check_context(serve_instance):
|
||||
worker = setup_worker(CONSUMER_NAME, echo)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
|
||||
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
|
||||
result_oid = q.enqueue_request.remote(query_param, i=42)
|
||||
|
||||
@@ -130,7 +131,7 @@ async def test_task_runner_custom_method_single(serve_instance):
|
||||
worker = setup_worker(CONSUMER_NAME, NonBatcher)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
|
||||
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
|
||||
|
||||
query_param = RequestMetadata(
|
||||
PRODUCER_NAME, context.TaskContext.Python, call_method="a")
|
||||
@@ -179,7 +180,10 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||
worker = setup_worker(
|
||||
CONSUMER_NAME, Batcher, backend_config=backend_config)
|
||||
|
||||
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
await q.set_traffic.remote(PRODUCER_NAME,
|
||||
TrafficPolicy({
|
||||
CONSUMER_NAME: 1.0
|
||||
}))
|
||||
await q.set_backend_config.remote(CONSUMER_NAME, backend_config)
|
||||
|
||||
def make_request_param(call_method):
|
||||
@@ -228,7 +232,10 @@ async def test_task_runner_perform_batch(serve_instance):
|
||||
worker = setup_worker(CONSUMER_NAME, batcher, backend_config=config)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
await q.set_backend_config.remote(CONSUMER_NAME, config)
|
||||
await q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
await q.set_traffic.remote(PRODUCER_NAME,
|
||||
TrafficPolicy({
|
||||
CONSUMER_NAME: 1.0
|
||||
}))
|
||||
|
||||
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
|
||||
|
||||
@@ -268,7 +275,7 @@ async def test_task_runner_perform_async(serve_instance):
|
||||
worker = setup_worker(CONSUMER_NAME, wait_and_go, backend_config=config)
|
||||
await q.add_new_worker.remote(CONSUMER_NAME, "replica1", worker)
|
||||
await q.set_backend_config.remote(CONSUMER_NAME, config)
|
||||
q.set_traffic.remote(PRODUCER_NAME, {CONSUMER_NAME: 1.0})
|
||||
q.set_traffic.remote(PRODUCER_NAME, TrafficPolicy({CONSUMER_NAME: 1.0}))
|
||||
|
||||
query_param = RequestMetadata(PRODUCER_NAME, context.TaskContext.Python)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections import defaultdict
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from ray.serve.master import TrafficPolicy
|
||||
from ray.serve.router import Router
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.utils import get_random_letters
|
||||
@@ -47,7 +48,7 @@ def task_runner_mock_actor():
|
||||
|
||||
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
q.set_traffic.remote("svc", {"backend-single-prod": 1.0})
|
||||
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
|
||||
q.add_new_worker.remote("backend-single-prod", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
|
||||
@@ -63,7 +64,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.set_traffic.remote("svc", {"backend-slo": 1.0})
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0}))
|
||||
|
||||
all_request_sent = []
|
||||
for i in range(10):
|
||||
@@ -88,14 +89,14 @@ async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-alter": 1})
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
|
||||
await q.add_new_worker.remote("backend-alter", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
||||
assert got_work.request_args[0] == 1
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-alter-2": 1})
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1}))
|
||||
await q.add_new_worker.remote("backend-alter-2", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
|
||||
@@ -106,10 +107,11 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
await q.set_traffic.remote("svc", {
|
||||
"backend-split": 0.5,
|
||||
"backend-split-2": 0.5
|
||||
})
|
||||
await q.set_traffic.remote(
|
||||
"svc", TrafficPolicy({
|
||||
"backend-split": 0.5,
|
||||
"backend-split-2": 0.5
|
||||
}))
|
||||
runner_1, runner_2 = [mock_task_runner() for _ in range(2)]
|
||||
await q.add_new_worker.remote("backend-split", "replica-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-split-2", "replica-1", runner_2)
|
||||
@@ -148,7 +150,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||
backend_name = "backend-split-" + str(i)
|
||||
traffic_dict[backend_name] = 1.0 / num_backends
|
||||
await q.add_new_worker.remote(backend_name, "replica-1", runner)
|
||||
await q.set_traffic.remote("svc", traffic_dict)
|
||||
await q.set_traffic.remote("svc", TrafficPolicy(traffic_dict))
|
||||
|
||||
# Generate random shard keys and send one request for each.
|
||||
shard_keys = [get_random_letters() for _ in range(100)]
|
||||
@@ -196,7 +198,7 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
q = ray.remote(VisibleRouter).remote()
|
||||
BACKEND_NAME = "max-concurrent-test"
|
||||
config = BackendConfig({"max_concurrent_queries": 1})
|
||||
await q.set_traffic.remote("svc", {BACKEND_NAME: 1.0})
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({BACKEND_NAME: 1.0}))
|
||||
await q.add_new_worker.remote(BACKEND_NAME, "replica-tag", worker)
|
||||
await q.set_backend_config.remote(BACKEND_NAME, config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user