[Dashboard] Ray memory dashboard backend (#8461)

This commit is contained in:
SangBin Cho
2020-05-21 12:22:28 -07:00
committed by GitHub
parent 9a83908c46
commit aa1cbe8abc
3 changed files with 780 additions and 4 deletions
+42 -3
View File
@@ -38,6 +38,7 @@ from ray.core.generated import core_worker_pb2
from ray.core.generated import core_worker_pb2_grpc
from ray.dashboard.interface import BaseDashboardController
from ray.dashboard.interface import BaseDashboardRouteHandler
from ray.dashboard.memory import construct_memory_table, MemoryTable
from ray.dashboard.metrics_exporter.client import Exporter
from ray.dashboard.metrics_exporter.client import MetricsExportClient
@@ -126,6 +127,7 @@ class DashboardController(BaseDashboardController):
redis_address, redis_password=redis_password)
if Analysis is not None:
self.tune_stats = TuneCollector(2.0)
self.memory_table = MemoryTable([])
def _construct_raylet_info(self):
D = self.raylet_stats.get_raylet_stats()
@@ -133,6 +135,7 @@ class DashboardController(BaseDashboardController):
data["nodeId"]: data.get("workersStats")
for data in D.values()
}
infeasible_tasks = sum(
(data.get("infeasibleTasks", []) for data in D.values()), [])
# ready_tasks are used to render tasks that are not schedulable
@@ -142,6 +145,7 @@ class DashboardController(BaseDashboardController):
[])
actor_tree = self.node_stats.get_actor_tree(
workers_info_by_node, infeasible_tasks, ready_tasks)
for address, data in D.items():
# process view data
measures_dicts = {}
@@ -224,6 +228,23 @@ class DashboardController(BaseDashboardController):
def get_raylet_info(self):
return self._construct_raylet_info()
def get_memory_table_info(self) -> MemoryTable:
# Collecting memory info adds big overhead to the cluster.
# This must be collected only when it is necessary.
self.raylet_stats.include_memory_info = True
D = self.raylet_stats.get_raylet_stats()
workers_info_by_node = {
data["nodeId"]: data.get("workersStats")
for data in D.values()
}
self.memory_table = construct_memory_table(workers_info_by_node)
return self.memory_table
def stop_collecting_memory_table_info(self):
# Reset memory table.
self.memory_table = MemoryTable([])
self.raylet_stats.include_memory_info = False
def tune_info(self):
if Analysis is not None:
D = self.tune_stats.get_stats()
@@ -313,6 +334,15 @@ class DashboardRouteHandler(BaseDashboardRouteHandler):
result = self.dashboard_controller.get_raylet_info()
return await json_response(self.is_dev, result=result)
async def memory_table_info(self, req) -> aiohttp.web.Response:
memory_table = self.dashboard_controller.get_memory_table_info()
return await json_response(self.is_dev, result=memory_table.__dict__())
async def stop_collecting_memory_table_info(self,
req) -> aiohttp.web.Response:
self.dashboard_controller.stop_collecting_memory_table_info()
return await json_response(self.is_dev, result={})
async def tune_info(self, req) -> aiohttp.web.Response:
result = self.dashboard_controller.tune_info()
return await json_response(self.is_dev, result=result)
@@ -462,7 +492,9 @@ def setup_dashboard_route(app: aiohttp.web.Application,
get_profiling_info=None,
kill_actor=None,
logs=None,
errors=None):
errors=None,
memory_table=None,
stop_memory_table=None):
def add_get_route(route, handler_func):
if route is not None:
app.router.add_get(route, handler_func)
@@ -480,6 +512,8 @@ def setup_dashboard_route(app: aiohttp.web.Application,
add_get_route(kill_actor, handler.kill_actor)
add_get_route(logs, handler.logs)
add_get_route(errors, handler.errors)
add_get_route(memory_table, handler.memory_table_info)
add_get_route(stop_memory_table, handler.stop_collecting_memory_table_info)
class Dashboard:
@@ -548,7 +582,9 @@ class Dashboard:
get_profiling_info="/api/get_profiling_info",
kill_actor="/api/kill_actor",
logs="/api/logs",
errors="/api/errors")
errors="/api/errors",
memory_table="/api/memory_table",
stop_memory_table="/api/stop_memory_table")
self.app.router.add_get("/{_}", route_handler.get_forbidden)
self.app.router.add_post("/api/set_tune_experiment",
route_handler.set_tune_experiment)
@@ -849,6 +885,7 @@ class RayletStats(threading.Thread):
self._profiling_stats = {}
self._update_nodes()
self.include_memory_info = False
super().__init__()
@@ -950,7 +987,9 @@ class RayletStats(threading.Thread):
node_id = node["NodeID"]
stub = self.stubs[node_id]
reply = stub.GetNodeStats(
node_manager_pb2.GetNodeStatsRequest(), timeout=2)
node_manager_pb2.GetNodeStatsRequest(
include_memory_info=self.include_memory_info),
timeout=2)
reply_dict = MessageToDict(reply)
reply_dict["nodeId"] = node_id
replies[node["NodeManagerAddress"]] = reply_dict
+295
View File
@@ -0,0 +1,295 @@
import base64
from collections import defaultdict
from enum import Enum
from typing import List
import ray
from ray._raylet import (TaskID, ActorID, JobID)
# These values are used to calculate if objectIDs are actor handles.
TASKID_BYTES_SIZE = TaskID.size()
ACTORID_BYTES_SIZE = ActorID.size()
JOBID_BYTES_SIZE = JobID.size()
# We need to multiply 2 because we need bits size instead of bytes size.
TASKID_RANDOM_BITS_SIZE = (TASKID_BYTES_SIZE - ACTORID_BYTES_SIZE) * 2
ACTORID_RANDOM_BITS_SIZE = (ACTORID_BYTES_SIZE - JOBID_BYTES_SIZE) * 2
def decode_object_id_if_needed(object_id: str) -> bytes:
"""Decode objectID bytes string.
gRPC reply contains an objectID that is encodded by Base64.
This function is used to decode the objectID.
Note that there are times that objectID is already decoded as
a hex string. In this case, just convert it to a binary number.
"""
if object_id.endswith("="):
# If the object id ends with =, that means it is base64 encoded.
# Object ids will always have = as a padding
# when it is base64 encoded because objectID is always 20B.
return base64.standard_b64decode(object_id)
else:
return ray.utils.hex_to_binary(object_id)
class SortingType(Enum):
PID = 1
OBJECT_SIZE = 3
REFERENCE_TYPE = 4
class GroupByType(Enum):
NODE_ADDRESS = 2
class ReferenceType:
# We don't use enum because enum is not json serializable.
ACTOR_HANDLE = "ACTOR_HANDLE"
PINNED_IN_MEMORY = "PINNED_IN_MEMORY"
LOCAL_REFERENCE = "LOCAL_REFERENCE"
USED_BY_PENDING_TASK = "USED_BY_PENDING_TASK"
CAPTURED_IN_OBJECT = "CAPTURED_IN_OBJECT"
UNKNOWN_STATUS = "UNKNOWN_STATUS"
class MemoryTableEntry:
def __init__(self, *, object_ref: dict, node_address: str, is_driver: bool,
pid: int):
# worker info
self.is_driver = is_driver
self.pid = pid
self.node_address = node_address
# object info
self.object_size = int(object_ref.get("objectSize", -1))
self.call_site = object_ref.get("callSite", "<Unknown>")
self.object_id = ray.ObjectID(
decode_object_id_if_needed(object_ref["objectId"]))
# reference info
self.local_ref_count = int(object_ref.get("localRefCount", 0))
self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False))
self.submitted_task_ref_count = int(
object_ref.get("submittedTaskRefCount", 0))
self.contained_in_owned = [
ray.ObjectID(decode_object_id_if_needed(object_id))
for object_id in object_ref.get("containedInOwned", [])
]
self.reference_type = self._get_reference_type()
def is_valid(self) -> bool:
# If the entry doesn't have a reference type or some invalid state,
# (e.g., no object ID presented), it is considered invalid.
if (not self.pinned_in_memory and self.local_ref_count == 0
and self.submitted_task_ref_count == 0
and len(self.contained_in_owned) == 0):
return False
elif self.object_id.is_nil():
return False
else:
return True
def group_key(self, group_by_type: GroupByType) -> str:
if group_by_type == GroupByType.NODE_ADDRESS:
return self.node_address
else:
raise ValueError(
"group by type {} is invalid.".format(group_by_type))
def _get_reference_type(self) -> str:
if self._is_object_id_actor_handle():
return ReferenceType.ACTOR_HANDLE
if self.pinned_in_memory:
return ReferenceType.PINNED_IN_MEMORY
elif self.submitted_task_ref_count > 0:
return ReferenceType.USED_BY_PENDING_TASK
elif self.local_ref_count > 0:
return ReferenceType.LOCAL_REFERENCE
elif len(self.contained_in_owned) > 0:
return ReferenceType.CAPTURED_IN_OBJECT
else:
return ReferenceType.UNKNOWN_STATUS
def _is_object_id_actor_handle(self) -> bool:
object_id_hex = self.object_id.hex()
# random (8B) | ActorID(6B) | flag (2B) | index (6B)
# ActorID(6B) == ActorRandomByte(4B) + JobID(2B)
# If random bytes are all 'f', but ActorRandomBytes
# are not all 'f', that means it is an actor creation
# task, which is an actor handle.
random_bits = object_id_hex[:TASKID_RANDOM_BITS_SIZE]
actor_random_bits = object_id_hex[TASKID_RANDOM_BITS_SIZE:
TASKID_RANDOM_BITS_SIZE +
ACTORID_RANDOM_BITS_SIZE]
if (random_bits == "f" * 16 and not actor_random_bits == "f" * 8):
return True
else:
return False
def __dict__(self):
return {
"object_id": self.object_id.hex(),
"pid": self.pid,
"node_ip_address": self.node_address,
"object_size": self.object_size,
"reference_type": self.reference_type,
"call_site": self.call_site,
"local_ref_count": self.local_ref_count,
"pinned_in_memory": self.pinned_in_memory,
"submitted_task_ref_count": self.submitted_task_ref_count,
"contained_in_owned": [
object_id.hex() for object_id in self.contained_in_owned
],
"type": "Driver" if self.is_driver else "Worker"
}
def __str__(self):
return self.__repr__()
def __repr__(self):
return str(self.__dict__())
class MemoryTable:
def __init__(self,
entries: List[MemoryTableEntry],
group_by_type: GroupByType = GroupByType.NODE_ADDRESS,
sort_by_type: SortingType = SortingType.PID):
self.table = entries
# Group is a list of memory tables grouped by a group key.
self.group = {}
self.summary = defaultdict(int)
if group_by_type and sort_by_type:
self.setup(group_by_type, sort_by_type)
elif group_by_type:
self._group_by(group_by_type)
elif sort_by_type:
self._sort_by(sort_by_type)
def setup(self, group_by_type: GroupByType, sort_by_type: SortingType):
"""Setup memory table.
This will sort entries first and gruop them after.
Sort order will be still kept.
"""
self._sort_by(sort_by_type)._group_by(group_by_type)
for group_memory_table in self.group.values():
group_memory_table.summarize()
self.summarize()
return self
def insert_entry(self, entry: MemoryTableEntry):
self.table.append(entry)
def summarize(self):
# Reset summary.
total_object_size = 0
total_local_ref_count = 0
total_pinned_in_memory = 0
total_used_by_pending_task = 0
total_captured_in_objects = 0
total_actor_handles = 0
for entry in self.table:
if entry.object_size > 0:
total_object_size += entry.object_size
if entry.reference_type == ReferenceType.LOCAL_REFERENCE:
total_local_ref_count += 1
elif entry.reference_type == ReferenceType.PINNED_IN_MEMORY:
total_pinned_in_memory += 1
elif entry.reference_type == ReferenceType.USED_BY_PENDING_TASK:
total_used_by_pending_task += 1
elif entry.reference_type == ReferenceType.CAPTURED_IN_OBJECT:
total_captured_in_objects += 1
elif entry.reference_type == ReferenceType.ACTOR_HANDLE:
total_actor_handles += 1
self.summary = {
"total_object_size": total_object_size,
"total_local_ref_count": total_local_ref_count,
"total_pinned_in_memory": total_pinned_in_memory,
"total_used_by_pending_task": total_used_by_pending_task,
"total_captured_in_objects": total_captured_in_objects,
"total_actor_handles": total_actor_handles
}
return self
def _sort_by(self, sorting_type: SortingType):
if sorting_type == SortingType.PID:
self.table.sort(key=lambda entry: entry.pid)
elif sorting_type == SortingType.OBJECT_SIZE:
self.table.sort(key=lambda entry: entry.object_size)
elif sorting_type == SortingType.REFERENCE_TYPE:
self.table.sort(key=lambda entry: entry.reference_type)
else:
raise ValueError(
"Give sorting type: {} is invalid.".format(sorting_type))
return self
def _group_by(self, group_by_type: GroupByType):
"""Group entries and summarize the result.
NOTE: Each group is another MemoryTable.
"""
# Reset group
self.group = {}
# Build entries per group.
group = defaultdict(list)
for entry in self.table:
group[entry.group_key(group_by_type)].append(entry)
# Build a group table.
for group_key, entries in group.items():
self.group[group_key] = MemoryTable(
entries, group_by_type=None, sort_by_type=None)
for group_key, group_memory_table in self.group.items():
group_memory_table.summarize()
return self
def __dict__(self):
return {
"summary": self.summary,
"group": {
group_key: {
"entries": group_memory_table.get_entries(),
"summary": group_memory_table.summary
}
for group_key, group_memory_table in self.group.items()
}
}
def get_entries(self) -> List[dict]:
return [entry.__dict__() for entry in self.table]
def __repr__(self):
return str(self.__dict__())
def __str__(self):
return self.__repr__()
def construct_memory_table(workers_info_by_node: dict) -> MemoryTable:
memory_table_entries = []
for node_id, worker_infos in workers_info_by_node.items():
for worker_info in worker_infos:
pid = worker_info["pid"]
is_driver = worker_info.get("isDriver", False)
core_worker_stats = worker_info["coreWorkerStats"]
node_address = core_worker_stats["ipAddress"]
object_refs = core_worker_stats.get("objectRefs", [])
for object_ref in object_refs:
memory_table_entry = MemoryTableEntry(
object_ref=object_ref,
node_address=node_address,
is_driver=is_driver,
pid=pid)
if memory_table_entry.is_valid():
memory_table_entries.append(memory_table_entry)
memory_table = MemoryTable(memory_table_entries)
return memory_table
+443 -1
View File
@@ -10,9 +10,11 @@ from ray.core.generated import node_manager_pb2
from ray.core.generated import node_manager_pb2_grpc
from ray.core.generated import reporter_pb2
from ray.core.generated import reporter_pb2_grpc
from ray.dashboard.memory import (ReferenceType, decode_object_id_if_needed,
MemoryTableEntry, MemoryTable, SortingType)
from ray.test_utils import (RayTestTimeoutException,
wait_until_succeeded_without_exception,
wait_until_server_available)
wait_until_server_available, wait_for_condition)
import psutil # We must import psutil after ray because we bundle it with ray.
@@ -377,6 +379,446 @@ def test_profiling_info_endpoint(shutdown_only):
assert profiling_stats is not None
# This variable is used inside test_memory_dashboard.
# It is defined as a global variable to be used across all nested test
# functions. We use it because memory table is updated every one second,
# and we need to have a way to verify if the test is running with a fresh
# new memory table.
prev_memory_table = MemoryTable([]).__dict__()["group"]
def test_memory_dashboard(shutdown_only):
"""Test Memory table.
These tests verify examples in this document.
https://docs.ray.io/en/latest/memory-management.html#debugging-using-ray-memory
"""
addresses = ray.init(num_cpus=2)
webui_url = addresses["webui_url"].replace("localhost", "http://127.0.0.1")
assert (wait_until_server_available(addresses["webui_url"]) is True)
def get_memory_table():
memory_table = requests.get(webui_url + "/api/memory_table").json()
return memory_table["result"]
def memory_table_ready():
"""Wait until the new fresh memory table is ready."""
global prev_memory_table
memory_table = get_memory_table()
from pprint import pprint
print("Current")
pprint(memory_table)
print("Prev")
pprint(prev_memory_table)
is_ready = memory_table["group"] != prev_memory_table
prev_memory_table = memory_table["group"]
return is_ready
def stop_memory_table():
requests.get(webui_url + "/api/stop_memory_table").json()
def test_local_reference():
@ray.remote
def f(arg):
return arg
# a and b are local references.
a = ray.put(None) # Noqa F841
b = f.remote(None) # Noqa F841
wait_for_condition(memory_table_ready)
memory_table = get_memory_table()
summary = memory_table["summary"]
group = memory_table["group"]
assert summary["total_captured_in_objects"] == 0
assert summary["total_pinned_in_memory"] == 0
assert summary["total_used_by_pending_task"] == 0
assert summary["total_local_ref_count"] == 2
for table in group.values():
for entry in table["entries"]:
assert (
entry["reference_type"] == ReferenceType.LOCAL_REFERENCE)
stop_memory_table()
return True
def test_object_pineed_in_memory():
import numpy as np
a = ray.put(np.zeros(1))
b = ray.get(a) # Noqa F841
del a
wait_for_condition(memory_table_ready)
memory_table = get_memory_table()
summary = memory_table["summary"]
group = memory_table["group"]
assert summary["total_captured_in_objects"] == 0
assert summary["total_pinned_in_memory"] == 1
assert summary["total_used_by_pending_task"] == 0
assert summary["total_local_ref_count"] == 0
for table in group.values():
for entry in table["entries"]:
assert (
entry["reference_type"] == ReferenceType.PINNED_IN_MEMORY)
stop_memory_table()
return True
def test_pending_task_references():
@ray.remote
def f(arg):
time.sleep(1)
a = ray.put(None) # Noqa F841
b = f.remote(a) # Noqa F841
wait_for_condition(memory_table_ready)
memory_table = get_memory_table()
summary = memory_table["summary"]
assert summary["total_captured_in_objects"] == 0
assert summary["total_pinned_in_memory"] == 1
assert summary["total_used_by_pending_task"] == 1
assert summary["total_local_ref_count"] == 1
# Make sure the function f is done before going to the next test.
# Otherwise, the memory table will be corrupted because the
# task f won't be done when the next test is running.
ray.get(b)
stop_memory_table()
return True
def test_serialized_object_id_reference():
@ray.remote
def f(arg):
time.sleep(1)
a = ray.put(None) # Noqa F841
b = f.remote([a]) # Noqa F841
wait_for_condition(memory_table_ready)
memory_table = get_memory_table()
summary = memory_table["summary"]
assert summary["total_captured_in_objects"] == 0
assert summary["total_pinned_in_memory"] == 0
assert summary["total_used_by_pending_task"] == 1
assert summary["total_local_ref_count"] == 2
# Make sure the function f is done before going to the next test.
# Otherwise, the memory table will be corrupted because the
# task f won't be done when the next test is running.
ray.get(b)
stop_memory_table()
return True
def test_captured_object_id_reference():
a = ray.put(None)
b = ray.put([a]) # Noqa F841
del a
wait_for_condition(memory_table_ready)
memory_table = get_memory_table()
summary = memory_table["summary"]
assert summary["total_captured_in_objects"] == 1
assert summary["total_pinned_in_memory"] == 0
assert summary["total_used_by_pending_task"] == 0
assert summary["total_local_ref_count"] == 1
stop_memory_table()
return True
def test_actor_handle_reference():
@ray.remote
class Actor:
pass
a = Actor.remote() # Noqa F841
b = Actor.remote() # Noqa F841
c = Actor.remote() # Noqa F841
wait_for_condition(memory_table_ready)
memory_table = get_memory_table()
summary = memory_table["summary"]
group = memory_table["group"]
assert summary["total_captured_in_objects"] == 0
assert summary["total_pinned_in_memory"] == 0
assert summary["total_used_by_pending_task"] == 0
assert summary["total_local_ref_count"] == 0
assert summary["total_actor_handles"] == 3
for table in group.values():
for entry in table["entries"]:
assert (entry["reference_type"] == ReferenceType.ACTOR_HANDLE)
stop_memory_table()
return True
# These tests should be retried because it takes at least one second
# to get the fresh new memory table. It is because memory table is updated
# Whenever raylet and node info is renewed which takes 1 second.
assert (wait_for_condition(
test_local_reference, timeout=30000, retry_interval_ms=1000) is True)
assert (wait_for_condition(
test_object_pineed_in_memory, timeout=30000, retry_interval_ms=1000) is
True)
assert (wait_for_condition(
test_pending_task_references, timeout=30000, retry_interval_ms=1000) is
True)
assert (wait_for_condition(
test_serialized_object_id_reference,
timeout=30000,
retry_interval_ms=1000) is True)
assert (wait_for_condition(
test_captured_object_id_reference,
timeout=30000,
retry_interval_ms=1000) is True)
assert (wait_for_condition(
test_actor_handle_reference, timeout=30000, retry_interval_ms=1000) is
True)
"""Memory Table Unit Test"""
NODE_ADDRESS = "127.0.0.1"
IS_DRIVER = True
PID = 1
OBJECT_ID = "7wpsIhgZiBz/////AQAAyAEAAAA="
ACTOR_ID = "fffffffffffffffff66d17ba010000c801000000"
DECODED_ID = decode_object_id_if_needed(OBJECT_ID)
OBJECT_SIZE = 100
def build_memory_entry(*,
local_ref_count,
pinned_in_memory,
submitted_task_reference_count,
contained_in_owned,
object_size,
pid,
object_id=OBJECT_ID,
node_address=NODE_ADDRESS):
object_ref = {
"objectId": object_id,
"callSite": "(task call) /Users:458",
"objectSize": object_size,
"localRefCount": local_ref_count,
"pinnedInMemory": pinned_in_memory,
"submittedTaskRefCount": submitted_task_reference_count,
"containedInOwned": contained_in_owned
}
return MemoryTableEntry(
object_ref=object_ref,
node_address=node_address,
is_driver=IS_DRIVER,
pid=pid)
def build_local_reference_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
return build_memory_entry(
local_ref_count=1,
pinned_in_memory=False,
submitted_task_reference_count=0,
contained_in_owned=[],
object_size=object_size,
pid=pid,
node_address=node_address)
def build_used_by_pending_task_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
return build_memory_entry(
local_ref_count=0,
pinned_in_memory=False,
submitted_task_reference_count=2,
contained_in_owned=[],
object_size=object_size,
pid=pid,
node_address=node_address)
def build_captured_in_object_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
return build_memory_entry(
local_ref_count=0,
pinned_in_memory=False,
submitted_task_reference_count=0,
contained_in_owned=[OBJECT_ID],
object_size=object_size,
pid=pid,
node_address=node_address)
def build_actor_handle_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
return build_memory_entry(
local_ref_count=1,
pinned_in_memory=False,
submitted_task_reference_count=0,
contained_in_owned=[],
object_size=object_size,
pid=pid,
node_address=node_address,
object_id=ACTOR_ID)
def build_pinned_in_memory_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS):
return build_memory_entry(
local_ref_count=0,
pinned_in_memory=True,
submitted_task_reference_count=0,
contained_in_owned=[],
object_size=object_size,
pid=pid,
node_address=node_address)
def build_entry(object_size=OBJECT_SIZE,
pid=PID,
node_address=NODE_ADDRESS,
reference_type=ReferenceType.PINNED_IN_MEMORY):
if reference_type == ReferenceType.USED_BY_PENDING_TASK:
return build_used_by_pending_task_entry(
pid=pid, object_size=object_size, node_address=node_address)
elif reference_type == ReferenceType.LOCAL_REFERENCE:
return build_local_reference_entry(
pid=pid, object_size=object_size, node_address=node_address)
elif reference_type == ReferenceType.PINNED_IN_MEMORY:
return build_pinned_in_memory_entry(
pid=pid, object_size=object_size, node_address=node_address)
elif reference_type == ReferenceType.ACTOR_HANDLE:
return build_actor_handle_entry(
pid=pid, object_size=object_size, node_address=node_address)
elif reference_type == ReferenceType.CAPTURED_IN_OBJECT:
return build_captured_in_object_entry(
pid=pid, object_size=object_size, node_address=node_address)
def test_invalid_memory_entry():
memory_entry = build_memory_entry(
local_ref_count=0,
pinned_in_memory=False,
submitted_task_reference_count=0,
contained_in_owned=[],
object_size=OBJECT_SIZE,
pid=PID)
assert memory_entry.is_valid() is False
memory_entry = build_memory_entry(
local_ref_count=0,
pinned_in_memory=False,
submitted_task_reference_count=0,
contained_in_owned=[],
object_size=-1,
pid=PID)
assert memory_entry.is_valid() is False
def test_valid_reference_memory_entry():
memory_entry = build_local_reference_entry()
assert memory_entry.reference_type == ReferenceType.LOCAL_REFERENCE
assert memory_entry.object_id == ray.ObjectID(
decode_object_id_if_needed(OBJECT_ID))
assert memory_entry.is_valid() is True
def test_reference_type():
# pinned in memory
memory_entry = build_pinned_in_memory_entry()
assert memory_entry.reference_type == ReferenceType.PINNED_IN_MEMORY
# used by pending task
memory_entry = build_used_by_pending_task_entry()
assert memory_entry.reference_type == ReferenceType.USED_BY_PENDING_TASK
# captued in object
memory_entry = build_captured_in_object_entry()
assert memory_entry.reference_type == ReferenceType.CAPTURED_IN_OBJECT
# actor handle
memory_entry = build_actor_handle_entry()
assert memory_entry.reference_type == ReferenceType.ACTOR_HANDLE
def test_memory_table_summary():
entries = [
build_pinned_in_memory_entry(),
build_used_by_pending_task_entry(),
build_captured_in_object_entry(),
build_actor_handle_entry(),
build_local_reference_entry(),
build_local_reference_entry()
]
memory_table = MemoryTable(entries)
assert len(memory_table.group) == 1
assert memory_table.summary["total_actor_handles"] == 1
assert memory_table.summary["total_captured_in_objects"] == 1
assert memory_table.summary["total_local_ref_count"] == 2
assert memory_table.summary[
"total_object_size"] == len(entries) * OBJECT_SIZE
assert memory_table.summary["total_pinned_in_memory"] == 1
assert memory_table.summary["total_used_by_pending_task"] == 1
def test_memory_table_sort_by_pid():
unsort = [1, 3, 2]
entries = [build_entry(pid=pid) for pid in unsort]
memory_table = MemoryTable(entries, sort_by_type=SortingType.PID)
sort = sorted(unsort)
for pid, entry in zip(sort, memory_table.table):
assert pid == entry.pid
def test_memory_table_sort_by_reference_type():
unsort = [
ReferenceType.USED_BY_PENDING_TASK, ReferenceType.LOCAL_REFERENCE,
ReferenceType.LOCAL_REFERENCE, ReferenceType.PINNED_IN_MEMORY
]
entries = [
build_entry(reference_type=reference_type) for reference_type in unsort
]
memory_table = MemoryTable(
entries, sort_by_type=SortingType.REFERENCE_TYPE)
sort = sorted(unsort)
for reference_type, entry in zip(sort, memory_table.table):
assert reference_type == entry.reference_type
def test_memory_table_sort_by_object_size():
unsort = [312, 214, -1, 1244, 642]
entries = [build_entry(object_size=object_size) for object_size in unsort]
memory_table = MemoryTable(entries, sort_by_type=SortingType.OBJECT_SIZE)
sort = sorted(unsort)
for object_size, entry in zip(sort, memory_table.table):
assert object_size == entry.object_size
def test_group_by():
node_second = "127.0.0.2"
node_first = "127.0.0.1"
entries = [
build_entry(node_address=node_second, pid=2),
build_entry(node_address=node_second, pid=1),
build_entry(node_address=node_first, pid=2),
build_entry(node_address=node_first, pid=1)
]
memory_table = MemoryTable(entries)
# Make sure it is correctly grouped
assert node_first in memory_table.group
assert node_second in memory_table.group
# make sure pid is sorted in the right order.
for group_key, group_memory_table in memory_table.group.items():
pid = 1
for entry in group_memory_table.table:
assert pid == entry.pid
pid += 1
if __name__ == "__main__":
import pytest
import sys