mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 09:55:49 +08:00
[serve] Basic autoscaling policy (#9845)
This commit is contained in:
@@ -0,0 +1,125 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from ray.serve.utils import logger
|
||||
|
||||
|
||||
class AutoscalingPolicy:
|
||||
"""Defines the interface for an autoscaling policy.
|
||||
|
||||
To add a new autoscaling policy, a class should be defined that provides
|
||||
this interface. The class may be stateful, in which case it may also want
|
||||
to provide a non-default constructor. However, this state will be lost when
|
||||
the controller recovers from a failure.
|
||||
"""
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, config):
|
||||
"""Initialize the policy using the specified config dictionary."""
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def scale(self, router_queue_lens, curr_replicas):
|
||||
"""Make a decision to scale backends.
|
||||
|
||||
Arguments:
|
||||
router_queue_lens (Dict[str, int]): map of routers to their most
|
||||
recent queue length of unsent queries for this backend.
|
||||
curr_replicas (int): The number of replicas that the backend
|
||||
currently has.
|
||||
|
||||
Returns:
|
||||
int The new number of replicas to scale this backend to.
|
||||
"""
|
||||
return curr_replicas
|
||||
|
||||
|
||||
class BasicAutoscalingPolicy(AutoscalingPolicy):
|
||||
"""The default autoscaling policy based on basic thresholds for scaling.
|
||||
|
||||
There is a minimum threshold for the average queue length in the cluster
|
||||
to scale up and a maximum threshold to scale down. Each period, a 'scale
|
||||
up' or 'scale down' decision is made. This decision must be made for a
|
||||
specified number of periods in a row before the number of replicas is
|
||||
actually scaled. See config options for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, backend, config):
|
||||
self.backend = backend
|
||||
|
||||
# The minimum average queue length to trigger scaling up.
|
||||
self.scale_up_threshold = config.get("scale_up_threshold", 5)
|
||||
# The maximum average queue length to trigger scaling down.
|
||||
self.scale_down_threshold = config.get("scale_down_threshold", 1)
|
||||
# The number of replicas to be added when scaling up.
|
||||
self.scale_up_num_replicas = config.get("scale_up_num_replicas", 2)
|
||||
# The number of replicas to be removed when scaling down.
|
||||
self.scale_down_num_replicas = config.get("scale_down_num_replicas", 1)
|
||||
# The number of consecutive 'scale up' decisions that need to be made
|
||||
# before the number of replicas is actually increased.
|
||||
self.scale_up_consecutive_periods = config.get(
|
||||
"scale_up_consecutive_periods", 2)
|
||||
# The number of consecutive 'scale down' decisions that need to be made
|
||||
# before the number of replicas is actually decreased.
|
||||
self.scale_down_consecutive_periods = config.get(
|
||||
"scale_down_consecutive_periods", 5)
|
||||
|
||||
# Keeps track of previous decisions. Each time the load is above
|
||||
# 'scale_up_threshold', the counter is incremented and each time it is
|
||||
# below 'scale_down_threshold', the counter is decremented. When the
|
||||
# load is between the thresholds or a scaling decision is made, the
|
||||
# counter is reset to 0.
|
||||
self.decision_counter = 0
|
||||
|
||||
def scale(self, router_queue_lens, curr_replicas):
|
||||
queue_lens = list(router_queue_lens.values())
|
||||
if len(queue_lens) == 0:
|
||||
return -1
|
||||
|
||||
new_replicas = curr_replicas
|
||||
avg_queue_len = sum(queue_lens) / len(queue_lens)
|
||||
|
||||
# Scale up.
|
||||
if avg_queue_len > self.scale_up_threshold:
|
||||
# If the previous decision was to scale down (the counter was
|
||||
# negative), we reset it and then increment it (set to 1).
|
||||
# Otherwise, just increment.
|
||||
if self.decision_counter < 0:
|
||||
self.decision_counter = 1
|
||||
else:
|
||||
self.decision_counter += 1
|
||||
|
||||
# Only actually scale the replicas if we've made this decision for
|
||||
# 'scale_up_consecutive_periods' in a row.
|
||||
if self.decision_counter >= self.scale_up_consecutive_periods:
|
||||
# TODO(edoakes): should we be resetting the counter here?
|
||||
self.decision_counter = 0
|
||||
new_replicas = curr_replicas + self.scale_up_num_replicas
|
||||
logger.info("Increasing number of replicas for backend '{}' "
|
||||
"from {} to {}".format(self.backend, curr_replicas,
|
||||
new_replicas))
|
||||
|
||||
# Scale down.
|
||||
elif avg_queue_len < self.scale_down_threshold and curr_replicas > 1:
|
||||
# If the previous decision was to scale up (the counter was
|
||||
# positive), reset it to zero before decrementing.
|
||||
if self.decision_counter > 0:
|
||||
self.decision_counter = -1
|
||||
else:
|
||||
self.decision_counter -= 1
|
||||
|
||||
# Only actually scale the replicas if we've made this decision for
|
||||
# 'scale_down_consecutive_periods' in a row.
|
||||
if (self.decision_counter <=
|
||||
-self.scale_down_consecutive_periods + 1):
|
||||
# TODO(edoakes): should we be resetting the counter here?
|
||||
self.decision_counter = 0
|
||||
new_replicas = curr_replicas - self.scale_down_num_replicas
|
||||
logger.info("Decreasing number of replicas for backend '{}' "
|
||||
"from {} to {}".format(self.backend, curr_replicas,
|
||||
new_replicas))
|
||||
|
||||
# Do nothing.
|
||||
else:
|
||||
self.decision_counter = 0
|
||||
|
||||
return new_replicas
|
||||
@@ -42,6 +42,9 @@ class BatchQueue:
|
||||
if self.queue.qsize() == self.max_batch_size:
|
||||
self.full_batch_event.set()
|
||||
|
||||
def qsize(self):
|
||||
return self.queue.qsize()
|
||||
|
||||
async def wait_for_batch(self):
|
||||
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
|
||||
|
||||
@@ -153,9 +156,9 @@ def ensure_async(func):
|
||||
class RayServeWorker:
|
||||
"""Handles requests with the provided callable."""
|
||||
|
||||
def __init__(self, name, replica_tag, _callable,
|
||||
def __init__(self, backend_tag, replica_tag, _callable,
|
||||
backend_config: BackendConfig, is_function, metric_client):
|
||||
self.name = name
|
||||
self.backend_tag = backend_tag
|
||||
self.replica_tag = replica_tag
|
||||
self.callable = _callable
|
||||
self.is_function = is_function
|
||||
@@ -183,7 +186,7 @@ class RayServeWorker:
|
||||
|
||||
self.restart_counter.labels(replica_tag=self.replica_tag).add()
|
||||
|
||||
self.loop_task = asyncio.get_event_loop().create_task(self.main_loop())
|
||||
asyncio.get_event_loop().create_task(self.main_loop())
|
||||
|
||||
def get_runner_method(self, request_item):
|
||||
method_name = request_item.call_method
|
||||
@@ -348,7 +351,8 @@ class RayServeWorker:
|
||||
async def handle_request(self, request: Union[Query, bytes]):
|
||||
if isinstance(request, bytes):
|
||||
request = Query.ray_deserialize(request)
|
||||
logger.debug("Worker {} got request {}".format(self.name, request))
|
||||
logger.debug("Worker {} got request {}".format(self.replica_tag,
|
||||
request))
|
||||
request.async_future = asyncio.get_event_loop().create_future()
|
||||
self.batch_queue.put(request)
|
||||
return await request.async_future
|
||||
|
||||
@@ -30,6 +30,7 @@ class BackendConfig:
|
||||
self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0)
|
||||
self.max_concurrent_queries = config_dict.pop("max_concurrent_queries",
|
||||
None)
|
||||
self.autoscaling_config = config_dict.pop("autoscaling", None)
|
||||
|
||||
if self.max_concurrent_queries is None:
|
||||
# Model serving mode: if the servable is blocking and the wait
|
||||
@@ -136,7 +137,7 @@ class ReplicaConfig:
|
||||
raise ValueError("Specifying max_restarts in "
|
||||
"actor_init_args is not allowed.")
|
||||
else:
|
||||
num_cpus = self.ray_actor_options.get("num_cpus", 0)
|
||||
num_cpus = self.ray_actor_options.get("num_cpus", 1)
|
||||
if not isinstance(num_cpus, (int, float)):
|
||||
raise TypeError(
|
||||
"num_cpus in ray_actor_options must be an int or a float.")
|
||||
|
||||
@@ -6,6 +6,7 @@ import time
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.serve.autoscaling_policy import BasicAutoscalingPolicy
|
||||
from ray.serve.backend_worker import create_backend_worker
|
||||
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME,
|
||||
SERVE_METRIC_SINK_NAME)
|
||||
@@ -101,6 +102,8 @@ class ServeController:
|
||||
self.routes = dict()
|
||||
# backend -> BackendInfo.
|
||||
self.backends = dict()
|
||||
# backend -> AutoscalingPolicy
|
||||
self.autoscaling_policies = dict()
|
||||
# backend -> replica_tags.
|
||||
self.replicas = defaultdict(list)
|
||||
# replicas that should be started if recovering from a checkpoint.
|
||||
@@ -118,6 +121,8 @@ class ServeController:
|
||||
# Dictionary of backend tag to dictionaries of replica tag to worker.
|
||||
# TODO(edoakes): consider removing this and just using the names.
|
||||
self.workers = defaultdict(dict)
|
||||
# Dictionary of backend_tag -> router_name -> most recent queue length.
|
||||
self.backend_stats = defaultdict(lambda: defaultdict(dict))
|
||||
|
||||
# Used to ensure that only a single state-changing operation happens
|
||||
# at any given time.
|
||||
@@ -182,6 +187,7 @@ class ServeController:
|
||||
node_resource: 0.01
|
||||
},
|
||||
).remote(
|
||||
node_id,
|
||||
self.http_host,
|
||||
self.http_port,
|
||||
instance_name=self.instance_name)
|
||||
@@ -324,6 +330,9 @@ class ServeController:
|
||||
for router in self.routers.values()
|
||||
])
|
||||
await self.broadcast_backend_config(backend)
|
||||
if info.backend_config.autoscaling_config is not None:
|
||||
self.autoscaling_policies[backend] = BasicAutoscalingPolicy(
|
||||
backend, info.backend_config.autoscaling_config)
|
||||
|
||||
# Push configuration state to the routers.
|
||||
await asyncio.gather(*[
|
||||
@@ -344,8 +353,21 @@ class ServeController:
|
||||
|
||||
self.write_lock.release()
|
||||
|
||||
async def do_autoscale(self):
|
||||
for backend in self.backends:
|
||||
if backend not in self.autoscaling_policies:
|
||||
continue
|
||||
|
||||
new_num_replicas = self.autoscaling_policies[backend].scale(
|
||||
self.backend_stats[backend],
|
||||
self.backends[backend].backend_config.num_replicas)
|
||||
if new_num_replicas > 0:
|
||||
await self.update_backend_config(
|
||||
backend, {"num_replicas": new_num_replicas})
|
||||
|
||||
async def run_control_loop(self):
|
||||
while True:
|
||||
await self.do_autoscale()
|
||||
async with self.write_lock:
|
||||
self._start_routers_if_needed()
|
||||
checkpoint_required = self._stop_routers_if_needed()
|
||||
@@ -730,6 +752,10 @@ class ServeController:
|
||||
# and the configuration for the backends.
|
||||
self.backends[backend_tag] = BackendInfo(
|
||||
backend_worker, backend_config, replica_config)
|
||||
if backend_config.autoscaling_config is not None:
|
||||
self.autoscaling_policies[
|
||||
backend_tag] = BasicAutoscalingPolicy(
|
||||
backend_tag, backend_config.autoscaling_config)
|
||||
|
||||
self._scale_replicas(backend_tag, backend_config.num_replicas)
|
||||
|
||||
@@ -769,6 +795,8 @@ class ServeController:
|
||||
|
||||
# Remove the backend's metadata.
|
||||
del self.backends[backend_tag]
|
||||
if backend_tag in self.autoscaling_policies:
|
||||
del self.autoscaling_policies[backend_tag]
|
||||
|
||||
# Add the intention to remove the backend from the router.
|
||||
self.backends_to_remove.append(backend_tag)
|
||||
@@ -840,3 +868,8 @@ class ServeController:
|
||||
for replica in replica_dict.values():
|
||||
ray.kill(replica, no_restart=True)
|
||||
self.kv_store.delete(CHECKPOINT_KEY)
|
||||
|
||||
async def report_queue_lengths(self, router_name, queue_lengths):
|
||||
# TODO: remove old router stats when removing them.
|
||||
for backend, queue_length in queue_lengths.items():
|
||||
self.backend_stats[backend][router_name] = queue_length
|
||||
|
||||
@@ -27,7 +27,7 @@ class HTTPProxy:
|
||||
# blocks forever
|
||||
"""
|
||||
|
||||
async def fetch_config_from_controller(self, instance_name=None):
|
||||
async def fetch_config_from_controller(self, name, instance_name=None):
|
||||
assert ray.is_initialized()
|
||||
controller = serve.api._get_controller()
|
||||
|
||||
@@ -43,7 +43,7 @@ class HTTPProxy:
|
||||
label_names=("route", ))
|
||||
|
||||
self.router = Router()
|
||||
await self.router.setup(instance_name)
|
||||
await self.router.setup(name, instance_name)
|
||||
|
||||
def set_route_table(self, route_table):
|
||||
self.route_table = route_table
|
||||
@@ -170,10 +170,10 @@ class HTTPProxy:
|
||||
|
||||
@ray.remote
|
||||
class HTTPProxyActor:
|
||||
async def __init__(self, host, port, instance_name=None):
|
||||
async def __init__(self, name, host, port, instance_name=None):
|
||||
serve.init(name=instance_name)
|
||||
self.app = HTTPProxy()
|
||||
await self.app.fetch_config_from_controller(instance_name)
|
||||
await self.app.fetch_config_from_controller(name, instance_name)
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
|
||||
+29
-12
@@ -12,9 +12,11 @@ from ray.exceptions import RayTaskError
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve.metric import MetricClient
|
||||
from ray.serve.policy import RandomEndpointPolicy
|
||||
from ray.serve.endpoint_policy import RandomEndpointPolicy
|
||||
from ray.serve.utils import logger, chain_future
|
||||
|
||||
REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
|
||||
|
||||
|
||||
class Query:
|
||||
def __init__(
|
||||
@@ -87,7 +89,7 @@ def _make_future_unwrapper(client_futures: List[asyncio.Future],
|
||||
class Router:
|
||||
"""A router that routes request to available workers."""
|
||||
|
||||
async def setup(self, instance_name=None):
|
||||
async def setup(self, name, instance_name=None):
|
||||
# Note: Several queues are used in the router
|
||||
# - When a request come in, it's placed inside its corresponding
|
||||
# endpoint_queue.
|
||||
@@ -98,6 +100,8 @@ class Router:
|
||||
# handles are dequed during the second stage of flush operation,
|
||||
# which assign queries in buffer_queue to actor handle.
|
||||
|
||||
self.name = name
|
||||
|
||||
# -- Queues -- #
|
||||
|
||||
# endpoint_name -> request queue
|
||||
@@ -117,8 +121,8 @@ class Router:
|
||||
self.backend_info = dict()
|
||||
# replica tag -> worker_handle
|
||||
self.replicas = dict()
|
||||
# replica_tag -> concurrent queries counter
|
||||
self.queries_counter = defaultdict(lambda: 0)
|
||||
# backend_name -> replica_tag -> concurrent queries counter
|
||||
self.queries_counter = defaultdict(lambda: defaultdict(int))
|
||||
|
||||
# -- Synchronization -- #
|
||||
|
||||
@@ -137,23 +141,25 @@ class Router:
|
||||
# them from the controller so that the router can transparently recover
|
||||
# from failure.
|
||||
serve.init(name=instance_name)
|
||||
controller = serve.api._get_controller()
|
||||
self.controller = serve.api._get_controller()
|
||||
|
||||
traffic_policies = ray.get(controller.get_traffic_policies.remote())
|
||||
traffic_policies = ray.get(
|
||||
self.controller.get_traffic_policies.remote())
|
||||
for endpoint, traffic_policy in traffic_policies.items():
|
||||
await self.set_traffic(endpoint, traffic_policy)
|
||||
|
||||
backend_dict = ray.get(controller.get_all_worker_handles.remote())
|
||||
backend_dict = ray.get(self.controller.get_all_worker_handles.remote())
|
||||
for backend_tag, replica_dict in backend_dict.items():
|
||||
for replica_tag, worker in replica_dict.items():
|
||||
await self.add_new_worker(backend_tag, replica_tag, worker)
|
||||
|
||||
backend_configs = ray.get(controller.get_backend_configs.remote())
|
||||
backend_configs = ray.get(self.controller.get_backend_configs.remote())
|
||||
for backend, backend_config in backend_configs.items():
|
||||
await self.set_backend_config(backend, backend_config)
|
||||
|
||||
# -- Metric Registration -- #
|
||||
[metric_exporter] = ray.get(controller.get_metric_exporter.remote())
|
||||
[metric_exporter] = ray.get(
|
||||
self.controller.get_metric_exporter.remote())
|
||||
self.metric_client = MetricClient(metric_exporter)
|
||||
self.num_router_requests = self.metric_client.new_counter(
|
||||
"num_router_requests",
|
||||
@@ -170,6 +176,8 @@ class Router:
|
||||
"from backend."),
|
||||
label_names=("backend", ))
|
||||
|
||||
asyncio.get_event_loop().create_task(self.report_queue_lengths())
|
||||
|
||||
async def enqueue_request(self, request_meta, *request_args,
|
||||
**request_kwargs):
|
||||
endpoint = request_meta.endpoint
|
||||
@@ -324,7 +332,7 @@ class Router:
|
||||
except RayTaskError as error:
|
||||
self.num_error_backend_request.labels(backend=backend).add()
|
||||
result = error
|
||||
self.queries_counter[backend_replica_tag] -= 1
|
||||
self.queries_counter[backend][backend_replica_tag] -= 1
|
||||
await self.mark_worker_idle(backend, backend_replica_tag)
|
||||
logger.debug("Got result in {:.2f}s".format(time.time() - start))
|
||||
return result
|
||||
@@ -347,7 +355,7 @@ class Router:
|
||||
max_queries = 1
|
||||
if backend in self.backend_info:
|
||||
max_queries = self.backend_info[backend].max_concurrent_queries
|
||||
curr_queries = self.queries_counter[backend_replica_tag]
|
||||
curr_queries = self.queries_counter[backend][backend_replica_tag]
|
||||
if curr_queries >= max_queries:
|
||||
# Put the worker back to the queue.
|
||||
worker_queue.appendleft(backend_replica_tag)
|
||||
@@ -359,7 +367,7 @@ class Router:
|
||||
continue
|
||||
|
||||
request = buffer_queue.pop(0)
|
||||
self.queries_counter[backend_replica_tag] += 1
|
||||
self.queries_counter[backend][backend_replica_tag] += 1
|
||||
future = asyncio.get_event_loop().create_task(
|
||||
self._do_query(backend, backend_replica_tag, request))
|
||||
|
||||
@@ -368,3 +376,12 @@ class Router:
|
||||
chain_future(future, request.async_future)
|
||||
|
||||
worker_queue.appendleft(backend_replica_tag)
|
||||
|
||||
async def report_queue_lengths(self):
|
||||
while True:
|
||||
self.controller.report_queue_lengths.remote(
|
||||
self.name, {
|
||||
backend: len(q)
|
||||
for backend, q in self.backend_queues.items()
|
||||
})
|
||||
await asyncio.sleep(REPORT_QUEUE_LENGTH_PERIOD_S)
|
||||
|
||||
@@ -50,7 +50,7 @@ async def test_runner_wraps_error():
|
||||
|
||||
async def test_runner_actor(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
def echo(flask_request, i=None):
|
||||
return i
|
||||
@@ -72,7 +72,7 @@ async def test_runner_actor(serve_instance):
|
||||
|
||||
async def test_ray_serve_mixin(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
CONSUMER_NAME = "runner-cls"
|
||||
PRODUCER_NAME = "prod-cls"
|
||||
@@ -98,7 +98,7 @@ async def test_ray_serve_mixin(serve_instance):
|
||||
|
||||
async def test_task_runner_check_context(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
def echo(flask_request, i=None):
|
||||
# Accessing the flask_request without web context should throw.
|
||||
@@ -120,7 +120,7 @@ async def test_task_runner_check_context(serve_instance):
|
||||
|
||||
async def test_task_runner_custom_method_single(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
class NonBatcher:
|
||||
def a(self, _):
|
||||
@@ -155,7 +155,7 @@ async def test_task_runner_custom_method_single(serve_instance):
|
||||
|
||||
async def test_task_runner_custom_method_batch(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
@serve.accept_batch
|
||||
class Batcher:
|
||||
@@ -221,7 +221,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
|
||||
|
||||
async def test_task_runner_perform_batch(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
def batcher(*args, **kwargs):
|
||||
return [serve.context.batch_size] * serve.context.batch_size
|
||||
@@ -252,7 +252,7 @@ async def test_task_runner_perform_batch(serve_instance):
|
||||
|
||||
async def test_task_runner_perform_async(serve_instance):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
@ray.remote
|
||||
class Barrier:
|
||||
|
||||
@@ -50,7 +50,7 @@ def task_runner_mock_actor():
|
||||
|
||||
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
|
||||
q.add_new_worker.remote("backend-single-prod", "replica-1",
|
||||
@@ -68,7 +68,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0}))
|
||||
|
||||
all_request_sent = []
|
||||
@@ -93,7 +93,7 @@ async def test_slo(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
|
||||
await q.add_new_worker.remote("backend-alter", "replica-1",
|
||||
@@ -112,7 +112,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
|
||||
|
||||
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
await q.set_traffic.remote(
|
||||
"svc", TrafficPolicy({
|
||||
@@ -142,7 +142,7 @@ async def test_queue_remove_replicas(serve_instance):
|
||||
|
||||
temp_actor = mock_task_runner()
|
||||
q = ray.remote(TestRouter).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
|
||||
await q.remove_worker.remote("backend-remove", "replica-1")
|
||||
assert ray.get(q.worker_queue_size.remote("backend")) == 0
|
||||
@@ -150,7 +150,7 @@ async def test_queue_remove_replicas(serve_instance):
|
||||
|
||||
async def test_shard_key(serve_instance, task_runner_mock_actor):
|
||||
q = ray.remote(Router).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
|
||||
num_backends = 5
|
||||
traffic_dict = {}
|
||||
@@ -211,7 +211,7 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
|
||||
worker = MockWorker.remote()
|
||||
q = ray.remote(VisibleRouter).remote()
|
||||
await q.setup.remote()
|
||||
await q.setup.remote("")
|
||||
backend_name = "max-concurrent-test"
|
||||
config = BackendConfig({"max_concurrent_queries": 1})
|
||||
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
|
||||
@@ -229,7 +229,8 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
# Let's retrieve the router internal state
|
||||
queries_counter, backend_queues = await q.get_queues.remote()
|
||||
# There should be just one inflight request
|
||||
assert queries_counter["max-concurrent-test:replica-tag"] == 1
|
||||
assert queries_counter[backend_name][
|
||||
"max-concurrent-test:replica-tag"] == 1
|
||||
# The second query is buffered
|
||||
assert len(backend_queues["max-concurrent-test"]) == 1
|
||||
|
||||
@@ -240,7 +241,8 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
# The internal state of router should have changed.
|
||||
queries_counter, backend_queues = await q.get_queues.remote()
|
||||
# There should still be one inflight request
|
||||
assert queries_counter["max-concurrent-test:replica-tag"] == 1
|
||||
assert queries_counter[backend_name][
|
||||
"max-concurrent-test:replica-tag"] == 1
|
||||
# But there shouldn't be any queries in the queue
|
||||
assert len(backend_queues["max-concurrent-test"]) == 0
|
||||
|
||||
@@ -250,7 +252,8 @@ async def test_router_use_max_concurrency(serve_instance):
|
||||
|
||||
# Checking the internal state of the router one more time
|
||||
queries_counter, backend_queues = await q.get_queues.remote()
|
||||
assert queries_counter["max-concurrent-test:replica-tag"] == 0
|
||||
assert queries_counter[backend_name][
|
||||
"max-concurrent-test:replica-tag"] == 0
|
||||
assert len(backend_queues["max-concurrent-test"]) == 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user