From 53dffe0bf23fe93bc4f00e9a68c0765d5a95fef5 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Fri, 10 Mar 2017 18:35:25 -0800 Subject: [PATCH] Use flatbuffers for some messages from Redis. (#341) * Compile the Ray redis module with C++. * Redo parsing of object table notifications with flatbuffers. * Update redis module python tests. * Redo parsing of task table notifications with flatbuffers. * Fix linting. * Redo parsing of db client notifications with flatbuffers. * Redo publishing of local scheduler heartbeats with flatbuffers. * Fix linting. * Remove usage of fixed-width formatting of scheduling state in channel name. * Reply with flatbuffer object to task table queries, also simplify redis string to flatbuffer string conversion. * Fix linting and tests. * fix * cleanup * simplify logic in ReplyWithTask --- .travis/install-dependencies.sh | 8 +- doc/source/install-on-macosx.md | 2 +- doc/source/install-on-ubuntu.md | 2 +- docker/base-deps/Dockerfile | 1 + python/ray/common/redis_module/runtest.py | 86 +++-- python/ray/core/generated/__init__.py | 0 python/ray/global_scheduler/test/test.py | 2 +- python/ray/monitor.py | 19 +- python/ray/worker.py | 2 +- scripts/start_ray.py | 2 +- src/common/CMakeLists.txt | 9 + src/common/format/common.fbs | 59 ++++ src/common/redis_module/CMakeLists.txt | 6 +- ...ray_redis_module.c => ray_redis_module.cc} | 251 +++++++-------- src/common/state/redis.cc | 300 ++++++------------ test/stress_tests.py | 13 +- webui/backend/ray_ui.py | 2 +- 17 files changed, 379 insertions(+), 385 deletions(-) create mode 100644 python/ray/core/generated/__init__.py rename src/common/redis_module/{ray_redis_module.c => ray_redis_module.cc} (86%) diff --git a/.travis/install-dependencies.sh b/.travis/install-dependencies.sh index c0de83c73..83077febe 100755 --- a/.travis/install-dependencies.sh +++ b/.travis/install-dependencies.sh @@ -20,7 +20,7 @@ fi if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then sudo apt-get update sudo apt-get install -y cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip - sudo pip install cloudpickle funcsigs colorama psutil redis tensorflow + sudo pip install cloudpickle funcsigs colorama psutil redis tensorflow flatbuffers elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then sudo apt-get update sudo apt-get install -y cmake python-dev python-numpy build-essential autoconf curl libtool libboost-all-dev unzip @@ -28,7 +28,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow + pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow flatbuffers elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed which -s brew @@ -41,7 +41,7 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then fi brew install cmake automake autoconf libtool boost sudo easy_install pip - sudo pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow --ignore-installed six + sudo pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow flatbuffers --ignore-installed six elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed which -s brew @@ -57,7 +57,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then wget https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow + pip install numpy cloudpickle funcsigs colorama psutil redis tensorflow flatbuffers elif [[ "$LINT" == "1" ]]; then sudo apt-get update sudo apt-get install -y cmake build-essential autoconf curl libtool libboost-all-dev unzip diff --git a/doc/source/install-on-macosx.md b/doc/source/install-on-macosx.md index 235abac5a..8cf85f653 100644 --- a/doc/source/install-on-macosx.md +++ b/doc/source/install-on-macosx.md @@ -11,7 +11,7 @@ To install Ray, first install the following dependencies. We recommend using brew update brew install cmake automake autoconf libtool boost wget -pip install numpy cloudpickle funcsigs colorama psutil redis --ignore-installed six +pip install numpy cloudpickle funcsigs colorama psutil redis flatbuffers --ignore-installed six ``` If you are using Anaconda, you may also need to run the following. diff --git a/doc/source/install-on-ubuntu.md b/doc/source/install-on-ubuntu.md index c94fcdb02..31ac80296 100644 --- a/doc/source/install-on-ubuntu.md +++ b/doc/source/install-on-ubuntu.md @@ -12,7 +12,7 @@ To install Ray, first install the following dependencies. We recommend using sudo apt-get update sudo apt-get install -y cmake build-essential autoconf curl libtool libboost-all-dev unzip python-dev python-pip # If you're using Anaconda, then python-dev and python-pip are unnecessary. -pip install numpy cloudpickle funcsigs colorama psutil redis +pip install numpy cloudpickle funcsigs colorama psutil redis flatbuffers ``` If you are using Anaconda, you may also need to run the following. diff --git a/docker/base-deps/Dockerfile b/docker/base-deps/Dockerfile index a0b91c379..5dabbfa44 100644 --- a/docker/base-deps/Dockerfile +++ b/docker/base-deps/Dockerfile @@ -11,4 +11,5 @@ RUN echo 'export PATH=/opt/conda/bin:$PATH' > /etc/profile.d/conda.sh \ && rm /tmp/anaconda.sh ENV PATH "/opt/conda/bin:$PATH" RUN conda install -y libgcc +RUN pip install flatbuffers RUN pip install --upgrade pip cloudpickle diff --git a/python/ray/common/redis_module/runtest.py b/python/ray/common/redis_module/runtest.py index b45a4c2df..be1d35ed7 100644 --- a/python/ray/common/redis_module/runtest.py +++ b/python/ray/common/redis_module/runtest.py @@ -11,6 +11,10 @@ import unittest import redis import ray.services +# Import flatbuffer bindings. +from ray.core.generated.SubscribeToNotificationsReply import SubscribeToNotificationsReply +from ray.core.generated.TaskReply import TaskReply + OBJECT_INFO_PREFIX = "OI:" OBJECT_LOCATION_PREFIX = "OL:" OBJECT_SUBSCRIBE_PREFIX = "OS:" @@ -142,6 +146,15 @@ class TestGlobalStateStore(unittest.TestCase): self.assertEqual(set(response), set()) def testObjectTableSubscribeToNotifications(self): + # Define a helper method for checking the contents of object notifications. + def check_object_notification(notification_message, object_id, object_size, manager_ids): + notification_object = SubscribeToNotificationsReply.GetRootAsSubscribeToNotificationsReply(notification_message, 0) + self.assertEqual(notification_object.ObjectId(), object_id) + self.assertEqual(notification_object.ObjectSize(), object_size) + self.assertEqual(notification_object.ManagerIdsLength(), len(manager_ids)) + for i in range(len(manager_ids)): + self.assertEqual(notification_object.ManagerIds(i), manager_ids[i]) + data_size = 0xf1f0 p = self.redis.pubsub() # Subscribe to an object ID. @@ -151,8 +164,12 @@ class TestGlobalStateStore(unittest.TestCase): self.assertEqual(get_next_message(p)["data"], 1) # Request a notification and receive the data. self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id1") - self.assertEqual(get_next_message(p)["data"], b"object_id1 %s MANAGERS manager_id2"\ - %integerToAsciiHex(data_size, 8)) + # Verify that the notification is correct. + check_object_notification(get_next_message(p)["data"], + b"object_id1", + data_size, + [b"manager_id2"]) + # Request a notification for an object that isn't there. Then add the object # and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD should # trigger notifications. @@ -160,15 +177,24 @@ class TestGlobalStateStore(unittest.TestCase): self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id1") self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id2") self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size, "hash1", "manager_id3") - self.assertEqual(get_next_message(p)["data"], b"object_id3 %s MANAGERS manager_id1"\ - %integerToAsciiHex(data_size, 8)) + # Verify that the notification is correct. + check_object_notification(get_next_message(p)["data"], + b"object_id3", + data_size, + [b"manager_id1"]) self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size, "hash1", "manager_id3") - self.assertEqual(get_next_message(p)["data"], b"object_id2 %s MANAGERS manager_id3"\ - %integerToAsciiHex(data_size, 8)) + # Verify that the notification is correct. + check_object_notification(get_next_message(p)["data"], + b"object_id2", + data_size, + [b"manager_id3"]) # Request notifications for object_id3 again. self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS", "manager_id1", "object_id3") - self.assertEqual(get_next_message(p)["data"], b"object_id3 %s MANAGERS manager_id1 manager_id2 manager_id3"\ - %integerToAsciiHex(data_size, 8)) + # Verify that the notification is correct. + check_object_notification(get_next_message(p)["data"], + b"object_id3", + data_size, + [b"manager_id1", b"manager_id2", b"manager_id3"]) def testResultTableAddAndLookup(self): # Try looking up something in the result table before anything is added. @@ -205,10 +231,6 @@ class TestGlobalStateStore(unittest.TestCase): # Non-integer scheduling states should not be added. self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", "invalid_state", "node_id", "task_spec") - with self.assertRaises(redis.ResponseError): - # Scheduling states with invalid width should not be added. - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 101, - "node_id", "task_spec") with self.assertRaises(redis.ResponseError): # Should not be able to update a non-existent task. self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10, @@ -219,17 +241,24 @@ class TestGlobalStateStore(unittest.TestCase): TASK_STATUS_SCHEDULED = 2 TASK_STATUS_QUEUED = 4 + def check_task_reply(message, task_args): + task_status, local_scheduler_id, task_spec = task_args + task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) + self.assertEqual(task_reply_object.State(), task_status) + self.assertEqual(task_reply_object.LocalSchedulerId(), local_scheduler_id) + self.assertEqual(task_reply_object.TaskSpec(), task_spec) + # Check that task table adds, updates, and lookups work correctly. task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"] response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", *task_args) response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - self.assertEqual(response, task_args) + check_task_reply(response, task_args) task_args[0] = TASK_STATUS_SCHEDULED self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", *task_args[:2]) response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - self.assertEqual(response, task_args) + check_task_reply(response, task_args) # If the current value, test value, and set value are all the same, the # update happens, and the response is still the same task. @@ -237,10 +266,10 @@ class TestGlobalStateStore(unittest.TestCase): response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *task_args[:3]) - self.assertEqual(response, task_args[1:]) + check_task_reply(response, task_args[1:]) # Check that the task entry is still the same. get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - self.assertEqual(get_response, task_args[1:]) + check_task_reply(get_response, task_args[1:]) # If the current value is the same as the test value, and the set value is # different, the update happens, and the response is the entire task. @@ -248,10 +277,10 @@ class TestGlobalStateStore(unittest.TestCase): response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *task_args[:3]) - self.assertEqual(response, task_args[1:]) + check_task_reply(response, task_args[1:]) # Check that the update happened. get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") - self.assertEqual(get_response, task_args[1:]) + check_task_reply(get_response, task_args[1:]) # If the current value is no longer the same as the test value, the # response is nil. @@ -271,7 +300,7 @@ class TestGlobalStateStore(unittest.TestCase): response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", "task_id", *task_args[:3]) - self.assertEqual(response, task_args[1:]) + check_task_reply(response, task_args[1:]) # If the test value is a bitmask that does not match the current value, the # update does not happen. @@ -288,13 +317,13 @@ class TestGlobalStateStore(unittest.TestCase): def testTaskTableSubscribe(self): scheduling_state = 1 - node_id = "node_id" + local_scheduler_id = "local_scheduler_id" # Subscribe to the task table. p = self.redis.pubsub() p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX)) - p.psubscribe("{prefix}*:{state: >2}".format(prefix=TASK_PREFIX, state=scheduling_state)) - p.psubscribe("{prefix}{node}:*".format(prefix=TASK_PREFIX, node=node_id)) - task_args = [b"task_id", scheduling_state, node_id.encode("ascii"), b"task_spec"] + p.psubscribe("{prefix}*:{state}".format(prefix=TASK_PREFIX, state=scheduling_state)) + p.psubscribe("{prefix}{local_scheduler_id}:*".format(prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id)) + task_args = [b"task_id", scheduling_state, local_scheduler_id.encode("ascii"), b"task_spec"] self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args) # Receive the acknowledgement message. self.assertEqual(get_next_message(p)["data"], 1) @@ -302,10 +331,13 @@ class TestGlobalStateStore(unittest.TestCase): self.assertEqual(get_next_message(p)["data"], 3) # Receive the actual data. for i in range(3): - message = get_next_message(p)["data"] - message = message.split() - message[1] = int(message[1]) - self.assertEqual(message, task_args) + message = get_next_message(p)["data"] + # Check that the notification object is correct. + notification_object = TaskReply.GetRootAsTaskReply(message, 0) + self.assertEqual(notification_object.TaskId(), b"task_id") + self.assertEqual(notification_object.State(), scheduling_state) + self.assertEqual(notification_object.LocalSchedulerId(), local_scheduler_id.encode("ascii")) + self.assertEqual(notification_object.TaskSpec(), b"task_spec") if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/core/generated/__init__.py b/python/ray/core/generated/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index 9be25d6ae..cdf6f651d 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -34,7 +34,7 @@ TASK_STATUS_QUEUED = 4 TASK_STATUS_RUNNING = 8 TASK_STATUS_DONE = 16 -# These constants are an implementation detail of ray_redis_module.c, so this +# These constants are an implementation detail of ray_redis_module.cc, so this # must be kept in sync with that file. DB_CLIENT_PREFIX = "CL:" TASK_PREFIX = "TT:" diff --git a/python/ray/monitor.py b/python/ray/monitor.py index b52d612a0..0355e40b6 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -12,13 +12,16 @@ import time from ray.services import get_ip_address from ray.services import get_port +# Import flatbuffer bindings. +from ray.core.generated.SubscribeToDBClientTableReply import SubscribeToDBClientTableReply + # These variables must be kept in sync with the C codebase. # common/common.h DB_CLIENT_ID_SIZE = 20 NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE # common/task.h TASK_STATUS_LOST = 32 -# common/redis_module/ray_redis_module.c +# common/redis_module/ray_redis_module.cc TASK_PREFIX = "TT:" DB_CLIENT_PREFIX = "CL:" DB_CLIENT_TABLE_NAME = b"db_clients" @@ -89,14 +92,12 @@ class Monitor(object): # Parse the message. data = message["data"] - db_client_id = data[:DB_CLIENT_ID_SIZE] - data = data[DB_CLIENT_ID_SIZE + 1:] - data = data.split(b" ") - client_type, auxiliary_address, is_insertion = data - is_insertion = int(is_insertion) - if is_insertion != 1 and is_insertion != 0: - raise Exception("Expected 0 or 1 for insertion field, got {} instead".format(is_insertion)) - is_insertion = bool(is_insertion) + + notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0) + db_client_id = notification_object.DbClientId() + client_type = notification_object.ClientType() + auxiliary_address = notification_object.AuxAddress() + is_insertion = notification_object.IsInsertion() return db_client_id, client_type, auxiliary_address, is_insertion diff --git a/python/ray/worker.py b/python/ray/worker.py index 56257fea7..379eb3510 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -723,7 +723,7 @@ def get_address_info_from_redis_helper(redis_address, node_ip_address): # must have run "CONFIG SET protected-mode no". redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) # The client table prefix must be kept in sync with the file - # "src/common/redis_module/ray_redis_module.c" where it is defined. + # "src/common/redis_module/ray_redis_module.cc" where it is defined. REDIS_CLIENT_TABLE_PREFIX = "CL:" client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX)) # Filter to clients on the same node and do some basic checking. diff --git a/scripts/start_ray.py b/scripts/start_ray.py index e8754a5f7..d49382aa4 100644 --- a/scripts/start_ray.py +++ b/scripts/start_ray.py @@ -21,7 +21,7 @@ def check_no_existing_redis_clients(node_ip_address, redis_address): redis_ip_address, redis_port = redis_address.split(":") redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) # The client table prefix must be kept in sync with the file - # "src/common/redis_module/ray_redis_module.c" where it is defined. + # "src/common/redis_module/ray_redis_module.cc" where it is defined. REDIS_CLIENT_TABLE_PREFIX = "CL:" client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX)) # Filter to clients on the same node and do some basic checking. diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index f6e7b629f..209aff1ea 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -24,6 +24,15 @@ add_custom_command( COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}" VERBATIM) +# Generate Python bindings for the flatbuffers objects. +set(PYTHON_OUTPUT_DIR ${CMAKE_BINARY_DIR}/generated/) +add_custom_command( + TARGET gen_common_fbs + COMMAND ${FLATBUFFERS_COMPILER} -p -o ${PYTHON_OUTPUT_DIR} ${COMMON_FBS_SRC} + DEPENDS ${FBS_DEPENDS} + COMMENT "Running flatc compiler on ${COMMON_FBS_SRC}" + VERBATIM) + add_dependencies(gen_common_fbs flatbuffers_ep) add_custom_target( diff --git a/src/common/format/common.fbs b/src/common/format/common.fbs index f94ae879e..762b97f77 100644 --- a/src/common/format/common.fbs +++ b/src/common/format/common.fbs @@ -50,3 +50,62 @@ table TaskInfo { } root_type TaskInfo; + +table SubscribeToNotificationsReply { + // The object ID of the object that the notification is about. + object_id: string; + // The size of the object. + object_size: long; + // The IDs of the managers that contain this object. + manager_ids: [string]; +} + +root_type SubscribeToNotificationsReply; + +table TaskReply { + // The task ID of the task that the message is about. + task_id: string; + // The state of the task. This is encoded as a bit mask of scheduling_state + // enum values in task.h. + state: long; + // A local scheduler ID. + local_scheduler_id: string; + // A string of bytes representing the task specification. + task_spec: string; +} + +root_type TaskReply; + +table SubscribeToDBClientTableReply { + // The db client ID of the client that the message is about. + db_client_id: string; + // The type of the client. + client_type: string; + // If the client is a local scheduler, this is the address of the plasma + // manager that the local scheduler is connected to. Otherwise, it is empty. + aux_address: string; + // True if the message is about the addition of a client and false if it is + // about the deletion of a client. + is_insertion: bool; +} + +root_type SubscribeToDBClientTableReply; + +table LocalSchedulerInfoMessage { + // The db client ID of the client that the message is about. + db_client_id: string; + // The total number of workers that are connected to this local scheduler. + total_num_workers: long; + // The number of tasks queued in this local scheduler. + task_queue_length: long; + // The number of workers that are available and waiting for tasks. + available_workers: long; + // The resource vector of resources generally available to this local + // scheduler. + static_resources: [double]; + // The resource vector of resources currently available to this local + // scheduler. + dynamic_resources: [double]; +} + +root_type LocalSchedulerInfoMessage; diff --git a/src/common/redis_module/CMakeLists.txt b/src/common/redis_module/CMakeLists.txt index 2b9fe7c33..afea36d7b 100644 --- a/src/common/redis_module/CMakeLists.txt +++ b/src/common/redis_module/CMakeLists.txt @@ -3,15 +3,15 @@ cmake_minimum_required(VERSION 2.8) project(ray_redis_module) if(APPLE) - set(REDIS_MODULE_CFLAGS -W -Wall -dynamic -fno-common -g -ggdb -std=c99 -O2) + set(REDIS_MODULE_CFLAGS -W -Wall -dynamic -fno-common -g -ggdb -std=c++11 -O2) set(REDIS_MODULE_LDFLAGS "-undefined dynamic_lookup") set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") else() - set(REDIS_MODULE_CFLAGS -W -Wall -fno-common -g -ggdb -std=c99 -O2) + set(REDIS_MODULE_CFLAGS -W -Wall -fno-common -g -ggdb -std=c++11 -O2) set(REDIS_MODULE_LDFLAGS -shared) endif() -add_library(ray_redis_module SHARED ray_redis_module.c) +add_library(ray_redis_module SHARED ray_redis_module.cc) target_compile_options(ray_redis_module PUBLIC ${REDIS_MODULE_CFLAGS} -fPIC) target_link_libraries(ray_redis_module ${REDIS_MODULE_LDFLAGS}) diff --git a/src/common/redis_module/ray_redis_module.c b/src/common/redis_module/ray_redis_module.cc similarity index 86% rename from src/common/redis_module/ray_redis_module.c rename to src/common/redis_module/ray_redis_module.cc index 9c7e0b654..f83b2c64b 100644 --- a/src/common/redis_module/ray_redis_module.c +++ b/src/common/redis_module/ray_redis_module.cc @@ -5,6 +5,10 @@ #include "redis_string.h" +#include "format/common_generated.h" + +#include "common_protocol.h" + /** * Various tables are maintained in redis: * @@ -42,11 +46,29 @@ RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, int mode) { RedisModuleString *prefixed_keyname = RedisString_Format(ctx, "%s%S", prefix, keyname); - RedisModuleKey *key = RedisModule_OpenKey(ctx, prefixed_keyname, mode); + RedisModuleKey *key = + (RedisModuleKey *) RedisModule_OpenKey(ctx, prefixed_keyname, mode); RedisModule_FreeString(ctx, prefixed_keyname); return key; } +/** + * This is a helper method to convert a redis module string to a flatbuffer + * string. + * + * @param fbb The flatbuffer builder. + * @param redis_string The redis string. + * @return The flatbuffer string. + */ +flatbuffers::Offset RedisStringToFlatbuf( + flatbuffers::FlatBufferBuilder &fbb, + RedisModuleString *redis_string) { + size_t redis_string_size; + const char *redis_string_str = + RedisModule_StringPtrLen(redis_string, &redis_string_size); + return fbb.CreateString(redis_string_str, redis_string_size); +} + /** * Publish a notification to a client's notification channel about an insertion * or deletion to the db client table. @@ -69,16 +91,23 @@ bool PublishDBClientNotification(RedisModuleCtx *ctx, /* Construct strings to publish on the db client channel. */ RedisModuleString *channel_name = RedisModule_CreateString(ctx, "db_clients", strlen("db_clients")); - RedisModuleString *client_info; - const char *is_insertion_string = is_insertion ? "1" : "0"; - if (aux_address) { - client_info = - RedisString_Format(ctx, "%S:%S %S %s", ray_client_id, client_type, - aux_address, is_insertion_string); + /* Construct the flatbuffers object to publish over the channel. */ + flatbuffers::FlatBufferBuilder fbb; + /* Use an empty aux address if one is not passed in. */ + flatbuffers::Offset aux_address_str; + if (aux_address != NULL) { + aux_address_str = RedisStringToFlatbuf(fbb, aux_address); } else { - client_info = RedisString_Format(ctx, "%S:%S : %s", ray_client_id, - client_type, is_insertion_string); + aux_address_str = fbb.CreateString("", strlen("")); } + /* Create the flatbuffers message. */ + auto message = CreateSubscribeToDBClientTableReply( + fbb, RedisStringToFlatbuf(fbb, ray_client_id), + RedisStringToFlatbuf(fbb, client_type), aux_address_str, is_insertion); + fbb.Finish(message); + /* Create a Redis string to publish by serializing the flatbuffers object. */ + RedisModuleString *client_info = RedisModule_CreateString( + ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); /* Publish the client info on the db client channel. */ RedisModuleCallReply *reply; @@ -328,48 +357,42 @@ bool PublishObjectNotification(RedisModuleCtx *ctx, RedisModuleString *object_id, RedisModuleString *data_size, RedisModuleKey *key) { - /* Create a string formatted as " MANAGERS - * ..." */ + flatbuffers::FlatBufferBuilder fbb; + long long data_size_value; if (RedisModule_StringToLongLong(data_size, &data_size_value) != REDISMODULE_OK) { return RedisModule_ReplyWithError(ctx, "data_size must be integer"); } - RedisModuleString *manager_list = RedisString_Format(ctx, "%S ", object_id); - - /* Append binary data size for this object. */ - /* TODO(pcm): Replace by a formatted fix length version of the size. */ - RedisModule_StringAppendBuffer(ctx, manager_list, - (const char *) &data_size_value, - sizeof(data_size_value)); - - RedisModule_StringAppendBuffer(ctx, manager_list, " MANAGERS", - strlen(" MANAGERS")); - + std::vector> manager_ids; CHECK_ERROR( RedisModule_ZsetFirstInScoreRange(key, REDISMODULE_NEGATIVE_INFINITE, REDISMODULE_POSITIVE_INFINITE, 1, 1), "Unable to initialize zset iterator"); - /* Loop over the managers in the object table for this object ID. */ do { RedisModuleString *curr = RedisModule_ZsetRangeCurrentElement(key, NULL); - RedisModule_StringAppendBuffer(ctx, manager_list, " ", 1); - size_t size; - const char *val = RedisModule_StringPtrLen(curr, &size); - RedisModule_StringAppendBuffer(ctx, manager_list, val, size); + manager_ids.push_back(RedisStringToFlatbuf(fbb, curr)); } while (RedisModule_ZsetRangeNext(key)); + auto message = CreateSubscribeToNotificationsReply( + fbb, RedisStringToFlatbuf(fbb, object_id), data_size_value, + fbb.CreateVector(manager_ids)); + fbb.Finish(message); + /* Publish the notification to the clients notification channel. * TODO(rkn): These notifications could be batched together. */ RedisModuleString *channel_name = RedisString_Format(ctx, "%s%S", OBJECT_CHANNEL_PREFIX, client_id); + RedisModuleString *payload = RedisModule_CreateString( + ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); + RedisModuleCallReply *reply; - reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, manager_list); + reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, payload); RedisModule_FreeString(ctx, channel_name); - RedisModule_FreeString(ctx, manager_list); + RedisModule_FreeString(ctx, payload); if (reply == NULL) { return false; } @@ -668,35 +691,6 @@ int ResultTableAdd_RedisCommand(RedisModuleCtx *ctx, return REDISMODULE_OK; } -int ParseTaskState(RedisModuleString *state) { - size_t state_length; - const char *state_string = RedisModule_StringPtrLen(state, &state_length); - int state_integer; - int scanned = sscanf(state_string, "%2d", &state_integer); - if (scanned != 1 || state_length != 2) { - return -1; - } - return state_integer; -} - -RedisModuleString *NormalizeTaskState(RedisModuleCtx *ctx, - RedisModuleString *state) { - /* Pad the state integer to a fixed-width integer, and make sure it has width - * less than or equal to 2. */ - long long state_integer; - int status = RedisModule_StringToLongLong(state, &state_integer); - if (status != REDISMODULE_OK) { - return NULL; - } - state = RedisModule_CreateStringPrintf(ctx, "%2d", state_integer); - size_t length; - RedisModule_StringPtrLen(state, &length); - if (length != 2) { - return NULL; - } - return state; -} - /** * Reply with information about a task ID. This is used by * RAY.RESULT_TABLE_LOOKUP and RAY.TASK_TABLE_GET. @@ -716,8 +710,9 @@ int ReplyWithTask(RedisModuleCtx *ctx, RedisModuleString *task_id) { RedisModuleString *state = NULL; RedisModuleString *local_scheduler_id = NULL; RedisModuleString *task_spec = NULL; - RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", &state, "node", - &local_scheduler_id, "TaskSpec", &task_spec, NULL); + RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", &state, + "local_scheduler_id", &local_scheduler_id, "TaskSpec", + &task_spec, NULL); if (state == NULL || local_scheduler_id == NULL || task_spec == NULL) { /* We must have either all fields or no fields. */ RedisModule_CloseKey(key); @@ -725,21 +720,26 @@ int ReplyWithTask(RedisModuleCtx *ctx, RedisModuleString *task_id) { ctx, "Missing fields in the task table entry"); } - int state_integer = ParseTaskState(state); - if (state_integer < 0) { + long long state_integer; + if (RedisModule_StringToLongLong(state, &state_integer) != REDISMODULE_OK || + state_integer < 0) { RedisModule_CloseKey(key); RedisModule_FreeString(ctx, state); RedisModule_FreeString(ctx, local_scheduler_id); RedisModule_FreeString(ctx, task_spec); - return RedisModule_ReplyWithError(ctx, - "Found invalid scheduling state (must " - "be an integer of width 2"); + return RedisModule_ReplyWithError(ctx, "Found invalid scheduling state."); } - RedisModule_ReplyWithArray(ctx, 3); - RedisModule_ReplyWithLongLong(ctx, state_integer); - RedisModule_ReplyWithString(ctx, local_scheduler_id); - RedisModule_ReplyWithString(ctx, task_spec); + flatbuffers::FlatBufferBuilder fbb; + auto message = + CreateTaskReply(fbb, RedisStringToFlatbuf(fbb, task_id), state_integer, + RedisStringToFlatbuf(fbb, local_scheduler_id), + RedisStringToFlatbuf(fbb, task_spec)); + fbb.Finish(message); + + RedisModuleString *reply = RedisModule_CreateString( + ctx, (char *) fbb.GetBufferPointer(), fbb.GetSize()); + RedisModule_ReplyWithString(ctx, reply); RedisModule_FreeString(ctx, state); RedisModule_FreeString(ctx, local_scheduler_id); @@ -801,54 +801,59 @@ int ResultTableLookup_RedisCommand(RedisModuleCtx *ctx, int TaskTableWrite(RedisModuleCtx *ctx, RedisModuleString *task_id, RedisModuleString *state, - RedisModuleString *node_id, + RedisModuleString *local_scheduler_id, RedisModuleString *task_spec) { - /* Pad the state integer to a fixed-width integer, and make sure it has width - * less than or equal to 2. */ - state = NormalizeTaskState(ctx, state); - if (state == NULL) { - return RedisModule_ReplyWithError( - ctx, - "Invalid scheduling state (must be an integer of width at most 2)"); + /* Extract the scheduling state. */ + long long state_value; + if (RedisModule_StringToLongLong(state, &state_value) != REDISMODULE_OK) { + return RedisModule_ReplyWithError(ctx, "scheduling state must be integer"); } - /* Add the task to the task table. If no spec was provided, get the existing * spec out of the task table so we can publish it. */ RedisModuleString *existing_task_spec = NULL; RedisModuleKey *key = OpenPrefixedKey(ctx, TASK_PREFIX, task_id, REDISMODULE_WRITE); if (task_spec == NULL) { - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, "node", - node_id, NULL); + RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, + "local_scheduler_id", local_scheduler_id, NULL); RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "TaskSpec", &existing_task_spec, NULL); if (existing_task_spec == NULL) { RedisModule_CloseKey(key); - RedisModule_FreeString(ctx, state); return RedisModule_ReplyWithError( ctx, "Cannot update a task that doesn't exist yet"); } } else { - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, "node", - node_id, "TaskSpec", task_spec, NULL); + RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, + "local_scheduler_id", local_scheduler_id, "TaskSpec", + task_spec, NULL); } RedisModule_CloseKey(key); /* Build the PUBLISH topic and message for task table subscribers. The topic - * is a string in the format "TASK_PREFIX::". The - * message is a string in the format: " ". */ - RedisModuleString *publish_topic = - RedisString_Format(ctx, "%s%S:%S", TASK_PREFIX, node_id, state); - RedisModuleString *publish_message; + * is a string in the format "TASK_PREFIX::". The + * message is a serialized SubscribeToTasksReply flatbuffer object. */ + RedisModuleString *publish_topic = RedisString_Format( + ctx, "%s%S:%S", TASK_PREFIX, local_scheduler_id, state); + + /* Construct the flatbuffers object for the payload. */ + flatbuffers::FlatBufferBuilder fbb; + /* Use the old task spec if the current one is NULL. */ + RedisModuleString *task_spec_to_use; if (task_spec != NULL) { - publish_message = RedisString_Format(ctx, "%S %S %S %S", task_id, state, - node_id, task_spec); + task_spec_to_use = task_spec; } else { - publish_message = RedisString_Format(ctx, "%S %S %S %S", task_id, state, - node_id, existing_task_spec); + task_spec_to_use = existing_task_spec; } - RedisModule_FreeString(ctx, state); + /* Create the flatbuffers message. */ + auto message = + CreateTaskReply(fbb, RedisStringToFlatbuf(fbb, task_id), state_value, + RedisStringToFlatbuf(fbb, local_scheduler_id), + RedisStringToFlatbuf(fbb, task_spec_to_use)); + fbb.Finish(message); + + RedisModuleString *publish_message = RedisModule_CreateString( + ctx, (const char *) fbb.GetBufferPointer(), fbb.GetSize()); RedisModuleCallReply *reply = RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message); @@ -878,10 +883,7 @@ int TaskTableWrite(RedisModuleCtx *ctx, * * @param task_id A string that is the ID of the task. * @param state A string that is the current scheduling state (a - * scheduling_state enum instance). The string's value must be a - * nonnegative integer less than 100, so that it has width at most 2. If - * less than 2, the value will be left-padded with spaces to a width of - * 2. + * scheduling_state enum instance). * @param local_scheduler_id A string that is the ray client ID of the * associated local scheduler, if any. * @param task_spec A string that is the specification of the task, which can @@ -908,10 +910,7 @@ int TaskTableAddTask_RedisCommand(RedisModuleCtx *ctx, * * @param task_id A string that is the ID of the task. * @param state A string that is the current scheduling state (a - * scheduling_state enum instance). The string's value must be a - * nonnegative integer less than 100, so that it has width at most 2. If - * less than 2, the value will be left-padded with spaces to a width of - * 2. + * scheduling_state enum instance). * @param ray_client_id A string that is the ray client ID of the associated * local scheduler, if any. * @return OK if the operation was successful. @@ -941,18 +940,15 @@ int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx, * scheduling state. The update happens if and only if the current * scheduling state AND-ed with the bitmask is greater than 0. * @param state A string that is the scheduling state (a scheduling_state enum - * instance) to update the task entry with. The string's value must be a - * nonnegative integer less than 100, so that it has width at most 2. If - * less than 2, the value will be left-padded with spaces to a width of - * 2. + * instance) to update the task entry with. * @param ray_client_id A string that is the ray client ID of the associated * local scheduler, if any, to update the task entry with. * @return If the current scheduling state does not match the test bitmask, * returns nil. Else, returns the same as RAY.TASK_TABLE_GET: an array * of strings representing the updated task fields in the following - * order: 1) (integer) scheduling state 2) (string) associated node ID, - * if any 3) (string) the task specification, which can be casted to a - * task_spec. + * order: 1) (integer) scheduling state 2) (string) associated local + * scheduler ID, if any 3) (string) the task specification, which can be + * cast to a task_spec. */ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, @@ -961,18 +957,12 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, return RedisModule_WrongArity(ctx); } - RedisModuleString *state = NormalizeTaskState(ctx, argv[3]); - if (state == NULL) { - return RedisModule_ReplyWithError( - ctx, - "Invalid scheduling state (must be an integer of width at most 2)"); - } + RedisModuleString *state = argv[3]; RedisModuleKey *key = OpenPrefixedKey(ctx, TASK_PREFIX, argv[1], REDISMODULE_READ | REDISMODULE_WRITE); if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { RedisModule_CloseKey(key); - RedisModule_FreeString(ctx, state); return RedisModule_ReplyWithNull(ctx); } @@ -980,19 +970,20 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString *current_state = NULL; RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", ¤t_state, NULL); - int current_state_integer = ParseTaskState(current_state); + long long current_state_integer; + if (RedisModule_StringToLongLong(current_state, ¤t_state_integer) != + REDISMODULE_OK) { + return RedisModule_ReplyWithError(ctx, "current_state must be integer"); + } + if (current_state_integer < 0) { RedisModule_CloseKey(key); - RedisModule_FreeString(ctx, state); - return RedisModule_ReplyWithError(ctx, - "Found invalid scheduling state (must " - "be an integer of width 2"); + return RedisModule_ReplyWithError(ctx, "Found invalid scheduling state."); } long long test_state_bitmask; int status = RedisModule_StringToLongLong(argv[2], &test_state_bitmask); if (status != REDISMODULE_OK) { RedisModule_CloseKey(key); - RedisModule_FreeString(ctx, state); return RedisModule_ReplyWithError( ctx, "Invalid test value for scheduling state"); } @@ -1000,16 +991,14 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, /* The current value does not match the test bitmask, so do not perform the * update. */ RedisModule_CloseKey(key); - RedisModule_FreeString(ctx, state); return RedisModule_ReplyWithNull(ctx); } /* The test passed, so perform the update. */ - RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, "node", - argv[4], NULL); + RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, + "local_scheduler_id", argv[4], NULL); /* Clean up. */ RedisModule_CloseKey(key); - RedisModule_FreeString(ctx, state); /* Construct a reply by getting the task from the task ID. */ return ReplyWithTask(ctx, argv[1]); } @@ -1023,9 +1012,9 @@ int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, * * @param task_id A string of the task ID to look up. * @return An array of strings representing the task fields in the following - * order: 1) (integer) scheduling state 2) (string) associated node ID, - * if any 3) (string) the task specification, which can be casted to a - * task_spec. If the task ID is not in the table, returns nil. + * order: 1) (integer) scheduling state 2) (string) associated local + * scheduler ID, if any 3) (string) the task specification, which can be + * cast to a task_spec. If the task ID is not in the table, returns nil. */ int TaskTableGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, @@ -1038,6 +1027,8 @@ int TaskTableGet_RedisCommand(RedisModuleCtx *ctx, return ReplyWithTask(ctx, argv[1]); } +extern "C" { + /* This function must be present on each Redis module. It is used in order to * register the commands into the Redis server. */ int RedisModule_OnLoad(RedisModuleCtx *ctx, @@ -1135,3 +1126,5 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, return REDISMODULE_OK; } + +} /* extern "C" */ diff --git a/src/common/state/redis.cc b/src/common/state/redis.cc index 74ef8227f..0b5cece64 100644 --- a/src/common/state/redis.cc +++ b/src/common/state/redis.cc @@ -27,6 +27,10 @@ extern "C" { #include "redis.h" #include "io.h" +#include "format/common_generated.h" + +#include "common_protocol.h" + #ifndef _WIN32 /* This function is actually not declared in standard POSIX, so declare it. */ extern int usleep(useconds_t usec); @@ -358,27 +362,14 @@ Task *parse_and_construct_task_from_redis_reply(redisReply *reply) { if (reply->type == REDIS_REPLY_NIL) { /* There is no task in the reply, so return NULL. */ task = NULL; - } else if (reply->type == REDIS_REPLY_ARRAY) { - /* Check that the reply is as expected. The 0th element is the scheduling - * state. The 1st element is the db_client_id of the associated local - * scheduler, and the 2nd element is the TaskSpec. */ - CHECK(reply->elements == 3); - CHECK(reply->element[0]->type == REDIS_REPLY_INTEGER); - CHECK(reply->element[1]->type == REDIS_REPLY_STRING); - CHECK(reply->element[2]->type == REDIS_REPLY_STRING); - /* Parse the scheduling state. */ - long long state = reply->element[0]->integer; - /* Parse the local scheduler db_client_id. */ - DBClientID local_scheduler_id; - CHECK(sizeof(local_scheduler_id) == reply->element[1]->len); - memcpy(local_scheduler_id.id, reply->element[1]->str, - reply->element[1]->len); - /* Parse the task spec. */ - TaskSpec *spec = (TaskSpec *) malloc(reply->element[2]->len); - memcpy(spec, reply->element[2]->str, reply->element[2]->len); - task = Task_alloc(spec, reply->element[2]->len, state, local_scheduler_id); - /* Free the task spec. */ - TaskSpec_free(spec); + } else if (reply->type == REDIS_REPLY_STRING) { + /* The reply is a flatbuffer TaskReply object. Parse it and construct the + * task. */ + auto message = flatbuffers::GetRoot(reply->str); + TaskSpec *spec = (TaskSpec *) message->task_spec()->data(); + int64_t task_spec_size = message->task_spec()->size(); + task = Task_alloc(spec, task_spec_size, message->state(), + from_flatbuf(message->local_scheduler_id())); } else { LOG_FATAL("Unexpected reply type %d", reply->type); } @@ -502,81 +493,6 @@ void redis_object_table_lookup_callback(redisAsyncContext *c, } } -/** - * This will parse a payload string published on the object notification - * channel. The string must have the format: - * - * MANAGERS ... - * - * where there may be any positive number of manager IDs. - * - * @param db The db handle. - * @param payload The payload string. - * @param length The length of the string. - * @param manager_count This method will write the number of managers at this - * address. - * @param manager_vector This method will allocate an array of pointers to - * manager addresses and write the address of the array at this address. - * The caller is responsible for freeing this array. - * @return The object ID that the notification is about. - */ -ObjectID parse_subscribe_to_notifications_payload( - DBHandle *db, - char *payload, - int length, - int64_t *data_size, - int *manager_count, - const char ***manager_vector) { - long long data_size_value = 0; - int num_managers = (length - sizeof(ObjectID) - 1 - sizeof(data_size_value) - - 1 - strlen("MANAGERS")) / - (1 + sizeof(DBClientID)); - - int64_t rval = sizeof(ObjectID) + 1 + sizeof(data_size_value) + 1 + - strlen("MANAGERS") + num_managers * (1 + sizeof(DBClientID)); - - CHECKM(length == rval, - "length mismatch: num_managers = %d, length = %d, rval = %" PRId64, - num_managers, length, rval); - CHECK(num_managers > 0); - ObjectID obj_id; - /* Track our current offset in the payload. */ - int offset = 0; - /* Parse the object ID. */ - memcpy(&obj_id.id, &payload[offset], sizeof(obj_id.id)); - offset += sizeof(obj_id.id); - /* The next part of the payload is a space. */ - const char *space_str = " "; - CHECK(memcmp(&payload[offset], space_str, strlen(space_str)) == 0); - offset += strlen(space_str); - /* The next part of the payload is binary data_size. */ - memcpy(&data_size_value, &payload[offset], sizeof(data_size_value)); - offset += sizeof(data_size_value); - /* The next part of the payload is the string " MANAGERS" with leading ' '. */ - const char *managers_str = " MANAGERS"; - CHECK(memcmp(&payload[offset], managers_str, strlen(managers_str)) == 0); - offset += strlen(managers_str); - /* Parse the managers. */ - const char **managers = (const char **) malloc(num_managers * sizeof(char *)); - for (int i = 0; i < num_managers; ++i) { - /* First there is a space. */ - CHECK(memcmp(&payload[offset], " ", strlen(" ")) == 0); - offset += strlen(" "); - /* Get the manager ID. */ - DBClientID manager_id; - memcpy(&manager_id.id, &payload[offset], sizeof(manager_id.id)); - offset += sizeof(manager_id.id); - /* Write the address of the corresponding manager to the returned array. */ - redis_get_cached_db_client(db, manager_id, &managers[i]); - } - CHECK(offset == length); - /* Return the manager array and the object ID. */ - *manager_count = num_managers; - *manager_vector = managers; - *data_size = data_size_value; - return obj_id; -} - void object_table_redis_subscribe_to_notifications_callback( redisAsyncContext *c, void *r, @@ -603,13 +519,22 @@ void object_table_redis_subscribe_to_notifications_callback( message_type->str); if (strcmp(message_type->str, "message") == 0) { - /* Handle an object notification. */ - int64_t data_size = 0; - int manager_count; - const char **manager_vector; - ObjectID obj_id = parse_subscribe_to_notifications_payload( - db, reply->element[2]->str, reply->element[2]->len, &data_size, - &manager_count, &manager_vector); + /* We received an object notification. Parse the payload. */ + auto message = flatbuffers::GetRoot( + reply->element[2]->str); + /* Extract the object ID. */ + ObjectID obj_id = from_flatbuf(message->object_id()); + /* Extract the data size. */ + int64_t data_size = message->object_size(); + int manager_count = message->manager_ids()->size(); + /* Construct the manager vector from the flatbuffers object. */ + const char **manager_vector = + (const char **) malloc(manager_count * sizeof(char *)); + for (int i = 0; i < manager_count; ++i) { + DBClientID manager_id = from_flatbuf(message->manager_ids()->Get(i)); + redis_get_cached_db_client(db, manager_id, &manager_vector[i]); + } + /* Call the subscribe callback. */ ObjectTableSubscribeData *data = (ObjectTableSubscribeData *) callback_data->data; @@ -641,7 +566,7 @@ void redis_object_table_subscribe_to_notifications( TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; /* The object channel prefix must match the value defined in - * src/common/redismodule/ray_redis_module.c. */ + * src/common/redismodule/ray_redis_module.cc. */ const char *object_channel_prefix = "OC:"; const char *object_channel_bcast = "BCAST"; int status = REDIS_OK; @@ -869,47 +794,6 @@ void redis_task_table_test_and_update(TableCallbackData *callback_data) { } } -/* The format of the payload is described in ray_redis_module.c and is - * " ". TODO(rkn): - * Make this code nicer. */ -void parse_task_table_subscribe_callback(char *payload, - int length, - TaskID *task_id, - int *state, - DBClientID *local_scheduler_id, - TaskSpec **spec, - int64_t *task_spec_size) { - /* Note that the state is padded with spaces to consist of precisely two - * characters. */ - int task_spec_payload_size = - length - sizeof(*task_id) - 1 - 2 - 1 - sizeof(*local_scheduler_id) - 1; - int offset = 0; - /* Read in the task ID. */ - memcpy(task_id, &payload[offset], sizeof(*task_id)); - offset += sizeof(*task_id); - /* Read in a space. */ - const char *space_str = (const char *) " "; - CHECK(memcmp(space_str, &payload[offset], strlen(space_str)) == 0); - offset += strlen(space_str); - /* Read in the state, which is an integer left-padded with spaces to two - * characters. */ - CHECK(sscanf(&payload[offset], "%2d", state) == 1); - offset += 2; - /* Read in a space. */ - CHECK(memcmp(space_str, &payload[offset], strlen(space_str)) == 0); - offset += strlen(space_str); - /* Read in the local scheduler ID. */ - memcpy(local_scheduler_id, &payload[offset], sizeof(*local_scheduler_id)); - offset += sizeof(*local_scheduler_id); - /* Read in a space. */ - CHECK(memcmp(space_str, &payload[offset], strlen(space_str)) == 0); - offset += strlen(space_str); - /* Read in the task spec. */ - *spec = (TaskSpec *) malloc(task_spec_payload_size); - memcpy(*spec, &payload[offset], task_spec_payload_size); - *task_spec_size = task_spec_payload_size; -} - void redis_task_table_subscribe_callback(redisAsyncContext *c, void *r, void *privdata) { @@ -917,7 +801,7 @@ void redis_task_table_subscribe_callback(redisAsyncContext *c, redisReply *reply = (redisReply *) r; CHECK(reply->type == REDIS_REPLY_ARRAY); - /* The number of elements is 3 for a reply to SUBSCRIBE, and 4 for a reply to + /* The number of elements is 3 for a reply to SUBSCRIBE, and 4 for a reply to * PSUBSCRIBE. */ CHECKM(reply->elements == 3 || reply->elements == 4, "reply->elements is %zu", reply->elements); @@ -929,20 +813,22 @@ void redis_task_table_subscribe_callback(redisAsyncContext *c, if (strcmp(message_type->str, "message") == 0 || strcmp(message_type->str, "pmessage") == 0) { /* Handle a task table event. Parse the payload and call the callback. */ + auto message = flatbuffers::GetRoot(payload->str); + /* Extract the task ID. */ + TaskID task_id = from_flatbuf(message->task_id()); + /* Extract the scheduling state. */ + int64_t state = message->state(); + /* Extract the local scheduler ID. */ + DBClientID local_scheduler_id = from_flatbuf(message->local_scheduler_id()); + /* Extract the task spec. */ + TaskSpec *spec = (TaskSpec *) message->task_spec()->data(); + int64_t task_spec_size = message->task_spec()->size(); + /* Create a task. */ + Task *task = Task_alloc(spec, task_spec_size, state, local_scheduler_id); + + /* Call the subscribe callback if there is one. */ TaskTableSubscribeData *data = (TaskTableSubscribeData *) callback_data->data; - /* Read out the information from the payload. */ - TaskID task_id; - int state; - DBClientID local_scheduler_id; - TaskSpec *spec; - int64_t task_spec_size; - parse_task_table_subscribe_callback(payload->str, payload->len, &task_id, - &state, &local_scheduler_id, &spec, - &task_spec_size); - Task *task = Task_alloc(spec, task_spec_size, state, local_scheduler_id); - TaskSpec_free(spec); - /* Call the subscribe callback if there is one. */ if (data->subscribe_callback != NULL) { data->subscribe_callback(task, data->subscribe_context); } @@ -969,24 +855,24 @@ void redis_task_table_subscribe_callback(redisAsyncContext *c, void redis_task_table_subscribe(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; TaskTableSubscribeData *data = (TaskTableSubscribeData *) callback_data->data; - /* TASK_CHANNEL_PREFIX is defined in ray_redis_module.c and must be kept in + /* TASK_CHANNEL_PREFIX is defined in ray_redis_module.cc and must be kept in * sync with that file. */ const char *TASK_CHANNEL_PREFIX = "TT:"; int status; if (IS_NIL_ID(data->local_scheduler_id)) { /* TODO(swang): Implement the state_filter by translating the bitmask into * a Redis key-matching pattern. */ - status = redisAsyncCommand( - db->sub_context, redis_task_table_subscribe_callback, - (void *) callback_data->timer_id, "PSUBSCRIBE %s*:%2d", - TASK_CHANNEL_PREFIX, data->state_filter); + status = + redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback, + (void *) callback_data->timer_id, "PSUBSCRIBE %s*:%d", + TASK_CHANNEL_PREFIX, data->state_filter); } else { DBClientID local_scheduler_id = data->local_scheduler_id; - status = redisAsyncCommand( - db->sub_context, redis_task_table_subscribe_callback, - (void *) callback_data->timer_id, "SUBSCRIBE %s%b:%2d", - TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id, - sizeof(local_scheduler_id.id), data->state_filter); + status = + redisAsyncCommand(db->sub_context, redis_task_table_subscribe_callback, + (void *) callback_data->timer_id, "SUBSCRIBE %s%b:%d", + TASK_CHANNEL_PREFIX, (char *) local_scheduler_id.id, + sizeof(local_scheduler_id.id), data->state_filter); } if ((status == REDIS_ERR) || db->sub_context->err) { LOG_REDIS_DEBUG(db->sub_context, "error in redis_task_table_subscribe"); @@ -1052,40 +938,23 @@ void redis_db_client_table_subscribe_callback(redisAsyncContext *c, return; } /* Otherwise, parse the payload and call the callback. */ - DBClientTableSubscribeData *data = - (DBClientTableSubscribeData *) callback_data->data; - DBClientID client; - memcpy(client.id, payload->str, sizeof(client.id)); - /* We subtract 1 + sizeof(client.id) to compute the length of the - * client_type string, and we add 1 to null-terminate the string. */ - int client_type_length = payload->len - 1 - sizeof(client.id) + 1; - CHECK(client_type_length > 0); + auto message = + flatbuffers::GetRoot(payload->str); + DBClientID client = from_flatbuf(message->db_client_id()); /* Parse the client type and auxiliary address from the response. If there is * only client type, then the update was a delete. */ - char *client_type = (char *) malloc(client_type_length); - char *aux_address = (char *) malloc(client_type_length); - int is_insertion; - memset(aux_address, 0, client_type_length); - /* Published message format: */ - int rv = sscanf(&payload->str[1 + sizeof(client.id)], "%s %s %d", client_type, - aux_address, &is_insertion); - CHECKM(rv == 3, - "redis_db_client_table_subscribe_callback: expected 2 parsed args, " - "Got %d instead.", - rv); - CHECKM(is_insertion == 1 || is_insertion == 0, - "redis_db_client_table_subscribe_callback: expected 0 or 1 for " - "insertion field, got %d instead.", - is_insertion); + char *client_type = (char *) message->client_type()->data(); + char *aux_address = (char *) message->aux_address()->data(); + bool is_insertion = message->is_insertion(); /* Call the subscription callback. */ + DBClientTableSubscribeData *data = + (DBClientTableSubscribeData *) callback_data->data; if (data->subscribe_callback) { - data->subscribe_callback(client, client_type, aux_address, - (bool) is_insertion, data->subscribe_context); + data->subscribe_callback(client, client_type, aux_address, is_insertion, + data->subscribe_context); } - free(client_type); - free(aux_address); } void redis_db_client_table_subscribe(TableCallbackData *callback_data) { @@ -1114,15 +983,26 @@ void redis_local_scheduler_table_subscribe_callback(redisAsyncContext *c, if (strcmp(message_type->str, "message") == 0) { /* Handle a local scheduler heartbeat. Parse the payload and call the * subscribe callback. */ - redisReply *payload = reply->element[2]; + auto message = + flatbuffers::GetRoot(reply->element[2]->str); + + /* Extract the client ID. */ + DBClientID client_id = from_flatbuf(message->db_client_id()); + /* Extract the fields of the local scheduler info struct. */ + LocalSchedulerInfo info; + info.total_num_workers = message->total_num_workers(); + info.task_queue_length = message->task_queue_length(); + info.available_workers = message->available_workers(); + for (int i = 0; i < ResourceIndex_MAX; ++i) { + info.static_resources[i] = message->static_resources()->Get(i); + } + for (int i = 0; i < ResourceIndex_MAX; ++i) { + info.dynamic_resources[i] = message->dynamic_resources()->Get(i); + } + + /* Call the subscribe callback. */ LocalSchedulerTableSubscribeData *data = (LocalSchedulerTableSubscribeData *) callback_data->data; - DBClientID client_id; - LocalSchedulerInfo info; - /* The payload should be the concatenation of these two structs. */ - CHECK(sizeof(client_id) + sizeof(info) == payload->len); - memcpy(&client_id, payload->str, sizeof(client_id)); - memcpy(&info, payload->str + sizeof(client_id), sizeof(info)); if (data->subscribe_callback) { data->subscribe_callback(client_id, info, data->subscribe_context); } @@ -1167,10 +1047,22 @@ void redis_local_scheduler_table_send_info(TableCallbackData *callback_data) { DBHandle *db = callback_data->db_handle; LocalSchedulerTableSendInfoData *data = (LocalSchedulerTableSendInfoData *) callback_data->data; + + /* Create a flatbuffer object to serialize and publish. */ + flatbuffers::FlatBufferBuilder fbb; + /* Create the flatbuffers message. */ + LocalSchedulerInfo info = data->info; + auto message = CreateLocalSchedulerInfoMessage( + fbb, to_flatbuf(fbb, db->client), info.total_num_workers, + info.task_queue_length, info.available_workers, + fbb.CreateVector(info.static_resources, ResourceIndex_MAX), + fbb.CreateVector(info.dynamic_resources, ResourceIndex_MAX)); + fbb.Finish(message); + int status = redisAsyncCommand( db->context, redis_local_scheduler_table_send_info_callback, - (void *) callback_data->timer_id, "PUBLISH local_schedulers %b%b", - db->client.id, sizeof(db->client.id), &data->info, sizeof(data->info)); + (void *) callback_data->timer_id, "PUBLISH local_schedulers %b", + fbb.GetBufferPointer(), fbb.GetSize()); if ((status == REDIS_ERR) || db->context->err) { LOG_REDIS_DEBUG(db->context, "error in redis_local_scheduler_table_send_info"); diff --git a/test/stress_tests.py b/test/stress_tests.py index 0ae3c7980..91c458059 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -8,6 +8,9 @@ import numpy as np import time import redis +# Import flatbuffer bindings. +from ray.core.generated.TaskReply import TaskReply + class TaskTests(unittest.TestCase): def testSubmittingTasks(self): @@ -164,9 +167,13 @@ class ReconstructionTests(unittest.TestCase): r = redis.StrictRedis(port=self.redis_port) task_ids = r.keys("TT:*") task_ids = [task_id[3:] for task_id in task_ids] - node_ids = [r.execute_command("ray.task_table_get", task_id)[1] for task_id - in task_ids] - self.assertEqual(len(set(node_ids)), self.num_local_schedulers) + local_scheduler_ids = [] + for task_id in task_ids: + message = r.execute_command("ray.task_table_get", task_id) + task_reply_object = TaskReply.GetRootAsTaskReply(message, 0) + local_scheduler_ids.append(task_reply_object.LocalSchedulerId()) + + self.assertEqual(len(set(local_scheduler_ids)), self.num_local_schedulers) # Clean up the Ray cluster. ray.worker.cleanup() diff --git a/webui/backend/ray_ui.py b/webui/backend/ray_ui.py index f12382858..e28815256 100644 --- a/webui/backend/ray_ui.py +++ b/webui/backend/ray_ui.py @@ -19,7 +19,7 @@ loop = asyncio.get_event_loop() IDENTIFIER_LENGTH = 20 -# This prefix must match the value defined in ray_redis_module.c. +# This prefix must match the value defined in ray_redis_module.cc. DB_CLIENT_PREFIX = b"CL:" def hex_identifier(identifier):