mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 07:34:30 +08:00
[serve] Add basic session affinity via shard key (#8449)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from ray.serve.policy import RoutePolicy
|
||||
from ray.serve.api import (init, create_backend, delete_backend,
|
||||
create_endpoint, delete_endpoint, set_traffic,
|
||||
get_handle, stat, update_backend_config,
|
||||
@@ -7,6 +6,5 @@ from ray.serve.api import (init, create_backend, delete_backend,
|
||||
__all__ = [
|
||||
"init", "create_backend", "delete_backend", "create_endpoint",
|
||||
"delete_endpoint", "set_traffic", "get_handle", "stat",
|
||||
"update_backend_config", "get_backend_config", "RoutePolicy",
|
||||
"accept_batch"
|
||||
"update_backend_config", "get_backend_config", "accept_batch"
|
||||
]
|
||||
|
||||
@@ -10,7 +10,6 @@ from ray.serve.handle import RayServeHandle
|
||||
from ray.serve.utils import block_until_http_ready, retry_actor_failures
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.serve.policy import RoutePolicy
|
||||
from ray.serve.router import Query
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.metric import InMemoryExporter
|
||||
@@ -69,8 +68,6 @@ def init(blocking=False,
|
||||
"object_store_memory": int(1e8),
|
||||
"num_cpus": max(cpu_count(), 8)
|
||||
},
|
||||
queueing_policy=RoutePolicy.Random,
|
||||
policy_kwargs={},
|
||||
metric_exporter=InMemoryExporter):
|
||||
"""Initialize a serve cluster.
|
||||
|
||||
@@ -90,9 +87,6 @@ def init(blocking=False,
|
||||
ray_init_kwargs (dict): Argument passed to ray.init, if there is no ray
|
||||
connection. Default to {"object_store_memory": int(1e8)} for
|
||||
performance stability reason
|
||||
queueing_policy(RoutePolicy): Define the queueing policy for selecting
|
||||
the backend for a service. (Default: RoutePolicy.Random)
|
||||
policy_kwargs: Arguments required to instantiate a queueing policy
|
||||
metric_exporter(ExporterInterface): The class aggregates metrics from
|
||||
all RayServe actors and optionally export them to external
|
||||
services. RayServe has two options built in: InMemoryExporter and
|
||||
@@ -132,8 +126,7 @@ def init(blocking=False,
|
||||
detached=True,
|
||||
name=SERVE_MASTER_NAME,
|
||||
max_restarts=-1,
|
||||
).remote(queueing_policy.value, policy_kwargs, start_server, http_node_id,
|
||||
http_host, http_port, metric_exporter)
|
||||
).remote(start_server, http_node_id, http_host, http_port, metric_exporter)
|
||||
|
||||
if start_server and blocking:
|
||||
block_until_http_ready("http://{}:{}/-/routes".format(
|
||||
|
||||
@@ -2,7 +2,6 @@ import ray
|
||||
from ray import serve
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.constants import DEFAULT_HTTP_ADDRESS
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
|
||||
|
||||
@@ -35,6 +34,7 @@ class RayServeHandle:
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None,
|
||||
method_name=None,
|
||||
shard_key=None,
|
||||
):
|
||||
self.router_handle = router_handle
|
||||
self.endpoint_name = endpoint_name
|
||||
@@ -45,6 +45,7 @@ class RayServeHandle:
|
||||
self.relative_slo_ms = self._check_slo_ms(relative_slo_ms)
|
||||
self.absolute_slo_ms = self._check_slo_ms(absolute_slo_ms)
|
||||
self.method_name = method_name
|
||||
self.shard_key = shard_key
|
||||
|
||||
def _check_slo_ms(self, slo_value):
|
||||
if slo_value is not None:
|
||||
@@ -75,12 +76,14 @@ class RayServeHandle:
|
||||
self.relative_slo_ms,
|
||||
self.absolute_slo_ms,
|
||||
call_method=method_name,
|
||||
shard_key=self.shard_key,
|
||||
)
|
||||
return self.router_handle.enqueue_request.remote(
|
||||
request_in_object, **kwargs)
|
||||
|
||||
def options(self,
|
||||
method_name=None,
|
||||
shard_key=None,
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None):
|
||||
# If both the slo's are None then then we use a high default
|
||||
@@ -95,48 +98,30 @@ class RayServeHandle:
|
||||
if method_name is None and self.method_name is not None:
|
||||
method_name = self.method_name
|
||||
|
||||
if shard_key is None and self.shard_key is not None:
|
||||
shard_key = self.shard_key
|
||||
|
||||
return RayServeHandle(
|
||||
self.router_handle,
|
||||
self.endpoint_name,
|
||||
relative_slo_ms,
|
||||
absolute_slo_ms,
|
||||
method_name=method_name,
|
||||
shard_key=shard_key,
|
||||
)
|
||||
|
||||
def get_http_endpoint(self):
|
||||
return DEFAULT_HTTP_ADDRESS
|
||||
|
||||
def get_traffic_policy(self):
|
||||
master_actor = serve.api._get_master_actor()
|
||||
return ray.get(
|
||||
master_actor.get_traffic_policy.remote(self.endpoint_name))
|
||||
|
||||
def _ensure_backend_unique(self, backend_tag=None):
|
||||
traffic_policy = self.get_traffic_policy()
|
||||
if backend_tag is None:
|
||||
assert len(traffic_policy) == 1, (
|
||||
"Multiple backends detected. "
|
||||
"Please pass in backend_tag=... argument to specify backend.")
|
||||
backends = set(traffic_policy.keys())
|
||||
return backends.pop()
|
||||
else:
|
||||
assert (backend_tag in traffic_policy
|
||||
), "Backend {} not found in avaiable backends: {}.".format(
|
||||
backend_tag, list(traffic_policy.keys()))
|
||||
return backend_tag
|
||||
|
||||
def __repr__(self):
|
||||
return """
|
||||
RayServeHandle(
|
||||
Endpoint="{endpoint_name}",
|
||||
URL="{http_endpoint}/{endpoint_name}",
|
||||
Traffic={traffic_policy}
|
||||
)
|
||||
""".format(
|
||||
endpoint_name=self.endpoint_name,
|
||||
http_endpoint=self.get_http_endpoint(),
|
||||
traffic_policy=self.get_traffic_policy(),
|
||||
)
|
||||
|
||||
# TODO(simon): a convenience function that dumps equivalent requests
|
||||
# code for a given call.
|
||||
|
||||
@@ -168,7 +168,9 @@ class HTTPProxy:
|
||||
TaskContext.Web,
|
||||
relative_slo_ms=relative_slo_ms,
|
||||
absolute_slo_ms=absolute_slo_ms,
|
||||
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"))
|
||||
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"),
|
||||
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
|
||||
)
|
||||
|
||||
retries = 0
|
||||
while retries <= MAX_ACTOR_DEAD_RETRIES:
|
||||
|
||||
@@ -49,8 +49,7 @@ class ServeMaster:
|
||||
requires all implementations here to be idempotent.
|
||||
"""
|
||||
|
||||
async def __init__(self, router_policy, router_policy_kwargs,
|
||||
start_http_proxy, http_node_id, http_proxy_host,
|
||||
async def __init__(self, start_http_proxy, http_node_id, http_proxy_host,
|
||||
http_proxy_port, metric_exporter_class):
|
||||
# Used to read/write checkpoints.
|
||||
# TODO(edoakes): namespace the master actor and its checkpoints.
|
||||
@@ -89,7 +88,7 @@ class ServeMaster:
|
||||
# If starting the actor for the first time, starts up the other system
|
||||
# components. If recovering, fetches their actor handles.
|
||||
self._get_or_start_metric_exporter(metric_exporter_class)
|
||||
self._get_or_start_router(router_policy, router_policy_kwargs)
|
||||
self._get_or_start_router()
|
||||
if start_http_proxy:
|
||||
self._get_or_start_http_proxy(http_node_id, http_proxy_host,
|
||||
http_proxy_port)
|
||||
@@ -114,7 +113,7 @@ class ServeMaster:
|
||||
asyncio.get_event_loop().create_task(
|
||||
self._recover_from_checkpoint(checkpoint))
|
||||
|
||||
def _get_or_start_router(self, policy, policy_kwargs):
|
||||
def _get_or_start_router(self):
|
||||
"""Get the router belonging to this serve cluster.
|
||||
|
||||
If the router does not already exist, it will be started.
|
||||
@@ -128,7 +127,7 @@ class ServeMaster:
|
||||
detached=True,
|
||||
name=SERVE_ROUTER_NAME,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1).remote(policy, policy_kwargs)
|
||||
max_restarts=-1).remote()
|
||||
|
||||
def get_router(self):
|
||||
"""Returns a handle to the router managed by this actor."""
|
||||
|
||||
+19
-111
@@ -1,13 +1,12 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from enum import Enum
|
||||
import itertools
|
||||
from hashlib import sha256
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.serve.utils import logger
|
||||
|
||||
|
||||
class RoutingPolicy:
|
||||
class EndpointPolicy:
|
||||
"""Defines the interface for a routing policy for a single endpoint.
|
||||
|
||||
To add a new routing policy, a class should be defined that provides this
|
||||
@@ -39,15 +38,18 @@ class RoutingPolicy:
|
||||
return assigned_backends
|
||||
|
||||
|
||||
class RandomPolicy(RoutingPolicy):
|
||||
class RandomEndpointPolicy(EndpointPolicy):
|
||||
"""
|
||||
A stateless policy that makes a weighted random decision to map each query
|
||||
to a backend using the specified weights.
|
||||
|
||||
If a shard key is provided in a query, the weighted random selection will
|
||||
be made deterministically based on the hash of the shard key.
|
||||
"""
|
||||
|
||||
def __init__(self, traffic_dict):
|
||||
self.backend_names = list(traffic_dict.keys())
|
||||
self.backend_weights = list(traffic_dict.values())
|
||||
self.backend_names, self.backend_weights = zip(
|
||||
*sorted(traffic_dict.items()))
|
||||
|
||||
async def flush(self, endpoint_queue, backend_queues):
|
||||
if len(self.backend_names) == 0:
|
||||
@@ -56,113 +58,19 @@ class RandomPolicy(RoutingPolicy):
|
||||
|
||||
assigned_backends = set()
|
||||
while endpoint_queue.qsize():
|
||||
chosen_backend = np.random.choice(
|
||||
query = await endpoint_queue.get()
|
||||
if query.shard_key is None:
|
||||
rstate = np.random
|
||||
else:
|
||||
sha256_seed = sha256(query.shard_key.encode("utf-8"))
|
||||
seed = np.frombuffer(sha256_seed.digest(), dtype=np.uint32)
|
||||
rstate = np.random.RandomState(seed)
|
||||
|
||||
chosen_backend = rstate.choice(
|
||||
self.backend_names, replace=False,
|
||||
p=self.backend_weights).squeeze()
|
||||
|
||||
assigned_backends.add(chosen_backend)
|
||||
backend_queues[chosen_backend].add(await endpoint_queue.get())
|
||||
backend_queues[chosen_backend].add(query)
|
||||
|
||||
return assigned_backends
|
||||
|
||||
|
||||
class RoundRobinPolicy(RoutingPolicy):
|
||||
"""A stateful policy that assigns queries in round-robin order."""
|
||||
|
||||
def __init__(self, traffic_dict):
|
||||
# NOTE(edoakes): the backend weights are not used.
|
||||
self.backend_names = list(traffic_dict.keys())
|
||||
# Saves the information about last assigned backend for every endpoint.
|
||||
self.round_robin_iterator = itertools.cycle(self.backend_names)
|
||||
|
||||
async def flush(self, endpoint_queue, backend_queues):
|
||||
if len(self.backend_names) == 0:
|
||||
logger.info("No backends to assign traffic to.")
|
||||
return set()
|
||||
|
||||
assigned_backends = set()
|
||||
while endpoint_queue.qsize():
|
||||
chosen_backend = next(self.round_robin_iterator)
|
||||
assigned_backends.add(chosen_backend)
|
||||
backend_queues[chosen_backend].add(await endpoint_queue.get())
|
||||
|
||||
return assigned_backends
|
||||
|
||||
|
||||
class PowerOfTwoPolicy(RoutingPolicy):
|
||||
"""A stateless policy that uses the "power of two" policy.
|
||||
|
||||
For each query, two random backends are chosen. Of those two, the query is
|
||||
assigned to the backend whose queue length is shorter.
|
||||
"""
|
||||
|
||||
def __init__(self, traffic_dict):
|
||||
self.backend_names = list(traffic_dict.keys())
|
||||
self.backend_weights = list(traffic_dict.values())
|
||||
|
||||
async def flush(self, endpoint_queue, backend_queues):
|
||||
if len(self.backend_names) == 0:
|
||||
logger.info("No backends to assign traffic to.")
|
||||
return set()
|
||||
|
||||
assigned_backends = set()
|
||||
while endpoint_queue.qsize():
|
||||
if len(self.backend_names) >= 2:
|
||||
backend1, backend2 = np.random.choice(
|
||||
self.backend_names,
|
||||
2,
|
||||
replace=False,
|
||||
p=self.backend_weights)
|
||||
|
||||
# Choose the backend that has a shorter queue.
|
||||
if (len(backend_queues[backend1]) <= len(
|
||||
backend_queues[backend2])):
|
||||
chosen_backend = backend1
|
||||
else:
|
||||
chosen_backend = backend2
|
||||
else:
|
||||
chosen_backend = np.random.choice(
|
||||
self.backend_names, replace=False,
|
||||
p=self.backend_weights).squeeze()
|
||||
backend_queues[chosen_backend].add(await endpoint_queue.get())
|
||||
assigned_backends.add(chosen_backend)
|
||||
|
||||
return assigned_backends
|
||||
|
||||
|
||||
class FixedPackingPolicy(RoutingPolicy):
|
||||
"""A stateful policy that uses a "fixed packing" policy.
|
||||
|
||||
The policy round-robins groups of packing_num queries across backends. For
|
||||
example, the first packing_num queries are handled by backend-1, then the
|
||||
next packing_num queries are handled by backend-2, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, traffic_dict, packing_num=3):
|
||||
# NOTE(edoakes): the backend weights are not used.
|
||||
self.backend_names = list(traffic_dict.keys())
|
||||
self.fixed_packing_iterator = itertools.cycle(
|
||||
itertools.chain.from_iterable(
|
||||
itertools.repeat(x, self.packing_num)
|
||||
for x in self.backend_names))
|
||||
self.packing_num = packing_num
|
||||
|
||||
async def flush(self, endpoint_queue, backend_queues):
|
||||
if len(self.backend_names) == 0:
|
||||
logger.info("No backends to assign traffic to.")
|
||||
return set()
|
||||
|
||||
assigned_backends = set()
|
||||
while endpoint_queue.qsize():
|
||||
chosen_backend = next(self.fixed_packing_iterator)
|
||||
backend_queues[chosen_backend].add(await endpoint_queue.get())
|
||||
assigned_backends.add(chosen_backend)
|
||||
|
||||
return assigned_backends
|
||||
|
||||
|
||||
class RoutePolicy(Enum):
|
||||
"""All builtin routing policies."""
|
||||
Random = RandomPolicy
|
||||
RoundRobin = RoundRobinPolicy
|
||||
PowerOfTwo = PowerOfTwoPolicy
|
||||
FixedPacking = FixedPackingPolicy
|
||||
|
||||
@@ -21,13 +21,15 @@ class RequestMetadata:
|
||||
request_context,
|
||||
relative_slo_ms=None,
|
||||
absolute_slo_ms=None,
|
||||
call_method="__call__"):
|
||||
call_method="__call__",
|
||||
shard_key=None):
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.request_context = request_context
|
||||
self.relative_slo_ms = relative_slo_ms
|
||||
self.absolute_slo_ms = absolute_slo_ms
|
||||
self.call_method = call_method
|
||||
self.shard_key = shard_key
|
||||
|
||||
def adjust_relative_slo_ms(self) -> float:
|
||||
"""Normalize the input latency objective to absolute timestamp.
|
||||
|
||||
@@ -16,6 +16,7 @@ import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.serve.metric import MetricClient
|
||||
from ray.serve.policy import RandomEndpointPolicy
|
||||
from ray.serve.utils import logger, retry_actor_failures
|
||||
|
||||
|
||||
@@ -26,6 +27,7 @@ class Query:
|
||||
request_context,
|
||||
request_slo_ms,
|
||||
call_method="__call__",
|
||||
shard_key=None,
|
||||
async_future=None):
|
||||
self.request_args = request_args
|
||||
self.request_kwargs = request_kwargs
|
||||
@@ -38,6 +40,7 @@ class Query:
|
||||
self.request_slo_ms = request_slo_ms
|
||||
|
||||
self.call_method = call_method
|
||||
self.shard_key = shard_key
|
||||
|
||||
def ray_serialize(self):
|
||||
# NOTE: this method is needed because Query need to be serialized and
|
||||
@@ -103,7 +106,7 @@ class Router:
|
||||
3. When there is only 1 backend ready, we will only use that backend.
|
||||
"""
|
||||
|
||||
async def __init__(self, policy, policy_kwargs):
|
||||
async def __init__(self):
|
||||
# Note: Several queues are used in the router
|
||||
# - When a request come in, it's placed inside its corresponding
|
||||
# endpoint_queue.
|
||||
@@ -114,11 +117,6 @@ class Router:
|
||||
# handles are dequed during the second stage of flush operation,
|
||||
# which assign queries in buffer_queue to actor handle.
|
||||
|
||||
# policy.RoutePolicy.
|
||||
self.policy = policy
|
||||
# kwargs to pass into the policy when it's constructed.
|
||||
self.policy_kwargs = policy_kwargs
|
||||
|
||||
# -- Queues -- #
|
||||
|
||||
# endpoint_name -> request queue
|
||||
@@ -211,6 +209,7 @@ class Router:
|
||||
request_context,
|
||||
request_slo_ms,
|
||||
call_method=request_meta.call_method,
|
||||
shard_key=request_meta.shard_key,
|
||||
async_future=asyncio.get_event_loop().create_future())
|
||||
await self.endpoint_queues[endpoint].put(query)
|
||||
async with self.flush_lock:
|
||||
@@ -268,8 +267,7 @@ class Router:
|
||||
logger.debug("Setting traffic for endpoint %s to %s", endpoint,
|
||||
traffic_dict)
|
||||
async with self.flush_lock:
|
||||
self.traffic[endpoint] = self.policy(traffic_dict,
|
||||
**self.policy_kwargs)
|
||||
self.traffic[endpoint] = RandomEndpointPolicy(traffic_dict)
|
||||
await self.flush_endpoint_queue(endpoint)
|
||||
|
||||
async def remove_endpoint(self, endpoint):
|
||||
|
||||
@@ -2,8 +2,9 @@ import time
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ray import serve
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.utils import get_random_letters
|
||||
|
||||
|
||||
def test_e2e(serve_instance):
|
||||
@@ -302,3 +303,42 @@ def test_delete_endpoint(serve_instance, route):
|
||||
else:
|
||||
handle = serve.get_handle(endpoint_name)
|
||||
assert ray.get(handle.remote()) == "hello"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("route", [None, "/shard"])
|
||||
def test_shard_key(serve_instance, route):
|
||||
serve.create_endpoint("endpoint", route=route)
|
||||
|
||||
# Create five backends that return different integers.
|
||||
num_backends = 5
|
||||
traffic_dict = {}
|
||||
for i in range(num_backends):
|
||||
|
||||
def function():
|
||||
return i
|
||||
|
||||
backend_name = "backend-split-" + str(i)
|
||||
traffic_dict[backend_name] = 1.0 / num_backends
|
||||
serve.create_backend(backend_name, function)
|
||||
|
||||
serve.set_traffic("endpoint", traffic_dict)
|
||||
|
||||
def do_request(shard_key):
|
||||
if route is not None:
|
||||
url = "http://127.0.0.1:8000" + route
|
||||
headers = {"X-SERVE-SHARD-KEY": shard_key}
|
||||
result = requests.get(url, headers=headers).text
|
||||
else:
|
||||
handle = serve.get_handle("endpoint").options(shard_key=shard_key)
|
||||
result = ray.get(handle.options(shard_key=shard_key).remote())
|
||||
return result
|
||||
|
||||
# Send requests with different shard keys and log the backends they go to.
|
||||
shard_keys = [get_random_letters() for _ in range(20)]
|
||||
results = {}
|
||||
for shard_key in shard_keys:
|
||||
results[shard_key] = do_request(shard_key)
|
||||
|
||||
# Check that the shard keys are mapped to the same backends.
|
||||
for shard_key in shard_keys:
|
||||
assert do_request(shard_key) == results[shard_key]
|
||||
|
||||
@@ -6,7 +6,6 @@ import numpy as np
|
||||
import ray
|
||||
from ray import serve
|
||||
import ray.serve.context as context
|
||||
from ray.serve.policy import RoundRobinPolicy
|
||||
from ray.serve.backend_worker import create_backend_worker, wrap_to_ray_error
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.router import Router
|
||||
@@ -43,7 +42,7 @@ async def test_runner_wraps_error():
|
||||
|
||||
|
||||
async def test_runner_actor(serve_instance):
|
||||
q = ray.remote(Router).remote(RoundRobinPolicy, {})
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
def echo(flask_request, i=None):
|
||||
return i
|
||||
@@ -64,7 +63,7 @@ async def test_runner_actor(serve_instance):
|
||||
|
||||
|
||||
async def test_ray_serve_mixin(serve_instance):
|
||||
q = ray.remote(Router).remote(RoundRobinPolicy, {})
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
CONSUMER_NAME = "runner-cls"
|
||||
PRODUCER_NAME = "prod-cls"
|
||||
@@ -89,7 +88,7 @@ async def test_ray_serve_mixin(serve_instance):
|
||||
|
||||
|
||||
async def test_task_runner_check_context(serve_instance):
|
||||
q = ray.remote(Router).remote(RoundRobinPolicy, {})
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
def echo(flask_request, i=None):
|
||||
# Accessing the flask_request without web context should throw.
|
||||
@@ -110,7 +109,7 @@ async def test_task_runner_check_context(serve_instance):
|
||||
|
||||
|
||||
async def test_task_runner_custom_method_single(serve_instance):
|
||||
q = ray.remote(Router).remote(RoundRobinPolicy, {})
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
class NonBatcher:
|
||||
def a(self, _):
|
||||
@@ -144,7 +143,7 @@ async def test_task_runner_custom_method_single(serve_instance):
|
||||
|
||||
|
||||
async def test_task_runner_custom_method_batch(serve_instance):
|
||||
q = ray.remote(Router).remote(RoundRobinPolicy, {})
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
@serve.accept_batch
|
||||
class Batcher:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
|
||||
from ray.serve.policy import (RandomPolicy, RoundRobinPolicy, PowerOfTwoPolicy,
|
||||
FixedPackingPolicy)
|
||||
from ray.serve.router import Router
|
||||
from ray.serve.request_params import RequestMetadata
|
||||
from ray.serve.utils import get_random_letters
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
@@ -29,6 +29,9 @@ def mock_task_runner():
|
||||
def get_all_calls(self):
|
||||
return self.queries
|
||||
|
||||
def clear_calls(self):
|
||||
self.queries = []
|
||||
|
||||
def ready(self):
|
||||
pass
|
||||
|
||||
@@ -41,7 +44,7 @@ def task_runner_mock_actor():
|
||||
|
||||
|
||||
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote(RandomPolicy, {})
|
||||
q = ray.remote(Router).remote()
|
||||
q.set_traffic.remote("svc", {"backend-single-prod": 1.0})
|
||||
q.add_new_worker.remote("backend-single-prod", "replica-1",
|
||||
task_runner_mock_actor)
|
||||
@@ -57,7 +60,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
|
||||
|
||||
async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote(RandomPolicy, {})
|
||||
q = ray.remote(Router).remote()
|
||||
await q.set_traffic.remote("svc", {"backend-slo": 1.0})
|
||||
|
||||
all_request_sent = []
|
||||
@@ -81,7 +84,7 @@ async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
|
||||
|
||||
async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote(RandomPolicy, {})
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-alter": 1})
|
||||
await q.add_new_worker.remote("backend-alter", "replica-1",
|
||||
@@ -99,7 +102,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
|
||||
|
||||
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote(RandomPolicy, {})
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
await q.set_traffic.remote("svc", {
|
||||
"backend-split": 0.5,
|
||||
@@ -121,88 +124,51 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
assert [g.request_args[0] for g in got_work] == [1, 1]
|
||||
|
||||
|
||||
async def test_round_robin(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote(RoundRobinPolicy, {})
|
||||
|
||||
await q.set_traffic.remote("svc", {"backend-rr": 0.5, "backend-rr-2": 0.5})
|
||||
runner_1, runner_2 = [mock_task_runner() for _ in range(2)]
|
||||
|
||||
# NOTE: this is the only difference between the
|
||||
# test_split_traffic_random and test_round_robin
|
||||
await q.add_new_worker.remote("backend-rr", "replica-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-rr-2", "replica-1", runner_2)
|
||||
|
||||
for _ in range(20):
|
||||
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
||||
|
||||
got_work = [
|
||||
await runner.get_recent_call.remote()
|
||||
for runner in (runner_1, runner_2)
|
||||
]
|
||||
assert [g.request_args[0] for g in got_work] == [1, 1]
|
||||
|
||||
|
||||
async def test_fixed_packing(serve_instance):
|
||||
packing_num = 4
|
||||
q = ray.remote(Router).remote(FixedPackingPolicy,
|
||||
{"packing_num": packing_num})
|
||||
await q.set_traffic.remote("svc", {
|
||||
"backend-fixed": 0.5,
|
||||
"backend-fixed-2": 0.5
|
||||
})
|
||||
|
||||
runner_1, runner_2 = (mock_task_runner() for _ in range(2))
|
||||
# both the backends will get equal number of queries
|
||||
# as it is packed round robin
|
||||
await q.add_new_worker.remote("backend-fixed", "replica-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-fixed-2", "replica-1", runner_2)
|
||||
|
||||
for backend, runner in zip(["1", "2"], [runner_1, runner_2]):
|
||||
for _ in range(packing_num):
|
||||
input_value = "should-go-to-backend-{}".format(backend)
|
||||
await q.enqueue_request.remote(
|
||||
RequestMetadata("svc", None), input_value)
|
||||
all_calls = await runner.get_all_calls.remote()
|
||||
for call in all_calls:
|
||||
assert call.request_args[0] == input_value
|
||||
|
||||
|
||||
async def test_power_of_two_choices(serve_instance):
|
||||
q = ray.remote(Router).remote(PowerOfTwoPolicy, {})
|
||||
enqueue_futures = []
|
||||
|
||||
# First, fill the queue for backend-1 with 3 requests
|
||||
await q.set_traffic.remote("svc", {"backend-pow2": 1.0})
|
||||
for _ in range(3):
|
||||
future = q.enqueue_request.remote(RequestMetadata("svc", None), "1")
|
||||
enqueue_futures.append(future)
|
||||
|
||||
# Then, add a new backend, this backend should be filled next
|
||||
await q.set_traffic.remote("svc", {
|
||||
"backend-pow2": 0.5,
|
||||
"backend-pow2-2": 0.5
|
||||
})
|
||||
for _ in range(2):
|
||||
future = q.enqueue_request.remote(RequestMetadata("svc", None), "2")
|
||||
enqueue_futures.append(future)
|
||||
|
||||
runner_1, runner_2 = (mock_task_runner() for _ in range(2))
|
||||
await q.add_new_worker.remote("backend-pow2", "replica-1", runner_1)
|
||||
await q.add_new_worker.remote("backend-pow2-2", "replica-1", runner_2)
|
||||
|
||||
await asyncio.gather(*enqueue_futures)
|
||||
|
||||
assert len(await runner_1.get_all_calls.remote()) == 3
|
||||
assert len(await runner_2.get_all_calls.remote()) == 2
|
||||
|
||||
|
||||
async def test_queue_remove_replicas(serve_instance):
|
||||
class TestRouter(Router):
|
||||
def worker_queue_size(self, backend):
|
||||
return self.worker_queues["backend-remove"].qsize()
|
||||
|
||||
temp_actor = mock_task_runner()
|
||||
q = ray.remote(TestRouter).remote(RandomPolicy, {})
|
||||
q = ray.remote(TestRouter).remote()
|
||||
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
|
||||
await q.remove_worker.remote("backend-remove", "replica-1")
|
||||
assert ray.get(q.worker_queue_size.remote("backend")) == 0
|
||||
|
||||
|
||||
async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
|
||||
num_backends = 5
|
||||
traffic_dict = {}
|
||||
runners = [mock_task_runner() for _ in range(num_backends)]
|
||||
for i, runner in enumerate(runners):
|
||||
backend_name = "backend-split-" + str(i)
|
||||
traffic_dict[backend_name] = 1.0 / num_backends
|
||||
await q.add_new_worker.remote(backend_name, "replica-1", runner)
|
||||
await q.set_traffic.remote("svc", traffic_dict)
|
||||
|
||||
# Generate random shard keys and send one request for each.
|
||||
shard_keys = [get_random_letters() for _ in range(100)]
|
||||
for shard_key in shard_keys:
|
||||
await q.enqueue_request.remote(
|
||||
RequestMetadata("svc", None, shard_key=shard_key), shard_key)
|
||||
|
||||
# Log the shard keys that were assigned to each backend.
|
||||
runner_shard_keys = defaultdict(set)
|
||||
for i, runner in enumerate(runners):
|
||||
calls = await runner.get_all_calls.remote()
|
||||
for call in calls:
|
||||
runner_shard_keys[i].add(call.request_args[0])
|
||||
await runner.clear_calls.remote()
|
||||
|
||||
# Send queries with the same shard keys a second time.
|
||||
for shard_key in shard_keys:
|
||||
await q.enqueue_request.remote(
|
||||
RequestMetadata("svc", None, shard_key=shard_key), shard_key)
|
||||
|
||||
# Check that the requests were all mapped to the same backends.
|
||||
for i, runner in enumerate(runners):
|
||||
calls = await runner.get_all_calls.remote()
|
||||
for call in calls:
|
||||
assert call.request_args[0] in runner_shard_keys[i]
|
||||
|
||||
Reference in New Issue
Block a user