[autoscaler/AWS] Updated AWS Node Provider threading logic (#11422)

This commit is contained in:
Gekho457
2020-10-21 21:42:38 -04:00
committed by Alex Wu
parent 7200ddb72d
commit cf1c737895
3 changed files with 139 additions and 53 deletions
@@ -3,6 +3,7 @@ import copy
import threading
from collections import defaultdict
import logging
import time
from typing import Any, Dict
import boto3
@@ -22,6 +23,8 @@ from ray.autoscaler._private.cli_logger import cli_logger, cf
logger = logging.getLogger(__name__)
TAG_BATCH_DELAY = 1
def to_aws_format(tags):
"""Convert the Ray node name tag to the AWS-specific 'Name' tag."""
@@ -68,56 +71,23 @@ class AWSNodeProvider(NodeProvider):
# Try availability zones round-robin, starting from random offset
self.subnet_idx = random.randint(0, 100)
self.tag_cache = {} # Tags that we believe to actually be on EC2.
self.tag_cache_pending = {} # Tags that we will soon upload.
# Tags that we believe to actually be on EC2.
self.tag_cache = {}
# Tags that we will soon upload.
self.tag_cache_pending = defaultdict(dict)
# Number of threads waiting for a batched tag update.
self.batch_thread_count = 0
self.batch_update_done = threading.Event()
self.batch_update_done.set()
self.ready_for_new_batch = threading.Event()
self.ready_for_new_batch.set()
self.tag_cache_lock = threading.Lock()
self.tag_cache_update_event = threading.Event()
self.tag_cache_kill_event = threading.Event()
self.tag_update_thread = threading.Thread(
target=self._node_tag_update_loop)
self.tag_update_thread.start()
self.count_lock = threading.Lock()
# Cache of node objects from the last nodes() call. This avoids
# excessive DescribeInstances requests.
self.cached_nodes = {}
def _node_tag_update_loop(self):
"""Update the AWS tags for a cluster periodically.
The purpose of this loop is to avoid excessive EC2 calls when a large
number of nodes are being launched simultaneously.
"""
while True:
self.tag_cache_update_event.wait()
self.tag_cache_update_event.clear()
batch_updates = defaultdict(list)
with self.tag_cache_lock:
for node_id, tags in self.tag_cache_pending.items():
for x in tags.items():
batch_updates[x].append(node_id)
self.tag_cache[node_id].update(tags)
self.tag_cache_pending = {}
for (k, v), node_ids in batch_updates.items():
m = "Set tag {}={} on {}".format(k, v, node_ids)
with LogTimer("AWSNodeProvider: {}".format(m)):
if k == TAG_RAY_NODE_NAME:
k = "Name"
self.ec2.meta.client.create_tags(
Resources=node_ids,
Tags=[{
"Key": k,
"Value": v
}],
)
self.tag_cache_kill_event.wait(timeout=5)
if self.tag_cache_kill_event.is_set():
return
def non_terminated_nodes(self, tag_filters):
# Note that these filters are acceptable because they are set on
# node initialization, and so can never be sitting in the cache.
@@ -186,13 +156,56 @@ class AWSNodeProvider(NodeProvider):
return node.private_ip_address
def set_node_tags(self, node_id, tags):
is_batching_thread = False
with self.tag_cache_lock:
try:
self.tag_cache_pending[node_id].update(tags)
except KeyError:
self.tag_cache_pending[node_id] = tags
if not self.tag_cache_pending:
is_batching_thread = True
# Wait for threads in the last batch to exit
self.ready_for_new_batch.wait()
self.ready_for_new_batch.clear()
self.batch_update_done.clear()
self.tag_cache_pending[node_id].update(tags)
self.tag_cache_update_event.set()
if is_batching_thread:
time.sleep(TAG_BATCH_DELAY)
with self.tag_cache_lock:
self._update_node_tags()
self.batch_update_done.set()
with self.count_lock:
self.batch_thread_count += 1
self.batch_update_done.wait()
with self.count_lock:
self.batch_thread_count -= 1
if self.batch_thread_count == 0:
self.ready_for_new_batch.set()
def _update_node_tags(self):
batch_updates = defaultdict(list)
for node_id, tags in self.tag_cache_pending.items():
for x in tags.items():
batch_updates[x].append(node_id)
self.tag_cache[node_id].update(tags)
self.tag_cache_pending = defaultdict(dict)
self._create_tags(batch_updates)
def _create_tags(self, batch_updates):
for (k, v), node_ids in batch_updates.items():
m = "Set tag {}={} on {}".format(k, v, node_ids)
with LogTimer("AWSNodeProvider: {}".format(m)):
if k == TAG_RAY_NODE_NAME:
k = "Name"
self.ec2.meta.client.create_tags(
Resources=node_ids,
Tags=[{
"Key": k,
"Value": v
}],
)
def create_node(self, node_config, tags, count):
tags = copy.deepcopy(tags)
@@ -473,10 +486,6 @@ class AWSNodeProvider(NodeProvider):
return self._get_node(node_id)
def cleanup(self):
self.tag_cache_update_event.set()
self.tag_cache_kill_event.set()
@staticmethod
def bootstrap_config(cluster_config):
return bootstrap_aws(cluster_config)
+7
View File
@@ -122,6 +122,13 @@ py_test(
deps = ["//:ray_lib"],
)
py_test(
name = "test_aws_batch_tag_update",
size = "small",
srcs = SRCS + ["aws/test_aws_batch_tag_update.py"],
deps = ["//:ray_lib"],
)
# Note(simon): typing tests are not included in module list
# because they requires globs and it might be refactored in the future.
py_test(
@@ -0,0 +1,70 @@
import threading
import time
import unittest
from unittest import mock
import pytest
from ray.autoscaler._private.aws.node_provider import AWSNodeProvider
from ray.autoscaler._private.aws.node_provider import TAG_BATCH_DELAY
def mock_create_tags(provider, batch_updates):
# Increment batches sent.
provider.batch_counter += 1
# Increment tags updated.
provider.tag_update_counter += sum(
len(batch_updates[x]) for x in batch_updates)
def batch_test(num_threads, delay):
"""Run AWSNodeProvider.set_node_tags in several threads, with a
specified delay between thread launches.
Return the number of batches of tag updates and the number of tags
updated.
"""
with mock.patch("ray.autoscaler._private.aws.node_provider.make_ec2_client"
), mock.patch.object(AWSNodeProvider, "_create_tags",
mock_create_tags):
provider = AWSNodeProvider(
provider_config={"region": "nowhere"}, cluster_name="default")
provider.batch_counter = 0
provider.tag_update_counter = 0
provider.tag_cache = {str(x): {} for x in range(num_threads)}
threads = []
for x in range(num_threads):
thread = threading.Thread(
target=provider.set_node_tags, args=(str(x), {
"foo": "bar"
}))
threads.append(thread)
for thread in threads:
thread.start()
time.sleep(delay)
for thread in threads:
thread.join()
return provider.batch_counter, provider.tag_update_counter
class TagBatchTest(unittest.TestCase):
def test_concurrent(self):
num_threads = 100
batches_sent, tags_updated = batch_test(num_threads, delay=0)
self.assertLess(batches_sent, num_threads / 10)
self.assertEqual(tags_updated, num_threads)
def test_serial(self):
num_threads = 5
long_delay = TAG_BATCH_DELAY * 1.2
batches_sent, tags_updated = batch_test(num_threads, delay=long_delay)
self.assertEqual(batches_sent, num_threads)
self.assertEqual(tags_updated, num_threads)
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))