From 3838548356f9b186ca2ba0265d587f29a124e88e Mon Sep 17 00:00:00 2001 From: Zachary Barry Date: Sun, 14 Apr 2019 02:55:41 -0400 Subject: [PATCH] Custom SSH socket directories (#4299) * ssh_control_path added as an auth option. * revamped default ssh options to take in control path, nodeupdater checks auth config to see if a custom SSH sockets path was specified, otherwise the original hardcoded path is used. control path is now a nodeupdater instance variable * revert socketdir in auth config and change method for determining dir * new ssh dir method * Lint * ' -> " lint * changed using USER env to getpass.getuser() --- python/ray/autoscaler/updater.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index d679f7548..942d1a151 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -13,6 +13,7 @@ import sys import time from threading import Thread +from getpass import getuser from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG from ray.autoscaler.log_timer import LogTimer @@ -22,15 +23,14 @@ logger = logging.getLogger(__name__) # How long to wait for a node to start, in seconds NODE_START_WAIT_S = 300 SSH_CHECK_INTERVAL = 5 -SSH_CONTROL_PATH = "/tmp/ray_ssh_sockets" -def get_default_ssh_options(private_key, connect_timeout): +def get_default_ssh_options(private_key, connect_timeout, ssh_control_path): OPTS = [ ("ConnectTimeout", "{}s".format(connect_timeout)), ("StrictHostKeyChecking", "no"), ("ControlMaster", "auto"), - ("ControlPath", "{}/%C".format(SSH_CONTROL_PATH)), + ("ControlPath", "{}/%C".format(ssh_control_path)), ("ControlPersist", "5m"), ] @@ -54,6 +54,10 @@ class NodeUpdater(object): runtime_hash, process_runner=subprocess, use_internal_ip=False): + + ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format( + getuser(), cluster_name) + self.daemon = True self.process_runner = process_runner self.node_id = node_id @@ -62,6 +66,7 @@ class NodeUpdater(object): self.provider = provider self.ssh_private_key = auth_config["ssh_private_key"] self.ssh_user = auth_config["ssh_user"] + self.ssh_control_path = ssh_control_path self.ssh_ip = None self.file_mounts = { remote: os.path.expanduser(local) @@ -113,12 +118,12 @@ class NodeUpdater(object): # persistent sessions later on. with open("/dev/null", "w") as redirect: self.get_caller(False)( - ["mkdir", "-p", SSH_CONTROL_PATH], + ["mkdir", "-p", self.ssh_control_path], stdout=redirect, stderr=redirect) self.get_caller(False)( - ["chmod", "0700", SSH_CONTROL_PATH], + ["chmod", "0700", self.ssh_control_path], stdout=redirect, stderr=redirect) @@ -234,9 +239,8 @@ class NodeUpdater(object): self.set_ssh_ip_if_required() self.get_caller(check_error)( [ - "rsync", "-e", - " ".join(["ssh"] + - get_default_ssh_options(self.ssh_private_key, 120)), + "rsync", "-e", " ".join(["ssh"] + get_default_ssh_options( + self.ssh_private_key, 120, self.ssh_control_path)), "--delete", "-avz", source, "{}@{}:{}".format( self.ssh_user, self.ssh_ip, target) ], @@ -247,11 +251,9 @@ class NodeUpdater(object): self.set_ssh_ip_if_required() self.get_caller(check_error)( [ - "rsync", "-e", - " ".join(["ssh"] + - get_default_ssh_options(self.ssh_private_key, 120)), - "-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip, - source), target + "rsync", "-e", " ".join(["ssh"] + get_default_ssh_options( + self.ssh_private_key, 120, self.ssh_control_path)), "-avz", + "{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target ], stdout=redirect or sys.stdout, stderr=redirect or sys.stderr) @@ -287,8 +289,8 @@ class NodeUpdater(object): ] self.get_caller(expect_error)( - ssh + ssh_opt + get_default_ssh_options(self.ssh_private_key, - connect_timeout) + + ssh + ssh_opt + get_default_ssh_options( + self.ssh_private_key, connect_timeout, self.ssh_control_path) + ["{}@{}".format(self.ssh_user, self.ssh_ip), cmd], stdout=redirect or sys.stdout, stderr=redirect or sys.stderr)