From cf1c7378957282ebc45bca71b3f1c12139c1fb36 Mon Sep 17 00:00:00 2001 From: Gekho457 <62982571+Gekho457@users.noreply.github.com> Date: Wed, 21 Oct 2020 21:42:38 -0400 Subject: [PATCH] [autoscaler/AWS] Updated AWS Node Provider threading logic (#11422) --- .../autoscaler/_private/aws/node_provider.py | 115 ++++++++++-------- python/ray/tests/BUILD | 7 ++ .../tests/aws/test_aws_batch_tag_update.py | 70 +++++++++++ 3 files changed, 139 insertions(+), 53 deletions(-) create mode 100644 python/ray/tests/aws/test_aws_batch_tag_update.py diff --git a/python/ray/autoscaler/_private/aws/node_provider.py b/python/ray/autoscaler/_private/aws/node_provider.py index 19b55d625..e7e318b40 100644 --- a/python/ray/autoscaler/_private/aws/node_provider.py +++ b/python/ray/autoscaler/_private/aws/node_provider.py @@ -3,6 +3,7 @@ import copy import threading from collections import defaultdict import logging +import time from typing import Any, Dict import boto3 @@ -22,6 +23,8 @@ from ray.autoscaler._private.cli_logger import cli_logger, cf logger = logging.getLogger(__name__) +TAG_BATCH_DELAY = 1 + def to_aws_format(tags): """Convert the Ray node name tag to the AWS-specific 'Name' tag.""" @@ -68,56 +71,23 @@ class AWSNodeProvider(NodeProvider): # Try availability zones round-robin, starting from random offset self.subnet_idx = random.randint(0, 100) - self.tag_cache = {} # Tags that we believe to actually be on EC2. - self.tag_cache_pending = {} # Tags that we will soon upload. + # Tags that we believe to actually be on EC2. + self.tag_cache = {} + # Tags that we will soon upload. + self.tag_cache_pending = defaultdict(dict) + # Number of threads waiting for a batched tag update. + self.batch_thread_count = 0 + self.batch_update_done = threading.Event() + self.batch_update_done.set() + self.ready_for_new_batch = threading.Event() + self.ready_for_new_batch.set() self.tag_cache_lock = threading.Lock() - self.tag_cache_update_event = threading.Event() - self.tag_cache_kill_event = threading.Event() - self.tag_update_thread = threading.Thread( - target=self._node_tag_update_loop) - self.tag_update_thread.start() + self.count_lock = threading.Lock() # Cache of node objects from the last nodes() call. This avoids # excessive DescribeInstances requests. self.cached_nodes = {} - def _node_tag_update_loop(self): - """Update the AWS tags for a cluster periodically. - - The purpose of this loop is to avoid excessive EC2 calls when a large - number of nodes are being launched simultaneously. - """ - while True: - self.tag_cache_update_event.wait() - self.tag_cache_update_event.clear() - - batch_updates = defaultdict(list) - - with self.tag_cache_lock: - for node_id, tags in self.tag_cache_pending.items(): - for x in tags.items(): - batch_updates[x].append(node_id) - self.tag_cache[node_id].update(tags) - - self.tag_cache_pending = {} - - for (k, v), node_ids in batch_updates.items(): - m = "Set tag {}={} on {}".format(k, v, node_ids) - with LogTimer("AWSNodeProvider: {}".format(m)): - if k == TAG_RAY_NODE_NAME: - k = "Name" - self.ec2.meta.client.create_tags( - Resources=node_ids, - Tags=[{ - "Key": k, - "Value": v - }], - ) - - self.tag_cache_kill_event.wait(timeout=5) - if self.tag_cache_kill_event.is_set(): - return - def non_terminated_nodes(self, tag_filters): # Note that these filters are acceptable because they are set on # node initialization, and so can never be sitting in the cache. @@ -186,13 +156,56 @@ class AWSNodeProvider(NodeProvider): return node.private_ip_address def set_node_tags(self, node_id, tags): + is_batching_thread = False with self.tag_cache_lock: - try: - self.tag_cache_pending[node_id].update(tags) - except KeyError: - self.tag_cache_pending[node_id] = tags + if not self.tag_cache_pending: + is_batching_thread = True + # Wait for threads in the last batch to exit + self.ready_for_new_batch.wait() + self.ready_for_new_batch.clear() + self.batch_update_done.clear() + self.tag_cache_pending[node_id].update(tags) - self.tag_cache_update_event.set() + if is_batching_thread: + time.sleep(TAG_BATCH_DELAY) + with self.tag_cache_lock: + self._update_node_tags() + self.batch_update_done.set() + + with self.count_lock: + self.batch_thread_count += 1 + self.batch_update_done.wait() + + with self.count_lock: + self.batch_thread_count -= 1 + if self.batch_thread_count == 0: + self.ready_for_new_batch.set() + + def _update_node_tags(self): + batch_updates = defaultdict(list) + + for node_id, tags in self.tag_cache_pending.items(): + for x in tags.items(): + batch_updates[x].append(node_id) + self.tag_cache[node_id].update(tags) + + self.tag_cache_pending = defaultdict(dict) + + self._create_tags(batch_updates) + + def _create_tags(self, batch_updates): + for (k, v), node_ids in batch_updates.items(): + m = "Set tag {}={} on {}".format(k, v, node_ids) + with LogTimer("AWSNodeProvider: {}".format(m)): + if k == TAG_RAY_NODE_NAME: + k = "Name" + self.ec2.meta.client.create_tags( + Resources=node_ids, + Tags=[{ + "Key": k, + "Value": v + }], + ) def create_node(self, node_config, tags, count): tags = copy.deepcopy(tags) @@ -473,10 +486,6 @@ class AWSNodeProvider(NodeProvider): return self._get_node(node_id) - def cleanup(self): - self.tag_cache_update_event.set() - self.tag_cache_kill_event.set() - @staticmethod def bootstrap_config(cluster_config): return bootstrap_aws(cluster_config) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 9fc69cfd2..ff7c2ab43 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -122,6 +122,13 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_aws_batch_tag_update", + size = "small", + srcs = SRCS + ["aws/test_aws_batch_tag_update.py"], + deps = ["//:ray_lib"], +) + # Note(simon): typing tests are not included in module list # because they requires globs and it might be refactored in the future. py_test( diff --git a/python/ray/tests/aws/test_aws_batch_tag_update.py b/python/ray/tests/aws/test_aws_batch_tag_update.py new file mode 100644 index 000000000..228e1c1b0 --- /dev/null +++ b/python/ray/tests/aws/test_aws_batch_tag_update.py @@ -0,0 +1,70 @@ +import threading +import time +import unittest +from unittest import mock + +import pytest + +from ray.autoscaler._private.aws.node_provider import AWSNodeProvider +from ray.autoscaler._private.aws.node_provider import TAG_BATCH_DELAY + + +def mock_create_tags(provider, batch_updates): + # Increment batches sent. + provider.batch_counter += 1 + # Increment tags updated. + provider.tag_update_counter += sum( + len(batch_updates[x]) for x in batch_updates) + + +def batch_test(num_threads, delay): + """Run AWSNodeProvider.set_node_tags in several threads, with a + specified delay between thread launches. + + Return the number of batches of tag updates and the number of tags + updated. + """ + with mock.patch("ray.autoscaler._private.aws.node_provider.make_ec2_client" + ), mock.patch.object(AWSNodeProvider, "_create_tags", + mock_create_tags): + provider = AWSNodeProvider( + provider_config={"region": "nowhere"}, cluster_name="default") + provider.batch_counter = 0 + provider.tag_update_counter = 0 + provider.tag_cache = {str(x): {} for x in range(num_threads)} + + threads = [] + for x in range(num_threads): + thread = threading.Thread( + target=provider.set_node_tags, args=(str(x), { + "foo": "bar" + })) + threads.append(thread) + + for thread in threads: + thread.start() + time.sleep(delay) + for thread in threads: + thread.join() + + return provider.batch_counter, provider.tag_update_counter + + +class TagBatchTest(unittest.TestCase): + def test_concurrent(self): + num_threads = 100 + batches_sent, tags_updated = batch_test(num_threads, delay=0) + self.assertLess(batches_sent, num_threads / 10) + self.assertEqual(tags_updated, num_threads) + + def test_serial(self): + num_threads = 5 + long_delay = TAG_BATCH_DELAY * 1.2 + batches_sent, tags_updated = batch_test(num_threads, delay=long_delay) + self.assertEqual(batches_sent, num_threads) + self.assertEqual(tags_updated, num_threads) + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__]))