diff --git a/python/ray/dashboard/dashboard.py b/python/ray/dashboard/dashboard.py index e8c510d42..6d7cc8e05 100644 --- a/python/ray/dashboard/dashboard.py +++ b/python/ray/dashboard/dashboard.py @@ -366,7 +366,7 @@ class RayletStats(threading.Thread): def __init__(self, redis_address, redis_password=None): self.nodes_lock = threading.Lock() self.nodes = [] - self.stubs = [] + self.stubs = {} self._raylet_stats_lock = threading.Lock() self._raylet_stats = {} @@ -378,13 +378,23 @@ class RayletStats(threading.Thread): def update_nodes(self): with self.nodes_lock: self.nodes = ray.nodes() - self.stubs = [] + node_ids = [node["NodeID"] for node in self.nodes] + # First remove node connections of disconnected nodes. + for node_id in self.stubs.keys(): + if node_id not in node_ids: + stub = self.stubs.pop(node_id) + stub.close() + + # Now add node connections of new nodes. for node in self.nodes: - channel = grpc.insecure_channel("{}:{}".format( - node["NodeManagerAddress"], node["NodeManagerPort"])) - stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) - self.stubs.append(stub) + node_id = node["NodeID"] + if node_id not in self.stubs: + channel = grpc.insecure_channel("{}:{}".format( + node["NodeManagerAddress"], node["NodeManagerPort"])) + stub = node_manager_pb2_grpc.NodeManagerServiceStub( + channel) + self.stubs[node_id] = stub def get_raylet_stats(self) -> Dict: with self._raylet_stats_lock: @@ -395,7 +405,9 @@ class RayletStats(threading.Thread): while True: time.sleep(1.0) with self._raylet_stats_lock: - for node, stub in zip(self.nodes, self.stubs): + for node in self.nodes: + node_id = node["NodeID"] + stub = self.stubs[node_id] reply = stub.GetNodeStats( node_manager_pb2.NodeStatsRequest()) self._raylet_stats[node[