Use flatbuffers for some messages from Redis. (#341)

* Compile the Ray redis module with C++.

* Redo parsing of object table notifications with flatbuffers.

* Update redis module python tests.

* Redo parsing of task table notifications with flatbuffers.

* Fix linting.

* Redo parsing of db client notifications with flatbuffers.

* Redo publishing of local scheduler heartbeats with flatbuffers.

* Fix linting.

* Remove usage of fixed-width formatting of scheduling state in channel name.

* Reply with flatbuffer object to task table queries, also simplify redis string to flatbuffer string conversion.

* Fix linting and tests.

* fix

* cleanup

* simplify logic in ReplyWithTask
This commit is contained in:
Robert Nishihara
2017-03-10 18:35:25 -08:00
committed by Philipp Moritz
parent 555dcf35a2
commit 53dffe0bf2
17 changed files with 379 additions and 385 deletions
+59 -27
View File
@@ -11,6 +11,10 @@ import unittest
import redis
import ray.services
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply
from ray.core.generated.TaskReply import TaskReply
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
OBJECT_SUBSCRIBE_PREFIX = "OS:"
@@ -142,6 +146,15 @@ class TestGlobalStateStore(unittest.TestCase):
self.assertEqual(set(response), set())
def testObjectTableSubscribeToNotifications(self):
# Define a helper method for checking the contents of object notifications.
def check_object_notification(notification_message, object_id, object_size, manager_ids):
notification_object = SubscribeToNotificationsReply.GetRootAsSubscribeToNotificationsReply(notification_message, 0)
self.assertEqual(notification_object.ObjectId(), object_id)
self.assertEqual(notification_object.ObjectSize(), object_size)
self.assertEqual(notification_object.ManagerIdsLength(), len(manager_ids))
for i in range(len(manager_ids)):
self.assertEqual(notification_object.ManagerIds(i), manager_ids[i])
data_size = 0xf1f0
p = self.redis.pubsub()
# Subscribe to an object ID.
@@ -151,8 +164,12 @@ class TestGlobalStateStore(unittest.TestCase):
self.assertEqual(get_next_message(p)["data"], 1)
# Request a notification and receive the data.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id1")
self.assertEqual(get_next_message(p)["data"], b"object_id1 %s MANAGERS manager_id2"\
%integerToAsciiHex(data_size, 8))
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id1",
data_size,
[b"manager_id2"])
# Request a notification for an object that isn't there. Then add the object
# and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD should
# trigger notifications.
@@ -160,15 +177,24 @@ class TestGlobalStateStore(unittest.TestCase):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id3")
self.assertEqual(get_next_message(p)["data"], b"object_id3 %s MANAGERS manager_id1"\
%integerToAsciiHex(data_size, 8))
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1"])
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", "manager_id3")
self.assertEqual(get_next_message(p)["data"], b"object_id2 %s MANAGERS manager_id3"\
%integerToAsciiHex(data_size, 8))
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id2",
data_size,
[b"manager_id3"])
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id3")
self.assertEqual(get_next_message(p)["data"], b"object_id3 %s MANAGERS manager_id1 manager_id2 manager_id3"\
%integerToAsciiHex(data_size, 8))
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1", b"manager_id2", b"manager_id3"])
def testResultTableAddAndLookup(self):
# Try looking up something in the result table before anything is added.
@@ -205,10 +231,6 @@ class TestGlobalStateStore(unittest.TestCase):
# Non-integer scheduling states should not be added.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
"invalid_state", "node_id", "task_spec")
with self.assertRaises(redis.ResponseError):
# Scheduling states with invalid width should not be added.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 101,
"node_id", "task_spec")
with self.assertRaises(redis.ResponseError):
# Should not be able to update a non-existent task.
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10,
@@ -219,17 +241,24 @@ class TestGlobalStateStore(unittest.TestCase):
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
def check_task_reply(message, task_args):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id)
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
# Check that task table adds, updates, and lookups work correctly.
task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"]
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
*task_args)
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(response, task_args)
check_task_reply(response, task_args)
task_args[0] = TASK_STATUS_SCHEDULED
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2])
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(response, task_args)
check_task_reply(response, task_args)
# If the current value, test value, and set value are all the same, the
# update happens, and the response is still the same task.
@@ -237,10 +266,10 @@ class TestGlobalStateStore(unittest.TestCase):
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
self.assertEqual(response, task_args[1:])
check_task_reply(response, task_args[1:])
# Check that the task entry is still the same.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(get_response, task_args[1:])
check_task_reply(get_response, task_args[1:])
# If the current value is the same as the test value, and the set value is
# different, the update happens, and the response is the entire task.
@@ -248,10 +277,10 @@ class TestGlobalStateStore(unittest.TestCase):
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
self.assertEqual(response, task_args[1:])
check_task_reply(response, task_args[1:])
# Check that the update happened.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(get_response, task_args[1:])
check_task_reply(get_response, task_args[1:])
# If the current value is no longer the same as the test value, the
# response is nil.
@@ -271,7 +300,7 @@ class TestGlobalStateStore(unittest.TestCase):
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
self.assertEqual(response, task_args[1:])
check_task_reply(response, task_args[1:])
# If the test value is a bitmask that does not match the current value, the
# update does not happen.
@@ -288,13 +317,13 @@ class TestGlobalStateStore(unittest.TestCase):
def testTaskTableSubscribe(self):
scheduling_state = 1
node_id = "node_id"
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.psubscribe("{prefix}*:{state: >2}".format(prefix=TASK_PREFIX, state=scheduling_state))
p.psubscribe("{prefix}{node}:*".format(prefix=TASK_PREFIX, node=node_id))
task_args = [b"task_id", scheduling_state, node_id.encode("ascii"), b"task_spec"]
p.psubscribe("{prefix}*:{state}".format(prefix=TASK_PREFIX, state=scheduling_state))
p.psubscribe("{prefix}{local_scheduler_id}:*".format(prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the acknowledgement message.
self.assertEqual(get_next_message(p)["data"], 1)
@@ -302,10 +331,13 @@ class TestGlobalStateStore(unittest.TestCase):
self.assertEqual(get_next_message(p)["data"], 3)
# Receive the actual data.
for i in range(3):
message = get_next_message(p)["data"]
message = message.split()
message[1] = int(message[1])
self.assertEqual(message, task_args)
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
if __name__ == "__main__":
unittest.main(verbosity=2)
+1 -1
View File
@@ -34,7 +34,7 @@ TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
# These constants are an implementation detail of ray_redis_module.c, so this
# These constants are an implementation detail of ray_redis_module.cc, so this
# must be kept in sync with that file.
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
+10 -9
View File
@@ -12,13 +12,16 @@ import time
from ray.services import get_ip_address
from ray.services import get_port
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToDBClientTableReply import SubscribeToDBClientTableReply
# These variables must be kept in sync with the C codebase.
# common/common.h
DB_CLIENT_ID_SIZE = 20
NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
# common/task.h
TASK_STATUS_LOST = 32
# common/redis_module/ray_redis_module.c
# common/redis_module/ray_redis_module.cc
TASK_PREFIX = "TT:"
DB_CLIENT_PREFIX = "CL:"
DB_CLIENT_TABLE_NAME = b"db_clients"
@@ -89,14 +92,12 @@ class Monitor(object):
# Parse the message.
data = message["data"]
db_client_id = data[:DB_CLIENT_ID_SIZE]
data = data[DB_CLIENT_ID_SIZE + 1:]
data = data.split(b" ")
client_type, auxiliary_address, is_insertion = data
is_insertion = int(is_insertion)
if is_insertion != 1 and is_insertion != 0:
raise Exception("Expected 0 or 1 for insertion field, got {} instead".format(is_insertion))
is_insertion = bool(is_insertion)
notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0)
db_client_id = notification_object.DbClientId()
client_type = notification_object.ClientType()
auxiliary_address = notification_object.AuxAddress()
is_insertion = notification_object.IsInsertion()
return db_client_id, client_type, auxiliary_address, is_insertion
+1 -1
View File
@@ -723,7 +723,7 @@ def get_address_info_from_redis_helper(redis_address, node_ip_address):
# must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.c" where it is defined.
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
# Filter to clients on the same node and do some basic checking.