[serve] Basic autoscaling policy (#9845)

This commit is contained in:
Edward Oakes
2020-08-05 21:11:35 -05:00
committed by GitHub
parent 1760586628
commit 38408574c4
9 changed files with 221 additions and 38 deletions
+125
View File
@@ -0,0 +1,125 @@
from abc import ABCMeta, abstractmethod
from ray.serve.utils import logger
class AutoscalingPolicy:
"""Defines the interface for an autoscaling policy.
To add a new autoscaling policy, a class should be defined that provides
this interface. The class may be stateful, in which case it may also want
to provide a non-default constructor. However, this state will be lost when
the controller recovers from a failure.
"""
__metaclass__ = ABCMeta
def __init__(self, config):
"""Initialize the policy using the specified config dictionary."""
self.config = config
@abstractmethod
def scale(self, router_queue_lens, curr_replicas):
"""Make a decision to scale backends.
Arguments:
router_queue_lens (Dict[str, int]): map of routers to their most
recent queue length of unsent queries for this backend.
curr_replicas (int): The number of replicas that the backend
currently has.
Returns:
int The new number of replicas to scale this backend to.
"""
return curr_replicas
class BasicAutoscalingPolicy(AutoscalingPolicy):
"""The default autoscaling policy based on basic thresholds for scaling.
There is a minimum threshold for the average queue length in the cluster
to scale up and a maximum threshold to scale down. Each period, a 'scale
up' or 'scale down' decision is made. This decision must be made for a
specified number of periods in a row before the number of replicas is
actually scaled. See config options for more details.
"""
def __init__(self, backend, config):
self.backend = backend
# The minimum average queue length to trigger scaling up.
self.scale_up_threshold = config.get("scale_up_threshold", 5)
# The maximum average queue length to trigger scaling down.
self.scale_down_threshold = config.get("scale_down_threshold", 1)
# The number of replicas to be added when scaling up.
self.scale_up_num_replicas = config.get("scale_up_num_replicas", 2)
# The number of replicas to be removed when scaling down.
self.scale_down_num_replicas = config.get("scale_down_num_replicas", 1)
# The number of consecutive 'scale up' decisions that need to be made
# before the number of replicas is actually increased.
self.scale_up_consecutive_periods = config.get(
"scale_up_consecutive_periods", 2)
# The number of consecutive 'scale down' decisions that need to be made
# before the number of replicas is actually decreased.
self.scale_down_consecutive_periods = config.get(
"scale_down_consecutive_periods", 5)
# Keeps track of previous decisions. Each time the load is above
# 'scale_up_threshold', the counter is incremented and each time it is
# below 'scale_down_threshold', the counter is decremented. When the
# load is between the thresholds or a scaling decision is made, the
# counter is reset to 0.
self.decision_counter = 0
def scale(self, router_queue_lens, curr_replicas):
queue_lens = list(router_queue_lens.values())
if len(queue_lens) == 0:
return -1
new_replicas = curr_replicas
avg_queue_len = sum(queue_lens) / len(queue_lens)
# Scale up.
if avg_queue_len > self.scale_up_threshold:
# If the previous decision was to scale down (the counter was
# negative), we reset it and then increment it (set to 1).
# Otherwise, just increment.
if self.decision_counter < 0:
self.decision_counter = 1
else:
self.decision_counter += 1
# Only actually scale the replicas if we've made this decision for
# 'scale_up_consecutive_periods' in a row.
if self.decision_counter >= self.scale_up_consecutive_periods:
# TODO(edoakes): should we be resetting the counter here?
self.decision_counter = 0
new_replicas = curr_replicas + self.scale_up_num_replicas
logger.info("Increasing number of replicas for backend '{}' "
"from {} to {}".format(self.backend, curr_replicas,
new_replicas))
# Scale down.
elif avg_queue_len < self.scale_down_threshold and curr_replicas > 1:
# If the previous decision was to scale up (the counter was
# positive), reset it to zero before decrementing.
if self.decision_counter > 0:
self.decision_counter = -1
else:
self.decision_counter -= 1
# Only actually scale the replicas if we've made this decision for
# 'scale_down_consecutive_periods' in a row.
if (self.decision_counter <=
-self.scale_down_consecutive_periods + 1):
# TODO(edoakes): should we be resetting the counter here?
self.decision_counter = 0
new_replicas = curr_replicas - self.scale_down_num_replicas
logger.info("Decreasing number of replicas for backend '{}' "
"from {} to {}".format(self.backend, curr_replicas,
new_replicas))
# Do nothing.
else:
self.decision_counter = 0
return new_replicas
+8 -4
View File
@@ -42,6 +42,9 @@ class BatchQueue:
if self.queue.qsize() == self.max_batch_size:
self.full_batch_event.set()
def qsize(self):
return self.queue.qsize()
async def wait_for_batch(self):
"""Wait for batch respecting self.max_batch_size and self.timeout_s.
@@ -153,9 +156,9 @@ def ensure_async(func):
class RayServeWorker:
"""Handles requests with the provided callable."""
def __init__(self, name, replica_tag, _callable,
def __init__(self, backend_tag, replica_tag, _callable,
backend_config: BackendConfig, is_function, metric_client):
self.name = name
self.backend_tag = backend_tag
self.replica_tag = replica_tag
self.callable = _callable
self.is_function = is_function
@@ -183,7 +186,7 @@ class RayServeWorker:
self.restart_counter.labels(replica_tag=self.replica_tag).add()
self.loop_task = asyncio.get_event_loop().create_task(self.main_loop())
asyncio.get_event_loop().create_task(self.main_loop())
def get_runner_method(self, request_item):
method_name = request_item.call_method
@@ -348,7 +351,8 @@ class RayServeWorker:
async def handle_request(self, request: Union[Query, bytes]):
if isinstance(request, bytes):
request = Query.ray_deserialize(request)
logger.debug("Worker {} got request {}".format(self.name, request))
logger.debug("Worker {} got request {}".format(self.replica_tag,
request))
request.async_future = asyncio.get_event_loop().create_future()
self.batch_queue.put(request)
return await request.async_future
+2 -1
View File
@@ -30,6 +30,7 @@ class BackendConfig:
self.batch_wait_timeout = config_dict.pop("batch_wait_timeout", 0)
self.max_concurrent_queries = config_dict.pop("max_concurrent_queries",
None)
self.autoscaling_config = config_dict.pop("autoscaling", None)
if self.max_concurrent_queries is None:
# Model serving mode: if the servable is blocking and the wait
@@ -136,7 +137,7 @@ class ReplicaConfig:
raise ValueError("Specifying max_restarts in "
"actor_init_args is not allowed.")
else:
num_cpus = self.ray_actor_options.get("num_cpus", 0)
num_cpus = self.ray_actor_options.get("num_cpus", 1)
if not isinstance(num_cpus, (int, float)):
raise TypeError(
"num_cpus in ray_actor_options must be an int or a float.")
+33
View File
@@ -6,6 +6,7 @@ import time
import ray
import ray.cloudpickle as pickle
from ray.serve.autoscaling_policy import BasicAutoscalingPolicy
from ray.serve.backend_worker import create_backend_worker
from ray.serve.constants import (ASYNC_CONCURRENCY, SERVE_PROXY_NAME,
SERVE_METRIC_SINK_NAME)
@@ -101,6 +102,8 @@ class ServeController:
self.routes = dict()
# backend -> BackendInfo.
self.backends = dict()
# backend -> AutoscalingPolicy
self.autoscaling_policies = dict()
# backend -> replica_tags.
self.replicas = defaultdict(list)
# replicas that should be started if recovering from a checkpoint.
@@ -118,6 +121,8 @@ class ServeController:
# Dictionary of backend tag to dictionaries of replica tag to worker.
# TODO(edoakes): consider removing this and just using the names.
self.workers = defaultdict(dict)
# Dictionary of backend_tag -> router_name -> most recent queue length.
self.backend_stats = defaultdict(lambda: defaultdict(dict))
# Used to ensure that only a single state-changing operation happens
# at any given time.
@@ -182,6 +187,7 @@ class ServeController:
node_resource: 0.01
},
).remote(
node_id,
self.http_host,
self.http_port,
instance_name=self.instance_name)
@@ -324,6 +330,9 @@ class ServeController:
for router in self.routers.values()
])
await self.broadcast_backend_config(backend)
if info.backend_config.autoscaling_config is not None:
self.autoscaling_policies[backend] = BasicAutoscalingPolicy(
backend, info.backend_config.autoscaling_config)
# Push configuration state to the routers.
await asyncio.gather(*[
@@ -344,8 +353,21 @@ class ServeController:
self.write_lock.release()
async def do_autoscale(self):
for backend in self.backends:
if backend not in self.autoscaling_policies:
continue
new_num_replicas = self.autoscaling_policies[backend].scale(
self.backend_stats[backend],
self.backends[backend].backend_config.num_replicas)
if new_num_replicas > 0:
await self.update_backend_config(
backend, {"num_replicas": new_num_replicas})
async def run_control_loop(self):
while True:
await self.do_autoscale()
async with self.write_lock:
self._start_routers_if_needed()
checkpoint_required = self._stop_routers_if_needed()
@@ -730,6 +752,10 @@ class ServeController:
# and the configuration for the backends.
self.backends[backend_tag] = BackendInfo(
backend_worker, backend_config, replica_config)
if backend_config.autoscaling_config is not None:
self.autoscaling_policies[
backend_tag] = BasicAutoscalingPolicy(
backend_tag, backend_config.autoscaling_config)
self._scale_replicas(backend_tag, backend_config.num_replicas)
@@ -769,6 +795,8 @@ class ServeController:
# Remove the backend's metadata.
del self.backends[backend_tag]
if backend_tag in self.autoscaling_policies:
del self.autoscaling_policies[backend_tag]
# Add the intention to remove the backend from the router.
self.backends_to_remove.append(backend_tag)
@@ -840,3 +868,8 @@ class ServeController:
for replica in replica_dict.values():
ray.kill(replica, no_restart=True)
self.kv_store.delete(CHECKPOINT_KEY)
async def report_queue_lengths(self, router_name, queue_lengths):
# TODO: remove old router stats when removing them.
for backend, queue_length in queue_lengths.items():
self.backend_stats[backend][router_name] = queue_length
+4 -4
View File
@@ -27,7 +27,7 @@ class HTTPProxy:
# blocks forever
"""
async def fetch_config_from_controller(self, instance_name=None):
async def fetch_config_from_controller(self, name, instance_name=None):
assert ray.is_initialized()
controller = serve.api._get_controller()
@@ -43,7 +43,7 @@ class HTTPProxy:
label_names=("route", ))
self.router = Router()
await self.router.setup(instance_name)
await self.router.setup(name, instance_name)
def set_route_table(self, route_table):
self.route_table = route_table
@@ -170,10 +170,10 @@ class HTTPProxy:
@ray.remote
class HTTPProxyActor:
async def __init__(self, host, port, instance_name=None):
async def __init__(self, name, host, port, instance_name=None):
serve.init(name=instance_name)
self.app = HTTPProxy()
await self.app.fetch_config_from_controller(instance_name)
await self.app.fetch_config_from_controller(name, instance_name)
self.host = host
self.port = port
+29 -12
View File
@@ -12,9 +12,11 @@ from ray.exceptions import RayTaskError
import ray
from ray import serve
from ray.serve.metric import MetricClient
from ray.serve.policy import RandomEndpointPolicy
from ray.serve.endpoint_policy import RandomEndpointPolicy
from ray.serve.utils import logger, chain_future
REPORT_QUEUE_LENGTH_PERIOD_S = 1.0
class Query:
def __init__(
@@ -87,7 +89,7 @@ def _make_future_unwrapper(client_futures: List[asyncio.Future],
class Router:
"""A router that routes request to available workers."""
async def setup(self, instance_name=None):
async def setup(self, name, instance_name=None):
# Note: Several queues are used in the router
# - When a request come in, it's placed inside its corresponding
# endpoint_queue.
@@ -98,6 +100,8 @@ class Router:
# handles are dequed during the second stage of flush operation,
# which assign queries in buffer_queue to actor handle.
self.name = name
# -- Queues -- #
# endpoint_name -> request queue
@@ -117,8 +121,8 @@ class Router:
self.backend_info = dict()
# replica tag -> worker_handle
self.replicas = dict()
# replica_tag -> concurrent queries counter
self.queries_counter = defaultdict(lambda: 0)
# backend_name -> replica_tag -> concurrent queries counter
self.queries_counter = defaultdict(lambda: defaultdict(int))
# -- Synchronization -- #
@@ -137,23 +141,25 @@ class Router:
# them from the controller so that the router can transparently recover
# from failure.
serve.init(name=instance_name)
controller = serve.api._get_controller()
self.controller = serve.api._get_controller()
traffic_policies = ray.get(controller.get_traffic_policies.remote())
traffic_policies = ray.get(
self.controller.get_traffic_policies.remote())
for endpoint, traffic_policy in traffic_policies.items():
await self.set_traffic(endpoint, traffic_policy)
backend_dict = ray.get(controller.get_all_worker_handles.remote())
backend_dict = ray.get(self.controller.get_all_worker_handles.remote())
for backend_tag, replica_dict in backend_dict.items():
for replica_tag, worker in replica_dict.items():
await self.add_new_worker(backend_tag, replica_tag, worker)
backend_configs = ray.get(controller.get_backend_configs.remote())
backend_configs = ray.get(self.controller.get_backend_configs.remote())
for backend, backend_config in backend_configs.items():
await self.set_backend_config(backend, backend_config)
# -- Metric Registration -- #
[metric_exporter] = ray.get(controller.get_metric_exporter.remote())
[metric_exporter] = ray.get(
self.controller.get_metric_exporter.remote())
self.metric_client = MetricClient(metric_exporter)
self.num_router_requests = self.metric_client.new_counter(
"num_router_requests",
@@ -170,6 +176,8 @@ class Router:
"from backend."),
label_names=("backend", ))
asyncio.get_event_loop().create_task(self.report_queue_lengths())
async def enqueue_request(self, request_meta, *request_args,
**request_kwargs):
endpoint = request_meta.endpoint
@@ -324,7 +332,7 @@ class Router:
except RayTaskError as error:
self.num_error_backend_request.labels(backend=backend).add()
result = error
self.queries_counter[backend_replica_tag] -= 1
self.queries_counter[backend][backend_replica_tag] -= 1
await self.mark_worker_idle(backend, backend_replica_tag)
logger.debug("Got result in {:.2f}s".format(time.time() - start))
return result
@@ -347,7 +355,7 @@ class Router:
max_queries = 1
if backend in self.backend_info:
max_queries = self.backend_info[backend].max_concurrent_queries
curr_queries = self.queries_counter[backend_replica_tag]
curr_queries = self.queries_counter[backend][backend_replica_tag]
if curr_queries >= max_queries:
# Put the worker back to the queue.
worker_queue.appendleft(backend_replica_tag)
@@ -359,7 +367,7 @@ class Router:
continue
request = buffer_queue.pop(0)
self.queries_counter[backend_replica_tag] += 1
self.queries_counter[backend][backend_replica_tag] += 1
future = asyncio.get_event_loop().create_task(
self._do_query(backend, backend_replica_tag, request))
@@ -368,3 +376,12 @@ class Router:
chain_future(future, request.async_future)
worker_queue.appendleft(backend_replica_tag)
async def report_queue_lengths(self):
while True:
self.controller.report_queue_lengths.remote(
self.name, {
backend: len(q)
for backend, q in self.backend_queues.items()
})
await asyncio.sleep(REPORT_QUEUE_LENGTH_PERIOD_S)
@@ -50,7 +50,7 @@ async def test_runner_wraps_error():
async def test_runner_actor(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
def echo(flask_request, i=None):
return i
@@ -72,7 +72,7 @@ async def test_runner_actor(serve_instance):
async def test_ray_serve_mixin(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
CONSUMER_NAME = "runner-cls"
PRODUCER_NAME = "prod-cls"
@@ -98,7 +98,7 @@ async def test_ray_serve_mixin(serve_instance):
async def test_task_runner_check_context(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
def echo(flask_request, i=None):
# Accessing the flask_request without web context should throw.
@@ -120,7 +120,7 @@ async def test_task_runner_check_context(serve_instance):
async def test_task_runner_custom_method_single(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
class NonBatcher:
def a(self, _):
@@ -155,7 +155,7 @@ async def test_task_runner_custom_method_single(serve_instance):
async def test_task_runner_custom_method_batch(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
@serve.accept_batch
class Batcher:
@@ -221,7 +221,7 @@ async def test_task_runner_custom_method_batch(serve_instance):
async def test_task_runner_perform_batch(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
def batcher(*args, **kwargs):
return [serve.context.batch_size] * serve.context.batch_size
@@ -252,7 +252,7 @@ async def test_task_runner_perform_batch(serve_instance):
async def test_task_runner_perform_async(serve_instance):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
@ray.remote
class Barrier:
+13 -10
View File
@@ -50,7 +50,7 @@ def task_runner_mock_actor():
async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
q.set_traffic.remote("svc", TrafficPolicy({"backend-single-prod": 1.0}))
q.add_new_worker.remote("backend-single-prod", "replica-1",
@@ -68,7 +68,7 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):
async def test_slo(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
await q.set_traffic.remote("svc", TrafficPolicy({"backend-slo": 1.0}))
all_request_sent = []
@@ -93,7 +93,7 @@ async def test_slo(serve_instance, task_runner_mock_actor):
async def test_alter_backend(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter": 1}))
await q.add_new_worker.remote("backend-alter", "replica-1",
@@ -112,7 +112,7 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
await q.set_traffic.remote(
"svc", TrafficPolicy({
@@ -142,7 +142,7 @@ async def test_queue_remove_replicas(serve_instance):
temp_actor = mock_task_runner()
q = ray.remote(TestRouter).remote()
await q.setup.remote()
await q.setup.remote("")
await q.add_new_worker.remote("backend-remove", "replica-1", temp_actor)
await q.remove_worker.remote("backend-remove", "replica-1")
assert ray.get(q.worker_queue_size.remote("backend")) == 0
@@ -150,7 +150,7 @@ async def test_queue_remove_replicas(serve_instance):
async def test_shard_key(serve_instance, task_runner_mock_actor):
q = ray.remote(Router).remote()
await q.setup.remote()
await q.setup.remote("")
num_backends = 5
traffic_dict = {}
@@ -211,7 +211,7 @@ async def test_router_use_max_concurrency(serve_instance):
worker = MockWorker.remote()
q = ray.remote(VisibleRouter).remote()
await q.setup.remote()
await q.setup.remote("")
backend_name = "max-concurrent-test"
config = BackendConfig({"max_concurrent_queries": 1})
await q.set_traffic.remote("svc", TrafficPolicy({backend_name: 1.0}))
@@ -229,7 +229,8 @@ async def test_router_use_max_concurrency(serve_instance):
# Let's retrieve the router internal state
queries_counter, backend_queues = await q.get_queues.remote()
# There should be just one inflight request
assert queries_counter["max-concurrent-test:replica-tag"] == 1
assert queries_counter[backend_name][
"max-concurrent-test:replica-tag"] == 1
# The second query is buffered
assert len(backend_queues["max-concurrent-test"]) == 1
@@ -240,7 +241,8 @@ async def test_router_use_max_concurrency(serve_instance):
# The internal state of router should have changed.
queries_counter, backend_queues = await q.get_queues.remote()
# There should still be one inflight request
assert queries_counter["max-concurrent-test:replica-tag"] == 1
assert queries_counter[backend_name][
"max-concurrent-test:replica-tag"] == 1
# But there shouldn't be any queries in the queue
assert len(backend_queues["max-concurrent-test"]) == 0
@@ -250,7 +252,8 @@ async def test_router_use_max_concurrency(serve_instance):
# Checking the internal state of the router one more time
queries_counter, backend_queues = await q.get_queues.remote()
assert queries_counter["max-concurrent-test:replica-tag"] == 0
assert queries_counter[backend_name][
"max-concurrent-test:replica-tag"] == 0
assert len(backend_queues["max-concurrent-test"]) == 0