mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 04:03:14 +08:00
[autoscaler] rsync cluster (#4785)
This commit is contained in:
@@ -423,6 +423,8 @@ def rsync(config_file, source, target, override_cluster_name, down):
|
||||
override_cluster_name: set the name of the cluster
|
||||
down: whether we're syncing remote -> local
|
||||
"""
|
||||
assert bool(source) == bool(target), (
|
||||
"Must either provide both or neither source and target.")
|
||||
|
||||
config = yaml.load(open(config_file).read())
|
||||
if override_cluster_name is not None:
|
||||
@@ -448,7 +450,12 @@ def rsync(config_file, source, target, override_cluster_name, down):
|
||||
rsync = updater.rsync_down
|
||||
else:
|
||||
rsync = updater.rsync_up
|
||||
rsync(source, target, check_error=False)
|
||||
|
||||
if source and target:
|
||||
rsync(source, target, check_error=False)
|
||||
else:
|
||||
updater.sync_file_mounts(rsync)
|
||||
|
||||
finally:
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
@@ -183,25 +183,9 @@ class NodeUpdater(object):
|
||||
|
||||
return False
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
|
||||
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
self.set_ssh_ip_if_required()
|
||||
|
||||
# Wait for SSH access
|
||||
with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)):
|
||||
ssh_ok = self.wait_for_ssh(deadline)
|
||||
assert ssh_ok, "Unable to SSH to node"
|
||||
|
||||
def sync_file_mounts(self, sync_cmd):
|
||||
# Rsync file mounts
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "syncing-files"})
|
||||
for remote_path, local_path in self.file_mounts.items():
|
||||
logger.info("NodeUpdater: "
|
||||
"{}: Syncing {} to {}...".format(
|
||||
self.node_id, local_path, remote_path))
|
||||
assert os.path.exists(local_path), local_path
|
||||
if os.path.isdir(local_path):
|
||||
if not local_path.endswith("/"):
|
||||
@@ -217,7 +201,23 @@ class NodeUpdater(object):
|
||||
"mkdir -p {}".format(os.path.dirname(remote_path)),
|
||||
redirect=redirect,
|
||||
)
|
||||
self.rsync_up(local_path, remote_path, redirect=redirect)
|
||||
sync_cmd(local_path, remote_path, redirect=redirect)
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
|
||||
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
self.set_ssh_ip_if_required()
|
||||
|
||||
# Wait for SSH access
|
||||
with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)):
|
||||
ssh_ok = self.wait_for_ssh(deadline)
|
||||
assert ssh_ok, "Unable to SSH to node"
|
||||
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "syncing-files"})
|
||||
self.sync_file_mounts(self.rsync_up)
|
||||
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
@@ -236,6 +236,9 @@ class NodeUpdater(object):
|
||||
self.ssh_cmd(cmd, redirect=redirect)
|
||||
|
||||
def rsync_up(self, source, target, redirect=None, check_error=True):
|
||||
logger.info("NodeUpdater: "
|
||||
"{}: Syncing {} to {}...".format(self.node_id, source,
|
||||
target))
|
||||
self.set_ssh_ip_if_required()
|
||||
self.get_caller(check_error)(
|
||||
[
|
||||
@@ -247,6 +250,9 @@ class NodeUpdater(object):
|
||||
stderr=redirect or sys.stderr)
|
||||
|
||||
def rsync_down(self, source, target, redirect=None, check_error=True):
|
||||
logger.info("NodeUpdater: "
|
||||
"{}: Syncing {} from {}...".format(self.node_id, source,
|
||||
target))
|
||||
self.set_ssh_ip_if_required()
|
||||
self.get_caller(check_error)(
|
||||
[
|
||||
|
||||
@@ -529,8 +529,8 @@ def attach(cluster_config_file, start, tmux, cluster_name, new):
|
||||
|
||||
@cli.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
@click.argument("source", required=True, type=str)
|
||||
@click.argument("target", required=True, type=str)
|
||||
@click.argument("source", required=False, type=str)
|
||||
@click.argument("target", required=False, type=str)
|
||||
@click.option(
|
||||
"--cluster-name",
|
||||
"-n",
|
||||
@@ -543,8 +543,8 @@ def rsync_down(cluster_config_file, source, target, cluster_name):
|
||||
|
||||
@cli.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
@click.argument("source", required=True, type=str)
|
||||
@click.argument("target", required=True, type=str)
|
||||
@click.argument("source", required=False, type=str)
|
||||
@click.argument("target", required=False, type=str)
|
||||
@click.option(
|
||||
"--cluster-name",
|
||||
"-n",
|
||||
|
||||
@@ -49,6 +49,11 @@ parser.add_argument(
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="disables CUDA training")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@@ -173,7 +178,7 @@ if __name__ == "__main__":
|
||||
from ray import tune
|
||||
from ray.tune.schedulers import HyperBandScheduler
|
||||
|
||||
ray.init()
|
||||
ray.init(redis_address=args.redis_address)
|
||||
sched = HyperBandScheduler(
|
||||
time_attr="training_iteration", reward_attr="neg_mean_loss")
|
||||
tune.run(
|
||||
|
||||
Reference in New Issue
Block a user