diff --git a/python/ray/autoscaler/_private/aws/node_provider.py b/python/ray/autoscaler/_private/aws/node_provider.py index 25f971a2b..3cd7797ed 100644 --- a/python/ray/autoscaler/_private/aws/node_provider.py +++ b/python/ray/autoscaler/_private/aws/node_provider.py @@ -52,7 +52,8 @@ def make_ec2_client(region, max_retries, aws_credentials=None): "ec2", region_name=region, config=config, **aws_credentials) -def list_ec2_instances(region: str) -> List[Dict[str, Any]]: +def list_ec2_instances(region: str, aws_credentials: Dict[str, Any] = None + ) -> List[Dict[str, Any]]: """Get all instance-types/resources available in the user's AWS region. Args: region (str): the region of the AWS provider. e.g., "us-west-2". @@ -68,13 +69,15 @@ def list_ec2_instances(region: str) -> List[Dict[str, Any]]: """ final_instance_types = [] - instance_types = boto3.client( - "ec2", region_name=region).describe_instance_types() + config = Config(retries={"max_attempts": BOTO_MAX_RETRIES}) + aws_credentials = aws_credentials or {} + ec2 = boto3.client( + "ec2", region_name=region, config=config, **aws_credentials) + instance_types = ec2.describe_instance_types() final_instance_types.extend(copy.deepcopy(instance_types["InstanceTypes"])) while "NextToken" in instance_types: - instance_types = boto3.client( - "ec2", region_name=region).describe_instance_types( - NextToken=instance_types["NextToken"]) + instance_types = ec2.describe_instance_types( + NextToken=instance_types["NextToken"]) final_instance_types.extend( copy.deepcopy(instance_types["InstanceTypes"])) @@ -480,7 +483,8 @@ class AWSNodeProvider(NodeProvider): cluster_config = copy.deepcopy(cluster_config) instances_list = list_ec2_instances( - cluster_config["provider"]["region"]) + cluster_config["provider"]["region"], + cluster_config["provider"].get("aws_credentials")) instances_dict = { instance["InstanceType"]: instance for instance in instances_list diff --git a/python/ray/autoscaler/_private/commands.py b/python/ray/autoscaler/_private/commands.py index a88015a94..c70fdf33e 100644 --- a/python/ray/autoscaler/_private/commands.py +++ b/python/ray/autoscaler/_private/commands.py @@ -260,7 +260,6 @@ def _bootstrap_config(config: Dict[str, Any], "This is normal if cluster launcher was updated.\n" "Config will be re-resolved.", config_cache.get("_version", "none"), CONFIG_CACHE_VERSION) - validate_config(config) importer = _NODE_PROVIDERS.get(config["provider"]["type"]) if not importer: @@ -271,6 +270,13 @@ def _bootstrap_config(config: Dict[str, Any], cli_logger.print("Checking {} environment settings", _PROVIDER_PRETTY_NAMES.get(config["provider"]["type"])) + + config = provider_cls.fillout_available_node_types_resources(config) + + # NOTE: if `resources` field is missing, validate_config for non-AWS will + # fail (the schema error will ask the user to manually fill the resources) + # as we currently support autofilling resources for AWS instances only. + validate_config(config) resolved_config = provider_cls.bootstrap_config(config) if not no_config_cache: @@ -291,8 +297,8 @@ def teardown_cluster(config_file: str, yes: bool, workers_only: bool, config = yaml.safe_load(open(config_file).read()) if override_cluster_name is not None: config["cluster_name"] = override_cluster_name - config = prepare_config(config) - validate_config(config) + + config = _bootstrap_config(config) cli_logger.confirm(yes, "Destroying cluster.", _abort=True) diff --git a/python/ray/autoscaler/_private/util.py b/python/ray/autoscaler/_private/util.py index d5fade4e5..b500e846a 100644 --- a/python/ray/autoscaler/_private/util.py +++ b/python/ray/autoscaler/_private/util.py @@ -9,8 +9,7 @@ from typing import Any, Dict import ray import ray._private.services as services -from ray.autoscaler._private.providers import _get_default_config, \ - _NODE_PROVIDERS +from ray.autoscaler._private.providers import _get_default_config from ray.autoscaler._private.docker import validate_docker_config from ray.autoscaler.tags import NODE_TYPE_LEGACY_WORKER, NODE_TYPE_LEGACY_HEAD @@ -131,29 +130,9 @@ def fillout_defaults(config: Dict[str, Any]) -> Dict[str, Any]: defaults.update(config) defaults["auth"] = defaults.get("auth", {}) defaults = rewrite_legacy_yaml_to_available_node_types(defaults) - try: - defaults = _fillout_available_node_types_resources(defaults) - except ValueError: - # When the user uses a wrong instance type. - raise - except Exception: - # When the user is using e.g., staroid, but it is not installed. - logger.exception("Failed to autodetect node resources.") return defaults -def _fillout_available_node_types_resources( - cluster_config: Dict[str, Any]) -> Dict[str, Any]: - """Fills out missing "resources" field for available_node_types.""" - if "available_node_types" in cluster_config: - importer = _NODE_PROVIDERS.get(cluster_config["provider"]["type"]) - if importer is not None: - provider_cls = importer(cluster_config["provider"]) - return provider_cls.fillout_available_node_types_resources( - cluster_config) - return cluster_config - - def merge_setup_commands(config): config["head_setup_commands"] = ( config["setup_commands"] + config["head_setup_commands"]) diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index 0e3478cde..7a582cbd2 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -1216,9 +1216,6 @@ class AutoscalingTest(unittest.TestCase): "type": "external", "module": "does-not-exist", } - with pytest.raises(ValueError): - invalid_provider = self.write_config( - config, call_prepare_config=True) invalid_provider = self.write_config(config, call_prepare_config=False) with pytest.raises(ValueError): StandardAutoscaler( diff --git a/python/ray/tests/test_autoscaler_yaml.py b/python/ray/tests/test_autoscaler_yaml.py index 235240b12..b712c8955 100644 --- a/python/ray/tests/test_autoscaler_yaml.py +++ b/python/ray/tests/test_autoscaler_yaml.py @@ -10,6 +10,8 @@ from unittest.mock import MagicMock, Mock, patch import pytest from ray.autoscaler._private.util import prepare_config, validate_config +from ray.autoscaler._private.providers import _NODE_PROVIDERS + from ray.test_utils import recursive_fnmatch RAY_PATH = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) @@ -111,12 +113,19 @@ class AutoscalingConfigTest(unittest.TestCase): boto3=boto3_mock, ): new_config = prepare_config(new_config) + importer = _NODE_PROVIDERS.get(new_config["provider"]["type"]) + provider_cls = importer(new_config["provider"]) - try: - validate_config(new_config) - expected_available_node_types == new_config["available_node_types"] - except Exception: - self.fail("Config did not pass multi node types auto fill test!") + try: + new_config = \ + provider_cls.fillout_available_node_types_resources( + new_config) + validate_config(new_config) + expected_available_node_types == new_config[ + "available_node_types"] + except Exception: + self.fail( + "Config did not pass multi node types auto fill test!") def testValidateNetworkConfig(self): web_yaml = "https://raw.githubusercontent.com/ray-project/ray/" \