diff --git a/python/ray/autoscaler/_private/kubernetes/node_provider.py b/python/ray/autoscaler/_private/kubernetes/node_provider.py index aac45c904..997a28718 100644 --- a/python/ray/autoscaler/_private/kubernetes/node_provider.py +++ b/python/ray/autoscaler/_private/kubernetes/node_provider.py @@ -1,4 +1,5 @@ import logging +import time from uuid import uuid4 from kubernetes.client.rest import ApiException @@ -11,6 +12,9 @@ from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME logger = logging.getLogger(__name__) +MAX_TAG_RETRIES = 3 +DELAY_BEFORE_TAG_RETRY = .5 + def to_label_selector(tags): label_selector = "" @@ -71,7 +75,23 @@ class KubernetesNodeProvider(NodeProvider): raise ValueError("Must use internal IPs with Kubernetes.") return super().get_node_id(ip_address, use_internal_ip=use_internal_ip) - def set_node_tags(self, node_id, tags): + def set_node_tags(self, node_ids, tags): + for _ in range(MAX_TAG_RETRIES - 1): + try: + self._set_node_tags(node_ids, tags) + return + except ApiException as e: + if e.status == 409: + logger.info(log_prefix + "Caught a 409 error while setting" + " node tags. Retrying...") + time.sleep(DELAY_BEFORE_TAG_RETRY) + continue + else: + raise + # One more try + self._set_node_tags(node_ids, tags) + + def _set_node_tags(self, node_id, tags): pod = core_api().read_namespaced_pod(node_id, self.namespace) pod.metadata.labels.update(tags) core_api().patch_namespaced_pod(node_id, self.namespace, pod) diff --git a/python/ray/autoscaler/_private/updater.py b/python/ray/autoscaler/_private/updater.py index 98877f913..04989bb85 100644 --- a/python/ray/autoscaler/_private/updater.py +++ b/python/ray/autoscaler/_private/updater.py @@ -75,7 +75,6 @@ class NodeUpdater: process_runner, use_internal_ip, docker_config) self.daemon = True - self.process_runner = process_runner self.node_id = node_id self.provider = provider # Some node providers don't specify empty structures as diff --git a/python/ray/node.py b/python/ray/node.py index cf0d89195..d645b70cd 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -224,7 +224,7 @@ class Node: ray.utils.set_sigterm_handler(sigterm_handler) def _init_temp(self, redis_client): - # Create an dictionary to store temp file index. + # Create a dictionary to store temp file index. self._incremental_dict = collections.defaultdict(lambda: 0) if self.head: