XRay Task Forwarding Milestone (#1785)

Summary:
Able to run 1000 tasks with object dependencies on a set of distributed Raylets.

Raylet Changes:

Finalized ClientConnection class.
Task forwarding.
NM-to-NM heartbeats.
NM resource accounting for tasks.
Simple scheduling policy with task forwarding.
Creating and maintaining NM 2 NM long-lived connections and reusing them for task forwarding.
LineageCache Changes:

LineageCache without cleanup of tasks committed by remote nodes.
Lineage cache writeback and cleanup implementation.
ObjectManager Changes:

Object manager event loop/ClientConnection refactor.
Multithreaded object manager (disabled in this PR).
Testing Changes:

Integration tests for task submission on multiple Raylets.
Stress tests for object manager (with GCS and object store integration).


Co-authored-by: Stephanie Wang <swang@cs.berkeley.edu>
Co-authored-by: Alexey Tumanov <atumanov@gmail.com>
This commit is contained in:
Melih Elibol
2018-03-31 18:02:58 -07:00
committed by Philipp Moritz
parent 40c9b9cd60
commit 6e06a9e338
91 changed files with 4888 additions and 1799 deletions
+7
View File
@@ -102,6 +102,13 @@ install:
- cd python/ray/core
- bash ../../../src/ray/test/run_gcs_tests.sh
# Raylet tests.
- bash ../../../src/ray/test/run_object_manager_tests.sh
- bash ../../../src/ray/test/run_task_test.sh
- ./src/ray/raylet/task_test
- ./src/ray/raylet/worker_pool_test
- ./src/ray/raylet/lineage_cache_test
- bash ../../../src/common/test/run_tests.sh
- bash ../../../src/plasma/test/run_tests.sh
- bash ../../../src/local_scheduler/test/run_tests.sh
+11 -3
View File
@@ -50,7 +50,14 @@
}
static const char *table_prefixes[] = {
NULL, "TASK:", "TASK:", "CLIENT:", "OBJECT:", "FUNCTION:",
NULL,
"TASK:",
"TASK:",
"CLIENT:",
"OBJECT:",
"FUNCTION:",
"TASK_RECONSTRUCTION:",
"HEARTBEAT:",
};
/// Parse a Redis string into a TablePubsub channel.
@@ -811,8 +818,9 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx,
// notifications.
flatbuffers::FlatBufferBuilder fbb;
TableEntryToFlatbuf(table_key, id, fbb);
RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, reinterpret_cast<const char *>(fbb.GetBufferPointer()),
fbb.GetSize());
RedisModule_Call(ctx, "PUBLISH", "sb", client_channel,
reinterpret_cast<const char *>(fbb.GetBufferPointer()),
fbb.GetSize());
}
RedisModule_CloseKey(table_key);
+1 -6
View File
@@ -141,13 +141,8 @@ GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop,
std::vector<std::string>());
db_attach(state->db, loop, false);
ClientTableDataT client_info;
client_info.client_id = get_db_client_id(state->db).binary();
client_info.node_manager_address = std::string(node_ip_address);
client_info.local_scheduler_port = 0;
client_info.object_manager_port = 0;
RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr),
redis_primary_port, client_info));
redis_primary_port));
RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop));
state->policy_state = GlobalSchedulerPolicyState_init();
return state;
+1 -6
View File
@@ -355,13 +355,8 @@ LocalSchedulerState *LocalSchedulerState_init(
"local_scheduler", node_ip_address, db_connect_args);
db_attach(state->db, loop, false);
ClientTableDataT client_info;
client_info.client_id = get_db_client_id(state->db).binary();
client_info.node_manager_address = std::string(node_ip_address);
client_info.local_scheduler_port = 0;
client_info.object_manager_port = 0;
RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr),
redis_primary_port, client_info));
redis_primary_port));
RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop));
} else {
state->db = NULL;
+1 -6
View File
@@ -487,13 +487,8 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name,
"plasma_manager", manager_addr, db_connect_args);
db_attach(state->db, state->loop, false);
ClientTableDataT client_info;
client_info.client_id = get_db_client_id(state->db).binary();
client_info.node_manager_address = std::string(manager_addr);
client_info.local_scheduler_port = 0;
client_info.object_manager_port = manager_port;
RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr),
redis_primary_port, client_info));
redis_primary_port));
RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(state->loop));
} else {
state->db = NULL;
+4 -1
View File
@@ -38,8 +38,11 @@ set(RAY_SRCS
gcs/asio.cc
common/client_connection.cc
object_manager/object_manager_client_connection.cc
object_manager/object_store_client.cc
object_manager/connection_pool.cc
object_manager/object_store_client_pool.cc
object_manager/object_store_notification_manager.cc
object_manager/object_directory.cc
object_manager/transfer_queue.cc
object_manager/object_manager.cc
raylet/mock_gcs_client.cc
raylet/task.cc
+72 -44
View File
@@ -7,20 +7,82 @@
namespace ray {
ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket,
const std::string &ip_address_string, int port) {
boost::asio::ip::address ip_address =
boost::asio::ip::address::from_string(ip_address_string);
boost::asio::ip::tcp::endpoint endpoint(ip_address, port);
boost::system::error_code error;
socket.connect(endpoint, error);
if (error) {
return ray::Status::IOError(error.message());
} else {
return ray::Status::OK();
}
}
template <class T>
ServerConnection<T>::ServerConnection(boost::asio::basic_stream_socket<T> &&socket)
: socket_(std::move(socket)) {}
template <class T>
void ServerConnection<T>::WriteBuffer(
const std::vector<boost::asio::const_buffer> &buffer, boost::system::error_code &ec) {
boost::asio::write(socket_, buffer, ec);
}
template <class T>
void ServerConnection<T>::ReadBuffer(
const std::vector<boost::asio::mutable_buffer> &buffer,
boost::system::error_code &ec) {
boost::asio::read(socket_, buffer, ec);
}
template <class T>
ray::Status ServerConnection<T>::WriteMessage(int64_t type, int64_t length,
const uint8_t *message) {
std::vector<boost::asio::const_buffer> message_buffers;
auto write_version = RayConfig::instance().ray_protocol_version();
message_buffers.push_back(boost::asio::buffer(&write_version, sizeof(write_version)));
message_buffers.push_back(boost::asio::buffer(&type, sizeof(type)));
message_buffers.push_back(boost::asio::buffer(&length, sizeof(length)));
message_buffers.push_back(boost::asio::buffer(message, length));
// Write the message and then wait for more messages.
// TODO(swang): Does this need to be an async write?
boost::system::error_code error;
boost::asio::write(socket_, message_buffers, error);
if (error) {
return ray::Status::IOError(error.message());
} else {
return ray::Status::OK();
}
}
template <class T>
std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
ClientManager<T> &manager, boost::asio::basic_stream_socket<T> &&socket) {
ClientHandler<T> &client_handler, MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket) {
std::shared_ptr<ClientConnection<T>> self(
new ClientConnection(manager, std::move(socket)));
new ClientConnection(message_handler, std::move(socket)));
// Let our manager process our new connection.
self->manager_.ProcessNewClient(self);
client_handler(self);
return self;
}
template <class T>
ClientConnection<T>::ClientConnection(ClientManager<T> &manager,
ClientConnection<T>::ClientConnection(MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket)
: socket_(std::move(socket)), manager_(manager) {}
: ServerConnection<T>(std::move(socket)), message_handler_(message_handler) {}
template <class T>
const ClientID &ClientConnection<T>::GetClientID() {
return client_id_;
}
template <class T>
void ClientConnection<T>::SetClientID(const ClientID &client_id) {
client_id_ = client_id;
}
template <class T>
void ClientConnection<T>::ProcessMessages() {
@@ -31,7 +93,7 @@ void ClientConnection<T>::ProcessMessages() {
header.push_back(boost::asio::buffer(&read_type_, sizeof(read_type_)));
header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_)));
boost::asio::async_read(
socket_, header,
ServerConnection<T>::socket_, header,
boost::bind(&ClientConnection<T>::ProcessMessageHeader, this->shared_from_this(),
boost::asio::placeholders::error));
}
@@ -52,57 +114,23 @@ void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &
read_message_.resize(read_length_);
// Wait for the message to be read.
boost::asio::async_read(
socket_, boost::asio::buffer(read_message_),
ServerConnection<T>::socket_, boost::asio::buffer(read_message_),
boost::bind(&ClientConnection<T>::ProcessMessage, this->shared_from_this(),
boost::asio::placeholders::error));
}
template <class T>
void ClientConnection<T>::WriteMessage(int64_t type, size_t length,
const uint8_t *message) {
std::vector<boost::asio::const_buffer> message_buffers;
write_version_ = RayConfig::instance().ray_protocol_version();
write_type_ = type;
write_length_ = length;
write_message_.assign(message, message + length);
message_buffers.push_back(boost::asio::buffer(&write_version_, sizeof(write_version_)));
message_buffers.push_back(boost::asio::buffer(&write_type_, sizeof(write_type_)));
message_buffers.push_back(boost::asio::buffer(&write_length_, sizeof(write_length_)));
message_buffers.push_back(boost::asio::buffer(write_message_));
boost::system::error_code error;
// Write the message and then wait for more messages.
boost::asio::async_write(
socket_, message_buffers,
boost::bind(&ClientConnection<T>::ProcessMessages, this->shared_from_this(),
boost::asio::placeholders::error));
}
template <class T>
void ClientConnection<T>::ProcessMessage(const boost::system::error_code &error) {
if (error) {
// TODO(hme): Disconnect differently & remove dependency on node_manager_generated.h
read_type_ = protocol::MessageType_DisconnectClient;
}
manager_.ProcessClientMessage(this->shared_from_this(), read_type_,
read_message_.data());
}
template <class T>
void ClientConnection<T>::ProcessMessages(const boost::system::error_code &error) {
if (error) {
ProcessMessage(error);
} else {
ProcessMessages();
}
message_handler_(this->shared_from_this(), read_type_, read_message_.data());
}
template class ServerConnection<boost::asio::local::stream_protocol>;
template class ServerConnection<boost::asio::ip::tcp>;
template class ClientConnection<boost::asio::local::stream_protocol>;
template class ClientConnection<boost::asio::ip::tcp>;
template <class T>
ClientManager<T>::~ClientManager<T>() {}
template class ClientManager<boost::asio::local::stream_protocol>;
template class ClientManager<boost::asio::ip::tcp>;
} // namespace ray
+80 -56
View File
@@ -7,18 +7,74 @@
#include <boost/asio/error.hpp>
#include <boost/enable_shared_from_this.hpp>
#include "ray/id.h"
#include "ray/status.h"
namespace ray {
template <class T>
class ClientManager;
/// \class ClientConnection
/// Connect a TCP socket.
///
/// A generic type representing a client connection on a server. This class can
/// be used to process and write messages asynchronously from and to the
/// client.
template <class T>
class ClientConnection : public std::enable_shared_from_this<ClientConnection<T>> {
/// \param socket The socket to connect.
/// \param ip_address The IP address to connect to.
/// \param port The port to connect to.
/// \return Status.
ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket,
const std::string &ip_address, int port);
/// \typename ServerConnection
///
/// A generic type representing a client connection to a server. This typename
/// can be used to write messages synchronously to the server.
template <typename T>
class ServerConnection {
public:
/// Create a connection to the server.
ServerConnection(boost::asio::basic_stream_socket<T> &&socket);
/// Write a message to the client.
///
/// \param type The message type (e.g., a flatbuffer enum).
/// \param length The size in bytes of the message.
/// \param message A pointer to the message buffer.
/// \return Status.
ray::Status WriteMessage(int64_t type, int64_t length, const uint8_t *message);
/// Write a buffer to this connection.
///
/// \param buffer The buffer.
/// \param ec The error code object in which to store error codes.
void WriteBuffer(const std::vector<boost::asio::const_buffer> &buffer,
boost::system::error_code &ec);
/// Read a buffer from this connection.
///
/// \param buffer The buffer.
/// \param ec The error code object in which to store error codes.
void ReadBuffer(const std::vector<boost::asio::mutable_buffer> &buffer,
boost::system::error_code &ec);
protected:
/// The socket connection to the server.
boost::asio::basic_stream_socket<T> socket_;
};
template <typename T>
class ClientConnection;
template <typename T>
using ClientHandler = std::function<void(std::shared_ptr<ClientConnection<T>>)>;
template <typename T>
using MessageHandler =
std::function<void(std::shared_ptr<ClientConnection<T>>, int64_t, const uint8_t *)>;
/// \typename ClientConnection
///
/// A generic type representing a client connection on a server. In addition to
/// writing messages to the client, like in ServerConnection, this typename can
/// also be used to process messages asynchronously from client.
template <typename T>
class ClientConnection : public ServerConnection<T>,
public std::enable_shared_from_this<ClientConnection<T>> {
public:
/// Allocate a new node client connection.
///
@@ -27,24 +83,23 @@ class ClientConnection : public std::enable_shared_from_this<ClientConnection<T>
/// \param socket The client socket.
/// \return std::shared_ptr<ClientConnection>.
static std::shared_ptr<ClientConnection<T>> Create(
ClientManager<T> &manager, boost::asio::basic_stream_socket<T> &&socket);
ClientHandler<T> &new_client_handler, MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket);
/// \return The ClientID of the remote client.
const ClientID &GetClientID();
/// \param client_id The ClientID of the remote client.
void SetClientID(const ClientID &client_id);
/// Listen for and process messages from the client connection. Once a
/// message has been fully received, the client manager's
/// ProcessClientMessage handler will be called.
void ProcessMessages();
/// Write a message to the client and then listen for more messages.
///
/// \param type The message type (e.g., a flatbuffer enum).
/// \param length The size in bytes of the message.
/// \param message A pointer to the message buffer. This will be copied into
/// the ClientConnection's buffer.
void WriteMessage(int64_t type, size_t length, const uint8_t *message);
private:
/// A private constructor for a node client connection.
ClientConnection(ClientManager<T> &manager,
ClientConnection(MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket);
/// Process an error from the last operation, then process the message
/// header from the client.
@@ -52,54 +107,23 @@ class ClientConnection : public std::enable_shared_from_this<ClientConnection<T>
/// Process an error from reading the message header, then process the
/// message from the client.
void ProcessMessage(const boost::system::error_code &error);
/// Process an error from the last operation and then listen for more
/// messages.
void ProcessMessages(const boost::system::error_code &error);
/// The client socket.
boost::asio::basic_stream_socket<T> socket_;
/// A reference to the manager for this client. The manager exposes a handler
/// for all messages processed by this client.
ClientManager<T> &manager_;
/// The ClientID of the remote client.
ClientID client_id_;
/// The handler for a message from the client.
MessageHandler<T> message_handler_;
/// Buffers for the current message being read rom the client.
int64_t read_version_;
int64_t read_type_;
uint64_t read_length_;
std::vector<uint8_t> read_message_;
/// Buffers for the current message being written to the client.
int64_t write_version_;
int64_t write_type_;
uint64_t write_length_;
std::vector<uint8_t> write_message_;
};
using LocalServerConnection = ServerConnection<boost::asio::local::stream_protocol>;
using TcpServerConnection = ServerConnection<boost::asio::ip::tcp>;
using LocalClientConnection = ClientConnection<boost::asio::local::stream_protocol>;
using TcpClientConnection = ClientConnection<boost::asio::ip::tcp>;
/// \class ClientManager
///
/// A virtual cliant manager. Derived classes should define a method for
/// processing a message on the server sent by the client.
template <class T>
class ClientManager {
public:
/// Process a new client connection.
///
/// \param client A shared pointer to the client that connected.
virtual void ProcessNewClient(std::shared_ptr<ClientConnection<T>> client) = 0;
/// Process a message from a client, then listen for more messages if the
/// client is still alive.
///
/// \param client A shared pointer to the client that sent the message.
/// \param message_type The message type (e.g., a flatbuffer enum).
/// \param message A pointer to the message buffer.
virtual void ProcessClientMessage(std::shared_ptr<ClientConnection<T>> client,
int64_t message_type, const uint8_t *message) = 0;
virtual ~ClientManager() = 0;
};
} // namespace ray
#endif // RAY_COMMON_CLIENT_CONNECTION_H
+18
View File
@@ -4,6 +4,24 @@
/// Length of Ray IDs in bytes.
constexpr int64_t kUniqueIDSize = 20;
/// An ObjectID's bytes are split into the task ID itself and the index of the
/// object's creation. This is the maximum width of the object index in bits.
constexpr int kObjectIdIndexSize = 32;
/// The maximum number of objects that can be returned by a task when finishing
/// execution. An ObjectID's bytes are split into the task ID itself and the
/// index of the object's creation. A positive index indicates an object
/// returned by the task, so the maximum number of objects that a task can
/// return is the maximum positive value for an integer with bit-width
/// `kObjectIdIndexSize`.
constexpr int64_t kMaxTaskReturns = ((int64_t)1 << (kObjectIdIndexSize - 1)) - 1;
/// The maximum number of objects that can be put by a task during execution.
/// An ObjectID's bytes are split into the task ID itself and the index of the
/// object's creation. A negative index indicates an object put by the task
/// during execution, so the maximum number of objects that a task can put is
/// the maximum negative value for an integer with bit-width
/// `kObjectIdIndexSize`.
constexpr int64_t kMaxTaskPuts = ((int64_t)1 << (kObjectIdIndexSize - 1));
/// Prefix for the object table keys in redis.
constexpr char kObjectTablePrefix[] = "ObjectTable";
/// Prefix for the task table keys in redis.
+13 -8
View File
@@ -6,19 +6,22 @@ namespace ray {
namespace gcs {
AsyncGcsClient::AsyncGcsClient() {}
AsyncGcsClient::~AsyncGcsClient() {}
Status AsyncGcsClient::Connect(const std::string &address, int port,
const ClientTableDataT &client_info) {
AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) {
context_.reset(new RedisContext());
RAY_RETURN_NOT_OK(context_->Connect(address, port));
client_table_.reset(new ClientTable(context_, this, client_id));
object_table_.reset(new ObjectTable(context_, this));
task_table_.reset(new TaskTable(context_, this));
raylet_task_table_.reset(new raylet::TaskTable(context_, this));
task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this));
client_table_.reset(new ClientTable(context_, this, client_info));
heartbeat_table_.reset(new HeartbeatTable(context_, this));
}
AsyncGcsClient::AsyncGcsClient() : AsyncGcsClient(ClientID::from_random()) {}
AsyncGcsClient::~AsyncGcsClient() {}
Status AsyncGcsClient::Connect(const std::string &address, int port) {
RAY_RETURN_NOT_OK(context_->Connect(address, port));
// TODO(swang): Call the client table's Connect() method here. To do this,
// we need to make sure that we are attached to an event loop first. This
// currently isn't possible because the aeEventLoop, which we use for
@@ -55,6 +58,8 @@ FunctionTable &AsyncGcsClient::function_table() { return *function_table_; }
ClassTable &AsyncGcsClient::class_table() { return *class_table_; }
HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; }
} // namespace gcs
} // namespace ray
+10 -3
View File
@@ -19,6 +19,13 @@ class RedisContext;
class RAY_EXPORT AsyncGcsClient {
public:
/// Start a GCS client with the given client ID. To read from the GCS tables,
/// Connect and then Attach must be called. To read and write from the GCS
/// tables requires a further call to Connect to the client table.
///
/// \param client_id The ID to assign to the client.
AsyncGcsClient(const ClientID &client_id);
/// Start a GCS client with a random client ID.
AsyncGcsClient();
~AsyncGcsClient();
@@ -26,10 +33,8 @@ class RAY_EXPORT AsyncGcsClient {
///
/// \param address The GCS IP address.
/// \param port The GCS port.
/// \param client_info Information about the local client to connect.
/// \return Status.
Status Connect(const std::string &address, int port,
const ClientTableDataT &client_info);
Status Connect(const std::string &address, int port);
/// Attach this client to a plasma event loop. Note that only
/// one event loop should be attached at a time.
Status Attach(plasma::EventLoop &event_loop);
@@ -48,6 +53,7 @@ class RAY_EXPORT AsyncGcsClient {
raylet::TaskTable &raylet_task_table();
TaskReconstructionLog &task_reconstruction_log();
ClientTable &client_table();
HeartbeatTable &heartbeat_table();
inline ErrorTable &error_table();
// We also need something to export generic code to run on workers from the
@@ -67,6 +73,7 @@ class RAY_EXPORT AsyncGcsClient {
std::unique_ptr<TaskTable> task_table_;
std::unique_ptr<raylet::TaskTable> raylet_task_table_;
std::unique_ptr<TaskReconstructionLog> task_reconstruction_log_;
std::unique_ptr<HeartbeatTable> heartbeat_table_;
std::unique_ptr<ClientTable> client_table_;
std::shared_ptr<RedisContext> context_;
std::unique_ptr<RedisAsioClient> asio_async_client_;
+13 -8
View File
@@ -23,12 +23,7 @@ class TestGcs : public ::testing::Test {
public:
TestGcs() : num_callbacks_(0) {
client_ = std::make_shared<gcs::AsyncGcsClient>();
ClientTableDataT client_info;
client_info.client_id = ClientID::from_random().binary();
client_info.node_manager_address = "127.0.0.1";
client_info.local_scheduler_port = 0;
client_info.object_manager_port = 0;
RAY_CHECK_OK(client_->Connect("127.0.0.1", 6379, client_info));
RAY_CHECK_OK(client_->Connect("127.0.0.1", 6379));
job_id_ = JobID::from_random();
}
@@ -747,6 +742,7 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client
ClientID added_id = client->client_table().GetLocalClientId();
ASSERT_EQ(client_id, added_id);
ASSERT_EQ(ClientID::from_binary(data.client_id), added_id);
ASSERT_EQ(ClientID::from_binary(data.client_id), added_id);
ASSERT_EQ(data.is_insertion, is_insertion);
auto cached_client = client->client_table().GetClient(added_id);
@@ -763,9 +759,14 @@ void TestClientTableConnect(const JobID &job_id,
ClientTableNotification(client, id, data, true);
test->Stop();
});
// Connect and disconnect to client table. We should receive notifications
// for the addition and removal of our own entry.
RAY_CHECK_OK(client->client_table().Connect());
ClientTableDataT local_client_info = client->client_table().GetLocalClient();
local_client_info.node_manager_address = "127.0.0.1";
local_client_info.node_manager_port = 0;
local_client_info.object_manager_port = 0;
RAY_CHECK_OK(client->client_table().Connect(local_client_info));
test->Start();
}
@@ -789,7 +790,11 @@ void TestClientTableDisconnect(const JobID &job_id,
});
// Connect and disconnect to client table. We should receive notifications
// for the addition and removal of our own entry.
RAY_CHECK_OK(client->client_table().Connect());
ClientTableDataT local_client_info = client->client_table().GetLocalClient();
local_client_info.node_manager_address = "127.0.0.1";
local_client_info.node_manager_port = 0;
local_client_info.object_manager_port = 0;
RAY_CHECK_OK(client->client_table().Connect(local_client_info));
RAY_CHECK_OK(client->client_table().Disconnect());
test->Start();
}
+23 -16
View File
@@ -11,7 +11,8 @@ enum TablePrefix:int {
CLIENT,
OBJECT,
FUNCTION,
TASK_RECONSTRUCTION
TASK_RECONSTRUCTION,
HEARTBEAT
}
// The channel that Add operations to the Table should be published on, if any.
@@ -21,7 +22,8 @@ enum TablePubsub:int {
RAYLET_TASK,
CLIENT,
OBJECT,
ACTOR
ACTOR,
HEARTBEAT
}
table GcsTableEntry {
@@ -98,6 +100,13 @@ table CustomSerializerData {
table ConfigTableData {
}
table RayResource {
// The type of the resource.
resource_name: string;
// The total capacity of this resource type.
resource_capacity: double;
}
table ClientTableData {
// The client ID of the client that the message is about.
client_id: string;
@@ -105,26 +114,24 @@ table ClientTableData {
node_manager_address: string;
// The port at which the client's node manager is listening for TCP
// connections from other node managers.
local_scheduler_port: int;
node_manager_port: int;
// The port at which the client's object manager is listening for TCP
// connections from other object managers.
object_manager_port: int;
// True if the message is about the addition of a client and false if it is
// about the deletion of a client.
is_insertion: bool;
resources_total_label: [string];
resources_total_capacity: [double];
}
table Resource {
// The type of the resource.
resource_name: string;
// The total capacity of this resource type.
resource_capacity: double;
}
table NodeManagerHeartbeat {
// The available resources on this node manager. This information may be
// stale.
resources_available: [Resource];
// The total resources on this node manager.
resources_total: [Resource];
table HeartbeatTableData {
// Node manager client id
client_id: string;
// Resource capacity currently available on this node manager.
resources_available_label: [string];
resources_available_capacity: [double];
// Total resource capacity configured for this node manager.
resources_total_label: [string];
resources_total_capacity: [double];
}
+4 -2
View File
@@ -25,7 +25,8 @@ void ProcessCallback(int64_t callback_index, const std::string &data) {
}
}
}
}
} // namespace
namespace ray {
@@ -98,9 +99,10 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) {
}
int64_t RedisCallbackManager::add(const RedisCallback &function) {
num_callbacks += 1;
callbacks_.emplace(num_callbacks, std::unique_ptr<RedisCallback>(
new RedisCallback(function)));
return num_callbacks++;
return num_callbacks;
}
RedisCallbackManager::RedisCallback &RedisCallbackManager::get(
+2 -1
View File
@@ -49,7 +49,8 @@ class RedisCallbackManager {
class RedisContext {
public:
RedisContext() {}
RedisContext()
: context_(nullptr), async_context_(nullptr), subscribe_context_(nullptr) {}
~RedisContext();
Status Connect(const std::string &address, int port);
Status AttachToEventLoop(aeEventLoop *loop);
+8 -4
View File
@@ -89,8 +89,8 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
<< "Client called Subscribe twice on the same table";
auto d = std::shared_ptr<CallbackData>(
new CallbackData({client_id, nullptr, subscribe, done, this, client_}));
int64_t callback_index = RedisCallbackManager::instance().add(
[this, d](const std::string &data) {
int64_t callback_index =
RedisCallbackManager::instance().add([this, d](const std::string &data) {
if (data.empty()) {
// No notification data is provided. This is the callback for the
// initial subscription request.
@@ -115,7 +115,7 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
results.emplace_back(std::move(result));
}
(d->callback)(d->client, id, results);
}
}
}
// We do not delete the callback after calling it since there may be
// more subscription messages.
@@ -270,9 +270,12 @@ const ClientID &ClientTable::GetLocalClientId() { return client_id_; }
const ClientTableDataT &ClientTable::GetLocalClient() { return local_client_; }
Status ClientTable::Connect() {
Status ClientTable::Connect(const ClientTableDataT &local_client) {
RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client.";
RAY_CHECK(local_client.client_id == local_client_.client_id);
local_client_ = local_client;
auto data = std::make_shared<ClientTableDataT>(local_client_);
data->is_insertion = true;
// Callback for a notification from the client table.
@@ -336,6 +339,7 @@ template class Log<TaskID, ray::protocol::Task>;
template class Table<TaskID, ray::protocol::Task>;
template class Table<TaskID, TaskTableData>;
template class Log<TaskID, TaskReconstructionData>;
template class Table<ClientID, HeartbeatTableData>;
} // namespace gcs
+40 -12
View File
@@ -167,13 +167,23 @@ class Log {
int64_t subscribe_callback_index_;
};
template <typename ID, typename Data>
class TableInterface {
public:
using DataT = typename Data::NativeTableType;
using WriteCallback = typename Log<ID, Data>::WriteCallback;
virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr<DataT> data,
const WriteCallback &done) = 0;
virtual ~TableInterface(){};
};
/// \class Table
///
/// A GCS table where every entry is a single data item.
/// Example tables backed by Log:
/// TaskTable: Stores Task metadata needed for executing the task.
template <typename ID, typename Data>
class Table : private Log<ID, Data> {
class Table : private Log<ID, Data>, public TableInterface<ID, Data> {
public:
using DataT = typename Log<ID, Data>::DataT;
using Callback =
@@ -242,6 +252,17 @@ class ObjectTable : public Log<ObjectID, ObjectTableData> {
pubsub_channel_ = TablePubsub_OBJECT;
prefix_ = TablePrefix_OBJECT;
};
virtual ~ObjectTable(){};
};
class HeartbeatTable : public Table<ClientID, HeartbeatTableData> {
public:
HeartbeatTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Table(context, client) {
pubsub_channel_ = TablePubsub_HEARTBEAT;
prefix_ = TablePrefix_HEARTBEAT;
}
virtual ~HeartbeatTable() {}
};
class FunctionTable : public Table<ObjectID, FunctionTableData> {
@@ -277,7 +298,8 @@ class TaskTable : public Table<TaskID, ray::protocol::Task> {
prefix_ = TablePrefix_RAYLET_TASK;
}
};
}
} // namespace raylet
class TaskTable : public Table<TaskID, TaskTableData> {
public:
@@ -286,6 +308,7 @@ class TaskTable : public Table<TaskID, TaskTableData> {
pubsub_channel_ = TablePubsub_TASK;
prefix_ = TablePrefix_TASK;
};
~TaskTable(){};
using TestAndUpdateCallback =
std::function<void(AsyncGcsClient *client, const TaskID &id,
@@ -350,12 +373,6 @@ class TaskTable : public Table<TaskID, TaskTableData> {
const Callback &done);
};
using ErrorTable = Table<TaskID, ErrorTableData>;
using CustomSerializerTable = Table<ClassID, CustomSerializerData>;
using ConfigTable = Table<ConfigID, ConfigTableData>;
Status TaskTableAdd(AsyncGcsClient *gcs_client, Task *task);
Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id,
@@ -363,6 +380,12 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id,
SchedulingState update_state,
const TaskTable::TestAndUpdateCallback &callback);
using ErrorTable = Table<TaskID, ErrorTableData>;
using CustomSerializerTable = Table<ClassID, CustomSerializerData>;
using ConfigTable = Table<ConfigID, ConfigTableData>;
/// \class ClientTable
///
/// The ClientTable stores information about active and inactive clients. It is
@@ -377,17 +400,20 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
using ClientTableCallback = std::function<void(
AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data)>;
ClientTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client,
const ClientTableDataT &local_client)
const ClientID &client_id)
: Log(context, client),
// We set the client log's key equal to nil so that all instances of
// ClientTable have the same key.
client_log_key_(UniqueID::nil()),
disconnected_(false),
client_id_(ClientID::from_binary(local_client.client_id)),
local_client_(local_client) {
client_id_(client_id),
local_client_() {
pubsub_channel_ = TablePubsub_CLIENT;
prefix_ = TablePrefix_CLIENT;
// Set the local client's ID.
local_client_.client_id = client_id.binary();
// Add a nil client to the cache so that we can serve requests for clients
// that we have not heard about.
ClientTableDataT nil_client;
@@ -398,8 +424,10 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
/// Connect as a client to the GCS. This registers us in the client table
/// and begins subscription to client table notifications.
///
/// \param Information about the connecting client. This must have the
/// same client_id as the one set in the client table.
/// \return Status
ray::Status Connect();
ray::Status Connect(const ClientTableDataT &local_client);
/// Disconnect the client from the GCS. The client ID assigned during
/// registration should never be reused after disconnecting.
+4 -2
View File
@@ -1,13 +1,15 @@
#include "ray/gcs/tables.h"
#include "ray/gcs/client.h"
#include "ray/id.h"
#include "common_protocol.h"
#include "task.h"
// TODO(swang): This file extends tables.cc so that we can separate out the
// part that depends on the Task* datasturcture from the build. This should be
// merged with tables.cc once we get rid of the Task* datastructure.
// part that depends on the legacy Task* data structure from the build. This
// should be merged with tables.cc once we get rid of the legacy Task*
// datastructure.
namespace {
+47
View File
@@ -2,6 +2,9 @@
#include <random>
#include "ray/constants.h"
#include "ray/status.h"
namespace ray {
UniqueID::UniqueID(const plasma::UniqueID &from) {
@@ -83,4 +86,48 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id) {
return os;
}
const ObjectID ComputeObjectId(TaskID task_id, int64_t object_index) {
RAY_CHECK(object_index <= kMaxTaskReturns && object_index >= -kMaxTaskPuts);
ObjectID return_id = task_id;
int64_t *first_bytes = reinterpret_cast<int64_t *>(&return_id);
// Zero out the lowest kObjectIdIndexSize bits of the first byte of the
// object ID.
uint64_t bitmask = static_cast<uint64_t>(-1) << kObjectIdIndexSize;
*first_bytes = *first_bytes & (bitmask);
// OR the first byte of the object ID with the return index.
*first_bytes = *first_bytes | (object_index & ~bitmask);
return return_id;
}
const TaskID FinishTaskId(const TaskID &task_id) { return ComputeObjectId(task_id, 0); }
const ObjectID ComputeReturnId(TaskID task_id, int64_t return_index) {
RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns);
return ComputeObjectId(task_id, return_index);
}
const ObjectID ComputePutId(TaskID task_id, int64_t put_index) {
RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts);
return ComputeObjectId(task_id, -1 * put_index);
}
const TaskID ComputeTaskId(const ObjectID &object_id) {
TaskID task_id = object_id;
int64_t *first_bytes = reinterpret_cast<int64_t *>(&task_id);
// Zero out the lowest kObjectIdIndexSize bits of the first byte of the
// object ID.
uint64_t bitmask = static_cast<uint64_t>(-1) << kObjectIdIndexSize;
*first_bytes = *first_bytes & (bitmask);
return task_id;
}
int64_t ComputeObjectIndex(const ObjectID &object_id) {
const int64_t *first_bytes = reinterpret_cast<const int64_t *>(&object_id);
uint64_t bitmask = static_cast<uint64_t>(-1) << kObjectIdIndexSize;
int64_t index = *first_bytes & (~bitmask);
index <<= (8 * sizeof(int64_t) - kObjectIdIndexSize);
index >>= (8 * sizeof(int64_t) - kObjectIdIndexSize);
return index;
}
} // namespace ray
+38 -3
View File
@@ -10,9 +10,6 @@
#include "ray/constants.h"
#include "ray/util/visibility.h"
// TODO(swang): Make task ID prefix of any object ID return values and puts so
// that we can co-locate task and object entries in the GCS.
namespace ray {
class RAY_EXPORT UniqueID {
@@ -61,6 +58,44 @@ typedef UniqueID DriverID;
typedef UniqueID ConfigID;
typedef UniqueID ClientID;
// TODO(swang): ObjectID and TaskID should derive from UniqueID. Then, we
// can make these methods of the derived classes.
/// Finish computing a task ID. Since objects created by the task share a
/// prefix of the ID, the suffix of the task ID is zeroed out by this function.
///
/// \param task_id A task ID to finish.
/// \return The finished task ID. It may now be used to compute IDs for objects
/// created by the task.
const TaskID FinishTaskId(const TaskID &task_id);
/// Compute the object ID of an object returned by the task.
///
/// \param task_id The task ID of the task that created the object.
/// \param put_index What number return value this object is in the task.
/// \return The computed object ID.
const ObjectID ComputeReturnId(TaskID task_id, int64_t return_index);
/// Compute the object ID of an object put by the task.
///
/// \param task_id The task ID of the task that created the object.
/// \param put_index What number put this object was created by in the task.
/// \return The computed object ID.
const ObjectID ComputePutId(TaskID task_id, int64_t put_index);
/// Compute the task ID of the task that created the object.
///
/// \param object_id The object ID.
/// \return The task ID of the task that created this object.
const TaskID ComputeTaskId(const ObjectID &object_id);
/// Compute the index of this object in the task that created it.
///
/// \param object_id The object ID.
/// \return The index of object creation according to the task that created
/// this object. This is positive if the task returned the object and negative
/// if created by a put.
int64_t ComputeObjectIndex(const ObjectID &object_id);
} // namespace ray
#endif // RAY_ID_H_
+2 -1
View File
@@ -19,7 +19,8 @@ add_custom_command(
add_custom_target(gen_object_manager_fbs DEPENDS ${OBJECT_MANAGER_FBS_OUTPUT_FILES})
ADD_RAY_TEST(object_manager_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(test/object_manager_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(test/object_manager_stress_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
add_library(object_manager object_manager.cc object_manager.h ${OBJECT_MANAGER_FBS_OUTPUT_FILES})
target_link_libraries(object_manager common ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY})
+113
View File
@@ -0,0 +1,113 @@
#include "ray/object_manager/connection_pool.h"
namespace ray {
ConnectionPool::ConnectionPool() {}
void ConnectionPool::RegisterReceiver(ConnectionType type, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> &conn) {
std::unique_lock<std::mutex> guard(connection_mutex);
switch (type) {
case ConnectionType::MESSAGE: {
Add(message_receive_connections_, client_id, conn);
} break;
case ConnectionType::TRANSFER: {
Add(transfer_receive_connections_, client_id, conn);
} break;
}
}
void ConnectionPool::RemoveReceiver(ConnectionType type, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> &conn) {
std::unique_lock<std::mutex> guard(connection_mutex);
switch (type) {
case ConnectionType::MESSAGE: {
Remove(message_receive_connections_, client_id, conn);
} break;
case ConnectionType::TRANSFER: {
Remove(transfer_receive_connections_, client_id, conn);
} break;
}
// TODO(hme): appropriately dispose of client connection.
}
void ConnectionPool::RegisterSender(ConnectionType type, const ClientID &client_id,
std::shared_ptr<SenderConnection> &conn) {
std::unique_lock<std::mutex> guard(connection_mutex);
SenderMapType &conn_map = (type == ConnectionType::MESSAGE)
? message_send_connections_
: transfer_send_connections_;
Add(conn_map, client_id, conn);
// Don't add to available connections. It will become available once it is released.
}
ray::Status ConnectionPool::GetSender(ConnectionType type, const ClientID &client_id,
std::shared_ptr<SenderConnection> *conn) {
std::unique_lock<std::mutex> guard(connection_mutex);
SenderMapType &avail_conn_map = (type == ConnectionType::MESSAGE)
? available_message_send_connections_
: available_transfer_send_connections_;
if (Count(avail_conn_map, client_id) > 0) {
*conn = Borrow(avail_conn_map, client_id);
} else {
*conn = nullptr;
}
return ray::Status::OK();
}
ray::Status ConnectionPool::ReleaseSender(ConnectionType type,
std::shared_ptr<SenderConnection> conn) {
std::unique_lock<std::mutex> guard(connection_mutex);
SenderMapType &conn_map = (type == ConnectionType::MESSAGE)
? available_message_send_connections_
: available_transfer_send_connections_;
Return(conn_map, conn->GetClientID(), conn);
return ray::Status::OK();
}
void ConnectionPool::Add(ReceiverMapType &conn_map, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> conn) {
conn_map[client_id].push_back(conn);
}
void ConnectionPool::Add(SenderMapType &conn_map, const ClientID &client_id,
std::shared_ptr<SenderConnection> conn) {
conn_map[client_id].push_back(conn);
}
void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> conn) {
if (conn_map.count(client_id) == 0) {
return;
}
std::vector<std::shared_ptr<TcpClientConnection>> &connections = conn_map[client_id];
int64_t pos =
std::find(connections.begin(), connections.end(), conn) - connections.begin();
if (pos >= (int64_t)connections.size()) {
return;
}
connections.erase(connections.begin() + pos);
}
uint64_t ConnectionPool::Count(SenderMapType &conn_map, const ClientID &client_id) {
if (conn_map.count(client_id) == 0) {
return 0;
};
return conn_map[client_id].size();
}
std::shared_ptr<SenderConnection> ConnectionPool::Borrow(SenderMapType &conn_map,
const ClientID &client_id) {
std::shared_ptr<SenderConnection> conn = conn_map[client_id].back();
conn_map[client_id].pop_back();
RAY_LOG(DEBUG) << "Borrow " << client_id << " " << conn_map[client_id].size();
return conn;
}
void ConnectionPool::Return(SenderMapType &conn_map, const ClientID &client_id,
std::shared_ptr<SenderConnection> conn) {
conn_map[client_id].push_back(conn);
RAY_LOG(DEBUG) << "Return " << client_id << " " << conn_map[client_id].size();
}
} // namespace ray
+143
View File
@@ -0,0 +1,143 @@
#ifndef RAY_OBJECT_MANAGER_CONNECTION_POOL_H
#define RAY_OBJECT_MANAGER_CONNECTION_POOL_H
#include <algorithm>
#include <cstdint>
#include <deque>
#include <map>
#include <memory>
#include <thread>
#include <boost/asio.hpp>
#include <boost/asio/error.hpp>
#include <boost/bind.hpp>
#include "ray/id.h"
#include "ray/status.h"
#include <mutex>
#include "ray/object_manager/format/object_manager_generated.h"
#include "ray/object_manager/object_directory.h"
#include "ray/object_manager/object_manager_client_connection.h"
namespace asio = boost::asio;
namespace ray {
class ConnectionPool {
public:
/// Callbacks for GetSender.
using SuccessCallback = std::function<void(std::shared_ptr<SenderConnection>)>;
using FailureCallback = std::function<void()>;
/// Connection type to distinguish between message and transfer connections.
enum class ConnectionType : int { MESSAGE = 0, TRANSFER };
/// Connection pool for all connections needed by the ObjectManager.
ConnectionPool();
/// Register a receiver connection.
///
/// \param type The type of connection.
/// \param client_id The ClientID of the remote object manager.
/// \param conn The actual connection.
void RegisterReceiver(ConnectionType type, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> &conn);
/// Remove a receiver connection.
///
/// \param type The type of connection.
/// \param client_id The ClientID of the remote object manager.
/// \param conn The actual connection.
void RemoveReceiver(ConnectionType type, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> &conn);
/// Register a receiver connection.
///
/// \param type The type of connection.
/// \param client_id The ClientID of the remote object manager.
/// \param conn The actual connection.
void RegisterSender(ConnectionType type, const ClientID &client_id,
std::shared_ptr<SenderConnection> &conn);
/// Get a sender connection from the connection pool.
/// The connection must be released or removed when the operation for which the
/// connection was obtained is completed. If the connection pool is empty, the
/// connection pointer passed in is set to a null pointer.
///
/// \param[in] type The type of connection.
/// \param[in] client_id The ClientID of the remote object manager.
/// \param[out] conn An empty pointer to a shared pointer.
/// \return Status of invoking this method.
ray::Status GetSender(ConnectionType type, const ClientID &client_id,
std::shared_ptr<SenderConnection> *conn);
/// Releases a sender connection, allowing it to be used by another operation.
///
/// \param type The type of connection.
/// \param conn The actual connection.
/// \return Status of invoking this method.
ray::Status ReleaseSender(ConnectionType type, std::shared_ptr<SenderConnection> conn);
// TODO(hme): Implement with error handling.
/// Remove a sender connection. This is invoked if the connection is no longer
/// usable.
///
/// \param type The type of connection.
/// \param conn The actual connection.
/// \return Status of invoking this method.
ray::Status RemoveSender(ConnectionType type, std::shared_ptr<SenderConnection> conn);
/// This object cannot be copied for thread-safety.
ConnectionPool &operator=(const ConnectionPool &o) {
throw std::runtime_error("Can't copy ConnectionPool.");
}
private:
/// A container type that maps ClientID to a connection type.
using SenderMapType =
std::unordered_map<ray::ClientID, std::vector<std::shared_ptr<SenderConnection>>,
ray::UniqueIDHasher>;
using ReceiverMapType =
std::unordered_map<ray::ClientID, std::vector<std::shared_ptr<TcpClientConnection>>,
ray::UniqueIDHasher>;
/// Adds a receiver for ClientID to the given map.
void Add(ReceiverMapType &conn_map, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> conn);
/// Adds a sender for ClientID to the given map.
void Add(SenderMapType &conn_map, const ClientID &client_id,
std::shared_ptr<SenderConnection> conn);
/// Removes the given receiver for ClientID from the given map.
void Remove(ReceiverMapType &conn_map, const ClientID &client_id,
std::shared_ptr<TcpClientConnection> conn);
/// Returns the count of sender connections to ClientID.
uint64_t Count(SenderMapType &conn_map, const ClientID &client_id);
/// Removes a sender connection to ClientID from the pool of available connections.
/// This method assumes conn_map has available connections to ClientID.
std::shared_ptr<SenderConnection> Borrow(SenderMapType &conn_map,
const ClientID &client_id);
/// Returns a sender connection to ClientID to the pool of available connections.
void Return(SenderMapType &conn_map, const ClientID &client_id,
std::shared_ptr<SenderConnection> conn);
// TODO(hme): Optimize with separate mutex per collection.
std::mutex connection_mutex;
SenderMapType message_send_connections_;
SenderMapType transfer_send_connections_;
SenderMapType available_message_send_connections_;
SenderMapType available_transfer_send_connections_;
ReceiverMapType message_receive_connections_;
ReceiverMapType transfer_receive_connections_;
};
} // namespace ray
#endif // RAY_OBJECT_MANAGER_CONNECTION_POOL_H
@@ -1,30 +1,37 @@
// Object Manager protocol specification
namespace ray.object_manager.protocol;
enum OMMessageType:int {
PullRequest = 1
enum MessageType:int {
ConnectClient = 1,
DisconnectClient,
PushRequest,
PullRequest
}
table PushRequest {
table PushRequestMessage {
// The object ID being transferred.
object_id: string;
// The size of the object being transferred.
object_size: ulong;
}
table PullRequest {
table PullRequestMessage {
// ID of the requesting client.
client_id: string;
// Requested ObjectID.
object_id: string;
}
table ClientConnectionInfo {
table ConnectClientMessage {
// ID of the connecting client.
client_id: string;
// Whether this is a transfer connection.
is_transfer: bool;
}
table ObjectHeader {
// The object ID being transferred.
object_id: string;
// The size of the object being transferred.
object_size: ulong;
table DisconnectClientMessage {
// ID of the connecting client.
client_id: string;
// Whether this is a transfer connection.
is_transfer: bool;
}
+60 -56
View File
@@ -1,41 +1,53 @@
#include "object_directory.h"
#include "ray/object_manager/object_directory.h"
namespace ray {
ObjectDirectory::ObjectDirectory(std::shared_ptr<GcsClient> gcs_client) {
ObjectDirectory::ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> gcs_client) {
gcs_client_ = gcs_client;
};
ray::Status ObjectDirectory::ReportObjectAdded(const ObjectID &object_id,
const ClientID &client_id) {
return gcs_client_->object_table().Add(object_id, client_id, [] {});
// TODO(hme): Determine whether we need to do lookup to append.
JobID job_id = JobID::from_random();
auto data = std::make_shared<ObjectTableDataT>();
data->manager = client_id.binary();
data->is_eviction = false;
ray::Status status = gcs_client_->object_table().Append(
job_id, object_id, data, [](gcs::AsyncGcsClient *client, const UniqueID &id,
const std::shared_ptr<ObjectTableDataT> data) {
// Do nothing.
});
return status;
};
ray::Status ObjectDirectory::ReportObjectRemoved(const ObjectID &object_id,
const ClientID &client_id) {
return gcs_client_->object_table().Remove(object_id, client_id, [] {});
// TODO(hme): Need corresponding remove method in GCS.
return ray::Status::NotImplemented("ObjectTable.Remove is not implemented");
};
ray::Status ObjectDirectory::GetInformation(const ClientID &client_id,
const InfoSuccessCallback &success_cb,
const InfoFailureCallback &fail_cb) {
gcs_client_->client_table().GetClientInformation(
client_id,
[this, success_cb, client_id](ClientInformation client_info) {
const auto &info =
RemoteConnectionInfo(client_id, client_info.GetIp(), client_info.GetPort());
success_cb(info);
},
fail_cb);
const InfoSuccessCallback &success_callback,
const InfoFailureCallback &fail_callback) {
const ClientTableDataT &data = gcs_client_->client_table().GetClient(client_id);
ClientID result_client_id = ClientID::from_binary(data.client_id);
if (result_client_id == ClientID::nil() || !data.is_insertion) {
fail_callback(ray::Status::RedisError("ClientID not found."));
} else {
const auto &info = RemoteConnectionInfo(client_id, data.node_manager_address,
(uint16_t)data.object_manager_port);
success_callback(info);
}
return ray::Status::OK();
};
ray::Status ObjectDirectory::GetLocations(const ObjectID &object_id,
const OnLocationsSuccess &success_cb,
const OnLocationsFailure &fail_cb) {
const OnLocationsSuccess &success_callback,
const OnLocationsFailure &fail_callback) {
ray::Status status_code = ray::Status::OK();
if (existing_requests_.count(object_id) == 0) {
existing_requests_[object_id] = ODCallbacks({success_cb, fail_cb});
existing_requests_[object_id] = ODCallbacks({success_callback, fail_callback});
status_code = ExecuteGetLocations(object_id);
} else {
// Do nothing. A request is in progress.
@@ -44,52 +56,44 @@ ray::Status ObjectDirectory::GetLocations(const ObjectID &object_id,
};
ray::Status ObjectDirectory::ExecuteGetLocations(const ObjectID &object_id) {
// TODO(hme): Avoid callback hell.
std::vector<RemoteConnectionInfo> remote_connections;
ray::Status status = gcs_client_->object_table().GetObjectClientIDs(
object_id,
[this, object_id, &remote_connections](const std::vector<ClientID> &client_ids) {
gcs_client_->client_table().GetClientInformationSet(
client_ids,
[this, object_id,
&remote_connections](const std::vector<ClientInformation> &info_vec) {
for (const auto &client_info : info_vec) {
RemoteConnectionInfo info =
RemoteConnectionInfo(client_info.GetClientId(), client_info.GetIp(),
client_info.GetPort());
remote_connections.push_back(info);
}
ray::Status cb_completion_status =
GetLocationsComplete(Status::OK(), object_id, remote_connections);
},
[this, object_id, &remote_connections](const Status &status) {
ray::Status cb_completion_status =
GetLocationsComplete(status, object_id, remote_connections);
});
},
[this, object_id, &remote_connections](const Status &status) {
ray::Status cb_completion_status =
GetLocationsComplete(status, object_id, remote_connections);
JobID job_id = JobID::from_random();
// Note: Lookup must be synchronous for thread-safe access.
// For now, this is only accessed by the main thread.
ray::Status status = gcs_client_->object_table().Lookup(
job_id, object_id,
[this, object_id](gcs::AsyncGcsClient *client, const ObjectID &object_id,
const std::vector<ObjectTableDataT> &data) {
GetLocationsComplete(object_id, data);
});
return status;
};
ray::Status ObjectDirectory::GetLocationsComplete(
const ray::Status &status, const ObjectID &object_id,
const std::vector<RemoteConnectionInfo> &remote_connections) {
bool success = status.ok();
// Only invoke a callback if the request was not cancelled.
if (existing_requests_.count(object_id) > 0) {
ODCallbacks cbs = existing_requests_[object_id];
if (success) {
cbs.success_cb(remote_connections, object_id);
void ObjectDirectory::GetLocationsComplete(
const ObjectID &object_id, const std::vector<ObjectTableDataT> &location_entries) {
auto request = existing_requests_.find(object_id);
// Do not invoke a callback if the request was cancelled.
if (request == existing_requests_.end()) {
return;
}
// Build the set of current locations based on the entries in the log.
std::unordered_set<ClientID, UniqueIDHasher> locations;
for (auto entry : location_entries) {
ClientID client_id = ClientID::from_binary(entry.manager);
if (!entry.is_eviction) {
locations.insert(client_id);
} else {
cbs.fail_cb(status, object_id);
locations.erase(client_id);
}
}
existing_requests_.erase(object_id);
return status;
};
// Invoke the callback.
std::vector<ClientID> locations_vector(locations.begin(), locations.end());
if (locations_vector.empty()) {
request->second.fail_cb(object_id);
} else {
request->second.success_cb(locations_vector, object_id);
}
existing_requests_.erase(request);
}
ray::Status ObjectDirectory::Cancel(const ObjectID &object_id) {
existing_requests_.erase(object_id);
+23 -18
View File
@@ -2,17 +2,19 @@
#define RAY_OBJECT_MANAGER_OBJECT_DIRECTORY_H
#include <memory>
#include <mutex>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "ray/gcs/client.h"
#include "ray/id.h"
#include "ray/raylet/mock_gcs_client.h"
#include "ray/status.h"
namespace ray {
struct RemoteConnectionInfo {
RemoteConnectionInfo() = default;
RemoteConnectionInfo(const ClientID &id, const std::string &ip_address,
uint16_t port_num)
: client_id(id), ip(ip_address), port(port_num) {}
@@ -42,10 +44,9 @@ class ObjectDirectoryInterface {
const InfoFailureCallback &fail_cb) = 0;
// Callbacks for GetLocations.
using OnLocationsSuccess = std::function<void(
const std::vector<ray::RemoteConnectionInfo> &v, const ray::ObjectID &object_id)>;
using OnLocationsFailure =
std::function<void(ray::Status status, const ray::ObjectID &object_id)>;
using OnLocationsSuccess = std::function<void(const std::vector<ray::ClientID> &v,
const ray::ObjectID &object_id)>;
using OnLocationsFailure = std::function<void(const ray::ObjectID &object_id)>;
/// Asynchronously obtain the locations of an object by ObjectID.
/// This is used to handle object pulls.
@@ -93,11 +94,11 @@ class ObjectDirectory : public ObjectDirectoryInterface {
~ObjectDirectory() override = default;
ray::Status GetInformation(const ClientID &client_id,
const InfoSuccessCallback &success_cb,
const InfoFailureCallback &fail_cb) override;
const InfoSuccessCallback &success_callback,
const InfoFailureCallback &fail_callback) override;
ray::Status GetLocations(const ObjectID &object_id,
const OnLocationsSuccess &success_cb,
const OnLocationsFailure &fail_cb) override;
const OnLocationsSuccess &success_callback,
const OnLocationsFailure &fail_callback) override;
ray::Status Cancel(const ObjectID &object_id) override;
ray::Status Terminate() override;
ray::Status ReportObjectAdded(const ObjectID &object_id,
@@ -105,7 +106,12 @@ class ObjectDirectory : public ObjectDirectoryInterface {
ray::Status ReportObjectRemoved(const ObjectID &object_id,
const ClientID &client_id) override;
/// Ray only (not part of the OD interface).
ObjectDirectory(std::shared_ptr<GcsClient> gcs_client);
ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
/// This object cannot be copied for thread-safety.
ObjectDirectory &operator=(const ObjectDirectory &o) {
throw std::runtime_error("Can't copy ObjectDirectory.");
}
private:
/// Callbacks associated with a call to GetLocations.
@@ -115,18 +121,17 @@ class ObjectDirectory : public ObjectDirectoryInterface {
OnLocationsFailure fail_cb;
};
/// Maintain map of in-flight GetLocation requests.
std::unordered_map<ObjectID, ODCallbacks, UniqueIDHasher> existing_requests_;
/// Reference to the gcs client.
std::shared_ptr<GcsClient> gcs_client_;
/// GetLocations registers a request for locations.
/// This function actually carries out that request.
ray::Status ExecuteGetLocations(const ObjectID &object_id);
/// Invoked when call to ExecuteGetLocations completes.
ray::Status GetLocationsComplete(const ray::Status &status, const ObjectID &object_id,
const std::vector<RemoteConnectionInfo> &v);
void GetLocationsComplete(const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_entries);
/// Maintain map of in-flight GetLocation requests.
std::unordered_map<ObjectID, ODCallbacks, UniqueIDHasher> existing_requests_;
/// Reference to the gcs client.
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
};
} // namespace ray
+425 -344
View File
@@ -1,58 +1,80 @@
#include "object_manager.h"
#include "ray/object_manager/object_manager.h"
namespace asio = boost::asio;
namespace object_manager_protocol = ray::object_manager::protocol;
namespace ray {
ObjectManager::ObjectManager(boost::asio::io_service &io_service,
ObjectManagerConfig config,
std::shared_ptr<ray::GcsClient> gcs_client)
: object_directory_(new ObjectDirectory(gcs_client)), work_(io_service_) {
ObjectManager::ObjectManager(asio::io_service &main_service,
std::unique_ptr<asio::io_service> object_manager_service,
const ObjectManagerConfig &config,
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
// TODO(hme): Eliminate knowledge of GCS.
: client_id_(gcs_client->client_table().GetLocalClientId()),
object_directory_(new ObjectDirectory(gcs_client)),
store_notification_(main_service, config.store_socket_name),
store_pool_(config.store_socket_name),
object_manager_service_(std::move(object_manager_service)),
work_(*object_manager_service_),
connection_pool_(),
transfer_queue_(),
num_transfers_send_(0),
num_transfers_receive_(0) {
main_service_ = &main_service;
config_ = config;
store_client_ = std::unique_ptr<ObjectStoreClient>(
new ObjectStoreClient(io_service, config.store_socket_name));
store_client_->SubscribeObjAdded(
store_notification_.SubscribeObjAdded(
[this](const ObjectID &oid) { NotifyDirectoryObjectAdd(oid); });
store_client_->SubscribeObjDeleted(
store_notification_.SubscribeObjDeleted(
[this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); });
StartIOService();
};
}
ObjectManager::ObjectManager(boost::asio::io_service &io_service,
ObjectManagerConfig config,
ObjectManager::ObjectManager(asio::io_service &main_service,
std::unique_ptr<asio::io_service> object_manager_service,
const ObjectManagerConfig &config,
std::unique_ptr<ObjectDirectoryInterface> od)
: object_directory_(std::move(od)), work_(io_service_) {
: object_directory_(std::move(od)),
store_notification_(main_service, config.store_socket_name),
store_pool_(config.store_socket_name),
object_manager_service_(std::move(object_manager_service)),
work_(*object_manager_service_),
connection_pool_(),
transfer_queue_(),
num_transfers_send_(0),
num_transfers_receive_(0) {
// TODO(hme) Client ID is never set with this constructor.
main_service_ = &main_service;
config_ = config;
store_client_ = std::unique_ptr<ObjectStoreClient>(
new ObjectStoreClient(io_service, config.store_socket_name));
store_client_->SubscribeObjAdded(
store_notification_.SubscribeObjAdded(
[this](const ObjectID &oid) { NotifyDirectoryObjectAdd(oid); });
store_client_->SubscribeObjDeleted(
store_notification_.SubscribeObjDeleted(
[this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); });
StartIOService();
};
}
void ObjectManager::StartIOService() {
io_thread_ = std::thread(&ObjectManager::IOServiceLoop, this);
// thread_group_.create_thread(boost::bind(&boost::asio::io_service::run,
// &io_service_));
for (int i = 0; i < config_.num_threads; ++i) {
io_threads_.emplace_back(std::thread(&ObjectManager::IOServiceLoop, this));
}
}
void ObjectManager::IOServiceLoop() { io_service_.run(); }
void ObjectManager::IOServiceLoop() { object_manager_service_->run(); }
void ObjectManager::StopIOService() {
io_service_.stop();
io_thread_.join();
// thread_group_.join_all();
object_manager_service_->stop();
for (int i = 0; i < config_.num_threads; ++i) {
io_threads_[i].join();
}
}
void ObjectManager::SetClientID(const ClientID &client_id) { client_id_ = client_id; }
ClientID ObjectManager::GetClientID() { return client_id_; }
void ObjectManager::NotifyDirectoryObjectAdd(const ObjectID &object_id) {
local_objects_.insert(object_id);
ray::Status status = object_directory_->ReportObjectAdded(object_id, client_id_);
}
void ObjectManager::NotifyDirectoryObjectDeleted(const ObjectID &object_id) {
local_objects_.erase(object_id);
ray::Status status = object_directory_->ReportObjectRemoved(object_id, client_id_);
}
@@ -60,391 +82,450 @@ ray::Status ObjectManager::Terminate() {
StopIOService();
ray::Status status_code = object_directory_->Terminate();
// TODO: evaluate store client termination status.
store_client_->Terminate();
store_notification_.Terminate();
store_pool_.Terminate();
return status_code;
};
}
ray::Status ObjectManager::SubscribeObjAdded(
std::function<void(const ObjectID &)> callback) {
store_client_->SubscribeObjAdded(callback);
store_notification_.SubscribeObjAdded(callback);
return ray::Status::OK();
};
}
ray::Status ObjectManager::SubscribeObjDeleted(
std::function<void(const ObjectID &)> callback) {
store_client_->SubscribeObjDeleted(callback);
store_notification_.SubscribeObjDeleted(callback);
return ray::Status::OK();
};
}
ray::Status ObjectManager::Pull(const ObjectID &object_id) {
// TODO(hme): Need to correct. Workaround to get all pull requests on the same thread.
SchedulePull(object_id, 0);
main_service_->dispatch(
[this, object_id]() { RAY_CHECK_OK(PullGetLocations(object_id)); });
return Status::OK();
};
}
void ObjectManager::SchedulePull(const ObjectID &object_id, int wait_ms) {
pull_requests_[object_id] = Timer(new boost::asio::deadline_timer(
io_service_, boost::posix_time::milliseconds(wait_ms)));
pull_requests_[object_id] = std::shared_ptr<boost::asio::deadline_timer>(
new asio::deadline_timer(*main_service_, boost::posix_time::milliseconds(wait_ms)));
pull_requests_[object_id]->async_wait(
[this, object_id](const boost::system::error_code &error_code) {
RAY_CHECK_OK(SchedulePullHandler(object_id));
pull_requests_.erase(object_id);
main_service_->dispatch(
[this, object_id]() { RAY_CHECK_OK(PullGetLocations(object_id)); });
});
}
ray::Status ObjectManager::SchedulePullHandler(const ObjectID &object_id) {
pull_requests_.erase(object_id);
ray::Status ObjectManager::PullGetLocations(const ObjectID &object_id) {
ray::Status status_code = object_directory_->GetLocations(
object_id,
[this](const std::vector<RemoteConnectionInfo> &vec, const ObjectID &object_id) {
return GetLocationsSuccess(vec, object_id);
[this](const std::vector<ClientID> &client_ids, const ObjectID &object_id) {
return GetLocationsSuccess(client_ids, object_id);
},
[this](ray::Status status, const ObjectID &object_id) {
return GetLocationsFailed(status, object_id);
});
[this](const ObjectID &object_id) { return GetLocationsFailed(object_id); });
return status_code;
}
void ObjectManager::GetLocationsSuccess(const std::vector<ray::RemoteConnectionInfo> &vec,
void ObjectManager::GetLocationsSuccess(const std::vector<ray::ClientID> &client_ids,
const ray::ObjectID &object_id) {
RemoteConnectionInfo info = vec.front();
RAY_CHECK(!client_ids.empty());
ClientID client_id = client_ids.front();
pull_requests_.erase(object_id);
ray::Status status_code = Pull(object_id, info.client_id);
};
ray::Status status_code = Pull(object_id, client_id);
}
void ObjectManager::GetLocationsFailed(ray::Status status, const ObjectID &object_id) {
void ObjectManager::GetLocationsFailed(const ObjectID &object_id) {
SchedulePull(object_id, config_.pull_timeout_ms);
};
}
ray::Status ObjectManager::Pull(const ObjectID &object_id, const ClientID &client_id) {
Status status =
GetMsgConnection(client_id, [this, object_id](SenderConnection::pointer client) {
Status status = ExecutePull(object_id, client);
});
return status;
main_service_->dispatch([this, object_id, client_id]() {
RAY_CHECK_OK(PullEstablishConnection(object_id, client_id));
});
return Status::OK();
};
ray::Status ObjectManager::ExecutePull(const ObjectID &object_id,
SenderConnection::pointer conn) {
size_t message_type = OMMessageType_PullRequest;
boost::system::error_code error_code;
boost::asio::write(conn->GetSocket(),
boost::asio::buffer(&message_type, sizeof(message_type)),
error_code);
ray::Status ObjectManager::PullEstablishConnection(const ObjectID &object_id,
const ClientID &client_id) {
// Check if object is already local, and client_id is not itself.
if (local_objects_.count(object_id) != 0 || client_id == client_id_) {
return ray::Status::OK();
}
// Acquire a message connection and send pull request.
ray::Status status;
std::shared_ptr<SenderConnection> conn;
// TODO(hme): There is no cap on the number of pull request connections.
status = connection_pool_.GetSender(ConnectionPool::ConnectionType::MESSAGE, client_id,
&conn);
if (!status.ok()) {
// TODO(hme): Keep track of retries,
// and only retry on object not local
// for now.
SchedulePull(object_id, config_.pull_timeout_ms);
return status;
}
if (conn == nullptr) {
status = object_directory_->GetInformation(
client_id,
[this, object_id, client_id](const RemoteConnectionInfo &connection_info) {
std::shared_ptr<SenderConnection> async_conn = CreateSenderConnection(
ConnectionPool::ConnectionType::MESSAGE, connection_info);
connection_pool_.RegisterSender(ConnectionPool::ConnectionType::MESSAGE,
client_id, async_conn);
RAY_CHECK_OK(PullSendRequest(object_id, async_conn));
},
[this, object_id](const Status &status) {
SchedulePull(object_id, config_.pull_timeout_ms);
});
} else {
RAY_CHECK_OK(PullSendRequest(object_id, conn));
}
return status;
}
ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id,
std::shared_ptr<SenderConnection> conn) {
flatbuffers::FlatBufferBuilder fbb;
auto message = CreatePullRequest(fbb, fbb.CreateString(client_id_.binary()),
fbb.CreateString(object_id.binary()));
auto message = object_manager_protocol::CreatePullRequestMessage(
fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary()));
fbb.Finish(message);
size_t length = fbb.GetSize();
std::vector<boost::asio::const_buffer> buffer;
buffer.push_back(boost::asio::buffer(&length, sizeof(length)));
buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length));
boost::asio::write(conn->GetSocket(), buffer);
RAY_CHECK_OK(conn->WriteMessage(object_manager_protocol::MessageType_PullRequest,
fbb.GetSize(), fbb.GetBufferPointer()));
RAY_CHECK_OK(
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn));
return ray::Status::OK();
};
}
ray::Status ObjectManager::Push(const ObjectID &object_id, const ClientID &client_id) {
ray::Status status;
status =
GetTransferConnection(client_id, [this, object_id](SenderConnection::pointer conn) {
ray::Status status = QueuePush(object_id, conn);
});
// TODO(hme): Cache this data in ObjectDirectory.
// Okay for now since the GCS client caches this data.
main_service_->dispatch([this, object_id, client_id]() {
Status status = object_directory_->GetInformation(
client_id,
[this, object_id, client_id](const RemoteConnectionInfo &info) {
transfer_queue_.QueueSend(client_id, object_id, info);
RAY_CHECK_OK(DequeueTransfers());
},
[this](const Status &status) {
// Push is best effort, so do nothing here.
});
RAY_CHECK_OK(status);
});
return ray::Status::OK();
}
ray::Status ObjectManager::DequeueTransfers() {
ray::Status status = ray::Status::OK();
// Dequeue sends.
while (true) {
if (std::atomic_fetch_add(&num_transfers_send_, 1) <= config_.max_sends) {
TransferQueue::SendRequest req;
bool exists = transfer_queue_.DequeueSendIfPresent(&req);
if (exists) {
object_manager_service_->dispatch([this, req]() {
RAY_LOG(DEBUG) << "DequeueSend " << client_id_ << " " << req.object_id << " "
<< num_transfers_send_ << "/" << config_.max_sends;
RAY_CHECK_OK(
ExecuteSendObject(req.object_id, req.client_id, req.connection_info));
});
} else {
std::atomic_fetch_sub(&num_transfers_send_, 1);
break;
}
} else {
std::atomic_fetch_sub(&num_transfers_send_, 1);
break;
}
}
// Dequeue receives.
while (true) {
if (std::atomic_fetch_add(&num_transfers_receive_, 1) <= config_.max_receives) {
TransferQueue::ReceiveRequest req;
bool exists = transfer_queue_.DequeueReceiveIfPresent(&req);
if (exists) {
object_manager_service_->dispatch([this, req]() {
RAY_LOG(DEBUG) << "DequeueReceive " << client_id_ << " " << req.object_id << " "
<< num_transfers_receive_ << "/" << config_.max_receives;
RAY_CHECK_OK(ExecuteReceiveObject(req.client_id, req.object_id, req.object_size,
req.conn));
});
} else {
std::atomic_fetch_sub(&num_transfers_receive_, 1);
break;
}
} else {
std::atomic_fetch_sub(&num_transfers_receive_, 1);
break;
}
}
return status;
}
ray::Status ObjectManager::TransferCompleted(TransferQueue::TransferType type) {
if (type == TransferQueue::TransferType::SEND) {
std::atomic_fetch_sub(&num_transfers_send_, 1);
} else {
std::atomic_fetch_sub(&num_transfers_receive_, 1);
}
return DequeueTransfers();
};
ray::Status ObjectManager::ExecuteSendObject(
const ObjectID &object_id, const ClientID &client_id,
const RemoteConnectionInfo &connection_info) {
ray::Status status;
std::shared_ptr<SenderConnection> conn;
status = connection_pool_.GetSender(ConnectionPool::ConnectionType::TRANSFER, client_id,
&conn);
if (!status.ok()) {
// TODO(hme): Keep track of retries,
// and only retry on object not local
// for now.
RAY_CHECK_OK(Push(object_id, conn->GetClientID()));
return Status::OK();
}
if (conn == nullptr) {
conn =
CreateSenderConnection(ConnectionPool::ConnectionType::TRANSFER, connection_info);
connection_pool_.RegisterSender(ConnectionPool::ConnectionType::TRANSFER, client_id,
conn);
}
status = SendObjectHeaders(object_id, conn);
if (!status.ok()) {
RAY_CHECK_OK(Push(object_id, conn->GetClientID()));
return Status::OK();
}
return Status::OK();
}
ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id_const,
std::shared_ptr<SenderConnection> conn) {
ObjectID object_id = ObjectID(object_id_const);
// Allocate and append the request to the transfer queue.
plasma::ObjectBuffer object_buffer;
plasma::ObjectID plasma_id = object_id.to_plasma_id();
std::shared_ptr<plasma::PlasmaClient> store_client = store_pool_.GetObjectStore();
ARROW_CHECK_OK(store_client->Get(&plasma_id, 1, 0, &object_buffer));
if (object_buffer.data_size == -1) {
RAY_LOG(ERROR) << "Failed to get object";
// If the object wasn't locally available, exit immediately. If the object
// later appears locally, the requesting plasma manager should request the
// transfer again.
RAY_CHECK_OK(
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::TRANSFER, conn));
return ray::Status::IOError(
"Unable to transfer object to requesting plasma manager, object not local.");
}
RAY_CHECK(object_buffer.metadata->data() ==
object_buffer.data->data() + object_buffer.data_size);
TransferQueue::SendContext context;
context.client_id = conn->GetClientID();
context.object_id = object_id;
context.object_size = static_cast<uint64_t>(object_buffer.data_size);
context.data = const_cast<uint8_t *>(object_buffer.data->data());
UniqueID context_id = transfer_queue_.AddContext(context);
// Create buffer.
flatbuffers::FlatBufferBuilder fbb;
// TODO(hme): use to_flatbuf
auto message = object_manager_protocol::CreatePushRequestMessage(
fbb, fbb.CreateString(object_id.binary()), context.object_size);
fbb.Finish(message);
ray::Status status =
conn->WriteMessage(object_manager_protocol::MessageType_PushRequest, fbb.GetSize(),
fbb.GetBufferPointer());
if (!status.ok()) {
// push failed.
// TODO(hme): Trash sender.
RAY_CHECK_OK(
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::TRANSFER, conn));
return status;
}
// TODO(hme): Make this async.
return SendObjectData(conn, context_id, store_client);
}
ray::Status ObjectManager::SendObjectData(
std::shared_ptr<SenderConnection> conn, const UniqueID &context_id,
std::shared_ptr<plasma::PlasmaClient> store_client) {
TransferQueue::SendContext context = transfer_queue_.GetContext(context_id);
boost::system::error_code ec;
std::vector<asio::const_buffer> buffer;
buffer.push_back(asio::buffer(context.data, context.object_size));
conn->WriteBuffer(buffer, ec);
ray::Status status = ray::Status::OK();
if (ec.value() != 0) {
// push failed.
// TODO(hme): Trash sender.
status = ray::Status::IOError(ec.message());
}
// Do this regardless of whether it failed or succeeded.
ARROW_CHECK_OK(store_client->Release(context.object_id.to_plasma_id()));
store_pool_.ReleaseObjectStore(store_client);
RAY_CHECK_OK(
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::TRANSFER, conn));
RAY_CHECK_OK(transfer_queue_.RemoveContext(context_id));
RAY_LOG(DEBUG) << "SendCompleted " << client_id_ << " " << context.object_id << " "
<< num_transfers_send_ << "/" << config_.max_sends;
RAY_CHECK_OK(TransferCompleted(TransferQueue::TransferType::SEND));
return status;
}
ray::Status ObjectManager::Cancel(const ObjectID &object_id) {
// TODO(hme): Account for pull timers.
ray::Status status = object_directory_->Cancel(object_id);
return ray::Status::OK();
};
}
ray::Status ObjectManager::Wait(const std::vector<ObjectID> &object_ids,
uint64_t timeout_ms, int num_ready_objects,
const WaitCallback &callback) {
// TODO: Implement wait.
return ray::Status::OK();
};
ray::Status ObjectManager::GetMsgConnection(
const ClientID &client_id, std::function<void(SenderConnection::pointer)> callback) {
ray::Status status = Status::OK();
if (message_send_connections_.count(client_id) > 0) {
callback(message_send_connections_[client_id]);
} else {
status = object_directory_->GetInformation(
client_id,
[this, callback](RemoteConnectionInfo info) {
Status status = CreateMsgConnection(info, callback);
},
[this](const Status &status) {
// TODO: deal with failure.
});
}
return status;
};
ray::Status ObjectManager::CreateMsgConnection(
const RemoteConnectionInfo &info,
std::function<void(SenderConnection::pointer)> callback) {
message_send_connections_.emplace(
info.client_id, SenderConnection::Create(io_service_, info.ip, info.port));
// Prepare client connection info buffer.
flatbuffers::FlatBufferBuilder fbb;
bool is_transfer = false;
auto message =
CreateClientConnectionInfo(fbb, fbb.CreateString(client_id_.binary()), is_transfer);
fbb.Finish(message);
// Pack into asio buffer.
size_t length = fbb.GetSize();
std::vector<boost::asio::const_buffer> buffer;
buffer.push_back(boost::asio::buffer(&length, sizeof(length)));
buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length));
// Send synchronously.
SenderConnection::pointer conn = message_send_connections_[info.client_id];
boost::system::error_code error;
boost::asio::write(conn->GetSocket(), buffer);
// The connection is ready, invoke callback with connection info.
callback(message_send_connections_[info.client_id]);
return Status::OK();
};
ray::Status ObjectManager::GetTransferConnection(
const ClientID &client_id, std::function<void(SenderConnection::pointer)> callback) {
ray::Status status = Status::OK();
if (transfer_send_connections_.count(client_id) > 0) {
callback(transfer_send_connections_[client_id]);
} else {
status = object_directory_->GetInformation(
client_id,
[this, callback](RemoteConnectionInfo info) {
Status status = CreateTransferConnection(info, callback);
},
[this](const Status &status) {
// TODO(hme): deal with failure.
});
}
return status;
};
ray::Status ObjectManager::CreateTransferConnection(
const RemoteConnectionInfo &info,
std::function<void(SenderConnection::pointer)> callback) {
transfer_send_connections_.emplace(
info.client_id, SenderConnection::Create(io_service_, info.ip, info.port));
// Prepare client connection info buffer.
flatbuffers::FlatBufferBuilder fbb;
bool is_transfer = true;
auto message =
CreateClientConnectionInfo(fbb, fbb.CreateString(client_id_.binary()), is_transfer);
fbb.Finish(message);
// Pack into asio buffer.
size_t length = fbb.GetSize();
std::vector<boost::asio::const_buffer> buffer;
buffer.push_back(boost::asio::buffer(&length, sizeof(length)));
buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length));
// Send synchronously.
SenderConnection::pointer conn = transfer_send_connections_[info.client_id];
boost::system::error_code ec;
boost::asio::write(conn->GetSocket(), buffer, ec);
callback(transfer_send_connections_[info.client_id]);
return Status::OK();
};
ray::Status ObjectManager::AcceptConnection(TCPClientConnection::pointer conn) {
boost::system::error_code ec;
// read header
size_t length;
std::vector<boost::asio::mutable_buffer> header;
header.push_back(boost::asio::buffer(&length, sizeof(length)));
boost::asio::read(conn->GetSocket(), header, ec);
// read data
std::vector<uint8_t> message;
message.resize(length);
boost::asio::read(conn->GetSocket(), boost::asio::buffer(message), ec);
// Serialize
auto info = flatbuffers::GetRoot<ClientConnectionInfo>(message.data());
ClientID client_id = ObjectID::from_binary(info->client_id()->str());
bool is_transfer = info->is_transfer();
// TODO: trash connection if either fails.
if (is_transfer) {
transfer_receive_connections_[client_id] = conn;
Status status = WaitPushReceive(conn);
return status;
} else {
message_receive_connections_[client_id] = conn;
Status status = WaitMessage(conn);
return status;
}
};
ray::Status ObjectManager::WaitPushReceive(TCPClientConnection::pointer conn) {
boost::asio::async_read(
conn->GetSocket(),
boost::asio::buffer(&conn->message_length_, sizeof(conn->message_length_)),
boost::bind(&ObjectManager::HandlePushReceive, this, conn,
boost::asio::placeholders::error));
return ray::Status::OK();
}
void ObjectManager::HandlePushReceive(TCPClientConnection::pointer conn,
BoostEC length_ec) {
std::vector<uint8_t> message;
message.resize(conn->message_length_);
boost::system::error_code ec;
boost::asio::read(conn->GetSocket(), boost::asio::buffer(message), ec);
std::shared_ptr<SenderConnection> ObjectManager::CreateSenderConnection(
ConnectionPool::ConnectionType type, RemoteConnectionInfo info) {
std::shared_ptr<SenderConnection> conn = SenderConnection::Create(
*object_manager_service_, info.client_id, info.ip, info.port);
// Prepare client connection info buffer
flatbuffers::FlatBufferBuilder fbb;
bool is_transfer = (type == ConnectionPool::ConnectionType::TRANSFER);
auto message = object_manager_protocol::CreateConnectClientMessage(
fbb, fbb.CreateString(client_id_.binary()), is_transfer);
fbb.Finish(message);
// Send synchronously.
RAY_CHECK_OK(conn->WriteMessage(object_manager_protocol::MessageType_ConnectClient,
fbb.GetSize(), fbb.GetBufferPointer()));
// The connection is ready; return to caller.
return conn;
}
void ObjectManager::ProcessNewClient(std::shared_ptr<TcpClientConnection> conn) {
conn->ProcessMessages();
}
void ObjectManager::ProcessClientMessage(std::shared_ptr<TcpClientConnection> conn,
int64_t message_type, const uint8_t *message) {
switch (message_type) {
case object_manager_protocol::MessageType_PushRequest: {
ReceivePushRequest(conn, message);
break;
}
case object_manager_protocol::MessageType_PullRequest: {
ReceivePullRequest(conn, message);
break;
}
case object_manager_protocol::MessageType_ConnectClient: {
ConnectClient(conn, message);
break;
}
case object_manager_protocol::MessageType_DisconnectClient: {
DisconnectClient(conn, message);
break;
}
default: { RAY_LOG(FATAL) << "invalid request " << message_type; }
}
}
void ObjectManager::ConnectClient(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message) {
// TODO: trash connection on failure.
auto info =
flatbuffers::GetRoot<object_manager_protocol::ConnectClientMessage>(message);
ClientID client_id = ObjectID::from_binary(info->client_id()->str());
bool is_transfer = info->is_transfer();
conn->SetClientID(client_id);
if (is_transfer) {
connection_pool_.RegisterReceiver(ConnectionPool::ConnectionType::TRANSFER, client_id,
conn);
} else {
connection_pool_.RegisterReceiver(ConnectionPool::ConnectionType::MESSAGE, client_id,
conn);
}
conn->ProcessMessages();
}
void ObjectManager::DisconnectClient(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message) {
auto info =
flatbuffers::GetRoot<object_manager_protocol::DisconnectClientMessage>(message);
ClientID client_id = ObjectID::from_binary(info->client_id()->str());
bool is_transfer = info->is_transfer();
if (is_transfer) {
connection_pool_.RemoveReceiver(ConnectionPool::ConnectionType::TRANSFER, client_id,
conn);
} else {
connection_pool_.RemoveReceiver(ConnectionPool::ConnectionType::MESSAGE, client_id,
conn);
}
}
void ObjectManager::ReceivePullRequest(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message) {
// Serialize and push object to requesting client.
auto pr = flatbuffers::GetRoot<object_manager_protocol::PullRequestMessage>(message);
ObjectID object_id = ObjectID::from_binary(pr->object_id()->str());
ClientID client_id = ClientID::from_binary(pr->client_id()->str());
ray::Status push_status = Push(object_id, client_id);
conn->ProcessMessages();
}
void ObjectManager::ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn,
const uint8_t *message) {
// Serialize.
auto object_header = flatbuffers::GetRoot<ObjectHeader>(message.data());
auto object_header =
flatbuffers::GetRoot<object_manager_protocol::PushRequestMessage>(message);
ObjectID object_id = ObjectID::from_binary(object_header->object_id()->str());
int64_t object_size = (int64_t)object_header->object_size();
transfer_queue_.QueueReceive(conn->GetClientID(), object_id, object_size, conn);
RAY_CHECK_OK(DequeueTransfers());
}
ray::Status ObjectManager::ExecuteReceiveObject(
const ClientID &client_id, const ObjectID &object_id, uint64_t object_size,
std::shared_ptr<TcpClientConnection> conn) {
boost::system::error_code ec;
int64_t metadata_size = 0;
const plasma::ObjectID plasma_id = ObjectID(object_id).to_plasma_id();
// Try to create shared buffer.
std::shared_ptr<Buffer> data;
arrow::Status s = store_client_->GetClient().Create(
object_id.to_plasma_id(), object_size, NULL, metadata_size, &data);
std::shared_ptr<plasma::PlasmaClient> store_client = store_pool_.GetObjectStore();
arrow::Status s =
store_client->Create(plasma_id, object_size, NULL, metadata_size, &data);
std::vector<boost::asio::mutable_buffer> buffer;
if (s.ok()) {
// Read object into store.
uint8_t *mutable_data = data->mutable_data();
boost::asio::read(conn->GetSocket(), boost::asio::buffer(mutable_data, object_size),
ec);
buffer.push_back(asio::buffer(mutable_data, object_size));
conn->ReadBuffer(buffer, ec);
if (!ec.value()) {
ARROW_CHECK_OK(store_client_->GetClient().Seal(object_id.to_plasma_id()));
ARROW_CHECK_OK(store_client_->GetClient().Release(object_id.to_plasma_id()));
ARROW_CHECK_OK(store_client->Seal(plasma_id));
ARROW_CHECK_OK(store_client->Release(plasma_id));
} else {
ARROW_CHECK_OK(store_client_->GetClient().Release(object_id.to_plasma_id()));
ARROW_CHECK_OK(store_client_->GetClient().Abort(object_id.to_plasma_id()));
ARROW_CHECK_OK(store_client->Release(plasma_id));
ARROW_CHECK_OK(store_client->Abort(plasma_id));
RAY_LOG(ERROR) << "Receive Failed";
}
} else {
RAY_LOG(ERROR) << "Buffer Create Failed: " << s.message();
// Read object into empty buffer.
uint8_t *mutable_data = (uint8_t *)malloc(object_size + metadata_size);
boost::asio::read(conn->GetSocket(), boost::asio::buffer(mutable_data, object_size),
ec);
std::vector<uint8_t> mutable_data;
mutable_data.resize(object_size + metadata_size);
buffer.push_back(asio::buffer(mutable_data, object_size + metadata_size));
conn->ReadBuffer(buffer, ec);
}
// Wait for another push.
ray::Status ray_status = WaitPushReceive(conn);
};
ray::Status ObjectManager::QueuePush(const ObjectID &object_id_const,
SenderConnection::pointer conn) {
ObjectID object_id = ObjectID(object_id_const);
if (conn->ObjectIdQueued(object_id)) {
// For now, return with status OK if the object is already in the send queue.
return ray::Status::OK();
}
conn->QueueObjectId(object_id);
if (num_transfers_ < max_transfers_) {
return ExecutePushQueue(conn);
}
return ray::Status::OK();
};
ray::Status ObjectManager::ExecutePushQueue(SenderConnection::pointer conn) {
ray::Status status = ray::Status::OK();
while (num_transfers_ < max_transfers_) {
if (conn->IsObjectIdQueueEmpty()) {
return ray::Status::OK();
}
ObjectID object_id = conn->DequeueObjectId();
// The threads that increment/decrement num_transfers_ are different.
// It's important to increment num_transfers_ before executing the push.
num_transfers_ += 1;
status = ExecutePushHeaders(object_id, conn);
}
return status;
};
ray::Status ObjectManager::ExecutePushHeaders(const ObjectID &object_id_const,
SenderConnection::pointer conn) {
ObjectID object_id = ObjectID(object_id_const);
// Allocate and append the request to the transfer queue.
plasma::ObjectBuffer object_buffer;
plasma::ObjectID plasma_id = object_id.to_plasma_id();
ARROW_CHECK_OK(store_client_->GetClientOther().Get(&plasma_id, 1, 0, &object_buffer));
if (object_buffer.data_size == -1) {
RAY_LOG(ERROR) << "Failed to get object";
// If the object wasn't locally available, exit immediately. If the object
// later appears locally, the requesting plasma manager should request the
// transfer again.
return ray::Status::IOError(
"Unable to transfer object to requesting plasma manager, object not local.");
}
RAY_CHECK(object_buffer.metadata->data() ==
object_buffer.data->data() + object_buffer.data_size);
SendRequest send_request;
send_request.object_id = object_id;
send_request.object_size = object_buffer.data_size;
send_request.data = const_cast<uint8_t *>(object_buffer.data->data());
conn->AddSendRequest(object_id, send_request);
// Create buffer.
flatbuffers::FlatBufferBuilder fbb;
auto message = CreateObjectHeader(fbb, fbb.CreateString(object_id.binary()),
send_request.object_size);
fbb.Finish(message);
// Pack into asio buffer.
size_t length = fbb.GetSize();
std::vector<boost::asio::const_buffer> buffer;
buffer.push_back(boost::asio::buffer(&length, sizeof(length)));
buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length));
// Send asynchronously.
boost::asio::async_write(conn->GetSocket(), buffer,
boost::bind(&ObjectManager::ExecutePushObject, this, conn,
object_id, boost::asio::placeholders::error));
return ray::Status::OK();
};
void ObjectManager::ExecutePushObject(SenderConnection::pointer conn,
const ObjectID &object_id,
const boost::system::error_code &header_ec) {
SendRequest &send_request = conn->GetSendRequest(object_id);
boost::system::error_code ec;
boost::asio::write(
conn->GetSocket(),
boost::asio::buffer(send_request.data, (size_t)send_request.object_size), ec);
// Do this regardless of whether it failed or succeeded.
ARROW_CHECK_OK(
store_client_->GetClientOther().Release(send_request.object_id.to_plasma_id()));
ray::Status ray_status = ExecutePushCompleted(object_id, conn);
}
ray::Status ObjectManager::ExecutePushCompleted(const ObjectID &object_id,
SenderConnection::pointer conn) {
conn->RemoveSendRequest(object_id);
num_transfers_ -= 1;
return ExecutePushQueue(conn);
};
ray::Status ObjectManager::WaitMessage(TCPClientConnection::pointer conn) {
boost::asio::async_read(
conn->GetSocket(),
boost::asio::buffer(&conn->message_type_, sizeof(conn->message_type_)),
boost::bind(&ObjectManager::HandleMessage, this, conn,
boost::asio::placeholders::error));
return ray::Status::OK();
}
void ObjectManager::HandleMessage(TCPClientConnection::pointer conn, BoostEC msg_ec) {
switch (conn->message_type_) {
case OMMessageType_PullRequest:
ReceivePullRequest(conn);
}
}
void ObjectManager::ReceivePullRequest(TCPClientConnection::pointer conn) {
boost::asio::read(
conn->GetSocket(),
boost::asio::buffer(&conn->message_length_, sizeof(conn->message_length_)));
std::vector<uint8_t> message;
message.resize(conn->message_length_);
boost::system::error_code error_code;
boost::asio::read(conn->GetSocket(), boost::asio::buffer(message), error_code);
// Serialize.
auto pull_request = flatbuffers::GetRoot<PullRequest>(message.data());
ObjectID object_id = ObjectID::from_binary(pull_request->object_id()->str());
ClientID client_id = ClientID::from_binary(pull_request->client_id()->str());
// Push object to requesting client.
ray::Status push_status = Push(object_id, client_id);
ray::Status wait_status = WaitMessage(conn);
store_pool_.ReleaseObjectStore(store_client);
conn->ProcessMessages();
RAY_LOG(DEBUG) << "ReceiveCompleted " << client_id_ << " " << object_id << " "
<< num_transfers_receive_ << "/" << config_.max_receives;
RAY_CHECK_OK(TransferCompleted(TransferQueue::TransferType::RECEIVE));
return Status::OK();
}
} // namespace ray
+139 -117
View File
@@ -12,58 +12,65 @@
#include <boost/asio/error.hpp>
#include <boost/bind.hpp>
#include "ray/common/client_connection.h"
#include "ray/id.h"
#include "ray/status.h"
#include "plasma/client.h"
#include "plasma/events.h"
#include "plasma/plasma.h"
#include "format/object_manager_generated.h"
#include "object_directory.h"
#include "object_manager_client_connection.h"
#include "object_store_client.h"
#include "ray/id.h"
#include "ray/status.h"
#include "ray/object_manager/connection_pool.h"
#include "ray/object_manager/format/object_manager_generated.h"
#include "ray/object_manager/object_directory.h"
#include "ray/object_manager/object_manager_client_connection.h"
#include "ray/object_manager/object_store_client_pool.h"
#include "ray/object_manager/object_store_notification_manager.h"
#include "ray/object_manager/transfer_queue.h"
namespace ray {
struct ObjectManagerConfig {
// The time in milliseconds to wait before retrying a pull
// that failed due to client id lookup.
/// The time in milliseconds to wait before retrying a pull
/// that failed due to client id lookup.
int pull_timeout_ms = 100;
/// Size of thread pool.
int num_threads = 2;
/// Maximum number of sends allowed.
int max_sends = 20;
/// Maximum number of receives allowed.
int max_receives = 20;
// TODO(hme): Implement num retries (to avoid infinite retries).
std::string store_socket_name;
};
// TODO(hme): Comment everything doxygen-style.
// TODO(hme): Implement connection cleanup.
// TODO(hme): Add success/failure callbacks for push and pull.
// TODO(hme): Use boost thread pool.
// TODO(hme): Add incoming connections to io_service tied to thread pool.
class ObjectManager {
public:
/// Implicitly instantiates Ray implementation of ObjectDirectory.
///
/// \param io_service The asio io_service tied to the object manager.
/// \param main_service The main asio io_service.
/// \param object_manager_service The asio io_service tied to the object manager.
/// \param config ObjectManager configuration.
/// \param gcs_client A client connection to the Ray GCS.
explicit ObjectManager(boost::asio::io_service &io_service, ObjectManagerConfig config,
std::shared_ptr<ray::GcsClient> gcs_client);
explicit ObjectManager(boost::asio::io_service &main_service,
std::unique_ptr<boost::asio::io_service> object_manager_service,
const ObjectManagerConfig &config,
std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
/// Takes user-defined ObjectDirectoryInterface implementation.
/// When this constructor is used, the ObjectManager assumes ownership of
/// the given ObjectDirectory instance.
///
/// \param io_service The asio io_service tied to the object manager.
/// \param main_service The main asio io_service.
/// \param object_manager_service The asio io_service tied to the object manager.
/// \param config ObjectManager configuration.
/// \param od An object implementing the object directory interface.
explicit ObjectManager(boost::asio::io_service &io_service, ObjectManagerConfig config,
explicit ObjectManager(boost::asio::io_service &main_service,
std::unique_ptr<boost::asio::io_service> object_manager_service,
const ObjectManagerConfig &config,
std::unique_ptr<ObjectDirectoryInterface> od);
/// \param client_id Set the client id associated with this node.
void SetClientID(const ClientID &client_id);
/// \return Get the client id associated with this node.
ClientID GetClientID();
/// Subscribe to notifications of objects added to local store.
/// Upon subscribing, the callback will be invoked for all objects that
///
@@ -106,7 +113,17 @@ class ObjectManager {
///
/// \param conn The connection.
/// \return Status of whether the connection was successfully established.
ray::Status AcceptConnection(TCPClientConnection::pointer conn);
void ProcessNewClient(std::shared_ptr<TcpClientConnection> conn);
/// Process messages sent from other nodes. We only establish
/// transfer connections using this method; all other transfer communication
/// is done separately.
///
/// \param conn The connection.
/// \param message_type The message type.
/// \param message A pointer set to the beginning of the message.
void ProcessClientMessage(std::shared_ptr<TcpClientConnection> conn,
int64_t message_type, const uint8_t *message);
/// Cancels all requests (Push/Pull) associated with the given ObjectID.
///
@@ -114,7 +131,7 @@ class ObjectManager {
/// \return Status of whether requests were successfully cancelled.
ray::Status Cancel(const ObjectID &object_id);
// Callback definition for wait.
/// Callback definition for wait.
using WaitCallback = std::function<void(const ray::Status, uint64_t,
const std::vector<ray::ObjectID> &)>;
/// Wait for timeout_ms before invoking the provided callback.
@@ -135,132 +152,137 @@ class ObjectManager {
ray::Status Terminate();
private:
using BoostEC = const boost::system::error_code &;
ClientID client_id_;
ObjectManagerConfig config_;
std::unique_ptr<ObjectDirectoryInterface> object_directory_;
std::unique_ptr<ObjectStoreClient> store_client_;
ObjectStoreNotificationManager store_notification_;
ObjectStoreClientPool store_pool_;
/// An io service for creating connections to other object managers.
boost::asio::io_service io_service_;
/// This runs on a thread pool.
std::unique_ptr<boost::asio::io_service> object_manager_service_;
/// Weak reference to main service. We ensure this object is destroyed before
/// main_service_ is stopped.
boost::asio::io_service *main_service_;
/// Used to create "work" for an io service, so when it's run, it doesn't exit.
boost::asio::io_service::work work_;
/// Single thread for executing asynchronous handlers.
/// This runs the (currently only) io_service, which handles all outgoing requests
/// and object transfers (push).
std::thread io_thread_;
/// Thread pool for executing asynchronous handlers.
/// These run the object_manager_service_, which handle
/// all incoming and outgoing object transfers.
std::vector<std::thread> io_threads_;
/// Relatively simple way to add thread pooling.
/// boost::thread_group thread_group_;
/// Connection pool for reusing outgoing connections to remote object managers.
ConnectionPool connection_pool_;
/// Timeout for failed pull requests.
using Timer = std::shared_ptr<boost::asio::deadline_timer>;
std::unordered_map<ObjectID, Timer, UniqueIDHasher> pull_requests_;
std::unordered_map<ObjectID, std::shared_ptr<boost::asio::deadline_timer>,
UniqueIDHasher>
pull_requests_;
// TODO (hme): This needs to account for receives as well.
/// This number is incremented whenever a push is started.
int num_transfers_ = 0;
// TODO (hme): Allow for concurrent sends.
/// This is the maximum number of pushes allowed.
/// We can only increase this number if we increase the number of
/// plasma client connections.
int max_transfers_ = 1;
/// Allows control of concurrent object transfers. This is a global queue,
/// allowing for concurrent transfers with many object managers as well as
/// concurrent transfers, including both sends and receives, with a single
/// remote object manager.
TransferQueue transfer_queue_;
/// Note that (currently) receives take place on the main thread,
/// and sends take place on a dedicated thread.
std::unordered_map<ray::ClientID, SenderConnection::pointer, ray::UniqueIDHasher>
message_send_connections_;
std::unordered_map<ray::ClientID, SenderConnection::pointer, ray::UniqueIDHasher>
transfer_send_connections_;
/// Variables to track number of concurrent sends and receives.
std::atomic<int> num_transfers_send_;
std::atomic<int> num_transfers_receive_;
std::unordered_map<ray::ClientID, TCPClientConnection::pointer, ray::UniqueIDHasher>
message_receive_connections_;
std::unordered_map<ray::ClientID, TCPClientConnection::pointer, ray::UniqueIDHasher>
transfer_receive_connections_;
/// Cache of locally available objects.
std::unordered_set<ObjectID, UniqueIDHasher> local_objects_;
/// Handle starting, running, and stopping asio io_service.
void StartIOService();
void IOServiceLoop();
void StopIOService();
/// Register object add with directory.
void NotifyDirectoryObjectAdd(const ObjectID &object_id);
/// Register object remove with directory.
void NotifyDirectoryObjectDeleted(const ObjectID &object_id);
/// Wait wait_ms milliseconds before triggering a pull request for object_id.
/// This is invoked when a pull fails. Only point of failure currently considered
/// is GetLocationsFailed.
void SchedulePull(const ObjectID &object_id, int wait_ms);
/// The handler for SchedulePull. Invokes a pull and removes the deadline timer
/// that was added to schedule the pull.
ray::Status SchedulePullHandler(const ObjectID &object_id);
/// Part of an asynchronous sequence of Pull methods.
/// Gets the location of an object before invoking PullEstablishConnection.
/// Guaranteed to execute on main_service_ thread.
/// Executes on main_service_ thread.
ray::Status PullGetLocations(const ObjectID &object_id);
/// Synchronously send a pull request.
/// Invoked once a connection to a remote manager that contains the required ObjectID
/// is established.
ray::Status ExecutePull(const ObjectID &object_id, SenderConnection::pointer conn);
/// Part of an asynchronous sequence of Pull methods.
/// Uses an existing connection or creates a connection to ClientID.
/// Executes on main_service_ thread.
ray::Status PullEstablishConnection(const ObjectID &object_id,
const ClientID &client_id);
/// Invoked once a connection to the remote manager to which the ObjectID
/// is to be sent is established.
ray::Status QueuePush(const ObjectID &object_id, SenderConnection::pointer client);
/// Starts as many queued pushes as possible without exceeding max_transfers_
/// concurrent transfers.
ray::Status ExecutePushQueue(SenderConnection::pointer client);
/// Initiate a push. This method asynchronously sends the object id and object size
/// Private callback implementation for success on get location. Called from
/// ObjectDirectory.
void GetLocationsSuccess(const std::vector<ray::ClientID> &client_ids,
const ray::ObjectID &object_id);
/// Private callback implementation for failure on get location. Called from
/// ObjectDirectory.
void GetLocationsFailed(const ObjectID &object_id);
/// Synchronously send a pull request via remote object manager connection.
/// Executes on main_service_ thread.
ray::Status PullSendRequest(const ObjectID &object_id,
std::shared_ptr<SenderConnection> conn);
/// Starts as many queued sends and receives as possible without exceeding
/// config_.max_sends and config_.max_receives, respectively.
/// Executes on object_manager_service_ thread pool.
ray::Status DequeueTransfers();
std::shared_ptr<SenderConnection> CreateSenderConnection(
ConnectionPool::ConnectionType type, RemoteConnectionInfo info);
/// Invoked when a transfer is completed. Invokes DequeueTransfers after
/// updating variables that track concurrent transfers.
/// Executes on object_manager_service_ thread pool.
ray::Status TransferCompleted(TransferQueue::TransferType type);
/// Begin executing a send.
/// Executes on object_manager_service_ thread pool.
ray::Status ExecuteSendObject(const ObjectID &object_id, const ClientID &client_id,
const RemoteConnectionInfo &connection_info);
/// This method synchronously sends the object id and object size
/// to the remote object manager.
ray::Status ExecutePushHeaders(const ObjectID &object_id,
SenderConnection::pointer client);
/// Called by the handler for ExecutePushMeta.
/// Executes on object_manager_service_ thread pool.
ray::Status SendObjectHeaders(const ObjectID &object_id,
std::shared_ptr<SenderConnection> client);
/// This method initiates the actual object transfer.
void ExecutePushObject(SenderConnection::pointer conn, const ObjectID &object_id,
const boost::system::error_code &header_ec);
/// Invoked when a push is completed. This method will decrement num_transfers_
/// and invoke ExecutePushQueue.
ray::Status ExecutePushCompleted(const ObjectID &object_id,
SenderConnection::pointer client);
/// Executes on object_manager_service_ thread pool.
ray::Status SendObjectData(std::shared_ptr<SenderConnection> conn,
const UniqueID &context_id,
std::shared_ptr<plasma::PlasmaClient> store_client);
/// Private callback implementation for success on get location. Called inside OD.
void GetLocationsSuccess(const std::vector<RemoteConnectionInfo> &vec,
const ObjectID &object_id);
/// Private callback implementation for failure on get location. Called inside OD.
void GetLocationsFailed(ray::Status status, const ObjectID &object_id);
/// Asynchronously obtain a connection to client_id.
/// If a connection to client_id already exists, the callback is invoked immediately.
ray::Status GetMsgConnection(const ClientID &client_id,
std::function<void(SenderConnection::pointer)> callback);
/// Asynchronously create a connection to client_id.
ray::Status CreateMsgConnection(
const RemoteConnectionInfo &info,
std::function<void(SenderConnection::pointer)> callback);
/// Asynchronously create a connection to client_id.
ray::Status GetTransferConnection(
const ClientID &client_id, std::function<void(SenderConnection::pointer)> callback);
/// Asynchronously obtain a connection to client_id.
/// If a connection to client_id already exists, the callback is invoked immediately.
ray::Status CreateTransferConnection(
const RemoteConnectionInfo &info,
std::function<void(SenderConnection::pointer)> callback);
/// A socket connection doing an asynchronous read on a transfer connection that was
/// added by AcceptConnection.
ray::Status WaitPushReceive(TCPClientConnection::pointer conn);
/// Invoked when a remote object manager pushes an object to this object manager.
void HandlePushReceive(TCPClientConnection::pointer conn, BoostEC length_ec);
/// This will queue the receive.
void ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn,
const uint8_t *message);
/// Execute a receive that was in the queue.
ray::Status ExecuteReceiveObject(const ClientID &client_id, const ObjectID &object_id,
uint64_t object_size,
std::shared_ptr<TcpClientConnection> conn);
/// A socket connection doing an asynchronous read on a message connection that was
/// added by AcceptConnection.
ray::Status WaitMessage(TCPClientConnection::pointer conn);
/// Handle messages.
void HandleMessage(TCPClientConnection::pointer conn, BoostEC msg_ec);
/// Process the receive pull request message.
void ReceivePullRequest(TCPClientConnection::pointer conn);
/// Handles receiving a pull request message.
void ReceivePullRequest(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message);
/// Register object add with directory.
void NotifyDirectoryObjectAdd(const ObjectID &object_id);
/// Register object remove with directory.
void NotifyDirectoryObjectDeleted(const ObjectID &object_id);
/// Handles connect message of a new client connection.
void ConnectClient(std::shared_ptr<TcpClientConnection> &conn, const uint8_t *message);
/// Handles disconnect message of an existing client connection.
void DisconnectClient(std::shared_ptr<TcpClientConnection> &conn,
const uint8_t *message);
};
} // namespace ray
@@ -1,59 +1,24 @@
#include "object_manager_client_connection.h"
#include "ray/object_manager/object_manager_client_connection.h"
namespace ray {
SenderConnection::pointer SenderConnection::Create(boost::asio::io_service &io_service,
const std::string &ip, uint16_t port) {
return pointer(new SenderConnection(io_service, ip, port));
uint64_t SenderConnection::id_counter_;
std::shared_ptr<SenderConnection> SenderConnection::Create(
boost::asio::io_service &io_service, const ClientID &client_id, const std::string &ip,
uint16_t port) {
boost::asio::ip::tcp::socket socket(io_service);
RAY_CHECK_OK(TcpConnect(socket, ip, port));
std::shared_ptr<TcpServerConnection> conn =
std::make_shared<TcpServerConnection>(std::move(socket));
return std::make_shared<SenderConnection>(conn, client_id);
};
SenderConnection::SenderConnection(boost::asio::io_service &io_service,
const std::string &ip, uint16_t port)
: socket_(io_service), send_queue_() {
boost::asio::ip::address ip_address = boost::asio::ip::address::from_string(ip);
boost::asio::ip::tcp::endpoint endpoint(ip_address, port);
socket_.connect(endpoint);
SenderConnection::SenderConnection(std::shared_ptr<TcpServerConnection> conn,
const ClientID &client_id)
: conn_(conn) {
client_id_ = client_id;
connection_id_ = SenderConnection::id_counter_++;
};
boost::asio::ip::tcp::socket &SenderConnection::GetSocket() { return socket_; };
bool SenderConnection::IsObjectIdQueueEmpty() { return send_queue_.empty(); }
bool SenderConnection::ObjectIdQueued(const ObjectID &object_id) {
return std::find(send_queue_.begin(), send_queue_.end(), object_id) !=
send_queue_.end();
}
void SenderConnection::QueueObjectId(const ObjectID &object_id) {
send_queue_.push_back(ObjectID(object_id));
}
ObjectID SenderConnection::DequeueObjectId() {
ObjectID object_id = send_queue_.front();
send_queue_.pop_front();
return object_id;
}
void SenderConnection::AddSendRequest(const ObjectID &object_id,
SendRequest &send_request) {
send_requests_.emplace(object_id, send_request);
}
void SenderConnection::RemoveSendRequest(const ObjectID &object_id) {
send_requests_.erase(object_id);
}
SendRequest &SenderConnection::GetSendRequest(const ObjectID &object_id) {
return send_requests_[object_id];
};
TCPClientConnection::TCPClientConnection(boost::asio::io_service &io_service)
: socket_(io_service) {}
TCPClientConnection::pointer TCPClientConnection::Create(
boost::asio::io_service &io_service) {
return TCPClientConnection::pointer(new TCPClientConnection(io_service));
}
boost::asio::ip::tcp::socket &TCPClientConnection::GetSocket() { return socket_; }
} // namespace ray
@@ -7,63 +7,72 @@
#include <boost/asio.hpp>
#include <boost/asio/error.hpp>
#include <boost/bind.hpp>
#include <boost/enable_shared_from_this.hpp>
#include "common/state/ray_config.h"
#include "ray/common/client_connection.h"
#include "ray/id.h"
namespace ray {
struct SendRequest {
ObjectID object_id;
ClientID client_id;
int64_t object_size;
uint8_t *data;
};
// TODO(hme): Document public API after integration with common connection.
class SenderConnection : public boost::enable_shared_from_this<SenderConnection> {
public:
typedef boost::shared_ptr<SenderConnection> pointer;
typedef std::unordered_map<ray::ObjectID, SendRequest, UniqueIDHasher> SendRequestsType;
typedef std::deque<ray::ObjectID> SendQueueType;
/// Create a connection for sending data to other object managers.
///
/// \param io_service The service to which the created socket should attach.
/// \param client_id The ClientID of the remote node.
/// \param ip The ip address of the remote node server.
/// \param port The port of the remote node server.
/// \return A connection to the remote object manager.
static std::shared_ptr<SenderConnection> Create(boost::asio::io_service &io_service,
const ClientID &client_id,
const std::string &ip, uint16_t port);
static pointer Create(boost::asio::io_service &io_service, const std::string &ip,
uint16_t port);
/// \param socket A reference to the socket created by the static Create method.
/// \param client_id The ClientID of the remote node.
SenderConnection(std::shared_ptr<TcpServerConnection> conn, const ClientID &client_id);
explicit SenderConnection(boost::asio::io_service &io_service, const std::string &ip,
uint16_t port);
/// Write a message to the client.
///
/// \param type The message type (e.g., a flatbuffer enum).
/// \param length The size in bytes of the message.
/// \param message A pointer to the message buffer.
/// \return Status.
ray::Status WriteMessage(int64_t type, uint64_t length, const uint8_t *message) {
return conn_->WriteMessage(type, length, message);
}
boost::asio::ip::tcp::socket &GetSocket();
/// Write a buffer to this connection.
///
/// \param buffer The buffer.
/// \param ec The error code object in which to store error codes.
void WriteBuffer(const std::vector<boost::asio::const_buffer> &buffer,
boost::system::error_code &ec) {
return conn_->WriteBuffer(buffer, ec);
}
bool IsObjectIdQueueEmpty();
bool ObjectIdQueued(const ObjectID &object_id);
void QueueObjectId(const ObjectID &object_id);
ObjectID DequeueObjectId();
/// Read a buffer from this connection.
///
/// \param buffer The buffer.
/// \param ec The error code object in which to store error codes.
void ReadBuffer(const std::vector<boost::asio::mutable_buffer> &buffer,
boost::system::error_code &ec) {
return conn_->ReadBuffer(buffer, ec);
}
void AddSendRequest(const ObjectID &object_id, SendRequest &send_request);
void RemoveSendRequest(const ObjectID &object_id);
SendRequest &GetSendRequest(const ObjectID &object_id);
/// \return The ClientID of this connection.
const ClientID &GetClientID() { return client_id_; }
private:
boost::asio::ip::tcp::socket socket_;
SendQueueType send_queue_;
SendRequestsType send_requests_;
};
bool operator==(const SenderConnection &rhs) const {
return connection_id_ == rhs.connection_id_;
}
// TODO(hme): Document public API after integration with common connection.
class TCPClientConnection : public boost::enable_shared_from_this<TCPClientConnection> {
public:
typedef boost::shared_ptr<TCPClientConnection> pointer;
static pointer Create(boost::asio::io_service &io_service);
boost::asio::ip::tcp::socket &GetSocket();
TCPClientConnection(boost::asio::io_service &io_service);
int64_t message_type_;
uint64_t message_length_;
private:
boost::asio::ip::tcp::socket socket_;
static uint64_t id_counter_;
uint64_t connection_id_;
ClientID client_id_;
std::shared_ptr<TcpServerConnection> conn_;
};
} // namespace ray
@@ -1 +0,0 @@
// TODO(hme): Move all messaging code here.
@@ -1,6 +0,0 @@
#ifndef RAY_OBJECT_MANAGER_OBJECT_MANAGER_PROTOCOL_H
#define RAY_OBJECT_MANAGER_OBJECT_MANAGER_PROTOCOL_H
// TODO(hme): Move all messaging code here.
#endif // RAY_OBJECT_MANAGER_OBJECT_MANAGER_PROTOCOL_H
@@ -1,137 +0,0 @@
#include <iostream>
#include <thread>
#include "gtest/gtest.h"
#include "plasma/client.h"
#include "plasma/events.h"
#include "plasma/plasma.h"
#include "plasma/protocol.h"
#include "ray/status.h"
#include "ray/object_manager/object_manager.h"
namespace ray {
std::string test_executable; // NOLINT
class TestObjectManager : public ::testing::Test {
public:
TestObjectManager() { RAY_LOG(DEBUG) << "TestObjectManager: started."; }
void SetUp() {
// start store
std::string om_dir = test_executable.substr(0, test_executable.find_last_of("/"));
std::string plasma_dir = om_dir + "./../plasma";
std::string plasma_command =
plasma_dir +
"/plasma_store -m 1000000000 -s /tmp/store 1> /dev/null 2> /dev/null &";
int s = system(plasma_command.c_str());
ASSERT_TRUE(!s);
// Start mock global control store.
mock_gcs_client_ = std::shared_ptr<GcsClient>(new GcsClient());
// mock_gcs_client_->Register();
// Start node server.
// Start object manager 1.
ObjectManagerConfig config;
config.store_socket_name = "/tmp/store";
object_manager_1_ = std::unique_ptr<ObjectManager>(
new ObjectManager(io_service_, config, mock_gcs_client_));
// Start object manager 2.
// ObjectManagerConfig config2;
// config2.store_socket_name = "/tmp/store";
// std::shared_ptr<ObjectDirectory> od2 = std::shared_ptr<ObjectDirectory>(new
// ObjectDirectory());
// od2->InitGcs(mock_gcs_client_);
// object_manager_2_ = std::unique_ptr<ObjectManager>(new ObjectManager(io_service,
// config2, od2));
// Initiate client connection.
ARROW_CHECK_OK(client_.Connect("/tmp/store", "", PLASMA_DEFAULT_RELEASE_DELAY));
this->StartLoop();
}
void TearDown() {
this->StopLoop();
arrow::Status arrow_status = client_.Disconnect();
ASSERT_TRUE(arrow_status.ok());
ray::Status ray_status = object_manager_1_->Terminate();
ASSERT_TRUE(ray_status.ok());
// object_manager_2_->Terminate();
int s = system("killall plasma_store &");
ASSERT_TRUE(!s);
}
void Loop() { io_service_.run(); };
void StartLoop() { process_thread_ = std::thread(&TestObjectManager::Loop, this); };
void StopLoop() {
io_service_.stop();
process_thread_.join();
}
protected:
std::thread process_thread_;
plasma::PlasmaClient client_;
plasma::PlasmaClient client2_;
boost::asio::io_service io_service_;
std::shared_ptr<GcsClient> mock_gcs_client_;
std::unique_ptr<ObjectManager> object_manager_1_;
std::unique_ptr<ObjectManager> object_manager_2_;
};
// TODO: get rid of dead code?
// TEST_F(TestObjectManager, TestPush) {
// // test object push between two object managers.
// ASSERT_TRUE(true);
// sleep(1);
//}
// TEST_F(TestObjectManager, TestPull) {
// ObjectID object_id = ObjectID().from_random();
// ClientID dbc_id = ClientID().from_random();
// RAY_LOG(INFO) << "ObjectID: " << object_id.hex().c_str();
// RAY_LOG(INFO) << "ClientID: " << dbc_id.hex().c_str();
// om->Pull(object_id, dbc_id);
// om->Pull(object_id);
// ASSERT_TRUE(true);
// sleep(1);
//}
void ObjectAdded(const ObjectID &object_id) {
RAY_LOG(INFO) << "ObjectID Added: " << object_id.hex().c_str();
}
TEST_F(TestObjectManager, TestNotifications) {
ray::Status status = object_manager_1_->SubscribeObjAdded(ObjectAdded);
ASSERT_TRUE(status.ok());
// put object
for (int i = 0; i < 10; ++i) {
ObjectID object_id = ObjectID::from_random();
RAY_LOG(INFO) << "ObjectID Created: " << object_id.hex().c_str();
int64_t data_size = 100;
uint8_t metadata[] = {5};
int64_t metadata_size = sizeof(metadata);
std::shared_ptr<Buffer> data;
ARROW_CHECK_OK(client_.Create(object_id.to_plasma_id(), data_size, metadata,
metadata_size, &data));
ARROW_CHECK_OK(client_.Seal(object_id.to_plasma_id()));
}
// TODO(hme): Can we do this without sleeping?
sleep(1);
}
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
ray::test_executable = std::string(argv[0]);
return RUN_ALL_TESTS();
}
@@ -1,95 +0,0 @@
#include <future>
#include <iostream>
#include <boost/asio.hpp>
#include <boost/bind.hpp>
#include <boost/function.hpp>
#include "common.h"
#include "common_protocol.h"
#include "ray/object_manager/object_store_client.h"
namespace ray {
// TODO(hme): Dedicate this class to notifications.
// TODO(hme): Create object store client pool for object manager.
ObjectStoreClient::ObjectStoreClient(boost::asio::io_service &io_service,
std::string &store_socket_name)
: client_one_(), client_two_(), socket_(io_service) {
ARROW_CHECK_OK(
client_two_.Connect(store_socket_name.c_str(), "", PLASMA_DEFAULT_RELEASE_DELAY));
ARROW_CHECK_OK(
client_one_.Connect(store_socket_name.c_str(), "", PLASMA_DEFAULT_RELEASE_DELAY));
// Connect to two clients, but subscribe to only one.
ARROW_CHECK_OK(client_one_.Subscribe(&c_socket_));
boost::system::error_code ec;
socket_.assign(boost::asio::local::stream_protocol(), c_socket_, ec);
assert(!ec.value());
NotificationWait();
};
void ObjectStoreClient::Terminate() {
ARROW_CHECK_OK(client_two_.Disconnect());
ARROW_CHECK_OK(client_one_.Disconnect());
}
void ObjectStoreClient::NotificationWait() {
boost::asio::async_read(socket_, boost::asio::buffer(&length_, sizeof(length_)),
boost::bind(&ObjectStoreClient::ProcessStoreLength, this,
boost::asio::placeholders::error));
}
void ObjectStoreClient::ProcessStoreLength(const boost::system::error_code &error) {
notification_.resize(length_);
boost::asio::async_read(socket_, boost::asio::buffer(notification_),
boost::bind(&ObjectStoreClient::ProcessStoreNotification, this,
boost::asio::placeholders::error));
}
void ObjectStoreClient::ProcessStoreNotification(const boost::system::error_code &error) {
if (error) {
throw std::runtime_error("ObjectStore may have died.");
}
auto object_info = flatbuffers::GetRoot<ObjectInfo>(notification_.data());
ObjectID object_id = from_flatbuf(*object_info->object_id());
if (object_info->is_deletion()) {
ProcessStoreRemove(object_id);
} else {
ProcessStoreAdd(object_id);
// why all these params?
// ProcessStoreAdd(
// object_id, object_info->data_size(),
// object_info->metadata_size(),
// (unsigned char *) object_info->digest()->data());
}
NotificationWait();
}
void ObjectStoreClient::ProcessStoreAdd(const ObjectID &object_id) {
for (auto handler : add_handlers_) {
handler(object_id);
}
};
void ObjectStoreClient::ProcessStoreRemove(const ObjectID &object_id) {
for (auto handler : rem_handlers_) {
handler(object_id);
}
};
void ObjectStoreClient::SubscribeObjAdded(
std::function<void(const ObjectID &)> callback) {
add_handlers_.push_back(callback);
};
void ObjectStoreClient::SubscribeObjDeleted(
std::function<void(const ObjectID &)> callback) {
rem_handlers_.push_back(callback);
};
plasma::PlasmaClient &ObjectStoreClient::GetClient() { return client_one_; };
plasma::PlasmaClient &ObjectStoreClient::GetClientOther() { return client_two_; };
} // namespace ray
@@ -0,0 +1,38 @@
#include "object_store_client_pool.h"
namespace ray {
ObjectStoreClientPool::ObjectStoreClientPool(const std::string &store_socket_name)
: store_socket_name_(store_socket_name) {}
std::shared_ptr<plasma::PlasmaClient> ObjectStoreClientPool::GetObjectStore() {
std::lock_guard<std::mutex> lock(pool_mutex);
if (available_clients.empty()) {
Add();
}
std::shared_ptr<plasma::PlasmaClient> client = available_clients.back();
available_clients.pop_back();
return client;
}
void ObjectStoreClientPool::ReleaseObjectStore(
std::shared_ptr<plasma::PlasmaClient> client) {
std::lock_guard<std::mutex> lock(pool_mutex);
available_clients.push_back(client);
}
void ObjectStoreClientPool::Terminate() {
for (const auto &client : clients) {
ARROW_CHECK_OK(client->Disconnect());
}
available_clients.clear();
clients.clear();
}
void ObjectStoreClientPool::Add() {
clients.emplace_back(new plasma::PlasmaClient());
ARROW_CHECK_OK(clients.back()->Connect(store_socket_name_.c_str(), "",
PLASMA_DEFAULT_RELEASE_DELAY));
available_clients.push_back(clients.back());
}
} // namespace ray
@@ -0,0 +1,65 @@
#ifndef RAY_OBJECT_MANAGER_OBJECT_STORE_CLIENT_POOL_H
#define RAY_OBJECT_MANAGER_OBJECT_STORE_CLIENT_POOL_H
#include <list>
#include <memory>
#include <mutex>
#include <vector>
#include <boost/asio.hpp>
#include <boost/asio/error.hpp>
#include <boost/bind.hpp>
#include "plasma/client.h"
#include "plasma/events.h"
#include "plasma/plasma.h"
#include "object_directory.h"
#include "ray/id.h"
#include "ray/status.h"
namespace ray {
/// \class ObjectStoreClientPool
///
/// Provides connections to the object store. Enables concurrent communication with
/// the object store.
class ObjectStoreClientPool {
public:
/// Constructor.
///
/// \param store_socket_name The object store socket name.
ObjectStoreClientPool(const std::string &store_socket_name);
/// This object cannot be copied due to pool_mutex.
RAY_DISALLOW_COPY_AND_ASSIGN(ObjectStoreClientPool);
/// Provides a connection to the object store from the object store pool.
/// This removes the object store client from the pool of available clients.
///
/// \return A connection to the object store.
std::shared_ptr<plasma::PlasmaClient> GetObjectStore();
/// Releases a client object and puts it back into the object store pool
/// for reuse.
/// Once a client is released, it is assumed that it is not being used.
/// \param client The client to return.
/// \param client
void ReleaseObjectStore(std::shared_ptr<plasma::PlasmaClient> client);
/// Terminates this object.
void Terminate();
private:
/// Adds a client to the client pool and mark it as available.
void Add();
std::mutex pool_mutex;
std::vector<std::shared_ptr<plasma::PlasmaClient>> available_clients;
std::vector<std::shared_ptr<plasma::PlasmaClient>> clients;
std::string store_socket_name_;
};
} // namespace ray
#endif // RAY_OBJECT_MANAGER_OBJECT_STORE_CLIENT_POOL_H
@@ -0,0 +1,90 @@
#include <future>
#include <iostream>
#include <boost/asio.hpp>
#include <boost/bind.hpp>
#include <boost/function.hpp>
#include "common/common.h"
#include "common/common_protocol.h"
#include "ray/object_manager/object_store_notification_manager.h"
namespace ray {
ObjectStoreNotificationManager::ObjectStoreNotificationManager(
boost::asio::io_service &io_service, const std::string &store_socket_name)
: store_client_(), socket_(io_service) {
ARROW_CHECK_OK(
store_client_.Connect(store_socket_name.c_str(), "", PLASMA_DEFAULT_RELEASE_DELAY));
ARROW_CHECK_OK(store_client_.Subscribe(&c_socket_));
boost::system::error_code ec;
socket_.assign(boost::asio::local::stream_protocol(), c_socket_, ec);
assert(!ec.value());
NotificationWait();
}
void ObjectStoreNotificationManager::Terminate() {
ARROW_CHECK_OK(store_client_.Disconnect());
}
void ObjectStoreNotificationManager::NotificationWait() {
boost::asio::async_read(socket_, boost::asio::buffer(&length_, sizeof(length_)),
boost::bind(&ObjectStoreNotificationManager::ProcessStoreLength,
this, boost::asio::placeholders::error));
}
void ObjectStoreNotificationManager::ProcessStoreLength(
const boost::system::error_code &error) {
notification_.resize(length_);
boost::asio::async_read(
socket_, boost::asio::buffer(notification_),
boost::bind(&ObjectStoreNotificationManager::ProcessStoreNotification, this,
boost::asio::placeholders::error));
}
void ObjectStoreNotificationManager::ProcessStoreNotification(
const boost::system::error_code &error) {
if (error) {
throw std::runtime_error("ObjectStore may have died.");
}
const auto &object_info = flatbuffers::GetRoot<ObjectInfo>(notification_.data());
const auto &object_id = from_flatbuf(*object_info->object_id());
if (object_info->is_deletion()) {
ProcessStoreRemove(object_id);
} else {
ProcessStoreAdd(object_id);
// TODO(hme): Determine what data is actually needed by consumer of this notification.
// ProcessStoreAdd(
// object_id, object_info->data_size(),
// object_info->metadata_size(),
// (unsigned char *) object_info->digest()->data());
}
NotificationWait();
}
void ObjectStoreNotificationManager::ProcessStoreAdd(const ObjectID &object_id) {
for (auto handler : add_handlers_) {
handler(object_id);
}
}
void ObjectStoreNotificationManager::ProcessStoreRemove(const ObjectID &object_id) {
for (auto handler : rem_handlers_) {
handler(object_id);
}
}
void ObjectStoreNotificationManager::SubscribeObjAdded(
std::function<void(const ObjectID &)> callback) {
add_handlers_.push_back(callback);
}
void ObjectStoreNotificationManager::SubscribeObjDeleted(
std::function<void(const ObjectID &)> callback) {
rem_handlers_.push_back(callback);
}
} // namespace ray
@@ -13,53 +13,58 @@
#include "plasma/events.h"
#include "plasma/plasma.h"
#include "object_directory.h"
#include "ray/id.h"
#include "ray/status.h"
#include "ray/object_manager/object_directory.h"
namespace ray {
// TODO(hme): document public API after refactor.
class ObjectStoreClient {
/// \class ObjectStoreClientPool
///
/// Encapsulates notification handling from the object store.
class ObjectStoreNotificationManager {
public:
// Encapsulates communication with the object store.
ObjectStoreClient(boost::asio::io_service &io_service, std::string &store_socket_name);
/// Constructor.
///
/// \param io_service The asio service to be used.
/// \param store_socket_name The store socket to connect to.
ObjectStoreNotificationManager(boost::asio::io_service &io_service,
const std::string &store_socket_name);
// Subscribe to notifications of objects added to local store.
// Upon subscribing, the callback will be invoked for all objects that
// already exist in the local store.
/// Subscribe to notifications of objects added to local store.
/// Upon subscribing, the callback will be invoked for all objects that
/// already exist in the local store
///
/// \param callback A callback expecting an ObjectID.
void SubscribeObjAdded(std::function<void(const ray::ObjectID &)> callback);
// Subscribe to notifications of objects deleted from local store.
/// Subscribe to notifications of objects deleted from local store.
///
/// \param callback A callback expecting an ObjectID.
void SubscribeObjDeleted(std::function<void(const ray::ObjectID &)> callback);
// TODO(hme): There should be as many client connections as there are threads.
// Two client connections are made to enable concurrent communication with the store.
plasma::PlasmaClient &GetClient();
plasma::PlasmaClient &GetClientOther();
// Terminate this object.
/// Terminate this object.
void Terminate();
private:
std::vector<std::function<void(const ray::ObjectID &)>> add_handlers_;
std::vector<std::function<void(const ray::ObjectID &)>> rem_handlers_;
plasma::PlasmaClient client_one_;
plasma::PlasmaClient client_two_;
int c_socket_;
int64_t length_;
std::vector<uint8_t> notification_;
boost::asio::local::stream_protocol::socket socket_;
// Async loop for handling object store notifications.
/// Async loop for handling object store notifications.
void NotificationWait();
void ProcessStoreLength(const boost::system::error_code &error);
void ProcessStoreNotification(const boost::system::error_code &error);
// Support for rebroadcasting object add/rem events.
/// Support for rebroadcasting object add/rem events.
void ProcessStoreAdd(const ObjectID &object_id);
void ProcessStoreRemove(const ObjectID &object_id);
std::vector<std::function<void(const ray::ObjectID &)>> add_handlers_;
std::vector<std::function<void(const ray::ObjectID &)>> rem_handlers_;
plasma::PlasmaClient store_client_;
int c_socket_;
int64_t length_;
std::vector<uint8_t> notification_;
boost::asio::local::stream_protocol::socket socket_;
};
} // namespace ray
@@ -0,0 +1,440 @@
#include <chrono>
#include <iostream>
#include <random>
#include <thread>
#include "gtest/gtest.h"
#include "ray/object_manager/object_manager.h"
namespace ray {
std::string store_executable;
static inline void flushall_redis(void) {
redisContext *context = redisConnect("127.0.0.1", 6379);
freeReplyObject(redisCommand(context, "FLUSHALL"));
redisFree(context);
}
int64_t current_time_ms() {
std::chrono::milliseconds ms_since_epoch =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch());
return ms_since_epoch.count();
}
class MockServer {
public:
MockServer(boost::asio::io_service &main_service,
std::unique_ptr<boost::asio::io_service> object_manager_service,
const ObjectManagerConfig &object_manager_config,
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
: object_manager_acceptor_(
main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
object_manager_socket_(main_service),
gcs_client_(gcs_client),
object_manager_(main_service, std::move(object_manager_service),
object_manager_config, gcs_client) {
RAY_CHECK_OK(RegisterGcs(main_service));
// Start listening for clients.
DoAcceptObjectManager();
}
~MockServer() {
RAY_CHECK_OK(gcs_client_->client_table().Disconnect());
RAY_CHECK_OK(object_manager_.Terminate());
}
private:
ray::Status RegisterGcs(boost::asio::io_service &io_service) {
RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379));
RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service));
boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint();
std::string ip = endpoint.address().to_string();
unsigned short object_manager_port = endpoint.port();
ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient();
client_info.node_manager_address = ip;
client_info.node_manager_port = object_manager_port;
client_info.object_manager_port = object_manager_port;
return gcs_client_->client_table().Connect(client_info);
}
void DoAcceptObjectManager() {
object_manager_acceptor_.async_accept(
object_manager_socket_, boost::bind(&MockServer::HandleAcceptObjectManager, this,
boost::asio::placeholders::error));
}
void HandleAcceptObjectManager(const boost::system::error_code &error) {
ClientHandler<boost::asio::ip::tcp> client_handler =
[this](std::shared_ptr<TcpClientConnection> client) {
object_manager_.ProcessNewClient(client);
};
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
const uint8_t *message) {
object_manager_.ProcessClientMessage(client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
std::move(object_manager_socket_));
DoAcceptObjectManager();
}
friend class StressTestObjectManager;
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
boost::asio::ip::tcp::socket object_manager_socket_;
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
ObjectManager object_manager_;
};
class TestObjectManagerBase : public ::testing::Test {
public:
TestObjectManagerBase() {}
std::string StartStore(const std::string &id) {
std::string store_id = "/tmp/store";
store_id = store_id + id;
std::string plasma_command = store_executable + " -m 1000000000 -s " + store_id +
" 1> /dev/null 2> /dev/null &";
RAY_LOG(DEBUG) << plasma_command;
int ec = system(plasma_command.c_str());
if (ec != 0) {
throw std::runtime_error("failed to start plasma store.");
};
return store_id;
}
void SetUp() {
flushall_redis();
object_manager_service_1.reset(new boost::asio::io_service());
object_manager_service_2.reset(new boost::asio::io_service());
// start store
std::string store_sock_1 = StartStore("1");
std::string store_sock_2 = StartStore("2");
// start first server
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
ObjectManagerConfig om_config_1;
om_config_1.store_socket_name = store_sock_1;
om_config_1.num_threads = 4;
om_config_1.max_sends = 20;
om_config_1.max_receives = 20;
server1.reset(new MockServer(main_service, std::move(object_manager_service_1),
om_config_1, gcs_client_1));
// start second server
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
ObjectManagerConfig om_config_2;
om_config_2.store_socket_name = store_sock_2;
om_config_2.num_threads = 4;
om_config_2.max_sends = 20;
om_config_2.max_receives = 20;
server2.reset(new MockServer(main_service, std::move(object_manager_service_2),
om_config_2, gcs_client_2));
// connect to stores.
ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY));
ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY));
}
void TearDown() {
arrow::Status client1_status = client1.Disconnect();
arrow::Status client2_status = client2.Disconnect();
ASSERT_TRUE(client1_status.ok() && client2_status.ok());
this->server1.reset();
this->server2.reset();
int s = system("killall plasma_store &");
ASSERT_TRUE(!s);
}
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
ObjectID object_id = ObjectID::from_random();
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id;
uint8_t metadata[] = {5};
int64_t metadata_size = sizeof(metadata);
std::shared_ptr<Buffer> data;
ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata,
metadata_size, &data));
ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id()));
return object_id;
}
void object_added_handler_1(ObjectID object_id) { v1.push_back(object_id); };
void object_added_handler_2(ObjectID object_id) { v2.push_back(object_id); };
protected:
std::thread p;
boost::asio::io_service main_service;
std::unique_ptr<boost::asio::io_service> object_manager_service_1;
std::unique_ptr<boost::asio::io_service> object_manager_service_2;
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_1;
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_2;
std::unique_ptr<MockServer> server1;
std::unique_ptr<MockServer> server2;
plasma::PlasmaClient client1;
plasma::PlasmaClient client2;
std::vector<ObjectID> v1;
std::vector<ObjectID> v2;
};
class StressTestObjectManager : public TestObjectManagerBase {
public:
enum TransferPattern {
PUSH_A_B,
PUSH_B_A,
BIDIRECTIONAL_PUSH,
PULL_A_B,
PULL_B_A,
BIDIRECTIONAL_PULL,
BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE,
};
int async_loop_index = -1;
uint num_expected_objects;
std::vector<TransferPattern> async_loop_patterns = {
PUSH_A_B,
PUSH_B_A,
BIDIRECTIONAL_PUSH,
PULL_A_B,
PULL_B_A,
BIDIRECTIONAL_PULL,
BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE};
int num_connected_clients = 0;
ClientID client_id_1;
ClientID client_id_2;
int64_t start_time;
void WaitConnections() {
client_id_1 = gcs_client_1->client_table().GetLocalClientId();
client_id_2 = gcs_client_2->client_table().GetLocalClientId();
gcs_client_1->client_table().RegisterClientAddedCallback([this](
gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientID parsed_id = ClientID::from_binary(data.client_id);
if (parsed_id == client_id_1 || parsed_id == client_id_2) {
num_connected_clients += 1;
}
if (num_connected_clients == 2) {
StartTests();
}
});
}
void StartTests() {
TestConnections();
AddTransferTestHandlers();
TransferTestNext();
}
void AddTransferTestHandlers() {
ray::Status status = ray::Status::OK();
status =
server1->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
object_added_handler_1(object_id);
if (v1.size() == num_expected_objects && v1.size() == v2.size()) {
TransferTestComplete();
}
});
RAY_CHECK_OK(status);
status =
server2->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
object_added_handler_2(object_id);
if (v2.size() == num_expected_objects && v1.size() == v2.size()) {
TransferTestComplete();
}
});
RAY_CHECK_OK(status);
}
void TransferTestNext() {
async_loop_index += 1;
if ((uint)async_loop_index < async_loop_patterns.size()) {
TransferPattern pattern = async_loop_patterns[async_loop_index];
TransferTestExecute(1000, 100, pattern);
} else {
main_service.stop();
}
}
plasma::ObjectBuffer GetObject(plasma::PlasmaClient &client, ObjectID &object_id) {
plasma::ObjectBuffer object_buffer;
plasma::ObjectID plasma_id = object_id.to_plasma_id();
ARROW_CHECK_OK(client.Get(&plasma_id, 1, 0, &object_buffer));
return object_buffer;
}
static unsigned char *GetDigest(plasma::PlasmaClient &client, ObjectID &object_id) {
const int64_t size = sizeof(uint64_t);
static unsigned char digest_1[size];
ARROW_CHECK_OK(client.Hash(object_id.to_plasma_id(), &digest_1[0]));
return digest_1;
}
void CompareObjects(ObjectID &object_id_1, ObjectID &object_id_2) {
plasma::ObjectBuffer object_buffer_1 = GetObject(client1, object_id_1);
plasma::ObjectBuffer object_buffer_2 = GetObject(client1, object_id_1);
uint8_t *data_1 = const_cast<uint8_t *>(object_buffer_1.data->data());
uint8_t *data_2 = const_cast<uint8_t *>(object_buffer_2.data->data());
ASSERT_EQ(object_buffer_1.data_size, object_buffer_2.data_size);
for (int i = -1; ++i < object_buffer_1.data_size;) {
ASSERT_TRUE(data_1[i] == data_2[i]);
}
}
void CompareHashes(ObjectID &object_id_1, ObjectID &object_id_2) {
const int64_t size = sizeof(uint64_t);
static unsigned char *digest_1 = GetDigest(client1, object_id_1);
static unsigned char *digest_2 = GetDigest(client2, object_id_2);
for (int i = -1; ++i < size;) {
ASSERT_TRUE(digest_1[i] == digest_2[i]);
}
}
void TransferTestComplete() {
int64_t elapsed = current_time_ms() - start_time;
RAY_LOG(INFO) << "TransferTestComplete: " << async_loop_patterns[async_loop_index]
<< " " << v1.size() << " " << elapsed;
ASSERT_TRUE(v1.size() == v2.size());
for (uint i = 0; i < v1.size(); ++i) {
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
}
// Compare objects and their hashes.
for (uint i = 0; i < v1.size(); ++i) {
ObjectID object_id_2 = v2[i];
ObjectID object_id_1 =
v1[std::distance(v1.begin(), std::find(v1.begin(), v1.end(), v2[i]))];
CompareHashes(object_id_1, object_id_2);
CompareObjects(object_id_1, object_id_2);
}
v1.clear();
v2.clear();
TransferTestNext();
}
void TransferTestExecute(int num_trials, int64_t data_size,
TransferPattern transfer_pattern) {
ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId();
ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId();
ray::Status status = ray::Status::OK();
if (transfer_pattern == BIDIRECTIONAL_PULL ||
transfer_pattern == BIDIRECTIONAL_PUSH ||
transfer_pattern == BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE) {
num_expected_objects = (uint)2 * num_trials;
} else {
num_expected_objects = (uint)num_trials;
}
start_time = current_time_ms();
switch (transfer_pattern) {
case PUSH_A_B: {
for (int i = -1; ++i < num_trials;) {
ObjectID oid1 = WriteDataToClient(client1, data_size);
status = server1->object_manager_.Push(oid1, client_id_2);
}
} break;
case PUSH_B_A: {
for (int i = -1; ++i < num_trials;) {
ObjectID oid2 = WriteDataToClient(client2, data_size);
status = server2->object_manager_.Push(oid2, client_id_1);
}
} break;
case BIDIRECTIONAL_PUSH: {
for (int i = -1; ++i < num_trials;) {
ObjectID oid1 = WriteDataToClient(client1, data_size);
status = server1->object_manager_.Push(oid1, client_id_2);
ObjectID oid2 = WriteDataToClient(client2, data_size);
status = server2->object_manager_.Push(oid2, client_id_1);
}
} break;
case PULL_A_B: {
for (int i = -1; ++i < num_trials;) {
ObjectID oid1 = WriteDataToClient(client1, data_size);
status = server2->object_manager_.Pull(oid1);
}
} break;
case PULL_B_A: {
for (int i = -1; ++i < num_trials;) {
ObjectID oid2 = WriteDataToClient(client2, data_size);
status = server1->object_manager_.Pull(oid2);
}
} break;
case BIDIRECTIONAL_PULL: {
for (int i = -1; ++i < num_trials;) {
ObjectID oid1 = WriteDataToClient(client1, data_size);
status = server2->object_manager_.Pull(oid1);
ObjectID oid2 = WriteDataToClient(client2, data_size);
status = server1->object_manager_.Pull(oid2);
}
} break;
case BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE: {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(1, 50);
for (int i = -1; ++i < num_trials;) {
ObjectID oid1 = WriteDataToClient(client1, data_size + dis(gen));
status = server2->object_manager_.Pull(oid1);
ObjectID oid2 = WriteDataToClient(client2, data_size + dis(gen));
status = server1->object_manager_.Pull(oid2);
}
} break;
default: {
RAY_LOG(FATAL) << "No case for transfer_pattern " << transfer_pattern;
} break;
}
}
void TestConnections() {
RAY_LOG(DEBUG) << "\n"
<< "Server client ids:"
<< "\n";
ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId();
ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId();
RAY_LOG(DEBUG) << "Server 1: " << client_id_1 << "\n"
<< "Server 2: " << client_id_2;
RAY_LOG(DEBUG) << "\n"
<< "All connected clients:"
<< "\n";
const ClientTableDataT &data = gcs_client_1->client_table().GetClient(client_id_1);
RAY_LOG(DEBUG) << "ClientID=" << ClientID::from_binary(data.client_id) << "\n"
<< "ClientIp=" << data.node_manager_address << "\n"
<< "ClientPort=" << data.node_manager_port;
const ClientTableDataT &data2 = gcs_client_1->client_table().GetClient(client_id_2);
RAY_LOG(DEBUG) << "ClientID=" << ClientID::from_binary(data2.client_id) << "\n"
<< "ClientIp=" << data2.node_manager_address << "\n"
<< "ClientPort=" << data2.node_manager_port;
}
};
TEST_F(StressTestObjectManager, StartStressTestObjectManager) {
auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); });
AsyncStartTests();
main_service.run();
}
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
ray::store_executable = std::string(argv[1]);
return RUN_ALL_TESTS();
}
@@ -0,0 +1,256 @@
#include <iostream>
#include <thread>
#include "gtest/gtest.h"
#include "ray/object_manager/object_manager.h"
namespace ray {
static inline void flushall_redis(void) {
redisContext *context = redisConnect("127.0.0.1", 6379);
freeReplyObject(redisCommand(context, "FLUSHALL"));
redisFree(context);
}
std::string store_executable;
class MockServer {
public:
MockServer(boost::asio::io_service &main_service,
std::unique_ptr<boost::asio::io_service> object_manager_service,
const ObjectManagerConfig &object_manager_config,
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
: object_manager_acceptor_(
main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
object_manager_socket_(main_service),
gcs_client_(gcs_client),
object_manager_(main_service, std::move(object_manager_service),
object_manager_config, gcs_client) {
RAY_CHECK_OK(RegisterGcs(main_service));
// Start listening for clients.
DoAcceptObjectManager();
}
~MockServer() {
RAY_CHECK_OK(gcs_client_->client_table().Disconnect());
RAY_CHECK_OK(object_manager_.Terminate());
}
private:
ray::Status RegisterGcs(boost::asio::io_service &io_service) {
RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379));
RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service));
boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint();
std::string ip = endpoint.address().to_string();
unsigned short object_manager_port = endpoint.port();
ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient();
client_info.node_manager_address = ip;
client_info.node_manager_port = object_manager_port;
client_info.object_manager_port = object_manager_port;
return gcs_client_->client_table().Connect(client_info);
}
void DoAcceptObjectManager() {
object_manager_acceptor_.async_accept(
object_manager_socket_, boost::bind(&MockServer::HandleAcceptObjectManager, this,
boost::asio::placeholders::error));
}
void HandleAcceptObjectManager(const boost::system::error_code &error) {
ClientHandler<boost::asio::ip::tcp> client_handler =
[this](std::shared_ptr<TcpClientConnection> client) {
object_manager_.ProcessNewClient(client);
};
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
const uint8_t *message) {
object_manager_.ProcessClientMessage(client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
std::move(object_manager_socket_));
DoAcceptObjectManager();
}
friend class TestObjectManagerCommands;
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
boost::asio::ip::tcp::socket object_manager_socket_;
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
ObjectManager object_manager_;
};
class TestObjectManager : public ::testing::Test {
public:
TestObjectManager() {}
std::string StartStore(const std::string &id) {
std::string store_id = "/tmp/store";
store_id = store_id + id;
std::string plasma_command = store_executable + " -m 1000000000 -s " + store_id +
" 1> /dev/null 2> /dev/null &";
RAY_LOG(DEBUG) << plasma_command;
int ec = system(plasma_command.c_str());
if (ec != 0) {
throw std::runtime_error("failed to start plasma store.");
};
return store_id;
}
void SetUp() {
flushall_redis();
object_manager_service_1.reset(new boost::asio::io_service());
object_manager_service_2.reset(new boost::asio::io_service());
// start store
std::string store_sock_1 = StartStore("1");
std::string store_sock_2 = StartStore("2");
// start first server
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
ObjectManagerConfig om_config_1;
om_config_1.store_socket_name = store_sock_1;
server1.reset(new MockServer(main_service, std::move(object_manager_service_1),
om_config_1, gcs_client_1));
// start second server
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
ObjectManagerConfig om_config_2;
om_config_2.store_socket_name = store_sock_2;
server2.reset(new MockServer(main_service, std::move(object_manager_service_2),
om_config_2, gcs_client_2));
// connect to stores.
ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY));
ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY));
}
void TearDown() {
arrow::Status client1_status = client1.Disconnect();
arrow::Status client2_status = client2.Disconnect();
ASSERT_TRUE(client1_status.ok() && client2_status.ok());
this->server1.reset();
this->server2.reset();
int s = system("killall plasma_store &");
ASSERT_TRUE(!s);
}
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
ObjectID object_id = ObjectID::from_random();
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id;
uint8_t metadata[] = {5};
int64_t metadata_size = sizeof(metadata);
std::shared_ptr<Buffer> data;
ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata,
metadata_size, &data));
ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id()));
return object_id;
}
void object_added_handler_1(ObjectID object_id) { v1.push_back(object_id); };
void object_added_handler_2(ObjectID object_id) { v2.push_back(object_id); };
protected:
std::thread p;
boost::asio::io_service main_service;
std::unique_ptr<boost::asio::io_service> object_manager_service_1;
std::unique_ptr<boost::asio::io_service> object_manager_service_2;
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_1;
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_2;
std::unique_ptr<MockServer> server1;
std::unique_ptr<MockServer> server2;
plasma::PlasmaClient client1;
plasma::PlasmaClient client2;
std::vector<ObjectID> v1;
std::vector<ObjectID> v2;
};
class TestObjectManagerCommands : public TestObjectManager {
public:
int num_connected_clients = 0;
uint num_expected_objects;
ClientID client_id_1;
ClientID client_id_2;
ObjectID created_object_id;
void WaitConnections() {
client_id_1 = gcs_client_1->client_table().GetLocalClientId();
client_id_2 = gcs_client_2->client_table().GetLocalClientId();
gcs_client_1->client_table().RegisterClientAddedCallback([this](
gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientID parsed_id = ClientID::from_binary(data.client_id);
if (parsed_id == client_id_1 || parsed_id == client_id_2) {
num_connected_clients += 1;
}
if (num_connected_clients == 2) {
StartTests();
}
});
}
void StartTests() {
TestConnections();
TestNotifications();
}
void TestNotifications() {
ray::Status status = ray::Status::OK();
status =
server1->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
object_added_handler_1(object_id);
if (v1.size() == num_expected_objects) {
NotificationTestComplete(created_object_id, object_id);
}
});
RAY_CHECK_OK(status);
num_expected_objects = 1;
uint data_size = 1000000;
created_object_id = WriteDataToClient(client1, data_size);
}
void NotificationTestComplete(ObjectID object_id_1, ObjectID object_id_2) {
ASSERT_EQ(object_id_1, object_id_2);
main_service.stop();
}
void TestConnections() {
RAY_LOG(DEBUG) << "\n"
<< "Server client ids:"
<< "\n";
const ClientTableDataT &data = gcs_client_1->client_table().GetClient(client_id_1);
RAY_LOG(DEBUG) << (ClientID::from_binary(data.client_id) == ClientID::nil());
RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::from_binary(data.client_id);
RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address;
RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port;
ASSERT_EQ(client_id_1, ClientID::from_binary(data.client_id));
const ClientTableDataT &data2 = gcs_client_1->client_table().GetClient(client_id_2);
RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::from_binary(data2.client_id);
RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address;
RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port;
ASSERT_EQ(client_id_2, ClientID::from_binary(data2.client_id));
}
};
TEST_F(TestObjectManagerCommands, StartTestObjectManagerCommands) {
auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); });
AsyncStartTests();
main_service.run();
}
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
ray::store_executable = std::string(argv[1]);
return RUN_ALL_TESTS();
}
+67
View File
@@ -0,0 +1,67 @@
#include "ray/object_manager/transfer_queue.h"
namespace ray {
void TransferQueue::QueueSend(const ClientID &client_id, const ObjectID &object_id,
const RemoteConnectionInfo &info) {
WriteLock guard(send_mutex);
SendRequest req = {client_id, object_id, info};
// TODO(hme): Use a set to speed this up.
if (std::find(send_queue_.begin(), send_queue_.end(), req) != send_queue_.end()) {
// already queued.
return;
}
send_queue_.push_back(req);
}
void TransferQueue::QueueReceive(const ClientID &client_id, const ObjectID &object_id,
uint64_t object_size,
std::shared_ptr<TcpClientConnection> conn) {
WriteLock guard(receive_mutex);
ReceiveRequest req = {client_id, object_id, object_size, conn};
if (std::find(receive_queue_.begin(), receive_queue_.end(), req) !=
receive_queue_.end()) {
// already queued.
return;
}
receive_queue_.push_back(req);
}
bool TransferQueue::DequeueSendIfPresent(TransferQueue::SendRequest *send_ptr) {
WriteLock guard(send_mutex);
if (send_queue_.empty()) {
return false;
}
*send_ptr = send_queue_.front();
send_queue_.pop_front();
return true;
}
bool TransferQueue::DequeueReceiveIfPresent(TransferQueue::ReceiveRequest *receive_ptr) {
WriteLock guard(receive_mutex);
if (receive_queue_.empty()) {
return false;
}
*receive_ptr = receive_queue_.front();
receive_queue_.pop_front();
return true;
}
UniqueID TransferQueue::AddContext(SendContext &context) {
WriteLock guard(context_mutex);
UniqueID id = UniqueID::from_random();
send_context_set_.emplace(id, context);
return id;
}
TransferQueue::SendContext &TransferQueue::GetContext(const UniqueID &id) {
ReadLock guard(context_mutex);
return send_context_set_[id];
}
ray::Status TransferQueue::RemoveContext(const UniqueID &id) {
WriteLock guard(context_mutex);
send_context_set_.erase(id);
return Status::OK();
}
} // namespace ray
+126
View File
@@ -0,0 +1,126 @@
#ifndef RAY_OBJECT_MANAGER_TRANSFER_QUEUE_H
#define RAY_OBJECT_MANAGER_TRANSFER_QUEUE_H
#include <algorithm>
#include <cstdint>
#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <thread>
#include <boost/asio.hpp>
#include <boost/asio/error.hpp>
#include <boost/bind.hpp>
#include "ray/id.h"
#include "ray/status.h"
#include "ray/object_manager/format/object_manager_generated.h"
#include "ray/object_manager/object_directory.h"
#include "ray/object_manager/object_manager_client_connection.h"
namespace ray {
class TransferQueue {
public:
enum TransferType { SEND = 1, RECEIVE };
/// Context maintained during an object send.
struct SendContext {
ClientID client_id;
ObjectID object_id;
uint64_t object_size;
uint8_t *data;
};
/// The structure used in the send queue.
struct SendRequest {
ClientID client_id;
ObjectID object_id;
RemoteConnectionInfo connection_info;
bool operator==(const SendRequest &rhs) const {
return client_id == rhs.client_id && object_id == rhs.object_id;
}
};
/// The structure used in the receive queue.
struct ReceiveRequest {
ClientID client_id;
ObjectID object_id;
uint64_t object_size;
std::shared_ptr<TcpClientConnection> conn;
bool operator==(const ReceiveRequest &rhs) const {
return client_id == rhs.client_id && object_id == rhs.object_id;
}
};
/// Queues a send.
///
/// \param client_id The ClientID to which the object needs to be sent.
/// \param object_id The ObjectID of the object to be sent.
void QueueSend(const ClientID &client_id, const ObjectID &object_id,
const RemoteConnectionInfo &info);
/// If send_queue_ is not empty, removes a SendRequest from send_queue_ and assigns
/// it to send_ptr. The queue is FIFO.
/// \param send_ptr A pointer to an empty SendRequest.
/// \return A bool indicating whether the queue was empty at the time this method
/// was invoked.
bool DequeueSendIfPresent(TransferQueue::SendRequest *send_ptr);
/// Queues a receive.
///
/// \param client_id The ClientID from which the object is being received.
/// \param object_id The ObjectID of the object to be received.
void QueueReceive(const ClientID &client_id, const ObjectID &object_id,
uint64_t object_size, std::shared_ptr<TcpClientConnection> conn);
/// If receive_queue_ is not empty, removes a ReceiveRequest from receive_queue_ and
/// assigns
/// it to receive_ptr. The queue is FIFO.
/// \param receive_ptr A pointer to an empty ReceiveRequest.
/// \return A bool indicating whether the queue was empty at the time this method
/// was invoked.
bool DequeueReceiveIfPresent(TransferQueue::ReceiveRequest *receive_ptr);
/// Maintain ownership over SendContext for sends in transit.
///
/// \param context The context to maintain.
/// \return A unique identifier identifying the context that was added.
UniqueID AddContext(SendContext &context);
/// Gets the SendContext associated with the given id.
///
/// \param id The unique identifier of the context.
/// \return The context.
SendContext &GetContext(const UniqueID &id);
/// Removes the context associated with the given id.
///
/// \param id The unique identifier of the context.
/// \return The status of invoking this method.
ray::Status RemoveContext(const UniqueID &id);
/// This object cannot be copied for thread-safety.
TransferQueue &operator=(const TransferQueue &o) {
throw std::runtime_error("Can't copy TransferQueue.");
}
private:
// TODO(hme): make this a shared mutex.
typedef std::mutex Lock;
typedef std::unique_lock<Lock> WriteLock;
// TODO(hme): make this a shared lock.
typedef std::unique_lock<Lock> ReadLock;
Lock send_mutex;
Lock receive_mutex;
Lock context_mutex;
std::deque<SendRequest> send_queue_;
std::deque<ReceiveRequest> receive_queue_;
std::unordered_map<ray::UniqueID, SendContext, ray::UniqueIDHasher> send_context_set_;
};
} // namespace ray
#endif // RAY_OBJECT_MANAGER_TRANSFER_QUEUE_H
+18
View File
@@ -0,0 +1,18 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from worker import Worker
parser = argparse.ArgumentParser()
parser.add_argument("raylet_socket_name")
parser.add_argument("object_store_socket_name")
if __name__ == '__main__':
args = parser.parse_args()
worker = Worker(args.raylet_socket_name, args.object_store_socket_name,
is_worker=True)
worker.main_loop()
+33
View File
@@ -0,0 +1,33 @@
import argparse
import ray
from worker import Worker, logger
from ray.utils import random_string
parser = argparse.ArgumentParser()
parser.add_argument("raylet_socket_name")
parser.add_argument("object_store_socket_name")
if __name__ == '__main__':
args = parser.parse_args()
driver = Worker(args.raylet_socket_name, args.object_store_socket_name,
is_worker=False)
task1 = ray.local_scheduler.Task(
ray.local_scheduler.ObjectID(random_string()),
ray.local_scheduler.ObjectID(random_string()),
[],
1,
ray.local_scheduler.ObjectID(random_string()),
0)
logger.debug("submitting", task1.task_id())
driver.node_manager_client.submit(task1)
logger.debug("Return values were", task1.returns())
print("[DRIVER] Return values were", task1.returns())
# Make sure the tasks get executed and we can get the result of the
# last task
obj = driver.get(task1.returns(), timeout_ms=1000)
print("[DRIVER]: task1 driver.get result ", obj)
+40
View File
@@ -0,0 +1,40 @@
import argparse
import ray
from worker import Worker, logger
from ray.utils import random_string
parser = argparse.ArgumentParser()
parser.add_argument("raylet_socket_name")
parser.add_argument("object_store_socket_name")
if __name__ == '__main__':
args = parser.parse_args()
driver = Worker(args.raylet_socket_name, args.object_store_socket_name,
is_worker=False)
task = ray.local_scheduler.Task(
ray.local_scheduler.ObjectID(random_string()),
ray.local_scheduler.ObjectID(random_string()),
[],
1,
ray.local_scheduler.ObjectID(random_string()),
0)
logger.debug("submitting %s", task.task_id())
driver.node_manager_client.submit(task)
logger.debug("Return values were %s", task.returns())
task2 = ray.local_scheduler.Task(
ray.local_scheduler.ObjectID(random_string()),
ray.local_scheduler.ObjectID(random_string()),
task.returns(),
1,
ray.local_scheduler.ObjectID(random_string()),
0)
logger.debug("Submitting dependent task 2 %s", task2.task_id())
driver.node_manager_client.submit(task2)
# Make sure the tasks get executed and we can get the result of the last
# task.
obj = driver.get(task2.returns(), timeout_ms=1000)
+115
View File
@@ -0,0 +1,115 @@
import argparse
import ray
from worker import Worker, logger
from ray.utils import random_string
parser = argparse.ArgumentParser()
parser.add_argument("raylet_socket_name")
parser.add_argument("object_store_socket_name")
def submit_task_withdep(driver_handle, task_object_dependencies=[]):
''' submit a task that depend on a list of @args'''
task = ray.local_scheduler.Task(
ray.local_scheduler.ObjectID(random_string()),
ray.local_scheduler.ObjectID(random_string()),
task_object_dependencies,
1, # num_returns
ray.local_scheduler.ObjectID(random_string()),
0)
logger.debug("[DRIVER]: submitting task ", task.task_id())
driver_handle.node_manager_client.submit(task)
logger.debug("[DRIVER]: task return values", task.returns())
return task.returns()
def submit_tasks_nodep(driver_handle, num_tasks):
''' submit a task that depend on a list of @args'''
for i in range(num_tasks):
task = ray.local_scheduler.Task(
ray.local_scheduler.ObjectID(random_string()),
ray.local_scheduler.ObjectID(random_string()),
[],
1, # num_returns
ray.local_scheduler.ObjectID(random_string()),
0)
logger.debug("[DRIVER]: submitting task ", task.task_id())
driver_handle.node_manager_client.submit(task)
logger.debug("[DRIVER]: task return values", task.returns())
def submit_task_chains(num_chains, tasks_per_chain):
# return task placement map on output
chain_returns = []
task_placement_map_ = {}
for chain_num in range(num_chains):
last_task_returns = []
task_placement_map_[chain_num] = []
for i in range(tasks_per_chain):
task_returns = submit_task_withdep(
driver,
task_object_dependencies=last_task_returns)
last_task_returns = task_returns
task_placement_map_[chain_num].append(task_returns[0])
chain_returns.append(last_task_returns)
logger.debug("chain_returns=", chain_returns)
chain_results = driver.get([r[0] for r in chain_returns], timeout_ms=5000)
print("[DRIVER]: chain return values: ", chain_results)
return task_placement_map_
def TEST_run_task_chains(num_chains, tasks_per_chain):
task_placement_map = submit_task_chains(num_chains=num_chains,
tasks_per_chain=tasks_per_chain)
logger.debug("[DRIVER]: task placement information, per chain:")
task_placement_total = []
for chain_num in range(len(task_placement_map)):
task_placement_list = driver.get(task_placement_map[chain_num],
timeout_ms=5000)
task_placement_total += [t[1] for t in task_placement_list]
logger.debug(chain_num, task_placement_list)
logger.debug("task placement overall: ", task_placement_total)
task_placement_stats = [(v, task_placement_total.count(v))
for v in set(task_placement_total)]
num_total_tasks = sum([t[1] for t in task_placement_stats])
print("total tasks executed = ", num_total_tasks)
assert(num_total_tasks == num_chains * tasks_per_chain)
print("task placement breakdown: total=", task_placement_stats)
def TEST_run_tasks_nodep(num_tasks):
# This test is the same as having num_tasks chains with 1 task per chain
# In this test we assume the num_tasks x 1 chain structure.
task_placement_map = submit_task_chains(num_chains=num_tasks,
tasks_per_chain=1)
logger.debug("[DRIVER]: task placement information, per chain:")
task_placement_total = []
for chain_num in range(len(task_placement_map)):
task_placement_list = driver.get(task_placement_map[chain_num],
timeout_ms=5000)
task_placement_total += [t[1] for t in task_placement_list]
logger.debug(chain_num, task_placement_list)
logger.debug("task placement overall: ", task_placement_total)
task_placement_stats = [(v, task_placement_total.count(v)) for v in
set(task_placement_total)]
num_total_tasks = sum([t[1] for t in task_placement_stats])
print("total tasks executed = ", num_total_tasks)
assert(num_total_tasks == num_tasks)
print("task placement breakdown: total=", task_placement_stats)
if __name__ == '__main__':
args = parser.parse_args()
driver = Worker(args.raylet_socket_name, args.object_store_socket_name,
is_worker=False)
# Set up the experiment : number of chains and tasks per chain.
# TEST_run_task_chains(num_chains=10, tasks_per_chain=100)
TEST_run_tasks_nodep(10000)
+67
View File
@@ -0,0 +1,67 @@
import logging
import ray
import pyarrow
import pyarrow.plasma as plasma
from ray.utils import random_string
logging.basicConfig()
logger = logging.getLogger(__name__)
# The default return value to put in the object store.
RETURN_VALUE = 0
class Worker(object):
total_task_count = 0
def __init__(self, raylet_socket_name, object_store_socket_name,
is_worker):
# Connect to the Raylet and object store.
self.node_manager_client = ray.local_scheduler.LocalSchedulerClient(
raylet_socket_name, random_string(), is_worker)
self.plasma_client = plasma.connect(object_store_socket_name, "", 0)
self.serialization_context = pyarrow.default_serialization_context()
self.raylet_socket_name = raylet_socket_name
self.object_store_socket_name = object_store_socket_name
def main_loop(self):
while True:
self.get_task()
def get(self, object_ids, timeout_ms=-1):
for object_id in object_ids:
self.node_manager_client.reconstruct_object(object_id.id())
plasma_ids = [plasma.ObjectID(argument.id()) for argument in
object_ids]
values = self.plasma_client.get(plasma_ids, timeout_ms,
self.serialization_context)
assert(all(value[0] == RETURN_VALUE for value in values))
return values
def get_task(self):
logger.debug("[WORKER] waiting for task")
task = self.node_manager_client.get_task()
logger.debug("Worker assigned %s with arguments %s",
ray.utils.binary_to_hex(task.task_id().id()),
" ".join([ray.utils.binary_to_hex(argument.id()) for
argument in task.arguments()]))
# Get the arguments. NOTE(swang): This will hang forever if the
# arguments have been evicted.
arguments = self.get(task.arguments())
for object_id in task.returns():
self.plasma_client.put((RETURN_VALUE, self.raylet_socket_name),
plasma.ObjectID(object_id.id()))
objval = self.plasma_client.get([plasma.ObjectID(object_id.id())])
assert(all([o[0] == RETURN_VALUE for o in objval]))
logger.debug("Worker returned %s",
" ".join([ray.utils.binary_to_hex(return_id.id()) for
return_id in task.returns()]))
# Release the arguments.
del arguments
+4 -3
View File
@@ -19,17 +19,18 @@ add_custom_command(
add_custom_target(gen_node_manager_fbs DEPENDS ${NODE_MANAGER_FBS_OUTPUT_FILES})
ADD_RAY_TEST(raylet_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
add_library(rayletlib raylet.cc ${NODE_MANAGER_FBS_OUTPUT_FILES})
target_link_libraries(rayletlib ray_static ${Boost_SYSTEM_LIBRARY})
add_executable(raylet main.cc)
target_link_libraries(raylet rayletlib ${Boost_SYSTEM_LIBRARY} pthread)
add_executable(raylet_demo remote_dependencies_demo.cc)
target_link_libraries(raylet_demo rayletlib ${Boost_SYSTEM_LIBRARY} pthread)
install(FILES
raylet
+4
View File
@@ -2,10 +2,14 @@
namespace ray {
namespace raylet {
ActorInformation::ActorInformation() : id_(UniqueID::nil()) {}
ActorInformation::~ActorInformation() {}
const ActorID &ActorInformation::GetActorId() const { return this->id_; }
} // namespace raylet
} // namespace ray
+5
View File
@@ -4,6 +4,9 @@
#include "ray/id.h"
namespace ray {
namespace raylet {
class ActorInformation {
public:
/// \brief ActorInformation constructor.
@@ -21,6 +24,8 @@ class ActorInformation {
ActorID id_;
}; // class ActorInformation
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_ACTOR_H
+13 -54
View File
@@ -51,7 +51,9 @@ enum MessageType:int {
// counts and execution dependencies, discard any tasks that already executed
// before the checkpoint, and make any tasks on the frontier runnable by
// making their execution dependencies available.
SetActorFrontier
SetActorFrontier,
// A node manager request to process a task forwarded from another node manager.
ForwardTaskRequest
}
table TaskExecutionSpecification {
@@ -82,12 +84,6 @@ table GetTaskReply {
gpu_ids: [int];
}
table EventLogMessage {
key: string;
value: string;
timestamp: double;
}
// This struct is used to register a new worker with the local scheduler.
// It is shipped as part of local_scheduler_connect.
table RegisterClientRequest {
@@ -108,57 +104,20 @@ table RegisterClientReply {
gpu_ids: [int];
}
table DisconnectClient {
}
table ReconstructObject {
// Object ID of the object that needs to be reconstructed.
object_id: string;
}
table PutObject {
// Task ID of the task that performed the put.
task_id: string;
// Object ID of the object that is being put.
object_id: string;
}
// The ActorFrontier is used to represent the current frontier of tasks that
// the local scheduler has marked as runnable for a particular actor. It is
// used to save the point in an actor's lifetime at which a checkpoint was
// taken, so that the same frontier of tasks can be made runnable again if the
// actor is resumed from that checkpoint.
table ActorFrontier {
// Actor ID of the actor whose frontier is described.
actor_id: string;
// A list of handle IDs, representing the callers of the actor that have
// submitted a runnable task to the local scheduler. A nil ID represents the
// creator of the actor.
handle_ids: [string];
// A list representing the number of tasks executed so far, per handle. Each
// count in task_counters corresponds to the handle at the same in index in
// handle_ids.
task_counters: [long];
// A list representing the execution dependency for the next runnable task,
// per handle. Each execution dependency in frontier_dependencies corresponds
// to the handle at the same in index in handle_ids.
frontier_dependencies: [string];
}
table GetActorFrontierRequest {
actor_id: string;
}
table RegisterNodeManagerRequest {
// GCS ClientID of the connecting node manager.
client_id: string;
}
table ForwardTaskRequest {
// The task to be forwarded.
// TODO(swang): Replace with a Task flatbuffer type.
task: string;
// The uncommitted lineage of the forwarded task, according to the sending
// node manager.
uncommitted_lineage: [string];
// The ID of the task to be forwarded.
task_id: string;
// The tasks in the uncommitted lineage of the forwarded task. This
// should include task_id.
uncommitted_tasks: [Task];
}
table ReconstructObject {
// Object ID of the object that needs to be reconstructed.
object_id: string;
}
+281 -13
View File
@@ -2,30 +2,298 @@
namespace ray {
LineageCache::LineageCache() {}
namespace raylet {
ray::Status LineageCache::AddTask(const Task &task) {
throw std::runtime_error("method not implemented");
return ray::Status::OK();
LineageEntry::LineageEntry(const Task &task, GcsStatus status)
: status_(status), task_(task) {}
GcsStatus LineageEntry::GetStatus() const { return status_; }
bool LineageEntry::SetStatus(GcsStatus new_status) {
if (status_ < new_status) {
status_ = new_status;
return true;
} else {
return false;
}
}
ray::Status LineageCache::AddTask(const Task &task, const Lineage &uncommitted_lineage) {
throw std::runtime_error("method not implemented");
return ray::Status::OK();
void LineageEntry::ResetStatus(GcsStatus new_status) {
RAY_CHECK(new_status < status_);
status_ = new_status;
}
ray::Status LineageCache::AddObjectLocation(const ObjectID &object_id) {
throw std::runtime_error("method not implemented");
return ray::Status::OK();
const TaskID LineageEntry::GetEntryId() const {
return task_.GetTaskSpecification().TaskId();
}
Lineage &LineageCache::GetUncommittedLineage(const ObjectID &object_id) {
throw std::runtime_error("method not implemented");
const std::unordered_set<UniqueID, UniqueIDHasher> LineageEntry::GetParentTaskIds()
const {
std::unordered_set<UniqueID, UniqueIDHasher> parent_ids;
// A task's parents are the tasks that created its arguments.
auto dependencies = task_.GetDependencies();
for (auto &dependency : dependencies) {
parent_ids.insert(ComputeTaskId(dependency));
}
return parent_ids;
}
const Task &LineageEntry::TaskData() const { return task_; }
Task &LineageEntry::TaskDataMutable() { return task_; }
Lineage::Lineage() {}
Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) {
// Deserialize and set entries for the uncommitted tasks.
auto tasks = task_request.uncommitted_tasks();
for (auto it = tasks->begin(); it != tasks->end(); it++) {
auto task = Task(**it);
LineageEntry entry(task, GcsStatus_UNCOMMITTED_REMOTE);
RAY_CHECK(SetEntry(std::move(entry)));
}
}
boost::optional<const LineageEntry &> Lineage::GetEntry(const UniqueID &task_id) const {
auto entry = entries_.find(task_id);
if (entry != entries_.end()) {
return entry->second;
} else {
return boost::optional<const LineageEntry &>();
}
}
boost::optional<LineageEntry &> Lineage::GetEntryMutable(const UniqueID &task_id) {
auto entry = entries_.find(task_id);
if (entry != entries_.end()) {
return entry->second;
} else {
return boost::optional<LineageEntry &>();
}
}
bool Lineage::SetEntry(LineageEntry &&new_entry) {
// Get the status of the current entry at the key.
auto task_id = new_entry.GetEntryId();
GcsStatus current_status = GcsStatus_NONE;
auto current_entry = PopEntry(task_id);
if (current_entry) {
current_status = current_entry->GetStatus();
}
if (current_status < new_entry.GetStatus()) {
// If the new status is greater, then overwrite the current entry.
entries_.emplace(std::make_pair(task_id, std::move(new_entry)));
return true;
} else {
// If the new status is not greater, then the new entry is invalid. Replace
// the current entry at the key.
entries_.emplace(std::make_pair(task_id, std::move(*current_entry)));
return false;
}
}
boost::optional<LineageEntry> Lineage::PopEntry(const UniqueID &task_id) {
auto entry = entries_.find(task_id);
if (entry != entries_.end()) {
LineageEntry entry = std::move(entries_.at(task_id));
entries_.erase(task_id);
return entry;
} else {
return boost::optional<LineageEntry>();
}
}
const std::unordered_map<const UniqueID, LineageEntry, UniqueIDHasher>
&Lineage::GetEntries() const {
return entries_;
}
flatbuffers::Offset<protocol::ForwardTaskRequest> Lineage::ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb, const TaskID &task_id) const {
RAY_CHECK(GetEntry(task_id));
// Serialize the task and object entries.
std::vector<flatbuffers::Offset<protocol::Task>> uncommitted_tasks;
for (const auto &entry : entries_) {
uncommitted_tasks.push_back(entry.second.TaskData().ToFlatbuffer(fbb));
}
auto request = protocol::CreateForwardTaskRequest(fbb, to_flatbuf(fbb, task_id),
fbb.CreateVector(uncommitted_tasks));
return request;
}
LineageCache::LineageCache(gcs::TableInterface<TaskID, protocol::Task> &task_storage)
: task_storage_(task_storage) {}
/// A helper function to merge one lineage into another, in DFS order.
///
/// \param task_id The current entry to merge from lineage_from into
/// lineage_to.
/// \param lineage_from The lineage to merge entries from. This lineage is
/// traversed by following each entry's parent pointers in DFS order,
/// until an entry is not found or the stopping condition is reached.
/// \param lineage_to The lineage to merge entries into.
/// \param stopping_condition A stopping condition for the DFS over
/// lineage_from. This should return true if the merge should stop.
void MergeLineageHelper(const UniqueID &task_id, const Lineage &lineage_from,
Lineage &lineage_to,
std::function<bool(GcsStatus)> stopping_condition) {
// If the entry is not found in the lineage to merge, then we stop since
// there is nothing to copy into the merged lineage.
auto entry = lineage_from.GetEntry(task_id);
if (!entry) {
return;
}
// Check whether we should stop at this entry in the DFS.
auto status = entry->GetStatus();
if (stopping_condition(status)) {
return;
}
// Insert a copy of the entry into lineage_to.
LineageEntry entry_copy = *entry;
auto parent_ids = entry_copy.GetParentTaskIds();
// If the insert is successful, then continue the DFS. The insert will fail
// if the new entry has an equal or lower GCS status than the current entry
// in lineage_to. This also prevents us from traversing the same node twice.
if (lineage_to.SetEntry(std::move(entry_copy))) {
for (const auto &parent_id : parent_ids) {
MergeLineageHelper(parent_id, lineage_from, lineage_to, stopping_condition);
}
}
}
void LineageCache::AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage) {
auto task_id = task.GetTaskSpecification().TaskId();
// Merge the uncommitted lineage into the lineage cache.
MergeLineageHelper(task_id, uncommitted_lineage, lineage_, [](GcsStatus status) {
if (status != GcsStatus_NONE) {
// We received the uncommitted lineage from a remote node, so make sure
// that all entries in the lineage to merge have status
// UNCOMMITTED_REMOTE.
RAY_CHECK(status == GcsStatus_UNCOMMITTED_REMOTE);
}
// The only stopping condition is that an entry is not found.
return false;
});
// Add the submitted task to the lineage cache as UNCOMMITTED_WAITING. It
// should be marked as UNCOMMITTED_READY once the task starts execution.
LineageEntry task_entry(task, GcsStatus_UNCOMMITTED_WAITING);
RAY_CHECK(lineage_.SetEntry(std::move(task_entry)));
}
void LineageCache::AddReadyTask(const Task &task) {
auto new_entry = LineageEntry(task, GcsStatus_UNCOMMITTED_READY);
RAY_CHECK(lineage_.SetEntry(std::move(new_entry)));
}
void LineageCache::RemoveWaitingTask(const TaskID &task_id) {
auto entry = lineage_.PopEntry(task_id);
// It's only okay to remove a task that is waiting for execution.
// TODO(swang): Is this necessarily true when there is reconstruction?
RAY_CHECK(entry->GetStatus() == GcsStatus_UNCOMMITTED_WAITING);
// Reset the status to REMOTE. We keep the task instead of removing it
// completely in case another task is submitted locally that depends on this
// one.
entry->ResetStatus(GcsStatus_UNCOMMITTED_REMOTE);
RAY_CHECK(lineage_.SetEntry(std::move(*entry)));
}
Lineage LineageCache::GetUncommittedLineage(const TaskID &task_id) const {
Lineage uncommitted_lineage;
// Add all uncommitted ancestors from the lineage cache to the uncommitted
// lineage of the requested task.
MergeLineageHelper(task_id, lineage_, uncommitted_lineage, [](GcsStatus status) {
// The stopping condition for recursion is that the entry has been
// committed to the GCS.
return status == GcsStatus_COMMITTED;
});
return uncommitted_lineage;
}
Status LineageCache::Flush() {
throw std::runtime_error("method not implemented");
// Find all tasks that are READY and whose arguments have been committed in the GCS.
std::vector<TaskID> ready_task_ids;
for (const auto &pair : lineage_.GetEntries()) {
auto task_id = pair.first;
auto entry = pair.second;
// Skip task entries that are not ready to be written yet. These tasks
// either have not started execution yet, are being executed on a remote
// node, or have already been written to the GCS.
if (entry.GetStatus() != GcsStatus_UNCOMMITTED_READY) {
continue;
}
// Check if all arguments have been committed to the GCS before writing
// this task.
bool all_arguments_committed = true;
for (const auto &parent_id : entry.GetParentTaskIds()) {
auto parent = lineage_.GetEntry(parent_id);
// If a parent entry exists in the lineage cache but has not been
// committed yet, then as far as we know, it's still in flight to the
// GCS. Skip this task for now.
if (parent && parent->GetStatus() != GcsStatus_COMMITTED) {
// TODO(swang): Once GCS notifications for the task table are ready,
// request notification for commit of the parent task here.
all_arguments_committed = false;
break;
}
}
if (all_arguments_committed) {
// All arguments have been committed to the GCS. Add this task to the
// list of tasks to write back to the GCS.
ready_task_ids.push_back(task_id);
}
}
// Write back all ready tasks whose arguments have been committed to the GCS.
gcs::raylet::TaskTable::WriteCallback task_callback =
[this](ray::gcs::AsyncGcsClient *client, const TaskID &id,
const std::shared_ptr<protocol::TaskT> data) { HandleEntryCommitted(id); };
for (const auto &ready_task_id : ready_task_ids) {
auto task = lineage_.GetEntry(ready_task_id);
// TODO(swang): Make this better...
flatbuffers::FlatBufferBuilder fbb;
auto message = task->TaskData().ToFlatbuffer(fbb);
fbb.Finish(message);
auto task_data = std::make_shared<protocol::TaskT>();
auto root = flatbuffers::GetRoot<protocol::Task>(fbb.GetBufferPointer());
root->UnPackTo(task_data.get());
RAY_CHECK_OK(task_storage_.Add(task->TaskData().GetTaskSpecification().DriverId(),
ready_task_id, task_data, task_callback));
// We successfully wrote the task, so mark it as committing.
// TODO(swang): Use a batched interface and write with all object entries.
auto entry = lineage_.PopEntry(ready_task_id);
RAY_CHECK(entry->SetStatus(GcsStatus_COMMITTING));
RAY_CHECK(lineage_.SetEntry(std::move(*entry)));
}
return ray::Status::OK();
}
void PopAncestorTasks(const UniqueID &task_id, Lineage &lineage) {
auto entry = lineage.PopEntry(task_id);
if (!entry) {
return;
}
auto status = entry->GetStatus();
RAY_CHECK(status == GcsStatus_UNCOMMITTED_REMOTE || status == GcsStatus_COMMITTED);
for (const auto &parent_id : entry->GetParentTaskIds()) {
PopAncestorTasks(parent_id, lineage);
}
}
void LineageCache::HandleEntryCommitted(const UniqueID &task_id) {
auto entry = lineage_.PopEntry(task_id);
for (const auto &parent_id : entry->GetParentTaskIds()) {
PopAncestorTasks(parent_id, lineage_);
}
RAY_CHECK(entry->SetStatus(GcsStatus_COMMITTED));
RAY_CHECK(lineage_.SetEntry(std::move(*entry)));
}
} // namespace raylet
} // namespace ray
+178 -45
View File
@@ -1,83 +1,216 @@
#ifndef RAY_RAYLET_LINEAGE_CACHE_H
#define RAY_RAYLET_LINEAGE_CACHE_H
#include <boost/optional.hpp>
// clang-format off
#include "common_protocol.h"
#include "ray/raylet/task.h"
#include "ray/gcs/tables.h"
#include "ray/id.h"
#include "ray/status.h"
// clang-format on
namespace ray {
// TODO(swang): Define this class.
class Lineage {};
namespace raylet {
class LineageCacheEntry {
private:
// TODO(swang): This should be an enum of the state of the entry - goes from
// completely local, to dirty, to in flight, to committed.
// bool dirty_;
/// The status of a lineage cache entry according to its status in the GCS.
enum GcsStatus {
/// The task is not in the lineage cache.
GcsStatus_NONE = 0,
/// The task is being executed or created on a remote node.
GcsStatus_UNCOMMITTED_REMOTE,
/// The task is waiting to be executed or created locally.
GcsStatus_UNCOMMITTED_WAITING,
/// The task has started execution, but the entry has not been written to the
/// GCS yet.
GcsStatus_UNCOMMITTED_READY,
/// The task has been written to the GCS and we are waiting for an
/// acknowledgement of the commit.
GcsStatus_COMMITTING,
/// The task has been committed in the GCS. It's safe to remove this entry
/// from the lineage cache.
GcsStatus_COMMITTED,
};
class LineageCacheTaskEntry : public LineageCacheEntry {};
class LineageCacheObjectEntry : public LineageCacheEntry {};
/// \class LineageEntry
///
/// A task entry in the data lineage. Each entry's parents are the tasks that
/// created the entry's arguments.
class LineageEntry {
public:
/// Create an entry for a task.
///
/// \param task The task data to eventually be written back to the GCS.
/// \param status The status of this entry, according to its write status in
/// the GCS.
LineageEntry(const Task &task, GcsStatus status);
/// Get this entry's GCS status.
///
/// \return The entry's status in the GCS.
GcsStatus GetStatus() const;
/// Set this entry's GCS status. The status is only set if the new status
/// is strictly greater than the entry's previous status, according to the
/// GcsStatus enum.
///
/// \param new_status Set the entry's status to this value if it is greater
/// than the current status.
/// \return Whether the entry was set to the new status.
bool SetStatus(GcsStatus new_status);
/// Reset this entry's GCS status to a lower status. The new status must
/// be lower than the current status.
///
/// \param new_status This must be lower than the current status.
void ResetStatus(GcsStatus new_status);
/// Get this entry's ID.
///
/// \return The entry's ID.
const TaskID GetEntryId() const;
/// Get the IDs of this entry's parent tasks. These are the IDs of the tasks
/// that created its arguments.
///
/// \return The IDs of the parent entries.
const std::unordered_set<TaskID, UniqueIDHasher> GetParentTaskIds() const;
/// Get the task data.
///
/// \return The task data.
const Task &TaskData() const;
Task &TaskDataMutable();
private:
/// The current state of this entry according to its status in the GCS.
GcsStatus status_;
/// The task data to be written to the GCS. This is nullptr if the entry is
/// an object.
// const Task task_;
Task task_;
};
/// \class Lineage
///
/// A lineage DAG, according to the data dependency graph. Each node is a task,
/// with an outgoing edge to each of its parent tasks. For a given task, the
/// parents are the tasks that created its arguments. Each entry also records
/// the current status in the GCS for that task or object.
class Lineage {
public:
/// Construct an empty Lineage.
Lineage();
/// Construct a Lineage from a ForwardTaskRequest.
///
/// \param task_request The request to construct the lineage from. All
/// uncommitted tasks in the request will be added to the lineage.
Lineage(const protocol::ForwardTaskRequest &task_request);
/// Get an entry from the lineage.
///
/// \param entry_id The ID of the entry to get.
/// \return An optional reference to the entry. If this is empty, then the
/// entry ID is not in the lineage.
boost::optional<const LineageEntry &> GetEntry(const TaskID &entry_id) const;
boost::optional<LineageEntry &> GetEntryMutable(const UniqueID &task_id);
/// Set an entry in the lineage. If an entry with this ID already exists,
/// then the entry is overwritten if and only if the new entry has a higher
/// GCS status than the current. The current entry's object or task data will
/// also be overwritten.
///
/// \param entry The new entry to set in the lineage, if its GCS status is
/// greater than the current entry.
/// \return Whether the entry was set.
bool SetEntry(LineageEntry &&entry);
/// Delete and return an entry from the lineage.
///
/// \param entry_id The ID of the entry to pop.
/// \return An optional reference to the popped entry. If this is empty, then
/// the entry ID is not in the lineage.
boost::optional<LineageEntry> PopEntry(const TaskID &entry_id);
/// Get all entries in the lineage.
///
/// \return A const reference to the lineage entries.
const std::unordered_map<const TaskID, LineageEntry, UniqueIDHasher> &GetEntries()
const;
/// Serialize this lineage to a ForwardTaskRequest flatbuffer.
///
/// \param entry_id The task ID to include in the ForwardTaskRequest
/// flatbuffer.
/// \return An offset to the serialized lineage. The serialization includes
/// all task and object entries in the lineage.
flatbuffers::Offset<protocol::ForwardTaskRequest> ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb, const TaskID &entry_id) const;
private:
/// The lineage entries.
std::unordered_map<const TaskID, LineageEntry, UniqueIDHasher> entries_;
};
/// \class LineageCache
///
/// A cache of the object and task tables. This consists of all tasks that this
/// node owns, as well as their task lineage, that have not yet been added
/// durably to the GCS.
/// A cache of the task table. This consists of all tasks that this node owns,
/// as well as their lineage, that have not yet been added durably to the GCS.
class LineageCache {
public:
/// Create a lineage cache policy.
/// TODO(swang): Pass in the policy (interface?) and a GCS client.
LineageCache();
/// Create a lineage cache for the given task storage system.
/// TODO(swang): Pass in the policy (interface?).
LineageCache(gcs::TableInterface<TaskID, protocol::Task> &task_storage);
/// Add a task and its object outputs asynchronously to the GCS. This
/// overwrites the task's mutable fields in the execution specification.
/// Add a task that is waiting for execution and its uncommitted lineage.
/// These entries will not be written to the GCS until set to ready.
///
/// \param task The task to add.
/// \return Status.
ray::Status AddTask(const Task &task);
/// Add a task and its uncommitted lineage asynchronously to the GCS. The
/// mutable fields for the given task will be overwritten, but not for the
/// tasks in the uncommitted lineage.
///
/// \param task The task to add.
/// \param task The waiting task to add.
/// \param uncommitted_lineage The task's uncommitted lineage. These are the
/// tasks that the given task is data-dependent on, but that have not
/// been made durable in the GCS, as far as we know.
/// \return Status.
ray::Status AddTask(const Task &task, const Lineage &uncommitted_lineage);
/// tasks that the given task is data-dependent on, but that have not
/// been made durable in the GCS, as far the task's submitter knows.
void AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage);
/// Add this node as an object location, to be asynchronously committed to
/// the GCS.
/// Add a task that is ready for GCS writeback. This overwrites the tasks
/// mutable fields in the execution specification.
///
/// \param object_id The object to add a location for.
/// \return Status.
ray::Status AddObjectLocation(const ObjectID &object_id);
/// \param task The task to set as ready.
void AddReadyTask(const Task &task);
/// Get the uncommitted lineage of an object. These are the tasks that the
/// given object is data-dependent on, but that have not been made durable in
void RemoveWaitingTask(const TaskID &entry_id);
/// Get the uncommitted lineage of a task. The uncommitted lineage consists
/// of all tasks in the given task's lineage that have not been committed in
/// the GCS, as far as we know.
///
/// \param object_id The object to get the uncommitted lineage for.
/// \return The uncommitted lineage of the object.
Lineage &GetUncommittedLineage(const ObjectID &object_id);
/// \param entry_id The ID of the task to get the uncommitted lineage for.
/// \return The uncommitted lineage of the task. The returned lineage
/// includes the entry for the requested entry_id.
Lineage GetUncommittedLineage(const TaskID &entry_id) const;
/// Asynchronously write any tasks and object locations that have been added
/// since the last flush to the GCS. When each write is acknowledged, its
/// entry will be marked as committed.
/// Asynchronously write any tasks that have been added since the last flush
/// to the GCS. When each write is acknowledged, its entry will be marked as
/// committed.
///
/// \return Status.
Status Flush();
private:
std::unordered_map<TaskID, LineageCacheTaskEntry, UniqueIDHasher> task_table_;
std::unordered_map<ObjectID, LineageCacheObjectEntry, UniqueIDHasher> object_table_;
void HandleEntryCommitted(const TaskID &unique_id);
/// The durable storage system for task information.
gcs::TableInterface<TaskID, protocol::Task> &task_storage_;
/// All tasks and objects that we are responsible for writing back to the
/// GCS, and the tasks and objects in their lineage.
Lineage lineage_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_LINEAGE_CACHE_H
+270
View File
@@ -0,0 +1,270 @@
#include <list>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/lineage_cache.h"
#include "ray/raylet/task.h"
#include "ray/raylet/task_execution_spec.h"
#include "ray/raylet/task_spec.h"
namespace ray {
namespace raylet {
class MockGcs : virtual public gcs::TableInterface<TaskID, protocol::Task> {
public:
MockGcs(){};
Status Add(const JobID &job_id, const TaskID &task_id,
std::shared_ptr<protocol::TaskT> task_data,
const gcs::TableInterface<TaskID, protocol::Task>::WriteCallback &done) {
task_table_[task_id] = task_data;
callbacks_.push_back(
std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>(done, task_id));
return ray::Status::OK();
};
void Flush() {
for (const auto &callback : callbacks_) {
callback.first(NULL, callback.second, task_table_[callback.second]);
}
callbacks_.clear();
};
const std::unordered_map<TaskID, std::shared_ptr<protocol::TaskT>, UniqueIDHasher>
&TaskTable() const {
return task_table_;
}
private:
std::unordered_map<TaskID, std::shared_ptr<protocol::TaskT>, UniqueIDHasher>
task_table_;
std::vector<std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>> callbacks_;
};
class LineageCacheTest : public ::testing::Test {
public:
LineageCacheTest() : mock_gcs_(), lineage_cache_(mock_gcs_) {}
protected:
MockGcs mock_gcs_;
LineageCache lineage_cache_;
};
static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
int64_t num_returns) {
std::unordered_map<std::string, double> required_resources;
std::vector<std::shared_ptr<TaskArgument>> task_arguments;
for (auto &argument : arguments) {
std::vector<ObjectID> references = {argument};
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
}
auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0,
UniqueID::from_random(), task_arguments, num_returns,
required_resources);
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
execution_spec.IncrementNumForwards();
Task task = Task(execution_spec, spec);
return task;
}
std::vector<ObjectID> InsertTaskChain(LineageCache &lineage_cache,
std::vector<Task> &inserted_tasks, int chain_size,
const std::vector<ObjectID> &initial_arguments,
int64_t num_returns) {
Lineage empty_lineage;
std::vector<ObjectID> arguments = initial_arguments;
for (int i = 0; i < chain_size; i++) {
auto task = ExampleTask(arguments, num_returns);
lineage_cache.AddWaitingTask(task, empty_lineage);
inserted_tasks.push_back(task);
arguments.clear();
for (int j = 0; j < task.GetTaskSpecification().NumReturns(); j++) {
arguments.push_back(task.GetTaskSpecification().ReturnId(j));
}
}
return arguments;
}
TEST_F(LineageCacheTest, TestGetUncommittedLineage) {
// Insert two independent chains of tasks.
std::vector<Task> tasks1;
auto return_values1 =
InsertTaskChain(lineage_cache_, tasks1, 3, std::vector<ObjectID>(), 1);
std::vector<TaskID> task_ids1;
for (const auto &task : tasks1) {
task_ids1.push_back(task.GetTaskSpecification().TaskId());
}
std::vector<Task> tasks2;
auto return_values2 =
InsertTaskChain(lineage_cache_, tasks2, 2, std::vector<ObjectID>(), 2);
std::vector<TaskID> task_ids2;
for (const auto &task : tasks2) {
task_ids2.push_back(task.GetTaskSpecification().TaskId());
}
// Get the uncommitted lineage for the last task (the leaf) of one of the
// chains.
auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_ids1.back());
// Check that the uncommitted lineage is exactly equal to the first chain of
// tasks.
ASSERT_EQ(task_ids1.size(), uncommitted_lineage.GetEntries().size());
for (auto &task_id : task_ids1) {
ASSERT_TRUE(uncommitted_lineage.GetEntry(task_id));
}
// Insert one task that is dependent on the previous chains of tasks.
std::vector<Task> combined_tasks = tasks1;
combined_tasks.insert(combined_tasks.end(), tasks2.begin(), tasks2.end());
std::vector<ObjectID> combined_arguments = return_values1;
combined_arguments.insert(combined_arguments.end(), return_values2.begin(),
return_values2.end());
InsertTaskChain(lineage_cache_, combined_tasks, 1, combined_arguments, 1);
std::vector<TaskID> combined_task_ids;
for (const auto &task : combined_tasks) {
combined_task_ids.push_back(task.GetTaskSpecification().TaskId());
}
// Get the uncommitted lineage for the inserted task.
uncommitted_lineage = lineage_cache_.GetUncommittedLineage(combined_task_ids.back());
// Check that the uncommitted lineage is exactly equal to the entire set of
// tasks inserted so far.
ASSERT_EQ(combined_task_ids.size(), uncommitted_lineage.GetEntries().size());
for (auto &task_id : combined_task_ids) {
ASSERT_TRUE(uncommitted_lineage.GetEntry(task_id));
}
}
void CheckFlush(LineageCache &lineage_cache, MockGcs &mock_gcs,
size_t num_tasks_flushed) {
RAY_CHECK_OK(lineage_cache.Flush());
ASSERT_EQ(mock_gcs.TaskTable().size(), num_tasks_flushed);
}
TEST_F(LineageCacheTest, TestWritebackNoneReady) {
// Insert a chain of dependent tasks.
size_t num_tasks_flushed = 0;
std::vector<Task> tasks;
auto return_values1 =
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
// Check that when no tasks have been marked as ready, we do not flush any
// entries.
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
}
TEST_F(LineageCacheTest, TestWritebackReady) {
// Insert a chain of dependent tasks.
size_t num_tasks_flushed = 0;
std::vector<Task> tasks;
auto return_values1 =
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
// Check that after marking the first task as ready, we flush only that task.
lineage_cache_.AddReadyTask(tasks.front());
num_tasks_flushed++;
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
}
TEST_F(LineageCacheTest, TestWritebackOrder) {
// Insert a chain of dependent tasks.
size_t num_tasks_flushed = 0;
std::vector<Task> tasks;
auto return_values1 =
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
// Mark all tasks as ready.
for (const auto &task : tasks) {
lineage_cache_.AddReadyTask(task);
}
// Check that we write back the tasks in order of data dependencies.
for (size_t i = 0; i < tasks.size(); i++) {
num_tasks_flushed++;
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
// Flush acknowledgements. The next task should be able to be written.
mock_gcs_.Flush();
}
}
TEST_F(LineageCacheTest, TestWritebackPartiallyReady) {
// Create two independent tasks, task1 and task2, and a dependent task
// that depends on both tasks.
size_t num_tasks_flushed = 0;
auto task1 = ExampleTask({}, 1);
auto task2 = ExampleTask({}, 1);
std::vector<ObjectID> returns;
for (int64_t i = 0; i < task1.GetTaskSpecification().NumReturns(); i++) {
returns.push_back(task1.GetTaskSpecification().ReturnId(i));
}
for (int64_t i = 0; i < task2.GetTaskSpecification().NumReturns(); i++) {
returns.push_back(task2.GetTaskSpecification().ReturnId(i));
}
auto dependent_task = ExampleTask(returns, 1);
auto dependencies = dependent_task.GetDependencies();
// Insert all tasks as waiting for execution.
lineage_cache_.AddWaitingTask(task1, Lineage());
lineage_cache_.AddWaitingTask(task2, Lineage());
lineage_cache_.AddWaitingTask(dependent_task, Lineage());
// Mark one of the independent tasks and the dependent task as ready.
lineage_cache_.AddReadyTask(task1);
lineage_cache_.AddReadyTask(dependent_task);
// Check that only the first independent task is flushed.
num_tasks_flushed++;
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
// Flush acknowledgements. The dependent task should still not be flushed
// since task2 is not committed yet.
mock_gcs_.Flush();
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
// Mark the other independent task as ready.
lineage_cache_.AddReadyTask(task2);
// Check that the other independent task gets flushed.
num_tasks_flushed++;
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
// Flush acknowledgements. The dependent task should now be able to be
// written.
mock_gcs_.Flush();
num_tasks_flushed++;
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
}
TEST_F(LineageCacheTest, TestRemoveWaitingTask) {
// Insert a chain of dependent tasks.
std::vector<Task> tasks;
auto return_values1 =
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
auto task_to_remove = tasks[1];
auto task_id_to_remove = task_to_remove.GetTaskSpecification().TaskId();
auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id_to_remove);
flatbuffers::FlatBufferBuilder fbb;
auto uncommitted_lineage_message =
uncommitted_lineage.ToFlatbuffer(fbb, task_id_to_remove);
fbb.Finish(uncommitted_lineage_message);
uncommitted_lineage = Lineage(
*flatbuffers::GetRoot<protocol::ForwardTaskRequest>(fbb.GetBufferPointer()));
const Task &task = uncommitted_lineage.GetEntry(task_id_to_remove)->TaskData();
RAY_LOG(INFO) << "removing task " << task.GetTaskSpecification().TaskId()
<< "with numforwards="
<< task.GetTaskExecutionSpecReadonly().NumForwards();
ASSERT_EQ(task.GetTaskExecutionSpecReadonly().NumForwards(), 1);
lineage_cache_.RemoveWaitingTask(task_id_to_remove);
lineage_cache_.AddWaitingTask(task_to_remove, uncommitted_lineage);
}
} // namespace raylet
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
+32 -21
View File
@@ -5,34 +5,45 @@
#ifndef RAYLET_TEST
int main(int argc, char *argv[]) {
RAY_CHECK(argc == 2);
RAY_CHECK(argc == 5);
// start store
std::string executable_str = std::string(argv[0]);
std::string exec_dir = executable_str.substr(0, executable_str.find_last_of("/"));
std::string plasma_dir = exec_dir + "./../plasma";
std::string plasma_command =
plasma_dir +
"/plasma_store -m 1000000000 -s /tmp/store 1> /dev/null 2> /dev/null &";
RAY_LOG(INFO) << plasma_command;
int s = system(plasma_command.c_str());
RAY_CHECK(s == 0);
const std::string raylet_socket_name = std::string(argv[1]);
const std::string store_socket_name = std::string(argv[2]);
const std::string redis_address = std::string(argv[3]);
int redis_port = std::stoi(argv[4]);
// configure
// Configuration for the node manager.
ray::raylet::NodeManagerConfig node_manager_config;
std::unordered_map<std::string, double> static_resource_conf;
static_resource_conf = {{"CPU", 1}, {"GPU", 1}};
ray::ResourceSet resource_config(std::move(static_resource_conf));
ray::ObjectManagerConfig om_config;
om_config.store_socket_name = "/tmp/store";
node_manager_config.resource_config =
ray::raylet::ResourceSet(std::move(static_resource_conf));
node_manager_config.num_initial_workers = 0;
// Use a default worker that can execute empty tasks with dependencies.
node_manager_config.worker_command.push_back("python");
node_manager_config.worker_command.push_back(
"../../../src/ray/python/default_worker.py");
node_manager_config.worker_command.push_back(raylet_socket_name.c_str());
node_manager_config.worker_command.push_back(store_socket_name.c_str());
// TODO(swang): Set this from a global config.
node_manager_config.heartbeat_period_ms = 100;
// Configuration for the object manager.
ray::ObjectManagerConfig object_manager_config;
object_manager_config.store_socket_name = store_socket_name;
// initialize mock gcs & object directory
std::shared_ptr<ray::GcsClient> mock_gcs_client =
std::shared_ptr<ray::GcsClient>(new ray::GcsClient());
auto gcs_client = std::make_shared<ray::gcs::AsyncGcsClient>();
RAY_LOG(INFO) << "Initializing GCS client "
<< gcs_client->client_table().GetLocalClientId();
// Initialize the node manager.
boost::asio::io_service io_service;
ray::Raylet server(io_service, std::string(argv[1]), resource_config, om_config,
mock_gcs_client);
io_service.run();
boost::asio::io_service main_service;
std::unique_ptr<boost::asio::io_service> object_manager_service;
object_manager_service.reset(new boost::asio::io_service());
ray::raylet::Raylet server(main_service, std::move(object_manager_service),
raylet_socket_name, redis_address, redis_port,
node_manager_config, object_manager_config, gcs_client);
main_service.run();
}
#endif
+289 -48
View File
@@ -5,20 +5,153 @@
namespace ray {
NodeManager::NodeManager(const std::string &socket_name,
const ResourceSet &resource_config,
ObjectManager &object_manager)
: local_resources_(resource_config),
worker_pool_(WorkerPool(0)),
namespace raylet {
NodeManager::NodeManager(boost::asio::io_service &io_service,
const NodeManagerConfig &config, ObjectManager &object_manager,
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
: io_service_(io_service),
heartbeat_timer_(io_service),
heartbeat_period_ms_(config.heartbeat_period_ms),
local_resources_(config.resource_config),
worker_pool_(config.num_initial_workers, config.worker_command),
local_queues_(SchedulingQueue()),
scheduling_policy_(local_queues_),
reconstruction_policy_([this](const TaskID &task_id) { ResubmitTask(task_id); }),
task_dependency_manager_(
object_manager,
// reconstruction_policy_,
[this](const TaskID &task_id) { HandleWaitingTaskReady(task_id); }) {
//// TODO(atumanov): need to add the self-knowledge of ClientID, using nill().
// cluster_resource_map_[ClientID::nil()] = local_resources_;
[this](const TaskID &task_id) { HandleWaitingTaskReady(task_id); }),
lineage_cache_(gcs_client->raylet_task_table()),
gcs_client_(gcs_client),
remote_clients_(),
remote_server_connections_(),
object_manager_(object_manager) {
RAY_CHECK(heartbeat_period_ms_ > 0);
// Initialize the resource map with own cluster resource configuration.
ClientID local_client_id = gcs_client_->client_table().GetLocalClientId();
cluster_resource_map_.emplace(local_client_id,
SchedulingResources(config.resource_config));
}
void NodeManager::Heartbeat() {
RAY_LOG(DEBUG) << "[Heartbeat] sending heartbeat.";
auto &heartbeat_table = gcs_client_->heartbeat_table();
auto heartbeat_data = std::make_shared<HeartbeatTableDataT>();
auto client_id = gcs_client_->client_table().GetLocalClientId();
const SchedulingResources &local_resources = cluster_resource_map_[client_id];
heartbeat_data->client_id = client_id.hex();
// TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly.
// TODO(atumanov): implement a ResourceSet const_iterator.
for (const auto &resource_pair :
local_resources.GetAvailableResources().GetResourceMap()) {
heartbeat_data->resources_available_label.push_back(resource_pair.first);
heartbeat_data->resources_available_capacity.push_back(resource_pair.second);
}
for (const auto &resource_pair : local_resources.GetTotalResources().GetResourceMap()) {
heartbeat_data->resources_total_label.push_back(resource_pair.first);
heartbeat_data->resources_total_capacity.push_back(resource_pair.second);
}
ray::Status status = heartbeat_table.Add(
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data,
[](ray::gcs::AsyncGcsClient *client, const ClientID &id,
std::shared_ptr<HeartbeatTableDataT> data) {
RAY_LOG(DEBUG) << "[HEARTBEAT] heartbeat sent callback";
});
if (!status.ok()) {
RAY_LOG(INFO) << "heartbeat failed: string " << status.ToString() << status.message();
RAY_LOG(INFO) << "is redis error: " << status.IsRedisError();
}
RAY_CHECK_OK(status);
// Reset the timer.
auto heartbeat_period = boost::posix_time::milliseconds(heartbeat_period_ms_);
heartbeat_timer_.expires_from_now(heartbeat_period);
heartbeat_timer_.async_wait([this](const boost::system::error_code &error) {
RAY_CHECK(!error);
Heartbeat();
});
}
void NodeManager::ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id,
const ClientTableDataT &client_data) {
ClientID client_id = ClientID::from_binary(client_data.client_id);
RAY_LOG(DEBUG) << "[ClientAdded] received callback from client id " << client_id.hex();
if (client_id == gcs_client_->client_table().GetLocalClientId()) {
// We got a notification for ourselves, so we are connected to the GCS now.
// Save this NodeManager's resource information in the cluster resource map.
cluster_resource_map_[client_id] = local_resources_;
// Start sending heartbeats to the GCS.
Heartbeat();
// Subscribe to heartbeats.
const auto heartbeat_added = [this](gcs::AsyncGcsClient *client, const ClientID &id,
const HeartbeatTableDataT &heartbeat_data) {
this->HeartbeatAdded(client, id, heartbeat_data);
};
ray::Status status = client->heartbeat_table().Subscribe(
UniqueID::nil(), UniqueID::nil(), heartbeat_added,
[this](gcs::AsyncGcsClient *client) {
RAY_LOG(DEBUG) << "heartbeat table subscription done callback called.";
});
RAY_CHECK_OK(status);
return;
}
// TODO(atumanov): make remote client lookup O(1)
if (std::find(remote_clients_.begin(), remote_clients_.end(), client_id) ==
remote_clients_.end()) {
RAY_LOG(DEBUG) << "a new client: " << client_id.hex();
remote_clients_.push_back(client_id);
} else {
// NodeManager connection to this client was already established.
RAY_LOG(DEBUG) << "received a new client connection that already exists: "
<< client_id.hex();
return;
}
ResourceSet resources_total(client_data.resources_total_label,
client_data.resources_total_capacity);
this->cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total));
// Establish a new NodeManager connection to this GCS client.
auto client_info = gcs_client_->client_table().GetClient(client_id);
RAY_LOG(DEBUG) << "[ClientAdded] CONNECTING TO: "
<< " " << client_info.node_manager_address << " "
<< client_info.node_manager_port;
boost::asio::ip::tcp::socket socket(io_service_);
RAY_CHECK_OK(TcpConnect(socket, client_info.node_manager_address,
client_info.node_manager_port));
auto server_conn = TcpServerConnection(std::move(socket));
remote_server_connections_.emplace(client_id, std::move(server_conn));
}
void NodeManager::HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &client_id,
const HeartbeatTableDataT &heartbeat_data) {
RAY_LOG(DEBUG) << "[HeartbeatAdded]: received heartbeat from client id "
<< client_id.hex();
if (client_id == gcs_client_->client_table().GetLocalClientId()) {
// Skip heartbeats from self.
return;
}
// Locate the client id in remote client table and update available resources based on
// the received heartbeat information.
if (this->cluster_resource_map_.count(client_id) == 0) {
// Haven't received the client registration for this client yet, skip this heartbeat.
RAY_LOG(INFO) << "[HeartbeatAdded]: received heartbeat from unknown client id "
<< client_id.hex();
return;
}
SchedulingResources &resources = this->cluster_resource_map_[client_id];
ResourceSet heartbeat_resource_available(heartbeat_data.resources_available_label,
heartbeat_data.resources_available_capacity);
resources.SetAvailableResources(
ResourceSet(heartbeat_data.resources_available_label,
heartbeat_data.resources_available_capacity));
RAY_CHECK(this->cluster_resource_map_[client_id].GetAvailableResources() ==
heartbeat_resource_available);
}
void NodeManager::ProcessNewClient(std::shared_ptr<LocalClientConnection> client) {
@@ -40,18 +173,6 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
// Register the new worker.
worker_pool_.RegisterWorker(std::move(worker));
}
// Build the reply to the worker's registration request. TODO(swang): This
// is legacy code and should be removed once actor creation tasks are
// implemented.
flatbuffers::FlatBufferBuilder fbb;
auto reply =
protocol::CreateRegisterClientReply(fbb, fbb.CreateVector(std::vector<int>()));
fbb.Finish(reply);
// Reply to the worker's registration request, then listen for more
// messages.
client->WriteMessage(protocol::MessageType_RegisterClientReply, fbb.GetSize(),
fbb.GetBufferPointer());
} break;
case protocol::MessageType_GetTask: {
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
@@ -74,8 +195,14 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
// Remove the dead worker from the pool and stop listening for messages.
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
if (worker) {
if (!worker->GetAssignedTaskId().is_nil()) {
// TODO(swang): Clean up any tasks that were assigned to the worker.
// Release any resources that may be held by this worker.
FinishTask(worker->GetAssignedTaskId());
}
worker_pool_.DisconnectWorker(worker);
}
return;
} break;
case protocol::MessageType_SubmitTask: {
// Read the task submitted by the client.
@@ -84,14 +211,49 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
from_flatbuf(*message->execution_dependencies()));
TaskSpecification task_spec(*message->task_spec());
Task task(task_execution_spec, task_spec);
// Submit the task to the local scheduler.
SubmitTask(task);
// Listen for more messages.
client->ProcessMessages();
// Submit the task to the local scheduler. Since the task was submitted
// locally, there is no uncommitted lineage.
SubmitTask(task, Lineage());
} break;
case protocol::MessageType_ReconstructObject: {
// TODO(hme): handle multiple object ids.
auto message = flatbuffers::GetRoot<protocol::ReconstructObject>(message_data);
ObjectID object_id = from_flatbuf(*message->object_id());
RAY_LOG(DEBUG) << "reconstructing object " << object_id.hex();
RAY_CHECK_OK(object_manager_.Pull(object_id));
} break;
default:
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
}
// Listen for more messages.
client->ProcessMessages();
}
void NodeManager::ProcessNewNodeManager(
std::shared_ptr<TcpClientConnection> node_manager_client) {
node_manager_client->ProcessMessages();
}
void NodeManager::ProcessNodeManagerMessage(
std::shared_ptr<TcpClientConnection> node_manager_client, int64_t message_type,
const uint8_t *message_data) {
switch (message_type) {
case protocol::MessageType_ForwardTaskRequest: {
auto message = flatbuffers::GetRoot<protocol::ForwardTaskRequest>(message_data);
TaskID task_id = from_flatbuf(*message->task_id());
Lineage uncommitted_lineage(*message);
const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
RAY_LOG(DEBUG) << "got task " << task.GetTaskSpecification().TaskId()
<< " spillback=" << task.GetTaskExecutionSpecReadonly().NumForwards();
SubmitTask(task, uncommitted_lineage);
} break;
default:
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
}
node_manager_client->ProcessMessages();
}
void NodeManager::HandleWaitingTaskReady(const TaskID &task_id) {
@@ -102,29 +264,45 @@ void NodeManager::HandleWaitingTaskReady(const TaskID &task_id) {
}
void NodeManager::ScheduleTasks() {
// Ask policy for scheduling decision.
// TODO(alexey): Give the policy all cluster resources instead of just the
// local one.
std::unordered_map<ClientID, SchedulingResources, UniqueIDHasher> cluster_resource_map;
cluster_resource_map[ClientID::nil()] = local_resources_;
const auto &policy_decision = scheduling_policy_.Schedule(cluster_resource_map);
auto policy_decision = scheduling_policy_.Schedule(
cluster_resource_map_, gcs_client_->client_table().GetLocalClientId(),
remote_clients_);
RAY_LOG(DEBUG) << "[NM ScheduleTasks] policy decision:";
for (const auto &pair : policy_decision) {
TaskID task_id = pair.first;
ClientID client_id = pair.second;
RAY_LOG(DEBUG) << task_id.hex() << " --> " << client_id.hex();
}
// Extract decision for this local scheduler.
// TODO(alexey): Check for this node's own client ID, not for nil.
std::unordered_set<TaskID, UniqueIDHasher> task_ids;
for (auto &task_schedule : policy_decision) {
if (task_schedule.second.is_nil()) {
task_ids.insert(task_schedule.first);
std::unordered_set<TaskID, UniqueIDHasher> local_task_ids;
// Iterate over (taskid, clientid) pairs, extract tasks to run on the local client.
for (const auto &task_schedule : policy_decision) {
TaskID task_id = task_schedule.first;
ClientID client_id = task_schedule.second;
if (client_id == gcs_client_->client_table().GetLocalClientId()) {
local_task_ids.insert(task_id);
} else {
auto tasks = local_queues_.RemoveTasks({task_id});
RAY_CHECK(1 == tasks.size());
Task &task = tasks.front();
// TODO(swang): Handle forward task failure.
// TODO(swang): Unsubscribe this task in the task dependency manager.
RAY_CHECK_OK(ForwardTask(task, client_id));
}
}
// Assign the tasks to workers.
std::vector<Task> tasks = local_queues_.RemoveTasks(task_ids);
std::vector<Task> tasks = local_queues_.RemoveTasks(local_task_ids);
for (auto &task : tasks) {
AssignTask(task);
}
}
void NodeManager::SubmitTask(const Task &task) {
void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineage) {
// Add the task and its uncommitted lineage to the lineage cache.
lineage_cache_.AddWaitingTask(task, uncommitted_lineage);
// Queue the task according to the availability of its arguments.
if (task_dependency_manager_.TaskReady(task)) {
local_queues_.QueueReadyTasks(std::vector<Task>({task}));
ScheduleTasks();
@@ -135,41 +313,104 @@ void NodeManager::SubmitTask(const Task &task) {
}
void NodeManager::AssignTask(const Task &task) {
// Resource accounting: acquire resources for the scheduled task.
const ClientID &my_client_id = gcs_client_->client_table().GetLocalClientId();
RAY_CHECK(this->cluster_resource_map_[my_client_id].Acquire(
task.GetTaskSpecification().GetRequiredResources()));
if (worker_pool_.PoolSize() == 0) {
// Start a new worker.
worker_pool_.StartWorker();
// Queue this task for future assignment. The task will be assigned to a
// worker once one becomes available.
local_queues_.QueueScheduledTasks(std::vector<Task>({task}));
// TODO(swang): Acquire resources here or when a worker becomes available?
return;
}
const TaskSpecification &spec = task.GetTaskSpecification();
std::shared_ptr<Worker> worker = worker_pool_.PopWorker();
RAY_LOG(DEBUG) << "Assigning task to worker with pid " << worker->Pid();
// TODO(swang): Acquire resources for the task.
// local_resources_.Acquire(task.GetTaskSpecification().GetRequiredResources());
worker->AssignTaskId(spec.TaskId());
local_queues_.QueueRunningTasks(std::vector<Task>({task}));
flatbuffers::FlatBufferBuilder fbb;
const TaskSpecification &spec = task.GetTaskSpecification();
auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb),
fbb.CreateVector(std::vector<int>()));
fbb.Finish(message);
worker->Connection()->WriteMessage(protocol::MessageType_ExecuteTask, fbb.GetSize(),
fbb.GetBufferPointer());
worker->AssignTaskId(spec.TaskId());
local_queues_.QueueRunningTasks(std::vector<Task>({task}));
auto status = worker->Connection()->WriteMessage(protocol::MessageType_ExecuteTask,
fbb.GetSize(), fbb.GetBufferPointer());
if (status.ok()) {
// We started running the task, so the task is ready to write to GCS.
lineage_cache_.AddReadyTask(task);
} else {
// We failed to send the task to the worker, so disconnect the worker. The
// task will get queued again during cleanup.
ProcessClientMessage(worker->Connection(), protocol::MessageType_DisconnectClient,
NULL);
}
}
void NodeManager::FinishTask(const TaskID &task_id) {
RAY_LOG(DEBUG) << "Finished task " << task_id.hex();
local_queues_.RemoveTasks({task_id});
// TODO(swang): Release resources that were held for the task.
auto tasks = local_queues_.RemoveTasks({task_id});
RAY_CHECK(tasks.size() == 1);
auto task = *tasks.begin();
// Resource accounting: release task's resources.
RAY_CHECK(
this->cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release(
task.GetTaskSpecification().GetRequiredResources()));
}
void NodeManager::ResubmitTask(const TaskID &task_id) {
throw std::runtime_error("Method not implemented");
}
ray::Status NodeManager::ForwardTask(Task &task, const ClientID &node_id) {
auto task_id = task.GetTaskSpecification().TaskId();
// Get and serialize the task's uncommitted lineage.
auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id);
Task &lineage_cache_entry_task =
uncommitted_lineage.GetEntryMutable(task_id)->TaskDataMutable();
// Increment forward count for the forwarded task.
lineage_cache_entry_task.GetTaskExecutionSpec().IncrementNumForwards();
flatbuffers::FlatBufferBuilder fbb;
auto request = uncommitted_lineage.ToFlatbuffer(fbb, task_id);
fbb.Finish(request);
RAY_LOG(DEBUG) << "Forwarding task " << task_id.hex() << " to " << node_id.hex()
<< " spillback="
<< lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards();
auto client_info = gcs_client_->client_table().GetClient(node_id);
// Lookup remote server connection for this node_id and use it to send the request.
if (remote_server_connections_.count(node_id) == 0) {
// TODO(atumanov): caller must handle failure to ensure tasks are not lost.
RAY_LOG(INFO) << "No NodeManager connection found for GCS client id "
<< node_id.hex();
return ray::Status::IOError("NodeManager connection not found");
}
auto &server_conn = remote_server_connections_.at(node_id);
auto status = server_conn.WriteMessage(protocol::MessageType_ForwardTaskRequest,
fbb.GetSize(), fbb.GetBufferPointer());
if (status.ok()) {
// If we were able to forward the task, remove the forwarded task from the
// lineage cache since the receiving node is now responsible for writing
// the task to the GCS.
lineage_cache_.RemoveWaitingTask(task_id);
} else {
// TODO(atumanov): caller must handle ForwardTask failure to ensure tasks are not
// lost.
RAY_LOG(FATAL) << "[NodeManager][ForwardTask] failed to forward task "
<< task_id.hex() << " to node " << node_id.hex();
}
return status;
}
} // namespace raylet
} // namespace ray
+47 -12
View File
@@ -2,11 +2,13 @@
#define RAY_RAYLET_NODE_MANAGER_H
// clang-format off
#include "ray/raylet/task.h"
#include "ray/object_manager/object_manager.h"
#include "ray/common/client_connection.h"
#include "ray/raylet/lineage_cache.h"
#include "ray/raylet/scheduling_policy.h"
#include "ray/raylet/scheduling_queue.h"
#include "ray/raylet/scheduling_resources.h"
#include "ray/object_manager/object_manager.h"
#include "ray/raylet/reconstruction_policy.h"
#include "ray/raylet/task_dependency_manager.h"
#include "ray/raylet/worker_pool.h"
@@ -14,16 +16,24 @@
namespace ray {
class NodeManager : public ClientManager<boost::asio::local::stream_protocol> {
namespace raylet {
struct NodeManagerConfig {
ResourceSet resource_config;
int num_initial_workers;
std::vector<const char *> worker_command;
uint64_t heartbeat_period_ms;
};
class NodeManager {
public:
/// Create a node manager.
///
/// \param socket_name The pathname of the Unix domain socket to listen at
/// for local connections.
/// \param resource_config The initial set of node resources.
/// \param object_manager A reference to the local object manager.
NodeManager(const std::string &socket_name, const ResourceSet &resource_config,
ObjectManager &object_manager);
NodeManager(boost::asio::io_service &io_service, const NodeManagerConfig &config,
ObjectManager &object_manager,
std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
/// Process a new client connection.
void ProcessNewClient(std::shared_ptr<LocalClientConnection> client);
@@ -38,26 +48,41 @@ class NodeManager : public ClientManager<boost::asio::local::stream_protocol> {
void ProcessClientMessage(std::shared_ptr<LocalClientConnection> client,
int64_t message_type, const uint8_t *message);
void ProcessNewNodeManager(std::shared_ptr<TcpClientConnection> node_manager_client);
void ProcessNodeManagerMessage(std::shared_ptr<TcpClientConnection> node_manager_client,
int64_t message_type, const uint8_t *message);
void ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id,
const ClientTableDataT &data);
void HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &id,
const HeartbeatTableDataT &data);
private:
/// Submit a task to this node.
void SubmitTask(const Task &task);
void SubmitTask(const Task &task, const Lineage &uncommitted_lineage);
/// Assign a task.
void AssignTask(const Task &task);
/// Finish a task.
void FinishTask(const TaskID &task_id);
/// Schedule tasks.
void ScheduleTasks();
/// Handle a task whose local dependencies were missing and are now
/// available.
/// Handle a task whose local dependencies were missing and are now available.
void HandleWaitingTaskReady(const TaskID &task_id);
/// Resubmit a task whose return value needs to be reconstructed.
void ResubmitTask(const TaskID &task_id);
ray::Status ForwardTask(Task &task, const ClientID &node_id);
/// Send heartbeats to the GCS.
void Heartbeat();
boost::asio::io_service &io_service_;
boost::asio::deadline_timer heartbeat_timer_;
uint64_t heartbeat_period_ms_;
/// The resources local to this node.
SchedulingResources local_resources_;
const SchedulingResources local_resources_;
// TODO(atumanov): Add resource information from other nodes.
// std::unordered_map<ClientID, SchedulingResources&, UniqueIDHasher>
// cluster_resource_map_;
std::unordered_map<ClientID, SchedulingResources, UniqueIDHasher> cluster_resource_map_;
/// A pool of workers.
WorkerPool worker_pool_;
/// A set of queues to maintain tasks.
@@ -68,8 +93,18 @@ class NodeManager : public ClientManager<boost::asio::local::stream_protocol> {
ReconstructionPolicy reconstruction_policy_;
/// A manager to make waiting tasks's missing object dependencies available.
TaskDependencyManager task_dependency_manager_;
/// The lineage cache for the GCS object and task tables.
LineageCache lineage_cache_;
/// A client connection to the GCS.
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
std::vector<ClientID> remote_clients_;
std::unordered_map<ClientID, TcpServerConnection, UniqueIDHasher>
remote_server_connections_;
ObjectManager &object_manager_;
};
} // namespace raylet
} // end namespace ray
#endif // RAY_RAYLET_NODE_MANAGER_H
@@ -0,0 +1,235 @@
#include <iostream>
#include <thread>
#include "gtest/gtest.h"
#include "ray/raylet/raylet.h"
namespace ray {
namespace raylet {
std::string test_executable;
std::string store_executable;
// TODO(hme): Get this working once the dust settles.
class TestObjectManagerBase : public ::testing::Test {
public:
TestObjectManagerBase() { RAY_LOG(INFO) << "TestObjectManagerBase: started."; }
std::string StartStore(const std::string &id) {
std::string store_id = "/tmp/store";
store_id = store_id + id;
std::string plasma_command = store_executable + " -m 1000000000 -s " + store_id +
" 1> /dev/null 2> /dev/null &";
RAY_LOG(INFO) << plasma_command;
int ec = system(plasma_command.c_str());
if (ec != 0) {
throw std::runtime_error("failed to start plasma store.");
};
return store_id;
}
NodeManagerConfig GetNodeManagerConfig(std::string raylet_socket_name,
std::string store_socket_name) {
// Configuration for the node manager.
ray::raylet::NodeManagerConfig node_manager_config;
std::unordered_map<std::string, double> static_resource_conf;
static_resource_conf = {{"CPU", 1}, {"GPU", 1}};
node_manager_config.resource_config =
ray::raylet::ResourceSet(std::move(static_resource_conf));
node_manager_config.num_initial_workers = 0;
// Use a default worker that can execute empty tasks with dependencies.
node_manager_config.worker_command.push_back("python");
node_manager_config.worker_command.push_back(
"../../../src/ray/python/default_worker.py");
node_manager_config.worker_command.push_back(raylet_socket_name.c_str());
node_manager_config.worker_command.push_back(store_socket_name.c_str());
return node_manager_config;
};
void SetUp() {
object_manager_service_1.reset(new boost::asio::io_service());
object_manager_service_2.reset(new boost::asio::io_service());
// start store
std::string store_sock_1 = StartStore("1");
std::string store_sock_2 = StartStore("2");
// start first server
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
ObjectManagerConfig om_config_1;
om_config_1.store_socket_name = store_sock_1;
server1.reset(new ray::raylet::Raylet(
main_service, std::move(object_manager_service_1), "raylet_1", "127.0.0.1", 6379,
GetNodeManagerConfig("raylet_1", store_sock_1), om_config_1, gcs_client_1));
// start second server
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
ObjectManagerConfig om_config_2;
om_config_2.store_socket_name = store_sock_2;
server2.reset(new ray::raylet::Raylet(
main_service, std::move(object_manager_service_2), "raylet_2", "127.0.0.1", 6379,
GetNodeManagerConfig("raylet_2", store_sock_2), om_config_2, gcs_client_2));
// connect to stores.
ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY));
ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY));
}
void TearDown() {
arrow::Status client1_status = client1.Disconnect();
arrow::Status client2_status = client2.Disconnect();
ASSERT_TRUE(client1_status.ok() && client2_status.ok());
this->server1.reset();
this->server2.reset();
int s = system("killall plasma_store &");
ASSERT_TRUE(!s);
std::string cmd_str = test_executable.substr(0, test_executable.find_last_of("/"));
s = system(("rm " + cmd_str + "/raylet_1").c_str());
ASSERT_TRUE(!s);
s = system(("rm " + cmd_str + "/raylet_2").c_str());
ASSERT_TRUE(!s);
}
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
ObjectID object_id = ObjectID::from_random();
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id;
uint8_t metadata[] = {5};
int64_t metadata_size = sizeof(metadata);
std::shared_ptr<Buffer> data;
ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata,
metadata_size, &data));
ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id()));
return object_id;
}
protected:
std::thread p;
boost::asio::io_service main_service;
std::unique_ptr<boost::asio::io_service> object_manager_service_1;
std::unique_ptr<boost::asio::io_service> object_manager_service_2;
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_1;
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_2;
std::unique_ptr<ray::raylet::Raylet> server1;
std::unique_ptr<ray::raylet::Raylet> server2;
plasma::PlasmaClient client1;
plasma::PlasmaClient client2;
std::vector<ObjectID> v1;
std::vector<ObjectID> v2;
};
class TestObjectManagerIntegration : public TestObjectManagerBase {
public:
uint num_expected_objects;
int num_connected_clients = 0;
ClientID client_id_1;
ClientID client_id_2;
void WaitConnections() {
client_id_1 = gcs_client_1->client_table().GetLocalClientId();
client_id_2 = gcs_client_2->client_table().GetLocalClientId();
gcs_client_1->client_table().RegisterClientAddedCallback([this](
gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
ClientID parsed_id = ClientID::from_binary(data.client_id);
if (parsed_id == client_id_1 || parsed_id == client_id_2) {
num_connected_clients += 1;
}
if (num_connected_clients == 2) {
StartTests();
}
});
}
void StartTests() {
TestConnections();
AddTransferTestHandlers();
TestPush(100);
}
void AddTransferTestHandlers() {
ray::Status status = ray::Status::OK();
status =
server1->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
v1.push_back(object_id);
if (v1.size() == num_expected_objects && v1.size() == v2.size()) {
TestPushComplete();
}
});
RAY_CHECK_OK(status);
status =
server2->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
v2.push_back(object_id);
if (v2.size() == num_expected_objects && v1.size() == v2.size()) {
TestPushComplete();
}
});
RAY_CHECK_OK(status);
}
void TestPush(int64_t data_size) {
ray::Status status = ray::Status::OK();
num_expected_objects = (uint)1;
ObjectID oid1 = WriteDataToClient(client1, data_size);
status = server1->object_manager_.Push(oid1, client_id_2);
}
void TestPushComplete() {
RAY_LOG(INFO) << "TestPushComplete: "
<< " " << v1.size() << " " << v2.size();
ASSERT_TRUE(v1.size() == v2.size());
for (int i = -1; ++i < (int)v1.size();) {
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
}
v1.clear();
v2.clear();
main_service.stop();
}
void TestConnections() {
RAY_LOG(INFO) << "\n"
<< "Server client ids:"
<< "\n";
ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId();
ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId();
RAY_LOG(INFO) << "Server 1: " << client_id_1;
RAY_LOG(INFO) << "Server 2: " << client_id_2;
RAY_LOG(INFO) << "\n"
<< "All connected clients:"
<< "\n";
const ClientTableDataT &data = gcs_client_2->client_table().GetClient(client_id_1);
RAY_LOG(INFO) << (ClientID::from_binary(data.client_id) == ClientID::nil());
RAY_LOG(INFO) << "ClientID=" << ClientID::from_binary(data.client_id);
RAY_LOG(INFO) << "ClientIp=" << data.node_manager_address;
RAY_LOG(INFO) << "ClientPort=" << data.node_manager_port;
const ClientTableDataT &data2 = gcs_client_1->client_table().GetClient(client_id_2);
RAY_LOG(INFO) << "ClientID=" << ClientID::from_binary(data2.client_id);
RAY_LOG(INFO) << "ClientIp=" << data2.node_manager_address;
RAY_LOG(INFO) << "ClientPort=" << data2.node_manager_port;
}
};
TEST_F(TestObjectManagerIntegration, StartTestObjectManagerPush) {
auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); });
AsyncStartTests();
main_service.run();
}
} // namespace raylet
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
ray::raylet::test_executable = std::string(argv[0]);
ray::raylet::store_executable = std::string(argv[1]);
return RUN_ALL_TESTS();
}
+120 -37
View File
@@ -1,56 +1,129 @@
#include "raylet.h"
#include <boost/asio.hpp>
#include <boost/bind.hpp>
#include <boost/date_time/posix_time/posix_time.hpp>
#include <iostream>
#include "ray/status.h"
namespace ray {
Raylet::Raylet(boost::asio::io_service &io_service, const std::string &socket_name,
const ResourceSet &resource_config,
namespace raylet {
Raylet::Raylet(boost::asio::io_service &main_service,
std::unique_ptr<boost::asio::io_service> object_manager_service,
const std::string &socket_name, const std::string &redis_address,
int redis_port, const NodeManagerConfig &node_manager_config,
const ObjectManagerConfig &object_manager_config,
std::shared_ptr<ray::GcsClient> gcs_client)
: acceptor_(io_service, boost::asio::local::stream_protocol::endpoint(socket_name)),
socket_(io_service),
tcp_acceptor_(io_service,
boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
tcp_socket_(io_service),
object_manager_(io_service, object_manager_config, gcs_client),
node_manager_(socket_name, resource_config, object_manager_),
gcs_client_(gcs_client) {
ClientID client_id = RegisterGcs();
object_manager_.SetClientID(client_id);
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
: acceptor_(main_service, boost::asio::local::stream_protocol::endpoint(socket_name)),
socket_(main_service),
object_manager_acceptor_(
main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
object_manager_socket_(main_service),
node_manager_acceptor_(
main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
node_manager_socket_(main_service),
gcs_client_(gcs_client),
object_manager_(main_service, std::move(object_manager_service),
object_manager_config, gcs_client),
node_manager_(main_service, node_manager_config, object_manager_, gcs_client_) {
// Start listening for clients.
DoAccept();
DoAcceptTcp();
DoAcceptObjectManager();
DoAcceptNodeManager();
RAY_CHECK_OK(RegisterGcs(redis_address, redis_port, main_service, node_manager_config));
RAY_CHECK_OK(RegisterPeriodicTimer(main_service));
}
Raylet::~Raylet() { RAY_CHECK_OK(object_manager_.Terminate()); }
ClientID Raylet::RegisterGcs() {
boost::asio::ip::tcp::endpoint endpoint = tcp_acceptor_.local_endpoint();
std::string ip = endpoint.address().to_string();
uint16_t port = endpoint.port();
ClientID client_id = gcs_client_->Register(ip, port);
return client_id;
Raylet::~Raylet() {
RAY_CHECK_OK(gcs_client_->client_table().Disconnect());
RAY_CHECK_OK(object_manager_.Terminate());
}
void Raylet::DoAcceptTcp() {
TCPClientConnection::pointer new_connection =
TCPClientConnection::Create(acceptor_.get_io_service());
tcp_acceptor_.async_accept(new_connection->GetSocket(),
boost::bind(&Raylet::HandleAcceptTcp, this, new_connection,
boost::asio::placeholders::error));
ray::Status Raylet::RegisterPeriodicTimer(boost::asio::io_service &io_service) {
boost::posix_time::milliseconds timer_period_ms(100);
boost::asio::deadline_timer timer(io_service, timer_period_ms);
return ray::Status::OK();
}
void Raylet::HandleAcceptTcp(TCPClientConnection::pointer new_connection,
const boost::system::error_code &error) {
if (!error) {
// Pass it off to object manager for now.
ray::Status status = object_manager_.AcceptConnection(std::move(new_connection));
ray::Status Raylet::RegisterGcs(const std::string &redis_address, int redis_port,
boost::asio::io_service &io_service,
const NodeManagerConfig &node_manager_config) {
RAY_RETURN_NOT_OK(gcs_client_->Connect(redis_address, redis_port));
RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service));
ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient();
client_info.node_manager_address =
node_manager_acceptor_.local_endpoint().address().to_string();
client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port();
client_info.node_manager_port = node_manager_acceptor_.local_endpoint().port();
// Add resource information.
for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) {
client_info.resources_total_label.push_back(resource_pair.first);
client_info.resources_total_capacity.push_back(resource_pair.second);
}
DoAcceptTcp();
RAY_LOG(DEBUG) << "NM LISTENING ON: IP " << client_info.node_manager_address << " PORT "
<< client_info.node_manager_port;
RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info));
auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id,
const ClientTableDataT &data) {
node_manager_.ClientAdded(client, id, data);
};
gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added);
return Status::OK();
}
void Raylet::DoAcceptNodeManager() {
node_manager_acceptor_.async_accept(node_manager_socket_,
boost::bind(&Raylet::HandleAcceptNodeManager, this,
boost::asio::placeholders::error));
}
void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) {
if (!error) {
ClientHandler<boost::asio::ip::tcp> client_handler =
[this](std::shared_ptr<TcpClientConnection> client) {
node_manager_.ProcessNewNodeManager(client);
};
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
const uint8_t *message) {
node_manager_.ProcessNodeManagerMessage(client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
std::move(node_manager_socket_));
}
// We're ready to accept another client.
DoAcceptNodeManager();
}
void Raylet::DoAcceptObjectManager() {
object_manager_acceptor_.async_accept(
object_manager_socket_, boost::bind(&Raylet::HandleAcceptObjectManager, this,
boost::asio::placeholders::error));
}
void Raylet::HandleAcceptObjectManager(const boost::system::error_code &error) {
ClientHandler<boost::asio::ip::tcp> client_handler =
[this](std::shared_ptr<TcpClientConnection> client) {
object_manager_.ProcessNewClient(client);
};
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
const uint8_t *message) {
object_manager_.ProcessClientMessage(client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
std::move(object_manager_socket_));
DoAcceptObjectManager();
}
void Raylet::DoAccept() {
@@ -60,14 +133,24 @@ void Raylet::DoAccept() {
void Raylet::HandleAccept(const boost::system::error_code &error) {
if (!error) {
// TODO: typedef these handlers.
ClientHandler<boost::asio::local::stream_protocol> client_handler =
[this](std::shared_ptr<LocalClientConnection> client) {
node_manager_.ProcessNewClient(client);
};
MessageHandler<boost::asio::local::stream_protocol> message_handler = [this](
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
const uint8_t *message) {
node_manager_.ProcessClientMessage(client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection =
LocalClientConnection::Create(node_manager_, std::move(socket_));
auto new_connection = LocalClientConnection::Create(client_handler, message_handler,
std::move(socket_));
}
// We're ready to accept another client.
DoAccept();
}
ObjectManager &Raylet::GetObjectManager() { return object_manager_; }
} // namespace raylet
} // namespace ray
+30 -18
View File
@@ -14,63 +14,75 @@
namespace ray {
namespace raylet {
class Task;
class NodeManager;
// TODO(swang): Rename class and source files to Raylet.
class Raylet {
public:
/// Create a node manager server and listen for new clients.
///
/// \param io_service The event loop to run the server on.
/// \param main_service The event loop to run the server on.
/// \param object_manager_service The asio io_service tied to the object manager.
/// \param socket_name The Unix domain socket to listen on for local clients.
/// \param resource_config The initial set of resources to start the local
/// \param node_manager_config Configuration to initialize the node manager.
/// scheduler with.
/// \param object_manager_config Configuration to initialize the object
/// manager.
/// \param gcs_client A client connection to the GCS.
Raylet(boost::asio::io_service &io_service, const std::string &socket_name,
const ResourceSet &resource_config,
Raylet(boost::asio::io_service &main_service,
std::unique_ptr<boost::asio::io_service> object_manager_service,
const std::string &socket_name, const std::string &redis_address, int redis_port,
const NodeManagerConfig &node_manager_config,
const ObjectManagerConfig &object_manager_config,
std::shared_ptr<ray::GcsClient> gcs_client);
std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
/// Destroy the NodeServer.
~Raylet();
// TODO(melih): Get rid of this method.
ObjectManager &GetObjectManager();
private:
/// Register GCS client.
ClientID RegisterGcs();
ray::Status RegisterGcs(const std::string &redis_address, int redis_port,
boost::asio::io_service &io_service, const NodeManagerConfig &);
ray::Status RegisterPeriodicTimer(boost::asio::io_service &io_service);
/// Accept a client connection.
void DoAccept();
/// Handle an accepted client connection.
void HandleAccept(const boost::system::error_code &error);
/// Accept a tcp client connection.
void DoAcceptTcp();
void DoAcceptObjectManager();
/// Handle an accepted tcp client connection.
void HandleAcceptTcp(TCPClientConnection::pointer new_connection,
const boost::system::error_code &error);
void HandleAcceptObjectManager(const boost::system::error_code &error);
void DoAcceptNodeManager();
void HandleAcceptNodeManager(const boost::system::error_code &error);
friend class TestObjectManagerIntegration;
/// An acceptor for new clients.
boost::asio::local::stream_protocol::acceptor acceptor_;
/// The socket to listen on for new clients.
boost::asio::local::stream_protocol::socket socket_;
/// An acceptor for new object manager tcp clients.
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
/// The socket to listen on for new object manager tcp clients.
boost::asio::ip::tcp::socket object_manager_socket_;
/// An acceptor for new tcp clients.
boost::asio::ip::tcp::acceptor tcp_acceptor_;
boost::asio::ip::tcp::acceptor node_manager_acceptor_;
/// The socket to listen on for new tcp clients.
boost::asio::ip::tcp::socket tcp_socket_;
boost::asio::ip::tcp::socket node_manager_socket_;
// TODO(swang): Lineage cache.
/// A client connection to the GCS.
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
/// Manages client requests for object transfers and availability.
ObjectManager object_manager_;
/// Manages client requests for task submission and execution.
NodeManager node_manager_;
/// A client connection to the GCS.
std::shared_ptr<ray::GcsClient> gcs_client_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_RAYLET_H
-275
View File
@@ -1,275 +0,0 @@
#include <iostream>
#include <thread>
#include "gtest/gtest.h"
#include "ray/raylet/raylet.h"
namespace ray {
std::string test_executable; // NOLINT
class TestRaylet : public ::testing::Test {
public:
TestRaylet() { RAY_LOG(INFO) << "TestRaylet: started."; }
std::string StartStore(const std::string &id) {
std::string store_id = "/tmp/store";
store_id = store_id + id;
std::string test_dir = test_executable.substr(0, test_executable.find_last_of("/"));
std::string plasma_dir = test_dir + "./../plasma";
std::string plasma_command = plasma_dir + "/plasma_store -m 1000000000 -s " +
store_id + " 1> /dev/null 2> /dev/null &";
RAY_LOG(INFO) << plasma_command;
int ec = system(plasma_command.c_str());
if (ec != 0) {
throw std::runtime_error("failed to start plasma store.");
};
return store_id;
}
void SetUp() {
// start store
std::string store_sock_1 = StartStore("1");
std::string store_sock_2 = StartStore("2");
// configure
std::unordered_map<std::string, double> static_resource_config;
static_resource_config = {{"num_cpus", 1}, {"num_gpus", 1}};
ray::ResourceSet resource_config(std::move(static_resource_config));
// start mock gcs
mock_gcs_client = std::shared_ptr<GcsClient>(new GcsClient());
// start first server
ray::ObjectManagerConfig om_config_1;
om_config_1.store_socket_name = store_sock_1;
server1.reset(new Raylet(io_service, std::string("hello1"), resource_config,
om_config_1, mock_gcs_client));
// start second server
ray::ObjectManagerConfig om_config_2;
om_config_2.store_socket_name = store_sock_2;
server2.reset(new Raylet(io_service, std::string("hello2"), resource_config,
om_config_2, mock_gcs_client));
// connect to stores.
ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY));
ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY));
this->StartLoop();
}
void TearDown() {
this->StopLoop();
arrow::Status client1_status = client1.Disconnect();
arrow::Status client2_status = client2.Disconnect();
ASSERT_TRUE(client1_status.ok() && client2_status.ok());
this->server1.reset();
this->server2.reset();
int s = system("killall plasma_store &");
ASSERT_TRUE(!s);
std::string cmd_str = test_executable.substr(0, test_executable.find_last_of("/"));
s = system(("rm " + cmd_str + "/hello1").c_str());
ASSERT_TRUE(!s);
s = system(("rm " + cmd_str + "/hello2").c_str());
ASSERT_TRUE(!s);
}
void Loop() { io_service.run(); };
void StartLoop() { p = std::thread(&TestRaylet::Loop, this); };
void StopLoop() {
io_service.stop();
p.join();
}
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
ObjectID object_id = ObjectID::from_random();
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id.hex().c_str();
uint8_t metadata[] = {5};
int64_t metadata_size = sizeof(metadata);
std::shared_ptr<Buffer> data;
ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata,
metadata_size, &data));
ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id()));
return object_id;
}
void object_added_handler_1(const ObjectID &object_id) {
RAY_LOG(INFO) << "Store 1 added: " << object_id.hex();
v1.push_back(object_id);
};
void object_added_handler_2(const ObjectID &object_id) {
RAY_LOG(INFO) << "Store 2 added: " << object_id.hex();
v2.push_back(object_id);
};
protected:
std::thread p;
boost::asio::io_service io_service;
std::shared_ptr<ray::GcsClient> mock_gcs_client;
std::unique_ptr<ray::Raylet> server1;
std::unique_ptr<ray::Raylet> server2;
plasma::PlasmaClient client1;
plasma::PlasmaClient client2;
std::vector<ObjectID> v1;
std::vector<ObjectID> v2;
};
TEST_F(TestRaylet, TestRayletCommands) {
ray::Status status = ray::Status::OK();
// TODO(atumanov): assert status is OK everywhere it's returned.
RAY_LOG(INFO) << "\n"
<< "All connected clients:"
<< "\n";
status = mock_gcs_client->client_table().GetClientIds(
[this](const std::vector<ClientID> &client_ids) {
mock_gcs_client->client_table().GetClientInformationSet(
client_ids,
[this](const std::vector<ClientInformation> &info_vec) {
for (const auto &info : info_vec) {
RAY_LOG(INFO) << "ClientID=" << info.GetClientId().hex();
RAY_LOG(INFO) << "ClientIp=" << info.GetIp();
RAY_LOG(INFO) << "ClientPort=" << info.GetPort();
}
},
[](Status status) {});
});
sleep(1);
RAY_LOG(INFO) << "\n"
<< "Server client ids:"
<< "\n";
status = server1->GetObjectManager().SubscribeObjAdded(
[this](const ObjectID &object_id) { object_added_handler_1(object_id); });
ASSERT_TRUE(status.ok());
status = server2->GetObjectManager().SubscribeObjAdded(
[this](const ObjectID &object_id) { object_added_handler_2(object_id); });
ASSERT_TRUE(status.ok());
ClientID client_id_1 = server1->GetObjectManager().GetClientID();
ClientID client_id_2 = server2->GetObjectManager().GetClientID();
RAY_LOG(INFO) << "Server 1: " << client_id_1.hex();
RAY_LOG(INFO) << "Server 2: " << client_id_2.hex();
sleep(1);
RAY_LOG(INFO) << "\n"
<< "Test bidirectional pull"
<< "\n";
for (int i = -1; ++i < 100;) {
ObjectID oid1 = WriteDataToClient(client1, 100);
ObjectID oid2 = WriteDataToClient(client2, 100);
status = server1->GetObjectManager().Pull(oid2);
status = server2->GetObjectManager().Pull(oid1);
}
sleep(1);
RAY_LOG(INFO) << v1.size() << " " << v2.size();
ASSERT_TRUE(v1.size() == v2.size());
for (int i = -1; ++i < (int)v1.size();) {
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
}
v1.clear();
v2.clear();
RAY_LOG(INFO) << "\n"
<< "Test pull 1 from 2"
<< "\n";
for (int i = -1; ++i < 3;) {
ObjectID oid2 = WriteDataToClient(client2, 100);
status = server1->GetObjectManager().Pull(oid2);
}
sleep(1);
RAY_LOG(INFO) << v1.size() << " " << v2.size();
ASSERT_TRUE(v1.size() == v2.size());
for (int i = -1; ++i < (int)v1.size();) {
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
}
v1.clear();
v2.clear();
RAY_LOG(INFO) << "\n"
<< "Test pull 2 from 1"
<< "\n";
for (int i = -1; ++i < 3;) {
ObjectID oid1 = WriteDataToClient(client1, 100);
status = server2->GetObjectManager().Pull(oid1);
}
sleep(1);
RAY_LOG(INFO) << v1.size() << " " << v2.size();
ASSERT_TRUE(v1.size() == v2.size());
for (int i = -1; ++i < (int)v1.size();) {
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
}
v1.clear();
v2.clear();
RAY_LOG(INFO) << "\n"
<< "Test push 1 to 2"
<< "\n";
for (int i = -1; ++i < 3;) {
ObjectID oid1 = WriteDataToClient(client1, 100);
status = server1->GetObjectManager().Push(oid1, client_id_2);
}
sleep(1);
RAY_LOG(INFO) << v1.size() << " " << v2.size();
ASSERT_TRUE(v1.size() == v2.size());
for (int i = -1; ++i < (int)v1.size();) {
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
}
v1.clear();
v2.clear();
RAY_LOG(INFO) << "\n"
<< "Test push 2 to 1"
<< "\n";
for (int i = -1; ++i < 3;) {
ObjectID oid2 = WriteDataToClient(client2, 100);
status = server2->GetObjectManager().Push(oid2, client_id_1);
}
sleep(1);
RAY_LOG(INFO) << v1.size() << " " << v2.size();
ASSERT_TRUE(v1.size() == v2.size());
for (int i = -1; ++i < (int)v1.size();) {
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
}
v1.clear();
v2.clear();
RAY_LOG(INFO) << "\n"
<< "Test bidirectional push"
<< "\n";
for (int i = -1; ++i < 3;) {
ObjectID oid1 = WriteDataToClient(client1, 100);
ObjectID oid2 = WriteDataToClient(client2, 100);
status = server1->GetObjectManager().Push(oid1, client_id_2);
status = server2->GetObjectManager().Push(oid2, client_id_1);
}
sleep(1);
RAY_LOG(INFO) << v1.size() << " " << v2.size();
ASSERT_TRUE(v1.size() == v2.size());
for (int i = -1; ++i < (int)v1.size();) {
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
}
v1.clear();
v2.clear();
ASSERT_TRUE(true);
}
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
ray::test_executable = std::string(argv[0]);
return RUN_ALL_TESTS();
}
+4
View File
@@ -2,8 +2,12 @@
namespace ray {
namespace raylet {
void ReconstructionPolicy::CheckObjectReconstruction(const ObjectID &object) {
throw std::runtime_error("Method not implemented");
}
} // namespace raylet
} // end namespace ray
+4
View File
@@ -7,6 +7,8 @@
namespace ray {
namespace raylet {
// TODO(swang): Use std::function instead of boost.
class ReconstructionPolicy {
@@ -29,6 +31,8 @@ class ReconstructionPolicy {
private:
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_RECONSTRUCTION_POLICY_H
@@ -1,42 +0,0 @@
#include <iostream>
#include "ray/raylet/raylet.h"
/// A demo that starts two Raylets, with one object store each. The two Raylets
/// share a mock GCS client for communication between the two (e.g., for
/// ObjectManager::Push).
int main(int argc, char *argv[]) {
RAY_CHECK(argc == 3);
std::string store1 = "/tmp/store1";
std::string store2 = "/tmp/store2";
// start store
std::string plasma_dir = "../../plasma";
std::string plasma_command1 = plasma_dir + "/plasma_store -m 1000000000 -s ";
std::string plasma_command2 = " 1> /dev/null 2> /dev/null &";
RAY_LOG(INFO) << plasma_command1 << store1 << plasma_command2;
RAY_LOG(INFO) << plasma_command1 << store2 << plasma_command2;
int s;
s = system((plasma_command1 + store1 + plasma_command2).c_str());
RAY_CHECK(s == 0);
s = system((plasma_command1 + store2 + plasma_command2).c_str());
// configure
std::unordered_map<std::string, double> static_resource_conf;
static_resource_conf = {{"CPU", 1}, {"GPU", 1}};
ray::ResourceSet resource_config(std::move(static_resource_conf));
ray::ObjectManagerConfig om_config;
// initialize mock gcs & object directory
std::shared_ptr<ray::GcsClient> mock_gcs_client =
std::shared_ptr<ray::GcsClient>(new ray::GcsClient());
// Initialize the node manager.
boost::asio::io_service io_service;
om_config.store_socket_name = store1;
ray::Raylet server1(io_service, std::string(argv[1]), resource_config, om_config,
mock_gcs_client);
om_config.store_socket_name = store2;
ray::Raylet server2(io_service, std::string(argv[2]), resource_config, om_config,
mock_gcs_client);
io_service.run();
}
+53 -10
View File
@@ -1,32 +1,75 @@
#include "scheduling_policy.h"
#include "ray/util/logging.h"
namespace ray {
namespace raylet {
SchedulingPolicy::SchedulingPolicy(const SchedulingQueue &scheduling_queue)
: scheduling_queue_(scheduling_queue) {}
: scheduling_queue_(scheduling_queue), gen_(rd_()) {}
std::unordered_map<TaskID, ClientID, UniqueIDHasher> SchedulingPolicy::Schedule(
const std::unordered_map<ClientID, SchedulingResources, UniqueIDHasher>
&cluster_resources) {
static ClientID local_node_id = ClientID::nil();
&cluster_resources,
const ClientID &local_client_id, const std::vector<ClientID> &others) {
// The policy decision to be returned.
std::unordered_map<TaskID, ClientID, UniqueIDHasher> decision;
// TODO(atumanov): consider all cluster resources.
SchedulingResources resource_supply = cluster_resources.at(local_node_id);
const auto &resource_supply_set = resource_supply.GetAvailableResources();
// TODO(atumanov): protect DEBUG code blocks with ifdef DEBUG
RAY_LOG(DEBUG) << "[Schedule] cluster resource map: ";
for (const auto &client_resource_pair : cluster_resources) {
// pair = ClientID, SchedulingResources
const ClientID &client_id = client_resource_pair.first;
const SchedulingResources &resources = client_resource_pair.second;
RAY_LOG(DEBUG) << "client_id: " << client_id << " "
<< resources.GetAvailableResources().ToString();
}
// Iterate over running tasks, get their resource demand and try to schedule.
for (const auto &t : scheduling_queue_.GetReadyTasks()) {
// Get task's resource demand
const auto &resource_demand = t.GetTaskSpecification().GetRequiredResources();
bool task_feasible = resource_demand.IsSubset(resource_supply_set);
if (task_feasible) {
const TaskID &task_id = t.GetTaskSpecification().TaskId();
decision[task_id] = local_node_id;
const TaskID &task_id = t.GetTaskSpecification().TaskId();
RAY_LOG(DEBUG) << "[SchedulingPolicy]: task=" << task_id
<< " numforwards=" << t.GetTaskExecutionSpecReadonly().NumForwards()
<< " resources="
<< t.GetTaskSpecification().GetRequiredResources().ToString();
// TODO(atumanov): replace the simple spillback policy with exponential backoff based
// policy.
if (t.GetTaskExecutionSpecReadonly().NumForwards() >= 1) {
decision[task_id] = local_client_id;
continue;
}
// Construct a set of viable node candidates and randomly pick between them.
// Get all the client id keys and randomly pick.
std::vector<ClientID> client_keys;
for (const auto &client_resource_pair : cluster_resources) {
// pair = ClientID, SchedulingResources
ClientID node_client_id = client_resource_pair.first;
SchedulingResources node_resources = client_resource_pair.second;
RAY_LOG(DEBUG) << "client_id " << node_client_id << " resources: "
<< node_resources.GetAvailableResources().ToString();
if (resource_demand.IsSubset(node_resources.GetTotalResources())) {
// This node is a feasible candidate.
client_keys.push_back(node_client_id);
}
}
RAY_CHECK(!client_keys.empty());
// Choose index at random.
// Initialize a uniform integer distribution over the key space.
// TODO(atumanov): change uniform random to discrete, weighted by resource capacity.
std::uniform_int_distribution<int> distribution(0, client_keys.size() - 1);
int client_key_index = distribution(gen_);
decision[task_id] = client_keys[client_key_index];
RAY_LOG(DEBUG) << "[SchedulingPolicy] idx=" << client_key_index << " " << task_id
<< " --> " << client_keys[client_key_index];
}
return decision;
}
SchedulingPolicy::~SchedulingPolicy() {}
} // namespace raylet
} // namespace ray
+11 -1
View File
@@ -1,6 +1,7 @@
#ifndef RAY_RAYLET_SCHEDULING_POLICY_H
#define RAY_RAYLET_SCHEDULING_POLICY_H
#include <random>
#include <unordered_map>
#include "ray/raylet/scheduling_queue.h"
@@ -8,6 +9,8 @@
namespace ray {
namespace raylet {
/// \class SchedulingPolicy
/// \brief Implements a scheduling policy for the node manager.
class SchedulingPolicy {
@@ -27,7 +30,8 @@ class SchedulingPolicy {
/// \return Scheduling decision, mapping tasks to node managers for placement.
std::unordered_map<TaskID, ClientID, UniqueIDHasher> Schedule(
const std::unordered_map<ClientID, SchedulingResources, UniqueIDHasher>
&cluster_resources);
&cluster_resources,
const ClientID &local_client_id, const std::vector<ClientID> &others);
/// \brief SchedulingPolicy destructor.
virtual ~SchedulingPolicy();
@@ -35,8 +39,14 @@ class SchedulingPolicy {
private:
/// An immutable reference to the scheduling task queues.
const SchedulingQueue &scheduling_queue_;
/// Internally maintained random number engine device.
std::random_device rd_;
/// Internally maintained random number generator.
std::mt19937_64 gen_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_SCHEDULING_POLICY_H
+4
View File
@@ -4,6 +4,8 @@
namespace ray {
namespace raylet {
const std::list<Task> &SchedulingQueue::GetWaitingTasks() const {
return this->waiting_tasks_;
}
@@ -88,4 +90,6 @@ bool SchedulingQueue::RegisterActor(ActorID actor_id,
return true;
}
} // namespace raylet
} // namespace ray
+5
View File
@@ -11,6 +11,8 @@
namespace ray {
namespace raylet {
/// \class SchedulingQueue
///
/// Encapsulates task queues. Each queue represents a scheduling state for a
@@ -103,6 +105,9 @@ class SchedulingQueue {
/// The registry of known actors.
std::unordered_map<ActorID, ActorInformation, UniqueIDHasher> actor_registry_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_SCHEDULING_QUEUE_H
+64 -3
View File
@@ -2,13 +2,25 @@
#include <cmath>
#include "ray/util/logging.h"
namespace ray {
namespace raylet {
ResourceSet::ResourceSet() {}
ResourceSet::ResourceSet(const std::unordered_map<std::string, double> &resource_map)
: resource_capacity_(resource_map) {}
ResourceSet::ResourceSet(const std::vector<std::string> &resource_labels,
const std::vector<double> resource_capacity) {
RAY_CHECK(resource_labels.size() == resource_capacity.size());
for (uint i = 0; i < resource_labels.size(); i++) {
RAY_CHECK(this->AddResource(resource_labels[i], resource_capacity[i]));
}
}
ResourceSet::~ResourceSet() {}
bool ResourceSet::operator==(const ResourceSet &rhs) const {
@@ -43,16 +55,42 @@ bool ResourceSet::IsEqual(const ResourceSet &rhs) const {
}
bool ResourceSet::AddResource(const std::string &resource_name, double capacity) {
throw std::runtime_error("Method not implemented");
this->resource_capacity_[resource_name] = capacity;
return true;
}
bool ResourceSet::RemoveResource(const std::string &resource_name) {
throw std::runtime_error("Method not implemented");
}
bool ResourceSet::SubtractResources(const ResourceSet &other) {
throw std::runtime_error("Method not implemented");
// Return failure if attempting to perform vector subtraction with unknown labels.
// TODO(atumanov): make the implementation atomic. Currently, if false is returned
// the resource capacity may be partially mutated. To reverse, call AddResources.
for (const auto &resource_pair : other.GetResourceMap()) {
const std::string &resource_label = resource_pair.first;
const double &resource_capacity = resource_pair.second;
if (resource_capacity_.count(resource_label) == 0) {
return false;
} else {
resource_capacity_[resource_label] -= resource_capacity;
}
}
return true;
}
bool ResourceSet::AddResources(const ResourceSet &other) {
throw std::runtime_error("Method not implemented");
// Return failure if attempting to perform vector addition with unknown labels.
// TODO(atumanov): make the implementation atomic. Currently, if false is returned
// the resource capacity may be partially mutated. To reverse, call SubtractResources.
for (const auto &resource_pair : other.GetResourceMap()) {
const std::string &resource_label = resource_pair.first;
const double &resource_capacity = resource_pair.second;
if (resource_capacity_.count(resource_label) == 0) {
return false;
} else {
resource_capacity_[resource_label] += resource_capacity;
}
}
return true;
}
bool ResourceSet::GetResource(const std::string &resource_name, double *value) const {
@@ -67,6 +105,19 @@ bool ResourceSet::GetResource(const std::string &resource_name, double *value) c
return true;
}
const std::string ResourceSet::ToString() const {
std::string return_string = "";
for (const auto &resource_pair : this->resource_capacity_) {
return_string +=
"{" + resource_pair.first + "," + std::to_string(resource_pair.second) + "}, ";
}
return return_string;
}
const std::unordered_map<std::string, double> &ResourceSet::GetResourceMap() const {
return this->resource_capacity_;
};
/// SchedulingResources class implementation
SchedulingResources::SchedulingResources()
@@ -93,6 +144,14 @@ const ResourceSet &SchedulingResources::GetAvailableResources() const {
return this->resources_available_;
}
void SchedulingResources::SetAvailableResources(ResourceSet &&newset) {
this->resources_available_ = newset;
}
const ResourceSet &SchedulingResources::GetTotalResources() const {
return this->resources_total_;
}
// Return specified resources back to SchedulingResources.
bool SchedulingResources::Release(const ResourceSet &resources) {
return this->resources_available_.AddResources(resources);
@@ -103,4 +162,6 @@ bool SchedulingResources::Acquire(const ResourceSet &resources) {
return this->resources_available_.SubtractResources(resources);
}
} // namespace raylet
} // namespace ray
+23
View File
@@ -4,9 +4,12 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace ray {
namespace raylet {
/// Resource availability status reports whether the resource requirement is
/// (1) infeasible, (2) feasible but currently unavailable, or (3) available.
typedef enum {
@@ -26,6 +29,11 @@ class ResourceSet {
/// \brief Constructs ResourceSet from the specified resource map.
ResourceSet(const std::unordered_map<std::string, double> &resource_map);
/// \brief Constructs ResourceSet from two equal-length vectors with label and capacity
/// specification.
ResourceSet(const std::vector<std::string> &resource_labels,
const std::vector<double> resource_capacity);
/// \brief Empty ResourceSet destructor.
~ResourceSet();
@@ -89,6 +97,11 @@ class ResourceSet {
/// False otherwise.
bool GetResource(const std::string &resource_name, double *value) const;
// TODO(atumanov): implement const_iterator class for the ResourceSet container.
const std::unordered_map<std::string, double> &GetResourceMap() const;
const std::string ToString() const;
private:
/// Resource capacity map.
std::unordered_map<std::string, double> resource_capacity_;
@@ -125,6 +138,14 @@ class SchedulingResources {
/// \return Immutable set of resources with currently available capacity.
const ResourceSet &GetAvailableResources() const;
/// \brief Overwrite available resource capacity with the specified resource set.
///
/// \param newset: The set of resources that replaces available resource capacity.
/// \return None.
void SetAvailableResources(ResourceSet &&newset);
const ResourceSet &GetTotalResources() const;
/// \brief Release the amount of resources specified.
///
/// \param resources: the amount of resources to be released.
@@ -145,6 +166,8 @@ class SchedulingResources {
/// gpu_map - replace with ResourceMap (for generality).
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_SCHEDULING_RESOURCES_H
+14 -1
View File
@@ -2,7 +2,18 @@
namespace ray {
const TaskExecutionSpecification &Task::GetTaskExecutionSpec() const {
namespace raylet {
flatbuffers::Offset<protocol::Task> Task::ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const {
auto task = CreateTask(fbb, task_spec_.ToFlatbuffer(fbb),
task_execution_spec_.ToFlatbuffer(fbb));
return task;
}
TaskExecutionSpecification &Task::GetTaskExecutionSpec() { return task_execution_spec_; }
const TaskExecutionSpecification &Task::GetTaskExecutionSpecReadonly() const {
return task_execution_spec_;
}
@@ -47,4 +58,6 @@ bool Task::DependsOn(const ObjectID &object_id) const {
return false;
}
} // namespace raylet
} // namespace ray
+23 -2
View File
@@ -3,11 +3,14 @@
#include <inttypes.h>
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/task_execution_spec.h"
#include "ray/raylet/task_spec.h"
namespace ray {
namespace raylet {
/// \class Task
///
/// A Task represents a Ray task and a specification of its execution (e.g.,
@@ -27,13 +30,29 @@ class Task {
const TaskSpecification &task_spec)
: task_execution_spec_(execution_spec), task_spec_(task_spec) {}
/// Create a task from a serialized flatbuffer.
///
/// \param task_flatbuffer The serialized task.
Task(const protocol::Task &task_flatbuffer)
: task_execution_spec_(*task_flatbuffer.task_execution_spec()),
task_spec_(*task_flatbuffer.task_specification()) {}
/// Destroy the task.
virtual ~Task() {}
/// Serialize a task to a flatbuffer.
///
/// \param fbb The flatbuffer builder.
/// \return An offset to the serialized task.
flatbuffers::Offset<protocol::Task> ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const;
/// Get the execution specification for the task.
///
/// \return The mutable specification for the task.
const TaskExecutionSpecification &GetTaskExecutionSpec() const;
TaskExecutionSpecification &GetTaskExecutionSpec();
const TaskExecutionSpecification &GetTaskExecutionSpecReadonly() const;
/// Get the immutable specification for the task.
///
@@ -64,6 +83,8 @@ class Task {
TaskSpecification task_spec_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_TASK_H
#endif // RAY_RAYLET_TASK_H
@@ -2,6 +2,8 @@
namespace ray {
namespace raylet {
TaskDependencyManager::TaskDependencyManager(
ObjectManager &object_manager,
// ReconstructionPolicy &reconstruction_policy,
@@ -59,6 +61,8 @@ bool TaskDependencyManager::TaskReady(const Task &task) const {
}
void TaskDependencyManager::SubscribeTaskReady(const Task &task) {
// TODO(swang): Don't pull arguments that are going to be created by a queued
// or running task.
TaskID task_id = task.GetTaskSpecification().TaskId();
const std::vector<ObjectID> arguments = task.GetDependencies();
// Add the task's arguments to the table of subscribed tasks.
@@ -104,4 +108,6 @@ void TaskDependencyManager::MarkDependencyReady(const ObjectID &object) {
throw std::runtime_error("Method not implemented");
}
} // namespace raylet
} // namespace ray
+4
View File
@@ -10,6 +10,8 @@
namespace ray {
namespace raylet {
class ReconstructionPolicy;
/// \class TaskDependencyManager
@@ -77,6 +79,8 @@ class TaskDependencyManager {
std::function<void(const TaskID &)> task_ready_callback_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_TASK_DEPENDENCY_MANAGER_H
+42 -20
View File
@@ -2,35 +2,57 @@
namespace ray {
TaskExecutionSpecification::TaskExecutionSpecification(
const std::vector<ObjectID> &&execution_dependencies)
: execution_dependencies_(std::move(execution_dependencies)),
last_timestamp_(0),
spillback_count_(0) {}
namespace raylet {
TaskExecutionSpecification::TaskExecutionSpecification(
const std::vector<ObjectID> &&execution_dependencies, int spillback_count)
: execution_dependencies_(std::move(execution_dependencies)),
last_timestamp_(0),
spillback_count_(spillback_count) {}
const std::vector<ObjectID> &&dependencies) {
SetExecutionDependencies(dependencies);
}
const std::vector<ObjectID> &TaskExecutionSpecification::ExecutionDependencies() const {
return execution_dependencies_;
TaskExecutionSpecification::TaskExecutionSpecification(
const std::vector<ObjectID> &&dependencies, int num_forwards) {
// TaskExecutionSpecification(std::move(dependencies));
SetExecutionDependencies(dependencies);
execution_spec_.num_forwards = num_forwards;
}
flatbuffers::Offset<protocol::TaskExecutionSpecification>
TaskExecutionSpecification::ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const {
fbb.ForceDefaults(true);
return protocol::TaskExecutionSpecification::Pack(fbb, &execution_spec_);
}
std::vector<ObjectID> TaskExecutionSpecification::ExecutionDependencies() const {
std::vector<ObjectID> dependencies;
for (const auto &dependency : execution_spec_.dependencies) {
dependencies.push_back(ObjectID::from_binary(dependency));
}
return dependencies;
}
void TaskExecutionSpecification::SetExecutionDependencies(
const std::vector<ObjectID> &dependencies) {
execution_dependencies_ = dependencies;
for (const auto &dependency : dependencies) {
execution_spec_.dependencies.push_back(dependency.binary());
}
}
int TaskExecutionSpecification::SpillbackCount() const { return spillback_count_; }
void TaskExecutionSpecification::IncrementSpillbackCount() { ++spillback_count_; }
int64_t TaskExecutionSpecification::LastTimeStamp() const { return last_timestamp_; }
void TaskExecutionSpecification::SetLastTimeStamp(int64_t new_timestamp) {
last_timestamp_ = new_timestamp;
int TaskExecutionSpecification::NumForwards() const {
return execution_spec_.num_forwards;
}
void TaskExecutionSpecification::IncrementNumForwards() {
execution_spec_.num_forwards += 1;
}
int64_t TaskExecutionSpecification::LastTimestamp() const {
return execution_spec_.last_timestamp;
}
void TaskExecutionSpecification::SetLastTimestamp(int64_t new_timestamp) {
execution_spec_.last_timestamp = new_timestamp;
}
} // namespace raylet
} // namespace ray
+41 -28
View File
@@ -4,9 +4,12 @@
#include <vector>
#include "ray/id.h"
#include "ray/raylet/format/node_manager_generated.h"
namespace ray {
namespace raylet {
/// \class TaskExecutionSpecification
///
/// The task execution specification encapsulates all mutable information about
@@ -16,60 +19,70 @@ class TaskExecutionSpecification {
public:
/// Create a task execution specification.
///
/// \param execution_dependencies The task's dependencies, determined at
/// execution time.
TaskExecutionSpecification(const std::vector<ObjectID> &&execution_dependencies);
/// \param dependencies The task's dependencies, determined at execution
/// time.
TaskExecutionSpecification(const std::vector<ObjectID> &&dependencies);
/// Create a task execution specification.
///
/// \param execution_dependencies The task's dependencies, determined at
/// execution time.
/// \param spillback_count The number of times this task was spilled back by
/// local schedulers.
TaskExecutionSpecification(const std::vector<ObjectID> &&execution_dependencies,
int spillback_count);
/// \param dependencies The task's dependencies, determined at execution
/// time.
/// \param num_forwards The number of times this task has been forwarded by a
/// node manager.
TaskExecutionSpecification(const std::vector<ObjectID> &&dependencies,
int num_forwards);
/// Create a task execution specification from a serialized flatbuffer.
///
/// \param spec_flatbuffer The serialized specification.
TaskExecutionSpecification(
const protocol::TaskExecutionSpecification &spec_flatbuffer) {
spec_flatbuffer.UnPackTo(&execution_spec_);
}
/// Serialize a task execution specification to a flatbuffer.
///
/// \param fbb The flatbuffer builder.
/// \return An offset to the serialized task execution specification.
flatbuffers::Offset<protocol::TaskExecutionSpecification> ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const;
/// Get the task's execution dependencies.
///
/// \return A vector of object IDs representing this task's execution
/// dependencies.
const std::vector<ObjectID> &ExecutionDependencies() const;
/// dependencies.
std::vector<ObjectID> ExecutionDependencies() const;
/// Set the task's execution dependencies.
///
/// \param dependencies The value to set the execution dependencies to.
void SetExecutionDependencies(const std::vector<ObjectID> &dependencies);
/// Get the task's spillback count, which tracks the number of times
/// this task was spilled back from local to the global scheduler.
/// Get the number of times this task has been forwarded.
///
/// \return The spillback count for this task.
int SpillbackCount() const;
/// \return The number of times this task has been forwarded.
int NumForwards() const;
/// Increment the spillback count for this task.
void IncrementSpillbackCount();
/// Increment the number of times this task has been forwarded.
void IncrementNumForwards();
/// Get the task's last timestamp.
///
/// \return The timestamp when this task was last received for scheduling.
int64_t LastTimeStamp() const;
int64_t LastTimestamp() const;
/// Set the task's last timestamp to the specified value.
///
/// \param new_timestamp The new timestamp in millisecond to set the task's
/// time stamp to. Tracks the last time this task entered a local
/// scheduler.
void SetLastTimeStamp(int64_t new_timestamp);
/// time stamp to. Tracks the last time this task entered a local scheduler.
void SetLastTimestamp(int64_t new_timestamp);
private:
/// A list of object IDs representing the dependencies of this task that may
/// change at execution time.
std::vector<ObjectID> execution_dependencies_;
/// The last time this task was received for scheduling.
int64_t last_timestamp_;
/// The number of times this task was spilled back by local schedulers.
int spillback_count_;
protocol::TaskExecutionSpecificationT execution_spec_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_TASK_EXECUTION_SPECIFICATION_H
+30 -43
View File
@@ -5,6 +5,8 @@
namespace ray {
namespace raylet {
TaskArgument::~TaskArgument() {}
TaskArgumentByReference::TaskArgumentByReference(const std::vector<ObjectID> &references)
@@ -15,14 +17,6 @@ flatbuffers::Offset<Arg> TaskArgumentByReference::ToFlatbuffer(
return CreateArg(fbb, to_flatbuf(fbb, references_));
}
const BYTE *TaskArgumentByReference::HashData() const {
return reinterpret_cast<const BYTE *>(references_.data());
}
size_t TaskArgumentByReference::HashDataLength() const {
return references_.size() * sizeof(ObjectID);
}
TaskArgumentByValue::TaskArgumentByValue(const uint8_t *value, size_t length) {
value_.assign(value, value + length);
}
@@ -35,27 +29,12 @@ flatbuffers::Offset<Arg> TaskArgumentByValue::ToFlatbuffer(
return CreateArg(fbb, empty_ids, arg);
}
const BYTE *TaskArgumentByValue::HashData() const { return value_.data(); }
size_t TaskArgumentByValue::HashDataLength() const { return value_.size(); }
static const ObjectID task_compute_return_id(TaskID task_id, int64_t return_index) {
// Here, return_indices need to be >= 0, so we can use negative indices for put.
RAY_DCHECK(return_index >= 0);
// TODO(rkn): This line requires object and task IDs to be the same size.
ObjectID return_id = task_id;
int64_t *first_bytes = (int64_t *)&return_id;
// XOR the first bytes of the object ID with the return index.
// We add one so the first return ID is not the same as the task ID.
*first_bytes = *first_bytes ^ (return_index + 1);
return return_id;
void TaskSpecification::AssignSpecification(const uint8_t *spec, size_t spec_size) {
spec_.assign(spec, spec + spec_size);
}
TaskSpecification::TaskSpecification(const uint8_t *spec, size_t spec_size)
: spec_(spec, spec + spec_size) {}
TaskSpecification::TaskSpecification(const flatbuffers::String &string)
: TaskSpecification(reinterpret_cast<const uint8_t *>(string.data()), string.size()) {
TaskSpecification::TaskSpecification(const flatbuffers::String &string) {
AssignSpecification(reinterpret_cast<const uint8_t *>(string.data()), string.size());
}
TaskSpecification::TaskSpecification(
@@ -63,9 +42,10 @@ TaskSpecification::TaskSpecification(
// UniqueID actor_id,
// UniqueID actor_handle_id,
// int64_t actor_counter,
FunctionID function_id, const std::vector<TaskArgument> &task_arguments,
int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources) {
FunctionID function_id,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources)
: spec_() {
flatbuffers::FlatBufferBuilder fbb;
// Compute hashes.
@@ -78,14 +58,6 @@ TaskSpecification::TaskSpecification(
// sha256_update(&ctx, (BYTE *) &actor_counter, sizeof(actor_counter));
// sha256_update(&ctx, (BYTE *) &is_actor_checkpoint_method,
// sizeof(is_actor_checkpoint_method));
sha256_update(&ctx, (BYTE *)&function_id, sizeof(function_id));
// Serialize and hash the arguments.
std::vector<flatbuffers::Offset<Arg>> arguments;
for (auto &argument : task_arguments) {
arguments.push_back(argument.ToFlatbuffer(fbb));
sha256_update(&ctx, (BYTE *)argument.HashData(), argument.HashDataLength());
}
// Compute the final task ID from the hash.
BYTE buff[DIGEST_SIZE];
@@ -93,11 +65,17 @@ TaskSpecification::TaskSpecification(
TaskID task_id;
RAY_DCHECK(sizeof(task_id) <= DIGEST_SIZE);
memcpy(&task_id, buff, sizeof(task_id));
task_id = FinishTaskId(task_id);
// Add argument object IDs.
std::vector<flatbuffers::Offset<Arg>> arguments;
for (auto &argument : task_arguments) {
arguments.push_back(argument->ToFlatbuffer(fbb));
}
// Add return object IDs.
std::vector<flatbuffers::Offset<flatbuffers::String>> returns;
for (int64_t i = 0; i < num_returns; i++) {
ObjectID return_id = task_compute_return_id(task_id, i);
for (int64_t i = 1; i < num_returns + 1; i++) {
ObjectID return_id = ComputeReturnId(task_id, i);
returns.push_back(to_flatbuf(fbb, return_id));
}
@@ -110,7 +88,7 @@ TaskSpecification::TaskSpecification(
fbb.CreateVector(arguments), fbb.CreateVector(returns),
map_to_flatbuf(fbb, required_resources));
fbb.Finish(spec);
TaskSpecification(fbb.GetBufferPointer(), fbb.GetSize());
AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize());
}
flatbuffers::Offset<flatbuffers::String> TaskSpecification::ToFlatbuffer(
@@ -130,7 +108,8 @@ TaskID TaskSpecification::TaskId() const {
return from_flatbuf(*message->task_id());
}
UniqueID TaskSpecification::DriverId() const {
throw std::runtime_error("Method not implemented");
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->driver_id());
}
TaskID TaskSpecification::ParentTaskId() const {
throw std::runtime_error("Method not implemented");
@@ -148,7 +127,13 @@ int64_t TaskSpecification::NumArgs() const {
}
int64_t TaskSpecification::NumReturns() const {
throw std::runtime_error("Method not implemented");
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->returns()->size();
}
ObjectID TaskSpecification::ReturnId(int64_t return_index) const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->returns()->Get(return_index));
}
bool TaskSpecification::ArgByRef(int64_t arg_index) const {
@@ -180,4 +165,6 @@ const ResourceSet TaskSpecification::GetRequiredResources() const {
return ResourceSet(required_resources);
}
} // namespace raylet
} // namespace ray
+14 -18
View File
@@ -6,7 +6,7 @@
#include <unordered_map>
#include <vector>
#include "ray/../common/format/common_generated.h"
#include "format/common_generated.h"
#include "ray/id.h"
#include "ray/raylet/scheduling_resources.h"
@@ -16,6 +16,8 @@ extern "C" {
namespace ray {
namespace raylet {
/// \class TaskArgument
///
/// A virtual class that represents an argument to a task.
@@ -28,16 +30,6 @@ class TaskArgument {
virtual flatbuffers::Offset<Arg> ToFlatbuffer(
flatbuffers::FlatBufferBuilder &fbb) const = 0;
/// Get the hashable byte data.
///
/// \return A pointer to the byte data.
virtual const BYTE *HashData() const = 0;
/// Get the hashable byte data length.
///
/// \return The length of the hashable byte data.
virtual size_t HashDataLength() const = 0;
virtual ~TaskArgument() = 0;
};
@@ -45,14 +37,15 @@ class TaskArgument {
///
/// A task argument consisting of a list of object ID references.
class TaskArgumentByReference : virtual public TaskArgument {
public:
/// Create a task argument by reference from a list of object IDs.
///
/// \param references A list of object ID references.
TaskArgumentByReference(const std::vector<ObjectID> &references);
~TaskArgumentByReference(){};
flatbuffers::Offset<Arg> ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const;
const BYTE *HashData() const;
size_t HashDataLength() const;
private:
/// The object IDs.
@@ -63,6 +56,7 @@ class TaskArgumentByReference : virtual public TaskArgument {
///
/// A task argument containing the raw value.
class TaskArgumentByValue : public TaskArgument {
public:
/// Create a task argument from a raw value.
///
/// \param value A pointer to the raw value.
@@ -70,8 +64,6 @@ class TaskArgumentByValue : public TaskArgument {
TaskArgumentByValue(const uint8_t *value, size_t length);
flatbuffers::Offset<Arg> ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const;
const BYTE *HashData() const;
size_t HashDataLength() const;
private:
/// The raw value.
@@ -106,7 +98,8 @@ class TaskSpecification {
// UniqueID actor_id,
// UniqueID actor_handle_id,
// int64_t actor_counter,
FunctionID function_id, const std::vector<TaskArgument> &arguments,
FunctionID function_id,
const std::vector<std::shared_ptr<TaskArgument>> &arguments,
int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources);
@@ -130,14 +123,15 @@ class TaskSpecification {
bool ArgByRef(int64_t arg_index) const;
int ArgIdCount(int64_t arg_index) const;
ObjectID ArgId(int64_t arg_index, int64_t id_index) const;
ObjectID ReturnId(int64_t return_index) const;
const uint8_t *ArgVal(int64_t arg_index) const;
size_t ArgValLength(int64_t arg_index) const;
double GetRequiredResource(const std::string &resource_name) const;
const ResourceSet GetRequiredResources() const;
private:
/// Task specification constructor from a pointer.
TaskSpecification(const uint8_t *spec, size_t spec_size);
/// Assign the specification data from a pointer.
void AssignSpecification(const uint8_t *spec, size_t spec_size);
/// Get a pointer to the byte data.
const uint8_t *data() const;
/// Get the size in bytes of the task specification.
@@ -147,6 +141,8 @@ class TaskSpecification {
std::vector<uint8_t> spec_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_TASK_SPECIFICATION_H
+45
View File
@@ -0,0 +1,45 @@
#include "gtest/gtest.h"
#include "ray/raylet/task_spec.h"
namespace ray {
namespace raylet {
void TestTaskReturnId(const TaskID &task_id, int64_t return_index) {
// Round trip test for computing the object ID for a task's return value,
// then computing the task ID that created the object.
ObjectID return_id = ComputeReturnId(task_id, return_index);
ASSERT_EQ(ComputeTaskId(return_id), task_id);
ASSERT_EQ(ComputeObjectIndex(return_id), return_index);
}
void TestTaskPutId(const TaskID &task_id, int64_t put_index) {
// Round trip test for computing the object ID for a task's put value, then
// computing the task ID that created the object.
ObjectID put_id = ComputePutId(task_id, put_index);
ASSERT_EQ(ComputeTaskId(put_id), task_id);
ASSERT_EQ(ComputeObjectIndex(put_id), -1 * put_index);
}
TEST(TaskSpecTest, TestTaskReturnIds) {
TaskID task_id = FinishTaskId(TaskID::from_random());
// Check that we can compute between a task ID and the object IDs of its
// return values and puts.
TestTaskReturnId(task_id, 1);
TestTaskReturnId(task_id, 2);
TestTaskReturnId(task_id, kMaxTaskReturns);
TestTaskPutId(task_id, 1);
TestTaskPutId(task_id, 2);
TestTaskPutId(task_id, kMaxTaskPuts);
}
} // namespace raylet
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
+4
View File
@@ -8,6 +8,8 @@
namespace ray {
namespace raylet {
/// A constructor responsible for initializing the state of a worker.
Worker::Worker(pid_t pid, std::shared_ptr<LocalClientConnection> connection)
: pid_(pid), connection_(connection), assigned_task_id_(TaskID::nil()) {}
@@ -22,4 +24,6 @@ const std::shared_ptr<LocalClientConnection> Worker::Connection() const {
return connection_;
}
} // namespace raylet
} // end namespace ray
+4
View File
@@ -8,6 +8,8 @@
namespace ray {
namespace raylet {
/// Worker class encapsulates the implementation details of a worker. A worker
/// is the execution container around a unit of Ray work, such as a task or an
/// actor. Ray units of work execute in the context of a Worker.
@@ -32,6 +34,8 @@ class Worker {
TaskID assigned_task_id_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_WORKER_H
+27 -5
View File
@@ -5,8 +5,15 @@
namespace ray {
namespace raylet {
/// A constructor that initializes a worker pool with num_workers workers.
WorkerPool::WorkerPool(int num_workers) {
WorkerPool::WorkerPool(int num_workers, const std::vector<const char *> &worker_command)
: worker_command_(worker_command) {
worker_command_.push_back(NULL);
// Ignore SIGCHLD signals. If we don't do this, then worker processes will
// become zombies instead of dying gracefully.
signal(SIGCHLD, SIG_IGN);
for (int i = 0; i < num_workers; i++) {
StartWorker();
}
@@ -18,10 +25,23 @@ WorkerPool::~WorkerPool() {
registered_workers_.clear();
}
/// Create a new worker and add it to the pool
bool WorkerPool::StartWorker() {
// TODO(swang): Start the worker.
return true;
void WorkerPool::StartWorker() {
RAY_CHECK(!worker_command_.empty()) << "No worker command provided";
// Launch the process to create the worker.
pid_t pid = fork();
if (pid != 0) {
RAY_LOG(DEBUG) << "Started worker with pid " << pid;
return;
}
// Reset the SIGCHLD handler for the worker.
signal(SIGCHLD, SIG_DFL);
// Try to execute the worker command.
int rv = execvp(worker_command_[0], (char *const *)worker_command_.data());
// The worker failed to start. This is a fatal error.
RAY_LOG(FATAL) << "Failed to start worker with return value " << rv;
}
uint32_t WorkerPool::PoolSize() const { return pool_.size(); }
@@ -75,4 +95,6 @@ bool WorkerPool::DisconnectWorker(std::shared_ptr<Worker> worker) {
return removeWorker(pool_, worker);
}
} // namespace raylet
} // namespace ray
+10 -5
View File
@@ -9,6 +9,8 @@
namespace ray {
namespace raylet {
class Worker;
/// \class WorkerPool
@@ -23,7 +25,7 @@ class WorkerPool {
/// pool.
///
/// \param num_workers The number of workers to start.
WorkerPool(int num_workers);
WorkerPool(int num_workers, const std::vector<const char *> &worker_command);
/// Destructor responsible for freeing a set of workers owned by this class.
~WorkerPool();
@@ -35,10 +37,9 @@ class WorkerPool {
/// Asynchronously start a new worker process. Once the worker process has
/// registered with an external server, the process should create and
/// register a new Worker, then add itself to the pool.
///
/// \return Whether the worker process was successfully started.
bool StartWorker();
/// register a new Worker, then add itself to the pool. Failure to start
/// the worker process is a fatal error.
void StartWorker();
/// Register a new worker. The Worker should be added by the caller to the
/// pool after it becomes idle (e.g., requests a work assignment).
@@ -73,6 +74,7 @@ class WorkerPool {
std::shared_ptr<Worker> PopWorker();
private:
std::vector<const char *> worker_command_;
/// The pool of idle workers.
std::list<std::shared_ptr<Worker>> pool_;
/// All workers that have registered and are still connected, including both
@@ -80,6 +82,9 @@ class WorkerPool {
// TODO(swang): Make this a map to make GetRegisteredWorker faster.
std::list<std::shared_ptr<Worker>> registered_workers_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_WORKER_POOL_H
+19 -9
View File
@@ -6,27 +6,35 @@
namespace ray {
class MockClientManager : public ClientManager<boost::asio::local::stream_protocol> {
public:
MOCK_METHOD3(ProcessClientMessage,
void(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *));
MOCK_METHOD1(ProcessNewClient, void(std::shared_ptr<LocalClientConnection>));
};
namespace raylet {
class WorkerPoolTest : public ::testing::Test {
public:
WorkerPoolTest() : worker_pool_(0), client_manager_(), io_service_() {}
WorkerPoolTest() : worker_pool_(0, {}), io_service_() {}
std::shared_ptr<Worker> CreateWorker(pid_t pid) {
std::function<void(std::shared_ptr<LocalClientConnection>)> client_handler =
[this](std::shared_ptr<LocalClientConnection> client) {
HandleNewClient(client);
};
std::function<void(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *)>
message_handler = [this](std::shared_ptr<LocalClientConnection> client,
int64_t message_type, const uint8_t *message) {
HandleMessage(client, message_type, message);
};
boost::asio::local::stream_protocol::socket socket(io_service_);
auto client = LocalClientConnection::Create(client_manager_, std::move(socket));
auto client =
LocalClientConnection::Create(client_handler, message_handler, std::move(socket));
return std::shared_ptr<Worker>(new Worker(pid, client));
}
protected:
WorkerPool worker_pool_;
MockClientManager client_manager_;
boost::asio::io_service io_service_;
private:
void HandleNewClient(std::shared_ptr<LocalClientConnection>){};
void HandleMessage(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *){};
};
TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
@@ -66,6 +74,8 @@ TEST_F(WorkerPoolTest, HandleWorkerPushPop) {
ASSERT_TRUE(workers.count(popped_worker) > 0);
}
} // namespace raylet
} // namespace ray
int main(int argc, char **argv) {
+2
View File
@@ -4,6 +4,7 @@
# Cause the script to exit if a single command fails.
set -e
set -x
# Start Redis.
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
@@ -12,3 +13,4 @@ sleep 1s
./src/ray/gcs/client_test
./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown
sleep 1s
+46
View File
@@ -0,0 +1,46 @@
#!/usr/bin/env bash
# This needs to be run in the build tree, which is normally ray/python/ray/core
# Cause the script to exit if a single command fails.
set -e
set -x
# Get the directory in which this script is executing.
SCRIPT_DIR="`dirname \"$0\"`"
RAY_ROOT="$SCRIPT_DIR/../../.."
RAY_ROOT="`( cd \"$RAY_ROOT\" && pwd )`"
if [ -z "$RAY_ROOT" ] ; then
exit 1
fi
# Ensure we're in the right directory.
if [ ! -d "$RAY_ROOT/python" ]; then
echo "Unable to find root Ray directory. Has this script moved?"
exit 1
fi
CORE_DIR="$RAY_ROOT/python/ray/core"
REDIS_DIR="$CORE_DIR/src/common/thirdparty/redis/src"
REDIS_MODULE="$CORE_DIR/src/common/redis_module/libray_redis_module.so"
STORE_EXEC="$CORE_DIR/src/plasma/plasma_store"
echo "$STORE_EXEC"
echo "$REDIS_DIR/redis-server --loglevel warning --loadmodule $REDIS_MODULE --port 6379"
echo "$REDIS_DIR/redis-cli -p 6379 shutdown"
# Allow cleanup commands to fail.
killall plasma_store || true
$REDIS_DIR/redis-cli -p 6379 shutdown || true
sleep 1s
$REDIS_DIR/redis-server --loglevel warning --loadmodule $REDIS_MODULE --port 6379 &
sleep 1s
# Run tests.
$CORE_DIR/src/ray/object_manager/object_manager_stress_test $STORE_EXEC
sleep 1s
$CORE_DIR/src/ray/object_manager/object_manager_test $STORE_EXEC
$REDIS_DIR/redis-cli -p 6379 shutdown
sleep 1s
# Include raylet integration test once it's ready.
# $CORE_DIR/src/ray/raylet/object_manager_integration_test $STORE_EXEC
+23
View File
@@ -0,0 +1,23 @@
#!/usr/bin/env bash
# This needs to be run in the build tree, which is normally ray/python/ray/core
# Cause the script to exit if a single command fails.
set -e
set -x
# Tear down the Raylet.
#bash ../../../src/ray/test/stop_raylets.sh
# Set up a single Raylet.
bash ../../../src/ray/test/start_raylets.sh
sleep 1
# Connect a driver to the raylet and make sure it completes.
python ../../../src/ray/python/test_driver.py /tmp/raylet1 /tmp/store1
sleep 1
./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown
bash ../../../src/ray/test/stop_raylets.sh
+29
View File
@@ -0,0 +1,29 @@
#!/usr/bin/env bash
# This needs to be run in the build tree, which is normally ray/python/ray/core
# Cause the script to exit if a single command fails.
set -e
if [[ $1 ]]; then
RAYLET_NUM=$1
else
RAYLET_NUM=1
fi
STORE_SOCKET_NAME="/tmp/store$RAYLET_NUM"
RAYLET_SOCKET_NAME="/tmp/raylet$RAYLET_NUM"
if [[ `stat $RAYLET_SOCKET_NAME` ]]; then
rm $RAYLET_SOCKET_NAME
fi
if [[ `stat $STORE_SOCKET_NAME` ]]; then
rm $STORE_SOCKET_NAME
fi
./src/plasma/plasma_store -m 1000000000 -s $STORE_SOCKET_NAME &
./src/ray/raylet/raylet $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME 127.0.0.1 6379 &
echo
echo "WORKER COMMAND: python ../../../src/ray/python/worker.py $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME"
echo
+36
View File
@@ -0,0 +1,36 @@
#!/usr/bin/env bash
# This needs to be run in the build tree, which is normally ray/python/ray/core
# Cause the script to exit if a single command fails.
set -e
# Start the GCS.
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 >/dev/null &
sleep 1s
if [[ $1 ]]; then
NUM_RAYLETS=$1
else
NUM_RAYLETS=1
fi
for i in `seq 1 $NUM_RAYLETS`; do
STORE_SOCKET_NAME="/tmp/store$i"
RAYLET_SOCKET_NAME="/tmp/raylet$i"
if [[ `stat $RAYLET_SOCKET_NAME` ]]; then
rm $RAYLET_SOCKET_NAME
fi
if [[ `stat $STORE_SOCKET_NAME` ]]; then
rm $STORE_SOCKET_NAME
fi
./src/plasma/plasma_store -m 1000000000 -s $STORE_SOCKET_NAME &
./src/ray/raylet/raylet $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME 127.0.0.1 6379 &
echo
echo "WORKER COMMAND: python ../../../src/ray/python/worker.py $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME"
echo
done
+11
View File
@@ -0,0 +1,11 @@
#!/usr/bin/env bash
# This needs to be run in the build tree, which is normally ray/python/ray/core
# Cause the script to exit if a single command fails.
set -e
# Start the GCS.
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 >/dev/null &
sleep 1s
+9
View File
@@ -0,0 +1,9 @@
#!/usr/bin/env bash
killall raylet
sleep 1
killall plasma_store
sleep 1
killall redis-server
sleep 1
rm /tmp/store* /tmp/raylet*