[autoscaler] rsync cluster (#4785)

This commit is contained in:
Richard Liaw
2019-05-16 23:11:06 -07:00
committed by GitHub
parent ffe61fcc70
commit 88b45a53d6
4 changed files with 42 additions and 24 deletions
+8 -1
View File
@@ -423,6 +423,8 @@ def rsync(config_file, source, target, override_cluster_name, down):
override_cluster_name: set the name of the cluster
down: whether we're syncing remote -> local
"""
assert bool(source) == bool(target), (
"Must either provide both or neither source and target.")
config = yaml.load(open(config_file).read())
if override_cluster_name is not None:
@@ -448,7 +450,12 @@ def rsync(config_file, source, target, override_cluster_name, down):
rsync = updater.rsync_down
else:
rsync = updater.rsync_up
rsync(source, target, check_error=False)
if source and target:
rsync(source, target, check_error=False)
else:
updater.sync_file_mounts(rsync)
finally:
provider.cleanup()
+24 -18
View File
@@ -183,25 +183,9 @@ class NodeUpdater(object):
return False
def do_update(self):
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
deadline = time.time() + NODE_START_WAIT_S
self.set_ssh_ip_if_required()
# Wait for SSH access
with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)):
ssh_ok = self.wait_for_ssh(deadline)
assert ssh_ok, "Unable to SSH to node"
def sync_file_mounts(self, sync_cmd):
# Rsync file mounts
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "syncing-files"})
for remote_path, local_path in self.file_mounts.items():
logger.info("NodeUpdater: "
"{}: Syncing {} to {}...".format(
self.node_id, local_path, remote_path))
assert os.path.exists(local_path), local_path
if os.path.isdir(local_path):
if not local_path.endswith("/"):
@@ -217,7 +201,23 @@ class NodeUpdater(object):
"mkdir -p {}".format(os.path.dirname(remote_path)),
redirect=redirect,
)
self.rsync_up(local_path, remote_path, redirect=redirect)
sync_cmd(local_path, remote_path, redirect=redirect)
def do_update(self):
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "waiting-for-ssh"})
deadline = time.time() + NODE_START_WAIT_S
self.set_ssh_ip_if_required()
# Wait for SSH access
with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)):
ssh_ok = self.wait_for_ssh(deadline)
assert ssh_ok, "Unable to SSH to node"
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "syncing-files"})
self.sync_file_mounts(self.rsync_up)
# Run init commands
self.provider.set_node_tags(self.node_id,
@@ -236,6 +236,9 @@ class NodeUpdater(object):
self.ssh_cmd(cmd, redirect=redirect)
def rsync_up(self, source, target, redirect=None, check_error=True):
logger.info("NodeUpdater: "
"{}: Syncing {} to {}...".format(self.node_id, source,
target))
self.set_ssh_ip_if_required()
self.get_caller(check_error)(
[
@@ -247,6 +250,9 @@ class NodeUpdater(object):
stderr=redirect or sys.stderr)
def rsync_down(self, source, target, redirect=None, check_error=True):
logger.info("NodeUpdater: "
"{}: Syncing {} from {}...".format(self.node_id, source,
target))
self.set_ssh_ip_if_required()
self.get_caller(check_error)(
[
+4 -4
View File
@@ -529,8 +529,8 @@ def attach(cluster_config_file, start, tmux, cluster_name, new):
@cli.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.argument("source", required=True, type=str)
@click.argument("target", required=True, type=str)
@click.argument("source", required=False, type=str)
@click.argument("target", required=False, type=str)
@click.option(
"--cluster-name",
"-n",
@@ -543,8 +543,8 @@ def rsync_down(cluster_config_file, source, target, cluster_name):
@cli.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.argument("source", required=True, type=str)
@click.argument("target", required=True, type=str)
@click.argument("source", required=False, type=str)
@click.argument("target", required=False, type=str)
@click.option(
"--cluster-name",
"-n",
@@ -49,6 +49,11 @@ parser.add_argument(
action="store_true",
default=False,
help="disables CUDA training")
parser.add_argument(
"--redis-address",
default=None,
type=str,
help="The Redis address of the cluster.")
parser.add_argument(
"--seed",
type=int,
@@ -173,7 +178,7 @@ if __name__ == "__main__":
from ray import tune
from ray.tune.schedulers import HyperBandScheduler
ray.init()
ray.init(redis_address=args.redis_address)
sched = HyperBandScheduler(
time_attr="training_iteration", reward_attr="neg_mean_loss")
tune.run(