[autoscaler] Try making GCP node provider thread-safe

This commit is contained in:
Eric Liang
2019-02-20 16:35:20 -08:00
committed by GitHub
parent a99676e39b
commit e3066d1fa5
+143 -130
View File
@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function
from uuid import uuid4
from threading import RLock
import time
import logging
@@ -45,6 +46,7 @@ class GCPNodeProvider(NodeProvider):
def __init__(self, provider_config, cluster_name):
NodeProvider.__init__(self, provider_config, cluster_name)
self.lock = RLock()
self.compute = discovery.build("compute", "v1")
# Cache of node objects from the last nodes() call. This avoids
@@ -52,162 +54,173 @@ class GCPNodeProvider(NodeProvider):
self.cached_nodes = {}
def nodes(self, tag_filters):
if tag_filters:
label_filter_expr = "(" + " AND ".join([
"(labels.{key} = {value})".format(key=key, value=value)
for key, value in tag_filters.items()
with self.lock:
if tag_filters:
label_filter_expr = "(" + " AND ".join([
"(labels.{key} = {value})".format(key=key, value=value)
for key, value in tag_filters.items()
]) + ")"
else:
label_filter_expr = ""
instance_state_filter_expr = "(" + " OR ".join([
"(status = {status})".format(status=status)
for status in {"PROVISIONING", "STAGING", "RUNNING"}
]) + ")"
else:
label_filter_expr = ""
instance_state_filter_expr = "(" + " OR ".join([
"(status = {status})".format(status=status)
for status in {"PROVISIONING", "STAGING", "RUNNING"}
]) + ")"
cluster_name_filter_expr = ("(labels.{key} = {value})"
"".format(
key=TAG_RAY_CLUSTER_NAME,
value=self.cluster_name))
cluster_name_filter_expr = ("(labels.{key} = {value})"
"".format(
key=TAG_RAY_CLUSTER_NAME,
value=self.cluster_name))
not_empty_filters = [
f for f in [
label_filter_expr,
instance_state_filter_expr,
cluster_name_filter_expr,
] if f
]
not_empty_filters = [
f for f in [
label_filter_expr,
instance_state_filter_expr,
cluster_name_filter_expr,
] if f
]
filter_expr = " AND ".join(not_empty_filters)
filter_expr = " AND ".join(not_empty_filters)
response = self.compute.instances().list(
project=self.provider_config["project_id"],
zone=self.provider_config["availability_zone"],
filter=filter_expr,
).execute()
response = self.compute.instances().list(
project=self.provider_config["project_id"],
zone=self.provider_config["availability_zone"],
filter=filter_expr,
).execute()
instances = response.get("items", [])
# Note: All the operations use "name" as the unique instance id
self.cached_nodes = {i["name"]: i for i in instances}
instances = response.get("items", [])
# Note: All the operations use "name" as the unique instance identifier
self.cached_nodes = {i["name"]: i for i in instances}
return [i["name"] for i in instances]
return [i["name"] for i in instances]
def is_running(self, node_id):
node = self._get_cached_node(node_id)
return node["status"] == "RUNNING"
with self.lock:
node = self._get_cached_node(node_id)
return node["status"] == "RUNNING"
def is_terminated(self, node_id):
node = self._get_cached_node(node_id)
return node["status"] not in {"PROVISIONING", "STAGING", "RUNNING"}
with self.lock:
node = self._get_cached_node(node_id)
return node["status"] not in {"PROVISIONING", "STAGING", "RUNNING"}
def node_tags(self, node_id):
node = self._get_cached_node(node_id)
labels = node.get("labels", {})
return labels
with self.lock:
node = self._get_cached_node(node_id)
labels = node.get("labels", {})
return labels
def set_node_tags(self, node_id, tags):
labels = tags
project_id = self.provider_config["project_id"]
availability_zone = self.provider_config["availability_zone"]
with self.lock:
labels = tags
project_id = self.provider_config["project_id"]
availability_zone = self.provider_config["availability_zone"]
node = self._get_node(node_id)
operation = self.compute.instances().setLabels(
project=project_id,
zone=availability_zone,
instance=node_id,
body={
"labels": dict(node["labels"], **labels),
"labelFingerprint": node["labelFingerprint"]
}).execute()
result = wait_for_compute_zone_operation(self.compute, project_id,
operation, availability_zone)
return result
def external_ip(self, node_id):
node = self._get_cached_node(node_id)
def get_external_ip(node):
return node.get("networkInterfaces", [{}])[0].get(
"accessConfigs", [{}])[0].get("natIP", None)
ip = get_external_ip(node)
if ip is None:
node = self._get_node(node_id)
ip = get_external_ip(node)
return ip
def internal_ip(self, node_id):
node = self._get_cached_node(node_id)
def get_internal_ip(node):
return node.get("networkInterfaces", [{}])[0].get("networkIP")
ip = get_internal_ip(node)
if ip is None:
node = self._get_node(node_id)
ip = get_internal_ip(node)
return ip
def create_node(self, base_config, tags, count):
labels = tags # gcp uses "labels" instead of aws "tags"
project_id = self.provider_config["project_id"]
availability_zone = self.provider_config["availability_zone"]
config = base_config.copy()
name_label = labels[TAG_RAY_NODE_NAME]
assert (len(name_label) <=
(INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1)), (
name_label, len(name_label))
config.update({
"machineType": ("zones/{zone}/machineTypes/{machine_type}"
"".format(
zone=availability_zone,
machine_type=base_config["machineType"])),
"labels": dict(
config.get("labels", {}), **labels,
**{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
})
operations = [
self.compute.instances().insert(
operation = self.compute.instances().setLabels(
project=project_id,
zone=availability_zone,
body=dict(
config, **{
"name": ("{name_label}-{uuid}".format(
name_label=name_label,
uuid=uuid4().hex[:INSTANCE_NAME_UUID_LEN]))
})).execute() for i in range(count)
]
instance=node_id,
body={
"labels": dict(node["labels"], **labels),
"labelFingerprint": node["labelFingerprint"]
}).execute()
results = [
wait_for_compute_zone_operation(self.compute, project_id,
operation, availability_zone)
for operation in operations
]
result = wait_for_compute_zone_operation(
self.compute, project_id, operation, availability_zone)
return results
return result
def external_ip(self, node_id):
with self.lock:
node = self._get_cached_node(node_id)
def get_external_ip(node):
return node.get("networkInterfaces", [{}])[0].get(
"accessConfigs", [{}])[0].get("natIP", None)
ip = get_external_ip(node)
if ip is None:
node = self._get_node(node_id)
ip = get_external_ip(node)
return ip
def internal_ip(self, node_id):
with self.lock:
node = self._get_cached_node(node_id)
def get_internal_ip(node):
return node.get("networkInterfaces", [{}])[0].get("networkIP")
ip = get_internal_ip(node)
if ip is None:
node = self._get_node(node_id)
ip = get_internal_ip(node)
return ip
def create_node(self, base_config, tags, count):
with self.lock:
labels = tags # gcp uses "labels" instead of aws "tags"
project_id = self.provider_config["project_id"]
availability_zone = self.provider_config["availability_zone"]
config = base_config.copy()
name_label = labels[TAG_RAY_NODE_NAME]
assert (len(name_label) <=
(INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1)), (
name_label, len(name_label))
machine_type = ("zones/{zone}/machineTypes/{machine_type}"
"".format(
zone=availability_zone,
machine_type=base_config["machineType"]))
labels = dict(config.get("labels", {}), **labels)
config.update({
"machineType": machine_type,
"labels": dict(labels,
**{TAG_RAY_CLUSTER_NAME: self.cluster_name}),
})
operations = [
self.compute.instances().insert(
project=project_id,
zone=availability_zone,
body=dict(
config, **{
"name": ("{name_label}-{uuid}".format(
name_label=name_label,
uuid=uuid4().hex[:INSTANCE_NAME_UUID_LEN]))
})).execute() for i in range(count)
]
results = [
wait_for_compute_zone_operation(self.compute, project_id,
operation, availability_zone)
for operation in operations
]
return results
def terminate_node(self, node_id):
project_id = self.provider_config["project_id"]
availability_zone = self.provider_config["availability_zone"]
with self.lock:
project_id = self.provider_config["project_id"]
availability_zone = self.provider_config["availability_zone"]
operation = self.compute.instances().delete(
project=project_id,
zone=availability_zone,
instance=node_id,
).execute()
operation = self.compute.instances().delete(
project=project_id,
zone=availability_zone,
instance=node_id,
).execute()
result = wait_for_compute_zone_operation(self.compute, project_id,
operation, availability_zone)
result = wait_for_compute_zone_operation(
self.compute, project_id, operation, availability_zone)
return result
return result
def _get_node(self, node_id):
self.nodes({}) # Side effect: fetches and caches the node.