diff --git a/python/ray/dashboard/dashboard.py b/python/ray/dashboard/dashboard.py index 3d94dd0a1..1ff6b961f 100644 --- a/python/ray/dashboard/dashboard.py +++ b/python/ray/dashboard/dashboard.py @@ -38,6 +38,7 @@ from ray.core.generated import core_worker_pb2 from ray.core.generated import core_worker_pb2_grpc from ray.dashboard.interface import BaseDashboardController from ray.dashboard.interface import BaseDashboardRouteHandler +from ray.dashboard.memory import construct_memory_table, MemoryTable from ray.dashboard.metrics_exporter.client import Exporter from ray.dashboard.metrics_exporter.client import MetricsExportClient @@ -126,6 +127,7 @@ class DashboardController(BaseDashboardController): redis_address, redis_password=redis_password) if Analysis is not None: self.tune_stats = TuneCollector(2.0) + self.memory_table = MemoryTable([]) def _construct_raylet_info(self): D = self.raylet_stats.get_raylet_stats() @@ -133,6 +135,7 @@ class DashboardController(BaseDashboardController): data["nodeId"]: data.get("workersStats") for data in D.values() } + infeasible_tasks = sum( (data.get("infeasibleTasks", []) for data in D.values()), []) # ready_tasks are used to render tasks that are not schedulable @@ -142,6 +145,7 @@ class DashboardController(BaseDashboardController): []) actor_tree = self.node_stats.get_actor_tree( workers_info_by_node, infeasible_tasks, ready_tasks) + for address, data in D.items(): # process view data measures_dicts = {} @@ -224,6 +228,23 @@ class DashboardController(BaseDashboardController): def get_raylet_info(self): return self._construct_raylet_info() + def get_memory_table_info(self) -> MemoryTable: + # Collecting memory info adds big overhead to the cluster. + # This must be collected only when it is necessary. + self.raylet_stats.include_memory_info = True + D = self.raylet_stats.get_raylet_stats() + workers_info_by_node = { + data["nodeId"]: data.get("workersStats") + for data in D.values() + } + self.memory_table = construct_memory_table(workers_info_by_node) + return self.memory_table + + def stop_collecting_memory_table_info(self): + # Reset memory table. + self.memory_table = MemoryTable([]) + self.raylet_stats.include_memory_info = False + def tune_info(self): if Analysis is not None: D = self.tune_stats.get_stats() @@ -313,6 +334,15 @@ class DashboardRouteHandler(BaseDashboardRouteHandler): result = self.dashboard_controller.get_raylet_info() return await json_response(self.is_dev, result=result) + async def memory_table_info(self, req) -> aiohttp.web.Response: + memory_table = self.dashboard_controller.get_memory_table_info() + return await json_response(self.is_dev, result=memory_table.__dict__()) + + async def stop_collecting_memory_table_info(self, + req) -> aiohttp.web.Response: + self.dashboard_controller.stop_collecting_memory_table_info() + return await json_response(self.is_dev, result={}) + async def tune_info(self, req) -> aiohttp.web.Response: result = self.dashboard_controller.tune_info() return await json_response(self.is_dev, result=result) @@ -462,7 +492,9 @@ def setup_dashboard_route(app: aiohttp.web.Application, get_profiling_info=None, kill_actor=None, logs=None, - errors=None): + errors=None, + memory_table=None, + stop_memory_table=None): def add_get_route(route, handler_func): if route is not None: app.router.add_get(route, handler_func) @@ -480,6 +512,8 @@ def setup_dashboard_route(app: aiohttp.web.Application, add_get_route(kill_actor, handler.kill_actor) add_get_route(logs, handler.logs) add_get_route(errors, handler.errors) + add_get_route(memory_table, handler.memory_table_info) + add_get_route(stop_memory_table, handler.stop_collecting_memory_table_info) class Dashboard: @@ -548,7 +582,9 @@ class Dashboard: get_profiling_info="/api/get_profiling_info", kill_actor="/api/kill_actor", logs="/api/logs", - errors="/api/errors") + errors="/api/errors", + memory_table="/api/memory_table", + stop_memory_table="/api/stop_memory_table") self.app.router.add_get("/{_}", route_handler.get_forbidden) self.app.router.add_post("/api/set_tune_experiment", route_handler.set_tune_experiment) @@ -849,6 +885,7 @@ class RayletStats(threading.Thread): self._profiling_stats = {} self._update_nodes() + self.include_memory_info = False super().__init__() @@ -950,7 +987,9 @@ class RayletStats(threading.Thread): node_id = node["NodeID"] stub = self.stubs[node_id] reply = stub.GetNodeStats( - node_manager_pb2.GetNodeStatsRequest(), timeout=2) + node_manager_pb2.GetNodeStatsRequest( + include_memory_info=self.include_memory_info), + timeout=2) reply_dict = MessageToDict(reply) reply_dict["nodeId"] = node_id replies[node["NodeManagerAddress"]] = reply_dict diff --git a/python/ray/dashboard/memory.py b/python/ray/dashboard/memory.py new file mode 100644 index 000000000..735411663 --- /dev/null +++ b/python/ray/dashboard/memory.py @@ -0,0 +1,295 @@ +import base64 + +from collections import defaultdict +from enum import Enum +from typing import List + +import ray + +from ray._raylet import (TaskID, ActorID, JobID) + +# These values are used to calculate if objectIDs are actor handles. +TASKID_BYTES_SIZE = TaskID.size() +ACTORID_BYTES_SIZE = ActorID.size() +JOBID_BYTES_SIZE = JobID.size() +# We need to multiply 2 because we need bits size instead of bytes size. +TASKID_RANDOM_BITS_SIZE = (TASKID_BYTES_SIZE - ACTORID_BYTES_SIZE) * 2 +ACTORID_RANDOM_BITS_SIZE = (ACTORID_BYTES_SIZE - JOBID_BYTES_SIZE) * 2 + + +def decode_object_id_if_needed(object_id: str) -> bytes: + """Decode objectID bytes string. + + gRPC reply contains an objectID that is encodded by Base64. + This function is used to decode the objectID. + Note that there are times that objectID is already decoded as + a hex string. In this case, just convert it to a binary number. + """ + if object_id.endswith("="): + # If the object id ends with =, that means it is base64 encoded. + # Object ids will always have = as a padding + # when it is base64 encoded because objectID is always 20B. + return base64.standard_b64decode(object_id) + else: + return ray.utils.hex_to_binary(object_id) + + +class SortingType(Enum): + PID = 1 + OBJECT_SIZE = 3 + REFERENCE_TYPE = 4 + + +class GroupByType(Enum): + NODE_ADDRESS = 2 + + +class ReferenceType: + # We don't use enum because enum is not json serializable. + ACTOR_HANDLE = "ACTOR_HANDLE" + PINNED_IN_MEMORY = "PINNED_IN_MEMORY" + LOCAL_REFERENCE = "LOCAL_REFERENCE" + USED_BY_PENDING_TASK = "USED_BY_PENDING_TASK" + CAPTURED_IN_OBJECT = "CAPTURED_IN_OBJECT" + UNKNOWN_STATUS = "UNKNOWN_STATUS" + + +class MemoryTableEntry: + def __init__(self, *, object_ref: dict, node_address: str, is_driver: bool, + pid: int): + # worker info + self.is_driver = is_driver + self.pid = pid + self.node_address = node_address + + # object info + self.object_size = int(object_ref.get("objectSize", -1)) + self.call_site = object_ref.get("callSite", "") + self.object_id = ray.ObjectID( + decode_object_id_if_needed(object_ref["objectId"])) + + # reference info + self.local_ref_count = int(object_ref.get("localRefCount", 0)) + self.pinned_in_memory = bool(object_ref.get("pinnedInMemory", False)) + self.submitted_task_ref_count = int( + object_ref.get("submittedTaskRefCount", 0)) + self.contained_in_owned = [ + ray.ObjectID(decode_object_id_if_needed(object_id)) + for object_id in object_ref.get("containedInOwned", []) + ] + self.reference_type = self._get_reference_type() + + def is_valid(self) -> bool: + # If the entry doesn't have a reference type or some invalid state, + # (e.g., no object ID presented), it is considered invalid. + if (not self.pinned_in_memory and self.local_ref_count == 0 + and self.submitted_task_ref_count == 0 + and len(self.contained_in_owned) == 0): + return False + elif self.object_id.is_nil(): + return False + else: + return True + + def group_key(self, group_by_type: GroupByType) -> str: + if group_by_type == GroupByType.NODE_ADDRESS: + return self.node_address + else: + raise ValueError( + "group by type {} is invalid.".format(group_by_type)) + + def _get_reference_type(self) -> str: + if self._is_object_id_actor_handle(): + return ReferenceType.ACTOR_HANDLE + if self.pinned_in_memory: + return ReferenceType.PINNED_IN_MEMORY + elif self.submitted_task_ref_count > 0: + return ReferenceType.USED_BY_PENDING_TASK + elif self.local_ref_count > 0: + return ReferenceType.LOCAL_REFERENCE + elif len(self.contained_in_owned) > 0: + return ReferenceType.CAPTURED_IN_OBJECT + else: + return ReferenceType.UNKNOWN_STATUS + + def _is_object_id_actor_handle(self) -> bool: + object_id_hex = self.object_id.hex() + + # random (8B) | ActorID(6B) | flag (2B) | index (6B) + # ActorID(6B) == ActorRandomByte(4B) + JobID(2B) + # If random bytes are all 'f', but ActorRandomBytes + # are not all 'f', that means it is an actor creation + # task, which is an actor handle. + random_bits = object_id_hex[:TASKID_RANDOM_BITS_SIZE] + actor_random_bits = object_id_hex[TASKID_RANDOM_BITS_SIZE: + TASKID_RANDOM_BITS_SIZE + + ACTORID_RANDOM_BITS_SIZE] + if (random_bits == "f" * 16 and not actor_random_bits == "f" * 8): + return True + else: + return False + + def __dict__(self): + return { + "object_id": self.object_id.hex(), + "pid": self.pid, + "node_ip_address": self.node_address, + "object_size": self.object_size, + "reference_type": self.reference_type, + "call_site": self.call_site, + "local_ref_count": self.local_ref_count, + "pinned_in_memory": self.pinned_in_memory, + "submitted_task_ref_count": self.submitted_task_ref_count, + "contained_in_owned": [ + object_id.hex() for object_id in self.contained_in_owned + ], + "type": "Driver" if self.is_driver else "Worker" + } + + def __str__(self): + return self.__repr__() + + def __repr__(self): + return str(self.__dict__()) + + +class MemoryTable: + def __init__(self, + entries: List[MemoryTableEntry], + group_by_type: GroupByType = GroupByType.NODE_ADDRESS, + sort_by_type: SortingType = SortingType.PID): + self.table = entries + # Group is a list of memory tables grouped by a group key. + self.group = {} + self.summary = defaultdict(int) + if group_by_type and sort_by_type: + self.setup(group_by_type, sort_by_type) + elif group_by_type: + self._group_by(group_by_type) + elif sort_by_type: + self._sort_by(sort_by_type) + + def setup(self, group_by_type: GroupByType, sort_by_type: SortingType): + """Setup memory table. + + This will sort entries first and gruop them after. + Sort order will be still kept. + """ + self._sort_by(sort_by_type)._group_by(group_by_type) + for group_memory_table in self.group.values(): + group_memory_table.summarize() + self.summarize() + return self + + def insert_entry(self, entry: MemoryTableEntry): + self.table.append(entry) + + def summarize(self): + # Reset summary. + total_object_size = 0 + total_local_ref_count = 0 + total_pinned_in_memory = 0 + total_used_by_pending_task = 0 + total_captured_in_objects = 0 + total_actor_handles = 0 + + for entry in self.table: + if entry.object_size > 0: + total_object_size += entry.object_size + if entry.reference_type == ReferenceType.LOCAL_REFERENCE: + total_local_ref_count += 1 + elif entry.reference_type == ReferenceType.PINNED_IN_MEMORY: + total_pinned_in_memory += 1 + elif entry.reference_type == ReferenceType.USED_BY_PENDING_TASK: + total_used_by_pending_task += 1 + elif entry.reference_type == ReferenceType.CAPTURED_IN_OBJECT: + total_captured_in_objects += 1 + elif entry.reference_type == ReferenceType.ACTOR_HANDLE: + total_actor_handles += 1 + + self.summary = { + "total_object_size": total_object_size, + "total_local_ref_count": total_local_ref_count, + "total_pinned_in_memory": total_pinned_in_memory, + "total_used_by_pending_task": total_used_by_pending_task, + "total_captured_in_objects": total_captured_in_objects, + "total_actor_handles": total_actor_handles + } + return self + + def _sort_by(self, sorting_type: SortingType): + if sorting_type == SortingType.PID: + self.table.sort(key=lambda entry: entry.pid) + elif sorting_type == SortingType.OBJECT_SIZE: + self.table.sort(key=lambda entry: entry.object_size) + elif sorting_type == SortingType.REFERENCE_TYPE: + self.table.sort(key=lambda entry: entry.reference_type) + else: + raise ValueError( + "Give sorting type: {} is invalid.".format(sorting_type)) + return self + + def _group_by(self, group_by_type: GroupByType): + """Group entries and summarize the result. + + NOTE: Each group is another MemoryTable. + """ + # Reset group + self.group = {} + + # Build entries per group. + group = defaultdict(list) + for entry in self.table: + group[entry.group_key(group_by_type)].append(entry) + + # Build a group table. + for group_key, entries in group.items(): + self.group[group_key] = MemoryTable( + entries, group_by_type=None, sort_by_type=None) + for group_key, group_memory_table in self.group.items(): + group_memory_table.summarize() + return self + + def __dict__(self): + return { + "summary": self.summary, + "group": { + group_key: { + "entries": group_memory_table.get_entries(), + "summary": group_memory_table.summary + } + for group_key, group_memory_table in self.group.items() + } + } + + def get_entries(self) -> List[dict]: + return [entry.__dict__() for entry in self.table] + + def __repr__(self): + return str(self.__dict__()) + + def __str__(self): + return self.__repr__() + + +def construct_memory_table(workers_info_by_node: dict) -> MemoryTable: + memory_table_entries = [] + for node_id, worker_infos in workers_info_by_node.items(): + for worker_info in worker_infos: + pid = worker_info["pid"] + is_driver = worker_info.get("isDriver", False) + core_worker_stats = worker_info["coreWorkerStats"] + node_address = core_worker_stats["ipAddress"] + object_refs = core_worker_stats.get("objectRefs", []) + + for object_ref in object_refs: + memory_table_entry = MemoryTableEntry( + object_ref=object_ref, + node_address=node_address, + is_driver=is_driver, + pid=pid) + if memory_table_entry.is_valid(): + memory_table_entries.append(memory_table_entry) + + memory_table = MemoryTable(memory_table_entries) + return memory_table diff --git a/python/ray/tests/test_metrics.py b/python/ray/tests/test_metrics.py index 8f2c79846..591ca2f44 100644 --- a/python/ray/tests/test_metrics.py +++ b/python/ray/tests/test_metrics.py @@ -10,9 +10,11 @@ from ray.core.generated import node_manager_pb2 from ray.core.generated import node_manager_pb2_grpc from ray.core.generated import reporter_pb2 from ray.core.generated import reporter_pb2_grpc +from ray.dashboard.memory import (ReferenceType, decode_object_id_if_needed, + MemoryTableEntry, MemoryTable, SortingType) from ray.test_utils import (RayTestTimeoutException, wait_until_succeeded_without_exception, - wait_until_server_available) + wait_until_server_available, wait_for_condition) import psutil # We must import psutil after ray because we bundle it with ray. @@ -377,6 +379,446 @@ def test_profiling_info_endpoint(shutdown_only): assert profiling_stats is not None +# This variable is used inside test_memory_dashboard. +# It is defined as a global variable to be used across all nested test +# functions. We use it because memory table is updated every one second, +# and we need to have a way to verify if the test is running with a fresh +# new memory table. +prev_memory_table = MemoryTable([]).__dict__()["group"] + + +def test_memory_dashboard(shutdown_only): + """Test Memory table. + + These tests verify examples in this document. + https://docs.ray.io/en/latest/memory-management.html#debugging-using-ray-memory + """ + addresses = ray.init(num_cpus=2) + webui_url = addresses["webui_url"].replace("localhost", "http://127.0.0.1") + assert (wait_until_server_available(addresses["webui_url"]) is True) + + def get_memory_table(): + memory_table = requests.get(webui_url + "/api/memory_table").json() + return memory_table["result"] + + def memory_table_ready(): + """Wait until the new fresh memory table is ready.""" + global prev_memory_table + memory_table = get_memory_table() + from pprint import pprint + print("Current") + pprint(memory_table) + print("Prev") + pprint(prev_memory_table) + is_ready = memory_table["group"] != prev_memory_table + prev_memory_table = memory_table["group"] + return is_ready + + def stop_memory_table(): + requests.get(webui_url + "/api/stop_memory_table").json() + + def test_local_reference(): + @ray.remote + def f(arg): + return arg + + # a and b are local references. + a = ray.put(None) # Noqa F841 + b = f.remote(None) # Noqa F841 + + wait_for_condition(memory_table_ready) + memory_table = get_memory_table() + summary = memory_table["summary"] + group = memory_table["group"] + assert summary["total_captured_in_objects"] == 0 + assert summary["total_pinned_in_memory"] == 0 + assert summary["total_used_by_pending_task"] == 0 + assert summary["total_local_ref_count"] == 2 + for table in group.values(): + for entry in table["entries"]: + assert ( + entry["reference_type"] == ReferenceType.LOCAL_REFERENCE) + stop_memory_table() + return True + + def test_object_pineed_in_memory(): + import numpy as np + + a = ray.put(np.zeros(1)) + b = ray.get(a) # Noqa F841 + del a + + wait_for_condition(memory_table_ready) + memory_table = get_memory_table() + summary = memory_table["summary"] + group = memory_table["group"] + assert summary["total_captured_in_objects"] == 0 + assert summary["total_pinned_in_memory"] == 1 + assert summary["total_used_by_pending_task"] == 0 + assert summary["total_local_ref_count"] == 0 + for table in group.values(): + for entry in table["entries"]: + assert ( + entry["reference_type"] == ReferenceType.PINNED_IN_MEMORY) + stop_memory_table() + return True + + def test_pending_task_references(): + @ray.remote + def f(arg): + time.sleep(1) + + a = ray.put(None) # Noqa F841 + b = f.remote(a) # Noqa F841 + + wait_for_condition(memory_table_ready) + memory_table = get_memory_table() + summary = memory_table["summary"] + assert summary["total_captured_in_objects"] == 0 + assert summary["total_pinned_in_memory"] == 1 + assert summary["total_used_by_pending_task"] == 1 + assert summary["total_local_ref_count"] == 1 + # Make sure the function f is done before going to the next test. + # Otherwise, the memory table will be corrupted because the + # task f won't be done when the next test is running. + ray.get(b) + stop_memory_table() + return True + + def test_serialized_object_id_reference(): + @ray.remote + def f(arg): + time.sleep(1) + + a = ray.put(None) # Noqa F841 + b = f.remote([a]) # Noqa F841 + + wait_for_condition(memory_table_ready) + memory_table = get_memory_table() + summary = memory_table["summary"] + assert summary["total_captured_in_objects"] == 0 + assert summary["total_pinned_in_memory"] == 0 + assert summary["total_used_by_pending_task"] == 1 + assert summary["total_local_ref_count"] == 2 + # Make sure the function f is done before going to the next test. + # Otherwise, the memory table will be corrupted because the + # task f won't be done when the next test is running. + ray.get(b) + stop_memory_table() + return True + + def test_captured_object_id_reference(): + a = ray.put(None) + b = ray.put([a]) # Noqa F841 + del a + + wait_for_condition(memory_table_ready) + memory_table = get_memory_table() + summary = memory_table["summary"] + assert summary["total_captured_in_objects"] == 1 + assert summary["total_pinned_in_memory"] == 0 + assert summary["total_used_by_pending_task"] == 0 + assert summary["total_local_ref_count"] == 1 + stop_memory_table() + return True + + def test_actor_handle_reference(): + @ray.remote + class Actor: + pass + + a = Actor.remote() # Noqa F841 + b = Actor.remote() # Noqa F841 + c = Actor.remote() # Noqa F841 + + wait_for_condition(memory_table_ready) + memory_table = get_memory_table() + summary = memory_table["summary"] + group = memory_table["group"] + assert summary["total_captured_in_objects"] == 0 + assert summary["total_pinned_in_memory"] == 0 + assert summary["total_used_by_pending_task"] == 0 + assert summary["total_local_ref_count"] == 0 + assert summary["total_actor_handles"] == 3 + for table in group.values(): + for entry in table["entries"]: + assert (entry["reference_type"] == ReferenceType.ACTOR_HANDLE) + stop_memory_table() + return True + + # These tests should be retried because it takes at least one second + # to get the fresh new memory table. It is because memory table is updated + # Whenever raylet and node info is renewed which takes 1 second. + assert (wait_for_condition( + test_local_reference, timeout=30000, retry_interval_ms=1000) is True) + + assert (wait_for_condition( + test_object_pineed_in_memory, timeout=30000, retry_interval_ms=1000) is + True) + + assert (wait_for_condition( + test_pending_task_references, timeout=30000, retry_interval_ms=1000) is + True) + + assert (wait_for_condition( + test_serialized_object_id_reference, + timeout=30000, + retry_interval_ms=1000) is True) + + assert (wait_for_condition( + test_captured_object_id_reference, + timeout=30000, + retry_interval_ms=1000) is True) + + assert (wait_for_condition( + test_actor_handle_reference, timeout=30000, retry_interval_ms=1000) is + True) + + +"""Memory Table Unit Test""" + +NODE_ADDRESS = "127.0.0.1" +IS_DRIVER = True +PID = 1 +OBJECT_ID = "7wpsIhgZiBz/////AQAAyAEAAAA=" +ACTOR_ID = "fffffffffffffffff66d17ba010000c801000000" +DECODED_ID = decode_object_id_if_needed(OBJECT_ID) +OBJECT_SIZE = 100 + + +def build_memory_entry(*, + local_ref_count, + pinned_in_memory, + submitted_task_reference_count, + contained_in_owned, + object_size, + pid, + object_id=OBJECT_ID, + node_address=NODE_ADDRESS): + object_ref = { + "objectId": object_id, + "callSite": "(task call) /Users:458", + "objectSize": object_size, + "localRefCount": local_ref_count, + "pinnedInMemory": pinned_in_memory, + "submittedTaskRefCount": submitted_task_reference_count, + "containedInOwned": contained_in_owned + } + return MemoryTableEntry( + object_ref=object_ref, + node_address=node_address, + is_driver=IS_DRIVER, + pid=pid) + + +def build_local_reference_entry(object_size=OBJECT_SIZE, + pid=PID, + node_address=NODE_ADDRESS): + return build_memory_entry( + local_ref_count=1, + pinned_in_memory=False, + submitted_task_reference_count=0, + contained_in_owned=[], + object_size=object_size, + pid=pid, + node_address=node_address) + + +def build_used_by_pending_task_entry(object_size=OBJECT_SIZE, + pid=PID, + node_address=NODE_ADDRESS): + return build_memory_entry( + local_ref_count=0, + pinned_in_memory=False, + submitted_task_reference_count=2, + contained_in_owned=[], + object_size=object_size, + pid=pid, + node_address=node_address) + + +def build_captured_in_object_entry(object_size=OBJECT_SIZE, + pid=PID, + node_address=NODE_ADDRESS): + return build_memory_entry( + local_ref_count=0, + pinned_in_memory=False, + submitted_task_reference_count=0, + contained_in_owned=[OBJECT_ID], + object_size=object_size, + pid=pid, + node_address=node_address) + + +def build_actor_handle_entry(object_size=OBJECT_SIZE, + pid=PID, + node_address=NODE_ADDRESS): + return build_memory_entry( + local_ref_count=1, + pinned_in_memory=False, + submitted_task_reference_count=0, + contained_in_owned=[], + object_size=object_size, + pid=pid, + node_address=node_address, + object_id=ACTOR_ID) + + +def build_pinned_in_memory_entry(object_size=OBJECT_SIZE, + pid=PID, + node_address=NODE_ADDRESS): + return build_memory_entry( + local_ref_count=0, + pinned_in_memory=True, + submitted_task_reference_count=0, + contained_in_owned=[], + object_size=object_size, + pid=pid, + node_address=node_address) + + +def build_entry(object_size=OBJECT_SIZE, + pid=PID, + node_address=NODE_ADDRESS, + reference_type=ReferenceType.PINNED_IN_MEMORY): + if reference_type == ReferenceType.USED_BY_PENDING_TASK: + return build_used_by_pending_task_entry( + pid=pid, object_size=object_size, node_address=node_address) + elif reference_type == ReferenceType.LOCAL_REFERENCE: + return build_local_reference_entry( + pid=pid, object_size=object_size, node_address=node_address) + elif reference_type == ReferenceType.PINNED_IN_MEMORY: + return build_pinned_in_memory_entry( + pid=pid, object_size=object_size, node_address=node_address) + elif reference_type == ReferenceType.ACTOR_HANDLE: + return build_actor_handle_entry( + pid=pid, object_size=object_size, node_address=node_address) + elif reference_type == ReferenceType.CAPTURED_IN_OBJECT: + return build_captured_in_object_entry( + pid=pid, object_size=object_size, node_address=node_address) + + +def test_invalid_memory_entry(): + memory_entry = build_memory_entry( + local_ref_count=0, + pinned_in_memory=False, + submitted_task_reference_count=0, + contained_in_owned=[], + object_size=OBJECT_SIZE, + pid=PID) + assert memory_entry.is_valid() is False + memory_entry = build_memory_entry( + local_ref_count=0, + pinned_in_memory=False, + submitted_task_reference_count=0, + contained_in_owned=[], + object_size=-1, + pid=PID) + assert memory_entry.is_valid() is False + + +def test_valid_reference_memory_entry(): + memory_entry = build_local_reference_entry() + assert memory_entry.reference_type == ReferenceType.LOCAL_REFERENCE + assert memory_entry.object_id == ray.ObjectID( + decode_object_id_if_needed(OBJECT_ID)) + assert memory_entry.is_valid() is True + + +def test_reference_type(): + # pinned in memory + memory_entry = build_pinned_in_memory_entry() + assert memory_entry.reference_type == ReferenceType.PINNED_IN_MEMORY + + # used by pending task + memory_entry = build_used_by_pending_task_entry() + assert memory_entry.reference_type == ReferenceType.USED_BY_PENDING_TASK + + # captued in object + memory_entry = build_captured_in_object_entry() + assert memory_entry.reference_type == ReferenceType.CAPTURED_IN_OBJECT + + # actor handle + memory_entry = build_actor_handle_entry() + assert memory_entry.reference_type == ReferenceType.ACTOR_HANDLE + + +def test_memory_table_summary(): + entries = [ + build_pinned_in_memory_entry(), + build_used_by_pending_task_entry(), + build_captured_in_object_entry(), + build_actor_handle_entry(), + build_local_reference_entry(), + build_local_reference_entry() + ] + memory_table = MemoryTable(entries) + assert len(memory_table.group) == 1 + assert memory_table.summary["total_actor_handles"] == 1 + assert memory_table.summary["total_captured_in_objects"] == 1 + assert memory_table.summary["total_local_ref_count"] == 2 + assert memory_table.summary[ + "total_object_size"] == len(entries) * OBJECT_SIZE + assert memory_table.summary["total_pinned_in_memory"] == 1 + assert memory_table.summary["total_used_by_pending_task"] == 1 + + +def test_memory_table_sort_by_pid(): + unsort = [1, 3, 2] + entries = [build_entry(pid=pid) for pid in unsort] + memory_table = MemoryTable(entries, sort_by_type=SortingType.PID) + sort = sorted(unsort) + for pid, entry in zip(sort, memory_table.table): + assert pid == entry.pid + + +def test_memory_table_sort_by_reference_type(): + unsort = [ + ReferenceType.USED_BY_PENDING_TASK, ReferenceType.LOCAL_REFERENCE, + ReferenceType.LOCAL_REFERENCE, ReferenceType.PINNED_IN_MEMORY + ] + entries = [ + build_entry(reference_type=reference_type) for reference_type in unsort + ] + memory_table = MemoryTable( + entries, sort_by_type=SortingType.REFERENCE_TYPE) + sort = sorted(unsort) + for reference_type, entry in zip(sort, memory_table.table): + assert reference_type == entry.reference_type + + +def test_memory_table_sort_by_object_size(): + unsort = [312, 214, -1, 1244, 642] + entries = [build_entry(object_size=object_size) for object_size in unsort] + memory_table = MemoryTable(entries, sort_by_type=SortingType.OBJECT_SIZE) + sort = sorted(unsort) + for object_size, entry in zip(sort, memory_table.table): + assert object_size == entry.object_size + + +def test_group_by(): + node_second = "127.0.0.2" + node_first = "127.0.0.1" + entries = [ + build_entry(node_address=node_second, pid=2), + build_entry(node_address=node_second, pid=1), + build_entry(node_address=node_first, pid=2), + build_entry(node_address=node_first, pid=1) + ] + memory_table = MemoryTable(entries) + + # Make sure it is correctly grouped + assert node_first in memory_table.group + assert node_second in memory_table.group + + # make sure pid is sorted in the right order. + for group_key, group_memory_table in memory_table.group.items(): + pid = 1 + for entry in group_memory_table.table: + assert pid == entry.pid + pid += 1 + + if __name__ == "__main__": import pytest import sys