Files
ray/python/ray/tests/test_command_runner.py
T

376 lines
14 KiB
Python

import logging
import pytest
import sys
from unittest.mock import patch
from ray.tests.test_autoscaler import MockProvider, MockProcessRunner
from ray.autoscaler.command_runner import CommandRunnerInterface
from ray.autoscaler._private.command_runner import SSHCommandRunner, \
DockerCommandRunner, KubernetesCommandRunner, _with_environment_variables
from ray.autoscaler.sdk import get_docker_host_mount_location
from getpass import getuser
import hashlib
auth_config = {
"ssh_user": "ray",
"ssh_private_key": "8265.pem",
}
def test_environment_variable_encoder_strings():
env_vars = {"var1": "quote between this \" and this", "var2": "123"}
res = _with_environment_variables("echo hello", env_vars)
expected = """export var1='"quote between this \\" and this"';export var2='"123"';echo hello""" # noqa: E501
assert res == expected
def test_environment_variable_encoder_dict():
env_vars = {"value1": "string1", "value2": {"a": "b", "c": 2}}
res = _with_environment_variables("echo hello", env_vars)
expected = """export value1='"string1"';export value2='{"a":"b","c":2}';echo hello""" # noqa: E501
assert res == expected
def test_command_runner_interface_abstraction_violation():
"""Enforces the CommandRunnerInterface functions on the subclasses.
This is important to make sure the subclasses do not violate the
function abstractions. If you need to add a new function to one of
the CommandRunnerInterface subclasses, you have to add it to
CommandRunnerInterface and all of its subclasses.
"""
cmd_runner_interface_public_functions = dir(CommandRunnerInterface)
allowed_public_interface_functions = {
func
for func in cmd_runner_interface_public_functions
if not func.startswith("_")
}
for subcls in [
SSHCommandRunner, DockerCommandRunner, KubernetesCommandRunner
]:
subclass_available_functions = dir(subcls)
subclass_public_functions = {
func
for func in subclass_available_functions
if not func.startswith("_")
}
assert allowed_public_interface_functions == subclass_public_functions
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_ssh_command_runner():
process_runner = MockProcessRunner()
provider = MockProvider()
provider.create_node({}, {}, 1)
cluster_name = "cluster"
ssh_control_hash = hashlib.md5(cluster_name.encode()).hexdigest()
ssh_user_hash = hashlib.md5(getuser().encode()).hexdigest()
ssh_control_path = "/tmp/ray_ssh_{}/{}".format(ssh_user_hash[:10],
ssh_control_hash[:10])
args = {
"log_prefix": "prefix",
"node_id": 0,
"provider": provider,
"auth_config": auth_config,
"cluster_name": cluster_name,
"process_runner": process_runner,
"use_internal_ip": False,
}
cmd_runner = SSHCommandRunner(**args)
env_vars = {"var1": "quote between this \" and this", "var2": "123"}
cmd_runner.run(
"echo helloo",
port_forward=[(8265, 8265)],
environment_variables=env_vars)
expected = [
"ssh",
"-tt",
"-L",
"8265:localhost:8265",
"-i",
"8265.pem",
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
"-o",
"IdentitiesOnly=yes",
"-o",
"ExitOnForwardFailure=yes",
"-o",
"ServerAliveInterval=5",
"-o",
"ServerAliveCountMax=3",
"-o",
"ControlMaster=auto",
"-o",
"ControlPath={}/%C".format(ssh_control_path),
"-o",
"ControlPersist=10s",
"-o",
"ConnectTimeout=120s",
"ray@1.2.3.4",
"bash",
"--login",
"-c",
"-i",
"""'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && (export var1='"'"'"quote between this \\" and this"'"'"';export var2='"'"'"123"'"'"';echo helloo)'""" # noqa: E501
]
# Much easier to debug this loop than the function call.
for x, y in zip(process_runner.calls[0], expected):
assert x == y
process_runner.assert_has_call("1.2.3.4", exact=expected)
def test_kubernetes_command_runner():
fail_cmd = "fail command"
process_runner = MockProcessRunner([fail_cmd])
provider = MockProvider()
provider.create_node({}, {}, 1)
args = {
"log_prefix": "prefix",
"namespace": "namespace",
"node_id": 0,
"auth_config": auth_config,
"process_runner": process_runner,
}
cmd_runner = KubernetesCommandRunner(**args)
env_vars = {"var1": "quote between this \" and this", "var2": "123"}
cmd_runner.run("echo helloo", environment_variables=env_vars)
expected = [
"kubectl",
"-n",
"namespace",
"exec",
"-it",
"0",
"--",
"bash",
"--login",
"-c",
"-i",
"""\'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && (export var1=\'"\'"\'"quote between this \\" and this"\'"\'"\';export var2=\'"\'"\'"123"\'"\'"\';echo helloo)\'""" # noqa: E501
]
assert process_runner.calls[0] == " ".join(expected)
logger = logging.getLogger("ray.autoscaler._private.command_runner")
with pytest.raises(SystemExit) as pytest_wrapped_e, patch.object(
logger, "error") as mock_logger_error:
cmd_runner.run(fail_cmd, exit_on_fail=True)
failed_cmd_expected = f'prefixCommand failed: \n\n kubectl -n namespace exec -it 0 --\'bash --login -c -i \'"\'"\'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && ({fail_cmd})\'"\'"\'\'\n' # noqa: E501
mock_logger_error.assert_called_once_with(failed_cmd_expected)
assert pytest_wrapped_e.type == SystemExit
assert pytest_wrapped_e.value.code == 1
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_docker_command_runner():
process_runner = MockProcessRunner()
provider = MockProvider()
provider.create_node({}, {}, 1)
cluster_name = "cluster"
ssh_control_hash = hashlib.md5(cluster_name.encode()).hexdigest()
ssh_user_hash = hashlib.md5(getuser().encode()).hexdigest()
ssh_control_path = "/tmp/ray_ssh_{}/{}".format(ssh_user_hash[:10],
ssh_control_hash[:10])
docker_config = {"container_name": "container"}
args = {
"log_prefix": "prefix",
"node_id": 0,
"provider": provider,
"auth_config": auth_config,
"cluster_name": cluster_name,
"process_runner": process_runner,
"use_internal_ip": False,
"docker_config": docker_config,
}
cmd_runner = DockerCommandRunner(**args)
assert len(process_runner.calls) == 0, "No calls should be made in ctor"
env_vars = {"var1": "quote between this \" and this", "var2": "123"}
cmd_runner.run("echo hello", environment_variables=env_vars)
# This string is insane because there are an absurd number of embedded
# quotes. While this is a ridiculous string, the escape behavior is
# important and somewhat difficult to get right for environment variables.
cmd = """'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && (docker exec -it container /bin/bash -c '"'"'bash --login -c -i '"'"'"'"'"'"'"'"'true && source ~/.bashrc && export OMP_NUM_THREADS=1 PYTHONWARNINGS=ignore && (export var1='"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"quote between this \\" and this"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"';export var2='"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"123"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"'"';echo hello)'"'"'"'"'"'"'"'"''"'"' )'""" # noqa: E501
expected = [
"ssh", "-tt", "-i", "8265.pem", "-o", "StrictHostKeyChecking=no", "-o",
"UserKnownHostsFile=/dev/null", "-o", "IdentitiesOnly=yes", "-o",
"ExitOnForwardFailure=yes", "-o", "ServerAliveInterval=5", "-o",
"ServerAliveCountMax=3", "-o", "ControlMaster=auto", "-o",
"ControlPath={}/%C".format(ssh_control_path), "-o",
"ControlPersist=10s", "-o", "ConnectTimeout=120s", "ray@1.2.3.4",
"bash", "--login", "-c", "-i", cmd
]
# Much easier to debug this loop than the function call.
for x, y in zip(process_runner.calls[0], expected):
print(f"expeted:\t{y}")
print(f"actual: \t{x}")
assert x == y
process_runner.assert_has_call("1.2.3.4", exact=expected)
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on Windows.")
def test_docker_rsync():
process_runner = MockProcessRunner()
provider = MockProvider()
provider.create_node({}, {}, 1)
cluster_name = "cluster"
docker_config = {"container_name": "container"}
args = {
"log_prefix": "prefix",
"node_id": 0,
"provider": provider,
"auth_config": auth_config,
"cluster_name": cluster_name,
"process_runner": process_runner,
"use_internal_ip": False,
"docker_config": docker_config,
}
cmd_runner = DockerCommandRunner(**args)
local_mount = "/home/ubuntu/base/mount/"
remote_mount = "/root/protected_mount/"
docker_mount_prefix = get_docker_host_mount_location(cluster_name)
remote_host_mount = f"{docker_mount_prefix}{remote_mount}"
local_file = "/home/ubuntu/base-file"
remote_file = "/root/protected-file"
remote_host_file = f"{docker_mount_prefix}{remote_file}"
process_runner.respond_to_call("docker inspect -f", ["true"])
cmd_runner.run_rsync_up(
local_mount, remote_mount, options={"docker_mount_if_possible": True})
# Make sure we do not copy directly to raw destination
process_runner.assert_not_has_call(
"1.2.3.4", pattern=f"-avz {local_mount} ray@1.2.3.4:{remote_mount}")
process_runner.assert_not_has_call(
"1.2.3.4", pattern=f"mkdir -p {remote_mount}")
# No docker cp for file_mounts
process_runner.assert_not_has_call("1.2.3.4", pattern=f"docker cp")
process_runner.assert_has_call(
"1.2.3.4",
pattern=f"-avz {local_mount} ray@1.2.3.4:{remote_host_mount}")
process_runner.clear_history()
##############################
process_runner.respond_to_call("docker inspect -f", ["true"])
cmd_runner.run_rsync_up(
local_file, remote_file, options={"docker_mount_if_possible": False})
# Make sure we do not copy directly to raw destination
process_runner.assert_not_has_call(
"1.2.3.4", pattern=f"-avz {local_file} ray@1.2.3.4:{remote_file}")
process_runner.assert_not_has_call(
"1.2.3.4", pattern=f"mkdir -p {remote_file}")
process_runner.assert_has_call("1.2.3.4", pattern=f"docker cp")
process_runner.assert_has_call(
"1.2.3.4", pattern=f"-avz {local_file} ray@1.2.3.4:{remote_host_file}")
process_runner.clear_history()
##############################
cmd_runner.run_rsync_down(
remote_mount, local_mount, options={"docker_mount_if_possible": True})
process_runner.assert_not_has_call("1.2.3.4", pattern=f"docker cp")
process_runner.assert_not_has_call(
"1.2.3.4", pattern=f"-avz ray@1.2.3.4:{remote_mount} {local_mount}")
process_runner.assert_has_call(
"1.2.3.4",
pattern=f"-avz ray@1.2.3.4:{remote_host_mount} {local_mount}")
process_runner.clear_history()
##############################
cmd_runner.run_rsync_down(
remote_file, local_file, options={"docker_mount_if_possible": False})
process_runner.assert_has_call("1.2.3.4", pattern=f"docker cp")
process_runner.assert_not_has_call(
"1.2.3.4", pattern=f"-avz ray@1.2.3.4:{remote_file} {local_file}")
process_runner.assert_has_call(
"1.2.3.4", pattern=f"-avz ray@1.2.3.4:{remote_host_file} {local_file}")
def test_rsync_exclude_and_filter():
process_runner = MockProcessRunner()
provider = MockProvider()
provider.create_node({}, {}, 1)
cluster_name = "cluster"
args = {
"log_prefix": "prefix",
"node_id": 0,
"provider": provider,
"auth_config": auth_config,
"cluster_name": cluster_name,
"process_runner": process_runner,
"use_internal_ip": False,
}
cmd_runner = SSHCommandRunner(**args)
local_mount = "/home/ubuntu/base/mount/"
remote_mount = "/root/protected_mount/"
process_runner.respond_to_call("docker inspect -f", ["true"])
cmd_runner.run_rsync_up(
local_mount,
remote_mount,
options={
"docker_mount_if_possible": True,
"rsync_exclude": ["test"],
"rsync_filter": [".ignore"]
})
process_runner.assert_has_call(
"1.2.3.4", pattern=f"--exclude test --filter dir-merge,- .ignore")
def test_rsync_without_exclude_and_filter():
process_runner = MockProcessRunner()
provider = MockProvider()
provider.create_node({}, {}, 1)
cluster_name = "cluster"
args = {
"log_prefix": "prefix",
"node_id": 0,
"provider": provider,
"auth_config": auth_config,
"cluster_name": cluster_name,
"process_runner": process_runner,
"use_internal_ip": False,
}
cmd_runner = SSHCommandRunner(**args)
local_mount = "/home/ubuntu/base/mount/"
remote_mount = "/root/protected_mount/"
process_runner.respond_to_call("docker inspect -f", ["true"])
cmd_runner.run_rsync_up(
local_mount,
remote_mount,
options={
"docker_mount_if_possible": True,
})
process_runner.assert_not_has_call("1.2.3.4", pattern=f"--exclude test")
process_runner.assert_not_has_call(
"1.2.3.4", pattern=f"--filter dir-merge,- .ignore")
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))