diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index 2a3734ba4..80245881c 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -69,6 +69,7 @@ CLUSTER_CONFIG_SCHEMA = { "project_id": (None, OPTIONAL), # gcp project id, if using gcp "head_ip": (str, OPTIONAL), # local cluster head node "worker_ips": (list, OPTIONAL), # local cluster worker nodes + "use_internal_ips": (bool, OPTIONAL), # don't require public ips }, REQUIRED), diff --git a/python/ray/autoscaler/aws/config.py b/python/ray/autoscaler/aws/config.py index 249970e8f..b6953c149 100644 --- a/python/ray/autoscaler/aws/config.py +++ b/python/ray/autoscaler/aws/config.py @@ -152,9 +152,10 @@ def _configure_key_pair(config): def _configure_subnet(config): ec2 = _resource("ec2", config) + use_internal_ips = config["provider"].get("use_internal_ips", False) subnets = sorted( - (s for s in ec2.subnets.all() - if s.state == "available" and s.map_public_ip_on_launch), + (s for s in ec2.subnets.all() if s.state == "available" and ( + use_internal_ips or s.map_public_ip_on_launch)), reverse=True, # sort from Z-A key=lambda subnet: subnet.availability_zone) if not subnets: @@ -162,7 +163,8 @@ def _configure_subnet(config): "No usable subnets found, try manually creating an instance in " "your specified region to populate the list of subnets " "and trying this again. Note that the subnet must map public IPs " - "on instance launch.") + "on instance launch unless you set 'use_internal_ips': True in " + "the 'provider' config.") if "availability_zone" in config["provider"]: azs = config["provider"]["availability_zone"].split(',') subnets = [s for s in subnets if s.availability_zone in azs] diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index b132971d2..1d6b5e23b 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -47,7 +47,8 @@ class NodeUpdater(object): self.daemon = True self.process_runner = process_runner self.node_id = node_id - self.use_internal_ip = use_internal_ip + self.use_internal_ip = (use_internal_ip or provider_config.get( + "use_internal_ips", False)) self.provider = get_node_provider(provider_config, cluster_name) self.ssh_private_key = auth_config["ssh_private_key"] self.ssh_user = auth_config["ssh_user"]