[C++] Add hash table to Redis-Module (#4911)

This commit is contained in:
Yuhong Guo
2019-06-07 16:11:37 +08:00
committed by Hao Chen
parent cbc67fc750
commit 5eff47b657
15 changed files with 686 additions and 93 deletions
+1 -1
View File
@@ -535,7 +535,7 @@ flatbuffer_py_library(
"ErrorTableData.py",
"ErrorType.py",
"FunctionTableData.py",
"GcsTableEntry.py",
"GcsEntry.py",
"HeartbeatBatchTableData.py",
"HeartbeatTableData.py",
"Language.py",
+1 -1
View File
@@ -29,7 +29,7 @@ MOCK_MODULES = [
"ray.core.generated.EntryType",
"ray.core.generated.ErrorTableData",
"ray.core.generated.ErrorType",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.GcsEntry",
"ray.core.generated.HeartbeatBatchTableData",
"ray.core.generated.HeartbeatTableData",
"ray.core.generated.Language",
+1 -1
View File
@@ -160,7 +160,7 @@ flatbuffers_generated_files = [
"ErrorTableData.java",
"ErrorType.java",
"FunctionTableData.java",
"GcsTableEntry.java",
"GcsEntry.java",
"HeartbeatBatchTableData.java",
"HeartbeatTableData.java",
"Language.java",
+2 -2
View File
@@ -9,7 +9,7 @@ from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.DriverTableData import DriverTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.GcsEntry import GcsEntry
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.Language import Language
@@ -25,7 +25,7 @@ __all__ = [
"ClientTableData",
"DriverTableData",
"ErrorTableData",
"GcsTableEntry",
"GcsEntry",
"HeartbeatBatchTableData",
"HeartbeatTableData",
"Language",
+2 -4
View File
@@ -101,8 +101,7 @@ class Monitor(object):
def xray_heartbeat_batch_handler(self, unused_channel, data):
"""Handle an xray heartbeat batch message from Redis."""
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
heartbeat_data = gcs_entries.Entries(0)
message = (ray.gcs_utils.HeartbeatBatchTableData.
@@ -208,8 +207,7 @@ class Monitor(object):
unused_channel: The message channel.
data: The message data.
"""
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
driver_data = gcs_entries.Entries(0)
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
driver_data, 0)
+8 -14
View File
@@ -41,7 +41,7 @@ def _parse_client_table(redis_client):
return []
node_info = {}
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
ordered_client_ids = []
@@ -248,8 +248,7 @@ class GlobalState(object):
object_id.binary())
if message is None:
return {}
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
assert gcs_entry.EntriesLength() > 0
@@ -307,8 +306,7 @@ class GlobalState(object):
"", task_id.binary())
if message is None:
return {}
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
assert gcs_entries.EntriesLength() == 1
@@ -431,8 +429,7 @@ class GlobalState(object):
if message is None:
return []
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
profile_events = []
for i in range(gcs_entries.EntriesLength()):
@@ -815,9 +812,8 @@ class GlobalState(object):
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
continue
data = raw_message["data"]
gcs_entries = (
ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0))
gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
data, 0))
heartbeat_data = gcs_entries.Entries(0)
message = (ray.gcs_utils.HeartbeatTableData.
GetRootAsHeartbeatTableData(heartbeat_data, 0))
@@ -871,8 +867,7 @@ class GlobalState(object):
if message is None:
return []
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
error_messages = []
for i in range(gcs_entries.EntriesLength()):
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
@@ -934,8 +929,7 @@ class GlobalState(object):
)
if message is None:
return None
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
entry = (
ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData(
gcs_entry.Entries(0), 0))
+1 -1
View File
@@ -1656,7 +1656,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
if msg is None:
threads_stopped.wait(timeout=0.01)
continue
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
msg["data"], 0)
assert gcs_entry.EntriesLength() == 1
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
+3
View File
@@ -120,6 +120,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
profile_table_.reset(new ProfileTable(shard_contexts_, this));
actor_checkpoint_table_.reset(new ActorCheckpointTable(shard_contexts_, this));
actor_checkpoint_id_table_.reset(new ActorCheckpointIdTable(shard_contexts_, this));
resource_table_.reset(new DynamicResourceTable({primary_context_}, this));
command_type_ = command_type;
// TODO(swang): Call the client table's Connect() method here. To do this,
@@ -229,6 +230,8 @@ ActorCheckpointIdTable &AsyncGcsClient::actor_checkpoint_id_table() {
return *actor_checkpoint_id_table_;
}
DynamicResourceTable &AsyncGcsClient::resource_table() { return *resource_table_; }
} // namespace gcs
} // namespace ray
+2
View File
@@ -62,6 +62,7 @@ class RAY_EXPORT AsyncGcsClient {
ProfileTable &profile_table();
ActorCheckpointTable &actor_checkpoint_table();
ActorCheckpointIdTable &actor_checkpoint_id_table();
DynamicResourceTable &resource_table();
// We also need something to export generic code to run on workers from the
// driver (to set the PYTHONPATH)
@@ -94,6 +95,7 @@ class RAY_EXPORT AsyncGcsClient {
std::unique_ptr<ClientTable> client_table_;
std::unique_ptr<ActorCheckpointTable> actor_checkpoint_table_;
std::unique_ptr<ActorCheckpointIdTable> actor_checkpoint_id_table_;
std::unique_ptr<DynamicResourceTable> resource_table_;
// The following contexts write to the data shard
std::vector<std::shared_ptr<RedisContext>> shard_contexts_;
std::vector<std::unique_ptr<RedisAsioClient>> shard_asio_async_clients_;
+162 -10
View File
@@ -657,13 +657,12 @@ void TestSetSubscribeAll(const DriverID &driver_id,
// Callback for a notification.
auto notification_callback = [object_ids, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsTableNotificationMode notification_mode,
gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode,
const std::vector<ObjectTableDataT> data) {
if (test->NumCallbacks() < 3 * 3) {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
} else {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::REMOVE);
ASSERT_EQ(change_mode, GcsChangeMode::REMOVE);
}
ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]);
// Check that we get notifications in the same order as the writes.
@@ -894,10 +893,9 @@ void TestSetSubscribeId(const DriverID &driver_id,
// The callback for a notification from the table. This should only be
// received for keys that we requested notifications for.
auto notification_callback = [object_id2, managers2](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsTableNotificationMode notification_mode,
gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
// Check that we only get notifications for the requested key.
ASSERT_EQ(id, object_id2);
// Check that we get notifications in the same order as the writes.
@@ -1111,10 +1109,9 @@ void TestSetSubscribeCancel(const DriverID &driver_id,
// The callback for a notification from the object table. This should only be
// received for the object that we requested notifications for.
auto notification_callback = [object_id, managers](
gcs::AsyncGcsClient *client, const ObjectID &id,
const GcsTableNotificationMode notification_mode,
gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode,
const std::vector<ObjectTableDataT> &data) {
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
ASSERT_EQ(id, object_id);
// Check that we get a duplicate notification for the first write. We get a
// duplicate notification because notifications
@@ -1307,6 +1304,161 @@ TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) {
TestClientTableMarkDisconnected(driver_id_, client_);
}
void TestHashTable(const DriverID &driver_id,
std::shared_ptr<gcs::AsyncGcsClient> client) {
const int expected_count = 14;
ClientID client_id = ClientID::FromRandom();
// Prepare the first resource map: data_map1.
auto cpu_data = std::make_shared<RayResourceT>();
cpu_data->resource_name = "CPU";
cpu_data->resource_capacity = 100;
auto gpu_data = std::make_shared<RayResourceT>();
gpu_data->resource_name = "GPU";
gpu_data->resource_capacity = 2;
DynamicResourceTable::DataMap data_map1;
data_map1.emplace("CPU", cpu_data);
data_map1.emplace("GPU", gpu_data);
// Prepare the second resource map: data_map2 which decreases CPU,
// increases GPU and add a new CUSTOM compared to data_map1.
auto data_cpu = std::make_shared<RayResourceT>();
data_cpu->resource_name = "CPU";
data_cpu->resource_capacity = 50;
auto data_gpu = std::make_shared<RayResourceT>();
data_gpu->resource_name = "GPU";
data_gpu->resource_capacity = 10;
auto data_custom = std::make_shared<RayResourceT>();
data_custom->resource_name = "CUSTOM";
data_custom->resource_capacity = 2;
DynamicResourceTable::DataMap data_map2;
data_map2.emplace("CPU", data_cpu);
data_map2.emplace("GPU", data_gpu);
data_map2.emplace("CUSTOM", data_custom);
data_map2["CPU"]->resource_capacity = 50;
// This is a common comparison function for the test.
auto compare_test = [](const DynamicResourceTable::DataMap &data1,
const DynamicResourceTable::DataMap &data2) {
ASSERT_EQ(data1.size(), data2.size());
for (const auto &data : data1) {
auto iter = data2.find(data.first);
ASSERT_TRUE(iter != data2.end());
ASSERT_EQ(iter->second->resource_name, data.second->resource_name);
ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity);
}
};
auto subscribe_callback = [](AsyncGcsClient *client) {
ASSERT_TRUE(true);
test->IncrementNumCallbacks();
};
auto notification_callback = [data_map1, data_map2, compare_test](
AsyncGcsClient *client, const ClientID &id, const GcsChangeMode change_mode,
const DynamicResourceTable::DataMap &data) {
if (change_mode == GcsChangeMode::REMOVE) {
ASSERT_EQ(data.size(), 2);
ASSERT_TRUE(data.find("GPU") != data.end());
ASSERT_TRUE(data.find("CUSTOM") != data.end() || data.find("CPU") != data.end());
// The key "None-Existent" will not appear in the notification.
} else {
if (data.size() == 2) {
compare_test(data_map1, data);
} else if (data.size() == 3) {
compare_test(data_map2, data);
} else {
ASSERT_TRUE(false);
}
}
test->IncrementNumCallbacks();
// It is not sure which of the notification or lookup callback will come first.
if (test->NumCallbacks() == expected_count) {
test->Stop();
}
};
// Step 0: Subscribe the change of the hash table.
RAY_CHECK_OK(client->resource_table().Subscribe(
driver_id, ClientID::Nil(), notification_callback, subscribe_callback));
RAY_CHECK_OK(client->resource_table().RequestNotifications(
driver_id, client_id, client->client_table().GetLocalClientId()));
// Step 1: Add elements to the hash table.
auto update_callback1 = [data_map1, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map1, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(
client->resource_table().Update(driver_id, client_id, data_map1, update_callback1));
auto lookup_callback1 = [data_map1, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map1, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback1));
// Step 2: Decrease one element, increase one and add a new one.
RAY_CHECK_OK(client->resource_table().Update(driver_id, client_id, data_map2, nullptr));
auto lookup_callback2 = [data_map2, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map2, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback2));
std::vector<std::string> delete_keys({"GPU", "CUSTOM", "None-Existent"});
auto remove_callback = [delete_keys](AsyncGcsClient *client, const ClientID &id,
const std::vector<std::string> &callback_data) {
for (int i = 0; i < callback_data.size(); ++i) {
// All deleting keys exist in this argument even if the key doesn't exist.
ASSERT_EQ(callback_data[i], delete_keys[i]);
}
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().RemoveEntries(driver_id, client_id, delete_keys,
remove_callback));
DynamicResourceTable::DataMap data_map3(data_map2);
data_map3.erase("GPU");
data_map3.erase("CUSTOM");
auto lookup_callback3 = [data_map3, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map3, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback3));
// Step 3: Reset the the resources to data_map1.
RAY_CHECK_OK(
client->resource_table().Update(driver_id, client_id, data_map1, update_callback1));
auto lookup_callback4 = [data_map1, compare_test](
AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
compare_test(data_map1, callback_data);
test->IncrementNumCallbacks();
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback4));
// Step 4: Removing all elements will remove the home Hash table from GCS.
RAY_CHECK_OK(client->resource_table().RemoveEntries(
driver_id, client_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr));
auto lookup_callback5 = [](AsyncGcsClient *client, const ClientID &id,
const DynamicResourceTable::DataMap &callback_data) {
ASSERT_EQ(callback_data.size(), 0);
test->IncrementNumCallbacks();
// It is not sure which of notification or lookup callback will come first.
if (test->NumCallbacks() == expected_count) {
test->Stop();
}
};
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback5));
test->Start();
ASSERT_EQ(test->NumCallbacks(), expected_count);
}
TEST_F(TestGcsWithAsio, TestHashTable) {
test = this;
TestHashTable(driver_id_, client_);
}
#undef TEST_MACRO
} // namespace gcs
+5 -3
View File
@@ -22,6 +22,7 @@ enum TablePrefix:int {
TASK_LEASE,
ACTOR_CHECKPOINT,
ACTOR_CHECKPOINT_ID,
NODE_RESOURCE,
}
// The channel that Add operations to the Table should be published on, if any.
@@ -37,6 +38,7 @@ enum TablePubsub:int {
ERROR_INFO,
TASK_LEASE,
DRIVER,
NODE_RESOURCE,
}
// Enum for the entry type in the ClientTable
@@ -113,13 +115,13 @@ table ResourcePair {
value: double;
}
enum GcsTableNotificationMode:int {
enum GcsChangeMode:int {
APPEND_OR_ADD = 0,
REMOVE,
}
table GcsTableEntry {
notification_mode: GcsTableNotificationMode;
table GcsEntry {
change_mode: GcsChangeMode;
id: string;
entries: [string];
}
+179 -38
View File
@@ -179,32 +179,20 @@ flatbuffers::Offset<flatbuffers::String> RedisStringToFlatbuf(
return fbb.CreateString(redis_string_str, redis_string_size);
}
/// Publish a notification for an entry update at a key. This publishes a
/// notification to all subscribers of the table, as well as every client that
/// has requested notifications for this key.
/// Helper method to publish formatted data to target channel.
///
/// \param pubsub_channel_str The pubsub channel name that notifications for
/// this key should be published to. When publishing to a specific client, the
/// channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key that the notification is about.
/// \param mode the update mode, such as append or remove.
/// \param data The appended/removed data.
/// \param data_buffer The data to publish, which is a GcsEntry buffer.
/// \return OK if there is no error during a publish.
int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str,
RedisModuleString *id, GcsTableNotificationMode notification_mode,
RedisModuleString *data) {
// Serialize the notification to send.
flatbuffers::FlatBufferBuilder fbb;
auto data_flatbuf = RedisStringToFlatbuf(fbb, data);
auto message =
CreateGcsTableEntry(fbb, notification_mode, RedisStringToFlatbuf(fbb, id),
fbb.CreateVector(&data_flatbuf, 1));
fbb.Finish(message);
int PublishDataHelper(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str,
RedisModuleString *id, RedisModuleString *data_buffer) {
// Write the data back to any subscribers that are listening to all table
// notifications.
RedisModuleCallReply *reply = RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str,
fbb.GetBufferPointer(), fbb.GetSize());
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data_buffer);
if (reply == NULL) {
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
}
@@ -221,8 +209,8 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st
// will be garbage collected by redis.
auto channel =
RedisModule_CreateString(ctx, client_channel.data(), client_channel.size());
RedisModuleCallReply *reply = RedisModule_Call(
ctx, "PUBLISH", "sb", channel, fbb.GetBufferPointer(), fbb.GetSize());
RedisModuleCallReply *reply =
RedisModule_Call(ctx, "PUBLISH", "ss", channel, data_buffer);
if (reply == NULL) {
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
}
@@ -231,6 +219,31 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
/// Publish a notification for an entry update at a key. This publishes a
/// notification to all subscribers of the table, as well as every client that
/// has requested notifications for this key.
///
/// \param pubsub_channel_str The pubsub channel name that notifications for
/// this key should be published to. When publishing to a specific client, the
/// channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key that the notification is about.
/// \param mode the update mode, such as append or remove.
/// \param data The appended/removed data.
/// \return OK if there is no error during a publish.
int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str,
RedisModuleString *id, GcsChangeMode change_mode,
RedisModuleString *data) {
// Serialize the notification to send.
flatbuffers::FlatBufferBuilder fbb;
auto data_flatbuf = RedisStringToFlatbuf(fbb, data);
auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id),
fbb.CreateVector(&data_flatbuf, 1));
fbb.Finish(message);
auto data_buffer = RedisModule_CreateString(
ctx, reinterpret_cast<char *>(fbb.GetBufferPointer()), fbb.GetSize());
return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer);
}
// RAY.TABLE_ADD:
// TableAdd_RedisCommand: the actual command handler.
// (helper) TableAdd_DoWrite: performs the write to redis state.
@@ -266,8 +279,8 @@ int TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
// All other pubsub channels write the data back directly onto the channel.
return PublishTableUpdate(ctx, pubsub_channel_str, id,
GcsTableNotificationMode::APPEND_OR_ADD, data);
return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD,
data);
} else {
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
@@ -366,8 +379,8 @@ int TableAppend_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int /*a
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
// All other pubsub channels write the data back directly onto the
// channel.
return PublishTableUpdate(ctx, pubsub_channel_str, id,
GcsTableNotificationMode::APPEND_OR_ADD, data);
return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD,
data);
} else {
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
@@ -419,10 +432,9 @@ int Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) {
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
// All other pubsub channels write the data back directly onto the
// channel.
return PublishTableUpdate(ctx, pubsub_channel_str, id,
is_add ? GcsTableNotificationMode::APPEND_OR_ADD
: GcsTableNotificationMode::REMOVE,
data);
return PublishTableUpdate(
ctx, pubsub_channel_str, id,
is_add ? GcsChangeMode::APPEND_OR_ADD : GcsChangeMode::REMOVE, data);
} else {
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
@@ -518,7 +530,125 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
/// A helper function to create and finish a GcsTableEntry, based on the
int Hash_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv) {
RedisModuleString *pubsub_channel_str = argv[2];
RedisModuleString *id = argv[3];
RedisModuleString *data = argv[4];
// Publish a message on the requested pubsub channel if necessary.
TablePubsub pubsub_channel;
REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str));
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
// All other pubsub channels write the data back directly onto the
// channel.
return PublishDataHelper(ctx, pubsub_channel_str, id, data);
} else {
return RedisModule_ReplyWithSimpleString(ctx, "OK");
}
}
/// Do the hash table write operation. This is called from by HashUpdate_RedisCommand.
///
/// \param change_mode Output the mode of the operation: APPEND_OR_ADD or REMOVE.
/// \param deleted_data Output data if the deleted data is not the same as required.
int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
GcsChangeMode *change_mode, RedisModuleString **changed_data) {
if (argc != 5) {
return RedisModule_WrongArity(ctx);
}
RedisModuleString *prefix_str = argv[1];
RedisModuleString *id = argv[3];
RedisModuleString *update_data = argv[4];
RedisModuleKey *key;
REPLY_AND_RETURN_IF_NOT_OK(OpenPrefixedKey(
&key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE, nullptr));
int type = RedisModule_KeyType(key);
REPLY_AND_RETURN_IF_FALSE(
type == REDISMODULE_KEYTYPE_HASH || type == REDISMODULE_KEYTYPE_EMPTY,
"HashUpdate_DoWrite: entries must be a hash or an empty hash");
size_t update_data_len = 0;
const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len);
auto data_vec = flatbuffers::GetRoot<GcsEntry>(update_data_buf);
*change_mode = data_vec->change_mode();
if (*change_mode == GcsChangeMode::APPEND_OR_ADD) {
// This code path means they are updating command.
size_t total_size = data_vec->entries()->size();
REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector.");
for (int i = 0; i < total_size; i += 2) {
// Reconstruct a key-value pair from a flattened list.
RedisModuleString *entry_key = RedisModule_CreateString(
ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size());
RedisModuleString *entry_value =
RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(),
data_vec->entries()->Get(i + 1)->size());
// Returning 0 if key exists(still updated), 1 if the key is created.
RAY_IGNORE_EXPR(
RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL));
}
*changed_data = update_data;
} else {
// This code path means the command wants to remove the entries.
size_t total_size = data_vec->entries()->size();
flatbuffers::FlatBufferBuilder fbb;
std::vector<flatbuffers::Offset<flatbuffers::String>> data;
for (int i = 0; i < total_size; i++) {
RedisModuleString *entry_key = RedisModule_CreateString(
ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size());
int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key,
REDISMODULE_HASH_DELETE, NULL);
if (deleted_num != 0) {
// The corresponding key is removed.
data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(),
data_vec->entries()->Get(i)->size()));
}
}
auto message =
CreateGcsEntry(fbb, data_vec->change_mode(),
fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()),
fbb.CreateVector(data));
fbb.Finish(message);
*changed_data = RedisModule_CreateString(
ctx, reinterpret_cast<char *>(fbb.GetBufferPointer()), fbb.GetSize());
auto size = RedisModule_ValueLength(key);
if (size == 0) {
REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK,
"ERR Failed to delete empty hash.");
}
}
return REDISMODULE_OK;
}
/// Update entries for a hash table.
///
/// This is called from a client with the command:
//
/// RAY.HASH_UPDATE <table_prefix> <pubsub_channel> <id> <data>
///
/// \param table_prefix The prefix string for keys in this table.
/// \param pubsub_channel The pubsub channel name that notifications for this
/// key should be published to. When publishing to a specific client, the
/// channel name should be <pubsub_channel>:<client_id>.
/// \param id The ID of the key to remove from.
/// \param data The GcsEntry flatbugger data used to update this hash table.
/// 1). For deletion, this is a list of keys.
/// 2). For updating, this is a list of pairs with each key followed by the value.
/// \return OK if the remove succeeds, or an error message string if the remove
/// fails.
int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
GcsChangeMode mode;
RedisModuleString *changed_data = nullptr;
if (HashUpdate_DoWrite(ctx, argv, argc, &mode, &changed_data) != REDISMODULE_OK) {
return REDISMODULE_ERR;
}
// Replace the data with the changed data to do the publish.
std::vector<RedisModuleString *> new_argv(argv, argv + argc);
new_argv[4] = changed_data;
return Hash_DoPublish(ctx, new_argv.data());
}
/// A helper function to create and finish a GcsEntry, based on the
/// current value or values at the given key.
///
/// \param ctx The Redis module context.
@@ -528,7 +658,7 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
/// \param prefix_str The string prefix associated with the open Redis key.
/// When parsed, this is expected to be a TablePrefix.
/// \param entry_id The UniqueID associated with the open Redis key.
/// \param fbb A flatbuffer builder used to build the GcsTableEntry.
/// \param fbb A flatbuffer builder used to build the GcsEntry.
Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
RedisModuleString *prefix_str, RedisModuleString *entry_id,
flatbuffers::FlatBufferBuilder &fbb) {
@@ -539,12 +669,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
size_t data_len = 0;
char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ);
auto data = fbb.CreateString(data_buf, data_len);
auto message = CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD,
RedisStringToFlatbuf(fbb, entry_id),
fbb.CreateVector(&data, 1));
auto message =
CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD,
RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1));
fbb.Finish(message);
} break;
case REDISMODULE_KEYTYPE_LIST:
case REDISMODULE_KEYTYPE_HASH:
case REDISMODULE_KEYTYPE_SET: {
RedisModule_CloseKey(table_key);
// Close the key before executing the command. NOTE(swang): According to
@@ -561,10 +692,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
case REDISMODULE_KEYTYPE_SET:
reply = RedisModule_Call(ctx, "SMEMBERS", "s", table_key_str);
break;
case REDISMODULE_KEYTYPE_HASH:
reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str);
break;
}
// Build the flatbuffer from the set of log entries.
if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) {
return Status::RedisError("Empty list or wrong type");
return Status::RedisError("Empty list/set/hash or wrong type");
}
std::vector<flatbuffers::Offset<flatbuffers::String>> data;
for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) {
@@ -574,13 +708,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
data.push_back(fbb.CreateString(element_str, len));
}
auto message =
CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD,
RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data));
CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD,
RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data));
fbb.Finish(message);
} break;
case REDISMODULE_KEYTYPE_EMPTY: {
auto message = CreateGcsTableEntry(
fbb, GcsTableNotificationMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id),
auto message = CreateGcsEntry(
fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id),
fbb.CreateVector(std::vector<flatbuffers::Offset<flatbuffers::String>>()));
fbb.Finish(message);
} break;
@@ -637,6 +771,7 @@ static Status DeleteKeyHelper(RedisModuleCtx *ctx, RedisModuleString *prefix_str
return Status::RedisError("Key does not exist.");
}
auto key_type = RedisModule_KeyType(delete_key);
// Set/Hash will delete itself when the length is 0.
if (key_type == REDISMODULE_KEYTYPE_STRING || key_type == REDISMODULE_KEYTYPE_LIST) {
// Current Table or Log only has this two types of entries.
RAY_RETURN_NOT_OK(
@@ -873,6 +1008,7 @@ int DebugString_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
// Wrap all Redis commands with Redis' auto memory management.
AUTO_MEMORY(TableAdd_RedisCommand);
AUTO_MEMORY(HashUpdate_RedisCommand);
AUTO_MEMORY(TableAppend_RedisCommand);
AUTO_MEMORY(SetAdd_RedisCommand);
AUTO_MEMORY(SetRemove_RedisCommand);
@@ -929,6 +1065,11 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
return REDISMODULE_ERR;
}
if (RedisModule_CreateCommand(ctx, "ray.hash_update", HashUpdate_RedisCommand,
"write pubsub", 0, 0, 0) == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications",
TableRequestNotifications_RedisCommand, "write pubsub", 0,
0, 0) == REDISMODULE_ERR) {
+157 -5
View File
@@ -92,7 +92,7 @@ Status Log<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
std::vector<DataT> results;
if (!reply.IsNil()) {
const auto data = reply.ReadAsString();
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
auto root = flatbuffers::GetRoot<GcsEntry>(data.data());
RAY_CHECK(from_flatbuf<ID>(*root->id()) == id);
for (size_t i = 0; i < root->entries()->size(); i++) {
DataT result;
@@ -114,9 +114,9 @@ Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clien
const Callback &subscribe,
const SubscriptionCallback &done) {
auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id,
const GcsTableNotificationMode notification_mode,
const GcsChangeMode change_mode,
const std::vector<DataT> &data) {
RAY_CHECK(notification_mode != GcsTableNotificationMode::REMOVE);
RAY_CHECK(change_mode != GcsChangeMode::REMOVE);
subscribe(client, id, data);
};
return Subscribe(driver_id, client_id, subscribe_wrapper, done);
@@ -141,7 +141,7 @@ Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clien
// Data is provided. This is the callback for a message.
if (subscribe != nullptr) {
// Parse the notification.
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
auto root = flatbuffers::GetRoot<GcsEntry>(data.data());
ID id;
if (root->id()->size() > 0) {
id = from_flatbuf<ID>(*root->id());
@@ -153,7 +153,7 @@ Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clien
data_root->UnPackTo(&result);
results.emplace_back(std::move(result));
}
subscribe(client_, id, root->notification_mode(), results);
subscribe(client_, id, root->change_mode(), results);
}
}
};
@@ -339,6 +339,155 @@ std::string Set<ID, Data>::DebugString() const {
return result.str();
}
template <typename ID, typename Data>
Status Hash<ID, Data>::Update(const DriverID &driver_id, const ID &id,
const DataMap &data_map, const HashCallback &done) {
num_adds_++;
auto callback = [this, id, data_map, done](const CallbackReply &reply) {
if (done != nullptr) {
(done)(client_, id, data_map);
}
};
flatbuffers::FlatBufferBuilder fbb;
std::vector<flatbuffers::Offset<flatbuffers::String>> data_vec;
data_vec.reserve(data_map.size() * 2);
for (auto const &pair : data_map) {
// Add the key.
data_vec.push_back(fbb.CreateString(pair.first));
flatbuffers::FlatBufferBuilder fbb_data;
fbb_data.ForceDefaults(true);
fbb_data.Finish(Data::Pack(fbb_data, pair.second.get()));
std::string data(reinterpret_cast<char *>(fbb_data.GetBufferPointer()),
fbb_data.GetSize());
// Add the value.
data_vec.push_back(fbb.CreateString(data));
}
fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD,
fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec)));
return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(),
fbb.GetSize(), prefix_, pubsub_channel_,
std::move(callback));
}
template <typename ID, typename Data>
Status Hash<ID, Data>::RemoveEntries(const DriverID &driver_id, const ID &id,
const std::vector<std::string> &keys,
const HashRemoveCallback &remove_callback) {
num_removes_++;
auto callback = [this, id, keys, remove_callback](const CallbackReply &reply) {
if (remove_callback != nullptr) {
(remove_callback)(client_, id, keys);
}
};
flatbuffers::FlatBufferBuilder fbb;
std::vector<flatbuffers::Offset<flatbuffers::String>> data_vec;
data_vec.reserve(keys.size());
// Add the keys.
for (auto const &key : keys) {
data_vec.push_back(fbb.CreateString(key));
}
fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()),
fbb.CreateVector(data_vec)));
return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(),
fbb.GetSize(), prefix_, pubsub_channel_,
std::move(callback));
}
template <typename ID, typename Data>
std::string Hash<ID, Data>::DebugString() const {
std::stringstream result;
result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_
<< ", num removes: " << num_removes_;
return result.str();
}
template <typename ID, typename Data>
Status Hash<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
const HashCallback &lookup) {
num_lookups_++;
auto callback = [this, id, lookup](const CallbackReply &reply) {
if (lookup != nullptr) {
DataMap results;
if (!reply.IsNil()) {
const auto data = reply.ReadAsString();
auto root = flatbuffers::GetRoot<GcsEntry>(data.data());
RAY_CHECK(from_flatbuf<ID>(*root->id()) == id);
RAY_CHECK(root->entries()->size() % 2 == 0);
for (size_t i = 0; i < root->entries()->size(); i += 2) {
std::string key(root->entries()->Get(i)->data(),
root->entries()->Get(i)->size());
auto result = std::make_shared<DataT>();
auto data_root =
flatbuffers::GetRoot<Data>(root->entries()->Get(i + 1)->data());
data_root->UnPackTo(result.get());
results.emplace(key, std::move(result));
}
}
lookup(client_, id, results);
}
};
std::vector<uint8_t> nil;
return GetRedisContext(id)->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(),
prefix_, pubsub_channel_, std::move(callback));
}
template <typename ID, typename Data>
Status Hash<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &client_id,
const HashNotificationCallback &subscribe,
const SubscriptionCallback &done) {
RAY_CHECK(subscribe_callback_index_ == -1)
<< "Client called Subscribe twice on the same table";
auto callback = [this, subscribe, done](const CallbackReply &reply) {
const auto data = reply.ReadAsPubsubData();
if (data.empty()) {
// No notification data is provided. This is the callback for the
// initial subscription request.
if (done != nullptr) {
done(client_);
}
} else {
// Data is provided. This is the callback for a message.
if (subscribe != nullptr) {
// Parse the notification.
auto root = flatbuffers::GetRoot<GcsEntry>(data.data());
DataMap data_map;
ID id;
if (root->id()->size() > 0) {
id = from_flatbuf<ID>(*root->id());
}
if (root->change_mode() == GcsChangeMode::REMOVE) {
for (size_t i = 0; i < root->entries()->size(); i++) {
std::string key(root->entries()->Get(i)->data(),
root->entries()->Get(i)->size());
data_map.emplace(key, std::shared_ptr<DataT>());
}
} else {
RAY_CHECK(root->entries()->size() % 2 == 0);
for (size_t i = 0; i < root->entries()->size(); i += 2) {
std::string key(root->entries()->Get(i)->data(),
root->entries()->Get(i)->size());
auto result = std::make_shared<DataT>();
auto data_root =
flatbuffers::GetRoot<Data>(root->entries()->Get(i + 1)->data());
data_root->UnPackTo(result.get());
data_map.emplace(key, std::move(result));
}
}
subscribe(client_, id, root->change_mode(), data_map);
}
}
};
subscribe_callback_index_ = 1;
for (auto &context : shard_contexts_) {
RAY_RETURN_NOT_OK(context->SubscribeAsync(client_id, pubsub_channel_, callback,
&subscribe_callback_index_));
}
return Status::OK();
}
Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type,
const std::string &error_message, double timestamp) {
auto data = std::make_shared<ErrorTableDataT>();
@@ -696,6 +845,9 @@ template class Log<UniqueID, ProfileTableData>;
template class Table<ActorCheckpointID, ActorCheckpointData>;
template class Table<ActorID, ActorCheckpointIdData>;
template class Log<ClientID, RayResource>;
template class Hash<ClientID, RayResource>;
} // namespace gcs
} // namespace ray
+155 -4
View File
@@ -75,9 +75,9 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
using DataT = typename Data::NativeTableType;
using Callback = std::function<void(AsyncGcsClient *client, const ID &id,
const std::vector<DataT> &data)>;
using NotificationCallback = std::function<void(
AsyncGcsClient *client, const ID &id,
const GcsTableNotificationMode notification_mode, const std::vector<DataT> &data)>;
using NotificationCallback = std::function<void(AsyncGcsClient *client, const ID &id,
const GcsChangeMode change_mode,
const std::vector<DataT> &data)>;
/// The callback to call when a write to a key succeeds.
using WriteCallback = typename LogInterface<ID, Data>::WriteCallback;
/// The callback to call when a SUBSCRIBE call completes and we are ready to
@@ -214,7 +214,7 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
/// to subscribe to all modifications, or to subscribe only to keys that it
/// requests notifications for. This may only be called once per Log
/// instance. This function is different from public version due to
/// an additional parameter notification_mode in NotificationCallback. Therefore this
/// an additional parameter change_mode in NotificationCallback. Therefore this
/// function supports notifications of remove operations.
///
/// \param driver_id The ID of the job (= driver).
@@ -451,6 +451,157 @@ class Set : private Log<ID, Data>,
using Log<ID, Data>::num_lookups_;
};
template <typename ID, typename Data>
class HashInterface {
public:
using DataT = typename Data::NativeTableType;
using DataMap = std::unordered_map<std::string, std::shared_ptr<DataT>>;
// Reuse Log's SubscriptionCallback when Subscribe is successfully called.
using SubscriptionCallback = typename Log<ID, Data>::SubscriptionCallback;
/// The callback function used by function Update & Lookup.
///
/// \param client The client on which the RemoveEntries is called.
/// \param id The ID of the Hash Table whose entries are removed.
/// \param data Map data contains the change to the Hash Table.
/// \return Void
using HashCallback =
std::function<void(AsyncGcsClient *client, const ID &id, const DataMap &pairs)>;
/// The callback function used by function RemoveEntries.
///
/// \param client The client on which the RemoveEntries is called.
/// \param id The ID of the Hash Table whose entries are removed.
/// \param keys The keys that are moved from this Hash Table.
/// \return Void
using HashRemoveCallback = std::function<void(AsyncGcsClient *client, const ID &id,
const std::vector<std::string> &keys)>;
/// The notification function used by function Subscribe.
///
/// \param client The client on which the Subscribe is called.
/// \param change_mode The mode to identify the data is removed or updated.
/// \param data Map data contains the change to the Hash Table.
/// \return Void
using HashNotificationCallback =
std::function<void(AsyncGcsClient *client, const ID &id,
const GcsChangeMode change_mode, const DataMap &data)>;
/// Add entries of a hash table.
///
/// \param driver_id The ID of the job (= driver).
/// \param id The ID of the data that is added to the GCS.
/// \param pairs Map data to add to the hash table.
/// \param done HashCallback that is called once the request data has been written to
/// the GCS.
/// \return Status
virtual Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs,
const HashCallback &done) = 0;
/// Remove entries from the hash table.
///
/// \param driver_id The ID of the job (= driver).
/// \param id The ID of the data that is removed from the GCS.
/// \param keys The entry keys of the hash table.
/// \param remove_callback HashRemoveCallback that is called once the data has been
/// written to the GCS no matter whether the key exists in the hash table.
/// \return Status
virtual Status RemoveEntries(const DriverID &driver_id, const ID &id,
const std::vector<std::string> &keys,
const HashRemoveCallback &remove_callback) = 0;
/// Lookup the map data of a hash table.
///
/// \param driver_id The ID of the job (= driver).
/// \param id The ID of the data that is looked up in the GCS.
/// \param lookup HashCallback that is called after lookup. If the callback is
/// called with an empty hash table, then there was no data in the callback.
/// \return Status
virtual Status Lookup(const DriverID &driver_id, const ID &id,
const HashCallback &lookup) = 0;
/// Subscribe to any Update or Remove operations to this hash table.
///
/// \param driver_id The ID of the driver.
/// \param client_id The type of update to listen to. If this is nil, then a
/// message for each Update to the table will be received. Else, only
/// messages for the given client will be received. In the latter
/// case, the client may request notifications on specific keys in the
/// table via `RequestNotifications`.
/// \param subscribe HashNotificationCallback that is called on each received message.
/// \param done SubscriptionCallback that is called when subscription is complete and
/// we are ready to receive messages.
/// \return Status
virtual Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
const HashNotificationCallback &subscribe,
const SubscriptionCallback &done) = 0;
virtual ~HashInterface(){};
};
template <typename ID, typename Data>
class Hash : private Log<ID, Data>,
public HashInterface<ID, Data>,
virtual public PubsubInterface<ID> {
public:
using DataT = typename Log<ID, Data>::DataT;
using DataMap = std::unordered_map<std::string, std::shared_ptr<DataT>>;
using HashCallback = typename HashInterface<ID, Data>::HashCallback;
using HashRemoveCallback = typename HashInterface<ID, Data>::HashRemoveCallback;
using HashNotificationCallback =
typename HashInterface<ID, Data>::HashNotificationCallback;
using SubscriptionCallback = typename Log<ID, Data>::SubscriptionCallback;
Hash(const std::vector<std::shared_ptr<RedisContext>> &contexts, AsyncGcsClient *client)
: Log<ID, Data>(contexts, client) {}
using Log<ID, Data>::RequestNotifications;
using Log<ID, Data>::CancelNotifications;
Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs,
const HashCallback &done) override;
Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
const HashNotificationCallback &subscribe,
const SubscriptionCallback &done) override;
Status Lookup(const DriverID &driver_id, const ID &id,
const HashCallback &lookup) override;
Status RemoveEntries(const DriverID &driver_id, const ID &id,
const std::vector<std::string> &keys,
const HashRemoveCallback &remove_callback) override;
/// Returns debug string for class.
///
/// \return string.
std::string DebugString() const;
protected:
using Log<ID, Data>::shard_contexts_;
using Log<ID, Data>::client_;
using Log<ID, Data>::pubsub_channel_;
using Log<ID, Data>::prefix_;
using Log<ID, Data>::subscribe_callback_index_;
using Log<ID, Data>::GetRedisContext;
int64_t num_adds_ = 0;
int64_t num_removes_ = 0;
using Log<ID, Data>::num_lookups_;
};
class DynamicResourceTable : public Hash<ClientID, RayResource> {
public:
DynamicResourceTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
AsyncGcsClient *client)
: Hash(contexts, client) {
pubsub_channel_ = TablePubsub::NODE_RESOURCE;
prefix_ = TablePrefix::NODE_RESOURCE;
};
virtual ~DynamicResourceTable(){};
};
class ObjectTable : public Set<ObjectID, ObjectTableData> {
public:
ObjectTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
+7 -9
View File
@@ -11,16 +11,16 @@ namespace {
/// Process a notification of the object table entries and store the result in
/// client_ids. This assumes that client_ids already contains the result of the
/// object table entries up to but not including this notification.
void UpdateObjectLocations(const GcsTableNotificationMode notification_mode,
void UpdateObjectLocations(const GcsChangeMode change_mode,
const std::vector<ObjectTableDataT> &location_updates,
const ray::gcs::ClientTable &client_table,
std::unordered_set<ClientID> *client_ids) {
// location_updates contains the updates of locations of the object.
// with GcsTableNotificationMode, we can determine whether the update mode is
// with GcsChangeMode, we can determine whether the update mode is
// addition or deletion.
for (const auto &object_table_data : location_updates) {
ClientID client_id = ClientID::FromBinary(object_table_data.manager);
if (notification_mode != GcsTableNotificationMode::REMOVE) {
if (change_mode != GcsChangeMode::REMOVE) {
client_ids->insert(client_id);
} else {
client_ids->erase(client_id);
@@ -41,7 +41,7 @@ void UpdateObjectLocations(const GcsTableNotificationMode notification_mode,
void ObjectDirectory::RegisterBackend() {
auto object_notification_callback = [this](
gcs::AsyncGcsClient *client, const ObjectID &object_id,
const GcsTableNotificationMode notification_mode,
const GcsChangeMode change_mode,
const std::vector<ObjectTableDataT> &location_updates) {
// Objects are added to this map in SubscribeObjectLocations.
auto it = listeners_.find(object_id);
@@ -54,8 +54,7 @@ void ObjectDirectory::RegisterBackend() {
it->second.subscribed = true;
// Update entries for this object.
UpdateObjectLocations(notification_mode, location_updates,
gcs_client_->client_table(),
UpdateObjectLocations(change_mode, location_updates, gcs_client_->client_table(),
&it->second.current_object_locations);
// Copy the callbacks so that the callbacks can unsubscribe without interrupting
// looping over the callbacks.
@@ -135,8 +134,7 @@ void ObjectDirectory::HandleClientRemoved(const ClientID &client_id) {
if (listener.second.current_object_locations.count(client_id) > 0) {
// If the subscribed object has the removed client as a location, update
// its locations with an empty update so that the location will be removed.
UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, {},
gcs_client_->client_table(),
UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, {}, gcs_client_->client_table(),
&listener.second.current_object_locations);
// Re-call all the subscribed callbacks for the object, since its
// locations have changed.
@@ -213,7 +211,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_updates) {
// Build the set of current locations based on the entries in the log.
std::unordered_set<ClientID> client_ids;
UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, location_updates,
UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates,
gcs_client_->client_table(), &client_ids);
// It is safe to call the callback directly since this is already running
// in the GCS client's lookup callback stack.