Auto-scale ray clusters based on GCS load metrics (#1348)

This adds (experimental) auto-scaling support for Ray clusters based on GCS load metrics. The auto-scaling algorithm is as follows:

Based on current (instantaneous) load information, we compute the approximate number of "used workers". This is based on the bottleneck resource, e.g. if 8/8 GPUs are used in a 8-node cluster but all the CPUs are idle, the number of used nodes is still counted as 8. This number can also be fractional.
We scale that number by 1 / target_utilization_fraction and round up to determine the target cluster size (subject to the max_workers constraint). The autoscaler control loop takes care of launching new nodes until the target cluster size is met.
When a node is idle for more than idle_timeout_minutes, we remove it from the cluster if that would not drop the cluster size below min_workers.
Note that we'll need to update the wheel in the example yaml file after this PR is merged.
This commit is contained in:
Eric Liang
2017-12-31 14:39:57 -08:00
committed by GitHub
parent e970e24ea5
commit b6c42f96be
12 changed files with 657 additions and 176 deletions
+44 -3
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
import argparse
import binascii
import json
import logging
import os
@@ -14,9 +15,11 @@ import ray.utils
import redis
# Import flatbuffer bindings.
from ray.core.generated.DriverTableMessage import DriverTableMessage
from ray.core.generated.LocalSchedulerInfoMessage import \
LocalSchedulerInfoMessage
from ray.core.generated.SubscribeToDBClientTableReply import \
SubscribeToDBClientTableReply
from ray.autoscaler.autoscaler import StandardAutoscaler
SubscribeToDBClientTableReply
from ray.autoscaler.autoscaler import LoadMetrics, StandardAutoscaler
from ray.core.generated.TaskInfo import TaskInfo
from ray.services import get_ip_address, get_port
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
@@ -31,6 +34,7 @@ NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
TASK_STATUS_LOST = 32
# common/state/redis.cc
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"
@@ -92,8 +96,10 @@ class Monitor(object):
self.dead_local_schedulers = set()
self.live_plasma_managers = Counter()
self.dead_plasma_managers = set()
self.load_metrics = LoadMetrics()
if autoscaling_config:
self.autoscaler = StandardAutoscaler(autoscaling_config)
self.autoscaler = StandardAutoscaler(
autoscaling_config, self.load_metrics)
else:
self.autoscaler = None
@@ -286,6 +292,36 @@ class Monitor(object):
# already dead.
del self.live_plasma_managers[db_client_id]
def local_scheduler_info_handler(self, unused_channel, data):
"""Handle a local scheduler heartbeat from Redis."""
message = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(
data, 0)
num_resources = message.DynamicResourcesLength()
static_resources = {}
dynamic_resources = {}
for i in range(num_resources):
dyn = message.DynamicResources(i)
static = message.StaticResources(i)
dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value()
static_resources[static.Key().decode("utf-8")] = static.Value()
client_id = binascii.hexlify(message.DbClientId()).decode("utf-8")
clients = ray.global_state.client_table()
local_schedulers = [
entry for client in clients.values() for entry in client
if (entry["ClientType"] == "local_scheduler" and not
entry["Deleted"])
]
ip = None
for ls in local_schedulers:
if ls["DBClientID"] == client_id:
ip = ls["AuxAddress"].split(":")[0]
if ip:
self.load_metrics.update(ip, static_resources, dynamic_resources)
else:
print("Warning: could not find ip for client {} in {}".format(
client_id, local_schedulers))
def plasma_manager_heartbeat_handler(self, unused_channel, data):
"""Handle a plasma manager heartbeat from Redis.
@@ -513,6 +549,10 @@ class Monitor(object):
assert self.subscribed[channel]
# The message was a heartbeat from a plasma manager.
message_handler = self.plasma_manager_heartbeat_handler
elif channel == LOCAL_SCHEDULER_INFO_CHANNEL:
assert self.subscribed[channel]
# The message was a heartbeat from a local scheduler
message_handler = self.local_scheduler_info_handler
elif channel == DB_CLIENT_TABLE_NAME:
assert self.subscribed[channel]
# The message was a notification from the db_client table.
@@ -537,6 +577,7 @@ class Monitor(object):
"""
# Initialize the subscription channel.
self.subscribe(DB_CLIENT_TABLE_NAME)
self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL)
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)