mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 09:42:22 +08:00
[gRPC] Migrate gcs data structures to protobuf (#5024)
This commit is contained in:
+31
-40
@@ -2,38 +2,39 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import flatbuffers
|
||||
import ray.core.generated.ErrorTableData
|
||||
|
||||
from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData
|
||||
from ray.core.generated.ClientTableData import ClientTableData
|
||||
from ray.core.generated.DriverTableData import DriverTableData
|
||||
from ray.core.generated.ErrorTableData import ErrorTableData
|
||||
from ray.core.generated.GcsEntry import GcsEntry
|
||||
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
|
||||
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
|
||||
from ray.core.generated.Language import Language
|
||||
from ray.core.generated.ObjectTableData import ObjectTableData
|
||||
from ray.core.generated.ProfileTableData import ProfileTableData
|
||||
from ray.core.generated.TablePrefix import TablePrefix
|
||||
from ray.core.generated.TablePubsub import TablePubsub
|
||||
|
||||
from ray.core.generated.ray.protocol.Task import Task
|
||||
|
||||
from ray.core.generated.gcs_pb2 import (
|
||||
ActorCheckpointIdData,
|
||||
ClientTableData,
|
||||
DriverTableData,
|
||||
ErrorTableData,
|
||||
ErrorType,
|
||||
GcsEntry,
|
||||
HeartbeatBatchTableData,
|
||||
HeartbeatTableData,
|
||||
ObjectTableData,
|
||||
ProfileTableData,
|
||||
TablePrefix,
|
||||
TablePubsub,
|
||||
TaskTableData,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActorCheckpointIdData",
|
||||
"ClientTableData",
|
||||
"DriverTableData",
|
||||
"ErrorTableData",
|
||||
"ErrorType",
|
||||
"GcsEntry",
|
||||
"HeartbeatBatchTableData",
|
||||
"HeartbeatTableData",
|
||||
"Language",
|
||||
"ObjectTableData",
|
||||
"ProfileTableData",
|
||||
"TablePrefix",
|
||||
"TablePubsub",
|
||||
"Task",
|
||||
"TaskTableData",
|
||||
"construct_error_message",
|
||||
]
|
||||
|
||||
@@ -42,13 +43,16 @@ LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL"
|
||||
REPORTER_CHANNEL = "RAY_REPORTER"
|
||||
|
||||
# xray heartbeats
|
||||
XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii")
|
||||
XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii")
|
||||
XRAY_HEARTBEAT_CHANNEL = str(
|
||||
TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii")
|
||||
XRAY_HEARTBEAT_BATCH_CHANNEL = str(
|
||||
TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii")
|
||||
|
||||
# xray driver updates
|
||||
XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii")
|
||||
XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii")
|
||||
|
||||
# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
|
||||
# These prefixes must be kept up-to-date with the TablePrefix enum in
|
||||
# gcs.proto.
|
||||
# TODO(rkn): We should use scoped enums, in which case we should be able to
|
||||
# just access the flatbuffer generated values.
|
||||
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
|
||||
@@ -70,22 +74,9 @@ def construct_error_message(driver_id, error_type, message, timestamp):
|
||||
Returns:
|
||||
The serialized object.
|
||||
"""
|
||||
builder = flatbuffers.Builder(0)
|
||||
driver_offset = builder.CreateString(driver_id.binary())
|
||||
error_type_offset = builder.CreateString(error_type)
|
||||
message_offset = builder.CreateString(message)
|
||||
|
||||
ray.core.generated.ErrorTableData.ErrorTableDataStart(builder)
|
||||
ray.core.generated.ErrorTableData.ErrorTableDataAddDriverId(
|
||||
builder, driver_offset)
|
||||
ray.core.generated.ErrorTableData.ErrorTableDataAddType(
|
||||
builder, error_type_offset)
|
||||
ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage(
|
||||
builder, message_offset)
|
||||
ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp(
|
||||
builder, timestamp)
|
||||
error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd(
|
||||
builder)
|
||||
builder.Finish(error_data_offset)
|
||||
|
||||
return bytes(builder.Output())
|
||||
data = ErrorTableData()
|
||||
data.driver_id = driver_id.binary()
|
||||
data.type = error_type
|
||||
data.error_message = message
|
||||
data.timestamp = timestamp
|
||||
return data.SerializeToString()
|
||||
|
||||
+15
-18
@@ -101,28 +101,26 @@ class Monitor(object):
|
||||
def xray_heartbeat_batch_handler(self, unused_channel, data):
|
||||
"""Handle an xray heartbeat batch message from Redis."""
|
||||
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
|
||||
heartbeat_data = gcs_entries.Entries(0)
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.FromString(data)
|
||||
heartbeat_data = gcs_entries.entries[0]
|
||||
|
||||
message = (ray.gcs_utils.HeartbeatBatchTableData.
|
||||
GetRootAsHeartbeatBatchTableData(heartbeat_data, 0))
|
||||
message = ray.gcs_utils.HeartbeatBatchTableData.FromString(
|
||||
heartbeat_data)
|
||||
|
||||
for j in range(message.BatchLength()):
|
||||
heartbeat_message = message.Batch(j)
|
||||
|
||||
num_resources = heartbeat_message.ResourcesTotalLabelLength()
|
||||
for heartbeat_message in message.batch:
|
||||
num_resources = len(heartbeat_message.resources_available_label)
|
||||
static_resources = {}
|
||||
dynamic_resources = {}
|
||||
for i in range(num_resources):
|
||||
dyn = heartbeat_message.ResourcesAvailableLabel(i)
|
||||
static = heartbeat_message.ResourcesTotalLabel(i)
|
||||
dyn = heartbeat_message.resources_available_label[i]
|
||||
static = heartbeat_message.resources_total_label[i]
|
||||
dynamic_resources[dyn] = (
|
||||
heartbeat_message.ResourcesAvailableCapacity(i))
|
||||
heartbeat_message.resources_available_capacity[i])
|
||||
static_resources[static] = (
|
||||
heartbeat_message.ResourcesTotalCapacity(i))
|
||||
heartbeat_message.resources_total_capacity[i])
|
||||
|
||||
# Update the load metrics for this raylet.
|
||||
client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId())
|
||||
client_id = ray.utils.binary_to_hex(heartbeat_message.client_id)
|
||||
ip = self.raylet_id_to_ip_map.get(client_id)
|
||||
if ip:
|
||||
self.load_metrics.update(ip, static_resources,
|
||||
@@ -207,11 +205,10 @@ class Monitor(object):
|
||||
unused_channel: The message channel.
|
||||
data: The message data.
|
||||
"""
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
|
||||
driver_data = gcs_entries.Entries(0)
|
||||
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
|
||||
driver_data, 0)
|
||||
driver_id = message.DriverId()
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.FromString(data)
|
||||
driver_data = gcs_entries.entries[0]
|
||||
message = ray.gcs_utils.DriverTableData.FromString(driver_data)
|
||||
driver_id = message.driver_id
|
||||
logger.info("Monitor: "
|
||||
"XRay Driver {} has been removed.".format(
|
||||
binary_to_hex(driver_id)))
|
||||
|
||||
+95
-135
@@ -10,11 +10,11 @@ import time
|
||||
|
||||
import ray
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
import ray.gcs_utils
|
||||
|
||||
from ray.ray_constants import ID_SIZE
|
||||
from ray import services
|
||||
from ray.core.generated.EntryType import EntryType
|
||||
from ray import (
|
||||
gcs_utils,
|
||||
services,
|
||||
)
|
||||
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
|
||||
hex_to_binary)
|
||||
|
||||
@@ -31,9 +31,9 @@ def _parse_client_table(redis_client):
|
||||
A list of information about the nodes in the cluster.
|
||||
"""
|
||||
NIL_CLIENT_ID = ray.ObjectID.nil().binary()
|
||||
message = redis_client.execute_command("RAY.TABLE_LOOKUP",
|
||||
ray.gcs_utils.TablePrefix.CLIENT,
|
||||
"", NIL_CLIENT_ID)
|
||||
message = redis_client.execute_command(
|
||||
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "",
|
||||
NIL_CLIENT_ID)
|
||||
|
||||
# Handle the case where no clients are returned. This should only
|
||||
# occur potentially immediately after the cluster is started.
|
||||
@@ -41,36 +41,31 @@ def _parse_client_table(redis_client):
|
||||
return []
|
||||
|
||||
node_info = {}
|
||||
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
gcs_entry = gcs_utils.GcsEntry.FromString(message)
|
||||
|
||||
ordered_client_ids = []
|
||||
|
||||
# Since GCS entries are append-only, we override so that
|
||||
# only the latest entries are kept.
|
||||
for i in range(gcs_entry.EntriesLength()):
|
||||
client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
|
||||
gcs_entry.Entries(i), 0))
|
||||
for entry in gcs_entry.entries:
|
||||
client = gcs_utils.ClientTableData.FromString(entry)
|
||||
|
||||
resources = {
|
||||
decode(client.ResourcesTotalLabel(i)):
|
||||
client.ResourcesTotalCapacity(i)
|
||||
for i in range(client.ResourcesTotalLabelLength())
|
||||
client.resources_total_label[i]: client.resources_total_capacity[i]
|
||||
for i in range(len(client.resources_total_label))
|
||||
}
|
||||
client_id = ray.utils.binary_to_hex(client.ClientId())
|
||||
client_id = ray.utils.binary_to_hex(client.client_id)
|
||||
|
||||
if client.EntryType() == EntryType.INSERTION:
|
||||
if client.entry_type == gcs_utils.ClientTableData.INSERTION:
|
||||
ordered_client_ids.append(client_id)
|
||||
node_info[client_id] = {
|
||||
"ClientID": client_id,
|
||||
"EntryType": client.EntryType(),
|
||||
"NodeManagerAddress": decode(
|
||||
client.NodeManagerAddress(), allow_none=True),
|
||||
"NodeManagerPort": client.NodeManagerPort(),
|
||||
"ObjectManagerPort": client.ObjectManagerPort(),
|
||||
"ObjectStoreSocketName": decode(
|
||||
client.ObjectStoreSocketName(), allow_none=True),
|
||||
"RayletSocketName": decode(
|
||||
client.RayletSocketName(), allow_none=True),
|
||||
"EntryType": client.entry_type,
|
||||
"NodeManagerAddress": client.node_manager_address,
|
||||
"NodeManagerPort": client.node_manager_port,
|
||||
"ObjectManagerPort": client.object_manager_port,
|
||||
"ObjectStoreSocketName": client.object_store_socket_name,
|
||||
"RayletSocketName": client.raylet_socket_name,
|
||||
"Resources": resources
|
||||
}
|
||||
|
||||
@@ -79,22 +74,23 @@ def _parse_client_table(redis_client):
|
||||
# it cannot have previously been removed.
|
||||
else:
|
||||
assert client_id in node_info, "Client not found!"
|
||||
assert node_info[client_id]["EntryType"] != EntryType.DELETION, (
|
||||
"Unexpected updation of deleted client.")
|
||||
is_deletion = (node_info[client_id]["EntryType"] !=
|
||||
gcs_utils.ClientTableData.DELETION)
|
||||
assert is_deletion, "Unexpected updation of deleted client."
|
||||
res_map = node_info[client_id]["Resources"]
|
||||
if client.EntryType() == EntryType.RES_CREATEUPDATE:
|
||||
if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE:
|
||||
for res in resources:
|
||||
res_map[res] = resources[res]
|
||||
elif client.EntryType() == EntryType.RES_DELETE:
|
||||
elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE:
|
||||
for res in resources:
|
||||
res_map.pop(res, None)
|
||||
elif client.EntryType() == EntryType.DELETION:
|
||||
elif client.entry_type == gcs_utils.ClientTableData.DELETION:
|
||||
pass # Do nothing with the resmap if client deletion
|
||||
else:
|
||||
raise RuntimeError("Unexpected EntryType {}".format(
|
||||
client.EntryType()))
|
||||
client.entry_type))
|
||||
node_info[client_id]["Resources"] = res_map
|
||||
node_info[client_id]["EntryType"] = client.EntryType()
|
||||
node_info[client_id]["EntryType"] = client.entry_type
|
||||
# NOTE: We return the list comprehension below instead of simply doing
|
||||
# 'list(node_info.values())' in order to have the nodes appear in the order
|
||||
# that they joined the cluster. Python dictionaries do not preserve
|
||||
@@ -244,20 +240,19 @@ class GlobalState(object):
|
||||
|
||||
# Return information about a single object ID.
|
||||
message = self._execute_command(object_id, "RAY.TABLE_LOOKUP",
|
||||
ray.gcs_utils.TablePrefix.OBJECT, "",
|
||||
object_id.binary())
|
||||
gcs_utils.TablePrefix.Value("OBJECT"),
|
||||
"", object_id.binary())
|
||||
if message is None:
|
||||
return {}
|
||||
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
gcs_entry = gcs_utils.GcsEntry.FromString(message)
|
||||
|
||||
assert gcs_entry.EntriesLength() > 0
|
||||
assert len(gcs_entry.entries) > 0
|
||||
|
||||
entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
|
||||
gcs_entry.Entries(0), 0)
|
||||
entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0])
|
||||
|
||||
object_info = {
|
||||
"DataSize": entry.ObjectSize(),
|
||||
"Manager": entry.Manager(),
|
||||
"DataSize": entry.object_size,
|
||||
"Manager": entry.manager,
|
||||
}
|
||||
|
||||
return object_info
|
||||
@@ -278,10 +273,9 @@ class GlobalState(object):
|
||||
return self._object_table(object_id)
|
||||
else:
|
||||
# Return the entire object table.
|
||||
object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string +
|
||||
"*")
|
||||
object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*")
|
||||
object_ids_binary = {
|
||||
key[len(ray.gcs_utils.TablePrefix_OBJECT_string):]
|
||||
key[len(gcs_utils.TablePrefix_OBJECT_string):]
|
||||
for key in object_keys
|
||||
}
|
||||
|
||||
@@ -301,17 +295,18 @@ class GlobalState(object):
|
||||
A dictionary with information about the task ID in question.
|
||||
"""
|
||||
assert isinstance(task_id, ray.TaskID)
|
||||
message = self._execute_command(task_id, "RAY.TABLE_LOOKUP",
|
||||
ray.gcs_utils.TablePrefix.RAYLET_TASK,
|
||||
"", task_id.binary())
|
||||
message = self._execute_command(
|
||||
task_id, "RAY.TABLE_LOOKUP",
|
||||
gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary())
|
||||
if message is None:
|
||||
return {}
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
gcs_entries = gcs_utils.GcsEntry.FromString(message)
|
||||
|
||||
assert gcs_entries.EntriesLength() == 1
|
||||
|
||||
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
|
||||
gcs_entries.Entries(0), 0)
|
||||
assert len(gcs_entries.entries) == 1
|
||||
task_table_data = gcs_utils.TaskTableData.FromString(
|
||||
gcs_entries.entries[0])
|
||||
task_table_message = gcs_utils.Task.GetRootAsTask(
|
||||
task_table_data.task, 0)
|
||||
|
||||
execution_spec = task_table_message.TaskExecutionSpec()
|
||||
task_spec = task_table_message.TaskSpecification()
|
||||
@@ -368,9 +363,9 @@ class GlobalState(object):
|
||||
return self._task_table(task_id)
|
||||
else:
|
||||
task_table_keys = self._keys(
|
||||
ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
|
||||
gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
|
||||
task_ids_binary = [
|
||||
key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):]
|
||||
key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):]
|
||||
for key in task_table_keys
|
||||
]
|
||||
|
||||
@@ -380,27 +375,6 @@ class GlobalState(object):
|
||||
ray.TaskID(task_id_binary))
|
||||
return results
|
||||
|
||||
def function_table(self, function_id=None):
|
||||
"""Fetch and parse the function table.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps function IDs to information about the
|
||||
function.
|
||||
"""
|
||||
self._check_connected()
|
||||
function_table_keys = self.redis_client.keys(
|
||||
ray.gcs_utils.FUNCTION_PREFIX + "*")
|
||||
results = {}
|
||||
for key in function_table_keys:
|
||||
info = self.redis_client.hgetall(key)
|
||||
function_info_parsed = {
|
||||
"DriverID": binary_to_hex(info[b"driver_id"]),
|
||||
"Module": decode(info[b"module"]),
|
||||
"Name": decode(info[b"name"])
|
||||
}
|
||||
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
|
||||
return results
|
||||
|
||||
def client_table(self):
|
||||
"""Fetch and parse the Redis DB client table.
|
||||
|
||||
@@ -423,37 +397,32 @@ class GlobalState(object):
|
||||
# TODO(rkn): This method should support limiting the number of log
|
||||
# events and should also support returning a window of events.
|
||||
message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP",
|
||||
ray.gcs_utils.TablePrefix.PROFILE, "",
|
||||
batch_id.binary())
|
||||
gcs_utils.TablePrefix.Value("PROFILE"),
|
||||
"", batch_id.binary())
|
||||
|
||||
if message is None:
|
||||
return []
|
||||
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
gcs_entries = gcs_utils.GcsEntry.FromString(message)
|
||||
|
||||
profile_events = []
|
||||
for i in range(gcs_entries.EntriesLength()):
|
||||
profile_table_message = (
|
||||
ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData(
|
||||
gcs_entries.Entries(i), 0))
|
||||
for entry in gcs_entries.entries:
|
||||
profile_table_message = gcs_utils.ProfileTableData.FromString(
|
||||
entry)
|
||||
|
||||
component_type = decode(profile_table_message.ComponentType())
|
||||
component_id = binary_to_hex(profile_table_message.ComponentId())
|
||||
node_ip_address = decode(
|
||||
profile_table_message.NodeIpAddress(), allow_none=True)
|
||||
|
||||
for j in range(profile_table_message.ProfileEventsLength()):
|
||||
profile_event_message = profile_table_message.ProfileEvents(j)
|
||||
component_type = profile_table_message.component_type
|
||||
component_id = binary_to_hex(profile_table_message.component_id)
|
||||
node_ip_address = profile_table_message.node_ip_address
|
||||
|
||||
for profile_event_message in profile_table_message.profile_events:
|
||||
profile_event = {
|
||||
"event_type": decode(profile_event_message.EventType()),
|
||||
"event_type": profile_event_message.event_type,
|
||||
"component_id": component_id,
|
||||
"node_ip_address": node_ip_address,
|
||||
"component_type": component_type,
|
||||
"start_time": profile_event_message.StartTime(),
|
||||
"end_time": profile_event_message.EndTime(),
|
||||
"extra_data": json.loads(
|
||||
decode(profile_event_message.ExtraData())),
|
||||
"start_time": profile_event_message.start_time,
|
||||
"end_time": profile_event_message.end_time,
|
||||
"extra_data": json.loads(profile_event_message.extra_data),
|
||||
}
|
||||
|
||||
profile_events.append(profile_event)
|
||||
@@ -462,10 +431,10 @@ class GlobalState(object):
|
||||
|
||||
def profile_table(self):
|
||||
self._check_connected()
|
||||
profile_table_keys = self._keys(
|
||||
ray.gcs_utils.TablePrefix_PROFILE_string + "*")
|
||||
profile_table_keys = self._keys(gcs_utils.TablePrefix_PROFILE_string +
|
||||
"*")
|
||||
batch_identifiers_binary = [
|
||||
key[len(ray.gcs_utils.TablePrefix_PROFILE_string):]
|
||||
key[len(gcs_utils.TablePrefix_PROFILE_string):]
|
||||
for key in profile_table_keys
|
||||
]
|
||||
|
||||
@@ -766,7 +735,7 @@ class GlobalState(object):
|
||||
clients = self.client_table()
|
||||
for client in clients:
|
||||
# Only count resources from latest entries of live clients.
|
||||
if client["EntryType"] != EntryType.DELETION:
|
||||
if client["EntryType"] != gcs_utils.ClientTableData.DELETION:
|
||||
for key, value in client["Resources"].items():
|
||||
resources[key] += value
|
||||
return dict(resources)
|
||||
@@ -776,7 +745,7 @@ class GlobalState(object):
|
||||
return {
|
||||
client["ClientID"]
|
||||
for client in self.client_table()
|
||||
if (client["EntryType"] != EntryType.DELETION)
|
||||
if (client["EntryType"] != gcs_utils.ClientTableData.DELETION)
|
||||
}
|
||||
|
||||
def available_resources(self):
|
||||
@@ -800,7 +769,7 @@ class GlobalState(object):
|
||||
for redis_client in self.redis_clients
|
||||
]
|
||||
for subscribe_client in subscribe_clients:
|
||||
subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL)
|
||||
subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL)
|
||||
|
||||
client_ids = self._live_client_ids()
|
||||
|
||||
@@ -809,24 +778,23 @@ class GlobalState(object):
|
||||
# Parse client message
|
||||
raw_message = subscribe_client.get_message()
|
||||
if (raw_message is None or raw_message["channel"] !=
|
||||
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
|
||||
gcs_utils.XRAY_HEARTBEAT_CHANNEL):
|
||||
continue
|
||||
data = raw_message["data"]
|
||||
gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
|
||||
data, 0))
|
||||
heartbeat_data = gcs_entries.Entries(0)
|
||||
message = (ray.gcs_utils.HeartbeatTableData.
|
||||
GetRootAsHeartbeatTableData(heartbeat_data, 0))
|
||||
gcs_entries = gcs_utils.GcsEntry.FromString(data)
|
||||
heartbeat_data = gcs_entries.entries[0]
|
||||
message = gcs_utils.HeartbeatTableData.FromString(
|
||||
heartbeat_data)
|
||||
# Calculate available resources for this client
|
||||
num_resources = message.ResourcesAvailableLabelLength()
|
||||
num_resources = len(message.resources_available_label)
|
||||
dynamic_resources = {}
|
||||
for i in range(num_resources):
|
||||
resource_id = decode(message.ResourcesAvailableLabel(i))
|
||||
resource_id = message.resources_available_label[i]
|
||||
dynamic_resources[resource_id] = (
|
||||
message.ResourcesAvailableCapacity(i))
|
||||
message.resources_available_capacity[i])
|
||||
|
||||
# Update available resources for this client
|
||||
client_id = ray.utils.binary_to_hex(message.ClientId())
|
||||
client_id = ray.utils.binary_to_hex(message.client_id)
|
||||
available_resources_by_id[client_id] = dynamic_resources
|
||||
|
||||
# Update clients in cluster
|
||||
@@ -860,23 +828,22 @@ class GlobalState(object):
|
||||
"""
|
||||
assert isinstance(driver_id, ray.DriverID)
|
||||
message = self.redis_client.execute_command(
|
||||
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "",
|
||||
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "",
|
||||
driver_id.binary())
|
||||
|
||||
# If there are no errors, return early.
|
||||
if message is None:
|
||||
return []
|
||||
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
gcs_entries = gcs_utils.GcsEntry.FromString(message)
|
||||
error_messages = []
|
||||
for i in range(gcs_entries.EntriesLength()):
|
||||
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
|
||||
gcs_entries.Entries(i), 0)
|
||||
assert driver_id.binary() == error_data.DriverId()
|
||||
for entry in gcs_entries.entries:
|
||||
error_data = gcs_utils.ErrorTableData.FromString(entry)
|
||||
assert driver_id.binary() == error_data.driver_id
|
||||
error_message = {
|
||||
"type": decode(error_data.Type()),
|
||||
"message": decode(error_data.ErrorMessage()),
|
||||
"timestamp": error_data.Timestamp(),
|
||||
"type": error_data.type,
|
||||
"message": error_data.error_message,
|
||||
"timestamp": error_data.timestamp,
|
||||
}
|
||||
error_messages.append(error_message)
|
||||
return error_messages
|
||||
@@ -899,9 +866,9 @@ class GlobalState(object):
|
||||
return self._error_messages(driver_id)
|
||||
|
||||
error_table_keys = self.redis_client.keys(
|
||||
ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*")
|
||||
gcs_utils.TablePrefix_ERROR_INFO_string + "*")
|
||||
driver_ids = [
|
||||
key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):]
|
||||
key[len(gcs_utils.TablePrefix_ERROR_INFO_string):]
|
||||
for key in error_table_keys
|
||||
]
|
||||
|
||||
@@ -923,30 +890,23 @@ class GlobalState(object):
|
||||
message = self._execute_command(
|
||||
actor_id,
|
||||
"RAY.TABLE_LOOKUP",
|
||||
ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID,
|
||||
gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"),
|
||||
"",
|
||||
actor_id.binary(),
|
||||
)
|
||||
if message is None:
|
||||
return None
|
||||
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
entry = (
|
||||
ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData(
|
||||
gcs_entry.Entries(0), 0))
|
||||
checkpoint_ids_str = entry.CheckpointIds()
|
||||
num_checkpoints = len(checkpoint_ids_str) // ID_SIZE
|
||||
assert len(checkpoint_ids_str) % ID_SIZE == 0
|
||||
gcs_entry = gcs_utils.GcsEntry.FromString(message)
|
||||
entry = gcs_utils.ActorCheckpointIdData.FromString(
|
||||
gcs_entry.entries[0])
|
||||
checkpoint_ids = [
|
||||
ray.ActorCheckpointID(
|
||||
checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)])
|
||||
for i in range(num_checkpoints)
|
||||
ray.ActorCheckpointID(checkpoint_id)
|
||||
for checkpoint_id in entry.checkpoint_ids
|
||||
]
|
||||
return {
|
||||
"ActorID": ray.utils.binary_to_hex(entry.ActorId()),
|
||||
"ActorID": ray.utils.binary_to_hex(entry.actor_id),
|
||||
"CheckpointIds": checkpoint_ids,
|
||||
"Timestamps": [
|
||||
entry.Timestamps(i) for i in range(num_checkpoints)
|
||||
],
|
||||
"Timestamps": list(entry.timestamps),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import time
|
||||
import redis
|
||||
|
||||
import ray
|
||||
from ray.core.generated.EntryType import EntryType
|
||||
from ray.gcs_utils import ClientTableData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -177,7 +177,7 @@ class Cluster(object):
|
||||
clients = ray.state._parse_client_table(redis_client)
|
||||
live_clients = [
|
||||
client for client in clients
|
||||
if client["EntryType"] == EntryType.INSERTION
|
||||
if client["EntryType"] == ClientTableData.INSERTION
|
||||
]
|
||||
|
||||
expected = len(self.list_all_nodes())
|
||||
|
||||
@@ -2736,15 +2736,17 @@ def test_duplicate_error_messages(shutdown_only):
|
||||
|
||||
r = ray.worker.global_worker.redis_client
|
||||
|
||||
r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
|
||||
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(),
|
||||
error_data)
|
||||
r.execute_command("RAY.TABLE_APPEND",
|
||||
ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
|
||||
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
|
||||
driver_id.binary(), error_data)
|
||||
|
||||
# Before https://github.com/ray-project/ray/pull/3316 this would
|
||||
# give an error
|
||||
r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
|
||||
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(),
|
||||
error_data)
|
||||
r.execute_command("RAY.TABLE_APPEND",
|
||||
ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
|
||||
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
|
||||
driver_id.binary(), error_data)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
@@ -493,8 +493,9 @@ def test_warning_monitor_died(shutdown_only):
|
||||
malformed_message = "asdf"
|
||||
redis_client = ray.worker.global_worker.redis_client
|
||||
redis_client.execute_command(
|
||||
"RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH,
|
||||
ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message)
|
||||
"RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"),
|
||||
ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id,
|
||||
malformed_message)
|
||||
|
||||
wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1)
|
||||
|
||||
|
||||
+4
-4
@@ -93,10 +93,10 @@ def push_error_to_driver_through_redis(redis_client,
|
||||
# of through the raylet.
|
||||
error_data = ray.gcs_utils.construct_error_message(driver_id, error_type,
|
||||
message, time.time())
|
||||
redis_client.execute_command("RAY.TABLE_APPEND",
|
||||
ray.gcs_utils.TablePrefix.ERROR_INFO,
|
||||
ray.gcs_utils.TablePubsub.ERROR_INFO,
|
||||
driver_id.binary(), error_data)
|
||||
redis_client.execute_command(
|
||||
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
|
||||
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
|
||||
driver_id.binary(), error_data)
|
||||
|
||||
|
||||
def is_cython(obj):
|
||||
|
||||
+20
-20
@@ -47,7 +47,7 @@ from ray import (
|
||||
from ray import import_thread
|
||||
from ray import profiling
|
||||
|
||||
from ray.core.generated.ErrorType import ErrorType
|
||||
from ray.gcs_utils import ErrorType
|
||||
from ray.exceptions import (
|
||||
RayActorError,
|
||||
RayError,
|
||||
@@ -461,11 +461,11 @@ class Worker(object):
|
||||
# Otherwise, return an exception object based on
|
||||
# the error type.
|
||||
error_type = int(metadata)
|
||||
if error_type == ErrorType.WORKER_DIED:
|
||||
if error_type == ErrorType.Value("WORKER_DIED"):
|
||||
return RayWorkerError()
|
||||
elif error_type == ErrorType.ACTOR_DIED:
|
||||
elif error_type == ErrorType.Value("ACTOR_DIED"):
|
||||
return RayActorError()
|
||||
elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE:
|
||||
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
|
||||
return UnreconstructableError(ray.ObjectID(object_id.binary()))
|
||||
else:
|
||||
assert False, "Unrecognized error type " + str(error_type)
|
||||
@@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
# Really we should just subscribe to the errors for this specific job.
|
||||
# However, currently all errors seem to be published on the same channel.
|
||||
error_pubsub_channel = str(
|
||||
ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii")
|
||||
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")).encode("ascii")
|
||||
worker.error_message_pubsub_client.subscribe(error_pubsub_channel)
|
||||
# worker.error_message_pubsub_client.psubscribe("*")
|
||||
|
||||
@@ -1656,21 +1656,19 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
if msg is None:
|
||||
threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
|
||||
msg["data"], 0)
|
||||
assert gcs_entry.EntriesLength() == 1
|
||||
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
|
||||
gcs_entry.Entries(0), 0)
|
||||
driver_id = error_data.DriverId()
|
||||
gcs_entry = ray.gcs_utils.GcsEntry.FromString(msg["data"])
|
||||
assert len(gcs_entry.entries) == 1
|
||||
error_data = ray.gcs_utils.ErrorTableData.FromString(
|
||||
gcs_entry.entries[0])
|
||||
driver_id = error_data.driver_id
|
||||
if driver_id not in [
|
||||
worker.task_driver_id.binary(),
|
||||
DriverID.nil().binary()
|
||||
]:
|
||||
continue
|
||||
|
||||
error_message = ray.utils.decode(error_data.ErrorMessage())
|
||||
if (ray.utils.decode(
|
||||
error_data.Type()) == ray_constants.TASK_PUSH_ERROR):
|
||||
error_message = error_data.error_message
|
||||
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
|
||||
# Delay it a bit to see if we can suppress it
|
||||
task_error_queue.put((error_message, time.time()))
|
||||
else:
|
||||
@@ -1878,14 +1876,16 @@ def connect(node,
|
||||
{}, # resource_map.
|
||||
{}, # placement_resource_map.
|
||||
)
|
||||
task_table_data = ray.gcs_utils.TaskTableData()
|
||||
task_table_data.task = driver_task._serialized_raylet_task()
|
||||
|
||||
# Add the driver task to the task table.
|
||||
ray.state.state._execute_command(driver_task.task_id(),
|
||||
"RAY.TABLE_ADD",
|
||||
ray.gcs_utils.TablePrefix.RAYLET_TASK,
|
||||
ray.gcs_utils.TablePubsub.RAYLET_TASK,
|
||||
driver_task.task_id().binary(),
|
||||
driver_task._serialized_raylet_task())
|
||||
ray.state.state._execute_command(
|
||||
driver_task.task_id(), "RAY.TABLE_ADD",
|
||||
ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"),
|
||||
ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"),
|
||||
driver_task.task_id().binary(),
|
||||
task_table_data.SerializeToString())
|
||||
|
||||
# Set the driver's current task ID to the task ID assigned to the
|
||||
# driver task.
|
||||
|
||||
@@ -150,6 +150,7 @@ requires = [
|
||||
"six >= 1.0.0",
|
||||
"flatbuffers",
|
||||
"faulthandler;python_version<'3.3'",
|
||||
"protobuf",
|
||||
]
|
||||
|
||||
setup(
|
||||
|
||||
Reference in New Issue
Block a user