[xray] Add error table and push error messages to driver through node manager. (#2256)

* Fix documentation indentation.

* Add error table to GCS and push error messages through node manager.

* Add type to error data.

* Linting

* Fix failure_test bug.

* Linting.

* Enable one more test.

* Attempt to fix doc building.

* Restructuring

* Fixes

* More fixes.

* Move current_time_ms function into util.h.
This commit is contained in:
Robert Nishihara
2018-06-20 21:29:28 -07:00
committed by Philipp Moritz
parent 6bf48f47bc
commit ff2217251f
27 changed files with 610 additions and 204 deletions
+4 -4
View File
@@ -164,7 +164,7 @@ def save_and_log_checkpoint(worker, actor):
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
@@ -188,7 +188,7 @@ def restore_and_log_checkpoint(worker, actor):
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
@@ -330,7 +330,7 @@ def fetch_and_register_actor(actor_class_key, worker):
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(
worker.redis_client,
worker,
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
traceback_str,
driver_id,
@@ -402,7 +402,7 @@ def export_actor_class(class_id, Class, actor_method_names,
.format(actor_class_info["class_name"],
len(actor_class_info["class"])))
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=worker.task_driver_id.id())
+20 -25
View File
@@ -8,20 +8,9 @@ import sys
import time
import unittest
import ray.gcs_utils
import ray.services
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToNotificationsReply \
import SubscribeToNotificationsReply
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.ResultTableReply import ResultTableReply
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
OBJECT_SUBSCRIBE_PREFIX = "OS:"
TASK_PREFIX = "TT:"
OBJECT_CHANNEL_PREFIX = "OC:"
def integerToAsciiHex(num, numbytes):
retstr = b""
@@ -194,7 +183,7 @@ class TestGlobalStateStore(unittest.TestCase):
# notifications.
def check_object_notification(notification_message, object_id,
object_size, manager_ids):
notification_object = (SubscribeToNotificationsReply.
notification_object = (ray.gcs_utils.SubscribeToNotificationsReply.
GetRootAsSubscribeToNotificationsReply(
notification_message, 0))
self.assertEqual(notification_object.ObjectId(), object_id)
@@ -208,7 +197,8 @@ class TestGlobalStateStore(unittest.TestCase):
data_size = 0xf1f0
p = self.redis.pubsub()
# Subscribe to an object ID.
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
p.psubscribe("{}manager_id1".format(
ray.gcs_utils.OBJECT_CHANNEL_PREFIX))
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1",
data_size, "hash1", "manager_id2")
# Receive the acknowledgement message.
@@ -252,8 +242,9 @@ class TestGlobalStateStore(unittest.TestCase):
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
result_table_reply = ResultTableReply.GetRootAsResultTableReply(
message, 0)
result_table_reply = (
ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply(
message, 0))
self.assertEqual(result_table_reply.TaskId(), task_id)
self.assertEqual(result_table_reply.IsPut(), is_put)
@@ -315,12 +306,13 @@ class TestGlobalStateStore(unittest.TestCase):
# make sure somebody will get a notification (checked in the redis
# module)
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX))
def check_task_reply(message, task_args, updated=False):
(task_status, local_scheduler_id, execution_dependencies_string,
spillback_count, task_spec) = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
task_reply_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(),
local_scheduler_id)
@@ -409,7 +401,8 @@ class TestGlobalStateStore(unittest.TestCase):
# Receive the data.
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
notification_object = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
message, 0)
self.assertEqual(notification_object.TaskId(), task_args[0])
self.assertEqual(notification_object.State(), task_args[1])
self.assertEqual(notification_object.LocalSchedulerId(), task_args[2])
@@ -422,32 +415,34 @@ class TestGlobalStateStore(unittest.TestCase):
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.psubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
# unsubscribe to make sure there is only one subscriber at a given time
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.punsubscribe("{prefix}*:*".format(prefix=ray.gcs_utils.TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
prefix=ray.gcs_utils.TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
prefix=ray.gcs_utils.TASK_PREFIX,
local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
prefix=ray.gcs_utils.TASK_PREFIX,
local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
+104 -59
View File
@@ -12,41 +12,10 @@ import sys
import time
import ray
import ray.gcs_utils
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
# Import flatbuffer bindings.
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.ResultTableReply import ResultTableReply
from ray.core.generated.TaskExecutionDependencies import \
TaskExecutionDependencies
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
# These prefixes must be kept up-to-date with the definitions in
# ray_redis_module.cc.
DB_CLIENT_PREFIX = "CL:"
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
OBJECT_SUBSCRIBE_PREFIX = "OS:"
TASK_PREFIX = "TT:"
FUNCTION_PREFIX = "RemoteFunction:"
OBJECT_CHANNEL_PREFIX = "OC:"
# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
TablePrefix_RAYLET_TASK = 2
TablePrefix_RAYLET_TASK_string = "TASK"
TablePrefix_CLIENT = 3
TablePrefix_CLIENT_string = "CLIENT"
TablePrefix_OBJECT = 4
TablePrefix_OBJECT_string = "OBJECT"
# This mapping from integer to task state string must be kept up-to-date with
# the scheduling_state enum in task.h.
TASK_STATUS_WAITING = 1
@@ -231,8 +200,9 @@ class GlobalState(object):
result_table_response = self._execute_command(
object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id())
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
result_table_message = (
ray.gcs_utils.ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0))
result = {
"ManagerIDs": manager_ids,
@@ -245,12 +215,14 @@ class GlobalState(object):
else:
# Use the raylet code path.
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", TablePrefix_OBJECT, "", object_id.id())
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.OBJECT, "",
object_id.id())
result = []
gcs_entry = GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
for i in range(gcs_entry.EntriesLength()):
entry = ObjectTableData.GetRootAsObjectTableData(
entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData(
gcs_entry.Entries(i), 0)
object_info = {
"DataSize": entry.ObjectSize(),
@@ -279,19 +251,22 @@ class GlobalState(object):
else:
# Return the entire object table.
if not self.use_raylet:
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_info_keys = self._keys(
ray.gcs_utils.OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(
ray.gcs_utils.OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set([
key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys
key[len(ray.gcs_utils.OBJECT_INFO_PREFIX):]
for key in object_info_keys
] + [
key[len(OBJECT_LOCATION_PREFIX):]
key[len(ray.gcs_utils.OBJECT_LOCATION_PREFIX):]
for key in object_location_keys
])
else:
object_keys = self.redis_client.keys(
TablePrefix_OBJECT_string + ":*")
ray.gcs_utils.TablePrefix_OBJECT_string + "*")
object_ids_binary = {
key[len(TablePrefix_OBJECT_string + ":"):]
key[len(ray.gcs_utils.TablePrefix_OBJECT_string):]
for key in object_keys
}
@@ -320,7 +295,7 @@ class GlobalState(object):
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task "
"table.".format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(
task_table_message = ray.gcs_utils.TaskReply.GetRootAsTaskReply(
task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec = ray.local_scheduler.task_from_string(task_spec)
@@ -343,7 +318,8 @@ class GlobalState(object):
}
execution_dependencies_message = (
TaskExecutionDependencies.GetRootAsTaskExecutionDependencies(
ray.gcs_utils.TaskExecutionDependencies.
GetRootAsTaskExecutionDependencies(
task_table_message.ExecutionDependencies(), 0))
execution_dependencies = [
ray.ObjectID(
@@ -371,15 +347,17 @@ class GlobalState(object):
else:
# Use the raylet code path.
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", TablePrefix_RAYLET_TASK, "", task_id.id())
gcs_entries = GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.RAYLET_TASK, "",
task_id.id())
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
info = []
for i in range(gcs_entries.EntriesLength()):
task_table_message = Task.GetRootAsTask(
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(i), 0)
task_table_message = Task.GetRootAsTask(
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(0), 0)
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
@@ -432,15 +410,16 @@ class GlobalState(object):
return self._task_table(task_id)
else:
if not self.use_raylet:
task_table_keys = self._keys(TASK_PREFIX + "*")
task_table_keys = self._keys(ray.gcs_utils.TASK_PREFIX + "*")
task_ids_binary = [
key[len(TASK_PREFIX):] for key in task_table_keys
key[len(ray.gcs_utils.TASK_PREFIX):]
for key in task_table_keys
]
else:
task_table_keys = self.redis_client.keys(
TablePrefix_RAYLET_TASK_string + ":*")
ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*")
task_ids_binary = [
key[len(TablePrefix_RAYLET_TASK_string + ":"):]
key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):]
for key in task_table_keys
]
@@ -458,7 +437,8 @@ class GlobalState(object):
function.
"""
self._check_connected()
function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*")
function_table_keys = self.redis_client.keys(
ray.gcs_utils.FUNCTION_PREFIX + "*")
results = {}
for key in function_table_keys:
info = self.redis_client.hgetall(key)
@@ -478,7 +458,8 @@ class GlobalState(object):
"""
self._check_connected()
if not self.use_raylet:
db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*")
db_client_keys = self.redis_client.keys(
ray.gcs_utils.DB_CLIENT_PREFIX + "*")
node_info = {}
for key in db_client_keys:
client_info = self.redis_client.hgetall(key)
@@ -520,13 +501,16 @@ class GlobalState(object):
# This is the raylet code path.
NIL_CLIENT_ID = 20 * b"\xff"
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", TablePrefix_CLIENT, "", NIL_CLIENT_ID)
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.CLIENT, "",
NIL_CLIENT_ID)
node_info = []
gcs_entry = GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
for i in range(gcs_entry.EntriesLength()):
client = ClientTableData.GetRootAsClientTableData(
gcs_entry.Entries(i), 0)
client = (
ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
gcs_entry.Entries(i), 0))
resources = {
client.ResourcesTotalLabel(i).decode("ascii"):
@@ -1146,3 +1130,64 @@ class GlobalState(object):
resources[key] += value
return dict(resources)
def _error_messages(self, job_id):
"""Get the error messages for a specific job.
Args:
job_id: The ID of the job to get the errors for.
Returns:
A list of the error messages for this job.
"""
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "",
job_id.id())
# If there are no errors, return early.
if message is None:
return []
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
error_messages = []
for i in range(gcs_entries.EntriesLength()):
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
gcs_entries.Entries(i), 0)
error_message = {
"type": error_data.Type().decode("ascii"),
"message": error_data.ErrorMessage().decode("ascii"),
"timestamp": error_data.Timestamp(),
}
error_messages.append(error_message)
return error_messages
def error_messages(self, job_id=None):
"""Get the error messages for all jobs or a specific job.
Args:
job_id: The specific job to get the errors for. If this is None,
then this method retrieves the errors for all jobs.
Returns:
A dictionary mapping job ID to a list of the error messages for
that job.
"""
if not self.use_raylet:
raise Exception("The error_messages method is only supported in "
"the raylet code path.")
if job_id is not None:
return self._error_messages(job_id)
error_table_keys = self.redis_client.keys(
ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*")
job_ids = [
key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):]
for key in error_table_keys
]
return {
binary_to_hex(job_id): self._error_messages(ray.ObjectID(job_id))
for job_id in job_ids
}
+84
View File
@@ -0,0 +1,84 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import flatbuffers
from ray.core.generated.ResultTableReply import ResultTableReply
from ray.core.generated.SubscribeToNotificationsReply \
import SubscribeToNotificationsReply
from ray.core.generated.TaskExecutionDependencies import \
TaskExecutionDependencies
from ray.core.generated.TaskReply import TaskReply
from ray.core.generated.DriverTableMessage import DriverTableMessage
from ray.core.generated.LocalSchedulerInfoMessage import \
LocalSchedulerInfoMessage
from ray.core.generated.SubscribeToDBClientTableReply import \
SubscribeToDBClientTableReply
from ray.core.generated.TaskInfo import TaskInfo
import ray.core.generated.ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
from ray.core.generated.TablePrefix import TablePrefix
from ray.core.generated.TablePubsub import TablePubsub
__all__ = [
"SubscribeToNotificationsReply", "ResultTableReply",
"TaskExecutionDependencies", "TaskReply", "DriverTableMessage",
"LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo",
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub",
"construct_error_message"
]
# These prefixes must be kept up-to-date with the definitions in
# ray_redis_module.cc.
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
OBJECT_CHANNEL_PREFIX = "OC:"
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
FUNCTION_PREFIX = "RemoteFunction:"
# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
TablePrefix_OBJECT_string = "OBJECT"
TablePrefix_ERROR_INFO_string = "ERROR_INFO"
def construct_error_message(error_type, message, timestamp):
"""Construct a serialized ErrorTableData object.
Args:
error_type: The type of the error.
message: The error message.
timestamp: The time of the error.
Returns:
The serialized object.
"""
builder = flatbuffers.Builder(0)
error_type_offset = builder.CreateString(error_type)
message_offset = builder.CreateString(message)
ray.core.generated.ErrorTableData.ErrorTableDataStart(builder)
ray.core.generated.ErrorTableData.ErrorTableDataAddType(
builder, error_type_offset)
ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage(
builder, message_offset)
ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp(
builder, timestamp)
error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd(
builder)
builder.Finish(error_data_offset)
return bytes(builder.Output())
-5
View File
@@ -29,11 +29,6 @@ NIL_WORKER_ID = 20 * b"\xff"
NIL_OBJECT_ID = 20 * b"\xff"
NIL_ACTOR_ID = 20 * b"\xff"
# These constants are an implementation detail of ray_redis_module.cc, so this
# must be kept in sync with that file.
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
+15 -19
View File
@@ -9,20 +9,13 @@ import os
import time
from collections import Counter, defaultdict
import ray
import ray.cloudpickle as pickle
import ray.utils
import redis
# Import flatbuffer bindings.
from ray.core.generated.DriverTableMessage import DriverTableMessage
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.LocalSchedulerInfoMessage import \
LocalSchedulerInfoMessage
from ray.core.generated.SubscribeToDBClientTableReply import \
SubscribeToDBClientTableReply
import ray
from ray.autoscaler.autoscaler import LoadMetrics, StandardAutoscaler
from ray.core.generated.TaskInfo import TaskInfo
import ray.cloudpickle as pickle
import ray.gcs_utils
import ray.utils
from ray.services import get_ip_address, get_port
from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary
from ray.worker import NIL_ACTOR_ID
@@ -259,7 +252,7 @@ class Monitor(object):
the associated state in the state tables should be handled by the
caller.
"""
notification_object = (SubscribeToDBClientTableReply.
notification_object = (ray.gcs_utils.SubscribeToDBClientTableReply.
GetRootAsSubscribeToDBClientTableReply(data, 0))
db_client_id = binary_to_hex(notification_object.DbClientId())
client_type = notification_object.ClientType()
@@ -285,8 +278,8 @@ class Monitor(object):
def local_scheduler_info_handler(self, unused_channel, data):
"""Handle a local scheduler heartbeat from Redis."""
message = LocalSchedulerInfoMessage.GetRootAsLocalSchedulerInfoMessage(
data, 0)
message = (ray.gcs_utils.LocalSchedulerInfoMessage.
GetRootAsLocalSchedulerInfoMessage(data, 0))
num_resources = message.DynamicResourcesLength()
static_resources = {}
dynamic_resources = {}
@@ -308,9 +301,10 @@ class Monitor(object):
def xray_heartbeat_handler(self, unused_channel, data):
"""Handle an xray heartbeat message from Redis."""
gcs_entries = GcsTableEntry.GetRootAsGcsTableEntry(data, 0)
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
heartbeat_data = gcs_entries.Entries(0)
message = HeartbeatTableData.GetRootAsHeartbeatTableData(
message = ray.gcs_utils.HeartbeatTableData.GetRootAsHeartbeatTableData(
heartbeat_data, 0)
num_resources = message.ResourcesAvailableLabelLength()
static_resources = {}
@@ -363,7 +357,8 @@ class Monitor(object):
# driver. Use a cursor in order not to block the redis shards.
for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"):
entry = redis.hgetall(key)
task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0)
task_info = ray.gcs_utils.TaskInfo.GetRootAsTaskInfo(
entry[b"TaskSpec"], 0)
if driver_id != task_info.DriverId():
# Ignore tasks that aren't from this driver.
continue
@@ -475,7 +470,8 @@ class Monitor(object):
This releases any GPU resources that were reserved for that driver in
Redis.
"""
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
message = ray.gcs_utils.DriverTableMessage.GetRootAsDriverTableMessage(
data, 0)
driver_id = message.DriverId()
log.info("Driver {} has been removed.".format(
binary_to_hex(driver_id)))
+5
View File
@@ -5,6 +5,8 @@ from __future__ import print_function
import os
import ray
def env_integer(key, default):
if key in os.environ:
@@ -12,6 +14,9 @@ def env_integer(key, default):
return default
ID_SIZE = 20
NIL_JOB_ID = ray.ObjectID(ID_SIZE * b"\x00")
# If a remote function or actor (or some other export) has serialized size
# greater than this quantity, print an warning.
PICKLE_OBJECT_WARNING_SIZE = 10**7
+62 -9
View File
@@ -7,9 +7,12 @@ import hashlib
import numpy as np
import os
import sys
import time
import uuid
import ray.gcs_utils
import ray.local_scheduler
import ray.ray_constants as ray_constants
ERROR_KEY_PREFIX = b"Error:"
DRIVER_ID_LENGTH = 20
@@ -45,7 +48,7 @@ def format_error_message(exception_message, task_exception=False):
return "\n".join(lines)
def push_error_to_driver(redis_client,
def push_error_to_driver(worker,
error_type,
message,
driver_id=None,
@@ -53,7 +56,7 @@ def push_error_to_driver(redis_client,
"""Push an error message to the driver to be printed in the background.
Args:
redis_client: The redis client to use.
worker: The worker to use.
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
@@ -63,15 +66,65 @@ def push_error_to_driver(redis_client,
will be serialized with json and stored in Redis.
"""
if driver_id is None:
driver_id = DRIVER_ID_LENGTH * b"\x00"
driver_id = ray_constants.NIL_JOB_ID.id()
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
data = {} if data is None else data
redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
redis_client.rpush("ErrorKeys", error_key)
if not worker.use_raylet:
worker.redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
worker.redis_client.rpush("ErrorKeys", error_key)
else:
worker.local_scheduler_client.push_error(
ray.ObjectID(driver_id), error_type, message, time.time())
def push_error_to_driver_through_redis(redis_client,
use_raylet,
error_type,
message,
driver_id=None,
data=None):
"""Push an error message to the driver to be printed in the background.
Normally the push_error_to_driver function should be used. However, in some
instances, the local scheduler client is not available, e.g., because the
error happens in Python before the driver or worker has connected to the
backend processes.
Args:
redis_client: The redis client to use.
use_raylet: True if we are using the Raylet code path and false
otherwise.
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
driver_id: The ID of the driver to push the error message to. If this
is None, then the message will be pushed to all drivers.
data: This should be a dictionary mapping strings to strings. It
will be serialized with json and stored in Redis.
"""
if driver_id is None:
driver_id = ray_constants.NIL_JOB_ID.id()
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
data = {} if data is None else data
if not use_raylet:
redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
redis_client.rpush("ErrorKeys", error_key)
else:
# Do everything in Python and through the Python Redis client instead
# of through the raylet.
error_data = ray.gcs_utils.construct_error_message(
error_type, message, time.time())
redis_client.execute_command(
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data)
def is_cython(obj):
+91 -21
View File
@@ -22,6 +22,7 @@ import pyarrow
import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
import ray.experimental.state as state
import ray.gcs_utils
import ray.remote_function
import ray.serialization as serialization
import ray.services as services
@@ -31,9 +32,6 @@ import ray.plasma
import ray.ray_constants as ray_constants
from ray.utils import random_string, binary_to_hex, is_cython
# Import flatbuffer bindings.
from ray.core.generated.ClientTableData import ClientTableData
SCRIPT_MODE = 0
WORKER_MODE = 1
PYTHON_MODE = 2
@@ -415,7 +413,7 @@ class Worker(object):
"may be a bug.")
if not warning_sent:
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
@@ -663,7 +661,7 @@ class Worker(object):
"large array or other object.".format(
function_name, len(pickled_function)))
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
@@ -726,7 +724,7 @@ class Worker(object):
.format(function.__name__,
len(pickled_function)))
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
@@ -781,7 +779,7 @@ class Worker(object):
"Ray.")
if not warning_sent:
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
warning_message,
driver_id=driver_id)
@@ -942,7 +940,7 @@ class Worker(object):
self._store_outputs_in_objstore(return_object_ids, failure_objects)
# Log the error message.
ray.utils.push_error_to_driver(
self.redis_client,
self,
ray_constants.TASK_PUSH_ERROR,
str(failure_object),
driver_id=self.task_driver_id.id(),
@@ -1200,6 +1198,11 @@ def error_info(worker=global_worker):
"""Return information about failed tasks."""
worker.check_connected()
check_main_thread()
if worker.use_raylet:
return (global_state.error_messages(job_id=worker.task_driver_id) +
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
errors = []
for error_key in error_keys:
@@ -1291,9 +1294,8 @@ def get_address_info_from_redis_helper(redis_address,
if not use_raylet:
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
client_keys = redis_client.keys(
"{}*".format(REDIS_CLIENT_TABLE_PREFIX))
client_keys = redis_client.keys("{}*".format(
ray.gcs_utils.DB_CLIENT_PREFIX))
# Filter to live clients on the same node and do some basic checking.
plasma_managers = []
local_schedulers = []
@@ -1350,11 +1352,11 @@ def get_address_info_from_redis_helper(redis_address,
else:
# In the raylet code path, all client data is stored in a zset at the
# key for the nil client.
client_key = b"CLIENT:" + NIL_CLIENT_ID
client_key = b"CLIENT" + NIL_CLIENT_ID
clients = redis_client.zrange(client_key, 0, -1)
raylets = []
for client_message in clients:
client = ClientTableData.GetRootAsClientTableData(
client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
client_message, 0)
client_node_ip_address = client.NodeManagerAddress().decode(
"ascii")
@@ -1819,6 +1821,71 @@ def custom_excepthook(type, value, tb):
sys.excepthook = custom_excepthook
def print_error_messages_raylet(worker):
"""Print error messages in the background on the driver.
This runs in a separate thread on the driver and prints error messages in
the background.
"""
if not worker.use_raylet:
raise Exception("This function is specific to the raylet code path.")
worker.error_message_pubsub_client = worker.redis_client.pubsub(
ignore_subscribe_messages=True)
# Exports that are published after the call to
# error_message_pubsub_client.subscribe and before the call to
# error_message_pubsub_client.listen will still be processed in the loop.
# Really we should just subscribe to the errors for this specific job.
# However, currently all errors seem to be published on the same channel.
error_pubsub_channel = str(
ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii")
worker.error_message_pubsub_client.subscribe(error_pubsub_channel)
# worker.error_message_pubsub_client.psubscribe("*")
# Keep a set of all the error messages that we've seen so far in order to
# avoid printing the same error message repeatedly. This is especially
# important when running a script inside of a tool like screen where
# scrolling is difficult.
old_error_messages = set()
# Get the exports that occurred before the call to subscribe.
with worker.lock:
error_messages = global_state.error_messages(worker.task_driver_id)
for error_message in error_messages:
if error_message not in old_error_messages:
print(error_message)
old_error_messages.add(error_message)
else:
print("Suppressing duplicate error message.")
try:
for msg in worker.error_message_pubsub_client.listen():
gcs_entry = state.GcsTableEntry.GetRootAsGcsTableEntry(
msg["data"], 0)
assert gcs_entry.EntriesLength() == 1
error_data = state.ErrorTableData.GetRootAsErrorTableData(
gcs_entry.Entries(0), 0)
NIL_JOB_ID = 20 * b"\x00"
job_id = error_data.JobId()
if job_id not in [worker.task_driver_id.id(), NIL_JOB_ID]:
continue
error_message = error_data.ErrorMessage().decode("ascii")
if error_message not in old_error_messages:
print(error_message)
old_error_messages.add(error_message)
else:
print("Suppressing duplicate error message.")
except redis.ConnectionError:
# When Redis terminates the listen call will throw a ConnectionError,
# which we catch here.
pass
def print_error_messages(worker):
"""Print error messages in the background on the driver.
@@ -1907,7 +1974,7 @@ def fetch_and_register_remote_function(key, worker=global_worker):
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
@@ -1952,7 +2019,7 @@ def fetch_and_execute_function_to_run(key, worker=global_worker):
name = function.__name__ if ("function" in locals()
and hasattr(function, "__name__")) else ""
ray.utils.push_error_to_driver(
worker.redis_client,
worker,
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
@@ -2111,8 +2178,9 @@ def connect(info,
raise e
elif mode == WORKER_MODE:
traceback_str = traceback.format_exc()
ray.utils.push_error_to_driver(
ray.utils.push_error_to_driver_through_redis(
worker.redis_client,
worker.use_raylet,
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
traceback_str,
driver_id=None)
@@ -2237,13 +2305,11 @@ def connect(info,
driver_task.execution_dependencies_string(), 0,
ray.local_scheduler.task_to_string(driver_task))
else:
TablePubsub_RAYLET_TASK = 2
# TODO(rkn): When we shard the GCS in xray, we will need to change
# this to use _execute_command.
global_state.redis_client.execute_command(
"RAY.TABLE_ADD", state.TablePrefix_RAYLET_TASK,
TablePubsub_RAYLET_TASK,
"RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.RAYLET_TASK,
ray.gcs_utils.TablePubsub.RAYLET_TASK,
driver_task.task_id().id(),
driver_task._serialized_raylet_task())
@@ -2271,7 +2337,11 @@ def connect(info,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
t = threading.Thread(target=print_error_messages, args=(worker, ))
if not worker.use_raylet:
t = threading.Thread(target=print_error_messages, args=(worker, ))
else:
t = threading.Thread(
target=print_error_messages_raylet, args=(worker, ))
# Making the thread a daemon causes it to exit when the main thread
# exits.
t.daemon = True
+4 -3
View File
@@ -69,10 +69,11 @@ if __name__ == "__main__":
ray.worker.global_worker.main_loop()
except Exception as e:
traceback_str = traceback.format_exc() + error_explanation
# Create a Redis client.
redis_client = ray.services.create_redis_client(args.redis_address)
ray.utils.push_error_to_driver(
redis_client, "worker_crash", traceback_str, driver_id=None)
ray.worker.global_worker,
"worker_crash",
traceback_str,
driver_id=None)
# TODO(rkn): Note that if the worker was in the middle of executing
# a task, then any worker or driver that is blocking in a get call
# and waiting for the output of that task will hang. We need to