Improve shared_ptr usage (#2030)

[xray] Improve shared_ptr usage
This commit is contained in:
eric-jj
2018-05-12 11:05:04 +08:00
committed by Philipp Moritz
parent a292d7ba32
commit 71997a481b
26 changed files with 221 additions and 261 deletions
+1 -1
View File
@@ -97,7 +97,7 @@ std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
std::shared_ptr<ClientConnection<T>> self(
new ClientConnection(message_handler, std::move(socket)));
// Let our manager process our new connection.
client_handler(self);
client_handler(*self);
return self;
}
+1 -1
View File
@@ -62,7 +62,7 @@ template <typename T>
class ClientConnection;
template <typename T>
using ClientHandler = std::function<void(std::shared_ptr<ClientConnection<T>>)>;
using ClientHandler = std::function<void(ClientConnection<T> &)>;
template <typename T>
using MessageHandler =
std::function<void(std::shared_ptr<ClientConnection<T>>, int64_t, const uint8_t *)>;
+1 -1
View File
@@ -7,7 +7,7 @@ namespace ray {
namespace gcs {
AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) {
context_.reset(new RedisContext());
context_ = std::make_shared<RedisContext>();
client_table_.reset(new ClientTable(context_, this, client_id));
object_table_.reset(new ObjectTable(context_, this));
actor_table_.reset(new ActorTable(context_, this));
+7 -7
View File
@@ -93,9 +93,9 @@ void TestTableLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> c
// Check that we added the correct task.
auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
const std::shared_ptr<protocol::TaskT> d) {
const protocol::TaskT &d) {
ASSERT_EQ(id, task_id);
ASSERT_EQ(data->task_specification, d->task_specification);
ASSERT_EQ(data->task_specification, d.task_specification);
};
// Check that the lookup returns the added task.
@@ -139,9 +139,9 @@ void TestLogLookup(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> cli
data->manager = manager;
// Check that we added the correct object entries.
auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id,
const std::shared_ptr<ObjectTableDataT> d) {
const ObjectTableDataT &d) {
ASSERT_EQ(id, object_id);
ASSERT_EQ(data->manager, d->manager);
ASSERT_EQ(data->manager, d.manager);
};
RAY_CHECK_OK(client->object_table().Append(job_id, object_id, data, add_callback));
}
@@ -222,7 +222,7 @@ void TestLogAppendAt(const JobID &job_id, std::shared_ptr<gcs::AsyncGcsClient> c
// Check that we added the correct task.
auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id,
const std::shared_ptr<TaskReconstructionDataT> d) {
const TaskReconstructionDataT &d) {
ASSERT_EQ(id, task_id);
test->IncrementNumCallbacks();
};
@@ -265,8 +265,8 @@ TEST_F(TestGcsWithAsio, TestLogAppendAt) {
// Task table callbacks.
void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id,
const std::shared_ptr<TaskTableDataT> data) {
ASSERT_EQ(data->scheduling_state, SchedulingState_SCHEDULED);
const TaskTableDataT &data) {
ASSERT_EQ(data.scheduling_state, SchedulingState_SCHEDULED);
}
void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id,
+8 -7
View File
@@ -100,15 +100,13 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) {
int64_t RedisCallbackManager::add(const RedisCallback &function) {
num_callbacks += 1;
callbacks_.emplace(num_callbacks, std::unique_ptr<RedisCallback>(
new RedisCallback(function)));
callbacks_.emplace(num_callbacks, function);
return num_callbacks;
}
RedisCallbackManager::RedisCallback &RedisCallbackManager::get(
int64_t callback_index) {
RedisCallback &RedisCallbackManager::get(int64_t callback_index) {
RAY_CHECK(callbacks_.find(callback_index) != callbacks_.end());
return *callbacks_[callback_index];
return callbacks_[callback_index];
}
void RedisCallbackManager::remove(int64_t callback_index) {
@@ -185,7 +183,9 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) {
Status RedisContext::RunAsync(const std::string &command, const UniqueID &id,
const uint8_t *data, int64_t length,
const TablePrefix prefix, const TablePubsub pubsub_channel,
int64_t callback_index, int log_length) {
RedisCallback redisCallback, int log_length) {
int64_t callback_index =
redisCallback != nullptr ? RedisCallbackManager::instance().add(redisCallback) : -1;
if (length > 0) {
if (log_length >= 0) {
std::string redis_command = command + " %d %d %b %b %d";
@@ -222,10 +222,11 @@ Status RedisContext::RunAsync(const std::string &command, const UniqueID &id,
Status RedisContext::SubscribeAsync(const ClientID &client_id,
const TablePubsub pubsub_channel,
int64_t callback_index) {
const RedisCallback &redisCallback) {
RAY_CHECK(pubsub_channel != TablePubsub_NO_PUBLISH)
<< "Client requested subscribe on a table that does not support pubsub";
int64_t callback_index = RedisCallbackManager::instance().add(redisCallback);
int status = 0;
if (client_id.is_nil()) {
// Subscribe to all messages.
+7 -7
View File
@@ -18,13 +18,13 @@ struct aeEventLoop;
namespace ray {
namespace gcs {
/// Every callback should take in a vector of the results from the Redis
/// operation and return a bool indicating whether the callback should be
/// deleted once called.
using RedisCallback = std::function<bool(const std::string &)>;
class RedisCallbackManager {
public:
/// Every callback should take in a vector of the results from the Redis
/// operation and return a bool indicating whether the callback should be
/// deleted once called.
using RedisCallback = std::function<bool(const std::string &)>;
static RedisCallbackManager &instance() {
static RedisCallbackManager instance;
@@ -44,7 +44,7 @@ class RedisCallbackManager {
~RedisCallbackManager() { printf("shut down callback manager\n"); }
int64_t num_callbacks;
std::unordered_map<int64_t, std::unique_ptr<RedisCallback>> callbacks_;
std::unordered_map<int64_t, RedisCallback> callbacks_;
};
class RedisContext {
@@ -70,11 +70,11 @@ class RedisContext {
/// -1 for unused. If set, then data must be provided.
Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data,
int64_t length, const TablePrefix prefix,
const TablePubsub pubsub_channel, int64_t callback_index,
const TablePubsub pubsub_channel, RedisCallback redisCallback,
int log_length = -1);
Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel,
int64_t callback_index);
const RedisCallback &redisCallback);
redisAsyncContext *async_context() { return async_context_; }
redisAsyncContext *subscribe_context() { return subscribe_context_; };
+88 -108
View File
@@ -9,76 +9,66 @@ namespace gcs {
template <typename ID, typename Data>
Status Log<ID, Data>::Append(const JobID &job_id, const ID &id,
std::shared_ptr<DataT> data, const WriteCallback &done) {
auto d = std::shared_ptr<CallbackData>(
new CallbackData({id, data, nullptr, nullptr, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d, done](const std::string &data) {
RAY_CHECK(data.empty());
if (done != nullptr) {
(done)(d->client, d->id, d->data);
}
return true;
});
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
auto callback = [this, id, dataT, done](const std::string &data) {
RAY_CHECK(data.empty());
if (done != nullptr) {
(done)(client_, id, *dataT);
}
return true;
};
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, data.get()));
fbb.Finish(Data::Pack(fbb, dataT.get()));
return context_->RunAsync("RAY.TABLE_APPEND", id, fbb.GetBufferPointer(), fbb.GetSize(),
prefix_, pubsub_channel_, callback_index);
prefix_, pubsub_channel_, std::move(callback));
}
template <typename ID, typename Data>
Status Log<ID, Data>::AppendAt(const JobID &job_id, const ID &id,
std::shared_ptr<DataT> data, const WriteCallback &done,
std::shared_ptr<DataT> &dataT, const WriteCallback &done,
const WriteCallback &failure, int log_length) {
auto d = std::shared_ptr<CallbackData>(
new CallbackData({id, data, nullptr, nullptr, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d, done, failure](const std::string &data) {
if (data.empty()) {
if (done != nullptr) {
(done)(d->client, d->id, d->data);
}
} else {
if (failure != nullptr) {
(failure)(d->client, d->id, d->data);
}
}
return true;
});
auto callback = [this, id, dataT, done, failure](const std::string &data) {
if (data.empty()) {
if (done != nullptr) {
(done)(client_, id, *dataT);
}
} else {
if (failure != nullptr) {
(failure)(client_, id, *dataT);
}
}
return true;
};
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, data.get()));
fbb.Finish(Data::Pack(fbb, dataT.get()));
return context_->RunAsync("RAY.TABLE_APPEND", id, fbb.GetBufferPointer(), fbb.GetSize(),
prefix_, pubsub_channel_, callback_index, log_length);
prefix_, pubsub_channel_, std::move(callback), log_length);
}
template <typename ID, typename Data>
Status Log<ID, Data>::Lookup(const JobID &job_id, const ID &id, const Callback &lookup) {
auto d = std::shared_ptr<CallbackData>(
new CallbackData({id, nullptr, lookup, nullptr, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d](const std::string &data) {
if (d->callback != nullptr) {
std::vector<DataT> results;
if (!data.empty()) {
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
RAY_CHECK(from_flatbuf(*root->id()) == d->id);
for (size_t i = 0; i < root->entries()->size(); i++) {
DataT result;
auto data_root =
flatbuffers::GetRoot<Data>(root->entries()->Get(i)->data());
data_root->UnPackTo(&result);
results.emplace_back(std::move(result));
}
}
(d->callback)(d->client, d->id, results);
auto callback = [this, id, lookup](const std::string &data) {
if (lookup != nullptr) {
std::vector<DataT> results;
if (!data.empty()) {
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
RAY_CHECK(from_flatbuf(*root->id()) == id);
for (size_t i = 0; i < root->entries()->size(); i++) {
DataT result;
auto data_root = flatbuffers::GetRoot<Data>(root->entries()->Get(i)->data());
data_root->UnPackTo(&result);
results.emplace_back(std::move(result));
}
return true;
});
}
lookup(client_, id, results);
}
return true;
};
std::vector<uint8_t> nil;
return context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), prefix_,
pubsub_channel_, callback_index);
pubsub_channel_, std::move(callback));
}
template <typename ID, typename Data>
@@ -87,42 +77,38 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
const SubscriptionCallback &done) {
RAY_CHECK(subscribe_callback_index_ == -1)
<< "Client called Subscribe twice on the same table";
auto d = std::shared_ptr<CallbackData>(
new CallbackData({client_id, nullptr, subscribe, done, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d](const std::string &data) {
if (data.empty()) {
// No notification data is provided. This is the callback for the
// initial subscription request.
if (d->subscription_callback != nullptr) {
(d->subscription_callback)(d->client);
}
} else {
// Data is provided. This is the callback for a message.
if (d->callback != nullptr) {
// Parse the notification.
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
ID id = UniqueID::nil();
if (root->id()->size() > 0) {
id = from_flatbuf(*root->id());
}
std::vector<DataT> results;
for (size_t i = 0; i < root->entries()->size(); i++) {
DataT result;
auto data_root =
flatbuffers::GetRoot<Data>(root->entries()->Get(i)->data());
data_root->UnPackTo(&result);
results.emplace_back(std::move(result));
}
(d->callback)(d->client, id, results);
}
auto callback = [this, subscribe, done](const std::string &data) {
if (data.empty()) {
// No notification data is provided. This is the callback for the
// initial subscription request.
if (done != nullptr) {
done(client_);
}
} else {
// Data is provided. This is the callback for a message.
if (subscribe != nullptr) {
// Parse the notification.
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
ID id = UniqueID::nil();
if (root->id()->size() > 0) {
id = from_flatbuf(*root->id());
}
// We do not delete the callback after calling it since there may be
// more subscription messages.
return false;
});
subscribe_callback_index_ = callback_index;
return context_->SubscribeAsync(client_id, pubsub_channel_, callback_index);
std::vector<DataT> results;
for (size_t i = 0; i < root->entries()->size(); i++) {
DataT result;
auto data_root = flatbuffers::GetRoot<Data>(root->entries()->Get(i)->data());
data_root->UnPackTo(&result);
results.emplace_back(std::move(result));
}
subscribe(client_, id, results);
}
}
// We do not delete the callback after calling it since there may be
// more subscription messages.
return false;
};
subscribe_callback_index_ = 1;
return context_->SubscribeAsync(client_id, pubsub_channel_, std::move(callback));
}
template <typename ID, typename Data>
@@ -131,8 +117,7 @@ Status Log<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id,
RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client requested notifications on a key before Subscribe completed";
return context_->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, client_id.data(),
client_id.size(), prefix_, pubsub_channel_,
/*callback_index=*/-1);
client_id.size(), prefix_, pubsub_channel_, nullptr);
}
template <typename ID, typename Data>
@@ -141,27 +126,23 @@ Status Log<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id,
RAY_CHECK(subscribe_callback_index_ >= 0)
<< "Client canceled notifications on a key before Subscribe completed";
return context_->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, client_id.data(),
client_id.size(), prefix_, pubsub_channel_,
/*callback_index=*/-1);
client_id.size(), prefix_, pubsub_channel_, nullptr);
}
template <typename ID, typename Data>
Status Table<ID, Data>::Add(const JobID &job_id, const ID &id,
std::shared_ptr<DataT> data, const WriteCallback &done) {
auto d = std::shared_ptr<CallbackData>(
new CallbackData({id, data, nullptr, nullptr, this, client_}));
int64_t callback_index =
RedisCallbackManager::instance().add([d, done](const std::string &data) {
if (done != nullptr) {
(done)(d->client, d->id, d->data);
}
return true;
});
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
auto callback = [this, id, dataT, done](const std::string &data) {
if (done != nullptr) {
(done)(client_, id, *dataT);
}
return true;
};
flatbuffers::FlatBufferBuilder fbb;
fbb.ForceDefaults(true);
fbb.Finish(Data::Pack(fbb, data.get()));
fbb.Finish(Data::Pack(fbb, dataT.get()));
return context_->RunAsync("RAY.TABLE_ADD", id, fbb.GetBufferPointer(), fbb.GetSize(),
prefix_, pubsub_channel_, callback_index);
prefix_, pubsub_channel_, std::move(callback));
}
template <typename ID, typename Data>
@@ -259,9 +240,8 @@ void ClientTable::HandleNotification(AsyncGcsClient *client,
}
}
void ClientTable::HandleConnected(AsyncGcsClient *client,
const std::shared_ptr<ClientTableDataT> data) {
auto connected_client_id = ClientID::from_binary(data->client_id);
void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) {
auto connected_client_id = ClientID::from_binary(data.client_id);
RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " "
<< client_id_;
}
@@ -282,7 +262,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) {
// Callback to handle our own successful connection once we've added
// ourselves.
auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key,
std::shared_ptr<ClientTableDataT> data) {
const ClientTableDataT &data) {
RAY_CHECK(log_key == client_log_key_);
HandleConnected(client, data);
@@ -311,7 +291,7 @@ Status ClientTable::Disconnect() {
auto data = std::make_shared<ClientTableDataT>(local_client_);
data->is_insertion = false;
auto add_callback = [this](AsyncGcsClient *client, const ClientID &id,
std::shared_ptr<ClientTableDataT> data) {
const ClientTableDataT &data) {
HandleConnected(client, data);
RAY_CHECK_OK(CancelNotifications(JobID::nil(), client_log_key_, id));
};
+15 -28
View File
@@ -57,8 +57,8 @@ class Log : virtual public PubsubInterface<ID> {
using Callback = std::function<void(AsyncGcsClient *client, const ID &id,
const std::vector<DataT> &data)>;
/// The callback to call when a write to a key succeeds.
using WriteCallback = std::function<void(AsyncGcsClient *client, const ID &id,
std::shared_ptr<DataT> data)>;
using WriteCallback =
std::function<void(AsyncGcsClient *client, const ID &id, const DataT &data)>;
/// The callback to call when a SUBSCRIBE call completes and we are ready to
/// request and receive notifications.
using SubscriptionCallback = std::function<void(AsyncGcsClient *client)>;
@@ -89,7 +89,7 @@ class Log : virtual public PubsubInterface<ID> {
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
Status Append(const JobID &job_id, const ID &id, std::shared_ptr<DataT> data,
Status Append(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done);
/// Append a log entry to a key if and only if the log has the given number
@@ -105,7 +105,7 @@ class Log : virtual public PubsubInterface<ID> {
/// \param log_length The number of entries that the log must have for the
/// append to succeed.
/// \return Status
Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr<DataT> data,
Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done, const WriteCallback &failure,
int log_length);
@@ -187,7 +187,7 @@ class TableInterface {
public:
using DataT = typename Data::NativeTableType;
using WriteCallback = typename Log<ID, Data>::WriteCallback;
virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr<DataT> data,
virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr<DataT> &data,
const WriteCallback &done) = 0;
virtual ~TableInterface(){};
};
@@ -212,17 +212,6 @@ class Table : private Log<ID, Data>,
/// request and receive notifications.
using SubscriptionCallback = typename Log<ID, Data>::SubscriptionCallback;
struct CallbackData {
ID id;
std::shared_ptr<DataT> data;
Callback callback;
// An optional callback to call for subscription operations, where the
// first message is a notification of subscription success.
SubscriptionCallback subscription_callback;
Log<ID, Data> *log;
AsyncGcsClient *client;
};
Table(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log<ID, Data>(context, client) {}
@@ -237,7 +226,7 @@ class Table : private Log<ID, Data>,
/// \param done Callback that is called once the data has been written to the
/// GCS.
/// \return Status
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<DataT> data,
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<DataT> &data,
const WriteCallback &done);
/// Lookup an entry asynchronously.
@@ -358,19 +347,18 @@ class TaskTable : public Table<TaskID, TaskTableData> {
Status TestAndUpdate(const JobID &job_id, const TaskID &id,
std::shared_ptr<TaskTableTestAndUpdateT> data,
const TestAndUpdateCallback &callback) {
int64_t callback_index = RedisCallbackManager::instance().add(
[this, callback, id](const std::string &data) {
auto result = std::make_shared<TaskTableDataT>();
auto root = flatbuffers::GetRoot<TaskTableData>(data.data());
root->UnPackTo(result.get());
callback(client_, id, *result, root->updated());
return true;
});
auto redisCallback = [this, callback, id](const std::string &data) {
auto result = std::make_shared<TaskTableDataT>();
auto root = flatbuffers::GetRoot<TaskTableData>(data.data());
root->UnPackTo(result.get());
callback(client_, id, *result, root->updated());
return true;
};
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(TaskTableTestAndUpdate::Pack(fbb, data.get()));
RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_TEST_AND_UPDATE", id,
fbb.GetBufferPointer(), fbb.GetSize(), prefix_,
pubsub_channel_, callback_index));
pubsub_channel_, redisCallback));
return Status::OK();
}
@@ -499,8 +487,7 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
/// Handle a client table notification.
void HandleNotification(AsyncGcsClient *client, const ClientTableDataT &notifications);
/// Handle this client's successful connection to the GCS.
void HandleConnected(AsyncGcsClient *client,
const std::shared_ptr<ClientTableDataT> client_data);
void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data);
/// The key at which the log of client information is stored. This key must
/// be kept the same across all instances of the ClientTable, so that all
+3 -3
View File
@@ -46,9 +46,9 @@ Status TaskTableAdd(AsyncGcsClient *gcs_client, Task *task) {
TaskSpec *spec = execution_spec.Spec();
auto data = MakeTaskTableData(execution_spec, Task_local_scheduler(task),
static_cast<SchedulingState>(Task_state(task)));
return gcs_client->task_table().Add(ray::JobID::nil(), TaskSpec_task_id(spec), data,
[](gcs::AsyncGcsClient *client, const TaskID &id,
std::shared_ptr<TaskTableDataT> data) {});
return gcs_client->task_table().Add(
ray::JobID::nil(), TaskSpec_task_id(spec), data,
[](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableDataT &data) {});
}
// TODO(pcm): This is a helper method that should go away once we get rid of
+13 -11
View File
@@ -53,7 +53,7 @@ ray::Status ConnectionPool::GetSender(ConnectionType type, const ClientID &clien
}
ray::Status ConnectionPool::ReleaseSender(ConnectionType type,
std::shared_ptr<SenderConnection> conn) {
std::shared_ptr<SenderConnection> &conn) {
std::unique_lock<std::mutex> guard(connection_mutex);
SenderMapType &conn_map = (type == ConnectionType::MESSAGE)
? available_message_send_connections_
@@ -64,20 +64,21 @@ ray::Status ConnectionPool::ReleaseSender(ConnectionType type,
void ConnectionPool::Add(ReceiverMapType &conn_map, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> conn) {
conn_map[client_id].push_back(conn);
conn_map[client_id].push_back(std::move(conn));
}
void ConnectionPool::Add(SenderMapType &conn_map, const ClientID &client_id,
std::shared_ptr<SenderConnection> conn) {
conn_map[client_id].push_back(conn);
conn_map[client_id].push_back(std::move(conn));
}
void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> conn) {
if (conn_map.count(client_id) == 0) {
std::shared_ptr<TcpClientConnection> &conn) {
auto it = conn_map.find(client_id);
if (it == conn_map.end()) {
return;
}
std::vector<std::shared_ptr<TcpClientConnection>> &connections = conn_map[client_id];
auto &connections = it->second;
int64_t pos =
std::find(connections.begin(), connections.end(), conn) - connections.begin();
if (pos >= (int64_t)connections.size()) {
@@ -87,15 +88,16 @@ void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id
}
uint64_t ConnectionPool::Count(SenderMapType &conn_map, const ClientID &client_id) {
if (conn_map.count(client_id) == 0) {
auto it = conn_map.find(client_id);
if (it == conn_map.end()) {
return 0;
};
return conn_map[client_id].size();
}
return it->second.size();
}
std::shared_ptr<SenderConnection> ConnectionPool::Borrow(SenderMapType &conn_map,
const ClientID &client_id) {
std::shared_ptr<SenderConnection> conn = conn_map[client_id].back();
std::shared_ptr<SenderConnection> conn = std::move(conn_map[client_id].back());
conn_map[client_id].pop_back();
RAY_LOG(DEBUG) << "Borrow " << client_id << " " << conn_map[client_id].size();
return conn;
@@ -103,7 +105,7 @@ std::shared_ptr<SenderConnection> ConnectionPool::Borrow(SenderMapType &conn_map
void ConnectionPool::Return(SenderMapType &conn_map, const ClientID &client_id,
std::shared_ptr<SenderConnection> conn) {
conn_map[client_id].push_back(conn);
conn_map[client_id].push_back(std::move(conn));
RAY_LOG(DEBUG) << "Return " << client_id << " " << conn_map[client_id].size();
}
+2 -2
View File
@@ -74,7 +74,7 @@ class ConnectionPool {
/// \param type The type of connection.
/// \param conn The actual connection.
/// \return Status of invoking this method.
ray::Status ReleaseSender(ConnectionType type, std::shared_ptr<SenderConnection> conn);
ray::Status ReleaseSender(ConnectionType type, std::shared_ptr<SenderConnection> &conn);
// TODO(hme): Implement with error handling.
/// Remove a sender connection. This is invoked if the connection is no longer
@@ -106,7 +106,7 @@ class ConnectionPool {
/// Removes the given receiver for ClientID from the given map.
void Remove(ReceiverMapType &conn_map, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> conn);
std::shared_ptr<TcpClientConnection> &conn);
/// Returns the count of sender connections to ClientID.
uint64_t Count(SenderMapType &conn_map, const ClientID &client_id);
+2 -2
View File
@@ -16,8 +16,8 @@ ray::Status ObjectDirectory::ReportObjectAdded(const ObjectID &object_id,
data->is_eviction = false;
data->object_size = object_info.data_size;
ray::Status status = gcs_client_->object_table().Append(
job_id, object_id, data, [](gcs::AsyncGcsClient *client, const UniqueID &id,
const std::shared_ptr<ObjectTableDataT> data) {
job_id, object_id, data,
[](gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &data) {
// Do nothing.
});
return status;
+15 -15
View File
@@ -110,8 +110,8 @@ ray::Status ObjectManager::Pull(const ObjectID &object_id) {
}
void ObjectManager::SchedulePull(const ObjectID &object_id, int wait_ms) {
pull_requests_[object_id] = std::shared_ptr<boost::asio::deadline_timer>(
new asio::deadline_timer(*main_service_, boost::posix_time::milliseconds(wait_ms)));
pull_requests_[object_id] = std::make_shared<boost::asio::deadline_timer>(
*main_service_, boost::posix_time::milliseconds(wait_ms));
pull_requests_[object_id]->async_wait(
[this, object_id](const boost::system::error_code &error_code) {
pull_requests_.erase(object_id);
@@ -184,7 +184,7 @@ ray::Status ObjectManager::PullEstablishConnection(const ObjectID &object_id,
}
ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id,
std::shared_ptr<SenderConnection> conn) {
std::shared_ptr<SenderConnection> &conn) {
flatbuffers::FlatBufferBuilder fbb;
auto message = object_manager_protocol::CreatePullRequestMessage(
fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary()));
@@ -209,7 +209,7 @@ ray::Status ObjectManager::Push(const ObjectID &object_id, const ClientID &clien
Status status = object_directory_->GetInformation(
client_id,
[this, object_id, client_id](const RemoteConnectionInfo &info) {
ObjectInfoT object_info = local_objects_[object_id];
const ObjectInfoT &object_info = local_objects_[object_id];
uint64_t data_size =
static_cast<uint64_t>(object_info.data_size + object_info.metadata_size);
uint64_t metadata_size = static_cast<uint64_t>(object_info.metadata_size);
@@ -251,7 +251,7 @@ void ObjectManager::ExecuteSendObject(const ClientID &client_id,
ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id,
uint64_t data_size, uint64_t metadata_size,
uint64_t chunk_index,
std::shared_ptr<SenderConnection> conn) {
std::shared_ptr<SenderConnection> &conn) {
std::pair<const ObjectBufferPool::ChunkInfo &, ray::Status> chunk_status =
buffer_pool_.GetChunk(object_id, data_size, metadata_size, chunk_index);
ObjectBufferPool::ChunkInfo chunk_info = chunk_status.first;
@@ -276,7 +276,7 @@ ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id,
ray::Status ObjectManager::SendObjectData(const ObjectID &object_id,
const ObjectBufferPool::ChunkInfo &chunk_info,
std::shared_ptr<SenderConnection> conn) {
std::shared_ptr<SenderConnection> &conn) {
boost::system::error_code ec;
std::vector<asio::const_buffer> buffer;
buffer.push_back(asio::buffer(chunk_info.data, chunk_info.buffer_length));
@@ -328,11 +328,11 @@ std::shared_ptr<SenderConnection> ObjectManager::CreateSenderConnection(
return conn;
}
void ObjectManager::ProcessNewClient(std::shared_ptr<TcpClientConnection> conn) {
conn->ProcessMessages();
void ObjectManager::ProcessNewClient(TcpClientConnection &conn) {
conn.ProcessMessages();
}
void ObjectManager::ProcessClientMessage(std::shared_ptr<TcpClientConnection> conn,
void ObjectManager::ProcessClientMessage(std::shared_ptr<TcpClientConnection> &conn,
int64_t message_type, const uint8_t *message) {
switch (message_type) {
case object_manager_protocol::MessageType_PushRequest: {
@@ -389,7 +389,7 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr<TcpClientConnection> &con
conn->ProcessMessages();
}
void ObjectManager::ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn,
void ObjectManager::ReceivePushRequest(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message) {
// Serialize.
auto object_header =
@@ -400,14 +400,14 @@ void ObjectManager::ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn
uint64_t metadata_size = object_header->metadata_size();
receive_service_.post([this, object_id, data_size, metadata_size, chunk_index, conn]() {
ExecuteReceiveObject(conn->GetClientID(), object_id, data_size, metadata_size,
chunk_index, conn);
chunk_index, *conn);
});
}
void ObjectManager::ExecuteReceiveObject(const ClientID &client_id,
const ObjectID &object_id, uint64_t data_size,
uint64_t metadata_size, uint64_t chunk_index,
std::shared_ptr<TcpClientConnection> conn) {
TcpClientConnection &conn) {
RAY_LOG(DEBUG) << "ExecuteReceiveObject " << client_id << " " << object_id << " "
<< chunk_index;
@@ -419,7 +419,7 @@ void ObjectManager::ExecuteReceiveObject(const ClientID &client_id,
std::vector<boost::asio::mutable_buffer> buffer;
buffer.push_back(asio::buffer(chunk_info.data, chunk_info.buffer_length));
boost::system::error_code ec;
conn->ReadBuffer(buffer, ec);
conn.ReadBuffer(buffer, ec);
if (ec.value() == 0) {
buffer_pool_.SealChunk(object_id, chunk_index);
} else {
@@ -435,13 +435,13 @@ void ObjectManager::ExecuteReceiveObject(const ClientID &client_id,
std::vector<boost::asio::mutable_buffer> buffer;
buffer.push_back(asio::buffer(mutable_vec, buffer_length));
boost::system::error_code ec;
conn->ReadBuffer(buffer, ec);
conn.ReadBuffer(buffer, ec);
if (ec.value() != 0) {
RAY_LOG(ERROR) << ec.message();
}
// TODO(hme): If the object isn't local, create a pull request for this chunk.
}
conn->ProcessMessages();
conn.ProcessMessages();
RAY_LOG(DEBUG) << "ReceiveCompleted " << client_id_ << " " << object_id << " "
<< "/" << config_.max_receives;
}
+7 -8
View File
@@ -110,7 +110,7 @@ class ObjectManager {
///
/// \param conn The connection.
/// \return Status of whether the connection was successfully established.
void ProcessNewClient(std::shared_ptr<TcpClientConnection> conn);
void ProcessNewClient(TcpClientConnection &conn);
/// Process messages sent from other nodes. We only establish
/// transfer connections using this method; all other transfer communication
@@ -119,7 +119,7 @@ class ObjectManager {
/// \param conn The connection.
/// \param message_type The message type.
/// \param message A pointer set to the beginning of the message.
void ProcessClientMessage(std::shared_ptr<TcpClientConnection> conn,
void ProcessClientMessage(std::shared_ptr<TcpClientConnection> &conn,
int64_t message_type, const uint8_t *message);
/// Cancels all requests (Push/Pull) associated with the given ObjectID.
@@ -226,7 +226,7 @@ class ObjectManager {
/// Synchronously send a pull request via remote object manager connection.
/// Executes on main_service_ thread.
ray::Status PullSendRequest(const ObjectID &object_id,
std::shared_ptr<SenderConnection> conn);
std::shared_ptr<SenderConnection> &conn);
std::shared_ptr<SenderConnection> CreateSenderConnection(
ConnectionPool::ConnectionType type, RemoteConnectionInfo info);
@@ -241,23 +241,22 @@ class ObjectManager {
/// Executes on send_service_ thread pool.
ray::Status SendObjectHeaders(const ObjectID &object_id, uint64_t data_size,
uint64_t metadata_size, uint64_t chunk_index,
std::shared_ptr<SenderConnection> conn);
std::shared_ptr<SenderConnection> &conn);
/// This method initiates the actual object transfer.
/// Executes on send_service_ thread pool.
ray::Status SendObjectData(const ObjectID &object_id,
const ObjectBufferPool::ChunkInfo &chunk_info,
std::shared_ptr<SenderConnection> conn);
std::shared_ptr<SenderConnection> &conn);
/// Invoked when a remote object manager pushes an object to this object manager.
/// This will invoke the object receive on the receive_service_ thread pool.
void ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn,
void ReceivePushRequest(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message);
/// Execute a receive on the receive_service_ thread pool.
void ExecuteReceiveObject(const ClientID &client_id, const ObjectID &object_id,
uint64_t data_size, uint64_t metadata_size,
uint64_t chunk_index,
std::shared_ptr<TcpClientConnection> conn);
uint64_t chunk_index, TcpClientConnection &conn);
/// Handles receiving a pull request message.
void ReceivePullRequest(std::shared_ptr<TcpClientConnection> &conn,
@@ -11,7 +11,7 @@ std::shared_ptr<SenderConnection> SenderConnection::Create(
RAY_CHECK_OK(TcpConnect(socket, ip, port));
std::shared_ptr<TcpServerConnection> conn =
std::make_shared<TcpServerConnection>(std::move(socket));
return std::make_shared<SenderConnection>(conn, client_id);
return std::make_shared<SenderConnection>(std::move(conn), client_id);
};
SenderConnection::SenderConnection(std::shared_ptr<TcpServerConnection> conn,
@@ -65,9 +65,7 @@ class MockServer {
void HandleAcceptObjectManager(const boost::system::error_code &error) {
ClientHandler<boost::asio::ip::tcp> client_handler =
[this](std::shared_ptr<TcpClientConnection> client) {
object_manager_.ProcessNewClient(client);
};
[this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); };
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
const uint8_t *message) {
@@ -56,9 +56,7 @@ class MockServer {
void HandleAcceptObjectManager(const boost::system::error_code &error) {
ClientHandler<boost::asio::ip::tcp> client_handler =
[this](std::shared_ptr<TcpClientConnection> client) {
object_manager_.ProcessNewClient(client);
};
[this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); };
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
const uint8_t *message) {
+3 -2
View File
@@ -258,8 +258,9 @@ Status LineageCache::Flush() {
// Write back all ready tasks whose arguments have been committed to the GCS.
gcs::raylet::TaskTable::WriteCallback task_callback = [this](
ray::gcs::AsyncGcsClient *client, const TaskID &id,
const std::shared_ptr<protocol::TaskT> data) { HandleEntryCommitted(id); };
ray::gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) {
HandleEntryCommitted(id);
};
for (const auto &ready_task_id : ready_task_ids) {
auto task = lineage_.GetEntry(ready_task_id);
// TODO(swang): Make this better...
+4 -4
View File
@@ -23,7 +23,7 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,
}
Status Add(const JobID &job_id, const TaskID &task_id,
std::shared_ptr<protocol::TaskT> task_data,
std::shared_ptr<protocol::TaskT> &task_data,
const gcs::TableInterface<TaskID, protocol::Task>::WriteCallback &done) {
task_table_[task_id] = task_data;
callbacks_.push_back(
@@ -38,7 +38,7 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,
bool send_notification = (subscribed_tasks_.count(task_id) == 1);
auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client,
const TaskID &task_id,
std::shared_ptr<protocol::TaskT> data) {
const protocol::TaskT &data) {
if (send_notification) {
notification_callback_(client, task_id, data);
}
@@ -63,7 +63,7 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,
void Flush() {
for (const auto &callback : callbacks_) {
callback.first(NULL, callback.second, task_table_[callback.second]);
callback.first(NULL, callback.second, *task_table_[callback.second]);
}
callbacks_.clear();
}
@@ -86,7 +86,7 @@ class LineageCacheTest : public ::testing::Test {
LineageCacheTest()
: mock_gcs_(), lineage_cache_(ClientID::from_random(), mock_gcs_, mock_gcs_) {
mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id,
std::shared_ptr<ray::protocol::TaskT> data) {
const ray::protocol::TaskT &data) {
lineage_cache_.HandleEntryCommitted(task_id);
});
}
+19 -20
View File
@@ -156,7 +156,7 @@ void NodeManager::Heartbeat() {
ray::Status status = heartbeat_table.Add(
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data,
[](ray::gcs::AsyncGcsClient *client, const ClientID &id,
std::shared_ptr<HeartbeatTableDataT> data) {
const HeartbeatTableDataT &data) {
RAY_LOG(DEBUG) << "[HEARTBEAT] heartbeat sent callback";
});
@@ -279,9 +279,9 @@ void NodeManager::HandleActorCreation(const ActorID &actor_id,
}
}
void NodeManager::ProcessNewClient(std::shared_ptr<LocalClientConnection> client) {
void NodeManager::ProcessNewClient(LocalClientConnection &client) {
// The new client is a worker, so begin listening for messages.
client->ProcessMessages();
client.ProcessMessages();
}
void NodeManager::DispatchTasks() {
@@ -309,9 +309,9 @@ void NodeManager::DispatchTasks() {
}
}
void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> client,
int64_t message_type,
const uint8_t *message_data) {
void NodeManager::ProcessClientMessage(
const std::shared_ptr<LocalClientConnection> &client, int64_t message_type,
const uint8_t *message_data) {
RAY_LOG(DEBUG) << "Message of type " << message_type;
switch (message_type) {
@@ -319,7 +319,7 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
if (message->is_worker()) {
// Create a new worker from the registration request.
std::shared_ptr<Worker> worker(new Worker(message->worker_pid(), client));
auto worker = std::make_shared<Worker>(message->worker_pid(), client);
// Register the new worker.
worker_pool_.RegisterWorker(std::move(worker));
}
@@ -329,10 +329,10 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
RAY_CHECK(worker);
// If the worker was assigned a task, mark it as finished.
if (!worker->GetAssignedTaskId().is_nil()) {
FinishAssignedTask(worker);
FinishAssignedTask(*worker);
}
// Return the worker to the idle pool.
worker_pool_.PushWorker(worker);
worker_pool_.PushWorker(std::move(worker));
// Call task dispatch to assign work to the new worker.
DispatchTasks();
@@ -436,14 +436,13 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
client->ProcessMessages();
}
void NodeManager::ProcessNewNodeManager(
std::shared_ptr<TcpClientConnection> node_manager_client) {
node_manager_client->ProcessMessages();
void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client) {
node_manager_client.ProcessMessages();
}
void NodeManager::ProcessNodeManagerMessage(
std::shared_ptr<TcpClientConnection> node_manager_client, int64_t message_type,
const uint8_t *message_data) {
void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
int64_t message_type,
const uint8_t *message_data) {
switch (message_type) {
case protocol::MessageType_ForwardTaskRequest: {
auto message = flatbuffers::GetRoot<protocol::ForwardTaskRequest>(message_data);
@@ -458,7 +457,7 @@ void NodeManager::ProcessNodeManagerMessage(
default:
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
}
node_manager_client->ProcessMessages();
node_manager_client.ProcessMessages();
}
void NodeManager::HandleWaitingTaskReady(const TaskID &task_id) {
@@ -639,8 +638,8 @@ void NodeManager::AssignTask(Task &task) {
}
}
void NodeManager::FinishAssignedTask(std::shared_ptr<Worker> worker) {
TaskID task_id = worker->GetAssignedTaskId();
void NodeManager::FinishAssignedTask(Worker &worker) {
TaskID task_id = worker.GetAssignedTaskId();
RAY_LOG(DEBUG) << "Finished task " << task_id;
auto tasks = local_queues_.RemoveTasks({task_id});
auto task = *tasks.begin();
@@ -648,7 +647,7 @@ void NodeManager::FinishAssignedTask(std::shared_ptr<Worker> worker) {
if (task.GetTaskSpecification().IsActorCreationTask()) {
// If this was an actor creation task, then convert the worker to an actor.
auto actor_id = task.GetTaskSpecification().ActorCreationId();
worker->AssignActorId(actor_id);
worker.AssignActorId(actor_id);
// Publish the actor creation event to all other nodes so that methods for
// the actor will be forwarded directly to this node.
@@ -684,7 +683,7 @@ void NodeManager::FinishAssignedTask(std::shared_ptr<Worker> worker) {
}
// Unset the worker's assigned task.
worker->AssignTaskId(TaskID::nil());
worker.AssignTaskId(TaskID::nil());
}
void NodeManager::ResubmitTask(const TaskID &task_id) {
+5 -5
View File
@@ -37,7 +37,7 @@ class NodeManager {
std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
/// Process a new client connection.
void ProcessNewClient(std::shared_ptr<LocalClientConnection> client);
void ProcessNewClient(LocalClientConnection &client);
/// Process a message from a client. This method is responsible for
/// explicitly listening for more messages from the client if the client is
@@ -46,12 +46,12 @@ class NodeManager {
/// \param client The client that sent the message.
/// \param message_type The message type (e.g., a flatbuffer enum).
/// \param message A pointer to the message data.
void ProcessClientMessage(std::shared_ptr<LocalClientConnection> client,
void ProcessClientMessage(const std::shared_ptr<LocalClientConnection> &client,
int64_t message_type, const uint8_t *message);
void ProcessNewNodeManager(std::shared_ptr<TcpClientConnection> node_manager_client);
void ProcessNewNodeManager(TcpClientConnection &node_manager_client);
void ProcessNodeManagerMessage(std::shared_ptr<TcpClientConnection> node_manager_client,
void ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
int64_t message_type, const uint8_t *message);
ray::Status RegisterGcs();
@@ -69,7 +69,7 @@ class NodeManager {
/// Assign a task. The task is assumed to not be queued in local_queues_.
void AssignTask(Task &task);
/// Handle a worker finishing its assigned task.
void FinishAssignedTask(std::shared_ptr<Worker> worker);
void FinishAssignedTask(Worker &worker);
/// Schedule tasks.
void ScheduleTasks();
/// Handle a task whose local dependencies were missing and are now available.
+5 -11
View File
@@ -86,14 +86,12 @@ void Raylet::DoAcceptNodeManager() {
void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) {
if (!error) {
ClientHandler<boost::asio::ip::tcp> client_handler =
[this](std::shared_ptr<TcpClientConnection> client) {
node_manager_.ProcessNewNodeManager(client);
};
ClientHandler<boost::asio::ip::tcp> client_handler = [this](
TcpClientConnection &client) { node_manager_.ProcessNewNodeManager(client); };
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
const uint8_t *message) {
node_manager_.ProcessNodeManagerMessage(client, message_type, message);
node_manager_.ProcessNodeManagerMessage(*client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
@@ -111,9 +109,7 @@ void Raylet::DoAcceptObjectManager() {
void Raylet::HandleAcceptObjectManager(const boost::system::error_code &error) {
ClientHandler<boost::asio::ip::tcp> client_handler =
[this](std::shared_ptr<TcpClientConnection> client) {
object_manager_.ProcessNewClient(client);
};
[this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); };
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
const uint8_t *message) {
@@ -134,9 +130,7 @@ void Raylet::HandleAccept(const boost::system::error_code &error) {
if (!error) {
// TODO: typedef these handlers.
ClientHandler<boost::asio::local::stream_protocol> client_handler =
[this](std::shared_ptr<LocalClientConnection> client) {
node_manager_.ProcessNewClient(client);
};
[this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); };
MessageHandler<boost::asio::local::stream_protocol> message_handler = [this](
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
const uint8_t *message) {
+7 -6
View File
@@ -87,14 +87,15 @@ void WorkerPool::StartWorker(bool force_start) {
}
void WorkerPool::RegisterWorker(std::shared_ptr<Worker> worker) {
RAY_LOG(DEBUG) << "Registering worker with pid " << worker->Pid();
registered_workers_.push_back(worker);
RAY_CHECK(started_worker_pids_.count(worker->Pid()) > 0);
started_worker_pids_.erase(worker->Pid());
auto pid = worker->Pid();
RAY_LOG(DEBUG) << "Registering worker with pid " << pid;
registered_workers_.push_back(std::move(worker));
RAY_CHECK(started_worker_pids_.count(pid) > 0);
started_worker_pids_.erase(pid);
}
std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
std::shared_ptr<LocalClientConnection> connection) const {
const std::shared_ptr<LocalClientConnection> &connection) const {
for (auto it = registered_workers_.begin(); it != registered_workers_.end(); it++) {
if ((*it)->Connection() == connection) {
return (*it);
@@ -135,7 +136,7 @@ std::shared_ptr<Worker> WorkerPool::PopWorker(const ActorID &actor_id) {
// A helper function to remove a worker from a list. Returns true if the worker
// was found and removed.
bool removeWorker(std::list<std::shared_ptr<Worker>> &worker_pool,
std::shared_ptr<Worker> worker) {
const std::shared_ptr<Worker> &worker) {
for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) {
if (*it == worker) {
worker_pool.erase(it);
+1 -1
View File
@@ -60,7 +60,7 @@ class WorkerPool {
/// \return The Worker that owns the given client connection. Returns nullptr
/// if the client has not registered a worker yet.
std::shared_ptr<Worker> GetRegisteredWorker(
std::shared_ptr<LocalClientConnection> connection) const;
const std::shared_ptr<LocalClientConnection> &connection) const;
/// Disconnect a registered worker.
///
+3 -3
View File
@@ -30,8 +30,8 @@ class WorkerPoolTest : public ::testing::Test {
WorkerPoolTest() : worker_pool_({}), io_service_() {}
std::shared_ptr<Worker> CreateWorker(pid_t pid) {
std::function<void(std::shared_ptr<LocalClientConnection>)> client_handler = [this](
std::shared_ptr<LocalClientConnection> client) { HandleNewClient(client); };
std::function<void(LocalClientConnection &)> client_handler =
[this](LocalClientConnection &client) { HandleNewClient(client); };
std::function<void(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *)>
message_handler = [this](std::shared_ptr<LocalClientConnection> client,
int64_t message_type, const uint8_t *message) {
@@ -49,7 +49,7 @@ class WorkerPoolTest : public ::testing::Test {
boost::asio::io_service io_service_;
private:
void HandleNewClient(std::shared_ptr<LocalClientConnection>){};
void HandleNewClient(LocalClientConnection &){};
void HandleMessage(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *){};
};
+1 -1
View File
@@ -2036,7 +2036,7 @@ class GlobalStateAPI(unittest.TestCase):
@ray.remote
def f():
return id(ray.worker.global_worker)
return id(ray.worker.global_worker), os.getpid()
# Wait until all of the workers have started.
worker_ids = set()