mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 16:58:23 +08:00
@@ -97,7 +97,7 @@ std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
|
||||
std::shared_ptr<ClientConnection<T>> self(
|
||||
new ClientConnection(message_handler, std::move(socket)));
|
||||
// Let our manager process our new connection.
|
||||
client_handler(self);
|
||||
client_handler(*self);
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ template <typename T>
|
||||
class ClientConnection;
|
||||
|
||||
template <typename T>
|
||||
using ClientHandler = std::function<void(std::shared_ptr<ClientConnection<T>>)>;
|
||||
using ClientHandler = std::function<void(ClientConnection<T> &)>;
|
||||
template <typename T>
|
||||
using MessageHandler =
|
||||
std::function<void(std::shared_ptr<ClientConnection<T>>, int64_t, const uint8_t *)>;
|
||||
|
||||
@@ -7,7 +7,7 @@ namespace ray {
|
||||
namespace gcs {
|
||||
|
||||
AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) {
|
||||
context_.reset(new RedisContext());
|
||||
context_ = std::make_shared<RedisContext>();
|
||||
client_table_.reset(new ClientTable(context_, this, client_id));
|
||||
object_table_.reset(new ObjectTable(context_, this));
|
||||
actor_table_.reset(new ActorTable(context_, this));
|
||||
|
||||
@@ -93,9 +93,9 @@ void TestTableLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> c
|
||||
|
||||
// Check that we added the correct task.
|
||||
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
const std::shared_ptr<protocol::TaskT> d) {
|
||||
const protocol::TaskT &d) {
|
||||
ASSERT_EQ(id, task_id);
|
||||
ASSERT_EQ(data->task_specification, d->task_specification);
|
||||
ASSERT_EQ(data->task_specification, d.task_specification);
|
||||
};
|
||||
|
||||
// Check that the lookup returns the added task.
|
||||
@@ -139,9 +139,9 @@ void TestLogLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> cli
|
||||
data->manager = manager;
|
||||
// Check that we added the correct object entries.
|
||||
auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
const std::shared_ptr<ObjectTableDataT> d) {
|
||||
const ObjectTableDataT &d) {
|
||||
ASSERT_EQ(id, object_id);
|
||||
ASSERT_EQ(data->manager, d->manager);
|
||||
ASSERT_EQ(data->manager, d.manager);
|
||||
};
|
||||
RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, add_callback));
|
||||
}
|
||||
@@ -222,7 +222,7 @@ void TestLogAppendAt(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> c
|
||||
|
||||
// Check that we added the correct task.
|
||||
auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
const std::shared_ptr<TaskReconstructionDataT> d) {
|
||||
const TaskReconstructionDataT &d) {
|
||||
ASSERT_EQ(id, task_id);
|
||||
test->IncrementNumCallbacks();
|
||||
};
|
||||
@@ -265,8 +265,8 @@ TEST_F(TestGcsWithAsio, TestLogAppendAt) {
|
||||
|
||||
// Task table callbacks.
|
||||
void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id,
|
||||
const std::shared_ptr<TaskTableDataT> data) {
|
||||
ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED);
|
||||
const TaskTableDataT &data) {
|
||||
ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED);
|
||||
}
|
||||
|
||||
void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id,
|
||||
|
||||
@@ -100,15 +100,13 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) {
|
||||
|
||||
int64_t RedisCallbackManager::add(const RedisCallback &function) {
|
||||
num_callbacks += 1;
|
||||
callbacks_.emplace(num_callbacks, std::unique_ptr<RedisCallback>(
|
||||
new RedisCallback(function)));
|
||||
callbacks_.emplace(num_callbacks, function);
|
||||
return num_callbacks;
|
||||
}
|
||||
|
||||
RedisCallbackManager::RedisCallback &RedisCallbackManager::get(
|
||||
int64_t callback_index) {
|
||||
RedisCallback &RedisCallbackManager::get(int64_t callback_index) {
|
||||
RAY_CHECK(callbacks_.find(callback_index) != callbacks_.end());
|
||||
return *callbacks_[callback_index];
|
||||
return callbacks_[callback_index];
|
||||
}
|
||||
|
||||
void RedisCallbackManager::remove(int64_t callback_index) {
|
||||
@@ -185,7 +183,9 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) {
|
||||
Status RedisContext::RunAsync(const std::string &command, const UniqueID &id,
|
||||
const uint8_t *data, int64_t length,
|
||||
const TablePrefix prefix, const TablePubsub pubsub_channel,
|
||||
int64_t callback_index, int log_length) {
|
||||
RedisCallback redisCallback, int log_length) {
|
||||
int64_t callback_index =
|
||||
redisCallback != nullptr ? RedisCallbackManager::instance().add(redisCallback) : -1;
|
||||
if (length > 0) {
|
||||
if (log_length >= 0) {
|
||||
std::string redis_command = command + " %d %d %b %b %d";
|
||||
@@ -222,10 +222,11 @@ Status RedisContext::RunAsync(const std::string &command, const UniqueID &id,
|
||||
|
||||
Status RedisContext::SubscribeAsync(const ClientID &client_id,
|
||||
const TablePubsub pubsub_channel,
|
||||
int64_t callback_index) {
|
||||
const RedisCallback &redisCallback) {
|
||||
RAY_CHECK(pubsub_channel != TablePubsub_NO_PUBLISH)
|
||||
<< "Client requested subscribe on a table that does not support pubsub";
|
||||
|
||||
int64_t callback_index = RedisCallbackManager::instance().add(redisCallback);
|
||||
int status = 0;
|
||||
if (client_id.is_nil()) {
|
||||
// Subscribe to all messages.
|
||||
|
||||
@@ -18,13 +18,13 @@ struct aeEventLoop;
|
||||
namespace ray {
|
||||
|
||||
namespace gcs {
|
||||
/// Every callback should take in a vector of the results from the Redis
|
||||
/// operation and return a bool indicating whether the callback should be
|
||||
/// deleted once called.
|
||||
using RedisCallback = std::function<bool(const std::string &)>;
|
||||
|
||||
class RedisCallbackManager {
|
||||
public:
|
||||
/// Every callback should take in a vector of the results from the Redis
|
||||
/// operation and return a bool indicating whether the callback should be
|
||||
/// deleted once called.
|
||||
using RedisCallback = std::function<bool(const std::string &)>;
|
||||
|
||||
static RedisCallbackManager &instance() {
|
||||
static RedisCallbackManager instance;
|
||||
@@ -44,7 +44,7 @@ class RedisCallbackManager {
|
||||
~RedisCallbackManager() { printf("shut down callback manager\n"); }
|
||||
|
||||
int64_t num_callbacks;
|
||||
std::unordered_map<int64_t, std::unique_ptr<RedisCallback>> callbacks_;
|
||||
std::unordered_map<int64_t, RedisCallback> callbacks_;
|
||||
};
|
||||
|
||||
class RedisContext {
|
||||
@@ -70,11 +70,11 @@ class RedisContext {
|
||||
/// -1 for unused. If set, then data must be provided.
|
||||
Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data,
|
||||
int64_t length, const TablePrefix prefix,
|
||||
const TablePubsub pubsub_channel, int64_t callback_index,
|
||||
const TablePubsub pubsub_channel, RedisCallback redisCallback,
|
||||
int log_length = -1);
|
||||
|
||||
Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel,
|
||||
int64_t callback_index);
|
||||
const RedisCallback &redisCallback);
|
||||
redisAsyncContext *async_context() { return async_context_; }
|
||||
redisAsyncContext *subscribe_context() { return subscribe_context_; };
|
||||
|
||||
|
||||
+88
-108
@@ -9,76 +9,66 @@ namespace gcs {
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::Append(const JobID &job_id, const ID &id,
|
||||
std::shared_ptr<DataT> data, const WriteCallback &done) {
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({id, data, nullptr, nullptr, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d, done](const std::string &data) {
|
||||
RAY_CHECK(data.empty());
|
||||
if (done != nullptr) {
|
||||
(done)(d->client, d->id, d->data);
|
||||
}
|
||||
return true;
|
||||
});
|
||||
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
|
||||
auto callback = [this, id, dataT, done](const std::string &data) {
|
||||
RAY_CHECK(data.empty());
|
||||
if (done != nullptr) {
|
||||
(done)(client_, id, *dataT);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.ForceDefaults(true);
|
||||
fbb.Finish(Data::Pack(fbb, data.get()));
|
||||
fbb.Finish(Data::Pack(fbb, dataT.get()));
|
||||
return context_->RunAsync("RAY.TABLE_APPEND", id, fbb.GetBufferPointer(), fbb.GetSize(),
|
||||
prefix_, pubsub_channel_, callback_index);
|
||||
prefix_, pubsub_channel_, std::move(callback));
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::AppendAt(const JobID &job_id, const ID &id,
|
||||
std::shared_ptr<DataT> data, const WriteCallback &done,
|
||||
std::shared_ptr<DataT> &dataT, const WriteCallback &done,
|
||||
const WriteCallback &failure, int log_length) {
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({id, data, nullptr, nullptr, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d, done, failure](const std::string &data) {
|
||||
if (data.empty()) {
|
||||
if (done != nullptr) {
|
||||
(done)(d->client, d->id, d->data);
|
||||
}
|
||||
} else {
|
||||
if (failure != nullptr) {
|
||||
(failure)(d->client, d->id, d->data);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
auto callback = [this, id, dataT, done, failure](const std::string &data) {
|
||||
if (data.empty()) {
|
||||
if (done != nullptr) {
|
||||
(done)(client_, id, *dataT);
|
||||
}
|
||||
} else {
|
||||
if (failure != nullptr) {
|
||||
(failure)(client_, id, *dataT);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.ForceDefaults(true);
|
||||
fbb.Finish(Data::Pack(fbb, data.get()));
|
||||
fbb.Finish(Data::Pack(fbb, dataT.get()));
|
||||
return context_->RunAsync("RAY.TABLE_APPEND", id, fbb.GetBufferPointer(), fbb.GetSize(),
|
||||
prefix_, pubsub_channel_, callback_index, log_length);
|
||||
prefix_, pubsub_channel_, std::move(callback), log_length);
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::Lookup(const JobID &job_id, const ID &id, const Callback &lookup) {
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({id, nullptr, lookup, nullptr, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d](const std::string &data) {
|
||||
if (d->callback != nullptr) {
|
||||
std::vector<DataT> results;
|
||||
if (!data.empty()) {
|
||||
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
|
||||
RAY_CHECK(from_flatbuf(*root->id()) == d->id);
|
||||
for (size_t i = 0; i < root->entries()->size(); i++) {
|
||||
DataT result;
|
||||
auto data_root =
|
||||
flatbuffers::GetRoot<Data>(root->entries()->Get(i)->data());
|
||||
data_root->UnPackTo(&result);
|
||||
results.emplace_back(std::move(result));
|
||||
}
|
||||
}
|
||||
(d->callback)(d->client, d->id, results);
|
||||
auto callback = [this, id, lookup](const std::string &data) {
|
||||
if (lookup != nullptr) {
|
||||
std::vector<DataT> results;
|
||||
if (!data.empty()) {
|
||||
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
|
||||
RAY_CHECK(from_flatbuf(*root->id()) == id);
|
||||
for (size_t i = 0; i < root->entries()->size(); i++) {
|
||||
DataT result;
|
||||
auto data_root = flatbuffers::GetRoot<Data>(root->entries()->Get(i)->data());
|
||||
data_root->UnPackTo(&result);
|
||||
results.emplace_back(std::move(result));
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
lookup(client_, id, results);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
std::vector<uint8_t> nil;
|
||||
return context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), prefix_,
|
||||
pubsub_channel_, callback_index);
|
||||
pubsub_channel_, std::move(callback));
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
@@ -87,42 +77,38 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const SubscriptionCallback &done) {
|
||||
RAY_CHECK(subscribe_callback_index_ == -1)
|
||||
<< "Client called Subscribe twice on the same table";
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({client_id, nullptr, subscribe, done, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d](const std::string &data) {
|
||||
if (data.empty()) {
|
||||
// No notification data is provided. This is the callback for the
|
||||
// initial subscription request.
|
||||
if (d->subscription_callback != nullptr) {
|
||||
(d->subscription_callback)(d->client);
|
||||
}
|
||||
} else {
|
||||
// Data is provided. This is the callback for a message.
|
||||
if (d->callback != nullptr) {
|
||||
// Parse the notification.
|
||||
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
|
||||
ID id = UniqueID::nil();
|
||||
if (root->id()->size() > 0) {
|
||||
id = from_flatbuf(*root->id());
|
||||
}
|
||||
std::vector<DataT> results;
|
||||
for (size_t i = 0; i < root->entries()->size(); i++) {
|
||||
DataT result;
|
||||
auto data_root =
|
||||
flatbuffers::GetRoot<Data>(root->entries()->Get(i)->data());
|
||||
data_root->UnPackTo(&result);
|
||||
results.emplace_back(std::move(result));
|
||||
}
|
||||
(d->callback)(d->client, id, results);
|
||||
}
|
||||
auto callback = [this, subscribe, done](const std::string &data) {
|
||||
if (data.empty()) {
|
||||
// No notification data is provided. This is the callback for the
|
||||
// initial subscription request.
|
||||
if (done != nullptr) {
|
||||
done(client_);
|
||||
}
|
||||
} else {
|
||||
// Data is provided. This is the callback for a message.
|
||||
if (subscribe != nullptr) {
|
||||
// Parse the notification.
|
||||
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
|
||||
ID id = UniqueID::nil();
|
||||
if (root->id()->size() > 0) {
|
||||
id = from_flatbuf(*root->id());
|
||||
}
|
||||
// We do not delete the callback after calling it since there may be
|
||||
// more subscription messages.
|
||||
return false;
|
||||
});
|
||||
subscribe_callback_index_ = callback_index;
|
||||
return context_->SubscribeAsync(client_id, pubsub_channel_, callback_index);
|
||||
std::vector<DataT> results;
|
||||
for (size_t i = 0; i < root->entries()->size(); i++) {
|
||||
DataT result;
|
||||
auto data_root = flatbuffers::GetRoot<Data>(root->entries()->Get(i)->data());
|
||||
data_root->UnPackTo(&result);
|
||||
results.emplace_back(std::move(result));
|
||||
}
|
||||
subscribe(client_, id, results);
|
||||
}
|
||||
}
|
||||
// We do not delete the callback after calling it since there may be
|
||||
// more subscription messages.
|
||||
return false;
|
||||
};
|
||||
subscribe_callback_index_ = 1;
|
||||
return context_->SubscribeAsync(client_id, pubsub_channel_, std::move(callback));
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
@@ -131,8 +117,7 @@ Status Log<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id,
|
||||
RAY_CHECK(subscribe_callback_index_ >= 0)
|
||||
<< "Client requested notifications on a key before Subscribe completed";
|
||||
return context_->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, client_id.data(),
|
||||
client_id.size(), prefix_, pubsub_channel_,
|
||||
/*callback_index=*/-1);
|
||||
client_id.size(), prefix_, pubsub_channel_, nullptr);
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
@@ -141,27 +126,23 @@ Status Log<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id,
|
||||
RAY_CHECK(subscribe_callback_index_ >= 0)
|
||||
<< "Client canceled notifications on a key before Subscribe completed";
|
||||
return context_->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, client_id.data(),
|
||||
client_id.size(), prefix_, pubsub_channel_,
|
||||
/*callback_index=*/-1);
|
||||
client_id.size(), prefix_, pubsub_channel_, nullptr);
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Table<ID, Data>::Add(const JobID &job_id, const ID &id,
|
||||
std::shared_ptr<DataT> data, const WriteCallback &done) {
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({id, data, nullptr, nullptr, this, client_}));
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([d, done](const std::string &data) {
|
||||
if (done != nullptr) {
|
||||
(done)(d->client, d->id, d->data);
|
||||
}
|
||||
return true;
|
||||
});
|
||||
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
|
||||
auto callback = [this, id, dataT, done](const std::string &data) {
|
||||
if (done != nullptr) {
|
||||
(done)(client_, id, *dataT);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.ForceDefaults(true);
|
||||
fbb.Finish(Data::Pack(fbb, data.get()));
|
||||
fbb.Finish(Data::Pack(fbb, dataT.get()));
|
||||
return context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(), fbb.GetSize(),
|
||||
prefix_, pubsub_channel_, callback_index);
|
||||
prefix_, pubsub_channel_, std::move(callback));
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
@@ -259,9 +240,8 @@ void ClientTable::HandleNotification(AsyncGcsClient *client,
|
||||
}
|
||||
}
|
||||
|
||||
void ClientTable::HandleConnected(AsyncGcsClient *client,
|
||||
const std::shared_ptr<ClientTableDataT> data) {
|
||||
auto connected_client_id = ClientID::from_binary(data->client_id);
|
||||
void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) {
|
||||
auto connected_client_id = ClientID::from_binary(data.client_id);
|
||||
RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " "
|
||||
<< client_id_;
|
||||
}
|
||||
@@ -282,7 +262,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) {
|
||||
// Callback to handle our own successful connection once we've added
|
||||
// ourselves.
|
||||
auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
const ClientTableDataT &data) {
|
||||
RAY_CHECK(log_key == client_log_key_);
|
||||
HandleConnected(client, data);
|
||||
|
||||
@@ -311,7 +291,7 @@ Status ClientTable::Disconnect() {
|
||||
auto data = std::make_shared<ClientTableDataT>(local_client_);
|
||||
data->is_insertion = false;
|
||||
auto add_callback = [this](AsyncGcsClient *client, const ClientID &id,
|
||||
std::shared_ptr<ClientTableDataT> data) {
|
||||
const ClientTableDataT &data) {
|
||||
HandleConnected(client, data);
|
||||
RAY_CHECK_OK(CancelNotifications(JobID::nil(), client_log_key_, id));
|
||||
};
|
||||
|
||||
+15
-28
@@ -57,8 +57,8 @@ class Log : virtual public PubsubInterface<ID> {
|
||||
using Callback = std::function<void(AsyncGcsClient *client, const ID &id,
|
||||
const std::vector<DataT> &data)>;
|
||||
/// The callback to call when a write to a key succeeds.
|
||||
using WriteCallback = std::function<void(AsyncGcsClient *client, const ID &id,
|
||||
std::shared_ptr<DataT> data)>;
|
||||
using WriteCallback =
|
||||
std::function<void(AsyncGcsClient *client, const ID &id, const DataT &data)>;
|
||||
/// The callback to call when a SUBSCRIBE call completes and we are ready to
|
||||
/// request and receive notifications.
|
||||
using SubscriptionCallback = std::function<void(AsyncGcsClient *client)>;
|
||||
@@ -89,7 +89,7 @@ class Log : virtual public PubsubInterface<ID> {
|
||||
/// \param done Callback that is called once the data has been written to the
|
||||
/// GCS.
|
||||
/// \return Status
|
||||
Status Append(const JobID &job_id, const ID &id, std::shared_ptr<DataT> data,
|
||||
Status Append(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
|
||||
const WriteCallback &done);
|
||||
|
||||
/// Append a log entry to a key if and only if the log has the given number
|
||||
@@ -105,7 +105,7 @@ class Log : virtual public PubsubInterface<ID> {
|
||||
/// \param log_length The number of entries that the log must have for the
|
||||
/// append to succeed.
|
||||
/// \return Status
|
||||
Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr<DataT> data,
|
||||
Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
|
||||
const WriteCallback &done, const WriteCallback &failure,
|
||||
int log_length);
|
||||
|
||||
@@ -187,7 +187,7 @@ class TableInterface {
|
||||
public:
|
||||
using DataT = typename Data::NativeTableType;
|
||||
using WriteCallback = typename Log<ID, Data>::WriteCallback;
|
||||
virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr<DataT> data,
|
||||
virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr<DataT> &data,
|
||||
const WriteCallback &done) = 0;
|
||||
virtual ~TableInterface(){};
|
||||
};
|
||||
@@ -212,17 +212,6 @@ class Table : private Log<ID, Data>,
|
||||
/// request and receive notifications.
|
||||
using SubscriptionCallback = typename Log<ID, Data>::SubscriptionCallback;
|
||||
|
||||
struct CallbackData {
|
||||
ID id;
|
||||
std::shared_ptr<DataT> data;
|
||||
Callback callback;
|
||||
// An optional callback to call for subscription operations, where the
|
||||
// first message is a notification of subscription success.
|
||||
SubscriptionCallback subscription_callback;
|
||||
Log<ID, Data> *log;
|
||||
AsyncGcsClient *client;
|
||||
};
|
||||
|
||||
Table(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
|
||||
: Log<ID, Data>(context, client) {}
|
||||
|
||||
@@ -237,7 +226,7 @@ class Table : private Log<ID, Data>,
|
||||
/// \param done Callback that is called once the data has been written to the
|
||||
/// GCS.
|
||||
/// \return Status
|
||||
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<DataT> data,
|
||||
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
|
||||
const WriteCallback &done);
|
||||
|
||||
/// Lookup an entry asynchronously.
|
||||
@@ -358,19 +347,18 @@ class TaskTable : public Table<TaskID, TaskTableData> {
|
||||
Status TestAndUpdate(const JobID &job_id, const TaskID &id,
|
||||
std::shared_ptr<TaskTableTestAndUpdateT> data,
|
||||
const TestAndUpdateCallback &callback) {
|
||||
int64_t callback_index = RedisCallbackManager::instance().add(
|
||||
[this, callback, id](const std::string &data) {
|
||||
auto result = std::make_shared<TaskTableDataT>();
|
||||
auto root = flatbuffers::GetRoot<TaskTableData>(data.data());
|
||||
root->UnPackTo(result.get());
|
||||
callback(client_, id, *result, root->updated());
|
||||
return true;
|
||||
});
|
||||
auto redisCallback = [this, callback, id](const std::string &data) {
|
||||
auto result = std::make_shared<TaskTableDataT>();
|
||||
auto root = flatbuffers::GetRoot<TaskTableData>(data.data());
|
||||
root->UnPackTo(result.get());
|
||||
callback(client_, id, *result, root->updated());
|
||||
return true;
|
||||
};
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.Finish(TaskTableTestAndUpdate::Pack(fbb, data.get()));
|
||||
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_TEST_AND_UPDATE", id,
|
||||
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
|
||||
pubsub_channel_, callback_index));
|
||||
pubsub_channel_, redisCallback));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@@ -499,8 +487,7 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
|
||||
/// Handle a client table notification.
|
||||
void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications);
|
||||
/// Handle this client's successful connection to the GCS.
|
||||
void HandleConnected(AsyncGcsClient *client,
|
||||
const std::shared_ptr<ClientTableDataT> client_data);
|
||||
void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data);
|
||||
|
||||
/// The key at which the log of client information is stored. This key must
|
||||
/// be kept the same across all instances of the ClientTable, so that all
|
||||
|
||||
@@ -46,9 +46,9 @@ Status TaskTableAdd(AsyncGcsClient *gcs_client, Task *task) {
|
||||
TaskSpec *spec = execution_spec.Spec();
|
||||
auto data = MakeTaskTableData(execution_spec, Task_local_scheduler(task),
|
||||
static_cast<SchedulingState>(Task_state(task)));
|
||||
return gcs_client->task_table().Add(ray::JobID::nil(), TaskSpec_task_id(spec), data,
|
||||
[](gcs::AsyncGcsClient *client, const TaskID &id,
|
||||
std::shared_ptr<TaskTableDataT> data) {});
|
||||
return gcs_client->task_table().Add(
|
||||
ray::JobID::nil(), TaskSpec_task_id(spec), data,
|
||||
[](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableDataT &data) {});
|
||||
}
|
||||
|
||||
// TODO(pcm): This is a helper method that should go away once we get rid of
|
||||
|
||||
@@ -53,7 +53,7 @@ ray::Status ConnectionPool::GetSender(ConnectionType type, const ClientID &clien
|
||||
}
|
||||
|
||||
ray::Status ConnectionPool::ReleaseSender(ConnectionType type,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
std::shared_ptr<SenderConnection> &conn) {
|
||||
std::unique_lock<std::mutex> guard(connection_mutex);
|
||||
SenderMapType &conn_map = (type == ConnectionType::MESSAGE)
|
||||
? available_message_send_connections_
|
||||
@@ -64,20 +64,21 @@ ray::Status ConnectionPool::ReleaseSender(ConnectionType type,
|
||||
|
||||
void ConnectionPool::Add(ReceiverMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> conn) {
|
||||
conn_map[client_id].push_back(conn);
|
||||
conn_map[client_id].push_back(std::move(conn));
|
||||
}
|
||||
|
||||
void ConnectionPool::Add(SenderMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
conn_map[client_id].push_back(conn);
|
||||
conn_map[client_id].push_back(std::move(conn));
|
||||
}
|
||||
|
||||
void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> conn) {
|
||||
if (conn_map.count(client_id) == 0) {
|
||||
std::shared_ptr<TcpClientConnection> &conn) {
|
||||
auto it = conn_map.find(client_id);
|
||||
if (it == conn_map.end()) {
|
||||
return;
|
||||
}
|
||||
std::vector<std::shared_ptr<TcpClientConnection>> &connections = conn_map[client_id];
|
||||
auto &connections = it->second;
|
||||
int64_t pos =
|
||||
std::find(connections.begin(), connections.end(), conn) - connections.begin();
|
||||
if (pos >= (int64_t)connections.size()) {
|
||||
@@ -87,15 +88,16 @@ void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id
|
||||
}
|
||||
|
||||
uint64_t ConnectionPool::Count(SenderMapType &conn_map, const ClientID &client_id) {
|
||||
if (conn_map.count(client_id) == 0) {
|
||||
auto it = conn_map.find(client_id);
|
||||
if (it == conn_map.end()) {
|
||||
return 0;
|
||||
};
|
||||
return conn_map[client_id].size();
|
||||
}
|
||||
return it->second.size();
|
||||
}
|
||||
|
||||
std::shared_ptr<SenderConnection> ConnectionPool::Borrow(SenderMapType &conn_map,
|
||||
const ClientID &client_id) {
|
||||
std::shared_ptr<SenderConnection> conn = conn_map[client_id].back();
|
||||
std::shared_ptr<SenderConnection> conn = std::move(conn_map[client_id].back());
|
||||
conn_map[client_id].pop_back();
|
||||
RAY_LOG(DEBUG) << "Borrow " << client_id << " " << conn_map[client_id].size();
|
||||
return conn;
|
||||
@@ -103,7 +105,7 @@ std::shared_ptr<SenderConnection> ConnectionPool::Borrow(SenderMapType &conn_map
|
||||
|
||||
void ConnectionPool::Return(SenderMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
conn_map[client_id].push_back(conn);
|
||||
conn_map[client_id].push_back(std::move(conn));
|
||||
RAY_LOG(DEBUG) << "Return " << client_id << " " << conn_map[client_id].size();
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class ConnectionPool {
|
||||
/// \param type The type of connection.
|
||||
/// \param conn The actual connection.
|
||||
/// \return Status of invoking this method.
|
||||
ray::Status ReleaseSender(ConnectionType type, std::shared_ptr<SenderConnection> conn);
|
||||
ray::Status ReleaseSender(ConnectionType type, std::shared_ptr<SenderConnection> &conn);
|
||||
|
||||
// TODO(hme): Implement with error handling.
|
||||
/// Remove a sender connection. This is invoked if the connection is no longer
|
||||
@@ -106,7 +106,7 @@ class ConnectionPool {
|
||||
|
||||
/// Removes the given receiver for ClientID from the given map.
|
||||
void Remove(ReceiverMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> conn);
|
||||
std::shared_ptr<TcpClientConnection> &conn);
|
||||
|
||||
/// Returns the count of sender connections to ClientID.
|
||||
uint64_t Count(SenderMapType &conn_map, const ClientID &client_id);
|
||||
|
||||
@@ -16,8 +16,8 @@ ray::Status ObjectDirectory::ReportObjectAdded(const ObjectID &object_id,
|
||||
data->is_eviction = false;
|
||||
data->object_size = object_info.data_size;
|
||||
ray::Status status = gcs_client_->object_table().Append(
|
||||
job_id, object_id, data, [](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
const std::shared_ptr<ObjectTableDataT> data) {
|
||||
job_id, object_id, data,
|
||||
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) {
|
||||
// Do nothing.
|
||||
});
|
||||
return status;
|
||||
|
||||
@@ -110,8 +110,8 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) {
|
||||
}
|
||||
|
||||
void ObjectManager::SchedulePull(const ObjectID &object_id, int wait_ms) {
|
||||
pull_requests_[object_id] = std::shared_ptr<boost::asio::deadline_timer>(
|
||||
new asio::deadline_timer(*main_service_, boost::posix_time::milliseconds(wait_ms)));
|
||||
pull_requests_[object_id] = std::make_shared<boost::asio::deadline_timer>(
|
||||
*main_service_, boost::posix_time::milliseconds(wait_ms));
|
||||
pull_requests_[object_id]->async_wait(
|
||||
[this, object_id](const boost::system::error_code &error_code) {
|
||||
pull_requests_.erase(object_id);
|
||||
@@ -184,7 +184,7 @@ ray::Status ObjectManager::PullEstablishConnection(const ObjectID &object_id,
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
std::shared_ptr<SenderConnection> &conn) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = object_manager_protocol::CreatePullRequestMessage(
|
||||
fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary()));
|
||||
@@ -209,7 +209,7 @@ ray::Status ObjectManager::Push(const ObjectID &object_id, const ClientID &clien
|
||||
Status status = object_directory_->GetInformation(
|
||||
client_id,
|
||||
[this, object_id, client_id](const RemoteConnectionInfo &info) {
|
||||
ObjectInfoT object_info = local_objects_[object_id];
|
||||
const ObjectInfoT &object_info = local_objects_[object_id];
|
||||
uint64_t data_size =
|
||||
static_cast<uint64_t>(object_info.data_size + object_info.metadata_size);
|
||||
uint64_t metadata_size = static_cast<uint64_t>(object_info.metadata_size);
|
||||
@@ -251,7 +251,7 @@ void ObjectManager::ExecuteSendObject(const ClientID &client_id,
|
||||
ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id,
|
||||
uint64_t data_size, uint64_t metadata_size,
|
||||
uint64_t chunk_index,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
std::shared_ptr<SenderConnection> &conn) {
|
||||
std::pair<const ObjectBufferPool::ChunkInfo &, ray::Status> chunk_status =
|
||||
buffer_pool_.GetChunk(object_id, data_size, metadata_size, chunk_index);
|
||||
ObjectBufferPool::ChunkInfo chunk_info = chunk_status.first;
|
||||
@@ -276,7 +276,7 @@ ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id,
|
||||
|
||||
ray::Status ObjectManager::SendObjectData(const ObjectID &object_id,
|
||||
const ObjectBufferPool::ChunkInfo &chunk_info,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
std::shared_ptr<SenderConnection> &conn) {
|
||||
boost::system::error_code ec;
|
||||
std::vector<asio::const_buffer> buffer;
|
||||
buffer.push_back(asio::buffer(chunk_info.data, chunk_info.buffer_length));
|
||||
@@ -328,11 +328,11 @@ std::shared_ptr<SenderConnection> ObjectManager::CreateSenderConnection(
|
||||
return conn;
|
||||
}
|
||||
|
||||
void ObjectManager::ProcessNewClient(std::shared_ptr<TcpClientConnection> conn) {
|
||||
conn->ProcessMessages();
|
||||
void ObjectManager::ProcessNewClient(TcpClientConnection &conn) {
|
||||
conn.ProcessMessages();
|
||||
}
|
||||
|
||||
void ObjectManager::ProcessClientMessage(std::shared_ptr<TcpClientConnection> conn,
|
||||
void ObjectManager::ProcessClientMessage(std::shared_ptr<TcpClientConnection> &conn,
|
||||
int64_t message_type, const uint8_t *message) {
|
||||
switch (message_type) {
|
||||
case object_manager_protocol::MessageType_PushRequest: {
|
||||
@@ -389,7 +389,7 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr<TcpClientConnection> &con
|
||||
conn->ProcessMessages();
|
||||
}
|
||||
|
||||
void ObjectManager::ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn,
|
||||
void ObjectManager::ReceivePushRequest(std::shared_ptr<TcpClientConnection> &conn,
|
||||
const uint8_t *message) {
|
||||
// Serialize.
|
||||
auto object_header =
|
||||
@@ -400,14 +400,14 @@ void ObjectManager::ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn
|
||||
uint64_t metadata_size = object_header->metadata_size();
|
||||
receive_service_.post([this, object_id, data_size, metadata_size, chunk_index, conn]() {
|
||||
ExecuteReceiveObject(conn->GetClientID(), object_id, data_size, metadata_size,
|
||||
chunk_index, conn);
|
||||
chunk_index, *conn);
|
||||
});
|
||||
}
|
||||
|
||||
void ObjectManager::ExecuteReceiveObject(const ClientID &client_id,
|
||||
const ObjectID &object_id, uint64_t data_size,
|
||||
uint64_t metadata_size, uint64_t chunk_index,
|
||||
std::shared_ptr<TcpClientConnection> conn) {
|
||||
TcpClientConnection &conn) {
|
||||
RAY_LOG(DEBUG) << "ExecuteReceiveObject " << client_id << " " << object_id << " "
|
||||
<< chunk_index;
|
||||
|
||||
@@ -419,7 +419,7 @@ void ObjectManager::ExecuteReceiveObject(const ClientID &client_id,
|
||||
std::vector<boost::asio::mutable_buffer> buffer;
|
||||
buffer.push_back(asio::buffer(chunk_info.data, chunk_info.buffer_length));
|
||||
boost::system::error_code ec;
|
||||
conn->ReadBuffer(buffer, ec);
|
||||
conn.ReadBuffer(buffer, ec);
|
||||
if (ec.value() == 0) {
|
||||
buffer_pool_.SealChunk(object_id, chunk_index);
|
||||
} else {
|
||||
@@ -435,13 +435,13 @@ void ObjectManager::ExecuteReceiveObject(const ClientID &client_id,
|
||||
std::vector<boost::asio::mutable_buffer> buffer;
|
||||
buffer.push_back(asio::buffer(mutable_vec, buffer_length));
|
||||
boost::system::error_code ec;
|
||||
conn->ReadBuffer(buffer, ec);
|
||||
conn.ReadBuffer(buffer, ec);
|
||||
if (ec.value() != 0) {
|
||||
RAY_LOG(ERROR) << ec.message();
|
||||
}
|
||||
// TODO(hme): If the object isn't local, create a pull request for this chunk.
|
||||
}
|
||||
conn->ProcessMessages();
|
||||
conn.ProcessMessages();
|
||||
RAY_LOG(DEBUG) << "ReceiveCompleted " << client_id_ << " " << object_id << " "
|
||||
<< "/" << config_.max_receives;
|
||||
}
|
||||
|
||||
@@ -110,7 +110,7 @@ class ObjectManager {
|
||||
///
|
||||
/// \param conn The connection.
|
||||
/// \return Status of whether the connection was successfully established.
|
||||
void ProcessNewClient(std::shared_ptr<TcpClientConnection> conn);
|
||||
void ProcessNewClient(TcpClientConnection &conn);
|
||||
|
||||
/// Process messages sent from other nodes. We only establish
|
||||
/// transfer connections using this method; all other transfer communication
|
||||
@@ -119,7 +119,7 @@ class ObjectManager {
|
||||
/// \param conn The connection.
|
||||
/// \param message_type The message type.
|
||||
/// \param message A pointer set to the beginning of the message.
|
||||
void ProcessClientMessage(std::shared_ptr<TcpClientConnection> conn,
|
||||
void ProcessClientMessage(std::shared_ptr<TcpClientConnection> &conn,
|
||||
int64_t message_type, const uint8_t *message);
|
||||
|
||||
/// Cancels all requests (Push/Pull) associated with the given ObjectID.
|
||||
@@ -226,7 +226,7 @@ class ObjectManager {
|
||||
/// Synchronously send a pull request via remote object manager connection.
|
||||
/// Executes on main_service_ thread.
|
||||
ray::Status PullSendRequest(const ObjectID &object_id,
|
||||
std::shared_ptr<SenderConnection> conn);
|
||||
std::shared_ptr<SenderConnection> &conn);
|
||||
|
||||
std::shared_ptr<SenderConnection> CreateSenderConnection(
|
||||
ConnectionPool::ConnectionType type, RemoteConnectionInfo info);
|
||||
@@ -241,23 +241,22 @@ class ObjectManager {
|
||||
/// Executes on send_service_ thread pool.
|
||||
ray::Status SendObjectHeaders(const ObjectID &object_id, uint64_t data_size,
|
||||
uint64_t metadata_size, uint64_t chunk_index,
|
||||
std::shared_ptr<SenderConnection> conn);
|
||||
std::shared_ptr<SenderConnection> &conn);
|
||||
|
||||
/// This method initiates the actual object transfer.
|
||||
/// Executes on send_service_ thread pool.
|
||||
ray::Status SendObjectData(const ObjectID &object_id,
|
||||
const ObjectBufferPool::ChunkInfo &chunk_info,
|
||||
std::shared_ptr<SenderConnection> conn);
|
||||
std::shared_ptr<SenderConnection> &conn);
|
||||
|
||||
/// Invoked when a remote object manager pushes an object to this object manager.
|
||||
/// This will invoke the object receive on the receive_service_ thread pool.
|
||||
void ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn,
|
||||
void ReceivePushRequest(std::shared_ptr<TcpClientConnection> &conn,
|
||||
const uint8_t *message);
|
||||
/// Execute a receive on the receive_service_ thread pool.
|
||||
void ExecuteReceiveObject(const ClientID &client_id, const ObjectID &object_id,
|
||||
uint64_t data_size, uint64_t metadata_size,
|
||||
uint64_t chunk_index,
|
||||
std::shared_ptr<TcpClientConnection> conn);
|
||||
uint64_t chunk_index, TcpClientConnection &conn);
|
||||
|
||||
/// Handles receiving a pull request message.
|
||||
void ReceivePullRequest(std::shared_ptr<TcpClientConnection> &conn,
|
||||
|
||||
@@ -11,7 +11,7 @@ std::shared_ptr<SenderConnection> SenderConnection::Create(
|
||||
RAY_CHECK_OK(TcpConnect(socket, ip, port));
|
||||
std::shared_ptr<TcpServerConnection> conn =
|
||||
std::make_shared<TcpServerConnection>(std::move(socket));
|
||||
return std::make_shared<SenderConnection>(conn, client_id);
|
||||
return std::make_shared<SenderConnection>(std::move(conn), client_id);
|
||||
};
|
||||
|
||||
SenderConnection::SenderConnection(std::shared_ptr<TcpServerConnection> conn,
|
||||
|
||||
@@ -65,9 +65,7 @@ class MockServer {
|
||||
|
||||
void HandleAcceptObjectManager(const boost::system::error_code &error) {
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler =
|
||||
[this](std::shared_ptr<TcpClientConnection> client) {
|
||||
object_manager_.ProcessNewClient(client);
|
||||
};
|
||||
[this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); };
|
||||
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
|
||||
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
|
||||
@@ -56,9 +56,7 @@ class MockServer {
|
||||
|
||||
void HandleAcceptObjectManager(const boost::system::error_code &error) {
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler =
|
||||
[this](std::shared_ptr<TcpClientConnection> client) {
|
||||
object_manager_.ProcessNewClient(client);
|
||||
};
|
||||
[this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); };
|
||||
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
|
||||
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
|
||||
@@ -258,8 +258,9 @@ Status LineageCache::Flush() {
|
||||
|
||||
// Write back all ready tasks whose arguments have been committed to the GCS.
|
||||
gcs::raylet::TaskTable::WriteCallback task_callback = [this](
|
||||
ray::gcs::AsyncGcsClient *client, const TaskID &id,
|
||||
const std::shared_ptr<protocol::TaskT> data) { HandleEntryCommitted(id); };
|
||||
ray::gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) {
|
||||
HandleEntryCommitted(id);
|
||||
};
|
||||
for (const auto &ready_task_id : ready_task_ids) {
|
||||
auto task = lineage_.GetEntry(ready_task_id);
|
||||
// TODO(swang): Make this better...
|
||||
|
||||
@@ -23,7 +23,7 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,
|
||||
}
|
||||
|
||||
Status Add(const JobID &job_id, const TaskID &task_id,
|
||||
std::shared_ptr<protocol::TaskT> task_data,
|
||||
std::shared_ptr<protocol::TaskT> &task_data,
|
||||
const gcs::TableInterface<TaskID, protocol::Task>::WriteCallback &done) {
|
||||
task_table_[task_id] = task_data;
|
||||
callbacks_.push_back(
|
||||
@@ -38,7 +38,7 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,
|
||||
bool send_notification = (subscribed_tasks_.count(task_id) == 1);
|
||||
auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client,
|
||||
const TaskID &task_id,
|
||||
std::shared_ptr<protocol::TaskT> data) {
|
||||
const protocol::TaskT &data) {
|
||||
if (send_notification) {
|
||||
notification_callback_(client, task_id, data);
|
||||
}
|
||||
@@ -63,7 +63,7 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,
|
||||
|
||||
void Flush() {
|
||||
for (const auto &callback : callbacks_) {
|
||||
callback.first(NULL, callback.second, task_table_[callback.second]);
|
||||
callback.first(NULL, callback.second, *task_table_[callback.second]);
|
||||
}
|
||||
callbacks_.clear();
|
||||
}
|
||||
@@ -86,7 +86,7 @@ class LineageCacheTest : public ::testing::Test {
|
||||
LineageCacheTest()
|
||||
: mock_gcs_(), lineage_cache_(ClientID::from_random(), mock_gcs_, mock_gcs_) {
|
||||
mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id,
|
||||
std::shared_ptr<ray::protocol::TaskT> data) {
|
||||
const ray::protocol::TaskT &data) {
|
||||
lineage_cache_.HandleEntryCommitted(task_id);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ void NodeManager::Heartbeat() {
|
||||
ray::Status status = heartbeat_table.Add(
|
||||
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data,
|
||||
[](ray::gcs::AsyncGcsClient *client, const ClientID &id,
|
||||
std::shared_ptr<HeartbeatTableDataT> data) {
|
||||
const HeartbeatTableDataT &data) {
|
||||
RAY_LOG(DEBUG) << "[HEARTBEAT] heartbeat sent callback";
|
||||
});
|
||||
|
||||
@@ -279,9 +279,9 @@ void NodeManager::HandleActorCreation(const ActorID &actor_id,
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::ProcessNewClient(std::shared_ptr<LocalClientConnection> client) {
|
||||
void NodeManager::ProcessNewClient(LocalClientConnection &client) {
|
||||
// The new client is a worker, so begin listening for messages.
|
||||
client->ProcessMessages();
|
||||
client.ProcessMessages();
|
||||
}
|
||||
|
||||
void NodeManager::DispatchTasks() {
|
||||
@@ -309,9 +309,9 @@ void NodeManager::DispatchTasks() {
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> client,
|
||||
int64_t message_type,
|
||||
const uint8_t *message_data) {
|
||||
void NodeManager::ProcessClientMessage(
|
||||
const std::shared_ptr<LocalClientConnection> &client, int64_t message_type,
|
||||
const uint8_t *message_data) {
|
||||
RAY_LOG(DEBUG) << "Message of type " << message_type;
|
||||
|
||||
switch (message_type) {
|
||||
@@ -319,7 +319,7 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
|
||||
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
|
||||
if (message->is_worker()) {
|
||||
// Create a new worker from the registration request.
|
||||
std::shared_ptr<Worker> worker(new Worker(message->worker_pid(), client));
|
||||
auto worker = std::make_shared<Worker>(message->worker_pid(), client);
|
||||
// Register the new worker.
|
||||
worker_pool_.RegisterWorker(std::move(worker));
|
||||
}
|
||||
@@ -329,10 +329,10 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
|
||||
RAY_CHECK(worker);
|
||||
// If the worker was assigned a task, mark it as finished.
|
||||
if (!worker->GetAssignedTaskId().is_nil()) {
|
||||
FinishAssignedTask(worker);
|
||||
FinishAssignedTask(*worker);
|
||||
}
|
||||
// Return the worker to the idle pool.
|
||||
worker_pool_.PushWorker(worker);
|
||||
worker_pool_.PushWorker(std::move(worker));
|
||||
// Call task dispatch to assign work to the new worker.
|
||||
DispatchTasks();
|
||||
|
||||
@@ -436,14 +436,13 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
|
||||
client->ProcessMessages();
|
||||
}
|
||||
|
||||
void NodeManager::ProcessNewNodeManager(
|
||||
std::shared_ptr<TcpClientConnection> node_manager_client) {
|
||||
node_manager_client->ProcessMessages();
|
||||
void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client) {
|
||||
node_manager_client.ProcessMessages();
|
||||
}
|
||||
|
||||
void NodeManager::ProcessNodeManagerMessage(
|
||||
std::shared_ptr<TcpClientConnection> node_manager_client, int64_t message_type,
|
||||
const uint8_t *message_data) {
|
||||
void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
|
||||
int64_t message_type,
|
||||
const uint8_t *message_data) {
|
||||
switch (message_type) {
|
||||
case protocol::MessageType_ForwardTaskRequest: {
|
||||
auto message = flatbuffers::GetRoot<protocol::ForwardTaskRequest>(message_data);
|
||||
@@ -458,7 +457,7 @@ void NodeManager::ProcessNodeManagerMessage(
|
||||
default:
|
||||
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
|
||||
}
|
||||
node_manager_client->ProcessMessages();
|
||||
node_manager_client.ProcessMessages();
|
||||
}
|
||||
|
||||
void NodeManager::HandleWaitingTaskReady(const TaskID &task_id) {
|
||||
@@ -639,8 +638,8 @@ void NodeManager::AssignTask(Task &task) {
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::FinishAssignedTask(std::shared_ptr<Worker> worker) {
|
||||
TaskID task_id = worker->GetAssignedTaskId();
|
||||
void NodeManager::FinishAssignedTask(Worker &worker) {
|
||||
TaskID task_id = worker.GetAssignedTaskId();
|
||||
RAY_LOG(DEBUG) << "Finished task " << task_id;
|
||||
auto tasks = local_queues_.RemoveTasks({task_id});
|
||||
auto task = *tasks.begin();
|
||||
@@ -648,7 +647,7 @@ void NodeManager::FinishAssignedTask(std::shared_ptr<Worker> worker) {
|
||||
if (task.GetTaskSpecification().IsActorCreationTask()) {
|
||||
// If this was an actor creation task, then convert the worker to an actor.
|
||||
auto actor_id = task.GetTaskSpecification().ActorCreationId();
|
||||
worker->AssignActorId(actor_id);
|
||||
worker.AssignActorId(actor_id);
|
||||
|
||||
// Publish the actor creation event to all other nodes so that methods for
|
||||
// the actor will be forwarded directly to this node.
|
||||
@@ -684,7 +683,7 @@ void NodeManager::FinishAssignedTask(std::shared_ptr<Worker> worker) {
|
||||
}
|
||||
|
||||
// Unset the worker's assigned task.
|
||||
worker->AssignTaskId(TaskID::nil());
|
||||
worker.AssignTaskId(TaskID::nil());
|
||||
}
|
||||
|
||||
void NodeManager::ResubmitTask(const TaskID &task_id) {
|
||||
|
||||
@@ -37,7 +37,7 @@ class NodeManager {
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
|
||||
|
||||
/// Process a new client connection.
|
||||
void ProcessNewClient(std::shared_ptr<LocalClientConnection> client);
|
||||
void ProcessNewClient(LocalClientConnection &client);
|
||||
|
||||
/// Process a message from a client. This method is responsible for
|
||||
/// explicitly listening for more messages from the client if the client is
|
||||
@@ -46,12 +46,12 @@ class NodeManager {
|
||||
/// \param client The client that sent the message.
|
||||
/// \param message_type The message type (e.g., a flatbuffer enum).
|
||||
/// \param message A pointer to the message data.
|
||||
void ProcessClientMessage(std::shared_ptr<LocalClientConnection> client,
|
||||
void ProcessClientMessage(const std::shared_ptr<LocalClientConnection> &client,
|
||||
int64_t message_type, const uint8_t *message);
|
||||
|
||||
void ProcessNewNodeManager(std::shared_ptr<TcpClientConnection> node_manager_client);
|
||||
void ProcessNewNodeManager(TcpClientConnection &node_manager_client);
|
||||
|
||||
void ProcessNodeManagerMessage(std::shared_ptr<TcpClientConnection> node_manager_client,
|
||||
void ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
|
||||
int64_t message_type, const uint8_t *message);
|
||||
|
||||
ray::Status RegisterGcs();
|
||||
@@ -69,7 +69,7 @@ class NodeManager {
|
||||
/// Assign a task. The task is assumed to not be queued in local_queues_.
|
||||
void AssignTask(Task &task);
|
||||
/// Handle a worker finishing its assigned task.
|
||||
void FinishAssignedTask(std::shared_ptr<Worker> worker);
|
||||
void FinishAssignedTask(Worker &worker);
|
||||
/// Schedule tasks.
|
||||
void ScheduleTasks();
|
||||
/// Handle a task whose local dependencies were missing and are now available.
|
||||
|
||||
@@ -86,14 +86,12 @@ void Raylet::DoAcceptNodeManager() {
|
||||
|
||||
void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) {
|
||||
if (!error) {
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler =
|
||||
[this](std::shared_ptr<TcpClientConnection> client) {
|
||||
node_manager_.ProcessNewNodeManager(client);
|
||||
};
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler = [this](
|
||||
TcpClientConnection &client) { node_manager_.ProcessNewNodeManager(client); };
|
||||
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
|
||||
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
node_manager_.ProcessNodeManagerMessage(client, message_type, message);
|
||||
node_manager_.ProcessNodeManagerMessage(*client, message_type, message);
|
||||
};
|
||||
// Accept a new local client and dispatch it to the node manager.
|
||||
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
|
||||
@@ -111,9 +109,7 @@ void Raylet::DoAcceptObjectManager() {
|
||||
|
||||
void Raylet::HandleAcceptObjectManager(const boost::system::error_code &error) {
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler =
|
||||
[this](std::shared_ptr<TcpClientConnection> client) {
|
||||
object_manager_.ProcessNewClient(client);
|
||||
};
|
||||
[this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); };
|
||||
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
|
||||
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
@@ -134,9 +130,7 @@ void Raylet::HandleAccept(const boost::system::error_code &error) {
|
||||
if (!error) {
|
||||
// TODO: typedef these handlers.
|
||||
ClientHandler<boost::asio::local::stream_protocol> client_handler =
|
||||
[this](std::shared_ptr<LocalClientConnection> client) {
|
||||
node_manager_.ProcessNewClient(client);
|
||||
};
|
||||
[this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); };
|
||||
MessageHandler<boost::asio::local::stream_protocol> message_handler = [this](
|
||||
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
|
||||
@@ -87,14 +87,15 @@ void WorkerPool::StartWorker(bool force_start) {
|
||||
}
|
||||
|
||||
void WorkerPool::RegisterWorker(std::shared_ptr<Worker> worker) {
|
||||
RAY_LOG(DEBUG) << "Registering worker with pid " << worker->Pid();
|
||||
registered_workers_.push_back(worker);
|
||||
RAY_CHECK(started_worker_pids_.count(worker->Pid()) > 0);
|
||||
started_worker_pids_.erase(worker->Pid());
|
||||
auto pid = worker->Pid();
|
||||
RAY_LOG(DEBUG) << "Registering worker with pid " << pid;
|
||||
registered_workers_.push_back(std::move(worker));
|
||||
RAY_CHECK(started_worker_pids_.count(pid) > 0);
|
||||
started_worker_pids_.erase(pid);
|
||||
}
|
||||
|
||||
std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
|
||||
std::shared_ptr<LocalClientConnection> connection) const {
|
||||
const std::shared_ptr<LocalClientConnection> &connection) const {
|
||||
for (auto it = registered_workers_.begin(); it != registered_workers_.end(); it++) {
|
||||
if ((*it)->Connection() == connection) {
|
||||
return (*it);
|
||||
@@ -135,7 +136,7 @@ std::shared_ptr<Worker> WorkerPool::PopWorker(const ActorID &actor_id) {
|
||||
// A helper function to remove a worker from a list. Returns true if the worker
|
||||
// was found and removed.
|
||||
bool removeWorker(std::list<std::shared_ptr<Worker>> &worker_pool,
|
||||
std::shared_ptr<Worker> worker) {
|
||||
const std::shared_ptr<Worker> &worker) {
|
||||
for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) {
|
||||
if (*it == worker) {
|
||||
worker_pool.erase(it);
|
||||
|
||||
@@ -60,7 +60,7 @@ class WorkerPool {
|
||||
/// \return The Worker that owns the given client connection. Returns nullptr
|
||||
/// if the client has not registered a worker yet.
|
||||
std::shared_ptr<Worker> GetRegisteredWorker(
|
||||
std::shared_ptr<LocalClientConnection> connection) const;
|
||||
const std::shared_ptr<LocalClientConnection> &connection) const;
|
||||
|
||||
/// Disconnect a registered worker.
|
||||
///
|
||||
|
||||
@@ -30,8 +30,8 @@ class WorkerPoolTest : public ::testing::Test {
|
||||
WorkerPoolTest() : worker_pool_({}), io_service_() {}
|
||||
|
||||
std::shared_ptr<Worker> CreateWorker(pid_t pid) {
|
||||
std::function<void(std::shared_ptr<LocalClientConnection>)> client_handler = [this](
|
||||
std::shared_ptr<LocalClientConnection> client) { HandleNewClient(client); };
|
||||
std::function<void(LocalClientConnection &)> client_handler =
|
||||
[this](LocalClientConnection &client) { HandleNewClient(client); };
|
||||
std::function<void(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *)>
|
||||
message_handler = [this](std::shared_ptr<LocalClientConnection> client,
|
||||
int64_t message_type, const uint8_t *message) {
|
||||
@@ -49,7 +49,7 @@ class WorkerPoolTest : public ::testing::Test {
|
||||
boost::asio::io_service io_service_;
|
||||
|
||||
private:
|
||||
void HandleNewClient(std::shared_ptr<LocalClientConnection>){};
|
||||
void HandleNewClient(LocalClientConnection &){};
|
||||
void HandleMessage(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *){};
|
||||
};
|
||||
|
||||
|
||||
+1
-1
@@ -2036,7 +2036,7 @@ class GlobalStateAPI(unittest.TestCase):
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return id(ray.worker.global_worker)
|
||||
return id(ray.worker.global_worker), os.getpid()
|
||||
|
||||
# Wait until all of the workers have started.
|
||||
worker_ids = set()
|
||||
|
||||
Reference in New Issue
Block a user