mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
258 lines
8.5 KiB
Python
258 lines
8.5 KiB
Python
"""
|
|
Unit tests for the router class. Please don't add any test that will involve
|
|
controller or the backend worker, use mock if necessary.
|
|
"""
|
|
import asyncio
|
|
from collections import defaultdict
|
|
import os
|
|
|
|
import pytest
|
|
|
|
import ray
|
|
from ray.serve.controller import TrafficPolicy
|
|
from ray.serve.router import Query, ReplicaSet, RequestMetadata, Router
|
|
from ray.serve.utils import get_random_letters
|
|
from ray.test_utils import SignalActor
|
|
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
|
|
@pytest.fixture
|
|
def ray_instance():
|
|
os.environ["SERVE_LOG_DEBUG"] = "1" # Turns on debug log for tests
|
|
ray.init(num_cpus=16)
|
|
yield
|
|
ray.shutdown()
|
|
|
|
|
|
def mock_task_runner():
|
|
@ray.remote(num_cpus=0)
|
|
class TaskRunnerMock:
|
|
def __init__(self):
|
|
self.query = None
|
|
self.queries = []
|
|
|
|
@ray.method(num_returns=2)
|
|
async def handle_request(self, request_metadata, *args, **kwargs):
|
|
self.query = Query(args, kwargs, request_metadata)
|
|
self.queries.append(self.query)
|
|
return b"", "DONE"
|
|
|
|
def get_recent_call(self):
|
|
return self.query
|
|
|
|
def get_all_calls(self):
|
|
return self.queries
|
|
|
|
def clear_calls(self):
|
|
self.queries = []
|
|
|
|
def ready(self):
|
|
pass
|
|
|
|
return TaskRunnerMock.remote()
|
|
|
|
|
|
@pytest.fixture
|
|
def task_runner_mock_actor():
|
|
yield mock_task_runner()
|
|
|
|
|
|
async def test_simple_endpoint_backend_pair(ray_instance, mock_controller,
|
|
task_runner_mock_actor):
|
|
q = ray.remote(Router).remote(mock_controller)
|
|
await q.setup_in_async_loop.remote()
|
|
|
|
# Propogate configs
|
|
await mock_controller.set_traffic.remote(
|
|
"svc", TrafficPolicy({
|
|
"backend-single-prod": 1.0
|
|
}))
|
|
await mock_controller.add_new_replica.remote("backend-single-prod",
|
|
task_runner_mock_actor)
|
|
|
|
# Make sure we get the request result back
|
|
ref = await q.assign_request.remote(
|
|
RequestMetadata(get_random_letters(10), "svc"), 1)
|
|
result = await ref
|
|
assert result == "DONE"
|
|
|
|
# Make sure it's the right request
|
|
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
|
assert got_work.args[0] == 1
|
|
assert got_work.kwargs == {}
|
|
|
|
|
|
async def test_changing_backend(ray_instance, mock_controller,
|
|
task_runner_mock_actor):
|
|
q = ray.remote(Router).remote(mock_controller)
|
|
await q.setup_in_async_loop.remote()
|
|
|
|
await mock_controller.set_traffic.remote(
|
|
"svc", TrafficPolicy({
|
|
"backend-alter": 1
|
|
}))
|
|
await mock_controller.add_new_replica.remote("backend-alter",
|
|
task_runner_mock_actor)
|
|
|
|
await (await q.assign_request.remote(
|
|
RequestMetadata(get_random_letters(10), "svc"), 1))
|
|
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
|
assert got_work.args[0] == 1
|
|
|
|
await mock_controller.set_traffic.remote(
|
|
"svc", TrafficPolicy({
|
|
"backend-alter-2": 1
|
|
}))
|
|
await mock_controller.add_new_replica.remote("backend-alter-2",
|
|
task_runner_mock_actor)
|
|
await (await q.assign_request.remote(
|
|
RequestMetadata(get_random_letters(10), "svc"), 2))
|
|
got_work = await task_runner_mock_actor.get_recent_call.remote()
|
|
assert got_work.args[0] == 2
|
|
|
|
|
|
async def test_split_traffic_random(ray_instance, mock_controller,
|
|
task_runner_mock_actor):
|
|
q = ray.remote(Router).remote(mock_controller)
|
|
await q.setup_in_async_loop.remote()
|
|
|
|
await mock_controller.set_traffic.remote(
|
|
"svc", TrafficPolicy({
|
|
"backend-split": 0.5,
|
|
"backend-split-2": 0.5
|
|
}))
|
|
runner_1, runner_2 = [mock_task_runner() for _ in range(2)]
|
|
await mock_controller.add_new_replica.remote("backend-split", runner_1)
|
|
await mock_controller.add_new_replica.remote("backend-split-2", runner_2)
|
|
|
|
# assume 50% split, the probability of all 20 requests goes to a
|
|
# single queue is 0.5^20 ~ 1-6
|
|
object_refs = []
|
|
for _ in range(20):
|
|
ref = await q.assign_request.remote(
|
|
RequestMetadata(get_random_letters(10), "svc"), 1)
|
|
object_refs.append(ref)
|
|
ray.get(object_refs)
|
|
|
|
got_work = [
|
|
await runner.get_recent_call.remote()
|
|
for runner in (runner_1, runner_2)
|
|
]
|
|
assert [g.args[0] for g in got_work] == [1, 1]
|
|
|
|
|
|
async def test_shard_key(ray_instance, mock_controller,
|
|
task_runner_mock_actor):
|
|
q = ray.remote(Router).remote(mock_controller)
|
|
await q.setup_in_async_loop.remote()
|
|
|
|
num_backends = 5
|
|
traffic_dict = {}
|
|
runners = [mock_task_runner() for _ in range(num_backends)]
|
|
for i, runner in enumerate(runners):
|
|
backend_name = "backend-split-" + str(i)
|
|
traffic_dict[backend_name] = 1.0 / num_backends
|
|
await mock_controller.add_new_replica.remote(backend_name, runner)
|
|
await mock_controller.set_traffic.remote("svc",
|
|
TrafficPolicy(traffic_dict))
|
|
|
|
# Generate random shard keys and send one request for each.
|
|
shard_keys = [get_random_letters() for _ in range(100)]
|
|
for shard_key in shard_keys:
|
|
await (await q.assign_request.remote(
|
|
RequestMetadata(
|
|
get_random_letters(10), "svc", shard_key=shard_key),
|
|
shard_key))
|
|
|
|
# Log the shard keys that were assigned to each backend.
|
|
runner_shard_keys = defaultdict(set)
|
|
for i, runner in enumerate(runners):
|
|
calls = await runner.get_all_calls.remote()
|
|
for call in calls:
|
|
runner_shard_keys[i].add(call.args[0])
|
|
await runner.clear_calls.remote()
|
|
|
|
# Send queries with the same shard keys a second time.
|
|
for shard_key in shard_keys:
|
|
await (await q.assign_request.remote(
|
|
RequestMetadata(
|
|
get_random_letters(10), "svc", shard_key=shard_key),
|
|
shard_key))
|
|
|
|
# Check that the requests were all mapped to the same backends.
|
|
for i, runner in enumerate(runners):
|
|
calls = await runner.get_all_calls.remote()
|
|
for call in calls:
|
|
assert call.args[0] in runner_shard_keys[i]
|
|
|
|
|
|
async def test_replica_set(ray_instance):
|
|
signal = SignalActor.remote()
|
|
|
|
@ray.remote(num_cpus=0)
|
|
class MockWorker:
|
|
_num_queries = 0
|
|
|
|
@ray.method(num_returns=2)
|
|
async def handle_request(self, request):
|
|
self._num_queries += 1
|
|
await signal.wait.remote()
|
|
return b"", "DONE"
|
|
|
|
async def num_queries(self):
|
|
return self._num_queries
|
|
|
|
# We will test a scenario with two replicas in the replica set.
|
|
rs = ReplicaSet("my_backend")
|
|
workers = [MockWorker.remote() for _ in range(2)]
|
|
rs.set_max_concurrent_queries(1)
|
|
rs.update_worker_replicas(workers)
|
|
|
|
# Send two queries. They should go through the router but blocked by signal
|
|
# actors.
|
|
query = Query([], {}, RequestMetadata("request-id", "endpoint"))
|
|
first_ref = await rs.assign_replica(query)
|
|
second_ref = await rs.assign_replica(query)
|
|
|
|
# These should be blocked by signal actor.
|
|
with pytest.raises(ray.exceptions.GetTimeoutError):
|
|
ray.get([first_ref, second_ref], timeout=1)
|
|
|
|
# Each replica should have exactly one inflight query. Let make sure the
|
|
# queries arrived there.
|
|
for worker in workers:
|
|
while await worker.num_queries.remote() != 1:
|
|
await asyncio.sleep(1)
|
|
|
|
# Let's try to send another query.
|
|
third_ref_pending_task = asyncio.get_event_loop().create_task(
|
|
rs.assign_replica(query))
|
|
# We should fail to assign a replica, so this coroutine should still be
|
|
# pending after some time.
|
|
await asyncio.sleep(0.2)
|
|
assert not third_ref_pending_task.done()
|
|
|
|
# Let's unblock the two workers
|
|
await signal.send.remote()
|
|
assert await first_ref == "DONE"
|
|
assert await second_ref == "DONE"
|
|
|
|
# The third request should be unblocked and sent to first worker.
|
|
# This meas we should be able to get the object ref.
|
|
third_ref = await third_ref_pending_task
|
|
|
|
# Now we got the object ref, let's get it result.
|
|
await signal.send.remote()
|
|
assert await third_ref == "DONE"
|
|
|
|
# Finally, make sure that one of the replica processed the third query.
|
|
num_queries_set = {(await worker.num_queries.remote())
|
|
for worker in workers}
|
|
assert num_queries_set == {2, 1}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
sys.exit(pytest.main(["-v", "-s", __file__]))
|