Dashboard minor refactor and first unit tests (#8705)

This commit is contained in:
Max Fitton
2020-06-03 09:04:55 -07:00
committed by GitHub
parent b37a162076
commit b9f0f7ae5b
9 changed files with 369 additions and 297 deletions
+13
View File
@@ -0,0 +1,13 @@
# This is a dummy test dependency that causes the above tests to be
# re-run if any of these files changes.
py_library(
name = "dashboard_lib",
srcs = glob(["**/*.py"],exclude=["tests/*"]),
)
py_test(
name = "test_node_stats",
size = "small",
srcs = glob(["tests/*.py"]),
deps = [":dashboard_lib"]
)
+3 -297
View File
@@ -12,19 +12,12 @@ import errno
import json
import logging
import os
import re
import socket
import threading
import time
import traceback
import yaml
import uuid
from base64 import b64decode
from collections import defaultdict
from operator import itemgetter
from typing import Dict
import grpc
from google.protobuf.json_format import MessageToDict
import ray
@@ -41,6 +34,8 @@ from ray.dashboard.interface import BaseDashboardRouteHandler
from ray.dashboard.memory import construct_memory_table, MemoryTable
from ray.dashboard.metrics_exporter.client import Exporter
from ray.dashboard.metrics_exporter.client import MetricsExportClient
from ray.dashboard.node_stats import NodeStats
from ray.dashboard.util import to_unix_time, measures_to_dict, format_resource
try:
from ray.tune import Analysis
@@ -54,54 +49,6 @@ except ImportError:
logger = logging.getLogger(__name__)
def to_unix_time(dt):
return (dt - datetime.datetime(1970, 1, 1)).total_seconds()
def round_resource_value(quantity):
if quantity.is_integer():
return int(quantity)
else:
return round(quantity, 2)
def format_resource(resource_name, quantity):
if resource_name == "object_store_memory" or resource_name == "memory":
# Convert to 50MiB chunks and then to GiB
quantity = quantity * (50 * 1024 * 1024) / (1024 * 1024 * 1024)
return "{} GiB".format(round_resource_value(quantity))
return "{}".format(round_resource_value(quantity))
def format_reply_id(reply):
if isinstance(reply, dict):
for k, v in reply.items():
if isinstance(v, dict) or isinstance(v, list):
format_reply_id(v)
else:
if k.endswith("Id"):
v = b64decode(v)
reply[k] = ray.utils.binary_to_hex(v)
elif isinstance(reply, list):
for item in reply:
format_reply_id(item)
def measures_to_dict(measures):
measures_dict = {}
for measure in measures:
tags = measure["tags"].split(",")[-1]
if "intValue" in measure:
measures_dict[tags] = measure["intValue"]
elif "doubleValue" in measure:
measures_dict[tags] = measure["doubleValue"]
return measures_dict
def b64_decode(reply):
return b64decode(reply).decode("utf-8")
async def json_response(is_dev, result=None, error=None,
ts=None) -> aiohttp.web.Response:
if ts is None:
@@ -629,247 +576,6 @@ class Dashboard:
aiohttp.web.run_app(self.app, host=self.host, port=self.port)
class NodeStats(threading.Thread):
def __init__(self, redis_address, redis_password=None):
self.redis_key = "{}.*".format(ray.gcs_utils.REPORTER_CHANNEL)
self.redis_client = ray.services.create_redis_client(
redis_address, password=redis_password)
self._node_stats = {}
self._addr_to_owner_addr = {}
self._addr_to_actor_id = {}
self._addr_to_extra_info_dict = {}
self._node_stats_lock = threading.Lock()
self._default_info = {
"actorId": "",
"children": {},
"currentTaskFuncDesc": [],
"ipAddress": "",
"jobId": "",
"numExecutedTasks": 0,
"numLocalObjects": 0,
"numObjectIdsInScope": 0,
"port": 0,
"state": 0,
"taskQueueLength": 0,
"usedObjectStoreMemory": 0,
"usedResources": {},
}
# Mapping from IP address to PID to list of log lines
self._logs = defaultdict(lambda: defaultdict(list))
# Mapping from IP address to PID to list of error messages
self._errors = defaultdict(lambda: defaultdict(list))
ray.state.state._initialize_global_state(
redis_address=redis_address, redis_password=redis_password)
super().__init__()
def _calculate_log_counts(self):
return {
ip: {
pid: len(logs_for_pid)
for pid, logs_for_pid in logs_for_ip.items()
}
for ip, logs_for_ip in self._logs.items()
}
def _calculate_error_counts(self):
return {
ip: {
pid: len(errors_for_pid)
for pid, errors_for_pid in errors_for_ip.items()
}
for ip, errors_for_ip in self._errors.items()
}
def _purge_outdated_stats(self):
def current(then, now):
if (now - then) > 5:
return False
return True
now = to_unix_time(datetime.datetime.utcnow())
self._node_stats = {
k: v
for k, v in self._node_stats.items() if current(v["now"], now)
}
def get_node_stats(self) -> Dict:
with self._node_stats_lock:
self._purge_outdated_stats()
node_stats = sorted(
(v for v in self._node_stats.values()),
key=itemgetter("boot_time"))
return {
"clients": node_stats,
"log_counts": self._calculate_log_counts(),
"error_counts": self._calculate_error_counts(),
}
def get_actor_tree(self, workers_info_by_node, infeasible_tasks,
ready_tasks) -> Dict:
now = time.time()
# construct flattened actor tree
flattened_tree = {"root": {"children": {}}}
child_to_parent = {}
with self._node_stats_lock:
for addr, actor_id in self._addr_to_actor_id.items():
flattened_tree[actor_id] = copy.deepcopy(self._default_info)
flattened_tree[actor_id].update(
self._addr_to_extra_info_dict[addr])
parent_id = self._addr_to_actor_id.get(
self._addr_to_owner_addr[addr], "root")
child_to_parent[actor_id] = parent_id
for node_id, workers_info in workers_info_by_node.items():
for worker_info in workers_info:
if "coreWorkerStats" in worker_info:
core_worker_stats = worker_info["coreWorkerStats"]
addr = (core_worker_stats["ipAddress"],
str(core_worker_stats["port"]))
if addr in self._addr_to_actor_id:
actor_info = flattened_tree[self._addr_to_actor_id[
addr]]
format_reply_id(core_worker_stats)
actor_info.update(core_worker_stats)
actor_info["averageTaskExecutionSpeed"] = round(
actor_info["numExecutedTasks"] /
(now - actor_info["timestamp"] / 1000), 2)
actor_info["nodeId"] = node_id
actor_info["pid"] = worker_info["pid"]
def _update_flatten_tree(task, task_spec_type, invalid_state_type):
actor_id = ray.utils.binary_to_hex(
b64decode(task[task_spec_type]["actorId"]))
caller_addr = (task["callerAddress"]["ipAddress"],
str(task["callerAddress"]["port"]))
caller_id = self._addr_to_actor_id.get(caller_addr, "root")
child_to_parent[actor_id] = caller_id
task["state"] = -1
task["invalidStateType"] = invalid_state_type
task["actorTitle"] = task["functionDescriptor"][
"pythonFunctionDescriptor"]["className"]
format_reply_id(task)
flattened_tree[actor_id] = task
for infeasible_task in infeasible_tasks:
_update_flatten_tree(infeasible_task, "actorCreationTaskSpec",
"infeasibleActor")
for ready_task in ready_tasks:
_update_flatten_tree(ready_task, "actorCreationTaskSpec",
"pendingActor")
# construct actor tree
actor_tree = flattened_tree
for actor_id, parent_id in child_to_parent.items():
actor_tree[parent_id]["children"][actor_id] = actor_tree[actor_id]
return actor_tree["root"]["children"]
def get_logs(self, hostname, pid):
ip = self._node_stats.get(hostname, {"ip": None})["ip"]
logs = self._logs.get(ip, {})
if pid:
logs = {pid: logs.get(pid, [])}
return logs
def get_errors(self, hostname, pid):
ip = self._node_stats.get(hostname, {"ip": None})["ip"]
errors = self._errors.get(ip, {})
if pid:
errors = {pid: errors.get(pid, [])}
return errors
def run(self):
p = self.redis_client.pubsub(ignore_subscribe_messages=True)
p.psubscribe(self.redis_key)
logger.info("NodeStats: subscribed to {}".format(self.redis_key))
log_channel = ray.gcs_utils.LOG_FILE_CHANNEL
p.subscribe(log_channel)
logger.info("NodeStats: subscribed to {}".format(log_channel))
error_channel = ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")
p.subscribe(error_channel)
logger.info("NodeStats: subscribed to {}".format(error_channel))
actor_channel = ray.gcs_utils.TablePubsub.Value("ACTOR_PUBSUB")
p.subscribe(actor_channel)
logger.info("NodeStats: subscribed to {}".format(actor_channel))
current_actor_table = ray.actors()
with self._node_stats_lock:
for actor_data in current_actor_table.values():
addr = (actor_data["Address"]["IPAddress"],
str(actor_data["Address"]["Port"]))
owner_addr = (actor_data["OwnerAddress"]["IPAddress"],
str(actor_data["OwnerAddress"]["Port"]))
self._addr_to_owner_addr[addr] = owner_addr
self._addr_to_actor_id[addr] = actor_data["ActorID"]
self._addr_to_extra_info_dict[addr] = {
"jobId": actor_data["JobID"],
"state": actor_data["State"],
"timestamp": actor_data["Timestamp"]
}
for x in p.listen():
try:
with self._node_stats_lock:
channel = ray.utils.decode(x["channel"])
data = x["data"]
if channel == log_channel:
data = json.loads(ray.utils.decode(data))
ip = data["ip"]
pid = str(data["pid"])
self._logs[ip][pid].extend(data["lines"])
elif channel == str(error_channel):
gcs_entry = ray.gcs_utils.GcsEntry.FromString(data)
error_data = ray.gcs_utils.ErrorTableData.FromString(
gcs_entry.entries[0])
message = error_data.error_message
message = re.sub(r"\x1b\[\d+m", "", message)
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
if match:
pid = match.group(1)
ip = match.group(2)
self._errors[ip][pid].append({
"message": message,
"timestamp": error_data.timestamp,
"type": error_data.type
})
elif channel == str(actor_channel):
gcs_entry = ray.gcs_utils.PubSubMessage.FromString(
data)
actor_data = ray.gcs_utils.ActorTableData.FromString(
gcs_entry.entries[0])
addr = (actor_data.address.ip_address,
str(actor_data.address.port))
owner_addr = (actor_data.owner_address.ip_address,
str(actor_data.owner_address.port))
self._addr_to_owner_addr[addr] = owner_addr
self._addr_to_actor_id[addr] = ray.utils.binary_to_hex(
actor_data.actor_id)
self._addr_to_extra_info_dict[addr] = {
"jobId": ray.utils.binary_to_hex(
actor_data.job_id),
"state": actor_data.state,
"timestamp": actor_data.timestamp
}
else:
data = json.loads(ray.utils.decode(data))
self._node_stats[data["hostname"]] = data
except Exception:
logger.exception(traceback.format_exc())
continue
class RayletStats(threading.Thread):
def __init__(self, redis_address, redis_password=None):
self.nodes_lock = threading.Lock()
@@ -927,7 +633,7 @@ class RayletStats(threading.Thread):
self.reporter_stubs), (self.stubs.keys(),
self.reporter_stubs.keys())
def get_raylet_stats(self) -> Dict:
def get_raylet_stats(self):
with self._raylet_stats_lock:
return copy.deepcopy(self._raylet_stats)
+257
View File
@@ -0,0 +1,257 @@
from collections import defaultdict
from ray.dashboard.util import to_unix_time, format_reply_id
from base64 import b64decode
import ray
import threading
import json
import traceback
import copy
import logging
import datetime
import time
import re
from operator import itemgetter
logger = logging.getLogger(__name__)
class NodeStats(threading.Thread):
def __init__(self, redis_address, redis_password=None):
self.redis_key = "{}.*".format(ray.gcs_utils.REPORTER_CHANNEL)
self.redis_client = ray.services.create_redis_client(
redis_address, password=redis_password)
self._node_stats = {}
self._addr_to_owner_addr = {}
self._addr_to_actor_id = {}
self._addr_to_extra_info_dict = {}
self._node_stats_lock = threading.Lock()
self._default_info = {
"actorId": "",
"children": {},
"currentTaskFuncDesc": [],
"ipAddress": "",
"jobId": "",
"numExecutedTasks": 0,
"numLocalObjects": 0,
"numObjectIdsInScope": 0,
"port": 0,
"state": 0,
"taskQueueLength": 0,
"usedObjectStoreMemory": 0,
"usedResources": {},
}
# Mapping from IP address to PID to list of log lines
self._logs = defaultdict(lambda: defaultdict(list))
# Mapping from IP address to PID to list of error messages
self._errors = defaultdict(lambda: defaultdict(list))
ray.state.state._initialize_global_state(
redis_address=redis_address, redis_password=redis_password)
super().__init__()
def _calculate_log_counts(self):
return {
ip: {
pid: len(logs_for_pid)
for pid, logs_for_pid in logs_for_ip.items()
}
for ip, logs_for_ip in self._logs.items()
}
def _calculate_error_counts(self):
return {
ip: {
pid: len(errors_for_pid)
for pid, errors_for_pid in errors_for_ip.items()
}
for ip, errors_for_ip in self._errors.items()
}
def _purge_outdated_stats(self):
def current(then, now):
if (now - then) > 5:
return False
return True
now = to_unix_time(datetime.datetime.utcnow())
self._node_stats = {
k: v
for k, v in self._node_stats.items() if current(v["now"], now)
}
def get_node_stats(self):
with self._node_stats_lock:
self._purge_outdated_stats()
node_stats = sorted(
(v for v in self._node_stats.values()),
key=itemgetter("boot_time"))
return {
"clients": node_stats,
"log_counts": self._calculate_log_counts(),
"error_counts": self._calculate_error_counts(),
}
def get_actor_tree(self, workers_info_by_node, infeasible_tasks,
ready_tasks):
now = time.time()
# construct flattened actor tree
flattened_tree = {"root": {"children": {}}}
child_to_parent = {}
with self._node_stats_lock:
for addr, actor_id in self._addr_to_actor_id.items():
flattened_tree[actor_id] = copy.deepcopy(self._default_info)
flattened_tree[actor_id].update(
self._addr_to_extra_info_dict[addr])
parent_id = self._addr_to_actor_id.get(
self._addr_to_owner_addr[addr], "root")
child_to_parent[actor_id] = parent_id
for node_id, workers_info in workers_info_by_node.items():
for worker_info in workers_info:
if "coreWorkerStats" in worker_info:
core_worker_stats = worker_info["coreWorkerStats"]
addr = (core_worker_stats["ipAddress"],
str(core_worker_stats["port"]))
if addr in self._addr_to_actor_id:
actor_info = flattened_tree[self._addr_to_actor_id[
addr]]
format_reply_id(core_worker_stats)
actor_info.update(core_worker_stats)
actor_info["averageTaskExecutionSpeed"] = round(
actor_info["numExecutedTasks"] /
(now - actor_info["timestamp"] / 1000), 2)
actor_info["nodeId"] = node_id
actor_info["pid"] = worker_info["pid"]
def _update_flatten_tree(task, task_spec_type, invalid_state_type):
actor_id = ray.utils.binary_to_hex(
b64decode(task[task_spec_type]["actorId"]))
caller_addr = (task["callerAddress"]["ipAddress"],
str(task["callerAddress"]["port"]))
caller_id = self._addr_to_actor_id.get(caller_addr, "root")
child_to_parent[actor_id] = caller_id
task["state"] = -1
task["invalidStateType"] = invalid_state_type
task["actorTitle"] = task["functionDescriptor"][
"pythonFunctionDescriptor"]["className"]
format_reply_id(task)
flattened_tree[actor_id] = task
for infeasible_task in infeasible_tasks:
_update_flatten_tree(infeasible_task, "actorCreationTaskSpec",
"infeasibleActor")
for ready_task in ready_tasks:
_update_flatten_tree(ready_task, "actorCreationTaskSpec",
"pendingActor")
# construct actor tree
actor_tree = flattened_tree
for actor_id, parent_id in child_to_parent.items():
actor_tree[parent_id]["children"][actor_id] = actor_tree[actor_id]
return actor_tree["root"]["children"]
def get_logs(self, hostname, pid):
ip = self._node_stats.get(hostname, {"ip": None})["ip"]
logs = self._logs.get(ip, {})
if pid:
logs = {pid: logs.get(pid, [])}
return logs
def get_errors(self, hostname, pid):
ip = self._node_stats.get(hostname, {"ip": None})["ip"]
errors = self._errors.get(ip, {})
if pid:
errors = {pid: errors.get(pid, [])}
return errors
def run(self):
p = self.redis_client.pubsub(ignore_subscribe_messages=True)
p.psubscribe(self.redis_key)
logger.info("NodeStats: subscribed to {}".format(self.redis_key))
log_channel = ray.gcs_utils.LOG_FILE_CHANNEL
p.subscribe(log_channel)
logger.info("NodeStats: subscribed to {}".format(log_channel))
error_channel = ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")
p.subscribe(error_channel)
logger.info("NodeStats: subscribed to {}".format(error_channel))
actor_channel = ray.gcs_utils.TablePubsub.Value("ACTOR_PUBSUB")
p.subscribe(actor_channel)
logger.info("NodeStats: subscribed to {}".format(actor_channel))
current_actor_table = ray.actors()
with self._node_stats_lock:
for actor_data in current_actor_table.values():
addr = (actor_data["Address"]["IPAddress"],
str(actor_data["Address"]["Port"]))
owner_addr = (actor_data["OwnerAddress"]["IPAddress"],
str(actor_data["OwnerAddress"]["Port"]))
self._addr_to_owner_addr[addr] = owner_addr
self._addr_to_actor_id[addr] = actor_data["ActorID"]
self._addr_to_extra_info_dict[addr] = {
"jobId": actor_data["JobID"],
"state": actor_data["State"],
"timestamp": actor_data["Timestamp"]
}
for x in p.listen():
try:
with self._node_stats_lock:
channel = ray.utils.decode(x["channel"])
data = x["data"]
if channel == log_channel:
data = json.loads(ray.utils.decode(data))
ip = data["ip"]
pid = str(data["pid"])
self._logs[ip][pid].extend(data["lines"])
elif channel == str(error_channel):
gcs_entry = ray.gcs_utils.GcsEntry.FromString(data)
error_data = ray.gcs_utils.ErrorTableData.FromString(
gcs_entry.entries[0])
message = error_data.error_message
message = re.sub(r"\x1b\[\d+m", "", message)
match = re.search(r"\(pid=(\d+), ip=(.*?)\)", message)
if match:
pid = match.group(1)
ip = match.group(2)
self._errors[ip][pid].append({
"message": message,
"timestamp": error_data.timestamp,
"type": error_data.type
})
elif channel == str(actor_channel):
gcs_entry = ray.gcs_utils.PubSubMessage.FromString(
data)
actor_data = ray.gcs_utils.ActorTableData.FromString(
gcs_entry.entries[0])
addr = (actor_data.address.ip_address,
str(actor_data.address.port))
owner_addr = (actor_data.owner_address.ip_address,
str(actor_data.owner_address.port))
self._addr_to_owner_addr[addr] = owner_addr
self._addr_to_actor_id[addr] = ray.utils.binary_to_hex(
actor_data.actor_id)
self._addr_to_extra_info_dict[addr] = {
"jobId": ray.utils.binary_to_hex(
actor_data.job_id),
"state": actor_data.state,
"timestamp": actor_data.timestamp
}
else:
data = json.loads(ray.utils.decode(data))
self._node_stats[data["hostname"]] = data
except Exception:
logger.exception(traceback.format_exc())
continue
+1
View File
@@ -0,0 +1 @@
from ray.tests.conftest import * # noqa
@@ -0,0 +1,32 @@
from ray.dashboard.node_stats import NodeStats
from ray.ray_constants import REDIS_DEFAULT_PASSWORD
from datetime import datetime
from time import sleep
import pytest
def test_basic(ray_start_with_dashboard):
"""Dashboard test that starts a Ray cluster with a dashboard server running,
then hits the dashboard API and asserts that it receives sensible data."""
redis_address = ray_start_with_dashboard["redis_address"]
redis_password = REDIS_DEFAULT_PASSWORD
node_stats = NodeStats(redis_address, redis_password)
node_stats.start()
# Wait for node stats to fire up.
MAX_START_TIME_S = 30
t_start = datetime.now()
while True:
try:
stats = node_stats.get_node_stats()
client_stats = stats and stats.get("clients")
if not client_stats:
sleep(3)
if (datetime.now() - t_start).seconds > MAX_START_TIME_S:
pytest.fail("Node stats took too long to start up")
continue
break
except Exception:
continue
assert len(client_stats) == 1
client = client_stats[0]
assert len(client["workers"]) == 1
+47
View File
@@ -0,0 +1,47 @@
from base64 import b64decode
import datetime
import ray
def to_unix_time(dt):
return (dt - datetime.datetime(1970, 1, 1)).total_seconds()
def round_resource_value(quantity):
if quantity.is_integer():
return int(quantity)
else:
return round(quantity, 2)
def format_reply_id(reply):
if isinstance(reply, dict):
for k, v in reply.items():
if isinstance(v, dict) or isinstance(v, list):
format_reply_id(v)
else:
if k.endswith("Id"):
v = b64decode(v)
reply[k] = ray.utils.binary_to_hex(v)
elif isinstance(reply, list):
for item in reply:
format_reply_id(item)
def format_resource(resource_name, quantity):
if resource_name == "object_store_memory" or resource_name == "memory":
# Convert to 50MiB chunks and then to GiB
quantity = quantity * (50 * 1024 * 1024) / (1024 * 1024 * 1024)
return "{} GiB".format(round_resource_value(quantity))
return "{}".format(round_resource_value(quantity))
def measures_to_dict(measures):
measures_dict = {}
for measure in measures:
tags = measure["tags"].split(",")[-1]
if "intValue" in measure:
measures_dict[tags] = measure["intValue"]
elif "doubleValue" in measure:
measures_dict[tags] = measure["doubleValue"]
return measures_dict
+7
View File
@@ -49,6 +49,13 @@ def _ray_start(**kwargs):
ray.shutdown()
@pytest.fixture
def ray_start_with_dashboard(request):
param = getattr(request, "param", {})
with _ray_start(num_cpus=1, include_webui=True, **param) as address_info:
yield address_info
# The following fixture will start ray with 0 cpu.
@pytest.fixture
def ray_start_no_cpu(request):