mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 23:57:45 +08:00
471 lines
17 KiB
Python
471 lines
17 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
try:
|
|
import aiohttp.web
|
|
except ImportError:
|
|
print("The dashboard requires aiohttp to run.")
|
|
import sys
|
|
sys.exit(1)
|
|
|
|
import argparse
|
|
import copy
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import threading
|
|
import time
|
|
import traceback
|
|
import yaml
|
|
|
|
from collections import defaultdict
|
|
from operator import itemgetter
|
|
from typing import Dict
|
|
|
|
import grpc
|
|
from google.protobuf.json_format import MessageToDict
|
|
import ray
|
|
from ray.core.generated import node_manager_pb2
|
|
from ray.core.generated import node_manager_pb2_grpc
|
|
import ray.ray_constants as ray_constants
|
|
import ray.utils
|
|
|
|
# Logger for this module. It should be configured at the entry point
|
|
# into the program using Ray. Ray provides a default configuration at
|
|
# entry/init points.
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def to_unix_time(dt):
|
|
return (dt - datetime.datetime(1970, 1, 1)).total_seconds()
|
|
|
|
|
|
def round_resource_value(quantity):
|
|
if quantity.is_integer():
|
|
return int(quantity)
|
|
else:
|
|
return round(quantity, 2)
|
|
|
|
|
|
def format_resource(resource_name, quantity):
|
|
if resource_name == "object_store_memory" or resource_name == "memory":
|
|
# Convert to 50MiB chunks and then to GiB
|
|
quantity = quantity * (50 * 1024 * 1024) / (1024 * 1024 * 1024)
|
|
return "{} GiB".format(round_resource_value(quantity))
|
|
return "{}".format(round_resource_value(quantity))
|
|
|
|
|
|
class Dashboard(object):
|
|
"""A dashboard process for monitoring Ray nodes.
|
|
|
|
This dashboard is made up of a REST API which collates data published by
|
|
Reporter processes on nodes into a json structure, and a webserver
|
|
which polls said API for display purposes.
|
|
|
|
Attributes:
|
|
redis_client: A client used to communicate with the Redis server.
|
|
"""
|
|
|
|
def __init__(self,
|
|
host,
|
|
port,
|
|
redis_address,
|
|
temp_dir,
|
|
redis_password=None):
|
|
"""Initialize the dashboard object."""
|
|
self.host = host
|
|
self.port = port
|
|
self.redis_client = ray.services.create_redis_client(
|
|
redis_address, password=redis_password)
|
|
self.temp_dir = temp_dir
|
|
|
|
self.node_stats = NodeStats(redis_address, redis_password)
|
|
self.raylet_stats = RayletStats(redis_address, redis_password)
|
|
|
|
# Setting the environment variable RAY_DASHBOARD_DEV=1 disables some
|
|
# security checks in the dashboard server to ease development while
|
|
# using the React dev server. Specifically, when this option is set, we
|
|
# allow cross-origin requests to be made.
|
|
self.is_dev = os.environ.get("RAY_DASHBOARD_DEV") == "1"
|
|
|
|
self.app = aiohttp.web.Application()
|
|
self.setup_routes()
|
|
|
|
def setup_routes(self):
|
|
def forbidden() -> aiohttp.web.Response:
|
|
return aiohttp.web.Response(status=403, text="403 Forbidden")
|
|
|
|
def get_forbidden(_) -> aiohttp.web.Response:
|
|
return forbidden()
|
|
|
|
async def get_index(req) -> aiohttp.web.Response:
|
|
return aiohttp.web.FileResponse(
|
|
os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)),
|
|
"client/build/index.html"))
|
|
|
|
async def json_response(result=None, error=None,
|
|
ts=None) -> aiohttp.web.Response:
|
|
if ts is None:
|
|
ts = datetime.datetime.utcnow()
|
|
|
|
headers = None
|
|
if self.is_dev:
|
|
headers = {"Access-Control-Allow-Origin": "*"}
|
|
|
|
return aiohttp.web.json_response(
|
|
{
|
|
"result": result,
|
|
"timestamp": to_unix_time(ts),
|
|
"error": error,
|
|
},
|
|
headers=headers)
|
|
|
|
async def ray_config(_) -> aiohttp.web.Response:
|
|
try:
|
|
config_path = os.path.expanduser("~/ray_bootstrap_config.yaml")
|
|
with open(config_path) as f:
|
|
cfg = yaml.safe_load(f)
|
|
except Exception:
|
|
return await json_response(error="No config")
|
|
|
|
D = {
|
|
"min_workers": cfg["min_workers"],
|
|
"max_workers": cfg["max_workers"],
|
|
"initial_workers": cfg["initial_workers"],
|
|
"autoscaling_mode": cfg["autoscaling_mode"],
|
|
"idle_timeout_minutes": cfg["idle_timeout_minutes"],
|
|
}
|
|
|
|
try:
|
|
D["head_type"] = cfg["head_node"]["InstanceType"]
|
|
except KeyError:
|
|
D["head_type"] = "unknown"
|
|
|
|
try:
|
|
D["worker_type"] = cfg["worker_nodes"]["InstanceType"]
|
|
except KeyError:
|
|
D["worker_type"] = "unknown"
|
|
|
|
return await json_response(result=D)
|
|
|
|
async def node_info(req) -> aiohttp.web.Response:
|
|
now = datetime.datetime.utcnow()
|
|
D = self.node_stats.get_node_stats()
|
|
return await json_response(result=D, ts=now)
|
|
|
|
async def raylet_info(req) -> aiohttp.web.Response:
|
|
D = self.raylet_stats.get_raylet_stats()
|
|
for address, data in D.items():
|
|
available_resources = data["availableResources"]
|
|
total_resources = data["totalResources"]
|
|
extra_info = []
|
|
for resource_name in sorted(available_resources.keys()):
|
|
total = total_resources[resource_name]
|
|
occupied = total - available_resources[resource_name]
|
|
total = format_resource(resource_name, total)
|
|
occupied = format_resource(resource_name, occupied)
|
|
extra_info.append("{}: {} / {}".format(
|
|
resource_name, occupied, total))
|
|
data["extraInfo"] = ", ".join(extra_info)
|
|
return await json_response(result=D)
|
|
|
|
async def logs(req) -> aiohttp.web.Response:
|
|
hostname = req.query.get("hostname")
|
|
pid = req.query.get("pid")
|
|
result = self.node_stats.get_logs(hostname, pid)
|
|
return await json_response(result=result)
|
|
|
|
async def errors(req) -> aiohttp.web.Response:
|
|
hostname = req.query.get("hostname")
|
|
pid = req.query.get("pid")
|
|
result = self.node_stats.get_errors(hostname, pid)
|
|
return await json_response(result=result)
|
|
|
|
self.app.router.add_get("/", get_index)
|
|
|
|
static_dir = os.path.join(
|
|
os.path.dirname(os.path.abspath(__file__)), "client/build/static")
|
|
if not os.path.isdir(static_dir):
|
|
raise ValueError(
|
|
"Dashboard static asset directory not found at '{}'. If "
|
|
"installing from source, please follow the additional steps "
|
|
"required to build the dashboard: "
|
|
"cd python/ray/dashboard/client && npm ci && "
|
|
"npm run build".format(static_dir))
|
|
self.app.router.add_static("/static", static_dir)
|
|
|
|
self.app.router.add_get("/api/ray_config", ray_config)
|
|
self.app.router.add_get("/api/node_info", node_info)
|
|
self.app.router.add_get("/api/raylet_info", raylet_info)
|
|
self.app.router.add_get("/api/logs", logs)
|
|
self.app.router.add_get("/api/errors", errors)
|
|
|
|
self.app.router.add_get("/{_}", get_forbidden)
|
|
|
|
def log_dashboard_url(self):
|
|
url = ray.services.get_webui_url_from_redis(self.redis_client)
|
|
with open(os.path.join(self.temp_dir, "dashboard_url"), "w") as f:
|
|
f.write(url)
|
|
logger.info("Dashboard running on {}".format(url))
|
|
|
|
def run(self):
|
|
self.log_dashboard_url()
|
|
self.node_stats.start()
|
|
self.raylet_stats.start()
|
|
aiohttp.web.run_app(self.app, host=self.host, port=self.port)
|
|
|
|
|
|
class NodeStats(threading.Thread):
|
|
def __init__(self, redis_address, redis_password=None):
|
|
self.redis_key = "{}.*".format(ray.gcs_utils.REPORTER_CHANNEL)
|
|
self.redis_client = ray.services.create_redis_client(
|
|
redis_address, password=redis_password)
|
|
|
|
self._node_stats = {}
|
|
self._node_stats_lock = threading.Lock()
|
|
|
|
# Mapping from IP address to PID to list of log lines
|
|
self._logs = defaultdict(lambda: defaultdict(list))
|
|
|
|
# Mapping from IP address to PID to list of error messages
|
|
self._errors = defaultdict(lambda: defaultdict(list))
|
|
|
|
ray.init(redis_address=redis_address, redis_password=redis_password)
|
|
|
|
super().__init__()
|
|
|
|
def calculate_log_counts(self):
|
|
return {
|
|
ip: {
|
|
pid: len(logs_for_pid)
|
|
for pid, logs_for_pid in logs_for_ip.items()
|
|
}
|
|
for ip, logs_for_ip in self._logs.items()
|
|
}
|
|
|
|
def calculate_error_counts(self):
|
|
return {
|
|
ip: {
|
|
pid: len(errors_for_pid)
|
|
for pid, errors_for_pid in errors_for_ip.items()
|
|
}
|
|
for ip, errors_for_ip in self._errors.items()
|
|
}
|
|
|
|
def purge_outdated_stats(self):
|
|
def current(then, now):
|
|
if (now - then) > 5:
|
|
return False
|
|
|
|
return True
|
|
|
|
now = to_unix_time(datetime.datetime.utcnow())
|
|
self._node_stats = {
|
|
k: v
|
|
for k, v in self._node_stats.items() if current(v["now"], now)
|
|
}
|
|
|
|
def get_node_stats(self) -> Dict:
|
|
with self._node_stats_lock:
|
|
self.purge_outdated_stats()
|
|
node_stats = sorted(
|
|
(v for v in self._node_stats.values()),
|
|
key=itemgetter("boot_time"))
|
|
return {
|
|
"clients": node_stats,
|
|
"log_counts": self.calculate_log_counts(),
|
|
"error_counts": self.calculate_error_counts(),
|
|
}
|
|
|
|
def get_logs(self, hostname, pid):
|
|
ip = self._node_stats.get(hostname, {"ip": None})["ip"]
|
|
logs = self._logs.get(ip, {})
|
|
if pid:
|
|
logs = {pid: logs.get(pid, [])}
|
|
return logs
|
|
|
|
def get_errors(self, hostname, pid):
|
|
ip = self._node_stats.get(hostname, {"ip": None})["ip"]
|
|
errors = self._errors.get(ip, {})
|
|
if pid:
|
|
errors = {pid: errors.get(pid, [])}
|
|
return errors
|
|
|
|
def run(self):
|
|
p = self.redis_client.pubsub(ignore_subscribe_messages=True)
|
|
|
|
p.psubscribe(self.redis_key)
|
|
logger.info("NodeStats: subscribed to {}".format(self.redis_key))
|
|
|
|
log_channel = ray.gcs_utils.LOG_FILE_CHANNEL
|
|
p.subscribe(log_channel)
|
|
logger.info("NodeStats: subscribed to {}".format(log_channel))
|
|
|
|
error_channel = ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")
|
|
p.subscribe(error_channel)
|
|
logger.info("NodeStats: subscribed to {}".format(error_channel))
|
|
|
|
for x in p.listen():
|
|
try:
|
|
with self._node_stats_lock:
|
|
channel = ray.utils.decode(x["channel"])
|
|
data = x["data"]
|
|
if channel == log_channel:
|
|
data = json.loads(ray.utils.decode(data))
|
|
ip = data["ip"]
|
|
pid = str(data["pid"])
|
|
self._logs[ip][pid].extend(data["lines"])
|
|
elif channel == str(error_channel):
|
|
gcs_entry = ray.gcs_utils.GcsEntry.FromString(data)
|
|
error_data = ray.gcs_utils.ErrorTableData.FromString(
|
|
gcs_entry.entries[0])
|
|
message = error_data.error_message
|
|
message = re.sub(r"\x1b\[\d+m", "", message)
|
|
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
|
|
if match:
|
|
pid = match.group(1)
|
|
ip = match.group(2)
|
|
self._errors[ip][pid].append({
|
|
"message": message,
|
|
"timestamp": error_data.timestamp,
|
|
"type": error_data.type
|
|
})
|
|
else:
|
|
data = json.loads(ray.utils.decode(data))
|
|
self._node_stats[data["hostname"]] = data
|
|
except Exception:
|
|
logger.exception(traceback.format_exc())
|
|
continue
|
|
|
|
|
|
class RayletStats(threading.Thread):
|
|
def __init__(self, redis_address, redis_password=None):
|
|
self.nodes_lock = threading.Lock()
|
|
self.nodes = []
|
|
self.stubs = {}
|
|
|
|
self._raylet_stats_lock = threading.Lock()
|
|
self._raylet_stats = {}
|
|
|
|
self.update_nodes()
|
|
|
|
super().__init__()
|
|
|
|
def update_nodes(self):
|
|
with self.nodes_lock:
|
|
self.nodes = ray.nodes()
|
|
node_ids = [node["NodeID"] for node in self.nodes]
|
|
|
|
# First remove node connections of disconnected nodes.
|
|
for node_id in self.stubs.keys():
|
|
if node_id not in node_ids:
|
|
stub = self.stubs.pop(node_id)
|
|
stub.close()
|
|
|
|
# Now add node connections of new nodes.
|
|
for node in self.nodes:
|
|
node_id = node["NodeID"]
|
|
if node_id not in self.stubs:
|
|
channel = grpc.insecure_channel("{}:{}".format(
|
|
node["NodeManagerAddress"], node["NodeManagerPort"]))
|
|
stub = node_manager_pb2_grpc.NodeManagerServiceStub(
|
|
channel)
|
|
self.stubs[node_id] = stub
|
|
|
|
def get_raylet_stats(self) -> Dict:
|
|
with self._raylet_stats_lock:
|
|
return copy.deepcopy(self._raylet_stats)
|
|
|
|
def run(self):
|
|
counter = 0
|
|
while True:
|
|
time.sleep(1.0)
|
|
replies = {}
|
|
for node in self.nodes:
|
|
node_id = node["NodeID"]
|
|
stub = self.stubs[node_id]
|
|
reply = stub.GetNodeStats(
|
|
node_manager_pb2.NodeStatsRequest(), timeout=2)
|
|
replies[node["NodeManagerAddress"]] = reply
|
|
with self._raylet_stats_lock:
|
|
for address, reply in replies.items():
|
|
self._raylet_stats[address] = MessageToDict(reply)
|
|
counter += 1
|
|
# From time to time, check if new nodes have joined the cluster
|
|
# and update self.nodes
|
|
if counter % 10:
|
|
self.update_nodes()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description=("Parse Redis server for the "
|
|
"dashboard to connect to."))
|
|
parser.add_argument(
|
|
"--host",
|
|
required=True,
|
|
type=str,
|
|
choices=["127.0.0.1", "0.0.0.0"],
|
|
help="The host to use for the HTTP server.")
|
|
parser.add_argument(
|
|
"--port",
|
|
required=True,
|
|
type=int,
|
|
help="The port to use for the HTTP server.")
|
|
parser.add_argument(
|
|
"--redis-address",
|
|
required=True,
|
|
type=str,
|
|
help="The address to use for Redis.")
|
|
parser.add_argument(
|
|
"--redis-password",
|
|
required=False,
|
|
type=str,
|
|
default=None,
|
|
help="the password to use for Redis")
|
|
parser.add_argument(
|
|
"--logging-level",
|
|
required=False,
|
|
type=str,
|
|
default=ray_constants.LOGGER_LEVEL,
|
|
choices=ray_constants.LOGGER_LEVEL_CHOICES,
|
|
help=ray_constants.LOGGER_LEVEL_HELP)
|
|
parser.add_argument(
|
|
"--logging-format",
|
|
required=False,
|
|
type=str,
|
|
default=ray_constants.LOGGER_FORMAT,
|
|
help=ray_constants.LOGGER_FORMAT_HELP)
|
|
parser.add_argument(
|
|
"--temp-dir",
|
|
required=False,
|
|
type=str,
|
|
default=None,
|
|
help="Specify the path of the temporary directory use by Ray process.")
|
|
args = parser.parse_args()
|
|
ray.utils.setup_logger(args.logging_level, args.logging_format)
|
|
|
|
try:
|
|
dashboard = Dashboard(
|
|
args.host,
|
|
args.port,
|
|
args.redis_address,
|
|
args.temp_dir,
|
|
redis_password=args.redis_password,
|
|
)
|
|
dashboard.run()
|
|
except Exception as e:
|
|
# Something went wrong, so push an error to all drivers.
|
|
redis_client = ray.services.create_redis_client(
|
|
args.redis_address, password=args.redis_password)
|
|
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
|
message = ("The dashboard on node {} failed with the following "
|
|
"error:\n{}".format(os.uname()[1], traceback_str))
|
|
ray.utils.push_error_to_driver_through_redis(
|
|
redis_client, ray_constants.DASHBOARD_DIED_ERROR, message)
|
|
raise e
|