mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 16:13:54 +08:00
231 lines
8.0 KiB
Python
231 lines
8.0 KiB
Python
from collections import defaultdict
|
|
|
|
import pytest
|
|
import ray
|
|
|
|
from ray.serve.controller import TrafficPolicy
|
|
from ray.serve.router import Router, Query
|
|
from ray.serve.request_params import RequestMetadata
|
|
from ray.serve.utils import get_random_letters
|
|
from ray.test_utils import SignalActor
|
|
from ray.serve.config import BackendConfig
|
|
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
|
|
def mock_task_runner():
|
|
@ray.remote(num_cpus=0)
|
|
class TaskRunnerMock:
|
|
def __init__(self):
|
|
self.query = None
|
|
self.queries = []
|
|
|
|
async def handle_request(self, request):
|
|
if isinstance(request, bytes):
|
|
request = Query.ray_deserialize(request)
|
|
self.query = request
|
|
self.queries.append(request)
|
|
return "DONE"
|
|
|
|
def get_recent_call(self):
|
|
return self.query
|
|
|
|
def get_all_calls(self):
|
|
return self.queries
|
|
|
|
def clear_calls(self):
|
|
self.queries = []
|
|
|
|
def ready(self):
|
|
pass
|
|
|
|
return TaskRunnerMock.remote()
|
|
|
|
|
|
@pytest.fixture
|
|
def task_runner_mock_actor():
|
|
yield mock_task_runner()
|
|
|
|
|
|
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
|
q = ray.remote(Router).remote()
|
|
await q.setup.remote("")
|
|
|
|
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
|
|
q.add_new_worker.remote("backend-single-prod", "replica-1",
|
|
task_runner_mock_actor)
|
|
|
|
# Make sure we get the request result back
|
|
result = await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
|
assert result == "DONE"
|
|
|
|
# Make sure it's the right request
|
|
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
|
assert got_work.request_args[0] == 1
|
|
assert got_work.request_kwargs == {}
|
|
|
|
|
|
async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
|
q = ray.remote(Router).remote()
|
|
await q.setup.remote("")
|
|
|
|
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
|
|
await q.add_new_worker.remote("backend-alter", "replica-1",
|
|
task_runner_mock_actor)
|
|
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
|
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
|
assert got_work.request_args[0] == 1
|
|
|
|
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1}))
|
|
await q.add_new_worker.remote("backend-alter-2", "replica-1",
|
|
task_runner_mock_actor)
|
|
await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
|
|
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
|
assert got_work.request_args[0] == 2
|
|
|
|
|
|
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
|
q = ray.remote(Router).remote()
|
|
await q.setup.remote("")
|
|
|
|
await q.set_traffic.remote(
|
|
"svc", TrafficPolicy({
|
|
"backend-split": 0.5,
|
|
"backend-split-2": 0.5
|
|
}))
|
|
runner_1, runner_2 = [mock_task_runner() for _ in range(2)]
|
|
await q.add_new_worker.remote("backend-split", "replica-1", runner_1)
|
|
await q.add_new_worker.remote("backend-split-2", "replica-1", runner_2)
|
|
|
|
# assume 50% split, the probability of all 20 requests goes to a
|
|
# single queue is 0.5^20 ~ 1-6
|
|
for _ in range(20):
|
|
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
|
|
|
got_work = [
|
|
await runner.get_recent_call.remote()
|
|
for runner in (runner_1, runner_2)
|
|
]
|
|
assert [g.request_args[0] for g in got_work] == [1, 1]
|
|
|
|
|
|
async def test_queue_remove_replicas(serve_instance):
|
|
class TestRouter(Router):
|
|
def worker_queue_size(self, backend):
|
|
return len(self.worker_queues["backend-remove"])
|
|
|
|
temp_actor = mock_task_runner()
|
|
q = ray.remote(TestRouter).remote()
|
|
await q.setup.remote("")
|
|
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
|
|
await q.remove_worker.remote("backend-remove", "replica-1")
|
|
assert ray.get(q.worker_queue_size.remote("backend")) == 0
|
|
|
|
|
|
async def test_shard_key(serve_instance, task_runner_mock_actor):
|
|
q = ray.remote(Router).remote()
|
|
await q.setup.remote("")
|
|
|
|
num_backends = 5
|
|
traffic_dict = {}
|
|
runners = [mock_task_runner() for _ in range(num_backends)]
|
|
for i, runner in enumerate(runners):
|
|
backend_name = "backend-split-" + str(i)
|
|
traffic_dict[backend_name] = 1.0 / num_backends
|
|
await q.add_new_worker.remote(backend_name, "replica-1", runner)
|
|
await q.set_traffic.remote("svc", TrafficPolicy(traffic_dict))
|
|
|
|
# Generate random shard keys and send one request for each.
|
|
shard_keys = [get_random_letters() for _ in range(100)]
|
|
for shard_key in shard_keys:
|
|
await q.enqueue_request.remote(
|
|
RequestMetadata("svc", None, shard_key=shard_key), shard_key)
|
|
|
|
# Log the shard keys that were assigned to each backend.
|
|
runner_shard_keys = defaultdict(set)
|
|
for i, runner in enumerate(runners):
|
|
calls = await runner.get_all_calls.remote()
|
|
for call in calls:
|
|
runner_shard_keys[i].add(call.request_args[0])
|
|
await runner.clear_calls.remote()
|
|
|
|
# Send queries with the same shard keys a second time.
|
|
for shard_key in shard_keys:
|
|
await q.enqueue_request.remote(
|
|
RequestMetadata("svc", None, shard_key=shard_key), shard_key)
|
|
|
|
# Check that the requests were all mapped to the same backends.
|
|
for i, runner in enumerate(runners):
|
|
calls = await runner.get_all_calls.remote()
|
|
for call in calls:
|
|
assert call.request_args[0] in runner_shard_keys[i]
|
|
|
|
|
|
async def test_router_use_max_concurrency(serve_instance):
|
|
signal = SignalActor.remote()
|
|
|
|
@ray.remote
|
|
class MockWorker:
|
|
async def handle_request(self, request):
|
|
await signal.wait.remote()
|
|
return "DONE"
|
|
|
|
def ready(self):
|
|
pass
|
|
|
|
class VisibleRouter(Router):
|
|
def get_queues(self):
|
|
return self.queries_counter, self.backend_queues
|
|
|
|
worker = MockWorker.remote()
|
|
q = ray.remote(VisibleRouter).remote()
|
|
await q.setup.remote("")
|
|
backend_name = "max-concurrent-test"
|
|
config = BackendConfig({"max_concurrent_queries": 1})
|
|
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
|
|
await q.add_new_worker.remote(backend_name, "replica-tag", worker)
|
|
await q.set_backend_config.remote(backend_name, config)
|
|
|
|
# We send over two queries
|
|
first_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
|
second_query = q.enqueue_request.remote(RequestMetadata("svc", None), 1)
|
|
|
|
# Neither queries should be available
|
|
with pytest.raises(ray.exceptions.GetTimeoutError):
|
|
ray.get([first_query, second_query], timeout=0.2)
|
|
|
|
# Let's retrieve the router internal state
|
|
queries_counter, backend_queues = await q.get_queues.remote()
|
|
# There should be just one inflight request
|
|
assert queries_counter[backend_name][
|
|
"max-concurrent-test:replica-tag"] == 1
|
|
# The second query is buffered
|
|
assert len(backend_queues["max-concurrent-test"]) == 1
|
|
|
|
# Let's unblock the first query
|
|
await signal.send.remote(clear=True)
|
|
assert await first_query == "DONE"
|
|
|
|
# The internal state of router should have changed.
|
|
queries_counter, backend_queues = await q.get_queues.remote()
|
|
# There should still be one inflight request
|
|
assert queries_counter[backend_name][
|
|
"max-concurrent-test:replica-tag"] == 1
|
|
# But there shouldn't be any queries in the queue
|
|
assert len(backend_queues["max-concurrent-test"]) == 0
|
|
|
|
# Unblocking the second query
|
|
await signal.send.remote(clear=True)
|
|
assert await second_query == "DONE"
|
|
|
|
# Checking the internal state of the router one more time
|
|
queries_counter, backend_queues = await q.get_queues.remote()
|
|
assert queries_counter[backend_name][
|
|
"max-concurrent-test:replica-tag"] == 0
|
|
assert len(backend_queues["max-concurrent-test"]) == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|