diff --git a/python/ray/autoscaler/autoscaler.py b/python/ray/autoscaler/autoscaler.py index 0d2f39563..be08b86d6 100644 --- a/python/ray/autoscaler/autoscaler.py +++ b/python/ray/autoscaler/autoscaler.py @@ -321,7 +321,8 @@ class StandardAutoscaler: self.config["worker_start_ray_commands"]), runtime_hash=self.runtime_hash, process_runner=self.process_runner, - use_internal_ip=True) + use_internal_ip=True, + docker_config=self.config["docker"]) updater.start() self.updaters[node_id] = updater @@ -360,7 +361,8 @@ class StandardAutoscaler: ray_start_commands=with_head_node_ip(ray_start_commands), runtime_hash=self.runtime_hash, process_runner=self.process_runner, - use_internal_ip=True) + use_internal_ip=True, + docker_config=self.config["docker"]) updater.start() self.updaters[node_id] = updater diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 121d2970d..e8260545d 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -147,7 +147,8 @@ def kill_node(config_file, yes, hard, override_cluster_name): initialization_commands=[], setup_commands=[], ray_start_commands=[], - runtime_hash="") + runtime_hash="", + docker_config=config["docker"]) _exec(updater, "ray stop", False, False) @@ -286,7 +287,7 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, setup_commands=init_commands, ray_start_commands=ray_start_commands, runtime_hash=runtime_hash, - ) + docker_config=config["docker"]) updater.start() updater.join() @@ -407,7 +408,7 @@ def exec_cluster(config_file, setup_commands=[], ray_start_commands=[], runtime_hash="", - ) + docker_config=config["docker"]) def wrap_docker(command): container_name = config["docker"]["container_name"] @@ -529,7 +530,7 @@ def rsync(config_file, setup_commands=[], ray_start_commands=[], runtime_hash="", - ) + docker_config=config["docker"]) if down: rsync = updater.rsync_down else: diff --git a/python/ray/autoscaler/kubernetes/node_provider.py b/python/ray/autoscaler/kubernetes/node_provider.py index 3471e4e76..88c3a6d83 100644 --- a/python/ray/autoscaler/kubernetes/node_provider.py +++ b/python/ray/autoscaler/kubernetes/node_provider.py @@ -87,7 +87,13 @@ class KubernetesNodeProvider(NodeProvider): for node_id in node_ids: self.terminate_node(node_id) - def get_command_runner(self, log_prefix, node_id, auth_config, - cluster_name, process_runner, use_internal_ip): + def get_command_runner(self, + log_prefix, + node_id, + auth_config, + cluster_name, + process_runner, + use_internal_ip, + docker_config=None): return KubernetesCommandRunner(log_prefix, self.namespace, node_id, auth_config, process_runner) diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index b0cc39f2a..15b6f8573 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -3,7 +3,7 @@ import logging import os import yaml -from ray.autoscaler.updater import SSHCommandRunner +from ray.autoscaler.updater import SSHCommandRunner, DockerCommandRunner logger = logging.getLogger(__name__) @@ -211,8 +211,14 @@ class NodeProvider: """Clean-up when a Provider is no longer required.""" pass - def get_command_runner(self, log_prefix, node_id, auth_config, - cluster_name, process_runner, use_internal_ip): + def get_command_runner(self, + log_prefix, + node_id, + auth_config, + cluster_name, + process_runner, + use_internal_ip, + docker_config=None): """ Returns the CommandRunner class used to perform SSH commands. Args: @@ -226,7 +232,19 @@ class NodeProvider: in the CommandRunner. E.g., subprocess. use_internal_ip(bool): whether the node_id belongs to an internal ip or external ip. + docker_config(dict): If set, the docker information of the docker + container that commands should be run on. """ - - return SSHCommandRunner(log_prefix, node_id, self, auth_config, - cluster_name, process_runner, use_internal_ip) + common_args = { + "log_prefix": log_prefix, + "node_id": node_id, + "provider": self, + "auth_config": auth_config, + "cluster_name": cluster_name, + "process_runner": process_runner, + "use_internal_ip": use_internal_ip + } + if docker_config and docker_config["container_name"] != "": + return DockerCommandRunner(docker_config, **common_args) + else: + return SSHCommandRunner(**common_args) diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index cf253e775..968ca2f93 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -294,6 +294,48 @@ class SSHCommandRunner: self.ssh_private_key, self.ssh_user, self.ssh_ip) +class DockerCommandRunner(SSHCommandRunner): + def __init__(self, docker_config, **common_args): + self.ssh_command_runner = SSHCommandRunner(**common_args) + self.docker_name = docker_config["container_name"] + self.docker_config = docker_config + + def run(self, + cmd, + timeout=120, + allocate_tty=False, + exit_on_fail=False, + port_forward=None, + with_output=False): + return self.ssh_command_runner.run(cmd, timeout, allocate_tty, + exit_on_fail, port_forward, + with_output) + + def run_rsync_up(self, source, target): + self.ssh_command_runner.run_rsync_up(source, target) + self.ssh_command_runner.run("docker cp {} {}:{}".format( + target, self.docker_name, self.docker_expand_user(target))) + + def run_rsync_down(self, source, target): + self.ssh_command_runner.run("docker cp {}:{} {}".format( + self.docker_name, self.docker_expand_user(source), source)) + self.ssh_command_runner.run_rsync_down(source, target) + + def remote_shell_command_str(self): + inner_str = self.ssh_command_runner.remote_shell_command_str().replace( + "ssh", "ssh -tt", 1).strip("\n") + return inner_str + " docker exec -it {} /bin/bash\n".format( + self.docker_name) + + def docker_expand_user(self, string): + if string.find("~") == 0: + return string.replace( + "~", + "`docker exec ray_docker env | grep HOME | cut -d'=' -f2`", 1) + else: + return string + + class NodeUpdater: """A process for syncing files and running init commands on a node.""" @@ -309,14 +351,15 @@ class NodeUpdater: ray_start_commands, runtime_hash, process_runner=subprocess, - use_internal_ip=False): + use_internal_ip=False, + docker_config=None): self.log_prefix = "NodeUpdater: {}: ".format(node_id) use_internal_ip = (use_internal_ip or provider_config.get("use_internal_ips", False)) self.cmd_runner = provider.get_command_runner( self.log_prefix, node_id, auth_config, cluster_name, - process_runner, use_internal_ip) + process_runner, use_internal_ip, docker_config) self.daemon = True self.process_runner = process_runner