[autoscaler] Run initialization_commands without a persistent connection (#9020)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
This commit is contained in:
Ian Rodney
2020-07-06 16:34:59 -07:00
committed by GitHub
parent 139d21e068
commit 6fecd3cfce
2 changed files with 92 additions and 39 deletions
+80 -37
View File
@@ -147,6 +147,44 @@ class KubernetesCommandRunner:
self.node_id)
class SSHOptions:
def __init__(self, ssh_key, control_path=None, **kwargs):
self.ssh_key = ssh_key
self.arg_dict = {
# Supresses initial fingerprint verification.
"StrictHostKeyChecking": "no",
# SSH IP and fingerprint pairs no longer added to known_hosts.
# This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED"
# warning if a new node has the same IP as a previously
# deleted node, because the fingerprints will not match in
# that case.
"UserKnownHostsFile": os.devnull,
# Try fewer extraneous key pairs.
"IdentitiesOnly": "yes",
# Abort if port forwarding fails (instead of just printing to
# stderr).
"ExitOnForwardFailure": "yes",
# Quickly kill the connection if network connection breaks (as
# opposed to hanging/blocking).
"ServerAliveInterval": 5,
"ServerAliveCountMax": 3
}
if control_path:
self.arg_dict.update({
"ControlMaster": "auto",
"ControlPath": "{}/%C".format(control_path),
"ControlPersist": "10s",
})
self.arg_dict.update(kwargs)
def to_ssh_options_list(self, *, timeout=60):
self.arg_dict["ConnectTimeout"] = "{}s".format(timeout)
return ["-i", self.ssh_key] + [
x for y in (["-o", "{}={}".format(k, v)]
for k, v in self.arg_dict.items()) for x in y
]
class SSHCommandRunner:
def __init__(self, log_prefix, node_id, provider, auth_config,
cluster_name, process_runner, use_internal_ip):
@@ -166,36 +204,8 @@ class SSHCommandRunner:
self.ssh_user = auth_config["ssh_user"]
self.ssh_control_path = ssh_control_path
self.ssh_ip = None
def get_default_ssh_options(self, connect_timeout):
OPTS = [
("ConnectTimeout", "{}s".format(connect_timeout)),
# Supresses initial fingerprint verification.
("StrictHostKeyChecking", "no"),
# SSH IP and fingerprint pairs no longer added to known_hosts.
# This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED"
# warning if a new node has the same IP as a previously
# deleted node, because the fingerprints will not match in
# that case.
("UserKnownHostsFile", os.devnull),
("ControlMaster", "auto"),
("ControlPath", "{}/%C".format(self.ssh_control_path)),
("ControlPersist", "10s"),
# Try fewer extraneous key pairs.
("IdentitiesOnly", "yes"),
# Abort if port forwarding fails (instead of just printing to
# stderr).
("ExitOnForwardFailure", "yes"),
# Quickly kill the connection if network connection breaks (as
# opposed to hanging/blocking).
("ServerAliveInterval", 5),
("ServerAliveCountMax", 3),
]
return ["-i", self.ssh_private_key] + [
x for y in (["-o", "{}={}".format(k, v)] for k, v in OPTS)
for x in y
]
self.base_ssh_options = SSHOptions(self.ssh_private_key,
self.ssh_control_path)
def get_node_ip(self):
if self.use_internal_ip:
@@ -241,7 +251,14 @@ class SSHCommandRunner:
exit_on_fail=False,
port_forward=None,
with_output=False,
ssh_options_override=None,
**kwargs):
ssh_options = ssh_options_override or self.base_ssh_options
assert isinstance(
ssh_options, SSHOptions
), "ssh_options must be of type SSHOptions, got {}".format(
type(ssh_options))
self.set_ssh_ip_if_required()
@@ -255,7 +272,7 @@ class SSHCommandRunner:
"{} -> localhost:{}".format(local, remote))
ssh += ["-L", "{}:localhost:{}".format(remote, local)]
final_cmd = ssh + self.get_default_ssh_options(timeout) + [
final_cmd = ssh + ssh_options.to_ssh_options_list(timeout=timeout) + [
"{}@{}".format(self.ssh_user, self.ssh_ip)
]
if cmd:
@@ -286,16 +303,20 @@ class SSHCommandRunner:
self.set_ssh_ip_if_required()
self.process_runner.check_call([
"rsync", "--rsh",
" ".join(["ssh"] + self.get_default_ssh_options(120)), "-avz",
source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip, target)
" ".join(["ssh"] +
self.base_ssh_options.to_ssh_options_list(timeout=120)),
"-avz", source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
target)
])
def run_rsync_down(self, source, target):
self.set_ssh_ip_if_required()
self.process_runner.check_call([
"rsync", "--rsh",
" ".join(["ssh"] + self.get_default_ssh_options(120)), "-avz",
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target
" ".join(["ssh"] +
self.base_ssh_options.to_ssh_options_list(timeout=120)),
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
source), target
])
def remote_shell_command_str(self):
@@ -309,6 +330,7 @@ class DockerCommandRunner(SSHCommandRunner):
self.docker_name = docker_config["container_name"]
self.docker_config = docker_config
self.home_dir = None
self.check_docker_installed()
self.shutdown = False
def run(self,
@@ -318,6 +340,7 @@ class DockerCommandRunner(SSHCommandRunner):
port_forward=None,
with_output=False,
run_env=True,
ssh_options_override=None,
**kwargs):
if run_env == "auto":
run_env = "host" if cmd.find("docker") == 0 else "docker"
@@ -335,7 +358,23 @@ class DockerCommandRunner(SSHCommandRunner):
timeout=timeout,
exit_on_fail=exit_on_fail,
port_forward=None,
with_output=False)
with_output=False,
ssh_options_override=ssh_options_override)
def check_docker_installed(self):
try:
self.ssh_command_runner.run("command -v docker")
return
except Exception:
install_commands = [
"curl -fsSL https://get.docker.com -o get-docker.sh",
"sudo sh get-docker.sh", "sudo usermod -aG docker $USER",
"sudo systemctl restart docker -f"
]
logger.error(
"Docker not installed. You can install Docker by adding the "
"following commands to 'initialization_commands':\n" +
"\n".join(install_commands))
def shutdown_after_next_cmd(self):
self.shutdown = True
@@ -422,6 +461,7 @@ class NodeUpdater:
self.setup_commands = setup_commands
self.ray_start_commands = ray_start_commands
self.runtime_hash = runtime_hash
self.auth_config = auth_config
def run(self):
logger.info(self.log_prefix +
@@ -516,7 +556,10 @@ class NodeUpdater:
self.log_prefix + "Initialization commands",
show_status=True):
for cmd in self.initialization_commands:
self.cmd_runner.run(cmd)
self.cmd_runner.run(
cmd,
ssh_options_override=SSHOptions(
self.auth_config.get("ssh_private_key")))
with LogTimer(
self.log_prefix + "Setup commands", show_status=True):