mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 08:48:13 +08:00
Auto-scale ray clusters based on GCS load metrics (#1348)
This adds (experimental) auto-scaling support for Ray clusters based on GCS load metrics. The auto-scaling algorithm is as follows: Based on current (instantaneous) load information, we compute the approximate number of "used workers". This is based on the bottleneck resource, e.g. if 8/8 GPUs are used in a 8-node cluster but all the CPUs are idle, the number of used nodes is still counted as 8. This number can also be fractional. We scale that number by 1 / target_utilization_fraction and round up to determine the target cluster size (subject to the max_workers constraint). The autoscaler control loop takes care of launching new nodes until the target cluster size is met. When a node is idle for more than idle_timeout_minutes, we remove it from the cluster if that would not drop the cluster size below min_workers. Note that we'll need to update the wheel in the example yaml file after this PR is merged.
This commit is contained in:
+44
-3
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import binascii
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -14,9 +15,11 @@ import ray.utils
|
||||
import redis
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.DriverTableMessage import DriverTableMessage
|
||||
from ray.core.generated.LocalSchedulerInfoMessage import \
|
||||
LocalSchedulerInfoMessage
|
||||
from ray.core.generated.SubscribeToDBClientTableReply import \
|
||||
SubscribeToDBClientTableReply
|
||||
from ray.autoscaler.autoscaler import StandardAutoscaler
|
||||
SubscribeToDBClientTableReply
|
||||
from ray.autoscaler.autoscaler import LoadMetrics, StandardAutoscaler
|
||||
from ray.core.generated.TaskInfo import TaskInfo
|
||||
from ray.services import get_ip_address, get_port
|
||||
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
|
||||
@@ -31,6 +34,7 @@ NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
|
||||
TASK_STATUS_LOST = 32
|
||||
|
||||
# common/state/redis.cc
|
||||
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
|
||||
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
|
||||
DRIVER_DEATH_CHANNEL = b"driver_deaths"
|
||||
|
||||
@@ -92,8 +96,10 @@ class Monitor(object):
|
||||
self.dead_local_schedulers = set()
|
||||
self.live_plasma_managers = Counter()
|
||||
self.dead_plasma_managers = set()
|
||||
self.load_metrics = LoadMetrics()
|
||||
if autoscaling_config:
|
||||
self.autoscaler = StandardAutoscaler(autoscaling_config)
|
||||
self.autoscaler = StandardAutoscaler(
|
||||
autoscaling_config, self.load_metrics)
|
||||
else:
|
||||
self.autoscaler = None
|
||||
|
||||
@@ -286,6 +292,36 @@ class Monitor(object):
|
||||
# already dead.
|
||||
del self.live_plasma_managers[db_client_id]
|
||||
|
||||
def local_scheduler_info_handler(self, unused_channel, data):
|
||||
"""Handle a local scheduler heartbeat from Redis."""
|
||||
|
||||
message = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(
|
||||
data, 0)
|
||||
num_resources = message.DynamicResourcesLength()
|
||||
static_resources = {}
|
||||
dynamic_resources = {}
|
||||
for i in range(num_resources):
|
||||
dyn = message.DynamicResources(i)
|
||||
static = message.StaticResources(i)
|
||||
dynamic_resources[dyn.Key().decode("utf-8")] = dyn.Value()
|
||||
static_resources[static.Key().decode("utf-8")] = static.Value()
|
||||
client_id = binascii.hexlify(message.DbClientId()).decode("utf-8")
|
||||
clients = ray.global_state.client_table()
|
||||
local_schedulers = [
|
||||
entry for client in clients.values() for entry in client
|
||||
if (entry["ClientType"] == "local_scheduler" and not
|
||||
entry["Deleted"])
|
||||
]
|
||||
ip = None
|
||||
for ls in local_schedulers:
|
||||
if ls["DBClientID"] == client_id:
|
||||
ip = ls["AuxAddress"].split(":")[0]
|
||||
if ip:
|
||||
self.load_metrics.update(ip, static_resources, dynamic_resources)
|
||||
else:
|
||||
print("Warning: could not find ip for client {} in {}".format(
|
||||
client_id, local_schedulers))
|
||||
|
||||
def plasma_manager_heartbeat_handler(self, unused_channel, data):
|
||||
"""Handle a plasma manager heartbeat from Redis.
|
||||
|
||||
@@ -513,6 +549,10 @@ class Monitor(object):
|
||||
assert self.subscribed[channel]
|
||||
# The message was a heartbeat from a plasma manager.
|
||||
message_handler = self.plasma_manager_heartbeat_handler
|
||||
elif channel == LOCAL_SCHEDULER_INFO_CHANNEL:
|
||||
assert self.subscribed[channel]
|
||||
# The message was a heartbeat from a local scheduler
|
||||
message_handler = self.local_scheduler_info_handler
|
||||
elif channel == DB_CLIENT_TABLE_NAME:
|
||||
assert self.subscribed[channel]
|
||||
# The message was a notification from the db_client table.
|
||||
@@ -537,6 +577,7 @@ class Monitor(object):
|
||||
"""
|
||||
# Initialize the subscription channel.
|
||||
self.subscribe(DB_CLIENT_TABLE_NAME)
|
||||
self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL)
|
||||
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
|
||||
self.subscribe(DRIVER_DEATH_CHANNEL)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user