diff --git a/python/ray/autoscaler/gcp/node_provider.py b/python/ray/autoscaler/gcp/node_provider.py index 787b9d8e1..c79014cc6 100644 --- a/python/ray/autoscaler/gcp/node_provider.py +++ b/python/ray/autoscaler/gcp/node_provider.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function from uuid import uuid4 +from threading import RLock import time import logging @@ -45,6 +46,7 @@ class GCPNodeProvider(NodeProvider): def __init__(self, provider_config, cluster_name): NodeProvider.__init__(self, provider_config, cluster_name) + self.lock = RLock() self.compute = discovery.build("compute", "v1") # Cache of node objects from the last nodes() call. This avoids @@ -52,162 +54,173 @@ class GCPNodeProvider(NodeProvider): self.cached_nodes = {} def nodes(self, tag_filters): - if tag_filters: - label_filter_expr = "(" + " AND ".join([ - "(labels.{key} = {value})".format(key=key, value=value) - for key, value in tag_filters.items() + with self.lock: + if tag_filters: + label_filter_expr = "(" + " AND ".join([ + "(labels.{key} = {value})".format(key=key, value=value) + for key, value in tag_filters.items() + ]) + ")" + else: + label_filter_expr = "" + + instance_state_filter_expr = "(" + " OR ".join([ + "(status = {status})".format(status=status) + for status in {"PROVISIONING", "STAGING", "RUNNING"} ]) + ")" - else: - label_filter_expr = "" - instance_state_filter_expr = "(" + " OR ".join([ - "(status = {status})".format(status=status) - for status in {"PROVISIONING", "STAGING", "RUNNING"} - ]) + ")" + cluster_name_filter_expr = ("(labels.{key} = {value})" + "".format( + key=TAG_RAY_CLUSTER_NAME, + value=self.cluster_name)) - cluster_name_filter_expr = ("(labels.{key} = {value})" - "".format( - key=TAG_RAY_CLUSTER_NAME, - value=self.cluster_name)) + not_empty_filters = [ + f for f in [ + label_filter_expr, + instance_state_filter_expr, + cluster_name_filter_expr, + ] if f + ] - not_empty_filters = [ - f for f in [ - label_filter_expr, - instance_state_filter_expr, - cluster_name_filter_expr, - ] if f - ] + filter_expr = " AND ".join(not_empty_filters) - filter_expr = " AND ".join(not_empty_filters) + response = self.compute.instances().list( + project=self.provider_config["project_id"], + zone=self.provider_config["availability_zone"], + filter=filter_expr, + ).execute() - response = self.compute.instances().list( - project=self.provider_config["project_id"], - zone=self.provider_config["availability_zone"], - filter=filter_expr, - ).execute() + instances = response.get("items", []) + # Note: All the operations use "name" as the unique instance id + self.cached_nodes = {i["name"]: i for i in instances} - instances = response.get("items", []) - # Note: All the operations use "name" as the unique instance identifier - self.cached_nodes = {i["name"]: i for i in instances} - - return [i["name"] for i in instances] + return [i["name"] for i in instances] def is_running(self, node_id): - node = self._get_cached_node(node_id) - return node["status"] == "RUNNING" + with self.lock: + node = self._get_cached_node(node_id) + return node["status"] == "RUNNING" def is_terminated(self, node_id): - node = self._get_cached_node(node_id) - return node["status"] not in {"PROVISIONING", "STAGING", "RUNNING"} + with self.lock: + node = self._get_cached_node(node_id) + return node["status"] not in {"PROVISIONING", "STAGING", "RUNNING"} def node_tags(self, node_id): - node = self._get_cached_node(node_id) - labels = node.get("labels", {}) - return labels + with self.lock: + node = self._get_cached_node(node_id) + labels = node.get("labels", {}) + return labels def set_node_tags(self, node_id, tags): - labels = tags - project_id = self.provider_config["project_id"] - availability_zone = self.provider_config["availability_zone"] + with self.lock: + labels = tags + project_id = self.provider_config["project_id"] + availability_zone = self.provider_config["availability_zone"] - node = self._get_node(node_id) - operation = self.compute.instances().setLabels( - project=project_id, - zone=availability_zone, - instance=node_id, - body={ - "labels": dict(node["labels"], **labels), - "labelFingerprint": node["labelFingerprint"] - }).execute() - - result = wait_for_compute_zone_operation(self.compute, project_id, - operation, availability_zone) - - return result - - def external_ip(self, node_id): - node = self._get_cached_node(node_id) - - def get_external_ip(node): - return node.get("networkInterfaces", [{}])[0].get( - "accessConfigs", [{}])[0].get("natIP", None) - - ip = get_external_ip(node) - if ip is None: node = self._get_node(node_id) - ip = get_external_ip(node) - - return ip - - def internal_ip(self, node_id): - node = self._get_cached_node(node_id) - - def get_internal_ip(node): - return node.get("networkInterfaces", [{}])[0].get("networkIP") - - ip = get_internal_ip(node) - if ip is None: - node = self._get_node(node_id) - ip = get_internal_ip(node) - - return ip - - def create_node(self, base_config, tags, count): - labels = tags # gcp uses "labels" instead of aws "tags" - project_id = self.provider_config["project_id"] - availability_zone = self.provider_config["availability_zone"] - - config = base_config.copy() - - name_label = labels[TAG_RAY_NODE_NAME] - assert (len(name_label) <= - (INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1)), ( - name_label, len(name_label)) - - config.update({ - "machineType": ("zones/{zone}/machineTypes/{machine_type}" - "".format( - zone=availability_zone, - machine_type=base_config["machineType"])), - "labels": dict( - config.get("labels", {}), **labels, - **{TAG_RAY_CLUSTER_NAME: self.cluster_name}), - }) - - operations = [ - self.compute.instances().insert( + operation = self.compute.instances().setLabels( project=project_id, zone=availability_zone, - body=dict( - config, **{ - "name": ("{name_label}-{uuid}".format( - name_label=name_label, - uuid=uuid4().hex[:INSTANCE_NAME_UUID_LEN])) - })).execute() for i in range(count) - ] + instance=node_id, + body={ + "labels": dict(node["labels"], **labels), + "labelFingerprint": node["labelFingerprint"] + }).execute() - results = [ - wait_for_compute_zone_operation(self.compute, project_id, - operation, availability_zone) - for operation in operations - ] + result = wait_for_compute_zone_operation( + self.compute, project_id, operation, availability_zone) - return results + return result + + def external_ip(self, node_id): + with self.lock: + node = self._get_cached_node(node_id) + + def get_external_ip(node): + return node.get("networkInterfaces", [{}])[0].get( + "accessConfigs", [{}])[0].get("natIP", None) + + ip = get_external_ip(node) + if ip is None: + node = self._get_node(node_id) + ip = get_external_ip(node) + + return ip + + def internal_ip(self, node_id): + with self.lock: + node = self._get_cached_node(node_id) + + def get_internal_ip(node): + return node.get("networkInterfaces", [{}])[0].get("networkIP") + + ip = get_internal_ip(node) + if ip is None: + node = self._get_node(node_id) + ip = get_internal_ip(node) + + return ip + + def create_node(self, base_config, tags, count): + with self.lock: + labels = tags # gcp uses "labels" instead of aws "tags" + project_id = self.provider_config["project_id"] + availability_zone = self.provider_config["availability_zone"] + + config = base_config.copy() + + name_label = labels[TAG_RAY_NODE_NAME] + assert (len(name_label) <= + (INSTANCE_NAME_MAX_LEN - INSTANCE_NAME_UUID_LEN - 1)), ( + name_label, len(name_label)) + + machine_type = ("zones/{zone}/machineTypes/{machine_type}" + "".format( + zone=availability_zone, + machine_type=base_config["machineType"])) + labels = dict(config.get("labels", {}), **labels) + + config.update({ + "machineType": machine_type, + "labels": dict(labels, + **{TAG_RAY_CLUSTER_NAME: self.cluster_name}), + }) + + operations = [ + self.compute.instances().insert( + project=project_id, + zone=availability_zone, + body=dict( + config, **{ + "name": ("{name_label}-{uuid}".format( + name_label=name_label, + uuid=uuid4().hex[:INSTANCE_NAME_UUID_LEN])) + })).execute() for i in range(count) + ] + + results = [ + wait_for_compute_zone_operation(self.compute, project_id, + operation, availability_zone) + for operation in operations + ] + + return results def terminate_node(self, node_id): - project_id = self.provider_config["project_id"] - availability_zone = self.provider_config["availability_zone"] + with self.lock: + project_id = self.provider_config["project_id"] + availability_zone = self.provider_config["availability_zone"] - operation = self.compute.instances().delete( - project=project_id, - zone=availability_zone, - instance=node_id, - ).execute() + operation = self.compute.instances().delete( + project=project_id, + zone=availability_zone, + instance=node_id, + ).execute() - result = wait_for_compute_zone_operation(self.compute, project_id, - operation, availability_zone) + result = wait_for_compute_zone_operation( + self.compute, project_id, operation, availability_zone) - return result + return result def _get_node(self, node_id): self.nodes({}) # Side effect: fetches and caches the node.