diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 807851b0f..be028afc3 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -446,7 +446,12 @@ def _exec(updater, cmd, screen, tmux, port_forward=None): cmd, allocate_tty=True, exit_on_fail=True, port_forward=port_forward) -def rsync(config_file, source, target, override_cluster_name, down): +def rsync(config_file, + source, + target, + override_cluster_name, + down, + all_nodes=False): """Rsyncs files. Arguments: @@ -455,6 +460,7 @@ def rsync(config_file, source, target, override_cluster_name, down): target: target dir override_cluster_name: set the name of the cluster down: whether we're syncing remote -> local + all_nodes: whether to sync worker nodes in addition to the head node """ assert bool(source) == bool(target), ( "Must either provide both or neither source and target.") @@ -463,32 +469,46 @@ def rsync(config_file, source, target, override_cluster_name, down): if override_cluster_name is not None: config["cluster_name"] = override_cluster_name config = _bootstrap_config(config) - head_node = _get_head_node( - config, config_file, override_cluster_name, create_if_needed=False) provider = get_node_provider(config["provider"], config["cluster_name"]) try: - updater = NodeUpdaterThread( - node_id=head_node, - provider_config=config["provider"], - provider=provider, - auth_config=config["auth"], - cluster_name=config["cluster_name"], - file_mounts=config["file_mounts"], - initialization_commands=[], - setup_commands=[], - ray_start_commands=[], - runtime_hash="", - ) - if down: - rsync = updater.rsync_down - else: - rsync = updater.rsync_up + nodes = [] + if all_nodes: + # technically we re-open the provider for no reason + # in get_worker_nodes but it's cleaner this way + # and _get_head_node does this too + nodes = _get_worker_nodes(config, override_cluster_name) - if source and target: - rsync(source, target) - else: - updater.sync_file_mounts(rsync) + nodes += [ + _get_head_node( + config, + config_file, + override_cluster_name, + create_if_needed=False) + ] + + for node_id in nodes: + updater = NodeUpdaterThread( + node_id=node_id, + provider_config=config["provider"], + provider=provider, + auth_config=config["auth"], + cluster_name=config["cluster_name"], + file_mounts=config["file_mounts"], + initialization_commands=[], + setup_commands=[], + ray_start_commands=[], + runtime_hash="", + ) + if down: + rsync = updater.rsync_down + else: + rsync = updater.rsync_up + + if source and target: + rsync(source, target) + else: + updater.sync_file_mounts(rsync) finally: provider.cleanup() @@ -535,6 +555,21 @@ def get_worker_node_ips(config_file, override_cluster_name): provider.cleanup() +def _get_worker_nodes(config, override_cluster_name): + """Returns worker node ids for given configuration.""" + # todo: technically could be reused in get_worker_node_ips + if override_cluster_name is not None: + config["cluster_name"] = override_cluster_name + + provider = get_node_provider(config["provider"], config["cluster_name"]) + try: + return provider.non_terminated_nodes({ + TAG_RAY_NODE_TYPE: NODE_TYPE_WORKER + }) + finally: + provider.cleanup() + + def _get_head_node(config, config_file, override_cluster_name, diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 1d3626c41..cacf2bfe0 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -696,8 +696,20 @@ def rsync_down(cluster_config_file, source, target, cluster_name): required=False, type=str, help="Override the configured cluster name.") -def rsync_up(cluster_config_file, source, target, cluster_name): - rsync(cluster_config_file, source, target, cluster_name, down=False) +@click.option( + "--all-nodes", + "-A", + is_flag=True, + required=False, + help="Upload to all nodes (workers and head).") +def rsync_up(cluster_config_file, source, target, cluster_name, all_nodes): + rsync( + cluster_config_file, + source, + target, + cluster_name, + down=False, + all_nodes=all_nodes) @cli.command(context_settings={"ignore_unknown_options": True})