mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 08:40:02 +08:00
Use flatbuffers for some messages from Redis. (#341)
* Compile the Ray redis module with C++. * Redo parsing of object table notifications with flatbuffers. * Update redis module python tests. * Redo parsing of task table notifications with flatbuffers. * Fix linting. * Redo parsing of db client notifications with flatbuffers. * Redo publishing of local scheduler heartbeats with flatbuffers. * Fix linting. * Remove usage of fixed-width formatting of scheduling state in channel name. * Reply with flatbuffer object to task table queries, also simplify redis string to flatbuffer string conversion. * Fix linting and tests. * fix * cleanup * simplify logic in ReplyWithTask
This commit is contained in:
committed by
Philipp Moritz
parent
555dcf35a2
commit
53dffe0bf2
@@ -11,6 +11,10 @@ import unittest
|
||||
import redis
|
||||
import ray.services
|
||||
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply
|
||||
from ray.core.generated.TaskReply import TaskReply
|
||||
|
||||
OBJECT_INFO_PREFIX = "OI:"
|
||||
OBJECT_LOCATION_PREFIX = "OL:"
|
||||
OBJECT_SUBSCRIBE_PREFIX = "OS:"
|
||||
@@ -142,6 +146,15 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
self.assertEqual(set(response), set())
|
||||
|
||||
def testObjectTableSubscribeToNotifications(self):
|
||||
# Define a helper method for checking the contents of object notifications.
|
||||
def check_object_notification(notification_message, object_id, object_size, manager_ids):
|
||||
notification_object = SubscribeToNotificationsReply.GetRootAsSubscribeToNotificationsReply(notification_message, 0)
|
||||
self.assertEqual(notification_object.ObjectId(), object_id)
|
||||
self.assertEqual(notification_object.ObjectSize(), object_size)
|
||||
self.assertEqual(notification_object.ManagerIdsLength(), len(manager_ids))
|
||||
for i in range(len(manager_ids)):
|
||||
self.assertEqual(notification_object.ManagerIds(i), manager_ids[i])
|
||||
|
||||
data_size = 0xf1f0
|
||||
p = self.redis.pubsub()
|
||||
# Subscribe to an object ID.
|
||||
@@ -151,8 +164,12 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
# Request a notification and receive the data.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id1")
|
||||
self.assertEqual(get_next_message(p)["data"], b"object_id1 %s MANAGERS manager_id2"\
|
||||
%integerToAsciiHex(data_size, 8))
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id1",
|
||||
data_size,
|
||||
[b"manager_id2"])
|
||||
|
||||
# Request a notification for an object that isn't there. Then add the object
|
||||
# and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD should
|
||||
# trigger notifications.
|
||||
@@ -160,15 +177,24 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id3")
|
||||
self.assertEqual(get_next_message(p)["data"], b"object_id3 %s MANAGERS manager_id1"\
|
||||
%integerToAsciiHex(data_size, 8))
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1"])
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", "manager_id3")
|
||||
self.assertEqual(get_next_message(p)["data"], b"object_id2 %s MANAGERS manager_id3"\
|
||||
%integerToAsciiHex(data_size, 8))
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id2",
|
||||
data_size,
|
||||
[b"manager_id3"])
|
||||
# Request notifications for object_id3 again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id3")
|
||||
self.assertEqual(get_next_message(p)["data"], b"object_id3 %s MANAGERS manager_id1 manager_id2 manager_id3"\
|
||||
%integerToAsciiHex(data_size, 8))
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1", b"manager_id2", b"manager_id3"])
|
||||
|
||||
def testResultTableAddAndLookup(self):
|
||||
# Try looking up something in the result table before anything is added.
|
||||
@@ -205,10 +231,6 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
# Non-integer scheduling states should not be added.
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
|
||||
"invalid_state", "node_id", "task_spec")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
# Scheduling states with invalid width should not be added.
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 101,
|
||||
"node_id", "task_spec")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
# Should not be able to update a non-existent task.
|
||||
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10,
|
||||
@@ -219,17 +241,24 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
TASK_STATUS_SCHEDULED = 2
|
||||
TASK_STATUS_QUEUED = 4
|
||||
|
||||
def check_task_reply(message, task_args):
|
||||
task_status, local_scheduler_id, task_spec = task_args
|
||||
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(task_reply_object.State(), task_status)
|
||||
self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id)
|
||||
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
|
||||
|
||||
# Check that task table adds, updates, and lookups work correctly.
|
||||
task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"]
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
|
||||
*task_args)
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
self.assertEqual(response, task_args)
|
||||
check_task_reply(response, task_args)
|
||||
|
||||
task_args[0] = TASK_STATUS_SCHEDULED
|
||||
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2])
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
self.assertEqual(response, task_args)
|
||||
check_task_reply(response, task_args)
|
||||
|
||||
# If the current value, test value, and set value are all the same, the
|
||||
# update happens, and the response is still the same task.
|
||||
@@ -237,10 +266,10 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
self.assertEqual(response, task_args[1:])
|
||||
check_task_reply(response, task_args[1:])
|
||||
# Check that the task entry is still the same.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
self.assertEqual(get_response, task_args[1:])
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
|
||||
# If the current value is the same as the test value, and the set value is
|
||||
# different, the update happens, and the response is the entire task.
|
||||
@@ -248,10 +277,10 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
self.assertEqual(response, task_args[1:])
|
||||
check_task_reply(response, task_args[1:])
|
||||
# Check that the update happened.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
self.assertEqual(get_response, task_args[1:])
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
|
||||
# If the current value is no longer the same as the test value, the
|
||||
# response is nil.
|
||||
@@ -271,7 +300,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
self.assertEqual(response, task_args[1:])
|
||||
check_task_reply(response, task_args[1:])
|
||||
|
||||
# If the test value is a bitmask that does not match the current value, the
|
||||
# update does not happen.
|
||||
@@ -288,13 +317,13 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
|
||||
def testTaskTableSubscribe(self):
|
||||
scheduling_state = 1
|
||||
node_id = "node_id"
|
||||
local_scheduler_id = "local_scheduler_id"
|
||||
# Subscribe to the task table.
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
p.psubscribe("{prefix}*:{state: >2}".format(prefix=TASK_PREFIX, state=scheduling_state))
|
||||
p.psubscribe("{prefix}{node}:*".format(prefix=TASK_PREFIX, node=node_id))
|
||||
task_args = [b"task_id", scheduling_state, node_id.encode("ascii"), b"task_spec"]
|
||||
p.psubscribe("{prefix}*:{state}".format(prefix=TASK_PREFIX, state=scheduling_state))
|
||||
p.psubscribe("{prefix}{local_scheduler_id}:*".format(prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the acknowledgement message.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
@@ -302,10 +331,13 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
self.assertEqual(get_next_message(p)["data"], 3)
|
||||
# Receive the actual data.
|
||||
for i in range(3):
|
||||
message = get_next_message(p)["data"]
|
||||
message = message.split()
|
||||
message[1] = int(message[1])
|
||||
self.assertEqual(message, task_args)
|
||||
message = get_next_message(p)["data"]
|
||||
# Check that the notification object is correct.
|
||||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), b"task_id")
|
||||
self.assertEqual(notification_object.State(), scheduling_state)
|
||||
self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii"))
|
||||
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -34,7 +34,7 @@ TASK_STATUS_QUEUED = 4
|
||||
TASK_STATUS_RUNNING = 8
|
||||
TASK_STATUS_DONE = 16
|
||||
|
||||
# These constants are an implementation detail of ray_redis_module.c, so this
|
||||
# These constants are an implementation detail of ray_redis_module.cc, so this
|
||||
# must be kept in sync with that file.
|
||||
DB_CLIENT_PREFIX = "CL:"
|
||||
TASK_PREFIX = "TT:"
|
||||
|
||||
+10
-9
@@ -12,13 +12,16 @@ import time
|
||||
from ray.services import get_ip_address
|
||||
from ray.services import get_port
|
||||
|
||||
# Import flatbuffer bindings.
|
||||
from ray.core.generated.SubscribeToDBClientTableReply import SubscribeToDBClientTableReply
|
||||
|
||||
# These variables must be kept in sync with the C codebase.
|
||||
# common/common.h
|
||||
DB_CLIENT_ID_SIZE = 20
|
||||
NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
|
||||
# common/task.h
|
||||
TASK_STATUS_LOST = 32
|
||||
# common/redis_module/ray_redis_module.c
|
||||
# common/redis_module/ray_redis_module.cc
|
||||
TASK_PREFIX = "TT:"
|
||||
DB_CLIENT_PREFIX = "CL:"
|
||||
DB_CLIENT_TABLE_NAME = b"db_clients"
|
||||
@@ -89,14 +92,12 @@ class Monitor(object):
|
||||
|
||||
# Parse the message.
|
||||
data = message["data"]
|
||||
db_client_id = data[:DB_CLIENT_ID_SIZE]
|
||||
data = data[DB_CLIENT_ID_SIZE + 1:]
|
||||
data = data.split(b" ")
|
||||
client_type, auxiliary_address, is_insertion = data
|
||||
is_insertion = int(is_insertion)
|
||||
if is_insertion != 1 and is_insertion != 0:
|
||||
raise Exception("Expected 0 or 1 for insertion field, got {} instead".format(is_insertion))
|
||||
is_insertion = bool(is_insertion)
|
||||
|
||||
notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0)
|
||||
db_client_id = notification_object.DbClientId()
|
||||
client_type = notification_object.ClientType()
|
||||
auxiliary_address = notification_object.AuxAddress()
|
||||
is_insertion = notification_object.IsInsertion()
|
||||
|
||||
return db_client_id, client_type, auxiliary_address, is_insertion
|
||||
|
||||
|
||||
@@ -723,7 +723,7 @@ def get_address_info_from_redis_helper(redis_address, node_ip_address):
|
||||
# must have run "CONFIG SET protected-mode no".
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
|
||||
# The client table prefix must be kept in sync with the file
|
||||
# "src/common/redis_module/ray_redis_module.c" where it is defined.
|
||||
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
|
||||
REDIS_CLIENT_TABLE_PREFIX = "CL:"
|
||||
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
|
||||
# Filter to clients on the same node and do some basic checking.
|
||||
|
||||
Reference in New Issue
Block a user