diff --git a/python/ray/autoscaler/_private/aws/config.py b/python/ray/autoscaler/_private/aws/config.py index 4c3a1c448..2fb90787b 100644 --- a/python/ray/autoscaler/_private/aws/config.py +++ b/python/ray/autoscaler/_private/aws/config.py @@ -496,11 +496,13 @@ def _check_ami(config): # If we do not provide a default AMI for the given region, noop. return - if config["head_node"].get("ImageId", "").lower() == "latest_dlami": + head_ami = config["head_node"].get("ImageId", "").lower() + if head_ami in ["", "latest_dlami"]: config["head_node"]["ImageId"] = default_ami _set_config_info(head_ami_src="dlami") - if config["worker_nodes"].get("ImageId", "").lower() == "latest_dlami": + worker_ami = config["worker_nodes"].get("ImageId", "").lower() + if worker_ami in ["", "latest_dlami"]: config["worker_nodes"]["ImageId"] = default_ami _set_config_info(workers_ami_src="dlami") diff --git a/python/ray/tests/aws/test_autoscaler_aws.py b/python/ray/tests/aws/test_autoscaler_aws.py index 52ceb9fb8..acf6c2d62 100644 --- a/python/ray/tests/aws/test_autoscaler_aws.py +++ b/python/ray/tests/aws/test_autoscaler_aws.py @@ -1,6 +1,8 @@ import pytest -from ray.autoscaler._private.aws.config import _get_vpc_id_or_die +from ray.autoscaler._private.aws.config import _get_vpc_id_or_die, \ + bootstrap_aws, \ + DEFAULT_AMI import ray.tests.aws.utils.stubs as stubs import ray.tests.aws.utils.helpers as helpers from ray.tests.aws.utils.constants import AUX_SUBNET, DEFAULT_SUBNET, \ @@ -133,6 +135,33 @@ def test_subnet_given_head_and_worker_sg(iam_client_stub, ec2_client_stub): ec2_client_stub.assert_no_pending_responses() +def test_fills_out_amis(iam_client_stub, ec2_client_stub): + # Setup stubs to mock out boto3 + stubs.configure_iam_role_default(iam_client_stub) + stubs.configure_key_pair_default(ec2_client_stub) + stubs.describe_a_security_group(ec2_client_stub, DEFAULT_SG) + stubs.configure_subnet_default(ec2_client_stub) + + config = helpers.load_aws_example_config_file("example-full.yaml") + del config["head_node"]["ImageId"] + del config["worker_nodes"]["ImageId"] + + # Pass in SG for stub to work + config["head_node"]["SecurityGroupIds"] = ["sg-1234abcd"] + config["worker_nodes"]["SecurityGroupIds"] = ["sg-1234abcd"] + + defaults_filled = bootstrap_aws(config) + + ami = DEFAULT_AMI.get(config.get("provider", {}).get("region")) + + assert defaults_filled["head_node"].get("ImageId") == ami + + assert defaults_filled["worker_nodes"].get("ImageId") == ami + + iam_client_stub.assert_no_pending_responses() + ec2_client_stub.assert_no_pending_responses() + + if __name__ == "__main__": import sys sys.exit(pytest.main(["-v", __file__]))