[gRPC] Migrate gcs data structures to protobuf (#5024)

This commit is contained in:
Hao Chen
2019-06-26 05:31:19 +08:00
committed by Philipp Moritz
parent bd8aceb896
commit 0131353d42
52 changed files with 1465 additions and 1642 deletions
+31 -40
View File
@@ -2,38 +2,39 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import flatbuffers
import ray.core.generated.ErrorTableData
from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.DriverTableData import DriverTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.GcsEntry import GcsEntry
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.Language import Language
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.TablePrefix import TablePrefix
from ray.core.generated.TablePubsub import TablePubsub
from ray.core.generated.ray.protocol.Task import Task
from ray.core.generated.gcs_pb2 import (
ActorCheckpointIdData,
ClientTableData,
DriverTableData,
ErrorTableData,
ErrorType,
GcsEntry,
HeartbeatBatchTableData,
HeartbeatTableData,
ObjectTableData,
ProfileTableData,
TablePrefix,
TablePubsub,
TaskTableData,
)
__all__ = [
"ActorCheckpointIdData",
"ClientTableData",
"DriverTableData",
"ErrorTableData",
"ErrorType",
"GcsEntry",
"HeartbeatBatchTableData",
"HeartbeatTableData",
"Language",
"ObjectTableData",
"ProfileTableData",
"TablePrefix",
"TablePubsub",
"Task",
"TaskTableData",
"construct_error_message",
]
@@ -42,13 +43,16 @@ LOG_FILE_CHANNEL = "RAY_LOG_CHANNEL"
REPORTER_CHANNEL = "RAY_REPORTER"
# xray heartbeats
XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii")
XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii")
XRAY_HEARTBEAT_CHANNEL = str(
TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii")
XRAY_HEARTBEAT_BATCH_CHANNEL = str(
TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii")
# xray driver updates
XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii")
XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii")
# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
# These prefixes must be kept up-to-date with the TablePrefix enum in
# gcs.proto.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
@@ -70,22 +74,9 @@ def construct_error_message(driver_id, error_type, message, timestamp):
Returns:
The serialized object.
"""
builder = flatbuffers.Builder(0)
driver_offset = builder.CreateString(driver_id.binary())
error_type_offset = builder.CreateString(error_type)
message_offset = builder.CreateString(message)
ray.core.generated.ErrorTableData.ErrorTableDataStart(builder)
ray.core.generated.ErrorTableData.ErrorTableDataAddDriverId(
builder, driver_offset)
ray.core.generated.ErrorTableData.ErrorTableDataAddType(
builder, error_type_offset)
ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage(
builder, message_offset)
ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp(
builder, timestamp)
error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd(
builder)
builder.Finish(error_data_offset)
return bytes(builder.Output())
data = ErrorTableData()
data.driver_id = driver_id.binary()
data.type = error_type
data.error_message = message
data.timestamp = timestamp
return data.SerializeToString()
+15 -18
View File
@@ -101,28 +101,26 @@ class Monitor(object):
def xray_heartbeat_batch_handler(self, unused_channel, data):
"""Handle an xray heartbeat batch message from Redis."""
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
heartbeat_data = gcs_entries.Entries(0)
gcs_entries = ray.gcs_utils.GcsEntry.FromString(data)
heartbeat_data = gcs_entries.entries[0]
message = (ray.gcs_utils.HeartbeatBatchTableData.
GetRootAsHeartbeatBatchTableData(heartbeat_data, 0))
message = ray.gcs_utils.HeartbeatBatchTableData.FromString(
heartbeat_data)
for j in range(message.BatchLength()):
heartbeat_message = message.Batch(j)
num_resources = heartbeat_message.ResourcesTotalLabelLength()
for heartbeat_message in message.batch:
num_resources = len(heartbeat_message.resources_available_label)
static_resources = {}
dynamic_resources = {}
for i in range(num_resources):
dyn = heartbeat_message.ResourcesAvailableLabel(i)
static = heartbeat_message.ResourcesTotalLabel(i)
dyn = heartbeat_message.resources_available_label[i]
static = heartbeat_message.resources_total_label[i]
dynamic_resources[dyn] = (
heartbeat_message.ResourcesAvailableCapacity(i))
heartbeat_message.resources_available_capacity[i])
static_resources[static] = (
heartbeat_message.ResourcesTotalCapacity(i))
heartbeat_message.resources_total_capacity[i])
# Update the load metrics for this raylet.
client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId())
client_id = ray.utils.binary_to_hex(heartbeat_message.client_id)
ip = self.raylet_id_to_ip_map.get(client_id)
if ip:
self.load_metrics.update(ip, static_resources,
@@ -207,11 +205,10 @@ class Monitor(object):
unused_channel: The message channel.
data: The message data.
"""
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
driver_data = gcs_entries.Entries(0)
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
driver_data, 0)
driver_id = message.DriverId()
gcs_entries = ray.gcs_utils.GcsEntry.FromString(data)
driver_data = gcs_entries.entries[0]
message = ray.gcs_utils.DriverTableData.FromString(driver_data)
driver_id = message.driver_id
logger.info("Monitor: "
"XRay Driver {} has been removed.".format(
binary_to_hex(driver_id)))
+95 -135
View File
@@ -10,11 +10,11 @@ import time
import ray
from ray.function_manager import FunctionDescriptor
import ray.gcs_utils
from ray.ray_constants import ID_SIZE
from ray import services
from ray.core.generated.EntryType import EntryType
from ray import (
gcs_utils,
services,
)
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
@@ -31,9 +31,9 @@ def _parse_client_table(redis_client):
A list of information about the nodes in the cluster.
"""
NIL_CLIENT_ID = ray.ObjectID.nil().binary()
message = redis_client.execute_command("RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.CLIENT,
"", NIL_CLIENT_ID)
message = redis_client.execute_command(
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "",
NIL_CLIENT_ID)
# Handle the case where no clients are returned. This should only
# occur potentially immediately after the cluster is started.
@@ -41,36 +41,31 @@ def _parse_client_table(redis_client):
return []
node_info = {}
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
gcs_entry = gcs_utils.GcsEntry.FromString(message)
ordered_client_ids = []
# Since GCS entries are append-only, we override so that
# only the latest entries are kept.
for i in range(gcs_entry.EntriesLength()):
client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
gcs_entry.Entries(i), 0))
for entry in gcs_entry.entries:
client = gcs_utils.ClientTableData.FromString(entry)
resources = {
decode(client.ResourcesTotalLabel(i)):
client.ResourcesTotalCapacity(i)
for i in range(client.ResourcesTotalLabelLength())
client.resources_total_label[i]: client.resources_total_capacity[i]
for i in range(len(client.resources_total_label))
}
client_id = ray.utils.binary_to_hex(client.ClientId())
client_id = ray.utils.binary_to_hex(client.client_id)
if client.EntryType() == EntryType.INSERTION:
if client.entry_type == gcs_utils.ClientTableData.INSERTION:
ordered_client_ids.append(client_id)
node_info[client_id] = {
"ClientID": client_id,
"EntryType": client.EntryType(),
"NodeManagerAddress": decode(
client.NodeManagerAddress(), allow_none=True),
"NodeManagerPort": client.NodeManagerPort(),
"ObjectManagerPort": client.ObjectManagerPort(),
"ObjectStoreSocketName": decode(
client.ObjectStoreSocketName(), allow_none=True),
"RayletSocketName": decode(
client.RayletSocketName(), allow_none=True),
"EntryType": client.entry_type,
"NodeManagerAddress": client.node_manager_address,
"NodeManagerPort": client.node_manager_port,
"ObjectManagerPort": client.object_manager_port,
"ObjectStoreSocketName": client.object_store_socket_name,
"RayletSocketName": client.raylet_socket_name,
"Resources": resources
}
@@ -79,22 +74,23 @@ def _parse_client_table(redis_client):
# it cannot have previously been removed.
else:
assert client_id in node_info, "Client not found!"
assert node_info[client_id]["EntryType"] != EntryType.DELETION, (
"Unexpected updation of deleted client.")
is_deletion = (node_info[client_id]["EntryType"] !=
gcs_utils.ClientTableData.DELETION)
assert is_deletion, "Unexpected updation of deleted client."
res_map = node_info[client_id]["Resources"]
if client.EntryType() == EntryType.RES_CREATEUPDATE:
if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE:
for res in resources:
res_map[res] = resources[res]
elif client.EntryType() == EntryType.RES_DELETE:
elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE:
for res in resources:
res_map.pop(res, None)
elif client.EntryType() == EntryType.DELETION:
elif client.entry_type == gcs_utils.ClientTableData.DELETION:
pass # Do nothing with the resmap if client deletion
else:
raise RuntimeError("Unexpected EntryType {}".format(
client.EntryType()))
client.entry_type))
node_info[client_id]["Resources"] = res_map
node_info[client_id]["EntryType"] = client.EntryType()
node_info[client_id]["EntryType"] = client.entry_type
# NOTE: We return the list comprehension below instead of simply doing
# 'list(node_info.values())' in order to have the nodes appear in the order
# that they joined the cluster. Python dictionaries do not preserve
@@ -244,20 +240,19 @@ class GlobalState(object):
# Return information about a single object ID.
message = self._execute_command(object_id, "RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.OBJECT, "",
object_id.binary())
gcs_utils.TablePrefix.Value("OBJECT"),
"", object_id.binary())
if message is None:
return {}
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
gcs_entry = gcs_utils.GcsEntry.FromString(message)
assert gcs_entry.EntriesLength() > 0
assert len(gcs_entry.entries) > 0
entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
gcs_entry.Entries(0), 0)
entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0])
object_info = {
"DataSize": entry.ObjectSize(),
"Manager": entry.Manager(),
"DataSize": entry.object_size,
"Manager": entry.manager,
}
return object_info
@@ -278,10 +273,9 @@ class GlobalState(object):
return self._object_table(object_id)
else:
# Return the entire object table.
object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string +
"*")
object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*")
object_ids_binary = {
key[len(ray.gcs_utils.TablePrefix_OBJECT_string):]
key[len(gcs_utils.TablePrefix_OBJECT_string):]
for key in object_keys
}
@@ -301,17 +295,18 @@ class GlobalState(object):
A dictionary with information about the task ID in question.
"""
assert isinstance(task_id, ray.TaskID)
message = self._execute_command(task_id, "RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.RAYLET_TASK,
"", task_id.binary())
message = self._execute_command(
task_id, "RAY.TABLE_LOOKUP",
gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary())
if message is None:
return {}
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
gcs_entries = gcs_utils.GcsEntry.FromString(message)
assert gcs_entries.EntriesLength() == 1
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(0), 0)
assert len(gcs_entries.entries) == 1
task_table_data = gcs_utils.TaskTableData.FromString(
gcs_entries.entries[0])
task_table_message = gcs_utils.Task.GetRootAsTask(
task_table_data.task, 0)
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
@@ -368,9 +363,9 @@ class GlobalState(object):
return self._task_table(task_id)
else:
task_table_keys = self._keys(
ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
task_ids_binary = [
key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):]
key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):]
for key in task_table_keys
]
@@ -380,27 +375,6 @@ class GlobalState(object):
ray.TaskID(task_id_binary))
return results
def function_table(self, function_id=None):
"""Fetch and parse the function table.
Returns:
A dictionary that maps function IDs to information about the
function.
"""
self._check_connected()
function_table_keys = self.redis_client.keys(
ray.gcs_utils.FUNCTION_PREFIX + "*")
results = {}
for key in function_table_keys:
info = self.redis_client.hgetall(key)
function_info_parsed = {
"DriverID": binary_to_hex(info[b"driver_id"]),
"Module": decode(info[b"module"]),
"Name": decode(info[b"name"])
}
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
return results
def client_table(self):
"""Fetch and parse the Redis DB client table.
@@ -423,37 +397,32 @@ class GlobalState(object):
# TODO(rkn): This method should support limiting the number of log
# events and should also support returning a window of events.
message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.PROFILE, "",
batch_id.binary())
gcs_utils.TablePrefix.Value("PROFILE"),
"", batch_id.binary())
if message is None:
return []
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
gcs_entries = gcs_utils.GcsEntry.FromString(message)
profile_events = []
for i in range(gcs_entries.EntriesLength()):
profile_table_message = (
ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData(
gcs_entries.Entries(i), 0))
for entry in gcs_entries.entries:
profile_table_message = gcs_utils.ProfileTableData.FromString(
entry)
component_type = decode(profile_table_message.ComponentType())
component_id = binary_to_hex(profile_table_message.ComponentId())
node_ip_address = decode(
profile_table_message.NodeIpAddress(), allow_none=True)
for j in range(profile_table_message.ProfileEventsLength()):
profile_event_message = profile_table_message.ProfileEvents(j)
component_type = profile_table_message.component_type
component_id = binary_to_hex(profile_table_message.component_id)
node_ip_address = profile_table_message.node_ip_address
for profile_event_message in profile_table_message.profile_events:
profile_event = {
"event_type": decode(profile_event_message.EventType()),
"event_type": profile_event_message.event_type,
"component_id": component_id,
"node_ip_address": node_ip_address,
"component_type": component_type,
"start_time": profile_event_message.StartTime(),
"end_time": profile_event_message.EndTime(),
"extra_data": json.loads(
decode(profile_event_message.ExtraData())),
"start_time": profile_event_message.start_time,
"end_time": profile_event_message.end_time,
"extra_data": json.loads(profile_event_message.extra_data),
}
profile_events.append(profile_event)
@@ -462,10 +431,10 @@ class GlobalState(object):
def profile_table(self):
self._check_connected()
profile_table_keys = self._keys(
ray.gcs_utils.TablePrefix_PROFILE_string + "*")
profile_table_keys = self._keys(gcs_utils.TablePrefix_PROFILE_string +
"*")
batch_identifiers_binary = [
key[len(ray.gcs_utils.TablePrefix_PROFILE_string):]
key[len(gcs_utils.TablePrefix_PROFILE_string):]
for key in profile_table_keys
]
@@ -766,7 +735,7 @@ class GlobalState(object):
clients = self.client_table()
for client in clients:
# Only count resources from latest entries of live clients.
if client["EntryType"] != EntryType.DELETION:
if client["EntryType"] != gcs_utils.ClientTableData.DELETION:
for key, value in client["Resources"].items():
resources[key] += value
return dict(resources)
@@ -776,7 +745,7 @@ class GlobalState(object):
return {
client["ClientID"]
for client in self.client_table()
if (client["EntryType"] != EntryType.DELETION)
if (client["EntryType"] != gcs_utils.ClientTableData.DELETION)
}
def available_resources(self):
@@ -800,7 +769,7 @@ class GlobalState(object):
for redis_client in self.redis_clients
]
for subscribe_client in subscribe_clients:
subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL)
subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL)
client_ids = self._live_client_ids()
@@ -809,24 +778,23 @@ class GlobalState(object):
# Parse client message
raw_message = subscribe_client.get_message()
if (raw_message is None or raw_message["channel"] !=
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
gcs_utils.XRAY_HEARTBEAT_CHANNEL):
continue
data = raw_message["data"]
gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
data, 0))
heartbeat_data = gcs_entries.Entries(0)
message = (ray.gcs_utils.HeartbeatTableData.
GetRootAsHeartbeatTableData(heartbeat_data, 0))
gcs_entries = gcs_utils.GcsEntry.FromString(data)
heartbeat_data = gcs_entries.entries[0]
message = gcs_utils.HeartbeatTableData.FromString(
heartbeat_data)
# Calculate available resources for this client
num_resources = message.ResourcesAvailableLabelLength()
num_resources = len(message.resources_available_label)
dynamic_resources = {}
for i in range(num_resources):
resource_id = decode(message.ResourcesAvailableLabel(i))
resource_id = message.resources_available_label[i]
dynamic_resources[resource_id] = (
message.ResourcesAvailableCapacity(i))
message.resources_available_capacity[i])
# Update available resources for this client
client_id = ray.utils.binary_to_hex(message.ClientId())
client_id = ray.utils.binary_to_hex(message.client_id)
available_resources_by_id[client_id] = dynamic_resources
# Update clients in cluster
@@ -860,23 +828,22 @@ class GlobalState(object):
"""
assert isinstance(driver_id, ray.DriverID)
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "",
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "",
driver_id.binary())
# If there are no errors, return early.
if message is None:
return []
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
gcs_entries = gcs_utils.GcsEntry.FromString(message)
error_messages = []
for i in range(gcs_entries.EntriesLength()):
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
gcs_entries.Entries(i), 0)
assert driver_id.binary() == error_data.DriverId()
for entry in gcs_entries.entries:
error_data = gcs_utils.ErrorTableData.FromString(entry)
assert driver_id.binary() == error_data.driver_id
error_message = {
"type": decode(error_data.Type()),
"message": decode(error_data.ErrorMessage()),
"timestamp": error_data.Timestamp(),
"type": error_data.type,
"message": error_data.error_message,
"timestamp": error_data.timestamp,
}
error_messages.append(error_message)
return error_messages
@@ -899,9 +866,9 @@ class GlobalState(object):
return self._error_messages(driver_id)
error_table_keys = self.redis_client.keys(
ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*")
gcs_utils.TablePrefix_ERROR_INFO_string + "*")
driver_ids = [
key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):]
key[len(gcs_utils.TablePrefix_ERROR_INFO_string):]
for key in error_table_keys
]
@@ -923,30 +890,23 @@ class GlobalState(object):
message = self._execute_command(
actor_id,
"RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID,
gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"),
"",
actor_id.binary(),
)
if message is None:
return None
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
entry = (
ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData(
gcs_entry.Entries(0), 0))
checkpoint_ids_str = entry.CheckpointIds()
num_checkpoints = len(checkpoint_ids_str) // ID_SIZE
assert len(checkpoint_ids_str) % ID_SIZE == 0
gcs_entry = gcs_utils.GcsEntry.FromString(message)
entry = gcs_utils.ActorCheckpointIdData.FromString(
gcs_entry.entries[0])
checkpoint_ids = [
ray.ActorCheckpointID(
checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)])
for i in range(num_checkpoints)
ray.ActorCheckpointID(checkpoint_id)
for checkpoint_id in entry.checkpoint_ids
]
return {
"ActorID": ray.utils.binary_to_hex(entry.ActorId()),
"ActorID": ray.utils.binary_to_hex(entry.actor_id),
"CheckpointIds": checkpoint_ids,
"Timestamps": [
entry.Timestamps(i) for i in range(num_checkpoints)
],
"Timestamps": list(entry.timestamps),
}
+2 -2
View File
@@ -8,7 +8,7 @@ import time
import redis
import ray
from ray.core.generated.EntryType import EntryType
from ray.gcs_utils import ClientTableData
logger = logging.getLogger(__name__)
@@ -177,7 +177,7 @@ class Cluster(object):
clients = ray.state._parse_client_table(redis_client)
live_clients = [
client for client in clients
if client["EntryType"] == EntryType.INSERTION
if client["EntryType"] == ClientTableData.INSERTION
]
expected = len(self.list_all_nodes())
+8 -6
View File
@@ -2736,15 +2736,17 @@ def test_duplicate_error_messages(shutdown_only):
r = ray.worker.global_worker.redis_client
r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(),
error_data)
r.execute_command("RAY.TABLE_APPEND",
ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
driver_id.binary(), error_data)
# Before https://github.com/ray-project/ray/pull/3316 this would
# give an error
r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(),
error_data)
r.execute_command("RAY.TABLE_APPEND",
ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
driver_id.binary(), error_data)
@pytest.mark.skipif(
+3 -2
View File
@@ -493,8 +493,9 @@ def test_warning_monitor_died(shutdown_only):
malformed_message = "asdf"
redis_client = ray.worker.global_worker.redis_client
redis_client.execute_command(
"RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH,
ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message)
"RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"),
ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id,
malformed_message)
wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1)
+4 -4
View File
@@ -93,10 +93,10 @@ def push_error_to_driver_through_redis(redis_client,
# of through the raylet.
error_data = ray.gcs_utils.construct_error_message(driver_id, error_type,
message, time.time())
redis_client.execute_command("RAY.TABLE_APPEND",
ray.gcs_utils.TablePrefix.ERROR_INFO,
ray.gcs_utils.TablePubsub.ERROR_INFO,
driver_id.binary(), error_data)
redis_client.execute_command(
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
driver_id.binary(), error_data)
def is_cython(obj):
+20 -20
View File
@@ -47,7 +47,7 @@ from ray import (
from ray import import_thread
from ray import profiling
from ray.core.generated.ErrorType import ErrorType
from ray.gcs_utils import ErrorType
from ray.exceptions import (
RayActorError,
RayError,
@@ -461,11 +461,11 @@ class Worker(object):
# Otherwise, return an exception object based on
# the error type.
error_type = int(metadata)
if error_type == ErrorType.WORKER_DIED:
if error_type == ErrorType.Value("WORKER_DIED"):
return RayWorkerError()
elif error_type == ErrorType.ACTOR_DIED:
elif error_type == ErrorType.Value("ACTOR_DIED"):
return RayActorError()
elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE:
elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"):
return UnreconstructableError(ray.ObjectID(object_id.binary()))
else:
assert False, "Unrecognized error type " + str(error_type)
@@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
# Really we should just subscribe to the errors for this specific job.
# However, currently all errors seem to be published on the same channel.
error_pubsub_channel = str(
ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii")
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")).encode("ascii")
worker.error_message_pubsub_client.subscribe(error_pubsub_channel)
# worker.error_message_pubsub_client.psubscribe("*")
@@ -1656,21 +1656,19 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
if msg is None:
threads_stopped.wait(timeout=0.01)
continue
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
msg["data"], 0)
assert gcs_entry.EntriesLength() == 1
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
gcs_entry.Entries(0), 0)
driver_id = error_data.DriverId()
gcs_entry = ray.gcs_utils.GcsEntry.FromString(msg["data"])
assert len(gcs_entry.entries) == 1
error_data = ray.gcs_utils.ErrorTableData.FromString(
gcs_entry.entries[0])
driver_id = error_data.driver_id
if driver_id not in [
worker.task_driver_id.binary(),
DriverID.nil().binary()
]:
continue
error_message = ray.utils.decode(error_data.ErrorMessage())
if (ray.utils.decode(
error_data.Type()) == ray_constants.TASK_PUSH_ERROR):
error_message = error_data.error_message
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
# Delay it a bit to see if we can suppress it
task_error_queue.put((error_message, time.time()))
else:
@@ -1878,14 +1876,16 @@ def connect(node,
{}, # resource_map.
{}, # placement_resource_map.
)
task_table_data = ray.gcs_utils.TaskTableData()
task_table_data.task = driver_task._serialized_raylet_task()
# Add the driver task to the task table.
ray.state.state._execute_command(driver_task.task_id(),
"RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.RAYLET_TASK,
ray.gcs_utils.TablePubsub.RAYLET_TASK,
driver_task.task_id().binary(),
driver_task._serialized_raylet_task())
ray.state.state._execute_command(
driver_task.task_id(), "RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"),
ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"),
driver_task.task_id().binary(),
task_table_data.SerializeToString())
# Set the driver's current task ID to the task ID assigned to the
# driver task.
+1
View File
@@ -150,6 +150,7 @@ requires = [
"six >= 1.0.0",
"flatbuffers",
"faulthandler;python_version<'3.3'",
"protobuf",
]
setup(