From 5b45f0bdffe0caed9e489468ae736bcaed0bc317 Mon Sep 17 00:00:00 2001 From: Yucong He Date: Fri, 31 Aug 2018 15:54:30 -0700 Subject: [PATCH] [xray] Implementing Gcs sharding (#2409) Basically a re-implementation of #2281, with modifications of #2298 (A fix of #2334, for rebasing issues.). [+] Implement sharding for gcs tables. [+] Keep ClientTable and ErrorTable managed by the primary_shard. TaskTable is managed by the primary_shard for now, until a good hashing for tasks is implemented. [+] Move AsyncGcsClient's initialization into Connect function. [-] Move GetRedisShard and bool sharding from RedisContext's connect into AsyncGcsClient. This may make the interface cleaner. --- src/global_scheduler/global_scheduler.cc | 5 - src/global_scheduler/global_scheduler.h | 2 - src/local_scheduler/local_scheduler.cc | 5 - src/local_scheduler/local_scheduler_shared.h | 2 - src/plasma/plasma_manager.cc | 9 - src/ray/gcs/client.cc | 170 ++++++++++++++---- src/ray/gcs/client.h | 33 ++-- src/ray/gcs/client_test.cc | 10 +- src/ray/gcs/redis_context.cc | 89 +-------- src/ray/gcs/redis_context.h | 1 + src/ray/gcs/tables.cc | 39 ++-- src/ray/gcs/tables.h | 107 ++++++----- .../test/object_manager_stress_test.cc | 7 +- .../test/object_manager_test.cc | 7 +- src/ray/raylet/main.cc | 2 +- src/ray/raylet/monitor.cc | 3 +- .../raylet/object_manager_integration_test.cc | 6 +- src/ray/raylet/raylet.cc | 1 - test/runtest.py | 6 +- 19 files changed, 269 insertions(+), 235 deletions(-) diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc index cd2477420..94dccf636 100644 --- a/src/global_scheduler/global_scheduler.cc +++ b/src/global_scheduler/global_scheduler.cc @@ -132,11 +132,6 @@ GlobalSchedulerState *GlobalSchedulerState_init(event_loop *loop, "global_scheduler", node_ip_address, std::vector()); db_attach(state->db, loop, false); - - RAY_CHECK_OK(state->gcs_client.Connect( - std::string(redis_primary_addr), redis_primary_port, /*sharding=*/true)); - RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop)); - RAY_CHECK_OK(state->gcs_client.primary_context()->AttachToEventLoop(loop)); state->policy_state = GlobalSchedulerPolicyState_init(); return state; } diff --git a/src/global_scheduler/global_scheduler.h b/src/global_scheduler/global_scheduler.h index 194559393..e1610c555 100644 --- a/src/global_scheduler/global_scheduler.h +++ b/src/global_scheduler/global_scheduler.h @@ -51,8 +51,6 @@ typedef struct { event_loop *loop; /** The global state store database. */ DBHandle *db; - /** The handle to the GCS (modern version of the above). */ - ray::gcs::AsyncGcsClient gcs_client; /** A hash table mapping local scheduler ID to the local schedulers that are * connected to Redis. */ std::unordered_map local_schedulers; diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index 27f14dad8..2f8a684f6 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -351,11 +351,6 @@ LocalSchedulerState *LocalSchedulerState_init( state->db = db_connect(std::string(redis_primary_addr), redis_primary_port, "local_scheduler", node_ip_address, db_connect_args); db_attach(state->db, loop, false); - - RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr), - redis_primary_port, true)); - RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(loop)); - RAY_CHECK_OK(state->gcs_client.primary_context()->AttachToEventLoop(loop)); } else { state->db = NULL; } diff --git a/src/local_scheduler/local_scheduler_shared.h b/src/local_scheduler/local_scheduler_shared.h index 013cf7a78..572f14a6f 100644 --- a/src/local_scheduler/local_scheduler_shared.h +++ b/src/local_scheduler/local_scheduler_shared.h @@ -60,8 +60,6 @@ struct LocalSchedulerState { std::unordered_map actor_mapping; /** The handle to the database. */ DBHandle *db; - /** The handle to the GCS (modern version of the above). */ - ray::gcs::AsyncGcsClient gcs_client; /** The Plasma client. */ plasma::PlasmaClient *plasma_conn; /** State for the scheduling algorithm. */ diff --git a/src/plasma/plasma_manager.cc b/src/plasma/plasma_manager.cc index 95c8c08aa..91f74e528 100644 --- a/src/plasma/plasma_manager.cc +++ b/src/plasma/plasma_manager.cc @@ -215,8 +215,6 @@ struct PlasmaManagerState { * other plasma stores. */ std::unordered_map manager_connections; DBHandle *db; - /** The handle to the GCS (modern version of the above). */ - ray::gcs::AsyncGcsClient gcs_client; /** Our address. */ const char *addr; /** Our port. */ @@ -490,13 +488,6 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name, state->db = db_connect(std::string(redis_primary_addr), redis_primary_port, "plasma_manager", manager_addr, db_connect_args); db_attach(state->db, state->loop, false); - - RAY_CHECK_OK(state->gcs_client.Connect(std::string(redis_primary_addr), - redis_primary_port, - /*sharding=*/true)); - RAY_CHECK_OK(state->gcs_client.context()->AttachToEventLoop(state->loop)); - RAY_CHECK_OK( - state->gcs_client.primary_context()->AttachToEventLoop(state->loop)); } else { state->db = NULL; RAY_LOG(DEBUG) << "No db connection specified"; diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 88eda1a5d..182c44a8a 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -2,51 +2,152 @@ #include "ray/gcs/redis_context.h" +static void GetRedisShards(redisContext *context, std::vector &addresses, + std::vector &ports) { + // Get the total number of Redis shards in the system. + int num_attempts = 0; + redisReply *reply = nullptr; + while (num_attempts < RayConfig::instance().redis_db_connect_retries()) { + // Try to read the number of Redis shards from the primary shard. If the + // entry is present, exit. + reply = reinterpret_cast(redisCommand(context, "GET NumRedisShards")); + if (reply->type != REDIS_REPLY_NIL) { + break; + } + + // Sleep for a little, and try again if the entry isn't there yet. */ + freeReplyObject(reply); + usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); + num_attempts++; + } + RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) + << "No entry found for NumRedisShards"; + RAY_CHECK(reply->type == REDIS_REPLY_STRING) << "Expected string, found Redis type " + << reply->type << " for NumRedisShards"; + int num_redis_shards = atoi(reply->str); + RAY_CHECK(num_redis_shards >= 1) << "Expected at least one Redis shard, " + << "found " << num_redis_shards; + freeReplyObject(reply); + + // Get the addresses of all of the Redis shards. + num_attempts = 0; + while (num_attempts < RayConfig::instance().redis_db_connect_retries()) { + // Try to read the Redis shard locations from the primary shard. If we find + // that all of them are present, exit. + reply = + reinterpret_cast(redisCommand(context, "LRANGE RedisShards 0 -1")); + if (static_cast(reply->elements) == num_redis_shards) { + break; + } + + // Sleep for a little, and try again if not all Redis shard addresses have + // been added yet. + freeReplyObject(reply); + usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); + num_attempts++; + } + RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) + << "Expected " << num_redis_shards << " Redis shard addresses, found " + << reply->elements; + + // Parse the Redis shard addresses. + for (size_t i = 0; i < reply->elements; ++i) { + // Parse the shard addresses and ports. + RAY_CHECK(reply->element[i]->type == REDIS_REPLY_STRING); + std::string addr; + std::stringstream ss(reply->element[i]->str); + getline(ss, addr, ':'); + addresses.push_back(addr); + int port; + ss >> port; + ports.push_back(port); + } + freeReplyObject(reply); +} + namespace ray { namespace gcs { -AsyncGcsClient::AsyncGcsClient(const ClientID &client_id, CommandType command_type) { - context_ = std::make_shared(); +AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, + const ClientID &client_id, CommandType command_type, + bool is_test_client = false) { primary_context_ = std::make_shared(); - client_table_.reset(new ClientTable(primary_context_, this, client_id)); - object_table_.reset(new ObjectTable(context_, this, command_type)); - actor_table_.reset(new ActorTable(context_, this)); - task_table_.reset(new TaskTable(context_, this, command_type)); - raylet_task_table_.reset(new raylet::TaskTable(context_, this, command_type)); - task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this)); - task_lease_table_.reset(new TaskLeaseTable(context_, this)); - heartbeat_table_.reset(new HeartbeatTable(context_, this)); - driver_table_.reset(new DriverTable(primary_context_, this)); - error_table_.reset(new ErrorTable(primary_context_, this)); - profile_table_.reset(new ProfileTable(context_, this)); + + RAY_CHECK_OK(primary_context_->Connect(address, port, /*sharding=*/true)); + + if (!is_test_client) { + // Moving sharding into constructor defaultly means that sharding = true. + // This design decision may worth a look. + std::vector addresses; + std::vector ports; + GetRedisShards(primary_context_->sync_context(), addresses, ports); + if (addresses.size() == 0 || ports.size() == 0) { + addresses.push_back(address); + ports.push_back(port); + } + + // Populate shard_contexts. + for (size_t i = 0; i < addresses.size(); ++i) { + shard_contexts_.push_back(std::make_shared()); + } + + RAY_CHECK(shard_contexts_.size() == addresses.size()); + for (size_t i = 0; i < addresses.size(); ++i) { + RAY_CHECK_OK( + shard_contexts_[i]->Connect(addresses[i], ports[i], /*sharding=*/true)); + } + } else { + shard_contexts_.push_back(std::make_shared()); + RAY_CHECK_OK(shard_contexts_[0]->Connect(address, port, /*sharding=*/true)); + } + + client_table_.reset(new ClientTable({primary_context_}, this, client_id)); + error_table_.reset(new ErrorTable({primary_context_}, this)); + driver_table_.reset(new DriverTable({primary_context_}, this)); + // Tables below would be sharded. + object_table_.reset(new ObjectTable(shard_contexts_, this, command_type)); + actor_table_.reset(new ActorTable(shard_contexts_, this)); + task_table_.reset(new TaskTable(shard_contexts_, this, command_type)); + raylet_task_table_.reset(new raylet::TaskTable(shard_contexts_, this, command_type)); + task_reconstruction_log_.reset(new TaskReconstructionLog(shard_contexts_, this)); + task_lease_table_.reset(new TaskLeaseTable(shard_contexts_, this)); + heartbeat_table_.reset(new HeartbeatTable(shard_contexts_, this)); + profile_table_.reset(new ProfileTable(shard_contexts_, this)); command_type_ = command_type; + + // TODO(swang): Call the client table's Connect() method here. To do this, + // we need to make sure that we are attached to an event loop first. This + // currently isn't possible because the aeEventLoop, which we use for + // testing, requires us to connect to Redis first. } #if RAY_USE_NEW_GCS // Use of kChain currently only applies to Table::Add which affects only the // task table, and when RAY_USE_NEW_GCS is set at compile time. -AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) - : AsyncGcsClient(client_id, CommandType::kChain) {} +AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, + const ClientID &client_id, bool is_test_client = false) + : AsyncGcsClient(address, port, client_id, CommandType::kChain, is_test_client) {} #else -AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) - : AsyncGcsClient(client_id, CommandType::kRegular) {} +AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, + const ClientID &client_id, bool is_test_client = false) + : AsyncGcsClient(address, port, client_id, CommandType::kRegular, is_test_client) {} #endif // RAY_USE_NEW_GCS -AsyncGcsClient::AsyncGcsClient(CommandType command_type) - : AsyncGcsClient(ClientID::from_random(), command_type) {} +AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, + CommandType command_type) + : AsyncGcsClient(address, port, ClientID::from_random(), command_type) {} -AsyncGcsClient::AsyncGcsClient() : AsyncGcsClient(ClientID::from_random()) {} +AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, + CommandType command_type, bool is_test_client) + : AsyncGcsClient(address, port, ClientID::from_random(), command_type, + is_test_client) {} -Status AsyncGcsClient::Connect(const std::string &address, int port, bool sharding) { - RAY_RETURN_NOT_OK(context_->Connect(address, port, sharding)); - RAY_RETURN_NOT_OK(primary_context_->Connect(address, port, /*sharding=*/false)); - // TODO(swang): Call the client table's Connect() method here. To do this, - // we need to make sure that we are attached to an event loop first. This - // currently isn't possible because the aeEventLoop, which we use for - // testing, requires us to connect to Redis first. - return Status::OK(); -} +AsyncGcsClient::AsyncGcsClient(const std::string &address, int port) + : AsyncGcsClient(address, port, ClientID::from_random()) {} + +AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, bool is_test_client) + : AsyncGcsClient(address, port, ClientID::from_random(), is_test_client) {} Status Attach(plasma::EventLoop &event_loop) { // TODO(pcm): Implement this via @@ -55,9 +156,14 @@ Status Attach(plasma::EventLoop &event_loop) { } Status AsyncGcsClient::Attach(boost::asio::io_service &io_service) { - asio_async_client_.reset(new RedisAsioClient(io_service, context_->async_context())); - asio_subscribe_client_.reset( - new RedisAsioClient(io_service, context_->subscribe_context())); + // Take care of sharding contexts. + RAY_CHECK(shard_asio_async_clients_.empty()) << "Attach shall be called only once"; + for (std::shared_ptr context : shard_contexts_) { + shard_asio_async_clients_.emplace_back( + new RedisAsioClient(io_service, context->async_context())); + shard_asio_subscribe_clients_.emplace_back( + new RedisAsioClient(io_service, context->subscribe_context())); + } asio_async_auxiliary_client_.reset( new RedisAsioClient(io_service, primary_context_->async_context())); asio_subscribe_auxiliary_client_.reset( diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 9da272e64..d89aadd80 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -24,21 +24,22 @@ class RAY_EXPORT AsyncGcsClient { /// Attach() must be called. To read and write from the GCS tables requires a /// further call to Connect() to the client table. /// - /// \param client_id The ID to assign to the client. - /// \param command_type GCS command type. If CommandType::kChain, chain-replicated - /// versions of the tables might be used, if available. - AsyncGcsClient(const ClientID &client_id, CommandType command_type); - AsyncGcsClient(const ClientID &client_id); - AsyncGcsClient(CommandType command_type); - AsyncGcsClient(); - - /// Connect to the GCS. - /// /// \param address The GCS IP address. /// \param port The GCS port. /// \param sharding If true, use sharded redis for the GCS. - /// \return Status. - Status Connect(const std::string &address, int port, bool sharding); + /// \param client_id The ID to assign to the client. + /// \param command_type GCS command type. If CommandType::kChain, chain-replicated + /// versions of the tables might be used, if available. + AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, + CommandType command_type, bool is_test_client); + AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, + bool is_test_client); + AsyncGcsClient(const std::string &address, int port, CommandType command_type); + AsyncGcsClient(const std::string &address, int port, CommandType command_type, + bool is_test_client); + AsyncGcsClient(const std::string &address, int port); + AsyncGcsClient(const std::string &address, int port, bool is_test_client); + /// Attach this client to a plasma event loop. Note that only /// one event loop should be attached at a time. Status Attach(plasma::EventLoop &event_loop); @@ -71,7 +72,7 @@ class RAY_EXPORT AsyncGcsClient { Status GetExport(const std::string &driver_id, int64_t export_index, const GetExportCallback &done_callback); - std::shared_ptr context() { return context_; } + std::vector> shard_contexts() { return shard_contexts_; } std::shared_ptr primary_context() { return primary_context_; } private: @@ -88,9 +89,9 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr profile_table_; std::unique_ptr client_table_; // The following contexts write to the data shard - std::shared_ptr context_; - std::unique_ptr asio_async_client_; - std::unique_ptr asio_subscribe_client_; + std::vector> shard_contexts_; + std::vector> shard_asio_async_clients_; + std::vector> shard_asio_subscribe_clients_; // The following context writes everything to the primary shard std::shared_ptr primary_context_; std::unique_ptr driver_table_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 715edc797..41d93f296 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -28,9 +28,8 @@ static inline void flushall_redis(void) { class TestGcs : public ::testing::Test { public: TestGcs(CommandType command_type) : num_callbacks_(0), command_type_(command_type) { - client_ = std::make_shared(command_type_); - RAY_CHECK_OK(client_->Connect("127.0.0.1", 6379, /*sharding=*/false)); - + client_ = std::make_shared("127.0.0.1", 6379, command_type_, + /*is_test_client=*/true); job_id_ = JobID::from_random(); } @@ -60,7 +59,10 @@ class TestGcsWithAe : public TestGcs { public: TestGcsWithAe(CommandType command_type) : TestGcs(command_type) { loop_ = aeCreateEventLoop(1024); - RAY_CHECK_OK(client_->context()->AttachToEventLoop(loop_)); + RAY_CHECK_OK(client_->primary_context()->AttachToEventLoop(loop_)); + for (auto &context : client_->shard_contexts()) { + RAY_CHECK_OK(context->AttachToEventLoop(loop_)); + } } TestGcsWithAe() : TestGcsWithAe(CommandType::kRegular) {} diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 124d39096..abc06a24a 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -135,69 +135,6 @@ RedisContext::~RedisContext() { } } -static void GetRedisShards(redisContext *context, std::vector *addresses, - std::vector *ports) { - // Get the total number of Redis shards in the system. - int num_attempts = 0; - redisReply *reply = nullptr; - while (num_attempts < RayConfig::instance().redis_db_connect_retries()) { - // Try to read the number of Redis shards from the primary shard. If the - // entry is present, exit. - reply = reinterpret_cast(redisCommand(context, "GET NumRedisShards")); - if (reply->type != REDIS_REPLY_NIL) { - break; - } - - // Sleep for a little, and try again if the entry isn't there yet. */ - freeReplyObject(reply); - usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); - num_attempts++; - } - RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) - << "No entry found for NumRedisShards"; - RAY_CHECK(reply->type == REDIS_REPLY_STRING) << "Expected string, found Redis type " - << reply->type << " for NumRedisShards"; - int num_redis_shards = atoi(reply->str); - RAY_CHECK(num_redis_shards >= 1) << "Expected at least one Redis shard, " - << "found " << num_redis_shards; - freeReplyObject(reply); - - // Get the addresses of all of the Redis shards. - num_attempts = 0; - while (num_attempts < RayConfig::instance().redis_db_connect_retries()) { - // Try to read the Redis shard locations from the primary shard. If we find - // that all of them are present, exit. - reply = - reinterpret_cast(redisCommand(context, "LRANGE RedisShards 0 -1")); - if (static_cast(reply->elements) == num_redis_shards) { - break; - } - - // Sleep for a little, and try again if not all Redis shard addresses have - // been added yet. - freeReplyObject(reply); - usleep(RayConfig::instance().redis_db_connect_wait_milliseconds() * 1000); - num_attempts++; - } - RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) - << "Expected " << num_redis_shards << " Redis shard addresses, found " - << reply->elements; - - // Parse the Redis shard addresses. - for (size_t i = 0; i < reply->elements; ++i) { - // Parse the shard addresses and ports. - RAY_CHECK(reply->element[i]->type == REDIS_REPLY_STRING); - std::string addr; - std::stringstream ss(reply->element[i]->str); - getline(ss, addr, ':'); - addresses->push_back(addr); - int port; - ss >> port; - ports->push_back(port); - } - freeReplyObject(reply); -} - Status RedisContext::Connect(const std::string &address, int port, bool sharding) { int connection_attempts = 0; context_ = redisConnect(address.c_str(), port); @@ -223,31 +160,17 @@ Status RedisContext::Connect(const std::string &address, int port, bool sharding REDIS_CHECK_ERROR(context_, reply); freeReplyObject(reply); - std::string redis_address; - int redis_port; - if (sharding) { - // Get the redis data shard - std::vector addresses; - std::vector ports; - GetRedisShards(context_, &addresses, &ports); - redis_address = addresses[0]; - redis_port = ports[0]; - } else { - redis_address = address; - redis_port = port; - } - // Connect to async context - async_context_ = redisAsyncConnect(redis_address.c_str(), redis_port); + async_context_ = redisAsyncConnect(address.c_str(), port); if (async_context_ == nullptr || async_context_->err) { - RAY_LOG(FATAL) << "Could not establish connection to redis " << redis_address << ":" - << redis_port; + RAY_LOG(FATAL) << "Could not establish connection to redis " << address << ":" + << port; } // Connect to subscribe context - subscribe_context_ = redisAsyncConnect(redis_address.c_str(), redis_port); + subscribe_context_ = redisAsyncConnect(address.c_str(), port); if (subscribe_context_ == nullptr || subscribe_context_->err) { - RAY_LOG(FATAL) << "Could not establish subscribe connection to redis " - << redis_address << ":" << redis_port; + RAY_LOG(FATAL) << "Could not establish subscribe connection to redis " << address + << ":" << port; } return Status::OK(); } diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index c54c270de..67bc8197c 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -88,6 +88,7 @@ class RedisContext { /// \return Status. Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel, const RedisCallback &redisCallback, int64_t *out_callback_index); + redisContext *sync_context() { return context_; } redisAsyncContext *async_context() { return async_context_; } redisAsyncContext *subscribe_context() { return subscribe_context_; }; diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index ce83f1464..ba0624814 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -47,9 +47,9 @@ Status Log::Append(const JobID &job_id, const ID &id, flatbuffers::FlatBufferBuilder fbb; fbb.ForceDefaults(true); fbb.Finish(Data::Pack(fbb, dataT.get())); - return context_->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, + fbb.GetBufferPointer(), fbb.GetSize(), prefix_, + pubsub_channel_, std::move(callback)); } template @@ -71,9 +71,9 @@ Status Log::AppendAt(const JobID &job_id, const ID &id, flatbuffers::FlatBufferBuilder fbb; fbb.ForceDefaults(true); fbb.Finish(Data::Pack(fbb, dataT.get())); - return context_->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback), log_length); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, + fbb.GetBufferPointer(), fbb.GetSize(), prefix_, + pubsub_channel_, std::move(callback), log_length); } template @@ -96,8 +96,8 @@ Status Log::Lookup(const JobID &job_id, const ID &id, const Callback & return true; }; std::vector nil; - return context_->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), prefix_, - pubsub_channel_, std::move(callback)); + return GetRedisContext(id)->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -136,8 +136,12 @@ Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, // more subscription messages. return false; }; - return context_->SubscribeAsync(client_id, pubsub_channel_, std::move(callback), - &subscribe_callback_index_); + subscribe_callback_index_ = 1; + for (auto &context : shard_contexts_) { + RAY_RETURN_NOT_OK(context->SubscribeAsync(client_id, pubsub_channel_, callback, + &subscribe_callback_index_)); + } + return Status::OK(); } template @@ -145,8 +149,9 @@ Status Log::RequestNotifications(const JobID &job_id, const ID &id, const ClientID &client_id) { RAY_CHECK(subscribe_callback_index_ >= 0) << "Client requested notifications on a key before Subscribe completed"; - return context_->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, client_id.data(), - client_id.size(), prefix_, pubsub_channel_, nullptr); + return GetRedisContext(id)->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, + client_id.data(), client_id.size(), prefix_, + pubsub_channel_, nullptr); } template @@ -154,8 +159,9 @@ Status Log::CancelNotifications(const JobID &job_id, const ID &id, const ClientID &client_id) { RAY_CHECK(subscribe_callback_index_ >= 0) << "Client canceled notifications on a key before Subscribe completed"; - return context_->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, client_id.data(), - client_id.size(), prefix_, pubsub_channel_, nullptr); + return GetRedisContext(id)->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, + client_id.data(), client_id.size(), prefix_, + pubsub_channel_, nullptr); } template @@ -170,8 +176,9 @@ Status Table::Add(const JobID &job_id, const ID &id, flatbuffers::FlatBufferBuilder fbb; fbb.ForceDefaults(true); fbb.Finish(Data::Pack(fbb, dataT.get())); - return context_->RunAsync(GetTableAddCommand(command_type_), id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, std::move(callback)); + return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, + fbb.GetBufferPointer(), fbb.GetSize(), prefix_, + pubsub_channel_, std::move(callback)); } template diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index ba8b561ea..e2f022502 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -99,8 +99,8 @@ class Log : public LogInterface, virtual public PubsubInterface { AsyncGcsClient *client; }; - Log(const std::shared_ptr &context, AsyncGcsClient *client) - : context_(context), + Log(const std::vector> &contexts, AsyncGcsClient *client) + : shard_contexts_(contexts), client_(client), pubsub_channel_(TablePubsub::NO_PUBLISH), prefix_(TablePrefix::UNUSED), @@ -190,8 +190,12 @@ class Log : public LogInterface, virtual public PubsubInterface { const ClientID &client_id); protected: + std::shared_ptr GetRedisContext(const ID &id) { + static std::hash index; + return shard_contexts_[index(id) % shard_contexts_.size()]; + } /// The connection to the GCS. - std::shared_ptr context_; + std::vector> shard_contexts_; /// The GCS client. AsyncGcsClient *client_; /// The pubsub channel to subscribe to for notifications about keys in this @@ -245,8 +249,9 @@ class Table : private Log, /// request and receive notifications. using SubscriptionCallback = typename Log::SubscriptionCallback; - Table(const std::shared_ptr &context, AsyncGcsClient *client) - : Log(context, client) {} + Table(const std::vector> &contexts, + AsyncGcsClient *client) + : Log(contexts, client) {} using Log::RequestNotifications; using Log::CancelNotifications; @@ -296,24 +301,26 @@ class Table : private Log, const SubscriptionCallback &done); protected: - using Log::context_; + using Log::shard_contexts_; using Log::client_; using Log::pubsub_channel_; using Log::prefix_; using Log::command_type_; + using Log::GetRedisContext; }; class ObjectTable : public Log { public: - ObjectTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Log(context, client) { + ObjectTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Log(contexts, client) { pubsub_channel_ = TablePubsub::OBJECT; prefix_ = TablePrefix::OBJECT; }; - ObjectTable(const std::shared_ptr &context, AsyncGcsClient *client, - gcs::CommandType command_type) - : ObjectTable(context, client) { + ObjectTable(const std::vector> &contexts, + AsyncGcsClient *client, gcs::CommandType command_type) + : ObjectTable(contexts, client) { command_type_ = command_type; }; @@ -322,8 +329,9 @@ class ObjectTable : public Log { class HeartbeatTable : public Table { public: - HeartbeatTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Table(context, client) { + HeartbeatTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Table(contexts, client) { pubsub_channel_ = TablePubsub::HEARTBEAT; prefix_ = TablePrefix::HEARTBEAT; } @@ -332,8 +340,9 @@ class HeartbeatTable : public Table { class DriverTable : public Log { public: - DriverTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Log(context, client) { + DriverTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Log(contexts, client) { pubsub_channel_ = TablePubsub::DRIVER; prefix_ = TablePrefix::DRIVER; }; @@ -349,8 +358,9 @@ class DriverTable : public Log { class FunctionTable : public Table { public: - FunctionTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Table(context, client) { + FunctionTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Table(contexts, client) { pubsub_channel_ = TablePubsub::NO_PUBLISH; prefix_ = TablePrefix::FUNCTION; }; @@ -361,8 +371,9 @@ using ClassTable = Table; // TODO(swang): Set the pubsub channel for the actor table. class ActorTable : public Log { public: - ActorTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Log(context, client) { + ActorTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Log(contexts, client) { pubsub_channel_ = TablePubsub::ACTOR; prefix_ = TablePrefix::ACTOR; } @@ -370,17 +381,18 @@ class ActorTable : public Log { class TaskReconstructionLog : public Log { public: - TaskReconstructionLog(const std::shared_ptr &context, + TaskReconstructionLog(const std::vector> &contexts, AsyncGcsClient *client) - : Log(context, client) { + : Log(contexts, client) { prefix_ = TablePrefix::TASK_RECONSTRUCTION; } }; class TaskLeaseTable : public Table { public: - TaskLeaseTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Table(context, client) { + TaskLeaseTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Table(contexts, client) { pubsub_channel_ = TablePubsub::TASK_LEASE; prefix_ = TablePrefix::TASK_LEASE; } @@ -397,7 +409,8 @@ class TaskLeaseTable : public Table { std::vector args = {"PEXPIRE", EnumNameTablePrefix(prefix_) + id.binary(), std::to_string(data->timeout)}; - return context_->RunArgvAsync(args); + + return GetRedisContext(id)->RunArgvAsync(args); } }; @@ -405,15 +418,16 @@ namespace raylet { class TaskTable : public Table { public: - TaskTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Table(context, client) { + TaskTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Table(contexts, client) { pubsub_channel_ = TablePubsub::RAYLET_TASK; prefix_ = TablePrefix::RAYLET_TASK; } - TaskTable(const std::shared_ptr &context, AsyncGcsClient *client, - gcs::CommandType command_type) - : TaskTable(context, client) { + TaskTable(const std::vector> &contexts, + AsyncGcsClient *client, gcs::CommandType command_type) + : TaskTable(contexts, client) { command_type_ = command_type; }; }; @@ -422,15 +436,16 @@ class TaskTable : public Table { class TaskTable : public Table { public: - TaskTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Table(context, client) { + TaskTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Table(contexts, client) { pubsub_channel_ = TablePubsub::TASK; prefix_ = TablePrefix::TASK; }; - TaskTable(const std::shared_ptr &context, AsyncGcsClient *client, - gcs::CommandType command_type) - : TaskTable(context, client) { + TaskTable(const std::vector> &contexts, + AsyncGcsClient *client, gcs::CommandType command_type) + : TaskTable(contexts, client) { command_type_ = command_type; } @@ -466,9 +481,11 @@ class TaskTable : public Table { }; flatbuffers::FlatBufferBuilder fbb; fbb.Finish(TaskTableTestAndUpdate::Pack(fbb, data.get())); - RAY_RETURN_NOT_OK(context_->RunAsync("RAY.TABLE_TEST_AND_UPDATE", id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, redisCallback)); + for (auto context : shard_contexts_) { + RAY_RETURN_NOT_OK(context->RunAsync("RAY.TABLE_TEST_AND_UPDATE", id, + fbb.GetBufferPointer(), fbb.GetSize(), prefix_, + pubsub_channel_, redisCallback)); + } return Status::OK(); } @@ -504,8 +521,9 @@ Status TaskTableTestAndUpdate(AsyncGcsClient *gcs_client, const TaskID &task_id, class ErrorTable : private Log { public: - ErrorTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Log(context, client) { + ErrorTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Log(contexts, client) { pubsub_channel_ = TablePubsub::ERROR_INFO; prefix_ = TablePrefix::ERROR_INFO; }; @@ -528,8 +546,9 @@ class ErrorTable : private Log { class ProfileTable : private Log { public: - ProfileTable(const std::shared_ptr &context, AsyncGcsClient *client) - : Log(context, client) { + ProfileTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Log(contexts, client) { prefix_ = TablePrefix::PROFILE; }; @@ -574,9 +593,9 @@ class ClientTable : private Log { public: using ClientTableCallback = std::function; - ClientTable(const std::shared_ptr &context, AsyncGcsClient *client, - const ClientID &client_id) - : Log(context, client), + ClientTable(const std::vector> &contexts, + AsyncGcsClient *client, const ClientID &client_id) + : Log(contexts, client), // We set the client log's key equal to nil so that all instances of // ClientTable have the same key. client_log_key_(UniqueID::nil()), diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 88cde3986..f421fb9cf 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -43,7 +43,6 @@ class MockServer { private: ray::Status RegisterGcs(boost::asio::io_service &io_service) { - RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379, /*sharding=*/false)); RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint(); @@ -130,7 +129,8 @@ class TestObjectManagerBase : public ::testing::Test { int push_timeout_ms = 10000; // start first server - gcs_client_1 = std::shared_ptr(new gcs::AsyncGcsClient()); + gcs_client_1 = std::shared_ptr( + new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); ObjectManagerConfig om_config_1; om_config_1.store_socket_name = store_id_1; om_config_1.pull_timeout_ms = pull_timeout_ms; @@ -141,7 +141,8 @@ class TestObjectManagerBase : public ::testing::Test { server1.reset(new MockServer(main_service, om_config_1, gcs_client_1)); // start second server - gcs_client_2 = std::shared_ptr(new gcs::AsyncGcsClient()); + gcs_client_2 = std::shared_ptr( + new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); ObjectManagerConfig om_config_2; om_config_2.store_socket_name = store_id_2; om_config_2.pull_timeout_ms = pull_timeout_ms; diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 52362a279..68cd0f156 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -34,7 +34,6 @@ class MockServer { private: ray::Status RegisterGcs(boost::asio::io_service &io_service) { - RAY_RETURN_NOT_OK(gcs_client_->Connect("127.0.0.1", 6379, /*sharding=*/false)); RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); boost::asio::ip::tcp::endpoint endpoint = object_manager_acceptor_.local_endpoint(); @@ -115,7 +114,8 @@ class TestObjectManagerBase : public ::testing::Test { push_timeout_ms = 1000; // start first server - gcs_client_1 = std::shared_ptr(new gcs::AsyncGcsClient()); + gcs_client_1 = std::shared_ptr( + new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); ObjectManagerConfig om_config_1; om_config_1.store_socket_name = store_id_1; om_config_1.pull_timeout_ms = pull_timeout_ms; @@ -126,7 +126,8 @@ class TestObjectManagerBase : public ::testing::Test { server1.reset(new MockServer(main_service, om_config_1, gcs_client_1)); // start second server - gcs_client_2 = std::shared_ptr(new gcs::AsyncGcsClient()); + gcs_client_2 = std::shared_ptr( + new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); ObjectManagerConfig om_config_2; om_config_2.store_socket_name = store_id_2; om_config_2.pull_timeout_ms = pull_timeout_ms; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index bdf842886..171078cb2 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -89,7 +89,7 @@ int main(int argc, char *argv[]) { << "object_chunk_size = " << object_manager_config.object_chunk_size; // initialize mock gcs & object directory - auto gcs_client = std::make_shared(); + auto gcs_client = std::make_shared(redis_address, redis_port); RAY_LOG(DEBUG) << "Initializing GCS client " << gcs_client->client_table().GetLocalClientId(); diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index b9d92bb01..aa845035c 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -15,10 +15,9 @@ namespace raylet { /// the client table, which broadcasts the event to all other Raylets. Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_address, int redis_port) - : gcs_client_(), + : gcs_client_(redis_address, redis_port), num_heartbeats_timeout_(RayConfig::instance().num_heartbeats_timeout()), heartbeat_timer_(io_service) { - RAY_CHECK_OK(gcs_client_.Connect(redis_address, redis_port, /*sharding=*/true)); RAY_CHECK_OK(gcs_client_.Attach(io_service)); } diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index c22c05c15..5fe7f774b 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -54,7 +54,8 @@ class TestObjectManagerBase : public ::testing::Test { std::string store_sock_2 = StartStore("2"); // start first server - gcs_client_1 = std::shared_ptr(new gcs::AsyncGcsClient()); + gcs_client_1 = std::shared_ptr( + new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); ObjectManagerConfig om_config_1; om_config_1.store_socket_name = store_sock_1; om_config_1.push_timeout_ms = 10000; @@ -63,7 +64,8 @@ class TestObjectManagerBase : public ::testing::Test { GetNodeManagerConfig("raylet_1", store_sock_1), om_config_1, gcs_client_1)); // start second server - gcs_client_2 = std::shared_ptr(new gcs::AsyncGcsClient()); + gcs_client_2 = std::shared_ptr( + new gcs::AsyncGcsClient("127.0.0.1", 6379, /*is_test_client=*/true)); ObjectManagerConfig om_config_2; om_config_2.store_socket_name = store_sock_2; om_config_2.push_timeout_ms = 10000; diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index e6c13c056..df30498f4 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -54,7 +54,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const std::string &redis_address, int redis_port, boost::asio::io_service &io_service, const NodeManagerConfig &node_manager_config) { - RAY_RETURN_NOT_OK(gcs_client_->Connect(redis_address, redis_port, /*sharding=*/true)); RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); diff --git a/test/runtest.py b/test/runtest.py index b814bd3a8..8520d6885 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1295,11 +1295,7 @@ class APITestSharded(APITest): if kwargs is None: kwargs = {} kwargs["start_ray_local"] = True - if os.environ.get("RAY_USE_XRAY") == "1": - print("XRAY currently supports only a single Redis shard.") - kwargs["num_redis_shards"] = 1 - else: - kwargs["num_redis_shards"] = 20 + kwargs["num_redis_shards"] = 20 kwargs["redirect_output"] = True ray.worker._init(**kwargs)