[Dashboard] New dashboard skeleton (#9099)

This commit is contained in:
fyrestone
2020-07-27 11:34:47 +08:00
committed by GitHub
parent 44ccca1acb
commit 4d08ddbf24
15 changed files with 1406 additions and 1 deletions
View File
+229
View File
@@ -0,0 +1,229 @@
import argparse
import asyncio
import logging
import logging.handlers
import os
import sys
import traceback
import aiohttp
import aioredis
from grpc.experimental import aio as aiogrpc
import ray
import ray.new_dashboard.consts as dashboard_consts
import ray.new_dashboard.utils as dashboard_utils
import ray.ray_constants as ray_constants
import ray.services
import ray.utils
import psutil
logger = logging.getLogger(__name__)
aiogrpc.init_grpc_aio()
class DashboardAgent(object):
def __init__(self,
redis_address,
redis_password=None,
temp_dir=None,
log_dir=None,
node_manager_port=None,
object_store_name=None,
raylet_name=None):
"""Initialize the DashboardAgent object."""
self._agent_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardAgentModule)
ip, port = redis_address.split(":")
# Public attributes are accessible for all agent modules.
self.redis_address = (ip, int(port))
self.redis_password = redis_password
self.temp_dir = temp_dir
self.log_dir = log_dir
self.node_manager_port = node_manager_port
self.object_store_name = object_store_name
self.raylet_name = raylet_name
self.ip = ray.services.get_node_ip_address()
self.server = aiogrpc.server(options=(("grpc.so_reuseport", 0), ))
listen_address = "[::]:0"
logger.info("Dashboard agent listen at: %s", listen_address)
self.port = self.server.add_insecure_port(listen_address)
self.aioredis_client = None
self.aiogrpc_raylet_channel = aiogrpc.insecure_channel("{}:{}".format(
self.ip, self.node_manager_port))
self.http_session = aiohttp.ClientSession(
loop=asyncio.get_event_loop())
def _load_modules(self):
"""Load dashboard agent modules."""
modules = []
for cls in self._agent_cls_list:
logger.info("Load %s: %s",
dashboard_utils.DashboardAgentModule.__name__, cls)
c = cls(self)
modules.append(c)
logger.info("Load {} modules.".format(len(modules)))
return modules
async def run(self):
# Create an aioredis client for all modules.
self.aioredis_client = await aioredis.create_redis_pool(
address=self.redis_address, password=self.redis_password)
# Start a grpc asyncio server.
await self.server.start()
# Write the dashboard agent port to redis.
await self.aioredis_client.set(
"{}{}".format(dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX,
self.ip), self.port)
async def _check_parent():
"""Check if raylet is dead."""
curr_proc = psutil.Process()
while True:
parent = curr_proc.parent()
if parent is None or parent.pid == 1:
logger.error("raylet is dead, agent will die because "
"it fate-shares with raylet.")
sys.exit(0)
await asyncio.sleep(
dashboard_consts.
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS)
modules = self._load_modules()
await asyncio.gather(_check_parent(),
*(m.run(self.server) for m in modules))
await self.server.wait_for_termination()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Dashboard agent.")
parser.add_argument(
"--redis-address",
required=True,
type=str,
help="The address to use for Redis.")
parser.add_argument(
"--node-manager-port",
required=True,
type=int,
help="The port to use for starting the node manager")
parser.add_argument(
"--object-store-name",
required=True,
type=str,
default=None,
help="The socket name of the plasma store")
parser.add_argument(
"--raylet-name",
required=True,
type=str,
default=None,
help="The socket path of the raylet process")
parser.add_argument(
"--redis-password",
required=False,
type=str,
default=None,
help="The password to use for Redis")
parser.add_argument(
"--logging-level",
required=False,
type=lambda s: logging.getLevelName(s.upper()),
default=ray_constants.LOGGER_LEVEL,
choices=ray_constants.LOGGER_LEVEL_CHOICES,
help=ray_constants.LOGGER_LEVEL_HELP)
parser.add_argument(
"--logging-format",
required=False,
type=str,
default=ray_constants.LOGGER_FORMAT,
help=ray_constants.LOGGER_FORMAT_HELP)
parser.add_argument(
"--logging-filename",
required=False,
type=str,
default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME,
help="Specify the name of log file, "
"log to stdout if set empty, default is \"{}\".".format(
dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME))
parser.add_argument(
"--logging-rotate-bytes",
required=False,
type=int,
default=dashboard_consts.LOGGING_ROTATE_BYTES,
help="Specify the max bytes for rotating "
"log file, default is {} bytes.".format(
dashboard_consts.LOGGING_ROTATE_BYTES))
parser.add_argument(
"--logging-rotate-backup-count",
required=False,
type=int,
default=dashboard_consts.LOGGING_ROTATE_BACKUP_COUNT,
help="Specify the backup count of rotated log file, default is {}.".
format(dashboard_consts.LOGGING_ROTATE_BACKUP_COUNT))
parser.add_argument(
"--log-dir",
required=False,
type=str,
default=None,
help="Specify the path of log directory.")
parser.add_argument(
"--temp-dir",
required=False,
type=str,
default=None,
help="Specify the path of the temporary directory use by Ray process.")
args = parser.parse_args()
try:
if args.temp_dir:
temp_dir = "/" + args.temp_dir.strip("/")
else:
temp_dir = "/tmp/ray"
os.makedirs(temp_dir, exist_ok=True)
if args.log_dir:
log_dir = args.log_dir
else:
log_dir = os.path.join(temp_dir, "session_latest/logs")
os.makedirs(log_dir, exist_ok=True)
if args.logging_filename:
logging_handlers = [
logging.handlers.RotatingFileHandler(
os.path.join(log_dir, args.logging_filename),
maxBytes=args.logging_rotate_bytes,
backupCount=args.logging_rotate_backup_count)
]
else:
logging_handlers = None
logging.basicConfig(
level=args.logging_level,
format=args.logging_format,
handlers=logging_handlers)
agent = DashboardAgent(
args.redis_address,
redis_password=args.redis_password,
temp_dir=temp_dir,
log_dir=log_dir,
node_manager_port=args.node_manager_port,
object_store_name=args.object_store_name,
raylet_name=args.raylet_name)
loop = asyncio.get_event_loop()
loop.create_task(agent.run())
loop.run_forever()
except Exception as e:
# Something went wrong, so push an error to all drivers.
redis_client = ray.services.create_redis_client(
args.redis_address, password=args.redis_password)
traceback_str = ray.utils.format_error_message(traceback.format_exc())
message = ("The agent on node {} failed with the following "
"error:\n{}".format(os.uname()[1], traceback_str))
ray.utils.push_error_to_driver_through_redis(
redis_client, ray_constants.DASHBOARD_AGENT_DIED_ERROR, message)
raise e
+16
View File
@@ -0,0 +1,16 @@
DASHBOARD_AGENT_PORT_PREFIX = "DASHBOARD_AGENT_PORT_PREFIX:"
DASHBOARD_AGENT_LOG_FILENAME = "dashboard_agent.log"
DASHBOARD_AGENT_CHECK_PARENT_INTERVAL_SECONDS = 2
MAX_COUNT_OF_GCS_RPC_ERROR = 10
UPDATE_NODES_INTERVAL_SECONDS = 5
CONNECT_GCS_INTERVAL_SECONDS = 2
PURGE_DATA_INTERVAL_SECONDS = 60 * 10
REDIS_KEY_DASHBOARD = "dashboard"
REDIS_KEY_GCS_SERVER_ADDRESS = "GcsServerAddress"
REPORT_METRICS_TIMEOUT_SECONDS = 2
REPORT_METRICS_INTERVAL_SECONDS = 10
# Named signals
SIGNAL_NODE_INFO_FETCHED = "node_info_fetched"
# Default param for RotatingFileHandler
LOGGING_ROTATE_BYTES = 100 * 1000 # maxBytes
LOGGING_ROTATE_BACKUP_COUNT = 5 # backupCount
+240
View File
@@ -0,0 +1,240 @@
try:
import aiohttp.web
except ImportError:
print("The dashboard requires aiohttp to run.")
import sys
sys.exit(1)
import argparse
import asyncio
import errno
import logging
import logging.handlers
import os
import traceback
import uuid
import aioredis
import ray
import ray.new_dashboard.consts as dashboard_consts
import ray.new_dashboard.head as dashboard_head
import ray.new_dashboard.utils as dashboard_utils
import ray.ray_constants as ray_constants
import ray.services
import ray.utils
# Logger for this module. It should be configured at the entry point
# into the program using Ray. Ray provides a default configuration at
# entry/init points.
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
def setup_static_dir(app):
build_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "client/build")
module_name = os.path.basename(os.path.dirname(__file__))
if not os.path.isdir(build_dir):
raise OSError(
errno.ENOENT, "Dashboard build directory not found. If installing "
"from source, please follow the additional steps "
"required to build the dashboard"
"(cd python/ray/{}/client "
"&& npm install "
"&& npm ci "
"&& npm run build)".format(module_name), build_dir)
static_dir = os.path.join(build_dir, "static")
app.router.add_static("/static", static_dir, follow_symlinks=True)
return build_dir
class Dashboard:
"""A dashboard process for monitoring Ray nodes.
This dashboard is made up of a REST API which collates data published by
Reporter processes on nodes into a json structure, and a webserver
which polls said API for display purposes.
Args:
host(str): Host address of dashboard aiohttp server.
port(int): Port number of dashboard aiohttp server.
redis_address(str): GCS address of a Ray cluster
temp_dir (str): The temporary directory used for log files and
information for this Ray session.
redis_password(str): Redis password to access GCS
"""
def __init__(self,
host,
port,
redis_address,
temp_dir,
redis_password=None):
self.host = host
self.port = port
self.temp_dir = temp_dir
self.dashboard_id = str(uuid.uuid4())
self.dashboard_head = dashboard_head.DashboardHead(
redis_address=redis_address, redis_password=redis_password)
self.app = aiohttp.web.Application()
self.app.add_routes(routes=routes.routes())
# Setup Dashboard Routes
build_dir = setup_static_dir(self.app)
logger.info("Setup static dir for dashboard: %s", build_dir)
dashboard_utils.ClassMethodRouteTable.bind(self)
@routes.get("/")
async def get_index(self, req) -> aiohttp.web.FileResponse:
return aiohttp.web.FileResponse(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"client/build/index.html"))
@routes.get("/favicon.ico")
async def get_favicon(self, req) -> aiohttp.web.FileResponse:
return aiohttp.web.FileResponse(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"client/build/favicon.ico"))
async def run(self):
coroutines = [
self.dashboard_head.run(),
aiohttp.web._run_app(self.app, host=self.host, port=self.port)
]
ip = ray.services.get_node_ip_address()
aioredis_client = await aioredis.create_redis_pool(
address=self.dashboard_head.redis_address,
password=self.dashboard_head.redis_password)
await aioredis_client.set(dashboard_consts.REDIS_KEY_DASHBOARD,
ip + ":" + str(self.port))
await asyncio.gather(*coroutines)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=("Parse Redis server for the "
"dashboard to connect to."))
parser.add_argument(
"--host",
required=True,
type=str,
help="The host to use for the HTTP server.")
parser.add_argument(
"--port",
required=True,
type=int,
help="The port to use for the HTTP server.")
parser.add_argument(
"--redis-address",
required=True,
type=str,
help="The address to use for Redis.")
parser.add_argument(
"--redis-password",
required=False,
type=str,
default=None,
help="The password to use for Redis")
parser.add_argument(
"--logging-level",
required=False,
type=lambda s: logging.getLevelName(s.upper()),
default=ray_constants.LOGGER_LEVEL,
choices=ray_constants.LOGGER_LEVEL_CHOICES,
help=ray_constants.LOGGER_LEVEL_HELP)
parser.add_argument(
"--logging-format",
required=False,
type=str,
default=ray_constants.LOGGER_FORMAT,
help=ray_constants.LOGGER_FORMAT_HELP)
parser.add_argument(
"--logging-filename",
required=False,
type=str,
default="",
help="Specify the name of log file, "
"log to stdout if set empty, default is \"\"")
parser.add_argument(
"--logging-rotate-bytes",
required=False,
type=int,
default=dashboard_consts.LOGGING_ROTATE_BYTES,
help="Specify the max bytes for rotating "
"log file, default is {} bytes.".format(
dashboard_consts.LOGGING_ROTATE_BYTES))
parser.add_argument(
"--logging-rotate-backup-count",
required=False,
type=int,
default=dashboard_consts.LOGGING_ROTATE_BACKUP_COUNT,
help="Specify the backup count of rotated log file, default is {}.".
format(dashboard_consts.LOGGING_ROTATE_BACKUP_COUNT))
parser.add_argument(
"--log-dir",
required=False,
type=str,
default=None,
help="Specify the path of log directory.")
parser.add_argument(
"--temp-dir",
required=False,
type=str,
default=None,
help="Specify the path of the temporary directory use by Ray process.")
args = parser.parse_args()
try:
if args.temp_dir:
temp_dir = "/" + args.temp_dir.strip("/")
else:
temp_dir = "/tmp/ray"
os.makedirs(temp_dir, exist_ok=True)
if args.log_dir:
log_dir = args.log_dir
else:
log_dir = os.path.join(temp_dir, "session_latest/logs")
os.makedirs(log_dir, exist_ok=True)
if args.logging_filename:
logging_handlers = [
logging.handlers.RotatingFileHandler(
os.path.join(log_dir, args.logging_filename),
maxBytes=args.logging_rotate_bytes,
backupCount=args.logging_rotate_backup_count)
]
else:
logging_handlers = None
logging.basicConfig(
level=args.logging_level,
format=args.logging_format,
handlers=logging_handlers)
dashboard = Dashboard(
args.host,
args.port,
args.redis_address,
temp_dir,
redis_password=args.redis_password)
loop = asyncio.get_event_loop()
loop.run_until_complete(dashboard.run())
except Exception as e:
# Something went wrong, so push an error to all drivers.
redis_client = ray.services.create_redis_client(
args.redis_address, password=args.redis_password)
traceback_str = ray.utils.format_error_message(traceback.format_exc())
message = ("The dashboard on node {} failed with the following "
"error:\n{}".format(os.uname()[1], traceback_str))
ray.utils.push_error_to_driver_through_redis(
redis_client, ray_constants.DASHBOARD_DIED_ERROR, message)
if isinstance(e, OSError) and e.errno == errno.ENOENT:
logger.warning(message)
else:
raise e
+108
View File
@@ -0,0 +1,108 @@
import logging
import ray.new_dashboard.consts as dashboard_consts
from ray.new_dashboard.utils import Dict, Signal
logger = logging.getLogger(__name__)
class GlobalSignals:
node_info_fetched = Signal(dashboard_consts.SIGNAL_NODE_INFO_FETCHED)
class DataSource:
# {ip address(str): node stats(dict of GetNodeStatsReply
# in node_manager.proto)}
node_stats = Dict()
# {ip address(str): node physical stats(dict from reporter_agent.py)}
node_physical_stats = Dict()
# {actor id hex(str): actor table data(dict of ActorTableData
# in gcs.proto)}
actors = Dict()
# {ip address(str): dashboard agent grpc server port(int)}
agents = Dict()
# {ip address(str): gcs node info(dict of GcsNodeInfo in gcs.proto)}
nodes = Dict()
# {hostname(str): ip address(str)}
hostname_to_ip = Dict()
# {ip address(str): hostname(str)}
ip_to_hostname = Dict()
class DataOrganizer:
@staticmethod
async def purge():
# Purge data that is out of date.
# These data sources are maintained by DashboardHead,
# we do not needs to purge them:
# * agents
# * nodes
# * hostname_to_ip
# * ip_to_hostname
logger.info("Purge data.")
valid_keys = DataSource.ip_to_hostname.keys()
for key in DataSource.node_stats.keys() - valid_keys:
DataSource.node_stats.pop(key)
for key in DataSource.node_physical_stats.keys() - valid_keys:
DataSource.node_physical_stats.pop(key)
@classmethod
async def get_node_actors(cls, hostname):
ip = DataSource.hostname_to_ip[hostname]
node_stats = DataSource.node_stats.get(ip, {})
node_worker_id_set = set()
for worker_stats in node_stats.get("workersStats", []):
node_worker_id_set.add(worker_stats["workerId"])
node_actors = {}
for actor_id, actor_table_data in DataSource.actors.items():
if actor_table_data["workerId"] in node_worker_id_set:
node_actors[actor_id] = actor_table_data
return node_actors
@classmethod
async def get_node_info(cls, hostname):
ip = DataSource.hostname_to_ip[hostname]
node_physical_stats = DataSource.node_physical_stats.get(ip, {})
node_stats = DataSource.node_stats.get(ip, {})
# Merge coreWorkerStats (node stats) to workers (node physical stats)
workers_stats = node_stats.pop("workersStats", {})
pid_to_worker_stats = {}
pid_to_language = {}
pid_to_job_id = {}
for stats in workers_stats:
d = pid_to_worker_stats.setdefault(stats["pid"], {}).setdefault(
stats["workerId"], stats["coreWorkerStats"])
d["workerId"] = stats["workerId"]
pid_to_language.setdefault(stats["pid"],
stats.get("language", "PYTHON"))
pid_to_job_id.setdefault(stats["pid"],
stats["coreWorkerStats"]["jobId"])
for worker in node_physical_stats.get("workers", []):
worker_stats = pid_to_worker_stats.get(worker["pid"], {})
worker["coreWorkerStats"] = list(worker_stats.values())
worker["language"] = pid_to_language.get(worker["pid"], "")
worker["jobId"] = pid_to_job_id.get(worker["pid"], "ffff")
# Merge node stats to node physical stats
node_info = node_physical_stats
node_info["raylet"] = node_stats
node_info["actors"] = await cls.get_node_actors(hostname)
node_info["state"] = DataSource.nodes.get(ip, {}).get("state", "DEAD")
await GlobalSignals.node_info_fetched.send(node_info)
return node_info
@classmethod
async def get_all_node_summary(cls):
all_nodes_summary = []
for hostname in DataSource.hostname_to_ip.keys():
node_info = await cls.get_node_info(hostname)
node_info.pop("workers", None)
node_info["raylet"].pop("workersStats", None)
node_info["raylet"].pop("viewData", None)
all_nodes_summary.append(node_info)
return all_nodes_summary
+170
View File
@@ -0,0 +1,170 @@
import sys
import asyncio
import logging
import aiohttp
import aioredis
from grpc.experimental import aio as aiogrpc
import ray.services
import ray.new_dashboard.consts as dashboard_consts
import ray.new_dashboard.utils as dashboard_utils
from ray.core.generated import gcs_service_pb2
from ray.core.generated import gcs_service_pb2_grpc
from ray.new_dashboard.datacenter import DataSource, DataOrganizer
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
aiogrpc.init_grpc_aio()
def gcs_node_info_to_dict(message):
return dashboard_utils.message_to_dict(
message, {"nodeId"}, including_default_value_fields=True)
class DashboardHead:
def __init__(self, redis_address, redis_password):
# Scan and import head modules for collecting http routes.
self._head_cls_list = dashboard_utils.get_all_modules(
dashboard_utils.DashboardHeadModule)
ip, port = redis_address.split(":")
# NodeInfoGcsService
self._gcs_node_info_stub = None
self._gcs_rpc_error_counter = 0
# Public attributes are accessible for all head modules.
self.redis_address = (ip, int(port))
self.redis_password = redis_password
self.aioredis_client = None
self.aiogrpc_gcs_channel = None
self.http_session = aiohttp.ClientSession(
loop=asyncio.get_event_loop())
self.ip = ray.services.get_node_ip_address()
async def _get_nodes(self):
"""Read the client table.
Returns:
A list of information about the nodes in the cluster.
"""
request = gcs_service_pb2.GetAllNodeInfoRequest()
reply = await self._gcs_node_info_stub.GetAllNodeInfo(
request, timeout=2)
if reply.status.code == 0:
results = []
node_id_set = set()
for node_info in reply.node_info_list:
if node_info.node_id in node_id_set:
continue
node_id_set.add(node_info.node_id)
node_info_dict = gcs_node_info_to_dict(node_info)
results.append(node_info_dict)
return results
else:
logger.error("Failed to GetAllNodeInfo: %s", reply.status.message)
async def _update_nodes(self):
while True:
try:
nodes = await self._get_nodes()
self._gcs_rpc_error_counter = 0
node_ips = [node["nodeManagerAddress"] for node in nodes]
node_hostnames = [
node["nodeManagerHostname"] for node in nodes
]
agents = dict(DataSource.agents)
for node in nodes:
node_ip = node["nodeManagerAddress"]
if node_ip not in agents:
key = "{}{}".format(
dashboard_consts.DASHBOARD_AGENT_PORT_PREFIX,
node_ip)
agent_port = await self.aioredis_client.get(key)
if agent_port:
agents[node_ip] = agent_port
for ip in agents.keys() - set(node_ips):
agents.pop(ip, None)
DataSource.agents.reset(agents)
DataSource.nodes.reset(dict(zip(node_ips, nodes)))
DataSource.hostname_to_ip.reset(
dict(zip(node_hostnames, node_ips)))
DataSource.ip_to_hostname.reset(
dict(zip(node_ips, node_hostnames)))
except aiogrpc.AioRpcError as ex:
logger.exception(ex)
self._gcs_rpc_error_counter += 1
if self._gcs_rpc_error_counter > \
dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR:
logger.error(
"Dashboard suicide, the GCS RPC error count %s > %s",
self._gcs_rpc_error_counter,
dashboard_consts.MAX_COUNT_OF_GCS_RPC_ERROR)
sys.exit(-1)
except Exception as ex:
logger.exception(ex)
finally:
await asyncio.sleep(
dashboard_consts.UPDATE_NODES_INTERVAL_SECONDS)
def _load_modules(self):
"""Load dashboard head modules."""
modules = []
for cls in self._head_cls_list:
logger.info("Load %s: %s",
dashboard_utils.DashboardHeadModule.__name__, cls)
c = cls(self)
dashboard_utils.ClassMethodRouteTable.bind(c)
modules.append(c)
return modules
async def run(self):
# Create an aioredis client for all modules.
self.aioredis_client = await aioredis.create_redis_pool(
address=self.redis_address, password=self.redis_password)
# Waiting for GCS is ready.
while True:
try:
gcs_address = await self.aioredis_client.get(
dashboard_consts.REDIS_KEY_GCS_SERVER_ADDRESS)
if not gcs_address:
raise Exception("GCS address not found.")
logger.info("Connect to GCS at %s", gcs_address)
channel = aiogrpc.insecure_channel(gcs_address)
except Exception as ex:
logger.error("Connect to GCS failed: %s, retry...", ex)
await asyncio.sleep(
dashboard_consts.CONNECT_GCS_INTERVAL_SECONDS)
else:
self.aiogrpc_gcs_channel = channel
break
# Create a NodeInfoGcsServiceStub.
self._gcs_node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
self.aiogrpc_gcs_channel)
async def _async_notify():
"""Notify signals from queue."""
while True:
co = await dashboard_utils.NotifyQueue.get()
try:
await co
except Exception as e:
logger.exception(e)
async def _purge_data():
"""Purge data in datacenter."""
while True:
await asyncio.sleep(
dashboard_consts.PURGE_DATA_INTERVAL_SECONDS)
try:
await DataOrganizer.purge()
except Exception as e:
logger.exception(e)
modules = self._load_modules()
# Freeze signal after all modules loaded.
dashboard_utils.SignalManager.freeze()
await asyncio.gather(self._update_nodes(), _async_notify(),
_purge_data(), *(m.run() for m in modules))
View File
@@ -0,0 +1,199 @@
import asyncio
import datetime
import json
import logging
import os
import socket
import subprocess
import sys
import aioredis
import ray
import ray.gcs_utils
import ray.new_dashboard.modules.reporter.reporter_consts as reporter_consts
import ray.new_dashboard.utils as dashboard_utils
import ray.services
import ray.utils
from ray.core.generated import reporter_pb2
from ray.core.generated import reporter_pb2_grpc
import psutil
logger = logging.getLogger(__name__)
def recursive_asdict(o):
if isinstance(o, tuple) and hasattr(o, "_asdict"):
return recursive_asdict(o._asdict())
if isinstance(o, (tuple, list)):
L = []
for k in o:
L.append(recursive_asdict(k))
return L
if isinstance(o, dict):
D = {k: recursive_asdict(v) for k, v in o.items()}
return D
return o
def jsonify_asdict(o):
return json.dumps(dashboard_utils.to_google_style(recursive_asdict(o)))
class ReporterAgent(dashboard_utils.DashboardAgentModule,
reporter_pb2_grpc.ReporterServiceServicer):
"""A monitor process for monitoring Ray nodes.
Attributes:
dashboard_agent: The DashboardAgent object contains global config
"""
def __init__(self, dashboard_agent):
"""Initialize the reporter object."""
super().__init__(dashboard_agent)
self._cpu_counts = (psutil.cpu_count(),
psutil.cpu_count(logical=False))
self._ip = ray.services.get_node_ip_address()
self._hostname = socket.gethostname()
self._workers = set()
self._network_stats_hist = [(0, (0.0, 0.0))] # time, (sent, recv)
async def GetProfilingStats(self, request, context):
pid = request.pid
duration = request.duration
profiling_file_path = os.path.join(ray.utils.get_ray_temp_dir(),
"{}_profiling.txt".format(pid))
process = subprocess.Popen(
"sudo $(which py-spy) record -o {} -p {} -d {} -f speedscope"
.format(profiling_file_path, pid, duration),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
stdout, stderr = process.communicate()
if process.returncode != 0:
profiling_stats = ""
else:
with open(profiling_file_path, "r") as f:
profiling_stats = f.read()
return reporter_pb2.GetProfilingStatsReply(
profiling_stats=profiling_stats, stdout=stdout, stderr=stderr)
@staticmethod
def _get_cpu_percent():
return psutil.cpu_percent()
@staticmethod
def _get_boot_time():
return psutil.boot_time()
@staticmethod
def _get_network_stats():
ifaces = [
v for k, v in psutil.net_io_counters(pernic=True).items()
if k[0] == "e"
]
sent = sum((iface.bytes_sent for iface in ifaces))
recv = sum((iface.bytes_recv for iface in ifaces))
return sent, recv
@staticmethod
def _get_mem_usage():
vm = psutil.virtual_memory()
return vm.total, vm.available, vm.percent
@staticmethod
def _get_disk_usage():
dirs = [
os.environ["USERPROFILE"] if sys.platform == "win32" else os.sep,
ray.utils.get_user_temp_dir(),
]
return {x: psutil.disk_usage(x) for x in dirs}
def _get_workers(self):
curr_proc = psutil.Process()
parent = curr_proc.parent()
if parent is None or parent.pid == 1:
return []
else:
workers = set(parent.children())
self._workers.intersection_update(workers)
self._workers.update(workers)
self._workers.discard(curr_proc)
return [
w.as_dict(attrs=[
"pid",
"create_time",
"cpu_percent",
"cpu_times",
"cmdline",
"memory_info",
]) for w in self._workers if w.status() != psutil.STATUS_ZOMBIE
]
@staticmethod
def _get_raylet_cmdline():
curr_proc = psutil.Process()
parent = curr_proc.parent()
if parent.pid == 1:
return ""
else:
return parent.cmdline()
def _get_load_avg(self):
if sys.platform == "win32":
cpu_percent = psutil.cpu_percent()
load = (cpu_percent, cpu_percent, cpu_percent)
else:
load = os.getloadavg()
per_cpu_load = tuple((round(x / self._cpu_counts[0], 2) for x in load))
return load, per_cpu_load
def _get_all_stats(self):
now = dashboard_utils.to_posix_time(datetime.datetime.utcnow())
network_stats = self._get_network_stats()
self._network_stats_hist.append((now, network_stats))
self._network_stats_hist = self._network_stats_hist[-7:]
then, prev_network_stats = self._network_stats_hist[0]
netstats = ((network_stats[0] - prev_network_stats[0]) / (now - then),
(network_stats[1] - prev_network_stats[1]) / (now - then))
return {
"now": now,
"hostname": self._hostname,
"ip": self._ip,
"cpu": self._get_cpu_percent(),
"cpus": self._cpu_counts,
"mem": self._get_mem_usage(),
"workers": self._get_workers(),
"bootTime": self._get_boot_time(),
"loadAvg": self._get_load_avg(),
"disk": self._get_disk_usage(),
"net": netstats,
"cmdline": self._get_raylet_cmdline(),
}
async def _perform_iteration(self):
"""Get any changes to the log files and push updates to Redis."""
aioredis_client = await aioredis.create_redis_pool(
address=self._dashboard_agent.redis_address,
password=self._dashboard_agent.redis_password)
while True:
try:
stats = self._get_all_stats()
await aioredis_client.publish(
"{}{}".format(reporter_consts.REPORTER_PREFIX,
self._hostname), jsonify_asdict(stats))
except Exception as ex:
logger.exception(ex)
await asyncio.sleep(
reporter_consts.REPORTER_UPDATE_INTERVAL_MS / 1000)
async def run(self, server):
reporter_pb2_grpc.add_ReporterServiceServicer_to_server(self, server)
await self._perform_iteration()
@@ -0,0 +1,6 @@
import ray.ray_constants as ray_constants
REPORTER_PREFIX = "RAY_REPORTER:"
# The reporter will report its statistics this often (milliseconds).
REPORTER_UPDATE_INTERVAL_MS = ray_constants.env_integer(
"REPORTER_UPDATE_INTERVAL_MS", 2500)
@@ -0,0 +1,94 @@
import json
import logging
import uuid
import aiohttp.web
from aioredis.pubsub import Receiver
from grpc.experimental import aio as aiogrpc
import ray
import ray.gcs_utils
import ray.new_dashboard.modules.reporter.reporter_consts as reporter_consts
import ray.new_dashboard.utils as dashboard_utils
import ray.services
import ray.utils
from ray.core.generated import reporter_pb2
from ray.core.generated import reporter_pb2_grpc
from ray.new_dashboard.datacenter import DataSource
logger = logging.getLogger(__name__)
routes = dashboard_utils.ClassMethodRouteTable
class ReportHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)
self._stubs = {}
self._profiling_stats = {}
DataSource.agents.signal.append(self._update_stubs)
async def _update_stubs(self, change):
if change.new:
ip, port = next(iter(change.new.items()))
channel = aiogrpc.insecure_channel("{}:{}".format(ip, int(port)))
stub = reporter_pb2_grpc.ReporterServiceStub(channel)
self._stubs[ip] = stub
if change.old:
ip, port = next(iter(change.old.items()))
self._stubs.pop(ip)
@routes.get("/api/launch_profiling")
async def launch_profiling(self, req) -> aiohttp.web.Response:
node_id = req.query.get("node_id")
pid = int(req.query.get("pid"))
duration = int(req.query.get("duration"))
profiling_id = str(uuid.uuid4())
reporter_stub = self._stubs[node_id]
reply = await reporter_stub.GetProfilingStats(
reporter_pb2.GetProfilingStatsRequest(pid=pid, duration=duration))
self._profiling_stats[profiling_id] = reply
return await dashboard_utils.rest_response(
success=True,
message="Profiling launched.",
profiling_id=profiling_id)
@routes.get("/api/check_profiling_status")
async def check_profiling_status(self, req) -> aiohttp.web.Response:
profiling_id = req.query.get("profiling_id")
is_present = profiling_id in self._profiling_stats
if not is_present:
status = {"status": "pending"}
else:
reply = self._profiling_stats[profiling_id]
if reply.stderr:
status = {"status": "error", "error": reply.stderr}
else:
status = {"status": "finished"}
return await dashboard_utils.rest_response(
success=True, message="Profiling status fetched.", status=status)
@routes.get("/api/get_profiling_info")
async def get_profiling_info(self, req) -> aiohttp.web.Response:
profiling_id = req.query.get("profiling_id")
profiling_stats = self._profiling_stats.get(profiling_id)
assert profiling_stats, "profiling not finished"
return await dashboard_utils.rest_response(
success=True,
message="Profiling info fetched.",
profiling_info=json.loads(profiling_stats.profiling_stats))
async def run(self):
p = self._dashboard_head.aioredis_client
mpsc = Receiver()
reporter_key = "{}*".format(reporter_consts.REPORTER_PREFIX)
await p.psubscribe(mpsc.pattern(reporter_key))
logger.info("Subscribed to {}".format(reporter_key))
async for sender, msg in mpsc.iter():
try:
_, data = msg
data = json.loads(ray.utils.decode(data))
DataSource.node_physical_stats[data["ip"]] = data
except Exception as ex:
logger.exception(ex)
+340
View File
@@ -0,0 +1,340 @@
import abc
import asyncio
import collections
import copy
import json
import datetime
import functools
import importlib
import inspect
import logging
import pkgutil
import traceback
from base64 import b64decode
from collections.abc import MutableMapping, Mapping
import aiohttp.web
from aiohttp import hdrs
from aiohttp.frozenlist import FrozenList
import aiohttp.signals
from google.protobuf.json_format import MessageToDict
from ray.utils import binary_to_hex
logger = logging.getLogger(__name__)
class DashboardAgentModule(abc.ABC):
def __init__(self, dashboard_agent):
"""
Initialize current module when DashboardAgent loading modules.
:param dashboard_agent: The DashboardAgent instance.
"""
self._dashboard_agent = dashboard_agent
@abc.abstractmethod
async def run(self, server):
"""
Run the module in an asyncio loop. An agent module can provide
servicers to the server.
:param server: Asyncio GRPC server.
"""
class DashboardHeadModule(abc.ABC):
def __init__(self, dashboard_head):
"""
Initialize current module when DashboardHead loading modules.
:param dashboard_head: The DashboardHead instance.
"""
self._dashboard_head = dashboard_head
@abc.abstractmethod
async def run(self):
"""
Run the module in an asyncio loop.
"""
class ClassMethodRouteTable:
"""A helper class to bind http route to class method."""
_bind_map = collections.defaultdict(dict)
_routes = aiohttp.web.RouteTableDef()
class _BindInfo:
def __init__(self, filename, lineno, instance):
self.filename = filename
self.lineno = lineno
self.instance = instance
@classmethod
def routes(cls):
return cls._routes
@classmethod
def _register_route(cls, method, path, **kwargs):
def _wrapper(handler):
if path in cls._bind_map[method]:
bind_info = cls._bind_map[method][path]
raise Exception("Duplicated route path: {}, "
"previous one registered at {}:{}".format(
path, bind_info.filename,
bind_info.lineno))
bind_info = cls._BindInfo(handler.__code__.co_filename,
handler.__code__.co_firstlineno, None)
@functools.wraps(handler)
async def _handler_route(*args, **kwargs):
if len(args) and args[0] == bind_info.instance:
args = args[1:]
try:
return await handler(bind_info.instance, *args, **kwargs)
except Exception:
return await rest_response(
success=False, message=traceback.format_exc())
cls._bind_map[method][path] = bind_info
_handler_route.__route_method__ = method
_handler_route.__route_path__ = path
return cls._routes.route(method, path, **kwargs)(_handler_route)
return _wrapper
@classmethod
def head(cls, path, **kwargs):
return cls._register_route(hdrs.METH_HEAD, path, **kwargs)
@classmethod
def get(cls, path, **kwargs):
return cls._register_route(hdrs.METH_GET, path, **kwargs)
@classmethod
def post(cls, path, **kwargs):
return cls._register_route(hdrs.METH_POST, path, **kwargs)
@classmethod
def put(cls, path, **kwargs):
return cls._register_route(hdrs.METH_PUT, path, **kwargs)
@classmethod
def patch(cls, path, **kwargs):
return cls._register_route(hdrs.METH_PATCH, path, **kwargs)
@classmethod
def delete(cls, path, **kwargs):
return cls._register_route(hdrs.METH_DELETE, path, **kwargs)
@classmethod
def view(cls, path, **kwargs):
return cls._register_route(hdrs.METH_ANY, path, **kwargs)
@classmethod
def bind(cls, instance):
def predicate(o):
if inspect.ismethod(o):
return hasattr(o, "__route_method__") and hasattr(
o, "__route_path__")
return False
handler_routes = inspect.getmembers(instance, predicate)
for _, h in handler_routes:
cls._bind_map[h.__func__.__route_method__][
h.__func__.__route_path__].instance = instance
def get_all_modules(module_type):
logger.info("Get all modules by type: {}".format(module_type.__name__))
import ray.new_dashboard.modules
for module_loader, name, ispkg in pkgutil.walk_packages(
ray.new_dashboard.modules.__path__,
ray.new_dashboard.modules.__name__ + "."):
importlib.import_module(name)
return module_type.__subclasses__()
def to_posix_time(dt):
return (dt - datetime.datetime(1970, 1, 1)).total_seconds()
class CustomEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, bytes):
return binary_to_hex(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)
async def rest_response(success, message, **kwargs) -> aiohttp.web.Response:
return aiohttp.web.json_response(
{
"result": success,
"msg": message,
"data": to_google_style(kwargs)
},
dumps=functools.partial(json.dumps, cls=CustomEncoder))
def to_camel_case(snake_str):
"""Convert a snake str to camel case."""
components = snake_str.split("_")
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
return components[0] + "".join(x.title() for x in components[1:])
def to_google_style(d):
"""Recursive convert all keys in dict to google style."""
new_dict = {}
for k, v in d.items():
if isinstance(v, dict):
new_dict[to_camel_case(k)] = to_google_style(v)
elif isinstance(v, list):
new_list = []
for i in v:
if isinstance(i, dict):
new_list.append(to_google_style(i))
else:
new_list.append(i)
new_dict[to_camel_case(k)] = new_list
else:
new_dict[to_camel_case(k)] = v
return new_dict
def message_to_dict(message, decode_keys=None, **kwargs):
"""Convert protobuf message to Python dict."""
def _decode_keys(d):
for k, v in d.items():
if isinstance(v, dict):
d[k] = _decode_keys(v)
if isinstance(v, list):
new_list = []
for i in v:
if isinstance(i, dict):
new_list.append(_decode_keys(i))
else:
new_list.append(i)
d[k] = new_list
else:
if k in decode_keys:
d[k] = binary_to_hex(b64decode(v))
else:
d[k] = v
return d
if decode_keys:
return _decode_keys(
MessageToDict(message, use_integers_for_enums=False, **kwargs))
else:
return MessageToDict(message, use_integers_for_enums=False, **kwargs)
class SignalManager:
_signals = FrozenList()
@classmethod
def register(cls, sig):
cls._signals.append(sig)
@classmethod
def freeze(cls):
cls._signals.freeze()
for sig in cls._signals:
sig.freeze()
class Signal(aiohttp.signals.Signal):
__slots__ = ()
def __init__(self, owner):
super().__init__(owner)
SignalManager.register(self)
class Bunch(dict):
"""A dict with attribute-access."""
def __getattr__(self, key):
try:
return self.__getitem__(key)
except KeyError:
raise AttributeError(key)
def __setattr__(self, key, value):
self.__setitem__(key, value)
class Change:
"""Notify change object."""
def __init__(self, owner=None, old=None, new=None):
self.owner = owner
self.old = old
self.new = new
def __str__(self):
return "Change(owner: {}, old: {}, new: {}".format(
self.owner, self.old, self.new)
class NotifyQueue:
"""Asyncio notify queue for Dict signal."""
_queue = asyncio.Queue()
@classmethod
def put(cls, co):
cls._queue.put_nowait(co)
@classmethod
async def get(cls):
return await cls._queue.get()
class Dict(MutableMapping):
"""A simple descriptor for dict type to notify data changes.
:note: Only the first level data report change.
"""
def __init__(self, *args, **kwargs):
self._data = dict(*args, **kwargs)
self.signal = Signal(self)
def __setitem__(self, key, value):
old = self._data.pop(key, None)
self._data[key] = value
if len(self.signal) and old != value:
if old is None:
co = self.signal.send(Change(owner=self, new={key: value}))
else:
co = self.signal.send(
Change(owner=self, old={key: old}, new={key: value}))
NotifyQueue.put(co)
def __getitem__(self, item):
return copy.deepcopy(self._data[item])
def __delitem__(self, key):
old = self._data.pop(key, None)
if len(self.signal) and old is not None:
co = self.signal.send(Change(owner=self, old={key: old}))
NotifyQueue.put(co)
def __len__(self):
return len(self._data)
def __iter__(self):
return iter(copy.deepcopy(self._data))
def reset(self, d):
assert isinstance(d, Mapping)
for key in self._data.keys() - d.keys():
self.pop(key)
self.update(d)
+1
View File
@@ -0,0 +1 @@
../../dashboard
+1
View File
@@ -123,6 +123,7 @@ REMOVED_NODE_ERROR = "node_removed"
MONITOR_DIED_ERROR = "monitor_died"
LOG_MONITOR_DIED_ERROR = "log_monitor_died"
REPORTER_DIED_ERROR = "reporter_died"
DASHBOARD_AGENT_DIED_ERROR = "dashboard_agent_died"
DASHBOARD_DIED_ERROR = "dashboard_died"
RAYLET_CONNECTION_ERROR = "raylet_connection_error"
+2 -1
View File
@@ -301,13 +301,14 @@ def find_version(*filepath):
install_requires = [
"aiohttp",
"aioredis",
"click >= 7.0",
"colorama",
"colorful",
"filelock",
"google",
"gpustat",
"grpcio",
"grpcio >= 1.28.1",
"jsonschema",
"msgpack >= 0.6.0, < 2.0.0",
"numpy >= 1.16",