[autoscaler] Move command runners into separate file and clean up interface. (#9340)

* cleanup

* wip

* fix imports

* fix lint
This commit is contained in:
Eric Liang
2020-07-09 15:40:56 -07:00
committed by GitHub
parent 8a76f4cbb5
commit 09b9b81ea4
4 changed files with 467 additions and 410 deletions
+463
View File
@@ -0,0 +1,463 @@
from getpass import getuser
from shlex import quote
from typing import List, Tuple
import click
import hashlib
import logging
import os
import subprocess
import sys
import time
from ray.autoscaler.docker import check_docker_running_cmd, with_docker_exec
from ray.autoscaler.log_timer import LogTimer
logger = logging.getLogger(__name__)
# How long to wait for a node to start, in seconds
NODE_START_WAIT_S = 300
HASH_MAX_LENGTH = 10
KUBECTL_RSYNC = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "kubernetes/kubectl-rsync.sh")
def _with_interactive(cmd):
force_interactive = ("true && source ~/.bashrc && "
"export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ")
return ["bash", "--login", "-c", "-i", quote(force_interactive + cmd)]
class CommandRunnerInterface:
"""Interface to run commands on a remote cluster node.
Command runner instances are returned by provider.get_command_runner()."""
def run(self,
cmd: str = None,
timeout: int = 120,
exit_on_fail: bool = False,
port_forward: List[Tuple[int, int]] = None,
with_output: bool = False,
**kwargs) -> str:
"""Run the given command on the cluster node and optionally get output.
Args:
cmd (str): The command to run.
timeout (int): The command timeout in seconds.
exit_on_fail (bool): Whether to sys exit on failure.
port_forward (list): List of (local, remote) ports to forward, or
a single tuple.
with_output (bool): Whether to return output.
"""
raise NotImplementedError
def run_rsync_up(self, source: str, target: str) -> None:
"""Rsync files up to the cluster node.
Args:
source (str): The (local) source directory or file.
target (str): The (remote) destination path.
"""
raise NotImplementedError
def run_rsync_down(self, source: str, target: str) -> None:
"""Rsync files down from the cluster node.
Args:
source (str): The (remote) source directory or file.
target (str): The (local) destination path.
"""
raise NotImplementedError
def remote_shell_command_str(self) -> str:
"""Return the command the user can use to open a shell."""
raise NotImplementedError
class KubernetesCommandRunner(CommandRunnerInterface):
def __init__(self, log_prefix, namespace, node_id, auth_config,
process_runner):
self.log_prefix = log_prefix
self.process_runner = process_runner
self.node_id = node_id
self.namespace = namespace
self.kubectl = ["kubectl", "-n", self.namespace]
def run(self,
cmd=None,
timeout=120,
exit_on_fail=False,
port_forward=None,
with_output=False,
**kwargs):
if cmd and port_forward:
raise Exception(
"exec with Kubernetes can't forward ports and execute"
"commands together.")
if port_forward:
if not isinstance(port_forward, list):
port_forward = [port_forward]
port_forward_cmd = self.kubectl + [
"port-forward",
self.node_id,
] + [
"{}:{}".format(local, remote) for local, remote in port_forward
]
logger.info("Port forwarding with: {}".format(
" ".join(port_forward_cmd)))
port_forward_process = subprocess.Popen(port_forward_cmd)
port_forward_process.wait()
# We should never get here, this indicates that port forwarding
# failed, likely because we couldn't bind to a port.
pout, perr = port_forward_process.communicate()
exception_str = " ".join(
port_forward_cmd) + " failed with error: " + perr
raise Exception(exception_str)
else:
final_cmd = self.kubectl + ["exec", "-it"]
final_cmd += [
self.node_id,
"--",
]
final_cmd += _with_interactive(cmd)
logger.info(self.log_prefix + "Running {}".format(final_cmd))
try:
if with_output:
return self.process_runner.check_output(
" ".join(final_cmd), shell=True)
else:
self.process_runner.check_call(
" ".join(final_cmd), shell=True)
except subprocess.CalledProcessError:
if exit_on_fail:
quoted_cmd = " ".join(final_cmd[:-1] +
[quote(final_cmd[-1])])
logger.error(
self.log_prefix +
"Command failed: \n\n {}\n".format(quoted_cmd))
sys.exit(1)
else:
raise
def run_rsync_up(self, source, target):
if target.startswith("~"):
target = "/root" + target[1:]
try:
self.process_runner.check_call([
KUBECTL_RSYNC,
"-avz",
source,
"{}@{}:{}".format(self.node_id, self.namespace, target),
])
except Exception as e:
logger.warning(self.log_prefix +
"rsync failed: '{}'. Falling back to 'kubectl cp'"
.format(e))
self.process_runner.check_call(self.kubectl + [
"cp", source, "{}/{}:{}".format(self.namespace, self.node_id,
target)
])
def run_rsync_down(self, source, target):
if target.startswith("~"):
target = "/root" + target[1:]
try:
self.process_runner.check_call([
KUBECTL_RSYNC,
"-avz",
"{}@{}:{}".format(self.node_id, self.namespace, source),
target,
])
except Exception as e:
logger.warning(self.log_prefix +
"rsync failed: '{}'. Falling back to 'kubectl cp'"
.format(e))
self.process_runner.check_call(self.kubectl + [
"cp", "{}/{}:{}".format(self.namespace, self.node_id, source),
target
])
def remote_shell_command_str(self):
return "{} exec -it {} bash".format(" ".join(self.kubectl),
self.node_id)
class SSHOptions:
def __init__(self, ssh_key, control_path=None, **kwargs):
self.ssh_key = ssh_key
self.arg_dict = {
# Supresses initial fingerprint verification.
"StrictHostKeyChecking": "no",
# SSH IP and fingerprint pairs no longer added to known_hosts.
# This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED"
# warning if a new node has the same IP as a previously
# deleted node, because the fingerprints will not match in
# that case.
"UserKnownHostsFile": os.devnull,
# Try fewer extraneous key pairs.
"IdentitiesOnly": "yes",
# Abort if port forwarding fails (instead of just printing to
# stderr).
"ExitOnForwardFailure": "yes",
# Quickly kill the connection if network connection breaks (as
# opposed to hanging/blocking).
"ServerAliveInterval": 5,
"ServerAliveCountMax": 3
}
if control_path:
self.arg_dict.update({
"ControlMaster": "auto",
"ControlPath": "{}/%C".format(control_path),
"ControlPersist": "10s",
})
self.arg_dict.update(kwargs)
def to_ssh_options_list(self, *, timeout=60):
self.arg_dict["ConnectTimeout"] = "{}s".format(timeout)
return ["-i", self.ssh_key] + [
x for y in (["-o", "{}={}".format(k, v)]
for k, v in self.arg_dict.items()) for x in y
]
class SSHCommandRunner(CommandRunnerInterface):
def __init__(self, log_prefix, node_id, provider, auth_config,
cluster_name, process_runner, use_internal_ip):
ssh_control_hash = hashlib.md5(cluster_name.encode()).hexdigest()
ssh_user_hash = hashlib.md5(getuser().encode()).hexdigest()
ssh_control_path = "/tmp/ray_ssh_{}/{}".format(
ssh_user_hash[:HASH_MAX_LENGTH],
ssh_control_hash[:HASH_MAX_LENGTH])
self.log_prefix = log_prefix
self.process_runner = process_runner
self.node_id = node_id
self.use_internal_ip = use_internal_ip
self.provider = provider
self.ssh_private_key = auth_config["ssh_private_key"]
self.ssh_user = auth_config["ssh_user"]
self.ssh_control_path = ssh_control_path
self.ssh_ip = None
self.base_ssh_options = SSHOptions(self.ssh_private_key,
self.ssh_control_path)
def _get_node_ip(self):
if self.use_internal_ip:
return self.provider.internal_ip(self.node_id)
else:
return self.provider.external_ip(self.node_id)
def _wait_for_ip(self, deadline):
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
logger.info(self.log_prefix + "Waiting for IP...")
ip = self._get_node_ip()
if ip is not None:
return ip
time.sleep(10)
return None
def _set_ssh_ip_if_required(self):
if self.ssh_ip is not None:
return
# We assume that this never changes.
# I think that's reasonable.
deadline = time.time() + NODE_START_WAIT_S
with LogTimer(self.log_prefix + "Got IP"):
ip = self._wait_for_ip(deadline)
assert ip is not None, "Unable to find IP of node"
self.ssh_ip = ip
# This should run before any SSH commands and therefore ensure that
# the ControlPath directory exists, allowing SSH to maintain
# persistent sessions later on.
try:
os.makedirs(self.ssh_control_path, mode=0o700, exist_ok=True)
except OSError as e:
logger.warning(e)
def run(self,
cmd,
timeout=120,
exit_on_fail=False,
port_forward=None,
with_output=False,
ssh_options_override=None,
**kwargs):
ssh_options = ssh_options_override or self.base_ssh_options
assert isinstance(
ssh_options, SSHOptions
), "ssh_options must be of type SSHOptions, got {}".format(
type(ssh_options))
self._set_ssh_ip_if_required()
ssh = ["ssh", "-tt"]
if port_forward:
if not isinstance(port_forward, list):
port_forward = [port_forward]
for local, remote in port_forward:
logger.info(self.log_prefix + "Forwarding " +
"{} -> localhost:{}".format(local, remote))
ssh += ["-L", "{}:localhost:{}".format(remote, local)]
final_cmd = ssh + ssh_options.to_ssh_options_list(timeout=timeout) + [
"{}@{}".format(self.ssh_user, self.ssh_ip)
]
if cmd:
final_cmd += _with_interactive(cmd)
logger.info(self.log_prefix +
"Running {}".format(" ".join(final_cmd)))
else:
# We do this because `-o ControlMaster` causes the `-N` flag to
# still create an interactive shell in some ssh versions.
final_cmd.append(quote("while true; do sleep 86400; done"))
try:
if with_output:
return self.process_runner.check_output(final_cmd)
else:
self.process_runner.check_call(final_cmd)
except subprocess.CalledProcessError:
if exit_on_fail:
quoted_cmd = " ".join(final_cmd[:-1] + [quote(final_cmd[-1])])
raise click.ClickException(
"Command failed: \n\n {}\n".format(quoted_cmd)) from None
else:
raise click.ClickException(
"SSH command Failed. See above for the output from the"
" failure.") from None
def run_rsync_up(self, source, target):
self._set_ssh_ip_if_required()
self.process_runner.check_call([
"rsync", "--rsh",
" ".join(["ssh"] +
self.base_ssh_options.to_ssh_options_list(timeout=120)),
"-avz", source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
target)
])
def run_rsync_down(self, source, target):
self._set_ssh_ip_if_required()
self.process_runner.check_call([
"rsync", "--rsh",
" ".join(["ssh"] +
self.base_ssh_options.to_ssh_options_list(timeout=120)),
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
source), target
])
def remote_shell_command_str(self):
return "ssh -o IdentitiesOnly=yes -i {} {}@{}\n".format(
self.ssh_private_key, self.ssh_user, self.ssh_ip)
class DockerCommandRunner(SSHCommandRunner):
def __init__(self, docker_config, **common_args):
self.ssh_command_runner = SSHCommandRunner(**common_args)
self.docker_name = docker_config["container_name"]
self.docker_config = docker_config
self.home_dir = None
self._check_docker_installed()
self.shutdown = False
def run(self,
cmd,
timeout=120,
exit_on_fail=False,
port_forward=None,
with_output=False,
run_env=True,
ssh_options_override=None,
**kwargs):
if run_env == "auto":
run_env = "host" if cmd.find("docker") == 0 else "docker"
if run_env == "docker":
cmd = self._docker_expand_user(cmd, any_char=True)
cmd = with_docker_exec(
[cmd], container_name=self.docker_name,
with_interactive=True)[0]
if self.shutdown:
cmd += "; sudo shutdown -h now"
return self.ssh_command_runner.run(
cmd,
timeout=timeout,
exit_on_fail=exit_on_fail,
port_forward=None,
with_output=False,
ssh_options_override=ssh_options_override)
def run_rsync_up(self, source, target):
self.ssh_command_runner.run_rsync_up(source, target)
if self._check_container_status():
self.ssh_command_runner.run("docker cp {} {}:{}".format(
target, self.docker_name, self._docker_expand_user(target)))
def run_rsync_down(self, source, target):
self.ssh_command_runner.run("docker cp {}:{} {}".format(
self.docker_name, self._docker_expand_user(source), source))
self.ssh_command_runner.run_rsync_down(source, target)
def remote_shell_command_str(self):
inner_str = self.ssh_command_runner.remote_shell_command_str().replace(
"ssh", "ssh -tt", 1).strip("\n")
return inner_str + " docker exec -it {} /bin/bash\n".format(
self.docker_name)
def _check_docker_installed(self):
try:
self.ssh_command_runner.run("command -v docker")
return
except Exception:
install_commands = [
"curl -fsSL https://get.docker.com -o get-docker.sh",
"sudo sh get-docker.sh", "sudo usermod -aG docker $USER",
"sudo systemctl restart docker -f"
]
logger.error(
"Docker not installed. You can install Docker by adding the "
"following commands to 'initialization_commands':\n" +
"\n".join(install_commands))
def _shutdown_after_next_cmd(self):
self.shutdown = True
def _check_container_status(self):
no_exist = "not_present"
cmd = check_docker_running_cmd(self.docker_name) + " ".join(
["||", "echo", quote(no_exist)])
output = self.ssh_command_runner.run(
cmd, with_output=True).decode("utf-8").strip()
if no_exist in output:
return False
return "true" in output.lower()
def _docker_expand_user(self, string, any_char=False):
user_pos = string.find("~")
if user_pos > -1:
if self.home_dir is None:
self.home_dir = self.ssh_command_runner.run(
"docker exec {} env | grep HOME | cut -d'=' -f2".format(
self.docker_name),
with_output=True).decode("utf-8").strip()
if any_char:
return string.replace("~/", self.home_dir + "/")
elif not any_char and user_pos == 0:
return string.replace("~", self.home_dir, 1)
return string
+2 -1
View File
@@ -25,7 +25,8 @@ from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
TAG_RAY_NODE_NAME, NODE_TYPE_WORKER, NODE_TYPE_HEAD
from ray.ray_constants import AUTOSCALER_RESOURCE_REQUEST_CHANNEL
from ray.autoscaler.updater import NodeUpdaterThread, DockerCommandRunner
from ray.autoscaler.updater import NodeUpdaterThread
from ray.autoscaler.command_runner import DockerCommandRunner
from ray.autoscaler.log_timer import LogTimer
from ray.worker import global_worker
+1 -1
View File
@@ -3,7 +3,7 @@ import logging
import os
import yaml
from ray.autoscaler.updater import SSHCommandRunner, DockerCommandRunner
from ray.autoscaler.command_runner import SSHCommandRunner, DockerCommandRunner
logger = logging.getLogger(__name__)
+1 -408
View File
@@ -1,427 +1,20 @@
try: # py3
from shlex import quote
except ImportError: # py2
from pipes import quote
import click
import hashlib
import logging
import os
import subprocess
import sys
import time
from threading import Thread
from getpass import getuser
from ray.autoscaler.tags import TAG_RAY_NODE_STATUS, TAG_RAY_RUNTIME_CONFIG, \
STATUS_UP_TO_DATE, STATUS_UPDATE_FAILED, STATUS_WAITING_FOR_SSH, \
STATUS_SETTING_UP, STATUS_SYNCING_FILES
from ray.autoscaler.command_runner import NODE_START_WAIT_S, SSHOptions
from ray.autoscaler.log_timer import LogTimer
from ray.autoscaler.docker import check_docker_running_cmd, with_docker_exec
logger = logging.getLogger(__name__)
# How long to wait for a node to start, in seconds
NODE_START_WAIT_S = 300
READY_CHECK_INTERVAL = 5
HASH_MAX_LENGTH = 10
KUBECTL_RSYNC = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "kubernetes/kubectl-rsync.sh")
def with_interactive(cmd):
force_interactive = ("true && source ~/.bashrc && "
"export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ")
return ["bash", "--login", "-c", "-i", quote(force_interactive + cmd)]
class KubernetesCommandRunner:
def __init__(self, log_prefix, namespace, node_id, auth_config,
process_runner):
self.log_prefix = log_prefix
self.process_runner = process_runner
self.node_id = node_id
self.namespace = namespace
self.kubectl = ["kubectl", "-n", self.namespace]
def run(self,
cmd=None,
timeout=120,
exit_on_fail=False,
port_forward=None,
with_output=False,
**kwargs):
if cmd and port_forward:
raise Exception(
"exec with Kubernetes can't forward ports and execute"
"commands together.")
if port_forward:
if not isinstance(port_forward, list):
port_forward = [port_forward]
port_forward_cmd = self.kubectl + [
"port-forward",
self.node_id,
] + [
"{}:{}".format(local, remote) for local, remote in port_forward
]
logger.info("Port forwarding with: {}".format(
" ".join(port_forward_cmd)))
port_forward_process = subprocess.Popen(port_forward_cmd)
port_forward_process.wait()
# We should never get here, this indicates that port forwarding
# failed, likely because we couldn't bind to a port.
pout, perr = port_forward_process.communicate()
exception_str = " ".join(
port_forward_cmd) + " failed with error: " + perr
raise Exception(exception_str)
else:
final_cmd = self.kubectl + ["exec", "-it"]
final_cmd += [
self.node_id,
"--",
]
final_cmd += with_interactive(cmd)
logger.info(self.log_prefix + "Running {}".format(final_cmd))
try:
if with_output:
return self.process_runner.check_output(
" ".join(final_cmd), shell=True)
else:
self.process_runner.check_call(
" ".join(final_cmd), shell=True)
except subprocess.CalledProcessError:
if exit_on_fail:
quoted_cmd = " ".join(final_cmd[:-1] +
[quote(final_cmd[-1])])
logger.error(
self.log_prefix +
"Command failed: \n\n {}\n".format(quoted_cmd))
sys.exit(1)
else:
raise
def run_rsync_up(self, source, target):
if target.startswith("~"):
target = "/root" + target[1:]
try:
self.process_runner.check_call([
KUBECTL_RSYNC,
"-avz",
source,
"{}@{}:{}".format(self.node_id, self.namespace, target),
])
except Exception as e:
logger.warning(self.log_prefix +
"rsync failed: '{}'. Falling back to 'kubectl cp'"
.format(e))
self.process_runner.check_call(self.kubectl + [
"cp", source, "{}/{}:{}".format(self.namespace, self.node_id,
target)
])
def run_rsync_down(self, source, target):
if target.startswith("~"):
target = "/root" + target[1:]
try:
self.process_runner.check_call([
KUBECTL_RSYNC,
"-avz",
"{}@{}:{}".format(self.node_id, self.namespace, source),
target,
])
except Exception as e:
logger.warning(self.log_prefix +
"rsync failed: '{}'. Falling back to 'kubectl cp'"
.format(e))
self.process_runner.check_call(self.kubectl + [
"cp", "{}/{}:{}".format(self.namespace, self.node_id, source),
target
])
def remote_shell_command_str(self):
return "{} exec -it {} bash".format(" ".join(self.kubectl),
self.node_id)
class SSHOptions:
def __init__(self, ssh_key, control_path=None, **kwargs):
self.ssh_key = ssh_key
self.arg_dict = {
# Supresses initial fingerprint verification.
"StrictHostKeyChecking": "no",
# SSH IP and fingerprint pairs no longer added to known_hosts.
# This is to remove a "REMOTE HOST IDENTIFICATION HAS CHANGED"
# warning if a new node has the same IP as a previously
# deleted node, because the fingerprints will not match in
# that case.
"UserKnownHostsFile": os.devnull,
# Try fewer extraneous key pairs.
"IdentitiesOnly": "yes",
# Abort if port forwarding fails (instead of just printing to
# stderr).
"ExitOnForwardFailure": "yes",
# Quickly kill the connection if network connection breaks (as
# opposed to hanging/blocking).
"ServerAliveInterval": 5,
"ServerAliveCountMax": 3
}
if control_path:
self.arg_dict.update({
"ControlMaster": "auto",
"ControlPath": "{}/%C".format(control_path),
"ControlPersist": "10s",
})
self.arg_dict.update(kwargs)
def to_ssh_options_list(self, *, timeout=60):
self.arg_dict["ConnectTimeout"] = "{}s".format(timeout)
return ["-i", self.ssh_key] + [
x for y in (["-o", "{}={}".format(k, v)]
for k, v in self.arg_dict.items()) for x in y
]
class SSHCommandRunner:
def __init__(self, log_prefix, node_id, provider, auth_config,
cluster_name, process_runner, use_internal_ip):
ssh_control_hash = hashlib.md5(cluster_name.encode()).hexdigest()
ssh_user_hash = hashlib.md5(getuser().encode()).hexdigest()
ssh_control_path = "/tmp/ray_ssh_{}/{}".format(
ssh_user_hash[:HASH_MAX_LENGTH],
ssh_control_hash[:HASH_MAX_LENGTH])
self.log_prefix = log_prefix
self.process_runner = process_runner
self.node_id = node_id
self.use_internal_ip = use_internal_ip
self.provider = provider
self.ssh_private_key = auth_config["ssh_private_key"]
self.ssh_user = auth_config["ssh_user"]
self.ssh_control_path = ssh_control_path
self.ssh_ip = None
self.base_ssh_options = SSHOptions(self.ssh_private_key,
self.ssh_control_path)
def get_node_ip(self):
if self.use_internal_ip:
return self.provider.internal_ip(self.node_id)
else:
return self.provider.external_ip(self.node_id)
def wait_for_ip(self, deadline):
while time.time() < deadline and \
not self.provider.is_terminated(self.node_id):
logger.info(self.log_prefix + "Waiting for IP...")
ip = self.get_node_ip()
if ip is not None:
return ip
time.sleep(10)
return None
def set_ssh_ip_if_required(self):
if self.ssh_ip is not None:
return
# We assume that this never changes.
# I think that's reasonable.
deadline = time.time() + NODE_START_WAIT_S
with LogTimer(self.log_prefix + "Got IP"):
ip = self.wait_for_ip(deadline)
assert ip is not None, "Unable to find IP of node"
self.ssh_ip = ip
# This should run before any SSH commands and therefore ensure that
# the ControlPath directory exists, allowing SSH to maintain
# persistent sessions later on.
try:
os.makedirs(self.ssh_control_path, mode=0o700, exist_ok=True)
except OSError as e:
logger.warning(e)
def run(self,
cmd,
timeout=120,
exit_on_fail=False,
port_forward=None,
with_output=False,
ssh_options_override=None,
**kwargs):
ssh_options = ssh_options_override or self.base_ssh_options
assert isinstance(
ssh_options, SSHOptions
), "ssh_options must be of type SSHOptions, got {}".format(
type(ssh_options))
self.set_ssh_ip_if_required()
ssh = ["ssh", "-tt"]
if port_forward:
if not isinstance(port_forward, list):
port_forward = [port_forward]
for local, remote in port_forward:
logger.info(self.log_prefix + "Forwarding " +
"{} -> localhost:{}".format(local, remote))
ssh += ["-L", "{}:localhost:{}".format(remote, local)]
final_cmd = ssh + ssh_options.to_ssh_options_list(timeout=timeout) + [
"{}@{}".format(self.ssh_user, self.ssh_ip)
]
if cmd:
final_cmd += with_interactive(cmd)
logger.info(self.log_prefix +
"Running {}".format(" ".join(final_cmd)))
else:
# We do this because `-o ControlMaster` causes the `-N` flag to
# still create an interactive shell in some ssh versions.
final_cmd.append(quote("while true; do sleep 86400; done"))
try:
if with_output:
return self.process_runner.check_output(final_cmd)
else:
self.process_runner.check_call(final_cmd)
except subprocess.CalledProcessError:
if exit_on_fail:
quoted_cmd = " ".join(final_cmd[:-1] + [quote(final_cmd[-1])])
raise click.ClickException(
"Command failed: \n\n {}\n".format(quoted_cmd)) from None
else:
raise click.ClickException(
"SSH command Failed. See above for the output from the"
" failure.") from None
def run_rsync_up(self, source, target):
self.set_ssh_ip_if_required()
self.process_runner.check_call([
"rsync", "--rsh",
" ".join(["ssh"] +
self.base_ssh_options.to_ssh_options_list(timeout=120)),
"-avz", source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
target)
])
def run_rsync_down(self, source, target):
self.set_ssh_ip_if_required()
self.process_runner.check_call([
"rsync", "--rsh",
" ".join(["ssh"] +
self.base_ssh_options.to_ssh_options_list(timeout=120)),
"-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip,
source), target
])
def remote_shell_command_str(self):
return "ssh -o IdentitiesOnly=yes -i {} {}@{}\n".format(
self.ssh_private_key, self.ssh_user, self.ssh_ip)
class DockerCommandRunner(SSHCommandRunner):
def __init__(self, docker_config, **common_args):
self.ssh_command_runner = SSHCommandRunner(**common_args)
self.docker_name = docker_config["container_name"]
self.docker_config = docker_config
self.home_dir = None
self.check_docker_installed()
self.shutdown = False
def run(self,
cmd,
timeout=120,
exit_on_fail=False,
port_forward=None,
with_output=False,
run_env=True,
ssh_options_override=None,
**kwargs):
if run_env == "auto":
run_env = "host" if cmd.find("docker") == 0 else "docker"
if run_env == "docker":
cmd = self.docker_expand_user(cmd, any_char=True)
cmd = with_docker_exec(
[cmd], container_name=self.docker_name,
with_interactive=True)[0]
if self.shutdown:
cmd += "; sudo shutdown -h now"
return self.ssh_command_runner.run(
cmd,
timeout=timeout,
exit_on_fail=exit_on_fail,
port_forward=None,
with_output=False,
ssh_options_override=ssh_options_override)
def check_docker_installed(self):
try:
self.ssh_command_runner.run("command -v docker")
return
except Exception:
install_commands = [
"curl -fsSL https://get.docker.com -o get-docker.sh",
"sudo sh get-docker.sh", "sudo usermod -aG docker $USER",
"sudo systemctl restart docker -f"
]
logger.error(
"Docker not installed. You can install Docker by adding the "
"following commands to 'initialization_commands':\n" +
"\n".join(install_commands))
def shutdown_after_next_cmd(self):
self.shutdown = True
def check_container_status(self):
no_exist = "not_present"
cmd = check_docker_running_cmd(self.docker_name) + " ".join(
["||", "echo", quote(no_exist)])
output = self.ssh_command_runner.run(
cmd, with_output=True).decode("utf-8").strip()
if no_exist in output:
return False
return "true" in output.lower()
def run_rsync_up(self, source, target):
self.ssh_command_runner.run_rsync_up(source, target)
if self.check_container_status():
self.ssh_command_runner.run("docker cp {} {}:{}".format(
target, self.docker_name, self.docker_expand_user(target)))
def run_rsync_down(self, source, target):
self.ssh_command_runner.run("docker cp {}:{} {}".format(
self.docker_name, self.docker_expand_user(source), source))
self.ssh_command_runner.run_rsync_down(source, target)
def remote_shell_command_str(self):
inner_str = self.ssh_command_runner.remote_shell_command_str().replace(
"ssh", "ssh -tt", 1).strip("\n")
return inner_str + " docker exec -it {} /bin/bash\n".format(
self.docker_name)
def docker_expand_user(self, string, any_char=False):
user_pos = string.find("~")
if user_pos > -1:
if self.home_dir is None:
self.home_dir = self.ssh_command_runner.run(
"docker exec {} env | grep HOME | cut -d'=' -f2".format(
self.docker_name),
with_output=True).decode("utf-8").strip()
if any_char:
return string.replace("~/", self.home_dir + "/")
elif not any_char and user_pos == 0:
return string.replace("~", self.home_dir, 1)
return string
class NodeUpdater: