mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:00:36 +08:00
[C++] Add hash table to Redis-Module (#4911)
This commit is contained in:
+1
-1
@@ -535,7 +535,7 @@ flatbuffer_py_library(
|
||||
"ErrorTableData.py",
|
||||
"ErrorType.py",
|
||||
"FunctionTableData.py",
|
||||
"GcsTableEntry.py",
|
||||
"GcsEntry.py",
|
||||
"HeartbeatBatchTableData.py",
|
||||
"HeartbeatTableData.py",
|
||||
"Language.py",
|
||||
|
||||
+1
-1
@@ -29,7 +29,7 @@ MOCK_MODULES = [
|
||||
"ray.core.generated.EntryType",
|
||||
"ray.core.generated.ErrorTableData",
|
||||
"ray.core.generated.ErrorType",
|
||||
"ray.core.generated.GcsTableEntry",
|
||||
"ray.core.generated.GcsEntry",
|
||||
"ray.core.generated.HeartbeatBatchTableData",
|
||||
"ray.core.generated.HeartbeatTableData",
|
||||
"ray.core.generated.Language",
|
||||
|
||||
+1
-1
@@ -160,7 +160,7 @@ flatbuffers_generated_files = [
|
||||
"ErrorTableData.java",
|
||||
"ErrorType.java",
|
||||
"FunctionTableData.java",
|
||||
"GcsTableEntry.java",
|
||||
"GcsEntry.java",
|
||||
"HeartbeatBatchTableData.java",
|
||||
"HeartbeatTableData.java",
|
||||
"Language.java",
|
||||
|
||||
@@ -9,7 +9,7 @@ from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData
|
||||
from ray.core.generated.ClientTableData import ClientTableData
|
||||
from ray.core.generated.DriverTableData import DriverTableData
|
||||
from ray.core.generated.ErrorTableData import ErrorTableData
|
||||
from ray.core.generated.GcsTableEntry import GcsTableEntry
|
||||
from ray.core.generated.GcsEntry import GcsEntry
|
||||
from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData
|
||||
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
|
||||
from ray.core.generated.Language import Language
|
||||
@@ -25,7 +25,7 @@ __all__ = [
|
||||
"ClientTableData",
|
||||
"DriverTableData",
|
||||
"ErrorTableData",
|
||||
"GcsTableEntry",
|
||||
"GcsEntry",
|
||||
"HeartbeatBatchTableData",
|
||||
"HeartbeatTableData",
|
||||
"Language",
|
||||
|
||||
@@ -101,8 +101,7 @@ class Monitor(object):
|
||||
def xray_heartbeat_batch_handler(self, unused_channel, data):
|
||||
"""Handle an xray heartbeat batch message from Redis."""
|
||||
|
||||
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
data, 0)
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
|
||||
heartbeat_data = gcs_entries.Entries(0)
|
||||
|
||||
message = (ray.gcs_utils.HeartbeatBatchTableData.
|
||||
@@ -208,8 +207,7 @@ class Monitor(object):
|
||||
unused_channel: The message channel.
|
||||
data: The message data.
|
||||
"""
|
||||
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
data, 0)
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0)
|
||||
driver_data = gcs_entries.Entries(0)
|
||||
message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData(
|
||||
driver_data, 0)
|
||||
|
||||
+8
-14
@@ -41,7 +41,7 @@ def _parse_client_table(redis_client):
|
||||
return []
|
||||
|
||||
node_info = {}
|
||||
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0)
|
||||
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
|
||||
ordered_client_ids = []
|
||||
|
||||
@@ -248,8 +248,7 @@ class GlobalState(object):
|
||||
object_id.binary())
|
||||
if message is None:
|
||||
return {}
|
||||
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
message, 0)
|
||||
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
|
||||
assert gcs_entry.EntriesLength() > 0
|
||||
|
||||
@@ -307,8 +306,7 @@ class GlobalState(object):
|
||||
"", task_id.binary())
|
||||
if message is None:
|
||||
return {}
|
||||
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
message, 0)
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
|
||||
assert gcs_entries.EntriesLength() == 1
|
||||
|
||||
@@ -431,8 +429,7 @@ class GlobalState(object):
|
||||
if message is None:
|
||||
return []
|
||||
|
||||
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
message, 0)
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
|
||||
profile_events = []
|
||||
for i in range(gcs_entries.EntriesLength()):
|
||||
@@ -815,9 +812,8 @@ class GlobalState(object):
|
||||
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
|
||||
continue
|
||||
data = raw_message["data"]
|
||||
gcs_entries = (
|
||||
ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
data, 0))
|
||||
gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
|
||||
data, 0))
|
||||
heartbeat_data = gcs_entries.Entries(0)
|
||||
message = (ray.gcs_utils.HeartbeatTableData.
|
||||
GetRootAsHeartbeatTableData(heartbeat_data, 0))
|
||||
@@ -871,8 +867,7 @@ class GlobalState(object):
|
||||
if message is None:
|
||||
return []
|
||||
|
||||
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
message, 0)
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
error_messages = []
|
||||
for i in range(gcs_entries.EntriesLength()):
|
||||
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
|
||||
@@ -934,8 +929,7 @@ class GlobalState(object):
|
||||
)
|
||||
if message is None:
|
||||
return None
|
||||
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
message, 0)
|
||||
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0)
|
||||
entry = (
|
||||
ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData(
|
||||
gcs_entry.Entries(0), 0))
|
||||
|
||||
@@ -1656,7 +1656,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
if msg is None:
|
||||
threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
|
||||
gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(
|
||||
msg["data"], 0)
|
||||
assert gcs_entry.EntriesLength() == 1
|
||||
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
|
||||
|
||||
@@ -120,6 +120,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
|
||||
profile_table_.reset(new ProfileTable(shard_contexts_, this));
|
||||
actor_checkpoint_table_.reset(new ActorCheckpointTable(shard_contexts_, this));
|
||||
actor_checkpoint_id_table_.reset(new ActorCheckpointIdTable(shard_contexts_, this));
|
||||
resource_table_.reset(new DynamicResourceTable({primary_context_}, this));
|
||||
command_type_ = command_type;
|
||||
|
||||
// TODO(swang): Call the client table's Connect() method here. To do this,
|
||||
@@ -229,6 +230,8 @@ ActorCheckpointIdTable &AsyncGcsClient::actor_checkpoint_id_table() {
|
||||
return *actor_checkpoint_id_table_;
|
||||
}
|
||||
|
||||
DynamicResourceTable &AsyncGcsClient::resource_table() { return *resource_table_; }
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -62,6 +62,7 @@ class RAY_EXPORT AsyncGcsClient {
|
||||
ProfileTable &profile_table();
|
||||
ActorCheckpointTable &actor_checkpoint_table();
|
||||
ActorCheckpointIdTable &actor_checkpoint_id_table();
|
||||
DynamicResourceTable &resource_table();
|
||||
|
||||
// We also need something to export generic code to run on workers from the
|
||||
// driver (to set the PYTHONPATH)
|
||||
@@ -94,6 +95,7 @@ class RAY_EXPORT AsyncGcsClient {
|
||||
std::unique_ptr<ClientTable> client_table_;
|
||||
std::unique_ptr<ActorCheckpointTable> actor_checkpoint_table_;
|
||||
std::unique_ptr<ActorCheckpointIdTable> actor_checkpoint_id_table_;
|
||||
std::unique_ptr<DynamicResourceTable> resource_table_;
|
||||
// The following contexts write to the data shard
|
||||
std::vector<std::shared_ptr<RedisContext>> shard_contexts_;
|
||||
std::vector<std::unique_ptr<RedisAsioClient>> shard_asio_async_clients_;
|
||||
|
||||
+162
-10
@@ -657,13 +657,12 @@ void TestSetSubscribeAll(const DriverID &driver_id,
|
||||
|
||||
// Callback for a notification.
|
||||
auto notification_callback = [object_ids, managers](
|
||||
gcs::AsyncGcsClient *client, const ObjectID &id,
|
||||
const GcsTableNotificationMode notification_mode,
|
||||
gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode,
|
||||
const std::vector<ObjectTableDataT> data) {
|
||||
if (test->NumCallbacks() < 3 * 3) {
|
||||
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
|
||||
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
|
||||
} else {
|
||||
ASSERT_EQ(notification_mode, GcsTableNotificationMode::REMOVE);
|
||||
ASSERT_EQ(change_mode, GcsChangeMode::REMOVE);
|
||||
}
|
||||
ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]);
|
||||
// Check that we get notifications in the same order as the writes.
|
||||
@@ -894,10 +893,9 @@ void TestSetSubscribeId(const DriverID &driver_id,
|
||||
// The callback for a notification from the table. This should only be
|
||||
// received for keys that we requested notifications for.
|
||||
auto notification_callback = [object_id2, managers2](
|
||||
gcs::AsyncGcsClient *client, const ObjectID &id,
|
||||
const GcsTableNotificationMode notification_mode,
|
||||
gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode,
|
||||
const std::vector<ObjectTableDataT> &data) {
|
||||
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
|
||||
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
|
||||
// Check that we only get notifications for the requested key.
|
||||
ASSERT_EQ(id, object_id2);
|
||||
// Check that we get notifications in the same order as the writes.
|
||||
@@ -1111,10 +1109,9 @@ void TestSetSubscribeCancel(const DriverID &driver_id,
|
||||
// The callback for a notification from the object table. This should only be
|
||||
// received for the object that we requested notifications for.
|
||||
auto notification_callback = [object_id, managers](
|
||||
gcs::AsyncGcsClient *client, const ObjectID &id,
|
||||
const GcsTableNotificationMode notification_mode,
|
||||
gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode,
|
||||
const std::vector<ObjectTableDataT> &data) {
|
||||
ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD);
|
||||
ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD);
|
||||
ASSERT_EQ(id, object_id);
|
||||
// Check that we get a duplicate notification for the first write. We get a
|
||||
// duplicate notification because notifications
|
||||
@@ -1307,6 +1304,161 @@ TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) {
|
||||
TestClientTableMarkDisconnected(driver_id_, client_);
|
||||
}
|
||||
|
||||
void TestHashTable(const DriverID &driver_id,
|
||||
std::shared_ptr<gcs::AsyncGcsClient> client) {
|
||||
const int expected_count = 14;
|
||||
ClientID client_id = ClientID::FromRandom();
|
||||
// Prepare the first resource map: data_map1.
|
||||
auto cpu_data = std::make_shared<RayResourceT>();
|
||||
cpu_data->resource_name = "CPU";
|
||||
cpu_data->resource_capacity = 100;
|
||||
auto gpu_data = std::make_shared<RayResourceT>();
|
||||
gpu_data->resource_name = "GPU";
|
||||
gpu_data->resource_capacity = 2;
|
||||
DynamicResourceTable::DataMap data_map1;
|
||||
data_map1.emplace("CPU", cpu_data);
|
||||
data_map1.emplace("GPU", gpu_data);
|
||||
// Prepare the second resource map: data_map2 which decreases CPU,
|
||||
// increases GPU and add a new CUSTOM compared to data_map1.
|
||||
auto data_cpu = std::make_shared<RayResourceT>();
|
||||
data_cpu->resource_name = "CPU";
|
||||
data_cpu->resource_capacity = 50;
|
||||
auto data_gpu = std::make_shared<RayResourceT>();
|
||||
data_gpu->resource_name = "GPU";
|
||||
data_gpu->resource_capacity = 10;
|
||||
auto data_custom = std::make_shared<RayResourceT>();
|
||||
data_custom->resource_name = "CUSTOM";
|
||||
data_custom->resource_capacity = 2;
|
||||
DynamicResourceTable::DataMap data_map2;
|
||||
data_map2.emplace("CPU", data_cpu);
|
||||
data_map2.emplace("GPU", data_gpu);
|
||||
data_map2.emplace("CUSTOM", data_custom);
|
||||
data_map2["CPU"]->resource_capacity = 50;
|
||||
// This is a common comparison function for the test.
|
||||
auto compare_test = [](const DynamicResourceTable::DataMap &data1,
|
||||
const DynamicResourceTable::DataMap &data2) {
|
||||
ASSERT_EQ(data1.size(), data2.size());
|
||||
for (const auto &data : data1) {
|
||||
auto iter = data2.find(data.first);
|
||||
ASSERT_TRUE(iter != data2.end());
|
||||
ASSERT_EQ(iter->second->resource_name, data.second->resource_name);
|
||||
ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity);
|
||||
}
|
||||
};
|
||||
auto subscribe_callback = [](AsyncGcsClient *client) {
|
||||
ASSERT_TRUE(true);
|
||||
test->IncrementNumCallbacks();
|
||||
};
|
||||
auto notification_callback = [data_map1, data_map2, compare_test](
|
||||
AsyncGcsClient *client, const ClientID &id, const GcsChangeMode change_mode,
|
||||
const DynamicResourceTable::DataMap &data) {
|
||||
if (change_mode == GcsChangeMode::REMOVE) {
|
||||
ASSERT_EQ(data.size(), 2);
|
||||
ASSERT_TRUE(data.find("GPU") != data.end());
|
||||
ASSERT_TRUE(data.find("CUSTOM") != data.end() || data.find("CPU") != data.end());
|
||||
// The key "None-Existent" will not appear in the notification.
|
||||
} else {
|
||||
if (data.size() == 2) {
|
||||
compare_test(data_map1, data);
|
||||
} else if (data.size() == 3) {
|
||||
compare_test(data_map2, data);
|
||||
} else {
|
||||
ASSERT_TRUE(false);
|
||||
}
|
||||
}
|
||||
test->IncrementNumCallbacks();
|
||||
// It is not sure which of the notification or lookup callback will come first.
|
||||
if (test->NumCallbacks() == expected_count) {
|
||||
test->Stop();
|
||||
}
|
||||
};
|
||||
// Step 0: Subscribe the change of the hash table.
|
||||
RAY_CHECK_OK(client->resource_table().Subscribe(
|
||||
driver_id, ClientID::Nil(), notification_callback, subscribe_callback));
|
||||
RAY_CHECK_OK(client->resource_table().RequestNotifications(
|
||||
driver_id, client_id, client->client_table().GetLocalClientId()));
|
||||
|
||||
// Step 1: Add elements to the hash table.
|
||||
auto update_callback1 = [data_map1, compare_test](
|
||||
AsyncGcsClient *client, const ClientID &id,
|
||||
const DynamicResourceTable::DataMap &callback_data) {
|
||||
compare_test(data_map1, callback_data);
|
||||
test->IncrementNumCallbacks();
|
||||
};
|
||||
RAY_CHECK_OK(
|
||||
client->resource_table().Update(driver_id, client_id, data_map1, update_callback1));
|
||||
auto lookup_callback1 = [data_map1, compare_test](
|
||||
AsyncGcsClient *client, const ClientID &id,
|
||||
const DynamicResourceTable::DataMap &callback_data) {
|
||||
compare_test(data_map1, callback_data);
|
||||
test->IncrementNumCallbacks();
|
||||
};
|
||||
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback1));
|
||||
|
||||
// Step 2: Decrease one element, increase one and add a new one.
|
||||
RAY_CHECK_OK(client->resource_table().Update(driver_id, client_id, data_map2, nullptr));
|
||||
auto lookup_callback2 = [data_map2, compare_test](
|
||||
AsyncGcsClient *client, const ClientID &id,
|
||||
const DynamicResourceTable::DataMap &callback_data) {
|
||||
compare_test(data_map2, callback_data);
|
||||
test->IncrementNumCallbacks();
|
||||
};
|
||||
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback2));
|
||||
std::vector<std::string> delete_keys({"GPU", "CUSTOM", "None-Existent"});
|
||||
auto remove_callback = [delete_keys](AsyncGcsClient *client, const ClientID &id,
|
||||
const std::vector<std::string> &callback_data) {
|
||||
for (int i = 0; i < callback_data.size(); ++i) {
|
||||
// All deleting keys exist in this argument even if the key doesn't exist.
|
||||
ASSERT_EQ(callback_data[i], delete_keys[i]);
|
||||
}
|
||||
test->IncrementNumCallbacks();
|
||||
};
|
||||
RAY_CHECK_OK(client->resource_table().RemoveEntries(driver_id, client_id, delete_keys,
|
||||
remove_callback));
|
||||
DynamicResourceTable::DataMap data_map3(data_map2);
|
||||
data_map3.erase("GPU");
|
||||
data_map3.erase("CUSTOM");
|
||||
auto lookup_callback3 = [data_map3, compare_test](
|
||||
AsyncGcsClient *client, const ClientID &id,
|
||||
const DynamicResourceTable::DataMap &callback_data) {
|
||||
compare_test(data_map3, callback_data);
|
||||
test->IncrementNumCallbacks();
|
||||
};
|
||||
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback3));
|
||||
|
||||
// Step 3: Reset the the resources to data_map1.
|
||||
RAY_CHECK_OK(
|
||||
client->resource_table().Update(driver_id, client_id, data_map1, update_callback1));
|
||||
auto lookup_callback4 = [data_map1, compare_test](
|
||||
AsyncGcsClient *client, const ClientID &id,
|
||||
const DynamicResourceTable::DataMap &callback_data) {
|
||||
compare_test(data_map1, callback_data);
|
||||
test->IncrementNumCallbacks();
|
||||
};
|
||||
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback4));
|
||||
|
||||
// Step 4: Removing all elements will remove the home Hash table from GCS.
|
||||
RAY_CHECK_OK(client->resource_table().RemoveEntries(
|
||||
driver_id, client_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr));
|
||||
auto lookup_callback5 = [](AsyncGcsClient *client, const ClientID &id,
|
||||
const DynamicResourceTable::DataMap &callback_data) {
|
||||
ASSERT_EQ(callback_data.size(), 0);
|
||||
test->IncrementNumCallbacks();
|
||||
// It is not sure which of notification or lookup callback will come first.
|
||||
if (test->NumCallbacks() == expected_count) {
|
||||
test->Stop();
|
||||
}
|
||||
};
|
||||
RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback5));
|
||||
test->Start();
|
||||
ASSERT_EQ(test->NumCallbacks(), expected_count);
|
||||
}
|
||||
|
||||
TEST_F(TestGcsWithAsio, TestHashTable) {
|
||||
test = this;
|
||||
TestHashTable(driver_id_, client_);
|
||||
}
|
||||
|
||||
#undef TEST_MACRO
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
@@ -22,6 +22,7 @@ enum TablePrefix:int {
|
||||
TASK_LEASE,
|
||||
ACTOR_CHECKPOINT,
|
||||
ACTOR_CHECKPOINT_ID,
|
||||
NODE_RESOURCE,
|
||||
}
|
||||
|
||||
// The channel that Add operations to the Table should be published on, if any.
|
||||
@@ -37,6 +38,7 @@ enum TablePubsub:int {
|
||||
ERROR_INFO,
|
||||
TASK_LEASE,
|
||||
DRIVER,
|
||||
NODE_RESOURCE,
|
||||
}
|
||||
|
||||
// Enum for the entry type in the ClientTable
|
||||
@@ -113,13 +115,13 @@ table ResourcePair {
|
||||
value: double;
|
||||
}
|
||||
|
||||
enum GcsTableNotificationMode:int {
|
||||
enum GcsChangeMode:int {
|
||||
APPEND_OR_ADD = 0,
|
||||
REMOVE,
|
||||
}
|
||||
|
||||
table GcsTableEntry {
|
||||
notification_mode: GcsTableNotificationMode;
|
||||
table GcsEntry {
|
||||
change_mode: GcsChangeMode;
|
||||
id: string;
|
||||
entries: [string];
|
||||
}
|
||||
|
||||
@@ -179,32 +179,20 @@ flatbuffers::Offset<flatbuffers::String> RedisStringToFlatbuf(
|
||||
return fbb.CreateString(redis_string_str, redis_string_size);
|
||||
}
|
||||
|
||||
/// Publish a notification for an entry update at a key. This publishes a
|
||||
/// notification to all subscribers of the table, as well as every client that
|
||||
/// has requested notifications for this key.
|
||||
/// Helper method to publish formatted data to target channel.
|
||||
///
|
||||
/// \param pubsub_channel_str The pubsub channel name that notifications for
|
||||
/// this key should be published to. When publishing to a specific client, the
|
||||
/// channel name should be <pubsub_channel>:<client_id>.
|
||||
/// \param id The ID of the key that the notification is about.
|
||||
/// \param mode the update mode, such as append or remove.
|
||||
/// \param data The appended/removed data.
|
||||
/// \param data_buffer The data to publish, which is a GcsEntry buffer.
|
||||
/// \return OK if there is no error during a publish.
|
||||
int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str,
|
||||
RedisModuleString *id, GcsTableNotificationMode notification_mode,
|
||||
RedisModuleString *data) {
|
||||
// Serialize the notification to send.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto data_flatbuf = RedisStringToFlatbuf(fbb, data);
|
||||
auto message =
|
||||
CreateGcsTableEntry(fbb, notification_mode, RedisStringToFlatbuf(fbb, id),
|
||||
fbb.CreateVector(&data_flatbuf, 1));
|
||||
fbb.Finish(message);
|
||||
|
||||
int PublishDataHelper(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str,
|
||||
RedisModuleString *id, RedisModuleString *data_buffer) {
|
||||
// Write the data back to any subscribers that are listening to all table
|
||||
// notifications.
|
||||
RedisModuleCallReply *reply = RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str,
|
||||
fbb.GetBufferPointer(), fbb.GetSize());
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data_buffer);
|
||||
if (reply == NULL) {
|
||||
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
|
||||
}
|
||||
@@ -221,8 +209,8 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st
|
||||
// will be garbage collected by redis.
|
||||
auto channel =
|
||||
RedisModule_CreateString(ctx, client_channel.data(), client_channel.size());
|
||||
RedisModuleCallReply *reply = RedisModule_Call(
|
||||
ctx, "PUBLISH", "sb", channel, fbb.GetBufferPointer(), fbb.GetSize());
|
||||
RedisModuleCallReply *reply =
|
||||
RedisModule_Call(ctx, "PUBLISH", "ss", channel, data_buffer);
|
||||
if (reply == NULL) {
|
||||
return RedisModule_ReplyWithError(ctx, "error during PUBLISH");
|
||||
}
|
||||
@@ -231,6 +219,31 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
|
||||
/// Publish a notification for an entry update at a key. This publishes a
|
||||
/// notification to all subscribers of the table, as well as every client that
|
||||
/// has requested notifications for this key.
|
||||
///
|
||||
/// \param pubsub_channel_str The pubsub channel name that notifications for
|
||||
/// this key should be published to. When publishing to a specific client, the
|
||||
/// channel name should be <pubsub_channel>:<client_id>.
|
||||
/// \param id The ID of the key that the notification is about.
|
||||
/// \param mode the update mode, such as append or remove.
|
||||
/// \param data The appended/removed data.
|
||||
/// \return OK if there is no error during a publish.
|
||||
int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str,
|
||||
RedisModuleString *id, GcsChangeMode change_mode,
|
||||
RedisModuleString *data) {
|
||||
// Serialize the notification to send.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto data_flatbuf = RedisStringToFlatbuf(fbb, data);
|
||||
auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id),
|
||||
fbb.CreateVector(&data_flatbuf, 1));
|
||||
fbb.Finish(message);
|
||||
auto data_buffer = RedisModule_CreateString(
|
||||
ctx, reinterpret_cast<char *>(fbb.GetBufferPointer()), fbb.GetSize());
|
||||
return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer);
|
||||
}
|
||||
|
||||
// RAY.TABLE_ADD:
|
||||
// TableAdd_RedisCommand: the actual command handler.
|
||||
// (helper) TableAdd_DoWrite: performs the write to redis state.
|
||||
@@ -266,8 +279,8 @@ int TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
|
||||
|
||||
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
|
||||
// All other pubsub channels write the data back directly onto the channel.
|
||||
return PublishTableUpdate(ctx, pubsub_channel_str, id,
|
||||
GcsTableNotificationMode::APPEND_OR_ADD, data);
|
||||
return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD,
|
||||
data);
|
||||
} else {
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
@@ -366,8 +379,8 @@ int TableAppend_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int /*a
|
||||
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
|
||||
// All other pubsub channels write the data back directly onto the
|
||||
// channel.
|
||||
return PublishTableUpdate(ctx, pubsub_channel_str, id,
|
||||
GcsTableNotificationMode::APPEND_OR_ADD, data);
|
||||
return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD,
|
||||
data);
|
||||
} else {
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
@@ -419,10 +432,9 @@ int Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) {
|
||||
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
|
||||
// All other pubsub channels write the data back directly onto the
|
||||
// channel.
|
||||
return PublishTableUpdate(ctx, pubsub_channel_str, id,
|
||||
is_add ? GcsTableNotificationMode::APPEND_OR_ADD
|
||||
: GcsTableNotificationMode::REMOVE,
|
||||
data);
|
||||
return PublishTableUpdate(
|
||||
ctx, pubsub_channel_str, id,
|
||||
is_add ? GcsChangeMode::APPEND_OR_ADD : GcsChangeMode::REMOVE, data);
|
||||
} else {
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
@@ -518,7 +530,125 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
|
||||
/// A helper function to create and finish a GcsTableEntry, based on the
|
||||
int Hash_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv) {
|
||||
RedisModuleString *pubsub_channel_str = argv[2];
|
||||
RedisModuleString *id = argv[3];
|
||||
RedisModuleString *data = argv[4];
|
||||
// Publish a message on the requested pubsub channel if necessary.
|
||||
TablePubsub pubsub_channel;
|
||||
REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str));
|
||||
if (pubsub_channel != TablePubsub::NO_PUBLISH) {
|
||||
// All other pubsub channels write the data back directly onto the
|
||||
// channel.
|
||||
return PublishDataHelper(ctx, pubsub_channel_str, id, data);
|
||||
} else {
|
||||
return RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
}
|
||||
}
|
||||
|
||||
/// Do the hash table write operation. This is called from by HashUpdate_RedisCommand.
|
||||
///
|
||||
/// \param change_mode Output the mode of the operation: APPEND_OR_ADD or REMOVE.
|
||||
/// \param deleted_data Output data if the deleted data is not the same as required.
|
||||
int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
|
||||
GcsChangeMode *change_mode, RedisModuleString **changed_data) {
|
||||
if (argc != 5) {
|
||||
return RedisModule_WrongArity(ctx);
|
||||
}
|
||||
RedisModuleString *prefix_str = argv[1];
|
||||
RedisModuleString *id = argv[3];
|
||||
RedisModuleString *update_data = argv[4];
|
||||
|
||||
RedisModuleKey *key;
|
||||
REPLY_AND_RETURN_IF_NOT_OK(OpenPrefixedKey(
|
||||
&key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE, nullptr));
|
||||
int type = RedisModule_KeyType(key);
|
||||
REPLY_AND_RETURN_IF_FALSE(
|
||||
type == REDISMODULE_KEYTYPE_HASH || type == REDISMODULE_KEYTYPE_EMPTY,
|
||||
"HashUpdate_DoWrite: entries must be a hash or an empty hash");
|
||||
|
||||
size_t update_data_len = 0;
|
||||
const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len);
|
||||
|
||||
auto data_vec = flatbuffers::GetRoot<GcsEntry>(update_data_buf);
|
||||
*change_mode = data_vec->change_mode();
|
||||
if (*change_mode == GcsChangeMode::APPEND_OR_ADD) {
|
||||
// This code path means they are updating command.
|
||||
size_t total_size = data_vec->entries()->size();
|
||||
REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector.");
|
||||
for (int i = 0; i < total_size; i += 2) {
|
||||
// Reconstruct a key-value pair from a flattened list.
|
||||
RedisModuleString *entry_key = RedisModule_CreateString(
|
||||
ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size());
|
||||
RedisModuleString *entry_value =
|
||||
RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(),
|
||||
data_vec->entries()->Get(i + 1)->size());
|
||||
// Returning 0 if key exists(still updated), 1 if the key is created.
|
||||
RAY_IGNORE_EXPR(
|
||||
RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL));
|
||||
}
|
||||
*changed_data = update_data;
|
||||
} else {
|
||||
// This code path means the command wants to remove the entries.
|
||||
size_t total_size = data_vec->entries()->size();
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
std::vector<flatbuffers::Offset<flatbuffers::String>> data;
|
||||
for (int i = 0; i < total_size; i++) {
|
||||
RedisModuleString *entry_key = RedisModule_CreateString(
|
||||
ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size());
|
||||
int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key,
|
||||
REDISMODULE_HASH_DELETE, NULL);
|
||||
if (deleted_num != 0) {
|
||||
// The corresponding key is removed.
|
||||
data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(),
|
||||
data_vec->entries()->Get(i)->size()));
|
||||
}
|
||||
}
|
||||
auto message =
|
||||
CreateGcsEntry(fbb, data_vec->change_mode(),
|
||||
fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()),
|
||||
fbb.CreateVector(data));
|
||||
fbb.Finish(message);
|
||||
*changed_data = RedisModule_CreateString(
|
||||
ctx, reinterpret_cast<char *>(fbb.GetBufferPointer()), fbb.GetSize());
|
||||
auto size = RedisModule_ValueLength(key);
|
||||
if (size == 0) {
|
||||
REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK,
|
||||
"ERR Failed to delete empty hash.");
|
||||
}
|
||||
}
|
||||
return REDISMODULE_OK;
|
||||
}
|
||||
|
||||
/// Update entries for a hash table.
|
||||
///
|
||||
/// This is called from a client with the command:
|
||||
//
|
||||
/// RAY.HASH_UPDATE <table_prefix> <pubsub_channel> <id> <data>
|
||||
///
|
||||
/// \param table_prefix The prefix string for keys in this table.
|
||||
/// \param pubsub_channel The pubsub channel name that notifications for this
|
||||
/// key should be published to. When publishing to a specific client, the
|
||||
/// channel name should be <pubsub_channel>:<client_id>.
|
||||
/// \param id The ID of the key to remove from.
|
||||
/// \param data The GcsEntry flatbugger data used to update this hash table.
|
||||
/// 1). For deletion, this is a list of keys.
|
||||
/// 2). For updating, this is a list of pairs with each key followed by the value.
|
||||
/// \return OK if the remove succeeds, or an error message string if the remove
|
||||
/// fails.
|
||||
int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
|
||||
GcsChangeMode mode;
|
||||
RedisModuleString *changed_data = nullptr;
|
||||
if (HashUpdate_DoWrite(ctx, argv, argc, &mode, &changed_data) != REDISMODULE_OK) {
|
||||
return REDISMODULE_ERR;
|
||||
}
|
||||
// Replace the data with the changed data to do the publish.
|
||||
std::vector<RedisModuleString *> new_argv(argv, argv + argc);
|
||||
new_argv[4] = changed_data;
|
||||
return Hash_DoPublish(ctx, new_argv.data());
|
||||
}
|
||||
|
||||
/// A helper function to create and finish a GcsEntry, based on the
|
||||
/// current value or values at the given key.
|
||||
///
|
||||
/// \param ctx The Redis module context.
|
||||
@@ -528,7 +658,7 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
|
||||
/// \param prefix_str The string prefix associated with the open Redis key.
|
||||
/// When parsed, this is expected to be a TablePrefix.
|
||||
/// \param entry_id The UniqueID associated with the open Redis key.
|
||||
/// \param fbb A flatbuffer builder used to build the GcsTableEntry.
|
||||
/// \param fbb A flatbuffer builder used to build the GcsEntry.
|
||||
Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
|
||||
RedisModuleString *prefix_str, RedisModuleString *entry_id,
|
||||
flatbuffers::FlatBufferBuilder &fbb) {
|
||||
@@ -539,12 +669,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
|
||||
size_t data_len = 0;
|
||||
char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ);
|
||||
auto data = fbb.CreateString(data_buf, data_len);
|
||||
auto message = CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD,
|
||||
RedisStringToFlatbuf(fbb, entry_id),
|
||||
fbb.CreateVector(&data, 1));
|
||||
auto message =
|
||||
CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD,
|
||||
RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1));
|
||||
fbb.Finish(message);
|
||||
} break;
|
||||
case REDISMODULE_KEYTYPE_LIST:
|
||||
case REDISMODULE_KEYTYPE_HASH:
|
||||
case REDISMODULE_KEYTYPE_SET: {
|
||||
RedisModule_CloseKey(table_key);
|
||||
// Close the key before executing the command. NOTE(swang): According to
|
||||
@@ -561,10 +692,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
|
||||
case REDISMODULE_KEYTYPE_SET:
|
||||
reply = RedisModule_Call(ctx, "SMEMBERS", "s", table_key_str);
|
||||
break;
|
||||
case REDISMODULE_KEYTYPE_HASH:
|
||||
reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str);
|
||||
break;
|
||||
}
|
||||
// Build the flatbuffer from the set of log entries.
|
||||
if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) {
|
||||
return Status::RedisError("Empty list or wrong type");
|
||||
return Status::RedisError("Empty list/set/hash or wrong type");
|
||||
}
|
||||
std::vector<flatbuffers::Offset<flatbuffers::String>> data;
|
||||
for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) {
|
||||
@@ -574,13 +708,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key,
|
||||
data.push_back(fbb.CreateString(element_str, len));
|
||||
}
|
||||
auto message =
|
||||
CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD,
|
||||
RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data));
|
||||
CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD,
|
||||
RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data));
|
||||
fbb.Finish(message);
|
||||
} break;
|
||||
case REDISMODULE_KEYTYPE_EMPTY: {
|
||||
auto message = CreateGcsTableEntry(
|
||||
fbb, GcsTableNotificationMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id),
|
||||
auto message = CreateGcsEntry(
|
||||
fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id),
|
||||
fbb.CreateVector(std::vector<flatbuffers::Offset<flatbuffers::String>>()));
|
||||
fbb.Finish(message);
|
||||
} break;
|
||||
@@ -637,6 +771,7 @@ static Status DeleteKeyHelper(RedisModuleCtx *ctx, RedisModuleString *prefix_str
|
||||
return Status::RedisError("Key does not exist.");
|
||||
}
|
||||
auto key_type = RedisModule_KeyType(delete_key);
|
||||
// Set/Hash will delete itself when the length is 0.
|
||||
if (key_type == REDISMODULE_KEYTYPE_STRING || key_type == REDISMODULE_KEYTYPE_LIST) {
|
||||
// Current Table or Log only has this two types of entries.
|
||||
RAY_RETURN_NOT_OK(
|
||||
@@ -873,6 +1008,7 @@ int DebugString_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
|
||||
|
||||
// Wrap all Redis commands with Redis' auto memory management.
|
||||
AUTO_MEMORY(TableAdd_RedisCommand);
|
||||
AUTO_MEMORY(HashUpdate_RedisCommand);
|
||||
AUTO_MEMORY(TableAppend_RedisCommand);
|
||||
AUTO_MEMORY(SetAdd_RedisCommand);
|
||||
AUTO_MEMORY(SetRemove_RedisCommand);
|
||||
@@ -929,6 +1065,11 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
|
||||
return REDISMODULE_ERR;
|
||||
}
|
||||
|
||||
if (RedisModule_CreateCommand(ctx, "ray.hash_update", HashUpdate_RedisCommand,
|
||||
"write pubsub", 0, 0, 0) == REDISMODULE_ERR) {
|
||||
return REDISMODULE_ERR;
|
||||
}
|
||||
|
||||
if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications",
|
||||
TableRequestNotifications_RedisCommand, "write pubsub", 0,
|
||||
0, 0) == REDISMODULE_ERR) {
|
||||
|
||||
+157
-5
@@ -92,7 +92,7 @@ Status Log<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
|
||||
std::vector<DataT> results;
|
||||
if (!reply.IsNil()) {
|
||||
const auto data = reply.ReadAsString();
|
||||
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
|
||||
auto root = flatbuffers::GetRoot<GcsEntry>(data.data());
|
||||
RAY_CHECK(from_flatbuf<ID>(*root->id()) == id);
|
||||
for (size_t i = 0; i < root->entries()->size(); i++) {
|
||||
DataT result;
|
||||
@@ -114,9 +114,9 @@ Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clien
|
||||
const Callback &subscribe,
|
||||
const SubscriptionCallback &done) {
|
||||
auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id,
|
||||
const GcsTableNotificationMode notification_mode,
|
||||
const GcsChangeMode change_mode,
|
||||
const std::vector<DataT> &data) {
|
||||
RAY_CHECK(notification_mode != GcsTableNotificationMode::REMOVE);
|
||||
RAY_CHECK(change_mode != GcsChangeMode::REMOVE);
|
||||
subscribe(client, id, data);
|
||||
};
|
||||
return Subscribe(driver_id, client_id, subscribe_wrapper, done);
|
||||
@@ -141,7 +141,7 @@ Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clien
|
||||
// Data is provided. This is the callback for a message.
|
||||
if (subscribe != nullptr) {
|
||||
// Parse the notification.
|
||||
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
|
||||
auto root = flatbuffers::GetRoot<GcsEntry>(data.data());
|
||||
ID id;
|
||||
if (root->id()->size() > 0) {
|
||||
id = from_flatbuf<ID>(*root->id());
|
||||
@@ -153,7 +153,7 @@ Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clien
|
||||
data_root->UnPackTo(&result);
|
||||
results.emplace_back(std::move(result));
|
||||
}
|
||||
subscribe(client_, id, root->notification_mode(), results);
|
||||
subscribe(client_, id, root->change_mode(), results);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -339,6 +339,155 @@ std::string Set<ID, Data>::DebugString() const {
|
||||
return result.str();
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Hash<ID, Data>::Update(const DriverID &driver_id, const ID &id,
|
||||
const DataMap &data_map, const HashCallback &done) {
|
||||
num_adds_++;
|
||||
auto callback = [this, id, data_map, done](const CallbackReply &reply) {
|
||||
if (done != nullptr) {
|
||||
(done)(client_, id, data_map);
|
||||
}
|
||||
};
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
std::vector<flatbuffers::Offset<flatbuffers::String>> data_vec;
|
||||
data_vec.reserve(data_map.size() * 2);
|
||||
for (auto const &pair : data_map) {
|
||||
// Add the key.
|
||||
data_vec.push_back(fbb.CreateString(pair.first));
|
||||
flatbuffers::FlatBufferBuilder fbb_data;
|
||||
fbb_data.ForceDefaults(true);
|
||||
fbb_data.Finish(Data::Pack(fbb_data, pair.second.get()));
|
||||
std::string data(reinterpret_cast<char *>(fbb_data.GetBufferPointer()),
|
||||
fbb_data.GetSize());
|
||||
// Add the value.
|
||||
data_vec.push_back(fbb.CreateString(data));
|
||||
}
|
||||
|
||||
fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD,
|
||||
fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec)));
|
||||
return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(),
|
||||
fbb.GetSize(), prefix_, pubsub_channel_,
|
||||
std::move(callback));
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Hash<ID, Data>::RemoveEntries(const DriverID &driver_id, const ID &id,
|
||||
const std::vector<std::string> &keys,
|
||||
const HashRemoveCallback &remove_callback) {
|
||||
num_removes_++;
|
||||
auto callback = [this, id, keys, remove_callback](const CallbackReply &reply) {
|
||||
if (remove_callback != nullptr) {
|
||||
(remove_callback)(client_, id, keys);
|
||||
}
|
||||
};
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
std::vector<flatbuffers::Offset<flatbuffers::String>> data_vec;
|
||||
data_vec.reserve(keys.size());
|
||||
// Add the keys.
|
||||
for (auto const &key : keys) {
|
||||
data_vec.push_back(fbb.CreateString(key));
|
||||
}
|
||||
|
||||
fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()),
|
||||
fbb.CreateVector(data_vec)));
|
||||
return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(),
|
||||
fbb.GetSize(), prefix_, pubsub_channel_,
|
||||
std::move(callback));
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
std::string Hash<ID, Data>::DebugString() const {
|
||||
std::stringstream result;
|
||||
result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_
|
||||
<< ", num removes: " << num_removes_;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Hash<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
|
||||
const HashCallback &lookup) {
|
||||
num_lookups_++;
|
||||
auto callback = [this, id, lookup](const CallbackReply &reply) {
|
||||
if (lookup != nullptr) {
|
||||
DataMap results;
|
||||
if (!reply.IsNil()) {
|
||||
const auto data = reply.ReadAsString();
|
||||
auto root = flatbuffers::GetRoot<GcsEntry>(data.data());
|
||||
RAY_CHECK(from_flatbuf<ID>(*root->id()) == id);
|
||||
RAY_CHECK(root->entries()->size() % 2 == 0);
|
||||
for (size_t i = 0; i < root->entries()->size(); i += 2) {
|
||||
std::string key(root->entries()->Get(i)->data(),
|
||||
root->entries()->Get(i)->size());
|
||||
auto result = std::make_shared<DataT>();
|
||||
auto data_root =
|
||||
flatbuffers::GetRoot<Data>(root->entries()->Get(i + 1)->data());
|
||||
data_root->UnPackTo(result.get());
|
||||
results.emplace(key, std::move(result));
|
||||
}
|
||||
}
|
||||
lookup(client_, id, results);
|
||||
}
|
||||
};
|
||||
std::vector<uint8_t> nil;
|
||||
return GetRedisContext(id)->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(),
|
||||
prefix_, pubsub_channel_, std::move(callback));
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Hash<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
const HashNotificationCallback &subscribe,
|
||||
const SubscriptionCallback &done) {
|
||||
RAY_CHECK(subscribe_callback_index_ == -1)
|
||||
<< "Client called Subscribe twice on the same table";
|
||||
auto callback = [this, subscribe, done](const CallbackReply &reply) {
|
||||
const auto data = reply.ReadAsPubsubData();
|
||||
if (data.empty()) {
|
||||
// No notification data is provided. This is the callback for the
|
||||
// initial subscription request.
|
||||
if (done != nullptr) {
|
||||
done(client_);
|
||||
}
|
||||
} else {
|
||||
// Data is provided. This is the callback for a message.
|
||||
if (subscribe != nullptr) {
|
||||
// Parse the notification.
|
||||
auto root = flatbuffers::GetRoot<GcsEntry>(data.data());
|
||||
DataMap data_map;
|
||||
ID id;
|
||||
if (root->id()->size() > 0) {
|
||||
id = from_flatbuf<ID>(*root->id());
|
||||
}
|
||||
if (root->change_mode() == GcsChangeMode::REMOVE) {
|
||||
for (size_t i = 0; i < root->entries()->size(); i++) {
|
||||
std::string key(root->entries()->Get(i)->data(),
|
||||
root->entries()->Get(i)->size());
|
||||
data_map.emplace(key, std::shared_ptr<DataT>());
|
||||
}
|
||||
} else {
|
||||
RAY_CHECK(root->entries()->size() % 2 == 0);
|
||||
for (size_t i = 0; i < root->entries()->size(); i += 2) {
|
||||
std::string key(root->entries()->Get(i)->data(),
|
||||
root->entries()->Get(i)->size());
|
||||
auto result = std::make_shared<DataT>();
|
||||
auto data_root =
|
||||
flatbuffers::GetRoot<Data>(root->entries()->Get(i + 1)->data());
|
||||
data_root->UnPackTo(result.get());
|
||||
data_map.emplace(key, std::move(result));
|
||||
}
|
||||
}
|
||||
subscribe(client_, id, root->change_mode(), data_map);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
subscribe_callback_index_ = 1;
|
||||
for (auto &context : shard_contexts_) {
|
||||
RAY_RETURN_NOT_OK(context->SubscribeAsync(client_id, pubsub_channel_, callback,
|
||||
&subscribe_callback_index_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp) {
|
||||
auto data = std::make_shared<ErrorTableDataT>();
|
||||
@@ -696,6 +845,9 @@ template class Log<UniqueID, ProfileTableData>;
|
||||
template class Table<ActorCheckpointID, ActorCheckpointData>;
|
||||
template class Table<ActorID, ActorCheckpointIdData>;
|
||||
|
||||
template class Log<ClientID, RayResource>;
|
||||
template class Hash<ClientID, RayResource>;
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
} // namespace ray
|
||||
|
||||
+155
-4
@@ -75,9 +75,9 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
|
||||
using DataT = typename Data::NativeTableType;
|
||||
using Callback = std::function<void(AsyncGcsClient *client, const ID &id,
|
||||
const std::vector<DataT> &data)>;
|
||||
using NotificationCallback = std::function<void(
|
||||
AsyncGcsClient *client, const ID &id,
|
||||
const GcsTableNotificationMode notification_mode, const std::vector<DataT> &data)>;
|
||||
using NotificationCallback = std::function<void(AsyncGcsClient *client, const ID &id,
|
||||
const GcsChangeMode change_mode,
|
||||
const std::vector<DataT> &data)>;
|
||||
/// The callback to call when a write to a key succeeds.
|
||||
using WriteCallback = typename LogInterface<ID, Data>::WriteCallback;
|
||||
/// The callback to call when a SUBSCRIBE call completes and we are ready to
|
||||
@@ -214,7 +214,7 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
|
||||
/// to subscribe to all modifications, or to subscribe only to keys that it
|
||||
/// requests notifications for. This may only be called once per Log
|
||||
/// instance. This function is different from public version due to
|
||||
/// an additional parameter notification_mode in NotificationCallback. Therefore this
|
||||
/// an additional parameter change_mode in NotificationCallback. Therefore this
|
||||
/// function supports notifications of remove operations.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
@@ -451,6 +451,157 @@ class Set : private Log<ID, Data>,
|
||||
using Log<ID, Data>::num_lookups_;
|
||||
};
|
||||
|
||||
template <typename ID, typename Data>
|
||||
class HashInterface {
|
||||
public:
|
||||
using DataT = typename Data::NativeTableType;
|
||||
using DataMap = std::unordered_map<std::string, std::shared_ptr<DataT>>;
|
||||
// Reuse Log's SubscriptionCallback when Subscribe is successfully called.
|
||||
using SubscriptionCallback = typename Log<ID, Data>::SubscriptionCallback;
|
||||
|
||||
/// The callback function used by function Update & Lookup.
|
||||
///
|
||||
/// \param client The client on which the RemoveEntries is called.
|
||||
/// \param id The ID of the Hash Table whose entries are removed.
|
||||
/// \param data Map data contains the change to the Hash Table.
|
||||
/// \return Void
|
||||
using HashCallback =
|
||||
std::function<void(AsyncGcsClient *client, const ID &id, const DataMap &pairs)>;
|
||||
|
||||
/// The callback function used by function RemoveEntries.
|
||||
///
|
||||
/// \param client The client on which the RemoveEntries is called.
|
||||
/// \param id The ID of the Hash Table whose entries are removed.
|
||||
/// \param keys The keys that are moved from this Hash Table.
|
||||
/// \return Void
|
||||
using HashRemoveCallback = std::function<void(AsyncGcsClient *client, const ID &id,
|
||||
const std::vector<std::string> &keys)>;
|
||||
|
||||
/// The notification function used by function Subscribe.
|
||||
///
|
||||
/// \param client The client on which the Subscribe is called.
|
||||
/// \param change_mode The mode to identify the data is removed or updated.
|
||||
/// \param data Map data contains the change to the Hash Table.
|
||||
/// \return Void
|
||||
using HashNotificationCallback =
|
||||
std::function<void(AsyncGcsClient *client, const ID &id,
|
||||
const GcsChangeMode change_mode, const DataMap &data)>;
|
||||
|
||||
/// Add entries of a hash table.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param id The ID of the data that is added to the GCS.
|
||||
/// \param pairs Map data to add to the hash table.
|
||||
/// \param done HashCallback that is called once the request data has been written to
|
||||
/// the GCS.
|
||||
/// \return Status
|
||||
virtual Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs,
|
||||
const HashCallback &done) = 0;
|
||||
|
||||
/// Remove entries from the hash table.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param id The ID of the data that is removed from the GCS.
|
||||
/// \param keys The entry keys of the hash table.
|
||||
/// \param remove_callback HashRemoveCallback that is called once the data has been
|
||||
/// written to the GCS no matter whether the key exists in the hash table.
|
||||
/// \return Status
|
||||
virtual Status RemoveEntries(const DriverID &driver_id, const ID &id,
|
||||
const std::vector<std::string> &keys,
|
||||
const HashRemoveCallback &remove_callback) = 0;
|
||||
|
||||
/// Lookup the map data of a hash table.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param id The ID of the data that is looked up in the GCS.
|
||||
/// \param lookup HashCallback that is called after lookup. If the callback is
|
||||
/// called with an empty hash table, then there was no data in the callback.
|
||||
/// \return Status
|
||||
virtual Status Lookup(const DriverID &driver_id, const ID &id,
|
||||
const HashCallback &lookup) = 0;
|
||||
|
||||
/// Subscribe to any Update or Remove operations to this hash table.
|
||||
///
|
||||
/// \param driver_id The ID of the driver.
|
||||
/// \param client_id The type of update to listen to. If this is nil, then a
|
||||
/// message for each Update to the table will be received. Else, only
|
||||
/// messages for the given client will be received. In the latter
|
||||
/// case, the client may request notifications on specific keys in the
|
||||
/// table via `RequestNotifications`.
|
||||
/// \param subscribe HashNotificationCallback that is called on each received message.
|
||||
/// \param done SubscriptionCallback that is called when subscription is complete and
|
||||
/// we are ready to receive messages.
|
||||
/// \return Status
|
||||
virtual Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
const HashNotificationCallback &subscribe,
|
||||
const SubscriptionCallback &done) = 0;
|
||||
|
||||
virtual ~HashInterface(){};
|
||||
};
|
||||
|
||||
template <typename ID, typename Data>
|
||||
class Hash : private Log<ID, Data>,
|
||||
public HashInterface<ID, Data>,
|
||||
virtual public PubsubInterface<ID> {
|
||||
public:
|
||||
using DataT = typename Log<ID, Data>::DataT;
|
||||
using DataMap = std::unordered_map<std::string, std::shared_ptr<DataT>>;
|
||||
using HashCallback = typename HashInterface<ID, Data>::HashCallback;
|
||||
using HashRemoveCallback = typename HashInterface<ID, Data>::HashRemoveCallback;
|
||||
using HashNotificationCallback =
|
||||
typename HashInterface<ID, Data>::HashNotificationCallback;
|
||||
using SubscriptionCallback = typename Log<ID, Data>::SubscriptionCallback;
|
||||
|
||||
Hash(const std::vector<std::shared_ptr<RedisContext>> &contexts, AsyncGcsClient *client)
|
||||
: Log<ID, Data>(contexts, client) {}
|
||||
|
||||
using Log<ID, Data>::RequestNotifications;
|
||||
using Log<ID, Data>::CancelNotifications;
|
||||
|
||||
Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs,
|
||||
const HashCallback &done) override;
|
||||
|
||||
Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
const HashNotificationCallback &subscribe,
|
||||
const SubscriptionCallback &done) override;
|
||||
|
||||
Status Lookup(const DriverID &driver_id, const ID &id,
|
||||
const HashCallback &lookup) override;
|
||||
|
||||
Status RemoveEntries(const DriverID &driver_id, const ID &id,
|
||||
const std::vector<std::string> &keys,
|
||||
const HashRemoveCallback &remove_callback) override;
|
||||
|
||||
/// Returns debug string for class.
|
||||
///
|
||||
/// \return string.
|
||||
std::string DebugString() const;
|
||||
|
||||
protected:
|
||||
using Log<ID, Data>::shard_contexts_;
|
||||
using Log<ID, Data>::client_;
|
||||
using Log<ID, Data>::pubsub_channel_;
|
||||
using Log<ID, Data>::prefix_;
|
||||
using Log<ID, Data>::subscribe_callback_index_;
|
||||
using Log<ID, Data>::GetRedisContext;
|
||||
|
||||
int64_t num_adds_ = 0;
|
||||
int64_t num_removes_ = 0;
|
||||
using Log<ID, Data>::num_lookups_;
|
||||
};
|
||||
|
||||
class DynamicResourceTable : public Hash<ClientID, RayResource> {
|
||||
public:
|
||||
DynamicResourceTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
|
||||
AsyncGcsClient *client)
|
||||
: Hash(contexts, client) {
|
||||
pubsub_channel_ = TablePubsub::NODE_RESOURCE;
|
||||
prefix_ = TablePrefix::NODE_RESOURCE;
|
||||
};
|
||||
|
||||
virtual ~DynamicResourceTable(){};
|
||||
};
|
||||
|
||||
class ObjectTable : public Set<ObjectID, ObjectTableData> {
|
||||
public:
|
||||
ObjectTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
|
||||
|
||||
@@ -11,16 +11,16 @@ namespace {
|
||||
/// Process a notification of the object table entries and store the result in
|
||||
/// client_ids. This assumes that client_ids already contains the result of the
|
||||
/// object table entries up to but not including this notification.
|
||||
void UpdateObjectLocations(const GcsTableNotificationMode notification_mode,
|
||||
void UpdateObjectLocations(const GcsChangeMode change_mode,
|
||||
const std::vector<ObjectTableDataT> &location_updates,
|
||||
const ray::gcs::ClientTable &client_table,
|
||||
std::unordered_set<ClientID> *client_ids) {
|
||||
// location_updates contains the updates of locations of the object.
|
||||
// with GcsTableNotificationMode, we can determine whether the update mode is
|
||||
// with GcsChangeMode, we can determine whether the update mode is
|
||||
// addition or deletion.
|
||||
for (const auto &object_table_data : location_updates) {
|
||||
ClientID client_id = ClientID::FromBinary(object_table_data.manager);
|
||||
if (notification_mode != GcsTableNotificationMode::REMOVE) {
|
||||
if (change_mode != GcsChangeMode::REMOVE) {
|
||||
client_ids->insert(client_id);
|
||||
} else {
|
||||
client_ids->erase(client_id);
|
||||
@@ -41,7 +41,7 @@ void UpdateObjectLocations(const GcsTableNotificationMode notification_mode,
|
||||
void ObjectDirectory::RegisterBackend() {
|
||||
auto object_notification_callback = [this](
|
||||
gcs::AsyncGcsClient *client, const ObjectID &object_id,
|
||||
const GcsTableNotificationMode notification_mode,
|
||||
const GcsChangeMode change_mode,
|
||||
const std::vector<ObjectTableDataT> &location_updates) {
|
||||
// Objects are added to this map in SubscribeObjectLocations.
|
||||
auto it = listeners_.find(object_id);
|
||||
@@ -54,8 +54,7 @@ void ObjectDirectory::RegisterBackend() {
|
||||
it->second.subscribed = true;
|
||||
|
||||
// Update entries for this object.
|
||||
UpdateObjectLocations(notification_mode, location_updates,
|
||||
gcs_client_->client_table(),
|
||||
UpdateObjectLocations(change_mode, location_updates, gcs_client_->client_table(),
|
||||
&it->second.current_object_locations);
|
||||
// Copy the callbacks so that the callbacks can unsubscribe without interrupting
|
||||
// looping over the callbacks.
|
||||
@@ -135,8 +134,7 @@ void ObjectDirectory::HandleClientRemoved(const ClientID &client_id) {
|
||||
if (listener.second.current_object_locations.count(client_id) > 0) {
|
||||
// If the subscribed object has the removed client as a location, update
|
||||
// its locations with an empty update so that the location will be removed.
|
||||
UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, {},
|
||||
gcs_client_->client_table(),
|
||||
UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, {}, gcs_client_->client_table(),
|
||||
&listener.second.current_object_locations);
|
||||
// Re-call all the subscribed callbacks for the object, since its
|
||||
// locations have changed.
|
||||
@@ -213,7 +211,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
|
||||
const std::vector<ObjectTableDataT> &location_updates) {
|
||||
// Build the set of current locations based on the entries in the log.
|
||||
std::unordered_set<ClientID> client_ids;
|
||||
UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, location_updates,
|
||||
UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates,
|
||||
gcs_client_->client_table(), &client_ids);
|
||||
// It is safe to call the callback directly since this is already running
|
||||
// in the GCS client's lookup callback stack.
|
||||
|
||||
Reference in New Issue
Block a user