[autoscaler] Allowing users to provide extra configs for AWS (#7844)

* Allowing users to provide custom key names & security group inbound rules

* linting

* getting aws credentials passed in

* one more thing

* one more thing part 2

* formatting

* addressing comments

* update

* update

* update

* update

* update

* update

* remove tests

* rerun tests

Co-authored-by: Allen Yin <allenyin@anyscale.io>
This commit is contained in:
Allen
2020-04-04 18:36:51 -07:00
committed by GitHub
parent 630b3b1752
commit 3c91ff1f63
4 changed files with 51 additions and 25 deletions
+38 -11
View File
@@ -37,13 +37,21 @@ assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
"Boto3 version >= 1.4.8 required, try `pip install -U boto3`"
def key_pair(i, region):
"""Returns the ith default (aws_key_pair_name, key_pair_path)."""
def key_pair(i, region, key_name):
"""
If key_name is not None, key_pair will be named after key_name.
Returns the ith default (aws_key_pair_name, key_pair_path).
"""
if i == 0:
return ("{}_{}".format(RAY, region),
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
return ("{}_{}_{}".format(RAY, i, region),
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
key_pair_name = ("{}_{}".format(RAY, region)
if key_name is None else key_name)
return (key_pair_name,
os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)))
key_pair_name = ("{}_{}_{}".format(RAY, i, region)
if key_name is None else key_name + "_key-{}".format(i))
return (key_pair_name,
os.path.expanduser("~/.ssh/{}.pem".format(key_pair_name)))
# Suppress excessive connection dropped logs from boto
@@ -136,7 +144,11 @@ def _configure_key_pair(config):
# Try a few times to get or create a good key pair.
MAX_NUM_KEYS = 30
for i in range(MAX_NUM_KEYS):
key_name, key_path = key_pair(i, config["provider"]["region"])
key_name = config["provider"].get("key_pair", {}).get("key_name")
key_name, key_path = key_pair(i, config["provider"]["region"],
key_name)
key = _get_key(key_name, config)
# Found a good key.
@@ -236,7 +248,7 @@ def _configure_security_group(config):
assert security_group, "Failed to create security group"
if not security_group.ip_permissions:
security_group.authorize_ingress(IpPermissions=[{
IpPermissions = [{
"FromPort": -1,
"ToPort": -1,
"IpProtocol": "-1",
@@ -250,7 +262,13 @@ def _configure_security_group(config):
"IpRanges": [{
"CidrIp": "0.0.0.0/0"
}]
}])
}]
additional_IpPermissions = config["provider"].get(
"security_group", {}).get("IpPermissions", [])
IpPermissions.extend(additional_IpPermissions)
security_group.authorize_ingress(IpPermissions=IpPermissions)
if "SecurityGroupIds" not in config["head_node"]:
logger.info(
@@ -359,10 +377,19 @@ def _get_key(key_name, config):
def _client(name, config):
boto_config = Config(retries={"max_attempts": BOTO_MAX_RETRIES})
return boto3.client(name, config["provider"]["region"], config=boto_config)
aws_credentials = config["provider"].get("aws_credentials", {})
return boto3.client(
name,
config["provider"]["region"],
config=boto_config,
**aws_credentials)
def _resource(name, config):
boto_config = Config(retries={"max_attempts": BOTO_MAX_RETRIES})
aws_credentials = config["provider"].get("aws_credentials", {})
return boto3.resource(
name, config["provider"]["region"], config=boto_config)
name,
config["provider"]["region"],
config=boto_config,
**aws_credentials)
+12 -4
View File
@@ -34,10 +34,12 @@ def from_aws_format(tags):
return tags
def make_ec2_client(region, max_retries):
def make_ec2_client(region, max_retries, aws_credentials=None):
"""Make client, retrying requests up to `max_retries`."""
config = Config(retries={"max_attempts": max_retries})
return boto3.resource("ec2", region_name=region, config=config)
aws_credentials = aws_credentials or {}
return boto3.resource(
"ec2", region_name=region, config=config, **aws_credentials)
class AWSNodeProvider(NodeProvider):
@@ -45,10 +47,16 @@ class AWSNodeProvider(NodeProvider):
NodeProvider.__init__(self, provider_config, cluster_name)
self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes",
True)
aws_credentials = provider_config.get("aws_credentials")
self.ec2 = make_ec2_client(
region=provider_config["region"], max_retries=BOTO_MAX_RETRIES)
region=provider_config["region"],
max_retries=BOTO_MAX_RETRIES,
aws_credentials=aws_credentials)
self.ec2_fail_fast = make_ec2_client(
region=provider_config["region"], max_retries=0)
region=provider_config["region"],
max_retries=0,
aws_credentials=aws_credentials)
# Try availability zones round-robin, starting from random offset
self.subnet_idx = random.randint(0, 100)
+1 -5
View File
@@ -58,7 +58,7 @@
"type": "object",
"description": "Cloud-provider specific configuration.",
"required": [ "type" ],
"additionalProperties": false,
"additionalProperties": true,
"properties": {
"type": {
"type": "string",
@@ -128,10 +128,6 @@
"type": "object",
"description": "k8s autoscaler permissions, if using k8s"
},
"extra_config": {
"type": "object",
"description": "provider-specific config"
},
"cache_stopped_nodes": {
"type": "boolean",
"description": " Whether to try to reuse previously stopped nodes instead of launching nodes. This will also cause the autoscaler to stop nodes instead of terminating them. Only implemented for AWS."
-5
View File
@@ -328,11 +328,6 @@ class AutoscalingTest(unittest.TestCase):
validate_config(config)
del config["blah"]
config["provider"]["blah"] = "blah"
with pytest.raises(ValidationError):
validate_config(config)
del config["provider"]["blah"]
del config["provider"]
with pytest.raises(ValidationError):
validate_config(config)