From 7b4a97c6100ba5a05e07d9b3712ccedd9f734e82 Mon Sep 17 00:00:00 2001 From: Dmitri Gekhtman <62982571+DmitriGekhtman@users.noreply.github.com> Date: Tue, 19 Jan 2021 12:17:46 -0800 Subject: [PATCH] Make AWSNodeProvider.create_node return nodes created (#13498) * Make AWSNodeProvider.create_node return node config * return-dict * Node provider interface create node return type Any * Type clarification. * Delete debug code * Oops reset example-full changes * Return type specified. GCP create node returns None. * Article --- .../autoscaler/_private/aws/node_provider.py | 21 +++++++++++++++++-- .../autoscaler/_private/gcp/node_provider.py | 10 +++------ python/ray/autoscaler/node_provider.py | 7 +++++-- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/python/ray/autoscaler/_private/aws/node_provider.py b/python/ray/autoscaler/_private/aws/node_provider.py index 3cd7797ed..0eed4a7da 100644 --- a/python/ray/autoscaler/_private/aws/node_provider.py +++ b/python/ray/autoscaler/_private/aws/node_provider.py @@ -239,8 +239,15 @@ class AWSNodeProvider(NodeProvider): }], ) - def create_node(self, node_config, tags, count): + def create_node(self, node_config, tags, count) -> Dict[str, Any]: + """Creates instances. + + Returns dict mapping instance id to ec2.Instance object for the created + instances. + """ tags = copy.deepcopy(tags) + + reused_nodes_dict = {} # Try to reuse previously stopped nodes with compatible configs if self.cache_stopped_nodes: # TODO(ekl) this is breaking the abstraction boundary a little by @@ -273,6 +280,7 @@ class AWSNodeProvider(NodeProvider): reuse_nodes = list( self.ec2.instances.filter(Filters=filters))[:count] reuse_node_ids = [n.id for n in reuse_nodes] + reused_nodes_dict = {n.id: n for n in reuse_nodes} if reuse_nodes: cli_logger.print( # todo: handle plural vs singular? @@ -298,10 +306,17 @@ class AWSNodeProvider(NodeProvider): self.set_node_tags(node_id, tags) count -= len(reuse_node_ids) + created_nodes_dict = {} if count: - self._create_node(node_config, tags, count) + created_nodes_dict = self._create_node(node_config, tags, count) + + all_created_nodes = reused_nodes_dict + all_created_nodes.update(created_nodes_dict) + return all_created_nodes def _create_node(self, node_config, tags, count): + created_nodes_dict = {} + tags = to_aws_format(tags) conf = node_config.copy() @@ -353,6 +368,7 @@ class AWSNodeProvider(NodeProvider): "TagSpecifications": tag_specs }) created = self.ec2_fail_fast.create_instances(**conf) + created_nodes_dict = {n.id: n for n in created} # todo: timed? # todo: handle plurality? @@ -390,6 +406,7 @@ class AWSNodeProvider(NodeProvider): cli_logger.print( "create_instances: Attempt failed with {}, retrying.", exc) + return created_nodes_dict def terminate_node(self, node_id): node = self._get_cached_node(node_id) diff --git a/python/ray/autoscaler/_private/gcp/node_provider.py b/python/ray/autoscaler/_private/gcp/node_provider.py index 79da7ab3e..853a5b63c 100644 --- a/python/ray/autoscaler/_private/gcp/node_provider.py +++ b/python/ray/autoscaler/_private/gcp/node_provider.py @@ -7,7 +7,7 @@ from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME from ray.autoscaler._private.gcp.config import bootstrap_gcp from ray.autoscaler._private.gcp.config import MAX_POLLS, POLL_INTERVAL, \ - construct_clients_from_provider_config + construct_clients_from_provider_config logger = logging.getLogger(__name__) @@ -158,7 +158,7 @@ class GCPNodeProvider(NodeProvider): return ip - def create_node(self, base_config, tags, count): + def create_node(self, base_config, tags, count) -> None: with self.lock: labels = tags # gcp uses "labels" instead of aws "tags" project_id = self.provider_config["project_id"] @@ -195,13 +195,9 @@ class GCPNodeProvider(NodeProvider): })).execute() for i in range(count) ] - results = [ + for operation in operations: wait_for_compute_zone_operation(self.compute, project_id, operation, availability_zone) - for operation in operations - ] - - return results def terminate_node(self, node_id): with self.lock: diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index 3f1af0ada..8ac3c1233 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -110,8 +110,11 @@ class NodeProvider: return find_node_id() def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str], - count: int) -> None: - """Creates a number of nodes within the namespace.""" + count: int) -> Optional[Dict[str, Any]]: + """Creates a number of nodes within the namespace. + + Optionally returns a mapping from created node ids to node metadata. + """ raise NotImplementedError def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None: