Use flatbuffers for some messages from Redis. (#341)

* Compile the Ray redis module with C++.

* Redo parsing of object table notifications with flatbuffers.

* Update redis module python tests.

* Redo parsing of task table notifications with flatbuffers.

* Fix linting.

* Redo parsing of db client notifications with flatbuffers.

* Redo publishing of local scheduler heartbeats with flatbuffers.

* Fix linting.

* Remove usage of fixed-width formatting of scheduling state in channel name.

* Reply with flatbuffer object to task table queries, also simplify redis string to flatbuffer string conversion.

* Fix linting and tests.

* fix

* cleanup

* simplify logic in ReplyWithTask
This commit is contained in:
Robert Nishihara
2017-03-10 18:35:25 -08:00
committed by Philipp Moritz
parent 555dcf35a2
commit 53dffe0bf2
17 changed files with 379 additions and 385 deletions
+4 -4
View File
@@ -20,7 +20,7 @@ fi
if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then
sudo apt-get update
sudo apt-get install -y cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip
sudo pip install cloudpickle funcsigs colorama psutil redis tensorflow
sudo pip install cloudpickle funcsigs colorama psutil redis tensorflow flatbuffers
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then
sudo apt-get update
sudo apt-get install -y cmake python-dev python-numpy build-essential autoconf curl libtool libboost-all-dev unzip
@@ -28,7 +28,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow
pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow flatbuffers
elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then
# check that brew is installed
which -s brew
@@ -41,7 +41,7 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then
fi
brew install cmake automake autoconf libtool boost
sudo easy_install pip
sudo pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow --ignore-installed six
sudo pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow flatbuffers --ignore-installed six
elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then
# check that brew is installed
which -s brew
@@ -57,7 +57,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then
wget https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow
pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow flatbuffers
elif [[ "$LINT" == "1" ]]; then
sudo apt-get update
sudo apt-get install -y cmake build-essential autoconf curl libtool libboost-all-dev unzip
+1 -1
View File
@@ -11,7 +11,7 @@ To install Ray, first install the following dependencies. We recommend using
brew update
brew install cmake automake autoconf libtool boost wget
pip install numpy cloudpickle funcsigs colorama psutil redis --ignore-installed six
pip install numpy cloudpickle funcsigs colorama psutil redis flatbuffers --ignore-installed six
```
If you are using Anaconda, you may also need to run the following.
+1 -1
View File
@@ -12,7 +12,7 @@ To install Ray, first install the following dependencies. We recommend using
sudo apt-get update
sudo apt-get install -y cmake build-essential autoconf curl libtool libboost-all-dev unzip python-dev python-pip # If you're using Anaconda, then python-dev and python-pip are unnecessary.
pip install numpy cloudpickle funcsigs colorama psutil redis
pip install numpy cloudpickle funcsigs colorama psutil redis flatbuffers
```
If you are using Anaconda, you may also need to run the following.
+1
View File
@@ -11,4 +11,5 @@ RUN echo 'export PATH=/opt/conda/bin:$PATH' > /etc/profile.d/conda.sh \
&& rm /tmp/anaconda.sh
ENV PATH "/opt/conda/bin:$PATH"
RUN conda install -y libgcc
RUN pip install flatbuffers
RUN pip install --upgrade pip cloudpickle
+59 -27
View File
@@ -11,6 +11,10 @@ import unittest
import redis
import ray.services
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply
from ray.core.generated.TaskReply import TaskReply
OBJECT_INFO_PREFIX = "OI:"
OBJECT_LOCATION_PREFIX = "OL:"
OBJECT_SUBSCRIBE_PREFIX = "OS:"
@@ -142,6 +146,15 @@ class TestGlobalStateStore(unittest.TestCase):
self.assertEqual(set(response), set())
def testObjectTableSubscribeToNotifications(self):
# Define a helper method for checking the contents of object notifications.
def check_object_notification(notification_message, object_id, object_size, manager_ids):
notification_object = SubscribeToNotificationsReply.GetRootAsSubscribeToNotificationsReply(notification_message, 0)
self.assertEqual(notification_object.ObjectId(), object_id)
self.assertEqual(notification_object.ObjectSize(), object_size)
self.assertEqual(notification_object.ManagerIdsLength(), len(manager_ids))
for i in range(len(manager_ids)):
self.assertEqual(notification_object.ManagerIds(i), manager_ids[i])
data_size = 0xf1f0
p = self.redis.pubsub()
# Subscribe to an object ID.
@@ -151,8 +164,12 @@ class TestGlobalStateStore(unittest.TestCase):
self.assertEqual(get_next_message(p)["data"], 1)
# Request a notification and receive the data.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id1")
self.assertEqual(get_next_message(p)["data"], b"object_id1 %s MANAGERS manager_id2"\
%integerToAsciiHex(data_size, 8))
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id1",
data_size,
[b"manager_id2"])
# Request a notification for an object that isn't there. Then add the object
# and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD should
# trigger notifications.
@@ -160,15 +177,24 @@ class TestGlobalStateStore(unittest.TestCase):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id3")
self.assertEqual(get_next_message(p)["data"], b"object_id3 %s MANAGERS manager_id1"\
%integerToAsciiHex(data_size, 8))
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1"])
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", "manager_id3")
self.assertEqual(get_next_message(p)["data"], b"object_id2 %s MANAGERS manager_id3"\
%integerToAsciiHex(data_size, 8))
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id2",
data_size,
[b"manager_id3"])
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id3")
self.assertEqual(get_next_message(p)["data"], b"object_id3 %s MANAGERS manager_id1 manager_id2 manager_id3"\
%integerToAsciiHex(data_size, 8))
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1", b"manager_id2", b"manager_id3"])
def testResultTableAddAndLookup(self):
# Try looking up something in the result table before anything is added.
@@ -205,10 +231,6 @@ class TestGlobalStateStore(unittest.TestCase):
# Non-integer scheduling states should not be added.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
"invalid_state", "node_id", "task_spec")
with self.assertRaises(redis.ResponseError):
# Scheduling states with invalid width should not be added.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 101,
"node_id", "task_spec")
with self.assertRaises(redis.ResponseError):
# Should not be able to update a non-existent task.
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10,
@@ -219,17 +241,24 @@ class TestGlobalStateStore(unittest.TestCase):
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
def check_task_reply(message, task_args):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id)
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
# Check that task table adds, updates, and lookups work correctly.
task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"]
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
*task_args)
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(response, task_args)
check_task_reply(response, task_args)
task_args[0] = TASK_STATUS_SCHEDULED
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2])
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(response, task_args)
check_task_reply(response, task_args)
# If the current value, test value, and set value are all the same, the
# update happens, and the response is still the same task.
@@ -237,10 +266,10 @@ class TestGlobalStateStore(unittest.TestCase):
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
self.assertEqual(response, task_args[1:])
check_task_reply(response, task_args[1:])
# Check that the task entry is still the same.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(get_response, task_args[1:])
check_task_reply(get_response, task_args[1:])
# If the current value is the same as the test value, and the set value is
# different, the update happens, and the response is the entire task.
@@ -248,10 +277,10 @@ class TestGlobalStateStore(unittest.TestCase):
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
self.assertEqual(response, task_args[1:])
check_task_reply(response, task_args[1:])
# Check that the update happened.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(get_response, task_args[1:])
check_task_reply(get_response, task_args[1:])
# If the current value is no longer the same as the test value, the
# response is nil.
@@ -271,7 +300,7 @@ class TestGlobalStateStore(unittest.TestCase):
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
self.assertEqual(response, task_args[1:])
check_task_reply(response, task_args[1:])
# If the test value is a bitmask that does not match the current value, the
# update does not happen.
@@ -288,13 +317,13 @@ class TestGlobalStateStore(unittest.TestCase):
def testTaskTableSubscribe(self):
scheduling_state = 1
node_id = "node_id"
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
p.psubscribe("{prefix}*:{state: >2}".format(prefix=TASK_PREFIX, state=scheduling_state))
p.psubscribe("{prefix}{node}:*".format(prefix=TASK_PREFIX, node=node_id))
task_args = [b"task_id", scheduling_state, node_id.encode("ascii"), b"task_spec"]
p.psubscribe("{prefix}*:{state}".format(prefix=TASK_PREFIX, state=scheduling_state))
p.psubscribe("{prefix}{local_scheduler_id}:*".format(prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the acknowledgement message.
self.assertEqual(get_next_message(p)["data"], 1)
@@ -302,10 +331,13 @@ class TestGlobalStateStore(unittest.TestCase):
self.assertEqual(get_next_message(p)["data"], 3)
# Receive the actual data.
for i in range(3):
message = get_next_message(p)["data"]
message = message.split()
message[1] = int(message[1])
self.assertEqual(message, task_args)
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
if __name__ == "__main__":
unittest.main(verbosity=2)
+1 -1
View File
@@ -34,7 +34,7 @@ TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
# These constants are an implementation detail of ray_redis_module.c, so this
# These constants are an implementation detail of ray_redis_module.cc, so this
# must be kept in sync with that file.
DB_CLIENT_PREFIX = "CL:"
TASK_PREFIX = "TT:"
+10 -9
View File
@@ -12,13 +12,16 @@ import time
from ray.services import get_ip_address
from ray.services import get_port
# Import flatbuffer bindings.
from ray.core.generated.SubscribeToDBClientTableReply import SubscribeToDBClientTableReply
# These variables must be kept in sync with the C codebase.
# common/common.h
DB_CLIENT_ID_SIZE = 20
NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE
# common/task.h
TASK_STATUS_LOST = 32
# common/redis_module/ray_redis_module.c
# common/redis_module/ray_redis_module.cc
TASK_PREFIX = "TT:"
DB_CLIENT_PREFIX = "CL:"
DB_CLIENT_TABLE_NAME = b"db_clients"
@@ -89,14 +92,12 @@ class Monitor(object):
# Parse the message.
data = message["data"]
db_client_id = data[:DB_CLIENT_ID_SIZE]
data = data[DB_CLIENT_ID_SIZE + 1:]
data = data.split(b" ")
client_type, auxiliary_address, is_insertion = data
is_insertion = int(is_insertion)
if is_insertion != 1 and is_insertion != 0:
raise Exception("Expected 0 or 1 for insertion field, got {} instead".format(is_insertion))
is_insertion = bool(is_insertion)
notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0)
db_client_id = notification_object.DbClientId()
client_type = notification_object.ClientType()
auxiliary_address = notification_object.AuxAddress()
is_insertion = notification_object.IsInsertion()
return db_client_id, client_type, auxiliary_address, is_insertion
+1 -1
View File
@@ -723,7 +723,7 @@ def get_address_info_from_redis_helper(redis_address, node_ip_address):
# must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.c" where it is defined.
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
# Filter to clients on the same node and do some basic checking.
+1 -1
View File
@@ -21,7 +21,7 @@ def check_no_existing_redis_clients(node_ip_address, redis_address):
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.c" where it is defined.
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
# Filter to clients on the same node and do some basic checking.
+9
View File
@@ -24,6 +24,15 @@ add_custom_command(
COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}"
VERBATIM)
# Generate Python bindings for the flatbuffers objects.
set(PYTHON_OUTPUT_DIR ${CMAKE_BINARY_DIR}/generated/)
add_custom_command(
TARGET gen_common_fbs
COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} ${COMMON_FBS_SRC}
DEPENDS ${FBS_DEPENDS}
COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}"
VERBATIM)
add_dependencies(gen_common_fbs flatbuffers_ep)
add_custom_target(
+59
View File
@@ -50,3 +50,62 @@ table TaskInfo {
}
root_type TaskInfo;
table SubscribeToNotificationsReply {
// The object ID of the object that the notification is about.
object_id: string;
// The size of the object.
object_size: long;
// The IDs of the managers that contain this object.
manager_ids: [string];
}
root_type SubscribeToNotificationsReply;
table TaskReply {
// The task ID of the task that the message is about.
task_id: string;
// The state of the task. This is encoded as a bit mask of scheduling_state
// enum values in task.h.
state: long;
// A local scheduler ID.
local_scheduler_id: string;
// A string of bytes representing the task specification.
task_spec: string;
}
root_type TaskReply;
table SubscribeToDBClientTableReply {
// The db client ID of the client that the message is about.
db_client_id: string;
// The type of the client.
client_type: string;
// If the client is a local scheduler, this is the address of the plasma
// manager that the local scheduler is connected to. Otherwise, it is empty.
aux_address: string;
// True if the message is about the addition of a client and false if it is
// about the deletion of a client.
is_insertion: bool;
}
root_type SubscribeToDBClientTableReply;
table LocalSchedulerInfoMessage {
// The db client ID of the client that the message is about.
db_client_id: string;
// The total number of workers that are connected to this local scheduler.
total_num_workers: long;
// The number of tasks queued in this local scheduler.
task_queue_length: long;
// The number of workers that are available and waiting for tasks.
available_workers: long;
// The resource vector of resources generally available to this local
// scheduler.
static_resources: [double];
// The resource vector of resources currently available to this local
// scheduler.
dynamic_resources: [double];
}
root_type LocalSchedulerInfoMessage;
+3 -3
View File
@@ -3,15 +3,15 @@ cmake_minimum_required(VERSION 2.8)
project(ray_redis_module)
if(APPLE)
set(REDIS_MODULE_CFLAGS -W -Wall -dynamic -fno-common -g -ggdb -std=c99 -O2)
set(REDIS_MODULE_CFLAGS -W -Wall -dynamic -fno-common -g -ggdb -std=c++11 -O2)
set(REDIS_MODULE_LDFLAGS "-undefined dynamic_lookup")
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
else()
set(REDIS_MODULE_CFLAGS -W -Wall -fno-common -g -ggdb -std=c99 -O2)
set(REDIS_MODULE_CFLAGS -W -Wall -fno-common -g -ggdb -std=c++11 -O2)
set(REDIS_MODULE_LDFLAGS -shared)
endif()
add_library(ray_redis_module SHARED ray_redis_module.c)
add_library(ray_redis_module SHARED ray_redis_module.cc)
target_compile_options(ray_redis_module PUBLIC ${REDIS_MODULE_CFLAGS} -fPIC)
target_link_libraries(ray_redis_module ${REDIS_MODULE_LDFLAGS})
@@ -5,6 +5,10 @@
#include "redis_string.h"
#include "format/common_generated.h"
#include "common_protocol.h"
/**
* Various tables are maintained in redis:
*
@@ -42,11 +46,29 @@ RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx,
int mode) {
RedisModuleString *prefixed_keyname =
RedisString_Format(ctx, "%s%S", prefix, keyname);
RedisModuleKey *key = RedisModule_OpenKey(ctx, prefixed_keyname, mode);
RedisModuleKey *key =
(RedisModuleKey *) RedisModule_OpenKey(ctx, prefixed_keyname, mode);
RedisModule_FreeString(ctx, prefixed_keyname);
return key;
}
/**
* This is a helper method to convert a redis module string to a flatbuffer
* string.
*
* @param fbb The flatbuffer builder.
* @param redis_string The redis string.
* @return The flatbuffer string.
*/
flatbuffers::Offset<flatbuffers::String> RedisStringToFlatbuf(
flatbuffers::FlatBufferBuilder &fbb,
RedisModuleString *redis_string) {
size_t redis_string_size;
const char *redis_string_str =
RedisModule_StringPtrLen(redis_string, &redis_string_size);
return fbb.CreateString(redis_string_str, redis_string_size);
}
/**
* Publish a notification to a client's notification channel about an insertion
* or deletion to the db client table.
@@ -69,16 +91,23 @@ bool PublishDBClientNotification(RedisModuleCtx *ctx,
/* Construct strings to publish on the db client channel. */
RedisModuleString *channel_name =
RedisModule_CreateString(ctx, "db_clients", strlen("db_clients"));
RedisModuleString *client_info;
const char *is_insertion_string = is_insertion ? "1" : "0";
if (aux_address) {
client_info =
RedisString_Format(ctx, "%S:%S %S %s", ray_client_id, client_type,
aux_address, is_insertion_string);
/* Construct the flatbuffers object to publish over the channel. */
flatbuffers::FlatBufferBuilder fbb;
/* Use an empty aux address if one is not passed in. */
flatbuffers::Offset<flatbuffers::String> aux_address_str;
if (aux_address != NULL) {
aux_address_str = RedisStringToFlatbuf(fbb, aux_address);
} else {
client_info = RedisString_Format(ctx, "%S:%S : %s", ray_client_id,
client_type, is_insertion_string);
aux_address_str = fbb.CreateString("", strlen(""));
}
/* Create the flatbuffers message. */
auto message = CreateSubscribeToDBClientTableReply(
fbb, RedisStringToFlatbuf(fbb, ray_client_id),
RedisStringToFlatbuf(fbb, client_type), aux_address_str, is_insertion);
fbb.Finish(message);
/* Create a Redis string to publish by serializing the flatbuffers object. */
RedisModuleString *client_info = RedisModule_CreateString(
ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize());
/* Publish the client info on the db client channel. */
RedisModuleCallReply *reply;
@@ -328,48 +357,42 @@ bool PublishObjectNotification(RedisModuleCtx *ctx,
RedisModuleString *object_id,
RedisModuleString *data_size,
RedisModuleKey *key) {
/* Create a string formatted as "<object id> MANAGERS <size> <manager id1>
* <manager id2> ..." */
flatbuffers::FlatBufferBuilder fbb;
long long data_size_value;
if (RedisModule_StringToLongLong(data_size, &data_size_value) !=
REDISMODULE_OK) {
return RedisModule_ReplyWithError(ctx, "data_size must be integer");
}
RedisModuleString *manager_list = RedisString_Format(ctx, "%S ", object_id);
/* Append binary data size for this object. */
/* TODO(pcm): Replace by a formatted fix length version of the size. */
RedisModule_StringAppendBuffer(ctx, manager_list,
(const char *) &data_size_value,
sizeof(data_size_value));
RedisModule_StringAppendBuffer(ctx, manager_list, " MANAGERS",
strlen(" MANAGERS"));
std::vector<flatbuffers::Offset<flatbuffers::String>> manager_ids;
CHECK_ERROR(
RedisModule_ZsetFirstInScoreRange(key, REDISMODULE_NEGATIVE_INFINITE,
REDISMODULE_POSITIVE_INFINITE, 1, 1),
"Unable to initialize zset iterator");
/* Loop over the managers in the object table for this object ID. */
do {
RedisModuleString *curr = RedisModule_ZsetRangeCurrentElement(key, NULL);
RedisModule_StringAppendBuffer(ctx, manager_list, " ", 1);
size_t size;
const char *val = RedisModule_StringPtrLen(curr, &size);
RedisModule_StringAppendBuffer(ctx, manager_list, val, size);
manager_ids.push_back(RedisStringToFlatbuf(fbb, curr));
} while (RedisModule_ZsetRangeNext(key));
auto message = CreateSubscribeToNotificationsReply(
fbb, RedisStringToFlatbuf(fbb, object_id), data_size_value,
fbb.CreateVector(manager_ids));
fbb.Finish(message);
/* Publish the notification to the clients notification channel.
* TODO(rkn): These notifications could be batched together. */
RedisModuleString *channel_name =
RedisString_Format(ctx, "%s%S", OBJECT_CHANNEL_PREFIX, client_id);
RedisModuleString *payload = RedisModule_CreateString(
ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize());
RedisModuleCallReply *reply;
reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, manager_list);
reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, payload);
RedisModule_FreeString(ctx, channel_name);
RedisModule_FreeString(ctx, manager_list);
RedisModule_FreeString(ctx, payload);
if (reply == NULL) {
return false;
}
@@ -668,35 +691,6 @@ int ResultTableAdd_RedisCommand(RedisModuleCtx *ctx,
return REDISMODULE_OK;
}
int ParseTaskState(RedisModuleString *state) {
size_t state_length;
const char *state_string = RedisModule_StringPtrLen(state, &state_length);
int state_integer;
int scanned = sscanf(state_string, "%2d", &state_integer);
if (scanned != 1 || state_length != 2) {
return -1;
}
return state_integer;
}
RedisModuleString *NormalizeTaskState(RedisModuleCtx *ctx,
RedisModuleString *state) {
/* Pad the state integer to a fixed-width integer, and make sure it has width
* less than or equal to 2. */
long long state_integer;
int status = RedisModule_StringToLongLong(state, &state_integer);
if (status != REDISMODULE_OK) {
return NULL;
}
state = RedisModule_CreateStringPrintf(ctx, "%2d", state_integer);
size_t length;
RedisModule_StringPtrLen(state, &length);
if (length != 2) {
return NULL;
}
return state;
}
/**
* Reply with information about a task ID. This is used by
* RAY.RESULT_TABLE_LOOKUP and RAY.TASK_TABLE_GET.
@@ -716,8 +710,9 @@ int ReplyWithTask(RedisModuleCtx *ctx, RedisModuleString *task_id) {
RedisModuleString *state = NULL;
RedisModuleString *local_scheduler_id = NULL;
RedisModuleString *task_spec = NULL;
RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", &state, "node",
&local_scheduler_id, "TaskSpec", &task_spec, NULL);
RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", &state,
"local_scheduler_id", &local_scheduler_id, "TaskSpec",
&task_spec, NULL);
if (state == NULL || local_scheduler_id == NULL || task_spec == NULL) {
/* We must have either all fields or no fields. */
RedisModule_CloseKey(key);
@@ -725,21 +720,26 @@ int ReplyWithTask(RedisModuleCtx *ctx, RedisModuleString *task_id) {
ctx, "Missing fields in the task table entry");
}
int state_integer = ParseTaskState(state);
if (state_integer < 0) {
long long state_integer;
if (RedisModule_StringToLongLong(state, &state_integer) != REDISMODULE_OK ||
state_integer < 0) {
RedisModule_CloseKey(key);
RedisModule_FreeString(ctx, state);
RedisModule_FreeString(ctx, local_scheduler_id);
RedisModule_FreeString(ctx, task_spec);
return RedisModule_ReplyWithError(ctx,
"Found invalid scheduling state (must "
"be an integer of width 2");
return RedisModule_ReplyWithError(ctx, "Found invalid scheduling state.");
}
RedisModule_ReplyWithArray(ctx, 3);
RedisModule_ReplyWithLongLong(ctx, state_integer);
RedisModule_ReplyWithString(ctx, local_scheduler_id);
RedisModule_ReplyWithString(ctx, task_spec);
flatbuffers::FlatBufferBuilder fbb;
auto message =
CreateTaskReply(fbb, RedisStringToFlatbuf(fbb, task_id), state_integer,
RedisStringToFlatbuf(fbb, local_scheduler_id),
RedisStringToFlatbuf(fbb, task_spec));
fbb.Finish(message);
RedisModuleString *reply = RedisModule_CreateString(
ctx, (char *) fbb.GetBufferPointer(), fbb.GetSize());
RedisModule_ReplyWithString(ctx, reply);
RedisModule_FreeString(ctx, state);
RedisModule_FreeString(ctx, local_scheduler_id);
@@ -801,54 +801,59 @@ int ResultTableLookup_RedisCommand(RedisModuleCtx *ctx,
int TaskTableWrite(RedisModuleCtx *ctx,
RedisModuleString *task_id,
RedisModuleString *state,
RedisModuleString *node_id,
RedisModuleString *local_scheduler_id,
RedisModuleString *task_spec) {
/* Pad the state integer to a fixed-width integer, and make sure it has width
* less than or equal to 2. */
state = NormalizeTaskState(ctx, state);
if (state == NULL) {
return RedisModule_ReplyWithError(
ctx,
"Invalid scheduling state (must be an integer of width at most 2)");
/* Extract the scheduling state. */
long long state_value;
if (RedisModule_StringToLongLong(state, &state_value) != REDISMODULE_OK) {
return RedisModule_ReplyWithError(ctx, "scheduling state must be integer");
}
/* Add the task to the task table. If no spec was provided, get the existing
* spec out of the task table so we can publish it. */
RedisModuleString *existing_task_spec = NULL;
RedisModuleKey *key =
OpenPrefixedKey(ctx, TASK_PREFIX, task_id, REDISMODULE_WRITE);
if (task_spec == NULL) {
RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, "node",
node_id, NULL);
RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state,
"local_scheduler_id", local_scheduler_id, NULL);
RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "TaskSpec",
&existing_task_spec, NULL);
if (existing_task_spec == NULL) {
RedisModule_CloseKey(key);
RedisModule_FreeString(ctx, state);
return RedisModule_ReplyWithError(
ctx, "Cannot update a task that doesn't exist yet");
}
} else {
RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, "node",
node_id, "TaskSpec", task_spec, NULL);
RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state,
"local_scheduler_id", local_scheduler_id, "TaskSpec",
task_spec, NULL);
}
RedisModule_CloseKey(key);
/* Build the PUBLISH topic and message for task table subscribers. The topic
* is a string in the format "TASK_PREFIX:<node ID>:<state>". The
* message is a string in the format: "<task ID> <state> <node ID> <task
* specification>". */
RedisModuleString *publish_topic =
RedisString_Format(ctx, "%s%S:%S", TASK_PREFIX, node_id, state);
RedisModuleString *publish_message;
* is a string in the format "TASK_PREFIX:<local scheduler ID>:<state>". The
* message is a serialized SubscribeToTasksReply flatbuffer object. */
RedisModuleString *publish_topic = RedisString_Format(
ctx, "%s%S:%S", TASK_PREFIX, local_scheduler_id, state);
/* Construct the flatbuffers object for the payload. */
flatbuffers::FlatBufferBuilder fbb;
/* Use the old task spec if the current one is NULL. */
RedisModuleString *task_spec_to_use;
if (task_spec != NULL) {
publish_message = RedisString_Format(ctx, "%S %S %S %S", task_id, state,
node_id, task_spec);
task_spec_to_use = task_spec;
} else {
publish_message = RedisString_Format(ctx, "%S %S %S %S", task_id, state,
node_id, existing_task_spec);
task_spec_to_use = existing_task_spec;
}
RedisModule_FreeString(ctx, state);
/* Create the flatbuffers message. */
auto message =
CreateTaskReply(fbb, RedisStringToFlatbuf(fbb, task_id), state_value,
RedisStringToFlatbuf(fbb, local_scheduler_id),
RedisStringToFlatbuf(fbb, task_spec_to_use));
fbb.Finish(message);
RedisModuleString *publish_message = RedisModule_CreateString(
ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize());
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message);
@@ -878,10 +883,7 @@ int TaskTableWrite(RedisModuleCtx *ctx,
*
* @param task_id A string that is the ID of the task.
* @param state A string that is the current scheduling state (a
* scheduling_state enum instance). The string's value must be a
* nonnegative integer less than 100, so that it has width at most 2. If
* less than 2, the value will be left-padded with spaces to a width of
* 2.
* scheduling_state enum instance).
* @param local_scheduler_id A string that is the ray client ID of the
* associated local scheduler, if any.
* @param task_spec A string that is the specification of the task, which can
@@ -908,10 +910,7 @@ int TaskTableAddTask_RedisCommand(RedisModuleCtx *ctx,
*
* @param task_id A string that is the ID of the task.
* @param state A string that is the current scheduling state (a
* scheduling_state enum instance). The string's value must be a
* nonnegative integer less than 100, so that it has width at most 2. If
* less than 2, the value will be left-padded with spaces to a width of
* 2.
* scheduling_state enum instance).
* @param ray_client_id A string that is the ray client ID of the associated
* local scheduler, if any.
* @return OK if the operation was successful.
@@ -941,18 +940,15 @@ int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx,
* scheduling state. The update happens if and only if the current
* scheduling state AND-ed with the bitmask is greater than 0.
* @param state A string that is the scheduling state (a scheduling_state enum
* instance) to update the task entry with. The string's value must be a
* nonnegative integer less than 100, so that it has width at most 2. If
* less than 2, the value will be left-padded with spaces to a width of
* 2.
* instance) to update the task entry with.
* @param ray_client_id A string that is the ray client ID of the associated
* local scheduler, if any, to update the task entry with.
* @return If the current scheduling state does not match the test bitmask,
* returns nil. Else, returns the same as RAY.TASK_TABLE_GET: an array
* of strings representing the updated task fields in the following
* order: 1) (integer) scheduling state 2) (string) associated node ID,
* if any 3) (string) the task specification, which can be casted to a
* task_spec.
* order: 1) (integer) scheduling state 2) (string) associated local
* scheduler ID, if any 3) (string) the task specification, which can be
* cast to a task_spec.
*/
int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
@@ -961,18 +957,12 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx,
return RedisModule_WrongArity(ctx);
}
RedisModuleString *state = NormalizeTaskState(ctx, argv[3]);
if (state == NULL) {
return RedisModule_ReplyWithError(
ctx,
"Invalid scheduling state (must be an integer of width at most 2)");
}
RedisModuleString *state = argv[3];
RedisModuleKey *key = OpenPrefixedKey(ctx, TASK_PREFIX, argv[1],
REDISMODULE_READ | REDISMODULE_WRITE);
if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(key);
RedisModule_FreeString(ctx, state);
return RedisModule_ReplyWithNull(ctx);
}
@@ -980,19 +970,20 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString *current_state = NULL;
RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", &current_state,
NULL);
int current_state_integer = ParseTaskState(current_state);
long long current_state_integer;
if (RedisModule_StringToLongLong(current_state, &current_state_integer) !=
REDISMODULE_OK) {
return RedisModule_ReplyWithError(ctx, "current_state must be integer");
}
if (current_state_integer < 0) {
RedisModule_CloseKey(key);
RedisModule_FreeString(ctx, state);
return RedisModule_ReplyWithError(ctx,
"Found invalid scheduling state (must "
"be an integer of width 2");
return RedisModule_ReplyWithError(ctx, "Found invalid scheduling state.");
}
long long test_state_bitmask;
int status = RedisModule_StringToLongLong(argv[2], &test_state_bitmask);
if (status != REDISMODULE_OK) {
RedisModule_CloseKey(key);
RedisModule_FreeString(ctx, state);
return RedisModule_ReplyWithError(
ctx, "Invalid test value for scheduling state");
}
@@ -1000,16 +991,14 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx,
/* The current value does not match the test bitmask, so do not perform the
* update. */
RedisModule_CloseKey(key);
RedisModule_FreeString(ctx, state);
return RedisModule_ReplyWithNull(ctx);
}
/* The test passed, so perform the update. */
RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, "node",
argv[4], NULL);
RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state,
"local_scheduler_id", argv[4], NULL);
/* Clean up. */
RedisModule_CloseKey(key);
RedisModule_FreeString(ctx, state);
/* Construct a reply by getting the task from the task ID. */
return ReplyWithTask(ctx, argv[1]);
}
@@ -1023,9 +1012,9 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx,
*
* @param task_id A string of the task ID to look up.
* @return An array of strings representing the task fields in the following
* order: 1) (integer) scheduling state 2) (string) associated node ID,
* if any 3) (string) the task specification, which can be casted to a
* task_spec. If the task ID is not in the table, returns nil.
* order: 1) (integer) scheduling state 2) (string) associated local
* scheduler ID, if any 3) (string) the task specification, which can be
* cast to a task_spec. If the task ID is not in the table, returns nil.
*/
int TaskTableGet_RedisCommand(RedisModuleCtx *ctx,
RedisModuleString **argv,
@@ -1038,6 +1027,8 @@ int TaskTableGet_RedisCommand(RedisModuleCtx *ctx,
return ReplyWithTask(ctx, argv[1]);
}
extern "C" {
/* This function must be present on each Redis module. It is used in order to
* register the commands into the Redis server. */
int RedisModule_OnLoad(RedisModuleCtx *ctx,
@@ -1135,3 +1126,5 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx,
return REDISMODULE_OK;
}
} /* extern "C" */
+96 -204
View File
@@ -27,6 +27,10 @@ extern "C" {
#include "redis.h"
#include "io.h"
#include "format/common_generated.h"
#include "common_protocol.h"
#ifndef _WIN32
/* This function is actually not declared in standard POSIX, so declare it. */
extern int usleep(useconds_t usec);
@@ -358,27 +362,14 @@ Task *parse_and_construct_task_from_redis_reply(redisReply *reply) {
if (reply->type == REDIS_REPLY_NIL) {
/* There is no task in the reply, so return NULL. */
task = NULL;
} else if (reply->type == REDIS_REPLY_ARRAY) {
/* Check that the reply is as expected. The 0th element is the scheduling
* state. The 1st element is the db_client_id of the associated local
* scheduler, and the 2nd element is the TaskSpec. */
CHECK(reply->elements == 3);
CHECK(reply->element[0]->type == REDIS_REPLY_INTEGER);
CHECK(reply->element[1]->type == REDIS_REPLY_STRING);
CHECK(reply->element[2]->type == REDIS_REPLY_STRING);
/* Parse the scheduling state. */
long long state = reply->element[0]->integer;
/* Parse the local scheduler db_client_id. */
DBClientID local_scheduler_id;
CHECK(sizeof(local_scheduler_id) == reply->element[1]->len);
memcpy(local_scheduler_id.id, reply->element[1]->str,
reply->element[1]->len);
/* Parse the task spec. */
TaskSpec *spec = (TaskSpec *) malloc(reply->element[2]->len);
memcpy(spec, reply->element[2]->str, reply->element[2]->len);
task = Task_alloc(spec, reply->element[2]->len, state, local_scheduler_id);
/* Free the task spec. */
TaskSpec_free(spec);
} else if (reply->type == REDIS_REPLY_STRING) {
/* The reply is a flatbuffer TaskReply object. Parse it and construct the
* task. */
auto message = flatbuffers::GetRoot<TaskReply>(reply->str);
TaskSpec *spec = (TaskSpec *) message->task_spec()->data();
int64_t task_spec_size = message->task_spec()->size();
task = Task_alloc(spec, task_spec_size, message->state(),
from_flatbuf(message->local_scheduler_id()));
} else {
LOG_FATAL("Unexpected reply type %d", reply->type);
}
@@ -502,81 +493,6 @@ void redis_object_table_lookup_callback(redisAsyncContext *c,
}
}
/**
* This will parse a payload string published on the object notification
* channel. The string must have the format:
*
* <object id> MANAGERS <manager id1> <manager id2> ...
*
* where there may be any positive number of manager IDs.
*
* @param db The db handle.
* @param payload The payload string.
* @param length The length of the string.
* @param manager_count This method will write the number of managers at this
* address.
* @param manager_vector This method will allocate an array of pointers to
* manager addresses and write the address of the array at this address.
* The caller is responsible for freeing this array.
* @return The object ID that the notification is about.
*/
ObjectID parse_subscribe_to_notifications_payload(
DBHandle *db,
char *payload,
int length,
int64_t *data_size,
int *manager_count,
const char ***manager_vector) {
long long data_size_value = 0;
int num_managers = (length - sizeof(ObjectID) - 1 - sizeof(data_size_value) -
1 - strlen("MANAGERS")) /
(1 + sizeof(DBClientID));
int64_t rval = sizeof(ObjectID) + 1 + sizeof(data_size_value) + 1 +
strlen("MANAGERS") + num_managers * (1 + sizeof(DBClientID));
CHECKM(length == rval,
"length mismatch: num_managers = %d, length = %d, rval = %" PRId64,
num_managers, length, rval);
CHECK(num_managers > 0);
ObjectID obj_id;
/* Track our current offset in the payload. */
int offset = 0;
/* Parse the object ID. */
memcpy(&obj_id.id, &payload[offset], sizeof(obj_id.id));
offset += sizeof(obj_id.id);
/* The next part of the payload is a space. */
const char *space_str = " ";
CHECK(memcmp(&payload[offset], space_str, strlen(space_str)) == 0);
offset += strlen(space_str);
/* The next part of the payload is binary data_size. */
memcpy(&data_size_value, &payload[offset], sizeof(data_size_value));
offset += sizeof(data_size_value);
/* The next part of the payload is the string " MANAGERS" with leading ' '. */
const char *managers_str = " MANAGERS";
CHECK(memcmp(&payload[offset], managers_str, strlen(managers_str)) == 0);
offset += strlen(managers_str);
/* Parse the managers. */
const char **managers = (const char **) malloc(num_managers * sizeof(char *));
for (int i = 0; i < num_managers; ++i) {
/* First there is a space. */
CHECK(memcmp(&payload[offset], " ", strlen(" ")) == 0);
offset += strlen(" ");
/* Get the manager ID. */
DBClientID manager_id;
memcpy(&manager_id.id, &payload[offset], sizeof(manager_id.id));
offset += sizeof(manager_id.id);
/* Write the address of the corresponding manager to the returned array. */
redis_get_cached_db_client(db, manager_id, &managers[i]);
}
CHECK(offset == length);
/* Return the manager array and the object ID. */
*manager_count = num_managers;
*manager_vector = managers;
*data_size = data_size_value;
return obj_id;
}
void object_table_redis_subscribe_to_notifications_callback(
redisAsyncContext *c,
void *r,
@@ -603,13 +519,22 @@ void object_table_redis_subscribe_to_notifications_callback(
message_type->str);
if (strcmp(message_type->str, "message") == 0) {
/* Handle an object notification. */
int64_t data_size = 0;
int manager_count;
const char **manager_vector;
ObjectID obj_id = parse_subscribe_to_notifications_payload(
db, reply->element[2]->str, reply->element[2]->len, &data_size,
&manager_count, &manager_vector);
/* We received an object notification. Parse the payload. */
auto message = flatbuffers::GetRoot<SubscribeToNotificationsReply>(
reply->element[2]->str);
/* Extract the object ID. */
ObjectID obj_id = from_flatbuf(message->object_id());
/* Extract the data size. */
int64_t data_size = message->object_size();
int manager_count = message->manager_ids()->size();
/* Construct the manager vector from the flatbuffers object. */
const char **manager_vector =
(const char **) malloc(manager_count * sizeof(char *));
for (int i = 0; i < manager_count; ++i) {
DBClientID manager_id = from_flatbuf(message->manager_ids()->Get(i));
redis_get_cached_db_client(db, manager_id, &manager_vector[i]);
}
/* Call the subscribe callback. */
ObjectTableSubscribeData *data =
(ObjectTableSubscribeData *) callback_data->data;
@@ -641,7 +566,7 @@ void redis_object_table_subscribe_to_notifications(
TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
/* The object channel prefix must match the value defined in
* src/common/redismodule/ray_redis_module.c. */
* src/common/redismodule/ray_redis_module.cc. */
const char *object_channel_prefix = "OC:";
const char *object_channel_bcast = "BCAST";
int status = REDIS_OK;
@@ -869,47 +794,6 @@ void redis_task_table_test_and_update(TableCallbackData *callback_data) {
}
}
/* The format of the payload is described in ray_redis_module.c and is
* "<task ID> <state> <local scheduler ID> <task specification>". TODO(rkn):
* Make this code nicer. */
void parse_task_table_subscribe_callback(char *payload,
int length,
TaskID *task_id,
int *state,
DBClientID *local_scheduler_id,
TaskSpec **spec,
int64_t *task_spec_size) {
/* Note that the state is padded with spaces to consist of precisely two
* characters. */
int task_spec_payload_size =
length - sizeof(*task_id) - 1 - 2 - 1 - sizeof(*local_scheduler_id) - 1;
int offset = 0;
/* Read in the task ID. */
memcpy(task_id, &payload[offset], sizeof(*task_id));
offset += sizeof(*task_id);
/* Read in a space. */
const char *space_str = (const char *) " ";
CHECK(memcmp(space_str, &payload[offset], strlen(space_str)) == 0);
offset += strlen(space_str);
/* Read in the state, which is an integer left-padded with spaces to two
* characters. */
CHECK(sscanf(&payload[offset], "%2d", state) == 1);
offset += 2;
/* Read in a space. */
CHECK(memcmp(space_str, &payload[offset], strlen(space_str)) == 0);
offset += strlen(space_str);
/* Read in the local scheduler ID. */
memcpy(local_scheduler_id, &payload[offset], sizeof(*local_scheduler_id));
offset += sizeof(*local_scheduler_id);
/* Read in a space. */
CHECK(memcmp(space_str, &payload[offset], strlen(space_str)) == 0);
offset += strlen(space_str);
/* Read in the task spec. */
*spec = (TaskSpec *) malloc(task_spec_payload_size);
memcpy(*spec, &payload[offset], task_spec_payload_size);
*task_spec_size = task_spec_payload_size;
}
void redis_task_table_subscribe_callback(redisAsyncContext *c,
void *r,
void *privdata) {
@@ -917,7 +801,7 @@ void redis_task_table_subscribe_callback(redisAsyncContext *c,
redisReply *reply = (redisReply *) r;
CHECK(reply->type == REDIS_REPLY_ARRAY);
/* The number of elements is 3 for a reply to SUBSCRIBE, and 4 for a reply to
/* The number of elements is 3 for a reply to SUBSCRIBE, and 4 for a reply to
* PSUBSCRIBE. */
CHECKM(reply->elements == 3 || reply->elements == 4, "reply->elements is %zu",
reply->elements);
@@ -929,20 +813,22 @@ void redis_task_table_subscribe_callback(redisAsyncContext *c,
if (strcmp(message_type->str, "message") == 0 ||
strcmp(message_type->str, "pmessage") == 0) {
/* Handle a task table event. Parse the payload and call the callback. */
auto message = flatbuffers::GetRoot<TaskReply>(payload->str);
/* Extract the task ID. */
TaskID task_id = from_flatbuf(message->task_id());
/* Extract the scheduling state. */
int64_t state = message->state();
/* Extract the local scheduler ID. */
DBClientID local_scheduler_id = from_flatbuf(message->local_scheduler_id());
/* Extract the task spec. */
TaskSpec *spec = (TaskSpec *) message->task_spec()->data();
int64_t task_spec_size = message->task_spec()->size();
/* Create a task. */
Task *task = Task_alloc(spec, task_spec_size, state, local_scheduler_id);
/* Call the subscribe callback if there is one. */
TaskTableSubscribeData *data =
(TaskTableSubscribeData *) callback_data->data;
/* Read out the information from the payload. */
TaskID task_id;
int state;
DBClientID local_scheduler_id;
TaskSpec *spec;
int64_t task_spec_size;
parse_task_table_subscribe_callback(payload->str, payload->len, &task_id,
&state, &local_scheduler_id, &spec,
&task_spec_size);
Task *task = Task_alloc(spec, task_spec_size, state, local_scheduler_id);
TaskSpec_free(spec);
/* Call the subscribe callback if there is one. */
if (data->subscribe_callback != NULL) {
data->subscribe_callback(task, data->subscribe_context);
}
@@ -969,24 +855,24 @@ void redis_task_table_subscribe_callback(redisAsyncContext *c,
void redis_task_table_subscribe(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
TaskTableSubscribeData *data = (TaskTableSubscribeData *) callback_data->data;
/* TASK_CHANNEL_PREFIX is defined in ray_redis_module.c and must be kept in
/* TASK_CHANNEL_PREFIX is defined in ray_redis_module.cc and must be kept in
* sync with that file. */
const char *TASK_CHANNEL_PREFIX = "TT:";
int status;
if (IS_NIL_ID(data->local_scheduler_id)) {
/* TODO(swang): Implement the state_filter by translating the bitmask into
* a Redis key-matching pattern. */
status = redisAsyncCommand(
db->sub_context, redis_task_table_subscribe_callback,
(void *) callback_data->timer_id, "PSUBSCRIBE %s*:%2d",
TASK_CHANNEL_PREFIX, data->state_filter);
status =
redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback,
(void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d",
TASK_CHANNEL_PREFIX, data->state_filter);
} else {
DBClientID local_scheduler_id = data->local_scheduler_id;
status = redisAsyncCommand(
db->sub_context, redis_task_table_subscribe_callback,
(void *) callback_data->timer_id, "SUBSCRIBE %s%b:%2d",
TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id,
sizeof(local_scheduler_id.id), data->state_filter);
status =
redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback,
(void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d",
TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id,
sizeof(local_scheduler_id.id), data->state_filter);
}
if ((status == REDIS_ERR) || db->sub_context->err) {
LOG_REDIS_DEBUG(db->sub_context, "error in redis_task_table_subscribe");
@@ -1052,40 +938,23 @@ void redis_db_client_table_subscribe_callback(redisAsyncContext *c,
return;
}
/* Otherwise, parse the payload and call the callback. */
DBClientTableSubscribeData *data =
(DBClientTableSubscribeData *) callback_data->data;
DBClientID client;
memcpy(client.id, payload->str, sizeof(client.id));
/* We subtract 1 + sizeof(client.id) to compute the length of the
* client_type string, and we add 1 to null-terminate the string. */
int client_type_length = payload->len - 1 - sizeof(client.id) + 1;
CHECK(client_type_length > 0);
auto message =
flatbuffers::GetRoot<SubscribeToDBClientTableReply>(payload->str);
DBClientID client = from_flatbuf(message->db_client_id());
/* Parse the client type and auxiliary address from the response. If there is
* only client type, then the update was a delete. */
char *client_type = (char *) malloc(client_type_length);
char *aux_address = (char *) malloc(client_type_length);
int is_insertion;
memset(aux_address, 0, client_type_length);
/* Published message format: <client_id:client_type aux_addr> */
int rv = sscanf(&payload->str[1 + sizeof(client.id)], "%s %s %d", client_type,
aux_address, &is_insertion);
CHECKM(rv == 3,
"redis_db_client_table_subscribe_callback: expected 2 parsed args, "
"Got %d instead.",
rv);
CHECKM(is_insertion == 1 || is_insertion == 0,
"redis_db_client_table_subscribe_callback: expected 0 or 1 for "
"insertion field, got %d instead.",
is_insertion);
char *client_type = (char *) message->client_type()->data();
char *aux_address = (char *) message->aux_address()->data();
bool is_insertion = message->is_insertion();
/* Call the subscription callback. */
DBClientTableSubscribeData *data =
(DBClientTableSubscribeData *) callback_data->data;
if (data->subscribe_callback) {
data->subscribe_callback(client, client_type, aux_address,
(bool) is_insertion, data->subscribe_context);
data->subscribe_callback(client, client_type, aux_address, is_insertion,
data->subscribe_context);
}
free(client_type);
free(aux_address);
}
void redis_db_client_table_subscribe(TableCallbackData *callback_data) {
@@ -1114,15 +983,26 @@ void redis_local_scheduler_table_subscribe_callback(redisAsyncContext *c,
if (strcmp(message_type->str, "message") == 0) {
/* Handle a local scheduler heartbeat. Parse the payload and call the
* subscribe callback. */
redisReply *payload = reply->element[2];
auto message =
flatbuffers::GetRoot<LocalSchedulerInfoMessage>(reply->element[2]->str);
/* Extract the client ID. */
DBClientID client_id = from_flatbuf(message->db_client_id());
/* Extract the fields of the local scheduler info struct. */
LocalSchedulerInfo info;
info.total_num_workers = message->total_num_workers();
info.task_queue_length = message->task_queue_length();
info.available_workers = message->available_workers();
for (int i = 0; i < ResourceIndex_MAX; ++i) {
info.static_resources[i] = message->static_resources()->Get(i);
}
for (int i = 0; i < ResourceIndex_MAX; ++i) {
info.dynamic_resources[i] = message->dynamic_resources()->Get(i);
}
/* Call the subscribe callback. */
LocalSchedulerTableSubscribeData *data =
(LocalSchedulerTableSubscribeData *) callback_data->data;
DBClientID client_id;
LocalSchedulerInfo info;
/* The payload should be the concatenation of these two structs. */
CHECK(sizeof(client_id) + sizeof(info) == payload->len);
memcpy(&client_id, payload->str, sizeof(client_id));
memcpy(&info, payload->str + sizeof(client_id), sizeof(info));
if (data->subscribe_callback) {
data->subscribe_callback(client_id, info, data->subscribe_context);
}
@@ -1167,10 +1047,22 @@ void redis_local_scheduler_table_send_info(TableCallbackData *callback_data) {
DBHandle *db = callback_data->db_handle;
LocalSchedulerTableSendInfoData *data =
(LocalSchedulerTableSendInfoData *) callback_data->data;
/* Create a flatbuffer object to serialize and publish. */
flatbuffers::FlatBufferBuilder fbb;
/* Create the flatbuffers message. */
LocalSchedulerInfo info = data->info;
auto message = CreateLocalSchedulerInfoMessage(
fbb, to_flatbuf(fbb, db->client), info.total_num_workers,
info.task_queue_length, info.available_workers,
fbb.CreateVector(info.static_resources, ResourceIndex_MAX),
fbb.CreateVector(info.dynamic_resources, ResourceIndex_MAX));
fbb.Finish(message);
int status = redisAsyncCommand(
db->context, redis_local_scheduler_table_send_info_callback,
(void *) callback_data->timer_id, "PUBLISH local_schedulers %b%b",
db->client.id, sizeof(db->client.id), &data->info, sizeof(data->info));
(void *) callback_data->timer_id, "PUBLISH local_schedulers %b",
fbb.GetBufferPointer(), fbb.GetSize());
if ((status == REDIS_ERR) || db->context->err) {
LOG_REDIS_DEBUG(db->context,
"error in redis_local_scheduler_table_send_info");
+10 -3
View File
@@ -8,6 +8,9 @@ import numpy as np
import time
import redis
# Import flatbuffer bindings.
from ray.core.generated.TaskReply import TaskReply
class TaskTests(unittest.TestCase):
def testSubmittingTasks(self):
@@ -164,9 +167,13 @@ class ReconstructionTests(unittest.TestCase):
r = redis.StrictRedis(port=self.redis_port)
task_ids = r.keys("TT:*")
task_ids = [task_id[3:] for task_id in task_ids]
node_ids = [r.execute_command("ray.task_table_get", task_id)[1] for task_id
in task_ids]
self.assertEqual(len(set(node_ids)), self.num_local_schedulers)
local_scheduler_ids = []
for task_id in task_ids:
message = r.execute_command("ray.task_table_get", task_id)
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
local_scheduler_ids.append(task_reply_object.LocalSchedulerId())
self.assertEqual(len(set(local_scheduler_ids)), self.num_local_schedulers)
# Clean up the Ray cluster.
ray.worker.cleanup()
+1 -1
View File
@@ -19,7 +19,7 @@ loop = asyncio.get_event_loop()
IDENTIFIER_LENGTH = 20
# This prefix must match the value defined in ray_redis_module.c.
# This prefix must match the value defined in ray_redis_module.cc.
DB_CLIENT_PREFIX = b"CL:"
def hex_identifier(identifier):