mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:19:38 +08:00
Custom SSH socket directories (#4299)
* ssh_control_path added as an auth option. * revamped default ssh options to take in control path, nodeupdater checks auth config to see if a custom SSH sockets path was specified, otherwise the original hardcoded path is used. control path is now a nodeupdater instance variable * revert socketdir in auth config and change method for determining dir * new ssh dir method * Lint * ' -> " lint * changed using USER env to getpass.getuser()
This commit is contained in:
committed by
Kristian Hartikainen
parent
3e1adafbce
commit
3838548356
@@ -13,6 +13,7 @@ import sys
|
||||
import time
|
||||
|
||||
from threading import Thread
|
||||
from getpass import getuser
|
||||
|
||||
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG
|
||||
from ray.autoscaler.log_timer import LogTimer
|
||||
@@ -22,15 +23,14 @@ logger = logging.getLogger(__name__)
|
||||
# How long to wait for a node to start, in seconds
|
||||
NODE_START_WAIT_S = 300
|
||||
SSH_CHECK_INTERVAL = 5
|
||||
SSH_CONTROL_PATH = "/tmp/ray_ssh_sockets"
|
||||
|
||||
|
||||
def get_default_ssh_options(private_key, connect_timeout):
|
||||
def get_default_ssh_options(private_key, connect_timeout, ssh_control_path):
|
||||
OPTS = [
|
||||
("ConnectTimeout", "{}s".format(connect_timeout)),
|
||||
("StrictHostKeyChecking", "no"),
|
||||
("ControlMaster", "auto"),
|
||||
("ControlPath", "{}/%C".format(SSH_CONTROL_PATH)),
|
||||
("ControlPath", "{}/%C".format(ssh_control_path)),
|
||||
("ControlPersist", "5m"),
|
||||
]
|
||||
|
||||
@@ -54,6 +54,10 @@ class NodeUpdater(object):
|
||||
runtime_hash,
|
||||
process_runner=subprocess,
|
||||
use_internal_ip=False):
|
||||
|
||||
ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format(
|
||||
getuser(), cluster_name)
|
||||
|
||||
self.daemon = True
|
||||
self.process_runner = process_runner
|
||||
self.node_id = node_id
|
||||
@@ -62,6 +66,7 @@ class NodeUpdater(object):
|
||||
self.provider = provider
|
||||
self.ssh_private_key = auth_config["ssh_private_key"]
|
||||
self.ssh_user = auth_config["ssh_user"]
|
||||
self.ssh_control_path = ssh_control_path
|
||||
self.ssh_ip = None
|
||||
self.file_mounts = {
|
||||
remote: os.path.expanduser(local)
|
||||
@@ -113,12 +118,12 @@ class NodeUpdater(object):
|
||||
# persistent sessions later on.
|
||||
with open("/dev/null", "w") as redirect:
|
||||
self.get_caller(False)(
|
||||
["mkdir", "-p", SSH_CONTROL_PATH],
|
||||
["mkdir", "-p", self.ssh_control_path],
|
||||
stdout=redirect,
|
||||
stderr=redirect)
|
||||
|
||||
self.get_caller(False)(
|
||||
["chmod", "0700", SSH_CONTROL_PATH],
|
||||
["chmod", "0700", self.ssh_control_path],
|
||||
stdout=redirect,
|
||||
stderr=redirect)
|
||||
|
||||
@@ -234,9 +239,8 @@ class NodeUpdater(object):
|
||||
self.set_ssh_ip_if_required()
|
||||
self.get_caller(check_error)(
|
||||
[
|
||||
"rsync", "-e",
|
||||
" ".join(["ssh"] +
|
||||
get_default_ssh_options(self.ssh_private_key, 120)),
|
||||
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
|
||||
self.ssh_private_key, 120, self.ssh_control_path)),
|
||||
"--delete", "-avz", source, "{}@{}:{}".format(
|
||||
self.ssh_user, self.ssh_ip, target)
|
||||
],
|
||||
@@ -247,11 +251,9 @@ class NodeUpdater(object):
|
||||
self.set_ssh_ip_if_required()
|
||||
self.get_caller(check_error)(
|
||||
[
|
||||
"rsync", "-e",
|
||||
" ".join(["ssh"] +
|
||||
get_default_ssh_options(self.ssh_private_key, 120)),
|
||||
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
|
||||
source), target
|
||||
"rsync", "-e", " ".join(["ssh"] + get_default_ssh_options(
|
||||
self.ssh_private_key, 120, self.ssh_control_path)), "-avz",
|
||||
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target
|
||||
],
|
||||
stdout=redirect or sys.stdout,
|
||||
stderr=redirect or sys.stderr)
|
||||
@@ -287,8 +289,8 @@ class NodeUpdater(object):
|
||||
]
|
||||
|
||||
self.get_caller(expect_error)(
|
||||
ssh + ssh_opt + get_default_ssh_options(self.ssh_private_key,
|
||||
connect_timeout) +
|
||||
ssh + ssh_opt + get_default_ssh_options(
|
||||
self.ssh_private_key, connect_timeout, self.ssh_control_path) +
|
||||
["{}@{}".format(self.ssh_user, self.ssh_ip), cmd],
|
||||
stdout=redirect or sys.stdout,
|
||||
stderr=redirect or sys.stderr)
|
||||
|
||||
Reference in New Issue
Block a user