diff --git a/ci/travis/test-wheels.sh b/ci/travis/test-wheels.sh index 3a9809f55..9d9f62b22 100755 --- a/ci/travis/test-wheels.sh +++ b/ci/travis/test-wheels.sh @@ -64,7 +64,7 @@ if [[ "$platform" == "linux" ]]; then $PYTHON_EXE -m pytest -v "$INSTALLED_RAY_DIRECTORY/$TEST_SCRIPT" # Run the UI test to make sure that the packaged UI works. - $PIP_CMD install -q aiohttp psutil requests setproctitle + $PIP_CMD install -q aiohttp google grpcio psutil requests setproctitle $PYTHON_EXE -m pytest -v "$INSTALLED_RAY_DIRECTORY/$UI_TEST_SCRIPT" # Check that the other wheels are present. @@ -106,7 +106,7 @@ elif [[ "$platform" == "macosx" ]]; then if (( $(echo "$PY_MM >= 3.0" | bc) )); then # Run the UI test to make sure that the packaged UI works. - $PIP_CMD install -q aiohttp psutil requests setproctitle + $PIP_CMD install -q aiohttp google grpcio psutil requests setproctitle $PYTHON_EXE -m pytest -v "$INSTALLED_RAY_DIRECTORY/$UI_TEST_SCRIPT" fi done diff --git a/python/ray/dashboard/dashboard.py b/python/ray/dashboard/dashboard.py index 4cdba6bae..e8c510d42 100644 --- a/python/ray/dashboard/dashboard.py +++ b/python/ray/dashboard/dashboard.py @@ -10,12 +10,14 @@ except ImportError: sys.exit(1) import argparse +import copy import datetime import json import logging import os import re import threading +import time import traceback import yaml @@ -25,7 +27,11 @@ from collections import defaultdict from operator import itemgetter from typing import Dict +import grpc +from google.protobuf.json_format import MessageToDict import ray +from ray.core.generated import node_manager_pb2 +from ray.core.generated import node_manager_pb2_grpc import ray.ray_constants as ray_constants import ray.utils @@ -64,6 +70,7 @@ class Dashboard(object): self.temp_dir = temp_dir self.node_stats = NodeStats(redis_address, redis_password) + self.raylet_stats = RayletStats(redis_address, redis_password) # Setting the environment variable RAY_DASHBOARD_DEV=1 disables some # security checks in the dashboard server to ease development while @@ -137,6 +144,10 @@ class Dashboard(object): D = self.node_stats.get_node_stats() return await json_response(result=D, ts=now) + async def raylet_info(req) -> aiohttp.web.Response: + D = self.raylet_stats.get_raylet_stats() + return await json_response(result=D) + async def logs(req) -> aiohttp.web.Response: hostname = req.query.get("hostname") pid = req.query.get("pid") @@ -162,6 +173,7 @@ class Dashboard(object): self.app.router.add_get("/api/ray_config", ray_config) self.app.router.add_get("/api/node_info", node_info) + self.app.router.add_get("/api/raylet_info", raylet_info) self.app.router.add_get("/api/logs", logs) self.app.router.add_get("/api/errors", errors) @@ -176,6 +188,7 @@ class Dashboard(object): def run(self): self.log_dashboard_url() self.node_stats.start() + self.raylet_stats.start() aiohttp.web.run_app(self.app, host=self.host, port=self.port) @@ -349,6 +362,51 @@ class NodeStats(threading.Thread): continue +class RayletStats(threading.Thread): + def __init__(self, redis_address, redis_password=None): + self.nodes_lock = threading.Lock() + self.nodes = [] + self.stubs = [] + + self._raylet_stats_lock = threading.Lock() + self._raylet_stats = {} + + self.update_nodes() + + super().__init__() + + def update_nodes(self): + with self.nodes_lock: + self.nodes = ray.nodes() + self.stubs = [] + + for node in self.nodes: + channel = grpc.insecure_channel("{}:{}".format( + node["NodeManagerAddress"], node["NodeManagerPort"])) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) + self.stubs.append(stub) + + def get_raylet_stats(self) -> Dict: + with self._raylet_stats_lock: + return copy.deepcopy(self._raylet_stats) + + def run(self): + counter = 0 + while True: + time.sleep(1.0) + with self._raylet_stats_lock: + for node, stub in zip(self.nodes, self.stubs): + reply = stub.GetNodeStats( + node_manager_pb2.NodeStatsRequest()) + self._raylet_stats[node[ + "NodeManagerAddress"]] = MessageToDict(reply) + counter += 1 + # From time to time, check if new nodes have joined the cluster + # and update self.nodes + if counter % 10: + self.update_nodes() + + if __name__ == "__main__": parser = argparse.ArgumentParser( description=("Parse Redis server for the " diff --git a/python/setup.py b/python/setup.py index 7492745aa..abf5356ce 100644 --- a/python/setup.py +++ b/python/setup.py @@ -77,7 +77,7 @@ extras = { "tabulate" ], "debug": ["psutil", "setproctitle", "py-spy >= 0.2.0"], - "dashboard": ["aiohttp", "psutil", "setproctitle"], + "dashboard": ["aiohttp", "google", "grpcio", "psutil", "setproctitle"], "serve": ["uvicorn", "pygments", "werkzeug", "flask", "pandas"], "tune": ["tabulate"], }