Custom SSH socket directories (#4299)

* ssh_control_path added as an auth option.

* revamped default ssh options to take in control path, nodeupdater checks auth config to see if a custom SSH sockets path was specified, otherwise the original hardcoded path is used. control path is now a nodeupdater instance variable

* revert socketdir in auth config and change method for determining dir

* new ssh dir method

* Lint

* ' -> " lint

* changed using USER env to getpass.getuser()
This commit is contained in:
Zachary Barry
2019-04-14 02:55:41 -04:00
committed by Kristian Hartikainen
parent 3e1adafbce
commit 3838548356
+17 -15
View File
@@ -13,6 +13,7 @@ import sys
import time
from threading import Thread
from getpass import getuser
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG
from ray.autoscaler.log_timer import LogTimer
@@ -22,15 +23,14 @@ logger = logging.getLogger(__name__)
# How long to wait for a node to start, in seconds
NODE_START_WAIT_S = 300
SSH_CHECK_INTERVAL = 5
SSH_CONTROL_PATH = "/tmp/ray_ssh_sockets"
def get_default_ssh_options(private_key, connect_timeout):
def get_default_ssh_options(private_key, connect_timeout, ssh_control_path):
OPTS = [
("ConnectTimeout", "{}s".format(connect_timeout)),
("StrictHostKeyChecking", "no"),
("ControlMaster", "auto"),
("ControlPath", "{}/%C".format(SSH_CONTROL_PATH)),
("ControlPath", "{}/%C".format(ssh_control_path)),
("ControlPersist", "5m"),
]
@@ -54,6 +54,10 @@ class NodeUpdater(object):
runtime_hash,
process_runner=subprocess,
use_internal_ip=False):
ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format(
getuser(), cluster_name)
self.daemon = True
self.process_runner = process_runner
self.node_id = node_id
@@ -62,6 +66,7 @@ class NodeUpdater(object):
self.provider = provider
self.ssh_private_key = auth_config["ssh_private_key"]
self.ssh_user = auth_config["ssh_user"]
self.ssh_control_path = ssh_control_path
self.ssh_ip = None
self.file_mounts = {
remote: os.path.expanduser(local)
@@ -113,12 +118,12 @@ class NodeUpdater(object):
# persistent sessions later on.
with open("/dev/null", "w") as redirect:
self.get_caller(False)(
["mkdir", "-p", SSH_CONTROL_PATH],
["mkdir", "-p", self.ssh_control_path],
stdout=redirect,
stderr=redirect)
self.get_caller(False)(
["chmod", "0700", SSH_CONTROL_PATH],
["chmod", "0700", self.ssh_control_path],
stdout=redirect,
stderr=redirect)
@@ -234,9 +239,8 @@ class NodeUpdater(object):
self.set_ssh_ip_if_required()
self.get_caller(check_error)(
[
"rsync", "-e",
" ".join(["ssh"] +
get_default_ssh_options(self.ssh_private_key, 120)),
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
self.ssh_private_key, 120, self.ssh_control_path)),
"--delete", "-avz", source, "{}@{}:{}".format(
self.ssh_user, self.ssh_ip, target)
],
@@ -247,11 +251,9 @@ class NodeUpdater(object):
self.set_ssh_ip_if_required()
self.get_caller(check_error)(
[
"rsync", "-e",
" ".join(["ssh"] +
get_default_ssh_options(self.ssh_private_key, 120)),
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
source), target
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target
],
stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr)
@@ -287,8 +289,8 @@ class NodeUpdater(object):
]
self.get_caller(expect_error)(
ssh + ssh_opt + get_default_ssh_options(self.ssh_private_key,
connect_timeout) +
ssh + ssh_opt + get_default_ssh_options(
self.ssh_private_key, connect_timeout, self.ssh_control_path) +
["{}@{}".format(self.ssh_user, self.ssh_ip), cmd],
stdout=redirect or sys.stdout,
stderr=redirect or sys.stderr)