diff --git a/.travis.yml b/.travis.yml index 0168ba459..323433dcb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -102,6 +102,13 @@ install: - cd python/ray/core - bash ../../../src/ray/test/run_gcs_tests.sh + # Raylet tests. + - bash ../../../src/ray/test/run_object_manager_tests.sh + - bash ../../../src/ray/test/run_task_test.sh + - ./src/ray/raylet/task_test + - ./src/ray/raylet/worker_pool_test + - ./src/ray/raylet/lineage_cache_test + - bash ../../../src/common/test/run_tests.sh - bash ../../../src/plasma/test/run_tests.sh - bash ../../../src/local_scheduler/test/run_tests.sh diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index b08b07500..984ea30bd 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -50,7 +50,14 @@ } static const char *table_prefixes[] = { - NULL, "TASK:", "TASK:", "CLIENT:", "OBJECT:", "FUNCTION:", + NULL, + "TASK:", + "TASK:", + "CLIENT:", + "OBJECT:", + "FUNCTION:", + "TASK_RECONSTRUCTION:", + "HEARTBEAT:", }; /// Parse a Redis string into a TablePubsub channel. @@ -811,8 +818,9 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, // notifications. flatbuffers::FlatBufferBuilder fbb; TableEntryToFlatbuf(table_key, id, fbb); - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, reinterpret_cast(fbb.GetBufferPointer()), - fbb.GetSize()); + RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, + reinterpret_cast(fbb.GetBufferPointer()), + fbb.GetSize()); } RedisModule_CloseKey(table_key); diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc index f8c49be9c..3dcb12593 100644 --- a/src/global_scheduler/global_scheduler.cc +++ b/src/global_scheduler/global_scheduler.cc @@ -141,13 +141,8 @@ GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop, std::vector()); db_attach(state->db, loop, false); - ClientTableDataT client_info; - client_info.client_id = get_db_client_id(state->db).binary(); - client_info.node_manager_address = std::string(node_ip_address); - client_info.local_scheduler_port = 0; - client_info.object_manager_port = 0; RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr), - redis_primary_port, client_info)); + redis_primary_port)); RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop)); state->policy_state = GlobalSchedulerPolicyState_init(); return state; diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index ecf43e35f..b68490890 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -355,13 +355,8 @@ LocalSchedulerState *LocalSchedulerState_init( "local_scheduler", node_ip_address, db_connect_args); db_attach(state->db, loop, false); - ClientTableDataT client_info; - client_info.client_id = get_db_client_id(state->db).binary(); - client_info.node_manager_address = std::string(node_ip_address); - client_info.local_scheduler_port = 0; - client_info.object_manager_port = 0; RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr), - redis_primary_port, client_info)); + redis_primary_port)); RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop)); } else { state->db = NULL; diff --git a/src/plasma/plasma_manager.cc b/src/plasma/plasma_manager.cc index be7b1aee8..03b6db0b8 100644 --- a/src/plasma/plasma_manager.cc +++ b/src/plasma/plasma_manager.cc @@ -487,13 +487,8 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name, "plasma_manager", manager_addr, db_connect_args); db_attach(state->db, state->loop, false); - ClientTableDataT client_info; - client_info.client_id = get_db_client_id(state->db).binary(); - client_info.node_manager_address = std::string(manager_addr); - client_info.local_scheduler_port = 0; - client_info.object_manager_port = manager_port; RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr), - redis_primary_port, client_info)); + redis_primary_port)); RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(state->loop)); } else { state->db = NULL; diff --git a/src/ray/CMakeLists.txt b/src/ray/CMakeLists.txt index 7a4166ac6..513832dc5 100644 --- a/src/ray/CMakeLists.txt +++ b/src/ray/CMakeLists.txt @@ -38,8 +38,11 @@ set(RAY_SRCS gcs/asio.cc common/client_connection.cc object_manager/object_manager_client_connection.cc - object_manager/object_store_client.cc + object_manager/connection_pool.cc + object_manager/object_store_client_pool.cc + object_manager/object_store_notification_manager.cc object_manager/object_directory.cc + object_manager/transfer_queue.cc object_manager/object_manager.cc raylet/mock_gcs_client.cc raylet/task.cc diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index e2a184654..c1b237775 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -7,20 +7,82 @@ namespace ray { +ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket, + const std::string &ip_address_string, int port) { + boost::asio::ip::address ip_address = + boost::asio::ip::address::from_string(ip_address_string); + boost::asio::ip::tcp::endpoint endpoint(ip_address, port); + boost::system::error_code error; + socket.connect(endpoint, error); + if (error) { + return ray::Status::IOError(error.message()); + } else { + return ray::Status::OK(); + } +} + +template +ServerConnection::ServerConnection(boost::asio::basic_stream_socket &&socket) + : socket_(std::move(socket)) {} + +template +void ServerConnection::WriteBuffer( + const std::vector &buffer, boost::system::error_code &ec) { + boost::asio::write(socket_, buffer, ec); +} + +template +void ServerConnection::ReadBuffer( + const std::vector &buffer, + boost::system::error_code &ec) { + boost::asio::read(socket_, buffer, ec); +} + +template +ray::Status ServerConnection::WriteMessage(int64_t type, int64_t length, + const uint8_t *message) { + std::vector message_buffers; + auto write_version = RayConfig::instance().ray_protocol_version(); + message_buffers.push_back(boost::asio::buffer(&write_version, sizeof(write_version))); + message_buffers.push_back(boost::asio::buffer(&type, sizeof(type))); + message_buffers.push_back(boost::asio::buffer(&length, sizeof(length))); + message_buffers.push_back(boost::asio::buffer(message, length)); + // Write the message and then wait for more messages. + // TODO(swang): Does this need to be an async write? + boost::system::error_code error; + boost::asio::write(socket_, message_buffers, error); + if (error) { + return ray::Status::IOError(error.message()); + } else { + return ray::Status::OK(); + } +} + template std::shared_ptr> ClientConnection::Create( - ClientManager &manager, boost::asio::basic_stream_socket &&socket) { + ClientHandler &client_handler, MessageHandler &message_handler, + boost::asio::basic_stream_socket &&socket) { std::shared_ptr> self( - new ClientConnection(manager, std::move(socket))); + new ClientConnection(message_handler, std::move(socket))); // Let our manager process our new connection. - self->manager_.ProcessNewClient(self); + client_handler(self); return self; } template -ClientConnection::ClientConnection(ClientManager &manager, +ClientConnection::ClientConnection(MessageHandler &message_handler, boost::asio::basic_stream_socket &&socket) - : socket_(std::move(socket)), manager_(manager) {} + : ServerConnection(std::move(socket)), message_handler_(message_handler) {} + +template +const ClientID &ClientConnection::GetClientID() { + return client_id_; +} + +template +void ClientConnection::SetClientID(const ClientID &client_id) { + client_id_ = client_id; +} template void ClientConnection::ProcessMessages() { @@ -31,7 +93,7 @@ void ClientConnection::ProcessMessages() { header.push_back(boost::asio::buffer(&read_type_, sizeof(read_type_))); header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_))); boost::asio::async_read( - socket_, header, + ServerConnection::socket_, header, boost::bind(&ClientConnection::ProcessMessageHeader, this->shared_from_this(), boost::asio::placeholders::error)); } @@ -52,57 +114,23 @@ void ClientConnection::ProcessMessageHeader(const boost::system::error_code & read_message_.resize(read_length_); // Wait for the message to be read. boost::asio::async_read( - socket_, boost::asio::buffer(read_message_), + ServerConnection::socket_, boost::asio::buffer(read_message_), boost::bind(&ClientConnection::ProcessMessage, this->shared_from_this(), boost::asio::placeholders::error)); } -template -void ClientConnection::WriteMessage(int64_t type, size_t length, - const uint8_t *message) { - std::vector message_buffers; - write_version_ = RayConfig::instance().ray_protocol_version(); - write_type_ = type; - write_length_ = length; - write_message_.assign(message, message + length); - message_buffers.push_back(boost::asio::buffer(&write_version_, sizeof(write_version_))); - message_buffers.push_back(boost::asio::buffer(&write_type_, sizeof(write_type_))); - message_buffers.push_back(boost::asio::buffer(&write_length_, sizeof(write_length_))); - message_buffers.push_back(boost::asio::buffer(write_message_)); - boost::system::error_code error; - // Write the message and then wait for more messages. - boost::asio::async_write( - socket_, message_buffers, - boost::bind(&ClientConnection::ProcessMessages, this->shared_from_this(), - boost::asio::placeholders::error)); -} - template void ClientConnection::ProcessMessage(const boost::system::error_code &error) { if (error) { // TODO(hme): Disconnect differently & remove dependency on node_manager_generated.h read_type_ = protocol::MessageType_DisconnectClient; } - manager_.ProcessClientMessage(this->shared_from_this(), read_type_, - read_message_.data()); -} - -template -void ClientConnection::ProcessMessages(const boost::system::error_code &error) { - if (error) { - ProcessMessage(error); - } else { - ProcessMessages(); - } + message_handler_(this->shared_from_this(), read_type_, read_message_.data()); } +template class ServerConnection; +template class ServerConnection; template class ClientConnection; template class ClientConnection; -template -ClientManager::~ClientManager() {} - -template class ClientManager; -template class ClientManager; - } // namespace ray diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index 08739aa07..55efed3dd 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -7,18 +7,74 @@ #include #include +#include "ray/id.h" +#include "ray/status.h" + namespace ray { -template -class ClientManager; - -/// \class ClientConnection +/// Connect a TCP socket. /// -/// A generic type representing a client connection on a server. This class can -/// be used to process and write messages asynchronously from and to the -/// client. -template -class ClientConnection : public std::enable_shared_from_this> { +/// \param socket The socket to connect. +/// \param ip_address The IP address to connect to. +/// \param port The port to connect to. +/// \return Status. +ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket, + const std::string &ip_address, int port); + +/// \typename ServerConnection +/// +/// A generic type representing a client connection to a server. This typename +/// can be used to write messages synchronously to the server. +template +class ServerConnection { + public: + /// Create a connection to the server. + ServerConnection(boost::asio::basic_stream_socket &&socket); + + /// Write a message to the client. + /// + /// \param type The message type (e.g., a flatbuffer enum). + /// \param length The size in bytes of the message. + /// \param message A pointer to the message buffer. + /// \return Status. + ray::Status WriteMessage(int64_t type, int64_t length, const uint8_t *message); + + /// Write a buffer to this connection. + /// + /// \param buffer The buffer. + /// \param ec The error code object in which to store error codes. + void WriteBuffer(const std::vector &buffer, + boost::system::error_code &ec); + + /// Read a buffer from this connection. + /// + /// \param buffer The buffer. + /// \param ec The error code object in which to store error codes. + void ReadBuffer(const std::vector &buffer, + boost::system::error_code &ec); + + protected: + /// The socket connection to the server. + boost::asio::basic_stream_socket socket_; +}; + +template +class ClientConnection; + +template +using ClientHandler = std::function>)>; +template +using MessageHandler = + std::function>, int64_t, const uint8_t *)>; + +/// \typename ClientConnection +/// +/// A generic type representing a client connection on a server. In addition to +/// writing messages to the client, like in ServerConnection, this typename can +/// also be used to process messages asynchronously from client. +template +class ClientConnection : public ServerConnection, + public std::enable_shared_from_this> { public: /// Allocate a new node client connection. /// @@ -27,24 +83,23 @@ class ClientConnection : public std::enable_shared_from_this /// \param socket The client socket. /// \return std::shared_ptr. static std::shared_ptr> Create( - ClientManager &manager, boost::asio::basic_stream_socket &&socket); + ClientHandler &new_client_handler, MessageHandler &message_handler, + boost::asio::basic_stream_socket &&socket); + + /// \return The ClientID of the remote client. + const ClientID &GetClientID(); + + /// \param client_id The ClientID of the remote client. + void SetClientID(const ClientID &client_id); /// Listen for and process messages from the client connection. Once a /// message has been fully received, the client manager's /// ProcessClientMessage handler will be called. void ProcessMessages(); - /// Write a message to the client and then listen for more messages. - /// - /// \param type The message type (e.g., a flatbuffer enum). - /// \param length The size in bytes of the message. - /// \param message A pointer to the message buffer. This will be copied into - /// the ClientConnection's buffer. - void WriteMessage(int64_t type, size_t length, const uint8_t *message); - private: /// A private constructor for a node client connection. - ClientConnection(ClientManager &manager, + ClientConnection(MessageHandler &message_handler, boost::asio::basic_stream_socket &&socket); /// Process an error from the last operation, then process the message /// header from the client. @@ -52,54 +107,23 @@ class ClientConnection : public std::enable_shared_from_this /// Process an error from reading the message header, then process the /// message from the client. void ProcessMessage(const boost::system::error_code &error); - /// Process an error from the last operation and then listen for more - /// messages. - void ProcessMessages(const boost::system::error_code &error); - /// The client socket. - boost::asio::basic_stream_socket socket_; - /// A reference to the manager for this client. The manager exposes a handler - /// for all messages processed by this client. - ClientManager &manager_; + /// The ClientID of the remote client. + ClientID client_id_; + /// The handler for a message from the client. + MessageHandler message_handler_; /// Buffers for the current message being read rom the client. int64_t read_version_; int64_t read_type_; uint64_t read_length_; std::vector read_message_; - /// Buffers for the current message being written to the client. - int64_t write_version_; - int64_t write_type_; - uint64_t write_length_; - std::vector write_message_; }; +using LocalServerConnection = ServerConnection; +using TcpServerConnection = ServerConnection; using LocalClientConnection = ClientConnection; using TcpClientConnection = ClientConnection; -/// \class ClientManager -/// -/// A virtual cliant manager. Derived classes should define a method for -/// processing a message on the server sent by the client. -template -class ClientManager { - public: - /// Process a new client connection. - /// - /// \param client A shared pointer to the client that connected. - virtual void ProcessNewClient(std::shared_ptr> client) = 0; - - /// Process a message from a client, then listen for more messages if the - /// client is still alive. - /// - /// \param client A shared pointer to the client that sent the message. - /// \param message_type The message type (e.g., a flatbuffer enum). - /// \param message A pointer to the message buffer. - virtual void ProcessClientMessage(std::shared_ptr> client, - int64_t message_type, const uint8_t *message) = 0; - - virtual ~ClientManager() = 0; -}; - } // namespace ray #endif // RAY_COMMON_CLIENT_CONNECTION_H diff --git a/src/ray/constants.h b/src/ray/constants.h index a084543c4..bdae39ff2 100644 --- a/src/ray/constants.h +++ b/src/ray/constants.h @@ -4,6 +4,24 @@ /// Length of Ray IDs in bytes. constexpr int64_t kUniqueIDSize = 20; +/// An ObjectID's bytes are split into the task ID itself and the index of the +/// object's creation. This is the maximum width of the object index in bits. +constexpr int kObjectIdIndexSize = 32; +/// The maximum number of objects that can be returned by a task when finishing +/// execution. An ObjectID's bytes are split into the task ID itself and the +/// index of the object's creation. A positive index indicates an object +/// returned by the task, so the maximum number of objects that a task can +/// return is the maximum positive value for an integer with bit-width +/// `kObjectIdIndexSize`. +constexpr int64_t kMaxTaskReturns = ((int64_t)1 << (kObjectIdIndexSize - 1)) - 1; +/// The maximum number of objects that can be put by a task during execution. +/// An ObjectID's bytes are split into the task ID itself and the index of the +/// object's creation. A negative index indicates an object put by the task +/// during execution, so the maximum number of objects that a task can put is +/// the maximum negative value for an integer with bit-width +/// `kObjectIdIndexSize`. +constexpr int64_t kMaxTaskPuts = ((int64_t)1 << (kObjectIdIndexSize - 1)); + /// Prefix for the object table keys in redis. constexpr char kObjectTablePrefix[] = "ObjectTable"; /// Prefix for the task table keys in redis. diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index e67836553..d100a2ed7 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -6,19 +6,22 @@ namespace ray { namespace gcs { -AsyncGcsClient::AsyncGcsClient() {} - -AsyncGcsClient::~AsyncGcsClient() {} - -Status AsyncGcsClient::Connect(const std::string &address, int port, - const ClientTableDataT &client_info) { +AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) { context_.reset(new RedisContext()); - RAY_RETURN_NOT_OK(context_->Connect(address, port)); + client_table_.reset(new ClientTable(context_, this, client_id)); object_table_.reset(new ObjectTable(context_, this)); task_table_.reset(new TaskTable(context_, this)); raylet_task_table_.reset(new raylet::TaskTable(context_, this)); task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this)); - client_table_.reset(new ClientTable(context_, this, client_info)); + heartbeat_table_.reset(new HeartbeatTable(context_, this)); +} + +AsyncGcsClient::AsyncGcsClient() : AsyncGcsClient(ClientID::from_random()) {} + +AsyncGcsClient::~AsyncGcsClient() {} + +Status AsyncGcsClient::Connect(const std::string &address, int port) { + RAY_RETURN_NOT_OK(context_->Connect(address, port)); // TODO(swang): Call the client table's Connect() method here. To do this, // we need to make sure that we are attached to an event loop first. This // currently isn't possible because the aeEventLoop, which we use for @@ -55,6 +58,8 @@ FunctionTable &AsyncGcsClient::function_table() { return *function_table_; } ClassTable &AsyncGcsClient::class_table() { return *class_table_; } +HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 099d602dc..99d8e6a65 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -19,6 +19,13 @@ class RedisContext; class RAY_EXPORT AsyncGcsClient { public: + /// Start a GCS client with the given client ID. To read from the GCS tables, + /// Connect and then Attach must be called. To read and write from the GCS + /// tables requires a further call to Connect to the client table. + /// + /// \param client_id The ID to assign to the client. + AsyncGcsClient(const ClientID &client_id); + /// Start a GCS client with a random client ID. AsyncGcsClient(); ~AsyncGcsClient(); @@ -26,10 +33,8 @@ class RAY_EXPORT AsyncGcsClient { /// /// \param address The GCS IP address. /// \param port The GCS port. - /// \param client_info Information about the local client to connect. /// \return Status. - Status Connect(const std::string &address, int port, - const ClientTableDataT &client_info); + Status Connect(const std::string &address, int port); /// Attach this client to a plasma event loop. Note that only /// one event loop should be attached at a time. Status Attach(plasma::EventLoop &event_loop); @@ -48,6 +53,7 @@ class RAY_EXPORT AsyncGcsClient { raylet::TaskTable &raylet_task_table(); TaskReconstructionLog &task_reconstruction_log(); ClientTable &client_table(); + HeartbeatTable &heartbeat_table(); inline ErrorTable &error_table(); // We also need something to export generic code to run on workers from the @@ -67,6 +73,7 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr task_table_; std::unique_ptr raylet_task_table_; std::unique_ptr task_reconstruction_log_; + std::unique_ptr heartbeat_table_; std::unique_ptr client_table_; std::shared_ptr context_; std::unique_ptr asio_async_client_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index cf501ab26..6cc0eec2c 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -23,12 +23,7 @@ class TestGcs : public ::testing::Test { public: TestGcs() : num_callbacks_(0) { client_ = std::make_shared(); - ClientTableDataT client_info; - client_info.client_id = ClientID::from_random().binary(); - client_info.node_manager_address = "127.0.0.1"; - client_info.local_scheduler_port = 0; - client_info.object_manager_port = 0; - RAY_CHECK_OK(client_->Connect("127.0.0.1", 6379, client_info)); + RAY_CHECK_OK(client_->Connect("127.0.0.1", 6379)); job_id_ = JobID::from_random(); } @@ -747,6 +742,7 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client ClientID added_id = client->client_table().GetLocalClientId(); ASSERT_EQ(client_id, added_id); ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); + ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); ASSERT_EQ(data.is_insertion, is_insertion); auto cached_client = client->client_table().GetClient(added_id); @@ -763,9 +759,14 @@ void TestClientTableConnect(const JobID &job_id, ClientTableNotification(client, id, data, true); test->Stop(); }); + // Connect and disconnect to client table. We should receive notifications // for the addition and removal of our own entry. - RAY_CHECK_OK(client->client_table().Connect()); + ClientTableDataT local_client_info = client->client_table().GetLocalClient(); + local_client_info.node_manager_address = "127.0.0.1"; + local_client_info.node_manager_port = 0; + local_client_info.object_manager_port = 0; + RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -789,7 +790,11 @@ void TestClientTableDisconnect(const JobID &job_id, }); // Connect and disconnect to client table. We should receive notifications // for the addition and removal of our own entry. - RAY_CHECK_OK(client->client_table().Connect()); + ClientTableDataT local_client_info = client->client_table().GetLocalClient(); + local_client_info.node_manager_address = "127.0.0.1"; + local_client_info.node_manager_port = 0; + local_client_info.object_manager_port = 0; + RAY_CHECK_OK(client->client_table().Connect(local_client_info)); RAY_CHECK_OK(client->client_table().Disconnect()); test->Start(); } diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 1caa4f93f..085eb0d55 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -11,7 +11,8 @@ enum TablePrefix:int { CLIENT, OBJECT, FUNCTION, - TASK_RECONSTRUCTION + TASK_RECONSTRUCTION, + HEARTBEAT } // The channel that Add operations to the Table should be published on, if any. @@ -21,7 +22,8 @@ enum TablePubsub:int { RAYLET_TASK, CLIENT, OBJECT, - ACTOR + ACTOR, + HEARTBEAT } table GcsTableEntry { @@ -98,6 +100,13 @@ table CustomSerializerData { table ConfigTableData { } +table RayResource { + // The type of the resource. + resource_name: string; + // The total capacity of this resource type. + resource_capacity: double; +} + table ClientTableData { // The client ID of the client that the message is about. client_id: string; @@ -105,26 +114,24 @@ table ClientTableData { node_manager_address: string; // The port at which the client's node manager is listening for TCP // connections from other node managers. - local_scheduler_port: int; + node_manager_port: int; // The port at which the client's object manager is listening for TCP // connections from other object managers. object_manager_port: int; // True if the message is about the addition of a client and false if it is // about the deletion of a client. is_insertion: bool; + resources_total_label: [string]; + resources_total_capacity: [double]; } -table Resource { - // The type of the resource. - resource_name: string; - // The total capacity of this resource type. - resource_capacity: double; -} - -table NodeManagerHeartbeat { - // The available resources on this node manager. This information may be - // stale. - resources_available: [Resource]; - // The total resources on this node manager. - resources_total: [Resource]; +table HeartbeatTableData { + // Node manager client id + client_id: string; + // Resource capacity currently available on this node manager. + resources_available_label: [string]; + resources_available_capacity: [double]; + // Total resource capacity configured for this node manager. + resources_total_label: [string]; + resources_total_capacity: [double]; } diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 82ebf34ca..43778f70b 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -25,7 +25,8 @@ void ProcessCallback(int64_t callback_index, const std::string &data) { } } } -} + +} // namespace namespace ray { @@ -98,9 +99,10 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) { } int64_t RedisCallbackManager::add(const RedisCallback &function) { + num_callbacks += 1; callbacks_.emplace(num_callbacks, std::unique_ptr( new RedisCallback(function))); - return num_callbacks++; + return num_callbacks; } RedisCallbackManager::RedisCallback &RedisCallbackManager::get( diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 00814f3d4..4d9a296ee 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -49,7 +49,8 @@ class RedisCallbackManager { class RedisContext { public: - RedisContext() {} + RedisContext() + : context_(nullptr), async_context_(nullptr), subscribe_context_(nullptr) {} ~RedisContext(); Status Connect(const std::string &address, int port); Status AttachToEventLoop(aeEventLoop *loop); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index f42a326e1..fa288c935 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -89,8 +89,8 @@ Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, << "Client called Subscribe twice on the same table"; auto d = std::shared_ptr( new CallbackData({client_id, nullptr, subscribe, done, this, client_})); - int64_t callback_index = RedisCallbackManager::instance().add( - [this, d](const std::string &data) { + int64_t callback_index = + RedisCallbackManager::instance().add([this, d](const std::string &data) { if (data.empty()) { // No notification data is provided. This is the callback for the // initial subscription request. @@ -115,7 +115,7 @@ Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, results.emplace_back(std::move(result)); } (d->callback)(d->client, id, results); - } + } } // We do not delete the callback after calling it since there may be // more subscription messages. @@ -270,9 +270,12 @@ const ClientID &ClientTable::GetLocalClientId() { return client_id_; } const ClientTableDataT &ClientTable::GetLocalClient() { return local_client_; } -Status ClientTable::Connect() { +Status ClientTable::Connect(const ClientTableDataT &local_client) { RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client."; + RAY_CHECK(local_client.client_id == local_client_.client_id); + local_client_ = local_client; + auto data = std::make_shared(local_client_); data->is_insertion = true; // Callback for a notification from the client table. @@ -336,6 +339,7 @@ template class Log; template class Table; template class Table; template class Log; +template class Table; } // namespace gcs diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index ec8e411e0..533c258f3 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -167,13 +167,23 @@ class Log { int64_t subscribe_callback_index_; }; +template +class TableInterface { + public: + using DataT = typename Data::NativeTableType; + using WriteCallback = typename Log::WriteCallback; + virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr data, + const WriteCallback &done) = 0; + virtual ~TableInterface(){}; +}; + /// \class Table /// /// A GCS table where every entry is a single data item. /// Example tables backed by Log: /// TaskTable: Stores Task metadata needed for executing the task. template -class Table : private Log { +class Table : private Log, public TableInterface { public: using DataT = typename Log::DataT; using Callback = @@ -242,6 +252,17 @@ class ObjectTable : public Log { pubsub_channel_ = TablePubsub_OBJECT; prefix_ = TablePrefix_OBJECT; }; + virtual ~ObjectTable(){}; +}; + +class HeartbeatTable : public Table { + public: + HeartbeatTable(const std::shared_ptr &context, AsyncGcsClient *client) + : Table(context, client) { + pubsub_channel_ = TablePubsub_HEARTBEAT; + prefix_ = TablePrefix_HEARTBEAT; + } + virtual ~HeartbeatTable() {} }; class FunctionTable : public Table { @@ -277,7 +298,8 @@ class TaskTable : public Table { prefix_ = TablePrefix_RAYLET_TASK; } }; -} + +} // namespace raylet class TaskTable : public Table { public: @@ -286,6 +308,7 @@ class TaskTable : public Table { pubsub_channel_ = TablePubsub_TASK; prefix_ = TablePrefix_TASK; }; + ~TaskTable(){}; using TestAndUpdateCallback = std::function { const Callback &done); }; -using ErrorTable = Table; - -using CustomSerializerTable = Table; - -using ConfigTable = Table; - Status TaskTableAdd(AsyncGcsClient *gcs_client, Task *task); Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id, @@ -363,6 +380,12 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id, SchedulingState update_state, const TaskTable::TestAndUpdateCallback &callback); +using ErrorTable = Table; + +using CustomSerializerTable = Table; + +using ConfigTable = Table; + /// \class ClientTable /// /// The ClientTable stores information about active and inactive clients. It is @@ -377,17 +400,20 @@ class ClientTable : private Log { using ClientTableCallback = std::function; ClientTable(const std::shared_ptr &context, AsyncGcsClient *client, - const ClientTableDataT &local_client) + const ClientID &client_id) : Log(context, client), // We set the client log's key equal to nil so that all instances of // ClientTable have the same key. client_log_key_(UniqueID::nil()), disconnected_(false), - client_id_(ClientID::from_binary(local_client.client_id)), - local_client_(local_client) { + client_id_(client_id), + local_client_() { pubsub_channel_ = TablePubsub_CLIENT; prefix_ = TablePrefix_CLIENT; + // Set the local client's ID. + local_client_.client_id = client_id.binary(); + // Add a nil client to the cache so that we can serve requests for clients // that we have not heard about. ClientTableDataT nil_client; @@ -398,8 +424,10 @@ class ClientTable : private Log { /// Connect as a client to the GCS. This registers us in the client table /// and begins subscription to client table notifications. /// + /// \param Information about the connecting client. This must have the + /// same client_id as the one set in the client table. /// \return Status - ray::Status Connect(); + ray::Status Connect(const ClientTableDataT &local_client); /// Disconnect the client from the GCS. The client ID assigned during /// registration should never be reused after disconnecting. diff --git a/src/ray/gcs/task_table.cc b/src/ray/gcs/task_table.cc index a60ab148e..1e3471cc4 100644 --- a/src/ray/gcs/task_table.cc +++ b/src/ray/gcs/task_table.cc @@ -1,13 +1,15 @@ #include "ray/gcs/tables.h" #include "ray/gcs/client.h" +#include "ray/id.h" #include "common_protocol.h" #include "task.h" // TODO(swang): This file extends tables.cc so that we can separate out the -// part that depends on the Task* datasturcture from the build. This should be -// merged with tables.cc once we get rid of the Task* datastructure. +// part that depends on the legacy Task* data structure from the build. This +// should be merged with tables.cc once we get rid of the legacy Task* +// datastructure. namespace { diff --git a/src/ray/id.cc b/src/ray/id.cc index 662ac2e06..f3c2b3212 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -2,6 +2,9 @@ #include +#include "ray/constants.h" +#include "ray/status.h" + namespace ray { UniqueID::UniqueID(const plasma::UniqueID &from) { @@ -83,4 +86,48 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id) { return os; } +const ObjectID ComputeObjectId(TaskID task_id, int64_t object_index) { + RAY_CHECK(object_index <= kMaxTaskReturns && object_index >= -kMaxTaskPuts); + ObjectID return_id = task_id; + int64_t *first_bytes = reinterpret_cast(&return_id); + // Zero out the lowest kObjectIdIndexSize bits of the first byte of the + // object ID. + uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; + *first_bytes = *first_bytes & (bitmask); + // OR the first byte of the object ID with the return index. + *first_bytes = *first_bytes | (object_index & ~bitmask); + return return_id; +} + +const TaskID FinishTaskId(const TaskID &task_id) { return ComputeObjectId(task_id, 0); } + +const ObjectID ComputeReturnId(TaskID task_id, int64_t return_index) { + RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns); + return ComputeObjectId(task_id, return_index); +} + +const ObjectID ComputePutId(TaskID task_id, int64_t put_index) { + RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts); + return ComputeObjectId(task_id, -1 * put_index); +} + +const TaskID ComputeTaskId(const ObjectID &object_id) { + TaskID task_id = object_id; + int64_t *first_bytes = reinterpret_cast(&task_id); + // Zero out the lowest kObjectIdIndexSize bits of the first byte of the + // object ID. + uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; + *first_bytes = *first_bytes & (bitmask); + return task_id; +} + +int64_t ComputeObjectIndex(const ObjectID &object_id) { + const int64_t *first_bytes = reinterpret_cast(&object_id); + uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; + int64_t index = *first_bytes & (~bitmask); + index <<= (8 * sizeof(int64_t) - kObjectIdIndexSize); + index >>= (8 * sizeof(int64_t) - kObjectIdIndexSize); + return index; +} + } // namespace ray diff --git a/src/ray/id.h b/src/ray/id.h index fcdbf69af..db8958cc8 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -10,9 +10,6 @@ #include "ray/constants.h" #include "ray/util/visibility.h" -// TODO(swang): Make task ID prefix of any object ID return values and puts so -// that we can co-locate task and object entries in the GCS. - namespace ray { class RAY_EXPORT UniqueID { @@ -61,6 +58,44 @@ typedef UniqueID DriverID; typedef UniqueID ConfigID; typedef UniqueID ClientID; +// TODO(swang): ObjectID and TaskID should derive from UniqueID. Then, we +// can make these methods of the derived classes. +/// Finish computing a task ID. Since objects created by the task share a +/// prefix of the ID, the suffix of the task ID is zeroed out by this function. +/// +/// \param task_id A task ID to finish. +/// \return The finished task ID. It may now be used to compute IDs for objects +/// created by the task. +const TaskID FinishTaskId(const TaskID &task_id); + +/// Compute the object ID of an object returned by the task. +/// +/// \param task_id The task ID of the task that created the object. +/// \param put_index What number return value this object is in the task. +/// \return The computed object ID. +const ObjectID ComputeReturnId(TaskID task_id, int64_t return_index); + +/// Compute the object ID of an object put by the task. +/// +/// \param task_id The task ID of the task that created the object. +/// \param put_index What number put this object was created by in the task. +/// \return The computed object ID. +const ObjectID ComputePutId(TaskID task_id, int64_t put_index); + +/// Compute the task ID of the task that created the object. +/// +/// \param object_id The object ID. +/// \return The task ID of the task that created this object. +const TaskID ComputeTaskId(const ObjectID &object_id); + +/// Compute the index of this object in the task that created it. +/// +/// \param object_id The object ID. +/// \return The index of object creation according to the task that created +/// this object. This is positive if the task returned the object and negative +/// if created by a put. +int64_t ComputeObjectIndex(const ObjectID &object_id); + } // namespace ray #endif // RAY_ID_H_ diff --git a/src/ray/object_manager/CMakeLists.txt b/src/ray/object_manager/CMakeLists.txt index 0068845e8..7b380ead8 100644 --- a/src/ray/object_manager/CMakeLists.txt +++ b/src/ray/object_manager/CMakeLists.txt @@ -19,7 +19,8 @@ add_custom_command( add_custom_target(gen_object_manager_fbs DEPENDS ${OBJECT_MANAGER_FBS_OUTPUT_FILES}) -ADD_RAY_TEST(object_manager_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY}) +ADD_RAY_TEST(test/object_manager_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY}) +ADD_RAY_TEST(test/object_manager_stress_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY}) add_library(object_manager object_manager.cc object_manager.h ${OBJECT_MANAGER_FBS_OUTPUT_FILES}) target_link_libraries(object_manager common ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY}) diff --git a/src/ray/object_manager/connection_pool.cc b/src/ray/object_manager/connection_pool.cc new file mode 100644 index 000000000..0fe0f2161 --- /dev/null +++ b/src/ray/object_manager/connection_pool.cc @@ -0,0 +1,113 @@ +#include "ray/object_manager/connection_pool.h" + +namespace ray { + +ConnectionPool::ConnectionPool() {} + +void ConnectionPool::RegisterReceiver(ConnectionType type, const ClientID &client_id, + std::shared_ptr &conn) { + std::unique_lock guard(connection_mutex); + switch (type) { + case ConnectionType::MESSAGE: { + Add(message_receive_connections_, client_id, conn); + } break; + case ConnectionType::TRANSFER: { + Add(transfer_receive_connections_, client_id, conn); + } break; + } +} + +void ConnectionPool::RemoveReceiver(ConnectionType type, const ClientID &client_id, + std::shared_ptr &conn) { + std::unique_lock guard(connection_mutex); + switch (type) { + case ConnectionType::MESSAGE: { + Remove(message_receive_connections_, client_id, conn); + } break; + case ConnectionType::TRANSFER: { + Remove(transfer_receive_connections_, client_id, conn); + } break; + } + // TODO(hme): appropriately dispose of client connection. +} + +void ConnectionPool::RegisterSender(ConnectionType type, const ClientID &client_id, + std::shared_ptr &conn) { + std::unique_lock guard(connection_mutex); + SenderMapType &conn_map = (type == ConnectionType::MESSAGE) + ? message_send_connections_ + : transfer_send_connections_; + Add(conn_map, client_id, conn); + // Don't add to available connections. It will become available once it is released. +} + +ray::Status ConnectionPool::GetSender(ConnectionType type, const ClientID &client_id, + std::shared_ptr *conn) { + std::unique_lock guard(connection_mutex); + SenderMapType &avail_conn_map = (type == ConnectionType::MESSAGE) + ? available_message_send_connections_ + : available_transfer_send_connections_; + if (Count(avail_conn_map, client_id) > 0) { + *conn = Borrow(avail_conn_map, client_id); + } else { + *conn = nullptr; + } + return ray::Status::OK(); +} + +ray::Status ConnectionPool::ReleaseSender(ConnectionType type, + std::shared_ptr conn) { + std::unique_lock guard(connection_mutex); + SenderMapType &conn_map = (type == ConnectionType::MESSAGE) + ? available_message_send_connections_ + : available_transfer_send_connections_; + Return(conn_map, conn->GetClientID(), conn); + return ray::Status::OK(); +} + +void ConnectionPool::Add(ReceiverMapType &conn_map, const ClientID &client_id, + std::shared_ptr conn) { + conn_map[client_id].push_back(conn); +} + +void ConnectionPool::Add(SenderMapType &conn_map, const ClientID &client_id, + std::shared_ptr conn) { + conn_map[client_id].push_back(conn); +} + +void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id, + std::shared_ptr conn) { + if (conn_map.count(client_id) == 0) { + return; + } + std::vector> &connections = conn_map[client_id]; + int64_t pos = + std::find(connections.begin(), connections.end(), conn) - connections.begin(); + if (pos >= (int64_t)connections.size()) { + return; + } + connections.erase(connections.begin() + pos); +} + +uint64_t ConnectionPool::Count(SenderMapType &conn_map, const ClientID &client_id) { + if (conn_map.count(client_id) == 0) { + return 0; + }; + return conn_map[client_id].size(); +} + +std::shared_ptr ConnectionPool::Borrow(SenderMapType &conn_map, + const ClientID &client_id) { + std::shared_ptr conn = conn_map[client_id].back(); + conn_map[client_id].pop_back(); + RAY_LOG(DEBUG) << "Borrow " << client_id << " " << conn_map[client_id].size(); + return conn; +} + +void ConnectionPool::Return(SenderMapType &conn_map, const ClientID &client_id, + std::shared_ptr conn) { + conn_map[client_id].push_back(conn); + RAY_LOG(DEBUG) << "Return " << client_id << " " << conn_map[client_id].size(); +} + +} // namespace ray diff --git a/src/ray/object_manager/connection_pool.h b/src/ray/object_manager/connection_pool.h new file mode 100644 index 000000000..0fba1b50c --- /dev/null +++ b/src/ray/object_manager/connection_pool.h @@ -0,0 +1,143 @@ +#ifndef RAY_OBJECT_MANAGER_CONNECTION_POOL_H +#define RAY_OBJECT_MANAGER_CONNECTION_POOL_H + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "ray/id.h" +#include "ray/status.h" + +#include +#include "ray/object_manager/format/object_manager_generated.h" +#include "ray/object_manager/object_directory.h" +#include "ray/object_manager/object_manager_client_connection.h" + +namespace asio = boost::asio; + +namespace ray { + +class ConnectionPool { + public: + /// Callbacks for GetSender. + using SuccessCallback = std::function)>; + using FailureCallback = std::function; + + /// Connection type to distinguish between message and transfer connections. + enum class ConnectionType : int { MESSAGE = 0, TRANSFER }; + + /// Connection pool for all connections needed by the ObjectManager. + ConnectionPool(); + + /// Register a receiver connection. + /// + /// \param type The type of connection. + /// \param client_id The ClientID of the remote object manager. + /// \param conn The actual connection. + void RegisterReceiver(ConnectionType type, const ClientID &client_id, + std::shared_ptr &conn); + + /// Remove a receiver connection. + /// + /// \param type The type of connection. + /// \param client_id The ClientID of the remote object manager. + /// \param conn The actual connection. + void RemoveReceiver(ConnectionType type, const ClientID &client_id, + std::shared_ptr &conn); + + /// Register a receiver connection. + /// + /// \param type The type of connection. + /// \param client_id The ClientID of the remote object manager. + /// \param conn The actual connection. + void RegisterSender(ConnectionType type, const ClientID &client_id, + std::shared_ptr &conn); + + /// Get a sender connection from the connection pool. + /// The connection must be released or removed when the operation for which the + /// connection was obtained is completed. If the connection pool is empty, the + /// connection pointer passed in is set to a null pointer. + /// + /// \param[in] type The type of connection. + /// \param[in] client_id The ClientID of the remote object manager. + /// \param[out] conn An empty pointer to a shared pointer. + /// \return Status of invoking this method. + ray::Status GetSender(ConnectionType type, const ClientID &client_id, + std::shared_ptr *conn); + + /// Releases a sender connection, allowing it to be used by another operation. + /// + /// \param type The type of connection. + /// \param conn The actual connection. + /// \return Status of invoking this method. + ray::Status ReleaseSender(ConnectionType type, std::shared_ptr conn); + + // TODO(hme): Implement with error handling. + /// Remove a sender connection. This is invoked if the connection is no longer + /// usable. + /// + /// \param type The type of connection. + /// \param conn The actual connection. + /// \return Status of invoking this method. + ray::Status RemoveSender(ConnectionType type, std::shared_ptr conn); + + /// This object cannot be copied for thread-safety. + ConnectionPool &operator=(const ConnectionPool &o) { + throw std::runtime_error("Can't copy ConnectionPool."); + } + + private: + /// A container type that maps ClientID to a connection type. + using SenderMapType = + std::unordered_map>, + ray::UniqueIDHasher>; + using ReceiverMapType = + std::unordered_map>, + ray::UniqueIDHasher>; + + /// Adds a receiver for ClientID to the given map. + void Add(ReceiverMapType &conn_map, const ClientID &client_id, + std::shared_ptr conn); + + /// Adds a sender for ClientID to the given map. + void Add(SenderMapType &conn_map, const ClientID &client_id, + std::shared_ptr conn); + + /// Removes the given receiver for ClientID from the given map. + void Remove(ReceiverMapType &conn_map, const ClientID &client_id, + std::shared_ptr conn); + + /// Returns the count of sender connections to ClientID. + uint64_t Count(SenderMapType &conn_map, const ClientID &client_id); + + /// Removes a sender connection to ClientID from the pool of available connections. + /// This method assumes conn_map has available connections to ClientID. + std::shared_ptr Borrow(SenderMapType &conn_map, + const ClientID &client_id); + + /// Returns a sender connection to ClientID to the pool of available connections. + void Return(SenderMapType &conn_map, const ClientID &client_id, + std::shared_ptr conn); + + // TODO(hme): Optimize with separate mutex per collection. + std::mutex connection_mutex; + + SenderMapType message_send_connections_; + SenderMapType transfer_send_connections_; + SenderMapType available_message_send_connections_; + SenderMapType available_transfer_send_connections_; + + ReceiverMapType message_receive_connections_; + ReceiverMapType transfer_receive_connections_; +}; + +} // namespace ray + +#endif // RAY_OBJECT_MANAGER_CONNECTION_POOL_H diff --git a/src/ray/object_manager/format/object_manager.fbs b/src/ray/object_manager/format/object_manager.fbs index 1bb6b5370..d1583e6c1 100644 --- a/src/ray/object_manager/format/object_manager.fbs +++ b/src/ray/object_manager/format/object_manager.fbs @@ -1,30 +1,37 @@ // Object Manager protocol specification +namespace ray.object_manager.protocol; -enum OMMessageType:int { - PullRequest = 1 +enum MessageType:int { + ConnectClient = 1, + DisconnectClient, + PushRequest, + PullRequest } -table PushRequest { - +table PushRequestMessage { + // The object ID being transferred. + object_id: string; + // The size of the object being transferred. + object_size: ulong; } -table PullRequest { +table PullRequestMessage { // ID of the requesting client. client_id: string; // Requested ObjectID. object_id: string; } -table ClientConnectionInfo { +table ConnectClientMessage { // ID of the connecting client. client_id: string; // Whether this is a transfer connection. is_transfer: bool; } -table ObjectHeader { - // The object ID being transferred. - object_id: string; - // The size of the object being transferred. - object_size: ulong; +table DisconnectClientMessage { + // ID of the connecting client. + client_id: string; + // Whether this is a transfer connection. + is_transfer: bool; } diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 406f63523..b4e131bbd 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -1,41 +1,53 @@ -#include "object_directory.h" +#include "ray/object_manager/object_directory.h" namespace ray { -ObjectDirectory::ObjectDirectory(std::shared_ptr gcs_client) { +ObjectDirectory::ObjectDirectory(std::shared_ptr gcs_client) { gcs_client_ = gcs_client; }; ray::Status ObjectDirectory::ReportObjectAdded(const ObjectID &object_id, const ClientID &client_id) { - return gcs_client_->object_table().Add(object_id, client_id, [] {}); + // TODO(hme): Determine whether we need to do lookup to append. + JobID job_id = JobID::from_random(); + auto data = std::make_shared(); + data->manager = client_id.binary(); + data->is_eviction = false; + ray::Status status = gcs_client_->object_table().Append( + job_id, object_id, data, [](gcs::AsyncGcsClient *client, const UniqueID &id, + const std::shared_ptr data) { + // Do nothing. + }); + return status; }; ray::Status ObjectDirectory::ReportObjectRemoved(const ObjectID &object_id, const ClientID &client_id) { - return gcs_client_->object_table().Remove(object_id, client_id, [] {}); + // TODO(hme): Need corresponding remove method in GCS. + return ray::Status::NotImplemented("ObjectTable.Remove is not implemented"); }; ray::Status ObjectDirectory::GetInformation(const ClientID &client_id, - const InfoSuccessCallback &success_cb, - const InfoFailureCallback &fail_cb) { - gcs_client_->client_table().GetClientInformation( - client_id, - [this, success_cb, client_id](ClientInformation client_info) { - const auto &info = - RemoteConnectionInfo(client_id, client_info.GetIp(), client_info.GetPort()); - success_cb(info); - }, - fail_cb); + const InfoSuccessCallback &success_callback, + const InfoFailureCallback &fail_callback) { + const ClientTableDataT &data = gcs_client_->client_table().GetClient(client_id); + ClientID result_client_id = ClientID::from_binary(data.client_id); + if (result_client_id == ClientID::nil() || !data.is_insertion) { + fail_callback(ray::Status::RedisError("ClientID not found.")); + } else { + const auto &info = RemoteConnectionInfo(client_id, data.node_manager_address, + (uint16_t)data.object_manager_port); + success_callback(info); + } return ray::Status::OK(); }; ray::Status ObjectDirectory::GetLocations(const ObjectID &object_id, - const OnLocationsSuccess &success_cb, - const OnLocationsFailure &fail_cb) { + const OnLocationsSuccess &success_callback, + const OnLocationsFailure &fail_callback) { ray::Status status_code = ray::Status::OK(); if (existing_requests_.count(object_id) == 0) { - existing_requests_[object_id] = ODCallbacks({success_cb, fail_cb}); + existing_requests_[object_id] = ODCallbacks({success_callback, fail_callback}); status_code = ExecuteGetLocations(object_id); } else { // Do nothing. A request is in progress. @@ -44,52 +56,44 @@ ray::Status ObjectDirectory::GetLocations(const ObjectID &object_id, }; ray::Status ObjectDirectory::ExecuteGetLocations(const ObjectID &object_id) { - // TODO(hme): Avoid callback hell. - std::vector remote_connections; - ray::Status status = gcs_client_->object_table().GetObjectClientIDs( - object_id, - [this, object_id, &remote_connections](const std::vector &client_ids) { - gcs_client_->client_table().GetClientInformationSet( - client_ids, - [this, object_id, - &remote_connections](const std::vector &info_vec) { - for (const auto &client_info : info_vec) { - RemoteConnectionInfo info = - RemoteConnectionInfo(client_info.GetClientId(), client_info.GetIp(), - client_info.GetPort()); - remote_connections.push_back(info); - } - ray::Status cb_completion_status = - GetLocationsComplete(Status::OK(), object_id, remote_connections); - }, - [this, object_id, &remote_connections](const Status &status) { - ray::Status cb_completion_status = - GetLocationsComplete(status, object_id, remote_connections); - }); - }, - [this, object_id, &remote_connections](const Status &status) { - ray::Status cb_completion_status = - GetLocationsComplete(status, object_id, remote_connections); + JobID job_id = JobID::from_random(); + // Note: Lookup must be synchronous for thread-safe access. + // For now, this is only accessed by the main thread. + ray::Status status = gcs_client_->object_table().Lookup( + job_id, object_id, + [this, object_id](gcs::AsyncGcsClient *client, const ObjectID &object_id, + const std::vector &data) { + GetLocationsComplete(object_id, data); }); return status; }; -ray::Status ObjectDirectory::GetLocationsComplete( - const ray::Status &status, const ObjectID &object_id, - const std::vector &remote_connections) { - bool success = status.ok(); - // Only invoke a callback if the request was not cancelled. - if (existing_requests_.count(object_id) > 0) { - ODCallbacks cbs = existing_requests_[object_id]; - if (success) { - cbs.success_cb(remote_connections, object_id); +void ObjectDirectory::GetLocationsComplete( + const ObjectID &object_id, const std::vector &location_entries) { + auto request = existing_requests_.find(object_id); + // Do not invoke a callback if the request was cancelled. + if (request == existing_requests_.end()) { + return; + } + // Build the set of current locations based on the entries in the log. + std::unordered_set locations; + for (auto entry : location_entries) { + ClientID client_id = ClientID::from_binary(entry.manager); + if (!entry.is_eviction) { + locations.insert(client_id); } else { - cbs.fail_cb(status, object_id); + locations.erase(client_id); } } - existing_requests_.erase(object_id); - return status; -}; + // Invoke the callback. + std::vector locations_vector(locations.begin(), locations.end()); + if (locations_vector.empty()) { + request->second.fail_cb(object_id); + } else { + request->second.success_cb(locations_vector, object_id); + } + existing_requests_.erase(request); +} ray::Status ObjectDirectory::Cancel(const ObjectID &object_id) { existing_requests_.erase(object_id); diff --git a/src/ray/object_manager/object_directory.h b/src/ray/object_manager/object_directory.h index 463913125..b1a33de99 100644 --- a/src/ray/object_manager/object_directory.h +++ b/src/ray/object_manager/object_directory.h @@ -2,17 +2,19 @@ #define RAY_OBJECT_MANAGER_OBJECT_DIRECTORY_H #include +#include #include #include #include +#include "ray/gcs/client.h" #include "ray/id.h" -#include "ray/raylet/mock_gcs_client.h" #include "ray/status.h" namespace ray { struct RemoteConnectionInfo { + RemoteConnectionInfo() = default; RemoteConnectionInfo(const ClientID &id, const std::string &ip_address, uint16_t port_num) : client_id(id), ip(ip_address), port(port_num) {} @@ -42,10 +44,9 @@ class ObjectDirectoryInterface { const InfoFailureCallback &fail_cb) = 0; // Callbacks for GetLocations. - using OnLocationsSuccess = std::function &v, const ray::ObjectID &object_id)>; - using OnLocationsFailure = - std::function; + using OnLocationsSuccess = std::function &v, + const ray::ObjectID &object_id)>; + using OnLocationsFailure = std::function; /// Asynchronously obtain the locations of an object by ObjectID. /// This is used to handle object pulls. @@ -93,11 +94,11 @@ class ObjectDirectory : public ObjectDirectoryInterface { ~ObjectDirectory() override = default; ray::Status GetInformation(const ClientID &client_id, - const InfoSuccessCallback &success_cb, - const InfoFailureCallback &fail_cb) override; + const InfoSuccessCallback &success_callback, + const InfoFailureCallback &fail_callback) override; ray::Status GetLocations(const ObjectID &object_id, - const OnLocationsSuccess &success_cb, - const OnLocationsFailure &fail_cb) override; + const OnLocationsSuccess &success_callback, + const OnLocationsFailure &fail_callback) override; ray::Status Cancel(const ObjectID &object_id) override; ray::Status Terminate() override; ray::Status ReportObjectAdded(const ObjectID &object_id, @@ -105,7 +106,12 @@ class ObjectDirectory : public ObjectDirectoryInterface { ray::Status ReportObjectRemoved(const ObjectID &object_id, const ClientID &client_id) override; /// Ray only (not part of the OD interface). - ObjectDirectory(std::shared_ptr gcs_client); + ObjectDirectory(std::shared_ptr gcs_client); + + /// This object cannot be copied for thread-safety. + ObjectDirectory &operator=(const ObjectDirectory &o) { + throw std::runtime_error("Can't copy ObjectDirectory."); + } private: /// Callbacks associated with a call to GetLocations. @@ -115,18 +121,17 @@ class ObjectDirectory : public ObjectDirectoryInterface { OnLocationsFailure fail_cb; }; - /// Maintain map of in-flight GetLocation requests. - std::unordered_map existing_requests_; - - /// Reference to the gcs client. - std::shared_ptr gcs_client_; - /// GetLocations registers a request for locations. /// This function actually carries out that request. ray::Status ExecuteGetLocations(const ObjectID &object_id); /// Invoked when call to ExecuteGetLocations completes. - ray::Status GetLocationsComplete(const ray::Status &status, const ObjectID &object_id, - const std::vector &v); + void GetLocationsComplete(const ObjectID &object_id, + const std::vector &location_entries); + + /// Maintain map of in-flight GetLocation requests. + std::unordered_map existing_requests_; + /// Reference to the gcs client. + std::shared_ptr gcs_client_; }; } // namespace ray diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 508a24ca4..a815efd5e 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -1,58 +1,80 @@ -#include "object_manager.h" +#include "ray/object_manager/object_manager.h" + +namespace asio = boost::asio; + +namespace object_manager_protocol = ray::object_manager::protocol; namespace ray { -ObjectManager::ObjectManager(boost::asio::io_service &io_service, - ObjectManagerConfig config, - std::shared_ptr gcs_client) - : object_directory_(new ObjectDirectory(gcs_client)), work_(io_service_) { +ObjectManager::ObjectManager(asio::io_service &main_service, + std::unique_ptr object_manager_service, + const ObjectManagerConfig &config, + std::shared_ptr gcs_client) + // TODO(hme): Eliminate knowledge of GCS. + : client_id_(gcs_client->client_table().GetLocalClientId()), + object_directory_(new ObjectDirectory(gcs_client)), + store_notification_(main_service, config.store_socket_name), + store_pool_(config.store_socket_name), + object_manager_service_(std::move(object_manager_service)), + work_(*object_manager_service_), + connection_pool_(), + transfer_queue_(), + num_transfers_send_(0), + num_transfers_receive_(0) { + main_service_ = &main_service; config_ = config; - store_client_ = std::unique_ptr( - new ObjectStoreClient(io_service, config.store_socket_name)); - store_client_->SubscribeObjAdded( + store_notification_.SubscribeObjAdded( [this](const ObjectID &oid) { NotifyDirectoryObjectAdd(oid); }); - store_client_->SubscribeObjDeleted( + store_notification_.SubscribeObjDeleted( [this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); }); StartIOService(); -}; +} -ObjectManager::ObjectManager(boost::asio::io_service &io_service, - ObjectManagerConfig config, +ObjectManager::ObjectManager(asio::io_service &main_service, + std::unique_ptr object_manager_service, + const ObjectManagerConfig &config, std::unique_ptr od) - : object_directory_(std::move(od)), work_(io_service_) { + : object_directory_(std::move(od)), + store_notification_(main_service, config.store_socket_name), + store_pool_(config.store_socket_name), + object_manager_service_(std::move(object_manager_service)), + work_(*object_manager_service_), + connection_pool_(), + transfer_queue_(), + num_transfers_send_(0), + num_transfers_receive_(0) { + // TODO(hme) Client ID is never set with this constructor. + main_service_ = &main_service; config_ = config; - store_client_ = std::unique_ptr( - new ObjectStoreClient(io_service, config.store_socket_name)); - store_client_->SubscribeObjAdded( + store_notification_.SubscribeObjAdded( [this](const ObjectID &oid) { NotifyDirectoryObjectAdd(oid); }); - store_client_->SubscribeObjDeleted( + store_notification_.SubscribeObjDeleted( [this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); }); StartIOService(); -}; +} void ObjectManager::StartIOService() { - io_thread_ = std::thread(&ObjectManager::IOServiceLoop, this); - // thread_group_.create_thread(boost::bind(&boost::asio::io_service::run, - // &io_service_)); + for (int i = 0; i < config_.num_threads; ++i) { + io_threads_.emplace_back(std::thread(&ObjectManager::IOServiceLoop, this)); + } } -void ObjectManager::IOServiceLoop() { io_service_.run(); } +void ObjectManager::IOServiceLoop() { object_manager_service_->run(); } void ObjectManager::StopIOService() { - io_service_.stop(); - io_thread_.join(); - // thread_group_.join_all(); + object_manager_service_->stop(); + for (int i = 0; i < config_.num_threads; ++i) { + io_threads_[i].join(); + } } -void ObjectManager::SetClientID(const ClientID &client_id) { client_id_ = client_id; } - -ClientID ObjectManager::GetClientID() { return client_id_; } - void ObjectManager::NotifyDirectoryObjectAdd(const ObjectID &object_id) { + local_objects_.insert(object_id); ray::Status status = object_directory_->ReportObjectAdded(object_id, client_id_); } void ObjectManager::NotifyDirectoryObjectDeleted(const ObjectID &object_id) { + local_objects_.erase(object_id); ray::Status status = object_directory_->ReportObjectRemoved(object_id, client_id_); } @@ -60,391 +82,450 @@ ray::Status ObjectManager::Terminate() { StopIOService(); ray::Status status_code = object_directory_->Terminate(); // TODO: evaluate store client termination status. - store_client_->Terminate(); + store_notification_.Terminate(); + store_pool_.Terminate(); return status_code; -}; +} ray::Status ObjectManager::SubscribeObjAdded( std::function callback) { - store_client_->SubscribeObjAdded(callback); + store_notification_.SubscribeObjAdded(callback); return ray::Status::OK(); -}; +} ray::Status ObjectManager::SubscribeObjDeleted( std::function callback) { - store_client_->SubscribeObjDeleted(callback); + store_notification_.SubscribeObjDeleted(callback); return ray::Status::OK(); -}; +} ray::Status ObjectManager::Pull(const ObjectID &object_id) { - // TODO(hme): Need to correct. Workaround to get all pull requests on the same thread. - SchedulePull(object_id, 0); + main_service_->dispatch( + [this, object_id]() { RAY_CHECK_OK(PullGetLocations(object_id)); }); return Status::OK(); -}; +} void ObjectManager::SchedulePull(const ObjectID &object_id, int wait_ms) { - pull_requests_[object_id] = Timer(new boost::asio::deadline_timer( - io_service_, boost::posix_time::milliseconds(wait_ms))); + pull_requests_[object_id] = std::shared_ptr( + new asio::deadline_timer(*main_service_, boost::posix_time::milliseconds(wait_ms))); pull_requests_[object_id]->async_wait( [this, object_id](const boost::system::error_code &error_code) { - RAY_CHECK_OK(SchedulePullHandler(object_id)); + pull_requests_.erase(object_id); + main_service_->dispatch( + [this, object_id]() { RAY_CHECK_OK(PullGetLocations(object_id)); }); }); } -ray::Status ObjectManager::SchedulePullHandler(const ObjectID &object_id) { - pull_requests_.erase(object_id); +ray::Status ObjectManager::PullGetLocations(const ObjectID &object_id) { ray::Status status_code = object_directory_->GetLocations( object_id, - [this](const std::vector &vec, const ObjectID &object_id) { - return GetLocationsSuccess(vec, object_id); + [this](const std::vector &client_ids, const ObjectID &object_id) { + return GetLocationsSuccess(client_ids, object_id); }, - [this](ray::Status status, const ObjectID &object_id) { - return GetLocationsFailed(status, object_id); - }); + [this](const ObjectID &object_id) { return GetLocationsFailed(object_id); }); return status_code; } -void ObjectManager::GetLocationsSuccess(const std::vector &vec, +void ObjectManager::GetLocationsSuccess(const std::vector &client_ids, const ray::ObjectID &object_id) { - RemoteConnectionInfo info = vec.front(); + RAY_CHECK(!client_ids.empty()); + ClientID client_id = client_ids.front(); pull_requests_.erase(object_id); - ray::Status status_code = Pull(object_id, info.client_id); -}; + ray::Status status_code = Pull(object_id, client_id); +} -void ObjectManager::GetLocationsFailed(ray::Status status, const ObjectID &object_id) { +void ObjectManager::GetLocationsFailed(const ObjectID &object_id) { SchedulePull(object_id, config_.pull_timeout_ms); -}; +} ray::Status ObjectManager::Pull(const ObjectID &object_id, const ClientID &client_id) { - Status status = - GetMsgConnection(client_id, [this, object_id](SenderConnection::pointer client) { - Status status = ExecutePull(object_id, client); - }); - return status; + main_service_->dispatch([this, object_id, client_id]() { + RAY_CHECK_OK(PullEstablishConnection(object_id, client_id)); + }); + return Status::OK(); }; -ray::Status ObjectManager::ExecutePull(const ObjectID &object_id, - SenderConnection::pointer conn) { - size_t message_type = OMMessageType_PullRequest; - boost::system::error_code error_code; - boost::asio::write(conn->GetSocket(), - boost::asio::buffer(&message_type, sizeof(message_type)), - error_code); +ray::Status ObjectManager::PullEstablishConnection(const ObjectID &object_id, + const ClientID &client_id) { + // Check if object is already local, and client_id is not itself. + if (local_objects_.count(object_id) != 0 || client_id == client_id_) { + return ray::Status::OK(); + } + + // Acquire a message connection and send pull request. + ray::Status status; + std::shared_ptr conn; + // TODO(hme): There is no cap on the number of pull request connections. + status = connection_pool_.GetSender(ConnectionPool::ConnectionType::MESSAGE, client_id, + &conn); + if (!status.ok()) { + // TODO(hme): Keep track of retries, + // and only retry on object not local + // for now. + SchedulePull(object_id, config_.pull_timeout_ms); + return status; + } + if (conn == nullptr) { + status = object_directory_->GetInformation( + client_id, + [this, object_id, client_id](const RemoteConnectionInfo &connection_info) { + std::shared_ptr async_conn = CreateSenderConnection( + ConnectionPool::ConnectionType::MESSAGE, connection_info); + connection_pool_.RegisterSender(ConnectionPool::ConnectionType::MESSAGE, + client_id, async_conn); + RAY_CHECK_OK(PullSendRequest(object_id, async_conn)); + }, + [this, object_id](const Status &status) { + SchedulePull(object_id, config_.pull_timeout_ms); + }); + } else { + RAY_CHECK_OK(PullSendRequest(object_id, conn)); + } + return status; +} + +ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id, + std::shared_ptr conn) { flatbuffers::FlatBufferBuilder fbb; - auto message = CreatePullRequest(fbb, fbb.CreateString(client_id_.binary()), - fbb.CreateString(object_id.binary())); + auto message = object_manager_protocol::CreatePullRequestMessage( + fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary())); fbb.Finish(message); - size_t length = fbb.GetSize(); - std::vector buffer; - buffer.push_back(boost::asio::buffer(&length, sizeof(length))); - buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length)); - boost::asio::write(conn->GetSocket(), buffer); + RAY_CHECK_OK(conn->WriteMessage(object_manager_protocol::MessageType_PullRequest, + fbb.GetSize(), fbb.GetBufferPointer())); + RAY_CHECK_OK( + connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn)); return ray::Status::OK(); -}; +} ray::Status ObjectManager::Push(const ObjectID &object_id, const ClientID &client_id) { - ray::Status status; - status = - GetTransferConnection(client_id, [this, object_id](SenderConnection::pointer conn) { - ray::Status status = QueuePush(object_id, conn); - }); + // TODO(hme): Cache this data in ObjectDirectory. + // Okay for now since the GCS client caches this data. + main_service_->dispatch([this, object_id, client_id]() { + Status status = object_directory_->GetInformation( + client_id, + [this, object_id, client_id](const RemoteConnectionInfo &info) { + transfer_queue_.QueueSend(client_id, object_id, info); + RAY_CHECK_OK(DequeueTransfers()); + }, + [this](const Status &status) { + // Push is best effort, so do nothing here. + }); + RAY_CHECK_OK(status); + }); + return ray::Status::OK(); +} + +ray::Status ObjectManager::DequeueTransfers() { + ray::Status status = ray::Status::OK(); + // Dequeue sends. + while (true) { + if (std::atomic_fetch_add(&num_transfers_send_, 1) <= config_.max_sends) { + TransferQueue::SendRequest req; + bool exists = transfer_queue_.DequeueSendIfPresent(&req); + if (exists) { + object_manager_service_->dispatch([this, req]() { + RAY_LOG(DEBUG) << "DequeueSend " << client_id_ << " " << req.object_id << " " + << num_transfers_send_ << "/" << config_.max_sends; + RAY_CHECK_OK( + ExecuteSendObject(req.object_id, req.client_id, req.connection_info)); + }); + } else { + std::atomic_fetch_sub(&num_transfers_send_, 1); + break; + } + } else { + std::atomic_fetch_sub(&num_transfers_send_, 1); + break; + } + } + // Dequeue receives. + while (true) { + if (std::atomic_fetch_add(&num_transfers_receive_, 1) <= config_.max_receives) { + TransferQueue::ReceiveRequest req; + bool exists = transfer_queue_.DequeueReceiveIfPresent(&req); + if (exists) { + object_manager_service_->dispatch([this, req]() { + RAY_LOG(DEBUG) << "DequeueReceive " << client_id_ << " " << req.object_id << " " + << num_transfers_receive_ << "/" << config_.max_receives; + RAY_CHECK_OK(ExecuteReceiveObject(req.client_id, req.object_id, req.object_size, + req.conn)); + }); + } else { + std::atomic_fetch_sub(&num_transfers_receive_, 1); + break; + } + } else { + std::atomic_fetch_sub(&num_transfers_receive_, 1); + break; + } + } return status; +} + +ray::Status ObjectManager::TransferCompleted(TransferQueue::TransferType type) { + if (type == TransferQueue::TransferType::SEND) { + std::atomic_fetch_sub(&num_transfers_send_, 1); + } else { + std::atomic_fetch_sub(&num_transfers_receive_, 1); + } + return DequeueTransfers(); }; +ray::Status ObjectManager::ExecuteSendObject( + const ObjectID &object_id, const ClientID &client_id, + const RemoteConnectionInfo &connection_info) { + ray::Status status; + std::shared_ptr conn; + status = connection_pool_.GetSender(ConnectionPool::ConnectionType::TRANSFER, client_id, + &conn); + if (!status.ok()) { + // TODO(hme): Keep track of retries, + // and only retry on object not local + // for now. + RAY_CHECK_OK(Push(object_id, conn->GetClientID())); + return Status::OK(); + } + if (conn == nullptr) { + conn = + CreateSenderConnection(ConnectionPool::ConnectionType::TRANSFER, connection_info); + connection_pool_.RegisterSender(ConnectionPool::ConnectionType::TRANSFER, client_id, + conn); + } + status = SendObjectHeaders(object_id, conn); + if (!status.ok()) { + RAY_CHECK_OK(Push(object_id, conn->GetClientID())); + return Status::OK(); + } + return Status::OK(); +} + +ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id_const, + std::shared_ptr conn) { + ObjectID object_id = ObjectID(object_id_const); + // Allocate and append the request to the transfer queue. + plasma::ObjectBuffer object_buffer; + plasma::ObjectID plasma_id = object_id.to_plasma_id(); + std::shared_ptr store_client = store_pool_.GetObjectStore(); + ARROW_CHECK_OK(store_client->Get(&plasma_id, 1, 0, &object_buffer)); + if (object_buffer.data_size == -1) { + RAY_LOG(ERROR) << "Failed to get object"; + // If the object wasn't locally available, exit immediately. If the object + // later appears locally, the requesting plasma manager should request the + // transfer again. + RAY_CHECK_OK( + connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::TRANSFER, conn)); + return ray::Status::IOError( + "Unable to transfer object to requesting plasma manager, object not local."); + } + RAY_CHECK(object_buffer.metadata->data() == + object_buffer.data->data() + object_buffer.data_size); + + TransferQueue::SendContext context; + context.client_id = conn->GetClientID(); + context.object_id = object_id; + context.object_size = static_cast(object_buffer.data_size); + context.data = const_cast(object_buffer.data->data()); + UniqueID context_id = transfer_queue_.AddContext(context); + + // Create buffer. + flatbuffers::FlatBufferBuilder fbb; + // TODO(hme): use to_flatbuf + auto message = object_manager_protocol::CreatePushRequestMessage( + fbb, fbb.CreateString(object_id.binary()), context.object_size); + fbb.Finish(message); + ray::Status status = + conn->WriteMessage(object_manager_protocol::MessageType_PushRequest, fbb.GetSize(), + fbb.GetBufferPointer()); + if (!status.ok()) { + // push failed. + // TODO(hme): Trash sender. + RAY_CHECK_OK( + connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::TRANSFER, conn)); + return status; + } + + // TODO(hme): Make this async. + return SendObjectData(conn, context_id, store_client); +} + +ray::Status ObjectManager::SendObjectData( + std::shared_ptr conn, const UniqueID &context_id, + std::shared_ptr store_client) { + TransferQueue::SendContext context = transfer_queue_.GetContext(context_id); + boost::system::error_code ec; + std::vector buffer; + buffer.push_back(asio::buffer(context.data, context.object_size)); + conn->WriteBuffer(buffer, ec); + + ray::Status status = ray::Status::OK(); + if (ec.value() != 0) { + // push failed. + // TODO(hme): Trash sender. + status = ray::Status::IOError(ec.message()); + } + + // Do this regardless of whether it failed or succeeded. + ARROW_CHECK_OK(store_client->Release(context.object_id.to_plasma_id())); + store_pool_.ReleaseObjectStore(store_client); + RAY_CHECK_OK( + connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::TRANSFER, conn)); + RAY_CHECK_OK(transfer_queue_.RemoveContext(context_id)); + RAY_LOG(DEBUG) << "SendCompleted " << client_id_ << " " << context.object_id << " " + << num_transfers_send_ << "/" << config_.max_sends; + RAY_CHECK_OK(TransferCompleted(TransferQueue::TransferType::SEND)); + return status; +} + ray::Status ObjectManager::Cancel(const ObjectID &object_id) { // TODO(hme): Account for pull timers. ray::Status status = object_directory_->Cancel(object_id); return ray::Status::OK(); -}; +} ray::Status ObjectManager::Wait(const std::vector &object_ids, uint64_t timeout_ms, int num_ready_objects, const WaitCallback &callback) { // TODO: Implement wait. return ray::Status::OK(); -}; - -ray::Status ObjectManager::GetMsgConnection( - const ClientID &client_id, std::function callback) { - ray::Status status = Status::OK(); - if (message_send_connections_.count(client_id) > 0) { - callback(message_send_connections_[client_id]); - } else { - status = object_directory_->GetInformation( - client_id, - [this, callback](RemoteConnectionInfo info) { - Status status = CreateMsgConnection(info, callback); - }, - [this](const Status &status) { - // TODO: deal with failure. - }); - } - return status; -}; - -ray::Status ObjectManager::CreateMsgConnection( - const RemoteConnectionInfo &info, - std::function callback) { - message_send_connections_.emplace( - info.client_id, SenderConnection::Create(io_service_, info.ip, info.port)); - // Prepare client connection info buffer. - flatbuffers::FlatBufferBuilder fbb; - bool is_transfer = false; - auto message = - CreateClientConnectionInfo(fbb, fbb.CreateString(client_id_.binary()), is_transfer); - fbb.Finish(message); - // Pack into asio buffer. - size_t length = fbb.GetSize(); - std::vector buffer; - buffer.push_back(boost::asio::buffer(&length, sizeof(length))); - buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length)); - // Send synchronously. - SenderConnection::pointer conn = message_send_connections_[info.client_id]; - boost::system::error_code error; - boost::asio::write(conn->GetSocket(), buffer); - // The connection is ready, invoke callback with connection info. - callback(message_send_connections_[info.client_id]); - return Status::OK(); -}; - -ray::Status ObjectManager::GetTransferConnection( - const ClientID &client_id, std::function callback) { - ray::Status status = Status::OK(); - if (transfer_send_connections_.count(client_id) > 0) { - callback(transfer_send_connections_[client_id]); - } else { - status = object_directory_->GetInformation( - client_id, - [this, callback](RemoteConnectionInfo info) { - Status status = CreateTransferConnection(info, callback); - }, - [this](const Status &status) { - // TODO(hme): deal with failure. - }); - } - return status; -}; - -ray::Status ObjectManager::CreateTransferConnection( - const RemoteConnectionInfo &info, - std::function callback) { - transfer_send_connections_.emplace( - info.client_id, SenderConnection::Create(io_service_, info.ip, info.port)); - // Prepare client connection info buffer. - flatbuffers::FlatBufferBuilder fbb; - bool is_transfer = true; - auto message = - CreateClientConnectionInfo(fbb, fbb.CreateString(client_id_.binary()), is_transfer); - fbb.Finish(message); - // Pack into asio buffer. - size_t length = fbb.GetSize(); - std::vector buffer; - buffer.push_back(boost::asio::buffer(&length, sizeof(length))); - buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length)); - // Send synchronously. - SenderConnection::pointer conn = transfer_send_connections_[info.client_id]; - boost::system::error_code ec; - boost::asio::write(conn->GetSocket(), buffer, ec); - callback(transfer_send_connections_[info.client_id]); - return Status::OK(); -}; - -ray::Status ObjectManager::AcceptConnection(TCPClientConnection::pointer conn) { - boost::system::error_code ec; - // read header - size_t length; - std::vector header; - header.push_back(boost::asio::buffer(&length, sizeof(length))); - boost::asio::read(conn->GetSocket(), header, ec); - // read data - std::vector message; - message.resize(length); - boost::asio::read(conn->GetSocket(), boost::asio::buffer(message), ec); - // Serialize - auto info = flatbuffers::GetRoot(message.data()); - ClientID client_id = ObjectID::from_binary(info->client_id()->str()); - bool is_transfer = info->is_transfer(); - // TODO: trash connection if either fails. - if (is_transfer) { - transfer_receive_connections_[client_id] = conn; - Status status = WaitPushReceive(conn); - return status; - } else { - message_receive_connections_[client_id] = conn; - Status status = WaitMessage(conn); - return status; - } -}; - -ray::Status ObjectManager::WaitPushReceive(TCPClientConnection::pointer conn) { - boost::asio::async_read( - conn->GetSocket(), - boost::asio::buffer(&conn->message_length_, sizeof(conn->message_length_)), - boost::bind(&ObjectManager::HandlePushReceive, this, conn, - boost::asio::placeholders::error)); - return ray::Status::OK(); } -void ObjectManager::HandlePushReceive(TCPClientConnection::pointer conn, - BoostEC length_ec) { - std::vector message; - message.resize(conn->message_length_); - boost::system::error_code ec; - boost::asio::read(conn->GetSocket(), boost::asio::buffer(message), ec); +std::shared_ptr ObjectManager::CreateSenderConnection( + ConnectionPool::ConnectionType type, RemoteConnectionInfo info) { + std::shared_ptr conn = SenderConnection::Create( + *object_manager_service_, info.client_id, info.ip, info.port); + // Prepare client connection info buffer + flatbuffers::FlatBufferBuilder fbb; + bool is_transfer = (type == ConnectionPool::ConnectionType::TRANSFER); + auto message = object_manager_protocol::CreateConnectClientMessage( + fbb, fbb.CreateString(client_id_.binary()), is_transfer); + fbb.Finish(message); + // Send synchronously. + RAY_CHECK_OK(conn->WriteMessage(object_manager_protocol::MessageType_ConnectClient, + fbb.GetSize(), fbb.GetBufferPointer())); + // The connection is ready; return to caller. + return conn; +} + +void ObjectManager::ProcessNewClient(std::shared_ptr conn) { + conn->ProcessMessages(); +} + +void ObjectManager::ProcessClientMessage(std::shared_ptr conn, + int64_t message_type, const uint8_t *message) { + switch (message_type) { + case object_manager_protocol::MessageType_PushRequest: { + ReceivePushRequest(conn, message); + break; + } + case object_manager_protocol::MessageType_PullRequest: { + ReceivePullRequest(conn, message); + break; + } + case object_manager_protocol::MessageType_ConnectClient: { + ConnectClient(conn, message); + break; + } + case object_manager_protocol::MessageType_DisconnectClient: { + DisconnectClient(conn, message); + break; + } + default: { RAY_LOG(FATAL) << "invalid request " << message_type; } + } +} + +void ObjectManager::ConnectClient(std::shared_ptr &conn, + const uint8_t *message) { + // TODO: trash connection on failure. + auto info = + flatbuffers::GetRoot(message); + ClientID client_id = ObjectID::from_binary(info->client_id()->str()); + bool is_transfer = info->is_transfer(); + conn->SetClientID(client_id); + if (is_transfer) { + connection_pool_.RegisterReceiver(ConnectionPool::ConnectionType::TRANSFER, client_id, + conn); + } else { + connection_pool_.RegisterReceiver(ConnectionPool::ConnectionType::MESSAGE, client_id, + conn); + } + conn->ProcessMessages(); +} + +void ObjectManager::DisconnectClient(std::shared_ptr &conn, + const uint8_t *message) { + auto info = + flatbuffers::GetRoot(message); + ClientID client_id = ObjectID::from_binary(info->client_id()->str()); + bool is_transfer = info->is_transfer(); + if (is_transfer) { + connection_pool_.RemoveReceiver(ConnectionPool::ConnectionType::TRANSFER, client_id, + conn); + } else { + connection_pool_.RemoveReceiver(ConnectionPool::ConnectionType::MESSAGE, client_id, + conn); + } +} + +void ObjectManager::ReceivePullRequest(std::shared_ptr &conn, + const uint8_t *message) { + // Serialize and push object to requesting client. + auto pr = flatbuffers::GetRoot(message); + ObjectID object_id = ObjectID::from_binary(pr->object_id()->str()); + ClientID client_id = ClientID::from_binary(pr->client_id()->str()); + ray::Status push_status = Push(object_id, client_id); + conn->ProcessMessages(); +} + +void ObjectManager::ReceivePushRequest(std::shared_ptr conn, + const uint8_t *message) { // Serialize. - auto object_header = flatbuffers::GetRoot(message.data()); + auto object_header = + flatbuffers::GetRoot(message); ObjectID object_id = ObjectID::from_binary(object_header->object_id()->str()); int64_t object_size = (int64_t)object_header->object_size(); + transfer_queue_.QueueReceive(conn->GetClientID(), object_id, object_size, conn); + RAY_CHECK_OK(DequeueTransfers()); +} + +ray::Status ObjectManager::ExecuteReceiveObject( + const ClientID &client_id, const ObjectID &object_id, uint64_t object_size, + std::shared_ptr conn) { + boost::system::error_code ec; int64_t metadata_size = 0; + const plasma::ObjectID plasma_id = ObjectID(object_id).to_plasma_id(); // Try to create shared buffer. std::shared_ptr data; - arrow::Status s = store_client_->GetClient().Create( - object_id.to_plasma_id(), object_size, NULL, metadata_size, &data); + std::shared_ptr store_client = store_pool_.GetObjectStore(); + arrow::Status s = + store_client->Create(plasma_id, object_size, NULL, metadata_size, &data); + std::vector buffer; if (s.ok()) { // Read object into store. uint8_t *mutable_data = data->mutable_data(); - boost::asio::read(conn->GetSocket(), boost::asio::buffer(mutable_data, object_size), - ec); + buffer.push_back(asio::buffer(mutable_data, object_size)); + conn->ReadBuffer(buffer, ec); if (!ec.value()) { - ARROW_CHECK_OK(store_client_->GetClient().Seal(object_id.to_plasma_id())); - ARROW_CHECK_OK(store_client_->GetClient().Release(object_id.to_plasma_id())); + ARROW_CHECK_OK(store_client->Seal(plasma_id)); + ARROW_CHECK_OK(store_client->Release(plasma_id)); } else { - ARROW_CHECK_OK(store_client_->GetClient().Release(object_id.to_plasma_id())); - ARROW_CHECK_OK(store_client_->GetClient().Abort(object_id.to_plasma_id())); + ARROW_CHECK_OK(store_client->Release(plasma_id)); + ARROW_CHECK_OK(store_client->Abort(plasma_id)); RAY_LOG(ERROR) << "Receive Failed"; } } else { RAY_LOG(ERROR) << "Buffer Create Failed: " << s.message(); // Read object into empty buffer. - uint8_t *mutable_data = (uint8_t *)malloc(object_size + metadata_size); - boost::asio::read(conn->GetSocket(), boost::asio::buffer(mutable_data, object_size), - ec); + std::vector mutable_data; + mutable_data.resize(object_size + metadata_size); + buffer.push_back(asio::buffer(mutable_data, object_size + metadata_size)); + conn->ReadBuffer(buffer, ec); } - // Wait for another push. - ray::Status ray_status = WaitPushReceive(conn); -}; - -ray::Status ObjectManager::QueuePush(const ObjectID &object_id_const, - SenderConnection::pointer conn) { - ObjectID object_id = ObjectID(object_id_const); - if (conn->ObjectIdQueued(object_id)) { - // For now, return with status OK if the object is already in the send queue. - return ray::Status::OK(); - } - conn->QueueObjectId(object_id); - if (num_transfers_ < max_transfers_) { - return ExecutePushQueue(conn); - } - return ray::Status::OK(); -}; - -ray::Status ObjectManager::ExecutePushQueue(SenderConnection::pointer conn) { - ray::Status status = ray::Status::OK(); - while (num_transfers_ < max_transfers_) { - if (conn->IsObjectIdQueueEmpty()) { - return ray::Status::OK(); - } - ObjectID object_id = conn->DequeueObjectId(); - // The threads that increment/decrement num_transfers_ are different. - // It's important to increment num_transfers_ before executing the push. - num_transfers_ += 1; - status = ExecutePushHeaders(object_id, conn); - } - return status; -}; - -ray::Status ObjectManager::ExecutePushHeaders(const ObjectID &object_id_const, - SenderConnection::pointer conn) { - ObjectID object_id = ObjectID(object_id_const); - // Allocate and append the request to the transfer queue. - plasma::ObjectBuffer object_buffer; - plasma::ObjectID plasma_id = object_id.to_plasma_id(); - ARROW_CHECK_OK(store_client_->GetClientOther().Get(&plasma_id, 1, 0, &object_buffer)); - if (object_buffer.data_size == -1) { - RAY_LOG(ERROR) << "Failed to get object"; - // If the object wasn't locally available, exit immediately. If the object - // later appears locally, the requesting plasma manager should request the - // transfer again. - return ray::Status::IOError( - "Unable to transfer object to requesting plasma manager, object not local."); - } - RAY_CHECK(object_buffer.metadata->data() == - object_buffer.data->data() + object_buffer.data_size); - SendRequest send_request; - send_request.object_id = object_id; - send_request.object_size = object_buffer.data_size; - send_request.data = const_cast(object_buffer.data->data()); - conn->AddSendRequest(object_id, send_request); - // Create buffer. - flatbuffers::FlatBufferBuilder fbb; - auto message = CreateObjectHeader(fbb, fbb.CreateString(object_id.binary()), - send_request.object_size); - fbb.Finish(message); - // Pack into asio buffer. - size_t length = fbb.GetSize(); - std::vector buffer; - buffer.push_back(boost::asio::buffer(&length, sizeof(length))); - buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length)); - // Send asynchronously. - boost::asio::async_write(conn->GetSocket(), buffer, - boost::bind(&ObjectManager::ExecutePushObject, this, conn, - object_id, boost::asio::placeholders::error)); - return ray::Status::OK(); -}; - -void ObjectManager::ExecutePushObject(SenderConnection::pointer conn, - const ObjectID &object_id, - const boost::system::error_code &header_ec) { - SendRequest &send_request = conn->GetSendRequest(object_id); - boost::system::error_code ec; - boost::asio::write( - conn->GetSocket(), - boost::asio::buffer(send_request.data, (size_t)send_request.object_size), ec); - // Do this regardless of whether it failed or succeeded. - ARROW_CHECK_OK( - store_client_->GetClientOther().Release(send_request.object_id.to_plasma_id())); - - ray::Status ray_status = ExecutePushCompleted(object_id, conn); -} - -ray::Status ObjectManager::ExecutePushCompleted(const ObjectID &object_id, - SenderConnection::pointer conn) { - conn->RemoveSendRequest(object_id); - num_transfers_ -= 1; - return ExecutePushQueue(conn); -}; - -ray::Status ObjectManager::WaitMessage(TCPClientConnection::pointer conn) { - boost::asio::async_read( - conn->GetSocket(), - boost::asio::buffer(&conn->message_type_, sizeof(conn->message_type_)), - boost::bind(&ObjectManager::HandleMessage, this, conn, - boost::asio::placeholders::error)); - return ray::Status::OK(); -} - -void ObjectManager::HandleMessage(TCPClientConnection::pointer conn, BoostEC msg_ec) { - switch (conn->message_type_) { - case OMMessageType_PullRequest: - ReceivePullRequest(conn); - } -} - -void ObjectManager::ReceivePullRequest(TCPClientConnection::pointer conn) { - boost::asio::read( - conn->GetSocket(), - boost::asio::buffer(&conn->message_length_, sizeof(conn->message_length_))); - std::vector message; - message.resize(conn->message_length_); - boost::system::error_code error_code; - boost::asio::read(conn->GetSocket(), boost::asio::buffer(message), error_code); - // Serialize. - auto pull_request = flatbuffers::GetRoot(message.data()); - ObjectID object_id = ObjectID::from_binary(pull_request->object_id()->str()); - ClientID client_id = ClientID::from_binary(pull_request->client_id()->str()); - // Push object to requesting client. - ray::Status push_status = Push(object_id, client_id); - ray::Status wait_status = WaitMessage(conn); + store_pool_.ReleaseObjectStore(store_client); + conn->ProcessMessages(); + RAY_LOG(DEBUG) << "ReceiveCompleted " << client_id_ << " " << object_id << " " + << num_transfers_receive_ << "/" << config_.max_receives; + RAY_CHECK_OK(TransferCompleted(TransferQueue::TransferType::RECEIVE)); + return Status::OK(); } } // namespace ray diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 9ead640d4..2ab5ad20b 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -12,58 +12,65 @@ #include #include +#include "ray/common/client_connection.h" +#include "ray/id.h" +#include "ray/status.h" + #include "plasma/client.h" #include "plasma/events.h" #include "plasma/plasma.h" -#include "format/object_manager_generated.h" -#include "object_directory.h" -#include "object_manager_client_connection.h" -#include "object_store_client.h" -#include "ray/id.h" -#include "ray/status.h" +#include "ray/object_manager/connection_pool.h" +#include "ray/object_manager/format/object_manager_generated.h" +#include "ray/object_manager/object_directory.h" +#include "ray/object_manager/object_manager_client_connection.h" +#include "ray/object_manager/object_store_client_pool.h" +#include "ray/object_manager/object_store_notification_manager.h" +#include "ray/object_manager/transfer_queue.h" namespace ray { struct ObjectManagerConfig { - // The time in milliseconds to wait before retrying a pull - // that failed due to client id lookup. + /// The time in milliseconds to wait before retrying a pull + /// that failed due to client id lookup. int pull_timeout_ms = 100; + /// Size of thread pool. + int num_threads = 2; + /// Maximum number of sends allowed. + int max_sends = 20; + /// Maximum number of receives allowed. + int max_receives = 20; // TODO(hme): Implement num retries (to avoid infinite retries). std::string store_socket_name; }; -// TODO(hme): Comment everything doxygen-style. -// TODO(hme): Implement connection cleanup. // TODO(hme): Add success/failure callbacks for push and pull. -// TODO(hme): Use boost thread pool. -// TODO(hme): Add incoming connections to io_service tied to thread pool. class ObjectManager { public: /// Implicitly instantiates Ray implementation of ObjectDirectory. /// - /// \param io_service The asio io_service tied to the object manager. + /// \param main_service The main asio io_service. + /// \param object_manager_service The asio io_service tied to the object manager. /// \param config ObjectManager configuration. /// \param gcs_client A client connection to the Ray GCS. - explicit ObjectManager(boost::asio::io_service &io_service, ObjectManagerConfig config, - std::shared_ptr gcs_client); + explicit ObjectManager(boost::asio::io_service &main_service, + std::unique_ptr object_manager_service, + const ObjectManagerConfig &config, + std::shared_ptr gcs_client); /// Takes user-defined ObjectDirectoryInterface implementation. /// When this constructor is used, the ObjectManager assumes ownership of /// the given ObjectDirectory instance. /// - /// \param io_service The asio io_service tied to the object manager. + /// \param main_service The main asio io_service. + /// \param object_manager_service The asio io_service tied to the object manager. /// \param config ObjectManager configuration. /// \param od An object implementing the object directory interface. - explicit ObjectManager(boost::asio::io_service &io_service, ObjectManagerConfig config, + explicit ObjectManager(boost::asio::io_service &main_service, + std::unique_ptr object_manager_service, + const ObjectManagerConfig &config, std::unique_ptr od); - /// \param client_id Set the client id associated with this node. - void SetClientID(const ClientID &client_id); - - /// \return Get the client id associated with this node. - ClientID GetClientID(); - /// Subscribe to notifications of objects added to local store. /// Upon subscribing, the callback will be invoked for all objects that /// @@ -106,7 +113,17 @@ class ObjectManager { /// /// \param conn The connection. /// \return Status of whether the connection was successfully established. - ray::Status AcceptConnection(TCPClientConnection::pointer conn); + void ProcessNewClient(std::shared_ptr conn); + + /// Process messages sent from other nodes. We only establish + /// transfer connections using this method; all other transfer communication + /// is done separately. + /// + /// \param conn The connection. + /// \param message_type The message type. + /// \param message A pointer set to the beginning of the message. + void ProcessClientMessage(std::shared_ptr conn, + int64_t message_type, const uint8_t *message); /// Cancels all requests (Push/Pull) associated with the given ObjectID. /// @@ -114,7 +131,7 @@ class ObjectManager { /// \return Status of whether requests were successfully cancelled. ray::Status Cancel(const ObjectID &object_id); - // Callback definition for wait. + /// Callback definition for wait. using WaitCallback = std::function &)>; /// Wait for timeout_ms before invoking the provided callback. @@ -135,132 +152,137 @@ class ObjectManager { ray::Status Terminate(); private: - using BoostEC = const boost::system::error_code &; - ClientID client_id_; ObjectManagerConfig config_; std::unique_ptr object_directory_; - std::unique_ptr store_client_; + ObjectStoreNotificationManager store_notification_; + ObjectStoreClientPool store_pool_; /// An io service for creating connections to other object managers. - boost::asio::io_service io_service_; + /// This runs on a thread pool. + std::unique_ptr object_manager_service_; + /// Weak reference to main service. We ensure this object is destroyed before + /// main_service_ is stopped. + boost::asio::io_service *main_service_; /// Used to create "work" for an io service, so when it's run, it doesn't exit. boost::asio::io_service::work work_; - /// Single thread for executing asynchronous handlers. - /// This runs the (currently only) io_service, which handles all outgoing requests - /// and object transfers (push). - std::thread io_thread_; + /// Thread pool for executing asynchronous handlers. + /// These run the object_manager_service_, which handle + /// all incoming and outgoing object transfers. + std::vector io_threads_; - /// Relatively simple way to add thread pooling. - /// boost::thread_group thread_group_; + /// Connection pool for reusing outgoing connections to remote object managers. + ConnectionPool connection_pool_; /// Timeout for failed pull requests. - using Timer = std::shared_ptr; - std::unordered_map pull_requests_; + std::unordered_map, + UniqueIDHasher> + pull_requests_; - // TODO (hme): This needs to account for receives as well. - /// This number is incremented whenever a push is started. - int num_transfers_ = 0; - // TODO (hme): Allow for concurrent sends. - /// This is the maximum number of pushes allowed. - /// We can only increase this number if we increase the number of - /// plasma client connections. - int max_transfers_ = 1; + /// Allows control of concurrent object transfers. This is a global queue, + /// allowing for concurrent transfers with many object managers as well as + /// concurrent transfers, including both sends and receives, with a single + /// remote object manager. + TransferQueue transfer_queue_; - /// Note that (currently) receives take place on the main thread, - /// and sends take place on a dedicated thread. - std::unordered_map - message_send_connections_; - std::unordered_map - transfer_send_connections_; + /// Variables to track number of concurrent sends and receives. + std::atomic num_transfers_send_; + std::atomic num_transfers_receive_; - std::unordered_map - message_receive_connections_; - std::unordered_map - transfer_receive_connections_; + /// Cache of locally available objects. + std::unordered_set local_objects_; /// Handle starting, running, and stopping asio io_service. void StartIOService(); void IOServiceLoop(); void StopIOService(); + /// Register object add with directory. + void NotifyDirectoryObjectAdd(const ObjectID &object_id); + + /// Register object remove with directory. + void NotifyDirectoryObjectDeleted(const ObjectID &object_id); + /// Wait wait_ms milliseconds before triggering a pull request for object_id. /// This is invoked when a pull fails. Only point of failure currently considered /// is GetLocationsFailed. void SchedulePull(const ObjectID &object_id, int wait_ms); - /// The handler for SchedulePull. Invokes a pull and removes the deadline timer - /// that was added to schedule the pull. - ray::Status SchedulePullHandler(const ObjectID &object_id); + /// Part of an asynchronous sequence of Pull methods. + /// Gets the location of an object before invoking PullEstablishConnection. + /// Guaranteed to execute on main_service_ thread. + /// Executes on main_service_ thread. + ray::Status PullGetLocations(const ObjectID &object_id); - /// Synchronously send a pull request. - /// Invoked once a connection to a remote manager that contains the required ObjectID - /// is established. - ray::Status ExecutePull(const ObjectID &object_id, SenderConnection::pointer conn); + /// Part of an asynchronous sequence of Pull methods. + /// Uses an existing connection or creates a connection to ClientID. + /// Executes on main_service_ thread. + ray::Status PullEstablishConnection(const ObjectID &object_id, + const ClientID &client_id); - /// Invoked once a connection to the remote manager to which the ObjectID - /// is to be sent is established. - ray::Status QueuePush(const ObjectID &object_id, SenderConnection::pointer client); - /// Starts as many queued pushes as possible without exceeding max_transfers_ - /// concurrent transfers. - ray::Status ExecutePushQueue(SenderConnection::pointer client); - /// Initiate a push. This method asynchronously sends the object id and object size + /// Private callback implementation for success on get location. Called from + /// ObjectDirectory. + void GetLocationsSuccess(const std::vector &client_ids, + const ray::ObjectID &object_id); + + /// Private callback implementation for failure on get location. Called from + /// ObjectDirectory. + void GetLocationsFailed(const ObjectID &object_id); + + /// Synchronously send a pull request via remote object manager connection. + /// Executes on main_service_ thread. + ray::Status PullSendRequest(const ObjectID &object_id, + std::shared_ptr conn); + + /// Starts as many queued sends and receives as possible without exceeding + /// config_.max_sends and config_.max_receives, respectively. + /// Executes on object_manager_service_ thread pool. + ray::Status DequeueTransfers(); + + std::shared_ptr CreateSenderConnection( + ConnectionPool::ConnectionType type, RemoteConnectionInfo info); + + /// Invoked when a transfer is completed. Invokes DequeueTransfers after + /// updating variables that track concurrent transfers. + /// Executes on object_manager_service_ thread pool. + ray::Status TransferCompleted(TransferQueue::TransferType type); + + /// Begin executing a send. + /// Executes on object_manager_service_ thread pool. + ray::Status ExecuteSendObject(const ObjectID &object_id, const ClientID &client_id, + const RemoteConnectionInfo &connection_info); + /// This method synchronously sends the object id and object size /// to the remote object manager. - ray::Status ExecutePushHeaders(const ObjectID &object_id, - SenderConnection::pointer client); - /// Called by the handler for ExecutePushMeta. + /// Executes on object_manager_service_ thread pool. + ray::Status SendObjectHeaders(const ObjectID &object_id, + std::shared_ptr client); + /// This method initiates the actual object transfer. - void ExecutePushObject(SenderConnection::pointer conn, const ObjectID &object_id, - const boost::system::error_code &header_ec); - /// Invoked when a push is completed. This method will decrement num_transfers_ - /// and invoke ExecutePushQueue. - ray::Status ExecutePushCompleted(const ObjectID &object_id, - SenderConnection::pointer client); + /// Executes on object_manager_service_ thread pool. + ray::Status SendObjectData(std::shared_ptr conn, + const UniqueID &context_id, + std::shared_ptr store_client); - /// Private callback implementation for success on get location. Called inside OD. - void GetLocationsSuccess(const std::vector &vec, - const ObjectID &object_id); - - /// Private callback implementation for failure on get location. Called inside OD. - void GetLocationsFailed(ray::Status status, const ObjectID &object_id); - - /// Asynchronously obtain a connection to client_id. - /// If a connection to client_id already exists, the callback is invoked immediately. - ray::Status GetMsgConnection(const ClientID &client_id, - std::function callback); - /// Asynchronously create a connection to client_id. - ray::Status CreateMsgConnection( - const RemoteConnectionInfo &info, - std::function callback); - /// Asynchronously create a connection to client_id. - ray::Status GetTransferConnection( - const ClientID &client_id, std::function callback); - /// Asynchronously obtain a connection to client_id. - /// If a connection to client_id already exists, the callback is invoked immediately. - ray::Status CreateTransferConnection( - const RemoteConnectionInfo &info, - std::function callback); - - /// A socket connection doing an asynchronous read on a transfer connection that was - /// added by AcceptConnection. - ray::Status WaitPushReceive(TCPClientConnection::pointer conn); /// Invoked when a remote object manager pushes an object to this object manager. - void HandlePushReceive(TCPClientConnection::pointer conn, BoostEC length_ec); + /// This will queue the receive. + void ReceivePushRequest(std::shared_ptr conn, + const uint8_t *message); + /// Execute a receive that was in the queue. + ray::Status ExecuteReceiveObject(const ClientID &client_id, const ObjectID &object_id, + uint64_t object_size, + std::shared_ptr conn); - /// A socket connection doing an asynchronous read on a message connection that was - /// added by AcceptConnection. - ray::Status WaitMessage(TCPClientConnection::pointer conn); - /// Handle messages. - void HandleMessage(TCPClientConnection::pointer conn, BoostEC msg_ec); - /// Process the receive pull request message. - void ReceivePullRequest(TCPClientConnection::pointer conn); + /// Handles receiving a pull request message. + void ReceivePullRequest(std::shared_ptr &conn, + const uint8_t *message); - /// Register object add with directory. - void NotifyDirectoryObjectAdd(const ObjectID &object_id); - /// Register object remove with directory. - void NotifyDirectoryObjectDeleted(const ObjectID &object_id); + /// Handles connect message of a new client connection. + void ConnectClient(std::shared_ptr &conn, const uint8_t *message); + /// Handles disconnect message of an existing client connection. + void DisconnectClient(std::shared_ptr &conn, + const uint8_t *message); }; } // namespace ray diff --git a/src/ray/object_manager/object_manager_client_connection.cc b/src/ray/object_manager/object_manager_client_connection.cc index 84bb60714..b904e5d90 100644 --- a/src/ray/object_manager/object_manager_client_connection.cc +++ b/src/ray/object_manager/object_manager_client_connection.cc @@ -1,59 +1,24 @@ -#include "object_manager_client_connection.h" +#include "ray/object_manager/object_manager_client_connection.h" namespace ray { -SenderConnection::pointer SenderConnection::Create(boost::asio::io_service &io_service, - const std::string &ip, uint16_t port) { - return pointer(new SenderConnection(io_service, ip, port)); +uint64_t SenderConnection::id_counter_; + +std::shared_ptr SenderConnection::Create( + boost::asio::io_service &io_service, const ClientID &client_id, const std::string &ip, + uint16_t port) { + boost::asio::ip::tcp::socket socket(io_service); + RAY_CHECK_OK(TcpConnect(socket, ip, port)); + std::shared_ptr conn = + std::make_shared(std::move(socket)); + return std::make_shared(conn, client_id); }; -SenderConnection::SenderConnection(boost::asio::io_service &io_service, - const std::string &ip, uint16_t port) - : socket_(io_service), send_queue_() { - boost::asio::ip::address ip_address = boost::asio::ip::address::from_string(ip); - boost::asio::ip::tcp::endpoint endpoint(ip_address, port); - socket_.connect(endpoint); +SenderConnection::SenderConnection(std::shared_ptr conn, + const ClientID &client_id) + : conn_(conn) { + client_id_ = client_id; + connection_id_ = SenderConnection::id_counter_++; }; -boost::asio::ip::tcp::socket &SenderConnection::GetSocket() { return socket_; }; - -bool SenderConnection::IsObjectIdQueueEmpty() { return send_queue_.empty(); } - -bool SenderConnection::ObjectIdQueued(const ObjectID &object_id) { - return std::find(send_queue_.begin(), send_queue_.end(), object_id) != - send_queue_.end(); -} - -void SenderConnection::QueueObjectId(const ObjectID &object_id) { - send_queue_.push_back(ObjectID(object_id)); -} - -ObjectID SenderConnection::DequeueObjectId() { - ObjectID object_id = send_queue_.front(); - send_queue_.pop_front(); - return object_id; -} - -void SenderConnection::AddSendRequest(const ObjectID &object_id, - SendRequest &send_request) { - send_requests_.emplace(object_id, send_request); -} - -void SenderConnection::RemoveSendRequest(const ObjectID &object_id) { - send_requests_.erase(object_id); -} - -SendRequest &SenderConnection::GetSendRequest(const ObjectID &object_id) { - return send_requests_[object_id]; -}; - -TCPClientConnection::TCPClientConnection(boost::asio::io_service &io_service) - : socket_(io_service) {} - -TCPClientConnection::pointer TCPClientConnection::Create( - boost::asio::io_service &io_service) { - return TCPClientConnection::pointer(new TCPClientConnection(io_service)); -} - -boost::asio::ip::tcp::socket &TCPClientConnection::GetSocket() { return socket_; } } // namespace ray diff --git a/src/ray/object_manager/object_manager_client_connection.h b/src/ray/object_manager/object_manager_client_connection.h index 09359580c..ebec1d9d5 100644 --- a/src/ray/object_manager/object_manager_client_connection.h +++ b/src/ray/object_manager/object_manager_client_connection.h @@ -7,63 +7,72 @@ #include #include +#include #include +#include "common/state/ray_config.h" +#include "ray/common/client_connection.h" #include "ray/id.h" namespace ray { -struct SendRequest { - ObjectID object_id; - ClientID client_id; - int64_t object_size; - uint8_t *data; -}; - -// TODO(hme): Document public API after integration with common connection. class SenderConnection : public boost::enable_shared_from_this { public: - typedef boost::shared_ptr pointer; - typedef std::unordered_map SendRequestsType; - typedef std::deque SendQueueType; + /// Create a connection for sending data to other object managers. + /// + /// \param io_service The service to which the created socket should attach. + /// \param client_id The ClientID of the remote node. + /// \param ip The ip address of the remote node server. + /// \param port The port of the remote node server. + /// \return A connection to the remote object manager. + static std::shared_ptr Create(boost::asio::io_service &io_service, + const ClientID &client_id, + const std::string &ip, uint16_t port); - static pointer Create(boost::asio::io_service &io_service, const std::string &ip, - uint16_t port); + /// \param socket A reference to the socket created by the static Create method. + /// \param client_id The ClientID of the remote node. + SenderConnection(std::shared_ptr conn, const ClientID &client_id); - explicit SenderConnection(boost::asio::io_service &io_service, const std::string &ip, - uint16_t port); + /// Write a message to the client. + /// + /// \param type The message type (e.g., a flatbuffer enum). + /// \param length The size in bytes of the message. + /// \param message A pointer to the message buffer. + /// \return Status. + ray::Status WriteMessage(int64_t type, uint64_t length, const uint8_t *message) { + return conn_->WriteMessage(type, length, message); + } - boost::asio::ip::tcp::socket &GetSocket(); + /// Write a buffer to this connection. + /// + /// \param buffer The buffer. + /// \param ec The error code object in which to store error codes. + void WriteBuffer(const std::vector &buffer, + boost::system::error_code &ec) { + return conn_->WriteBuffer(buffer, ec); + } - bool IsObjectIdQueueEmpty(); - bool ObjectIdQueued(const ObjectID &object_id); - void QueueObjectId(const ObjectID &object_id); - ObjectID DequeueObjectId(); + /// Read a buffer from this connection. + /// + /// \param buffer The buffer. + /// \param ec The error code object in which to store error codes. + void ReadBuffer(const std::vector &buffer, + boost::system::error_code &ec) { + return conn_->ReadBuffer(buffer, ec); + } - void AddSendRequest(const ObjectID &object_id, SendRequest &send_request); - void RemoveSendRequest(const ObjectID &object_id); - SendRequest &GetSendRequest(const ObjectID &object_id); + /// \return The ClientID of this connection. + const ClientID &GetClientID() { return client_id_; } private: - boost::asio::ip::tcp::socket socket_; - SendQueueType send_queue_; - SendRequestsType send_requests_; -}; + bool operator==(const SenderConnection &rhs) const { + return connection_id_ == rhs.connection_id_; + } -// TODO(hme): Document public API after integration with common connection. -class TCPClientConnection : public boost::enable_shared_from_this { - public: - typedef boost::shared_ptr pointer; - static pointer Create(boost::asio::io_service &io_service); - boost::asio::ip::tcp::socket &GetSocket(); - - TCPClientConnection(boost::asio::io_service &io_service); - - int64_t message_type_; - uint64_t message_length_; - - private: - boost::asio::ip::tcp::socket socket_; + static uint64_t id_counter_; + uint64_t connection_id_; + ClientID client_id_; + std::shared_ptr conn_; }; } // namespace ray diff --git a/src/ray/object_manager/object_manager_protocol.cc b/src/ray/object_manager/object_manager_protocol.cc deleted file mode 100644 index 6d35601df..000000000 --- a/src/ray/object_manager/object_manager_protocol.cc +++ /dev/null @@ -1 +0,0 @@ -// TODO(hme): Move all messaging code here. diff --git a/src/ray/object_manager/object_manager_protocol.h b/src/ray/object_manager/object_manager_protocol.h deleted file mode 100644 index e140543de..000000000 --- a/src/ray/object_manager/object_manager_protocol.h +++ /dev/null @@ -1,6 +0,0 @@ -#ifndef RAY_OBJECT_MANAGER_OBJECT_MANAGER_PROTOCOL_H -#define RAY_OBJECT_MANAGER_OBJECT_MANAGER_PROTOCOL_H - -// TODO(hme): Move all messaging code here. - -#endif // RAY_OBJECT_MANAGER_OBJECT_MANAGER_PROTOCOL_H diff --git a/src/ray/object_manager/object_manager_test.cc b/src/ray/object_manager/object_manager_test.cc deleted file mode 100644 index 4cae4e96c..000000000 --- a/src/ray/object_manager/object_manager_test.cc +++ /dev/null @@ -1,137 +0,0 @@ -#include -#include - -#include "gtest/gtest.h" -#include "plasma/client.h" -#include "plasma/events.h" -#include "plasma/plasma.h" -#include "plasma/protocol.h" - -#include "ray/status.h" - -#include "ray/object_manager/object_manager.h" - -namespace ray { - -std::string test_executable; // NOLINT - -class TestObjectManager : public ::testing::Test { - public: - TestObjectManager() { RAY_LOG(DEBUG) << "TestObjectManager: started."; } - - void SetUp() { - // start store - std::string om_dir = test_executable.substr(0, test_executable.find_last_of("/")); - std::string plasma_dir = om_dir + "./../plasma"; - std::string plasma_command = - plasma_dir + - "/plasma_store -m 1000000000 -s /tmp/store 1> /dev/null 2> /dev/null &"; - int s = system(plasma_command.c_str()); - ASSERT_TRUE(!s); - - // Start mock global control store. - mock_gcs_client_ = std::shared_ptr(new GcsClient()); - // mock_gcs_client_->Register(); - - // Start node server. - - // Start object manager 1. - ObjectManagerConfig config; - config.store_socket_name = "/tmp/store"; - object_manager_1_ = std::unique_ptr( - new ObjectManager(io_service_, config, mock_gcs_client_)); - - // Start object manager 2. - // ObjectManagerConfig config2; - // config2.store_socket_name = "/tmp/store"; - // std::shared_ptr od2 = std::shared_ptr(new - // ObjectDirectory()); - // od2->InitGcs(mock_gcs_client_); - // object_manager_2_ = std::unique_ptr(new ObjectManager(io_service, - // config2, od2)); - - // Initiate client connection. - ARROW_CHECK_OK(client_.Connect("/tmp/store", "", PLASMA_DEFAULT_RELEASE_DELAY)); - - this->StartLoop(); - } - - void TearDown() { - this->StopLoop(); - arrow::Status arrow_status = client_.Disconnect(); - ASSERT_TRUE(arrow_status.ok()); - ray::Status ray_status = object_manager_1_->Terminate(); - ASSERT_TRUE(ray_status.ok()); - // object_manager_2_->Terminate(); - int s = system("killall plasma_store &"); - ASSERT_TRUE(!s); - } - - void Loop() { io_service_.run(); }; - - void StartLoop() { process_thread_ = std::thread(&TestObjectManager::Loop, this); }; - - void StopLoop() { - io_service_.stop(); - process_thread_.join(); - } - - protected: - std::thread process_thread_; - plasma::PlasmaClient client_; - plasma::PlasmaClient client2_; - boost::asio::io_service io_service_; - - std::shared_ptr mock_gcs_client_; - std::unique_ptr object_manager_1_; - std::unique_ptr object_manager_2_; -}; - -// TODO: get rid of dead code? -// TEST_F(TestObjectManager, TestPush) { -// // test object push between two object managers. -// ASSERT_TRUE(true); -// sleep(1); -//} - -// TEST_F(TestObjectManager, TestPull) { -// ObjectID object_id = ObjectID().from_random(); -// ClientID dbc_id = ClientID().from_random(); -// RAY_LOG(INFO) << "ObjectID: " << object_id.hex().c_str(); -// RAY_LOG(INFO) << "ClientID: " << dbc_id.hex().c_str(); -// om->Pull(object_id, dbc_id); -// om->Pull(object_id); -// ASSERT_TRUE(true); -// sleep(1); -//} - -void ObjectAdded(const ObjectID &object_id) { - RAY_LOG(INFO) << "ObjectID Added: " << object_id.hex().c_str(); -} - -TEST_F(TestObjectManager, TestNotifications) { - ray::Status status = object_manager_1_->SubscribeObjAdded(ObjectAdded); - ASSERT_TRUE(status.ok()); - // put object - for (int i = 0; i < 10; ++i) { - ObjectID object_id = ObjectID::from_random(); - RAY_LOG(INFO) << "ObjectID Created: " << object_id.hex().c_str(); - int64_t data_size = 100; - uint8_t metadata[] = {5}; - int64_t metadata_size = sizeof(metadata); - std::shared_ptr data; - ARROW_CHECK_OK(client_.Create(object_id.to_plasma_id(), data_size, metadata, - metadata_size, &data)); - ARROW_CHECK_OK(client_.Seal(object_id.to_plasma_id())); - } - // TODO(hme): Can we do this without sleeping? - sleep(1); -} - -} // namespace ray - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - ray::test_executable = std::string(argv[0]); - return RUN_ALL_TESTS(); -} diff --git a/src/ray/object_manager/object_store_client.cc b/src/ray/object_manager/object_store_client.cc deleted file mode 100644 index 4db9e90d4..000000000 --- a/src/ray/object_manager/object_store_client.cc +++ /dev/null @@ -1,95 +0,0 @@ -#include -#include - -#include -#include -#include - -#include "common.h" -#include "common_protocol.h" -#include "ray/object_manager/object_store_client.h" - -namespace ray { - -// TODO(hme): Dedicate this class to notifications. -// TODO(hme): Create object store client pool for object manager. -ObjectStoreClient::ObjectStoreClient(boost::asio::io_service &io_service, - std::string &store_socket_name) - : client_one_(), client_two_(), socket_(io_service) { - ARROW_CHECK_OK( - client_two_.Connect(store_socket_name.c_str(), "", PLASMA_DEFAULT_RELEASE_DELAY)); - ARROW_CHECK_OK( - client_one_.Connect(store_socket_name.c_str(), "", PLASMA_DEFAULT_RELEASE_DELAY)); - - // Connect to two clients, but subscribe to only one. - ARROW_CHECK_OK(client_one_.Subscribe(&c_socket_)); - boost::system::error_code ec; - socket_.assign(boost::asio::local::stream_protocol(), c_socket_, ec); - assert(!ec.value()); - NotificationWait(); -}; - -void ObjectStoreClient::Terminate() { - ARROW_CHECK_OK(client_two_.Disconnect()); - ARROW_CHECK_OK(client_one_.Disconnect()); -} - -void ObjectStoreClient::NotificationWait() { - boost::asio::async_read(socket_, boost::asio::buffer(&length_, sizeof(length_)), - boost::bind(&ObjectStoreClient::ProcessStoreLength, this, - boost::asio::placeholders::error)); -} - -void ObjectStoreClient::ProcessStoreLength(const boost::system::error_code &error) { - notification_.resize(length_); - boost::asio::async_read(socket_, boost::asio::buffer(notification_), - boost::bind(&ObjectStoreClient::ProcessStoreNotification, this, - boost::asio::placeholders::error)); -} - -void ObjectStoreClient::ProcessStoreNotification(const boost::system::error_code &error) { - if (error) { - throw std::runtime_error("ObjectStore may have died."); - } - - auto object_info = flatbuffers::GetRoot(notification_.data()); - ObjectID object_id = from_flatbuf(*object_info->object_id()); - if (object_info->is_deletion()) { - ProcessStoreRemove(object_id); - } else { - ProcessStoreAdd(object_id); - // why all these params? - // ProcessStoreAdd( - // object_id, object_info->data_size(), - // object_info->metadata_size(), - // (unsigned char *) object_info->digest()->data()); - } - NotificationWait(); -} - -void ObjectStoreClient::ProcessStoreAdd(const ObjectID &object_id) { - for (auto handler : add_handlers_) { - handler(object_id); - } -}; - -void ObjectStoreClient::ProcessStoreRemove(const ObjectID &object_id) { - for (auto handler : rem_handlers_) { - handler(object_id); - } -}; - -void ObjectStoreClient::SubscribeObjAdded( - std::function callback) { - add_handlers_.push_back(callback); -}; - -void ObjectStoreClient::SubscribeObjDeleted( - std::function callback) { - rem_handlers_.push_back(callback); -}; - -plasma::PlasmaClient &ObjectStoreClient::GetClient() { return client_one_; }; - -plasma::PlasmaClient &ObjectStoreClient::GetClientOther() { return client_two_; }; -} // namespace ray diff --git a/src/ray/object_manager/object_store_client_pool.cc b/src/ray/object_manager/object_store_client_pool.cc new file mode 100644 index 000000000..6742d0d5b --- /dev/null +++ b/src/ray/object_manager/object_store_client_pool.cc @@ -0,0 +1,38 @@ +#include "object_store_client_pool.h" + +namespace ray { + +ObjectStoreClientPool::ObjectStoreClientPool(const std::string &store_socket_name) + : store_socket_name_(store_socket_name) {} + +std::shared_ptr ObjectStoreClientPool::GetObjectStore() { + std::lock_guard lock(pool_mutex); + if (available_clients.empty()) { + Add(); + } + std::shared_ptr client = available_clients.back(); + available_clients.pop_back(); + return client; +} + +void ObjectStoreClientPool::ReleaseObjectStore( + std::shared_ptr client) { + std::lock_guard lock(pool_mutex); + available_clients.push_back(client); +} + +void ObjectStoreClientPool::Terminate() { + for (const auto &client : clients) { + ARROW_CHECK_OK(client->Disconnect()); + } + available_clients.clear(); + clients.clear(); +} + +void ObjectStoreClientPool::Add() { + clients.emplace_back(new plasma::PlasmaClient()); + ARROW_CHECK_OK(clients.back()->Connect(store_socket_name_.c_str(), "", + PLASMA_DEFAULT_RELEASE_DELAY)); + available_clients.push_back(clients.back()); +} +} // namespace ray diff --git a/src/ray/object_manager/object_store_client_pool.h b/src/ray/object_manager/object_store_client_pool.h new file mode 100644 index 000000000..a74588536 --- /dev/null +++ b/src/ray/object_manager/object_store_client_pool.h @@ -0,0 +1,65 @@ +#ifndef RAY_OBJECT_MANAGER_OBJECT_STORE_CLIENT_POOL_H +#define RAY_OBJECT_MANAGER_OBJECT_STORE_CLIENT_POOL_H + +#include +#include +#include +#include + +#include +#include +#include + +#include "plasma/client.h" +#include "plasma/events.h" +#include "plasma/plasma.h" + +#include "object_directory.h" +#include "ray/id.h" +#include "ray/status.h" + +namespace ray { + +/// \class ObjectStoreClientPool +/// +/// Provides connections to the object store. Enables concurrent communication with +/// the object store. +class ObjectStoreClientPool { + public: + /// Constructor. + /// + /// \param store_socket_name The object store socket name. + ObjectStoreClientPool(const std::string &store_socket_name); + + /// This object cannot be copied due to pool_mutex. + RAY_DISALLOW_COPY_AND_ASSIGN(ObjectStoreClientPool); + + /// Provides a connection to the object store from the object store pool. + /// This removes the object store client from the pool of available clients. + /// + /// \return A connection to the object store. + std::shared_ptr GetObjectStore(); + + /// Releases a client object and puts it back into the object store pool + /// for reuse. + /// Once a client is released, it is assumed that it is not being used. + /// \param client The client to return. + /// \param client + void ReleaseObjectStore(std::shared_ptr client); + + /// Terminates this object. + void Terminate(); + + private: + /// Adds a client to the client pool and mark it as available. + void Add(); + + std::mutex pool_mutex; + std::vector> available_clients; + std::vector> clients; + std::string store_socket_name_; +}; + +} // namespace ray + +#endif // RAY_OBJECT_MANAGER_OBJECT_STORE_CLIENT_POOL_H diff --git a/src/ray/object_manager/object_store_notification_manager.cc b/src/ray/object_manager/object_store_notification_manager.cc new file mode 100644 index 000000000..87d48b9ef --- /dev/null +++ b/src/ray/object_manager/object_store_notification_manager.cc @@ -0,0 +1,90 @@ +#include +#include + +#include +#include +#include + +#include "common/common.h" +#include "common/common_protocol.h" + +#include "ray/object_manager/object_store_notification_manager.h" + +namespace ray { + +ObjectStoreNotificationManager::ObjectStoreNotificationManager( + boost::asio::io_service &io_service, const std::string &store_socket_name) + : store_client_(), socket_(io_service) { + ARROW_CHECK_OK( + store_client_.Connect(store_socket_name.c_str(), "", PLASMA_DEFAULT_RELEASE_DELAY)); + + ARROW_CHECK_OK(store_client_.Subscribe(&c_socket_)); + boost::system::error_code ec; + socket_.assign(boost::asio::local::stream_protocol(), c_socket_, ec); + assert(!ec.value()); + NotificationWait(); +} + +void ObjectStoreNotificationManager::Terminate() { + ARROW_CHECK_OK(store_client_.Disconnect()); +} + +void ObjectStoreNotificationManager::NotificationWait() { + boost::asio::async_read(socket_, boost::asio::buffer(&length_, sizeof(length_)), + boost::bind(&ObjectStoreNotificationManager::ProcessStoreLength, + this, boost::asio::placeholders::error)); +} + +void ObjectStoreNotificationManager::ProcessStoreLength( + const boost::system::error_code &error) { + notification_.resize(length_); + boost::asio::async_read( + socket_, boost::asio::buffer(notification_), + boost::bind(&ObjectStoreNotificationManager::ProcessStoreNotification, this, + boost::asio::placeholders::error)); +} + +void ObjectStoreNotificationManager::ProcessStoreNotification( + const boost::system::error_code &error) { + if (error) { + throw std::runtime_error("ObjectStore may have died."); + } + + const auto &object_info = flatbuffers::GetRoot(notification_.data()); + const auto &object_id = from_flatbuf(*object_info->object_id()); + if (object_info->is_deletion()) { + ProcessStoreRemove(object_id); + } else { + ProcessStoreAdd(object_id); + // TODO(hme): Determine what data is actually needed by consumer of this notification. + // ProcessStoreAdd( + // object_id, object_info->data_size(), + // object_info->metadata_size(), + // (unsigned char *) object_info->digest()->data()); + } + NotificationWait(); +} + +void ObjectStoreNotificationManager::ProcessStoreAdd(const ObjectID &object_id) { + for (auto handler : add_handlers_) { + handler(object_id); + } +} + +void ObjectStoreNotificationManager::ProcessStoreRemove(const ObjectID &object_id) { + for (auto handler : rem_handlers_) { + handler(object_id); + } +} + +void ObjectStoreNotificationManager::SubscribeObjAdded( + std::function callback) { + add_handlers_.push_back(callback); +} + +void ObjectStoreNotificationManager::SubscribeObjDeleted( + std::function callback) { + rem_handlers_.push_back(callback); +} + +} // namespace ray diff --git a/src/ray/object_manager/object_store_client.h b/src/ray/object_manager/object_store_notification_manager.h similarity index 54% rename from src/ray/object_manager/object_store_client.h rename to src/ray/object_manager/object_store_notification_manager.h index 9ab235054..bc90d67c6 100644 --- a/src/ray/object_manager/object_store_client.h +++ b/src/ray/object_manager/object_store_notification_manager.h @@ -13,53 +13,58 @@ #include "plasma/events.h" #include "plasma/plasma.h" -#include "object_directory.h" #include "ray/id.h" #include "ray/status.h" +#include "ray/object_manager/object_directory.h" + namespace ray { -// TODO(hme): document public API after refactor. -class ObjectStoreClient { +/// \class ObjectStoreClientPool +/// +/// Encapsulates notification handling from the object store. +class ObjectStoreNotificationManager { public: - // Encapsulates communication with the object store. - ObjectStoreClient(boost::asio::io_service &io_service, std::string &store_socket_name); + /// Constructor. + /// + /// \param io_service The asio service to be used. + /// \param store_socket_name The store socket to connect to. + ObjectStoreNotificationManager(boost::asio::io_service &io_service, + const std::string &store_socket_name); - // Subscribe to notifications of objects added to local store. - // Upon subscribing, the callback will be invoked for all objects that - // already exist in the local store. + /// Subscribe to notifications of objects added to local store. + /// Upon subscribing, the callback will be invoked for all objects that + /// already exist in the local store + /// + /// \param callback A callback expecting an ObjectID. void SubscribeObjAdded(std::function callback); - // Subscribe to notifications of objects deleted from local store. + /// Subscribe to notifications of objects deleted from local store. + /// + /// \param callback A callback expecting an ObjectID. void SubscribeObjDeleted(std::function callback); - // TODO(hme): There should be as many client connections as there are threads. - // Two client connections are made to enable concurrent communication with the store. - plasma::PlasmaClient &GetClient(); - plasma::PlasmaClient &GetClientOther(); - - // Terminate this object. + /// Terminate this object. void Terminate(); private: - std::vector> add_handlers_; - std::vector> rem_handlers_; - - plasma::PlasmaClient client_one_; - plasma::PlasmaClient client_two_; - int c_socket_; - int64_t length_; - std::vector notification_; - boost::asio::local::stream_protocol::socket socket_; - - // Async loop for handling object store notifications. + /// Async loop for handling object store notifications. void NotificationWait(); void ProcessStoreLength(const boost::system::error_code &error); void ProcessStoreNotification(const boost::system::error_code &error); - // Support for rebroadcasting object add/rem events. + /// Support for rebroadcasting object add/rem events. void ProcessStoreAdd(const ObjectID &object_id); void ProcessStoreRemove(const ObjectID &object_id); + + std::vector> add_handlers_; + std::vector> rem_handlers_; + + plasma::PlasmaClient store_client_; + int c_socket_; + int64_t length_; + std::vector notification_; + boost::asio::local::stream_protocol::socket socket_; }; } // namespace ray diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc new file mode 100644 index 000000000..28b58e032 --- /dev/null +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -0,0 +1,440 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "ray/object_manager/object_manager.h" + +namespace ray { + +std::string store_executable; + +static inline void flushall_redis(void) { + redisContext *context = redisConnect("127.0.0.1", 6379); + freeReplyObject(redisCommand(context, "FLUSHALL")); + redisFree(context); +} + +int64_t current_time_ms() { + std::chrono::milliseconds ms_since_epoch = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()); + return ms_since_epoch.count(); +} + +class MockServer { + public: + MockServer(boost::asio::io_service &main_service, + std::unique_ptr object_manager_service, + const ObjectManagerConfig &object_manager_config, + std::shared_ptr gcs_client) + : object_manager_acceptor_( + main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)), + object_manager_socket_(main_service), + gcs_client_(gcs_client), + object_manager_(main_service, std::move(object_manager_service), + object_manager_config, gcs_client) { + RAY_CHECK_OK(RegisterGcs(main_service)); + // Start listening for clients. + DoAcceptObjectManager(); + } + + ~MockServer() { + RAY_CHECK_OK(gcs_client_->client_table().Disconnect()); + RAY_CHECK_OK(object_manager_.Terminate()); + } + + private: + ray::Status RegisterGcs(boost::asio::io_service &io_service) { + RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379)); + RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); + + boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint(); + std::string ip = endpoint.address().to_string(); + unsigned short object_manager_port = endpoint.port(); + + ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); + client_info.node_manager_address = ip; + client_info.node_manager_port = object_manager_port; + client_info.object_manager_port = object_manager_port; + return gcs_client_->client_table().Connect(client_info); + } + + void DoAcceptObjectManager() { + object_manager_acceptor_.async_accept( + object_manager_socket_, boost::bind(&MockServer::HandleAcceptObjectManager, this, + boost::asio::placeholders::error)); + } + + void HandleAcceptObjectManager(const boost::system::error_code &error) { + ClientHandler client_handler = + [this](std::shared_ptr client) { + object_manager_.ProcessNewClient(client); + }; + MessageHandler message_handler = [this]( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + object_manager_.ProcessClientMessage(client, message_type, message); + }; + // Accept a new local client and dispatch it to the node manager. + auto new_connection = TcpClientConnection::Create(client_handler, message_handler, + std::move(object_manager_socket_)); + DoAcceptObjectManager(); + } + + friend class StressTestObjectManager; + + boost::asio::ip::tcp::acceptor object_manager_acceptor_; + boost::asio::ip::tcp::socket object_manager_socket_; + std::shared_ptr gcs_client_; + ObjectManager object_manager_; +}; + +class TestObjectManagerBase : public ::testing::Test { + public: + TestObjectManagerBase() {} + + std::string StartStore(const std::string &id) { + std::string store_id = "/tmp/store"; + store_id = store_id + id; + std::string plasma_command = store_executable + " -m 1000000000 -s " + store_id + + " 1> /dev/null 2> /dev/null &"; + RAY_LOG(DEBUG) << plasma_command; + int ec = system(plasma_command.c_str()); + if (ec != 0) { + throw std::runtime_error("failed to start plasma store."); + }; + return store_id; + } + + void SetUp() { + flushall_redis(); + + object_manager_service_1.reset(new boost::asio::io_service()); + object_manager_service_2.reset(new boost::asio::io_service()); + + // start store + std::string store_sock_1 = StartStore("1"); + std::string store_sock_2 = StartStore("2"); + + // start first server + gcs_client_1 = std::shared_ptr(new gcs::AsyncGcsClient()); + ObjectManagerConfig om_config_1; + om_config_1.store_socket_name = store_sock_1; + om_config_1.num_threads = 4; + om_config_1.max_sends = 20; + om_config_1.max_receives = 20; + server1.reset(new MockServer(main_service, std::move(object_manager_service_1), + om_config_1, gcs_client_1)); + + // start second server + gcs_client_2 = std::shared_ptr(new gcs::AsyncGcsClient()); + ObjectManagerConfig om_config_2; + om_config_2.store_socket_name = store_sock_2; + om_config_2.num_threads = 4; + om_config_2.max_sends = 20; + om_config_2.max_receives = 20; + server2.reset(new MockServer(main_service, std::move(object_manager_service_2), + om_config_2, gcs_client_2)); + + // connect to stores. + ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY)); + ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY)); + } + + void TearDown() { + arrow::Status client1_status = client1.Disconnect(); + arrow::Status client2_status = client2.Disconnect(); + ASSERT_TRUE(client1_status.ok() && client2_status.ok()); + + this->server1.reset(); + this->server2.reset(); + + int s = system("killall plasma_store &"); + ASSERT_TRUE(!s); + } + + ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { + ObjectID object_id = ObjectID::from_random(); + RAY_LOG(DEBUG) << "ObjectID Created: " << object_id; + uint8_t metadata[] = {5}; + int64_t metadata_size = sizeof(metadata); + std::shared_ptr data; + ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata, + metadata_size, &data)); + ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id())); + return object_id; + } + + void object_added_handler_1(ObjectID object_id) { v1.push_back(object_id); }; + + void object_added_handler_2(ObjectID object_id) { v2.push_back(object_id); }; + + protected: + std::thread p; + boost::asio::io_service main_service; + std::unique_ptr object_manager_service_1; + std::unique_ptr object_manager_service_2; + std::shared_ptr gcs_client_1; + std::shared_ptr gcs_client_2; + std::unique_ptr server1; + std::unique_ptr server2; + + plasma::PlasmaClient client1; + plasma::PlasmaClient client2; + std::vector v1; + std::vector v2; +}; + +class StressTestObjectManager : public TestObjectManagerBase { + public: + enum TransferPattern { + PUSH_A_B, + PUSH_B_A, + BIDIRECTIONAL_PUSH, + PULL_A_B, + PULL_B_A, + BIDIRECTIONAL_PULL, + BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE, + }; + + int async_loop_index = -1; + uint num_expected_objects; + + std::vector async_loop_patterns = { + PUSH_A_B, + PUSH_B_A, + BIDIRECTIONAL_PUSH, + PULL_A_B, + PULL_B_A, + BIDIRECTIONAL_PULL, + BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE}; + + int num_connected_clients = 0; + + ClientID client_id_1; + ClientID client_id_2; + + int64_t start_time; + + void WaitConnections() { + client_id_1 = gcs_client_1->client_table().GetLocalClientId(); + client_id_2 = gcs_client_2->client_table().GetLocalClientId(); + gcs_client_1->client_table().RegisterClientAddedCallback([this]( + gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + ClientID parsed_id = ClientID::from_binary(data.client_id); + if (parsed_id == client_id_1 || parsed_id == client_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 2) { + StartTests(); + } + }); + } + + void StartTests() { + TestConnections(); + AddTransferTestHandlers(); + TransferTestNext(); + } + + void AddTransferTestHandlers() { + ray::Status status = ray::Status::OK(); + status = + server1->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) { + object_added_handler_1(object_id); + if (v1.size() == num_expected_objects && v1.size() == v2.size()) { + TransferTestComplete(); + } + }); + RAY_CHECK_OK(status); + status = + server2->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) { + object_added_handler_2(object_id); + if (v2.size() == num_expected_objects && v1.size() == v2.size()) { + TransferTestComplete(); + } + }); + RAY_CHECK_OK(status); + } + + void TransferTestNext() { + async_loop_index += 1; + if ((uint)async_loop_index < async_loop_patterns.size()) { + TransferPattern pattern = async_loop_patterns[async_loop_index]; + TransferTestExecute(1000, 100, pattern); + } else { + main_service.stop(); + } + } + + plasma::ObjectBuffer GetObject(plasma::PlasmaClient &client, ObjectID &object_id) { + plasma::ObjectBuffer object_buffer; + plasma::ObjectID plasma_id = object_id.to_plasma_id(); + ARROW_CHECK_OK(client.Get(&plasma_id, 1, 0, &object_buffer)); + return object_buffer; + } + + static unsigned char *GetDigest(plasma::PlasmaClient &client, ObjectID &object_id) { + const int64_t size = sizeof(uint64_t); + static unsigned char digest_1[size]; + ARROW_CHECK_OK(client.Hash(object_id.to_plasma_id(), &digest_1[0])); + return digest_1; + } + + void CompareObjects(ObjectID &object_id_1, ObjectID &object_id_2) { + plasma::ObjectBuffer object_buffer_1 = GetObject(client1, object_id_1); + plasma::ObjectBuffer object_buffer_2 = GetObject(client1, object_id_1); + uint8_t *data_1 = const_cast(object_buffer_1.data->data()); + uint8_t *data_2 = const_cast(object_buffer_2.data->data()); + ASSERT_EQ(object_buffer_1.data_size, object_buffer_2.data_size); + for (int i = -1; ++i < object_buffer_1.data_size;) { + ASSERT_TRUE(data_1[i] == data_2[i]); + } + } + + void CompareHashes(ObjectID &object_id_1, ObjectID &object_id_2) { + const int64_t size = sizeof(uint64_t); + static unsigned char *digest_1 = GetDigest(client1, object_id_1); + static unsigned char *digest_2 = GetDigest(client2, object_id_2); + for (int i = -1; ++i < size;) { + ASSERT_TRUE(digest_1[i] == digest_2[i]); + } + } + + void TransferTestComplete() { + int64_t elapsed = current_time_ms() - start_time; + RAY_LOG(INFO) << "TransferTestComplete: " << async_loop_patterns[async_loop_index] + << " " << v1.size() << " " << elapsed; + ASSERT_TRUE(v1.size() == v2.size()); + for (uint i = 0; i < v1.size(); ++i) { + ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end()); + } + + // Compare objects and their hashes. + for (uint i = 0; i < v1.size(); ++i) { + ObjectID object_id_2 = v2[i]; + ObjectID object_id_1 = + v1[std::distance(v1.begin(), std::find(v1.begin(), v1.end(), v2[i]))]; + CompareHashes(object_id_1, object_id_2); + CompareObjects(object_id_1, object_id_2); + } + + v1.clear(); + v2.clear(); + TransferTestNext(); + } + + void TransferTestExecute(int num_trials, int64_t data_size, + TransferPattern transfer_pattern) { + ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId(); + ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId(); + + ray::Status status = ray::Status::OK(); + + if (transfer_pattern == BIDIRECTIONAL_PULL || + transfer_pattern == BIDIRECTIONAL_PUSH || + transfer_pattern == BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE) { + num_expected_objects = (uint)2 * num_trials; + } else { + num_expected_objects = (uint)num_trials; + } + + start_time = current_time_ms(); + + switch (transfer_pattern) { + case PUSH_A_B: { + for (int i = -1; ++i < num_trials;) { + ObjectID oid1 = WriteDataToClient(client1, data_size); + status = server1->object_manager_.Push(oid1, client_id_2); + } + } break; + case PUSH_B_A: { + for (int i = -1; ++i < num_trials;) { + ObjectID oid2 = WriteDataToClient(client2, data_size); + status = server2->object_manager_.Push(oid2, client_id_1); + } + } break; + case BIDIRECTIONAL_PUSH: { + for (int i = -1; ++i < num_trials;) { + ObjectID oid1 = WriteDataToClient(client1, data_size); + status = server1->object_manager_.Push(oid1, client_id_2); + ObjectID oid2 = WriteDataToClient(client2, data_size); + status = server2->object_manager_.Push(oid2, client_id_1); + } + } break; + case PULL_A_B: { + for (int i = -1; ++i < num_trials;) { + ObjectID oid1 = WriteDataToClient(client1, data_size); + status = server2->object_manager_.Pull(oid1); + } + } break; + case PULL_B_A: { + for (int i = -1; ++i < num_trials;) { + ObjectID oid2 = WriteDataToClient(client2, data_size); + status = server1->object_manager_.Pull(oid2); + } + } break; + case BIDIRECTIONAL_PULL: { + for (int i = -1; ++i < num_trials;) { + ObjectID oid1 = WriteDataToClient(client1, data_size); + status = server2->object_manager_.Pull(oid1); + ObjectID oid2 = WriteDataToClient(client2, data_size); + status = server1->object_manager_.Pull(oid2); + } + } break; + case BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE: { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(1, 50); + for (int i = -1; ++i < num_trials;) { + ObjectID oid1 = WriteDataToClient(client1, data_size + dis(gen)); + status = server2->object_manager_.Pull(oid1); + ObjectID oid2 = WriteDataToClient(client2, data_size + dis(gen)); + status = server1->object_manager_.Pull(oid2); + } + } break; + default: { + RAY_LOG(FATAL) << "No case for transfer_pattern " << transfer_pattern; + } break; + } + } + + void TestConnections() { + RAY_LOG(DEBUG) << "\n" + << "Server client ids:" + << "\n"; + ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId(); + ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId(); + RAY_LOG(DEBUG) << "Server 1: " << client_id_1 << "\n" + << "Server 2: " << client_id_2; + + RAY_LOG(DEBUG) << "\n" + << "All connected clients:" + << "\n"; + const ClientTableDataT &data = gcs_client_1->client_table().GetClient(client_id_1); + RAY_LOG(DEBUG) << "ClientID=" << ClientID::from_binary(data.client_id) << "\n" + << "ClientIp=" << data.node_manager_address << "\n" + << "ClientPort=" << data.node_manager_port; + const ClientTableDataT &data2 = gcs_client_1->client_table().GetClient(client_id_2); + RAY_LOG(DEBUG) << "ClientID=" << ClientID::from_binary(data2.client_id) << "\n" + << "ClientIp=" << data2.node_manager_address << "\n" + << "ClientPort=" << data2.node_manager_port; + } +}; + +TEST_F(StressTestObjectManager, StartStressTestObjectManager) { + auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); }); + AsyncStartTests(); + main_service.run(); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + ray::store_executable = std::string(argv[1]); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc new file mode 100644 index 000000000..8415191a7 --- /dev/null +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -0,0 +1,256 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "ray/object_manager/object_manager.h" + +namespace ray { + +static inline void flushall_redis(void) { + redisContext *context = redisConnect("127.0.0.1", 6379); + freeReplyObject(redisCommand(context, "FLUSHALL")); + redisFree(context); +} + +std::string store_executable; + +class MockServer { + public: + MockServer(boost::asio::io_service &main_service, + std::unique_ptr object_manager_service, + const ObjectManagerConfig &object_manager_config, + std::shared_ptr gcs_client) + : object_manager_acceptor_( + main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)), + object_manager_socket_(main_service), + gcs_client_(gcs_client), + object_manager_(main_service, std::move(object_manager_service), + object_manager_config, gcs_client) { + RAY_CHECK_OK(RegisterGcs(main_service)); + // Start listening for clients. + DoAcceptObjectManager(); + } + + ~MockServer() { + RAY_CHECK_OK(gcs_client_->client_table().Disconnect()); + RAY_CHECK_OK(object_manager_.Terminate()); + } + + private: + ray::Status RegisterGcs(boost::asio::io_service &io_service) { + RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379)); + RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); + + boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint(); + std::string ip = endpoint.address().to_string(); + unsigned short object_manager_port = endpoint.port(); + + ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); + client_info.node_manager_address = ip; + client_info.node_manager_port = object_manager_port; + client_info.object_manager_port = object_manager_port; + return gcs_client_->client_table().Connect(client_info); + } + + void DoAcceptObjectManager() { + object_manager_acceptor_.async_accept( + object_manager_socket_, boost::bind(&MockServer::HandleAcceptObjectManager, this, + boost::asio::placeholders::error)); + } + + void HandleAcceptObjectManager(const boost::system::error_code &error) { + ClientHandler client_handler = + [this](std::shared_ptr client) { + object_manager_.ProcessNewClient(client); + }; + MessageHandler message_handler = [this]( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + object_manager_.ProcessClientMessage(client, message_type, message); + }; + // Accept a new local client and dispatch it to the node manager. + auto new_connection = TcpClientConnection::Create(client_handler, message_handler, + std::move(object_manager_socket_)); + DoAcceptObjectManager(); + } + + friend class TestObjectManagerCommands; + + boost::asio::ip::tcp::acceptor object_manager_acceptor_; + boost::asio::ip::tcp::socket object_manager_socket_; + std::shared_ptr gcs_client_; + ObjectManager object_manager_; +}; + +class TestObjectManager : public ::testing::Test { + public: + TestObjectManager() {} + + std::string StartStore(const std::string &id) { + std::string store_id = "/tmp/store"; + store_id = store_id + id; + std::string plasma_command = store_executable + " -m 1000000000 -s " + store_id + + " 1> /dev/null 2> /dev/null &"; + RAY_LOG(DEBUG) << plasma_command; + int ec = system(plasma_command.c_str()); + if (ec != 0) { + throw std::runtime_error("failed to start plasma store."); + }; + return store_id; + } + + void SetUp() { + flushall_redis(); + + object_manager_service_1.reset(new boost::asio::io_service()); + object_manager_service_2.reset(new boost::asio::io_service()); + + // start store + std::string store_sock_1 = StartStore("1"); + std::string store_sock_2 = StartStore("2"); + + // start first server + gcs_client_1 = std::shared_ptr(new gcs::AsyncGcsClient()); + ObjectManagerConfig om_config_1; + om_config_1.store_socket_name = store_sock_1; + server1.reset(new MockServer(main_service, std::move(object_manager_service_1), + om_config_1, gcs_client_1)); + + // start second server + gcs_client_2 = std::shared_ptr(new gcs::AsyncGcsClient()); + ObjectManagerConfig om_config_2; + om_config_2.store_socket_name = store_sock_2; + server2.reset(new MockServer(main_service, std::move(object_manager_service_2), + om_config_2, gcs_client_2)); + + // connect to stores. + ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY)); + ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY)); + } + + void TearDown() { + arrow::Status client1_status = client1.Disconnect(); + arrow::Status client2_status = client2.Disconnect(); + ASSERT_TRUE(client1_status.ok() && client2_status.ok()); + + this->server1.reset(); + this->server2.reset(); + + int s = system("killall plasma_store &"); + ASSERT_TRUE(!s); + } + + ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { + ObjectID object_id = ObjectID::from_random(); + RAY_LOG(DEBUG) << "ObjectID Created: " << object_id; + uint8_t metadata[] = {5}; + int64_t metadata_size = sizeof(metadata); + std::shared_ptr data; + ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata, + metadata_size, &data)); + ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id())); + return object_id; + } + + void object_added_handler_1(ObjectID object_id) { v1.push_back(object_id); }; + + void object_added_handler_2(ObjectID object_id) { v2.push_back(object_id); }; + + protected: + std::thread p; + boost::asio::io_service main_service; + std::unique_ptr object_manager_service_1; + std::unique_ptr object_manager_service_2; + std::shared_ptr gcs_client_1; + std::shared_ptr gcs_client_2; + std::unique_ptr server1; + std::unique_ptr server2; + + plasma::PlasmaClient client1; + plasma::PlasmaClient client2; + std::vector v1; + std::vector v2; +}; + +class TestObjectManagerCommands : public TestObjectManager { + public: + int num_connected_clients = 0; + uint num_expected_objects; + ClientID client_id_1; + ClientID client_id_2; + + ObjectID created_object_id; + + void WaitConnections() { + client_id_1 = gcs_client_1->client_table().GetLocalClientId(); + client_id_2 = gcs_client_2->client_table().GetLocalClientId(); + gcs_client_1->client_table().RegisterClientAddedCallback([this]( + gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + ClientID parsed_id = ClientID::from_binary(data.client_id); + if (parsed_id == client_id_1 || parsed_id == client_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 2) { + StartTests(); + } + }); + } + + void StartTests() { + TestConnections(); + TestNotifications(); + } + + void TestNotifications() { + ray::Status status = ray::Status::OK(); + status = + server1->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) { + object_added_handler_1(object_id); + if (v1.size() == num_expected_objects) { + NotificationTestComplete(created_object_id, object_id); + } + }); + RAY_CHECK_OK(status); + + num_expected_objects = 1; + uint data_size = 1000000; + created_object_id = WriteDataToClient(client1, data_size); + } + + void NotificationTestComplete(ObjectID object_id_1, ObjectID object_id_2) { + ASSERT_EQ(object_id_1, object_id_2); + main_service.stop(); + } + + void TestConnections() { + RAY_LOG(DEBUG) << "\n" + << "Server client ids:" + << "\n"; + const ClientTableDataT &data = gcs_client_1->client_table().GetClient(client_id_1); + RAY_LOG(DEBUG) << (ClientID::from_binary(data.client_id) == ClientID::nil()); + RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::from_binary(data.client_id); + RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address; + RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port; + ASSERT_EQ(client_id_1, ClientID::from_binary(data.client_id)); + const ClientTableDataT &data2 = gcs_client_1->client_table().GetClient(client_id_2); + RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::from_binary(data2.client_id); + RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address; + RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port; + ASSERT_EQ(client_id_2, ClientID::from_binary(data2.client_id)); + } +}; + +TEST_F(TestObjectManagerCommands, StartTestObjectManagerCommands) { + auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); }); + AsyncStartTests(); + main_service.run(); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + ray::store_executable = std::string(argv[1]); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/object_manager/transfer_queue.cc b/src/ray/object_manager/transfer_queue.cc new file mode 100644 index 000000000..1d0dc0e33 --- /dev/null +++ b/src/ray/object_manager/transfer_queue.cc @@ -0,0 +1,67 @@ +#include "ray/object_manager/transfer_queue.h" + +namespace ray { + +void TransferQueue::QueueSend(const ClientID &client_id, const ObjectID &object_id, + const RemoteConnectionInfo &info) { + WriteLock guard(send_mutex); + SendRequest req = {client_id, object_id, info}; + // TODO(hme): Use a set to speed this up. + if (std::find(send_queue_.begin(), send_queue_.end(), req) != send_queue_.end()) { + // already queued. + return; + } + send_queue_.push_back(req); +} + +void TransferQueue::QueueReceive(const ClientID &client_id, const ObjectID &object_id, + uint64_t object_size, + std::shared_ptr conn) { + WriteLock guard(receive_mutex); + ReceiveRequest req = {client_id, object_id, object_size, conn}; + if (std::find(receive_queue_.begin(), receive_queue_.end(), req) != + receive_queue_.end()) { + // already queued. + return; + } + receive_queue_.push_back(req); +} + +bool TransferQueue::DequeueSendIfPresent(TransferQueue::SendRequest *send_ptr) { + WriteLock guard(send_mutex); + if (send_queue_.empty()) { + return false; + } + *send_ptr = send_queue_.front(); + send_queue_.pop_front(); + return true; +} + +bool TransferQueue::DequeueReceiveIfPresent(TransferQueue::ReceiveRequest *receive_ptr) { + WriteLock guard(receive_mutex); + if (receive_queue_.empty()) { + return false; + } + *receive_ptr = receive_queue_.front(); + receive_queue_.pop_front(); + return true; +} + +UniqueID TransferQueue::AddContext(SendContext &context) { + WriteLock guard(context_mutex); + UniqueID id = UniqueID::from_random(); + send_context_set_.emplace(id, context); + return id; +} + +TransferQueue::SendContext &TransferQueue::GetContext(const UniqueID &id) { + ReadLock guard(context_mutex); + return send_context_set_[id]; +} + +ray::Status TransferQueue::RemoveContext(const UniqueID &id) { + WriteLock guard(context_mutex); + send_context_set_.erase(id); + return Status::OK(); +} +} // namespace ray diff --git a/src/ray/object_manager/transfer_queue.h b/src/ray/object_manager/transfer_queue.h new file mode 100644 index 000000000..50645fe93 --- /dev/null +++ b/src/ray/object_manager/transfer_queue.h @@ -0,0 +1,126 @@ +#ifndef RAY_OBJECT_MANAGER_TRANSFER_QUEUE_H +#define RAY_OBJECT_MANAGER_TRANSFER_QUEUE_H + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "ray/id.h" +#include "ray/status.h" + +#include "ray/object_manager/format/object_manager_generated.h" +#include "ray/object_manager/object_directory.h" +#include "ray/object_manager/object_manager_client_connection.h" + +namespace ray { + +class TransferQueue { + public: + enum TransferType { SEND = 1, RECEIVE }; + + /// Context maintained during an object send. + struct SendContext { + ClientID client_id; + ObjectID object_id; + uint64_t object_size; + uint8_t *data; + }; + + /// The structure used in the send queue. + struct SendRequest { + ClientID client_id; + ObjectID object_id; + RemoteConnectionInfo connection_info; + bool operator==(const SendRequest &rhs) const { + return client_id == rhs.client_id && object_id == rhs.object_id; + } + }; + + /// The structure used in the receive queue. + struct ReceiveRequest { + ClientID client_id; + ObjectID object_id; + uint64_t object_size; + std::shared_ptr conn; + bool operator==(const ReceiveRequest &rhs) const { + return client_id == rhs.client_id && object_id == rhs.object_id; + } + }; + + /// Queues a send. + /// + /// \param client_id The ClientID to which the object needs to be sent. + /// \param object_id The ObjectID of the object to be sent. + void QueueSend(const ClientID &client_id, const ObjectID &object_id, + const RemoteConnectionInfo &info); + + /// If send_queue_ is not empty, removes a SendRequest from send_queue_ and assigns + /// it to send_ptr. The queue is FIFO. + /// \param send_ptr A pointer to an empty SendRequest. + /// \return A bool indicating whether the queue was empty at the time this method + /// was invoked. + bool DequeueSendIfPresent(TransferQueue::SendRequest *send_ptr); + + /// Queues a receive. + /// + /// \param client_id The ClientID from which the object is being received. + /// \param object_id The ObjectID of the object to be received. + void QueueReceive(const ClientID &client_id, const ObjectID &object_id, + uint64_t object_size, std::shared_ptr conn); + + /// If receive_queue_ is not empty, removes a ReceiveRequest from receive_queue_ and + /// assigns + /// it to receive_ptr. The queue is FIFO. + /// \param receive_ptr A pointer to an empty ReceiveRequest. + /// \return A bool indicating whether the queue was empty at the time this method + /// was invoked. + bool DequeueReceiveIfPresent(TransferQueue::ReceiveRequest *receive_ptr); + + /// Maintain ownership over SendContext for sends in transit. + /// + /// \param context The context to maintain. + /// \return A unique identifier identifying the context that was added. + UniqueID AddContext(SendContext &context); + + /// Gets the SendContext associated with the given id. + /// + /// \param id The unique identifier of the context. + /// \return The context. + SendContext &GetContext(const UniqueID &id); + + /// Removes the context associated with the given id. + /// + /// \param id The unique identifier of the context. + /// \return The status of invoking this method. + ray::Status RemoveContext(const UniqueID &id); + + /// This object cannot be copied for thread-safety. + TransferQueue &operator=(const TransferQueue &o) { + throw std::runtime_error("Can't copy TransferQueue."); + } + + private: + // TODO(hme): make this a shared mutex. + typedef std::mutex Lock; + typedef std::unique_lock WriteLock; + // TODO(hme): make this a shared lock. + typedef std::unique_lock ReadLock; + Lock send_mutex; + Lock receive_mutex; + Lock context_mutex; + + std::deque send_queue_; + std::deque receive_queue_; + std::unordered_map send_context_set_; +}; +} // namespace ray + +#endif // RAY_OBJECT_MANAGER_TRANSFER_QUEUE_H diff --git a/src/ray/python/default_worker.py b/src/ray/python/default_worker.py new file mode 100644 index 000000000..877c6220e --- /dev/null +++ b/src/ray/python/default_worker.py @@ -0,0 +1,18 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse + +from worker import Worker + +parser = argparse.ArgumentParser() +parser.add_argument("raylet_socket_name") +parser.add_argument("object_store_socket_name") + +if __name__ == '__main__': + args = parser.parse_args() + + worker = Worker(args.raylet_socket_name, args.object_store_socket_name, + is_worker=True) + worker.main_loop() diff --git a/src/ray/python/one_test_driver.py b/src/ray/python/one_test_driver.py new file mode 100644 index 000000000..b18dc5b20 --- /dev/null +++ b/src/ray/python/one_test_driver.py @@ -0,0 +1,33 @@ +import argparse + +import ray +from worker import Worker, logger +from ray.utils import random_string + + +parser = argparse.ArgumentParser() +parser.add_argument("raylet_socket_name") +parser.add_argument("object_store_socket_name") + +if __name__ == '__main__': + args = parser.parse_args() + + driver = Worker(args.raylet_socket_name, args.object_store_socket_name, + is_worker=False) + + task1 = ray.local_scheduler.Task( + ray.local_scheduler.ObjectID(random_string()), + ray.local_scheduler.ObjectID(random_string()), + [], + 1, + ray.local_scheduler.ObjectID(random_string()), + 0) + logger.debug("submitting", task1.task_id()) + driver.node_manager_client.submit(task1) + + logger.debug("Return values were", task1.returns()) + print("[DRIVER] Return values were", task1.returns()) + # Make sure the tasks get executed and we can get the result of the + # last task + obj = driver.get(task1.returns(), timeout_ms=1000) + print("[DRIVER]: task1 driver.get result ", obj) diff --git a/src/ray/python/test_driver.py b/src/ray/python/test_driver.py new file mode 100644 index 000000000..5850f4d78 --- /dev/null +++ b/src/ray/python/test_driver.py @@ -0,0 +1,40 @@ +import argparse + +import ray +from worker import Worker, logger +from ray.utils import random_string + +parser = argparse.ArgumentParser() +parser.add_argument("raylet_socket_name") +parser.add_argument("object_store_socket_name") + +if __name__ == '__main__': + args = parser.parse_args() + + driver = Worker(args.raylet_socket_name, args.object_store_socket_name, + is_worker=False) + + task = ray.local_scheduler.Task( + ray.local_scheduler.ObjectID(random_string()), + ray.local_scheduler.ObjectID(random_string()), + [], + 1, + ray.local_scheduler.ObjectID(random_string()), + 0) + logger.debug("submitting %s", task.task_id()) + driver.node_manager_client.submit(task) + + logger.debug("Return values were %s", task.returns()) + task2 = ray.local_scheduler.Task( + ray.local_scheduler.ObjectID(random_string()), + ray.local_scheduler.ObjectID(random_string()), + task.returns(), + 1, + ray.local_scheduler.ObjectID(random_string()), + 0) + logger.debug("Submitting dependent task 2 %s", task2.task_id()) + driver.node_manager_client.submit(task2) + + # Make sure the tasks get executed and we can get the result of the last + # task. + obj = driver.get(task2.returns(), timeout_ms=1000) diff --git a/src/ray/python/test_driver_taskchains.py b/src/ray/python/test_driver_taskchains.py new file mode 100644 index 000000000..8a14c2f68 --- /dev/null +++ b/src/ray/python/test_driver_taskchains.py @@ -0,0 +1,115 @@ +import argparse + +import ray +from worker import Worker, logger +from ray.utils import random_string + + +parser = argparse.ArgumentParser() +parser.add_argument("raylet_socket_name") +parser.add_argument("object_store_socket_name") + + +def submit_task_withdep(driver_handle, task_object_dependencies=[]): + ''' submit a task that depend on a list of @args''' + task = ray.local_scheduler.Task( + ray.local_scheduler.ObjectID(random_string()), + ray.local_scheduler.ObjectID(random_string()), + task_object_dependencies, + 1, # num_returns + ray.local_scheduler.ObjectID(random_string()), + 0) + logger.debug("[DRIVER]: submitting task ", task.task_id()) + driver_handle.node_manager_client.submit(task) + logger.debug("[DRIVER]: task return values", task.returns()) + return task.returns() + + +def submit_tasks_nodep(driver_handle, num_tasks): + ''' submit a task that depend on a list of @args''' + for i in range(num_tasks): + task = ray.local_scheduler.Task( + ray.local_scheduler.ObjectID(random_string()), + ray.local_scheduler.ObjectID(random_string()), + [], + 1, # num_returns + ray.local_scheduler.ObjectID(random_string()), + 0) + + logger.debug("[DRIVER]: submitting task ", task.task_id()) + driver_handle.node_manager_client.submit(task) + logger.debug("[DRIVER]: task return values", task.returns()) + + +def submit_task_chains(num_chains, tasks_per_chain): + # return task placement map on output + chain_returns = [] + task_placement_map_ = {} + for chain_num in range(num_chains): + last_task_returns = [] + task_placement_map_[chain_num] = [] + for i in range(tasks_per_chain): + task_returns = submit_task_withdep( + driver, + task_object_dependencies=last_task_returns) + last_task_returns = task_returns + task_placement_map_[chain_num].append(task_returns[0]) + chain_returns.append(last_task_returns) + + logger.debug("chain_returns=", chain_returns) + chain_results = driver.get([r[0] for r in chain_returns], timeout_ms=5000) + print("[DRIVER]: chain return values: ", chain_results) + + return task_placement_map_ + + +def TEST_run_task_chains(num_chains, tasks_per_chain): + task_placement_map = submit_task_chains(num_chains=num_chains, + tasks_per_chain=tasks_per_chain) + logger.debug("[DRIVER]: task placement information, per chain:") + task_placement_total = [] + for chain_num in range(len(task_placement_map)): + task_placement_list = driver.get(task_placement_map[chain_num], + timeout_ms=5000) + task_placement_total += [t[1] for t in task_placement_list] + logger.debug(chain_num, task_placement_list) + logger.debug("task placement overall: ", task_placement_total) + task_placement_stats = [(v, task_placement_total.count(v)) + for v in set(task_placement_total)] + num_total_tasks = sum([t[1] for t in task_placement_stats]) + print("total tasks executed = ", num_total_tasks) + assert(num_total_tasks == num_chains * tasks_per_chain) + print("task placement breakdown: total=", task_placement_stats) + + +def TEST_run_tasks_nodep(num_tasks): + # This test is the same as having num_tasks chains with 1 task per chain + # In this test we assume the num_tasks x 1 chain structure. + task_placement_map = submit_task_chains(num_chains=num_tasks, + tasks_per_chain=1) + logger.debug("[DRIVER]: task placement information, per chain:") + task_placement_total = [] + for chain_num in range(len(task_placement_map)): + task_placement_list = driver.get(task_placement_map[chain_num], + timeout_ms=5000) + task_placement_total += [t[1] for t in task_placement_list] + logger.debug(chain_num, task_placement_list) + logger.debug("task placement overall: ", task_placement_total) + task_placement_stats = [(v, task_placement_total.count(v)) for v in + set(task_placement_total)] + num_total_tasks = sum([t[1] for t in task_placement_stats]) + print("total tasks executed = ", num_total_tasks) + assert(num_total_tasks == num_tasks) + print("task placement breakdown: total=", task_placement_stats) + + +if __name__ == '__main__': + args = parser.parse_args() + + driver = Worker(args.raylet_socket_name, args.object_store_socket_name, + is_worker=False) + + # Set up the experiment : number of chains and tasks per chain. + # TEST_run_task_chains(num_chains=10, tasks_per_chain=100) + + TEST_run_tasks_nodep(10000) diff --git a/src/ray/python/worker.py b/src/ray/python/worker.py new file mode 100644 index 000000000..072b386f9 --- /dev/null +++ b/src/ray/python/worker.py @@ -0,0 +1,67 @@ +import logging + +import ray +import pyarrow +import pyarrow.plasma as plasma +from ray.utils import random_string + + +logging.basicConfig() +logger = logging.getLogger(__name__) + +# The default return value to put in the object store. +RETURN_VALUE = 0 + + +class Worker(object): + + total_task_count = 0 + + def __init__(self, raylet_socket_name, object_store_socket_name, + is_worker): + # Connect to the Raylet and object store. + self.node_manager_client = ray.local_scheduler.LocalSchedulerClient( + raylet_socket_name, random_string(), is_worker) + self.plasma_client = plasma.connect(object_store_socket_name, "", 0) + self.serialization_context = pyarrow.default_serialization_context() + self.raylet_socket_name = raylet_socket_name + self.object_store_socket_name = object_store_socket_name + + def main_loop(self): + while True: + self.get_task() + + def get(self, object_ids, timeout_ms=-1): + for object_id in object_ids: + self.node_manager_client.reconstruct_object(object_id.id()) + plasma_ids = [plasma.ObjectID(argument.id()) for argument in + object_ids] + values = self.plasma_client.get(plasma_ids, timeout_ms, + self.serialization_context) + assert(all(value[0] == RETURN_VALUE for value in values)) + return values + + def get_task(self): + logger.debug("[WORKER] waiting for task") + task = self.node_manager_client.get_task() + logger.debug("Worker assigned %s with arguments %s", + ray.utils.binary_to_hex(task.task_id().id()), + " ".join([ray.utils.binary_to_hex(argument.id()) for + argument in task.arguments()])) + + # Get the arguments. NOTE(swang): This will hang forever if the + # arguments have been evicted. + arguments = self.get(task.arguments()) + + for object_id in task.returns(): + self.plasma_client.put((RETURN_VALUE, self.raylet_socket_name), + plasma.ObjectID(object_id.id())) + objval = self.plasma_client.get([plasma.ObjectID(object_id.id())]) + assert(all([o[0] == RETURN_VALUE for o in objval])) + + logger.debug("Worker returned %s", + " ".join([ray.utils.binary_to_hex(return_id.id()) for + return_id in task.returns()])) + + # Release the arguments. + del arguments diff --git a/src/ray/raylet/CMakeLists.txt b/src/ray/raylet/CMakeLists.txt index 07420eb6b..c0580552e 100644 --- a/src/ray/raylet/CMakeLists.txt +++ b/src/ray/raylet/CMakeLists.txt @@ -19,17 +19,18 @@ add_custom_command( add_custom_target(gen_node_manager_fbs DEPENDS ${NODE_MANAGER_FBS_OUTPUT_FILES}) -ADD_RAY_TEST(raylet_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY}) +ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY}) ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) +ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) +ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY}) + add_library(rayletlib raylet.cc ${NODE_MANAGER_FBS_OUTPUT_FILES}) target_link_libraries(rayletlib ray_static ${Boost_SYSTEM_LIBRARY}) add_executable(raylet main.cc) target_link_libraries(raylet rayletlib ${Boost_SYSTEM_LIBRARY} pthread) -add_executable(raylet_demo remote_dependencies_demo.cc) -target_link_libraries(raylet_demo rayletlib ${Boost_SYSTEM_LIBRARY} pthread) install(FILES raylet diff --git a/src/ray/raylet/actor.cc b/src/ray/raylet/actor.cc index 4b9dc196d..0dd2487c7 100644 --- a/src/ray/raylet/actor.cc +++ b/src/ray/raylet/actor.cc @@ -2,10 +2,14 @@ namespace ray { +namespace raylet { + ActorInformation::ActorInformation() : id_(UniqueID::nil()) {} ActorInformation::~ActorInformation() {} const ActorID &ActorInformation::GetActorId() const { return this->id_; } +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/actor.h b/src/ray/raylet/actor.h index 56bdb4f78..25f9e2dde 100644 --- a/src/ray/raylet/actor.h +++ b/src/ray/raylet/actor.h @@ -4,6 +4,9 @@ #include "ray/id.h" namespace ray { + +namespace raylet { + class ActorInformation { public: /// \brief ActorInformation constructor. @@ -21,6 +24,8 @@ class ActorInformation { ActorID id_; }; // class ActorInformation +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_ACTOR_H diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 7de47dd31..c4d50a217 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -51,7 +51,9 @@ enum MessageType:int { // counts and execution dependencies, discard any tasks that already executed // before the checkpoint, and make any tasks on the frontier runnable by // making their execution dependencies available. - SetActorFrontier + SetActorFrontier, + // A node manager request to process a task forwarded from another node manager. + ForwardTaskRequest } table TaskExecutionSpecification { @@ -82,12 +84,6 @@ table GetTaskReply { gpu_ids: [int]; } -table EventLogMessage { - key: string; - value: string; - timestamp: double; -} - // This struct is used to register a new worker with the local scheduler. // It is shipped as part of local_scheduler_connect. table RegisterClientRequest { @@ -108,57 +104,20 @@ table RegisterClientReply { gpu_ids: [int]; } -table DisconnectClient { -} - -table ReconstructObject { - // Object ID of the object that needs to be reconstructed. - object_id: string; -} - -table PutObject { - // Task ID of the task that performed the put. - task_id: string; - // Object ID of the object that is being put. - object_id: string; -} - -// The ActorFrontier is used to represent the current frontier of tasks that -// the local scheduler has marked as runnable for a particular actor. It is -// used to save the point in an actor's lifetime at which a checkpoint was -// taken, so that the same frontier of tasks can be made runnable again if the -// actor is resumed from that checkpoint. -table ActorFrontier { - // Actor ID of the actor whose frontier is described. - actor_id: string; - // A list of handle IDs, representing the callers of the actor that have - // submitted a runnable task to the local scheduler. A nil ID represents the - // creator of the actor. - handle_ids: [string]; - // A list representing the number of tasks executed so far, per handle. Each - // count in task_counters corresponds to the handle at the same in index in - // handle_ids. - task_counters: [long]; - // A list representing the execution dependency for the next runnable task, - // per handle. Each execution dependency in frontier_dependencies corresponds - // to the handle at the same in index in handle_ids. - frontier_dependencies: [string]; -} - -table GetActorFrontierRequest { - actor_id: string; -} - table RegisterNodeManagerRequest { // GCS ClientID of the connecting node manager. client_id: string; } table ForwardTaskRequest { - // The task to be forwarded. - // TODO(swang): Replace with a Task flatbuffer type. - task: string; - // The uncommitted lineage of the forwarded task, according to the sending - // node manager. - uncommitted_lineage: [string]; + // The ID of the task to be forwarded. + task_id: string; + // The tasks in the uncommitted lineage of the forwarded task. This + // should include task_id. + uncommitted_tasks: [Task]; +} + +table ReconstructObject { + // Object ID of the object that needs to be reconstructed. + object_id: string; } diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index e5db72289..30030e9f6 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -2,30 +2,298 @@ namespace ray { -LineageCache::LineageCache() {} +namespace raylet { -ray::Status LineageCache::AddTask(const Task &task) { - throw std::runtime_error("method not implemented"); - return ray::Status::OK(); +LineageEntry::LineageEntry(const Task &task, GcsStatus status) + : status_(status), task_(task) {} + +GcsStatus LineageEntry::GetStatus() const { return status_; } + +bool LineageEntry::SetStatus(GcsStatus new_status) { + if (status_ < new_status) { + status_ = new_status; + return true; + } else { + return false; + } } -ray::Status LineageCache::AddTask(const Task &task, const Lineage &uncommitted_lineage) { - throw std::runtime_error("method not implemented"); - return ray::Status::OK(); +void LineageEntry::ResetStatus(GcsStatus new_status) { + RAY_CHECK(new_status < status_); + status_ = new_status; } -ray::Status LineageCache::AddObjectLocation(const ObjectID &object_id) { - throw std::runtime_error("method not implemented"); - return ray::Status::OK(); +const TaskID LineageEntry::GetEntryId() const { + return task_.GetTaskSpecification().TaskId(); } -Lineage &LineageCache::GetUncommittedLineage(const ObjectID &object_id) { - throw std::runtime_error("method not implemented"); +const std::unordered_set LineageEntry::GetParentTaskIds() + const { + std::unordered_set parent_ids; + // A task's parents are the tasks that created its arguments. + auto dependencies = task_.GetDependencies(); + for (auto &dependency : dependencies) { + parent_ids.insert(ComputeTaskId(dependency)); + } + return parent_ids; +} + +const Task &LineageEntry::TaskData() const { return task_; } + +Task &LineageEntry::TaskDataMutable() { return task_; } + +Lineage::Lineage() {} + +Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) { + // Deserialize and set entries for the uncommitted tasks. + auto tasks = task_request.uncommitted_tasks(); + for (auto it = tasks->begin(); it != tasks->end(); it++) { + auto task = Task(**it); + LineageEntry entry(task, GcsStatus_UNCOMMITTED_REMOTE); + RAY_CHECK(SetEntry(std::move(entry))); + } +} + +boost::optional Lineage::GetEntry(const UniqueID &task_id) const { + auto entry = entries_.find(task_id); + if (entry != entries_.end()) { + return entry->second; + } else { + return boost::optional(); + } +} + +boost::optional Lineage::GetEntryMutable(const UniqueID &task_id) { + auto entry = entries_.find(task_id); + if (entry != entries_.end()) { + return entry->second; + } else { + return boost::optional(); + } +} + +bool Lineage::SetEntry(LineageEntry &&new_entry) { + // Get the status of the current entry at the key. + auto task_id = new_entry.GetEntryId(); + GcsStatus current_status = GcsStatus_NONE; + auto current_entry = PopEntry(task_id); + if (current_entry) { + current_status = current_entry->GetStatus(); + } + + if (current_status < new_entry.GetStatus()) { + // If the new status is greater, then overwrite the current entry. + entries_.emplace(std::make_pair(task_id, std::move(new_entry))); + return true; + } else { + // If the new status is not greater, then the new entry is invalid. Replace + // the current entry at the key. + entries_.emplace(std::make_pair(task_id, std::move(*current_entry))); + return false; + } +} + +boost::optional Lineage::PopEntry(const UniqueID &task_id) { + auto entry = entries_.find(task_id); + if (entry != entries_.end()) { + LineageEntry entry = std::move(entries_.at(task_id)); + entries_.erase(task_id); + return entry; + } else { + return boost::optional(); + } +} + +const std::unordered_map + &Lineage::GetEntries() const { + return entries_; +} + +flatbuffers::Offset Lineage::ToFlatbuffer( + flatbuffers::FlatBufferBuilder &fbb, const TaskID &task_id) const { + RAY_CHECK(GetEntry(task_id)); + // Serialize the task and object entries. + std::vector> uncommitted_tasks; + for (const auto &entry : entries_) { + uncommitted_tasks.push_back(entry.second.TaskData().ToFlatbuffer(fbb)); + } + + auto request = protocol::CreateForwardTaskRequest(fbb, to_flatbuf(fbb, task_id), + fbb.CreateVector(uncommitted_tasks)); + return request; +} + +LineageCache::LineageCache(gcs::TableInterface &task_storage) + : task_storage_(task_storage) {} + +/// A helper function to merge one lineage into another, in DFS order. +/// +/// \param task_id The current entry to merge from lineage_from into +/// lineage_to. +/// \param lineage_from The lineage to merge entries from. This lineage is +/// traversed by following each entry's parent pointers in DFS order, +/// until an entry is not found or the stopping condition is reached. +/// \param lineage_to The lineage to merge entries into. +/// \param stopping_condition A stopping condition for the DFS over +/// lineage_from. This should return true if the merge should stop. +void MergeLineageHelper(const UniqueID &task_id, const Lineage &lineage_from, + Lineage &lineage_to, + std::function stopping_condition) { + // If the entry is not found in the lineage to merge, then we stop since + // there is nothing to copy into the merged lineage. + auto entry = lineage_from.GetEntry(task_id); + if (!entry) { + return; + } + // Check whether we should stop at this entry in the DFS. + auto status = entry->GetStatus(); + if (stopping_condition(status)) { + return; + } + + // Insert a copy of the entry into lineage_to. + LineageEntry entry_copy = *entry; + auto parent_ids = entry_copy.GetParentTaskIds(); + // If the insert is successful, then continue the DFS. The insert will fail + // if the new entry has an equal or lower GCS status than the current entry + // in lineage_to. This also prevents us from traversing the same node twice. + if (lineage_to.SetEntry(std::move(entry_copy))) { + for (const auto &parent_id : parent_ids) { + MergeLineageHelper(parent_id, lineage_from, lineage_to, stopping_condition); + } + } +} + +void LineageCache::AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage) { + auto task_id = task.GetTaskSpecification().TaskId(); + // Merge the uncommitted lineage into the lineage cache. + MergeLineageHelper(task_id, uncommitted_lineage, lineage_, [](GcsStatus status) { + if (status != GcsStatus_NONE) { + // We received the uncommitted lineage from a remote node, so make sure + // that all entries in the lineage to merge have status + // UNCOMMITTED_REMOTE. + RAY_CHECK(status == GcsStatus_UNCOMMITTED_REMOTE); + } + // The only stopping condition is that an entry is not found. + return false; + }); + + // Add the submitted task to the lineage cache as UNCOMMITTED_WAITING. It + // should be marked as UNCOMMITTED_READY once the task starts execution. + LineageEntry task_entry(task, GcsStatus_UNCOMMITTED_WAITING); + RAY_CHECK(lineage_.SetEntry(std::move(task_entry))); +} + +void LineageCache::AddReadyTask(const Task &task) { + auto new_entry = LineageEntry(task, GcsStatus_UNCOMMITTED_READY); + RAY_CHECK(lineage_.SetEntry(std::move(new_entry))); +} + +void LineageCache::RemoveWaitingTask(const TaskID &task_id) { + auto entry = lineage_.PopEntry(task_id); + // It's only okay to remove a task that is waiting for execution. + // TODO(swang): Is this necessarily true when there is reconstruction? + RAY_CHECK(entry->GetStatus() == GcsStatus_UNCOMMITTED_WAITING); + // Reset the status to REMOTE. We keep the task instead of removing it + // completely in case another task is submitted locally that depends on this + // one. + entry->ResetStatus(GcsStatus_UNCOMMITTED_REMOTE); + RAY_CHECK(lineage_.SetEntry(std::move(*entry))); +} + +Lineage LineageCache::GetUncommittedLineage(const TaskID &task_id) const { + Lineage uncommitted_lineage; + // Add all uncommitted ancestors from the lineage cache to the uncommitted + // lineage of the requested task. + MergeLineageHelper(task_id, lineage_, uncommitted_lineage, [](GcsStatus status) { + // The stopping condition for recursion is that the entry has been + // committed to the GCS. + return status == GcsStatus_COMMITTED; + }); + return uncommitted_lineage; } Status LineageCache::Flush() { - throw std::runtime_error("method not implemented"); + // Find all tasks that are READY and whose arguments have been committed in the GCS. + std::vector ready_task_ids; + for (const auto &pair : lineage_.GetEntries()) { + auto task_id = pair.first; + auto entry = pair.second; + // Skip task entries that are not ready to be written yet. These tasks + // either have not started execution yet, are being executed on a remote + // node, or have already been written to the GCS. + if (entry.GetStatus() != GcsStatus_UNCOMMITTED_READY) { + continue; + } + // Check if all arguments have been committed to the GCS before writing + // this task. + bool all_arguments_committed = true; + for (const auto &parent_id : entry.GetParentTaskIds()) { + auto parent = lineage_.GetEntry(parent_id); + // If a parent entry exists in the lineage cache but has not been + // committed yet, then as far as we know, it's still in flight to the + // GCS. Skip this task for now. + if (parent && parent->GetStatus() != GcsStatus_COMMITTED) { + // TODO(swang): Once GCS notifications for the task table are ready, + // request notification for commit of the parent task here. + all_arguments_committed = false; + break; + } + } + if (all_arguments_committed) { + // All arguments have been committed to the GCS. Add this task to the + // list of tasks to write back to the GCS. + ready_task_ids.push_back(task_id); + } + } + + // Write back all ready tasks whose arguments have been committed to the GCS. + gcs::raylet::TaskTable::WriteCallback task_callback = + [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, + const std::shared_ptr data) { HandleEntryCommitted(id); }; + for (const auto &ready_task_id : ready_task_ids) { + auto task = lineage_.GetEntry(ready_task_id); + // TODO(swang): Make this better... + flatbuffers::FlatBufferBuilder fbb; + auto message = task->TaskData().ToFlatbuffer(fbb); + fbb.Finish(message); + auto task_data = std::make_shared(); + auto root = flatbuffers::GetRoot(fbb.GetBufferPointer()); + root->UnPackTo(task_data.get()); + RAY_CHECK_OK(task_storage_.Add(task->TaskData().GetTaskSpecification().DriverId(), + ready_task_id, task_data, task_callback)); + + // We successfully wrote the task, so mark it as committing. + // TODO(swang): Use a batched interface and write with all object entries. + auto entry = lineage_.PopEntry(ready_task_id); + RAY_CHECK(entry->SetStatus(GcsStatus_COMMITTING)); + RAY_CHECK(lineage_.SetEntry(std::move(*entry))); + } + return ray::Status::OK(); } +void PopAncestorTasks(const UniqueID &task_id, Lineage &lineage) { + auto entry = lineage.PopEntry(task_id); + if (!entry) { + return; + } + auto status = entry->GetStatus(); + RAY_CHECK(status == GcsStatus_UNCOMMITTED_REMOTE || status == GcsStatus_COMMITTED); + for (const auto &parent_id : entry->GetParentTaskIds()) { + PopAncestorTasks(parent_id, lineage); + } +} + +void LineageCache::HandleEntryCommitted(const UniqueID &task_id) { + auto entry = lineage_.PopEntry(task_id); + for (const auto &parent_id : entry->GetParentTaskIds()) { + PopAncestorTasks(parent_id, lineage_); + } + RAY_CHECK(entry->SetStatus(GcsStatus_COMMITTED)); + RAY_CHECK(lineage_.SetEntry(std::move(*entry))); +} + +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index dfc3350b8..6dac1a3d6 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -1,83 +1,216 @@ #ifndef RAY_RAYLET_LINEAGE_CACHE_H #define RAY_RAYLET_LINEAGE_CACHE_H +#include + // clang-format off +#include "common_protocol.h" #include "ray/raylet/task.h" +#include "ray/gcs/tables.h" #include "ray/id.h" #include "ray/status.h" // clang-format on namespace ray { -// TODO(swang): Define this class. -class Lineage {}; +namespace raylet { -class LineageCacheEntry { - private: - // TODO(swang): This should be an enum of the state of the entry - goes from - // completely local, to dirty, to in flight, to committed. - // bool dirty_; +/// The status of a lineage cache entry according to its status in the GCS. +enum GcsStatus { + /// The task is not in the lineage cache. + GcsStatus_NONE = 0, + /// The task is being executed or created on a remote node. + GcsStatus_UNCOMMITTED_REMOTE, + /// The task is waiting to be executed or created locally. + GcsStatus_UNCOMMITTED_WAITING, + /// The task has started execution, but the entry has not been written to the + /// GCS yet. + GcsStatus_UNCOMMITTED_READY, + /// The task has been written to the GCS and we are waiting for an + /// acknowledgement of the commit. + GcsStatus_COMMITTING, + /// The task has been committed in the GCS. It's safe to remove this entry + /// from the lineage cache. + GcsStatus_COMMITTED, }; -class LineageCacheTaskEntry : public LineageCacheEntry {}; -class LineageCacheObjectEntry : public LineageCacheEntry {}; +/// \class LineageEntry +/// +/// A task entry in the data lineage. Each entry's parents are the tasks that +/// created the entry's arguments. +class LineageEntry { + public: + /// Create an entry for a task. + /// + /// \param task The task data to eventually be written back to the GCS. + /// \param status The status of this entry, according to its write status in + /// the GCS. + LineageEntry(const Task &task, GcsStatus status); + + /// Get this entry's GCS status. + /// + /// \return The entry's status in the GCS. + GcsStatus GetStatus() const; + + /// Set this entry's GCS status. The status is only set if the new status + /// is strictly greater than the entry's previous status, according to the + /// GcsStatus enum. + /// + /// \param new_status Set the entry's status to this value if it is greater + /// than the current status. + /// \return Whether the entry was set to the new status. + bool SetStatus(GcsStatus new_status); + + /// Reset this entry's GCS status to a lower status. The new status must + /// be lower than the current status. + /// + /// \param new_status This must be lower than the current status. + void ResetStatus(GcsStatus new_status); + + /// Get this entry's ID. + /// + /// \return The entry's ID. + const TaskID GetEntryId() const; + + /// Get the IDs of this entry's parent tasks. These are the IDs of the tasks + /// that created its arguments. + /// + /// \return The IDs of the parent entries. + const std::unordered_set GetParentTaskIds() const; + + /// Get the task data. + /// + /// \return The task data. + const Task &TaskData() const; + + Task &TaskDataMutable(); + + private: + /// The current state of this entry according to its status in the GCS. + GcsStatus status_; + /// The task data to be written to the GCS. This is nullptr if the entry is + /// an object. + // const Task task_; + Task task_; +}; + +/// \class Lineage +/// +/// A lineage DAG, according to the data dependency graph. Each node is a task, +/// with an outgoing edge to each of its parent tasks. For a given task, the +/// parents are the tasks that created its arguments. Each entry also records +/// the current status in the GCS for that task or object. +class Lineage { + public: + /// Construct an empty Lineage. + Lineage(); + + /// Construct a Lineage from a ForwardTaskRequest. + /// + /// \param task_request The request to construct the lineage from. All + /// uncommitted tasks in the request will be added to the lineage. + Lineage(const protocol::ForwardTaskRequest &task_request); + + /// Get an entry from the lineage. + /// + /// \param entry_id The ID of the entry to get. + /// \return An optional reference to the entry. If this is empty, then the + /// entry ID is not in the lineage. + boost::optional GetEntry(const TaskID &entry_id) const; + boost::optional GetEntryMutable(const UniqueID &task_id); + + /// Set an entry in the lineage. If an entry with this ID already exists, + /// then the entry is overwritten if and only if the new entry has a higher + /// GCS status than the current. The current entry's object or task data will + /// also be overwritten. + /// + /// \param entry The new entry to set in the lineage, if its GCS status is + /// greater than the current entry. + /// \return Whether the entry was set. + bool SetEntry(LineageEntry &&entry); + + /// Delete and return an entry from the lineage. + /// + /// \param entry_id The ID of the entry to pop. + /// \return An optional reference to the popped entry. If this is empty, then + /// the entry ID is not in the lineage. + boost::optional PopEntry(const TaskID &entry_id); + + /// Get all entries in the lineage. + /// + /// \return A const reference to the lineage entries. + const std::unordered_map &GetEntries() + const; + + /// Serialize this lineage to a ForwardTaskRequest flatbuffer. + /// + /// \param entry_id The task ID to include in the ForwardTaskRequest + /// flatbuffer. + /// \return An offset to the serialized lineage. The serialization includes + /// all task and object entries in the lineage. + flatbuffers::Offset ToFlatbuffer( + flatbuffers::FlatBufferBuilder &fbb, const TaskID &entry_id) const; + + private: + /// The lineage entries. + std::unordered_map entries_; +}; /// \class LineageCache /// -/// A cache of the object and task tables. This consists of all tasks that this -/// node owns, as well as their task lineage, that have not yet been added -/// durably to the GCS. +/// A cache of the task table. This consists of all tasks that this node owns, +/// as well as their lineage, that have not yet been added durably to the GCS. class LineageCache { public: - /// Create a lineage cache policy. - /// TODO(swang): Pass in the policy (interface?) and a GCS client. - LineageCache(); + /// Create a lineage cache for the given task storage system. + /// TODO(swang): Pass in the policy (interface?). + LineageCache(gcs::TableInterface &task_storage); - /// Add a task and its object outputs asynchronously to the GCS. This - /// overwrites the task's mutable fields in the execution specification. + /// Add a task that is waiting for execution and its uncommitted lineage. + /// These entries will not be written to the GCS until set to ready. /// - /// \param task The task to add. - /// \return Status. - ray::Status AddTask(const Task &task); - - /// Add a task and its uncommitted lineage asynchronously to the GCS. The - /// mutable fields for the given task will be overwritten, but not for the - /// tasks in the uncommitted lineage. - /// - /// \param task The task to add. + /// \param task The waiting task to add. /// \param uncommitted_lineage The task's uncommitted lineage. These are the - /// tasks that the given task is data-dependent on, but that have not - /// been made durable in the GCS, as far as we know. - /// \return Status. - ray::Status AddTask(const Task &task, const Lineage &uncommitted_lineage); + /// tasks that the given task is data-dependent on, but that have not + /// been made durable in the GCS, as far the task's submitter knows. + void AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage); - /// Add this node as an object location, to be asynchronously committed to - /// the GCS. + /// Add a task that is ready for GCS writeback. This overwrites the task’s + /// mutable fields in the execution specification. /// - /// \param object_id The object to add a location for. - /// \return Status. - ray::Status AddObjectLocation(const ObjectID &object_id); + /// \param task The task to set as ready. + void AddReadyTask(const Task &task); - /// Get the uncommitted lineage of an object. These are the tasks that the - /// given object is data-dependent on, but that have not been made durable in + void RemoveWaitingTask(const TaskID &entry_id); + + /// Get the uncommitted lineage of a task. The uncommitted lineage consists + /// of all tasks in the given task's lineage that have not been committed in /// the GCS, as far as we know. /// - /// \param object_id The object to get the uncommitted lineage for. - /// \return The uncommitted lineage of the object. - Lineage &GetUncommittedLineage(const ObjectID &object_id); + /// \param entry_id The ID of the task to get the uncommitted lineage for. + /// \return The uncommitted lineage of the task. The returned lineage + /// includes the entry for the requested entry_id. + Lineage GetUncommittedLineage(const TaskID &entry_id) const; - /// Asynchronously write any tasks and object locations that have been added - /// since the last flush to the GCS. When each write is acknowledged, its - /// entry will be marked as committed. + /// Asynchronously write any tasks that have been added since the last flush + /// to the GCS. When each write is acknowledged, its entry will be marked as + /// committed. /// /// \return Status. Status Flush(); private: - std::unordered_map task_table_; - std::unordered_map object_table_; + void HandleEntryCommitted(const TaskID &unique_id); + + /// The durable storage system for task information. + gcs::TableInterface &task_storage_; + /// All tasks and objects that we are responsible for writing back to the + /// GCS, and the tasks and objects in their lineage. + Lineage lineage_; }; +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_LINEAGE_CACHE_H diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc new file mode 100644 index 000000000..0610997ff --- /dev/null +++ b/src/ray/raylet/lineage_cache_test.cc @@ -0,0 +1,270 @@ +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "ray/raylet/format/node_manager_generated.h" +#include "ray/raylet/lineage_cache.h" +#include "ray/raylet/task.h" +#include "ray/raylet/task_execution_spec.h" +#include "ray/raylet/task_spec.h" + +namespace ray { + +namespace raylet { + +class MockGcs : virtual public gcs::TableInterface { + public: + MockGcs(){}; + Status Add(const JobID &job_id, const TaskID &task_id, + std::shared_ptr task_data, + const gcs::TableInterface::WriteCallback &done) { + task_table_[task_id] = task_data; + callbacks_.push_back( + std::pair(done, task_id)); + return ray::Status::OK(); + }; + + void Flush() { + for (const auto &callback : callbacks_) { + callback.first(NULL, callback.second, task_table_[callback.second]); + } + callbacks_.clear(); + }; + + const std::unordered_map, UniqueIDHasher> + &TaskTable() const { + return task_table_; + } + + private: + std::unordered_map, UniqueIDHasher> + task_table_; + std::vector> callbacks_; +}; + +class LineageCacheTest : public ::testing::Test { + public: + LineageCacheTest() : mock_gcs_(), lineage_cache_(mock_gcs_) {} + + protected: + MockGcs mock_gcs_; + LineageCache lineage_cache_; +}; + +static inline Task ExampleTask(const std::vector &arguments, + int64_t num_returns) { + std::unordered_map required_resources; + std::vector> task_arguments; + for (auto &argument : arguments) { + std::vector references = {argument}; + task_arguments.emplace_back(std::make_shared(references)); + } + auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0, + UniqueID::from_random(), task_arguments, num_returns, + required_resources); + auto execution_spec = TaskExecutionSpecification(std::vector()); + execution_spec.IncrementNumForwards(); + Task task = Task(execution_spec, spec); + return task; +} + +std::vector InsertTaskChain(LineageCache &lineage_cache, + std::vector &inserted_tasks, int chain_size, + const std::vector &initial_arguments, + int64_t num_returns) { + Lineage empty_lineage; + std::vector arguments = initial_arguments; + for (int i = 0; i < chain_size; i++) { + auto task = ExampleTask(arguments, num_returns); + lineage_cache.AddWaitingTask(task, empty_lineage); + inserted_tasks.push_back(task); + arguments.clear(); + for (int j = 0; j < task.GetTaskSpecification().NumReturns(); j++) { + arguments.push_back(task.GetTaskSpecification().ReturnId(j)); + } + } + return arguments; +} + +TEST_F(LineageCacheTest, TestGetUncommittedLineage) { + // Insert two independent chains of tasks. + std::vector tasks1; + auto return_values1 = + InsertTaskChain(lineage_cache_, tasks1, 3, std::vector(), 1); + std::vector task_ids1; + for (const auto &task : tasks1) { + task_ids1.push_back(task.GetTaskSpecification().TaskId()); + } + + std::vector tasks2; + auto return_values2 = + InsertTaskChain(lineage_cache_, tasks2, 2, std::vector(), 2); + std::vector task_ids2; + for (const auto &task : tasks2) { + task_ids2.push_back(task.GetTaskSpecification().TaskId()); + } + + // Get the uncommitted lineage for the last task (the leaf) of one of the + // chains. + auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_ids1.back()); + // Check that the uncommitted lineage is exactly equal to the first chain of + // tasks. + ASSERT_EQ(task_ids1.size(), uncommitted_lineage.GetEntries().size()); + for (auto &task_id : task_ids1) { + ASSERT_TRUE(uncommitted_lineage.GetEntry(task_id)); + } + + // Insert one task that is dependent on the previous chains of tasks. + std::vector combined_tasks = tasks1; + combined_tasks.insert(combined_tasks.end(), tasks2.begin(), tasks2.end()); + std::vector combined_arguments = return_values1; + combined_arguments.insert(combined_arguments.end(), return_values2.begin(), + return_values2.end()); + InsertTaskChain(lineage_cache_, combined_tasks, 1, combined_arguments, 1); + std::vector combined_task_ids; + for (const auto &task : combined_tasks) { + combined_task_ids.push_back(task.GetTaskSpecification().TaskId()); + } + + // Get the uncommitted lineage for the inserted task. + uncommitted_lineage = lineage_cache_.GetUncommittedLineage(combined_task_ids.back()); + // Check that the uncommitted lineage is exactly equal to the entire set of + // tasks inserted so far. + ASSERT_EQ(combined_task_ids.size(), uncommitted_lineage.GetEntries().size()); + for (auto &task_id : combined_task_ids) { + ASSERT_TRUE(uncommitted_lineage.GetEntry(task_id)); + } +} + +void CheckFlush(LineageCache &lineage_cache, MockGcs &mock_gcs, + size_t num_tasks_flushed) { + RAY_CHECK_OK(lineage_cache.Flush()); + ASSERT_EQ(mock_gcs.TaskTable().size(), num_tasks_flushed); +} + +TEST_F(LineageCacheTest, TestWritebackNoneReady) { + // Insert a chain of dependent tasks. + size_t num_tasks_flushed = 0; + std::vector tasks; + auto return_values1 = + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + + // Check that when no tasks have been marked as ready, we do not flush any + // entries. + CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed); +} + +TEST_F(LineageCacheTest, TestWritebackReady) { + // Insert a chain of dependent tasks. + size_t num_tasks_flushed = 0; + std::vector tasks; + auto return_values1 = + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + + // Check that after marking the first task as ready, we flush only that task. + lineage_cache_.AddReadyTask(tasks.front()); + num_tasks_flushed++; + CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed); +} + +TEST_F(LineageCacheTest, TestWritebackOrder) { + // Insert a chain of dependent tasks. + size_t num_tasks_flushed = 0; + std::vector tasks; + auto return_values1 = + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + + // Mark all tasks as ready. + for (const auto &task : tasks) { + lineage_cache_.AddReadyTask(task); + } + // Check that we write back the tasks in order of data dependencies. + for (size_t i = 0; i < tasks.size(); i++) { + num_tasks_flushed++; + CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed); + // Flush acknowledgements. The next task should be able to be written. + mock_gcs_.Flush(); + } +} + +TEST_F(LineageCacheTest, TestWritebackPartiallyReady) { + // Create two independent tasks, task1 and task2, and a dependent task + // that depends on both tasks. + size_t num_tasks_flushed = 0; + auto task1 = ExampleTask({}, 1); + auto task2 = ExampleTask({}, 1); + std::vector returns; + for (int64_t i = 0; i < task1.GetTaskSpecification().NumReturns(); i++) { + returns.push_back(task1.GetTaskSpecification().ReturnId(i)); + } + for (int64_t i = 0; i < task2.GetTaskSpecification().NumReturns(); i++) { + returns.push_back(task2.GetTaskSpecification().ReturnId(i)); + } + auto dependent_task = ExampleTask(returns, 1); + auto dependencies = dependent_task.GetDependencies(); + + // Insert all tasks as waiting for execution. + lineage_cache_.AddWaitingTask(task1, Lineage()); + lineage_cache_.AddWaitingTask(task2, Lineage()); + lineage_cache_.AddWaitingTask(dependent_task, Lineage()); + + // Mark one of the independent tasks and the dependent task as ready. + lineage_cache_.AddReadyTask(task1); + lineage_cache_.AddReadyTask(dependent_task); + // Check that only the first independent task is flushed. + num_tasks_flushed++; + CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed); + + // Flush acknowledgements. The dependent task should still not be flushed + // since task2 is not committed yet. + mock_gcs_.Flush(); + CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed); + + // Mark the other independent task as ready. + lineage_cache_.AddReadyTask(task2); + // Check that the other independent task gets flushed. + num_tasks_flushed++; + CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed); + + // Flush acknowledgements. The dependent task should now be able to be + // written. + mock_gcs_.Flush(); + num_tasks_flushed++; + CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed); +} + +TEST_F(LineageCacheTest, TestRemoveWaitingTask) { + // Insert a chain of dependent tasks. + std::vector tasks; + auto return_values1 = + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + + auto task_to_remove = tasks[1]; + auto task_id_to_remove = task_to_remove.GetTaskSpecification().TaskId(); + auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id_to_remove); + flatbuffers::FlatBufferBuilder fbb; + auto uncommitted_lineage_message = + uncommitted_lineage.ToFlatbuffer(fbb, task_id_to_remove); + fbb.Finish(uncommitted_lineage_message); + uncommitted_lineage = Lineage( + *flatbuffers::GetRoot(fbb.GetBufferPointer())); + + const Task &task = uncommitted_lineage.GetEntry(task_id_to_remove)->TaskData(); + RAY_LOG(INFO) << "removing task " << task.GetTaskSpecification().TaskId() + << "with numforwards=" + << task.GetTaskExecutionSpecReadonly().NumForwards(); + ASSERT_EQ(task.GetTaskExecutionSpecReadonly().NumForwards(), 1); + + lineage_cache_.RemoveWaitingTask(task_id_to_remove); + lineage_cache_.AddWaitingTask(task_to_remove, uncommitted_lineage); +} + +} // namespace raylet + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index d4c66c1a7..956223233 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -5,34 +5,45 @@ #ifndef RAYLET_TEST int main(int argc, char *argv[]) { - RAY_CHECK(argc == 2); + RAY_CHECK(argc == 5); - // start store - std::string executable_str = std::string(argv[0]); - std::string exec_dir = executable_str.substr(0, executable_str.find_last_of("/")); - std::string plasma_dir = exec_dir + "./../plasma"; - std::string plasma_command = - plasma_dir + - "/plasma_store -m 1000000000 -s /tmp/store 1> /dev/null 2> /dev/null &"; - RAY_LOG(INFO) << plasma_command; - int s = system(plasma_command.c_str()); - RAY_CHECK(s == 0); + const std::string raylet_socket_name = std::string(argv[1]); + const std::string store_socket_name = std::string(argv[2]); + const std::string redis_address = std::string(argv[3]); + int redis_port = std::stoi(argv[4]); - // configure + // Configuration for the node manager. + ray::raylet::NodeManagerConfig node_manager_config; std::unordered_map static_resource_conf; static_resource_conf = {{"CPU", 1}, {"GPU", 1}}; - ray::ResourceSet resource_config(std::move(static_resource_conf)); - ray::ObjectManagerConfig om_config; - om_config.store_socket_name = "/tmp/store"; + node_manager_config.resource_config = + ray::raylet::ResourceSet(std::move(static_resource_conf)); + node_manager_config.num_initial_workers = 0; + // Use a default worker that can execute empty tasks with dependencies. + node_manager_config.worker_command.push_back("python"); + node_manager_config.worker_command.push_back( + "../../../src/ray/python/default_worker.py"); + node_manager_config.worker_command.push_back(raylet_socket_name.c_str()); + node_manager_config.worker_command.push_back(store_socket_name.c_str()); + // TODO(swang): Set this from a global config. + node_manager_config.heartbeat_period_ms = 100; + + // Configuration for the object manager. + ray::ObjectManagerConfig object_manager_config; + object_manager_config.store_socket_name = store_socket_name; // initialize mock gcs & object directory - std::shared_ptr mock_gcs_client = - std::shared_ptr(new ray::GcsClient()); + auto gcs_client = std::make_shared(); + RAY_LOG(INFO) << "Initializing GCS client " + << gcs_client->client_table().GetLocalClientId(); // Initialize the node manager. - boost::asio::io_service io_service; - ray::Raylet server(io_service, std::string(argv[1]), resource_config, om_config, - mock_gcs_client); - io_service.run(); + boost::asio::io_service main_service; + std::unique_ptr object_manager_service; + object_manager_service.reset(new boost::asio::io_service()); + ray::raylet::Raylet server(main_service, std::move(object_manager_service), + raylet_socket_name, redis_address, redis_port, + node_manager_config, object_manager_config, gcs_client); + main_service.run(); } #endif diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index d0e144b74..54601a7b1 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -5,20 +5,153 @@ namespace ray { -NodeManager::NodeManager(const std::string &socket_name, - const ResourceSet &resource_config, - ObjectManager &object_manager) - : local_resources_(resource_config), - worker_pool_(WorkerPool(0)), +namespace raylet { + +NodeManager::NodeManager(boost::asio::io_service &io_service, + const NodeManagerConfig &config, ObjectManager &object_manager, + std::shared_ptr gcs_client) + : io_service_(io_service), + heartbeat_timer_(io_service), + heartbeat_period_ms_(config.heartbeat_period_ms), + local_resources_(config.resource_config), + worker_pool_(config.num_initial_workers, config.worker_command), local_queues_(SchedulingQueue()), scheduling_policy_(local_queues_), reconstruction_policy_([this](const TaskID &task_id) { ResubmitTask(task_id); }), task_dependency_manager_( object_manager, // reconstruction_policy_, - [this](const TaskID &task_id) { HandleWaitingTaskReady(task_id); }) { - //// TODO(atumanov): need to add the self-knowledge of ClientID, using nill(). - // cluster_resource_map_[ClientID::nil()] = local_resources_; + [this](const TaskID &task_id) { HandleWaitingTaskReady(task_id); }), + lineage_cache_(gcs_client->raylet_task_table()), + gcs_client_(gcs_client), + remote_clients_(), + remote_server_connections_(), + object_manager_(object_manager) { + RAY_CHECK(heartbeat_period_ms_ > 0); + // Initialize the resource map with own cluster resource configuration. + ClientID local_client_id = gcs_client_->client_table().GetLocalClientId(); + cluster_resource_map_.emplace(local_client_id, + SchedulingResources(config.resource_config)); +} + +void NodeManager::Heartbeat() { + RAY_LOG(DEBUG) << "[Heartbeat] sending heartbeat."; + auto &heartbeat_table = gcs_client_->heartbeat_table(); + auto heartbeat_data = std::make_shared(); + auto client_id = gcs_client_->client_table().GetLocalClientId(); + const SchedulingResources &local_resources = cluster_resource_map_[client_id]; + heartbeat_data->client_id = client_id.hex(); + // TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly. + // TODO(atumanov): implement a ResourceSet const_iterator. + for (const auto &resource_pair : + local_resources.GetAvailableResources().GetResourceMap()) { + heartbeat_data->resources_available_label.push_back(resource_pair.first); + heartbeat_data->resources_available_capacity.push_back(resource_pair.second); + } + for (const auto &resource_pair : local_resources.GetTotalResources().GetResourceMap()) { + heartbeat_data->resources_total_label.push_back(resource_pair.first); + heartbeat_data->resources_total_capacity.push_back(resource_pair.second); + } + + ray::Status status = heartbeat_table.Add( + UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, + [](ray::gcs::AsyncGcsClient *client, const ClientID &id, + std::shared_ptr data) { + RAY_LOG(DEBUG) << "[HEARTBEAT] heartbeat sent callback"; + }); + + if (!status.ok()) { + RAY_LOG(INFO) << "heartbeat failed: string " << status.ToString() << status.message(); + RAY_LOG(INFO) << "is redis error: " << status.IsRedisError(); + } + RAY_CHECK_OK(status); + + // Reset the timer. + auto heartbeat_period = boost::posix_time::milliseconds(heartbeat_period_ms_); + heartbeat_timer_.expires_from_now(heartbeat_period); + heartbeat_timer_.async_wait([this](const boost::system::error_code &error) { + RAY_CHECK(!error); + Heartbeat(); + }); +} + +void NodeManager::ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &client_data) { + ClientID client_id = ClientID::from_binary(client_data.client_id); + RAY_LOG(DEBUG) << "[ClientAdded] received callback from client id " << client_id.hex(); + if (client_id == gcs_client_->client_table().GetLocalClientId()) { + // We got a notification for ourselves, so we are connected to the GCS now. + // Save this NodeManager's resource information in the cluster resource map. + cluster_resource_map_[client_id] = local_resources_; + // Start sending heartbeats to the GCS. + Heartbeat(); + // Subscribe to heartbeats. + const auto heartbeat_added = [this](gcs::AsyncGcsClient *client, const ClientID &id, + const HeartbeatTableDataT &heartbeat_data) { + this->HeartbeatAdded(client, id, heartbeat_data); + }; + ray::Status status = client->heartbeat_table().Subscribe( + UniqueID::nil(), UniqueID::nil(), heartbeat_added, + [this](gcs::AsyncGcsClient *client) { + RAY_LOG(DEBUG) << "heartbeat table subscription done callback called."; + }); + RAY_CHECK_OK(status); + return; + } + + // TODO(atumanov): make remote client lookup O(1) + if (std::find(remote_clients_.begin(), remote_clients_.end(), client_id) == + remote_clients_.end()) { + RAY_LOG(DEBUG) << "a new client: " << client_id.hex(); + remote_clients_.push_back(client_id); + } else { + // NodeManager connection to this client was already established. + RAY_LOG(DEBUG) << "received a new client connection that already exists: " + << client_id.hex(); + return; + } + + ResourceSet resources_total(client_data.resources_total_label, + client_data.resources_total_capacity); + this->cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); + + // Establish a new NodeManager connection to this GCS client. + auto client_info = gcs_client_->client_table().GetClient(client_id); + RAY_LOG(DEBUG) << "[ClientAdded] CONNECTING TO: " + << " " << client_info.node_manager_address << " " + << client_info.node_manager_port; + + boost::asio::ip::tcp::socket socket(io_service_); + RAY_CHECK_OK(TcpConnect(socket, client_info.node_manager_address, + client_info.node_manager_port)); + auto server_conn = TcpServerConnection(std::move(socket)); + remote_server_connections_.emplace(client_id, std::move(server_conn)); +} + +void NodeManager::HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &client_id, + const HeartbeatTableDataT &heartbeat_data) { + RAY_LOG(DEBUG) << "[HeartbeatAdded]: received heartbeat from client id " + << client_id.hex(); + if (client_id == gcs_client_->client_table().GetLocalClientId()) { + // Skip heartbeats from self. + return; + } + // Locate the client id in remote client table and update available resources based on + // the received heartbeat information. + if (this->cluster_resource_map_.count(client_id) == 0) { + // Haven't received the client registration for this client yet, skip this heartbeat. + RAY_LOG(INFO) << "[HeartbeatAdded]: received heartbeat from unknown client id " + << client_id.hex(); + return; + } + SchedulingResources &resources = this->cluster_resource_map_[client_id]; + ResourceSet heartbeat_resource_available(heartbeat_data.resources_available_label, + heartbeat_data.resources_available_capacity); + resources.SetAvailableResources( + ResourceSet(heartbeat_data.resources_available_label, + heartbeat_data.resources_available_capacity)); + RAY_CHECK(this->cluster_resource_map_[client_id].GetAvailableResources() == + heartbeat_resource_available); } void NodeManager::ProcessNewClient(std::shared_ptr client) { @@ -40,18 +173,6 @@ void NodeManager::ProcessClientMessage(std::shared_ptr cl // Register the new worker. worker_pool_.RegisterWorker(std::move(worker)); } - - // Build the reply to the worker's registration request. TODO(swang): This - // is legacy code and should be removed once actor creation tasks are - // implemented. - flatbuffers::FlatBufferBuilder fbb; - auto reply = - protocol::CreateRegisterClientReply(fbb, fbb.CreateVector(std::vector())); - fbb.Finish(reply); - // Reply to the worker's registration request, then listen for more - // messages. - client->WriteMessage(protocol::MessageType_RegisterClientReply, fbb.GetSize(), - fbb.GetBufferPointer()); } break; case protocol::MessageType_GetTask: { const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); @@ -74,8 +195,14 @@ void NodeManager::ProcessClientMessage(std::shared_ptr cl // Remove the dead worker from the pool and stop listening for messages. const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker) { + if (!worker->GetAssignedTaskId().is_nil()) { + // TODO(swang): Clean up any tasks that were assigned to the worker. + // Release any resources that may be held by this worker. + FinishTask(worker->GetAssignedTaskId()); + } worker_pool_.DisconnectWorker(worker); } + return; } break; case protocol::MessageType_SubmitTask: { // Read the task submitted by the client. @@ -84,14 +211,49 @@ void NodeManager::ProcessClientMessage(std::shared_ptr cl from_flatbuf(*message->execution_dependencies())); TaskSpecification task_spec(*message->task_spec()); Task task(task_execution_spec, task_spec); - // Submit the task to the local scheduler. - SubmitTask(task); - // Listen for more messages. - client->ProcessMessages(); + // Submit the task to the local scheduler. Since the task was submitted + // locally, there is no uncommitted lineage. + SubmitTask(task, Lineage()); + } break; + case protocol::MessageType_ReconstructObject: { + // TODO(hme): handle multiple object ids. + auto message = flatbuffers::GetRoot(message_data); + ObjectID object_id = from_flatbuf(*message->object_id()); + RAY_LOG(DEBUG) << "reconstructing object " << object_id.hex(); + RAY_CHECK_OK(object_manager_.Pull(object_id)); + } break; + + default: + RAY_LOG(FATAL) << "Received unexpected message type " << message_type; + } + + // Listen for more messages. + client->ProcessMessages(); +} + +void NodeManager::ProcessNewNodeManager( + std::shared_ptr node_manager_client) { + node_manager_client->ProcessMessages(); +} + +void NodeManager::ProcessNodeManagerMessage( + std::shared_ptr node_manager_client, int64_t message_type, + const uint8_t *message_data) { + switch (message_type) { + case protocol::MessageType_ForwardTaskRequest: { + auto message = flatbuffers::GetRoot(message_data); + TaskID task_id = from_flatbuf(*message->task_id()); + + Lineage uncommitted_lineage(*message); + const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); + RAY_LOG(DEBUG) << "got task " << task.GetTaskSpecification().TaskId() + << " spillback=" << task.GetTaskExecutionSpecReadonly().NumForwards(); + SubmitTask(task, uncommitted_lineage); } break; default: RAY_LOG(FATAL) << "Received unexpected message type " << message_type; } + node_manager_client->ProcessMessages(); } void NodeManager::HandleWaitingTaskReady(const TaskID &task_id) { @@ -102,29 +264,45 @@ void NodeManager::HandleWaitingTaskReady(const TaskID &task_id) { } void NodeManager::ScheduleTasks() { - // Ask policy for scheduling decision. - // TODO(alexey): Give the policy all cluster resources instead of just the - // local one. - std::unordered_map cluster_resource_map; - cluster_resource_map[ClientID::nil()] = local_resources_; - const auto &policy_decision = scheduling_policy_.Schedule(cluster_resource_map); + auto policy_decision = scheduling_policy_.Schedule( + cluster_resource_map_, gcs_client_->client_table().GetLocalClientId(), + remote_clients_); + RAY_LOG(DEBUG) << "[NM ScheduleTasks] policy decision:"; + for (const auto &pair : policy_decision) { + TaskID task_id = pair.first; + ClientID client_id = pair.second; + RAY_LOG(DEBUG) << task_id.hex() << " --> " << client_id.hex(); + } + // Extract decision for this local scheduler. - // TODO(alexey): Check for this node's own client ID, not for nil. - std::unordered_set task_ids; - for (auto &task_schedule : policy_decision) { - if (task_schedule.second.is_nil()) { - task_ids.insert(task_schedule.first); + std::unordered_set local_task_ids; + // Iterate over (taskid, clientid) pairs, extract tasks to run on the local client. + for (const auto &task_schedule : policy_decision) { + TaskID task_id = task_schedule.first; + ClientID client_id = task_schedule.second; + if (client_id == gcs_client_->client_table().GetLocalClientId()) { + local_task_ids.insert(task_id); + } else { + auto tasks = local_queues_.RemoveTasks({task_id}); + RAY_CHECK(1 == tasks.size()); + Task &task = tasks.front(); + // TODO(swang): Handle forward task failure. + // TODO(swang): Unsubscribe this task in the task dependency manager. + RAY_CHECK_OK(ForwardTask(task, client_id)); } } // Assign the tasks to workers. - std::vector tasks = local_queues_.RemoveTasks(task_ids); + std::vector tasks = local_queues_.RemoveTasks(local_task_ids); for (auto &task : tasks) { AssignTask(task); } } -void NodeManager::SubmitTask(const Task &task) { +void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineage) { + // Add the task and its uncommitted lineage to the lineage cache. + lineage_cache_.AddWaitingTask(task, uncommitted_lineage); + // Queue the task according to the availability of its arguments. if (task_dependency_manager_.TaskReady(task)) { local_queues_.QueueReadyTasks(std::vector({task})); ScheduleTasks(); @@ -135,41 +313,104 @@ void NodeManager::SubmitTask(const Task &task) { } void NodeManager::AssignTask(const Task &task) { + // Resource accounting: acquire resources for the scheduled task. + const ClientID &my_client_id = gcs_client_->client_table().GetLocalClientId(); + RAY_CHECK(this->cluster_resource_map_[my_client_id].Acquire( + task.GetTaskSpecification().GetRequiredResources())); + if (worker_pool_.PoolSize() == 0) { - // Start a new worker. worker_pool_.StartWorker(); // Queue this task for future assignment. The task will be assigned to a // worker once one becomes available. local_queues_.QueueScheduledTasks(std::vector({task})); - // TODO(swang): Acquire resources here or when a worker becomes available? return; } + const TaskSpecification &spec = task.GetTaskSpecification(); std::shared_ptr worker = worker_pool_.PopWorker(); RAY_LOG(DEBUG) << "Assigning task to worker with pid " << worker->Pid(); - // TODO(swang): Acquire resources for the task. - // local_resources_.Acquire(task.GetTaskSpecification().GetRequiredResources()); + worker->AssignTaskId(spec.TaskId()); + local_queues_.QueueRunningTasks(std::vector({task})); flatbuffers::FlatBufferBuilder fbb; - const TaskSpecification &spec = task.GetTaskSpecification(); auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb), fbb.CreateVector(std::vector())); fbb.Finish(message); - worker->Connection()->WriteMessage(protocol::MessageType_ExecuteTask, fbb.GetSize(), - fbb.GetBufferPointer()); - worker->AssignTaskId(spec.TaskId()); - local_queues_.QueueRunningTasks(std::vector({task})); + auto status = worker->Connection()->WriteMessage(protocol::MessageType_ExecuteTask, + fbb.GetSize(), fbb.GetBufferPointer()); + if (status.ok()) { + // We started running the task, so the task is ready to write to GCS. + lineage_cache_.AddReadyTask(task); + } else { + // We failed to send the task to the worker, so disconnect the worker. The + // task will get queued again during cleanup. + ProcessClientMessage(worker->Connection(), protocol::MessageType_DisconnectClient, + NULL); + } } void NodeManager::FinishTask(const TaskID &task_id) { RAY_LOG(DEBUG) << "Finished task " << task_id.hex(); - local_queues_.RemoveTasks({task_id}); - // TODO(swang): Release resources that were held for the task. + auto tasks = local_queues_.RemoveTasks({task_id}); + RAY_CHECK(tasks.size() == 1); + auto task = *tasks.begin(); + + // Resource accounting: release task's resources. + RAY_CHECK( + this->cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( + task.GetTaskSpecification().GetRequiredResources())); } void NodeManager::ResubmitTask(const TaskID &task_id) { throw std::runtime_error("Method not implemented"); } +ray::Status NodeManager::ForwardTask(Task &task, const ClientID &node_id) { + auto task_id = task.GetTaskSpecification().TaskId(); + + // Get and serialize the task's uncommitted lineage. + auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id); + Task &lineage_cache_entry_task = + uncommitted_lineage.GetEntryMutable(task_id)->TaskDataMutable(); + // Increment forward count for the forwarded task. + lineage_cache_entry_task.GetTaskExecutionSpec().IncrementNumForwards(); + + flatbuffers::FlatBufferBuilder fbb; + auto request = uncommitted_lineage.ToFlatbuffer(fbb, task_id); + fbb.Finish(request); + + RAY_LOG(DEBUG) << "Forwarding task " << task_id.hex() << " to " << node_id.hex() + << " spillback=" + << lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards(); + + auto client_info = gcs_client_->client_table().GetClient(node_id); + + // Lookup remote server connection for this node_id and use it to send the request. + if (remote_server_connections_.count(node_id) == 0) { + // TODO(atumanov): caller must handle failure to ensure tasks are not lost. + RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " + << node_id.hex(); + return ray::Status::IOError("NodeManager connection not found"); + } + + auto &server_conn = remote_server_connections_.at(node_id); + auto status = server_conn.WriteMessage(protocol::MessageType_ForwardTaskRequest, + fbb.GetSize(), fbb.GetBufferPointer()); + if (status.ok()) { + // If we were able to forward the task, remove the forwarded task from the + // lineage cache since the receiving node is now responsible for writing + // the task to the GCS. + lineage_cache_.RemoveWaitingTask(task_id); + } else { + // TODO(atumanov): caller must handle ForwardTask failure to ensure tasks are not + // lost. + RAY_LOG(FATAL) << "[NodeManager][ForwardTask] failed to forward task " + << task_id.hex() << " to node " << node_id.hex(); + } + return status; +} + +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index b3741e4d8..1b356c366 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -2,11 +2,13 @@ #define RAY_RAYLET_NODE_MANAGER_H // clang-format off +#include "ray/raylet/task.h" +#include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" +#include "ray/raylet/lineage_cache.h" #include "ray/raylet/scheduling_policy.h" #include "ray/raylet/scheduling_queue.h" #include "ray/raylet/scheduling_resources.h" -#include "ray/object_manager/object_manager.h" #include "ray/raylet/reconstruction_policy.h" #include "ray/raylet/task_dependency_manager.h" #include "ray/raylet/worker_pool.h" @@ -14,16 +16,24 @@ namespace ray { -class NodeManager : public ClientManager { +namespace raylet { + +struct NodeManagerConfig { + ResourceSet resource_config; + int num_initial_workers; + std::vector worker_command; + uint64_t heartbeat_period_ms; +}; + +class NodeManager { public: /// Create a node manager. /// - /// \param socket_name The pathname of the Unix domain socket to listen at - /// for local connections. /// \param resource_config The initial set of node resources. /// \param object_manager A reference to the local object manager. - NodeManager(const std::string &socket_name, const ResourceSet &resource_config, - ObjectManager &object_manager); + NodeManager(boost::asio::io_service &io_service, const NodeManagerConfig &config, + ObjectManager &object_manager, + std::shared_ptr gcs_client); /// Process a new client connection. void ProcessNewClient(std::shared_ptr client); @@ -38,26 +48,41 @@ class NodeManager : public ClientManager { void ProcessClientMessage(std::shared_ptr client, int64_t message_type, const uint8_t *message); + void ProcessNewNodeManager(std::shared_ptr node_manager_client); + + void ProcessNodeManagerMessage(std::shared_ptr node_manager_client, + int64_t message_type, const uint8_t *message); + + void ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data); + + void HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &id, + const HeartbeatTableDataT &data); + private: /// Submit a task to this node. - void SubmitTask(const Task &task); + void SubmitTask(const Task &task, const Lineage &uncommitted_lineage); /// Assign a task. void AssignTask(const Task &task); /// Finish a task. void FinishTask(const TaskID &task_id); /// Schedule tasks. void ScheduleTasks(); - /// Handle a task whose local dependencies were missing and are now - /// available. + /// Handle a task whose local dependencies were missing and are now available. void HandleWaitingTaskReady(const TaskID &task_id); /// Resubmit a task whose return value needs to be reconstructed. void ResubmitTask(const TaskID &task_id); + ray::Status ForwardTask(Task &task, const ClientID &node_id); + /// Send heartbeats to the GCS. + void Heartbeat(); + boost::asio::io_service &io_service_; + boost::asio::deadline_timer heartbeat_timer_; + uint64_t heartbeat_period_ms_; /// The resources local to this node. - SchedulingResources local_resources_; + const SchedulingResources local_resources_; // TODO(atumanov): Add resource information from other nodes. - // std::unordered_map - // cluster_resource_map_; + std::unordered_map cluster_resource_map_; /// A pool of workers. WorkerPool worker_pool_; /// A set of queues to maintain tasks. @@ -68,8 +93,18 @@ class NodeManager : public ClientManager { ReconstructionPolicy reconstruction_policy_; /// A manager to make waiting tasks's missing object dependencies available. TaskDependencyManager task_dependency_manager_; + /// The lineage cache for the GCS object and task tables. + LineageCache lineage_cache_; + /// A client connection to the GCS. + std::shared_ptr gcs_client_; + std::vector remote_clients_; + std::unordered_map + remote_server_connections_; + ObjectManager &object_manager_; }; +} // namespace raylet + } // end namespace ray #endif // RAY_RAYLET_NODE_MANAGER_H diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc new file mode 100644 index 000000000..013664d1e --- /dev/null +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -0,0 +1,235 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "ray/raylet/raylet.h" + +namespace ray { + +namespace raylet { + +std::string test_executable; +std::string store_executable; + +// TODO(hme): Get this working once the dust settles. +class TestObjectManagerBase : public ::testing::Test { + public: + TestObjectManagerBase() { RAY_LOG(INFO) << "TestObjectManagerBase: started."; } + + std::string StartStore(const std::string &id) { + std::string store_id = "/tmp/store"; + store_id = store_id + id; + std::string plasma_command = store_executable + " -m 1000000000 -s " + store_id + + " 1> /dev/null 2> /dev/null &"; + RAY_LOG(INFO) << plasma_command; + int ec = system(plasma_command.c_str()); + if (ec != 0) { + throw std::runtime_error("failed to start plasma store."); + }; + return store_id; + } + + NodeManagerConfig GetNodeManagerConfig(std::string raylet_socket_name, + std::string store_socket_name) { + // Configuration for the node manager. + ray::raylet::NodeManagerConfig node_manager_config; + std::unordered_map static_resource_conf; + static_resource_conf = {{"CPU", 1}, {"GPU", 1}}; + node_manager_config.resource_config = + ray::raylet::ResourceSet(std::move(static_resource_conf)); + node_manager_config.num_initial_workers = 0; + // Use a default worker that can execute empty tasks with dependencies. + node_manager_config.worker_command.push_back("python"); + node_manager_config.worker_command.push_back( + "../../../src/ray/python/default_worker.py"); + node_manager_config.worker_command.push_back(raylet_socket_name.c_str()); + node_manager_config.worker_command.push_back(store_socket_name.c_str()); + return node_manager_config; + }; + + void SetUp() { + object_manager_service_1.reset(new boost::asio::io_service()); + object_manager_service_2.reset(new boost::asio::io_service()); + + // start store + std::string store_sock_1 = StartStore("1"); + std::string store_sock_2 = StartStore("2"); + + // start first server + gcs_client_1 = std::shared_ptr(new gcs::AsyncGcsClient()); + ObjectManagerConfig om_config_1; + om_config_1.store_socket_name = store_sock_1; + server1.reset(new ray::raylet::Raylet( + main_service, std::move(object_manager_service_1), "raylet_1", "127.0.0.1", 6379, + GetNodeManagerConfig("raylet_1", store_sock_1), om_config_1, gcs_client_1)); + + // start second server + gcs_client_2 = std::shared_ptr(new gcs::AsyncGcsClient()); + ObjectManagerConfig om_config_2; + om_config_2.store_socket_name = store_sock_2; + server2.reset(new ray::raylet::Raylet( + main_service, std::move(object_manager_service_2), "raylet_2", "127.0.0.1", 6379, + GetNodeManagerConfig("raylet_2", store_sock_2), om_config_2, gcs_client_2)); + + // connect to stores. + ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY)); + ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY)); + } + + void TearDown() { + arrow::Status client1_status = client1.Disconnect(); + arrow::Status client2_status = client2.Disconnect(); + ASSERT_TRUE(client1_status.ok() && client2_status.ok()); + + this->server1.reset(); + this->server2.reset(); + + int s = system("killall plasma_store &"); + ASSERT_TRUE(!s); + + std::string cmd_str = test_executable.substr(0, test_executable.find_last_of("/")); + s = system(("rm " + cmd_str + "/raylet_1").c_str()); + ASSERT_TRUE(!s); + s = system(("rm " + cmd_str + "/raylet_2").c_str()); + ASSERT_TRUE(!s); + } + + ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { + ObjectID object_id = ObjectID::from_random(); + RAY_LOG(DEBUG) << "ObjectID Created: " << object_id; + uint8_t metadata[] = {5}; + int64_t metadata_size = sizeof(metadata); + std::shared_ptr data; + ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata, + metadata_size, &data)); + ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id())); + return object_id; + } + + protected: + std::thread p; + boost::asio::io_service main_service; + std::unique_ptr object_manager_service_1; + std::unique_ptr object_manager_service_2; + std::shared_ptr gcs_client_1; + std::shared_ptr gcs_client_2; + std::unique_ptr server1; + std::unique_ptr server2; + + plasma::PlasmaClient client1; + plasma::PlasmaClient client2; + std::vector v1; + std::vector v2; +}; + +class TestObjectManagerIntegration : public TestObjectManagerBase { + public: + uint num_expected_objects; + + int num_connected_clients = 0; + + ClientID client_id_1; + ClientID client_id_2; + + void WaitConnections() { + client_id_1 = gcs_client_1->client_table().GetLocalClientId(); + client_id_2 = gcs_client_2->client_table().GetLocalClientId(); + gcs_client_1->client_table().RegisterClientAddedCallback([this]( + gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + ClientID parsed_id = ClientID::from_binary(data.client_id); + if (parsed_id == client_id_1 || parsed_id == client_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 2) { + StartTests(); + } + }); + } + + void StartTests() { + TestConnections(); + AddTransferTestHandlers(); + TestPush(100); + } + + void AddTransferTestHandlers() { + ray::Status status = ray::Status::OK(); + status = + server1->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) { + v1.push_back(object_id); + if (v1.size() == num_expected_objects && v1.size() == v2.size()) { + TestPushComplete(); + } + }); + RAY_CHECK_OK(status); + status = + server2->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) { + v2.push_back(object_id); + if (v2.size() == num_expected_objects && v1.size() == v2.size()) { + TestPushComplete(); + } + }); + RAY_CHECK_OK(status); + } + + void TestPush(int64_t data_size) { + ray::Status status = ray::Status::OK(); + + num_expected_objects = (uint)1; + ObjectID oid1 = WriteDataToClient(client1, data_size); + status = server1->object_manager_.Push(oid1, client_id_2); + } + + void TestPushComplete() { + RAY_LOG(INFO) << "TestPushComplete: " + << " " << v1.size() << " " << v2.size(); + ASSERT_TRUE(v1.size() == v2.size()); + for (int i = -1; ++i < (int)v1.size();) { + ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end()); + } + v1.clear(); + v2.clear(); + main_service.stop(); + } + + void TestConnections() { + RAY_LOG(INFO) << "\n" + << "Server client ids:" + << "\n"; + ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId(); + ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId(); + RAY_LOG(INFO) << "Server 1: " << client_id_1; + RAY_LOG(INFO) << "Server 2: " << client_id_2; + + RAY_LOG(INFO) << "\n" + << "All connected clients:" + << "\n"; + const ClientTableDataT &data = gcs_client_2->client_table().GetClient(client_id_1); + RAY_LOG(INFO) << (ClientID::from_binary(data.client_id) == ClientID::nil()); + RAY_LOG(INFO) << "ClientID=" << ClientID::from_binary(data.client_id); + RAY_LOG(INFO) << "ClientIp=" << data.node_manager_address; + RAY_LOG(INFO) << "ClientPort=" << data.node_manager_port; + const ClientTableDataT &data2 = gcs_client_1->client_table().GetClient(client_id_2); + RAY_LOG(INFO) << "ClientID=" << ClientID::from_binary(data2.client_id); + RAY_LOG(INFO) << "ClientIp=" << data2.node_manager_address; + RAY_LOG(INFO) << "ClientPort=" << data2.node_manager_port; + } +}; + +TEST_F(TestObjectManagerIntegration, StartTestObjectManagerPush) { + auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); }); + AsyncStartTests(); + main_service.run(); +} + +} // namespace raylet + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + ray::raylet::test_executable = std::string(argv[0]); + ray::raylet::store_executable = std::string(argv[1]); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 35ee7f35f..73df723c1 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -1,56 +1,129 @@ #include "raylet.h" +#include #include +#include #include #include "ray/status.h" namespace ray { -Raylet::Raylet(boost::asio::io_service &io_service, const std::string &socket_name, - const ResourceSet &resource_config, +namespace raylet { + +Raylet::Raylet(boost::asio::io_service &main_service, + std::unique_ptr object_manager_service, + const std::string &socket_name, const std::string &redis_address, + int redis_port, const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, - std::shared_ptr gcs_client) - : acceptor_(io_service, boost::asio::local::stream_protocol::endpoint(socket_name)), - socket_(io_service), - tcp_acceptor_(io_service, - boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)), - tcp_socket_(io_service), - object_manager_(io_service, object_manager_config, gcs_client), - node_manager_(socket_name, resource_config, object_manager_), - gcs_client_(gcs_client) { - ClientID client_id = RegisterGcs(); - object_manager_.SetClientID(client_id); + std::shared_ptr gcs_client) + : acceptor_(main_service, boost::asio::local::stream_protocol::endpoint(socket_name)), + socket_(main_service), + object_manager_acceptor_( + main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)), + object_manager_socket_(main_service), + node_manager_acceptor_( + main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)), + node_manager_socket_(main_service), + gcs_client_(gcs_client), + object_manager_(main_service, std::move(object_manager_service), + object_manager_config, gcs_client), + node_manager_(main_service, node_manager_config, object_manager_, gcs_client_) { // Start listening for clients. DoAccept(); - DoAcceptTcp(); + DoAcceptObjectManager(); + DoAcceptNodeManager(); + + RAY_CHECK_OK(RegisterGcs(redis_address, redis_port, main_service, node_manager_config)); + + RAY_CHECK_OK(RegisterPeriodicTimer(main_service)); } -Raylet::~Raylet() { RAY_CHECK_OK(object_manager_.Terminate()); } - -ClientID Raylet::RegisterGcs() { - boost::asio::ip::tcp::endpoint endpoint = tcp_acceptor_.local_endpoint(); - std::string ip = endpoint.address().to_string(); - uint16_t port = endpoint.port(); - ClientID client_id = gcs_client_->Register(ip, port); - return client_id; +Raylet::~Raylet() { + RAY_CHECK_OK(gcs_client_->client_table().Disconnect()); + RAY_CHECK_OK(object_manager_.Terminate()); } -void Raylet::DoAcceptTcp() { - TCPClientConnection::pointer new_connection = - TCPClientConnection::Create(acceptor_.get_io_service()); - tcp_acceptor_.async_accept(new_connection->GetSocket(), - boost::bind(&Raylet::HandleAcceptTcp, this, new_connection, - boost::asio::placeholders::error)); +ray::Status Raylet::RegisterPeriodicTimer(boost::asio::io_service &io_service) { + boost::posix_time::milliseconds timer_period_ms(100); + boost::asio::deadline_timer timer(io_service, timer_period_ms); + return ray::Status::OK(); } -void Raylet::HandleAcceptTcp(TCPClientConnection::pointer new_connection, - const boost::system::error_code &error) { - if (!error) { - // Pass it off to object manager for now. - ray::Status status = object_manager_.AcceptConnection(std::move(new_connection)); +ray::Status Raylet::RegisterGcs(const std::string &redis_address, int redis_port, + boost::asio::io_service &io_service, + const NodeManagerConfig &node_manager_config) { + RAY_RETURN_NOT_OK(gcs_client_->Connect(redis_address, redis_port)); + RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); + + ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); + client_info.node_manager_address = + node_manager_acceptor_.local_endpoint().address().to_string(); + client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); + client_info.node_manager_port = node_manager_acceptor_.local_endpoint().port(); + // Add resource information. + for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { + client_info.resources_total_label.push_back(resource_pair.first); + client_info.resources_total_capacity.push_back(resource_pair.second); } - DoAcceptTcp(); + + RAY_LOG(DEBUG) << "NM LISTENING ON: IP " << client_info.node_manager_address << " PORT " + << client_info.node_manager_port; + RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info)); + + auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { + node_manager_.ClientAdded(client, id, data); + }; + gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); + return Status::OK(); +} + +void Raylet::DoAcceptNodeManager() { + node_manager_acceptor_.async_accept(node_manager_socket_, + boost::bind(&Raylet::HandleAcceptNodeManager, this, + boost::asio::placeholders::error)); +} + +void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) { + if (!error) { + ClientHandler client_handler = + [this](std::shared_ptr client) { + node_manager_.ProcessNewNodeManager(client); + }; + MessageHandler message_handler = [this]( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + node_manager_.ProcessNodeManagerMessage(client, message_type, message); + }; + // Accept a new local client and dispatch it to the node manager. + auto new_connection = TcpClientConnection::Create(client_handler, message_handler, + std::move(node_manager_socket_)); + } + // We're ready to accept another client. + DoAcceptNodeManager(); +} + +void Raylet::DoAcceptObjectManager() { + object_manager_acceptor_.async_accept( + object_manager_socket_, boost::bind(&Raylet::HandleAcceptObjectManager, this, + boost::asio::placeholders::error)); +} + +void Raylet::HandleAcceptObjectManager(const boost::system::error_code &error) { + ClientHandler client_handler = + [this](std::shared_ptr client) { + object_manager_.ProcessNewClient(client); + }; + MessageHandler message_handler = [this]( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + object_manager_.ProcessClientMessage(client, message_type, message); + }; + // Accept a new local client and dispatch it to the node manager. + auto new_connection = TcpClientConnection::Create(client_handler, message_handler, + std::move(object_manager_socket_)); + DoAcceptObjectManager(); } void Raylet::DoAccept() { @@ -60,14 +133,24 @@ void Raylet::DoAccept() { void Raylet::HandleAccept(const boost::system::error_code &error) { if (!error) { + // TODO: typedef these handlers. + ClientHandler client_handler = + [this](std::shared_ptr client) { + node_manager_.ProcessNewClient(client); + }; + MessageHandler message_handler = [this]( + std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + node_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new local client and dispatch it to the node manager. - auto new_connection = - LocalClientConnection::Create(node_manager_, std::move(socket_)); + auto new_connection = LocalClientConnection::Create(client_handler, message_handler, + std::move(socket_)); } // We're ready to accept another client. DoAccept(); } -ObjectManager &Raylet::GetObjectManager() { return object_manager_; } +} // namespace raylet } // namespace ray diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 5080e0730..45ef60f4e 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -14,63 +14,75 @@ namespace ray { +namespace raylet { + class Task; class NodeManager; -// TODO(swang): Rename class and source files to Raylet. class Raylet { public: /// Create a node manager server and listen for new clients. /// - /// \param io_service The event loop to run the server on. + /// \param main_service The event loop to run the server on. + /// \param object_manager_service The asio io_service tied to the object manager. /// \param socket_name The Unix domain socket to listen on for local clients. - /// \param resource_config The initial set of resources to start the local + /// \param node_manager_config Configuration to initialize the node manager. /// scheduler with. /// \param object_manager_config Configuration to initialize the object /// manager. /// \param gcs_client A client connection to the GCS. - Raylet(boost::asio::io_service &io_service, const std::string &socket_name, - const ResourceSet &resource_config, + Raylet(boost::asio::io_service &main_service, + std::unique_ptr object_manager_service, + const std::string &socket_name, const std::string &redis_address, int redis_port, + const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, - std::shared_ptr gcs_client); + std::shared_ptr gcs_client); /// Destroy the NodeServer. ~Raylet(); - // TODO(melih): Get rid of this method. - ObjectManager &GetObjectManager(); - private: /// Register GCS client. - ClientID RegisterGcs(); + ray::Status RegisterGcs(const std::string &redis_address, int redis_port, + boost::asio::io_service &io_service, const NodeManagerConfig &); + + ray::Status RegisterPeriodicTimer(boost::asio::io_service &io_service); /// Accept a client connection. void DoAccept(); /// Handle an accepted client connection. void HandleAccept(const boost::system::error_code &error); /// Accept a tcp client connection. - void DoAcceptTcp(); + void DoAcceptObjectManager(); /// Handle an accepted tcp client connection. - void HandleAcceptTcp(TCPClientConnection::pointer new_connection, - const boost::system::error_code &error); + void HandleAcceptObjectManager(const boost::system::error_code &error); + void DoAcceptNodeManager(); + void HandleAcceptNodeManager(const boost::system::error_code &error); + + friend class TestObjectManagerIntegration; /// An acceptor for new clients. boost::asio::local::stream_protocol::acceptor acceptor_; /// The socket to listen on for new clients. boost::asio::local::stream_protocol::socket socket_; + /// An acceptor for new object manager tcp clients. + boost::asio::ip::tcp::acceptor object_manager_acceptor_; + /// The socket to listen on for new object manager tcp clients. + boost::asio::ip::tcp::socket object_manager_socket_; /// An acceptor for new tcp clients. - boost::asio::ip::tcp::acceptor tcp_acceptor_; + boost::asio::ip::tcp::acceptor node_manager_acceptor_; /// The socket to listen on for new tcp clients. - boost::asio::ip::tcp::socket tcp_socket_; + boost::asio::ip::tcp::socket node_manager_socket_; - // TODO(swang): Lineage cache. + /// A client connection to the GCS. + std::shared_ptr gcs_client_; /// Manages client requests for object transfers and availability. ObjectManager object_manager_; /// Manages client requests for task submission and execution. NodeManager node_manager_; - /// A client connection to the GCS. - std::shared_ptr gcs_client_; }; +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_RAYLET_H diff --git a/src/ray/raylet/raylet_test.cc b/src/ray/raylet/raylet_test.cc deleted file mode 100644 index ee5bc270d..000000000 --- a/src/ray/raylet/raylet_test.cc +++ /dev/null @@ -1,275 +0,0 @@ -#include -#include - -#include "gtest/gtest.h" - -#include "ray/raylet/raylet.h" - -namespace ray { - -std::string test_executable; // NOLINT - -class TestRaylet : public ::testing::Test { - public: - TestRaylet() { RAY_LOG(INFO) << "TestRaylet: started."; } - - std::string StartStore(const std::string &id) { - std::string store_id = "/tmp/store"; - store_id = store_id + id; - std::string test_dir = test_executable.substr(0, test_executable.find_last_of("/")); - std::string plasma_dir = test_dir + "./../plasma"; - std::string plasma_command = plasma_dir + "/plasma_store -m 1000000000 -s " + - store_id + " 1> /dev/null 2> /dev/null &"; - RAY_LOG(INFO) << plasma_command; - int ec = system(plasma_command.c_str()); - if (ec != 0) { - throw std::runtime_error("failed to start plasma store."); - }; - return store_id; - } - - void SetUp() { - // start store - std::string store_sock_1 = StartStore("1"); - std::string store_sock_2 = StartStore("2"); - - // configure - std::unordered_map static_resource_config; - static_resource_config = {{"num_cpus", 1}, {"num_gpus", 1}}; - ray::ResourceSet resource_config(std::move(static_resource_config)); - - // start mock gcs - mock_gcs_client = std::shared_ptr(new GcsClient()); - - // start first server - ray::ObjectManagerConfig om_config_1; - om_config_1.store_socket_name = store_sock_1; - server1.reset(new Raylet(io_service, std::string("hello1"), resource_config, - om_config_1, mock_gcs_client)); - - // start second server - ray::ObjectManagerConfig om_config_2; - om_config_2.store_socket_name = store_sock_2; - server2.reset(new Raylet(io_service, std::string("hello2"), resource_config, - om_config_2, mock_gcs_client)); - - // connect to stores. - ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY)); - ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY)); - this->StartLoop(); - } - - void TearDown() { - this->StopLoop(); - arrow::Status client1_status = client1.Disconnect(); - arrow::Status client2_status = client2.Disconnect(); - ASSERT_TRUE(client1_status.ok() && client2_status.ok()); - - this->server1.reset(); - this->server2.reset(); - - int s = system("killall plasma_store &"); - ASSERT_TRUE(!s); - - std::string cmd_str = test_executable.substr(0, test_executable.find_last_of("/")); - s = system(("rm " + cmd_str + "/hello1").c_str()); - ASSERT_TRUE(!s); - s = system(("rm " + cmd_str + "/hello2").c_str()); - ASSERT_TRUE(!s); - } - - void Loop() { io_service.run(); }; - - void StartLoop() { p = std::thread(&TestRaylet::Loop, this); }; - - void StopLoop() { - io_service.stop(); - p.join(); - } - - ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { - ObjectID object_id = ObjectID::from_random(); - RAY_LOG(DEBUG) << "ObjectID Created: " << object_id.hex().c_str(); - uint8_t metadata[] = {5}; - int64_t metadata_size = sizeof(metadata); - std::shared_ptr data; - ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata, - metadata_size, &data)); - ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id())); - return object_id; - } - - void object_added_handler_1(const ObjectID &object_id) { - RAY_LOG(INFO) << "Store 1 added: " << object_id.hex(); - v1.push_back(object_id); - }; - - void object_added_handler_2(const ObjectID &object_id) { - RAY_LOG(INFO) << "Store 2 added: " << object_id.hex(); - v2.push_back(object_id); - }; - - protected: - std::thread p; - boost::asio::io_service io_service; - std::shared_ptr mock_gcs_client; - std::unique_ptr server1; - std::unique_ptr server2; - - plasma::PlasmaClient client1; - plasma::PlasmaClient client2; - std::vector v1; - std::vector v2; -}; - -TEST_F(TestRaylet, TestRayletCommands) { - ray::Status status = ray::Status::OK(); - // TODO(atumanov): assert status is OK everywhere it's returned. - RAY_LOG(INFO) << "\n" - << "All connected clients:" - << "\n"; - status = mock_gcs_client->client_table().GetClientIds( - [this](const std::vector &client_ids) { - mock_gcs_client->client_table().GetClientInformationSet( - client_ids, - [this](const std::vector &info_vec) { - for (const auto &info : info_vec) { - RAY_LOG(INFO) << "ClientID=" << info.GetClientId().hex(); - RAY_LOG(INFO) << "ClientIp=" << info.GetIp(); - RAY_LOG(INFO) << "ClientPort=" << info.GetPort(); - } - }, - [](Status status) {}); - }); - - sleep(1); - - RAY_LOG(INFO) << "\n" - << "Server client ids:" - << "\n"; - - status = server1->GetObjectManager().SubscribeObjAdded( - [this](const ObjectID &object_id) { object_added_handler_1(object_id); }); - ASSERT_TRUE(status.ok()); - - status = server2->GetObjectManager().SubscribeObjAdded( - [this](const ObjectID &object_id) { object_added_handler_2(object_id); }); - ASSERT_TRUE(status.ok()); - - ClientID client_id_1 = server1->GetObjectManager().GetClientID(); - ClientID client_id_2 = server2->GetObjectManager().GetClientID(); - RAY_LOG(INFO) << "Server 1: " << client_id_1.hex(); - RAY_LOG(INFO) << "Server 2: " << client_id_2.hex(); - - sleep(1); - - RAY_LOG(INFO) << "\n" - << "Test bidirectional pull" - << "\n"; - for (int i = -1; ++i < 100;) { - ObjectID oid1 = WriteDataToClient(client1, 100); - ObjectID oid2 = WriteDataToClient(client2, 100); - status = server1->GetObjectManager().Pull(oid2); - status = server2->GetObjectManager().Pull(oid1); - } - sleep(1); - RAY_LOG(INFO) << v1.size() << " " << v2.size(); - ASSERT_TRUE(v1.size() == v2.size()); - for (int i = -1; ++i < (int)v1.size();) { - ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end()); - } - v1.clear(); - v2.clear(); - - RAY_LOG(INFO) << "\n" - << "Test pull 1 from 2" - << "\n"; - for (int i = -1; ++i < 3;) { - ObjectID oid2 = WriteDataToClient(client2, 100); - status = server1->GetObjectManager().Pull(oid2); - } - sleep(1); - RAY_LOG(INFO) << v1.size() << " " << v2.size(); - ASSERT_TRUE(v1.size() == v2.size()); - for (int i = -1; ++i < (int)v1.size();) { - ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end()); - } - v1.clear(); - v2.clear(); - - RAY_LOG(INFO) << "\n" - << "Test pull 2 from 1" - << "\n"; - for (int i = -1; ++i < 3;) { - ObjectID oid1 = WriteDataToClient(client1, 100); - status = server2->GetObjectManager().Pull(oid1); - } - sleep(1); - RAY_LOG(INFO) << v1.size() << " " << v2.size(); - ASSERT_TRUE(v1.size() == v2.size()); - for (int i = -1; ++i < (int)v1.size();) { - ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end()); - } - v1.clear(); - v2.clear(); - - RAY_LOG(INFO) << "\n" - << "Test push 1 to 2" - << "\n"; - for (int i = -1; ++i < 3;) { - ObjectID oid1 = WriteDataToClient(client1, 100); - status = server1->GetObjectManager().Push(oid1, client_id_2); - } - sleep(1); - RAY_LOG(INFO) << v1.size() << " " << v2.size(); - ASSERT_TRUE(v1.size() == v2.size()); - for (int i = -1; ++i < (int)v1.size();) { - ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end()); - } - v1.clear(); - v2.clear(); - - RAY_LOG(INFO) << "\n" - << "Test push 2 to 1" - << "\n"; - for (int i = -1; ++i < 3;) { - ObjectID oid2 = WriteDataToClient(client2, 100); - status = server2->GetObjectManager().Push(oid2, client_id_1); - } - sleep(1); - RAY_LOG(INFO) << v1.size() << " " << v2.size(); - ASSERT_TRUE(v1.size() == v2.size()); - for (int i = -1; ++i < (int)v1.size();) { - ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end()); - } - v1.clear(); - v2.clear(); - - RAY_LOG(INFO) << "\n" - << "Test bidirectional push" - << "\n"; - for (int i = -1; ++i < 3;) { - ObjectID oid1 = WriteDataToClient(client1, 100); - ObjectID oid2 = WriteDataToClient(client2, 100); - status = server1->GetObjectManager().Push(oid1, client_id_2); - status = server2->GetObjectManager().Push(oid2, client_id_1); - } - sleep(1); - RAY_LOG(INFO) << v1.size() << " " << v2.size(); - ASSERT_TRUE(v1.size() == v2.size()); - for (int i = -1; ++i < (int)v1.size();) { - ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end()); - } - v1.clear(); - v2.clear(); - - ASSERT_TRUE(true); -} - -} // namespace ray - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - ray::test_executable = std::string(argv[0]); - return RUN_ALL_TESTS(); -} diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index ddfa3733e..d22c8c966 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -2,8 +2,12 @@ namespace ray { +namespace raylet { + void ReconstructionPolicy::CheckObjectReconstruction(const ObjectID &object) { throw std::runtime_error("Method not implemented"); } +} // namespace raylet + } // end namespace ray diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index 664e44d28..4f8ca7069 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -7,6 +7,8 @@ namespace ray { +namespace raylet { + // TODO(swang): Use std::function instead of boost. class ReconstructionPolicy { @@ -29,6 +31,8 @@ class ReconstructionPolicy { private: }; +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_RECONSTRUCTION_POLICY_H diff --git a/src/ray/raylet/remote_dependencies_demo.cc b/src/ray/raylet/remote_dependencies_demo.cc deleted file mode 100644 index d8cd18b96..000000000 --- a/src/ray/raylet/remote_dependencies_demo.cc +++ /dev/null @@ -1,42 +0,0 @@ -#include - -#include "ray/raylet/raylet.h" - -/// A demo that starts two Raylets, with one object store each. The two Raylets -/// share a mock GCS client for communication between the two (e.g., for -/// ObjectManager::Push). -int main(int argc, char *argv[]) { - RAY_CHECK(argc == 3); - std::string store1 = "/tmp/store1"; - std::string store2 = "/tmp/store2"; - // start store - std::string plasma_dir = "../../plasma"; - std::string plasma_command1 = plasma_dir + "/plasma_store -m 1000000000 -s "; - std::string plasma_command2 = " 1> /dev/null 2> /dev/null &"; - RAY_LOG(INFO) << plasma_command1 << store1 << plasma_command2; - RAY_LOG(INFO) << plasma_command1 << store2 << plasma_command2; - int s; - s = system((plasma_command1 + store1 + plasma_command2).c_str()); - RAY_CHECK(s == 0); - s = system((plasma_command1 + store2 + plasma_command2).c_str()); - - // configure - std::unordered_map static_resource_conf; - static_resource_conf = {{"CPU", 1}, {"GPU", 1}}; - ray::ResourceSet resource_config(std::move(static_resource_conf)); - ray::ObjectManagerConfig om_config; - - // initialize mock gcs & object directory - std::shared_ptr mock_gcs_client = - std::shared_ptr(new ray::GcsClient()); - - // Initialize the node manager. - boost::asio::io_service io_service; - om_config.store_socket_name = store1; - ray::Raylet server1(io_service, std::string(argv[1]), resource_config, om_config, - mock_gcs_client); - om_config.store_socket_name = store2; - ray::Raylet server2(io_service, std::string(argv[2]), resource_config, om_config, - mock_gcs_client); - io_service.run(); -} diff --git a/src/ray/raylet/scheduling_policy.cc b/src/ray/raylet/scheduling_policy.cc index 9119131e0..0ec8f3d08 100644 --- a/src/ray/raylet/scheduling_policy.cc +++ b/src/ray/raylet/scheduling_policy.cc @@ -1,32 +1,75 @@ #include "scheduling_policy.h" +#include "ray/util/logging.h" + namespace ray { +namespace raylet { + SchedulingPolicy::SchedulingPolicy(const SchedulingQueue &scheduling_queue) - : scheduling_queue_(scheduling_queue) {} + : scheduling_queue_(scheduling_queue), gen_(rd_()) {} std::unordered_map SchedulingPolicy::Schedule( const std::unordered_map - &cluster_resources) { - static ClientID local_node_id = ClientID::nil(); + &cluster_resources, + const ClientID &local_client_id, const std::vector &others) { + // The policy decision to be returned. std::unordered_map decision; - // TODO(atumanov): consider all cluster resources. - SchedulingResources resource_supply = cluster_resources.at(local_node_id); - const auto &resource_supply_set = resource_supply.GetAvailableResources(); + // TODO(atumanov): protect DEBUG code blocks with ifdef DEBUG + RAY_LOG(DEBUG) << "[Schedule] cluster resource map: "; + for (const auto &client_resource_pair : cluster_resources) { + // pair = ClientID, SchedulingResources + const ClientID &client_id = client_resource_pair.first; + const SchedulingResources &resources = client_resource_pair.second; + RAY_LOG(DEBUG) << "client_id: " << client_id << " " + << resources.GetAvailableResources().ToString(); + } // Iterate over running tasks, get their resource demand and try to schedule. for (const auto &t : scheduling_queue_.GetReadyTasks()) { // Get task's resource demand const auto &resource_demand = t.GetTaskSpecification().GetRequiredResources(); - bool task_feasible = resource_demand.IsSubset(resource_supply_set); - if (task_feasible) { - const TaskID &task_id = t.GetTaskSpecification().TaskId(); - decision[task_id] = local_node_id; + const TaskID &task_id = t.GetTaskSpecification().TaskId(); + RAY_LOG(DEBUG) << "[SchedulingPolicy]: task=" << task_id + << " numforwards=" << t.GetTaskExecutionSpecReadonly().NumForwards() + << " resources=" + << t.GetTaskSpecification().GetRequiredResources().ToString(); + // TODO(atumanov): replace the simple spillback policy with exponential backoff based + // policy. + if (t.GetTaskExecutionSpecReadonly().NumForwards() >= 1) { + decision[task_id] = local_client_id; + continue; } + // Construct a set of viable node candidates and randomly pick between them. + // Get all the client id keys and randomly pick. + std::vector client_keys; + for (const auto &client_resource_pair : cluster_resources) { + // pair = ClientID, SchedulingResources + ClientID node_client_id = client_resource_pair.first; + SchedulingResources node_resources = client_resource_pair.second; + RAY_LOG(DEBUG) << "client_id " << node_client_id << " resources: " + << node_resources.GetAvailableResources().ToString(); + if (resource_demand.IsSubset(node_resources.GetTotalResources())) { + // This node is a feasible candidate. + client_keys.push_back(node_client_id); + } + } + RAY_CHECK(!client_keys.empty()); + + // Choose index at random. + // Initialize a uniform integer distribution over the key space. + // TODO(atumanov): change uniform random to discrete, weighted by resource capacity. + std::uniform_int_distribution distribution(0, client_keys.size() - 1); + int client_key_index = distribution(gen_); + decision[task_id] = client_keys[client_key_index]; + RAY_LOG(DEBUG) << "[SchedulingPolicy] idx=" << client_key_index << " " << task_id + << " --> " << client_keys[client_key_index]; } return decision; } SchedulingPolicy::~SchedulingPolicy() {} +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/scheduling_policy.h b/src/ray/raylet/scheduling_policy.h index 47f0ab712..f049cfc22 100644 --- a/src/ray/raylet/scheduling_policy.h +++ b/src/ray/raylet/scheduling_policy.h @@ -1,6 +1,7 @@ #ifndef RAY_RAYLET_SCHEDULING_POLICY_H #define RAY_RAYLET_SCHEDULING_POLICY_H +#include #include #include "ray/raylet/scheduling_queue.h" @@ -8,6 +9,8 @@ namespace ray { +namespace raylet { + /// \class SchedulingPolicy /// \brief Implements a scheduling policy for the node manager. class SchedulingPolicy { @@ -27,7 +30,8 @@ class SchedulingPolicy { /// \return Scheduling decision, mapping tasks to node managers for placement. std::unordered_map Schedule( const std::unordered_map - &cluster_resources); + &cluster_resources, + const ClientID &local_client_id, const std::vector &others); /// \brief SchedulingPolicy destructor. virtual ~SchedulingPolicy(); @@ -35,8 +39,14 @@ class SchedulingPolicy { private: /// An immutable reference to the scheduling task queues. const SchedulingQueue &scheduling_queue_; + /// Internally maintained random number engine device. + std::random_device rd_; + /// Internally maintained random number generator. + std::mt19937_64 gen_; }; +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_SCHEDULING_POLICY_H diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 39c5ec321..63d8869bd 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -4,6 +4,8 @@ namespace ray { +namespace raylet { + const std::list &SchedulingQueue::GetWaitingTasks() const { return this->waiting_tasks_; } @@ -88,4 +90,6 @@ bool SchedulingQueue::RegisterActor(ActorID actor_id, return true; } +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index 5d980e569..304fd78d8 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -11,6 +11,8 @@ namespace ray { +namespace raylet { + /// \class SchedulingQueue /// /// Encapsulates task queues. Each queue represents a scheduling state for a @@ -103,6 +105,9 @@ class SchedulingQueue { /// The registry of known actors. std::unordered_map actor_registry_; }; + +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_SCHEDULING_QUEUE_H diff --git a/src/ray/raylet/scheduling_resources.cc b/src/ray/raylet/scheduling_resources.cc index f3648b5ad..c5bfa6435 100644 --- a/src/ray/raylet/scheduling_resources.cc +++ b/src/ray/raylet/scheduling_resources.cc @@ -2,13 +2,25 @@ #include +#include "ray/util/logging.h" + namespace ray { +namespace raylet { + ResourceSet::ResourceSet() {} ResourceSet::ResourceSet(const std::unordered_map &resource_map) : resource_capacity_(resource_map) {} +ResourceSet::ResourceSet(const std::vector &resource_labels, + const std::vector resource_capacity) { + RAY_CHECK(resource_labels.size() == resource_capacity.size()); + for (uint i = 0; i < resource_labels.size(); i++) { + RAY_CHECK(this->AddResource(resource_labels[i], resource_capacity[i])); + } +} + ResourceSet::~ResourceSet() {} bool ResourceSet::operator==(const ResourceSet &rhs) const { @@ -43,16 +55,42 @@ bool ResourceSet::IsEqual(const ResourceSet &rhs) const { } bool ResourceSet::AddResource(const std::string &resource_name, double capacity) { - throw std::runtime_error("Method not implemented"); + this->resource_capacity_[resource_name] = capacity; + return true; } bool ResourceSet::RemoveResource(const std::string &resource_name) { throw std::runtime_error("Method not implemented"); } bool ResourceSet::SubtractResources(const ResourceSet &other) { - throw std::runtime_error("Method not implemented"); + // Return failure if attempting to perform vector subtraction with unknown labels. + // TODO(atumanov): make the implementation atomic. Currently, if false is returned + // the resource capacity may be partially mutated. To reverse, call AddResources. + for (const auto &resource_pair : other.GetResourceMap()) { + const std::string &resource_label = resource_pair.first; + const double &resource_capacity = resource_pair.second; + if (resource_capacity_.count(resource_label) == 0) { + return false; + } else { + resource_capacity_[resource_label] -= resource_capacity; + } + } + return true; } + bool ResourceSet::AddResources(const ResourceSet &other) { - throw std::runtime_error("Method not implemented"); + // Return failure if attempting to perform vector addition with unknown labels. + // TODO(atumanov): make the implementation atomic. Currently, if false is returned + // the resource capacity may be partially mutated. To reverse, call SubtractResources. + for (const auto &resource_pair : other.GetResourceMap()) { + const std::string &resource_label = resource_pair.first; + const double &resource_capacity = resource_pair.second; + if (resource_capacity_.count(resource_label) == 0) { + return false; + } else { + resource_capacity_[resource_label] += resource_capacity; + } + } + return true; } bool ResourceSet::GetResource(const std::string &resource_name, double *value) const { @@ -67,6 +105,19 @@ bool ResourceSet::GetResource(const std::string &resource_name, double *value) c return true; } +const std::string ResourceSet::ToString() const { + std::string return_string = ""; + for (const auto &resource_pair : this->resource_capacity_) { + return_string += + "{" + resource_pair.first + "," + std::to_string(resource_pair.second) + "}, "; + } + return return_string; +} + +const std::unordered_map &ResourceSet::GetResourceMap() const { + return this->resource_capacity_; +}; + /// SchedulingResources class implementation SchedulingResources::SchedulingResources() @@ -93,6 +144,14 @@ const ResourceSet &SchedulingResources::GetAvailableResources() const { return this->resources_available_; } +void SchedulingResources::SetAvailableResources(ResourceSet &&newset) { + this->resources_available_ = newset; +} + +const ResourceSet &SchedulingResources::GetTotalResources() const { + return this->resources_total_; +} + // Return specified resources back to SchedulingResources. bool SchedulingResources::Release(const ResourceSet &resources) { return this->resources_available_.AddResources(resources); @@ -103,4 +162,6 @@ bool SchedulingResources::Acquire(const ResourceSet &resources) { return this->resources_available_.SubtractResources(resources); } +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/scheduling_resources.h b/src/ray/raylet/scheduling_resources.h index 7918b2869..febcbb392 100644 --- a/src/ray/raylet/scheduling_resources.h +++ b/src/ray/raylet/scheduling_resources.h @@ -4,9 +4,12 @@ #include #include #include +#include namespace ray { +namespace raylet { + /// Resource availability status reports whether the resource requirement is /// (1) infeasible, (2) feasible but currently unavailable, or (3) available. typedef enum { @@ -26,6 +29,11 @@ class ResourceSet { /// \brief Constructs ResourceSet from the specified resource map. ResourceSet(const std::unordered_map &resource_map); + /// \brief Constructs ResourceSet from two equal-length vectors with label and capacity + /// specification. + ResourceSet(const std::vector &resource_labels, + const std::vector resource_capacity); + /// \brief Empty ResourceSet destructor. ~ResourceSet(); @@ -89,6 +97,11 @@ class ResourceSet { /// False otherwise. bool GetResource(const std::string &resource_name, double *value) const; + // TODO(atumanov): implement const_iterator class for the ResourceSet container. + const std::unordered_map &GetResourceMap() const; + + const std::string ToString() const; + private: /// Resource capacity map. std::unordered_map resource_capacity_; @@ -125,6 +138,14 @@ class SchedulingResources { /// \return Immutable set of resources with currently available capacity. const ResourceSet &GetAvailableResources() const; + /// \brief Overwrite available resource capacity with the specified resource set. + /// + /// \param newset: The set of resources that replaces available resource capacity. + /// \return None. + void SetAvailableResources(ResourceSet &&newset); + + const ResourceSet &GetTotalResources() const; + /// \brief Release the amount of resources specified. /// /// \param resources: the amount of resources to be released. @@ -145,6 +166,8 @@ class SchedulingResources { /// gpu_map - replace with ResourceMap (for generality). }; +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_SCHEDULING_RESOURCES_H diff --git a/src/ray/raylet/task.cc b/src/ray/raylet/task.cc index 45607e424..0f209d33b 100644 --- a/src/ray/raylet/task.cc +++ b/src/ray/raylet/task.cc @@ -2,7 +2,18 @@ namespace ray { -const TaskExecutionSpecification &Task::GetTaskExecutionSpec() const { +namespace raylet { + +flatbuffers::Offset Task::ToFlatbuffer( + flatbuffers::FlatBufferBuilder &fbb) const { + auto task = CreateTask(fbb, task_spec_.ToFlatbuffer(fbb), + task_execution_spec_.ToFlatbuffer(fbb)); + return task; +} + +TaskExecutionSpecification &Task::GetTaskExecutionSpec() { return task_execution_spec_; } + +const TaskExecutionSpecification &Task::GetTaskExecutionSpecReadonly() const { return task_execution_spec_; } @@ -47,4 +58,6 @@ bool Task::DependsOn(const ObjectID &object_id) const { return false; } +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/task.h b/src/ray/raylet/task.h index 547584464..63bdb35f3 100644 --- a/src/ray/raylet/task.h +++ b/src/ray/raylet/task.h @@ -3,11 +3,14 @@ #include +#include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/task_execution_spec.h" #include "ray/raylet/task_spec.h" namespace ray { +namespace raylet { + /// \class Task /// /// A Task represents a Ray task and a specification of its execution (e.g., @@ -27,13 +30,29 @@ class Task { const TaskSpecification &task_spec) : task_execution_spec_(execution_spec), task_spec_(task_spec) {} + /// Create a task from a serialized flatbuffer. + /// + /// \param task_flatbuffer The serialized task. + Task(const protocol::Task &task_flatbuffer) + : task_execution_spec_(*task_flatbuffer.task_execution_spec()), + task_spec_(*task_flatbuffer.task_specification()) {} + /// Destroy the task. virtual ~Task() {} + /// Serialize a task to a flatbuffer. + /// + /// \param fbb The flatbuffer builder. + /// \return An offset to the serialized task. + flatbuffers::Offset ToFlatbuffer( + flatbuffers::FlatBufferBuilder &fbb) const; + /// Get the execution specification for the task. /// /// \return The mutable specification for the task. - const TaskExecutionSpecification &GetTaskExecutionSpec() const; + TaskExecutionSpecification &GetTaskExecutionSpec(); + + const TaskExecutionSpecification &GetTaskExecutionSpecReadonly() const; /// Get the immutable specification for the task. /// @@ -64,6 +83,8 @@ class Task { TaskSpecification task_spec_; }; +} // namespace raylet + } // namespace ray -#endif // RAY_RAYLET_TASK_H \ No newline at end of file +#endif // RAY_RAYLET_TASK_H diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index dad1f5c8f..4ade555f3 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -2,6 +2,8 @@ namespace ray { +namespace raylet { + TaskDependencyManager::TaskDependencyManager( ObjectManager &object_manager, // ReconstructionPolicy &reconstruction_policy, @@ -59,6 +61,8 @@ bool TaskDependencyManager::TaskReady(const Task &task) const { } void TaskDependencyManager::SubscribeTaskReady(const Task &task) { + // TODO(swang): Don't pull arguments that are going to be created by a queued + // or running task. TaskID task_id = task.GetTaskSpecification().TaskId(); const std::vector arguments = task.GetDependencies(); // Add the task's arguments to the table of subscribed tasks. @@ -104,4 +108,6 @@ void TaskDependencyManager::MarkDependencyReady(const ObjectID &object) { throw std::runtime_error("Method not implemented"); } +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index d1120ecfc..cfee6e8c7 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -10,6 +10,8 @@ namespace ray { +namespace raylet { + class ReconstructionPolicy; /// \class TaskDependencyManager @@ -77,6 +79,8 @@ class TaskDependencyManager { std::function task_ready_callback_; }; +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_TASK_DEPENDENCY_MANAGER_H diff --git a/src/ray/raylet/task_execution_spec.cc b/src/ray/raylet/task_execution_spec.cc index 8cbd0e067..91473f557 100644 --- a/src/ray/raylet/task_execution_spec.cc +++ b/src/ray/raylet/task_execution_spec.cc @@ -2,35 +2,57 @@ namespace ray { -TaskExecutionSpecification::TaskExecutionSpecification( - const std::vector &&execution_dependencies) - : execution_dependencies_(std::move(execution_dependencies)), - last_timestamp_(0), - spillback_count_(0) {} +namespace raylet { TaskExecutionSpecification::TaskExecutionSpecification( - const std::vector &&execution_dependencies, int spillback_count) - : execution_dependencies_(std::move(execution_dependencies)), - last_timestamp_(0), - spillback_count_(spillback_count) {} + const std::vector &&dependencies) { + SetExecutionDependencies(dependencies); +} -const std::vector &TaskExecutionSpecification::ExecutionDependencies() const { - return execution_dependencies_; +TaskExecutionSpecification::TaskExecutionSpecification( + const std::vector &&dependencies, int num_forwards) { + // TaskExecutionSpecification(std::move(dependencies)); + SetExecutionDependencies(dependencies); + execution_spec_.num_forwards = num_forwards; +} + +flatbuffers::Offset +TaskExecutionSpecification::ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const { + fbb.ForceDefaults(true); + return protocol::TaskExecutionSpecification::Pack(fbb, &execution_spec_); +} + +std::vector TaskExecutionSpecification::ExecutionDependencies() const { + std::vector dependencies; + for (const auto &dependency : execution_spec_.dependencies) { + dependencies.push_back(ObjectID::from_binary(dependency)); + } + return dependencies; } void TaskExecutionSpecification::SetExecutionDependencies( const std::vector &dependencies) { - execution_dependencies_ = dependencies; + for (const auto &dependency : dependencies) { + execution_spec_.dependencies.push_back(dependency.binary()); + } } -int TaskExecutionSpecification::SpillbackCount() const { return spillback_count_; } - -void TaskExecutionSpecification::IncrementSpillbackCount() { ++spillback_count_; } - -int64_t TaskExecutionSpecification::LastTimeStamp() const { return last_timestamp_; } - -void TaskExecutionSpecification::SetLastTimeStamp(int64_t new_timestamp) { - last_timestamp_ = new_timestamp; +int TaskExecutionSpecification::NumForwards() const { + return execution_spec_.num_forwards; } +void TaskExecutionSpecification::IncrementNumForwards() { + execution_spec_.num_forwards += 1; +} + +int64_t TaskExecutionSpecification::LastTimestamp() const { + return execution_spec_.last_timestamp; +} + +void TaskExecutionSpecification::SetLastTimestamp(int64_t new_timestamp) { + execution_spec_.last_timestamp = new_timestamp; +} + +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/task_execution_spec.h b/src/ray/raylet/task_execution_spec.h index 82d992d95..6005a40e8 100644 --- a/src/ray/raylet/task_execution_spec.h +++ b/src/ray/raylet/task_execution_spec.h @@ -4,9 +4,12 @@ #include #include "ray/id.h" +#include "ray/raylet/format/node_manager_generated.h" namespace ray { +namespace raylet { + /// \class TaskExecutionSpecification /// /// The task execution specification encapsulates all mutable information about @@ -16,60 +19,70 @@ class TaskExecutionSpecification { public: /// Create a task execution specification. /// - /// \param execution_dependencies The task's dependencies, determined at - /// execution time. - TaskExecutionSpecification(const std::vector &&execution_dependencies); + /// \param dependencies The task's dependencies, determined at execution + /// time. + TaskExecutionSpecification(const std::vector &&dependencies); /// Create a task execution specification. /// - /// \param execution_dependencies The task's dependencies, determined at - /// execution time. - /// \param spillback_count The number of times this task was spilled back by - /// local schedulers. - TaskExecutionSpecification(const std::vector &&execution_dependencies, - int spillback_count); + /// \param dependencies The task's dependencies, determined at execution + /// time. + /// \param num_forwards The number of times this task has been forwarded by a + /// node manager. + TaskExecutionSpecification(const std::vector &&dependencies, + int num_forwards); + + /// Create a task execution specification from a serialized flatbuffer. + /// + /// \param spec_flatbuffer The serialized specification. + TaskExecutionSpecification( + const protocol::TaskExecutionSpecification &spec_flatbuffer) { + spec_flatbuffer.UnPackTo(&execution_spec_); + } + + /// Serialize a task execution specification to a flatbuffer. + /// + /// \param fbb The flatbuffer builder. + /// \return An offset to the serialized task execution specification. + flatbuffers::Offset ToFlatbuffer( + flatbuffers::FlatBufferBuilder &fbb) const; /// Get the task's execution dependencies. /// /// \return A vector of object IDs representing this task's execution - /// dependencies. - const std::vector &ExecutionDependencies() const; + /// dependencies. + std::vector ExecutionDependencies() const; /// Set the task's execution dependencies. /// /// \param dependencies The value to set the execution dependencies to. void SetExecutionDependencies(const std::vector &dependencies); - /// Get the task's spillback count, which tracks the number of times - /// this task was spilled back from local to the global scheduler. + /// Get the number of times this task has been forwarded. /// - /// \return The spillback count for this task. - int SpillbackCount() const; + /// \return The number of times this task has been forwarded. + int NumForwards() const; - /// Increment the spillback count for this task. - void IncrementSpillbackCount(); + /// Increment the number of times this task has been forwarded. + void IncrementNumForwards(); /// Get the task's last timestamp. /// /// \return The timestamp when this task was last received for scheduling. - int64_t LastTimeStamp() const; + int64_t LastTimestamp() const; /// Set the task's last timestamp to the specified value. /// /// \param new_timestamp The new timestamp in millisecond to set the task's - /// time stamp to. Tracks the last time this task entered a local - /// scheduler. - void SetLastTimeStamp(int64_t new_timestamp); + /// time stamp to. Tracks the last time this task entered a local scheduler. + void SetLastTimestamp(int64_t new_timestamp); private: - /// A list of object IDs representing the dependencies of this task that may - /// change at execution time. - std::vector execution_dependencies_; - /// The last time this task was received for scheduling. - int64_t last_timestamp_; - /// The number of times this task was spilled back by local schedulers. - int spillback_count_; + protocol::TaskExecutionSpecificationT execution_spec_; }; +} // namespace raylet + } // namespace ray + #endif // RAY_RAYLET_TASK_EXECUTION_SPECIFICATION_H diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 285d1c91f..b11ed64c4 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -5,6 +5,8 @@ namespace ray { +namespace raylet { + TaskArgument::~TaskArgument() {} TaskArgumentByReference::TaskArgumentByReference(const std::vector &references) @@ -15,14 +17,6 @@ flatbuffers::Offset TaskArgumentByReference::ToFlatbuffer( return CreateArg(fbb, to_flatbuf(fbb, references_)); } -const BYTE *TaskArgumentByReference::HashData() const { - return reinterpret_cast(references_.data()); -} - -size_t TaskArgumentByReference::HashDataLength() const { - return references_.size() * sizeof(ObjectID); -} - TaskArgumentByValue::TaskArgumentByValue(const uint8_t *value, size_t length) { value_.assign(value, value + length); } @@ -35,27 +29,12 @@ flatbuffers::Offset TaskArgumentByValue::ToFlatbuffer( return CreateArg(fbb, empty_ids, arg); } -const BYTE *TaskArgumentByValue::HashData() const { return value_.data(); } - -size_t TaskArgumentByValue::HashDataLength() const { return value_.size(); } - -static const ObjectID task_compute_return_id(TaskID task_id, int64_t return_index) { - // Here, return_indices need to be >= 0, so we can use negative indices for put. - RAY_DCHECK(return_index >= 0); - // TODO(rkn): This line requires object and task IDs to be the same size. - ObjectID return_id = task_id; - int64_t *first_bytes = (int64_t *)&return_id; - // XOR the first bytes of the object ID with the return index. - // We add one so the first return ID is not the same as the task ID. - *first_bytes = *first_bytes ^ (return_index + 1); - return return_id; +void TaskSpecification::AssignSpecification(const uint8_t *spec, size_t spec_size) { + spec_.assign(spec, spec + spec_size); } -TaskSpecification::TaskSpecification(const uint8_t *spec, size_t spec_size) - : spec_(spec, spec + spec_size) {} - -TaskSpecification::TaskSpecification(const flatbuffers::String &string) - : TaskSpecification(reinterpret_cast(string.data()), string.size()) { +TaskSpecification::TaskSpecification(const flatbuffers::String &string) { + AssignSpecification(reinterpret_cast(string.data()), string.size()); } TaskSpecification::TaskSpecification( @@ -63,9 +42,10 @@ TaskSpecification::TaskSpecification( // UniqueID actor_id, // UniqueID actor_handle_id, // int64_t actor_counter, - FunctionID function_id, const std::vector &task_arguments, - int64_t num_returns, - const std::unordered_map &required_resources) { + FunctionID function_id, + const std::vector> &task_arguments, int64_t num_returns, + const std::unordered_map &required_resources) + : spec_() { flatbuffers::FlatBufferBuilder fbb; // Compute hashes. @@ -78,14 +58,6 @@ TaskSpecification::TaskSpecification( // sha256_update(&ctx, (BYTE *) &actor_counter, sizeof(actor_counter)); // sha256_update(&ctx, (BYTE *) &is_actor_checkpoint_method, // sizeof(is_actor_checkpoint_method)); - sha256_update(&ctx, (BYTE *)&function_id, sizeof(function_id)); - - // Serialize and hash the arguments. - std::vector> arguments; - for (auto &argument : task_arguments) { - arguments.push_back(argument.ToFlatbuffer(fbb)); - sha256_update(&ctx, (BYTE *)argument.HashData(), argument.HashDataLength()); - } // Compute the final task ID from the hash. BYTE buff[DIGEST_SIZE]; @@ -93,11 +65,17 @@ TaskSpecification::TaskSpecification( TaskID task_id; RAY_DCHECK(sizeof(task_id) <= DIGEST_SIZE); memcpy(&task_id, buff, sizeof(task_id)); + task_id = FinishTaskId(task_id); + // Add argument object IDs. + std::vector> arguments; + for (auto &argument : task_arguments) { + arguments.push_back(argument->ToFlatbuffer(fbb)); + } // Add return object IDs. std::vector> returns; - for (int64_t i = 0; i < num_returns; i++) { - ObjectID return_id = task_compute_return_id(task_id, i); + for (int64_t i = 1; i < num_returns + 1; i++) { + ObjectID return_id = ComputeReturnId(task_id, i); returns.push_back(to_flatbuf(fbb, return_id)); } @@ -110,7 +88,7 @@ TaskSpecification::TaskSpecification( fbb.CreateVector(arguments), fbb.CreateVector(returns), map_to_flatbuf(fbb, required_resources)); fbb.Finish(spec); - TaskSpecification(fbb.GetBufferPointer(), fbb.GetSize()); + AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } flatbuffers::Offset TaskSpecification::ToFlatbuffer( @@ -130,7 +108,8 @@ TaskID TaskSpecification::TaskId() const { return from_flatbuf(*message->task_id()); } UniqueID TaskSpecification::DriverId() const { - throw std::runtime_error("Method not implemented"); + auto message = flatbuffers::GetRoot(spec_.data()); + return from_flatbuf(*message->driver_id()); } TaskID TaskSpecification::ParentTaskId() const { throw std::runtime_error("Method not implemented"); @@ -148,7 +127,13 @@ int64_t TaskSpecification::NumArgs() const { } int64_t TaskSpecification::NumReturns() const { - throw std::runtime_error("Method not implemented"); + auto message = flatbuffers::GetRoot(spec_.data()); + return message->returns()->size(); +} + +ObjectID TaskSpecification::ReturnId(int64_t return_index) const { + auto message = flatbuffers::GetRoot(spec_.data()); + return from_flatbuf(*message->returns()->Get(return_index)); } bool TaskSpecification::ArgByRef(int64_t arg_index) const { @@ -180,4 +165,6 @@ const ResourceSet TaskSpecification::GetRequiredResources() const { return ResourceSet(required_resources); } +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 2ab555ee1..9bc91595c 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -6,7 +6,7 @@ #include #include -#include "ray/../common/format/common_generated.h" +#include "format/common_generated.h" #include "ray/id.h" #include "ray/raylet/scheduling_resources.h" @@ -16,6 +16,8 @@ extern "C" { namespace ray { +namespace raylet { + /// \class TaskArgument /// /// A virtual class that represents an argument to a task. @@ -28,16 +30,6 @@ class TaskArgument { virtual flatbuffers::Offset ToFlatbuffer( flatbuffers::FlatBufferBuilder &fbb) const = 0; - /// Get the hashable byte data. - /// - /// \return A pointer to the byte data. - virtual const BYTE *HashData() const = 0; - - /// Get the hashable byte data length. - /// - /// \return The length of the hashable byte data. - virtual size_t HashDataLength() const = 0; - virtual ~TaskArgument() = 0; }; @@ -45,14 +37,15 @@ class TaskArgument { /// /// A task argument consisting of a list of object ID references. class TaskArgumentByReference : virtual public TaskArgument { + public: /// Create a task argument by reference from a list of object IDs. /// /// \param references A list of object ID references. TaskArgumentByReference(const std::vector &references); + ~TaskArgumentByReference(){}; + flatbuffers::Offset ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const; - const BYTE *HashData() const; - size_t HashDataLength() const; private: /// The object IDs. @@ -63,6 +56,7 @@ class TaskArgumentByReference : virtual public TaskArgument { /// /// A task argument containing the raw value. class TaskArgumentByValue : public TaskArgument { + public: /// Create a task argument from a raw value. /// /// \param value A pointer to the raw value. @@ -70,8 +64,6 @@ class TaskArgumentByValue : public TaskArgument { TaskArgumentByValue(const uint8_t *value, size_t length); flatbuffers::Offset ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const; - const BYTE *HashData() const; - size_t HashDataLength() const; private: /// The raw value. @@ -106,7 +98,8 @@ class TaskSpecification { // UniqueID actor_id, // UniqueID actor_handle_id, // int64_t actor_counter, - FunctionID function_id, const std::vector &arguments, + FunctionID function_id, + const std::vector> &arguments, int64_t num_returns, const std::unordered_map &required_resources); @@ -130,14 +123,15 @@ class TaskSpecification { bool ArgByRef(int64_t arg_index) const; int ArgIdCount(int64_t arg_index) const; ObjectID ArgId(int64_t arg_index, int64_t id_index) const; + ObjectID ReturnId(int64_t return_index) const; const uint8_t *ArgVal(int64_t arg_index) const; size_t ArgValLength(int64_t arg_index) const; double GetRequiredResource(const std::string &resource_name) const; const ResourceSet GetRequiredResources() const; private: - /// Task specification constructor from a pointer. - TaskSpecification(const uint8_t *spec, size_t spec_size); + /// Assign the specification data from a pointer. + void AssignSpecification(const uint8_t *spec, size_t spec_size); /// Get a pointer to the byte data. const uint8_t *data() const; /// Get the size in bytes of the task specification. @@ -147,6 +141,8 @@ class TaskSpecification { std::vector spec_; }; +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_TASK_SPECIFICATION_H diff --git a/src/ray/raylet/task_test.cc b/src/ray/raylet/task_test.cc new file mode 100644 index 000000000..9f3545bdf --- /dev/null +++ b/src/ray/raylet/task_test.cc @@ -0,0 +1,45 @@ +#include "gtest/gtest.h" + +#include "ray/raylet/task_spec.h" + +namespace ray { + +namespace raylet { + +void TestTaskReturnId(const TaskID &task_id, int64_t return_index) { + // Round trip test for computing the object ID for a task's return value, + // then computing the task ID that created the object. + ObjectID return_id = ComputeReturnId(task_id, return_index); + ASSERT_EQ(ComputeTaskId(return_id), task_id); + ASSERT_EQ(ComputeObjectIndex(return_id), return_index); +} + +void TestTaskPutId(const TaskID &task_id, int64_t put_index) { + // Round trip test for computing the object ID for a task's put value, then + // computing the task ID that created the object. + ObjectID put_id = ComputePutId(task_id, put_index); + ASSERT_EQ(ComputeTaskId(put_id), task_id); + ASSERT_EQ(ComputeObjectIndex(put_id), -1 * put_index); +} + +TEST(TaskSpecTest, TestTaskReturnIds) { + TaskID task_id = FinishTaskId(TaskID::from_random()); + + // Check that we can compute between a task ID and the object IDs of its + // return values and puts. + TestTaskReturnId(task_id, 1); + TestTaskReturnId(task_id, 2); + TestTaskReturnId(task_id, kMaxTaskReturns); + TestTaskPutId(task_id, 1); + TestTaskPutId(task_id, 2); + TestTaskPutId(task_id, kMaxTaskPuts); +} + +} // namespace raylet + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index b303b3f0d..ef7eb75c1 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -8,6 +8,8 @@ namespace ray { +namespace raylet { + /// A constructor responsible for initializing the state of a worker. Worker::Worker(pid_t pid, std::shared_ptr connection) : pid_(pid), connection_(connection), assigned_task_id_(TaskID::nil()) {} @@ -22,4 +24,6 @@ const std::shared_ptr Worker::Connection() const { return connection_; } +} // namespace raylet + } // end namespace ray diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 8e02f1367..8fc8827b2 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -8,6 +8,8 @@ namespace ray { +namespace raylet { + /// Worker class encapsulates the implementation details of a worker. A worker /// is the execution container around a unit of Ray work, such as a task or an /// actor. Ray units of work execute in the context of a Worker. @@ -32,6 +34,8 @@ class Worker { TaskID assigned_task_id_; }; +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_WORKER_H diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 7f6c2bed8..c9551c2c5 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -5,8 +5,15 @@ namespace ray { +namespace raylet { + /// A constructor that initializes a worker pool with num_workers workers. -WorkerPool::WorkerPool(int num_workers) { +WorkerPool::WorkerPool(int num_workers, const std::vector &worker_command) + : worker_command_(worker_command) { + worker_command_.push_back(NULL); + // Ignore SIGCHLD signals. If we don't do this, then worker processes will + // become zombies instead of dying gracefully. + signal(SIGCHLD, SIG_IGN); for (int i = 0; i < num_workers; i++) { StartWorker(); } @@ -18,10 +25,23 @@ WorkerPool::~WorkerPool() { registered_workers_.clear(); } -/// Create a new worker and add it to the pool -bool WorkerPool::StartWorker() { - // TODO(swang): Start the worker. - return true; +void WorkerPool::StartWorker() { + RAY_CHECK(!worker_command_.empty()) << "No worker command provided"; + + // Launch the process to create the worker. + pid_t pid = fork(); + if (pid != 0) { + RAY_LOG(DEBUG) << "Started worker with pid " << pid; + return; + } + + // Reset the SIGCHLD handler for the worker. + signal(SIGCHLD, SIG_DFL); + // Try to execute the worker command. + + int rv = execvp(worker_command_[0], (char *const *)worker_command_.data()); + // The worker failed to start. This is a fatal error. + RAY_LOG(FATAL) << "Failed to start worker with return value " << rv; } uint32_t WorkerPool::PoolSize() const { return pool_.size(); } @@ -75,4 +95,6 @@ bool WorkerPool::DisconnectWorker(std::shared_ptr worker) { return removeWorker(pool_, worker); } +} // namespace raylet + } // namespace ray diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 6a37047ed..7c76eea11 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -9,6 +9,8 @@ namespace ray { +namespace raylet { + class Worker; /// \class WorkerPool @@ -23,7 +25,7 @@ class WorkerPool { /// pool. /// /// \param num_workers The number of workers to start. - WorkerPool(int num_workers); + WorkerPool(int num_workers, const std::vector &worker_command); /// Destructor responsible for freeing a set of workers owned by this class. ~WorkerPool(); @@ -35,10 +37,9 @@ class WorkerPool { /// Asynchronously start a new worker process. Once the worker process has /// registered with an external server, the process should create and - /// register a new Worker, then add itself to the pool. - /// - /// \return Whether the worker process was successfully started. - bool StartWorker(); + /// register a new Worker, then add itself to the pool. Failure to start + /// the worker process is a fatal error. + void StartWorker(); /// Register a new worker. The Worker should be added by the caller to the /// pool after it becomes idle (e.g., requests a work assignment). @@ -73,6 +74,7 @@ class WorkerPool { std::shared_ptr PopWorker(); private: + std::vector worker_command_; /// The pool of idle workers. std::list> pool_; /// All workers that have registered and are still connected, including both @@ -80,6 +82,9 @@ class WorkerPool { // TODO(swang): Make this a map to make GetRegisteredWorker faster. std::list> registered_workers_; }; + +} // namespace raylet + } // namespace ray #endif // RAY_RAYLET_WORKER_POOL_H diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 49a5daf12..f2dcb8978 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -6,27 +6,35 @@ namespace ray { -class MockClientManager : public ClientManager { - public: - MOCK_METHOD3(ProcessClientMessage, - void(std::shared_ptr, int64_t, const uint8_t *)); - MOCK_METHOD1(ProcessNewClient, void(std::shared_ptr)); -}; +namespace raylet { class WorkerPoolTest : public ::testing::Test { public: - WorkerPoolTest() : worker_pool_(0), client_manager_(), io_service_() {} + WorkerPoolTest() : worker_pool_(0, {}), io_service_() {} std::shared_ptr CreateWorker(pid_t pid) { + std::function)> client_handler = + [this](std::shared_ptr client) { + HandleNewClient(client); + }; + std::function, int64_t, const uint8_t *)> + message_handler = [this](std::shared_ptr client, + int64_t message_type, const uint8_t *message) { + HandleMessage(client, message_type, message); + }; boost::asio::local::stream_protocol::socket socket(io_service_); - auto client = LocalClientConnection::Create(client_manager_, std::move(socket)); + auto client = + LocalClientConnection::Create(client_handler, message_handler, std::move(socket)); return std::shared_ptr(new Worker(pid, client)); } protected: WorkerPool worker_pool_; - MockClientManager client_manager_; boost::asio::io_service io_service_; + + private: + void HandleNewClient(std::shared_ptr){}; + void HandleMessage(std::shared_ptr, int64_t, const uint8_t *){}; }; TEST_F(WorkerPoolTest, HandleWorkerRegistration) { @@ -66,6 +74,8 @@ TEST_F(WorkerPoolTest, HandleWorkerPushPop) { ASSERT_TRUE(workers.count(popped_worker) > 0); } +} // namespace raylet + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/test/run_gcs_tests.sh b/src/ray/test/run_gcs_tests.sh index 624424e5e..145582257 100644 --- a/src/ray/test/run_gcs_tests.sh +++ b/src/ray/test/run_gcs_tests.sh @@ -4,6 +4,7 @@ # Cause the script to exit if a single command fails. set -e +set -x # Start Redis. ./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 & @@ -12,3 +13,4 @@ sleep 1s ./src/ray/gcs/client_test ./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown +sleep 1s diff --git a/src/ray/test/run_object_manager_tests.sh b/src/ray/test/run_object_manager_tests.sh new file mode 100644 index 000000000..2c4d66768 --- /dev/null +++ b/src/ray/test/run_object_manager_tests.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +# This needs to be run in the build tree, which is normally ray/python/ray/core + +# Cause the script to exit if a single command fails. +set -e +set -x + +# Get the directory in which this script is executing. +SCRIPT_DIR="`dirname \"$0\"`" +RAY_ROOT="$SCRIPT_DIR/../../.." +RAY_ROOT="`( cd \"$RAY_ROOT\" && pwd )`" +if [ -z "$RAY_ROOT" ] ; then + exit 1 +fi +# Ensure we're in the right directory. +if [ ! -d "$RAY_ROOT/python" ]; then + echo "Unable to find root Ray directory. Has this script moved?" + exit 1 +fi + +CORE_DIR="$RAY_ROOT/python/ray/core" +REDIS_DIR="$CORE_DIR/src/common/thirdparty/redis/src" +REDIS_MODULE="$CORE_DIR/src/common/redis_module/libray_redis_module.so" +STORE_EXEC="$CORE_DIR/src/plasma/plasma_store" + +echo "$STORE_EXEC" +echo "$REDIS_DIR/redis-server --loglevel warning --loadmodule $REDIS_MODULE --port 6379" +echo "$REDIS_DIR/redis-cli -p 6379 shutdown" + +# Allow cleanup commands to fail. +killall plasma_store || true +$REDIS_DIR/redis-cli -p 6379 shutdown || true +sleep 1s +$REDIS_DIR/redis-server --loglevel warning --loadmodule $REDIS_MODULE --port 6379 & +sleep 1s + +# Run tests. +$CORE_DIR/src/ray/object_manager/object_manager_stress_test $STORE_EXEC +sleep 1s +$CORE_DIR/src/ray/object_manager/object_manager_test $STORE_EXEC +$REDIS_DIR/redis-cli -p 6379 shutdown +sleep 1s + +# Include raylet integration test once it's ready. +# $CORE_DIR/src/ray/raylet/object_manager_integration_test $STORE_EXEC diff --git a/src/ray/test/run_task_test.sh b/src/ray/test/run_task_test.sh new file mode 100644 index 000000000..56a593a8d --- /dev/null +++ b/src/ray/test/run_task_test.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# This needs to be run in the build tree, which is normally ray/python/ray/core + +# Cause the script to exit if a single command fails. +set -e +set -x + +# Tear down the Raylet. +#bash ../../../src/ray/test/stop_raylets.sh + +# Set up a single Raylet. +bash ../../../src/ray/test/start_raylets.sh + +sleep 1 + +# Connect a driver to the raylet and make sure it completes. +python ../../../src/ray/python/test_driver.py /tmp/raylet1 /tmp/store1 + +sleep 1 + +./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown +bash ../../../src/ray/test/stop_raylets.sh diff --git a/src/ray/test/start_raylet.sh b/src/ray/test/start_raylet.sh new file mode 100644 index 000000000..afff9ed3c --- /dev/null +++ b/src/ray/test/start_raylet.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +# This needs to be run in the build tree, which is normally ray/python/ray/core + +# Cause the script to exit if a single command fails. +set -e + +if [[ $1 ]]; then + RAYLET_NUM=$1 +else + RAYLET_NUM=1 +fi + +STORE_SOCKET_NAME="/tmp/store$RAYLET_NUM" +RAYLET_SOCKET_NAME="/tmp/raylet$RAYLET_NUM" + +if [[ `stat $RAYLET_SOCKET_NAME` ]]; then + rm $RAYLET_SOCKET_NAME +fi +if [[ `stat $STORE_SOCKET_NAME` ]]; then + rm $STORE_SOCKET_NAME +fi + +./src/plasma/plasma_store -m 1000000000 -s $STORE_SOCKET_NAME & +./src/ray/raylet/raylet $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME 127.0.0.1 6379 & + +echo +echo "WORKER COMMAND: python ../../../src/ray/python/worker.py $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME" +echo diff --git a/src/ray/test/start_raylets.sh b/src/ray/test/start_raylets.sh new file mode 100644 index 000000000..8ad6dde18 --- /dev/null +++ b/src/ray/test/start_raylets.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# This needs to be run in the build tree, which is normally ray/python/ray/core + +# Cause the script to exit if a single command fails. +set -e + +# Start the GCS. +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 >/dev/null & +sleep 1s + +if [[ $1 ]]; then + NUM_RAYLETS=$1 +else + NUM_RAYLETS=1 +fi + + +for i in `seq 1 $NUM_RAYLETS`; do + STORE_SOCKET_NAME="/tmp/store$i" + RAYLET_SOCKET_NAME="/tmp/raylet$i" + + if [[ `stat $RAYLET_SOCKET_NAME` ]]; then + rm $RAYLET_SOCKET_NAME + fi + if [[ `stat $STORE_SOCKET_NAME` ]]; then + rm $STORE_SOCKET_NAME + fi + + ./src/plasma/plasma_store -m 1000000000 -s $STORE_SOCKET_NAME & + ./src/ray/raylet/raylet $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME 127.0.0.1 6379 & + + echo + echo "WORKER COMMAND: python ../../../src/ray/python/worker.py $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME" + echo +done diff --git a/src/ray/test/start_redis.sh b/src/ray/test/start_redis.sh new file mode 100644 index 000000000..0b1c26ed9 --- /dev/null +++ b/src/ray/test/start_redis.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +# This needs to be run in the build tree, which is normally ray/python/ray/core + +# Cause the script to exit if a single command fails. +set -e + +# Start the GCS. +./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 >/dev/null & +sleep 1s + diff --git a/src/ray/test/stop_raylets.sh b/src/ray/test/stop_raylets.sh new file mode 100644 index 000000000..8387430ef --- /dev/null +++ b/src/ray/test/stop_raylets.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +killall raylet +sleep 1 +killall plasma_store +sleep 1 +killall redis-server +sleep 1 +rm /tmp/store* /tmp/raylet*