From d966d987293cba0c5adcbf49167e6d9e6eed923e Mon Sep 17 00:00:00 2001 From: Ameer Haj Ali Date: Thu, 4 Jun 2020 13:27:17 -0700 Subject: [PATCH] cleanup to support provider's custom ssh command runner (#8720) * cleanup to support provider's custom ssh command runner * clean up * trailing white spaces fix * k8s signature fix Co-authored-by: Ameer Haj Ali --- .../autoscaler/kubernetes/node_provider.py | 6 +++++ python/ray/autoscaler/node_provider.py | 22 +++++++++++++++++++ python/ray/autoscaler/updater.py | 15 +++++-------- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/python/ray/autoscaler/kubernetes/node_provider.py b/python/ray/autoscaler/kubernetes/node_provider.py index 1775cece1..3471e4e76 100644 --- a/python/ray/autoscaler/kubernetes/node_provider.py +++ b/python/ray/autoscaler/kubernetes/node_provider.py @@ -3,6 +3,7 @@ import logging from ray.autoscaler.kubernetes import core_api, log_prefix from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME +from ray.autoscaler.updater import KubernetesCommandRunner logger = logging.getLogger(__name__) @@ -85,3 +86,8 @@ class KubernetesNodeProvider(NodeProvider): def terminate_nodes(self, node_ids): for node_id in node_ids: self.terminate_node(node_id) + + def get_command_runner(self, log_prefix, node_id, auth_config, + cluster_name, process_runner, use_internal_ip): + return KubernetesCommandRunner(log_prefix, self.namespace, node_id, + auth_config, process_runner) diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index 27a4a61ba..b0cc39f2a 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -3,6 +3,8 @@ import logging import os import yaml +from ray.autoscaler.updater import SSHCommandRunner + logger = logging.getLogger(__name__) @@ -208,3 +210,23 @@ class NodeProvider: def cleanup(self): """Clean-up when a Provider is no longer required.""" pass + + def get_command_runner(self, log_prefix, node_id, auth_config, + cluster_name, process_runner, use_internal_ip): + """ Returns the CommandRunner class used to perform SSH commands. + + Args: + log_prefix(str): stores "NodeUpdater: {}: ".format(). Used + to print progress in the CommandRunner. + node_id(str): the node ID. + auth_config(dict): the authentication configs from the autoscaler + yaml file. + cluster_name(str): the name of the cluster. + process_runner(module): the module to use to run the commands + in the CommandRunner. E.g., subprocess. + use_internal_ip(bool): whether the node_id belongs to an internal ip + or external ip. + """ + + return SSHCommandRunner(log_prefix, node_id, self, auth_config, + cluster_name, process_runner, use_internal_ip) diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 8d1a2d469..1ecf11363 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -315,16 +315,11 @@ class NodeUpdater: use_internal_ip=False): self.log_prefix = "NodeUpdater: {}: ".format(node_id) - if provider_config["type"] == "kubernetes": - self.cmd_runner = KubernetesCommandRunner( - self.log_prefix, provider.namespace, node_id, auth_config, - process_runner) - else: - use_internal_ip = (use_internal_ip or provider_config.get( - "use_internal_ips", False)) - self.cmd_runner = SSHCommandRunner( - self.log_prefix, node_id, provider, auth_config, cluster_name, - process_runner, use_internal_ip) + use_internal_ip = (use_internal_ip + or provider_config.get("use_internal_ips", False)) + self.cmd_runner = provider.get_command_runner( + self.log_prefix, node_id, auth_config, cluster_name, + process_runner, use_internal_ip) self.daemon = True self.process_runner = process_runner