[autoscaler] Add `--all-nodes` option to rsync-up (#7065)

* Add option to sync workers to rsync-up

* Format

* Rename --sync-workers to --all-nodes
This commit is contained in:
Maksim Smolin
2020-02-10 16:27:59 -08:00
committed by GitHub
parent ad1848b623
commit 4139e02f01
2 changed files with 72 additions and 25 deletions
+58 -23
View File
@@ -446,7 +446,12 @@ def _exec(updater, cmd, screen, tmux, port_forward=None):
cmd, allocate_tty=True, exit_on_fail=True, port_forward=port_forward)
def rsync(config_file, source, target, override_cluster_name, down):
def rsync(config_file,
source,
target,
override_cluster_name,
down,
all_nodes=False):
"""Rsyncs files.
Arguments:
@@ -455,6 +460,7 @@ def rsync(config_file, source, target, override_cluster_name, down):
target: target dir
override_cluster_name: set the name of the cluster
down: whether we're syncing remote -> local
all_nodes: whether to sync worker nodes in addition to the head node
"""
assert bool(source) == bool(target), (
"Must either provide both or neither source and target.")
@@ -463,32 +469,46 @@ def rsync(config_file, source, target, override_cluster_name, down):
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name
config = _bootstrap_config(config)
head_node = _get_head_node(
config, config_file, override_cluster_name, create_if_needed=False)
provider = get_node_provider(config["provider"], config["cluster_name"])
try:
updater = NodeUpdaterThread(
node_id=head_node,
provider_config=config["provider"],
provider=provider,
auth_config=config["auth"],
cluster_name=config["cluster_name"],
file_mounts=config["file_mounts"],
initialization_commands=[],
setup_commands=[],
ray_start_commands=[],
runtime_hash="",
)
if down:
rsync = updater.rsync_down
else:
rsync = updater.rsync_up
nodes = []
if all_nodes:
# technically we re-open the provider for no reason
# in get_worker_nodes but it's cleaner this way
# and _get_head_node does this too
nodes = _get_worker_nodes(config, override_cluster_name)
if source and target:
rsync(source, target)
else:
updater.sync_file_mounts(rsync)
nodes += [
_get_head_node(
config,
config_file,
override_cluster_name,
create_if_needed=False)
]
for node_id in nodes:
updater = NodeUpdaterThread(
node_id=node_id,
provider_config=config["provider"],
provider=provider,
auth_config=config["auth"],
cluster_name=config["cluster_name"],
file_mounts=config["file_mounts"],
initialization_commands=[],
setup_commands=[],
ray_start_commands=[],
runtime_hash="",
)
if down:
rsync = updater.rsync_down
else:
rsync = updater.rsync_up
if source and target:
rsync(source, target)
else:
updater.sync_file_mounts(rsync)
finally:
provider.cleanup()
@@ -535,6 +555,21 @@ def get_worker_node_ips(config_file, override_cluster_name):
provider.cleanup()
def _get_worker_nodes(config, override_cluster_name):
"""Returns worker node ids for given configuration."""
# todo: technically could be reused in get_worker_node_ips
if override_cluster_name is not None:
config["cluster_name"] = override_cluster_name
provider = get_node_provider(config["provider"], config["cluster_name"])
try:
return provider.non_terminated_nodes({
TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER
})
finally:
provider.cleanup()
def _get_head_node(config,
config_file,
override_cluster_name,
+14 -2
View File
@@ -696,8 +696,20 @@ def rsync_down(cluster_config_file, source, target, cluster_name):
required=False,
type=str,
help="Override the configured cluster name.")
def rsync_up(cluster_config_file, source, target, cluster_name):
rsync(cluster_config_file, source, target, cluster_name, down=False)
@click.option(
"--all-nodes",
"-A",
is_flag=True,
required=False,
help="Upload to all nodes (workers and head).")
def rsync_up(cluster_config_file, source, target, cluster_name, all_nodes):
rsync(
cluster_config_file,
source,
target,
cluster_name,
down=False,
all_nodes=all_nodes)
@cli.command(context_settings={"ignore_unknown_options": True})