mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[Dashboard] Ray memory dashboard backend (#8461)
This commit is contained in:
@@ -38,6 +38,7 @@ from ray.core.generated import core_worker_pb2
|
||||
from ray.core.generated import core_worker_pb2_grpc
|
||||
from ray.dashboard.interface import BaseDashboardController
|
||||
from ray.dashboard.interface import BaseDashboardRouteHandler
|
||||
from ray.dashboard.memory import construct_memory_table, MemoryTable
|
||||
from ray.dashboard.metrics_exporter.client import Exporter
|
||||
from ray.dashboard.metrics_exporter.client import MetricsExportClient
|
||||
|
||||
@@ -126,6 +127,7 @@ class DashboardController(BaseDashboardController):
|
||||
redis_address, redis_password=redis_password)
|
||||
if Analysis is not None:
|
||||
self.tune_stats = TuneCollector(2.0)
|
||||
self.memory_table = MemoryTable([])
|
||||
|
||||
def _construct_raylet_info(self):
|
||||
D = self.raylet_stats.get_raylet_stats()
|
||||
@@ -133,6 +135,7 @@ class DashboardController(BaseDashboardController):
|
||||
data["nodeId"]: data.get("workersStats")
|
||||
for data in D.values()
|
||||
}
|
||||
|
||||
infeasible_tasks = sum(
|
||||
(data.get("infeasibleTasks", []) for data in D.values()), [])
|
||||
# ready_tasks are used to render tasks that are not schedulable
|
||||
@@ -142,6 +145,7 @@ class DashboardController(BaseDashboardController):
|
||||
[])
|
||||
actor_tree = self.node_stats.get_actor_tree(
|
||||
workers_info_by_node, infeasible_tasks, ready_tasks)
|
||||
|
||||
for address, data in D.items():
|
||||
# process view data
|
||||
measures_dicts = {}
|
||||
@@ -224,6 +228,23 @@ class DashboardController(BaseDashboardController):
|
||||
def get_raylet_info(self):
|
||||
return self._construct_raylet_info()
|
||||
|
||||
def get_memory_table_info(self) -> MemoryTable:
|
||||
# Collecting memory info adds big overhead to the cluster.
|
||||
# This must be collected only when it is necessary.
|
||||
self.raylet_stats.include_memory_info = True
|
||||
D = self.raylet_stats.get_raylet_stats()
|
||||
workers_info_by_node = {
|
||||
data["nodeId"]: data.get("workersStats")
|
||||
for data in D.values()
|
||||
}
|
||||
self.memory_table = construct_memory_table(workers_info_by_node)
|
||||
return self.memory_table
|
||||
|
||||
def stop_collecting_memory_table_info(self):
|
||||
# Reset memory table.
|
||||
self.memory_table = MemoryTable([])
|
||||
self.raylet_stats.include_memory_info = False
|
||||
|
||||
def tune_info(self):
|
||||
if Analysis is not None:
|
||||
D = self.tune_stats.get_stats()
|
||||
@@ -313,6 +334,15 @@ class DashboardRouteHandler(BaseDashboardRouteHandler):
|
||||
result = self.dashboard_controller.get_raylet_info()
|
||||
return await json_response(self.is_dev, result=result)
|
||||
|
||||
async def memory_table_info(self, req) -> aiohttp.web.Response:
|
||||
memory_table = self.dashboard_controller.get_memory_table_info()
|
||||
return await json_response(self.is_dev, result=memory_table.__dict__())
|
||||
|
||||
async def stop_collecting_memory_table_info(self,
|
||||
req) -> aiohttp.web.Response:
|
||||
self.dashboard_controller.stop_collecting_memory_table_info()
|
||||
return await json_response(self.is_dev, result={})
|
||||
|
||||
async def tune_info(self, req) -> aiohttp.web.Response:
|
||||
result = self.dashboard_controller.tune_info()
|
||||
return await json_response(self.is_dev, result=result)
|
||||
@@ -462,7 +492,9 @@ def setup_dashboard_route(app: aiohttp.web.Application,
|
||||
get_profiling_info=None,
|
||||
kill_actor=None,
|
||||
logs=None,
|
||||
errors=None):
|
||||
errors=None,
|
||||
memory_table=None,
|
||||
stop_memory_table=None):
|
||||
def add_get_route(route, handler_func):
|
||||
if route is not None:
|
||||
app.router.add_get(route, handler_func)
|
||||
@@ -480,6 +512,8 @@ def setup_dashboard_route(app: aiohttp.web.Application,
|
||||
add_get_route(kill_actor, handler.kill_actor)
|
||||
add_get_route(logs, handler.logs)
|
||||
add_get_route(errors, handler.errors)
|
||||
add_get_route(memory_table, handler.memory_table_info)
|
||||
add_get_route(stop_memory_table, handler.stop_collecting_memory_table_info)
|
||||
|
||||
|
||||
class Dashboard:
|
||||
@@ -548,7 +582,9 @@ class Dashboard:
|
||||
get_profiling_info="/api/get_profiling_info",
|
||||
kill_actor="/api/kill_actor",
|
||||
logs="/api/logs",
|
||||
errors="/api/errors")
|
||||
errors="/api/errors",
|
||||
memory_table="/api/memory_table",
|
||||
stop_memory_table="/api/stop_memory_table")
|
||||
self.app.router.add_get("/{_}", route_handler.get_forbidden)
|
||||
self.app.router.add_post("/api/set_tune_experiment",
|
||||
route_handler.set_tune_experiment)
|
||||
@@ -849,6 +885,7 @@ class RayletStats(threading.Thread):
|
||||
self._profiling_stats = {}
|
||||
|
||||
self._update_nodes()
|
||||
self.include_memory_info = False
|
||||
|
||||
super().__init__()
|
||||
|
||||
@@ -950,7 +987,9 @@ class RayletStats(threading.Thread):
|
||||
node_id = node["NodeID"]
|
||||
stub = self.stubs[node_id]
|
||||
reply = stub.GetNodeStats(
|
||||
node_manager_pb2.GetNodeStatsRequest(), timeout=2)
|
||||
node_manager_pb2.GetNodeStatsRequest(
|
||||
include_memory_info=self.include_memory_info),
|
||||
timeout=2)
|
||||
reply_dict = MessageToDict(reply)
|
||||
reply_dict["nodeId"] = node_id
|
||||
replies[node["NodeManagerAddress"]] = reply_dict
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
import base64
|
||||
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import ray
|
||||
|
||||
from ray._raylet import (TaskID, ActorID, JobID)
|
||||
|
||||
# These values are used to calculate if objectIDs are actor handles.
|
||||
TASKID_BYTES_SIZE = TaskID.size()
|
||||
ACTORID_BYTES_SIZE = ActorID.size()
|
||||
JOBID_BYTES_SIZE = JobID.size()
|
||||
# We need to multiply 2 because we need bits size instead of bytes size.
|
||||
TASKID_RANDOM_BITS_SIZE = (TASKID_BYTES_SIZE - ACTORID_BYTES_SIZE) * 2
|
||||
ACTORID_RANDOM_BITS_SIZE = (ACTORID_BYTES_SIZE - JOBID_BYTES_SIZE) * 2
|
||||
|
||||
|
||||
def decode_object_id_if_needed(object_id: str) -> bytes:
|
||||
"""Decode objectID bytes string.
|
||||
|
||||
gRPC reply contains an objectID that is encodded by Base64.
|
||||
This function is used to decode the objectID.
|
||||
Note that there are times that objectID is already decoded as
|
||||
a hex string. In this case, just convert it to a binary number.
|
||||
"""
|
||||
if object_id.endswith("="):
|
||||
# If the object id ends with =, that means it is base64 encoded.
|
||||
# Object ids will always have = as a padding
|
||||
# when it is base64 encoded because objectID is always 20B.
|
||||
return base64.standard_b64decode(object_id)
|
||||
else:
|
||||
return ray.utils.hex_to_binary(object_id)
|
||||
|
||||
|
||||
class SortingType(Enum):
|
||||
PID = 1
|
||||
OBJECT_SIZE = 3
|
||||
REFERENCE_TYPE = 4
|
||||
|
||||
|
||||
class GroupByType(Enum):
|
||||
NODE_ADDRESS = 2
|
||||
|
||||
|
||||
class ReferenceType:
|
||||
# We don't use enum because enum is not json serializable.
|
||||
ACTOR_HANDLE = "ACTOR_HANDLE"
|
||||
PINNED_IN_MEMORY = "PINNED_IN_MEMORY"
|
||||
LOCAL_REFERENCE = "LOCAL_REFERENCE"
|
||||
USED_BY_PENDING_TASK = "USED_BY_PENDING_TASK"
|
||||
CAPTURED_IN_OBJECT = "CAPTURED_IN_OBJECT"
|
||||
UNKNOWN_STATUS = "UNKNOWN_STATUS"
|
||||
|
||||
|
||||
class MemoryTableEntry:
|
||||
def __init__(self, *, object_ref: dict, node_address: str, is_driver: bool,
|
||||
pid: int):
|
||||
# worker info
|
||||
self.is_driver = is_driver
|
||||
self.pid = pid
|
||||
self.node_address = node_address
|
||||
|
||||
# object info
|
||||
self.object_size = int(object_ref.get("objectSize", -1))
|
||||
self.call_site = object_ref.get("callSite", "<Unknown>")
|
||||
self.object_id = ray.ObjectID(
|
||||
decode_object_id_if_needed(object_ref["objectId"]))
|
||||
|
||||
# reference info
|
||||
self.local_ref_count = int(object_ref.get("localRefCount", 0))
|
||||
self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False))
|
||||
self.submitted_task_ref_count = int(
|
||||
object_ref.get("submittedTaskRefCount", 0))
|
||||
self.contained_in_owned = [
|
||||
ray.ObjectID(decode_object_id_if_needed(object_id))
|
||||
for object_id in object_ref.get("containedInOwned", [])
|
||||
]
|
||||
self.reference_type = self._get_reference_type()
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
# If the entry doesn't have a reference type or some invalid state,
|
||||
# (e.g., no object ID presented), it is considered invalid.
|
||||
if (not self.pinned_in_memory and self.local_ref_count == 0
|
||||
and self.submitted_task_ref_count == 0
|
||||
and len(self.contained_in_owned) == 0):
|
||||
return False
|
||||
elif self.object_id.is_nil():
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def group_key(self, group_by_type: GroupByType) -> str:
|
||||
if group_by_type == GroupByType.NODE_ADDRESS:
|
||||
return self.node_address
|
||||
else:
|
||||
raise ValueError(
|
||||
"group by type {} is invalid.".format(group_by_type))
|
||||
|
||||
def _get_reference_type(self) -> str:
|
||||
if self._is_object_id_actor_handle():
|
||||
return ReferenceType.ACTOR_HANDLE
|
||||
if self.pinned_in_memory:
|
||||
return ReferenceType.PINNED_IN_MEMORY
|
||||
elif self.submitted_task_ref_count > 0:
|
||||
return ReferenceType.USED_BY_PENDING_TASK
|
||||
elif self.local_ref_count > 0:
|
||||
return ReferenceType.LOCAL_REFERENCE
|
||||
elif len(self.contained_in_owned) > 0:
|
||||
return ReferenceType.CAPTURED_IN_OBJECT
|
||||
else:
|
||||
return ReferenceType.UNKNOWN_STATUS
|
||||
|
||||
def _is_object_id_actor_handle(self) -> bool:
|
||||
object_id_hex = self.object_id.hex()
|
||||
|
||||
# random (8B) | ActorID(6B) | flag (2B) | index (6B)
|
||||
# ActorID(6B) == ActorRandomByte(4B) + JobID(2B)
|
||||
# If random bytes are all 'f', but ActorRandomBytes
|
||||
# are not all 'f', that means it is an actor creation
|
||||
# task, which is an actor handle.
|
||||
random_bits = object_id_hex[:TASKID_RANDOM_BITS_SIZE]
|
||||
actor_random_bits = object_id_hex[TASKID_RANDOM_BITS_SIZE:
|
||||
TASKID_RANDOM_BITS_SIZE +
|
||||
ACTORID_RANDOM_BITS_SIZE]
|
||||
if (random_bits == "f" * 16 and not actor_random_bits == "f" * 8):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def __dict__(self):
|
||||
return {
|
||||
"object_id": self.object_id.hex(),
|
||||
"pid": self.pid,
|
||||
"node_ip_address": self.node_address,
|
||||
"object_size": self.object_size,
|
||||
"reference_type": self.reference_type,
|
||||
"call_site": self.call_site,
|
||||
"local_ref_count": self.local_ref_count,
|
||||
"pinned_in_memory": self.pinned_in_memory,
|
||||
"submitted_task_ref_count": self.submitted_task_ref_count,
|
||||
"contained_in_owned": [
|
||||
object_id.hex() for object_id in self.contained_in_owned
|
||||
],
|
||||
"type": "Driver" if self.is_driver else "Worker"
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.__dict__())
|
||||
|
||||
|
||||
class MemoryTable:
|
||||
def __init__(self,
|
||||
entries: List[MemoryTableEntry],
|
||||
group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
|
||||
sort_by_type: SortingType = SortingType.PID):
|
||||
self.table = entries
|
||||
# Group is a list of memory tables grouped by a group key.
|
||||
self.group = {}
|
||||
self.summary = defaultdict(int)
|
||||
if group_by_type and sort_by_type:
|
||||
self.setup(group_by_type, sort_by_type)
|
||||
elif group_by_type:
|
||||
self._group_by(group_by_type)
|
||||
elif sort_by_type:
|
||||
self._sort_by(sort_by_type)
|
||||
|
||||
def setup(self, group_by_type: GroupByType, sort_by_type: SortingType):
|
||||
"""Setup memory table.
|
||||
|
||||
This will sort entries first and gruop them after.
|
||||
Sort order will be still kept.
|
||||
"""
|
||||
self._sort_by(sort_by_type)._group_by(group_by_type)
|
||||
for group_memory_table in self.group.values():
|
||||
group_memory_table.summarize()
|
||||
self.summarize()
|
||||
return self
|
||||
|
||||
def insert_entry(self, entry: MemoryTableEntry):
|
||||
self.table.append(entry)
|
||||
|
||||
def summarize(self):
|
||||
# Reset summary.
|
||||
total_object_size = 0
|
||||
total_local_ref_count = 0
|
||||
total_pinned_in_memory = 0
|
||||
total_used_by_pending_task = 0
|
||||
total_captured_in_objects = 0
|
||||
total_actor_handles = 0
|
||||
|
||||
for entry in self.table:
|
||||
if entry.object_size > 0:
|
||||
total_object_size += entry.object_size
|
||||
if entry.reference_type == ReferenceType.LOCAL_REFERENCE:
|
||||
total_local_ref_count += 1
|
||||
elif entry.reference_type == ReferenceType.PINNED_IN_MEMORY:
|
||||
total_pinned_in_memory += 1
|
||||
elif entry.reference_type == ReferenceType.USED_BY_PENDING_TASK:
|
||||
total_used_by_pending_task += 1
|
||||
elif entry.reference_type == ReferenceType.CAPTURED_IN_OBJECT:
|
||||
total_captured_in_objects += 1
|
||||
elif entry.reference_type == ReferenceType.ACTOR_HANDLE:
|
||||
total_actor_handles += 1
|
||||
|
||||
self.summary = {
|
||||
"total_object_size": total_object_size,
|
||||
"total_local_ref_count": total_local_ref_count,
|
||||
"total_pinned_in_memory": total_pinned_in_memory,
|
||||
"total_used_by_pending_task": total_used_by_pending_task,
|
||||
"total_captured_in_objects": total_captured_in_objects,
|
||||
"total_actor_handles": total_actor_handles
|
||||
}
|
||||
return self
|
||||
|
||||
def _sort_by(self, sorting_type: SortingType):
|
||||
if sorting_type == SortingType.PID:
|
||||
self.table.sort(key=lambda entry: entry.pid)
|
||||
elif sorting_type == SortingType.OBJECT_SIZE:
|
||||
self.table.sort(key=lambda entry: entry.object_size)
|
||||
elif sorting_type == SortingType.REFERENCE_TYPE:
|
||||
self.table.sort(key=lambda entry: entry.reference_type)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Give sorting type: {} is invalid.".format(sorting_type))
|
||||
return self
|
||||
|
||||
def _group_by(self, group_by_type: GroupByType):
|
||||
"""Group entries and summarize the result.
|
||||
|
||||
NOTE: Each group is another MemoryTable.
|
||||
"""
|
||||
# Reset group
|
||||
self.group = {}
|
||||
|
||||
# Build entries per group.
|
||||
group = defaultdict(list)
|
||||
for entry in self.table:
|
||||
group[entry.group_key(group_by_type)].append(entry)
|
||||
|
||||
# Build a group table.
|
||||
for group_key, entries in group.items():
|
||||
self.group[group_key] = MemoryTable(
|
||||
entries, group_by_type=None, sort_by_type=None)
|
||||
for group_key, group_memory_table in self.group.items():
|
||||
group_memory_table.summarize()
|
||||
return self
|
||||
|
||||
def __dict__(self):
|
||||
return {
|
||||
"summary": self.summary,
|
||||
"group": {
|
||||
group_key: {
|
||||
"entries": group_memory_table.get_entries(),
|
||||
"summary": group_memory_table.summary
|
||||
}
|
||||
for group_key, group_memory_table in self.group.items()
|
||||
}
|
||||
}
|
||||
|
||||
def get_entries(self) -> List[dict]:
|
||||
return [entry.__dict__() for entry in self.table]
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.__dict__())
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
def construct_memory_table(workers_info_by_node: dict) -> MemoryTable:
|
||||
memory_table_entries = []
|
||||
for node_id, worker_infos in workers_info_by_node.items():
|
||||
for worker_info in worker_infos:
|
||||
pid = worker_info["pid"]
|
||||
is_driver = worker_info.get("isDriver", False)
|
||||
core_worker_stats = worker_info["coreWorkerStats"]
|
||||
node_address = core_worker_stats["ipAddress"]
|
||||
object_refs = core_worker_stats.get("objectRefs", [])
|
||||
|
||||
for object_ref in object_refs:
|
||||
memory_table_entry = MemoryTableEntry(
|
||||
object_ref=object_ref,
|
||||
node_address=node_address,
|
||||
is_driver=is_driver,
|
||||
pid=pid)
|
||||
if memory_table_entry.is_valid():
|
||||
memory_table_entries.append(memory_table_entry)
|
||||
|
||||
memory_table = MemoryTable(memory_table_entries)
|
||||
return memory_table
|
||||
@@ -10,9 +10,11 @@ from ray.core.generated import node_manager_pb2
|
||||
from ray.core.generated import node_manager_pb2_grpc
|
||||
from ray.core.generated import reporter_pb2
|
||||
from ray.core.generated import reporter_pb2_grpc
|
||||
from ray.dashboard.memory import (ReferenceType, decode_object_id_if_needed,
|
||||
MemoryTableEntry, MemoryTable, SortingType)
|
||||
from ray.test_utils import (RayTestTimeoutException,
|
||||
wait_until_succeeded_without_exception,
|
||||
wait_until_server_available)
|
||||
wait_until_server_available, wait_for_condition)
|
||||
|
||||
import psutil # We must import psutil after ray because we bundle it with ray.
|
||||
|
||||
@@ -377,6 +379,446 @@ def test_profiling_info_endpoint(shutdown_only):
|
||||
assert profiling_stats is not None
|
||||
|
||||
|
||||
# This variable is used inside test_memory_dashboard.
|
||||
# It is defined as a global variable to be used across all nested test
|
||||
# functions. We use it because memory table is updated every one second,
|
||||
# and we need to have a way to verify if the test is running with a fresh
|
||||
# new memory table.
|
||||
prev_memory_table = MemoryTable([]).__dict__()["group"]
|
||||
|
||||
|
||||
def test_memory_dashboard(shutdown_only):
|
||||
"""Test Memory table.
|
||||
|
||||
These tests verify examples in this document.
|
||||
https://docs.ray.io/en/latest/memory-management.html#debugging-using-ray-memory
|
||||
"""
|
||||
addresses = ray.init(num_cpus=2)
|
||||
webui_url = addresses["webui_url"].replace("localhost", "http://127.0.0.1")
|
||||
assert (wait_until_server_available(addresses["webui_url"]) is True)
|
||||
|
||||
def get_memory_table():
|
||||
memory_table = requests.get(webui_url + "/api/memory_table").json()
|
||||
return memory_table["result"]
|
||||
|
||||
def memory_table_ready():
|
||||
"""Wait until the new fresh memory table is ready."""
|
||||
global prev_memory_table
|
||||
memory_table = get_memory_table()
|
||||
from pprint import pprint
|
||||
print("Current")
|
||||
pprint(memory_table)
|
||||
print("Prev")
|
||||
pprint(prev_memory_table)
|
||||
is_ready = memory_table["group"] != prev_memory_table
|
||||
prev_memory_table = memory_table["group"]
|
||||
return is_ready
|
||||
|
||||
def stop_memory_table():
|
||||
requests.get(webui_url + "/api/stop_memory_table").json()
|
||||
|
||||
def test_local_reference():
|
||||
@ray.remote
|
||||
def f(arg):
|
||||
return arg
|
||||
|
||||
# a and b are local references.
|
||||
a = ray.put(None) # Noqa F841
|
||||
b = f.remote(None) # Noqa F841
|
||||
|
||||
wait_for_condition(memory_table_ready)
|
||||
memory_table = get_memory_table()
|
||||
summary = memory_table["summary"]
|
||||
group = memory_table["group"]
|
||||
assert summary["total_captured_in_objects"] == 0
|
||||
assert summary["total_pinned_in_memory"] == 0
|
||||
assert summary["total_used_by_pending_task"] == 0
|
||||
assert summary["total_local_ref_count"] == 2
|
||||
for table in group.values():
|
||||
for entry in table["entries"]:
|
||||
assert (
|
||||
entry["reference_type"] == ReferenceType.LOCAL_REFERENCE)
|
||||
stop_memory_table()
|
||||
return True
|
||||
|
||||
def test_object_pineed_in_memory():
|
||||
import numpy as np
|
||||
|
||||
a = ray.put(np.zeros(1))
|
||||
b = ray.get(a) # Noqa F841
|
||||
del a
|
||||
|
||||
wait_for_condition(memory_table_ready)
|
||||
memory_table = get_memory_table()
|
||||
summary = memory_table["summary"]
|
||||
group = memory_table["group"]
|
||||
assert summary["total_captured_in_objects"] == 0
|
||||
assert summary["total_pinned_in_memory"] == 1
|
||||
assert summary["total_used_by_pending_task"] == 0
|
||||
assert summary["total_local_ref_count"] == 0
|
||||
for table in group.values():
|
||||
for entry in table["entries"]:
|
||||
assert (
|
||||
entry["reference_type"] == ReferenceType.PINNED_IN_MEMORY)
|
||||
stop_memory_table()
|
||||
return True
|
||||
|
||||
def test_pending_task_references():
|
||||
@ray.remote
|
||||
def f(arg):
|
||||
time.sleep(1)
|
||||
|
||||
a = ray.put(None) # Noqa F841
|
||||
b = f.remote(a) # Noqa F841
|
||||
|
||||
wait_for_condition(memory_table_ready)
|
||||
memory_table = get_memory_table()
|
||||
summary = memory_table["summary"]
|
||||
assert summary["total_captured_in_objects"] == 0
|
||||
assert summary["total_pinned_in_memory"] == 1
|
||||
assert summary["total_used_by_pending_task"] == 1
|
||||
assert summary["total_local_ref_count"] == 1
|
||||
# Make sure the function f is done before going to the next test.
|
||||
# Otherwise, the memory table will be corrupted because the
|
||||
# task f won't be done when the next test is running.
|
||||
ray.get(b)
|
||||
stop_memory_table()
|
||||
return True
|
||||
|
||||
def test_serialized_object_id_reference():
|
||||
@ray.remote
|
||||
def f(arg):
|
||||
time.sleep(1)
|
||||
|
||||
a = ray.put(None) # Noqa F841
|
||||
b = f.remote([a]) # Noqa F841
|
||||
|
||||
wait_for_condition(memory_table_ready)
|
||||
memory_table = get_memory_table()
|
||||
summary = memory_table["summary"]
|
||||
assert summary["total_captured_in_objects"] == 0
|
||||
assert summary["total_pinned_in_memory"] == 0
|
||||
assert summary["total_used_by_pending_task"] == 1
|
||||
assert summary["total_local_ref_count"] == 2
|
||||
# Make sure the function f is done before going to the next test.
|
||||
# Otherwise, the memory table will be corrupted because the
|
||||
# task f won't be done when the next test is running.
|
||||
ray.get(b)
|
||||
stop_memory_table()
|
||||
return True
|
||||
|
||||
def test_captured_object_id_reference():
|
||||
a = ray.put(None)
|
||||
b = ray.put([a]) # Noqa F841
|
||||
del a
|
||||
|
||||
wait_for_condition(memory_table_ready)
|
||||
memory_table = get_memory_table()
|
||||
summary = memory_table["summary"]
|
||||
assert summary["total_captured_in_objects"] == 1
|
||||
assert summary["total_pinned_in_memory"] == 0
|
||||
assert summary["total_used_by_pending_task"] == 0
|
||||
assert summary["total_local_ref_count"] == 1
|
||||
stop_memory_table()
|
||||
return True
|
||||
|
||||
def test_actor_handle_reference():
|
||||
@ray.remote
|
||||
class Actor:
|
||||
pass
|
||||
|
||||
a = Actor.remote() # Noqa F841
|
||||
b = Actor.remote() # Noqa F841
|
||||
c = Actor.remote() # Noqa F841
|
||||
|
||||
wait_for_condition(memory_table_ready)
|
||||
memory_table = get_memory_table()
|
||||
summary = memory_table["summary"]
|
||||
group = memory_table["group"]
|
||||
assert summary["total_captured_in_objects"] == 0
|
||||
assert summary["total_pinned_in_memory"] == 0
|
||||
assert summary["total_used_by_pending_task"] == 0
|
||||
assert summary["total_local_ref_count"] == 0
|
||||
assert summary["total_actor_handles"] == 3
|
||||
for table in group.values():
|
||||
for entry in table["entries"]:
|
||||
assert (entry["reference_type"] == ReferenceType.ACTOR_HANDLE)
|
||||
stop_memory_table()
|
||||
return True
|
||||
|
||||
# These tests should be retried because it takes at least one second
|
||||
# to get the fresh new memory table. It is because memory table is updated
|
||||
# Whenever raylet and node info is renewed which takes 1 second.
|
||||
assert (wait_for_condition(
|
||||
test_local_reference, timeout=30000, retry_interval_ms=1000) is True)
|
||||
|
||||
assert (wait_for_condition(
|
||||
test_object_pineed_in_memory, timeout=30000, retry_interval_ms=1000) is
|
||||
True)
|
||||
|
||||
assert (wait_for_condition(
|
||||
test_pending_task_references, timeout=30000, retry_interval_ms=1000) is
|
||||
True)
|
||||
|
||||
assert (wait_for_condition(
|
||||
test_serialized_object_id_reference,
|
||||
timeout=30000,
|
||||
retry_interval_ms=1000) is True)
|
||||
|
||||
assert (wait_for_condition(
|
||||
test_captured_object_id_reference,
|
||||
timeout=30000,
|
||||
retry_interval_ms=1000) is True)
|
||||
|
||||
assert (wait_for_condition(
|
||||
test_actor_handle_reference, timeout=30000, retry_interval_ms=1000) is
|
||||
True)
|
||||
|
||||
|
||||
"""Memory Table Unit Test"""
|
||||
|
||||
NODE_ADDRESS = "127.0.0.1"
|
||||
IS_DRIVER = True
|
||||
PID = 1
|
||||
OBJECT_ID = "7wpsIhgZiBz/////AQAAyAEAAAA="
|
||||
ACTOR_ID = "fffffffffffffffff66d17ba010000c801000000"
|
||||
DECODED_ID = decode_object_id_if_needed(OBJECT_ID)
|
||||
OBJECT_SIZE = 100
|
||||
|
||||
|
||||
def build_memory_entry(*,
|
||||
local_ref_count,
|
||||
pinned_in_memory,
|
||||
submitted_task_reference_count,
|
||||
contained_in_owned,
|
||||
object_size,
|
||||
pid,
|
||||
object_id=OBJECT_ID,
|
||||
node_address=NODE_ADDRESS):
|
||||
object_ref = {
|
||||
"objectId": object_id,
|
||||
"callSite": "(task call) /Users:458",
|
||||
"objectSize": object_size,
|
||||
"localRefCount": local_ref_count,
|
||||
"pinnedInMemory": pinned_in_memory,
|
||||
"submittedTaskRefCount": submitted_task_reference_count,
|
||||
"containedInOwned": contained_in_owned
|
||||
}
|
||||
return MemoryTableEntry(
|
||||
object_ref=object_ref,
|
||||
node_address=node_address,
|
||||
is_driver=IS_DRIVER,
|
||||
pid=pid)
|
||||
|
||||
|
||||
def build_local_reference_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
return build_memory_entry(
|
||||
local_ref_count=1,
|
||||
pinned_in_memory=False,
|
||||
submitted_task_reference_count=0,
|
||||
contained_in_owned=[],
|
||||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address)
|
||||
|
||||
|
||||
def build_used_by_pending_task_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
return build_memory_entry(
|
||||
local_ref_count=0,
|
||||
pinned_in_memory=False,
|
||||
submitted_task_reference_count=2,
|
||||
contained_in_owned=[],
|
||||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address)
|
||||
|
||||
|
||||
def build_captured_in_object_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
return build_memory_entry(
|
||||
local_ref_count=0,
|
||||
pinned_in_memory=False,
|
||||
submitted_task_reference_count=0,
|
||||
contained_in_owned=[OBJECT_ID],
|
||||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address)
|
||||
|
||||
|
||||
def build_actor_handle_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
return build_memory_entry(
|
||||
local_ref_count=1,
|
||||
pinned_in_memory=False,
|
||||
submitted_task_reference_count=0,
|
||||
contained_in_owned=[],
|
||||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address,
|
||||
object_id=ACTOR_ID)
|
||||
|
||||
|
||||
def build_pinned_in_memory_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS):
|
||||
return build_memory_entry(
|
||||
local_ref_count=0,
|
||||
pinned_in_memory=True,
|
||||
submitted_task_reference_count=0,
|
||||
contained_in_owned=[],
|
||||
object_size=object_size,
|
||||
pid=pid,
|
||||
node_address=node_address)
|
||||
|
||||
|
||||
def build_entry(object_size=OBJECT_SIZE,
|
||||
pid=PID,
|
||||
node_address=NODE_ADDRESS,
|
||||
reference_type=ReferenceType.PINNED_IN_MEMORY):
|
||||
if reference_type == ReferenceType.USED_BY_PENDING_TASK:
|
||||
return build_used_by_pending_task_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
elif reference_type == ReferenceType.LOCAL_REFERENCE:
|
||||
return build_local_reference_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
elif reference_type == ReferenceType.PINNED_IN_MEMORY:
|
||||
return build_pinned_in_memory_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
elif reference_type == ReferenceType.ACTOR_HANDLE:
|
||||
return build_actor_handle_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
elif reference_type == ReferenceType.CAPTURED_IN_OBJECT:
|
||||
return build_captured_in_object_entry(
|
||||
pid=pid, object_size=object_size, node_address=node_address)
|
||||
|
||||
|
||||
def test_invalid_memory_entry():
|
||||
memory_entry = build_memory_entry(
|
||||
local_ref_count=0,
|
||||
pinned_in_memory=False,
|
||||
submitted_task_reference_count=0,
|
||||
contained_in_owned=[],
|
||||
object_size=OBJECT_SIZE,
|
||||
pid=PID)
|
||||
assert memory_entry.is_valid() is False
|
||||
memory_entry = build_memory_entry(
|
||||
local_ref_count=0,
|
||||
pinned_in_memory=False,
|
||||
submitted_task_reference_count=0,
|
||||
contained_in_owned=[],
|
||||
object_size=-1,
|
||||
pid=PID)
|
||||
assert memory_entry.is_valid() is False
|
||||
|
||||
|
||||
def test_valid_reference_memory_entry():
|
||||
memory_entry = build_local_reference_entry()
|
||||
assert memory_entry.reference_type == ReferenceType.LOCAL_REFERENCE
|
||||
assert memory_entry.object_id == ray.ObjectID(
|
||||
decode_object_id_if_needed(OBJECT_ID))
|
||||
assert memory_entry.is_valid() is True
|
||||
|
||||
|
||||
def test_reference_type():
|
||||
# pinned in memory
|
||||
memory_entry = build_pinned_in_memory_entry()
|
||||
assert memory_entry.reference_type == ReferenceType.PINNED_IN_MEMORY
|
||||
|
||||
# used by pending task
|
||||
memory_entry = build_used_by_pending_task_entry()
|
||||
assert memory_entry.reference_type == ReferenceType.USED_BY_PENDING_TASK
|
||||
|
||||
# captued in object
|
||||
memory_entry = build_captured_in_object_entry()
|
||||
assert memory_entry.reference_type == ReferenceType.CAPTURED_IN_OBJECT
|
||||
|
||||
# actor handle
|
||||
memory_entry = build_actor_handle_entry()
|
||||
assert memory_entry.reference_type == ReferenceType.ACTOR_HANDLE
|
||||
|
||||
|
||||
def test_memory_table_summary():
|
||||
entries = [
|
||||
build_pinned_in_memory_entry(),
|
||||
build_used_by_pending_task_entry(),
|
||||
build_captured_in_object_entry(),
|
||||
build_actor_handle_entry(),
|
||||
build_local_reference_entry(),
|
||||
build_local_reference_entry()
|
||||
]
|
||||
memory_table = MemoryTable(entries)
|
||||
assert len(memory_table.group) == 1
|
||||
assert memory_table.summary["total_actor_handles"] == 1
|
||||
assert memory_table.summary["total_captured_in_objects"] == 1
|
||||
assert memory_table.summary["total_local_ref_count"] == 2
|
||||
assert memory_table.summary[
|
||||
"total_object_size"] == len(entries) * OBJECT_SIZE
|
||||
assert memory_table.summary["total_pinned_in_memory"] == 1
|
||||
assert memory_table.summary["total_used_by_pending_task"] == 1
|
||||
|
||||
|
||||
def test_memory_table_sort_by_pid():
|
||||
unsort = [1, 3, 2]
|
||||
entries = [build_entry(pid=pid) for pid in unsort]
|
||||
memory_table = MemoryTable(entries, sort_by_type=SortingType.PID)
|
||||
sort = sorted(unsort)
|
||||
for pid, entry in zip(sort, memory_table.table):
|
||||
assert pid == entry.pid
|
||||
|
||||
|
||||
def test_memory_table_sort_by_reference_type():
|
||||
unsort = [
|
||||
ReferenceType.USED_BY_PENDING_TASK, ReferenceType.LOCAL_REFERENCE,
|
||||
ReferenceType.LOCAL_REFERENCE, ReferenceType.PINNED_IN_MEMORY
|
||||
]
|
||||
entries = [
|
||||
build_entry(reference_type=reference_type) for reference_type in unsort
|
||||
]
|
||||
memory_table = MemoryTable(
|
||||
entries, sort_by_type=SortingType.REFERENCE_TYPE)
|
||||
sort = sorted(unsort)
|
||||
for reference_type, entry in zip(sort, memory_table.table):
|
||||
assert reference_type == entry.reference_type
|
||||
|
||||
|
||||
def test_memory_table_sort_by_object_size():
|
||||
unsort = [312, 214, -1, 1244, 642]
|
||||
entries = [build_entry(object_size=object_size) for object_size in unsort]
|
||||
memory_table = MemoryTable(entries, sort_by_type=SortingType.OBJECT_SIZE)
|
||||
sort = sorted(unsort)
|
||||
for object_size, entry in zip(sort, memory_table.table):
|
||||
assert object_size == entry.object_size
|
||||
|
||||
|
||||
def test_group_by():
|
||||
node_second = "127.0.0.2"
|
||||
node_first = "127.0.0.1"
|
||||
entries = [
|
||||
build_entry(node_address=node_second, pid=2),
|
||||
build_entry(node_address=node_second, pid=1),
|
||||
build_entry(node_address=node_first, pid=2),
|
||||
build_entry(node_address=node_first, pid=1)
|
||||
]
|
||||
memory_table = MemoryTable(entries)
|
||||
|
||||
# Make sure it is correctly grouped
|
||||
assert node_first in memory_table.group
|
||||
assert node_second in memory_table.group
|
||||
|
||||
# make sure pid is sorted in the right order.
|
||||
for group_key, group_memory_table in memory_table.group.items():
|
||||
pid = 1
|
||||
for entry in group_memory_table.table:
|
||||
assert pid == entry.pid
|
||||
pid += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
Reference in New Issue
Block a user