[autoscaler] Small fixes for local cluster usability (#4864)

This commit is contained in:
Richard Liaw
2019-07-06 21:55:18 -07:00
committed by Eric Liang
parent 1798d4f077
commit 6a14f1a540
2 changed files with 22 additions and 12 deletions
+15 -4
View File
@@ -28,10 +28,15 @@ class ClusterState(object):
with self.file_lock:
if os.path.exists(self.save_path):
workers = json.loads(open(self.save_path).read())
head_config = workers.get(provider_config["head_ip"])
if not head_config or head_config.get(
"tags", {}).get(TAG_RAY_NODE_TYPE) != "head":
workers = {}
logger.info("Head IP changed - recreating cluster.")
else:
workers = {}
logger.info("ClusterState: "
"Loaded cluster state: {}".format(workers))
"Loaded cluster state: {}".format(list(workers)))
for worker_ip in provider_config["worker_ips"]:
if worker_ip not in workers:
workers[worker_ip] = {
@@ -55,8 +60,8 @@ class ClusterState(object):
TAG_RAY_NODE_TYPE] == "head"
assert len(workers) == len(provider_config["worker_ips"]) + 1
with open(self.save_path, "w") as f:
logger.info("ClusterState: "
"Writing cluster state: {}".format(workers))
logger.debug("ClusterState: "
"Writing cluster state: {}".format(workers))
f.write(json.dumps(workers))
def get(self):
@@ -74,11 +79,17 @@ class ClusterState(object):
workers[worker_id] = info
with open(self.save_path, "w") as f:
logger.info("ClusterState: "
"Writing cluster state: {}".format(workers))
"Writing cluster state: {}".format(
list(workers)))
f.write(json.dumps(workers))
class LocalNodeProvider(NodeProvider):
"""NodeProvider for private/local clusters.
`node_id` is overloaded to also be `node_ip` in this class.
"""
def __init__(self, provider_config, cluster_name):
NodeProvider.__init__(self, provider_config, cluster_name)
self.state = ClusterState("/tmp/cluster-{}.lock".format(cluster_name),
+7 -8
View File
@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
# How long to wait for a node to start, in seconds
NODE_START_WAIT_S = 300
SSH_CHECK_INTERVAL = 5
CONTROL_PATH_MAX_LENGTH = 70
def get_default_ssh_options(private_key, connect_timeout, ssh_control_path):
@@ -56,7 +57,7 @@ class NodeUpdater(object):
use_internal_ip=False):
ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format(
getuser(), cluster_name)
getuser(), cluster_name)[:CONTROL_PATH_MAX_LENGTH]
self.daemon = True
self.process_runner = process_runner
@@ -197,12 +198,11 @@ class NodeUpdater(object):
m = "{}: Synced {} to {}".format(self.node_id, local_path,
remote_path)
with LogTimer("NodeUpdater {}".format(m)):
with open("/dev/null", "w") as redirect:
self.ssh_cmd(
"mkdir -p {}".format(os.path.dirname(remote_path)),
redirect=redirect,
)
sync_cmd(local_path, remote_path, redirect=redirect)
self.ssh_cmd(
"mkdir -p {}".format(os.path.dirname(remote_path)),
redirect=None,
)
sync_cmd(local_path, remote_path, redirect=None)
def do_update(self):
self.provider.set_node_tags(self.node_id,
@@ -223,7 +223,6 @@ class NodeUpdater(object):
# Run init commands
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "setting-up"})
m = "{}: Initialization commands completed".format(self.node_id)
with LogTimer("NodeUpdater: {}".format(m)):
for cmd in self.initialization_commands: