diff --git a/python/ray/autoscaler/command_runner.py b/python/ray/autoscaler/command_runner.py index 6a10654ea..4b38108eb 100644 --- a/python/ray/autoscaler/command_runner.py +++ b/python/ray/autoscaler/command_runner.py @@ -220,7 +220,8 @@ class SSHOptions: self.arg_dict["ConnectTimeout"] = "{}s".format(timeout) return ["-i", self.ssh_key] + [ x for y in (["-o", "{}={}".format(k, v)] - for k, v in self.arg_dict.items()) for x in y + for k, v in self.arg_dict.items() + if v is not None) for x in y ] @@ -243,8 +244,11 @@ class SSHCommandRunner(CommandRunnerInterface): self.ssh_user = auth_config["ssh_user"] self.ssh_control_path = ssh_control_path self.ssh_ip = None - self.base_ssh_options = SSHOptions(self.ssh_private_key, - self.ssh_control_path) + self.ssh_proxy_command = auth_config.get("ssh_proxy_command", None) + self.ssh_options = SSHOptions( + self.ssh_private_key, + self.ssh_control_path, + ProxyCommand=self.ssh_proxy_command) def _get_node_ip(self): if self.use_internal_ip: @@ -292,7 +296,7 @@ class SSHCommandRunner(CommandRunnerInterface): with_output=False, ssh_options_override=None, **kwargs): - ssh_options = ssh_options_override or self.base_ssh_options + ssh_options = ssh_options_override or self.ssh_options assert isinstance( ssh_options, SSHOptions @@ -342,8 +346,8 @@ class SSHCommandRunner(CommandRunnerInterface): self._set_ssh_ip_if_required() self.process_runner.check_call([ "rsync", "--rsh", - " ".join(["ssh"] + - self.base_ssh_options.to_ssh_options_list(timeout=120)), + subprocess.list2cmdline( + ["ssh"] + self.ssh_options.to_ssh_options_list(timeout=120)), "-avz", source, "{}@{}:{}".format(self.ssh_user, self.ssh_ip, target) ]) @@ -352,8 +356,8 @@ class SSHCommandRunner(CommandRunnerInterface): self._set_ssh_ip_if_required() self.process_runner.check_call([ "rsync", "--rsh", - " ".join(["ssh"] + - self.base_ssh_options.to_ssh_options_list(timeout=120)), + subprocess.list2cmdline( + ["ssh"] + self.ssh_options.to_ssh_options_list(timeout=120)), "-avz", "{}@{}:{}".format(self.ssh_user, self.ssh_ip, source), target ]) diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index a60f5a9bb..1840dcdc6 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -321,6 +321,9 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, # Rewrite the auth config so that the head node can update the workers remote_config = copy.deepcopy(config) + # drop proxy options if they exist, otherwise + # head node won't be able to connect to workers + remote_config["auth"].pop("ssh_proxy_command", None) if config["provider"]["type"] != "kubernetes": remote_key_path = "~/ray_bootstrap_key.pem" remote_config["auth"]["ssh_private_key"] = remote_key_path diff --git a/python/ray/autoscaler/ray-schema.json b/python/ray/autoscaler/ray-schema.json index ad299ebf5..2c3fead8e 100644 --- a/python/ray/autoscaler/ray-schema.json +++ b/python/ray/autoscaler/ray-schema.json @@ -173,6 +173,10 @@ }, "ssh_private_key": { "type": "string" + }, + "ssh_proxy_command": { + "description": "A value for ProxyCommand ssh option, for connecting through proxies. Example: nc -x proxy.example.com:1234 %h %p", + "type": "string" } } },