mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 19:22:51 +08:00
[autoscaler/AWS] Updated AWS Node Provider threading logic (#11422)
This commit is contained in:
@@ -3,6 +3,7 @@ import copy
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import boto3
|
||||
@@ -22,6 +23,8 @@ from ray.autoscaler._private.cli_logger import cli_logger, cf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TAG_BATCH_DELAY = 1
|
||||
|
||||
|
||||
def to_aws_format(tags):
|
||||
"""Convert the Ray node name tag to the AWS-specific 'Name' tag."""
|
||||
@@ -68,56 +71,23 @@ class AWSNodeProvider(NodeProvider):
|
||||
# Try availability zones round-robin, starting from random offset
|
||||
self.subnet_idx = random.randint(0, 100)
|
||||
|
||||
self.tag_cache = {} # Tags that we believe to actually be on EC2.
|
||||
self.tag_cache_pending = {} # Tags that we will soon upload.
|
||||
# Tags that we believe to actually be on EC2.
|
||||
self.tag_cache = {}
|
||||
# Tags that we will soon upload.
|
||||
self.tag_cache_pending = defaultdict(dict)
|
||||
# Number of threads waiting for a batched tag update.
|
||||
self.batch_thread_count = 0
|
||||
self.batch_update_done = threading.Event()
|
||||
self.batch_update_done.set()
|
||||
self.ready_for_new_batch = threading.Event()
|
||||
self.ready_for_new_batch.set()
|
||||
self.tag_cache_lock = threading.Lock()
|
||||
self.tag_cache_update_event = threading.Event()
|
||||
self.tag_cache_kill_event = threading.Event()
|
||||
self.tag_update_thread = threading.Thread(
|
||||
target=self._node_tag_update_loop)
|
||||
self.tag_update_thread.start()
|
||||
self.count_lock = threading.Lock()
|
||||
|
||||
# Cache of node objects from the last nodes() call. This avoids
|
||||
# excessive DescribeInstances requests.
|
||||
self.cached_nodes = {}
|
||||
|
||||
def _node_tag_update_loop(self):
|
||||
"""Update the AWS tags for a cluster periodically.
|
||||
|
||||
The purpose of this loop is to avoid excessive EC2 calls when a large
|
||||
number of nodes are being launched simultaneously.
|
||||
"""
|
||||
while True:
|
||||
self.tag_cache_update_event.wait()
|
||||
self.tag_cache_update_event.clear()
|
||||
|
||||
batch_updates = defaultdict(list)
|
||||
|
||||
with self.tag_cache_lock:
|
||||
for node_id, tags in self.tag_cache_pending.items():
|
||||
for x in tags.items():
|
||||
batch_updates[x].append(node_id)
|
||||
self.tag_cache[node_id].update(tags)
|
||||
|
||||
self.tag_cache_pending = {}
|
||||
|
||||
for (k, v), node_ids in batch_updates.items():
|
||||
m = "Set tag {}={} on {}".format(k, v, node_ids)
|
||||
with LogTimer("AWSNodeProvider: {}".format(m)):
|
||||
if k == TAG_RAY_NODE_NAME:
|
||||
k = "Name"
|
||||
self.ec2.meta.client.create_tags(
|
||||
Resources=node_ids,
|
||||
Tags=[{
|
||||
"Key": k,
|
||||
"Value": v
|
||||
}],
|
||||
)
|
||||
|
||||
self.tag_cache_kill_event.wait(timeout=5)
|
||||
if self.tag_cache_kill_event.is_set():
|
||||
return
|
||||
|
||||
def non_terminated_nodes(self, tag_filters):
|
||||
# Note that these filters are acceptable because they are set on
|
||||
# node initialization, and so can never be sitting in the cache.
|
||||
@@ -186,13 +156,56 @@ class AWSNodeProvider(NodeProvider):
|
||||
return node.private_ip_address
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
is_batching_thread = False
|
||||
with self.tag_cache_lock:
|
||||
try:
|
||||
self.tag_cache_pending[node_id].update(tags)
|
||||
except KeyError:
|
||||
self.tag_cache_pending[node_id] = tags
|
||||
if not self.tag_cache_pending:
|
||||
is_batching_thread = True
|
||||
# Wait for threads in the last batch to exit
|
||||
self.ready_for_new_batch.wait()
|
||||
self.ready_for_new_batch.clear()
|
||||
self.batch_update_done.clear()
|
||||
self.tag_cache_pending[node_id].update(tags)
|
||||
|
||||
self.tag_cache_update_event.set()
|
||||
if is_batching_thread:
|
||||
time.sleep(TAG_BATCH_DELAY)
|
||||
with self.tag_cache_lock:
|
||||
self._update_node_tags()
|
||||
self.batch_update_done.set()
|
||||
|
||||
with self.count_lock:
|
||||
self.batch_thread_count += 1
|
||||
self.batch_update_done.wait()
|
||||
|
||||
with self.count_lock:
|
||||
self.batch_thread_count -= 1
|
||||
if self.batch_thread_count == 0:
|
||||
self.ready_for_new_batch.set()
|
||||
|
||||
def _update_node_tags(self):
|
||||
batch_updates = defaultdict(list)
|
||||
|
||||
for node_id, tags in self.tag_cache_pending.items():
|
||||
for x in tags.items():
|
||||
batch_updates[x].append(node_id)
|
||||
self.tag_cache[node_id].update(tags)
|
||||
|
||||
self.tag_cache_pending = defaultdict(dict)
|
||||
|
||||
self._create_tags(batch_updates)
|
||||
|
||||
def _create_tags(self, batch_updates):
|
||||
for (k, v), node_ids in batch_updates.items():
|
||||
m = "Set tag {}={} on {}".format(k, v, node_ids)
|
||||
with LogTimer("AWSNodeProvider: {}".format(m)):
|
||||
if k == TAG_RAY_NODE_NAME:
|
||||
k = "Name"
|
||||
self.ec2.meta.client.create_tags(
|
||||
Resources=node_ids,
|
||||
Tags=[{
|
||||
"Key": k,
|
||||
"Value": v
|
||||
}],
|
||||
)
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
tags = copy.deepcopy(tags)
|
||||
@@ -473,10 +486,6 @@ class AWSNodeProvider(NodeProvider):
|
||||
|
||||
return self._get_node(node_id)
|
||||
|
||||
def cleanup(self):
|
||||
self.tag_cache_update_event.set()
|
||||
self.tag_cache_kill_event.set()
|
||||
|
||||
@staticmethod
|
||||
def bootstrap_config(cluster_config):
|
||||
return bootstrap_aws(cluster_config)
|
||||
|
||||
@@ -122,6 +122,13 @@ py_test(
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_aws_batch_tag_update",
|
||||
size = "small",
|
||||
srcs = SRCS + ["aws/test_aws_batch_tag_update.py"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
# Note(simon): typing tests are not included in module list
|
||||
# because they requires globs and it might be refactored in the future.
|
||||
py_test(
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from ray.autoscaler._private.aws.node_provider import AWSNodeProvider
|
||||
from ray.autoscaler._private.aws.node_provider import TAG_BATCH_DELAY
|
||||
|
||||
|
||||
def mock_create_tags(provider, batch_updates):
|
||||
# Increment batches sent.
|
||||
provider.batch_counter += 1
|
||||
# Increment tags updated.
|
||||
provider.tag_update_counter += sum(
|
||||
len(batch_updates[x]) for x in batch_updates)
|
||||
|
||||
|
||||
def batch_test(num_threads, delay):
|
||||
"""Run AWSNodeProvider.set_node_tags in several threads, with a
|
||||
specified delay between thread launches.
|
||||
|
||||
Return the number of batches of tag updates and the number of tags
|
||||
updated.
|
||||
"""
|
||||
with mock.patch("ray.autoscaler._private.aws.node_provider.make_ec2_client"
|
||||
), mock.patch.object(AWSNodeProvider, "_create_tags",
|
||||
mock_create_tags):
|
||||
provider = AWSNodeProvider(
|
||||
provider_config={"region": "nowhere"}, cluster_name="default")
|
||||
provider.batch_counter = 0
|
||||
provider.tag_update_counter = 0
|
||||
provider.tag_cache = {str(x): {} for x in range(num_threads)}
|
||||
|
||||
threads = []
|
||||
for x in range(num_threads):
|
||||
thread = threading.Thread(
|
||||
target=provider.set_node_tags, args=(str(x), {
|
||||
"foo": "bar"
|
||||
}))
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
time.sleep(delay)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
return provider.batch_counter, provider.tag_update_counter
|
||||
|
||||
|
||||
class TagBatchTest(unittest.TestCase):
|
||||
def test_concurrent(self):
|
||||
num_threads = 100
|
||||
batches_sent, tags_updated = batch_test(num_threads, delay=0)
|
||||
self.assertLess(batches_sent, num_threads / 10)
|
||||
self.assertEqual(tags_updated, num_threads)
|
||||
|
||||
def test_serial(self):
|
||||
num_threads = 5
|
||||
long_delay = TAG_BATCH_DELAY * 1.2
|
||||
batches_sent, tags_updated = batch_test(num_threads, delay=long_delay)
|
||||
self.assertEqual(batches_sent, num_threads)
|
||||
self.assertEqual(tags_updated, num_threads)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user