mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:06:25 +08:00
Make AWSNodeProvider.create_node return nodes created (#13498)
* Make AWSNodeProvider.create_node return node config * return-dict * Node provider interface create node return type Any * Type clarification. * Delete debug code * Oops reset example-full changes * Return type specified. GCP create node returns None. * Article
This commit is contained in:
@@ -239,8 +239,15 @@ class AWSNodeProvider(NodeProvider):
|
||||
}],
|
||||
)
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
def create_node(self, node_config, tags, count) -> Dict[str, Any]:
|
||||
"""Creates instances.
|
||||
|
||||
Returns dict mapping instance id to ec2.Instance object for the created
|
||||
instances.
|
||||
"""
|
||||
tags = copy.deepcopy(tags)
|
||||
|
||||
reused_nodes_dict = {}
|
||||
# Try to reuse previously stopped nodes with compatible configs
|
||||
if self.cache_stopped_nodes:
|
||||
# TODO(ekl) this is breaking the abstraction boundary a little by
|
||||
@@ -273,6 +280,7 @@ class AWSNodeProvider(NodeProvider):
|
||||
reuse_nodes = list(
|
||||
self.ec2.instances.filter(Filters=filters))[:count]
|
||||
reuse_node_ids = [n.id for n in reuse_nodes]
|
||||
reused_nodes_dict = {n.id: n for n in reuse_nodes}
|
||||
if reuse_nodes:
|
||||
cli_logger.print(
|
||||
# todo: handle plural vs singular?
|
||||
@@ -298,10 +306,17 @@ class AWSNodeProvider(NodeProvider):
|
||||
self.set_node_tags(node_id, tags)
|
||||
count -= len(reuse_node_ids)
|
||||
|
||||
created_nodes_dict = {}
|
||||
if count:
|
||||
self._create_node(node_config, tags, count)
|
||||
created_nodes_dict = self._create_node(node_config, tags, count)
|
||||
|
||||
all_created_nodes = reused_nodes_dict
|
||||
all_created_nodes.update(created_nodes_dict)
|
||||
return all_created_nodes
|
||||
|
||||
def _create_node(self, node_config, tags, count):
|
||||
created_nodes_dict = {}
|
||||
|
||||
tags = to_aws_format(tags)
|
||||
conf = node_config.copy()
|
||||
|
||||
@@ -353,6 +368,7 @@ class AWSNodeProvider(NodeProvider):
|
||||
"TagSpecifications": tag_specs
|
||||
})
|
||||
created = self.ec2_fail_fast.create_instances(**conf)
|
||||
created_nodes_dict = {n.id: n for n in created}
|
||||
|
||||
# todo: timed?
|
||||
# todo: handle plurality?
|
||||
@@ -390,6 +406,7 @@ class AWSNodeProvider(NodeProvider):
|
||||
cli_logger.print(
|
||||
"create_instances: Attempt failed with {}, retrying.",
|
||||
exc)
|
||||
return created_nodes_dict
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
node = self._get_cached_node(node_id)
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
|
||||
from ray.autoscaler._private.gcp.config import bootstrap_gcp
|
||||
from ray.autoscaler._private.gcp.config import MAX_POLLS, POLL_INTERVAL, \
|
||||
construct_clients_from_provider_config
|
||||
construct_clients_from_provider_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -158,7 +158,7 @@ class GCPNodeProvider(NodeProvider):
|
||||
|
||||
return ip
|
||||
|
||||
def create_node(self, base_config, tags, count):
|
||||
def create_node(self, base_config, tags, count) -> None:
|
||||
with self.lock:
|
||||
labels = tags # gcp uses "labels" instead of aws "tags"
|
||||
project_id = self.provider_config["project_id"]
|
||||
@@ -195,13 +195,9 @@ class GCPNodeProvider(NodeProvider):
|
||||
})).execute() for i in range(count)
|
||||
]
|
||||
|
||||
results = [
|
||||
for operation in operations:
|
||||
wait_for_compute_zone_operation(self.compute, project_id,
|
||||
operation, availability_zone)
|
||||
for operation in operations
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
with self.lock:
|
||||
|
||||
@@ -110,8 +110,11 @@ class NodeProvider:
|
||||
return find_node_id()
|
||||
|
||||
def create_node(self, node_config: Dict[str, Any], tags: Dict[str, str],
|
||||
count: int) -> None:
|
||||
"""Creates a number of nodes within the namespace."""
|
||||
count: int) -> Optional[Dict[str, Any]]:
|
||||
"""Creates a number of nodes within the namespace.
|
||||
|
||||
Optionally returns a mapping from created node ids to node metadata.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_node_tags(self, node_id: str, tags: Dict[str, str]) -> None:
|
||||
|
||||
Reference in New Issue
Block a user