mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 00:55:31 +08:00
[autoscaler] Small fixes for local cluster usability (#4864)
This commit is contained in:
@@ -28,10 +28,15 @@ class ClusterState(object):
|
||||
with self.file_lock:
|
||||
if os.path.exists(self.save_path):
|
||||
workers = json.loads(open(self.save_path).read())
|
||||
head_config = workers.get(provider_config["head_ip"])
|
||||
if not head_config or head_config.get(
|
||||
"tags", {}).get(TAG_RAY_NODE_TYPE) != "head":
|
||||
workers = {}
|
||||
logger.info("Head IP changed - recreating cluster.")
|
||||
else:
|
||||
workers = {}
|
||||
logger.info("ClusterState: "
|
||||
"Loaded cluster state: {}".format(workers))
|
||||
"Loaded cluster state: {}".format(list(workers)))
|
||||
for worker_ip in provider_config["worker_ips"]:
|
||||
if worker_ip not in workers:
|
||||
workers[worker_ip] = {
|
||||
@@ -55,8 +60,8 @@ class ClusterState(object):
|
||||
TAG_RAY_NODE_TYPE] == "head"
|
||||
assert len(workers) == len(provider_config["worker_ips"]) + 1
|
||||
with open(self.save_path, "w") as f:
|
||||
logger.info("ClusterState: "
|
||||
"Writing cluster state: {}".format(workers))
|
||||
logger.debug("ClusterState: "
|
||||
"Writing cluster state: {}".format(workers))
|
||||
f.write(json.dumps(workers))
|
||||
|
||||
def get(self):
|
||||
@@ -74,11 +79,17 @@ class ClusterState(object):
|
||||
workers[worker_id] = info
|
||||
with open(self.save_path, "w") as f:
|
||||
logger.info("ClusterState: "
|
||||
"Writing cluster state: {}".format(workers))
|
||||
"Writing cluster state: {}".format(
|
||||
list(workers)))
|
||||
f.write(json.dumps(workers))
|
||||
|
||||
|
||||
class LocalNodeProvider(NodeProvider):
|
||||
"""NodeProvider for private/local clusters.
|
||||
|
||||
`node_id` is overloaded to also be `node_ip` in this class.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_config, cluster_name):
|
||||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
self.state = ClusterState("/tmp/cluster-{}.lock".format(cluster_name),
|
||||
|
||||
@@ -23,6 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
# How long to wait for a node to start, in seconds
|
||||
NODE_START_WAIT_S = 300
|
||||
SSH_CHECK_INTERVAL = 5
|
||||
CONTROL_PATH_MAX_LENGTH = 70
|
||||
|
||||
|
||||
def get_default_ssh_options(private_key, connect_timeout, ssh_control_path):
|
||||
@@ -56,7 +57,7 @@ class NodeUpdater(object):
|
||||
use_internal_ip=False):
|
||||
|
||||
ssh_control_path = "/tmp/{}_ray_ssh_sockets/{}".format(
|
||||
getuser(), cluster_name)
|
||||
getuser(), cluster_name)[:CONTROL_PATH_MAX_LENGTH]
|
||||
|
||||
self.daemon = True
|
||||
self.process_runner = process_runner
|
||||
@@ -197,12 +198,11 @@ class NodeUpdater(object):
|
||||
m = "{}: Synced {} to {}".format(self.node_id, local_path,
|
||||
remote_path)
|
||||
with LogTimer("NodeUpdater {}".format(m)):
|
||||
with open("/dev/null", "w") as redirect:
|
||||
self.ssh_cmd(
|
||||
"mkdir -p {}".format(os.path.dirname(remote_path)),
|
||||
redirect=redirect,
|
||||
)
|
||||
sync_cmd(local_path, remote_path, redirect=redirect)
|
||||
self.ssh_cmd(
|
||||
"mkdir -p {}".format(os.path.dirname(remote_path)),
|
||||
redirect=None,
|
||||
)
|
||||
sync_cmd(local_path, remote_path, redirect=None)
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
@@ -223,7 +223,6 @@ class NodeUpdater(object):
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "setting-up"})
|
||||
|
||||
m = "{}: Initialization commands completed".format(self.node_id)
|
||||
with LogTimer("NodeUpdater: {}".format(m)):
|
||||
for cmd in self.initialization_commands:
|
||||
|
||||
Reference in New Issue
Block a user