mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:16:06 +08:00
[autoscaler] Try making GCP node provider thread-safe
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from uuid import uuid4
|
||||
from threading import RLock
|
||||
import time
|
||||
import logging
|
||||
|
||||
@@ -45,6 +46,7 @@ class GCPNodeProvider(NodeProvider):
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
|
||||
self.lock = RLock()
|
||||
self.compute = discovery.build("compute", "v1")
|
||||
|
||||
# Cache of node objects from the last nodes() call. This avoids
|
||||
@@ -52,162 +54,173 @@ class GCPNodeProvider(NodeProvider):
|
||||
self.cached_nodes = {}
|
||||
|
||||
def nodes(self, tag_filters):
|
||||
if tag_filters:
|
||||
label_filter_expr = "(" + " AND ".join([
|
||||
"(labels.{key} = {value})".format(key=key, value=value)
|
||||
for key, value in tag_filters.items()
|
||||
with self.lock:
|
||||
if tag_filters:
|
||||
label_filter_expr = "(" + " AND ".join([
|
||||
"(labels.{key} = {value})".format(key=key, value=value)
|
||||
for key, value in tag_filters.items()
|
||||
]) + ")"
|
||||
else:
|
||||
label_filter_expr = ""
|
||||
|
||||
instance_state_filter_expr = "(" + " OR ".join([
|
||||
"(status = {status})".format(status=status)
|
||||
for status in {"PROVISIONING", "STAGING", "RUNNING"}
|
||||
]) + ")"
|
||||
else:
|
||||
label_filter_expr = ""
|
||||
|
||||
instance_state_filter_expr = "(" + " OR ".join([
|
||||
"(status = {status})".format(status=status)
|
||||
for status in {"PROVISIONING", "STAGING", "RUNNING"}
|
||||
]) + ")"
|
||||
cluster_name_filter_expr = ("(labels.{key} = {value})"
|
||||
"".format(
|
||||
key=TAG_RAY_CLUSTER_NAME,
|
||||
value=self.cluster_name))
|
||||
|
||||
cluster_name_filter_expr = ("(labels.{key} = {value})"
|
||||
"".format(
|
||||
key=TAG_RAY_CLUSTER_NAME,
|
||||
value=self.cluster_name))
|
||||
not_empty_filters = [
|
||||
f for f in [
|
||||
label_filter_expr,
|
||||
instance_state_filter_expr,
|
||||
cluster_name_filter_expr,
|
||||
] if f
|
||||
]
|
||||
|
||||
not_empty_filters = [
|
||||
f for f in [
|
||||
label_filter_expr,
|
||||
instance_state_filter_expr,
|
||||
cluster_name_filter_expr,
|
||||
] if f
|
||||
]
|
||||
filter_expr = " AND ".join(not_empty_filters)
|
||||
|
||||
filter_expr = " AND ".join(not_empty_filters)
|
||||
response = self.compute.instances().list(
|
||||
project=self.provider_config["project_id"],
|
||||
zone=self.provider_config["availability_zone"],
|
||||
filter=filter_expr,
|
||||
).execute()
|
||||
|
||||
response = self.compute.instances().list(
|
||||
project=self.provider_config["project_id"],
|
||||
zone=self.provider_config["availability_zone"],
|
||||
filter=filter_expr,
|
||||
).execute()
|
||||
instances = response.get("items", [])
|
||||
# Note: All the operations use "name" as the unique instance id
|
||||
self.cached_nodes = {i["name"]: i for i in instances}
|
||||
|
||||
instances = response.get("items", [])
|
||||
# Note: All the operations use "name" as the unique instance identifier
|
||||
self.cached_nodes = {i["name"]: i for i in instances}
|
||||
|
||||
return [i["name"] for i in instances]
|
||||
return [i["name"] for i in instances]
|
||||
|
||||
def is_running(self, node_id):
|
||||
node = self._get_cached_node(node_id)
|
||||
return node["status"] == "RUNNING"
|
||||
with self.lock:
|
||||
node = self._get_cached_node(node_id)
|
||||
return node["status"] == "RUNNING"
|
||||
|
||||
def is_terminated(self, node_id):
|
||||
node = self._get_cached_node(node_id)
|
||||
return node["status"] not in {"PROVISIONING", "STAGING", "RUNNING"}
|
||||
with self.lock:
|
||||
node = self._get_cached_node(node_id)
|
||||
return node["status"] not in {"PROVISIONING", "STAGING", "RUNNING"}
|
||||
|
||||
def node_tags(self, node_id):
|
||||
node = self._get_cached_node(node_id)
|
||||
labels = node.get("labels", {})
|
||||
return labels
|
||||
with self.lock:
|
||||
node = self._get_cached_node(node_id)
|
||||
labels = node.get("labels", {})
|
||||
return labels
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
labels = tags
|
||||
project_id = self.provider_config["project_id"]
|
||||
availability_zone = self.provider_config["availability_zone"]
|
||||
with self.lock:
|
||||
labels = tags
|
||||
project_id = self.provider_config["project_id"]
|
||||
availability_zone = self.provider_config["availability_zone"]
|
||||
|
||||
node = self._get_node(node_id)
|
||||
operation = self.compute.instances().setLabels(
|
||||
project=project_id,
|
||||
zone=availability_zone,
|
||||
instance=node_id,
|
||||
body={
|
||||
"labels": dict(node["labels"], **labels),
|
||||
"labelFingerprint": node["labelFingerprint"]
|
||||
}).execute()
|
||||
|
||||
result = wait_for_compute_zone_operation(self.compute, project_id,
|
||||
operation, availability_zone)
|
||||
|
||||
return result
|
||||
|
||||
def external_ip(self, node_id):
|
||||
node = self._get_cached_node(node_id)
|
||||
|
||||
def get_external_ip(node):
|
||||
return node.get("networkInterfaces", [{}])[0].get(
|
||||
"accessConfigs", [{}])[0].get("natIP", None)
|
||||
|
||||
ip = get_external_ip(node)
|
||||
if ip is None:
|
||||
node = self._get_node(node_id)
|
||||
ip = get_external_ip(node)
|
||||
|
||||
return ip
|
||||
|
||||
def internal_ip(self, node_id):
|
||||
node = self._get_cached_node(node_id)
|
||||
|
||||
def get_internal_ip(node):
|
||||
return node.get("networkInterfaces", [{}])[0].get("networkIP")
|
||||
|
||||
ip = get_internal_ip(node)
|
||||
if ip is None:
|
||||
node = self._get_node(node_id)
|
||||
ip = get_internal_ip(node)
|
||||
|
||||
return ip
|
||||
|
||||
def create_node(self, base_config, tags, count):
|
||||
labels = tags # gcp uses "labels" instead of aws "tags"
|
||||
project_id = self.provider_config["project_id"]
|
||||
availability_zone = self.provider_config["availability_zone"]
|
||||
|
||||
config = base_config.copy()
|
||||
|
||||
name_label = labels[TAG_RAY_NODE_NAME]
|
||||
assert (len(name_label) <=
|
||||
(INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1)), (
|
||||
name_label, len(name_label))
|
||||
|
||||
config.update({
|
||||
"machineType": ("zones/{zone}/machineTypes/{machine_type}"
|
||||
"".format(
|
||||
zone=availability_zone,
|
||||
machine_type=base_config["machineType"])),
|
||||
"labels": dict(
|
||||
config.get("labels", {}), **labels,
|
||||
**{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
|
||||
})
|
||||
|
||||
operations = [
|
||||
self.compute.instances().insert(
|
||||
operation = self.compute.instances().setLabels(
|
||||
project=project_id,
|
||||
zone=availability_zone,
|
||||
body=dict(
|
||||
config, **{
|
||||
"name": ("{name_label}-{uuid}".format(
|
||||
name_label=name_label,
|
||||
uuid=uuid4().hex[:INSTANCE_NAME_UUID_LEN]))
|
||||
})).execute() for i in range(count)
|
||||
]
|
||||
instance=node_id,
|
||||
body={
|
||||
"labels": dict(node["labels"], **labels),
|
||||
"labelFingerprint": node["labelFingerprint"]
|
||||
}).execute()
|
||||
|
||||
results = [
|
||||
wait_for_compute_zone_operation(self.compute, project_id,
|
||||
operation, availability_zone)
|
||||
for operation in operations
|
||||
]
|
||||
result = wait_for_compute_zone_operation(
|
||||
self.compute, project_id, operation, availability_zone)
|
||||
|
||||
return results
|
||||
return result
|
||||
|
||||
def external_ip(self, node_id):
|
||||
with self.lock:
|
||||
node = self._get_cached_node(node_id)
|
||||
|
||||
def get_external_ip(node):
|
||||
return node.get("networkInterfaces", [{}])[0].get(
|
||||
"accessConfigs", [{}])[0].get("natIP", None)
|
||||
|
||||
ip = get_external_ip(node)
|
||||
if ip is None:
|
||||
node = self._get_node(node_id)
|
||||
ip = get_external_ip(node)
|
||||
|
||||
return ip
|
||||
|
||||
def internal_ip(self, node_id):
|
||||
with self.lock:
|
||||
node = self._get_cached_node(node_id)
|
||||
|
||||
def get_internal_ip(node):
|
||||
return node.get("networkInterfaces", [{}])[0].get("networkIP")
|
||||
|
||||
ip = get_internal_ip(node)
|
||||
if ip is None:
|
||||
node = self._get_node(node_id)
|
||||
ip = get_internal_ip(node)
|
||||
|
||||
return ip
|
||||
|
||||
def create_node(self, base_config, tags, count):
|
||||
with self.lock:
|
||||
labels = tags # gcp uses "labels" instead of aws "tags"
|
||||
project_id = self.provider_config["project_id"]
|
||||
availability_zone = self.provider_config["availability_zone"]
|
||||
|
||||
config = base_config.copy()
|
||||
|
||||
name_label = labels[TAG_RAY_NODE_NAME]
|
||||
assert (len(name_label) <=
|
||||
(INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1)), (
|
||||
name_label, len(name_label))
|
||||
|
||||
machine_type = ("zones/{zone}/machineTypes/{machine_type}"
|
||||
"".format(
|
||||
zone=availability_zone,
|
||||
machine_type=base_config["machineType"]))
|
||||
labels = dict(config.get("labels", {}), **labels)
|
||||
|
||||
config.update({
|
||||
"machineType": machine_type,
|
||||
"labels": dict(labels,
|
||||
**{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
|
||||
})
|
||||
|
||||
operations = [
|
||||
self.compute.instances().insert(
|
||||
project=project_id,
|
||||
zone=availability_zone,
|
||||
body=dict(
|
||||
config, **{
|
||||
"name": ("{name_label}-{uuid}".format(
|
||||
name_label=name_label,
|
||||
uuid=uuid4().hex[:INSTANCE_NAME_UUID_LEN]))
|
||||
})).execute() for i in range(count)
|
||||
]
|
||||
|
||||
results = [
|
||||
wait_for_compute_zone_operation(self.compute, project_id,
|
||||
operation, availability_zone)
|
||||
for operation in operations
|
||||
]
|
||||
|
||||
return results
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
project_id = self.provider_config["project_id"]
|
||||
availability_zone = self.provider_config["availability_zone"]
|
||||
with self.lock:
|
||||
project_id = self.provider_config["project_id"]
|
||||
availability_zone = self.provider_config["availability_zone"]
|
||||
|
||||
operation = self.compute.instances().delete(
|
||||
project=project_id,
|
||||
zone=availability_zone,
|
||||
instance=node_id,
|
||||
).execute()
|
||||
operation = self.compute.instances().delete(
|
||||
project=project_id,
|
||||
zone=availability_zone,
|
||||
instance=node_id,
|
||||
).execute()
|
||||
|
||||
result = wait_for_compute_zone_operation(self.compute, project_id,
|
||||
operation, availability_zone)
|
||||
result = wait_for_compute_zone_operation(
|
||||
self.compute, project_id, operation, availability_zone)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
def _get_node(self, node_id):
|
||||
self.nodes({}) # Side effect: fetches and caches the node.
|
||||
|
||||
Reference in New Issue
Block a user