mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 20:00:22 +08:00
XRay Task Forwarding Milestone (#1785)
Summary: Able to run 1000 tasks with object dependencies on a set of distributed Raylets. Raylet Changes: Finalized ClientConnection class. Task forwarding. NM-to-NM heartbeats. NM resource accounting for tasks. Simple scheduling policy with task forwarding. Creating and maintaining NM 2 NM long-lived connections and reusing them for task forwarding. LineageCache Changes: LineageCache without cleanup of tasks committed by remote nodes. Lineage cache writeback and cleanup implementation. ObjectManager Changes: Object manager event loop/ClientConnection refactor. Multithreaded object manager (disabled in this PR). Testing Changes: Integration tests for task submission on multiple Raylets. Stress tests for object manager (with GCS and object store integration). Co-authored-by: Stephanie Wang <swang@cs.berkeley.edu> Co-authored-by: Alexey Tumanov <atumanov@gmail.com>
This commit is contained in:
committed by
Philipp Moritz
parent
40c9b9cd60
commit
6e06a9e338
@@ -102,6 +102,13 @@ install:
|
||||
|
||||
- cd python/ray/core
|
||||
- bash ../../../src/ray/test/run_gcs_tests.sh
|
||||
# Raylet tests.
|
||||
- bash ../../../src/ray/test/run_object_manager_tests.sh
|
||||
- bash ../../../src/ray/test/run_task_test.sh
|
||||
- ./src/ray/raylet/task_test
|
||||
- ./src/ray/raylet/worker_pool_test
|
||||
- ./src/ray/raylet/lineage_cache_test
|
||||
|
||||
- bash ../../../src/common/test/run_tests.sh
|
||||
- bash ../../../src/plasma/test/run_tests.sh
|
||||
- bash ../../../src/local_scheduler/test/run_tests.sh
|
||||
|
||||
@@ -50,7 +50,14 @@
|
||||
}
|
||||
|
||||
static const char *table_prefixes[] = {
|
||||
NULL, "TASK:", "TASK:", "CLIENT:", "OBJECT:", "FUNCTION:",
|
||||
NULL,
|
||||
"TASK:",
|
||||
"TASK:",
|
||||
"CLIENT:",
|
||||
"OBJECT:",
|
||||
"FUNCTION:",
|
||||
"TASK_RECONSTRUCTION:",
|
||||
"HEARTBEAT:",
|
||||
};
|
||||
|
||||
/// Parse a Redis string into a TablePubsub channel.
|
||||
@@ -811,8 +818,9 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx,
|
||||
// notifications.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
TableEntryToFlatbuf(table_key, id, fbb);
|
||||
RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, reinterpret_cast<const char *>(fbb.GetBufferPointer()),
|
||||
fbb.GetSize());
|
||||
RedisModule_Call(ctx, "PUBLISH", "sb", client_channel,
|
||||
reinterpret_cast<const char *>(fbb.GetBufferPointer()),
|
||||
fbb.GetSize());
|
||||
}
|
||||
RedisModule_CloseKey(table_key);
|
||||
|
||||
|
||||
@@ -141,13 +141,8 @@ GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop,
|
||||
std::vector<std::string>());
|
||||
db_attach(state->db, loop, false);
|
||||
|
||||
ClientTableDataT client_info;
|
||||
client_info.client_id = get_db_client_id(state->db).binary();
|
||||
client_info.node_manager_address = std::string(node_ip_address);
|
||||
client_info.local_scheduler_port = 0;
|
||||
client_info.object_manager_port = 0;
|
||||
RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr),
|
||||
redis_primary_port, client_info));
|
||||
redis_primary_port));
|
||||
RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop));
|
||||
state->policy_state = GlobalSchedulerPolicyState_init();
|
||||
return state;
|
||||
|
||||
@@ -355,13 +355,8 @@ LocalSchedulerState *LocalSchedulerState_init(
|
||||
"local_scheduler", node_ip_address, db_connect_args);
|
||||
db_attach(state->db, loop, false);
|
||||
|
||||
ClientTableDataT client_info;
|
||||
client_info.client_id = get_db_client_id(state->db).binary();
|
||||
client_info.node_manager_address = std::string(node_ip_address);
|
||||
client_info.local_scheduler_port = 0;
|
||||
client_info.object_manager_port = 0;
|
||||
RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr),
|
||||
redis_primary_port, client_info));
|
||||
redis_primary_port));
|
||||
RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop));
|
||||
} else {
|
||||
state->db = NULL;
|
||||
|
||||
@@ -487,13 +487,8 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name,
|
||||
"plasma_manager", manager_addr, db_connect_args);
|
||||
db_attach(state->db, state->loop, false);
|
||||
|
||||
ClientTableDataT client_info;
|
||||
client_info.client_id = get_db_client_id(state->db).binary();
|
||||
client_info.node_manager_address = std::string(manager_addr);
|
||||
client_info.local_scheduler_port = 0;
|
||||
client_info.object_manager_port = manager_port;
|
||||
RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr),
|
||||
redis_primary_port, client_info));
|
||||
redis_primary_port));
|
||||
RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(state->loop));
|
||||
} else {
|
||||
state->db = NULL;
|
||||
|
||||
@@ -38,8 +38,11 @@ set(RAY_SRCS
|
||||
gcs/asio.cc
|
||||
common/client_connection.cc
|
||||
object_manager/object_manager_client_connection.cc
|
||||
object_manager/object_store_client.cc
|
||||
object_manager/connection_pool.cc
|
||||
object_manager/object_store_client_pool.cc
|
||||
object_manager/object_store_notification_manager.cc
|
||||
object_manager/object_directory.cc
|
||||
object_manager/transfer_queue.cc
|
||||
object_manager/object_manager.cc
|
||||
raylet/mock_gcs_client.cc
|
||||
raylet/task.cc
|
||||
|
||||
@@ -7,20 +7,82 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket,
|
||||
const std::string &ip_address_string, int port) {
|
||||
boost::asio::ip::address ip_address =
|
||||
boost::asio::ip::address::from_string(ip_address_string);
|
||||
boost::asio::ip::tcp::endpoint endpoint(ip_address, port);
|
||||
boost::system::error_code error;
|
||||
socket.connect(endpoint, error);
|
||||
if (error) {
|
||||
return ray::Status::IOError(error.message());
|
||||
} else {
|
||||
return ray::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
ServerConnection<T>::ServerConnection(boost::asio::basic_stream_socket<T> &&socket)
|
||||
: socket_(std::move(socket)) {}
|
||||
|
||||
template <class T>
|
||||
void ServerConnection<T>::WriteBuffer(
|
||||
const std::vector<boost::asio::const_buffer> &buffer, boost::system::error_code &ec) {
|
||||
boost::asio::write(socket_, buffer, ec);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void ServerConnection<T>::ReadBuffer(
|
||||
const std::vector<boost::asio::mutable_buffer> &buffer,
|
||||
boost::system::error_code &ec) {
|
||||
boost::asio::read(socket_, buffer, ec);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
ray::Status ServerConnection<T>::WriteMessage(int64_t type, int64_t length,
|
||||
const uint8_t *message) {
|
||||
std::vector<boost::asio::const_buffer> message_buffers;
|
||||
auto write_version = RayConfig::instance().ray_protocol_version();
|
||||
message_buffers.push_back(boost::asio::buffer(&write_version, sizeof(write_version)));
|
||||
message_buffers.push_back(boost::asio::buffer(&type, sizeof(type)));
|
||||
message_buffers.push_back(boost::asio::buffer(&length, sizeof(length)));
|
||||
message_buffers.push_back(boost::asio::buffer(message, length));
|
||||
// Write the message and then wait for more messages.
|
||||
// TODO(swang): Does this need to be an async write?
|
||||
boost::system::error_code error;
|
||||
boost::asio::write(socket_, message_buffers, error);
|
||||
if (error) {
|
||||
return ray::Status::IOError(error.message());
|
||||
} else {
|
||||
return ray::Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
|
||||
ClientManager<T> &manager, boost::asio::basic_stream_socket<T> &&socket) {
|
||||
ClientHandler<T> &client_handler, MessageHandler<T> &message_handler,
|
||||
boost::asio::basic_stream_socket<T> &&socket) {
|
||||
std::shared_ptr<ClientConnection<T>> self(
|
||||
new ClientConnection(manager, std::move(socket)));
|
||||
new ClientConnection(message_handler, std::move(socket)));
|
||||
// Let our manager process our new connection.
|
||||
self->manager_.ProcessNewClient(self);
|
||||
client_handler(self);
|
||||
return self;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
ClientConnection<T>::ClientConnection(ClientManager<T> &manager,
|
||||
ClientConnection<T>::ClientConnection(MessageHandler<T> &message_handler,
|
||||
boost::asio::basic_stream_socket<T> &&socket)
|
||||
: socket_(std::move(socket)), manager_(manager) {}
|
||||
: ServerConnection<T>(std::move(socket)), message_handler_(message_handler) {}
|
||||
|
||||
template <class T>
|
||||
const ClientID &ClientConnection<T>::GetClientID() {
|
||||
return client_id_;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void ClientConnection<T>::SetClientID(const ClientID &client_id) {
|
||||
client_id_ = client_id;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void ClientConnection<T>::ProcessMessages() {
|
||||
@@ -31,7 +93,7 @@ void ClientConnection<T>::ProcessMessages() {
|
||||
header.push_back(boost::asio::buffer(&read_type_, sizeof(read_type_)));
|
||||
header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_)));
|
||||
boost::asio::async_read(
|
||||
socket_, header,
|
||||
ServerConnection<T>::socket_, header,
|
||||
boost::bind(&ClientConnection<T>::ProcessMessageHeader, this->shared_from_this(),
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
@@ -52,57 +114,23 @@ void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &
|
||||
read_message_.resize(read_length_);
|
||||
// Wait for the message to be read.
|
||||
boost::asio::async_read(
|
||||
socket_, boost::asio::buffer(read_message_),
|
||||
ServerConnection<T>::socket_, boost::asio::buffer(read_message_),
|
||||
boost::bind(&ClientConnection<T>::ProcessMessage, this->shared_from_this(),
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void ClientConnection<T>::WriteMessage(int64_t type, size_t length,
|
||||
const uint8_t *message) {
|
||||
std::vector<boost::asio::const_buffer> message_buffers;
|
||||
write_version_ = RayConfig::instance().ray_protocol_version();
|
||||
write_type_ = type;
|
||||
write_length_ = length;
|
||||
write_message_.assign(message, message + length);
|
||||
message_buffers.push_back(boost::asio::buffer(&write_version_, sizeof(write_version_)));
|
||||
message_buffers.push_back(boost::asio::buffer(&write_type_, sizeof(write_type_)));
|
||||
message_buffers.push_back(boost::asio::buffer(&write_length_, sizeof(write_length_)));
|
||||
message_buffers.push_back(boost::asio::buffer(write_message_));
|
||||
boost::system::error_code error;
|
||||
// Write the message and then wait for more messages.
|
||||
boost::asio::async_write(
|
||||
socket_, message_buffers,
|
||||
boost::bind(&ClientConnection<T>::ProcessMessages, this->shared_from_this(),
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void ClientConnection<T>::ProcessMessage(const boost::system::error_code &error) {
|
||||
if (error) {
|
||||
// TODO(hme): Disconnect differently & remove dependency on node_manager_generated.h
|
||||
read_type_ = protocol::MessageType_DisconnectClient;
|
||||
}
|
||||
manager_.ProcessClientMessage(this->shared_from_this(), read_type_,
|
||||
read_message_.data());
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void ClientConnection<T>::ProcessMessages(const boost::system::error_code &error) {
|
||||
if (error) {
|
||||
ProcessMessage(error);
|
||||
} else {
|
||||
ProcessMessages();
|
||||
}
|
||||
message_handler_(this->shared_from_this(), read_type_, read_message_.data());
|
||||
}
|
||||
|
||||
template class ServerConnection<boost::asio::local::stream_protocol>;
|
||||
template class ServerConnection<boost::asio::ip::tcp>;
|
||||
template class ClientConnection<boost::asio::local::stream_protocol>;
|
||||
template class ClientConnection<boost::asio::ip::tcp>;
|
||||
|
||||
template <class T>
|
||||
ClientManager<T>::~ClientManager<T>() {}
|
||||
|
||||
template class ClientManager<boost::asio::local::stream_protocol>;
|
||||
template class ClientManager<boost::asio::ip::tcp>;
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -7,18 +7,74 @@
|
||||
#include <boost/asio/error.hpp>
|
||||
#include <boost/enable_shared_from_this.hpp>
|
||||
|
||||
#include "ray/id.h"
|
||||
#include "ray/status.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
template <class T>
|
||||
class ClientManager;
|
||||
|
||||
/// \class ClientConnection
|
||||
/// Connect a TCP socket.
|
||||
///
|
||||
/// A generic type representing a client connection on a server. This class can
|
||||
/// be used to process and write messages asynchronously from and to the
|
||||
/// client.
|
||||
template <class T>
|
||||
class ClientConnection : public std::enable_shared_from_this<ClientConnection<T>> {
|
||||
/// \param socket The socket to connect.
|
||||
/// \param ip_address The IP address to connect to.
|
||||
/// \param port The port to connect to.
|
||||
/// \return Status.
|
||||
ray::Status TcpConnect(boost::asio::ip::tcp::socket &socket,
|
||||
const std::string &ip_address, int port);
|
||||
|
||||
/// \typename ServerConnection
|
||||
///
|
||||
/// A generic type representing a client connection to a server. This typename
|
||||
/// can be used to write messages synchronously to the server.
|
||||
template <typename T>
|
||||
class ServerConnection {
|
||||
public:
|
||||
/// Create a connection to the server.
|
||||
ServerConnection(boost::asio::basic_stream_socket<T> &&socket);
|
||||
|
||||
/// Write a message to the client.
|
||||
///
|
||||
/// \param type The message type (e.g., a flatbuffer enum).
|
||||
/// \param length The size in bytes of the message.
|
||||
/// \param message A pointer to the message buffer.
|
||||
/// \return Status.
|
||||
ray::Status WriteMessage(int64_t type, int64_t length, const uint8_t *message);
|
||||
|
||||
/// Write a buffer to this connection.
|
||||
///
|
||||
/// \param buffer The buffer.
|
||||
/// \param ec The error code object in which to store error codes.
|
||||
void WriteBuffer(const std::vector<boost::asio::const_buffer> &buffer,
|
||||
boost::system::error_code &ec);
|
||||
|
||||
/// Read a buffer from this connection.
|
||||
///
|
||||
/// \param buffer The buffer.
|
||||
/// \param ec The error code object in which to store error codes.
|
||||
void ReadBuffer(const std::vector<boost::asio::mutable_buffer> &buffer,
|
||||
boost::system::error_code &ec);
|
||||
|
||||
protected:
|
||||
/// The socket connection to the server.
|
||||
boost::asio::basic_stream_socket<T> socket_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ClientConnection;
|
||||
|
||||
template <typename T>
|
||||
using ClientHandler = std::function<void(std::shared_ptr<ClientConnection<T>>)>;
|
||||
template <typename T>
|
||||
using MessageHandler =
|
||||
std::function<void(std::shared_ptr<ClientConnection<T>>, int64_t, const uint8_t *)>;
|
||||
|
||||
/// \typename ClientConnection
|
||||
///
|
||||
/// A generic type representing a client connection on a server. In addition to
|
||||
/// writing messages to the client, like in ServerConnection, this typename can
|
||||
/// also be used to process messages asynchronously from client.
|
||||
template <typename T>
|
||||
class ClientConnection : public ServerConnection<T>,
|
||||
public std::enable_shared_from_this<ClientConnection<T>> {
|
||||
public:
|
||||
/// Allocate a new node client connection.
|
||||
///
|
||||
@@ -27,24 +83,23 @@ class ClientConnection : public std::enable_shared_from_this<ClientConnection<T>
|
||||
/// \param socket The client socket.
|
||||
/// \return std::shared_ptr<ClientConnection>.
|
||||
static std::shared_ptr<ClientConnection<T>> Create(
|
||||
ClientManager<T> &manager, boost::asio::basic_stream_socket<T> &&socket);
|
||||
ClientHandler<T> &new_client_handler, MessageHandler<T> &message_handler,
|
||||
boost::asio::basic_stream_socket<T> &&socket);
|
||||
|
||||
/// \return The ClientID of the remote client.
|
||||
const ClientID &GetClientID();
|
||||
|
||||
/// \param client_id The ClientID of the remote client.
|
||||
void SetClientID(const ClientID &client_id);
|
||||
|
||||
/// Listen for and process messages from the client connection. Once a
|
||||
/// message has been fully received, the client manager's
|
||||
/// ProcessClientMessage handler will be called.
|
||||
void ProcessMessages();
|
||||
|
||||
/// Write a message to the client and then listen for more messages.
|
||||
///
|
||||
/// \param type The message type (e.g., a flatbuffer enum).
|
||||
/// \param length The size in bytes of the message.
|
||||
/// \param message A pointer to the message buffer. This will be copied into
|
||||
/// the ClientConnection's buffer.
|
||||
void WriteMessage(int64_t type, size_t length, const uint8_t *message);
|
||||
|
||||
private:
|
||||
/// A private constructor for a node client connection.
|
||||
ClientConnection(ClientManager<T> &manager,
|
||||
ClientConnection(MessageHandler<T> &message_handler,
|
||||
boost::asio::basic_stream_socket<T> &&socket);
|
||||
/// Process an error from the last operation, then process the message
|
||||
/// header from the client.
|
||||
@@ -52,54 +107,23 @@ class ClientConnection : public std::enable_shared_from_this<ClientConnection<T>
|
||||
/// Process an error from reading the message header, then process the
|
||||
/// message from the client.
|
||||
void ProcessMessage(const boost::system::error_code &error);
|
||||
/// Process an error from the last operation and then listen for more
|
||||
/// messages.
|
||||
void ProcessMessages(const boost::system::error_code &error);
|
||||
|
||||
/// The client socket.
|
||||
boost::asio::basic_stream_socket<T> socket_;
|
||||
/// A reference to the manager for this client. The manager exposes a handler
|
||||
/// for all messages processed by this client.
|
||||
ClientManager<T> &manager_;
|
||||
/// The ClientID of the remote client.
|
||||
ClientID client_id_;
|
||||
/// The handler for a message from the client.
|
||||
MessageHandler<T> message_handler_;
|
||||
/// Buffers for the current message being read rom the client.
|
||||
int64_t read_version_;
|
||||
int64_t read_type_;
|
||||
uint64_t read_length_;
|
||||
std::vector<uint8_t> read_message_;
|
||||
/// Buffers for the current message being written to the client.
|
||||
int64_t write_version_;
|
||||
int64_t write_type_;
|
||||
uint64_t write_length_;
|
||||
std::vector<uint8_t> write_message_;
|
||||
};
|
||||
|
||||
using LocalServerConnection = ServerConnection<boost::asio::local::stream_protocol>;
|
||||
using TcpServerConnection = ServerConnection<boost::asio::ip::tcp>;
|
||||
using LocalClientConnection = ClientConnection<boost::asio::local::stream_protocol>;
|
||||
using TcpClientConnection = ClientConnection<boost::asio::ip::tcp>;
|
||||
|
||||
/// \class ClientManager
|
||||
///
|
||||
/// A virtual cliant manager. Derived classes should define a method for
|
||||
/// processing a message on the server sent by the client.
|
||||
template <class T>
|
||||
class ClientManager {
|
||||
public:
|
||||
/// Process a new client connection.
|
||||
///
|
||||
/// \param client A shared pointer to the client that connected.
|
||||
virtual void ProcessNewClient(std::shared_ptr<ClientConnection<T>> client) = 0;
|
||||
|
||||
/// Process a message from a client, then listen for more messages if the
|
||||
/// client is still alive.
|
||||
///
|
||||
/// \param client A shared pointer to the client that sent the message.
|
||||
/// \param message_type The message type (e.g., a flatbuffer enum).
|
||||
/// \param message A pointer to the message buffer.
|
||||
virtual void ProcessClientMessage(std::shared_ptr<ClientConnection<T>> client,
|
||||
int64_t message_type, const uint8_t *message) = 0;
|
||||
|
||||
virtual ~ClientManager() = 0;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_COMMON_CLIENT_CONNECTION_H
|
||||
|
||||
@@ -4,6 +4,24 @@
|
||||
/// Length of Ray IDs in bytes.
|
||||
constexpr int64_t kUniqueIDSize = 20;
|
||||
|
||||
/// An ObjectID's bytes are split into the task ID itself and the index of the
|
||||
/// object's creation. This is the maximum width of the object index in bits.
|
||||
constexpr int kObjectIdIndexSize = 32;
|
||||
/// The maximum number of objects that can be returned by a task when finishing
|
||||
/// execution. An ObjectID's bytes are split into the task ID itself and the
|
||||
/// index of the object's creation. A positive index indicates an object
|
||||
/// returned by the task, so the maximum number of objects that a task can
|
||||
/// return is the maximum positive value for an integer with bit-width
|
||||
/// `kObjectIdIndexSize`.
|
||||
constexpr int64_t kMaxTaskReturns = ((int64_t)1 << (kObjectIdIndexSize - 1)) - 1;
|
||||
/// The maximum number of objects that can be put by a task during execution.
|
||||
/// An ObjectID's bytes are split into the task ID itself and the index of the
|
||||
/// object's creation. A negative index indicates an object put by the task
|
||||
/// during execution, so the maximum number of objects that a task can put is
|
||||
/// the maximum negative value for an integer with bit-width
|
||||
/// `kObjectIdIndexSize`.
|
||||
constexpr int64_t kMaxTaskPuts = ((int64_t)1 << (kObjectIdIndexSize - 1));
|
||||
|
||||
/// Prefix for the object table keys in redis.
|
||||
constexpr char kObjectTablePrefix[] = "ObjectTable";
|
||||
/// Prefix for the task table keys in redis.
|
||||
|
||||
+13
-8
@@ -6,19 +6,22 @@ namespace ray {
|
||||
|
||||
namespace gcs {
|
||||
|
||||
AsyncGcsClient::AsyncGcsClient() {}
|
||||
|
||||
AsyncGcsClient::~AsyncGcsClient() {}
|
||||
|
||||
Status AsyncGcsClient::Connect(const std::string &address, int port,
|
||||
const ClientTableDataT &client_info) {
|
||||
AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) {
|
||||
context_.reset(new RedisContext());
|
||||
RAY_RETURN_NOT_OK(context_->Connect(address, port));
|
||||
client_table_.reset(new ClientTable(context_, this, client_id));
|
||||
object_table_.reset(new ObjectTable(context_, this));
|
||||
task_table_.reset(new TaskTable(context_, this));
|
||||
raylet_task_table_.reset(new raylet::TaskTable(context_, this));
|
||||
task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this));
|
||||
client_table_.reset(new ClientTable(context_, this, client_info));
|
||||
heartbeat_table_.reset(new HeartbeatTable(context_, this));
|
||||
}
|
||||
|
||||
AsyncGcsClient::AsyncGcsClient() : AsyncGcsClient(ClientID::from_random()) {}
|
||||
|
||||
AsyncGcsClient::~AsyncGcsClient() {}
|
||||
|
||||
Status AsyncGcsClient::Connect(const std::string &address, int port) {
|
||||
RAY_RETURN_NOT_OK(context_->Connect(address, port));
|
||||
// TODO(swang): Call the client table's Connect() method here. To do this,
|
||||
// we need to make sure that we are attached to an event loop first. This
|
||||
// currently isn't possible because the aeEventLoop, which we use for
|
||||
@@ -55,6 +58,8 @@ FunctionTable &AsyncGcsClient::function_table() { return *function_table_; }
|
||||
|
||||
ClassTable &AsyncGcsClient::class_table() { return *class_table_; }
|
||||
|
||||
HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; }
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
} // namespace ray
|
||||
|
||||
+10
-3
@@ -19,6 +19,13 @@ class RedisContext;
|
||||
|
||||
class RAY_EXPORT AsyncGcsClient {
|
||||
public:
|
||||
/// Start a GCS client with the given client ID. To read from the GCS tables,
|
||||
/// Connect and then Attach must be called. To read and write from the GCS
|
||||
/// tables requires a further call to Connect to the client table.
|
||||
///
|
||||
/// \param client_id The ID to assign to the client.
|
||||
AsyncGcsClient(const ClientID &client_id);
|
||||
/// Start a GCS client with a random client ID.
|
||||
AsyncGcsClient();
|
||||
~AsyncGcsClient();
|
||||
|
||||
@@ -26,10 +33,8 @@ class RAY_EXPORT AsyncGcsClient {
|
||||
///
|
||||
/// \param address The GCS IP address.
|
||||
/// \param port The GCS port.
|
||||
/// \param client_info Information about the local client to connect.
|
||||
/// \return Status.
|
||||
Status Connect(const std::string &address, int port,
|
||||
const ClientTableDataT &client_info);
|
||||
Status Connect(const std::string &address, int port);
|
||||
/// Attach this client to a plasma event loop. Note that only
|
||||
/// one event loop should be attached at a time.
|
||||
Status Attach(plasma::EventLoop &event_loop);
|
||||
@@ -48,6 +53,7 @@ class RAY_EXPORT AsyncGcsClient {
|
||||
raylet::TaskTable &raylet_task_table();
|
||||
TaskReconstructionLog &task_reconstruction_log();
|
||||
ClientTable &client_table();
|
||||
HeartbeatTable &heartbeat_table();
|
||||
inline ErrorTable &error_table();
|
||||
|
||||
// We also need something to export generic code to run on workers from the
|
||||
@@ -67,6 +73,7 @@ class RAY_EXPORT AsyncGcsClient {
|
||||
std::unique_ptr<TaskTable> task_table_;
|
||||
std::unique_ptr<raylet::TaskTable> raylet_task_table_;
|
||||
std::unique_ptr<TaskReconstructionLog> task_reconstruction_log_;
|
||||
std::unique_ptr<HeartbeatTable> heartbeat_table_;
|
||||
std::unique_ptr<ClientTable> client_table_;
|
||||
std::shared_ptr<RedisContext> context_;
|
||||
std::unique_ptr<RedisAsioClient> asio_async_client_;
|
||||
|
||||
@@ -23,12 +23,7 @@ class TestGcs : public ::testing::Test {
|
||||
public:
|
||||
TestGcs() : num_callbacks_(0) {
|
||||
client_ = std::make_shared<gcs::AsyncGcsClient>();
|
||||
ClientTableDataT client_info;
|
||||
client_info.client_id = ClientID::from_random().binary();
|
||||
client_info.node_manager_address = "127.0.0.1";
|
||||
client_info.local_scheduler_port = 0;
|
||||
client_info.object_manager_port = 0;
|
||||
RAY_CHECK_OK(client_->Connect("127.0.0.1", 6379, client_info));
|
||||
RAY_CHECK_OK(client_->Connect("127.0.0.1", 6379));
|
||||
|
||||
job_id_ = JobID::from_random();
|
||||
}
|
||||
@@ -747,6 +742,7 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client
|
||||
ClientID added_id = client->client_table().GetLocalClientId();
|
||||
ASSERT_EQ(client_id, added_id);
|
||||
ASSERT_EQ(ClientID::from_binary(data.client_id), added_id);
|
||||
ASSERT_EQ(ClientID::from_binary(data.client_id), added_id);
|
||||
ASSERT_EQ(data.is_insertion, is_insertion);
|
||||
|
||||
auto cached_client = client->client_table().GetClient(added_id);
|
||||
@@ -763,9 +759,14 @@ void TestClientTableConnect(const JobID &job_id,
|
||||
ClientTableNotification(client, id, data, true);
|
||||
test->Stop();
|
||||
});
|
||||
|
||||
// Connect and disconnect to client table. We should receive notifications
|
||||
// for the addition and removal of our own entry.
|
||||
RAY_CHECK_OK(client->client_table().Connect());
|
||||
ClientTableDataT local_client_info = client->client_table().GetLocalClient();
|
||||
local_client_info.node_manager_address = "127.0.0.1";
|
||||
local_client_info.node_manager_port = 0;
|
||||
local_client_info.object_manager_port = 0;
|
||||
RAY_CHECK_OK(client->client_table().Connect(local_client_info));
|
||||
test->Start();
|
||||
}
|
||||
|
||||
@@ -789,7 +790,11 @@ void TestClientTableDisconnect(const JobID &job_id,
|
||||
});
|
||||
// Connect and disconnect to client table. We should receive notifications
|
||||
// for the addition and removal of our own entry.
|
||||
RAY_CHECK_OK(client->client_table().Connect());
|
||||
ClientTableDataT local_client_info = client->client_table().GetLocalClient();
|
||||
local_client_info.node_manager_address = "127.0.0.1";
|
||||
local_client_info.node_manager_port = 0;
|
||||
local_client_info.object_manager_port = 0;
|
||||
RAY_CHECK_OK(client->client_table().Connect(local_client_info));
|
||||
RAY_CHECK_OK(client->client_table().Disconnect());
|
||||
test->Start();
|
||||
}
|
||||
|
||||
+23
-16
@@ -11,7 +11,8 @@ enum TablePrefix:int {
|
||||
CLIENT,
|
||||
OBJECT,
|
||||
FUNCTION,
|
||||
TASK_RECONSTRUCTION
|
||||
TASK_RECONSTRUCTION,
|
||||
HEARTBEAT
|
||||
}
|
||||
|
||||
// The channel that Add operations to the Table should be published on, if any.
|
||||
@@ -21,7 +22,8 @@ enum TablePubsub:int {
|
||||
RAYLET_TASK,
|
||||
CLIENT,
|
||||
OBJECT,
|
||||
ACTOR
|
||||
ACTOR,
|
||||
HEARTBEAT
|
||||
}
|
||||
|
||||
table GcsTableEntry {
|
||||
@@ -98,6 +100,13 @@ table CustomSerializerData {
|
||||
table ConfigTableData {
|
||||
}
|
||||
|
||||
table RayResource {
|
||||
// The type of the resource.
|
||||
resource_name: string;
|
||||
// The total capacity of this resource type.
|
||||
resource_capacity: double;
|
||||
}
|
||||
|
||||
table ClientTableData {
|
||||
// The client ID of the client that the message is about.
|
||||
client_id: string;
|
||||
@@ -105,26 +114,24 @@ table ClientTableData {
|
||||
node_manager_address: string;
|
||||
// The port at which the client's node manager is listening for TCP
|
||||
// connections from other node managers.
|
||||
local_scheduler_port: int;
|
||||
node_manager_port: int;
|
||||
// The port at which the client's object manager is listening for TCP
|
||||
// connections from other object managers.
|
||||
object_manager_port: int;
|
||||
// True if the message is about the addition of a client and false if it is
|
||||
// about the deletion of a client.
|
||||
is_insertion: bool;
|
||||
resources_total_label: [string];
|
||||
resources_total_capacity: [double];
|
||||
}
|
||||
|
||||
table Resource {
|
||||
// The type of the resource.
|
||||
resource_name: string;
|
||||
// The total capacity of this resource type.
|
||||
resource_capacity: double;
|
||||
}
|
||||
|
||||
table NodeManagerHeartbeat {
|
||||
// The available resources on this node manager. This information may be
|
||||
// stale.
|
||||
resources_available: [Resource];
|
||||
// The total resources on this node manager.
|
||||
resources_total: [Resource];
|
||||
table HeartbeatTableData {
|
||||
// Node manager client id
|
||||
client_id: string;
|
||||
// Resource capacity currently available on this node manager.
|
||||
resources_available_label: [string];
|
||||
resources_available_capacity: [double];
|
||||
// Total resource capacity configured for this node manager.
|
||||
resources_total_label: [string];
|
||||
resources_total_capacity: [double];
|
||||
}
|
||||
|
||||
@@ -25,7 +25,8 @@ void ProcessCallback(int64_t callback_index, const std::string &data) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace ray {
|
||||
|
||||
@@ -98,9 +99,10 @@ void SubscribeRedisCallback(void *c, void *r, void *privdata) {
|
||||
}
|
||||
|
||||
int64_t RedisCallbackManager::add(const RedisCallback &function) {
|
||||
num_callbacks += 1;
|
||||
callbacks_.emplace(num_callbacks, std::unique_ptr<RedisCallback>(
|
||||
new RedisCallback(function)));
|
||||
return num_callbacks++;
|
||||
return num_callbacks;
|
||||
}
|
||||
|
||||
RedisCallbackManager::RedisCallback &RedisCallbackManager::get(
|
||||
|
||||
@@ -49,7 +49,8 @@ class RedisCallbackManager {
|
||||
|
||||
class RedisContext {
|
||||
public:
|
||||
RedisContext() {}
|
||||
RedisContext()
|
||||
: context_(nullptr), async_context_(nullptr), subscribe_context_(nullptr) {}
|
||||
~RedisContext();
|
||||
Status Connect(const std::string &address, int port);
|
||||
Status AttachToEventLoop(aeEventLoop *loop);
|
||||
|
||||
@@ -89,8 +89,8 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
<< "Client called Subscribe twice on the same table";
|
||||
auto d = std::shared_ptr<CallbackData>(
|
||||
new CallbackData({client_id, nullptr, subscribe, done, this, client_}));
|
||||
int64_t callback_index = RedisCallbackManager::instance().add(
|
||||
[this, d](const std::string &data) {
|
||||
int64_t callback_index =
|
||||
RedisCallbackManager::instance().add([this, d](const std::string &data) {
|
||||
if (data.empty()) {
|
||||
// No notification data is provided. This is the callback for the
|
||||
// initial subscription request.
|
||||
@@ -115,7 +115,7 @@ Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
results.emplace_back(std::move(result));
|
||||
}
|
||||
(d->callback)(d->client, id, results);
|
||||
}
|
||||
}
|
||||
}
|
||||
// We do not delete the callback after calling it since there may be
|
||||
// more subscription messages.
|
||||
@@ -270,9 +270,12 @@ const ClientID &ClientTable::GetLocalClientId() { return client_id_; }
|
||||
|
||||
const ClientTableDataT &ClientTable::GetLocalClient() { return local_client_; }
|
||||
|
||||
Status ClientTable::Connect() {
|
||||
Status ClientTable::Connect(const ClientTableDataT &local_client) {
|
||||
RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client.";
|
||||
|
||||
RAY_CHECK(local_client.client_id == local_client_.client_id);
|
||||
local_client_ = local_client;
|
||||
|
||||
auto data = std::make_shared<ClientTableDataT>(local_client_);
|
||||
data->is_insertion = true;
|
||||
// Callback for a notification from the client table.
|
||||
@@ -336,6 +339,7 @@ template class Log<TaskID, ray::protocol::Task>;
|
||||
template class Table<TaskID, ray::protocol::Task>;
|
||||
template class Table<TaskID, TaskTableData>;
|
||||
template class Log<TaskID, TaskReconstructionData>;
|
||||
template class Table<ClientID, HeartbeatTableData>;
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
|
||||
+40
-12
@@ -167,13 +167,23 @@ class Log {
|
||||
int64_t subscribe_callback_index_;
|
||||
};
|
||||
|
||||
template <typename ID, typename Data>
|
||||
class TableInterface {
|
||||
public:
|
||||
using DataT = typename Data::NativeTableType;
|
||||
using WriteCallback = typename Log<ID, Data>::WriteCallback;
|
||||
virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr<DataT> data,
|
||||
const WriteCallback &done) = 0;
|
||||
virtual ~TableInterface(){};
|
||||
};
|
||||
|
||||
/// \class Table
|
||||
///
|
||||
/// A GCS table where every entry is a single data item.
|
||||
/// Example tables backed by Log:
|
||||
/// TaskTable: Stores Task metadata needed for executing the task.
|
||||
template <typename ID, typename Data>
|
||||
class Table : private Log<ID, Data> {
|
||||
class Table : private Log<ID, Data>, public TableInterface<ID, Data> {
|
||||
public:
|
||||
using DataT = typename Log<ID, Data>::DataT;
|
||||
using Callback =
|
||||
@@ -242,6 +252,17 @@ class ObjectTable : public Log<ObjectID, ObjectTableData> {
|
||||
pubsub_channel_ = TablePubsub_OBJECT;
|
||||
prefix_ = TablePrefix_OBJECT;
|
||||
};
|
||||
virtual ~ObjectTable(){};
|
||||
};
|
||||
|
||||
class HeartbeatTable : public Table<ClientID, HeartbeatTableData> {
|
||||
public:
|
||||
HeartbeatTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
|
||||
: Table(context, client) {
|
||||
pubsub_channel_ = TablePubsub_HEARTBEAT;
|
||||
prefix_ = TablePrefix_HEARTBEAT;
|
||||
}
|
||||
virtual ~HeartbeatTable() {}
|
||||
};
|
||||
|
||||
class FunctionTable : public Table<ObjectID, FunctionTableData> {
|
||||
@@ -277,7 +298,8 @@ class TaskTable : public Table<TaskID, ray::protocol::Task> {
|
||||
prefix_ = TablePrefix_RAYLET_TASK;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
class TaskTable : public Table<TaskID, TaskTableData> {
|
||||
public:
|
||||
@@ -286,6 +308,7 @@ class TaskTable : public Table<TaskID, TaskTableData> {
|
||||
pubsub_channel_ = TablePubsub_TASK;
|
||||
prefix_ = TablePrefix_TASK;
|
||||
};
|
||||
~TaskTable(){};
|
||||
|
||||
using TestAndUpdateCallback =
|
||||
std::function<void(AsyncGcsClient *client, const TaskID &id,
|
||||
@@ -350,12 +373,6 @@ class TaskTable : public Table<TaskID, TaskTableData> {
|
||||
const Callback &done);
|
||||
};
|
||||
|
||||
using ErrorTable = Table<TaskID, ErrorTableData>;
|
||||
|
||||
using CustomSerializerTable = Table<ClassID, CustomSerializerData>;
|
||||
|
||||
using ConfigTable = Table<ConfigID, ConfigTableData>;
|
||||
|
||||
Status TaskTableAdd(AsyncGcsClient *gcs_client, Task *task);
|
||||
|
||||
Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id,
|
||||
@@ -363,6 +380,12 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id,
|
||||
SchedulingState update_state,
|
||||
const TaskTable::TestAndUpdateCallback &callback);
|
||||
|
||||
using ErrorTable = Table<TaskID, ErrorTableData>;
|
||||
|
||||
using CustomSerializerTable = Table<ClassID, CustomSerializerData>;
|
||||
|
||||
using ConfigTable = Table<ConfigID, ConfigTableData>;
|
||||
|
||||
/// \class ClientTable
|
||||
///
|
||||
/// The ClientTable stores information about active and inactive clients. It is
|
||||
@@ -377,17 +400,20 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
|
||||
using ClientTableCallback = std::function<void(
|
||||
AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data)>;
|
||||
ClientTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client,
|
||||
const ClientTableDataT &local_client)
|
||||
const ClientID &client_id)
|
||||
: Log(context, client),
|
||||
// We set the client log's key equal to nil so that all instances of
|
||||
// ClientTable have the same key.
|
||||
client_log_key_(UniqueID::nil()),
|
||||
disconnected_(false),
|
||||
client_id_(ClientID::from_binary(local_client.client_id)),
|
||||
local_client_(local_client) {
|
||||
client_id_(client_id),
|
||||
local_client_() {
|
||||
pubsub_channel_ = TablePubsub_CLIENT;
|
||||
prefix_ = TablePrefix_CLIENT;
|
||||
|
||||
// Set the local client's ID.
|
||||
local_client_.client_id = client_id.binary();
|
||||
|
||||
// Add a nil client to the cache so that we can serve requests for clients
|
||||
// that we have not heard about.
|
||||
ClientTableDataT nil_client;
|
||||
@@ -398,8 +424,10 @@ class ClientTable : private Log<UniqueID, ClientTableData> {
|
||||
/// Connect as a client to the GCS. This registers us in the client table
|
||||
/// and begins subscription to client table notifications.
|
||||
///
|
||||
/// \param Information about the connecting client. This must have the
|
||||
/// same client_id as the one set in the client table.
|
||||
/// \return Status
|
||||
ray::Status Connect();
|
||||
ray::Status Connect(const ClientTableDataT &local_client);
|
||||
|
||||
/// Disconnect the client from the GCS. The client ID assigned during
|
||||
/// registration should never be reused after disconnecting.
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
#include "ray/gcs/tables.h"
|
||||
|
||||
#include "ray/gcs/client.h"
|
||||
#include "ray/id.h"
|
||||
|
||||
#include "common_protocol.h"
|
||||
#include "task.h"
|
||||
|
||||
// TODO(swang): This file extends tables.cc so that we can separate out the
|
||||
// part that depends on the Task* datasturcture from the build. This should be
|
||||
// merged with tables.cc once we get rid of the Task* datastructure.
|
||||
// part that depends on the legacy Task* data structure from the build. This
|
||||
// should be merged with tables.cc once we get rid of the legacy Task*
|
||||
// datastructure.
|
||||
|
||||
namespace {
|
||||
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "ray/constants.h"
|
||||
#include "ray/status.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
UniqueID::UniqueID(const plasma::UniqueID &from) {
|
||||
@@ -83,4 +86,48 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id) {
|
||||
return os;
|
||||
}
|
||||
|
||||
const ObjectID ComputeObjectId(TaskID task_id, int64_t object_index) {
|
||||
RAY_CHECK(object_index <= kMaxTaskReturns && object_index >= -kMaxTaskPuts);
|
||||
ObjectID return_id = task_id;
|
||||
int64_t *first_bytes = reinterpret_cast<int64_t *>(&return_id);
|
||||
// Zero out the lowest kObjectIdIndexSize bits of the first byte of the
|
||||
// object ID.
|
||||
uint64_t bitmask = static_cast<uint64_t>(-1) << kObjectIdIndexSize;
|
||||
*first_bytes = *first_bytes & (bitmask);
|
||||
// OR the first byte of the object ID with the return index.
|
||||
*first_bytes = *first_bytes | (object_index & ~bitmask);
|
||||
return return_id;
|
||||
}
|
||||
|
||||
const TaskID FinishTaskId(const TaskID &task_id) { return ComputeObjectId(task_id, 0); }
|
||||
|
||||
const ObjectID ComputeReturnId(TaskID task_id, int64_t return_index) {
|
||||
RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns);
|
||||
return ComputeObjectId(task_id, return_index);
|
||||
}
|
||||
|
||||
const ObjectID ComputePutId(TaskID task_id, int64_t put_index) {
|
||||
RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts);
|
||||
return ComputeObjectId(task_id, -1 * put_index);
|
||||
}
|
||||
|
||||
const TaskID ComputeTaskId(const ObjectID &object_id) {
|
||||
TaskID task_id = object_id;
|
||||
int64_t *first_bytes = reinterpret_cast<int64_t *>(&task_id);
|
||||
// Zero out the lowest kObjectIdIndexSize bits of the first byte of the
|
||||
// object ID.
|
||||
uint64_t bitmask = static_cast<uint64_t>(-1) << kObjectIdIndexSize;
|
||||
*first_bytes = *first_bytes & (bitmask);
|
||||
return task_id;
|
||||
}
|
||||
|
||||
int64_t ComputeObjectIndex(const ObjectID &object_id) {
|
||||
const int64_t *first_bytes = reinterpret_cast<const int64_t *>(&object_id);
|
||||
uint64_t bitmask = static_cast<uint64_t>(-1) << kObjectIdIndexSize;
|
||||
int64_t index = *first_bytes & (~bitmask);
|
||||
index <<= (8 * sizeof(int64_t) - kObjectIdIndexSize);
|
||||
index >>= (8 * sizeof(int64_t) - kObjectIdIndexSize);
|
||||
return index;
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
+38
-3
@@ -10,9 +10,6 @@
|
||||
#include "ray/constants.h"
|
||||
#include "ray/util/visibility.h"
|
||||
|
||||
// TODO(swang): Make task ID prefix of any object ID return values and puts so
|
||||
// that we can co-locate task and object entries in the GCS.
|
||||
|
||||
namespace ray {
|
||||
|
||||
class RAY_EXPORT UniqueID {
|
||||
@@ -61,6 +58,44 @@ typedef UniqueID DriverID;
|
||||
typedef UniqueID ConfigID;
|
||||
typedef UniqueID ClientID;
|
||||
|
||||
// TODO(swang): ObjectID and TaskID should derive from UniqueID. Then, we
|
||||
// can make these methods of the derived classes.
|
||||
/// Finish computing a task ID. Since objects created by the task share a
|
||||
/// prefix of the ID, the suffix of the task ID is zeroed out by this function.
|
||||
///
|
||||
/// \param task_id A task ID to finish.
|
||||
/// \return The finished task ID. It may now be used to compute IDs for objects
|
||||
/// created by the task.
|
||||
const TaskID FinishTaskId(const TaskID &task_id);
|
||||
|
||||
/// Compute the object ID of an object returned by the task.
|
||||
///
|
||||
/// \param task_id The task ID of the task that created the object.
|
||||
/// \param put_index What number return value this object is in the task.
|
||||
/// \return The computed object ID.
|
||||
const ObjectID ComputeReturnId(TaskID task_id, int64_t return_index);
|
||||
|
||||
/// Compute the object ID of an object put by the task.
|
||||
///
|
||||
/// \param task_id The task ID of the task that created the object.
|
||||
/// \param put_index What number put this object was created by in the task.
|
||||
/// \return The computed object ID.
|
||||
const ObjectID ComputePutId(TaskID task_id, int64_t put_index);
|
||||
|
||||
/// Compute the task ID of the task that created the object.
|
||||
///
|
||||
/// \param object_id The object ID.
|
||||
/// \return The task ID of the task that created this object.
|
||||
const TaskID ComputeTaskId(const ObjectID &object_id);
|
||||
|
||||
/// Compute the index of this object in the task that created it.
|
||||
///
|
||||
/// \param object_id The object ID.
|
||||
/// \return The index of object creation according to the task that created
|
||||
/// this object. This is positive if the task returned the object and negative
|
||||
/// if created by a put.
|
||||
int64_t ComputeObjectIndex(const ObjectID &object_id);
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_ID_H_
|
||||
|
||||
@@ -19,7 +19,8 @@ add_custom_command(
|
||||
|
||||
add_custom_target(gen_object_manager_fbs DEPENDS ${OBJECT_MANAGER_FBS_OUTPUT_FILES})
|
||||
|
||||
ADD_RAY_TEST(object_manager_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
ADD_RAY_TEST(test/object_manager_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
ADD_RAY_TEST(test/object_manager_stress_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
|
||||
add_library(object_manager object_manager.cc object_manager.h ${OBJECT_MANAGER_FBS_OUTPUT_FILES})
|
||||
target_link_libraries(object_manager common ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} ${Boost_SYSTEM_LIBRARY})
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
#include "ray/object_manager/connection_pool.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
ConnectionPool::ConnectionPool() {}
|
||||
|
||||
void ConnectionPool::RegisterReceiver(ConnectionType type, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> &conn) {
|
||||
std::unique_lock<std::mutex> guard(connection_mutex);
|
||||
switch (type) {
|
||||
case ConnectionType::MESSAGE: {
|
||||
Add(message_receive_connections_, client_id, conn);
|
||||
} break;
|
||||
case ConnectionType::TRANSFER: {
|
||||
Add(transfer_receive_connections_, client_id, conn);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void ConnectionPool::RemoveReceiver(ConnectionType type, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> &conn) {
|
||||
std::unique_lock<std::mutex> guard(connection_mutex);
|
||||
switch (type) {
|
||||
case ConnectionType::MESSAGE: {
|
||||
Remove(message_receive_connections_, client_id, conn);
|
||||
} break;
|
||||
case ConnectionType::TRANSFER: {
|
||||
Remove(transfer_receive_connections_, client_id, conn);
|
||||
} break;
|
||||
}
|
||||
// TODO(hme): appropriately dispose of client connection.
|
||||
}
|
||||
|
||||
void ConnectionPool::RegisterSender(ConnectionType type, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> &conn) {
|
||||
std::unique_lock<std::mutex> guard(connection_mutex);
|
||||
SenderMapType &conn_map = (type == ConnectionType::MESSAGE)
|
||||
? message_send_connections_
|
||||
: transfer_send_connections_;
|
||||
Add(conn_map, client_id, conn);
|
||||
// Don't add to available connections. It will become available once it is released.
|
||||
}
|
||||
|
||||
ray::Status ConnectionPool::GetSender(ConnectionType type, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> *conn) {
|
||||
std::unique_lock<std::mutex> guard(connection_mutex);
|
||||
SenderMapType &avail_conn_map = (type == ConnectionType::MESSAGE)
|
||||
? available_message_send_connections_
|
||||
: available_transfer_send_connections_;
|
||||
if (Count(avail_conn_map, client_id) > 0) {
|
||||
*conn = Borrow(avail_conn_map, client_id);
|
||||
} else {
|
||||
*conn = nullptr;
|
||||
}
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status ConnectionPool::ReleaseSender(ConnectionType type,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
std::unique_lock<std::mutex> guard(connection_mutex);
|
||||
SenderMapType &conn_map = (type == ConnectionType::MESSAGE)
|
||||
? available_message_send_connections_
|
||||
: available_transfer_send_connections_;
|
||||
Return(conn_map, conn->GetClientID(), conn);
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
void ConnectionPool::Add(ReceiverMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> conn) {
|
||||
conn_map[client_id].push_back(conn);
|
||||
}
|
||||
|
||||
void ConnectionPool::Add(SenderMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
conn_map[client_id].push_back(conn);
|
||||
}
|
||||
|
||||
void ConnectionPool::Remove(ReceiverMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> conn) {
|
||||
if (conn_map.count(client_id) == 0) {
|
||||
return;
|
||||
}
|
||||
std::vector<std::shared_ptr<TcpClientConnection>> &connections = conn_map[client_id];
|
||||
int64_t pos =
|
||||
std::find(connections.begin(), connections.end(), conn) - connections.begin();
|
||||
if (pos >= (int64_t)connections.size()) {
|
||||
return;
|
||||
}
|
||||
connections.erase(connections.begin() + pos);
|
||||
}
|
||||
|
||||
uint64_t ConnectionPool::Count(SenderMapType &conn_map, const ClientID &client_id) {
|
||||
if (conn_map.count(client_id) == 0) {
|
||||
return 0;
|
||||
};
|
||||
return conn_map[client_id].size();
|
||||
}
|
||||
|
||||
std::shared_ptr<SenderConnection> ConnectionPool::Borrow(SenderMapType &conn_map,
|
||||
const ClientID &client_id) {
|
||||
std::shared_ptr<SenderConnection> conn = conn_map[client_id].back();
|
||||
conn_map[client_id].pop_back();
|
||||
RAY_LOG(DEBUG) << "Borrow " << client_id << " " << conn_map[client_id].size();
|
||||
return conn;
|
||||
}
|
||||
|
||||
void ConnectionPool::Return(SenderMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
conn_map[client_id].push_back(conn);
|
||||
RAY_LOG(DEBUG) << "Return " << client_id << " " << conn_map[client_id].size();
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
@@ -0,0 +1,143 @@
|
||||
#ifndef RAY_OBJECT_MANAGER_CONNECTION_POOL_H
|
||||
#define RAY_OBJECT_MANAGER_CONNECTION_POOL_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/asio/error.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
|
||||
#include "ray/id.h"
|
||||
#include "ray/status.h"
|
||||
|
||||
#include <mutex>
|
||||
#include "ray/object_manager/format/object_manager_generated.h"
|
||||
#include "ray/object_manager/object_directory.h"
|
||||
#include "ray/object_manager/object_manager_client_connection.h"
|
||||
|
||||
namespace asio = boost::asio;
|
||||
|
||||
namespace ray {
|
||||
|
||||
class ConnectionPool {
|
||||
public:
|
||||
/// Callbacks for GetSender.
|
||||
using SuccessCallback = std::function<void(std::shared_ptr<SenderConnection>)>;
|
||||
using FailureCallback = std::function<void()>;
|
||||
|
||||
/// Connection type to distinguish between message and transfer connections.
|
||||
enum class ConnectionType : int { MESSAGE = 0, TRANSFER };
|
||||
|
||||
/// Connection pool for all connections needed by the ObjectManager.
|
||||
ConnectionPool();
|
||||
|
||||
/// Register a receiver connection.
|
||||
///
|
||||
/// \param type The type of connection.
|
||||
/// \param client_id The ClientID of the remote object manager.
|
||||
/// \param conn The actual connection.
|
||||
void RegisterReceiver(ConnectionType type, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> &conn);
|
||||
|
||||
/// Remove a receiver connection.
|
||||
///
|
||||
/// \param type The type of connection.
|
||||
/// \param client_id The ClientID of the remote object manager.
|
||||
/// \param conn The actual connection.
|
||||
void RemoveReceiver(ConnectionType type, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> &conn);
|
||||
|
||||
/// Register a receiver connection.
|
||||
///
|
||||
/// \param type The type of connection.
|
||||
/// \param client_id The ClientID of the remote object manager.
|
||||
/// \param conn The actual connection.
|
||||
void RegisterSender(ConnectionType type, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> &conn);
|
||||
|
||||
/// Get a sender connection from the connection pool.
|
||||
/// The connection must be released or removed when the operation for which the
|
||||
/// connection was obtained is completed. If the connection pool is empty, the
|
||||
/// connection pointer passed in is set to a null pointer.
|
||||
///
|
||||
/// \param[in] type The type of connection.
|
||||
/// \param[in] client_id The ClientID of the remote object manager.
|
||||
/// \param[out] conn An empty pointer to a shared pointer.
|
||||
/// \return Status of invoking this method.
|
||||
ray::Status GetSender(ConnectionType type, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> *conn);
|
||||
|
||||
/// Releases a sender connection, allowing it to be used by another operation.
|
||||
///
|
||||
/// \param type The type of connection.
|
||||
/// \param conn The actual connection.
|
||||
/// \return Status of invoking this method.
|
||||
ray::Status ReleaseSender(ConnectionType type, std::shared_ptr<SenderConnection> conn);
|
||||
|
||||
// TODO(hme): Implement with error handling.
|
||||
/// Remove a sender connection. This is invoked if the connection is no longer
|
||||
/// usable.
|
||||
///
|
||||
/// \param type The type of connection.
|
||||
/// \param conn The actual connection.
|
||||
/// \return Status of invoking this method.
|
||||
ray::Status RemoveSender(ConnectionType type, std::shared_ptr<SenderConnection> conn);
|
||||
|
||||
/// This object cannot be copied for thread-safety.
|
||||
ConnectionPool &operator=(const ConnectionPool &o) {
|
||||
throw std::runtime_error("Can't copy ConnectionPool.");
|
||||
}
|
||||
|
||||
private:
|
||||
/// A container type that maps ClientID to a connection type.
|
||||
using SenderMapType =
|
||||
std::unordered_map<ray::ClientID, std::vector<std::shared_ptr<SenderConnection>>,
|
||||
ray::UniqueIDHasher>;
|
||||
using ReceiverMapType =
|
||||
std::unordered_map<ray::ClientID, std::vector<std::shared_ptr<TcpClientConnection>>,
|
||||
ray::UniqueIDHasher>;
|
||||
|
||||
/// Adds a receiver for ClientID to the given map.
|
||||
void Add(ReceiverMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> conn);
|
||||
|
||||
/// Adds a sender for ClientID to the given map.
|
||||
void Add(SenderMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> conn);
|
||||
|
||||
/// Removes the given receiver for ClientID from the given map.
|
||||
void Remove(ReceiverMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<TcpClientConnection> conn);
|
||||
|
||||
/// Returns the count of sender connections to ClientID.
|
||||
uint64_t Count(SenderMapType &conn_map, const ClientID &client_id);
|
||||
|
||||
/// Removes a sender connection to ClientID from the pool of available connections.
|
||||
/// This method assumes conn_map has available connections to ClientID.
|
||||
std::shared_ptr<SenderConnection> Borrow(SenderMapType &conn_map,
|
||||
const ClientID &client_id);
|
||||
|
||||
/// Returns a sender connection to ClientID to the pool of available connections.
|
||||
void Return(SenderMapType &conn_map, const ClientID &client_id,
|
||||
std::shared_ptr<SenderConnection> conn);
|
||||
|
||||
// TODO(hme): Optimize with separate mutex per collection.
|
||||
std::mutex connection_mutex;
|
||||
|
||||
SenderMapType message_send_connections_;
|
||||
SenderMapType transfer_send_connections_;
|
||||
SenderMapType available_message_send_connections_;
|
||||
SenderMapType available_transfer_send_connections_;
|
||||
|
||||
ReceiverMapType message_receive_connections_;
|
||||
ReceiverMapType transfer_receive_connections_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_OBJECT_MANAGER_CONNECTION_POOL_H
|
||||
@@ -1,30 +1,37 @@
|
||||
// Object Manager protocol specification
|
||||
namespace ray.object_manager.protocol;
|
||||
|
||||
enum OMMessageType:int {
|
||||
PullRequest = 1
|
||||
enum MessageType:int {
|
||||
ConnectClient = 1,
|
||||
DisconnectClient,
|
||||
PushRequest,
|
||||
PullRequest
|
||||
}
|
||||
|
||||
table PushRequest {
|
||||
|
||||
table PushRequestMessage {
|
||||
// The object ID being transferred.
|
||||
object_id: string;
|
||||
// The size of the object being transferred.
|
||||
object_size: ulong;
|
||||
}
|
||||
|
||||
table PullRequest {
|
||||
table PullRequestMessage {
|
||||
// ID of the requesting client.
|
||||
client_id: string;
|
||||
// Requested ObjectID.
|
||||
object_id: string;
|
||||
}
|
||||
|
||||
table ClientConnectionInfo {
|
||||
table ConnectClientMessage {
|
||||
// ID of the connecting client.
|
||||
client_id: string;
|
||||
// Whether this is a transfer connection.
|
||||
is_transfer: bool;
|
||||
}
|
||||
|
||||
table ObjectHeader {
|
||||
// The object ID being transferred.
|
||||
object_id: string;
|
||||
// The size of the object being transferred.
|
||||
object_size: ulong;
|
||||
table DisconnectClientMessage {
|
||||
// ID of the connecting client.
|
||||
client_id: string;
|
||||
// Whether this is a transfer connection.
|
||||
is_transfer: bool;
|
||||
}
|
||||
|
||||
@@ -1,41 +1,53 @@
|
||||
#include "object_directory.h"
|
||||
#include "ray/object_manager/object_directory.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
ObjectDirectory::ObjectDirectory(std::shared_ptr<GcsClient> gcs_client) {
|
||||
ObjectDirectory::ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> gcs_client) {
|
||||
gcs_client_ = gcs_client;
|
||||
};
|
||||
|
||||
ray::Status ObjectDirectory::ReportObjectAdded(const ObjectID &object_id,
|
||||
const ClientID &client_id) {
|
||||
return gcs_client_->object_table().Add(object_id, client_id, [] {});
|
||||
// TODO(hme): Determine whether we need to do lookup to append.
|
||||
JobID job_id = JobID::from_random();
|
||||
auto data = std::make_shared<ObjectTableDataT>();
|
||||
data->manager = client_id.binary();
|
||||
data->is_eviction = false;
|
||||
ray::Status status = gcs_client_->object_table().Append(
|
||||
job_id, object_id, data, [](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
const std::shared_ptr<ObjectTableDataT> data) {
|
||||
// Do nothing.
|
||||
});
|
||||
return status;
|
||||
};
|
||||
|
||||
ray::Status ObjectDirectory::ReportObjectRemoved(const ObjectID &object_id,
|
||||
const ClientID &client_id) {
|
||||
return gcs_client_->object_table().Remove(object_id, client_id, [] {});
|
||||
// TODO(hme): Need corresponding remove method in GCS.
|
||||
return ray::Status::NotImplemented("ObjectTable.Remove is not implemented");
|
||||
};
|
||||
|
||||
ray::Status ObjectDirectory::GetInformation(const ClientID &client_id,
|
||||
const InfoSuccessCallback &success_cb,
|
||||
const InfoFailureCallback &fail_cb) {
|
||||
gcs_client_->client_table().GetClientInformation(
|
||||
client_id,
|
||||
[this, success_cb, client_id](ClientInformation client_info) {
|
||||
const auto &info =
|
||||
RemoteConnectionInfo(client_id, client_info.GetIp(), client_info.GetPort());
|
||||
success_cb(info);
|
||||
},
|
||||
fail_cb);
|
||||
const InfoSuccessCallback &success_callback,
|
||||
const InfoFailureCallback &fail_callback) {
|
||||
const ClientTableDataT &data = gcs_client_->client_table().GetClient(client_id);
|
||||
ClientID result_client_id = ClientID::from_binary(data.client_id);
|
||||
if (result_client_id == ClientID::nil() || !data.is_insertion) {
|
||||
fail_callback(ray::Status::RedisError("ClientID not found."));
|
||||
} else {
|
||||
const auto &info = RemoteConnectionInfo(client_id, data.node_manager_address,
|
||||
(uint16_t)data.object_manager_port);
|
||||
success_callback(info);
|
||||
}
|
||||
return ray::Status::OK();
|
||||
};
|
||||
|
||||
ray::Status ObjectDirectory::GetLocations(const ObjectID &object_id,
|
||||
const OnLocationsSuccess &success_cb,
|
||||
const OnLocationsFailure &fail_cb) {
|
||||
const OnLocationsSuccess &success_callback,
|
||||
const OnLocationsFailure &fail_callback) {
|
||||
ray::Status status_code = ray::Status::OK();
|
||||
if (existing_requests_.count(object_id) == 0) {
|
||||
existing_requests_[object_id] = ODCallbacks({success_cb, fail_cb});
|
||||
existing_requests_[object_id] = ODCallbacks({success_callback, fail_callback});
|
||||
status_code = ExecuteGetLocations(object_id);
|
||||
} else {
|
||||
// Do nothing. A request is in progress.
|
||||
@@ -44,52 +56,44 @@ ray::Status ObjectDirectory::GetLocations(const ObjectID &object_id,
|
||||
};
|
||||
|
||||
ray::Status ObjectDirectory::ExecuteGetLocations(const ObjectID &object_id) {
|
||||
// TODO(hme): Avoid callback hell.
|
||||
std::vector<RemoteConnectionInfo> remote_connections;
|
||||
ray::Status status = gcs_client_->object_table().GetObjectClientIDs(
|
||||
object_id,
|
||||
[this, object_id, &remote_connections](const std::vector<ClientID> &client_ids) {
|
||||
gcs_client_->client_table().GetClientInformationSet(
|
||||
client_ids,
|
||||
[this, object_id,
|
||||
&remote_connections](const std::vector<ClientInformation> &info_vec) {
|
||||
for (const auto &client_info : info_vec) {
|
||||
RemoteConnectionInfo info =
|
||||
RemoteConnectionInfo(client_info.GetClientId(), client_info.GetIp(),
|
||||
client_info.GetPort());
|
||||
remote_connections.push_back(info);
|
||||
}
|
||||
ray::Status cb_completion_status =
|
||||
GetLocationsComplete(Status::OK(), object_id, remote_connections);
|
||||
},
|
||||
[this, object_id, &remote_connections](const Status &status) {
|
||||
ray::Status cb_completion_status =
|
||||
GetLocationsComplete(status, object_id, remote_connections);
|
||||
});
|
||||
},
|
||||
[this, object_id, &remote_connections](const Status &status) {
|
||||
ray::Status cb_completion_status =
|
||||
GetLocationsComplete(status, object_id, remote_connections);
|
||||
JobID job_id = JobID::from_random();
|
||||
// Note: Lookup must be synchronous for thread-safe access.
|
||||
// For now, this is only accessed by the main thread.
|
||||
ray::Status status = gcs_client_->object_table().Lookup(
|
||||
job_id, object_id,
|
||||
[this, object_id](gcs::AsyncGcsClient *client, const ObjectID &object_id,
|
||||
const std::vector<ObjectTableDataT> &data) {
|
||||
GetLocationsComplete(object_id, data);
|
||||
});
|
||||
return status;
|
||||
};
|
||||
|
||||
ray::Status ObjectDirectory::GetLocationsComplete(
|
||||
const ray::Status &status, const ObjectID &object_id,
|
||||
const std::vector<RemoteConnectionInfo> &remote_connections) {
|
||||
bool success = status.ok();
|
||||
// Only invoke a callback if the request was not cancelled.
|
||||
if (existing_requests_.count(object_id) > 0) {
|
||||
ODCallbacks cbs = existing_requests_[object_id];
|
||||
if (success) {
|
||||
cbs.success_cb(remote_connections, object_id);
|
||||
void ObjectDirectory::GetLocationsComplete(
|
||||
const ObjectID &object_id, const std::vector<ObjectTableDataT> &location_entries) {
|
||||
auto request = existing_requests_.find(object_id);
|
||||
// Do not invoke a callback if the request was cancelled.
|
||||
if (request == existing_requests_.end()) {
|
||||
return;
|
||||
}
|
||||
// Build the set of current locations based on the entries in the log.
|
||||
std::unordered_set<ClientID, UniqueIDHasher> locations;
|
||||
for (auto entry : location_entries) {
|
||||
ClientID client_id = ClientID::from_binary(entry.manager);
|
||||
if (!entry.is_eviction) {
|
||||
locations.insert(client_id);
|
||||
} else {
|
||||
cbs.fail_cb(status, object_id);
|
||||
locations.erase(client_id);
|
||||
}
|
||||
}
|
||||
existing_requests_.erase(object_id);
|
||||
return status;
|
||||
};
|
||||
// Invoke the callback.
|
||||
std::vector<ClientID> locations_vector(locations.begin(), locations.end());
|
||||
if (locations_vector.empty()) {
|
||||
request->second.fail_cb(object_id);
|
||||
} else {
|
||||
request->second.success_cb(locations_vector, object_id);
|
||||
}
|
||||
existing_requests_.erase(request);
|
||||
}
|
||||
|
||||
ray::Status ObjectDirectory::Cancel(const ObjectID &object_id) {
|
||||
existing_requests_.erase(object_id);
|
||||
|
||||
@@ -2,17 +2,19 @@
|
||||
#define RAY_OBJECT_MANAGER_OBJECT_DIRECTORY_H
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "ray/gcs/client.h"
|
||||
#include "ray/id.h"
|
||||
#include "ray/raylet/mock_gcs_client.h"
|
||||
#include "ray/status.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
struct RemoteConnectionInfo {
|
||||
RemoteConnectionInfo() = default;
|
||||
RemoteConnectionInfo(const ClientID &id, const std::string &ip_address,
|
||||
uint16_t port_num)
|
||||
: client_id(id), ip(ip_address), port(port_num) {}
|
||||
@@ -42,10 +44,9 @@ class ObjectDirectoryInterface {
|
||||
const InfoFailureCallback &fail_cb) = 0;
|
||||
|
||||
// Callbacks for GetLocations.
|
||||
using OnLocationsSuccess = std::function<void(
|
||||
const std::vector<ray::RemoteConnectionInfo> &v, const ray::ObjectID &object_id)>;
|
||||
using OnLocationsFailure =
|
||||
std::function<void(ray::Status status, const ray::ObjectID &object_id)>;
|
||||
using OnLocationsSuccess = std::function<void(const std::vector<ray::ClientID> &v,
|
||||
const ray::ObjectID &object_id)>;
|
||||
using OnLocationsFailure = std::function<void(const ray::ObjectID &object_id)>;
|
||||
|
||||
/// Asynchronously obtain the locations of an object by ObjectID.
|
||||
/// This is used to handle object pulls.
|
||||
@@ -93,11 +94,11 @@ class ObjectDirectory : public ObjectDirectoryInterface {
|
||||
~ObjectDirectory() override = default;
|
||||
|
||||
ray::Status GetInformation(const ClientID &client_id,
|
||||
const InfoSuccessCallback &success_cb,
|
||||
const InfoFailureCallback &fail_cb) override;
|
||||
const InfoSuccessCallback &success_callback,
|
||||
const InfoFailureCallback &fail_callback) override;
|
||||
ray::Status GetLocations(const ObjectID &object_id,
|
||||
const OnLocationsSuccess &success_cb,
|
||||
const OnLocationsFailure &fail_cb) override;
|
||||
const OnLocationsSuccess &success_callback,
|
||||
const OnLocationsFailure &fail_callback) override;
|
||||
ray::Status Cancel(const ObjectID &object_id) override;
|
||||
ray::Status Terminate() override;
|
||||
ray::Status ReportObjectAdded(const ObjectID &object_id,
|
||||
@@ -105,7 +106,12 @@ class ObjectDirectory : public ObjectDirectoryInterface {
|
||||
ray::Status ReportObjectRemoved(const ObjectID &object_id,
|
||||
const ClientID &client_id) override;
|
||||
/// Ray only (not part of the OD interface).
|
||||
ObjectDirectory(std::shared_ptr<GcsClient> gcs_client);
|
||||
ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
|
||||
|
||||
/// This object cannot be copied for thread-safety.
|
||||
ObjectDirectory &operator=(const ObjectDirectory &o) {
|
||||
throw std::runtime_error("Can't copy ObjectDirectory.");
|
||||
}
|
||||
|
||||
private:
|
||||
/// Callbacks associated with a call to GetLocations.
|
||||
@@ -115,18 +121,17 @@ class ObjectDirectory : public ObjectDirectoryInterface {
|
||||
OnLocationsFailure fail_cb;
|
||||
};
|
||||
|
||||
/// Maintain map of in-flight GetLocation requests.
|
||||
std::unordered_map<ObjectID, ODCallbacks, UniqueIDHasher> existing_requests_;
|
||||
|
||||
/// Reference to the gcs client.
|
||||
std::shared_ptr<GcsClient> gcs_client_;
|
||||
|
||||
/// GetLocations registers a request for locations.
|
||||
/// This function actually carries out that request.
|
||||
ray::Status ExecuteGetLocations(const ObjectID &object_id);
|
||||
/// Invoked when call to ExecuteGetLocations completes.
|
||||
ray::Status GetLocationsComplete(const ray::Status &status, const ObjectID &object_id,
|
||||
const std::vector<RemoteConnectionInfo> &v);
|
||||
void GetLocationsComplete(const ObjectID &object_id,
|
||||
const std::vector<ObjectTableDataT> &location_entries);
|
||||
|
||||
/// Maintain map of in-flight GetLocation requests.
|
||||
std::unordered_map<ObjectID, ODCallbacks, UniqueIDHasher> existing_requests_;
|
||||
/// Reference to the gcs client.
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -1,58 +1,80 @@
|
||||
#include "object_manager.h"
|
||||
#include "ray/object_manager/object_manager.h"
|
||||
|
||||
namespace asio = boost::asio;
|
||||
|
||||
namespace object_manager_protocol = ray::object_manager::protocol;
|
||||
|
||||
namespace ray {
|
||||
|
||||
ObjectManager::ObjectManager(boost::asio::io_service &io_service,
|
||||
ObjectManagerConfig config,
|
||||
std::shared_ptr<ray::GcsClient> gcs_client)
|
||||
: object_directory_(new ObjectDirectory(gcs_client)), work_(io_service_) {
|
||||
ObjectManager::ObjectManager(asio::io_service &main_service,
|
||||
std::unique_ptr<asio::io_service> object_manager_service,
|
||||
const ObjectManagerConfig &config,
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
|
||||
// TODO(hme): Eliminate knowledge of GCS.
|
||||
: client_id_(gcs_client->client_table().GetLocalClientId()),
|
||||
object_directory_(new ObjectDirectory(gcs_client)),
|
||||
store_notification_(main_service, config.store_socket_name),
|
||||
store_pool_(config.store_socket_name),
|
||||
object_manager_service_(std::move(object_manager_service)),
|
||||
work_(*object_manager_service_),
|
||||
connection_pool_(),
|
||||
transfer_queue_(),
|
||||
num_transfers_send_(0),
|
||||
num_transfers_receive_(0) {
|
||||
main_service_ = &main_service;
|
||||
config_ = config;
|
||||
store_client_ = std::unique_ptr<ObjectStoreClient>(
|
||||
new ObjectStoreClient(io_service, config.store_socket_name));
|
||||
store_client_->SubscribeObjAdded(
|
||||
store_notification_.SubscribeObjAdded(
|
||||
[this](const ObjectID &oid) { NotifyDirectoryObjectAdd(oid); });
|
||||
store_client_->SubscribeObjDeleted(
|
||||
store_notification_.SubscribeObjDeleted(
|
||||
[this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); });
|
||||
StartIOService();
|
||||
};
|
||||
}
|
||||
|
||||
ObjectManager::ObjectManager(boost::asio::io_service &io_service,
|
||||
ObjectManagerConfig config,
|
||||
ObjectManager::ObjectManager(asio::io_service &main_service,
|
||||
std::unique_ptr<asio::io_service> object_manager_service,
|
||||
const ObjectManagerConfig &config,
|
||||
std::unique_ptr<ObjectDirectoryInterface> od)
|
||||
: object_directory_(std::move(od)), work_(io_service_) {
|
||||
: object_directory_(std::move(od)),
|
||||
store_notification_(main_service, config.store_socket_name),
|
||||
store_pool_(config.store_socket_name),
|
||||
object_manager_service_(std::move(object_manager_service)),
|
||||
work_(*object_manager_service_),
|
||||
connection_pool_(),
|
||||
transfer_queue_(),
|
||||
num_transfers_send_(0),
|
||||
num_transfers_receive_(0) {
|
||||
// TODO(hme) Client ID is never set with this constructor.
|
||||
main_service_ = &main_service;
|
||||
config_ = config;
|
||||
store_client_ = std::unique_ptr<ObjectStoreClient>(
|
||||
new ObjectStoreClient(io_service, config.store_socket_name));
|
||||
store_client_->SubscribeObjAdded(
|
||||
store_notification_.SubscribeObjAdded(
|
||||
[this](const ObjectID &oid) { NotifyDirectoryObjectAdd(oid); });
|
||||
store_client_->SubscribeObjDeleted(
|
||||
store_notification_.SubscribeObjDeleted(
|
||||
[this](const ObjectID &oid) { NotifyDirectoryObjectDeleted(oid); });
|
||||
StartIOService();
|
||||
};
|
||||
}
|
||||
|
||||
void ObjectManager::StartIOService() {
|
||||
io_thread_ = std::thread(&ObjectManager::IOServiceLoop, this);
|
||||
// thread_group_.create_thread(boost::bind(&boost::asio::io_service::run,
|
||||
// &io_service_));
|
||||
for (int i = 0; i < config_.num_threads; ++i) {
|
||||
io_threads_.emplace_back(std::thread(&ObjectManager::IOServiceLoop, this));
|
||||
}
|
||||
}
|
||||
|
||||
void ObjectManager::IOServiceLoop() { io_service_.run(); }
|
||||
void ObjectManager::IOServiceLoop() { object_manager_service_->run(); }
|
||||
|
||||
void ObjectManager::StopIOService() {
|
||||
io_service_.stop();
|
||||
io_thread_.join();
|
||||
// thread_group_.join_all();
|
||||
object_manager_service_->stop();
|
||||
for (int i = 0; i < config_.num_threads; ++i) {
|
||||
io_threads_[i].join();
|
||||
}
|
||||
}
|
||||
|
||||
void ObjectManager::SetClientID(const ClientID &client_id) { client_id_ = client_id; }
|
||||
|
||||
ClientID ObjectManager::GetClientID() { return client_id_; }
|
||||
|
||||
void ObjectManager::NotifyDirectoryObjectAdd(const ObjectID &object_id) {
|
||||
local_objects_.insert(object_id);
|
||||
ray::Status status = object_directory_->ReportObjectAdded(object_id, client_id_);
|
||||
}
|
||||
|
||||
void ObjectManager::NotifyDirectoryObjectDeleted(const ObjectID &object_id) {
|
||||
local_objects_.erase(object_id);
|
||||
ray::Status status = object_directory_->ReportObjectRemoved(object_id, client_id_);
|
||||
}
|
||||
|
||||
@@ -60,391 +82,450 @@ ray::Status ObjectManager::Terminate() {
|
||||
StopIOService();
|
||||
ray::Status status_code = object_directory_->Terminate();
|
||||
// TODO: evaluate store client termination status.
|
||||
store_client_->Terminate();
|
||||
store_notification_.Terminate();
|
||||
store_pool_.Terminate();
|
||||
return status_code;
|
||||
};
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::SubscribeObjAdded(
|
||||
std::function<void(const ObjectID &)> callback) {
|
||||
store_client_->SubscribeObjAdded(callback);
|
||||
store_notification_.SubscribeObjAdded(callback);
|
||||
return ray::Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::SubscribeObjDeleted(
|
||||
std::function<void(const ObjectID &)> callback) {
|
||||
store_client_->SubscribeObjDeleted(callback);
|
||||
store_notification_.SubscribeObjDeleted(callback);
|
||||
return ray::Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::Pull(const ObjectID &object_id) {
|
||||
// TODO(hme): Need to correct. Workaround to get all pull requests on the same thread.
|
||||
SchedulePull(object_id, 0);
|
||||
main_service_->dispatch(
|
||||
[this, object_id]() { RAY_CHECK_OK(PullGetLocations(object_id)); });
|
||||
return Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
void ObjectManager::SchedulePull(const ObjectID &object_id, int wait_ms) {
|
||||
pull_requests_[object_id] = Timer(new boost::asio::deadline_timer(
|
||||
io_service_, boost::posix_time::milliseconds(wait_ms)));
|
||||
pull_requests_[object_id] = std::shared_ptr<boost::asio::deadline_timer>(
|
||||
new asio::deadline_timer(*main_service_, boost::posix_time::milliseconds(wait_ms)));
|
||||
pull_requests_[object_id]->async_wait(
|
||||
[this, object_id](const boost::system::error_code &error_code) {
|
||||
RAY_CHECK_OK(SchedulePullHandler(object_id));
|
||||
pull_requests_.erase(object_id);
|
||||
main_service_->dispatch(
|
||||
[this, object_id]() { RAY_CHECK_OK(PullGetLocations(object_id)); });
|
||||
});
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::SchedulePullHandler(const ObjectID &object_id) {
|
||||
pull_requests_.erase(object_id);
|
||||
ray::Status ObjectManager::PullGetLocations(const ObjectID &object_id) {
|
||||
ray::Status status_code = object_directory_->GetLocations(
|
||||
object_id,
|
||||
[this](const std::vector<RemoteConnectionInfo> &vec, const ObjectID &object_id) {
|
||||
return GetLocationsSuccess(vec, object_id);
|
||||
[this](const std::vector<ClientID> &client_ids, const ObjectID &object_id) {
|
||||
return GetLocationsSuccess(client_ids, object_id);
|
||||
},
|
||||
[this](ray::Status status, const ObjectID &object_id) {
|
||||
return GetLocationsFailed(status, object_id);
|
||||
});
|
||||
[this](const ObjectID &object_id) { return GetLocationsFailed(object_id); });
|
||||
return status_code;
|
||||
}
|
||||
|
||||
void ObjectManager::GetLocationsSuccess(const std::vector<ray::RemoteConnectionInfo> &vec,
|
||||
void ObjectManager::GetLocationsSuccess(const std::vector<ray::ClientID> &client_ids,
|
||||
const ray::ObjectID &object_id) {
|
||||
RemoteConnectionInfo info = vec.front();
|
||||
RAY_CHECK(!client_ids.empty());
|
||||
ClientID client_id = client_ids.front();
|
||||
pull_requests_.erase(object_id);
|
||||
ray::Status status_code = Pull(object_id, info.client_id);
|
||||
};
|
||||
ray::Status status_code = Pull(object_id, client_id);
|
||||
}
|
||||
|
||||
void ObjectManager::GetLocationsFailed(ray::Status status, const ObjectID &object_id) {
|
||||
void ObjectManager::GetLocationsFailed(const ObjectID &object_id) {
|
||||
SchedulePull(object_id, config_.pull_timeout_ms);
|
||||
};
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::Pull(const ObjectID &object_id, const ClientID &client_id) {
|
||||
Status status =
|
||||
GetMsgConnection(client_id, [this, object_id](SenderConnection::pointer client) {
|
||||
Status status = ExecutePull(object_id, client);
|
||||
});
|
||||
return status;
|
||||
main_service_->dispatch([this, object_id, client_id]() {
|
||||
RAY_CHECK_OK(PullEstablishConnection(object_id, client_id));
|
||||
});
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::ExecutePull(const ObjectID &object_id,
|
||||
SenderConnection::pointer conn) {
|
||||
size_t message_type = OMMessageType_PullRequest;
|
||||
boost::system::error_code error_code;
|
||||
boost::asio::write(conn->GetSocket(),
|
||||
boost::asio::buffer(&message_type, sizeof(message_type)),
|
||||
error_code);
|
||||
ray::Status ObjectManager::PullEstablishConnection(const ObjectID &object_id,
|
||||
const ClientID &client_id) {
|
||||
// Check if object is already local, and client_id is not itself.
|
||||
if (local_objects_.count(object_id) != 0 || client_id == client_id_) {
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
// Acquire a message connection and send pull request.
|
||||
ray::Status status;
|
||||
std::shared_ptr<SenderConnection> conn;
|
||||
// TODO(hme): There is no cap on the number of pull request connections.
|
||||
status = connection_pool_.GetSender(ConnectionPool::ConnectionType::MESSAGE, client_id,
|
||||
&conn);
|
||||
if (!status.ok()) {
|
||||
// TODO(hme): Keep track of retries,
|
||||
// and only retry on object not local
|
||||
// for now.
|
||||
SchedulePull(object_id, config_.pull_timeout_ms);
|
||||
return status;
|
||||
}
|
||||
if (conn == nullptr) {
|
||||
status = object_directory_->GetInformation(
|
||||
client_id,
|
||||
[this, object_id, client_id](const RemoteConnectionInfo &connection_info) {
|
||||
std::shared_ptr<SenderConnection> async_conn = CreateSenderConnection(
|
||||
ConnectionPool::ConnectionType::MESSAGE, connection_info);
|
||||
connection_pool_.RegisterSender(ConnectionPool::ConnectionType::MESSAGE,
|
||||
client_id, async_conn);
|
||||
RAY_CHECK_OK(PullSendRequest(object_id, async_conn));
|
||||
},
|
||||
[this, object_id](const Status &status) {
|
||||
SchedulePull(object_id, config_.pull_timeout_ms);
|
||||
});
|
||||
} else {
|
||||
RAY_CHECK_OK(PullSendRequest(object_id, conn));
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::PullSendRequest(const ObjectID &object_id,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = CreatePullRequest(fbb, fbb.CreateString(client_id_.binary()),
|
||||
fbb.CreateString(object_id.binary()));
|
||||
auto message = object_manager_protocol::CreatePullRequestMessage(
|
||||
fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary()));
|
||||
fbb.Finish(message);
|
||||
size_t length = fbb.GetSize();
|
||||
std::vector<boost::asio::const_buffer> buffer;
|
||||
buffer.push_back(boost::asio::buffer(&length, sizeof(length)));
|
||||
buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length));
|
||||
boost::asio::write(conn->GetSocket(), buffer);
|
||||
RAY_CHECK_OK(conn->WriteMessage(object_manager_protocol::MessageType_PullRequest,
|
||||
fbb.GetSize(), fbb.GetBufferPointer()));
|
||||
RAY_CHECK_OK(
|
||||
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::MESSAGE, conn));
|
||||
return ray::Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::Push(const ObjectID &object_id, const ClientID &client_id) {
|
||||
ray::Status status;
|
||||
status =
|
||||
GetTransferConnection(client_id, [this, object_id](SenderConnection::pointer conn) {
|
||||
ray::Status status = QueuePush(object_id, conn);
|
||||
});
|
||||
// TODO(hme): Cache this data in ObjectDirectory.
|
||||
// Okay for now since the GCS client caches this data.
|
||||
main_service_->dispatch([this, object_id, client_id]() {
|
||||
Status status = object_directory_->GetInformation(
|
||||
client_id,
|
||||
[this, object_id, client_id](const RemoteConnectionInfo &info) {
|
||||
transfer_queue_.QueueSend(client_id, object_id, info);
|
||||
RAY_CHECK_OK(DequeueTransfers());
|
||||
},
|
||||
[this](const Status &status) {
|
||||
// Push is best effort, so do nothing here.
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
});
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::DequeueTransfers() {
|
||||
ray::Status status = ray::Status::OK();
|
||||
// Dequeue sends.
|
||||
while (true) {
|
||||
if (std::atomic_fetch_add(&num_transfers_send_, 1) <= config_.max_sends) {
|
||||
TransferQueue::SendRequest req;
|
||||
bool exists = transfer_queue_.DequeueSendIfPresent(&req);
|
||||
if (exists) {
|
||||
object_manager_service_->dispatch([this, req]() {
|
||||
RAY_LOG(DEBUG) << "DequeueSend " << client_id_ << " " << req.object_id << " "
|
||||
<< num_transfers_send_ << "/" << config_.max_sends;
|
||||
RAY_CHECK_OK(
|
||||
ExecuteSendObject(req.object_id, req.client_id, req.connection_info));
|
||||
});
|
||||
} else {
|
||||
std::atomic_fetch_sub(&num_transfers_send_, 1);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
std::atomic_fetch_sub(&num_transfers_send_, 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Dequeue receives.
|
||||
while (true) {
|
||||
if (std::atomic_fetch_add(&num_transfers_receive_, 1) <= config_.max_receives) {
|
||||
TransferQueue::ReceiveRequest req;
|
||||
bool exists = transfer_queue_.DequeueReceiveIfPresent(&req);
|
||||
if (exists) {
|
||||
object_manager_service_->dispatch([this, req]() {
|
||||
RAY_LOG(DEBUG) << "DequeueReceive " << client_id_ << " " << req.object_id << " "
|
||||
<< num_transfers_receive_ << "/" << config_.max_receives;
|
||||
RAY_CHECK_OK(ExecuteReceiveObject(req.client_id, req.object_id, req.object_size,
|
||||
req.conn));
|
||||
});
|
||||
} else {
|
||||
std::atomic_fetch_sub(&num_transfers_receive_, 1);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
std::atomic_fetch_sub(&num_transfers_receive_, 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::TransferCompleted(TransferQueue::TransferType type) {
|
||||
if (type == TransferQueue::TransferType::SEND) {
|
||||
std::atomic_fetch_sub(&num_transfers_send_, 1);
|
||||
} else {
|
||||
std::atomic_fetch_sub(&num_transfers_receive_, 1);
|
||||
}
|
||||
return DequeueTransfers();
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::ExecuteSendObject(
|
||||
const ObjectID &object_id, const ClientID &client_id,
|
||||
const RemoteConnectionInfo &connection_info) {
|
||||
ray::Status status;
|
||||
std::shared_ptr<SenderConnection> conn;
|
||||
status = connection_pool_.GetSender(ConnectionPool::ConnectionType::TRANSFER, client_id,
|
||||
&conn);
|
||||
if (!status.ok()) {
|
||||
// TODO(hme): Keep track of retries,
|
||||
// and only retry on object not local
|
||||
// for now.
|
||||
RAY_CHECK_OK(Push(object_id, conn->GetClientID()));
|
||||
return Status::OK();
|
||||
}
|
||||
if (conn == nullptr) {
|
||||
conn =
|
||||
CreateSenderConnection(ConnectionPool::ConnectionType::TRANSFER, connection_info);
|
||||
connection_pool_.RegisterSender(ConnectionPool::ConnectionType::TRANSFER, client_id,
|
||||
conn);
|
||||
}
|
||||
status = SendObjectHeaders(object_id, conn);
|
||||
if (!status.ok()) {
|
||||
RAY_CHECK_OK(Push(object_id, conn->GetClientID()));
|
||||
return Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::SendObjectHeaders(const ObjectID &object_id_const,
|
||||
std::shared_ptr<SenderConnection> conn) {
|
||||
ObjectID object_id = ObjectID(object_id_const);
|
||||
// Allocate and append the request to the transfer queue.
|
||||
plasma::ObjectBuffer object_buffer;
|
||||
plasma::ObjectID plasma_id = object_id.to_plasma_id();
|
||||
std::shared_ptr<plasma::PlasmaClient> store_client = store_pool_.GetObjectStore();
|
||||
ARROW_CHECK_OK(store_client->Get(&plasma_id, 1, 0, &object_buffer));
|
||||
if (object_buffer.data_size == -1) {
|
||||
RAY_LOG(ERROR) << "Failed to get object";
|
||||
// If the object wasn't locally available, exit immediately. If the object
|
||||
// later appears locally, the requesting plasma manager should request the
|
||||
// transfer again.
|
||||
RAY_CHECK_OK(
|
||||
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::TRANSFER, conn));
|
||||
return ray::Status::IOError(
|
||||
"Unable to transfer object to requesting plasma manager, object not local.");
|
||||
}
|
||||
RAY_CHECK(object_buffer.metadata->data() ==
|
||||
object_buffer.data->data() + object_buffer.data_size);
|
||||
|
||||
TransferQueue::SendContext context;
|
||||
context.client_id = conn->GetClientID();
|
||||
context.object_id = object_id;
|
||||
context.object_size = static_cast<uint64_t>(object_buffer.data_size);
|
||||
context.data = const_cast<uint8_t *>(object_buffer.data->data());
|
||||
UniqueID context_id = transfer_queue_.AddContext(context);
|
||||
|
||||
// Create buffer.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
// TODO(hme): use to_flatbuf
|
||||
auto message = object_manager_protocol::CreatePushRequestMessage(
|
||||
fbb, fbb.CreateString(object_id.binary()), context.object_size);
|
||||
fbb.Finish(message);
|
||||
ray::Status status =
|
||||
conn->WriteMessage(object_manager_protocol::MessageType_PushRequest, fbb.GetSize(),
|
||||
fbb.GetBufferPointer());
|
||||
if (!status.ok()) {
|
||||
// push failed.
|
||||
// TODO(hme): Trash sender.
|
||||
RAY_CHECK_OK(
|
||||
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::TRANSFER, conn));
|
||||
return status;
|
||||
}
|
||||
|
||||
// TODO(hme): Make this async.
|
||||
return SendObjectData(conn, context_id, store_client);
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::SendObjectData(
|
||||
std::shared_ptr<SenderConnection> conn, const UniqueID &context_id,
|
||||
std::shared_ptr<plasma::PlasmaClient> store_client) {
|
||||
TransferQueue::SendContext context = transfer_queue_.GetContext(context_id);
|
||||
boost::system::error_code ec;
|
||||
std::vector<asio::const_buffer> buffer;
|
||||
buffer.push_back(asio::buffer(context.data, context.object_size));
|
||||
conn->WriteBuffer(buffer, ec);
|
||||
|
||||
ray::Status status = ray::Status::OK();
|
||||
if (ec.value() != 0) {
|
||||
// push failed.
|
||||
// TODO(hme): Trash sender.
|
||||
status = ray::Status::IOError(ec.message());
|
||||
}
|
||||
|
||||
// Do this regardless of whether it failed or succeeded.
|
||||
ARROW_CHECK_OK(store_client->Release(context.object_id.to_plasma_id()));
|
||||
store_pool_.ReleaseObjectStore(store_client);
|
||||
RAY_CHECK_OK(
|
||||
connection_pool_.ReleaseSender(ConnectionPool::ConnectionType::TRANSFER, conn));
|
||||
RAY_CHECK_OK(transfer_queue_.RemoveContext(context_id));
|
||||
RAY_LOG(DEBUG) << "SendCompleted " << client_id_ << " " << context.object_id << " "
|
||||
<< num_transfers_send_ << "/" << config_.max_sends;
|
||||
RAY_CHECK_OK(TransferCompleted(TransferQueue::TransferType::SEND));
|
||||
return status;
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::Cancel(const ObjectID &object_id) {
|
||||
// TODO(hme): Account for pull timers.
|
||||
ray::Status status = object_directory_->Cancel(object_id);
|
||||
return ray::Status::OK();
|
||||
};
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::Wait(const std::vector<ObjectID> &object_ids,
|
||||
uint64_t timeout_ms, int num_ready_objects,
|
||||
const WaitCallback &callback) {
|
||||
// TODO: Implement wait.
|
||||
return ray::Status::OK();
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::GetMsgConnection(
|
||||
const ClientID &client_id, std::function<void(SenderConnection::pointer)> callback) {
|
||||
ray::Status status = Status::OK();
|
||||
if (message_send_connections_.count(client_id) > 0) {
|
||||
callback(message_send_connections_[client_id]);
|
||||
} else {
|
||||
status = object_directory_->GetInformation(
|
||||
client_id,
|
||||
[this, callback](RemoteConnectionInfo info) {
|
||||
Status status = CreateMsgConnection(info, callback);
|
||||
},
|
||||
[this](const Status &status) {
|
||||
// TODO: deal with failure.
|
||||
});
|
||||
}
|
||||
return status;
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::CreateMsgConnection(
|
||||
const RemoteConnectionInfo &info,
|
||||
std::function<void(SenderConnection::pointer)> callback) {
|
||||
message_send_connections_.emplace(
|
||||
info.client_id, SenderConnection::Create(io_service_, info.ip, info.port));
|
||||
// Prepare client connection info buffer.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
bool is_transfer = false;
|
||||
auto message =
|
||||
CreateClientConnectionInfo(fbb, fbb.CreateString(client_id_.binary()), is_transfer);
|
||||
fbb.Finish(message);
|
||||
// Pack into asio buffer.
|
||||
size_t length = fbb.GetSize();
|
||||
std::vector<boost::asio::const_buffer> buffer;
|
||||
buffer.push_back(boost::asio::buffer(&length, sizeof(length)));
|
||||
buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length));
|
||||
// Send synchronously.
|
||||
SenderConnection::pointer conn = message_send_connections_[info.client_id];
|
||||
boost::system::error_code error;
|
||||
boost::asio::write(conn->GetSocket(), buffer);
|
||||
// The connection is ready, invoke callback with connection info.
|
||||
callback(message_send_connections_[info.client_id]);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::GetTransferConnection(
|
||||
const ClientID &client_id, std::function<void(SenderConnection::pointer)> callback) {
|
||||
ray::Status status = Status::OK();
|
||||
if (transfer_send_connections_.count(client_id) > 0) {
|
||||
callback(transfer_send_connections_[client_id]);
|
||||
} else {
|
||||
status = object_directory_->GetInformation(
|
||||
client_id,
|
||||
[this, callback](RemoteConnectionInfo info) {
|
||||
Status status = CreateTransferConnection(info, callback);
|
||||
},
|
||||
[this](const Status &status) {
|
||||
// TODO(hme): deal with failure.
|
||||
});
|
||||
}
|
||||
return status;
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::CreateTransferConnection(
|
||||
const RemoteConnectionInfo &info,
|
||||
std::function<void(SenderConnection::pointer)> callback) {
|
||||
transfer_send_connections_.emplace(
|
||||
info.client_id, SenderConnection::Create(io_service_, info.ip, info.port));
|
||||
// Prepare client connection info buffer.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
bool is_transfer = true;
|
||||
auto message =
|
||||
CreateClientConnectionInfo(fbb, fbb.CreateString(client_id_.binary()), is_transfer);
|
||||
fbb.Finish(message);
|
||||
// Pack into asio buffer.
|
||||
size_t length = fbb.GetSize();
|
||||
std::vector<boost::asio::const_buffer> buffer;
|
||||
buffer.push_back(boost::asio::buffer(&length, sizeof(length)));
|
||||
buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length));
|
||||
// Send synchronously.
|
||||
SenderConnection::pointer conn = transfer_send_connections_[info.client_id];
|
||||
boost::system::error_code ec;
|
||||
boost::asio::write(conn->GetSocket(), buffer, ec);
|
||||
callback(transfer_send_connections_[info.client_id]);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::AcceptConnection(TCPClientConnection::pointer conn) {
|
||||
boost::system::error_code ec;
|
||||
// read header
|
||||
size_t length;
|
||||
std::vector<boost::asio::mutable_buffer> header;
|
||||
header.push_back(boost::asio::buffer(&length, sizeof(length)));
|
||||
boost::asio::read(conn->GetSocket(), header, ec);
|
||||
// read data
|
||||
std::vector<uint8_t> message;
|
||||
message.resize(length);
|
||||
boost::asio::read(conn->GetSocket(), boost::asio::buffer(message), ec);
|
||||
// Serialize
|
||||
auto info = flatbuffers::GetRoot<ClientConnectionInfo>(message.data());
|
||||
ClientID client_id = ObjectID::from_binary(info->client_id()->str());
|
||||
bool is_transfer = info->is_transfer();
|
||||
// TODO: trash connection if either fails.
|
||||
if (is_transfer) {
|
||||
transfer_receive_connections_[client_id] = conn;
|
||||
Status status = WaitPushReceive(conn);
|
||||
return status;
|
||||
} else {
|
||||
message_receive_connections_[client_id] = conn;
|
||||
Status status = WaitMessage(conn);
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::WaitPushReceive(TCPClientConnection::pointer conn) {
|
||||
boost::asio::async_read(
|
||||
conn->GetSocket(),
|
||||
boost::asio::buffer(&conn->message_length_, sizeof(conn->message_length_)),
|
||||
boost::bind(&ObjectManager::HandlePushReceive, this, conn,
|
||||
boost::asio::placeholders::error));
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
void ObjectManager::HandlePushReceive(TCPClientConnection::pointer conn,
|
||||
BoostEC length_ec) {
|
||||
std::vector<uint8_t> message;
|
||||
message.resize(conn->message_length_);
|
||||
boost::system::error_code ec;
|
||||
boost::asio::read(conn->GetSocket(), boost::asio::buffer(message), ec);
|
||||
std::shared_ptr<SenderConnection> ObjectManager::CreateSenderConnection(
|
||||
ConnectionPool::ConnectionType type, RemoteConnectionInfo info) {
|
||||
std::shared_ptr<SenderConnection> conn = SenderConnection::Create(
|
||||
*object_manager_service_, info.client_id, info.ip, info.port);
|
||||
// Prepare client connection info buffer
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
bool is_transfer = (type == ConnectionPool::ConnectionType::TRANSFER);
|
||||
auto message = object_manager_protocol::CreateConnectClientMessage(
|
||||
fbb, fbb.CreateString(client_id_.binary()), is_transfer);
|
||||
fbb.Finish(message);
|
||||
// Send synchronously.
|
||||
RAY_CHECK_OK(conn->WriteMessage(object_manager_protocol::MessageType_ConnectClient,
|
||||
fbb.GetSize(), fbb.GetBufferPointer()));
|
||||
// The connection is ready; return to caller.
|
||||
return conn;
|
||||
}
|
||||
|
||||
void ObjectManager::ProcessNewClient(std::shared_ptr<TcpClientConnection> conn) {
|
||||
conn->ProcessMessages();
|
||||
}
|
||||
|
||||
void ObjectManager::ProcessClientMessage(std::shared_ptr<TcpClientConnection> conn,
|
||||
int64_t message_type, const uint8_t *message) {
|
||||
switch (message_type) {
|
||||
case object_manager_protocol::MessageType_PushRequest: {
|
||||
ReceivePushRequest(conn, message);
|
||||
break;
|
||||
}
|
||||
case object_manager_protocol::MessageType_PullRequest: {
|
||||
ReceivePullRequest(conn, message);
|
||||
break;
|
||||
}
|
||||
case object_manager_protocol::MessageType_ConnectClient: {
|
||||
ConnectClient(conn, message);
|
||||
break;
|
||||
}
|
||||
case object_manager_protocol::MessageType_DisconnectClient: {
|
||||
DisconnectClient(conn, message);
|
||||
break;
|
||||
}
|
||||
default: { RAY_LOG(FATAL) << "invalid request " << message_type; }
|
||||
}
|
||||
}
|
||||
|
||||
void ObjectManager::ConnectClient(std::shared_ptr<TcpClientConnection> &conn,
|
||||
const uint8_t *message) {
|
||||
// TODO: trash connection on failure.
|
||||
auto info =
|
||||
flatbuffers::GetRoot<object_manager_protocol::ConnectClientMessage>(message);
|
||||
ClientID client_id = ObjectID::from_binary(info->client_id()->str());
|
||||
bool is_transfer = info->is_transfer();
|
||||
conn->SetClientID(client_id);
|
||||
if (is_transfer) {
|
||||
connection_pool_.RegisterReceiver(ConnectionPool::ConnectionType::TRANSFER, client_id,
|
||||
conn);
|
||||
} else {
|
||||
connection_pool_.RegisterReceiver(ConnectionPool::ConnectionType::MESSAGE, client_id,
|
||||
conn);
|
||||
}
|
||||
conn->ProcessMessages();
|
||||
}
|
||||
|
||||
void ObjectManager::DisconnectClient(std::shared_ptr<TcpClientConnection> &conn,
|
||||
const uint8_t *message) {
|
||||
auto info =
|
||||
flatbuffers::GetRoot<object_manager_protocol::DisconnectClientMessage>(message);
|
||||
ClientID client_id = ObjectID::from_binary(info->client_id()->str());
|
||||
bool is_transfer = info->is_transfer();
|
||||
if (is_transfer) {
|
||||
connection_pool_.RemoveReceiver(ConnectionPool::ConnectionType::TRANSFER, client_id,
|
||||
conn);
|
||||
} else {
|
||||
connection_pool_.RemoveReceiver(ConnectionPool::ConnectionType::MESSAGE, client_id,
|
||||
conn);
|
||||
}
|
||||
}
|
||||
|
||||
void ObjectManager::ReceivePullRequest(std::shared_ptr<TcpClientConnection> &conn,
|
||||
const uint8_t *message) {
|
||||
// Serialize and push object to requesting client.
|
||||
auto pr = flatbuffers::GetRoot<object_manager_protocol::PullRequestMessage>(message);
|
||||
ObjectID object_id = ObjectID::from_binary(pr->object_id()->str());
|
||||
ClientID client_id = ClientID::from_binary(pr->client_id()->str());
|
||||
ray::Status push_status = Push(object_id, client_id);
|
||||
conn->ProcessMessages();
|
||||
}
|
||||
|
||||
void ObjectManager::ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn,
|
||||
const uint8_t *message) {
|
||||
// Serialize.
|
||||
auto object_header = flatbuffers::GetRoot<ObjectHeader>(message.data());
|
||||
auto object_header =
|
||||
flatbuffers::GetRoot<object_manager_protocol::PushRequestMessage>(message);
|
||||
ObjectID object_id = ObjectID::from_binary(object_header->object_id()->str());
|
||||
int64_t object_size = (int64_t)object_header->object_size();
|
||||
transfer_queue_.QueueReceive(conn->GetClientID(), object_id, object_size, conn);
|
||||
RAY_CHECK_OK(DequeueTransfers());
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::ExecuteReceiveObject(
|
||||
const ClientID &client_id, const ObjectID &object_id, uint64_t object_size,
|
||||
std::shared_ptr<TcpClientConnection> conn) {
|
||||
boost::system::error_code ec;
|
||||
int64_t metadata_size = 0;
|
||||
const plasma::ObjectID plasma_id = ObjectID(object_id).to_plasma_id();
|
||||
// Try to create shared buffer.
|
||||
std::shared_ptr<Buffer> data;
|
||||
arrow::Status s = store_client_->GetClient().Create(
|
||||
object_id.to_plasma_id(), object_size, NULL, metadata_size, &data);
|
||||
std::shared_ptr<plasma::PlasmaClient> store_client = store_pool_.GetObjectStore();
|
||||
arrow::Status s =
|
||||
store_client->Create(plasma_id, object_size, NULL, metadata_size, &data);
|
||||
std::vector<boost::asio::mutable_buffer> buffer;
|
||||
if (s.ok()) {
|
||||
// Read object into store.
|
||||
uint8_t *mutable_data = data->mutable_data();
|
||||
boost::asio::read(conn->GetSocket(), boost::asio::buffer(mutable_data, object_size),
|
||||
ec);
|
||||
buffer.push_back(asio::buffer(mutable_data, object_size));
|
||||
conn->ReadBuffer(buffer, ec);
|
||||
if (!ec.value()) {
|
||||
ARROW_CHECK_OK(store_client_->GetClient().Seal(object_id.to_plasma_id()));
|
||||
ARROW_CHECK_OK(store_client_->GetClient().Release(object_id.to_plasma_id()));
|
||||
ARROW_CHECK_OK(store_client->Seal(plasma_id));
|
||||
ARROW_CHECK_OK(store_client->Release(plasma_id));
|
||||
} else {
|
||||
ARROW_CHECK_OK(store_client_->GetClient().Release(object_id.to_plasma_id()));
|
||||
ARROW_CHECK_OK(store_client_->GetClient().Abort(object_id.to_plasma_id()));
|
||||
ARROW_CHECK_OK(store_client->Release(plasma_id));
|
||||
ARROW_CHECK_OK(store_client->Abort(plasma_id));
|
||||
RAY_LOG(ERROR) << "Receive Failed";
|
||||
}
|
||||
} else {
|
||||
RAY_LOG(ERROR) << "Buffer Create Failed: " << s.message();
|
||||
// Read object into empty buffer.
|
||||
uint8_t *mutable_data = (uint8_t *)malloc(object_size + metadata_size);
|
||||
boost::asio::read(conn->GetSocket(), boost::asio::buffer(mutable_data, object_size),
|
||||
ec);
|
||||
std::vector<uint8_t> mutable_data;
|
||||
mutable_data.resize(object_size + metadata_size);
|
||||
buffer.push_back(asio::buffer(mutable_data, object_size + metadata_size));
|
||||
conn->ReadBuffer(buffer, ec);
|
||||
}
|
||||
// Wait for another push.
|
||||
ray::Status ray_status = WaitPushReceive(conn);
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::QueuePush(const ObjectID &object_id_const,
|
||||
SenderConnection::pointer conn) {
|
||||
ObjectID object_id = ObjectID(object_id_const);
|
||||
if (conn->ObjectIdQueued(object_id)) {
|
||||
// For now, return with status OK if the object is already in the send queue.
|
||||
return ray::Status::OK();
|
||||
}
|
||||
conn->QueueObjectId(object_id);
|
||||
if (num_transfers_ < max_transfers_) {
|
||||
return ExecutePushQueue(conn);
|
||||
}
|
||||
return ray::Status::OK();
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::ExecutePushQueue(SenderConnection::pointer conn) {
|
||||
ray::Status status = ray::Status::OK();
|
||||
while (num_transfers_ < max_transfers_) {
|
||||
if (conn->IsObjectIdQueueEmpty()) {
|
||||
return ray::Status::OK();
|
||||
}
|
||||
ObjectID object_id = conn->DequeueObjectId();
|
||||
// The threads that increment/decrement num_transfers_ are different.
|
||||
// It's important to increment num_transfers_ before executing the push.
|
||||
num_transfers_ += 1;
|
||||
status = ExecutePushHeaders(object_id, conn);
|
||||
}
|
||||
return status;
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::ExecutePushHeaders(const ObjectID &object_id_const,
|
||||
SenderConnection::pointer conn) {
|
||||
ObjectID object_id = ObjectID(object_id_const);
|
||||
// Allocate and append the request to the transfer queue.
|
||||
plasma::ObjectBuffer object_buffer;
|
||||
plasma::ObjectID plasma_id = object_id.to_plasma_id();
|
||||
ARROW_CHECK_OK(store_client_->GetClientOther().Get(&plasma_id, 1, 0, &object_buffer));
|
||||
if (object_buffer.data_size == -1) {
|
||||
RAY_LOG(ERROR) << "Failed to get object";
|
||||
// If the object wasn't locally available, exit immediately. If the object
|
||||
// later appears locally, the requesting plasma manager should request the
|
||||
// transfer again.
|
||||
return ray::Status::IOError(
|
||||
"Unable to transfer object to requesting plasma manager, object not local.");
|
||||
}
|
||||
RAY_CHECK(object_buffer.metadata->data() ==
|
||||
object_buffer.data->data() + object_buffer.data_size);
|
||||
SendRequest send_request;
|
||||
send_request.object_id = object_id;
|
||||
send_request.object_size = object_buffer.data_size;
|
||||
send_request.data = const_cast<uint8_t *>(object_buffer.data->data());
|
||||
conn->AddSendRequest(object_id, send_request);
|
||||
// Create buffer.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = CreateObjectHeader(fbb, fbb.CreateString(object_id.binary()),
|
||||
send_request.object_size);
|
||||
fbb.Finish(message);
|
||||
// Pack into asio buffer.
|
||||
size_t length = fbb.GetSize();
|
||||
std::vector<boost::asio::const_buffer> buffer;
|
||||
buffer.push_back(boost::asio::buffer(&length, sizeof(length)));
|
||||
buffer.push_back(boost::asio::buffer(fbb.GetBufferPointer(), length));
|
||||
// Send asynchronously.
|
||||
boost::asio::async_write(conn->GetSocket(), buffer,
|
||||
boost::bind(&ObjectManager::ExecutePushObject, this, conn,
|
||||
object_id, boost::asio::placeholders::error));
|
||||
return ray::Status::OK();
|
||||
};
|
||||
|
||||
void ObjectManager::ExecutePushObject(SenderConnection::pointer conn,
|
||||
const ObjectID &object_id,
|
||||
const boost::system::error_code &header_ec) {
|
||||
SendRequest &send_request = conn->GetSendRequest(object_id);
|
||||
boost::system::error_code ec;
|
||||
boost::asio::write(
|
||||
conn->GetSocket(),
|
||||
boost::asio::buffer(send_request.data, (size_t)send_request.object_size), ec);
|
||||
// Do this regardless of whether it failed or succeeded.
|
||||
ARROW_CHECK_OK(
|
||||
store_client_->GetClientOther().Release(send_request.object_id.to_plasma_id()));
|
||||
|
||||
ray::Status ray_status = ExecutePushCompleted(object_id, conn);
|
||||
}
|
||||
|
||||
ray::Status ObjectManager::ExecutePushCompleted(const ObjectID &object_id,
|
||||
SenderConnection::pointer conn) {
|
||||
conn->RemoveSendRequest(object_id);
|
||||
num_transfers_ -= 1;
|
||||
return ExecutePushQueue(conn);
|
||||
};
|
||||
|
||||
ray::Status ObjectManager::WaitMessage(TCPClientConnection::pointer conn) {
|
||||
boost::asio::async_read(
|
||||
conn->GetSocket(),
|
||||
boost::asio::buffer(&conn->message_type_, sizeof(conn->message_type_)),
|
||||
boost::bind(&ObjectManager::HandleMessage, this, conn,
|
||||
boost::asio::placeholders::error));
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
void ObjectManager::HandleMessage(TCPClientConnection::pointer conn, BoostEC msg_ec) {
|
||||
switch (conn->message_type_) {
|
||||
case OMMessageType_PullRequest:
|
||||
ReceivePullRequest(conn);
|
||||
}
|
||||
}
|
||||
|
||||
void ObjectManager::ReceivePullRequest(TCPClientConnection::pointer conn) {
|
||||
boost::asio::read(
|
||||
conn->GetSocket(),
|
||||
boost::asio::buffer(&conn->message_length_, sizeof(conn->message_length_)));
|
||||
std::vector<uint8_t> message;
|
||||
message.resize(conn->message_length_);
|
||||
boost::system::error_code error_code;
|
||||
boost::asio::read(conn->GetSocket(), boost::asio::buffer(message), error_code);
|
||||
// Serialize.
|
||||
auto pull_request = flatbuffers::GetRoot<PullRequest>(message.data());
|
||||
ObjectID object_id = ObjectID::from_binary(pull_request->object_id()->str());
|
||||
ClientID client_id = ClientID::from_binary(pull_request->client_id()->str());
|
||||
// Push object to requesting client.
|
||||
ray::Status push_status = Push(object_id, client_id);
|
||||
ray::Status wait_status = WaitMessage(conn);
|
||||
store_pool_.ReleaseObjectStore(store_client);
|
||||
conn->ProcessMessages();
|
||||
RAY_LOG(DEBUG) << "ReceiveCompleted " << client_id_ << " " << object_id << " "
|
||||
<< num_transfers_receive_ << "/" << config_.max_receives;
|
||||
RAY_CHECK_OK(TransferCompleted(TransferQueue::TransferType::RECEIVE));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -12,58 +12,65 @@
|
||||
#include <boost/asio/error.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
|
||||
#include "ray/common/client_connection.h"
|
||||
#include "ray/id.h"
|
||||
#include "ray/status.h"
|
||||
|
||||
#include "plasma/client.h"
|
||||
#include "plasma/events.h"
|
||||
#include "plasma/plasma.h"
|
||||
|
||||
#include "format/object_manager_generated.h"
|
||||
#include "object_directory.h"
|
||||
#include "object_manager_client_connection.h"
|
||||
#include "object_store_client.h"
|
||||
#include "ray/id.h"
|
||||
#include "ray/status.h"
|
||||
#include "ray/object_manager/connection_pool.h"
|
||||
#include "ray/object_manager/format/object_manager_generated.h"
|
||||
#include "ray/object_manager/object_directory.h"
|
||||
#include "ray/object_manager/object_manager_client_connection.h"
|
||||
#include "ray/object_manager/object_store_client_pool.h"
|
||||
#include "ray/object_manager/object_store_notification_manager.h"
|
||||
#include "ray/object_manager/transfer_queue.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
struct ObjectManagerConfig {
|
||||
// The time in milliseconds to wait before retrying a pull
|
||||
// that failed due to client id lookup.
|
||||
/// The time in milliseconds to wait before retrying a pull
|
||||
/// that failed due to client id lookup.
|
||||
int pull_timeout_ms = 100;
|
||||
/// Size of thread pool.
|
||||
int num_threads = 2;
|
||||
/// Maximum number of sends allowed.
|
||||
int max_sends = 20;
|
||||
/// Maximum number of receives allowed.
|
||||
int max_receives = 20;
|
||||
// TODO(hme): Implement num retries (to avoid infinite retries).
|
||||
std::string store_socket_name;
|
||||
};
|
||||
|
||||
// TODO(hme): Comment everything doxygen-style.
|
||||
// TODO(hme): Implement connection cleanup.
|
||||
// TODO(hme): Add success/failure callbacks for push and pull.
|
||||
// TODO(hme): Use boost thread pool.
|
||||
// TODO(hme): Add incoming connections to io_service tied to thread pool.
|
||||
class ObjectManager {
|
||||
public:
|
||||
/// Implicitly instantiates Ray implementation of ObjectDirectory.
|
||||
///
|
||||
/// \param io_service The asio io_service tied to the object manager.
|
||||
/// \param main_service The main asio io_service.
|
||||
/// \param object_manager_service The asio io_service tied to the object manager.
|
||||
/// \param config ObjectManager configuration.
|
||||
/// \param gcs_client A client connection to the Ray GCS.
|
||||
explicit ObjectManager(boost::asio::io_service &io_service, ObjectManagerConfig config,
|
||||
std::shared_ptr<ray::GcsClient> gcs_client);
|
||||
explicit ObjectManager(boost::asio::io_service &main_service,
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service,
|
||||
const ObjectManagerConfig &config,
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
|
||||
|
||||
/// Takes user-defined ObjectDirectoryInterface implementation.
|
||||
/// When this constructor is used, the ObjectManager assumes ownership of
|
||||
/// the given ObjectDirectory instance.
|
||||
///
|
||||
/// \param io_service The asio io_service tied to the object manager.
|
||||
/// \param main_service The main asio io_service.
|
||||
/// \param object_manager_service The asio io_service tied to the object manager.
|
||||
/// \param config ObjectManager configuration.
|
||||
/// \param od An object implementing the object directory interface.
|
||||
explicit ObjectManager(boost::asio::io_service &io_service, ObjectManagerConfig config,
|
||||
explicit ObjectManager(boost::asio::io_service &main_service,
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service,
|
||||
const ObjectManagerConfig &config,
|
||||
std::unique_ptr<ObjectDirectoryInterface> od);
|
||||
|
||||
/// \param client_id Set the client id associated with this node.
|
||||
void SetClientID(const ClientID &client_id);
|
||||
|
||||
/// \return Get the client id associated with this node.
|
||||
ClientID GetClientID();
|
||||
|
||||
/// Subscribe to notifications of objects added to local store.
|
||||
/// Upon subscribing, the callback will be invoked for all objects that
|
||||
///
|
||||
@@ -106,7 +113,17 @@ class ObjectManager {
|
||||
///
|
||||
/// \param conn The connection.
|
||||
/// \return Status of whether the connection was successfully established.
|
||||
ray::Status AcceptConnection(TCPClientConnection::pointer conn);
|
||||
void ProcessNewClient(std::shared_ptr<TcpClientConnection> conn);
|
||||
|
||||
/// Process messages sent from other nodes. We only establish
|
||||
/// transfer connections using this method; all other transfer communication
|
||||
/// is done separately.
|
||||
///
|
||||
/// \param conn The connection.
|
||||
/// \param message_type The message type.
|
||||
/// \param message A pointer set to the beginning of the message.
|
||||
void ProcessClientMessage(std::shared_ptr<TcpClientConnection> conn,
|
||||
int64_t message_type, const uint8_t *message);
|
||||
|
||||
/// Cancels all requests (Push/Pull) associated with the given ObjectID.
|
||||
///
|
||||
@@ -114,7 +131,7 @@ class ObjectManager {
|
||||
/// \return Status of whether requests were successfully cancelled.
|
||||
ray::Status Cancel(const ObjectID &object_id);
|
||||
|
||||
// Callback definition for wait.
|
||||
/// Callback definition for wait.
|
||||
using WaitCallback = std::function<void(const ray::Status, uint64_t,
|
||||
const std::vector<ray::ObjectID> &)>;
|
||||
/// Wait for timeout_ms before invoking the provided callback.
|
||||
@@ -135,132 +152,137 @@ class ObjectManager {
|
||||
ray::Status Terminate();
|
||||
|
||||
private:
|
||||
using BoostEC = const boost::system::error_code &;
|
||||
|
||||
ClientID client_id_;
|
||||
ObjectManagerConfig config_;
|
||||
std::unique_ptr<ObjectDirectoryInterface> object_directory_;
|
||||
std::unique_ptr<ObjectStoreClient> store_client_;
|
||||
ObjectStoreNotificationManager store_notification_;
|
||||
ObjectStoreClientPool store_pool_;
|
||||
|
||||
/// An io service for creating connections to other object managers.
|
||||
boost::asio::io_service io_service_;
|
||||
/// This runs on a thread pool.
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service_;
|
||||
/// Weak reference to main service. We ensure this object is destroyed before
|
||||
/// main_service_ is stopped.
|
||||
boost::asio::io_service *main_service_;
|
||||
|
||||
/// Used to create "work" for an io service, so when it's run, it doesn't exit.
|
||||
boost::asio::io_service::work work_;
|
||||
|
||||
/// Single thread for executing asynchronous handlers.
|
||||
/// This runs the (currently only) io_service, which handles all outgoing requests
|
||||
/// and object transfers (push).
|
||||
std::thread io_thread_;
|
||||
/// Thread pool for executing asynchronous handlers.
|
||||
/// These run the object_manager_service_, which handle
|
||||
/// all incoming and outgoing object transfers.
|
||||
std::vector<std::thread> io_threads_;
|
||||
|
||||
/// Relatively simple way to add thread pooling.
|
||||
/// boost::thread_group thread_group_;
|
||||
/// Connection pool for reusing outgoing connections to remote object managers.
|
||||
ConnectionPool connection_pool_;
|
||||
|
||||
/// Timeout for failed pull requests.
|
||||
using Timer = std::shared_ptr<boost::asio::deadline_timer>;
|
||||
std::unordered_map<ObjectID, Timer, UniqueIDHasher> pull_requests_;
|
||||
std::unordered_map<ObjectID, std::shared_ptr<boost::asio::deadline_timer>,
|
||||
UniqueIDHasher>
|
||||
pull_requests_;
|
||||
|
||||
// TODO (hme): This needs to account for receives as well.
|
||||
/// This number is incremented whenever a push is started.
|
||||
int num_transfers_ = 0;
|
||||
// TODO (hme): Allow for concurrent sends.
|
||||
/// This is the maximum number of pushes allowed.
|
||||
/// We can only increase this number if we increase the number of
|
||||
/// plasma client connections.
|
||||
int max_transfers_ = 1;
|
||||
/// Allows control of concurrent object transfers. This is a global queue,
|
||||
/// allowing for concurrent transfers with many object managers as well as
|
||||
/// concurrent transfers, including both sends and receives, with a single
|
||||
/// remote object manager.
|
||||
TransferQueue transfer_queue_;
|
||||
|
||||
/// Note that (currently) receives take place on the main thread,
|
||||
/// and sends take place on a dedicated thread.
|
||||
std::unordered_map<ray::ClientID, SenderConnection::pointer, ray::UniqueIDHasher>
|
||||
message_send_connections_;
|
||||
std::unordered_map<ray::ClientID, SenderConnection::pointer, ray::UniqueIDHasher>
|
||||
transfer_send_connections_;
|
||||
/// Variables to track number of concurrent sends and receives.
|
||||
std::atomic<int> num_transfers_send_;
|
||||
std::atomic<int> num_transfers_receive_;
|
||||
|
||||
std::unordered_map<ray::ClientID, TCPClientConnection::pointer, ray::UniqueIDHasher>
|
||||
message_receive_connections_;
|
||||
std::unordered_map<ray::ClientID, TCPClientConnection::pointer, ray::UniqueIDHasher>
|
||||
transfer_receive_connections_;
|
||||
/// Cache of locally available objects.
|
||||
std::unordered_set<ObjectID, UniqueIDHasher> local_objects_;
|
||||
|
||||
/// Handle starting, running, and stopping asio io_service.
|
||||
void StartIOService();
|
||||
void IOServiceLoop();
|
||||
void StopIOService();
|
||||
|
||||
/// Register object add with directory.
|
||||
void NotifyDirectoryObjectAdd(const ObjectID &object_id);
|
||||
|
||||
/// Register object remove with directory.
|
||||
void NotifyDirectoryObjectDeleted(const ObjectID &object_id);
|
||||
|
||||
/// Wait wait_ms milliseconds before triggering a pull request for object_id.
|
||||
/// This is invoked when a pull fails. Only point of failure currently considered
|
||||
/// is GetLocationsFailed.
|
||||
void SchedulePull(const ObjectID &object_id, int wait_ms);
|
||||
|
||||
/// The handler for SchedulePull. Invokes a pull and removes the deadline timer
|
||||
/// that was added to schedule the pull.
|
||||
ray::Status SchedulePullHandler(const ObjectID &object_id);
|
||||
/// Part of an asynchronous sequence of Pull methods.
|
||||
/// Gets the location of an object before invoking PullEstablishConnection.
|
||||
/// Guaranteed to execute on main_service_ thread.
|
||||
/// Executes on main_service_ thread.
|
||||
ray::Status PullGetLocations(const ObjectID &object_id);
|
||||
|
||||
/// Synchronously send a pull request.
|
||||
/// Invoked once a connection to a remote manager that contains the required ObjectID
|
||||
/// is established.
|
||||
ray::Status ExecutePull(const ObjectID &object_id, SenderConnection::pointer conn);
|
||||
/// Part of an asynchronous sequence of Pull methods.
|
||||
/// Uses an existing connection or creates a connection to ClientID.
|
||||
/// Executes on main_service_ thread.
|
||||
ray::Status PullEstablishConnection(const ObjectID &object_id,
|
||||
const ClientID &client_id);
|
||||
|
||||
/// Invoked once a connection to the remote manager to which the ObjectID
|
||||
/// is to be sent is established.
|
||||
ray::Status QueuePush(const ObjectID &object_id, SenderConnection::pointer client);
|
||||
/// Starts as many queued pushes as possible without exceeding max_transfers_
|
||||
/// concurrent transfers.
|
||||
ray::Status ExecutePushQueue(SenderConnection::pointer client);
|
||||
/// Initiate a push. This method asynchronously sends the object id and object size
|
||||
/// Private callback implementation for success on get location. Called from
|
||||
/// ObjectDirectory.
|
||||
void GetLocationsSuccess(const std::vector<ray::ClientID> &client_ids,
|
||||
const ray::ObjectID &object_id);
|
||||
|
||||
/// Private callback implementation for failure on get location. Called from
|
||||
/// ObjectDirectory.
|
||||
void GetLocationsFailed(const ObjectID &object_id);
|
||||
|
||||
/// Synchronously send a pull request via remote object manager connection.
|
||||
/// Executes on main_service_ thread.
|
||||
ray::Status PullSendRequest(const ObjectID &object_id,
|
||||
std::shared_ptr<SenderConnection> conn);
|
||||
|
||||
/// Starts as many queued sends and receives as possible without exceeding
|
||||
/// config_.max_sends and config_.max_receives, respectively.
|
||||
/// Executes on object_manager_service_ thread pool.
|
||||
ray::Status DequeueTransfers();
|
||||
|
||||
std::shared_ptr<SenderConnection> CreateSenderConnection(
|
||||
ConnectionPool::ConnectionType type, RemoteConnectionInfo info);
|
||||
|
||||
/// Invoked when a transfer is completed. Invokes DequeueTransfers after
|
||||
/// updating variables that track concurrent transfers.
|
||||
/// Executes on object_manager_service_ thread pool.
|
||||
ray::Status TransferCompleted(TransferQueue::TransferType type);
|
||||
|
||||
/// Begin executing a send.
|
||||
/// Executes on object_manager_service_ thread pool.
|
||||
ray::Status ExecuteSendObject(const ObjectID &object_id, const ClientID &client_id,
|
||||
const RemoteConnectionInfo &connection_info);
|
||||
/// This method synchronously sends the object id and object size
|
||||
/// to the remote object manager.
|
||||
ray::Status ExecutePushHeaders(const ObjectID &object_id,
|
||||
SenderConnection::pointer client);
|
||||
/// Called by the handler for ExecutePushMeta.
|
||||
/// Executes on object_manager_service_ thread pool.
|
||||
ray::Status SendObjectHeaders(const ObjectID &object_id,
|
||||
std::shared_ptr<SenderConnection> client);
|
||||
|
||||
/// This method initiates the actual object transfer.
|
||||
void ExecutePushObject(SenderConnection::pointer conn, const ObjectID &object_id,
|
||||
const boost::system::error_code &header_ec);
|
||||
/// Invoked when a push is completed. This method will decrement num_transfers_
|
||||
/// and invoke ExecutePushQueue.
|
||||
ray::Status ExecutePushCompleted(const ObjectID &object_id,
|
||||
SenderConnection::pointer client);
|
||||
/// Executes on object_manager_service_ thread pool.
|
||||
ray::Status SendObjectData(std::shared_ptr<SenderConnection> conn,
|
||||
const UniqueID &context_id,
|
||||
std::shared_ptr<plasma::PlasmaClient> store_client);
|
||||
|
||||
/// Private callback implementation for success on get location. Called inside OD.
|
||||
void GetLocationsSuccess(const std::vector<RemoteConnectionInfo> &vec,
|
||||
const ObjectID &object_id);
|
||||
|
||||
/// Private callback implementation for failure on get location. Called inside OD.
|
||||
void GetLocationsFailed(ray::Status status, const ObjectID &object_id);
|
||||
|
||||
/// Asynchronously obtain a connection to client_id.
|
||||
/// If a connection to client_id already exists, the callback is invoked immediately.
|
||||
ray::Status GetMsgConnection(const ClientID &client_id,
|
||||
std::function<void(SenderConnection::pointer)> callback);
|
||||
/// Asynchronously create a connection to client_id.
|
||||
ray::Status CreateMsgConnection(
|
||||
const RemoteConnectionInfo &info,
|
||||
std::function<void(SenderConnection::pointer)> callback);
|
||||
/// Asynchronously create a connection to client_id.
|
||||
ray::Status GetTransferConnection(
|
||||
const ClientID &client_id, std::function<void(SenderConnection::pointer)> callback);
|
||||
/// Asynchronously obtain a connection to client_id.
|
||||
/// If a connection to client_id already exists, the callback is invoked immediately.
|
||||
ray::Status CreateTransferConnection(
|
||||
const RemoteConnectionInfo &info,
|
||||
std::function<void(SenderConnection::pointer)> callback);
|
||||
|
||||
/// A socket connection doing an asynchronous read on a transfer connection that was
|
||||
/// added by AcceptConnection.
|
||||
ray::Status WaitPushReceive(TCPClientConnection::pointer conn);
|
||||
/// Invoked when a remote object manager pushes an object to this object manager.
|
||||
void HandlePushReceive(TCPClientConnection::pointer conn, BoostEC length_ec);
|
||||
/// This will queue the receive.
|
||||
void ReceivePushRequest(std::shared_ptr<TcpClientConnection> conn,
|
||||
const uint8_t *message);
|
||||
/// Execute a receive that was in the queue.
|
||||
ray::Status ExecuteReceiveObject(const ClientID &client_id, const ObjectID &object_id,
|
||||
uint64_t object_size,
|
||||
std::shared_ptr<TcpClientConnection> conn);
|
||||
|
||||
/// A socket connection doing an asynchronous read on a message connection that was
|
||||
/// added by AcceptConnection.
|
||||
ray::Status WaitMessage(TCPClientConnection::pointer conn);
|
||||
/// Handle messages.
|
||||
void HandleMessage(TCPClientConnection::pointer conn, BoostEC msg_ec);
|
||||
/// Process the receive pull request message.
|
||||
void ReceivePullRequest(TCPClientConnection::pointer conn);
|
||||
/// Handles receiving a pull request message.
|
||||
void ReceivePullRequest(std::shared_ptr<TcpClientConnection> &conn,
|
||||
const uint8_t *message);
|
||||
|
||||
/// Register object add with directory.
|
||||
void NotifyDirectoryObjectAdd(const ObjectID &object_id);
|
||||
/// Register object remove with directory.
|
||||
void NotifyDirectoryObjectDeleted(const ObjectID &object_id);
|
||||
/// Handles connect message of a new client connection.
|
||||
void ConnectClient(std::shared_ptr<TcpClientConnection> &conn, const uint8_t *message);
|
||||
/// Handles disconnect message of an existing client connection.
|
||||
void DisconnectClient(std::shared_ptr<TcpClientConnection> &conn,
|
||||
const uint8_t *message);
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -1,59 +1,24 @@
|
||||
#include "object_manager_client_connection.h"
|
||||
#include "ray/object_manager/object_manager_client_connection.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
SenderConnection::pointer SenderConnection::Create(boost::asio::io_service &io_service,
|
||||
const std::string &ip, uint16_t port) {
|
||||
return pointer(new SenderConnection(io_service, ip, port));
|
||||
uint64_t SenderConnection::id_counter_;
|
||||
|
||||
std::shared_ptr<SenderConnection> SenderConnection::Create(
|
||||
boost::asio::io_service &io_service, const ClientID &client_id, const std::string &ip,
|
||||
uint16_t port) {
|
||||
boost::asio::ip::tcp::socket socket(io_service);
|
||||
RAY_CHECK_OK(TcpConnect(socket, ip, port));
|
||||
std::shared_ptr<TcpServerConnection> conn =
|
||||
std::make_shared<TcpServerConnection>(std::move(socket));
|
||||
return std::make_shared<SenderConnection>(conn, client_id);
|
||||
};
|
||||
|
||||
SenderConnection::SenderConnection(boost::asio::io_service &io_service,
|
||||
const std::string &ip, uint16_t port)
|
||||
: socket_(io_service), send_queue_() {
|
||||
boost::asio::ip::address ip_address = boost::asio::ip::address::from_string(ip);
|
||||
boost::asio::ip::tcp::endpoint endpoint(ip_address, port);
|
||||
socket_.connect(endpoint);
|
||||
SenderConnection::SenderConnection(std::shared_ptr<TcpServerConnection> conn,
|
||||
const ClientID &client_id)
|
||||
: conn_(conn) {
|
||||
client_id_ = client_id;
|
||||
connection_id_ = SenderConnection::id_counter_++;
|
||||
};
|
||||
|
||||
boost::asio::ip::tcp::socket &SenderConnection::GetSocket() { return socket_; };
|
||||
|
||||
bool SenderConnection::IsObjectIdQueueEmpty() { return send_queue_.empty(); }
|
||||
|
||||
bool SenderConnection::ObjectIdQueued(const ObjectID &object_id) {
|
||||
return std::find(send_queue_.begin(), send_queue_.end(), object_id) !=
|
||||
send_queue_.end();
|
||||
}
|
||||
|
||||
void SenderConnection::QueueObjectId(const ObjectID &object_id) {
|
||||
send_queue_.push_back(ObjectID(object_id));
|
||||
}
|
||||
|
||||
ObjectID SenderConnection::DequeueObjectId() {
|
||||
ObjectID object_id = send_queue_.front();
|
||||
send_queue_.pop_front();
|
||||
return object_id;
|
||||
}
|
||||
|
||||
void SenderConnection::AddSendRequest(const ObjectID &object_id,
|
||||
SendRequest &send_request) {
|
||||
send_requests_.emplace(object_id, send_request);
|
||||
}
|
||||
|
||||
void SenderConnection::RemoveSendRequest(const ObjectID &object_id) {
|
||||
send_requests_.erase(object_id);
|
||||
}
|
||||
|
||||
SendRequest &SenderConnection::GetSendRequest(const ObjectID &object_id) {
|
||||
return send_requests_[object_id];
|
||||
};
|
||||
|
||||
TCPClientConnection::TCPClientConnection(boost::asio::io_service &io_service)
|
||||
: socket_(io_service) {}
|
||||
|
||||
TCPClientConnection::pointer TCPClientConnection::Create(
|
||||
boost::asio::io_service &io_service) {
|
||||
return TCPClientConnection::pointer(new TCPClientConnection(io_service));
|
||||
}
|
||||
|
||||
boost::asio::ip::tcp::socket &TCPClientConnection::GetSocket() { return socket_; }
|
||||
} // namespace ray
|
||||
|
||||
@@ -7,63 +7,72 @@
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/asio/error.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
#include <boost/enable_shared_from_this.hpp>
|
||||
|
||||
#include "common/state/ray_config.h"
|
||||
#include "ray/common/client_connection.h"
|
||||
#include "ray/id.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
struct SendRequest {
|
||||
ObjectID object_id;
|
||||
ClientID client_id;
|
||||
int64_t object_size;
|
||||
uint8_t *data;
|
||||
};
|
||||
|
||||
// TODO(hme): Document public API after integration with common connection.
|
||||
class SenderConnection : public boost::enable_shared_from_this<SenderConnection> {
|
||||
public:
|
||||
typedef boost::shared_ptr<SenderConnection> pointer;
|
||||
typedef std::unordered_map<ray::ObjectID, SendRequest, UniqueIDHasher> SendRequestsType;
|
||||
typedef std::deque<ray::ObjectID> SendQueueType;
|
||||
/// Create a connection for sending data to other object managers.
|
||||
///
|
||||
/// \param io_service The service to which the created socket should attach.
|
||||
/// \param client_id The ClientID of the remote node.
|
||||
/// \param ip The ip address of the remote node server.
|
||||
/// \param port The port of the remote node server.
|
||||
/// \return A connection to the remote object manager.
|
||||
static std::shared_ptr<SenderConnection> Create(boost::asio::io_service &io_service,
|
||||
const ClientID &client_id,
|
||||
const std::string &ip, uint16_t port);
|
||||
|
||||
static pointer Create(boost::asio::io_service &io_service, const std::string &ip,
|
||||
uint16_t port);
|
||||
/// \param socket A reference to the socket created by the static Create method.
|
||||
/// \param client_id The ClientID of the remote node.
|
||||
SenderConnection(std::shared_ptr<TcpServerConnection> conn, const ClientID &client_id);
|
||||
|
||||
explicit SenderConnection(boost::asio::io_service &io_service, const std::string &ip,
|
||||
uint16_t port);
|
||||
/// Write a message to the client.
|
||||
///
|
||||
/// \param type The message type (e.g., a flatbuffer enum).
|
||||
/// \param length The size in bytes of the message.
|
||||
/// \param message A pointer to the message buffer.
|
||||
/// \return Status.
|
||||
ray::Status WriteMessage(int64_t type, uint64_t length, const uint8_t *message) {
|
||||
return conn_->WriteMessage(type, length, message);
|
||||
}
|
||||
|
||||
boost::asio::ip::tcp::socket &GetSocket();
|
||||
/// Write a buffer to this connection.
|
||||
///
|
||||
/// \param buffer The buffer.
|
||||
/// \param ec The error code object in which to store error codes.
|
||||
void WriteBuffer(const std::vector<boost::asio::const_buffer> &buffer,
|
||||
boost::system::error_code &ec) {
|
||||
return conn_->WriteBuffer(buffer, ec);
|
||||
}
|
||||
|
||||
bool IsObjectIdQueueEmpty();
|
||||
bool ObjectIdQueued(const ObjectID &object_id);
|
||||
void QueueObjectId(const ObjectID &object_id);
|
||||
ObjectID DequeueObjectId();
|
||||
/// Read a buffer from this connection.
|
||||
///
|
||||
/// \param buffer The buffer.
|
||||
/// \param ec The error code object in which to store error codes.
|
||||
void ReadBuffer(const std::vector<boost::asio::mutable_buffer> &buffer,
|
||||
boost::system::error_code &ec) {
|
||||
return conn_->ReadBuffer(buffer, ec);
|
||||
}
|
||||
|
||||
void AddSendRequest(const ObjectID &object_id, SendRequest &send_request);
|
||||
void RemoveSendRequest(const ObjectID &object_id);
|
||||
SendRequest &GetSendRequest(const ObjectID &object_id);
|
||||
/// \return The ClientID of this connection.
|
||||
const ClientID &GetClientID() { return client_id_; }
|
||||
|
||||
private:
|
||||
boost::asio::ip::tcp::socket socket_;
|
||||
SendQueueType send_queue_;
|
||||
SendRequestsType send_requests_;
|
||||
};
|
||||
bool operator==(const SenderConnection &rhs) const {
|
||||
return connection_id_ == rhs.connection_id_;
|
||||
}
|
||||
|
||||
// TODO(hme): Document public API after integration with common connection.
|
||||
class TCPClientConnection : public boost::enable_shared_from_this<TCPClientConnection> {
|
||||
public:
|
||||
typedef boost::shared_ptr<TCPClientConnection> pointer;
|
||||
static pointer Create(boost::asio::io_service &io_service);
|
||||
boost::asio::ip::tcp::socket &GetSocket();
|
||||
|
||||
TCPClientConnection(boost::asio::io_service &io_service);
|
||||
|
||||
int64_t message_type_;
|
||||
uint64_t message_length_;
|
||||
|
||||
private:
|
||||
boost::asio::ip::tcp::socket socket_;
|
||||
static uint64_t id_counter_;
|
||||
uint64_t connection_id_;
|
||||
ClientID client_id_;
|
||||
std::shared_ptr<TcpServerConnection> conn_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
// TODO(hme): Move all messaging code here.
|
||||
@@ -1,6 +0,0 @@
|
||||
#ifndef RAY_OBJECT_MANAGER_OBJECT_MANAGER_PROTOCOL_H
|
||||
#define RAY_OBJECT_MANAGER_OBJECT_MANAGER_PROTOCOL_H
|
||||
|
||||
// TODO(hme): Move all messaging code here.
|
||||
|
||||
#endif // RAY_OBJECT_MANAGER_OBJECT_MANAGER_PROTOCOL_H
|
||||
@@ -1,137 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "plasma/client.h"
|
||||
#include "plasma/events.h"
|
||||
#include "plasma/plasma.h"
|
||||
#include "plasma/protocol.h"
|
||||
|
||||
#include "ray/status.h"
|
||||
|
||||
#include "ray/object_manager/object_manager.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
std::string test_executable; // NOLINT
|
||||
|
||||
class TestObjectManager : public ::testing::Test {
|
||||
public:
|
||||
TestObjectManager() { RAY_LOG(DEBUG) << "TestObjectManager: started."; }
|
||||
|
||||
void SetUp() {
|
||||
// start store
|
||||
std::string om_dir = test_executable.substr(0, test_executable.find_last_of("/"));
|
||||
std::string plasma_dir = om_dir + "./../plasma";
|
||||
std::string plasma_command =
|
||||
plasma_dir +
|
||||
"/plasma_store -m 1000000000 -s /tmp/store 1> /dev/null 2> /dev/null &";
|
||||
int s = system(plasma_command.c_str());
|
||||
ASSERT_TRUE(!s);
|
||||
|
||||
// Start mock global control store.
|
||||
mock_gcs_client_ = std::shared_ptr<GcsClient>(new GcsClient());
|
||||
// mock_gcs_client_->Register();
|
||||
|
||||
// Start node server.
|
||||
|
||||
// Start object manager 1.
|
||||
ObjectManagerConfig config;
|
||||
config.store_socket_name = "/tmp/store";
|
||||
object_manager_1_ = std::unique_ptr<ObjectManager>(
|
||||
new ObjectManager(io_service_, config, mock_gcs_client_));
|
||||
|
||||
// Start object manager 2.
|
||||
// ObjectManagerConfig config2;
|
||||
// config2.store_socket_name = "/tmp/store";
|
||||
// std::shared_ptr<ObjectDirectory> od2 = std::shared_ptr<ObjectDirectory>(new
|
||||
// ObjectDirectory());
|
||||
// od2->InitGcs(mock_gcs_client_);
|
||||
// object_manager_2_ = std::unique_ptr<ObjectManager>(new ObjectManager(io_service,
|
||||
// config2, od2));
|
||||
|
||||
// Initiate client connection.
|
||||
ARROW_CHECK_OK(client_.Connect("/tmp/store", "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
|
||||
this->StartLoop();
|
||||
}
|
||||
|
||||
void TearDown() {
|
||||
this->StopLoop();
|
||||
arrow::Status arrow_status = client_.Disconnect();
|
||||
ASSERT_TRUE(arrow_status.ok());
|
||||
ray::Status ray_status = object_manager_1_->Terminate();
|
||||
ASSERT_TRUE(ray_status.ok());
|
||||
// object_manager_2_->Terminate();
|
||||
int s = system("killall plasma_store &");
|
||||
ASSERT_TRUE(!s);
|
||||
}
|
||||
|
||||
void Loop() { io_service_.run(); };
|
||||
|
||||
void StartLoop() { process_thread_ = std::thread(&TestObjectManager::Loop, this); };
|
||||
|
||||
void StopLoop() {
|
||||
io_service_.stop();
|
||||
process_thread_.join();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::thread process_thread_;
|
||||
plasma::PlasmaClient client_;
|
||||
plasma::PlasmaClient client2_;
|
||||
boost::asio::io_service io_service_;
|
||||
|
||||
std::shared_ptr<GcsClient> mock_gcs_client_;
|
||||
std::unique_ptr<ObjectManager> object_manager_1_;
|
||||
std::unique_ptr<ObjectManager> object_manager_2_;
|
||||
};
|
||||
|
||||
// TODO: get rid of dead code?
|
||||
// TEST_F(TestObjectManager, TestPush) {
|
||||
// // test object push between two object managers.
|
||||
// ASSERT_TRUE(true);
|
||||
// sleep(1);
|
||||
//}
|
||||
|
||||
// TEST_F(TestObjectManager, TestPull) {
|
||||
// ObjectID object_id = ObjectID().from_random();
|
||||
// ClientID dbc_id = ClientID().from_random();
|
||||
// RAY_LOG(INFO) << "ObjectID: " << object_id.hex().c_str();
|
||||
// RAY_LOG(INFO) << "ClientID: " << dbc_id.hex().c_str();
|
||||
// om->Pull(object_id, dbc_id);
|
||||
// om->Pull(object_id);
|
||||
// ASSERT_TRUE(true);
|
||||
// sleep(1);
|
||||
//}
|
||||
|
||||
void ObjectAdded(const ObjectID &object_id) {
|
||||
RAY_LOG(INFO) << "ObjectID Added: " << object_id.hex().c_str();
|
||||
}
|
||||
|
||||
TEST_F(TestObjectManager, TestNotifications) {
|
||||
ray::Status status = object_manager_1_->SubscribeObjAdded(ObjectAdded);
|
||||
ASSERT_TRUE(status.ok());
|
||||
// put object
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
ObjectID object_id = ObjectID::from_random();
|
||||
RAY_LOG(INFO) << "ObjectID Created: " << object_id.hex().c_str();
|
||||
int64_t data_size = 100;
|
||||
uint8_t metadata[] = {5};
|
||||
int64_t metadata_size = sizeof(metadata);
|
||||
std::shared_ptr<Buffer> data;
|
||||
ARROW_CHECK_OK(client_.Create(object_id.to_plasma_id(), data_size, metadata,
|
||||
metadata_size, &data));
|
||||
ARROW_CHECK_OK(client_.Seal(object_id.to_plasma_id()));
|
||||
}
|
||||
// TODO(hme): Can we do this without sleeping?
|
||||
sleep(1);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
ray::test_executable = std::string(argv[0]);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
#include <boost/function.hpp>
|
||||
|
||||
#include "common.h"
|
||||
#include "common_protocol.h"
|
||||
#include "ray/object_manager/object_store_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
// TODO(hme): Dedicate this class to notifications.
|
||||
// TODO(hme): Create object store client pool for object manager.
|
||||
ObjectStoreClient::ObjectStoreClient(boost::asio::io_service &io_service,
|
||||
std::string &store_socket_name)
|
||||
: client_one_(), client_two_(), socket_(io_service) {
|
||||
ARROW_CHECK_OK(
|
||||
client_two_.Connect(store_socket_name.c_str(), "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
ARROW_CHECK_OK(
|
||||
client_one_.Connect(store_socket_name.c_str(), "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
|
||||
// Connect to two clients, but subscribe to only one.
|
||||
ARROW_CHECK_OK(client_one_.Subscribe(&c_socket_));
|
||||
boost::system::error_code ec;
|
||||
socket_.assign(boost::asio::local::stream_protocol(), c_socket_, ec);
|
||||
assert(!ec.value());
|
||||
NotificationWait();
|
||||
};
|
||||
|
||||
void ObjectStoreClient::Terminate() {
|
||||
ARROW_CHECK_OK(client_two_.Disconnect());
|
||||
ARROW_CHECK_OK(client_one_.Disconnect());
|
||||
}
|
||||
|
||||
void ObjectStoreClient::NotificationWait() {
|
||||
boost::asio::async_read(socket_, boost::asio::buffer(&length_, sizeof(length_)),
|
||||
boost::bind(&ObjectStoreClient::ProcessStoreLength, this,
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void ObjectStoreClient::ProcessStoreLength(const boost::system::error_code &error) {
|
||||
notification_.resize(length_);
|
||||
boost::asio::async_read(socket_, boost::asio::buffer(notification_),
|
||||
boost::bind(&ObjectStoreClient::ProcessStoreNotification, this,
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void ObjectStoreClient::ProcessStoreNotification(const boost::system::error_code &error) {
|
||||
if (error) {
|
||||
throw std::runtime_error("ObjectStore may have died.");
|
||||
}
|
||||
|
||||
auto object_info = flatbuffers::GetRoot<ObjectInfo>(notification_.data());
|
||||
ObjectID object_id = from_flatbuf(*object_info->object_id());
|
||||
if (object_info->is_deletion()) {
|
||||
ProcessStoreRemove(object_id);
|
||||
} else {
|
||||
ProcessStoreAdd(object_id);
|
||||
// why all these params?
|
||||
// ProcessStoreAdd(
|
||||
// object_id, object_info->data_size(),
|
||||
// object_info->metadata_size(),
|
||||
// (unsigned char *) object_info->digest()->data());
|
||||
}
|
||||
NotificationWait();
|
||||
}
|
||||
|
||||
void ObjectStoreClient::ProcessStoreAdd(const ObjectID &object_id) {
|
||||
for (auto handler : add_handlers_) {
|
||||
handler(object_id);
|
||||
}
|
||||
};
|
||||
|
||||
void ObjectStoreClient::ProcessStoreRemove(const ObjectID &object_id) {
|
||||
for (auto handler : rem_handlers_) {
|
||||
handler(object_id);
|
||||
}
|
||||
};
|
||||
|
||||
void ObjectStoreClient::SubscribeObjAdded(
|
||||
std::function<void(const ObjectID &)> callback) {
|
||||
add_handlers_.push_back(callback);
|
||||
};
|
||||
|
||||
void ObjectStoreClient::SubscribeObjDeleted(
|
||||
std::function<void(const ObjectID &)> callback) {
|
||||
rem_handlers_.push_back(callback);
|
||||
};
|
||||
|
||||
plasma::PlasmaClient &ObjectStoreClient::GetClient() { return client_one_; };
|
||||
|
||||
plasma::PlasmaClient &ObjectStoreClient::GetClientOther() { return client_two_; };
|
||||
} // namespace ray
|
||||
@@ -0,0 +1,38 @@
|
||||
#include "object_store_client_pool.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
ObjectStoreClientPool::ObjectStoreClientPool(const std::string &store_socket_name)
|
||||
: store_socket_name_(store_socket_name) {}
|
||||
|
||||
std::shared_ptr<plasma::PlasmaClient> ObjectStoreClientPool::GetObjectStore() {
|
||||
std::lock_guard<std::mutex> lock(pool_mutex);
|
||||
if (available_clients.empty()) {
|
||||
Add();
|
||||
}
|
||||
std::shared_ptr<plasma::PlasmaClient> client = available_clients.back();
|
||||
available_clients.pop_back();
|
||||
return client;
|
||||
}
|
||||
|
||||
void ObjectStoreClientPool::ReleaseObjectStore(
|
||||
std::shared_ptr<plasma::PlasmaClient> client) {
|
||||
std::lock_guard<std::mutex> lock(pool_mutex);
|
||||
available_clients.push_back(client);
|
||||
}
|
||||
|
||||
void ObjectStoreClientPool::Terminate() {
|
||||
for (const auto &client : clients) {
|
||||
ARROW_CHECK_OK(client->Disconnect());
|
||||
}
|
||||
available_clients.clear();
|
||||
clients.clear();
|
||||
}
|
||||
|
||||
void ObjectStoreClientPool::Add() {
|
||||
clients.emplace_back(new plasma::PlasmaClient());
|
||||
ARROW_CHECK_OK(clients.back()->Connect(store_socket_name_.c_str(), "",
|
||||
PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
available_clients.push_back(clients.back());
|
||||
}
|
||||
} // namespace ray
|
||||
@@ -0,0 +1,65 @@
|
||||
#ifndef RAY_OBJECT_MANAGER_OBJECT_STORE_CLIENT_POOL_H
|
||||
#define RAY_OBJECT_MANAGER_OBJECT_STORE_CLIENT_POOL_H
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/asio/error.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
|
||||
#include "plasma/client.h"
|
||||
#include "plasma/events.h"
|
||||
#include "plasma/plasma.h"
|
||||
|
||||
#include "object_directory.h"
|
||||
#include "ray/id.h"
|
||||
#include "ray/status.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
/// \class ObjectStoreClientPool
|
||||
///
|
||||
/// Provides connections to the object store. Enables concurrent communication with
|
||||
/// the object store.
|
||||
class ObjectStoreClientPool {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param store_socket_name The object store socket name.
|
||||
ObjectStoreClientPool(const std::string &store_socket_name);
|
||||
|
||||
/// This object cannot be copied due to pool_mutex.
|
||||
RAY_DISALLOW_COPY_AND_ASSIGN(ObjectStoreClientPool);
|
||||
|
||||
/// Provides a connection to the object store from the object store pool.
|
||||
/// This removes the object store client from the pool of available clients.
|
||||
///
|
||||
/// \return A connection to the object store.
|
||||
std::shared_ptr<plasma::PlasmaClient> GetObjectStore();
|
||||
|
||||
/// Releases a client object and puts it back into the object store pool
|
||||
/// for reuse.
|
||||
/// Once a client is released, it is assumed that it is not being used.
|
||||
/// \param client The client to return.
|
||||
/// \param client
|
||||
void ReleaseObjectStore(std::shared_ptr<plasma::PlasmaClient> client);
|
||||
|
||||
/// Terminates this object.
|
||||
void Terminate();
|
||||
|
||||
private:
|
||||
/// Adds a client to the client pool and mark it as available.
|
||||
void Add();
|
||||
|
||||
std::mutex pool_mutex;
|
||||
std::vector<std::shared_ptr<plasma::PlasmaClient>> available_clients;
|
||||
std::vector<std::shared_ptr<plasma::PlasmaClient>> clients;
|
||||
std::string store_socket_name_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_OBJECT_MANAGER_OBJECT_STORE_CLIENT_POOL_H
|
||||
@@ -0,0 +1,90 @@
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
#include <boost/function.hpp>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "common/common_protocol.h"
|
||||
|
||||
#include "ray/object_manager/object_store_notification_manager.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
ObjectStoreNotificationManager::ObjectStoreNotificationManager(
|
||||
boost::asio::io_service &io_service, const std::string &store_socket_name)
|
||||
: store_client_(), socket_(io_service) {
|
||||
ARROW_CHECK_OK(
|
||||
store_client_.Connect(store_socket_name.c_str(), "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
|
||||
ARROW_CHECK_OK(store_client_.Subscribe(&c_socket_));
|
||||
boost::system::error_code ec;
|
||||
socket_.assign(boost::asio::local::stream_protocol(), c_socket_, ec);
|
||||
assert(!ec.value());
|
||||
NotificationWait();
|
||||
}
|
||||
|
||||
void ObjectStoreNotificationManager::Terminate() {
|
||||
ARROW_CHECK_OK(store_client_.Disconnect());
|
||||
}
|
||||
|
||||
void ObjectStoreNotificationManager::NotificationWait() {
|
||||
boost::asio::async_read(socket_, boost::asio::buffer(&length_, sizeof(length_)),
|
||||
boost::bind(&ObjectStoreNotificationManager::ProcessStoreLength,
|
||||
this, boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void ObjectStoreNotificationManager::ProcessStoreLength(
|
||||
const boost::system::error_code &error) {
|
||||
notification_.resize(length_);
|
||||
boost::asio::async_read(
|
||||
socket_, boost::asio::buffer(notification_),
|
||||
boost::bind(&ObjectStoreNotificationManager::ProcessStoreNotification, this,
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void ObjectStoreNotificationManager::ProcessStoreNotification(
|
||||
const boost::system::error_code &error) {
|
||||
if (error) {
|
||||
throw std::runtime_error("ObjectStore may have died.");
|
||||
}
|
||||
|
||||
const auto &object_info = flatbuffers::GetRoot<ObjectInfo>(notification_.data());
|
||||
const auto &object_id = from_flatbuf(*object_info->object_id());
|
||||
if (object_info->is_deletion()) {
|
||||
ProcessStoreRemove(object_id);
|
||||
} else {
|
||||
ProcessStoreAdd(object_id);
|
||||
// TODO(hme): Determine what data is actually needed by consumer of this notification.
|
||||
// ProcessStoreAdd(
|
||||
// object_id, object_info->data_size(),
|
||||
// object_info->metadata_size(),
|
||||
// (unsigned char *) object_info->digest()->data());
|
||||
}
|
||||
NotificationWait();
|
||||
}
|
||||
|
||||
void ObjectStoreNotificationManager::ProcessStoreAdd(const ObjectID &object_id) {
|
||||
for (auto handler : add_handlers_) {
|
||||
handler(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
void ObjectStoreNotificationManager::ProcessStoreRemove(const ObjectID &object_id) {
|
||||
for (auto handler : rem_handlers_) {
|
||||
handler(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
void ObjectStoreNotificationManager::SubscribeObjAdded(
|
||||
std::function<void(const ObjectID &)> callback) {
|
||||
add_handlers_.push_back(callback);
|
||||
}
|
||||
|
||||
void ObjectStoreNotificationManager::SubscribeObjDeleted(
|
||||
std::function<void(const ObjectID &)> callback) {
|
||||
rem_handlers_.push_back(callback);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
+32
-27
@@ -13,53 +13,58 @@
|
||||
#include "plasma/events.h"
|
||||
#include "plasma/plasma.h"
|
||||
|
||||
#include "object_directory.h"
|
||||
#include "ray/id.h"
|
||||
#include "ray/status.h"
|
||||
|
||||
#include "ray/object_manager/object_directory.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
// TODO(hme): document public API after refactor.
|
||||
class ObjectStoreClient {
|
||||
/// \class ObjectStoreClientPool
|
||||
///
|
||||
/// Encapsulates notification handling from the object store.
|
||||
class ObjectStoreNotificationManager {
|
||||
public:
|
||||
// Encapsulates communication with the object store.
|
||||
ObjectStoreClient(boost::asio::io_service &io_service, std::string &store_socket_name);
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param io_service The asio service to be used.
|
||||
/// \param store_socket_name The store socket to connect to.
|
||||
ObjectStoreNotificationManager(boost::asio::io_service &io_service,
|
||||
const std::string &store_socket_name);
|
||||
|
||||
// Subscribe to notifications of objects added to local store.
|
||||
// Upon subscribing, the callback will be invoked for all objects that
|
||||
// already exist in the local store.
|
||||
/// Subscribe to notifications of objects added to local store.
|
||||
/// Upon subscribing, the callback will be invoked for all objects that
|
||||
/// already exist in the local store
|
||||
///
|
||||
/// \param callback A callback expecting an ObjectID.
|
||||
void SubscribeObjAdded(std::function<void(const ray::ObjectID &)> callback);
|
||||
|
||||
// Subscribe to notifications of objects deleted from local store.
|
||||
/// Subscribe to notifications of objects deleted from local store.
|
||||
///
|
||||
/// \param callback A callback expecting an ObjectID.
|
||||
void SubscribeObjDeleted(std::function<void(const ray::ObjectID &)> callback);
|
||||
|
||||
// TODO(hme): There should be as many client connections as there are threads.
|
||||
// Two client connections are made to enable concurrent communication with the store.
|
||||
plasma::PlasmaClient &GetClient();
|
||||
plasma::PlasmaClient &GetClientOther();
|
||||
|
||||
// Terminate this object.
|
||||
/// Terminate this object.
|
||||
void Terminate();
|
||||
|
||||
private:
|
||||
std::vector<std::function<void(const ray::ObjectID &)>> add_handlers_;
|
||||
std::vector<std::function<void(const ray::ObjectID &)>> rem_handlers_;
|
||||
|
||||
plasma::PlasmaClient client_one_;
|
||||
plasma::PlasmaClient client_two_;
|
||||
int c_socket_;
|
||||
int64_t length_;
|
||||
std::vector<uint8_t> notification_;
|
||||
boost::asio::local::stream_protocol::socket socket_;
|
||||
|
||||
// Async loop for handling object store notifications.
|
||||
/// Async loop for handling object store notifications.
|
||||
void NotificationWait();
|
||||
void ProcessStoreLength(const boost::system::error_code &error);
|
||||
void ProcessStoreNotification(const boost::system::error_code &error);
|
||||
|
||||
// Support for rebroadcasting object add/rem events.
|
||||
/// Support for rebroadcasting object add/rem events.
|
||||
void ProcessStoreAdd(const ObjectID &object_id);
|
||||
void ProcessStoreRemove(const ObjectID &object_id);
|
||||
|
||||
std::vector<std::function<void(const ray::ObjectID &)>> add_handlers_;
|
||||
std::vector<std::function<void(const ray::ObjectID &)>> rem_handlers_;
|
||||
|
||||
plasma::PlasmaClient store_client_;
|
||||
int c_socket_;
|
||||
int64_t length_;
|
||||
std::vector<uint8_t> notification_;
|
||||
boost::asio::local::stream_protocol::socket socket_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
@@ -0,0 +1,440 @@
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/object_manager/object_manager.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
std::string store_executable;
|
||||
|
||||
static inline void flushall_redis(void) {
|
||||
redisContext *context = redisConnect("127.0.0.1", 6379);
|
||||
freeReplyObject(redisCommand(context, "FLUSHALL"));
|
||||
redisFree(context);
|
||||
}
|
||||
|
||||
int64_t current_time_ms() {
|
||||
std::chrono::milliseconds ms_since_epoch =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::steady_clock::now().time_since_epoch());
|
||||
return ms_since_epoch.count();
|
||||
}
|
||||
|
||||
class MockServer {
|
||||
public:
|
||||
MockServer(boost::asio::io_service &main_service,
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service,
|
||||
const ObjectManagerConfig &object_manager_config,
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
|
||||
: object_manager_acceptor_(
|
||||
main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
|
||||
object_manager_socket_(main_service),
|
||||
gcs_client_(gcs_client),
|
||||
object_manager_(main_service, std::move(object_manager_service),
|
||||
object_manager_config, gcs_client) {
|
||||
RAY_CHECK_OK(RegisterGcs(main_service));
|
||||
// Start listening for clients.
|
||||
DoAcceptObjectManager();
|
||||
}
|
||||
|
||||
~MockServer() {
|
||||
RAY_CHECK_OK(gcs_client_->client_table().Disconnect());
|
||||
RAY_CHECK_OK(object_manager_.Terminate());
|
||||
}
|
||||
|
||||
private:
|
||||
ray::Status RegisterGcs(boost::asio::io_service &io_service) {
|
||||
RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379));
|
||||
RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service));
|
||||
|
||||
boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint();
|
||||
std::string ip = endpoint.address().to_string();
|
||||
unsigned short object_manager_port = endpoint.port();
|
||||
|
||||
ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient();
|
||||
client_info.node_manager_address = ip;
|
||||
client_info.node_manager_port = object_manager_port;
|
||||
client_info.object_manager_port = object_manager_port;
|
||||
return gcs_client_->client_table().Connect(client_info);
|
||||
}
|
||||
|
||||
void DoAcceptObjectManager() {
|
||||
object_manager_acceptor_.async_accept(
|
||||
object_manager_socket_, boost::bind(&MockServer::HandleAcceptObjectManager, this,
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void HandleAcceptObjectManager(const boost::system::error_code &error) {
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler =
|
||||
[this](std::shared_ptr<TcpClientConnection> client) {
|
||||
object_manager_.ProcessNewClient(client);
|
||||
};
|
||||
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
|
||||
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
object_manager_.ProcessClientMessage(client, message_type, message);
|
||||
};
|
||||
// Accept a new local client and dispatch it to the node manager.
|
||||
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
|
||||
std::move(object_manager_socket_));
|
||||
DoAcceptObjectManager();
|
||||
}
|
||||
|
||||
friend class StressTestObjectManager;
|
||||
|
||||
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
|
||||
boost::asio::ip::tcp::socket object_manager_socket_;
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
|
||||
ObjectManager object_manager_;
|
||||
};
|
||||
|
||||
class TestObjectManagerBase : public ::testing::Test {
|
||||
public:
|
||||
TestObjectManagerBase() {}
|
||||
|
||||
std::string StartStore(const std::string &id) {
|
||||
std::string store_id = "/tmp/store";
|
||||
store_id = store_id + id;
|
||||
std::string plasma_command = store_executable + " -m 1000000000 -s " + store_id +
|
||||
" 1> /dev/null 2> /dev/null &";
|
||||
RAY_LOG(DEBUG) << plasma_command;
|
||||
int ec = system(plasma_command.c_str());
|
||||
if (ec != 0) {
|
||||
throw std::runtime_error("failed to start plasma store.");
|
||||
};
|
||||
return store_id;
|
||||
}
|
||||
|
||||
void SetUp() {
|
||||
flushall_redis();
|
||||
|
||||
object_manager_service_1.reset(new boost::asio::io_service());
|
||||
object_manager_service_2.reset(new boost::asio::io_service());
|
||||
|
||||
// start store
|
||||
std::string store_sock_1 = StartStore("1");
|
||||
std::string store_sock_2 = StartStore("2");
|
||||
|
||||
// start first server
|
||||
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
|
||||
ObjectManagerConfig om_config_1;
|
||||
om_config_1.store_socket_name = store_sock_1;
|
||||
om_config_1.num_threads = 4;
|
||||
om_config_1.max_sends = 20;
|
||||
om_config_1.max_receives = 20;
|
||||
server1.reset(new MockServer(main_service, std::move(object_manager_service_1),
|
||||
om_config_1, gcs_client_1));
|
||||
|
||||
// start second server
|
||||
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
|
||||
ObjectManagerConfig om_config_2;
|
||||
om_config_2.store_socket_name = store_sock_2;
|
||||
om_config_2.num_threads = 4;
|
||||
om_config_2.max_sends = 20;
|
||||
om_config_2.max_receives = 20;
|
||||
server2.reset(new MockServer(main_service, std::move(object_manager_service_2),
|
||||
om_config_2, gcs_client_2));
|
||||
|
||||
// connect to stores.
|
||||
ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
}
|
||||
|
||||
void TearDown() {
|
||||
arrow::Status client1_status = client1.Disconnect();
|
||||
arrow::Status client2_status = client2.Disconnect();
|
||||
ASSERT_TRUE(client1_status.ok() && client2_status.ok());
|
||||
|
||||
this->server1.reset();
|
||||
this->server2.reset();
|
||||
|
||||
int s = system("killall plasma_store &");
|
||||
ASSERT_TRUE(!s);
|
||||
}
|
||||
|
||||
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
|
||||
ObjectID object_id = ObjectID::from_random();
|
||||
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id;
|
||||
uint8_t metadata[] = {5};
|
||||
int64_t metadata_size = sizeof(metadata);
|
||||
std::shared_ptr<Buffer> data;
|
||||
ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata,
|
||||
metadata_size, &data));
|
||||
ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id()));
|
||||
return object_id;
|
||||
}
|
||||
|
||||
void object_added_handler_1(ObjectID object_id) { v1.push_back(object_id); };
|
||||
|
||||
void object_added_handler_2(ObjectID object_id) { v2.push_back(object_id); };
|
||||
|
||||
protected:
|
||||
std::thread p;
|
||||
boost::asio::io_service main_service;
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service_1;
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service_2;
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_1;
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_2;
|
||||
std::unique_ptr<MockServer> server1;
|
||||
std::unique_ptr<MockServer> server2;
|
||||
|
||||
plasma::PlasmaClient client1;
|
||||
plasma::PlasmaClient client2;
|
||||
std::vector<ObjectID> v1;
|
||||
std::vector<ObjectID> v2;
|
||||
};
|
||||
|
||||
class StressTestObjectManager : public TestObjectManagerBase {
|
||||
public:
|
||||
enum TransferPattern {
|
||||
PUSH_A_B,
|
||||
PUSH_B_A,
|
||||
BIDIRECTIONAL_PUSH,
|
||||
PULL_A_B,
|
||||
PULL_B_A,
|
||||
BIDIRECTIONAL_PULL,
|
||||
BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE,
|
||||
};
|
||||
|
||||
int async_loop_index = -1;
|
||||
uint num_expected_objects;
|
||||
|
||||
std::vector<TransferPattern> async_loop_patterns = {
|
||||
PUSH_A_B,
|
||||
PUSH_B_A,
|
||||
BIDIRECTIONAL_PUSH,
|
||||
PULL_A_B,
|
||||
PULL_B_A,
|
||||
BIDIRECTIONAL_PULL,
|
||||
BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE};
|
||||
|
||||
int num_connected_clients = 0;
|
||||
|
||||
ClientID client_id_1;
|
||||
ClientID client_id_2;
|
||||
|
||||
int64_t start_time;
|
||||
|
||||
void WaitConnections() {
|
||||
client_id_1 = gcs_client_1->client_table().GetLocalClientId();
|
||||
client_id_2 = gcs_client_2->client_table().GetLocalClientId();
|
||||
gcs_client_1->client_table().RegisterClientAddedCallback([this](
|
||||
gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
|
||||
ClientID parsed_id = ClientID::from_binary(data.client_id);
|
||||
if (parsed_id == client_id_1 || parsed_id == client_id_2) {
|
||||
num_connected_clients += 1;
|
||||
}
|
||||
if (num_connected_clients == 2) {
|
||||
StartTests();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void StartTests() {
|
||||
TestConnections();
|
||||
AddTransferTestHandlers();
|
||||
TransferTestNext();
|
||||
}
|
||||
|
||||
void AddTransferTestHandlers() {
|
||||
ray::Status status = ray::Status::OK();
|
||||
status =
|
||||
server1->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
|
||||
object_added_handler_1(object_id);
|
||||
if (v1.size() == num_expected_objects && v1.size() == v2.size()) {
|
||||
TransferTestComplete();
|
||||
}
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
status =
|
||||
server2->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
|
||||
object_added_handler_2(object_id);
|
||||
if (v2.size() == num_expected_objects && v1.size() == v2.size()) {
|
||||
TransferTestComplete();
|
||||
}
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
}
|
||||
|
||||
void TransferTestNext() {
|
||||
async_loop_index += 1;
|
||||
if ((uint)async_loop_index < async_loop_patterns.size()) {
|
||||
TransferPattern pattern = async_loop_patterns[async_loop_index];
|
||||
TransferTestExecute(1000, 100, pattern);
|
||||
} else {
|
||||
main_service.stop();
|
||||
}
|
||||
}
|
||||
|
||||
plasma::ObjectBuffer GetObject(plasma::PlasmaClient &client, ObjectID &object_id) {
|
||||
plasma::ObjectBuffer object_buffer;
|
||||
plasma::ObjectID plasma_id = object_id.to_plasma_id();
|
||||
ARROW_CHECK_OK(client.Get(&plasma_id, 1, 0, &object_buffer));
|
||||
return object_buffer;
|
||||
}
|
||||
|
||||
static unsigned char *GetDigest(plasma::PlasmaClient &client, ObjectID &object_id) {
|
||||
const int64_t size = sizeof(uint64_t);
|
||||
static unsigned char digest_1[size];
|
||||
ARROW_CHECK_OK(client.Hash(object_id.to_plasma_id(), &digest_1[0]));
|
||||
return digest_1;
|
||||
}
|
||||
|
||||
void CompareObjects(ObjectID &object_id_1, ObjectID &object_id_2) {
|
||||
plasma::ObjectBuffer object_buffer_1 = GetObject(client1, object_id_1);
|
||||
plasma::ObjectBuffer object_buffer_2 = GetObject(client1, object_id_1);
|
||||
uint8_t *data_1 = const_cast<uint8_t *>(object_buffer_1.data->data());
|
||||
uint8_t *data_2 = const_cast<uint8_t *>(object_buffer_2.data->data());
|
||||
ASSERT_EQ(object_buffer_1.data_size, object_buffer_2.data_size);
|
||||
for (int i = -1; ++i < object_buffer_1.data_size;) {
|
||||
ASSERT_TRUE(data_1[i] == data_2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void CompareHashes(ObjectID &object_id_1, ObjectID &object_id_2) {
|
||||
const int64_t size = sizeof(uint64_t);
|
||||
static unsigned char *digest_1 = GetDigest(client1, object_id_1);
|
||||
static unsigned char *digest_2 = GetDigest(client2, object_id_2);
|
||||
for (int i = -1; ++i < size;) {
|
||||
ASSERT_TRUE(digest_1[i] == digest_2[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void TransferTestComplete() {
|
||||
int64_t elapsed = current_time_ms() - start_time;
|
||||
RAY_LOG(INFO) << "TransferTestComplete: " << async_loop_patterns[async_loop_index]
|
||||
<< " " << v1.size() << " " << elapsed;
|
||||
ASSERT_TRUE(v1.size() == v2.size());
|
||||
for (uint i = 0; i < v1.size(); ++i) {
|
||||
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
|
||||
}
|
||||
|
||||
// Compare objects and their hashes.
|
||||
for (uint i = 0; i < v1.size(); ++i) {
|
||||
ObjectID object_id_2 = v2[i];
|
||||
ObjectID object_id_1 =
|
||||
v1[std::distance(v1.begin(), std::find(v1.begin(), v1.end(), v2[i]))];
|
||||
CompareHashes(object_id_1, object_id_2);
|
||||
CompareObjects(object_id_1, object_id_2);
|
||||
}
|
||||
|
||||
v1.clear();
|
||||
v2.clear();
|
||||
TransferTestNext();
|
||||
}
|
||||
|
||||
void TransferTestExecute(int num_trials, int64_t data_size,
|
||||
TransferPattern transfer_pattern) {
|
||||
ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId();
|
||||
ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId();
|
||||
|
||||
ray::Status status = ray::Status::OK();
|
||||
|
||||
if (transfer_pattern == BIDIRECTIONAL_PULL ||
|
||||
transfer_pattern == BIDIRECTIONAL_PUSH ||
|
||||
transfer_pattern == BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE) {
|
||||
num_expected_objects = (uint)2 * num_trials;
|
||||
} else {
|
||||
num_expected_objects = (uint)num_trials;
|
||||
}
|
||||
|
||||
start_time = current_time_ms();
|
||||
|
||||
switch (transfer_pattern) {
|
||||
case PUSH_A_B: {
|
||||
for (int i = -1; ++i < num_trials;) {
|
||||
ObjectID oid1 = WriteDataToClient(client1, data_size);
|
||||
status = server1->object_manager_.Push(oid1, client_id_2);
|
||||
}
|
||||
} break;
|
||||
case PUSH_B_A: {
|
||||
for (int i = -1; ++i < num_trials;) {
|
||||
ObjectID oid2 = WriteDataToClient(client2, data_size);
|
||||
status = server2->object_manager_.Push(oid2, client_id_1);
|
||||
}
|
||||
} break;
|
||||
case BIDIRECTIONAL_PUSH: {
|
||||
for (int i = -1; ++i < num_trials;) {
|
||||
ObjectID oid1 = WriteDataToClient(client1, data_size);
|
||||
status = server1->object_manager_.Push(oid1, client_id_2);
|
||||
ObjectID oid2 = WriteDataToClient(client2, data_size);
|
||||
status = server2->object_manager_.Push(oid2, client_id_1);
|
||||
}
|
||||
} break;
|
||||
case PULL_A_B: {
|
||||
for (int i = -1; ++i < num_trials;) {
|
||||
ObjectID oid1 = WriteDataToClient(client1, data_size);
|
||||
status = server2->object_manager_.Pull(oid1);
|
||||
}
|
||||
} break;
|
||||
case PULL_B_A: {
|
||||
for (int i = -1; ++i < num_trials;) {
|
||||
ObjectID oid2 = WriteDataToClient(client2, data_size);
|
||||
status = server1->object_manager_.Pull(oid2);
|
||||
}
|
||||
} break;
|
||||
case BIDIRECTIONAL_PULL: {
|
||||
for (int i = -1; ++i < num_trials;) {
|
||||
ObjectID oid1 = WriteDataToClient(client1, data_size);
|
||||
status = server2->object_manager_.Pull(oid1);
|
||||
ObjectID oid2 = WriteDataToClient(client2, data_size);
|
||||
status = server1->object_manager_.Pull(oid2);
|
||||
}
|
||||
} break;
|
||||
case BIDIRECTIONAL_PULL_VARIABLE_DATA_SIZE: {
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<> dis(1, 50);
|
||||
for (int i = -1; ++i < num_trials;) {
|
||||
ObjectID oid1 = WriteDataToClient(client1, data_size + dis(gen));
|
||||
status = server2->object_manager_.Pull(oid1);
|
||||
ObjectID oid2 = WriteDataToClient(client2, data_size + dis(gen));
|
||||
status = server1->object_manager_.Pull(oid2);
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
RAY_LOG(FATAL) << "No case for transfer_pattern " << transfer_pattern;
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void TestConnections() {
|
||||
RAY_LOG(DEBUG) << "\n"
|
||||
<< "Server client ids:"
|
||||
<< "\n";
|
||||
ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId();
|
||||
ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId();
|
||||
RAY_LOG(DEBUG) << "Server 1: " << client_id_1 << "\n"
|
||||
<< "Server 2: " << client_id_2;
|
||||
|
||||
RAY_LOG(DEBUG) << "\n"
|
||||
<< "All connected clients:"
|
||||
<< "\n";
|
||||
const ClientTableDataT &data = gcs_client_1->client_table().GetClient(client_id_1);
|
||||
RAY_LOG(DEBUG) << "ClientID=" << ClientID::from_binary(data.client_id) << "\n"
|
||||
<< "ClientIp=" << data.node_manager_address << "\n"
|
||||
<< "ClientPort=" << data.node_manager_port;
|
||||
const ClientTableDataT &data2 = gcs_client_1->client_table().GetClient(client_id_2);
|
||||
RAY_LOG(DEBUG) << "ClientID=" << ClientID::from_binary(data2.client_id) << "\n"
|
||||
<< "ClientIp=" << data2.node_manager_address << "\n"
|
||||
<< "ClientPort=" << data2.node_manager_port;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(StressTestObjectManager, StartStressTestObjectManager) {
|
||||
auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); });
|
||||
AsyncStartTests();
|
||||
main_service.run();
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
ray::store_executable = std::string(argv[1]);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -0,0 +1,256 @@
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/object_manager/object_manager.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
static inline void flushall_redis(void) {
|
||||
redisContext *context = redisConnect("127.0.0.1", 6379);
|
||||
freeReplyObject(redisCommand(context, "FLUSHALL"));
|
||||
redisFree(context);
|
||||
}
|
||||
|
||||
std::string store_executable;
|
||||
|
||||
class MockServer {
|
||||
public:
|
||||
MockServer(boost::asio::io_service &main_service,
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service,
|
||||
const ObjectManagerConfig &object_manager_config,
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
|
||||
: object_manager_acceptor_(
|
||||
main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
|
||||
object_manager_socket_(main_service),
|
||||
gcs_client_(gcs_client),
|
||||
object_manager_(main_service, std::move(object_manager_service),
|
||||
object_manager_config, gcs_client) {
|
||||
RAY_CHECK_OK(RegisterGcs(main_service));
|
||||
// Start listening for clients.
|
||||
DoAcceptObjectManager();
|
||||
}
|
||||
|
||||
~MockServer() {
|
||||
RAY_CHECK_OK(gcs_client_->client_table().Disconnect());
|
||||
RAY_CHECK_OK(object_manager_.Terminate());
|
||||
}
|
||||
|
||||
private:
|
||||
ray::Status RegisterGcs(boost::asio::io_service &io_service) {
|
||||
RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379));
|
||||
RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service));
|
||||
|
||||
boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint();
|
||||
std::string ip = endpoint.address().to_string();
|
||||
unsigned short object_manager_port = endpoint.port();
|
||||
|
||||
ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient();
|
||||
client_info.node_manager_address = ip;
|
||||
client_info.node_manager_port = object_manager_port;
|
||||
client_info.object_manager_port = object_manager_port;
|
||||
return gcs_client_->client_table().Connect(client_info);
|
||||
}
|
||||
|
||||
void DoAcceptObjectManager() {
|
||||
object_manager_acceptor_.async_accept(
|
||||
object_manager_socket_, boost::bind(&MockServer::HandleAcceptObjectManager, this,
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void HandleAcceptObjectManager(const boost::system::error_code &error) {
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler =
|
||||
[this](std::shared_ptr<TcpClientConnection> client) {
|
||||
object_manager_.ProcessNewClient(client);
|
||||
};
|
||||
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
|
||||
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
object_manager_.ProcessClientMessage(client, message_type, message);
|
||||
};
|
||||
// Accept a new local client and dispatch it to the node manager.
|
||||
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
|
||||
std::move(object_manager_socket_));
|
||||
DoAcceptObjectManager();
|
||||
}
|
||||
|
||||
friend class TestObjectManagerCommands;
|
||||
|
||||
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
|
||||
boost::asio::ip::tcp::socket object_manager_socket_;
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
|
||||
ObjectManager object_manager_;
|
||||
};
|
||||
|
||||
class TestObjectManager : public ::testing::Test {
|
||||
public:
|
||||
TestObjectManager() {}
|
||||
|
||||
std::string StartStore(const std::string &id) {
|
||||
std::string store_id = "/tmp/store";
|
||||
store_id = store_id + id;
|
||||
std::string plasma_command = store_executable + " -m 1000000000 -s " + store_id +
|
||||
" 1> /dev/null 2> /dev/null &";
|
||||
RAY_LOG(DEBUG) << plasma_command;
|
||||
int ec = system(plasma_command.c_str());
|
||||
if (ec != 0) {
|
||||
throw std::runtime_error("failed to start plasma store.");
|
||||
};
|
||||
return store_id;
|
||||
}
|
||||
|
||||
void SetUp() {
|
||||
flushall_redis();
|
||||
|
||||
object_manager_service_1.reset(new boost::asio::io_service());
|
||||
object_manager_service_2.reset(new boost::asio::io_service());
|
||||
|
||||
// start store
|
||||
std::string store_sock_1 = StartStore("1");
|
||||
std::string store_sock_2 = StartStore("2");
|
||||
|
||||
// start first server
|
||||
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
|
||||
ObjectManagerConfig om_config_1;
|
||||
om_config_1.store_socket_name = store_sock_1;
|
||||
server1.reset(new MockServer(main_service, std::move(object_manager_service_1),
|
||||
om_config_1, gcs_client_1));
|
||||
|
||||
// start second server
|
||||
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
|
||||
ObjectManagerConfig om_config_2;
|
||||
om_config_2.store_socket_name = store_sock_2;
|
||||
server2.reset(new MockServer(main_service, std::move(object_manager_service_2),
|
||||
om_config_2, gcs_client_2));
|
||||
|
||||
// connect to stores.
|
||||
ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
}
|
||||
|
||||
void TearDown() {
|
||||
arrow::Status client1_status = client1.Disconnect();
|
||||
arrow::Status client2_status = client2.Disconnect();
|
||||
ASSERT_TRUE(client1_status.ok() && client2_status.ok());
|
||||
|
||||
this->server1.reset();
|
||||
this->server2.reset();
|
||||
|
||||
int s = system("killall plasma_store &");
|
||||
ASSERT_TRUE(!s);
|
||||
}
|
||||
|
||||
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
|
||||
ObjectID object_id = ObjectID::from_random();
|
||||
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id;
|
||||
uint8_t metadata[] = {5};
|
||||
int64_t metadata_size = sizeof(metadata);
|
||||
std::shared_ptr<Buffer> data;
|
||||
ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata,
|
||||
metadata_size, &data));
|
||||
ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id()));
|
||||
return object_id;
|
||||
}
|
||||
|
||||
void object_added_handler_1(ObjectID object_id) { v1.push_back(object_id); };
|
||||
|
||||
void object_added_handler_2(ObjectID object_id) { v2.push_back(object_id); };
|
||||
|
||||
protected:
|
||||
std::thread p;
|
||||
boost::asio::io_service main_service;
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service_1;
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service_2;
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_1;
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_2;
|
||||
std::unique_ptr<MockServer> server1;
|
||||
std::unique_ptr<MockServer> server2;
|
||||
|
||||
plasma::PlasmaClient client1;
|
||||
plasma::PlasmaClient client2;
|
||||
std::vector<ObjectID> v1;
|
||||
std::vector<ObjectID> v2;
|
||||
};
|
||||
|
||||
class TestObjectManagerCommands : public TestObjectManager {
|
||||
public:
|
||||
int num_connected_clients = 0;
|
||||
uint num_expected_objects;
|
||||
ClientID client_id_1;
|
||||
ClientID client_id_2;
|
||||
|
||||
ObjectID created_object_id;
|
||||
|
||||
void WaitConnections() {
|
||||
client_id_1 = gcs_client_1->client_table().GetLocalClientId();
|
||||
client_id_2 = gcs_client_2->client_table().GetLocalClientId();
|
||||
gcs_client_1->client_table().RegisterClientAddedCallback([this](
|
||||
gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
|
||||
ClientID parsed_id = ClientID::from_binary(data.client_id);
|
||||
if (parsed_id == client_id_1 || parsed_id == client_id_2) {
|
||||
num_connected_clients += 1;
|
||||
}
|
||||
if (num_connected_clients == 2) {
|
||||
StartTests();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void StartTests() {
|
||||
TestConnections();
|
||||
TestNotifications();
|
||||
}
|
||||
|
||||
void TestNotifications() {
|
||||
ray::Status status = ray::Status::OK();
|
||||
status =
|
||||
server1->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
|
||||
object_added_handler_1(object_id);
|
||||
if (v1.size() == num_expected_objects) {
|
||||
NotificationTestComplete(created_object_id, object_id);
|
||||
}
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
|
||||
num_expected_objects = 1;
|
||||
uint data_size = 1000000;
|
||||
created_object_id = WriteDataToClient(client1, data_size);
|
||||
}
|
||||
|
||||
void NotificationTestComplete(ObjectID object_id_1, ObjectID object_id_2) {
|
||||
ASSERT_EQ(object_id_1, object_id_2);
|
||||
main_service.stop();
|
||||
}
|
||||
|
||||
void TestConnections() {
|
||||
RAY_LOG(DEBUG) << "\n"
|
||||
<< "Server client ids:"
|
||||
<< "\n";
|
||||
const ClientTableDataT &data = gcs_client_1->client_table().GetClient(client_id_1);
|
||||
RAY_LOG(DEBUG) << (ClientID::from_binary(data.client_id) == ClientID::nil());
|
||||
RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::from_binary(data.client_id);
|
||||
RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address;
|
||||
RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port;
|
||||
ASSERT_EQ(client_id_1, ClientID::from_binary(data.client_id));
|
||||
const ClientTableDataT &data2 = gcs_client_1->client_table().GetClient(client_id_2);
|
||||
RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::from_binary(data2.client_id);
|
||||
RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address;
|
||||
RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port;
|
||||
ASSERT_EQ(client_id_2, ClientID::from_binary(data2.client_id));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestObjectManagerCommands, StartTestObjectManagerCommands) {
|
||||
auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); });
|
||||
AsyncStartTests();
|
||||
main_service.run();
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
ray::store_executable = std::string(argv[1]);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
#include "ray/object_manager/transfer_queue.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
void TransferQueue::QueueSend(const ClientID &client_id, const ObjectID &object_id,
|
||||
const RemoteConnectionInfo &info) {
|
||||
WriteLock guard(send_mutex);
|
||||
SendRequest req = {client_id, object_id, info};
|
||||
// TODO(hme): Use a set to speed this up.
|
||||
if (std::find(send_queue_.begin(), send_queue_.end(), req) != send_queue_.end()) {
|
||||
// already queued.
|
||||
return;
|
||||
}
|
||||
send_queue_.push_back(req);
|
||||
}
|
||||
|
||||
void TransferQueue::QueueReceive(const ClientID &client_id, const ObjectID &object_id,
|
||||
uint64_t object_size,
|
||||
std::shared_ptr<TcpClientConnection> conn) {
|
||||
WriteLock guard(receive_mutex);
|
||||
ReceiveRequest req = {client_id, object_id, object_size, conn};
|
||||
if (std::find(receive_queue_.begin(), receive_queue_.end(), req) !=
|
||||
receive_queue_.end()) {
|
||||
// already queued.
|
||||
return;
|
||||
}
|
||||
receive_queue_.push_back(req);
|
||||
}
|
||||
|
||||
bool TransferQueue::DequeueSendIfPresent(TransferQueue::SendRequest *send_ptr) {
|
||||
WriteLock guard(send_mutex);
|
||||
if (send_queue_.empty()) {
|
||||
return false;
|
||||
}
|
||||
*send_ptr = send_queue_.front();
|
||||
send_queue_.pop_front();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TransferQueue::DequeueReceiveIfPresent(TransferQueue::ReceiveRequest *receive_ptr) {
|
||||
WriteLock guard(receive_mutex);
|
||||
if (receive_queue_.empty()) {
|
||||
return false;
|
||||
}
|
||||
*receive_ptr = receive_queue_.front();
|
||||
receive_queue_.pop_front();
|
||||
return true;
|
||||
}
|
||||
|
||||
UniqueID TransferQueue::AddContext(SendContext &context) {
|
||||
WriteLock guard(context_mutex);
|
||||
UniqueID id = UniqueID::from_random();
|
||||
send_context_set_.emplace(id, context);
|
||||
return id;
|
||||
}
|
||||
|
||||
TransferQueue::SendContext &TransferQueue::GetContext(const UniqueID &id) {
|
||||
ReadLock guard(context_mutex);
|
||||
return send_context_set_[id];
|
||||
}
|
||||
|
||||
ray::Status TransferQueue::RemoveContext(const UniqueID &id) {
|
||||
WriteLock guard(context_mutex);
|
||||
send_context_set_.erase(id);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace ray
|
||||
@@ -0,0 +1,126 @@
|
||||
#ifndef RAY_OBJECT_MANAGER_TRANSFER_QUEUE_H
|
||||
#define RAY_OBJECT_MANAGER_TRANSFER_QUEUE_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/asio/error.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
|
||||
#include "ray/id.h"
|
||||
#include "ray/status.h"
|
||||
|
||||
#include "ray/object_manager/format/object_manager_generated.h"
|
||||
#include "ray/object_manager/object_directory.h"
|
||||
#include "ray/object_manager/object_manager_client_connection.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
class TransferQueue {
|
||||
public:
|
||||
enum TransferType { SEND = 1, RECEIVE };
|
||||
|
||||
/// Context maintained during an object send.
|
||||
struct SendContext {
|
||||
ClientID client_id;
|
||||
ObjectID object_id;
|
||||
uint64_t object_size;
|
||||
uint8_t *data;
|
||||
};
|
||||
|
||||
/// The structure used in the send queue.
|
||||
struct SendRequest {
|
||||
ClientID client_id;
|
||||
ObjectID object_id;
|
||||
RemoteConnectionInfo connection_info;
|
||||
bool operator==(const SendRequest &rhs) const {
|
||||
return client_id == rhs.client_id && object_id == rhs.object_id;
|
||||
}
|
||||
};
|
||||
|
||||
/// The structure used in the receive queue.
|
||||
struct ReceiveRequest {
|
||||
ClientID client_id;
|
||||
ObjectID object_id;
|
||||
uint64_t object_size;
|
||||
std::shared_ptr<TcpClientConnection> conn;
|
||||
bool operator==(const ReceiveRequest &rhs) const {
|
||||
return client_id == rhs.client_id && object_id == rhs.object_id;
|
||||
}
|
||||
};
|
||||
|
||||
/// Queues a send.
|
||||
///
|
||||
/// \param client_id The ClientID to which the object needs to be sent.
|
||||
/// \param object_id The ObjectID of the object to be sent.
|
||||
void QueueSend(const ClientID &client_id, const ObjectID &object_id,
|
||||
const RemoteConnectionInfo &info);
|
||||
|
||||
/// If send_queue_ is not empty, removes a SendRequest from send_queue_ and assigns
|
||||
/// it to send_ptr. The queue is FIFO.
|
||||
/// \param send_ptr A pointer to an empty SendRequest.
|
||||
/// \return A bool indicating whether the queue was empty at the time this method
|
||||
/// was invoked.
|
||||
bool DequeueSendIfPresent(TransferQueue::SendRequest *send_ptr);
|
||||
|
||||
/// Queues a receive.
|
||||
///
|
||||
/// \param client_id The ClientID from which the object is being received.
|
||||
/// \param object_id The ObjectID of the object to be received.
|
||||
void QueueReceive(const ClientID &client_id, const ObjectID &object_id,
|
||||
uint64_t object_size, std::shared_ptr<TcpClientConnection> conn);
|
||||
|
||||
/// If receive_queue_ is not empty, removes a ReceiveRequest from receive_queue_ and
|
||||
/// assigns
|
||||
/// it to receive_ptr. The queue is FIFO.
|
||||
/// \param receive_ptr A pointer to an empty ReceiveRequest.
|
||||
/// \return A bool indicating whether the queue was empty at the time this method
|
||||
/// was invoked.
|
||||
bool DequeueReceiveIfPresent(TransferQueue::ReceiveRequest *receive_ptr);
|
||||
|
||||
/// Maintain ownership over SendContext for sends in transit.
|
||||
///
|
||||
/// \param context The context to maintain.
|
||||
/// \return A unique identifier identifying the context that was added.
|
||||
UniqueID AddContext(SendContext &context);
|
||||
|
||||
/// Gets the SendContext associated with the given id.
|
||||
///
|
||||
/// \param id The unique identifier of the context.
|
||||
/// \return The context.
|
||||
SendContext &GetContext(const UniqueID &id);
|
||||
|
||||
/// Removes the context associated with the given id.
|
||||
///
|
||||
/// \param id The unique identifier of the context.
|
||||
/// \return The status of invoking this method.
|
||||
ray::Status RemoveContext(const UniqueID &id);
|
||||
|
||||
/// This object cannot be copied for thread-safety.
|
||||
TransferQueue &operator=(const TransferQueue &o) {
|
||||
throw std::runtime_error("Can't copy TransferQueue.");
|
||||
}
|
||||
|
||||
private:
|
||||
// TODO(hme): make this a shared mutex.
|
||||
typedef std::mutex Lock;
|
||||
typedef std::unique_lock<Lock> WriteLock;
|
||||
// TODO(hme): make this a shared lock.
|
||||
typedef std::unique_lock<Lock> ReadLock;
|
||||
Lock send_mutex;
|
||||
Lock receive_mutex;
|
||||
Lock context_mutex;
|
||||
|
||||
std::deque<SendRequest> send_queue_;
|
||||
std::deque<ReceiveRequest> receive_queue_;
|
||||
std::unordered_map<ray::UniqueID, SendContext, ray::UniqueIDHasher> send_context_set_;
|
||||
};
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_OBJECT_MANAGER_TRANSFER_QUEUE_H
|
||||
@@ -0,0 +1,18 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
|
||||
from worker import Worker
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("raylet_socket_name")
|
||||
parser.add_argument("object_store_socket_name")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
|
||||
worker = Worker(args.raylet_socket_name, args.object_store_socket_name,
|
||||
is_worker=True)
|
||||
worker.main_loop()
|
||||
@@ -0,0 +1,33 @@
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from worker import Worker, logger
|
||||
from ray.utils import random_string
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("raylet_socket_name")
|
||||
parser.add_argument("object_store_socket_name")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
|
||||
driver = Worker(args.raylet_socket_name, args.object_store_socket_name,
|
||||
is_worker=False)
|
||||
|
||||
task1 = ray.local_scheduler.Task(
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
[],
|
||||
1,
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
0)
|
||||
logger.debug("submitting", task1.task_id())
|
||||
driver.node_manager_client.submit(task1)
|
||||
|
||||
logger.debug("Return values were", task1.returns())
|
||||
print("[DRIVER] Return values were", task1.returns())
|
||||
# Make sure the tasks get executed and we can get the result of the
|
||||
# last task
|
||||
obj = driver.get(task1.returns(), timeout_ms=1000)
|
||||
print("[DRIVER]: task1 driver.get result ", obj)
|
||||
@@ -0,0 +1,40 @@
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from worker import Worker, logger
|
||||
from ray.utils import random_string
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("raylet_socket_name")
|
||||
parser.add_argument("object_store_socket_name")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
|
||||
driver = Worker(args.raylet_socket_name, args.object_store_socket_name,
|
||||
is_worker=False)
|
||||
|
||||
task = ray.local_scheduler.Task(
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
[],
|
||||
1,
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
0)
|
||||
logger.debug("submitting %s", task.task_id())
|
||||
driver.node_manager_client.submit(task)
|
||||
|
||||
logger.debug("Return values were %s", task.returns())
|
||||
task2 = ray.local_scheduler.Task(
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
task.returns(),
|
||||
1,
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
0)
|
||||
logger.debug("Submitting dependent task 2 %s", task2.task_id())
|
||||
driver.node_manager_client.submit(task2)
|
||||
|
||||
# Make sure the tasks get executed and we can get the result of the last
|
||||
# task.
|
||||
obj = driver.get(task2.returns(), timeout_ms=1000)
|
||||
@@ -0,0 +1,115 @@
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
from worker import Worker, logger
|
||||
from ray.utils import random_string
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("raylet_socket_name")
|
||||
parser.add_argument("object_store_socket_name")
|
||||
|
||||
|
||||
def submit_task_withdep(driver_handle, task_object_dependencies=[]):
|
||||
''' submit a task that depend on a list of @args'''
|
||||
task = ray.local_scheduler.Task(
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
task_object_dependencies,
|
||||
1, # num_returns
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
0)
|
||||
logger.debug("[DRIVER]: submitting task ", task.task_id())
|
||||
driver_handle.node_manager_client.submit(task)
|
||||
logger.debug("[DRIVER]: task return values", task.returns())
|
||||
return task.returns()
|
||||
|
||||
|
||||
def submit_tasks_nodep(driver_handle, num_tasks):
|
||||
''' submit a task that depend on a list of @args'''
|
||||
for i in range(num_tasks):
|
||||
task = ray.local_scheduler.Task(
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
[],
|
||||
1, # num_returns
|
||||
ray.local_scheduler.ObjectID(random_string()),
|
||||
0)
|
||||
|
||||
logger.debug("[DRIVER]: submitting task ", task.task_id())
|
||||
driver_handle.node_manager_client.submit(task)
|
||||
logger.debug("[DRIVER]: task return values", task.returns())
|
||||
|
||||
|
||||
def submit_task_chains(num_chains, tasks_per_chain):
|
||||
# return task placement map on output
|
||||
chain_returns = []
|
||||
task_placement_map_ = {}
|
||||
for chain_num in range(num_chains):
|
||||
last_task_returns = []
|
||||
task_placement_map_[chain_num] = []
|
||||
for i in range(tasks_per_chain):
|
||||
task_returns = submit_task_withdep(
|
||||
driver,
|
||||
task_object_dependencies=last_task_returns)
|
||||
last_task_returns = task_returns
|
||||
task_placement_map_[chain_num].append(task_returns[0])
|
||||
chain_returns.append(last_task_returns)
|
||||
|
||||
logger.debug("chain_returns=", chain_returns)
|
||||
chain_results = driver.get([r[0] for r in chain_returns], timeout_ms=5000)
|
||||
print("[DRIVER]: chain return values: ", chain_results)
|
||||
|
||||
return task_placement_map_
|
||||
|
||||
|
||||
def TEST_run_task_chains(num_chains, tasks_per_chain):
|
||||
task_placement_map = submit_task_chains(num_chains=num_chains,
|
||||
tasks_per_chain=tasks_per_chain)
|
||||
logger.debug("[DRIVER]: task placement information, per chain:")
|
||||
task_placement_total = []
|
||||
for chain_num in range(len(task_placement_map)):
|
||||
task_placement_list = driver.get(task_placement_map[chain_num],
|
||||
timeout_ms=5000)
|
||||
task_placement_total += [t[1] for t in task_placement_list]
|
||||
logger.debug(chain_num, task_placement_list)
|
||||
logger.debug("task placement overall: ", task_placement_total)
|
||||
task_placement_stats = [(v, task_placement_total.count(v))
|
||||
for v in set(task_placement_total)]
|
||||
num_total_tasks = sum([t[1] for t in task_placement_stats])
|
||||
print("total tasks executed = ", num_total_tasks)
|
||||
assert(num_total_tasks == num_chains * tasks_per_chain)
|
||||
print("task placement breakdown: total=", task_placement_stats)
|
||||
|
||||
|
||||
def TEST_run_tasks_nodep(num_tasks):
|
||||
# This test is the same as having num_tasks chains with 1 task per chain
|
||||
# In this test we assume the num_tasks x 1 chain structure.
|
||||
task_placement_map = submit_task_chains(num_chains=num_tasks,
|
||||
tasks_per_chain=1)
|
||||
logger.debug("[DRIVER]: task placement information, per chain:")
|
||||
task_placement_total = []
|
||||
for chain_num in range(len(task_placement_map)):
|
||||
task_placement_list = driver.get(task_placement_map[chain_num],
|
||||
timeout_ms=5000)
|
||||
task_placement_total += [t[1] for t in task_placement_list]
|
||||
logger.debug(chain_num, task_placement_list)
|
||||
logger.debug("task placement overall: ", task_placement_total)
|
||||
task_placement_stats = [(v, task_placement_total.count(v)) for v in
|
||||
set(task_placement_total)]
|
||||
num_total_tasks = sum([t[1] for t in task_placement_stats])
|
||||
print("total tasks executed = ", num_total_tasks)
|
||||
assert(num_total_tasks == num_tasks)
|
||||
print("task placement breakdown: total=", task_placement_stats)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
|
||||
driver = Worker(args.raylet_socket_name, args.object_store_socket_name,
|
||||
is_worker=False)
|
||||
|
||||
# Set up the experiment : number of chains and tasks per chain.
|
||||
# TEST_run_task_chains(num_chains=10, tasks_per_chain=100)
|
||||
|
||||
TEST_run_tasks_nodep(10000)
|
||||
@@ -0,0 +1,67 @@
|
||||
import logging
|
||||
|
||||
import ray
|
||||
import pyarrow
|
||||
import pyarrow.plasma as plasma
|
||||
from ray.utils import random_string
|
||||
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The default return value to put in the object store.
|
||||
RETURN_VALUE = 0
|
||||
|
||||
|
||||
class Worker(object):
|
||||
|
||||
total_task_count = 0
|
||||
|
||||
def __init__(self, raylet_socket_name, object_store_socket_name,
|
||||
is_worker):
|
||||
# Connect to the Raylet and object store.
|
||||
self.node_manager_client = ray.local_scheduler.LocalSchedulerClient(
|
||||
raylet_socket_name, random_string(), is_worker)
|
||||
self.plasma_client = plasma.connect(object_store_socket_name, "", 0)
|
||||
self.serialization_context = pyarrow.default_serialization_context()
|
||||
self.raylet_socket_name = raylet_socket_name
|
||||
self.object_store_socket_name = object_store_socket_name
|
||||
|
||||
def main_loop(self):
|
||||
while True:
|
||||
self.get_task()
|
||||
|
||||
def get(self, object_ids, timeout_ms=-1):
|
||||
for object_id in object_ids:
|
||||
self.node_manager_client.reconstruct_object(object_id.id())
|
||||
plasma_ids = [plasma.ObjectID(argument.id()) for argument in
|
||||
object_ids]
|
||||
values = self.plasma_client.get(plasma_ids, timeout_ms,
|
||||
self.serialization_context)
|
||||
assert(all(value[0] == RETURN_VALUE for value in values))
|
||||
return values
|
||||
|
||||
def get_task(self):
|
||||
logger.debug("[WORKER] waiting for task")
|
||||
task = self.node_manager_client.get_task()
|
||||
logger.debug("Worker assigned %s with arguments %s",
|
||||
ray.utils.binary_to_hex(task.task_id().id()),
|
||||
" ".join([ray.utils.binary_to_hex(argument.id()) for
|
||||
argument in task.arguments()]))
|
||||
|
||||
# Get the arguments. NOTE(swang): This will hang forever if the
|
||||
# arguments have been evicted.
|
||||
arguments = self.get(task.arguments())
|
||||
|
||||
for object_id in task.returns():
|
||||
self.plasma_client.put((RETURN_VALUE, self.raylet_socket_name),
|
||||
plasma.ObjectID(object_id.id()))
|
||||
objval = self.plasma_client.get([plasma.ObjectID(object_id.id())])
|
||||
assert(all([o[0] == RETURN_VALUE for o in objval]))
|
||||
|
||||
logger.debug("Worker returned %s",
|
||||
" ".join([ray.utils.binary_to_hex(return_id.id()) for
|
||||
return_id in task.returns()]))
|
||||
|
||||
# Release the arguments.
|
||||
del arguments
|
||||
@@ -19,17 +19,18 @@ add_custom_command(
|
||||
|
||||
add_custom_target(gen_node_manager_fbs DEPENDS ${NODE_MANAGER_FBS_OUTPUT_FILES})
|
||||
|
||||
ADD_RAY_TEST(raylet_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
ADD_RAY_TEST(object_manager_integration_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
|
||||
ADD_RAY_TEST(worker_pool_test STATIC_LINK_LIBS ray_static ${PLASMA_STATIC_LIB} ${ARROW_STATIC_LIB} gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
|
||||
ADD_RAY_TEST(task_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
ADD_RAY_TEST(lineage_cache_test STATIC_LINK_LIBS ray_static gtest gtest_main gmock_main pthread ${Boost_SYSTEM_LIBRARY})
|
||||
|
||||
add_library(rayletlib raylet.cc ${NODE_MANAGER_FBS_OUTPUT_FILES})
|
||||
target_link_libraries(rayletlib ray_static ${Boost_SYSTEM_LIBRARY})
|
||||
|
||||
add_executable(raylet main.cc)
|
||||
target_link_libraries(raylet rayletlib ${Boost_SYSTEM_LIBRARY} pthread)
|
||||
add_executable(raylet_demo remote_dependencies_demo.cc)
|
||||
target_link_libraries(raylet_demo rayletlib ${Boost_SYSTEM_LIBRARY} pthread)
|
||||
|
||||
install(FILES
|
||||
raylet
|
||||
|
||||
@@ -2,10 +2,14 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
ActorInformation::ActorInformation() : id_(UniqueID::nil()) {}
|
||||
|
||||
ActorInformation::~ActorInformation() {}
|
||||
|
||||
const ActorID &ActorInformation::GetActorId() const { return this->id_; }
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
#include "ray/id.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
class ActorInformation {
|
||||
public:
|
||||
/// \brief ActorInformation constructor.
|
||||
@@ -21,6 +24,8 @@ class ActorInformation {
|
||||
ActorID id_;
|
||||
}; // class ActorInformation
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_ACTOR_H
|
||||
|
||||
@@ -51,7 +51,9 @@ enum MessageType:int {
|
||||
// counts and execution dependencies, discard any tasks that already executed
|
||||
// before the checkpoint, and make any tasks on the frontier runnable by
|
||||
// making their execution dependencies available.
|
||||
SetActorFrontier
|
||||
SetActorFrontier,
|
||||
// A node manager request to process a task forwarded from another node manager.
|
||||
ForwardTaskRequest
|
||||
}
|
||||
|
||||
table TaskExecutionSpecification {
|
||||
@@ -82,12 +84,6 @@ table GetTaskReply {
|
||||
gpu_ids: [int];
|
||||
}
|
||||
|
||||
table EventLogMessage {
|
||||
key: string;
|
||||
value: string;
|
||||
timestamp: double;
|
||||
}
|
||||
|
||||
// This struct is used to register a new worker with the local scheduler.
|
||||
// It is shipped as part of local_scheduler_connect.
|
||||
table RegisterClientRequest {
|
||||
@@ -108,57 +104,20 @@ table RegisterClientReply {
|
||||
gpu_ids: [int];
|
||||
}
|
||||
|
||||
table DisconnectClient {
|
||||
}
|
||||
|
||||
table ReconstructObject {
|
||||
// Object ID of the object that needs to be reconstructed.
|
||||
object_id: string;
|
||||
}
|
||||
|
||||
table PutObject {
|
||||
// Task ID of the task that performed the put.
|
||||
task_id: string;
|
||||
// Object ID of the object that is being put.
|
||||
object_id: string;
|
||||
}
|
||||
|
||||
// The ActorFrontier is used to represent the current frontier of tasks that
|
||||
// the local scheduler has marked as runnable for a particular actor. It is
|
||||
// used to save the point in an actor's lifetime at which a checkpoint was
|
||||
// taken, so that the same frontier of tasks can be made runnable again if the
|
||||
// actor is resumed from that checkpoint.
|
||||
table ActorFrontier {
|
||||
// Actor ID of the actor whose frontier is described.
|
||||
actor_id: string;
|
||||
// A list of handle IDs, representing the callers of the actor that have
|
||||
// submitted a runnable task to the local scheduler. A nil ID represents the
|
||||
// creator of the actor.
|
||||
handle_ids: [string];
|
||||
// A list representing the number of tasks executed so far, per handle. Each
|
||||
// count in task_counters corresponds to the handle at the same in index in
|
||||
// handle_ids.
|
||||
task_counters: [long];
|
||||
// A list representing the execution dependency for the next runnable task,
|
||||
// per handle. Each execution dependency in frontier_dependencies corresponds
|
||||
// to the handle at the same in index in handle_ids.
|
||||
frontier_dependencies: [string];
|
||||
}
|
||||
|
||||
table GetActorFrontierRequest {
|
||||
actor_id: string;
|
||||
}
|
||||
|
||||
table RegisterNodeManagerRequest {
|
||||
// GCS ClientID of the connecting node manager.
|
||||
client_id: string;
|
||||
}
|
||||
|
||||
table ForwardTaskRequest {
|
||||
// The task to be forwarded.
|
||||
// TODO(swang): Replace with a Task flatbuffer type.
|
||||
task: string;
|
||||
// The uncommitted lineage of the forwarded task, according to the sending
|
||||
// node manager.
|
||||
uncommitted_lineage: [string];
|
||||
// The ID of the task to be forwarded.
|
||||
task_id: string;
|
||||
// The tasks in the uncommitted lineage of the forwarded task. This
|
||||
// should include task_id.
|
||||
uncommitted_tasks: [Task];
|
||||
}
|
||||
|
||||
table ReconstructObject {
|
||||
// Object ID of the object that needs to be reconstructed.
|
||||
object_id: string;
|
||||
}
|
||||
|
||||
+281
-13
@@ -2,30 +2,298 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
LineageCache::LineageCache() {}
|
||||
namespace raylet {
|
||||
|
||||
ray::Status LineageCache::AddTask(const Task &task) {
|
||||
throw std::runtime_error("method not implemented");
|
||||
return ray::Status::OK();
|
||||
LineageEntry::LineageEntry(const Task &task, GcsStatus status)
|
||||
: status_(status), task_(task) {}
|
||||
|
||||
GcsStatus LineageEntry::GetStatus() const { return status_; }
|
||||
|
||||
bool LineageEntry::SetStatus(GcsStatus new_status) {
|
||||
if (status_ < new_status) {
|
||||
status_ = new_status;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
ray::Status LineageCache::AddTask(const Task &task, const Lineage &uncommitted_lineage) {
|
||||
throw std::runtime_error("method not implemented");
|
||||
return ray::Status::OK();
|
||||
void LineageEntry::ResetStatus(GcsStatus new_status) {
|
||||
RAY_CHECK(new_status < status_);
|
||||
status_ = new_status;
|
||||
}
|
||||
|
||||
ray::Status LineageCache::AddObjectLocation(const ObjectID &object_id) {
|
||||
throw std::runtime_error("method not implemented");
|
||||
return ray::Status::OK();
|
||||
const TaskID LineageEntry::GetEntryId() const {
|
||||
return task_.GetTaskSpecification().TaskId();
|
||||
}
|
||||
|
||||
Lineage &LineageCache::GetUncommittedLineage(const ObjectID &object_id) {
|
||||
throw std::runtime_error("method not implemented");
|
||||
const std::unordered_set<UniqueID, UniqueIDHasher> LineageEntry::GetParentTaskIds()
|
||||
const {
|
||||
std::unordered_set<UniqueID, UniqueIDHasher> parent_ids;
|
||||
// A task's parents are the tasks that created its arguments.
|
||||
auto dependencies = task_.GetDependencies();
|
||||
for (auto &dependency : dependencies) {
|
||||
parent_ids.insert(ComputeTaskId(dependency));
|
||||
}
|
||||
return parent_ids;
|
||||
}
|
||||
|
||||
const Task &LineageEntry::TaskData() const { return task_; }
|
||||
|
||||
Task &LineageEntry::TaskDataMutable() { return task_; }
|
||||
|
||||
Lineage::Lineage() {}
|
||||
|
||||
Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) {
|
||||
// Deserialize and set entries for the uncommitted tasks.
|
||||
auto tasks = task_request.uncommitted_tasks();
|
||||
for (auto it = tasks->begin(); it != tasks->end(); it++) {
|
||||
auto task = Task(**it);
|
||||
LineageEntry entry(task, GcsStatus_UNCOMMITTED_REMOTE);
|
||||
RAY_CHECK(SetEntry(std::move(entry)));
|
||||
}
|
||||
}
|
||||
|
||||
boost::optional<const LineageEntry &> Lineage::GetEntry(const UniqueID &task_id) const {
|
||||
auto entry = entries_.find(task_id);
|
||||
if (entry != entries_.end()) {
|
||||
return entry->second;
|
||||
} else {
|
||||
return boost::optional<const LineageEntry &>();
|
||||
}
|
||||
}
|
||||
|
||||
boost::optional<LineageEntry &> Lineage::GetEntryMutable(const UniqueID &task_id) {
|
||||
auto entry = entries_.find(task_id);
|
||||
if (entry != entries_.end()) {
|
||||
return entry->second;
|
||||
} else {
|
||||
return boost::optional<LineageEntry &>();
|
||||
}
|
||||
}
|
||||
|
||||
bool Lineage::SetEntry(LineageEntry &&new_entry) {
|
||||
// Get the status of the current entry at the key.
|
||||
auto task_id = new_entry.GetEntryId();
|
||||
GcsStatus current_status = GcsStatus_NONE;
|
||||
auto current_entry = PopEntry(task_id);
|
||||
if (current_entry) {
|
||||
current_status = current_entry->GetStatus();
|
||||
}
|
||||
|
||||
if (current_status < new_entry.GetStatus()) {
|
||||
// If the new status is greater, then overwrite the current entry.
|
||||
entries_.emplace(std::make_pair(task_id, std::move(new_entry)));
|
||||
return true;
|
||||
} else {
|
||||
// If the new status is not greater, then the new entry is invalid. Replace
|
||||
// the current entry at the key.
|
||||
entries_.emplace(std::make_pair(task_id, std::move(*current_entry)));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
boost::optional<LineageEntry> Lineage::PopEntry(const UniqueID &task_id) {
|
||||
auto entry = entries_.find(task_id);
|
||||
if (entry != entries_.end()) {
|
||||
LineageEntry entry = std::move(entries_.at(task_id));
|
||||
entries_.erase(task_id);
|
||||
return entry;
|
||||
} else {
|
||||
return boost::optional<LineageEntry>();
|
||||
}
|
||||
}
|
||||
|
||||
const std::unordered_map<const UniqueID, LineageEntry, UniqueIDHasher>
|
||||
&Lineage::GetEntries() const {
|
||||
return entries_;
|
||||
}
|
||||
|
||||
flatbuffers::Offset<protocol::ForwardTaskRequest> Lineage::ToFlatbuffer(
|
||||
flatbuffers::FlatBufferBuilder &fbb, const TaskID &task_id) const {
|
||||
RAY_CHECK(GetEntry(task_id));
|
||||
// Serialize the task and object entries.
|
||||
std::vector<flatbuffers::Offset<protocol::Task>> uncommitted_tasks;
|
||||
for (const auto &entry : entries_) {
|
||||
uncommitted_tasks.push_back(entry.second.TaskData().ToFlatbuffer(fbb));
|
||||
}
|
||||
|
||||
auto request = protocol::CreateForwardTaskRequest(fbb, to_flatbuf(fbb, task_id),
|
||||
fbb.CreateVector(uncommitted_tasks));
|
||||
return request;
|
||||
}
|
||||
|
||||
LineageCache::LineageCache(gcs::TableInterface<TaskID, protocol::Task> &task_storage)
|
||||
: task_storage_(task_storage) {}
|
||||
|
||||
/// A helper function to merge one lineage into another, in DFS order.
|
||||
///
|
||||
/// \param task_id The current entry to merge from lineage_from into
|
||||
/// lineage_to.
|
||||
/// \param lineage_from The lineage to merge entries from. This lineage is
|
||||
/// traversed by following each entry's parent pointers in DFS order,
|
||||
/// until an entry is not found or the stopping condition is reached.
|
||||
/// \param lineage_to The lineage to merge entries into.
|
||||
/// \param stopping_condition A stopping condition for the DFS over
|
||||
/// lineage_from. This should return true if the merge should stop.
|
||||
void MergeLineageHelper(const UniqueID &task_id, const Lineage &lineage_from,
|
||||
Lineage &lineage_to,
|
||||
std::function<bool(GcsStatus)> stopping_condition) {
|
||||
// If the entry is not found in the lineage to merge, then we stop since
|
||||
// there is nothing to copy into the merged lineage.
|
||||
auto entry = lineage_from.GetEntry(task_id);
|
||||
if (!entry) {
|
||||
return;
|
||||
}
|
||||
// Check whether we should stop at this entry in the DFS.
|
||||
auto status = entry->GetStatus();
|
||||
if (stopping_condition(status)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Insert a copy of the entry into lineage_to.
|
||||
LineageEntry entry_copy = *entry;
|
||||
auto parent_ids = entry_copy.GetParentTaskIds();
|
||||
// If the insert is successful, then continue the DFS. The insert will fail
|
||||
// if the new entry has an equal or lower GCS status than the current entry
|
||||
// in lineage_to. This also prevents us from traversing the same node twice.
|
||||
if (lineage_to.SetEntry(std::move(entry_copy))) {
|
||||
for (const auto &parent_id : parent_ids) {
|
||||
MergeLineageHelper(parent_id, lineage_from, lineage_to, stopping_condition);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LineageCache::AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage) {
|
||||
auto task_id = task.GetTaskSpecification().TaskId();
|
||||
// Merge the uncommitted lineage into the lineage cache.
|
||||
MergeLineageHelper(task_id, uncommitted_lineage, lineage_, [](GcsStatus status) {
|
||||
if (status != GcsStatus_NONE) {
|
||||
// We received the uncommitted lineage from a remote node, so make sure
|
||||
// that all entries in the lineage to merge have status
|
||||
// UNCOMMITTED_REMOTE.
|
||||
RAY_CHECK(status == GcsStatus_UNCOMMITTED_REMOTE);
|
||||
}
|
||||
// The only stopping condition is that an entry is not found.
|
||||
return false;
|
||||
});
|
||||
|
||||
// Add the submitted task to the lineage cache as UNCOMMITTED_WAITING. It
|
||||
// should be marked as UNCOMMITTED_READY once the task starts execution.
|
||||
LineageEntry task_entry(task, GcsStatus_UNCOMMITTED_WAITING);
|
||||
RAY_CHECK(lineage_.SetEntry(std::move(task_entry)));
|
||||
}
|
||||
|
||||
void LineageCache::AddReadyTask(const Task &task) {
|
||||
auto new_entry = LineageEntry(task, GcsStatus_UNCOMMITTED_READY);
|
||||
RAY_CHECK(lineage_.SetEntry(std::move(new_entry)));
|
||||
}
|
||||
|
||||
void LineageCache::RemoveWaitingTask(const TaskID &task_id) {
|
||||
auto entry = lineage_.PopEntry(task_id);
|
||||
// It's only okay to remove a task that is waiting for execution.
|
||||
// TODO(swang): Is this necessarily true when there is reconstruction?
|
||||
RAY_CHECK(entry->GetStatus() == GcsStatus_UNCOMMITTED_WAITING);
|
||||
// Reset the status to REMOTE. We keep the task instead of removing it
|
||||
// completely in case another task is submitted locally that depends on this
|
||||
// one.
|
||||
entry->ResetStatus(GcsStatus_UNCOMMITTED_REMOTE);
|
||||
RAY_CHECK(lineage_.SetEntry(std::move(*entry)));
|
||||
}
|
||||
|
||||
Lineage LineageCache::GetUncommittedLineage(const TaskID &task_id) const {
|
||||
Lineage uncommitted_lineage;
|
||||
// Add all uncommitted ancestors from the lineage cache to the uncommitted
|
||||
// lineage of the requested task.
|
||||
MergeLineageHelper(task_id, lineage_, uncommitted_lineage, [](GcsStatus status) {
|
||||
// The stopping condition for recursion is that the entry has been
|
||||
// committed to the GCS.
|
||||
return status == GcsStatus_COMMITTED;
|
||||
});
|
||||
return uncommitted_lineage;
|
||||
}
|
||||
|
||||
Status LineageCache::Flush() {
|
||||
throw std::runtime_error("method not implemented");
|
||||
// Find all tasks that are READY and whose arguments have been committed in the GCS.
|
||||
std::vector<TaskID> ready_task_ids;
|
||||
for (const auto &pair : lineage_.GetEntries()) {
|
||||
auto task_id = pair.first;
|
||||
auto entry = pair.second;
|
||||
// Skip task entries that are not ready to be written yet. These tasks
|
||||
// either have not started execution yet, are being executed on a remote
|
||||
// node, or have already been written to the GCS.
|
||||
if (entry.GetStatus() != GcsStatus_UNCOMMITTED_READY) {
|
||||
continue;
|
||||
}
|
||||
// Check if all arguments have been committed to the GCS before writing
|
||||
// this task.
|
||||
bool all_arguments_committed = true;
|
||||
for (const auto &parent_id : entry.GetParentTaskIds()) {
|
||||
auto parent = lineage_.GetEntry(parent_id);
|
||||
// If a parent entry exists in the lineage cache but has not been
|
||||
// committed yet, then as far as we know, it's still in flight to the
|
||||
// GCS. Skip this task for now.
|
||||
if (parent && parent->GetStatus() != GcsStatus_COMMITTED) {
|
||||
// TODO(swang): Once GCS notifications for the task table are ready,
|
||||
// request notification for commit of the parent task here.
|
||||
all_arguments_committed = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_arguments_committed) {
|
||||
// All arguments have been committed to the GCS. Add this task to the
|
||||
// list of tasks to write back to the GCS.
|
||||
ready_task_ids.push_back(task_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Write back all ready tasks whose arguments have been committed to the GCS.
|
||||
gcs::raylet::TaskTable::WriteCallback task_callback =
|
||||
[this](ray::gcs::AsyncGcsClient *client, const TaskID &id,
|
||||
const std::shared_ptr<protocol::TaskT> data) { HandleEntryCommitted(id); };
|
||||
for (const auto &ready_task_id : ready_task_ids) {
|
||||
auto task = lineage_.GetEntry(ready_task_id);
|
||||
// TODO(swang): Make this better...
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = task->TaskData().ToFlatbuffer(fbb);
|
||||
fbb.Finish(message);
|
||||
auto task_data = std::make_shared<protocol::TaskT>();
|
||||
auto root = flatbuffers::GetRoot<protocol::Task>(fbb.GetBufferPointer());
|
||||
root->UnPackTo(task_data.get());
|
||||
RAY_CHECK_OK(task_storage_.Add(task->TaskData().GetTaskSpecification().DriverId(),
|
||||
ready_task_id, task_data, task_callback));
|
||||
|
||||
// We successfully wrote the task, so mark it as committing.
|
||||
// TODO(swang): Use a batched interface and write with all object entries.
|
||||
auto entry = lineage_.PopEntry(ready_task_id);
|
||||
RAY_CHECK(entry->SetStatus(GcsStatus_COMMITTING));
|
||||
RAY_CHECK(lineage_.SetEntry(std::move(*entry)));
|
||||
}
|
||||
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
void PopAncestorTasks(const UniqueID &task_id, Lineage &lineage) {
|
||||
auto entry = lineage.PopEntry(task_id);
|
||||
if (!entry) {
|
||||
return;
|
||||
}
|
||||
auto status = entry->GetStatus();
|
||||
RAY_CHECK(status == GcsStatus_UNCOMMITTED_REMOTE || status == GcsStatus_COMMITTED);
|
||||
for (const auto &parent_id : entry->GetParentTaskIds()) {
|
||||
PopAncestorTasks(parent_id, lineage);
|
||||
}
|
||||
}
|
||||
|
||||
void LineageCache::HandleEntryCommitted(const UniqueID &task_id) {
|
||||
auto entry = lineage_.PopEntry(task_id);
|
||||
for (const auto &parent_id : entry->GetParentTaskIds()) {
|
||||
PopAncestorTasks(parent_id, lineage_);
|
||||
}
|
||||
RAY_CHECK(entry->SetStatus(GcsStatus_COMMITTED));
|
||||
RAY_CHECK(lineage_.SetEntry(std::move(*entry)));
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
+178
-45
@@ -1,83 +1,216 @@
|
||||
#ifndef RAY_RAYLET_LINEAGE_CACHE_H
|
||||
#define RAY_RAYLET_LINEAGE_CACHE_H
|
||||
|
||||
#include <boost/optional.hpp>
|
||||
|
||||
// clang-format off
|
||||
#include "common_protocol.h"
|
||||
#include "ray/raylet/task.h"
|
||||
#include "ray/gcs/tables.h"
|
||||
#include "ray/id.h"
|
||||
#include "ray/status.h"
|
||||
// clang-format on
|
||||
|
||||
namespace ray {
|
||||
|
||||
// TODO(swang): Define this class.
|
||||
class Lineage {};
|
||||
namespace raylet {
|
||||
|
||||
class LineageCacheEntry {
|
||||
private:
|
||||
// TODO(swang): This should be an enum of the state of the entry - goes from
|
||||
// completely local, to dirty, to in flight, to committed.
|
||||
// bool dirty_;
|
||||
/// The status of a lineage cache entry according to its status in the GCS.
|
||||
enum GcsStatus {
|
||||
/// The task is not in the lineage cache.
|
||||
GcsStatus_NONE = 0,
|
||||
/// The task is being executed or created on a remote node.
|
||||
GcsStatus_UNCOMMITTED_REMOTE,
|
||||
/// The task is waiting to be executed or created locally.
|
||||
GcsStatus_UNCOMMITTED_WAITING,
|
||||
/// The task has started execution, but the entry has not been written to the
|
||||
/// GCS yet.
|
||||
GcsStatus_UNCOMMITTED_READY,
|
||||
/// The task has been written to the GCS and we are waiting for an
|
||||
/// acknowledgement of the commit.
|
||||
GcsStatus_COMMITTING,
|
||||
/// The task has been committed in the GCS. It's safe to remove this entry
|
||||
/// from the lineage cache.
|
||||
GcsStatus_COMMITTED,
|
||||
};
|
||||
|
||||
class LineageCacheTaskEntry : public LineageCacheEntry {};
|
||||
class LineageCacheObjectEntry : public LineageCacheEntry {};
|
||||
/// \class LineageEntry
|
||||
///
|
||||
/// A task entry in the data lineage. Each entry's parents are the tasks that
|
||||
/// created the entry's arguments.
|
||||
class LineageEntry {
|
||||
public:
|
||||
/// Create an entry for a task.
|
||||
///
|
||||
/// \param task The task data to eventually be written back to the GCS.
|
||||
/// \param status The status of this entry, according to its write status in
|
||||
/// the GCS.
|
||||
LineageEntry(const Task &task, GcsStatus status);
|
||||
|
||||
/// Get this entry's GCS status.
|
||||
///
|
||||
/// \return The entry's status in the GCS.
|
||||
GcsStatus GetStatus() const;
|
||||
|
||||
/// Set this entry's GCS status. The status is only set if the new status
|
||||
/// is strictly greater than the entry's previous status, according to the
|
||||
/// GcsStatus enum.
|
||||
///
|
||||
/// \param new_status Set the entry's status to this value if it is greater
|
||||
/// than the current status.
|
||||
/// \return Whether the entry was set to the new status.
|
||||
bool SetStatus(GcsStatus new_status);
|
||||
|
||||
/// Reset this entry's GCS status to a lower status. The new status must
|
||||
/// be lower than the current status.
|
||||
///
|
||||
/// \param new_status This must be lower than the current status.
|
||||
void ResetStatus(GcsStatus new_status);
|
||||
|
||||
/// Get this entry's ID.
|
||||
///
|
||||
/// \return The entry's ID.
|
||||
const TaskID GetEntryId() const;
|
||||
|
||||
/// Get the IDs of this entry's parent tasks. These are the IDs of the tasks
|
||||
/// that created its arguments.
|
||||
///
|
||||
/// \return The IDs of the parent entries.
|
||||
const std::unordered_set<TaskID, UniqueIDHasher> GetParentTaskIds() const;
|
||||
|
||||
/// Get the task data.
|
||||
///
|
||||
/// \return The task data.
|
||||
const Task &TaskData() const;
|
||||
|
||||
Task &TaskDataMutable();
|
||||
|
||||
private:
|
||||
/// The current state of this entry according to its status in the GCS.
|
||||
GcsStatus status_;
|
||||
/// The task data to be written to the GCS. This is nullptr if the entry is
|
||||
/// an object.
|
||||
// const Task task_;
|
||||
Task task_;
|
||||
};
|
||||
|
||||
/// \class Lineage
|
||||
///
|
||||
/// A lineage DAG, according to the data dependency graph. Each node is a task,
|
||||
/// with an outgoing edge to each of its parent tasks. For a given task, the
|
||||
/// parents are the tasks that created its arguments. Each entry also records
|
||||
/// the current status in the GCS for that task or object.
|
||||
class Lineage {
|
||||
public:
|
||||
/// Construct an empty Lineage.
|
||||
Lineage();
|
||||
|
||||
/// Construct a Lineage from a ForwardTaskRequest.
|
||||
///
|
||||
/// \param task_request The request to construct the lineage from. All
|
||||
/// uncommitted tasks in the request will be added to the lineage.
|
||||
Lineage(const protocol::ForwardTaskRequest &task_request);
|
||||
|
||||
/// Get an entry from the lineage.
|
||||
///
|
||||
/// \param entry_id The ID of the entry to get.
|
||||
/// \return An optional reference to the entry. If this is empty, then the
|
||||
/// entry ID is not in the lineage.
|
||||
boost::optional<const LineageEntry &> GetEntry(const TaskID &entry_id) const;
|
||||
boost::optional<LineageEntry &> GetEntryMutable(const UniqueID &task_id);
|
||||
|
||||
/// Set an entry in the lineage. If an entry with this ID already exists,
|
||||
/// then the entry is overwritten if and only if the new entry has a higher
|
||||
/// GCS status than the current. The current entry's object or task data will
|
||||
/// also be overwritten.
|
||||
///
|
||||
/// \param entry The new entry to set in the lineage, if its GCS status is
|
||||
/// greater than the current entry.
|
||||
/// \return Whether the entry was set.
|
||||
bool SetEntry(LineageEntry &&entry);
|
||||
|
||||
/// Delete and return an entry from the lineage.
|
||||
///
|
||||
/// \param entry_id The ID of the entry to pop.
|
||||
/// \return An optional reference to the popped entry. If this is empty, then
|
||||
/// the entry ID is not in the lineage.
|
||||
boost::optional<LineageEntry> PopEntry(const TaskID &entry_id);
|
||||
|
||||
/// Get all entries in the lineage.
|
||||
///
|
||||
/// \return A const reference to the lineage entries.
|
||||
const std::unordered_map<const TaskID, LineageEntry, UniqueIDHasher> &GetEntries()
|
||||
const;
|
||||
|
||||
/// Serialize this lineage to a ForwardTaskRequest flatbuffer.
|
||||
///
|
||||
/// \param entry_id The task ID to include in the ForwardTaskRequest
|
||||
/// flatbuffer.
|
||||
/// \return An offset to the serialized lineage. The serialization includes
|
||||
/// all task and object entries in the lineage.
|
||||
flatbuffers::Offset<protocol::ForwardTaskRequest> ToFlatbuffer(
|
||||
flatbuffers::FlatBufferBuilder &fbb, const TaskID &entry_id) const;
|
||||
|
||||
private:
|
||||
/// The lineage entries.
|
||||
std::unordered_map<const TaskID, LineageEntry, UniqueIDHasher> entries_;
|
||||
};
|
||||
|
||||
/// \class LineageCache
|
||||
///
|
||||
/// A cache of the object and task tables. This consists of all tasks that this
|
||||
/// node owns, as well as their task lineage, that have not yet been added
|
||||
/// durably to the GCS.
|
||||
/// A cache of the task table. This consists of all tasks that this node owns,
|
||||
/// as well as their lineage, that have not yet been added durably to the GCS.
|
||||
class LineageCache {
|
||||
public:
|
||||
/// Create a lineage cache policy.
|
||||
/// TODO(swang): Pass in the policy (interface?) and a GCS client.
|
||||
LineageCache();
|
||||
/// Create a lineage cache for the given task storage system.
|
||||
/// TODO(swang): Pass in the policy (interface?).
|
||||
LineageCache(gcs::TableInterface<TaskID, protocol::Task> &task_storage);
|
||||
|
||||
/// Add a task and its object outputs asynchronously to the GCS. This
|
||||
/// overwrites the task's mutable fields in the execution specification.
|
||||
/// Add a task that is waiting for execution and its uncommitted lineage.
|
||||
/// These entries will not be written to the GCS until set to ready.
|
||||
///
|
||||
/// \param task The task to add.
|
||||
/// \return Status.
|
||||
ray::Status AddTask(const Task &task);
|
||||
|
||||
/// Add a task and its uncommitted lineage asynchronously to the GCS. The
|
||||
/// mutable fields for the given task will be overwritten, but not for the
|
||||
/// tasks in the uncommitted lineage.
|
||||
///
|
||||
/// \param task The task to add.
|
||||
/// \param task The waiting task to add.
|
||||
/// \param uncommitted_lineage The task's uncommitted lineage. These are the
|
||||
/// tasks that the given task is data-dependent on, but that have not
|
||||
/// been made durable in the GCS, as far as we know.
|
||||
/// \return Status.
|
||||
ray::Status AddTask(const Task &task, const Lineage &uncommitted_lineage);
|
||||
/// tasks that the given task is data-dependent on, but that have not
|
||||
/// been made durable in the GCS, as far the task's submitter knows.
|
||||
void AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage);
|
||||
|
||||
/// Add this node as an object location, to be asynchronously committed to
|
||||
/// the GCS.
|
||||
/// Add a task that is ready for GCS writeback. This overwrites the task’s
|
||||
/// mutable fields in the execution specification.
|
||||
///
|
||||
/// \param object_id The object to add a location for.
|
||||
/// \return Status.
|
||||
ray::Status AddObjectLocation(const ObjectID &object_id);
|
||||
/// \param task The task to set as ready.
|
||||
void AddReadyTask(const Task &task);
|
||||
|
||||
/// Get the uncommitted lineage of an object. These are the tasks that the
|
||||
/// given object is data-dependent on, but that have not been made durable in
|
||||
void RemoveWaitingTask(const TaskID &entry_id);
|
||||
|
||||
/// Get the uncommitted lineage of a task. The uncommitted lineage consists
|
||||
/// of all tasks in the given task's lineage that have not been committed in
|
||||
/// the GCS, as far as we know.
|
||||
///
|
||||
/// \param object_id The object to get the uncommitted lineage for.
|
||||
/// \return The uncommitted lineage of the object.
|
||||
Lineage &GetUncommittedLineage(const ObjectID &object_id);
|
||||
/// \param entry_id The ID of the task to get the uncommitted lineage for.
|
||||
/// \return The uncommitted lineage of the task. The returned lineage
|
||||
/// includes the entry for the requested entry_id.
|
||||
Lineage GetUncommittedLineage(const TaskID &entry_id) const;
|
||||
|
||||
/// Asynchronously write any tasks and object locations that have been added
|
||||
/// since the last flush to the GCS. When each write is acknowledged, its
|
||||
/// entry will be marked as committed.
|
||||
/// Asynchronously write any tasks that have been added since the last flush
|
||||
/// to the GCS. When each write is acknowledged, its entry will be marked as
|
||||
/// committed.
|
||||
///
|
||||
/// \return Status.
|
||||
Status Flush();
|
||||
|
||||
private:
|
||||
std::unordered_map<TaskID, LineageCacheTaskEntry, UniqueIDHasher> task_table_;
|
||||
std::unordered_map<ObjectID, LineageCacheObjectEntry, UniqueIDHasher> object_table_;
|
||||
void HandleEntryCommitted(const TaskID &unique_id);
|
||||
|
||||
/// The durable storage system for task information.
|
||||
gcs::TableInterface<TaskID, protocol::Task> &task_storage_;
|
||||
/// All tasks and objects that we are responsible for writing back to the
|
||||
/// GCS, and the tasks and objects in their lineage.
|
||||
Lineage lineage_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_LINEAGE_CACHE_H
|
||||
|
||||
@@ -0,0 +1,270 @@
|
||||
#include <list>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
#include "ray/raylet/lineage_cache.h"
|
||||
#include "ray/raylet/task.h"
|
||||
#include "ray/raylet/task_execution_spec.h"
|
||||
#include "ray/raylet/task_spec.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
class MockGcs : virtual public gcs::TableInterface<TaskID, protocol::Task> {
|
||||
public:
|
||||
MockGcs(){};
|
||||
Status Add(const JobID &job_id, const TaskID &task_id,
|
||||
std::shared_ptr<protocol::TaskT> task_data,
|
||||
const gcs::TableInterface<TaskID, protocol::Task>::WriteCallback &done) {
|
||||
task_table_[task_id] = task_data;
|
||||
callbacks_.push_back(
|
||||
std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>(done, task_id));
|
||||
return ray::Status::OK();
|
||||
};
|
||||
|
||||
void Flush() {
|
||||
for (const auto &callback : callbacks_) {
|
||||
callback.first(NULL, callback.second, task_table_[callback.second]);
|
||||
}
|
||||
callbacks_.clear();
|
||||
};
|
||||
|
||||
const std::unordered_map<TaskID, std::shared_ptr<protocol::TaskT>, UniqueIDHasher>
|
||||
&TaskTable() const {
|
||||
return task_table_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<TaskID, std::shared_ptr<protocol::TaskT>, UniqueIDHasher>
|
||||
task_table_;
|
||||
std::vector<std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>> callbacks_;
|
||||
};
|
||||
|
||||
class LineageCacheTest : public ::testing::Test {
|
||||
public:
|
||||
LineageCacheTest() : mock_gcs_(), lineage_cache_(mock_gcs_) {}
|
||||
|
||||
protected:
|
||||
MockGcs mock_gcs_;
|
||||
LineageCache lineage_cache_;
|
||||
};
|
||||
|
||||
static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
|
||||
int64_t num_returns) {
|
||||
std::unordered_map<std::string, double> required_resources;
|
||||
std::vector<std::shared_ptr<TaskArgument>> task_arguments;
|
||||
for (auto &argument : arguments) {
|
||||
std::vector<ObjectID> references = {argument};
|
||||
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
|
||||
}
|
||||
auto spec = TaskSpecification(UniqueID::nil(), UniqueID::from_random(), 0,
|
||||
UniqueID::from_random(), task_arguments, num_returns,
|
||||
required_resources);
|
||||
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
|
||||
execution_spec.IncrementNumForwards();
|
||||
Task task = Task(execution_spec, spec);
|
||||
return task;
|
||||
}
|
||||
|
||||
std::vector<ObjectID> InsertTaskChain(LineageCache &lineage_cache,
|
||||
std::vector<Task> &inserted_tasks, int chain_size,
|
||||
const std::vector<ObjectID> &initial_arguments,
|
||||
int64_t num_returns) {
|
||||
Lineage empty_lineage;
|
||||
std::vector<ObjectID> arguments = initial_arguments;
|
||||
for (int i = 0; i < chain_size; i++) {
|
||||
auto task = ExampleTask(arguments, num_returns);
|
||||
lineage_cache.AddWaitingTask(task, empty_lineage);
|
||||
inserted_tasks.push_back(task);
|
||||
arguments.clear();
|
||||
for (int j = 0; j < task.GetTaskSpecification().NumReturns(); j++) {
|
||||
arguments.push_back(task.GetTaskSpecification().ReturnId(j));
|
||||
}
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
|
||||
TEST_F(LineageCacheTest, TestGetUncommittedLineage) {
|
||||
// Insert two independent chains of tasks.
|
||||
std::vector<Task> tasks1;
|
||||
auto return_values1 =
|
||||
InsertTaskChain(lineage_cache_, tasks1, 3, std::vector<ObjectID>(), 1);
|
||||
std::vector<TaskID> task_ids1;
|
||||
for (const auto &task : tasks1) {
|
||||
task_ids1.push_back(task.GetTaskSpecification().TaskId());
|
||||
}
|
||||
|
||||
std::vector<Task> tasks2;
|
||||
auto return_values2 =
|
||||
InsertTaskChain(lineage_cache_, tasks2, 2, std::vector<ObjectID>(), 2);
|
||||
std::vector<TaskID> task_ids2;
|
||||
for (const auto &task : tasks2) {
|
||||
task_ids2.push_back(task.GetTaskSpecification().TaskId());
|
||||
}
|
||||
|
||||
// Get the uncommitted lineage for the last task (the leaf) of one of the
|
||||
// chains.
|
||||
auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_ids1.back());
|
||||
// Check that the uncommitted lineage is exactly equal to the first chain of
|
||||
// tasks.
|
||||
ASSERT_EQ(task_ids1.size(), uncommitted_lineage.GetEntries().size());
|
||||
for (auto &task_id : task_ids1) {
|
||||
ASSERT_TRUE(uncommitted_lineage.GetEntry(task_id));
|
||||
}
|
||||
|
||||
// Insert one task that is dependent on the previous chains of tasks.
|
||||
std::vector<Task> combined_tasks = tasks1;
|
||||
combined_tasks.insert(combined_tasks.end(), tasks2.begin(), tasks2.end());
|
||||
std::vector<ObjectID> combined_arguments = return_values1;
|
||||
combined_arguments.insert(combined_arguments.end(), return_values2.begin(),
|
||||
return_values2.end());
|
||||
InsertTaskChain(lineage_cache_, combined_tasks, 1, combined_arguments, 1);
|
||||
std::vector<TaskID> combined_task_ids;
|
||||
for (const auto &task : combined_tasks) {
|
||||
combined_task_ids.push_back(task.GetTaskSpecification().TaskId());
|
||||
}
|
||||
|
||||
// Get the uncommitted lineage for the inserted task.
|
||||
uncommitted_lineage = lineage_cache_.GetUncommittedLineage(combined_task_ids.back());
|
||||
// Check that the uncommitted lineage is exactly equal to the entire set of
|
||||
// tasks inserted so far.
|
||||
ASSERT_EQ(combined_task_ids.size(), uncommitted_lineage.GetEntries().size());
|
||||
for (auto &task_id : combined_task_ids) {
|
||||
ASSERT_TRUE(uncommitted_lineage.GetEntry(task_id));
|
||||
}
|
||||
}
|
||||
|
||||
void CheckFlush(LineageCache &lineage_cache, MockGcs &mock_gcs,
|
||||
size_t num_tasks_flushed) {
|
||||
RAY_CHECK_OK(lineage_cache.Flush());
|
||||
ASSERT_EQ(mock_gcs.TaskTable().size(), num_tasks_flushed);
|
||||
}
|
||||
|
||||
TEST_F(LineageCacheTest, TestWritebackNoneReady) {
|
||||
// Insert a chain of dependent tasks.
|
||||
size_t num_tasks_flushed = 0;
|
||||
std::vector<Task> tasks;
|
||||
auto return_values1 =
|
||||
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
|
||||
|
||||
// Check that when no tasks have been marked as ready, we do not flush any
|
||||
// entries.
|
||||
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
|
||||
}
|
||||
|
||||
TEST_F(LineageCacheTest, TestWritebackReady) {
|
||||
// Insert a chain of dependent tasks.
|
||||
size_t num_tasks_flushed = 0;
|
||||
std::vector<Task> tasks;
|
||||
auto return_values1 =
|
||||
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
|
||||
|
||||
// Check that after marking the first task as ready, we flush only that task.
|
||||
lineage_cache_.AddReadyTask(tasks.front());
|
||||
num_tasks_flushed++;
|
||||
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
|
||||
}
|
||||
|
||||
TEST_F(LineageCacheTest, TestWritebackOrder) {
|
||||
// Insert a chain of dependent tasks.
|
||||
size_t num_tasks_flushed = 0;
|
||||
std::vector<Task> tasks;
|
||||
auto return_values1 =
|
||||
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
|
||||
|
||||
// Mark all tasks as ready.
|
||||
for (const auto &task : tasks) {
|
||||
lineage_cache_.AddReadyTask(task);
|
||||
}
|
||||
// Check that we write back the tasks in order of data dependencies.
|
||||
for (size_t i = 0; i < tasks.size(); i++) {
|
||||
num_tasks_flushed++;
|
||||
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
|
||||
// Flush acknowledgements. The next task should be able to be written.
|
||||
mock_gcs_.Flush();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LineageCacheTest, TestWritebackPartiallyReady) {
|
||||
// Create two independent tasks, task1 and task2, and a dependent task
|
||||
// that depends on both tasks.
|
||||
size_t num_tasks_flushed = 0;
|
||||
auto task1 = ExampleTask({}, 1);
|
||||
auto task2 = ExampleTask({}, 1);
|
||||
std::vector<ObjectID> returns;
|
||||
for (int64_t i = 0; i < task1.GetTaskSpecification().NumReturns(); i++) {
|
||||
returns.push_back(task1.GetTaskSpecification().ReturnId(i));
|
||||
}
|
||||
for (int64_t i = 0; i < task2.GetTaskSpecification().NumReturns(); i++) {
|
||||
returns.push_back(task2.GetTaskSpecification().ReturnId(i));
|
||||
}
|
||||
auto dependent_task = ExampleTask(returns, 1);
|
||||
auto dependencies = dependent_task.GetDependencies();
|
||||
|
||||
// Insert all tasks as waiting for execution.
|
||||
lineage_cache_.AddWaitingTask(task1, Lineage());
|
||||
lineage_cache_.AddWaitingTask(task2, Lineage());
|
||||
lineage_cache_.AddWaitingTask(dependent_task, Lineage());
|
||||
|
||||
// Mark one of the independent tasks and the dependent task as ready.
|
||||
lineage_cache_.AddReadyTask(task1);
|
||||
lineage_cache_.AddReadyTask(dependent_task);
|
||||
// Check that only the first independent task is flushed.
|
||||
num_tasks_flushed++;
|
||||
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
|
||||
|
||||
// Flush acknowledgements. The dependent task should still not be flushed
|
||||
// since task2 is not committed yet.
|
||||
mock_gcs_.Flush();
|
||||
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
|
||||
|
||||
// Mark the other independent task as ready.
|
||||
lineage_cache_.AddReadyTask(task2);
|
||||
// Check that the other independent task gets flushed.
|
||||
num_tasks_flushed++;
|
||||
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
|
||||
|
||||
// Flush acknowledgements. The dependent task should now be able to be
|
||||
// written.
|
||||
mock_gcs_.Flush();
|
||||
num_tasks_flushed++;
|
||||
CheckFlush(lineage_cache_, mock_gcs_, num_tasks_flushed);
|
||||
}
|
||||
|
||||
TEST_F(LineageCacheTest, TestRemoveWaitingTask) {
|
||||
// Insert a chain of dependent tasks.
|
||||
std::vector<Task> tasks;
|
||||
auto return_values1 =
|
||||
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
|
||||
|
||||
auto task_to_remove = tasks[1];
|
||||
auto task_id_to_remove = task_to_remove.GetTaskSpecification().TaskId();
|
||||
auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id_to_remove);
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto uncommitted_lineage_message =
|
||||
uncommitted_lineage.ToFlatbuffer(fbb, task_id_to_remove);
|
||||
fbb.Finish(uncommitted_lineage_message);
|
||||
uncommitted_lineage = Lineage(
|
||||
*flatbuffers::GetRoot<protocol::ForwardTaskRequest>(fbb.GetBufferPointer()));
|
||||
|
||||
const Task &task = uncommitted_lineage.GetEntry(task_id_to_remove)->TaskData();
|
||||
RAY_LOG(INFO) << "removing task " << task.GetTaskSpecification().TaskId()
|
||||
<< "with numforwards="
|
||||
<< task.GetTaskExecutionSpecReadonly().NumForwards();
|
||||
ASSERT_EQ(task.GetTaskExecutionSpecReadonly().NumForwards(), 1);
|
||||
|
||||
lineage_cache_.RemoveWaitingTask(task_id_to_remove);
|
||||
lineage_cache_.AddWaitingTask(task_to_remove, uncommitted_lineage);
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
+32
-21
@@ -5,34 +5,45 @@
|
||||
|
||||
#ifndef RAYLET_TEST
|
||||
int main(int argc, char *argv[]) {
|
||||
RAY_CHECK(argc == 2);
|
||||
RAY_CHECK(argc == 5);
|
||||
|
||||
// start store
|
||||
std::string executable_str = std::string(argv[0]);
|
||||
std::string exec_dir = executable_str.substr(0, executable_str.find_last_of("/"));
|
||||
std::string plasma_dir = exec_dir + "./../plasma";
|
||||
std::string plasma_command =
|
||||
plasma_dir +
|
||||
"/plasma_store -m 1000000000 -s /tmp/store 1> /dev/null 2> /dev/null &";
|
||||
RAY_LOG(INFO) << plasma_command;
|
||||
int s = system(plasma_command.c_str());
|
||||
RAY_CHECK(s == 0);
|
||||
const std::string raylet_socket_name = std::string(argv[1]);
|
||||
const std::string store_socket_name = std::string(argv[2]);
|
||||
const std::string redis_address = std::string(argv[3]);
|
||||
int redis_port = std::stoi(argv[4]);
|
||||
|
||||
// configure
|
||||
// Configuration for the node manager.
|
||||
ray::raylet::NodeManagerConfig node_manager_config;
|
||||
std::unordered_map<std::string, double> static_resource_conf;
|
||||
static_resource_conf = {{"CPU", 1}, {"GPU", 1}};
|
||||
ray::ResourceSet resource_config(std::move(static_resource_conf));
|
||||
ray::ObjectManagerConfig om_config;
|
||||
om_config.store_socket_name = "/tmp/store";
|
||||
node_manager_config.resource_config =
|
||||
ray::raylet::ResourceSet(std::move(static_resource_conf));
|
||||
node_manager_config.num_initial_workers = 0;
|
||||
// Use a default worker that can execute empty tasks with dependencies.
|
||||
node_manager_config.worker_command.push_back("python");
|
||||
node_manager_config.worker_command.push_back(
|
||||
"../../../src/ray/python/default_worker.py");
|
||||
node_manager_config.worker_command.push_back(raylet_socket_name.c_str());
|
||||
node_manager_config.worker_command.push_back(store_socket_name.c_str());
|
||||
// TODO(swang): Set this from a global config.
|
||||
node_manager_config.heartbeat_period_ms = 100;
|
||||
|
||||
// Configuration for the object manager.
|
||||
ray::ObjectManagerConfig object_manager_config;
|
||||
object_manager_config.store_socket_name = store_socket_name;
|
||||
|
||||
// initialize mock gcs & object directory
|
||||
std::shared_ptr<ray::GcsClient> mock_gcs_client =
|
||||
std::shared_ptr<ray::GcsClient>(new ray::GcsClient());
|
||||
auto gcs_client = std::make_shared<ray::gcs::AsyncGcsClient>();
|
||||
RAY_LOG(INFO) << "Initializing GCS client "
|
||||
<< gcs_client->client_table().GetLocalClientId();
|
||||
|
||||
// Initialize the node manager.
|
||||
boost::asio::io_service io_service;
|
||||
ray::Raylet server(io_service, std::string(argv[1]), resource_config, om_config,
|
||||
mock_gcs_client);
|
||||
io_service.run();
|
||||
boost::asio::io_service main_service;
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service;
|
||||
object_manager_service.reset(new boost::asio::io_service());
|
||||
ray::raylet::Raylet server(main_service, std::move(object_manager_service),
|
||||
raylet_socket_name, redis_address, redis_port,
|
||||
node_manager_config, object_manager_config, gcs_client);
|
||||
main_service.run();
|
||||
}
|
||||
#endif
|
||||
|
||||
+289
-48
@@ -5,20 +5,153 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
NodeManager::NodeManager(const std::string &socket_name,
|
||||
const ResourceSet &resource_config,
|
||||
ObjectManager &object_manager)
|
||||
: local_resources_(resource_config),
|
||||
worker_pool_(WorkerPool(0)),
|
||||
namespace raylet {
|
||||
|
||||
NodeManager::NodeManager(boost::asio::io_service &io_service,
|
||||
const NodeManagerConfig &config, ObjectManager &object_manager,
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
|
||||
: io_service_(io_service),
|
||||
heartbeat_timer_(io_service),
|
||||
heartbeat_period_ms_(config.heartbeat_period_ms),
|
||||
local_resources_(config.resource_config),
|
||||
worker_pool_(config.num_initial_workers, config.worker_command),
|
||||
local_queues_(SchedulingQueue()),
|
||||
scheduling_policy_(local_queues_),
|
||||
reconstruction_policy_([this](const TaskID &task_id) { ResubmitTask(task_id); }),
|
||||
task_dependency_manager_(
|
||||
object_manager,
|
||||
// reconstruction_policy_,
|
||||
[this](const TaskID &task_id) { HandleWaitingTaskReady(task_id); }) {
|
||||
//// TODO(atumanov): need to add the self-knowledge of ClientID, using nill().
|
||||
// cluster_resource_map_[ClientID::nil()] = local_resources_;
|
||||
[this](const TaskID &task_id) { HandleWaitingTaskReady(task_id); }),
|
||||
lineage_cache_(gcs_client->raylet_task_table()),
|
||||
gcs_client_(gcs_client),
|
||||
remote_clients_(),
|
||||
remote_server_connections_(),
|
||||
object_manager_(object_manager) {
|
||||
RAY_CHECK(heartbeat_period_ms_ > 0);
|
||||
// Initialize the resource map with own cluster resource configuration.
|
||||
ClientID local_client_id = gcs_client_->client_table().GetLocalClientId();
|
||||
cluster_resource_map_.emplace(local_client_id,
|
||||
SchedulingResources(config.resource_config));
|
||||
}
|
||||
|
||||
void NodeManager::Heartbeat() {
|
||||
RAY_LOG(DEBUG) << "[Heartbeat] sending heartbeat.";
|
||||
auto &heartbeat_table = gcs_client_->heartbeat_table();
|
||||
auto heartbeat_data = std::make_shared<HeartbeatTableDataT>();
|
||||
auto client_id = gcs_client_->client_table().GetLocalClientId();
|
||||
const SchedulingResources &local_resources = cluster_resource_map_[client_id];
|
||||
heartbeat_data->client_id = client_id.hex();
|
||||
// TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly.
|
||||
// TODO(atumanov): implement a ResourceSet const_iterator.
|
||||
for (const auto &resource_pair :
|
||||
local_resources.GetAvailableResources().GetResourceMap()) {
|
||||
heartbeat_data->resources_available_label.push_back(resource_pair.first);
|
||||
heartbeat_data->resources_available_capacity.push_back(resource_pair.second);
|
||||
}
|
||||
for (const auto &resource_pair : local_resources.GetTotalResources().GetResourceMap()) {
|
||||
heartbeat_data->resources_total_label.push_back(resource_pair.first);
|
||||
heartbeat_data->resources_total_capacity.push_back(resource_pair.second);
|
||||
}
|
||||
|
||||
ray::Status status = heartbeat_table.Add(
|
||||
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data,
|
||||
[](ray::gcs::AsyncGcsClient *client, const ClientID &id,
|
||||
std::shared_ptr<HeartbeatTableDataT> data) {
|
||||
RAY_LOG(DEBUG) << "[HEARTBEAT] heartbeat sent callback";
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(INFO) << "heartbeat failed: string " << status.ToString() << status.message();
|
||||
RAY_LOG(INFO) << "is redis error: " << status.IsRedisError();
|
||||
}
|
||||
RAY_CHECK_OK(status);
|
||||
|
||||
// Reset the timer.
|
||||
auto heartbeat_period = boost::posix_time::milliseconds(heartbeat_period_ms_);
|
||||
heartbeat_timer_.expires_from_now(heartbeat_period);
|
||||
heartbeat_timer_.async_wait([this](const boost::system::error_code &error) {
|
||||
RAY_CHECK(!error);
|
||||
Heartbeat();
|
||||
});
|
||||
}
|
||||
|
||||
void NodeManager::ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
const ClientTableDataT &client_data) {
|
||||
ClientID client_id = ClientID::from_binary(client_data.client_id);
|
||||
RAY_LOG(DEBUG) << "[ClientAdded] received callback from client id " << client_id.hex();
|
||||
if (client_id == gcs_client_->client_table().GetLocalClientId()) {
|
||||
// We got a notification for ourselves, so we are connected to the GCS now.
|
||||
// Save this NodeManager's resource information in the cluster resource map.
|
||||
cluster_resource_map_[client_id] = local_resources_;
|
||||
// Start sending heartbeats to the GCS.
|
||||
Heartbeat();
|
||||
// Subscribe to heartbeats.
|
||||
const auto heartbeat_added = [this](gcs::AsyncGcsClient *client, const ClientID &id,
|
||||
const HeartbeatTableDataT &heartbeat_data) {
|
||||
this->HeartbeatAdded(client, id, heartbeat_data);
|
||||
};
|
||||
ray::Status status = client->heartbeat_table().Subscribe(
|
||||
UniqueID::nil(), UniqueID::nil(), heartbeat_added,
|
||||
[this](gcs::AsyncGcsClient *client) {
|
||||
RAY_LOG(DEBUG) << "heartbeat table subscription done callback called.";
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(atumanov): make remote client lookup O(1)
|
||||
if (std::find(remote_clients_.begin(), remote_clients_.end(), client_id) ==
|
||||
remote_clients_.end()) {
|
||||
RAY_LOG(DEBUG) << "a new client: " << client_id.hex();
|
||||
remote_clients_.push_back(client_id);
|
||||
} else {
|
||||
// NodeManager connection to this client was already established.
|
||||
RAY_LOG(DEBUG) << "received a new client connection that already exists: "
|
||||
<< client_id.hex();
|
||||
return;
|
||||
}
|
||||
|
||||
ResourceSet resources_total(client_data.resources_total_label,
|
||||
client_data.resources_total_capacity);
|
||||
this->cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total));
|
||||
|
||||
// Establish a new NodeManager connection to this GCS client.
|
||||
auto client_info = gcs_client_->client_table().GetClient(client_id);
|
||||
RAY_LOG(DEBUG) << "[ClientAdded] CONNECTING TO: "
|
||||
<< " " << client_info.node_manager_address << " "
|
||||
<< client_info.node_manager_port;
|
||||
|
||||
boost::asio::ip::tcp::socket socket(io_service_);
|
||||
RAY_CHECK_OK(TcpConnect(socket, client_info.node_manager_address,
|
||||
client_info.node_manager_port));
|
||||
auto server_conn = TcpServerConnection(std::move(socket));
|
||||
remote_server_connections_.emplace(client_id, std::move(server_conn));
|
||||
}
|
||||
|
||||
void NodeManager::HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &client_id,
|
||||
const HeartbeatTableDataT &heartbeat_data) {
|
||||
RAY_LOG(DEBUG) << "[HeartbeatAdded]: received heartbeat from client id "
|
||||
<< client_id.hex();
|
||||
if (client_id == gcs_client_->client_table().GetLocalClientId()) {
|
||||
// Skip heartbeats from self.
|
||||
return;
|
||||
}
|
||||
// Locate the client id in remote client table and update available resources based on
|
||||
// the received heartbeat information.
|
||||
if (this->cluster_resource_map_.count(client_id) == 0) {
|
||||
// Haven't received the client registration for this client yet, skip this heartbeat.
|
||||
RAY_LOG(INFO) << "[HeartbeatAdded]: received heartbeat from unknown client id "
|
||||
<< client_id.hex();
|
||||
return;
|
||||
}
|
||||
SchedulingResources &resources = this->cluster_resource_map_[client_id];
|
||||
ResourceSet heartbeat_resource_available(heartbeat_data.resources_available_label,
|
||||
heartbeat_data.resources_available_capacity);
|
||||
resources.SetAvailableResources(
|
||||
ResourceSet(heartbeat_data.resources_available_label,
|
||||
heartbeat_data.resources_available_capacity));
|
||||
RAY_CHECK(this->cluster_resource_map_[client_id].GetAvailableResources() ==
|
||||
heartbeat_resource_available);
|
||||
}
|
||||
|
||||
void NodeManager::ProcessNewClient(std::shared_ptr<LocalClientConnection> client) {
|
||||
@@ -40,18 +173,6 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
|
||||
// Register the new worker.
|
||||
worker_pool_.RegisterWorker(std::move(worker));
|
||||
}
|
||||
|
||||
// Build the reply to the worker's registration request. TODO(swang): This
|
||||
// is legacy code and should be removed once actor creation tasks are
|
||||
// implemented.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto reply =
|
||||
protocol::CreateRegisterClientReply(fbb, fbb.CreateVector(std::vector<int>()));
|
||||
fbb.Finish(reply);
|
||||
// Reply to the worker's registration request, then listen for more
|
||||
// messages.
|
||||
client->WriteMessage(protocol::MessageType_RegisterClientReply, fbb.GetSize(),
|
||||
fbb.GetBufferPointer());
|
||||
} break;
|
||||
case protocol::MessageType_GetTask: {
|
||||
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
@@ -74,8 +195,14 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
|
||||
// Remove the dead worker from the pool and stop listening for messages.
|
||||
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
if (worker) {
|
||||
if (!worker->GetAssignedTaskId().is_nil()) {
|
||||
// TODO(swang): Clean up any tasks that were assigned to the worker.
|
||||
// Release any resources that may be held by this worker.
|
||||
FinishTask(worker->GetAssignedTaskId());
|
||||
}
|
||||
worker_pool_.DisconnectWorker(worker);
|
||||
}
|
||||
return;
|
||||
} break;
|
||||
case protocol::MessageType_SubmitTask: {
|
||||
// Read the task submitted by the client.
|
||||
@@ -84,14 +211,49 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
|
||||
from_flatbuf(*message->execution_dependencies()));
|
||||
TaskSpecification task_spec(*message->task_spec());
|
||||
Task task(task_execution_spec, task_spec);
|
||||
// Submit the task to the local scheduler.
|
||||
SubmitTask(task);
|
||||
// Listen for more messages.
|
||||
client->ProcessMessages();
|
||||
// Submit the task to the local scheduler. Since the task was submitted
|
||||
// locally, there is no uncommitted lineage.
|
||||
SubmitTask(task, Lineage());
|
||||
} break;
|
||||
case protocol::MessageType_ReconstructObject: {
|
||||
// TODO(hme): handle multiple object ids.
|
||||
auto message = flatbuffers::GetRoot<protocol::ReconstructObject>(message_data);
|
||||
ObjectID object_id = from_flatbuf(*message->object_id());
|
||||
RAY_LOG(DEBUG) << "reconstructing object " << object_id.hex();
|
||||
RAY_CHECK_OK(object_manager_.Pull(object_id));
|
||||
} break;
|
||||
|
||||
default:
|
||||
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
|
||||
}
|
||||
|
||||
// Listen for more messages.
|
||||
client->ProcessMessages();
|
||||
}
|
||||
|
||||
void NodeManager::ProcessNewNodeManager(
|
||||
std::shared_ptr<TcpClientConnection> node_manager_client) {
|
||||
node_manager_client->ProcessMessages();
|
||||
}
|
||||
|
||||
void NodeManager::ProcessNodeManagerMessage(
|
||||
std::shared_ptr<TcpClientConnection> node_manager_client, int64_t message_type,
|
||||
const uint8_t *message_data) {
|
||||
switch (message_type) {
|
||||
case protocol::MessageType_ForwardTaskRequest: {
|
||||
auto message = flatbuffers::GetRoot<protocol::ForwardTaskRequest>(message_data);
|
||||
TaskID task_id = from_flatbuf(*message->task_id());
|
||||
|
||||
Lineage uncommitted_lineage(*message);
|
||||
const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
|
||||
RAY_LOG(DEBUG) << "got task " << task.GetTaskSpecification().TaskId()
|
||||
<< " spillback=" << task.GetTaskExecutionSpecReadonly().NumForwards();
|
||||
SubmitTask(task, uncommitted_lineage);
|
||||
} break;
|
||||
default:
|
||||
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
|
||||
}
|
||||
node_manager_client->ProcessMessages();
|
||||
}
|
||||
|
||||
void NodeManager::HandleWaitingTaskReady(const TaskID &task_id) {
|
||||
@@ -102,29 +264,45 @@ void NodeManager::HandleWaitingTaskReady(const TaskID &task_id) {
|
||||
}
|
||||
|
||||
void NodeManager::ScheduleTasks() {
|
||||
// Ask policy for scheduling decision.
|
||||
// TODO(alexey): Give the policy all cluster resources instead of just the
|
||||
// local one.
|
||||
std::unordered_map<ClientID, SchedulingResources, UniqueIDHasher> cluster_resource_map;
|
||||
cluster_resource_map[ClientID::nil()] = local_resources_;
|
||||
const auto &policy_decision = scheduling_policy_.Schedule(cluster_resource_map);
|
||||
auto policy_decision = scheduling_policy_.Schedule(
|
||||
cluster_resource_map_, gcs_client_->client_table().GetLocalClientId(),
|
||||
remote_clients_);
|
||||
RAY_LOG(DEBUG) << "[NM ScheduleTasks] policy decision:";
|
||||
for (const auto &pair : policy_decision) {
|
||||
TaskID task_id = pair.first;
|
||||
ClientID client_id = pair.second;
|
||||
RAY_LOG(DEBUG) << task_id.hex() << " --> " << client_id.hex();
|
||||
}
|
||||
|
||||
// Extract decision for this local scheduler.
|
||||
// TODO(alexey): Check for this node's own client ID, not for nil.
|
||||
std::unordered_set<TaskID, UniqueIDHasher> task_ids;
|
||||
for (auto &task_schedule : policy_decision) {
|
||||
if (task_schedule.second.is_nil()) {
|
||||
task_ids.insert(task_schedule.first);
|
||||
std::unordered_set<TaskID, UniqueIDHasher> local_task_ids;
|
||||
// Iterate over (taskid, clientid) pairs, extract tasks to run on the local client.
|
||||
for (const auto &task_schedule : policy_decision) {
|
||||
TaskID task_id = task_schedule.first;
|
||||
ClientID client_id = task_schedule.second;
|
||||
if (client_id == gcs_client_->client_table().GetLocalClientId()) {
|
||||
local_task_ids.insert(task_id);
|
||||
} else {
|
||||
auto tasks = local_queues_.RemoveTasks({task_id});
|
||||
RAY_CHECK(1 == tasks.size());
|
||||
Task &task = tasks.front();
|
||||
// TODO(swang): Handle forward task failure.
|
||||
// TODO(swang): Unsubscribe this task in the task dependency manager.
|
||||
RAY_CHECK_OK(ForwardTask(task, client_id));
|
||||
}
|
||||
}
|
||||
|
||||
// Assign the tasks to workers.
|
||||
std::vector<Task> tasks = local_queues_.RemoveTasks(task_ids);
|
||||
std::vector<Task> tasks = local_queues_.RemoveTasks(local_task_ids);
|
||||
for (auto &task : tasks) {
|
||||
AssignTask(task);
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::SubmitTask(const Task &task) {
|
||||
void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineage) {
|
||||
// Add the task and its uncommitted lineage to the lineage cache.
|
||||
lineage_cache_.AddWaitingTask(task, uncommitted_lineage);
|
||||
// Queue the task according to the availability of its arguments.
|
||||
if (task_dependency_manager_.TaskReady(task)) {
|
||||
local_queues_.QueueReadyTasks(std::vector<Task>({task}));
|
||||
ScheduleTasks();
|
||||
@@ -135,41 +313,104 @@ void NodeManager::SubmitTask(const Task &task) {
|
||||
}
|
||||
|
||||
void NodeManager::AssignTask(const Task &task) {
|
||||
// Resource accounting: acquire resources for the scheduled task.
|
||||
const ClientID &my_client_id = gcs_client_->client_table().GetLocalClientId();
|
||||
RAY_CHECK(this->cluster_resource_map_[my_client_id].Acquire(
|
||||
task.GetTaskSpecification().GetRequiredResources()));
|
||||
|
||||
if (worker_pool_.PoolSize() == 0) {
|
||||
// Start a new worker.
|
||||
worker_pool_.StartWorker();
|
||||
// Queue this task for future assignment. The task will be assigned to a
|
||||
// worker once one becomes available.
|
||||
local_queues_.QueueScheduledTasks(std::vector<Task>({task}));
|
||||
// TODO(swang): Acquire resources here or when a worker becomes available?
|
||||
return;
|
||||
}
|
||||
|
||||
const TaskSpecification &spec = task.GetTaskSpecification();
|
||||
std::shared_ptr<Worker> worker = worker_pool_.PopWorker();
|
||||
RAY_LOG(DEBUG) << "Assigning task to worker with pid " << worker->Pid();
|
||||
|
||||
// TODO(swang): Acquire resources for the task.
|
||||
// local_resources_.Acquire(task.GetTaskSpecification().GetRequiredResources());
|
||||
worker->AssignTaskId(spec.TaskId());
|
||||
local_queues_.QueueRunningTasks(std::vector<Task>({task}));
|
||||
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
const TaskSpecification &spec = task.GetTaskSpecification();
|
||||
auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb),
|
||||
fbb.CreateVector(std::vector<int>()));
|
||||
fbb.Finish(message);
|
||||
worker->Connection()->WriteMessage(protocol::MessageType_ExecuteTask, fbb.GetSize(),
|
||||
fbb.GetBufferPointer());
|
||||
worker->AssignTaskId(spec.TaskId());
|
||||
local_queues_.QueueRunningTasks(std::vector<Task>({task}));
|
||||
auto status = worker->Connection()->WriteMessage(protocol::MessageType_ExecuteTask,
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
if (status.ok()) {
|
||||
// We started running the task, so the task is ready to write to GCS.
|
||||
lineage_cache_.AddReadyTask(task);
|
||||
} else {
|
||||
// We failed to send the task to the worker, so disconnect the worker. The
|
||||
// task will get queued again during cleanup.
|
||||
ProcessClientMessage(worker->Connection(), protocol::MessageType_DisconnectClient,
|
||||
NULL);
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::FinishTask(const TaskID &task_id) {
|
||||
RAY_LOG(DEBUG) << "Finished task " << task_id.hex();
|
||||
local_queues_.RemoveTasks({task_id});
|
||||
// TODO(swang): Release resources that were held for the task.
|
||||
auto tasks = local_queues_.RemoveTasks({task_id});
|
||||
RAY_CHECK(tasks.size() == 1);
|
||||
auto task = *tasks.begin();
|
||||
|
||||
// Resource accounting: release task's resources.
|
||||
RAY_CHECK(
|
||||
this->cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release(
|
||||
task.GetTaskSpecification().GetRequiredResources()));
|
||||
}
|
||||
|
||||
void NodeManager::ResubmitTask(const TaskID &task_id) {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
}
|
||||
|
||||
ray::Status NodeManager::ForwardTask(Task &task, const ClientID &node_id) {
|
||||
auto task_id = task.GetTaskSpecification().TaskId();
|
||||
|
||||
// Get and serialize the task's uncommitted lineage.
|
||||
auto uncommitted_lineage = lineage_cache_.GetUncommittedLineage(task_id);
|
||||
Task &lineage_cache_entry_task =
|
||||
uncommitted_lineage.GetEntryMutable(task_id)->TaskDataMutable();
|
||||
// Increment forward count for the forwarded task.
|
||||
lineage_cache_entry_task.GetTaskExecutionSpec().IncrementNumForwards();
|
||||
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto request = uncommitted_lineage.ToFlatbuffer(fbb, task_id);
|
||||
fbb.Finish(request);
|
||||
|
||||
RAY_LOG(DEBUG) << "Forwarding task " << task_id.hex() << " to " << node_id.hex()
|
||||
<< " spillback="
|
||||
<< lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards();
|
||||
|
||||
auto client_info = gcs_client_->client_table().GetClient(node_id);
|
||||
|
||||
// Lookup remote server connection for this node_id and use it to send the request.
|
||||
if (remote_server_connections_.count(node_id) == 0) {
|
||||
// TODO(atumanov): caller must handle failure to ensure tasks are not lost.
|
||||
RAY_LOG(INFO) << "No NodeManager connection found for GCS client id "
|
||||
<< node_id.hex();
|
||||
return ray::Status::IOError("NodeManager connection not found");
|
||||
}
|
||||
|
||||
auto &server_conn = remote_server_connections_.at(node_id);
|
||||
auto status = server_conn.WriteMessage(protocol::MessageType_ForwardTaskRequest,
|
||||
fbb.GetSize(), fbb.GetBufferPointer());
|
||||
if (status.ok()) {
|
||||
// If we were able to forward the task, remove the forwarded task from the
|
||||
// lineage cache since the receiving node is now responsible for writing
|
||||
// the task to the GCS.
|
||||
lineage_cache_.RemoveWaitingTask(task_id);
|
||||
} else {
|
||||
// TODO(atumanov): caller must handle ForwardTask failure to ensure tasks are not
|
||||
// lost.
|
||||
RAY_LOG(FATAL) << "[NodeManager][ForwardTask] failed to forward task "
|
||||
<< task_id.hex() << " to node " << node_id.hex();
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
#define RAY_RAYLET_NODE_MANAGER_H
|
||||
|
||||
// clang-format off
|
||||
#include "ray/raylet/task.h"
|
||||
#include "ray/object_manager/object_manager.h"
|
||||
#include "ray/common/client_connection.h"
|
||||
#include "ray/raylet/lineage_cache.h"
|
||||
#include "ray/raylet/scheduling_policy.h"
|
||||
#include "ray/raylet/scheduling_queue.h"
|
||||
#include "ray/raylet/scheduling_resources.h"
|
||||
#include "ray/object_manager/object_manager.h"
|
||||
#include "ray/raylet/reconstruction_policy.h"
|
||||
#include "ray/raylet/task_dependency_manager.h"
|
||||
#include "ray/raylet/worker_pool.h"
|
||||
@@ -14,16 +16,24 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
class NodeManager : public ClientManager<boost::asio::local::stream_protocol> {
|
||||
namespace raylet {
|
||||
|
||||
struct NodeManagerConfig {
|
||||
ResourceSet resource_config;
|
||||
int num_initial_workers;
|
||||
std::vector<const char *> worker_command;
|
||||
uint64_t heartbeat_period_ms;
|
||||
};
|
||||
|
||||
class NodeManager {
|
||||
public:
|
||||
/// Create a node manager.
|
||||
///
|
||||
/// \param socket_name The pathname of the Unix domain socket to listen at
|
||||
/// for local connections.
|
||||
/// \param resource_config The initial set of node resources.
|
||||
/// \param object_manager A reference to the local object manager.
|
||||
NodeManager(const std::string &socket_name, const ResourceSet &resource_config,
|
||||
ObjectManager &object_manager);
|
||||
NodeManager(boost::asio::io_service &io_service, const NodeManagerConfig &config,
|
||||
ObjectManager &object_manager,
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
|
||||
|
||||
/// Process a new client connection.
|
||||
void ProcessNewClient(std::shared_ptr<LocalClientConnection> client);
|
||||
@@ -38,26 +48,41 @@ class NodeManager : public ClientManager<boost::asio::local::stream_protocol> {
|
||||
void ProcessClientMessage(std::shared_ptr<LocalClientConnection> client,
|
||||
int64_t message_type, const uint8_t *message);
|
||||
|
||||
void ProcessNewNodeManager(std::shared_ptr<TcpClientConnection> node_manager_client);
|
||||
|
||||
void ProcessNodeManagerMessage(std::shared_ptr<TcpClientConnection> node_manager_client,
|
||||
int64_t message_type, const uint8_t *message);
|
||||
|
||||
void ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
const ClientTableDataT &data);
|
||||
|
||||
void HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &id,
|
||||
const HeartbeatTableDataT &data);
|
||||
|
||||
private:
|
||||
/// Submit a task to this node.
|
||||
void SubmitTask(const Task &task);
|
||||
void SubmitTask(const Task &task, const Lineage &uncommitted_lineage);
|
||||
/// Assign a task.
|
||||
void AssignTask(const Task &task);
|
||||
/// Finish a task.
|
||||
void FinishTask(const TaskID &task_id);
|
||||
/// Schedule tasks.
|
||||
void ScheduleTasks();
|
||||
/// Handle a task whose local dependencies were missing and are now
|
||||
/// available.
|
||||
/// Handle a task whose local dependencies were missing and are now available.
|
||||
void HandleWaitingTaskReady(const TaskID &task_id);
|
||||
/// Resubmit a task whose return value needs to be reconstructed.
|
||||
void ResubmitTask(const TaskID &task_id);
|
||||
ray::Status ForwardTask(Task &task, const ClientID &node_id);
|
||||
/// Send heartbeats to the GCS.
|
||||
void Heartbeat();
|
||||
|
||||
boost::asio::io_service &io_service_;
|
||||
boost::asio::deadline_timer heartbeat_timer_;
|
||||
uint64_t heartbeat_period_ms_;
|
||||
/// The resources local to this node.
|
||||
SchedulingResources local_resources_;
|
||||
const SchedulingResources local_resources_;
|
||||
// TODO(atumanov): Add resource information from other nodes.
|
||||
// std::unordered_map<ClientID, SchedulingResources&, UniqueIDHasher>
|
||||
// cluster_resource_map_;
|
||||
std::unordered_map<ClientID, SchedulingResources, UniqueIDHasher> cluster_resource_map_;
|
||||
/// A pool of workers.
|
||||
WorkerPool worker_pool_;
|
||||
/// A set of queues to maintain tasks.
|
||||
@@ -68,8 +93,18 @@ class NodeManager : public ClientManager<boost::asio::local::stream_protocol> {
|
||||
ReconstructionPolicy reconstruction_policy_;
|
||||
/// A manager to make waiting tasks's missing object dependencies available.
|
||||
TaskDependencyManager task_dependency_manager_;
|
||||
/// The lineage cache for the GCS object and task tables.
|
||||
LineageCache lineage_cache_;
|
||||
/// A client connection to the GCS.
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
|
||||
std::vector<ClientID> remote_clients_;
|
||||
std::unordered_map<ClientID, TcpServerConnection, UniqueIDHasher>
|
||||
remote_server_connections_;
|
||||
ObjectManager &object_manager_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // end namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_NODE_MANAGER_H
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/raylet/raylet.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
std::string test_executable;
|
||||
std::string store_executable;
|
||||
|
||||
// TODO(hme): Get this working once the dust settles.
|
||||
class TestObjectManagerBase : public ::testing::Test {
|
||||
public:
|
||||
TestObjectManagerBase() { RAY_LOG(INFO) << "TestObjectManagerBase: started."; }
|
||||
|
||||
std::string StartStore(const std::string &id) {
|
||||
std::string store_id = "/tmp/store";
|
||||
store_id = store_id + id;
|
||||
std::string plasma_command = store_executable + " -m 1000000000 -s " + store_id +
|
||||
" 1> /dev/null 2> /dev/null &";
|
||||
RAY_LOG(INFO) << plasma_command;
|
||||
int ec = system(plasma_command.c_str());
|
||||
if (ec != 0) {
|
||||
throw std::runtime_error("failed to start plasma store.");
|
||||
};
|
||||
return store_id;
|
||||
}
|
||||
|
||||
NodeManagerConfig GetNodeManagerConfig(std::string raylet_socket_name,
|
||||
std::string store_socket_name) {
|
||||
// Configuration for the node manager.
|
||||
ray::raylet::NodeManagerConfig node_manager_config;
|
||||
std::unordered_map<std::string, double> static_resource_conf;
|
||||
static_resource_conf = {{"CPU", 1}, {"GPU", 1}};
|
||||
node_manager_config.resource_config =
|
||||
ray::raylet::ResourceSet(std::move(static_resource_conf));
|
||||
node_manager_config.num_initial_workers = 0;
|
||||
// Use a default worker that can execute empty tasks with dependencies.
|
||||
node_manager_config.worker_command.push_back("python");
|
||||
node_manager_config.worker_command.push_back(
|
||||
"../../../src/ray/python/default_worker.py");
|
||||
node_manager_config.worker_command.push_back(raylet_socket_name.c_str());
|
||||
node_manager_config.worker_command.push_back(store_socket_name.c_str());
|
||||
return node_manager_config;
|
||||
};
|
||||
|
||||
void SetUp() {
|
||||
object_manager_service_1.reset(new boost::asio::io_service());
|
||||
object_manager_service_2.reset(new boost::asio::io_service());
|
||||
|
||||
// start store
|
||||
std::string store_sock_1 = StartStore("1");
|
||||
std::string store_sock_2 = StartStore("2");
|
||||
|
||||
// start first server
|
||||
gcs_client_1 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
|
||||
ObjectManagerConfig om_config_1;
|
||||
om_config_1.store_socket_name = store_sock_1;
|
||||
server1.reset(new ray::raylet::Raylet(
|
||||
main_service, std::move(object_manager_service_1), "raylet_1", "127.0.0.1", 6379,
|
||||
GetNodeManagerConfig("raylet_1", store_sock_1), om_config_1, gcs_client_1));
|
||||
|
||||
// start second server
|
||||
gcs_client_2 = std::shared_ptr<gcs::AsyncGcsClient>(new gcs::AsyncGcsClient());
|
||||
ObjectManagerConfig om_config_2;
|
||||
om_config_2.store_socket_name = store_sock_2;
|
||||
server2.reset(new ray::raylet::Raylet(
|
||||
main_service, std::move(object_manager_service_2), "raylet_2", "127.0.0.1", 6379,
|
||||
GetNodeManagerConfig("raylet_2", store_sock_2), om_config_2, gcs_client_2));
|
||||
|
||||
// connect to stores.
|
||||
ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
}
|
||||
|
||||
void TearDown() {
|
||||
arrow::Status client1_status = client1.Disconnect();
|
||||
arrow::Status client2_status = client2.Disconnect();
|
||||
ASSERT_TRUE(client1_status.ok() && client2_status.ok());
|
||||
|
||||
this->server1.reset();
|
||||
this->server2.reset();
|
||||
|
||||
int s = system("killall plasma_store &");
|
||||
ASSERT_TRUE(!s);
|
||||
|
||||
std::string cmd_str = test_executable.substr(0, test_executable.find_last_of("/"));
|
||||
s = system(("rm " + cmd_str + "/raylet_1").c_str());
|
||||
ASSERT_TRUE(!s);
|
||||
s = system(("rm " + cmd_str + "/raylet_2").c_str());
|
||||
ASSERT_TRUE(!s);
|
||||
}
|
||||
|
||||
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
|
||||
ObjectID object_id = ObjectID::from_random();
|
||||
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id;
|
||||
uint8_t metadata[] = {5};
|
||||
int64_t metadata_size = sizeof(metadata);
|
||||
std::shared_ptr<Buffer> data;
|
||||
ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata,
|
||||
metadata_size, &data));
|
||||
ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id()));
|
||||
return object_id;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::thread p;
|
||||
boost::asio::io_service main_service;
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service_1;
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service_2;
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_1;
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_2;
|
||||
std::unique_ptr<ray::raylet::Raylet> server1;
|
||||
std::unique_ptr<ray::raylet::Raylet> server2;
|
||||
|
||||
plasma::PlasmaClient client1;
|
||||
plasma::PlasmaClient client2;
|
||||
std::vector<ObjectID> v1;
|
||||
std::vector<ObjectID> v2;
|
||||
};
|
||||
|
||||
class TestObjectManagerIntegration : public TestObjectManagerBase {
|
||||
public:
|
||||
uint num_expected_objects;
|
||||
|
||||
int num_connected_clients = 0;
|
||||
|
||||
ClientID client_id_1;
|
||||
ClientID client_id_2;
|
||||
|
||||
void WaitConnections() {
|
||||
client_id_1 = gcs_client_1->client_table().GetLocalClientId();
|
||||
client_id_2 = gcs_client_2->client_table().GetLocalClientId();
|
||||
gcs_client_1->client_table().RegisterClientAddedCallback([this](
|
||||
gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) {
|
||||
ClientID parsed_id = ClientID::from_binary(data.client_id);
|
||||
if (parsed_id == client_id_1 || parsed_id == client_id_2) {
|
||||
num_connected_clients += 1;
|
||||
}
|
||||
if (num_connected_clients == 2) {
|
||||
StartTests();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void StartTests() {
|
||||
TestConnections();
|
||||
AddTransferTestHandlers();
|
||||
TestPush(100);
|
||||
}
|
||||
|
||||
void AddTransferTestHandlers() {
|
||||
ray::Status status = ray::Status::OK();
|
||||
status =
|
||||
server1->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
|
||||
v1.push_back(object_id);
|
||||
if (v1.size() == num_expected_objects && v1.size() == v2.size()) {
|
||||
TestPushComplete();
|
||||
}
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
status =
|
||||
server2->object_manager_.SubscribeObjAdded([this](const ObjectID &object_id) {
|
||||
v2.push_back(object_id);
|
||||
if (v2.size() == num_expected_objects && v1.size() == v2.size()) {
|
||||
TestPushComplete();
|
||||
}
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
}
|
||||
|
||||
void TestPush(int64_t data_size) {
|
||||
ray::Status status = ray::Status::OK();
|
||||
|
||||
num_expected_objects = (uint)1;
|
||||
ObjectID oid1 = WriteDataToClient(client1, data_size);
|
||||
status = server1->object_manager_.Push(oid1, client_id_2);
|
||||
}
|
||||
|
||||
void TestPushComplete() {
|
||||
RAY_LOG(INFO) << "TestPushComplete: "
|
||||
<< " " << v1.size() << " " << v2.size();
|
||||
ASSERT_TRUE(v1.size() == v2.size());
|
||||
for (int i = -1; ++i < (int)v1.size();) {
|
||||
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
|
||||
}
|
||||
v1.clear();
|
||||
v2.clear();
|
||||
main_service.stop();
|
||||
}
|
||||
|
||||
void TestConnections() {
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "Server client ids:"
|
||||
<< "\n";
|
||||
ClientID client_id_1 = gcs_client_1->client_table().GetLocalClientId();
|
||||
ClientID client_id_2 = gcs_client_2->client_table().GetLocalClientId();
|
||||
RAY_LOG(INFO) << "Server 1: " << client_id_1;
|
||||
RAY_LOG(INFO) << "Server 2: " << client_id_2;
|
||||
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "All connected clients:"
|
||||
<< "\n";
|
||||
const ClientTableDataT &data = gcs_client_2->client_table().GetClient(client_id_1);
|
||||
RAY_LOG(INFO) << (ClientID::from_binary(data.client_id) == ClientID::nil());
|
||||
RAY_LOG(INFO) << "ClientID=" << ClientID::from_binary(data.client_id);
|
||||
RAY_LOG(INFO) << "ClientIp=" << data.node_manager_address;
|
||||
RAY_LOG(INFO) << "ClientPort=" << data.node_manager_port;
|
||||
const ClientTableDataT &data2 = gcs_client_1->client_table().GetClient(client_id_2);
|
||||
RAY_LOG(INFO) << "ClientID=" << ClientID::from_binary(data2.client_id);
|
||||
RAY_LOG(INFO) << "ClientIp=" << data2.node_manager_address;
|
||||
RAY_LOG(INFO) << "ClientPort=" << data2.node_manager_port;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TestObjectManagerIntegration, StartTestObjectManagerPush) {
|
||||
auto AsyncStartTests = main_service.wrap([this]() { WaitConnections(); });
|
||||
AsyncStartTests();
|
||||
main_service.run();
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
ray::raylet::test_executable = std::string(argv[0]);
|
||||
ray::raylet::store_executable = std::string(argv[1]);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
+120
-37
@@ -1,56 +1,129 @@
|
||||
#include "raylet.h"
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/bind.hpp>
|
||||
#include <boost/date_time/posix_time/posix_time.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#include "ray/status.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
Raylet::Raylet(boost::asio::io_service &io_service, const std::string &socket_name,
|
||||
const ResourceSet &resource_config,
|
||||
namespace raylet {
|
||||
|
||||
Raylet::Raylet(boost::asio::io_service &main_service,
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service,
|
||||
const std::string &socket_name, const std::string &redis_address,
|
||||
int redis_port, const NodeManagerConfig &node_manager_config,
|
||||
const ObjectManagerConfig &object_manager_config,
|
||||
std::shared_ptr<ray::GcsClient> gcs_client)
|
||||
: acceptor_(io_service, boost::asio::local::stream_protocol::endpoint(socket_name)),
|
||||
socket_(io_service),
|
||||
tcp_acceptor_(io_service,
|
||||
boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
|
||||
tcp_socket_(io_service),
|
||||
object_manager_(io_service, object_manager_config, gcs_client),
|
||||
node_manager_(socket_name, resource_config, object_manager_),
|
||||
gcs_client_(gcs_client) {
|
||||
ClientID client_id = RegisterGcs();
|
||||
object_manager_.SetClientID(client_id);
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client)
|
||||
: acceptor_(main_service, boost::asio::local::stream_protocol::endpoint(socket_name)),
|
||||
socket_(main_service),
|
||||
object_manager_acceptor_(
|
||||
main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
|
||||
object_manager_socket_(main_service),
|
||||
node_manager_acceptor_(
|
||||
main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), 0)),
|
||||
node_manager_socket_(main_service),
|
||||
gcs_client_(gcs_client),
|
||||
object_manager_(main_service, std::move(object_manager_service),
|
||||
object_manager_config, gcs_client),
|
||||
node_manager_(main_service, node_manager_config, object_manager_, gcs_client_) {
|
||||
// Start listening for clients.
|
||||
DoAccept();
|
||||
DoAcceptTcp();
|
||||
DoAcceptObjectManager();
|
||||
DoAcceptNodeManager();
|
||||
|
||||
RAY_CHECK_OK(RegisterGcs(redis_address, redis_port, main_service, node_manager_config));
|
||||
|
||||
RAY_CHECK_OK(RegisterPeriodicTimer(main_service));
|
||||
}
|
||||
|
||||
Raylet::~Raylet() { RAY_CHECK_OK(object_manager_.Terminate()); }
|
||||
|
||||
ClientID Raylet::RegisterGcs() {
|
||||
boost::asio::ip::tcp::endpoint endpoint = tcp_acceptor_.local_endpoint();
|
||||
std::string ip = endpoint.address().to_string();
|
||||
uint16_t port = endpoint.port();
|
||||
ClientID client_id = gcs_client_->Register(ip, port);
|
||||
return client_id;
|
||||
Raylet::~Raylet() {
|
||||
RAY_CHECK_OK(gcs_client_->client_table().Disconnect());
|
||||
RAY_CHECK_OK(object_manager_.Terminate());
|
||||
}
|
||||
|
||||
void Raylet::DoAcceptTcp() {
|
||||
TCPClientConnection::pointer new_connection =
|
||||
TCPClientConnection::Create(acceptor_.get_io_service());
|
||||
tcp_acceptor_.async_accept(new_connection->GetSocket(),
|
||||
boost::bind(&Raylet::HandleAcceptTcp, this, new_connection,
|
||||
boost::asio::placeholders::error));
|
||||
ray::Status Raylet::RegisterPeriodicTimer(boost::asio::io_service &io_service) {
|
||||
boost::posix_time::milliseconds timer_period_ms(100);
|
||||
boost::asio::deadline_timer timer(io_service, timer_period_ms);
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
void Raylet::HandleAcceptTcp(TCPClientConnection::pointer new_connection,
|
||||
const boost::system::error_code &error) {
|
||||
if (!error) {
|
||||
// Pass it off to object manager for now.
|
||||
ray::Status status = object_manager_.AcceptConnection(std::move(new_connection));
|
||||
ray::Status Raylet::RegisterGcs(const std::string &redis_address, int redis_port,
|
||||
boost::asio::io_service &io_service,
|
||||
const NodeManagerConfig &node_manager_config) {
|
||||
RAY_RETURN_NOT_OK(gcs_client_->Connect(redis_address, redis_port));
|
||||
RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service));
|
||||
|
||||
ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient();
|
||||
client_info.node_manager_address =
|
||||
node_manager_acceptor_.local_endpoint().address().to_string();
|
||||
client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port();
|
||||
client_info.node_manager_port = node_manager_acceptor_.local_endpoint().port();
|
||||
// Add resource information.
|
||||
for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) {
|
||||
client_info.resources_total_label.push_back(resource_pair.first);
|
||||
client_info.resources_total_capacity.push_back(resource_pair.second);
|
||||
}
|
||||
DoAcceptTcp();
|
||||
|
||||
RAY_LOG(DEBUG) << "NM LISTENING ON: IP " << client_info.node_manager_address << " PORT "
|
||||
<< client_info.node_manager_port;
|
||||
RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info));
|
||||
|
||||
auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
const ClientTableDataT &data) {
|
||||
node_manager_.ClientAdded(client, id, data);
|
||||
};
|
||||
gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Raylet::DoAcceptNodeManager() {
|
||||
node_manager_acceptor_.async_accept(node_manager_socket_,
|
||||
boost::bind(&Raylet::HandleAcceptNodeManager, this,
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) {
|
||||
if (!error) {
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler =
|
||||
[this](std::shared_ptr<TcpClientConnection> client) {
|
||||
node_manager_.ProcessNewNodeManager(client);
|
||||
};
|
||||
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
|
||||
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
node_manager_.ProcessNodeManagerMessage(client, message_type, message);
|
||||
};
|
||||
// Accept a new local client and dispatch it to the node manager.
|
||||
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
|
||||
std::move(node_manager_socket_));
|
||||
}
|
||||
// We're ready to accept another client.
|
||||
DoAcceptNodeManager();
|
||||
}
|
||||
|
||||
void Raylet::DoAcceptObjectManager() {
|
||||
object_manager_acceptor_.async_accept(
|
||||
object_manager_socket_, boost::bind(&Raylet::HandleAcceptObjectManager, this,
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void Raylet::HandleAcceptObjectManager(const boost::system::error_code &error) {
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler =
|
||||
[this](std::shared_ptr<TcpClientConnection> client) {
|
||||
object_manager_.ProcessNewClient(client);
|
||||
};
|
||||
MessageHandler<boost::asio::ip::tcp> message_handler = [this](
|
||||
std::shared_ptr<TcpClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
object_manager_.ProcessClientMessage(client, message_type, message);
|
||||
};
|
||||
// Accept a new local client and dispatch it to the node manager.
|
||||
auto new_connection = TcpClientConnection::Create(client_handler, message_handler,
|
||||
std::move(object_manager_socket_));
|
||||
DoAcceptObjectManager();
|
||||
}
|
||||
|
||||
void Raylet::DoAccept() {
|
||||
@@ -60,14 +133,24 @@ void Raylet::DoAccept() {
|
||||
|
||||
void Raylet::HandleAccept(const boost::system::error_code &error) {
|
||||
if (!error) {
|
||||
// TODO: typedef these handlers.
|
||||
ClientHandler<boost::asio::local::stream_protocol> client_handler =
|
||||
[this](std::shared_ptr<LocalClientConnection> client) {
|
||||
node_manager_.ProcessNewClient(client);
|
||||
};
|
||||
MessageHandler<boost::asio::local::stream_protocol> message_handler = [this](
|
||||
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
node_manager_.ProcessClientMessage(client, message_type, message);
|
||||
};
|
||||
// Accept a new local client and dispatch it to the node manager.
|
||||
auto new_connection =
|
||||
LocalClientConnection::Create(node_manager_, std::move(socket_));
|
||||
auto new_connection = LocalClientConnection::Create(client_handler, message_handler,
|
||||
std::move(socket_));
|
||||
}
|
||||
// We're ready to accept another client.
|
||||
DoAccept();
|
||||
}
|
||||
|
||||
ObjectManager &Raylet::GetObjectManager() { return object_manager_; }
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
+30
-18
@@ -14,63 +14,75 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
class Task;
|
||||
class NodeManager;
|
||||
|
||||
// TODO(swang): Rename class and source files to Raylet.
|
||||
class Raylet {
|
||||
public:
|
||||
/// Create a node manager server and listen for new clients.
|
||||
///
|
||||
/// \param io_service The event loop to run the server on.
|
||||
/// \param main_service The event loop to run the server on.
|
||||
/// \param object_manager_service The asio io_service tied to the object manager.
|
||||
/// \param socket_name The Unix domain socket to listen on for local clients.
|
||||
/// \param resource_config The initial set of resources to start the local
|
||||
/// \param node_manager_config Configuration to initialize the node manager.
|
||||
/// scheduler with.
|
||||
/// \param object_manager_config Configuration to initialize the object
|
||||
/// manager.
|
||||
/// \param gcs_client A client connection to the GCS.
|
||||
Raylet(boost::asio::io_service &io_service, const std::string &socket_name,
|
||||
const ResourceSet &resource_config,
|
||||
Raylet(boost::asio::io_service &main_service,
|
||||
std::unique_ptr<boost::asio::io_service> object_manager_service,
|
||||
const std::string &socket_name, const std::string &redis_address, int redis_port,
|
||||
const NodeManagerConfig &node_manager_config,
|
||||
const ObjectManagerConfig &object_manager_config,
|
||||
std::shared_ptr<ray::GcsClient> gcs_client);
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client);
|
||||
|
||||
/// Destroy the NodeServer.
|
||||
~Raylet();
|
||||
|
||||
// TODO(melih): Get rid of this method.
|
||||
ObjectManager &GetObjectManager();
|
||||
|
||||
private:
|
||||
/// Register GCS client.
|
||||
ClientID RegisterGcs();
|
||||
ray::Status RegisterGcs(const std::string &redis_address, int redis_port,
|
||||
boost::asio::io_service &io_service, const NodeManagerConfig &);
|
||||
|
||||
ray::Status RegisterPeriodicTimer(boost::asio::io_service &io_service);
|
||||
/// Accept a client connection.
|
||||
void DoAccept();
|
||||
/// Handle an accepted client connection.
|
||||
void HandleAccept(const boost::system::error_code &error);
|
||||
/// Accept a tcp client connection.
|
||||
void DoAcceptTcp();
|
||||
void DoAcceptObjectManager();
|
||||
/// Handle an accepted tcp client connection.
|
||||
void HandleAcceptTcp(TCPClientConnection::pointer new_connection,
|
||||
const boost::system::error_code &error);
|
||||
void HandleAcceptObjectManager(const boost::system::error_code &error);
|
||||
void DoAcceptNodeManager();
|
||||
void HandleAcceptNodeManager(const boost::system::error_code &error);
|
||||
|
||||
friend class TestObjectManagerIntegration;
|
||||
|
||||
/// An acceptor for new clients.
|
||||
boost::asio::local::stream_protocol::acceptor acceptor_;
|
||||
/// The socket to listen on for new clients.
|
||||
boost::asio::local::stream_protocol::socket socket_;
|
||||
/// An acceptor for new object manager tcp clients.
|
||||
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
|
||||
/// The socket to listen on for new object manager tcp clients.
|
||||
boost::asio::ip::tcp::socket object_manager_socket_;
|
||||
/// An acceptor for new tcp clients.
|
||||
boost::asio::ip::tcp::acceptor tcp_acceptor_;
|
||||
boost::asio::ip::tcp::acceptor node_manager_acceptor_;
|
||||
/// The socket to listen on for new tcp clients.
|
||||
boost::asio::ip::tcp::socket tcp_socket_;
|
||||
boost::asio::ip::tcp::socket node_manager_socket_;
|
||||
|
||||
// TODO(swang): Lineage cache.
|
||||
/// A client connection to the GCS.
|
||||
std::shared_ptr<gcs::AsyncGcsClient> gcs_client_;
|
||||
/// Manages client requests for object transfers and availability.
|
||||
ObjectManager object_manager_;
|
||||
/// Manages client requests for task submission and execution.
|
||||
NodeManager node_manager_;
|
||||
/// A client connection to the GCS.
|
||||
std::shared_ptr<ray::GcsClient> gcs_client_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_RAYLET_H
|
||||
|
||||
@@ -1,275 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/raylet/raylet.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
std::string test_executable; // NOLINT
|
||||
|
||||
class TestRaylet : public ::testing::Test {
|
||||
public:
|
||||
TestRaylet() { RAY_LOG(INFO) << "TestRaylet: started."; }
|
||||
|
||||
std::string StartStore(const std::string &id) {
|
||||
std::string store_id = "/tmp/store";
|
||||
store_id = store_id + id;
|
||||
std::string test_dir = test_executable.substr(0, test_executable.find_last_of("/"));
|
||||
std::string plasma_dir = test_dir + "./../plasma";
|
||||
std::string plasma_command = plasma_dir + "/plasma_store -m 1000000000 -s " +
|
||||
store_id + " 1> /dev/null 2> /dev/null &";
|
||||
RAY_LOG(INFO) << plasma_command;
|
||||
int ec = system(plasma_command.c_str());
|
||||
if (ec != 0) {
|
||||
throw std::runtime_error("failed to start plasma store.");
|
||||
};
|
||||
return store_id;
|
||||
}
|
||||
|
||||
void SetUp() {
|
||||
// start store
|
||||
std::string store_sock_1 = StartStore("1");
|
||||
std::string store_sock_2 = StartStore("2");
|
||||
|
||||
// configure
|
||||
std::unordered_map<std::string, double> static_resource_config;
|
||||
static_resource_config = {{"num_cpus", 1}, {"num_gpus", 1}};
|
||||
ray::ResourceSet resource_config(std::move(static_resource_config));
|
||||
|
||||
// start mock gcs
|
||||
mock_gcs_client = std::shared_ptr<GcsClient>(new GcsClient());
|
||||
|
||||
// start first server
|
||||
ray::ObjectManagerConfig om_config_1;
|
||||
om_config_1.store_socket_name = store_sock_1;
|
||||
server1.reset(new Raylet(io_service, std::string("hello1"), resource_config,
|
||||
om_config_1, mock_gcs_client));
|
||||
|
||||
// start second server
|
||||
ray::ObjectManagerConfig om_config_2;
|
||||
om_config_2.store_socket_name = store_sock_2;
|
||||
server2.reset(new Raylet(io_service, std::string("hello2"), resource_config,
|
||||
om_config_2, mock_gcs_client));
|
||||
|
||||
// connect to stores.
|
||||
ARROW_CHECK_OK(client1.Connect(store_sock_1, "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
ARROW_CHECK_OK(client2.Connect(store_sock_2, "", PLASMA_DEFAULT_RELEASE_DELAY));
|
||||
this->StartLoop();
|
||||
}
|
||||
|
||||
void TearDown() {
|
||||
this->StopLoop();
|
||||
arrow::Status client1_status = client1.Disconnect();
|
||||
arrow::Status client2_status = client2.Disconnect();
|
||||
ASSERT_TRUE(client1_status.ok() && client2_status.ok());
|
||||
|
||||
this->server1.reset();
|
||||
this->server2.reset();
|
||||
|
||||
int s = system("killall plasma_store &");
|
||||
ASSERT_TRUE(!s);
|
||||
|
||||
std::string cmd_str = test_executable.substr(0, test_executable.find_last_of("/"));
|
||||
s = system(("rm " + cmd_str + "/hello1").c_str());
|
||||
ASSERT_TRUE(!s);
|
||||
s = system(("rm " + cmd_str + "/hello2").c_str());
|
||||
ASSERT_TRUE(!s);
|
||||
}
|
||||
|
||||
void Loop() { io_service.run(); };
|
||||
|
||||
void StartLoop() { p = std::thread(&TestRaylet::Loop, this); };
|
||||
|
||||
void StopLoop() {
|
||||
io_service.stop();
|
||||
p.join();
|
||||
}
|
||||
|
||||
ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) {
|
||||
ObjectID object_id = ObjectID::from_random();
|
||||
RAY_LOG(DEBUG) << "ObjectID Created: " << object_id.hex().c_str();
|
||||
uint8_t metadata[] = {5};
|
||||
int64_t metadata_size = sizeof(metadata);
|
||||
std::shared_ptr<Buffer> data;
|
||||
ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata,
|
||||
metadata_size, &data));
|
||||
ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id()));
|
||||
return object_id;
|
||||
}
|
||||
|
||||
void object_added_handler_1(const ObjectID &object_id) {
|
||||
RAY_LOG(INFO) << "Store 1 added: " << object_id.hex();
|
||||
v1.push_back(object_id);
|
||||
};
|
||||
|
||||
void object_added_handler_2(const ObjectID &object_id) {
|
||||
RAY_LOG(INFO) << "Store 2 added: " << object_id.hex();
|
||||
v2.push_back(object_id);
|
||||
};
|
||||
|
||||
protected:
|
||||
std::thread p;
|
||||
boost::asio::io_service io_service;
|
||||
std::shared_ptr<ray::GcsClient> mock_gcs_client;
|
||||
std::unique_ptr<ray::Raylet> server1;
|
||||
std::unique_ptr<ray::Raylet> server2;
|
||||
|
||||
plasma::PlasmaClient client1;
|
||||
plasma::PlasmaClient client2;
|
||||
std::vector<ObjectID> v1;
|
||||
std::vector<ObjectID> v2;
|
||||
};
|
||||
|
||||
TEST_F(TestRaylet, TestRayletCommands) {
|
||||
ray::Status status = ray::Status::OK();
|
||||
// TODO(atumanov): assert status is OK everywhere it's returned.
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "All connected clients:"
|
||||
<< "\n";
|
||||
status = mock_gcs_client->client_table().GetClientIds(
|
||||
[this](const std::vector<ClientID> &client_ids) {
|
||||
mock_gcs_client->client_table().GetClientInformationSet(
|
||||
client_ids,
|
||||
[this](const std::vector<ClientInformation> &info_vec) {
|
||||
for (const auto &info : info_vec) {
|
||||
RAY_LOG(INFO) << "ClientID=" << info.GetClientId().hex();
|
||||
RAY_LOG(INFO) << "ClientIp=" << info.GetIp();
|
||||
RAY_LOG(INFO) << "ClientPort=" << info.GetPort();
|
||||
}
|
||||
},
|
||||
[](Status status) {});
|
||||
});
|
||||
|
||||
sleep(1);
|
||||
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "Server client ids:"
|
||||
<< "\n";
|
||||
|
||||
status = server1->GetObjectManager().SubscribeObjAdded(
|
||||
[this](const ObjectID &object_id) { object_added_handler_1(object_id); });
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
status = server2->GetObjectManager().SubscribeObjAdded(
|
||||
[this](const ObjectID &object_id) { object_added_handler_2(object_id); });
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
ClientID client_id_1 = server1->GetObjectManager().GetClientID();
|
||||
ClientID client_id_2 = server2->GetObjectManager().GetClientID();
|
||||
RAY_LOG(INFO) << "Server 1: " << client_id_1.hex();
|
||||
RAY_LOG(INFO) << "Server 2: " << client_id_2.hex();
|
||||
|
||||
sleep(1);
|
||||
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "Test bidirectional pull"
|
||||
<< "\n";
|
||||
for (int i = -1; ++i < 100;) {
|
||||
ObjectID oid1 = WriteDataToClient(client1, 100);
|
||||
ObjectID oid2 = WriteDataToClient(client2, 100);
|
||||
status = server1->GetObjectManager().Pull(oid2);
|
||||
status = server2->GetObjectManager().Pull(oid1);
|
||||
}
|
||||
sleep(1);
|
||||
RAY_LOG(INFO) << v1.size() << " " << v2.size();
|
||||
ASSERT_TRUE(v1.size() == v2.size());
|
||||
for (int i = -1; ++i < (int)v1.size();) {
|
||||
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
|
||||
}
|
||||
v1.clear();
|
||||
v2.clear();
|
||||
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "Test pull 1 from 2"
|
||||
<< "\n";
|
||||
for (int i = -1; ++i < 3;) {
|
||||
ObjectID oid2 = WriteDataToClient(client2, 100);
|
||||
status = server1->GetObjectManager().Pull(oid2);
|
||||
}
|
||||
sleep(1);
|
||||
RAY_LOG(INFO) << v1.size() << " " << v2.size();
|
||||
ASSERT_TRUE(v1.size() == v2.size());
|
||||
for (int i = -1; ++i < (int)v1.size();) {
|
||||
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
|
||||
}
|
||||
v1.clear();
|
||||
v2.clear();
|
||||
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "Test pull 2 from 1"
|
||||
<< "\n";
|
||||
for (int i = -1; ++i < 3;) {
|
||||
ObjectID oid1 = WriteDataToClient(client1, 100);
|
||||
status = server2->GetObjectManager().Pull(oid1);
|
||||
}
|
||||
sleep(1);
|
||||
RAY_LOG(INFO) << v1.size() << " " << v2.size();
|
||||
ASSERT_TRUE(v1.size() == v2.size());
|
||||
for (int i = -1; ++i < (int)v1.size();) {
|
||||
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
|
||||
}
|
||||
v1.clear();
|
||||
v2.clear();
|
||||
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "Test push 1 to 2"
|
||||
<< "\n";
|
||||
for (int i = -1; ++i < 3;) {
|
||||
ObjectID oid1 = WriteDataToClient(client1, 100);
|
||||
status = server1->GetObjectManager().Push(oid1, client_id_2);
|
||||
}
|
||||
sleep(1);
|
||||
RAY_LOG(INFO) << v1.size() << " " << v2.size();
|
||||
ASSERT_TRUE(v1.size() == v2.size());
|
||||
for (int i = -1; ++i < (int)v1.size();) {
|
||||
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
|
||||
}
|
||||
v1.clear();
|
||||
v2.clear();
|
||||
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "Test push 2 to 1"
|
||||
<< "\n";
|
||||
for (int i = -1; ++i < 3;) {
|
||||
ObjectID oid2 = WriteDataToClient(client2, 100);
|
||||
status = server2->GetObjectManager().Push(oid2, client_id_1);
|
||||
}
|
||||
sleep(1);
|
||||
RAY_LOG(INFO) << v1.size() << " " << v2.size();
|
||||
ASSERT_TRUE(v1.size() == v2.size());
|
||||
for (int i = -1; ++i < (int)v1.size();) {
|
||||
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
|
||||
}
|
||||
v1.clear();
|
||||
v2.clear();
|
||||
|
||||
RAY_LOG(INFO) << "\n"
|
||||
<< "Test bidirectional push"
|
||||
<< "\n";
|
||||
for (int i = -1; ++i < 3;) {
|
||||
ObjectID oid1 = WriteDataToClient(client1, 100);
|
||||
ObjectID oid2 = WriteDataToClient(client2, 100);
|
||||
status = server1->GetObjectManager().Push(oid1, client_id_2);
|
||||
status = server2->GetObjectManager().Push(oid2, client_id_1);
|
||||
}
|
||||
sleep(1);
|
||||
RAY_LOG(INFO) << v1.size() << " " << v2.size();
|
||||
ASSERT_TRUE(v1.size() == v2.size());
|
||||
for (int i = -1; ++i < (int)v1.size();) {
|
||||
ASSERT_TRUE(std::find(v1.begin(), v1.end(), v2[i]) != v1.end());
|
||||
}
|
||||
v1.clear();
|
||||
v2.clear();
|
||||
|
||||
ASSERT_TRUE(true);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
ray::test_executable = std::string(argv[0]);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -2,8 +2,12 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
void ReconstructionPolicy::CheckObjectReconstruction(const ObjectID &object) {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // end namespace ray
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
// TODO(swang): Use std::function instead of boost.
|
||||
|
||||
class ReconstructionPolicy {
|
||||
@@ -29,6 +31,8 @@ class ReconstructionPolicy {
|
||||
private:
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_RECONSTRUCTION_POLICY_H
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
#include <iostream>
|
||||
|
||||
#include "ray/raylet/raylet.h"
|
||||
|
||||
/// A demo that starts two Raylets, with one object store each. The two Raylets
|
||||
/// share a mock GCS client for communication between the two (e.g., for
|
||||
/// ObjectManager::Push).
|
||||
int main(int argc, char *argv[]) {
|
||||
RAY_CHECK(argc == 3);
|
||||
std::string store1 = "/tmp/store1";
|
||||
std::string store2 = "/tmp/store2";
|
||||
// start store
|
||||
std::string plasma_dir = "../../plasma";
|
||||
std::string plasma_command1 = plasma_dir + "/plasma_store -m 1000000000 -s ";
|
||||
std::string plasma_command2 = " 1> /dev/null 2> /dev/null &";
|
||||
RAY_LOG(INFO) << plasma_command1 << store1 << plasma_command2;
|
||||
RAY_LOG(INFO) << plasma_command1 << store2 << plasma_command2;
|
||||
int s;
|
||||
s = system((plasma_command1 + store1 + plasma_command2).c_str());
|
||||
RAY_CHECK(s == 0);
|
||||
s = system((plasma_command1 + store2 + plasma_command2).c_str());
|
||||
|
||||
// configure
|
||||
std::unordered_map<std::string, double> static_resource_conf;
|
||||
static_resource_conf = {{"CPU", 1}, {"GPU", 1}};
|
||||
ray::ResourceSet resource_config(std::move(static_resource_conf));
|
||||
ray::ObjectManagerConfig om_config;
|
||||
|
||||
// initialize mock gcs & object directory
|
||||
std::shared_ptr<ray::GcsClient> mock_gcs_client =
|
||||
std::shared_ptr<ray::GcsClient>(new ray::GcsClient());
|
||||
|
||||
// Initialize the node manager.
|
||||
boost::asio::io_service io_service;
|
||||
om_config.store_socket_name = store1;
|
||||
ray::Raylet server1(io_service, std::string(argv[1]), resource_config, om_config,
|
||||
mock_gcs_client);
|
||||
om_config.store_socket_name = store2;
|
||||
ray::Raylet server2(io_service, std::string(argv[2]), resource_config, om_config,
|
||||
mock_gcs_client);
|
||||
io_service.run();
|
||||
}
|
||||
@@ -1,32 +1,75 @@
|
||||
#include "scheduling_policy.h"
|
||||
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
SchedulingPolicy::SchedulingPolicy(const SchedulingQueue &scheduling_queue)
|
||||
: scheduling_queue_(scheduling_queue) {}
|
||||
: scheduling_queue_(scheduling_queue), gen_(rd_()) {}
|
||||
|
||||
std::unordered_map<TaskID, ClientID, UniqueIDHasher> SchedulingPolicy::Schedule(
|
||||
const std::unordered_map<ClientID, SchedulingResources, UniqueIDHasher>
|
||||
&cluster_resources) {
|
||||
static ClientID local_node_id = ClientID::nil();
|
||||
&cluster_resources,
|
||||
const ClientID &local_client_id, const std::vector<ClientID> &others) {
|
||||
// The policy decision to be returned.
|
||||
std::unordered_map<TaskID, ClientID, UniqueIDHasher> decision;
|
||||
// TODO(atumanov): consider all cluster resources.
|
||||
SchedulingResources resource_supply = cluster_resources.at(local_node_id);
|
||||
const auto &resource_supply_set = resource_supply.GetAvailableResources();
|
||||
// TODO(atumanov): protect DEBUG code blocks with ifdef DEBUG
|
||||
RAY_LOG(DEBUG) << "[Schedule] cluster resource map: ";
|
||||
for (const auto &client_resource_pair : cluster_resources) {
|
||||
// pair = ClientID, SchedulingResources
|
||||
const ClientID &client_id = client_resource_pair.first;
|
||||
const SchedulingResources &resources = client_resource_pair.second;
|
||||
RAY_LOG(DEBUG) << "client_id: " << client_id << " "
|
||||
<< resources.GetAvailableResources().ToString();
|
||||
}
|
||||
|
||||
// Iterate over running tasks, get their resource demand and try to schedule.
|
||||
for (const auto &t : scheduling_queue_.GetReadyTasks()) {
|
||||
// Get task's resource demand
|
||||
const auto &resource_demand = t.GetTaskSpecification().GetRequiredResources();
|
||||
bool task_feasible = resource_demand.IsSubset(resource_supply_set);
|
||||
if (task_feasible) {
|
||||
const TaskID &task_id = t.GetTaskSpecification().TaskId();
|
||||
decision[task_id] = local_node_id;
|
||||
const TaskID &task_id = t.GetTaskSpecification().TaskId();
|
||||
RAY_LOG(DEBUG) << "[SchedulingPolicy]: task=" << task_id
|
||||
<< " numforwards=" << t.GetTaskExecutionSpecReadonly().NumForwards()
|
||||
<< " resources="
|
||||
<< t.GetTaskSpecification().GetRequiredResources().ToString();
|
||||
// TODO(atumanov): replace the simple spillback policy with exponential backoff based
|
||||
// policy.
|
||||
if (t.GetTaskExecutionSpecReadonly().NumForwards() >= 1) {
|
||||
decision[task_id] = local_client_id;
|
||||
continue;
|
||||
}
|
||||
// Construct a set of viable node candidates and randomly pick between them.
|
||||
// Get all the client id keys and randomly pick.
|
||||
std::vector<ClientID> client_keys;
|
||||
for (const auto &client_resource_pair : cluster_resources) {
|
||||
// pair = ClientID, SchedulingResources
|
||||
ClientID node_client_id = client_resource_pair.first;
|
||||
SchedulingResources node_resources = client_resource_pair.second;
|
||||
RAY_LOG(DEBUG) << "client_id " << node_client_id << " resources: "
|
||||
<< node_resources.GetAvailableResources().ToString();
|
||||
if (resource_demand.IsSubset(node_resources.GetTotalResources())) {
|
||||
// This node is a feasible candidate.
|
||||
client_keys.push_back(node_client_id);
|
||||
}
|
||||
}
|
||||
RAY_CHECK(!client_keys.empty());
|
||||
|
||||
// Choose index at random.
|
||||
// Initialize a uniform integer distribution over the key space.
|
||||
// TODO(atumanov): change uniform random to discrete, weighted by resource capacity.
|
||||
std::uniform_int_distribution<int> distribution(0, client_keys.size() - 1);
|
||||
int client_key_index = distribution(gen_);
|
||||
decision[task_id] = client_keys[client_key_index];
|
||||
RAY_LOG(DEBUG) << "[SchedulingPolicy] idx=" << client_key_index << " " << task_id
|
||||
<< " --> " << client_keys[client_key_index];
|
||||
}
|
||||
return decision;
|
||||
}
|
||||
|
||||
SchedulingPolicy::~SchedulingPolicy() {}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#ifndef RAY_RAYLET_SCHEDULING_POLICY_H
|
||||
#define RAY_RAYLET_SCHEDULING_POLICY_H
|
||||
|
||||
#include <random>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ray/raylet/scheduling_queue.h"
|
||||
@@ -8,6 +9,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
/// \class SchedulingPolicy
|
||||
/// \brief Implements a scheduling policy for the node manager.
|
||||
class SchedulingPolicy {
|
||||
@@ -27,7 +30,8 @@ class SchedulingPolicy {
|
||||
/// \return Scheduling decision, mapping tasks to node managers for placement.
|
||||
std::unordered_map<TaskID, ClientID, UniqueIDHasher> Schedule(
|
||||
const std::unordered_map<ClientID, SchedulingResources, UniqueIDHasher>
|
||||
&cluster_resources);
|
||||
&cluster_resources,
|
||||
const ClientID &local_client_id, const std::vector<ClientID> &others);
|
||||
|
||||
/// \brief SchedulingPolicy destructor.
|
||||
virtual ~SchedulingPolicy();
|
||||
@@ -35,8 +39,14 @@ class SchedulingPolicy {
|
||||
private:
|
||||
/// An immutable reference to the scheduling task queues.
|
||||
const SchedulingQueue &scheduling_queue_;
|
||||
/// Internally maintained random number engine device.
|
||||
std::random_device rd_;
|
||||
/// Internally maintained random number generator.
|
||||
std::mt19937_64 gen_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_SCHEDULING_POLICY_H
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
const std::list<Task> &SchedulingQueue::GetWaitingTasks() const {
|
||||
return this->waiting_tasks_;
|
||||
}
|
||||
@@ -88,4 +90,6 @@ bool SchedulingQueue::RegisterActor(ActorID actor_id,
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
/// \class SchedulingQueue
|
||||
///
|
||||
/// Encapsulates task queues. Each queue represents a scheduling state for a
|
||||
@@ -103,6 +105,9 @@ class SchedulingQueue {
|
||||
/// The registry of known actors.
|
||||
std::unordered_map<ActorID, ActorInformation, UniqueIDHasher> actor_registry_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_SCHEDULING_QUEUE_H
|
||||
|
||||
@@ -2,13 +2,25 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
ResourceSet::ResourceSet() {}
|
||||
|
||||
ResourceSet::ResourceSet(const std::unordered_map<std::string, double> &resource_map)
|
||||
: resource_capacity_(resource_map) {}
|
||||
|
||||
ResourceSet::ResourceSet(const std::vector<std::string> &resource_labels,
|
||||
const std::vector<double> resource_capacity) {
|
||||
RAY_CHECK(resource_labels.size() == resource_capacity.size());
|
||||
for (uint i = 0; i < resource_labels.size(); i++) {
|
||||
RAY_CHECK(this->AddResource(resource_labels[i], resource_capacity[i]));
|
||||
}
|
||||
}
|
||||
|
||||
ResourceSet::~ResourceSet() {}
|
||||
|
||||
bool ResourceSet::operator==(const ResourceSet &rhs) const {
|
||||
@@ -43,16 +55,42 @@ bool ResourceSet::IsEqual(const ResourceSet &rhs) const {
|
||||
}
|
||||
|
||||
bool ResourceSet::AddResource(const std::string &resource_name, double capacity) {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
this->resource_capacity_[resource_name] = capacity;
|
||||
return true;
|
||||
}
|
||||
bool ResourceSet::RemoveResource(const std::string &resource_name) {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
}
|
||||
bool ResourceSet::SubtractResources(const ResourceSet &other) {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
// Return failure if attempting to perform vector subtraction with unknown labels.
|
||||
// TODO(atumanov): make the implementation atomic. Currently, if false is returned
|
||||
// the resource capacity may be partially mutated. To reverse, call AddResources.
|
||||
for (const auto &resource_pair : other.GetResourceMap()) {
|
||||
const std::string &resource_label = resource_pair.first;
|
||||
const double &resource_capacity = resource_pair.second;
|
||||
if (resource_capacity_.count(resource_label) == 0) {
|
||||
return false;
|
||||
} else {
|
||||
resource_capacity_[resource_label] -= resource_capacity;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ResourceSet::AddResources(const ResourceSet &other) {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
// Return failure if attempting to perform vector addition with unknown labels.
|
||||
// TODO(atumanov): make the implementation atomic. Currently, if false is returned
|
||||
// the resource capacity may be partially mutated. To reverse, call SubtractResources.
|
||||
for (const auto &resource_pair : other.GetResourceMap()) {
|
||||
const std::string &resource_label = resource_pair.first;
|
||||
const double &resource_capacity = resource_pair.second;
|
||||
if (resource_capacity_.count(resource_label) == 0) {
|
||||
return false;
|
||||
} else {
|
||||
resource_capacity_[resource_label] += resource_capacity;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ResourceSet::GetResource(const std::string &resource_name, double *value) const {
|
||||
@@ -67,6 +105,19 @@ bool ResourceSet::GetResource(const std::string &resource_name, double *value) c
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::string ResourceSet::ToString() const {
|
||||
std::string return_string = "";
|
||||
for (const auto &resource_pair : this->resource_capacity_) {
|
||||
return_string +=
|
||||
"{" + resource_pair.first + "," + std::to_string(resource_pair.second) + "}, ";
|
||||
}
|
||||
return return_string;
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, double> &ResourceSet::GetResourceMap() const {
|
||||
return this->resource_capacity_;
|
||||
};
|
||||
|
||||
/// SchedulingResources class implementation
|
||||
|
||||
SchedulingResources::SchedulingResources()
|
||||
@@ -93,6 +144,14 @@ const ResourceSet &SchedulingResources::GetAvailableResources() const {
|
||||
return this->resources_available_;
|
||||
}
|
||||
|
||||
void SchedulingResources::SetAvailableResources(ResourceSet &&newset) {
|
||||
this->resources_available_ = newset;
|
||||
}
|
||||
|
||||
const ResourceSet &SchedulingResources::GetTotalResources() const {
|
||||
return this->resources_total_;
|
||||
}
|
||||
|
||||
// Return specified resources back to SchedulingResources.
|
||||
bool SchedulingResources::Release(const ResourceSet &resources) {
|
||||
return this->resources_available_.AddResources(resources);
|
||||
@@ -103,4 +162,6 @@ bool SchedulingResources::Acquire(const ResourceSet &resources) {
|
||||
return this->resources_available_.SubtractResources(resources);
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -4,9 +4,12 @@
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
/// Resource availability status reports whether the resource requirement is
|
||||
/// (1) infeasible, (2) feasible but currently unavailable, or (3) available.
|
||||
typedef enum {
|
||||
@@ -26,6 +29,11 @@ class ResourceSet {
|
||||
/// \brief Constructs ResourceSet from the specified resource map.
|
||||
ResourceSet(const std::unordered_map<std::string, double> &resource_map);
|
||||
|
||||
/// \brief Constructs ResourceSet from two equal-length vectors with label and capacity
|
||||
/// specification.
|
||||
ResourceSet(const std::vector<std::string> &resource_labels,
|
||||
const std::vector<double> resource_capacity);
|
||||
|
||||
/// \brief Empty ResourceSet destructor.
|
||||
~ResourceSet();
|
||||
|
||||
@@ -89,6 +97,11 @@ class ResourceSet {
|
||||
/// False otherwise.
|
||||
bool GetResource(const std::string &resource_name, double *value) const;
|
||||
|
||||
// TODO(atumanov): implement const_iterator class for the ResourceSet container.
|
||||
const std::unordered_map<std::string, double> &GetResourceMap() const;
|
||||
|
||||
const std::string ToString() const;
|
||||
|
||||
private:
|
||||
/// Resource capacity map.
|
||||
std::unordered_map<std::string, double> resource_capacity_;
|
||||
@@ -125,6 +138,14 @@ class SchedulingResources {
|
||||
/// \return Immutable set of resources with currently available capacity.
|
||||
const ResourceSet &GetAvailableResources() const;
|
||||
|
||||
/// \brief Overwrite available resource capacity with the specified resource set.
|
||||
///
|
||||
/// \param newset: The set of resources that replaces available resource capacity.
|
||||
/// \return None.
|
||||
void SetAvailableResources(ResourceSet &&newset);
|
||||
|
||||
const ResourceSet &GetTotalResources() const;
|
||||
|
||||
/// \brief Release the amount of resources specified.
|
||||
///
|
||||
/// \param resources: the amount of resources to be released.
|
||||
@@ -145,6 +166,8 @@ class SchedulingResources {
|
||||
/// gpu_map - replace with ResourceMap (for generality).
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_SCHEDULING_RESOURCES_H
|
||||
|
||||
+14
-1
@@ -2,7 +2,18 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
const TaskExecutionSpecification &Task::GetTaskExecutionSpec() const {
|
||||
namespace raylet {
|
||||
|
||||
flatbuffers::Offset<protocol::Task> Task::ToFlatbuffer(
|
||||
flatbuffers::FlatBufferBuilder &fbb) const {
|
||||
auto task = CreateTask(fbb, task_spec_.ToFlatbuffer(fbb),
|
||||
task_execution_spec_.ToFlatbuffer(fbb));
|
||||
return task;
|
||||
}
|
||||
|
||||
TaskExecutionSpecification &Task::GetTaskExecutionSpec() { return task_execution_spec_; }
|
||||
|
||||
const TaskExecutionSpecification &Task::GetTaskExecutionSpecReadonly() const {
|
||||
return task_execution_spec_;
|
||||
}
|
||||
|
||||
@@ -47,4 +58,6 @@ bool Task::DependsOn(const ObjectID &object_id) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
+23
-2
@@ -3,11 +3,14 @@
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
#include "ray/raylet/task_execution_spec.h"
|
||||
#include "ray/raylet/task_spec.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
/// \class Task
|
||||
///
|
||||
/// A Task represents a Ray task and a specification of its execution (e.g.,
|
||||
@@ -27,13 +30,29 @@ class Task {
|
||||
const TaskSpecification &task_spec)
|
||||
: task_execution_spec_(execution_spec), task_spec_(task_spec) {}
|
||||
|
||||
/// Create a task from a serialized flatbuffer.
|
||||
///
|
||||
/// \param task_flatbuffer The serialized task.
|
||||
Task(const protocol::Task &task_flatbuffer)
|
||||
: task_execution_spec_(*task_flatbuffer.task_execution_spec()),
|
||||
task_spec_(*task_flatbuffer.task_specification()) {}
|
||||
|
||||
/// Destroy the task.
|
||||
virtual ~Task() {}
|
||||
|
||||
/// Serialize a task to a flatbuffer.
|
||||
///
|
||||
/// \param fbb The flatbuffer builder.
|
||||
/// \return An offset to the serialized task.
|
||||
flatbuffers::Offset<protocol::Task> ToFlatbuffer(
|
||||
flatbuffers::FlatBufferBuilder &fbb) const;
|
||||
|
||||
/// Get the execution specification for the task.
|
||||
///
|
||||
/// \return The mutable specification for the task.
|
||||
const TaskExecutionSpecification &GetTaskExecutionSpec() const;
|
||||
TaskExecutionSpecification &GetTaskExecutionSpec();
|
||||
|
||||
const TaskExecutionSpecification &GetTaskExecutionSpecReadonly() const;
|
||||
|
||||
/// Get the immutable specification for the task.
|
||||
///
|
||||
@@ -64,6 +83,8 @@ class Task {
|
||||
TaskSpecification task_spec_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_TASK_H
|
||||
#endif // RAY_RAYLET_TASK_H
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
TaskDependencyManager::TaskDependencyManager(
|
||||
ObjectManager &object_manager,
|
||||
// ReconstructionPolicy &reconstruction_policy,
|
||||
@@ -59,6 +61,8 @@ bool TaskDependencyManager::TaskReady(const Task &task) const {
|
||||
}
|
||||
|
||||
void TaskDependencyManager::SubscribeTaskReady(const Task &task) {
|
||||
// TODO(swang): Don't pull arguments that are going to be created by a queued
|
||||
// or running task.
|
||||
TaskID task_id = task.GetTaskSpecification().TaskId();
|
||||
const std::vector<ObjectID> arguments = task.GetDependencies();
|
||||
// Add the task's arguments to the table of subscribed tasks.
|
||||
@@ -104,4 +108,6 @@ void TaskDependencyManager::MarkDependencyReady(const ObjectID &object) {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
class ReconstructionPolicy;
|
||||
|
||||
/// \class TaskDependencyManager
|
||||
@@ -77,6 +79,8 @@ class TaskDependencyManager {
|
||||
std::function<void(const TaskID &)> task_ready_callback_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_TASK_DEPENDENCY_MANAGER_H
|
||||
|
||||
@@ -2,35 +2,57 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
TaskExecutionSpecification::TaskExecutionSpecification(
|
||||
const std::vector<ObjectID> &&execution_dependencies)
|
||||
: execution_dependencies_(std::move(execution_dependencies)),
|
||||
last_timestamp_(0),
|
||||
spillback_count_(0) {}
|
||||
namespace raylet {
|
||||
|
||||
TaskExecutionSpecification::TaskExecutionSpecification(
|
||||
const std::vector<ObjectID> &&execution_dependencies, int spillback_count)
|
||||
: execution_dependencies_(std::move(execution_dependencies)),
|
||||
last_timestamp_(0),
|
||||
spillback_count_(spillback_count) {}
|
||||
const std::vector<ObjectID> &&dependencies) {
|
||||
SetExecutionDependencies(dependencies);
|
||||
}
|
||||
|
||||
const std::vector<ObjectID> &TaskExecutionSpecification::ExecutionDependencies() const {
|
||||
return execution_dependencies_;
|
||||
TaskExecutionSpecification::TaskExecutionSpecification(
|
||||
const std::vector<ObjectID> &&dependencies, int num_forwards) {
|
||||
// TaskExecutionSpecification(std::move(dependencies));
|
||||
SetExecutionDependencies(dependencies);
|
||||
execution_spec_.num_forwards = num_forwards;
|
||||
}
|
||||
|
||||
flatbuffers::Offset<protocol::TaskExecutionSpecification>
|
||||
TaskExecutionSpecification::ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const {
|
||||
fbb.ForceDefaults(true);
|
||||
return protocol::TaskExecutionSpecification::Pack(fbb, &execution_spec_);
|
||||
}
|
||||
|
||||
std::vector<ObjectID> TaskExecutionSpecification::ExecutionDependencies() const {
|
||||
std::vector<ObjectID> dependencies;
|
||||
for (const auto &dependency : execution_spec_.dependencies) {
|
||||
dependencies.push_back(ObjectID::from_binary(dependency));
|
||||
}
|
||||
return dependencies;
|
||||
}
|
||||
|
||||
void TaskExecutionSpecification::SetExecutionDependencies(
|
||||
const std::vector<ObjectID> &dependencies) {
|
||||
execution_dependencies_ = dependencies;
|
||||
for (const auto &dependency : dependencies) {
|
||||
execution_spec_.dependencies.push_back(dependency.binary());
|
||||
}
|
||||
}
|
||||
|
||||
int TaskExecutionSpecification::SpillbackCount() const { return spillback_count_; }
|
||||
|
||||
void TaskExecutionSpecification::IncrementSpillbackCount() { ++spillback_count_; }
|
||||
|
||||
int64_t TaskExecutionSpecification::LastTimeStamp() const { return last_timestamp_; }
|
||||
|
||||
void TaskExecutionSpecification::SetLastTimeStamp(int64_t new_timestamp) {
|
||||
last_timestamp_ = new_timestamp;
|
||||
int TaskExecutionSpecification::NumForwards() const {
|
||||
return execution_spec_.num_forwards;
|
||||
}
|
||||
|
||||
void TaskExecutionSpecification::IncrementNumForwards() {
|
||||
execution_spec_.num_forwards += 1;
|
||||
}
|
||||
|
||||
int64_t TaskExecutionSpecification::LastTimestamp() const {
|
||||
return execution_spec_.last_timestamp;
|
||||
}
|
||||
|
||||
void TaskExecutionSpecification::SetLastTimestamp(int64_t new_timestamp) {
|
||||
execution_spec_.last_timestamp = new_timestamp;
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -4,9 +4,12 @@
|
||||
#include <vector>
|
||||
|
||||
#include "ray/id.h"
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
/// \class TaskExecutionSpecification
|
||||
///
|
||||
/// The task execution specification encapsulates all mutable information about
|
||||
@@ -16,60 +19,70 @@ class TaskExecutionSpecification {
|
||||
public:
|
||||
/// Create a task execution specification.
|
||||
///
|
||||
/// \param execution_dependencies The task's dependencies, determined at
|
||||
/// execution time.
|
||||
TaskExecutionSpecification(const std::vector<ObjectID> &&execution_dependencies);
|
||||
/// \param dependencies The task's dependencies, determined at execution
|
||||
/// time.
|
||||
TaskExecutionSpecification(const std::vector<ObjectID> &&dependencies);
|
||||
|
||||
/// Create a task execution specification.
|
||||
///
|
||||
/// \param execution_dependencies The task's dependencies, determined at
|
||||
/// execution time.
|
||||
/// \param spillback_count The number of times this task was spilled back by
|
||||
/// local schedulers.
|
||||
TaskExecutionSpecification(const std::vector<ObjectID> &&execution_dependencies,
|
||||
int spillback_count);
|
||||
/// \param dependencies The task's dependencies, determined at execution
|
||||
/// time.
|
||||
/// \param num_forwards The number of times this task has been forwarded by a
|
||||
/// node manager.
|
||||
TaskExecutionSpecification(const std::vector<ObjectID> &&dependencies,
|
||||
int num_forwards);
|
||||
|
||||
/// Create a task execution specification from a serialized flatbuffer.
|
||||
///
|
||||
/// \param spec_flatbuffer The serialized specification.
|
||||
TaskExecutionSpecification(
|
||||
const protocol::TaskExecutionSpecification &spec_flatbuffer) {
|
||||
spec_flatbuffer.UnPackTo(&execution_spec_);
|
||||
}
|
||||
|
||||
/// Serialize a task execution specification to a flatbuffer.
|
||||
///
|
||||
/// \param fbb The flatbuffer builder.
|
||||
/// \return An offset to the serialized task execution specification.
|
||||
flatbuffers::Offset<protocol::TaskExecutionSpecification> ToFlatbuffer(
|
||||
flatbuffers::FlatBufferBuilder &fbb) const;
|
||||
|
||||
/// Get the task's execution dependencies.
|
||||
///
|
||||
/// \return A vector of object IDs representing this task's execution
|
||||
/// dependencies.
|
||||
const std::vector<ObjectID> &ExecutionDependencies() const;
|
||||
/// dependencies.
|
||||
std::vector<ObjectID> ExecutionDependencies() const;
|
||||
|
||||
/// Set the task's execution dependencies.
|
||||
///
|
||||
/// \param dependencies The value to set the execution dependencies to.
|
||||
void SetExecutionDependencies(const std::vector<ObjectID> &dependencies);
|
||||
|
||||
/// Get the task's spillback count, which tracks the number of times
|
||||
/// this task was spilled back from local to the global scheduler.
|
||||
/// Get the number of times this task has been forwarded.
|
||||
///
|
||||
/// \return The spillback count for this task.
|
||||
int SpillbackCount() const;
|
||||
/// \return The number of times this task has been forwarded.
|
||||
int NumForwards() const;
|
||||
|
||||
/// Increment the spillback count for this task.
|
||||
void IncrementSpillbackCount();
|
||||
/// Increment the number of times this task has been forwarded.
|
||||
void IncrementNumForwards();
|
||||
|
||||
/// Get the task's last timestamp.
|
||||
///
|
||||
/// \return The timestamp when this task was last received for scheduling.
|
||||
int64_t LastTimeStamp() const;
|
||||
int64_t LastTimestamp() const;
|
||||
|
||||
/// Set the task's last timestamp to the specified value.
|
||||
///
|
||||
/// \param new_timestamp The new timestamp in millisecond to set the task's
|
||||
/// time stamp to. Tracks the last time this task entered a local
|
||||
/// scheduler.
|
||||
void SetLastTimeStamp(int64_t new_timestamp);
|
||||
/// time stamp to. Tracks the last time this task entered a local scheduler.
|
||||
void SetLastTimestamp(int64_t new_timestamp);
|
||||
|
||||
private:
|
||||
/// A list of object IDs representing the dependencies of this task that may
|
||||
/// change at execution time.
|
||||
std::vector<ObjectID> execution_dependencies_;
|
||||
/// The last time this task was received for scheduling.
|
||||
int64_t last_timestamp_;
|
||||
/// The number of times this task was spilled back by local schedulers.
|
||||
int spillback_count_;
|
||||
protocol::TaskExecutionSpecificationT execution_spec_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_TASK_EXECUTION_SPECIFICATION_H
|
||||
|
||||
+30
-43
@@ -5,6 +5,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
TaskArgument::~TaskArgument() {}
|
||||
|
||||
TaskArgumentByReference::TaskArgumentByReference(const std::vector<ObjectID> &references)
|
||||
@@ -15,14 +17,6 @@ flatbuffers::Offset<Arg> TaskArgumentByReference::ToFlatbuffer(
|
||||
return CreateArg(fbb, to_flatbuf(fbb, references_));
|
||||
}
|
||||
|
||||
const BYTE *TaskArgumentByReference::HashData() const {
|
||||
return reinterpret_cast<const BYTE *>(references_.data());
|
||||
}
|
||||
|
||||
size_t TaskArgumentByReference::HashDataLength() const {
|
||||
return references_.size() * sizeof(ObjectID);
|
||||
}
|
||||
|
||||
TaskArgumentByValue::TaskArgumentByValue(const uint8_t *value, size_t length) {
|
||||
value_.assign(value, value + length);
|
||||
}
|
||||
@@ -35,27 +29,12 @@ flatbuffers::Offset<Arg> TaskArgumentByValue::ToFlatbuffer(
|
||||
return CreateArg(fbb, empty_ids, arg);
|
||||
}
|
||||
|
||||
const BYTE *TaskArgumentByValue::HashData() const { return value_.data(); }
|
||||
|
||||
size_t TaskArgumentByValue::HashDataLength() const { return value_.size(); }
|
||||
|
||||
static const ObjectID task_compute_return_id(TaskID task_id, int64_t return_index) {
|
||||
// Here, return_indices need to be >= 0, so we can use negative indices for put.
|
||||
RAY_DCHECK(return_index >= 0);
|
||||
// TODO(rkn): This line requires object and task IDs to be the same size.
|
||||
ObjectID return_id = task_id;
|
||||
int64_t *first_bytes = (int64_t *)&return_id;
|
||||
// XOR the first bytes of the object ID with the return index.
|
||||
// We add one so the first return ID is not the same as the task ID.
|
||||
*first_bytes = *first_bytes ^ (return_index + 1);
|
||||
return return_id;
|
||||
void TaskSpecification::AssignSpecification(const uint8_t *spec, size_t spec_size) {
|
||||
spec_.assign(spec, spec + spec_size);
|
||||
}
|
||||
|
||||
TaskSpecification::TaskSpecification(const uint8_t *spec, size_t spec_size)
|
||||
: spec_(spec, spec + spec_size) {}
|
||||
|
||||
TaskSpecification::TaskSpecification(const flatbuffers::String &string)
|
||||
: TaskSpecification(reinterpret_cast<const uint8_t *>(string.data()), string.size()) {
|
||||
TaskSpecification::TaskSpecification(const flatbuffers::String &string) {
|
||||
AssignSpecification(reinterpret_cast<const uint8_t *>(string.data()), string.size());
|
||||
}
|
||||
|
||||
TaskSpecification::TaskSpecification(
|
||||
@@ -63,9 +42,10 @@ TaskSpecification::TaskSpecification(
|
||||
// UniqueID actor_id,
|
||||
// UniqueID actor_handle_id,
|
||||
// int64_t actor_counter,
|
||||
FunctionID function_id, const std::vector<TaskArgument> &task_arguments,
|
||||
int64_t num_returns,
|
||||
const std::unordered_map<std::string, double> &required_resources) {
|
||||
FunctionID function_id,
|
||||
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
|
||||
const std::unordered_map<std::string, double> &required_resources)
|
||||
: spec_() {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
|
||||
// Compute hashes.
|
||||
@@ -78,14 +58,6 @@ TaskSpecification::TaskSpecification(
|
||||
// sha256_update(&ctx, (BYTE *) &actor_counter, sizeof(actor_counter));
|
||||
// sha256_update(&ctx, (BYTE *) &is_actor_checkpoint_method,
|
||||
// sizeof(is_actor_checkpoint_method));
|
||||
sha256_update(&ctx, (BYTE *)&function_id, sizeof(function_id));
|
||||
|
||||
// Serialize and hash the arguments.
|
||||
std::vector<flatbuffers::Offset<Arg>> arguments;
|
||||
for (auto &argument : task_arguments) {
|
||||
arguments.push_back(argument.ToFlatbuffer(fbb));
|
||||
sha256_update(&ctx, (BYTE *)argument.HashData(), argument.HashDataLength());
|
||||
}
|
||||
|
||||
// Compute the final task ID from the hash.
|
||||
BYTE buff[DIGEST_SIZE];
|
||||
@@ -93,11 +65,17 @@ TaskSpecification::TaskSpecification(
|
||||
TaskID task_id;
|
||||
RAY_DCHECK(sizeof(task_id) <= DIGEST_SIZE);
|
||||
memcpy(&task_id, buff, sizeof(task_id));
|
||||
task_id = FinishTaskId(task_id);
|
||||
// Add argument object IDs.
|
||||
std::vector<flatbuffers::Offset<Arg>> arguments;
|
||||
for (auto &argument : task_arguments) {
|
||||
arguments.push_back(argument->ToFlatbuffer(fbb));
|
||||
}
|
||||
|
||||
// Add return object IDs.
|
||||
std::vector<flatbuffers::Offset<flatbuffers::String>> returns;
|
||||
for (int64_t i = 0; i < num_returns; i++) {
|
||||
ObjectID return_id = task_compute_return_id(task_id, i);
|
||||
for (int64_t i = 1; i < num_returns + 1; i++) {
|
||||
ObjectID return_id = ComputeReturnId(task_id, i);
|
||||
returns.push_back(to_flatbuf(fbb, return_id));
|
||||
}
|
||||
|
||||
@@ -110,7 +88,7 @@ TaskSpecification::TaskSpecification(
|
||||
fbb.CreateVector(arguments), fbb.CreateVector(returns),
|
||||
map_to_flatbuf(fbb, required_resources));
|
||||
fbb.Finish(spec);
|
||||
TaskSpecification(fbb.GetBufferPointer(), fbb.GetSize());
|
||||
AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize());
|
||||
}
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> TaskSpecification::ToFlatbuffer(
|
||||
@@ -130,7 +108,8 @@ TaskID TaskSpecification::TaskId() const {
|
||||
return from_flatbuf(*message->task_id());
|
||||
}
|
||||
UniqueID TaskSpecification::DriverId() const {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
|
||||
return from_flatbuf(*message->driver_id());
|
||||
}
|
||||
TaskID TaskSpecification::ParentTaskId() const {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
@@ -148,7 +127,13 @@ int64_t TaskSpecification::NumArgs() const {
|
||||
}
|
||||
|
||||
int64_t TaskSpecification::NumReturns() const {
|
||||
throw std::runtime_error("Method not implemented");
|
||||
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
|
||||
return message->returns()->size();
|
||||
}
|
||||
|
||||
ObjectID TaskSpecification::ReturnId(int64_t return_index) const {
|
||||
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
|
||||
return from_flatbuf(*message->returns()->Get(return_index));
|
||||
}
|
||||
|
||||
bool TaskSpecification::ArgByRef(int64_t arg_index) const {
|
||||
@@ -180,4 +165,6 @@ const ResourceSet TaskSpecification::GetRequiredResources() const {
|
||||
return ResourceSet(required_resources);
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
+14
-18
@@ -6,7 +6,7 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ray/../common/format/common_generated.h"
|
||||
#include "format/common_generated.h"
|
||||
#include "ray/id.h"
|
||||
#include "ray/raylet/scheduling_resources.h"
|
||||
|
||||
@@ -16,6 +16,8 @@ extern "C" {
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
/// \class TaskArgument
|
||||
///
|
||||
/// A virtual class that represents an argument to a task.
|
||||
@@ -28,16 +30,6 @@ class TaskArgument {
|
||||
virtual flatbuffers::Offset<Arg> ToFlatbuffer(
|
||||
flatbuffers::FlatBufferBuilder &fbb) const = 0;
|
||||
|
||||
/// Get the hashable byte data.
|
||||
///
|
||||
/// \return A pointer to the byte data.
|
||||
virtual const BYTE *HashData() const = 0;
|
||||
|
||||
/// Get the hashable byte data length.
|
||||
///
|
||||
/// \return The length of the hashable byte data.
|
||||
virtual size_t HashDataLength() const = 0;
|
||||
|
||||
virtual ~TaskArgument() = 0;
|
||||
};
|
||||
|
||||
@@ -45,14 +37,15 @@ class TaskArgument {
|
||||
///
|
||||
/// A task argument consisting of a list of object ID references.
|
||||
class TaskArgumentByReference : virtual public TaskArgument {
|
||||
public:
|
||||
/// Create a task argument by reference from a list of object IDs.
|
||||
///
|
||||
/// \param references A list of object ID references.
|
||||
TaskArgumentByReference(const std::vector<ObjectID> &references);
|
||||
|
||||
~TaskArgumentByReference(){};
|
||||
|
||||
flatbuffers::Offset<Arg> ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const;
|
||||
const BYTE *HashData() const;
|
||||
size_t HashDataLength() const;
|
||||
|
||||
private:
|
||||
/// The object IDs.
|
||||
@@ -63,6 +56,7 @@ class TaskArgumentByReference : virtual public TaskArgument {
|
||||
///
|
||||
/// A task argument containing the raw value.
|
||||
class TaskArgumentByValue : public TaskArgument {
|
||||
public:
|
||||
/// Create a task argument from a raw value.
|
||||
///
|
||||
/// \param value A pointer to the raw value.
|
||||
@@ -70,8 +64,6 @@ class TaskArgumentByValue : public TaskArgument {
|
||||
TaskArgumentByValue(const uint8_t *value, size_t length);
|
||||
|
||||
flatbuffers::Offset<Arg> ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) const;
|
||||
const BYTE *HashData() const;
|
||||
size_t HashDataLength() const;
|
||||
|
||||
private:
|
||||
/// The raw value.
|
||||
@@ -106,7 +98,8 @@ class TaskSpecification {
|
||||
// UniqueID actor_id,
|
||||
// UniqueID actor_handle_id,
|
||||
// int64_t actor_counter,
|
||||
FunctionID function_id, const std::vector<TaskArgument> &arguments,
|
||||
FunctionID function_id,
|
||||
const std::vector<std::shared_ptr<TaskArgument>> &arguments,
|
||||
int64_t num_returns,
|
||||
const std::unordered_map<std::string, double> &required_resources);
|
||||
|
||||
@@ -130,14 +123,15 @@ class TaskSpecification {
|
||||
bool ArgByRef(int64_t arg_index) const;
|
||||
int ArgIdCount(int64_t arg_index) const;
|
||||
ObjectID ArgId(int64_t arg_index, int64_t id_index) const;
|
||||
ObjectID ReturnId(int64_t return_index) const;
|
||||
const uint8_t *ArgVal(int64_t arg_index) const;
|
||||
size_t ArgValLength(int64_t arg_index) const;
|
||||
double GetRequiredResource(const std::string &resource_name) const;
|
||||
const ResourceSet GetRequiredResources() const;
|
||||
|
||||
private:
|
||||
/// Task specification constructor from a pointer.
|
||||
TaskSpecification(const uint8_t *spec, size_t spec_size);
|
||||
/// Assign the specification data from a pointer.
|
||||
void AssignSpecification(const uint8_t *spec, size_t spec_size);
|
||||
/// Get a pointer to the byte data.
|
||||
const uint8_t *data() const;
|
||||
/// Get the size in bytes of the task specification.
|
||||
@@ -147,6 +141,8 @@ class TaskSpecification {
|
||||
std::vector<uint8_t> spec_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_TASK_SPECIFICATION_H
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/raylet/task_spec.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
void TestTaskReturnId(const TaskID &task_id, int64_t return_index) {
|
||||
// Round trip test for computing the object ID for a task's return value,
|
||||
// then computing the task ID that created the object.
|
||||
ObjectID return_id = ComputeReturnId(task_id, return_index);
|
||||
ASSERT_EQ(ComputeTaskId(return_id), task_id);
|
||||
ASSERT_EQ(ComputeObjectIndex(return_id), return_index);
|
||||
}
|
||||
|
||||
void TestTaskPutId(const TaskID &task_id, int64_t put_index) {
|
||||
// Round trip test for computing the object ID for a task's put value, then
|
||||
// computing the task ID that created the object.
|
||||
ObjectID put_id = ComputePutId(task_id, put_index);
|
||||
ASSERT_EQ(ComputeTaskId(put_id), task_id);
|
||||
ASSERT_EQ(ComputeObjectIndex(put_id), -1 * put_index);
|
||||
}
|
||||
|
||||
TEST(TaskSpecTest, TestTaskReturnIds) {
|
||||
TaskID task_id = FinishTaskId(TaskID::from_random());
|
||||
|
||||
// Check that we can compute between a task ID and the object IDs of its
|
||||
// return values and puts.
|
||||
TestTaskReturnId(task_id, 1);
|
||||
TestTaskReturnId(task_id, 2);
|
||||
TestTaskReturnId(task_id, kMaxTaskReturns);
|
||||
TestTaskPutId(task_id, 1);
|
||||
TestTaskPutId(task_id, 2);
|
||||
TestTaskPutId(task_id, kMaxTaskPuts);
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
/// A constructor responsible for initializing the state of a worker.
|
||||
Worker::Worker(pid_t pid, std::shared_ptr<LocalClientConnection> connection)
|
||||
: pid_(pid), connection_(connection), assigned_task_id_(TaskID::nil()) {}
|
||||
@@ -22,4 +24,6 @@ const std::shared_ptr<LocalClientConnection> Worker::Connection() const {
|
||||
return connection_;
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // end namespace ray
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
/// Worker class encapsulates the implementation details of a worker. A worker
|
||||
/// is the execution container around a unit of Ray work, such as a task or an
|
||||
/// actor. Ray units of work execute in the context of a Worker.
|
||||
@@ -32,6 +34,8 @@ class Worker {
|
||||
TaskID assigned_task_id_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_WORKER_H
|
||||
|
||||
@@ -5,8 +5,15 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
/// A constructor that initializes a worker pool with num_workers workers.
|
||||
WorkerPool::WorkerPool(int num_workers) {
|
||||
WorkerPool::WorkerPool(int num_workers, const std::vector<const char *> &worker_command)
|
||||
: worker_command_(worker_command) {
|
||||
worker_command_.push_back(NULL);
|
||||
// Ignore SIGCHLD signals. If we don't do this, then worker processes will
|
||||
// become zombies instead of dying gracefully.
|
||||
signal(SIGCHLD, SIG_IGN);
|
||||
for (int i = 0; i < num_workers; i++) {
|
||||
StartWorker();
|
||||
}
|
||||
@@ -18,10 +25,23 @@ WorkerPool::~WorkerPool() {
|
||||
registered_workers_.clear();
|
||||
}
|
||||
|
||||
/// Create a new worker and add it to the pool
|
||||
bool WorkerPool::StartWorker() {
|
||||
// TODO(swang): Start the worker.
|
||||
return true;
|
||||
void WorkerPool::StartWorker() {
|
||||
RAY_CHECK(!worker_command_.empty()) << "No worker command provided";
|
||||
|
||||
// Launch the process to create the worker.
|
||||
pid_t pid = fork();
|
||||
if (pid != 0) {
|
||||
RAY_LOG(DEBUG) << "Started worker with pid " << pid;
|
||||
return;
|
||||
}
|
||||
|
||||
// Reset the SIGCHLD handler for the worker.
|
||||
signal(SIGCHLD, SIG_DFL);
|
||||
// Try to execute the worker command.
|
||||
|
||||
int rv = execvp(worker_command_[0], (char *const *)worker_command_.data());
|
||||
// The worker failed to start. This is a fatal error.
|
||||
RAY_LOG(FATAL) << "Failed to start worker with return value " << rv;
|
||||
}
|
||||
|
||||
uint32_t WorkerPool::PoolSize() const { return pool_.size(); }
|
||||
@@ -75,4 +95,6 @@ bool WorkerPool::DisconnectWorker(std::shared_ptr<Worker> worker) {
|
||||
return removeWorker(pool_, worker);
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
|
||||
class Worker;
|
||||
|
||||
/// \class WorkerPool
|
||||
@@ -23,7 +25,7 @@ class WorkerPool {
|
||||
/// pool.
|
||||
///
|
||||
/// \param num_workers The number of workers to start.
|
||||
WorkerPool(int num_workers);
|
||||
WorkerPool(int num_workers, const std::vector<const char *> &worker_command);
|
||||
|
||||
/// Destructor responsible for freeing a set of workers owned by this class.
|
||||
~WorkerPool();
|
||||
@@ -35,10 +37,9 @@ class WorkerPool {
|
||||
|
||||
/// Asynchronously start a new worker process. Once the worker process has
|
||||
/// registered with an external server, the process should create and
|
||||
/// register a new Worker, then add itself to the pool.
|
||||
///
|
||||
/// \return Whether the worker process was successfully started.
|
||||
bool StartWorker();
|
||||
/// register a new Worker, then add itself to the pool. Failure to start
|
||||
/// the worker process is a fatal error.
|
||||
void StartWorker();
|
||||
|
||||
/// Register a new worker. The Worker should be added by the caller to the
|
||||
/// pool after it becomes idle (e.g., requests a work assignment).
|
||||
@@ -73,6 +74,7 @@ class WorkerPool {
|
||||
std::shared_ptr<Worker> PopWorker();
|
||||
|
||||
private:
|
||||
std::vector<const char *> worker_command_;
|
||||
/// The pool of idle workers.
|
||||
std::list<std::shared_ptr<Worker>> pool_;
|
||||
/// All workers that have registered and are still connected, including both
|
||||
@@ -80,6 +82,9 @@ class WorkerPool {
|
||||
// TODO(swang): Make this a map to make GetRegisteredWorker faster.
|
||||
std::list<std::shared_ptr<Worker>> registered_workers_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RAYLET_WORKER_POOL_H
|
||||
|
||||
@@ -6,27 +6,35 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
class MockClientManager : public ClientManager<boost::asio::local::stream_protocol> {
|
||||
public:
|
||||
MOCK_METHOD3(ProcessClientMessage,
|
||||
void(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *));
|
||||
MOCK_METHOD1(ProcessNewClient, void(std::shared_ptr<LocalClientConnection>));
|
||||
};
|
||||
namespace raylet {
|
||||
|
||||
class WorkerPoolTest : public ::testing::Test {
|
||||
public:
|
||||
WorkerPoolTest() : worker_pool_(0), client_manager_(), io_service_() {}
|
||||
WorkerPoolTest() : worker_pool_(0, {}), io_service_() {}
|
||||
|
||||
std::shared_ptr<Worker> CreateWorker(pid_t pid) {
|
||||
std::function<void(std::shared_ptr<LocalClientConnection>)> client_handler =
|
||||
[this](std::shared_ptr<LocalClientConnection> client) {
|
||||
HandleNewClient(client);
|
||||
};
|
||||
std::function<void(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *)>
|
||||
message_handler = [this](std::shared_ptr<LocalClientConnection> client,
|
||||
int64_t message_type, const uint8_t *message) {
|
||||
HandleMessage(client, message_type, message);
|
||||
};
|
||||
boost::asio::local::stream_protocol::socket socket(io_service_);
|
||||
auto client = LocalClientConnection::Create(client_manager_, std::move(socket));
|
||||
auto client =
|
||||
LocalClientConnection::Create(client_handler, message_handler, std::move(socket));
|
||||
return std::shared_ptr<Worker>(new Worker(pid, client));
|
||||
}
|
||||
|
||||
protected:
|
||||
WorkerPool worker_pool_;
|
||||
MockClientManager client_manager_;
|
||||
boost::asio::io_service io_service_;
|
||||
|
||||
private:
|
||||
void HandleNewClient(std::shared_ptr<LocalClientConnection>){};
|
||||
void HandleMessage(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *){};
|
||||
};
|
||||
|
||||
TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
|
||||
@@ -66,6 +74,8 @@ TEST_F(WorkerPoolTest, HandleWorkerPushPop) {
|
||||
ASSERT_TRUE(workers.count(popped_worker) > 0);
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
set -x
|
||||
|
||||
# Start Redis.
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 &
|
||||
@@ -12,3 +13,4 @@ sleep 1s
|
||||
./src/ray/gcs/client_test
|
||||
|
||||
./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown
|
||||
sleep 1s
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This needs to be run in the build tree, which is normally ray/python/ray/core
|
||||
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
set -x
|
||||
|
||||
# Get the directory in which this script is executing.
|
||||
SCRIPT_DIR="`dirname \"$0\"`"
|
||||
RAY_ROOT="$SCRIPT_DIR/../../.."
|
||||
RAY_ROOT="`( cd \"$RAY_ROOT\" && pwd )`"
|
||||
if [ -z "$RAY_ROOT" ] ; then
|
||||
exit 1
|
||||
fi
|
||||
# Ensure we're in the right directory.
|
||||
if [ ! -d "$RAY_ROOT/python" ]; then
|
||||
echo "Unable to find root Ray directory. Has this script moved?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CORE_DIR="$RAY_ROOT/python/ray/core"
|
||||
REDIS_DIR="$CORE_DIR/src/common/thirdparty/redis/src"
|
||||
REDIS_MODULE="$CORE_DIR/src/common/redis_module/libray_redis_module.so"
|
||||
STORE_EXEC="$CORE_DIR/src/plasma/plasma_store"
|
||||
|
||||
echo "$STORE_EXEC"
|
||||
echo "$REDIS_DIR/redis-server --loglevel warning --loadmodule $REDIS_MODULE --port 6379"
|
||||
echo "$REDIS_DIR/redis-cli -p 6379 shutdown"
|
||||
|
||||
# Allow cleanup commands to fail.
|
||||
killall plasma_store || true
|
||||
$REDIS_DIR/redis-cli -p 6379 shutdown || true
|
||||
sleep 1s
|
||||
$REDIS_DIR/redis-server --loglevel warning --loadmodule $REDIS_MODULE --port 6379 &
|
||||
sleep 1s
|
||||
|
||||
# Run tests.
|
||||
$CORE_DIR/src/ray/object_manager/object_manager_stress_test $STORE_EXEC
|
||||
sleep 1s
|
||||
$CORE_DIR/src/ray/object_manager/object_manager_test $STORE_EXEC
|
||||
$REDIS_DIR/redis-cli -p 6379 shutdown
|
||||
sleep 1s
|
||||
|
||||
# Include raylet integration test once it's ready.
|
||||
# $CORE_DIR/src/ray/raylet/object_manager_integration_test $STORE_EXEC
|
||||
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This needs to be run in the build tree, which is normally ray/python/ray/core
|
||||
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
set -x
|
||||
|
||||
# Tear down the Raylet.
|
||||
#bash ../../../src/ray/test/stop_raylets.sh
|
||||
|
||||
# Set up a single Raylet.
|
||||
bash ../../../src/ray/test/start_raylets.sh
|
||||
|
||||
sleep 1
|
||||
|
||||
# Connect a driver to the raylet and make sure it completes.
|
||||
python ../../../src/ray/python/test_driver.py /tmp/raylet1 /tmp/store1
|
||||
|
||||
sleep 1
|
||||
|
||||
./src/common/thirdparty/redis/src/redis-cli -p 6379 shutdown
|
||||
bash ../../../src/ray/test/stop_raylets.sh
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This needs to be run in the build tree, which is normally ray/python/ray/core
|
||||
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
|
||||
if [[ $1 ]]; then
|
||||
RAYLET_NUM=$1
|
||||
else
|
||||
RAYLET_NUM=1
|
||||
fi
|
||||
|
||||
STORE_SOCKET_NAME="/tmp/store$RAYLET_NUM"
|
||||
RAYLET_SOCKET_NAME="/tmp/raylet$RAYLET_NUM"
|
||||
|
||||
if [[ `stat $RAYLET_SOCKET_NAME` ]]; then
|
||||
rm $RAYLET_SOCKET_NAME
|
||||
fi
|
||||
if [[ `stat $STORE_SOCKET_NAME` ]]; then
|
||||
rm $STORE_SOCKET_NAME
|
||||
fi
|
||||
|
||||
./src/plasma/plasma_store -m 1000000000 -s $STORE_SOCKET_NAME &
|
||||
./src/ray/raylet/raylet $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME 127.0.0.1 6379 &
|
||||
|
||||
echo
|
||||
echo "WORKER COMMAND: python ../../../src/ray/python/worker.py $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME"
|
||||
echo
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This needs to be run in the build tree, which is normally ray/python/ray/core
|
||||
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
|
||||
# Start the GCS.
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 >/dev/null &
|
||||
sleep 1s
|
||||
|
||||
if [[ $1 ]]; then
|
||||
NUM_RAYLETS=$1
|
||||
else
|
||||
NUM_RAYLETS=1
|
||||
fi
|
||||
|
||||
|
||||
for i in `seq 1 $NUM_RAYLETS`; do
|
||||
STORE_SOCKET_NAME="/tmp/store$i"
|
||||
RAYLET_SOCKET_NAME="/tmp/raylet$i"
|
||||
|
||||
if [[ `stat $RAYLET_SOCKET_NAME` ]]; then
|
||||
rm $RAYLET_SOCKET_NAME
|
||||
fi
|
||||
if [[ `stat $STORE_SOCKET_NAME` ]]; then
|
||||
rm $STORE_SOCKET_NAME
|
||||
fi
|
||||
|
||||
./src/plasma/plasma_store -m 1000000000 -s $STORE_SOCKET_NAME &
|
||||
./src/ray/raylet/raylet $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME 127.0.0.1 6379 &
|
||||
|
||||
echo
|
||||
echo "WORKER COMMAND: python ../../../src/ray/python/worker.py $RAYLET_SOCKET_NAME $STORE_SOCKET_NAME"
|
||||
echo
|
||||
done
|
||||
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This needs to be run in the build tree, which is normally ray/python/ray/core
|
||||
|
||||
# Cause the script to exit if a single command fails.
|
||||
set -e
|
||||
|
||||
# Start the GCS.
|
||||
./src/common/thirdparty/redis/src/redis-server --loglevel warning --loadmodule ./src/common/redis_module/libray_redis_module.so --port 6379 >/dev/null &
|
||||
sleep 1s
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
killall raylet
|
||||
sleep 1
|
||||
killall plasma_store
|
||||
sleep 1
|
||||
killall redis-server
|
||||
sleep 1
|
||||
rm /tmp/store* /tmp/raylet*
|
||||
Reference in New Issue
Block a user