diff --git a/BUILD.bazel b/BUILD.bazel index 1eb15c48c..26e60b6ba 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -362,6 +362,7 @@ cc_library( ":gcs", ":gcs_pub_sub_lib", ":gcs_service_rpc", + ":gcs_table_storage_lib", ":node_manager_rpc", ":raylet_lib", ":worker_rpc", @@ -749,11 +750,23 @@ cc_test( ], ) +cc_library( + name = "gcs_test_util_lib", + hdrs = [ + "src/ray/gcs/test/accessor_test_base.h", + "src/ray/gcs/test/gcs_test_util.h", + ], + copts = COPTS, + deps = [ + ":gcs", + ":gcs_service_rpc", + ], +) + cc_test( name = "gcs_server_rpc_test", srcs = [ "src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc", - "src/ray/gcs/test/gcs_test_util.h", ], args = ["$(location redis-server) $(location redis-cli) $(location libray_redis_module.so)"], copts = COPTS, @@ -764,6 +777,7 @@ cc_test( ], deps = [ ":gcs_server_lib", + ":gcs_test_util_lib", "@com_google_googletest//:gtest_main", ], ) @@ -772,11 +786,12 @@ cc_test( name = "gcs_node_manager_test", srcs = [ "src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc", - "src/ray/gcs/test/gcs_test_util.h", + "src/ray/gcs/gcs_server/test/gcs_server_test_util.h", ], copts = COPTS, deps = [ ":gcs_server_lib", + ":gcs_test_util_lib", "@com_google_googletest//:gtest_main", ], ) @@ -785,11 +800,12 @@ cc_test( name = "gcs_actor_scheduler_test", srcs = [ "src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc", - "src/ray/gcs/test/gcs_test_util.h", + "src/ray/gcs/gcs_server/test/gcs_server_test_util.h", ], copts = COPTS, deps = [ ":gcs_server_lib", + ":gcs_test_util_lib", "@com_google_googletest//:gtest_main", ], ) @@ -798,11 +814,12 @@ cc_test( name = "gcs_actor_manager_test", srcs = [ "src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc", - "src/ray/gcs/test/gcs_test_util.h", + "src/ray/gcs/gcs_server/test/gcs_server_test_util.h", ], copts = COPTS, deps = [ ":gcs_server_lib", + ":gcs_test_util_lib", "@com_google_googletest//:gtest_main", ], ) @@ -822,16 +839,27 @@ cc_library( copts = COPTS, deps = [ ":gcs", + ":gcs_in_memory_store_client", ":ray_common", ":redis_store_client", ], ) +cc_library( + name = "gcs_table_storage_test_lib", + hdrs = [ + "src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h", + ], + copts = COPTS, + deps = [ + "redis_store_client", + ], +) + cc_test( name = "redis_gcs_table_storage_test", srcs = [ "src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc", - "src/ray/gcs/test/gcs_test_util.h", ], args = ["$(location redis-server) $(location redis-cli) $(location libray_redis_module.so)"], copts = COPTS, @@ -841,8 +869,24 @@ cc_test( "//:redis-server", ], deps = [ - ":gcs_server_lib", ":gcs_table_storage_lib", + ":gcs_table_storage_test_lib", + ":gcs_test_util_lib", + ":store_client_test_lib", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "in_memory_gcs_table_storage_test", + srcs = [ + "src/ray/gcs/gcs_server/test/in_memory_gcs_table_storage_test.cc", + ], + copts = COPTS, + deps = [ + ":gcs_table_storage_lib", + ":gcs_table_storage_test_lib", + ":gcs_test_util_lib", ":store_client_test_lib", "@com_google_googletest//:gtest_main", ], @@ -873,7 +917,6 @@ cc_test( name = "gcs_server_test", srcs = [ "src/ray/gcs/gcs_client/test/service_based_gcs_client_test.cc", - "src/ray/gcs/test/gcs_test_util.h", ], args = ["$(location redis-server) $(location redis-cli) $(location libray_redis_module.so)"], copts = COPTS, @@ -884,6 +927,7 @@ cc_test( ], deps = [ ":gcs_server_lib", + ":gcs_test_util_lib", ":service_based_gcs_client_lib", "@com_google_googletest//:gtest_main", ], @@ -1153,7 +1197,6 @@ cc_library( ), hdrs = glob([ "src/ray/gcs/*.h", - "src/ray/gcs/test/*.h", ]), copts = COPTS, deps = [ @@ -1197,6 +1240,7 @@ cc_test( ], deps = [ ":gcs", + ":gcs_test_util_lib", "@com_google_googletest//:gtest_main", ], ) @@ -1213,6 +1257,7 @@ cc_test( ], deps = [ ":gcs", + ":gcs_test_util_lib", "@com_google_googletest//:gtest_main", ], ) @@ -1229,6 +1274,7 @@ cc_test( ], deps = [ ":gcs", + ":gcs_test_util_lib", "@com_google_googletest//:gtest_main", ], ) @@ -1245,6 +1291,7 @@ cc_test( ], deps = [ ":gcs", + ":gcs_test_util_lib", "@com_google_googletest//:gtest_main", ], ) @@ -1261,6 +1308,7 @@ cc_test( ], deps = [ ":gcs", + ":gcs_test_util_lib", "@com_google_googletest//:gtest_main", ], ) diff --git a/src/ray/gcs/gcs_server/gcs_table_storage.h b/src/ray/gcs/gcs_server/gcs_table_storage.h index 3ef728779..b94014abf 100644 --- a/src/ray/gcs/gcs_server/gcs_table_storage.h +++ b/src/ray/gcs/gcs_server/gcs_table_storage.h @@ -17,6 +17,7 @@ #include +#include "ray/gcs/store_client/in_memory_store_client.h" #include "ray/gcs/store_client/redis_store_client.h" #include "ray/protobuf/gcs.pb.h" @@ -400,6 +401,31 @@ class RedisGcsTableStorage : public GcsTableStorage { } }; +/// \class InMemoryGcsTableStorage +/// InMemoryGcsTableStorage is an implementation of `GcsTableStorage` +/// that uses memory as storage. +class InMemoryGcsTableStorage : public GcsTableStorage { + public: + explicit InMemoryGcsTableStorage(boost::asio::io_service &main_io_service) { + store_client_ = std::make_shared(main_io_service); + job_table_.reset(new GcsJobTable(store_client_)); + actor_table_.reset(new GcsActorTable(store_client_)); + actor_checkpoint_table_.reset(new GcsActorCheckpointTable(store_client_)); + actor_checkpoint_id_table_.reset(new GcsActorCheckpointIdTable(store_client_)); + task_table_.reset(new GcsTaskTable(store_client_)); + task_lease_table_.reset(new GcsTaskLeaseTable(store_client_)); + task_reconstruction_table_.reset(new GcsTaskReconstructionTable(store_client_)); + object_table_.reset(new GcsObjectTable(store_client_)); + node_table_.reset(new GcsNodeTable(store_client_)); + node_resource_table_.reset(new GcsNodeResourceTable(store_client_)); + heartbeat_table_.reset(new GcsHeartbeatTable(store_client_)); + heartbeat_batch_table_.reset(new GcsHeartbeatBatchTable(store_client_)); + error_info_table_.reset(new GcsErrorInfoTable(store_client_)); + profile_table_.reset(new GcsProfileTable(store_client_)); + worker_failure_table_.reset(new GcsWorkerFailureTable(store_client_)); + } +}; + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc index e7e4d85f1..48c50741a 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -28,7 +29,7 @@ class MockedGcsActorManager : public gcs::GcsActorManager { rpc::ClientFactoryFn client_factory = nullptr) : gcs::GcsActorManager(io_context, actor_info_accessor, gcs_node_manager, lease_client_factory, client_factory) { - gcs_actor_scheduler_.reset(new Mocker::MockedGcsActorScheduler( + gcs_actor_scheduler_.reset(new GcsServerMocker::MockedGcsActorScheduler( io_context, actor_info_accessor, gcs_node_manager, /*schedule_failure_handler=*/ [this](std::shared_ptr actor) { @@ -47,14 +48,14 @@ class MockedGcsActorManager : public gcs::GcsActorManager { public: void ResetLeaseClientFactory(gcs::LeaseClientFactoryFn lease_client_factory) { - auto gcs_actor_scheduler = - dynamic_cast(gcs_actor_scheduler_.get()); + auto gcs_actor_scheduler = dynamic_cast( + gcs_actor_scheduler_.get()); gcs_actor_scheduler->ResetLeaseClientFactory(std::move(lease_client_factory)); } void ResetClientFactory(rpc::ClientFactoryFn client_factory) { - auto gcs_actor_scheduler = - dynamic_cast(gcs_actor_scheduler_.get()); + auto gcs_actor_scheduler = dynamic_cast( + gcs_actor_scheduler_.get()); gcs_actor_scheduler->ResetClientFactory(std::move(client_factory)); } @@ -71,8 +72,8 @@ class MockedGcsActorManager : public gcs::GcsActorManager { class GcsActorManagerTest : public ::testing::Test { public: void SetUp() override { - raylet_client_ = std::make_shared(); - worker_client_ = std::make_shared(); + raylet_client_ = std::make_shared(); + worker_client_ = std::make_shared(); gcs_node_manager_ = std::make_shared( io_service_, node_info_accessor_, error_info_accessor_); gcs_actor_manager_ = std::make_shared( @@ -85,12 +86,12 @@ class GcsActorManagerTest : public ::testing::Test { protected: boost::asio::io_service io_service_; - Mocker::MockedActorInfoAccessor actor_info_accessor_; - Mocker::MockedNodeInfoAccessor node_info_accessor_; - Mocker::MockedErrorInfoAccessor error_info_accessor_; + GcsServerMocker::MockedActorInfoAccessor actor_info_accessor_; + GcsServerMocker::MockedNodeInfoAccessor node_info_accessor_; + GcsServerMocker::MockedErrorInfoAccessor error_info_accessor_; - std::shared_ptr raylet_client_; - std::shared_ptr worker_client_; + std::shared_ptr raylet_client_; + std::shared_ptr worker_client_; std::shared_ptr gcs_node_manager_; std::shared_ptr gcs_actor_manager_; }; diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc index 49865b4c9..c2f04e265 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_scheduler_test.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include -#include "gmock/gmock.h" #include "gtest/gtest.h" namespace ray { @@ -23,11 +23,11 @@ namespace ray { class GcsActorSchedulerTest : public ::testing::Test { public: void SetUp() override { - raylet_client_ = std::make_shared(); - worker_client_ = std::make_shared(); + raylet_client_ = std::make_shared(); + worker_client_ = std::make_shared(); gcs_node_manager_ = std::make_shared( io_service_, node_info_accessor_, error_info_accessor_); - gcs_actor_scheduler_ = std::make_shared( + gcs_actor_scheduler_ = std::make_shared( io_service_, actor_info_accessor_, *gcs_node_manager_, /*schedule_failure_handler=*/ [this](std::shared_ptr actor) { @@ -45,14 +45,14 @@ class GcsActorSchedulerTest : public ::testing::Test { protected: boost::asio::io_service io_service_; - Mocker::MockedActorInfoAccessor actor_info_accessor_; - Mocker::MockedNodeInfoAccessor node_info_accessor_; - Mocker::MockedErrorInfoAccessor error_info_accessor_; + GcsServerMocker::MockedActorInfoAccessor actor_info_accessor_; + GcsServerMocker::MockedNodeInfoAccessor node_info_accessor_; + GcsServerMocker::MockedErrorInfoAccessor error_info_accessor_; - std::shared_ptr raylet_client_; - std::shared_ptr worker_client_; + std::shared_ptr raylet_client_; + std::shared_ptr worker_client_; std::shared_ptr gcs_node_manager_; - std::shared_ptr gcs_actor_scheduler_; + std::shared_ptr gcs_actor_scheduler_; std::vector> success_actors_; std::vector> failure_actors_; }; diff --git a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc index bb7f7dd39..b0e387d46 100644 --- a/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_node_manager_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -22,8 +23,8 @@ class GcsNodeManagerTest : public ::testing::Test {}; TEST_F(GcsNodeManagerTest, TestManagement) { boost::asio::io_service io_service; - auto node_info_accessor = Mocker::MockedNodeInfoAccessor(); - auto error_info_accessor = Mocker::MockedErrorInfoAccessor(); + auto node_info_accessor = GcsServerMocker::MockedNodeInfoAccessor(); + auto error_info_accessor = GcsServerMocker::MockedErrorInfoAccessor(); gcs::GcsNodeManager node_manager(io_service, node_info_accessor, error_info_accessor); // Test Add/Get/Remove functionality. auto node = Mocker::GenNodeInfo(); @@ -38,8 +39,8 @@ TEST_F(GcsNodeManagerTest, TestManagement) { TEST_F(GcsNodeManagerTest, TestListener) { boost::asio::io_service io_service; - auto node_info_accessor = Mocker::MockedNodeInfoAccessor(); - auto error_info_accessor = Mocker::MockedErrorInfoAccessor(); + auto node_info_accessor = GcsServerMocker::MockedNodeInfoAccessor(); + auto error_info_accessor = GcsServerMocker::MockedErrorInfoAccessor(); gcs::GcsNodeManager node_manager(io_service, node_info_accessor, error_info_accessor); // Test AddNodeAddedListener. int node_count = 1000; diff --git a/src/ray/gcs/gcs_server/test/gcs_server_test_util.h b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h new file mode 100644 index 000000000..cb467cfa8 --- /dev/null +++ b/src/ray/gcs/gcs_server/test/gcs_server_test_util.h @@ -0,0 +1,362 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RAY_GCS_SERVER_TEST_UTIL_H +#define RAY_GCS_SERVER_TEST_UTIL_H + +#include +#include + +#include "src/ray/common/task/task.h" +#include "src/ray/common/task/task_util.h" +#include "src/ray/common/test_util.h" +#include "src/ray/gcs/gcs_server/gcs_actor_manager.h" +#include "src/ray/gcs/gcs_server/gcs_actor_scheduler.h" +#include "src/ray/gcs/gcs_server/gcs_node_manager.h" +#include "src/ray/util/asio_util.h" + +namespace ray { + +struct GcsServerMocker { + class MockWorkerClient : public rpc::CoreWorkerClientInterface { + public: + ray::Status PushNormalTask( + std::unique_ptr request, + const rpc::ClientCallback &callback) override { + callbacks.push_back(callback); + if (enable_auto_reply) { + ReplyPushTask(); + } + return Status::OK(); + } + + bool ReplyPushTask(Status status = Status::OK(), bool exit = false) { + if (callbacks.size() == 0) { + return false; + } + auto callback = callbacks.front(); + auto reply = rpc::PushTaskReply(); + if (exit) { + reply.set_worker_exiting(true); + } + callback(status, reply); + callbacks.pop_front(); + return true; + } + + bool enable_auto_reply = false; + std::list> callbacks; + }; + + class MockRayletClient : public WorkerLeaseInterface { + public: + ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, + bool disconnect_worker) override { + if (disconnect_worker) { + num_workers_disconnected++; + } else { + num_workers_returned++; + } + return Status::OK(); + } + + ray::Status RequestWorkerLease( + const ray::TaskSpecification &resource_spec, + const rpc::ClientCallback &callback) override { + num_workers_requested += 1; + callbacks.push_back(callback); + if (!auto_grant_node_id.IsNil()) { + GrantWorkerLease("", 0, WorkerID::FromRandom(), auto_grant_node_id, + ClientID::Nil()); + } + return Status::OK(); + } + + ray::Status CancelWorkerLease( + const TaskID &task_id, + const rpc::ClientCallback &callback) override { + num_leases_canceled += 1; + cancel_callbacks.push_back(callback); + return Status::OK(); + } + + // Trigger reply to RequestWorkerLease. + bool GrantWorkerLease(const std::string &address, int port, const WorkerID &worker_id, + const ClientID &raylet_id, const ClientID &retry_at_raylet_id, + Status status = Status::OK()) { + rpc::RequestWorkerLeaseReply reply; + if (!retry_at_raylet_id.IsNil()) { + reply.mutable_retry_at_raylet_address()->set_ip_address(address); + reply.mutable_retry_at_raylet_address()->set_port(port); + reply.mutable_retry_at_raylet_address()->set_raylet_id( + retry_at_raylet_id.Binary()); + } else { + reply.mutable_worker_address()->set_ip_address(address); + reply.mutable_worker_address()->set_port(port); + reply.mutable_worker_address()->set_raylet_id(raylet_id.Binary()); + reply.mutable_worker_address()->set_worker_id(worker_id.Binary()); + } + if (callbacks.size() == 0) { + return false; + } else { + auto callback = callbacks.front(); + callback(status, reply); + callbacks.pop_front(); + return true; + } + } + + bool ReplyCancelWorkerLease(bool success = true) { + rpc::CancelWorkerLeaseReply reply; + reply.set_success(success); + if (cancel_callbacks.size() == 0) { + return false; + } else { + auto callback = cancel_callbacks.front(); + callback(Status::OK(), reply); + cancel_callbacks.pop_front(); + return true; + } + } + + ~MockRayletClient() {} + + int num_workers_requested = 0; + int num_workers_returned = 0; + int num_workers_disconnected = 0; + int num_leases_canceled = 0; + ClientID auto_grant_node_id; + std::list> callbacks = {}; + std::list> cancel_callbacks = {}; + }; + + class MockedGcsActorScheduler : public gcs::GcsActorScheduler { + public: + using gcs::GcsActorScheduler::GcsActorScheduler; + + void ResetLeaseClientFactory(gcs::LeaseClientFactoryFn lease_client_factory) { + lease_client_factory_ = std::move(lease_client_factory); + } + + void ResetClientFactory(rpc::ClientFactoryFn client_factory) { + client_factory_ = std::move(client_factory); + } + + protected: + void RetryLeasingWorkerFromNode(std::shared_ptr actor, + std::shared_ptr node) override { + ++num_retry_leasing_count_; + DoRetryLeasingWorkerFromNode(actor, node); + } + + void RetryCreatingActorOnWorker(std::shared_ptr actor, + std::shared_ptr worker) override { + ++num_retry_creating_count_; + DoRetryCreatingActorOnWorker(actor, worker); + } + + public: + int num_retry_leasing_count_ = 0; + int num_retry_creating_count_ = 0; + }; + + class MockedActorInfoAccessor : public gcs::ActorInfoAccessor { + public: + Status GetAll(std::vector *actor_table_data_list) override { + return Status::NotImplemented(""); + } + + Status AsyncGet( + const ActorID &actor_id, + const gcs::OptionalItemCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncCreateActor(const TaskSpecification &task_spec, + const gcs::StatusCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncRegister(const std::shared_ptr &data_ptr, + const gcs::StatusCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncUpdate(const ActorID &actor_id, + const std::shared_ptr &data_ptr, + const gcs::StatusCallback &callback) override { + if (callback) { + callback(Status::OK()); + } + return Status::OK(); + } + + Status AsyncSubscribeAll( + const gcs::SubscribeCallback &subscribe, + const gcs::StatusCallback &done) override { + return Status::NotImplemented(""); + } + + Status AsyncSubscribe( + const ActorID &actor_id, + const gcs::SubscribeCallback &subscribe, + const gcs::StatusCallback &done) override { + return Status::NotImplemented(""); + } + + Status AsyncUnsubscribe(const ActorID &actor_id, + const gcs::StatusCallback &done) override { + return Status::NotImplemented(""); + } + + Status AsyncAddCheckpoint(const std::shared_ptr &data_ptr, + const gcs::StatusCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncGetCheckpoint( + const ActorCheckpointID &checkpoint_id, const ActorID &actor_id, + const gcs::OptionalItemCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncGetCheckpointID( + const ActorID &actor_id, + const gcs::OptionalItemCallback &callback) override { + return Status::NotImplemented(""); + } + }; + + class MockedNodeInfoAccessor : public gcs::NodeInfoAccessor { + public: + Status RegisterSelf(const rpc::GcsNodeInfo &local_node_info) override { + return Status::NotImplemented(""); + } + + Status UnregisterSelf() override { return Status::NotImplemented(""); } + + const ClientID &GetSelfId() const override { + static ClientID node_id; + return node_id; + } + + const rpc::GcsNodeInfo &GetSelfInfo() const override { + static rpc::GcsNodeInfo node_info; + return node_info; + } + + Status AsyncRegister(const rpc::GcsNodeInfo &node_info, + const gcs::StatusCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncUnregister(const ClientID &node_id, + const gcs::StatusCallback &callback) override { + if (callback) { + callback(Status::OK()); + } + return Status::OK(); + } + + Status AsyncGetAll( + const gcs::MultiItemCallback &callback) override { + if (callback) { + callback(Status::OK(), {}); + } + return Status::OK(); + } + + Status AsyncSubscribeToNodeChange( + const gcs::SubscribeCallback &subscribe, + const gcs::StatusCallback &done) override { + return Status::NotImplemented(""); + } + + boost::optional Get(const ClientID &node_id) const override { + return boost::none; + } + + const std::unordered_map &GetAll() const override { + static std::unordered_map node_info_list; + return node_info_list; + } + + bool IsRemoved(const ClientID &node_id) const override { return false; } + + Status AsyncGetResources( + const ClientID &node_id, + const gcs::OptionalItemCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncUpdateResources(const ClientID &node_id, const ResourceMap &resources, + const gcs::StatusCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncDeleteResources(const ClientID &node_id, + const std::vector &resource_names, + const gcs::StatusCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncSubscribeToResources( + const gcs::SubscribeCallback + &subscribe, + const gcs::StatusCallback &done) override { + return Status::NotImplemented(""); + } + + Status AsyncReportHeartbeat(const std::shared_ptr &data_ptr, + const gcs::StatusCallback &callback) override { + return Status::NotImplemented(""); + } + + Status AsyncSubscribeHeartbeat( + const gcs::SubscribeCallback &subscribe, + const gcs::StatusCallback &done) override { + return Status::NotImplemented(""); + } + + Status AsyncReportBatchHeartbeat( + const std::shared_ptr &data_ptr, + const gcs::StatusCallback &callback) override { + if (callback) { + callback(Status::OK()); + } + return Status::OK(); + } + + Status AsyncSubscribeBatchHeartbeat( + const gcs::ItemCallback &subscribe, + const gcs::StatusCallback &done) override { + return Status::NotImplemented(""); + } + }; + + class MockedErrorInfoAccessor : public gcs::ErrorInfoAccessor { + public: + Status AsyncReportJobError(const std::shared_ptr &data_ptr, + const gcs::StatusCallback &callback) override { + if (callback) { + callback(Status::OK()); + } + return Status::OK(); + } + }; +}; + +} // namespace ray + +#endif // RAY_GCS_SERVER_TEST_UTIL_H diff --git a/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h b/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h new file mode 100644 index 000000000..ddd6c59c7 --- /dev/null +++ b/src/ray/gcs/gcs_server/test/gcs_table_storage_test_base.h @@ -0,0 +1,129 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "ray/common/id.h" +#include "ray/common/test_util.h" +#include "ray/gcs/gcs_server/gcs_table_storage.h" +#include "ray/gcs/test/gcs_test_util.h" + +namespace ray { + +namespace gcs { + +class GcsTableStorageTestBase : public ::testing::Test { + public: + GcsTableStorageTestBase() { + io_service_pool_ = std::make_shared(io_service_num_); + io_service_pool_->Run(); + } + + virtual ~GcsTableStorageTestBase() { io_service_pool_->Stop(); } + + protected: + void TestGcsTableApi() { + auto table = gcs_table_storage_->JobTable(); + JobID job1_id = JobID::FromInt(1); + JobID job2_id = JobID::FromInt(2); + auto job1_table_data = Mocker::GenJobTableData(job1_id); + auto job2_table_data = Mocker::GenJobTableData(job2_id); + + // Put. + Put(table, job1_id, *job1_table_data); + Put(table, job2_id, *job2_table_data); + + // Get. + std::vector values; + ASSERT_EQ(Get(table, job2_id, values), 1); + ASSERT_EQ(Get(table, job2_id, values), 1); + + // Delete. + Delete(table, job1_id); + ASSERT_EQ(Get(table, job1_id, values), 0); + ASSERT_EQ(Get(table, job2_id, values), 1); + } + + void TestGcsTableWithJobIdApi() { + auto table = gcs_table_storage_->ActorTable(); + JobID job_id = JobID::FromInt(3); + auto actor_table_data = Mocker::GenActorTableData(job_id); + ActorID actor_id = ActorID::FromBinary(actor_table_data->actor_id()); + + // Put. + Put(table, actor_id, *actor_table_data); + + // Get. + std::vector values; + ASSERT_EQ(Get(table, actor_id, values), 1); + + // Delete. + Delete(table, actor_id); + ASSERT_EQ(Get(table, actor_id, values), 0); + } + + template + void Put(TABLE &table, const KEY &key, const VALUE &value) { + auto on_done = [this](Status status) { --pending_count_; }; + ++pending_count_; + RAY_CHECK_OK(table.Put(key, value, on_done)); + WaitPendingDone(); + } + + template + int Get(TABLE &table, const KEY &key, std::vector &values) { + auto on_done = [this, &values](Status status, const boost::optional &result) { + RAY_CHECK_OK(status); + --pending_count_; + values.clear(); + if (result) { + values.push_back(*result); + } + }; + ++pending_count_; + RAY_CHECK_OK(table.Get(key, on_done)); + WaitPendingDone(); + return values.size(); + } + + template + void Delete(TABLE &table, const KEY &key) { + auto on_done = [this](Status status) { + RAY_CHECK_OK(status); + --pending_count_; + }; + ++pending_count_; + RAY_CHECK_OK(table.Delete(key, on_done)); + WaitPendingDone(); + } + + void WaitPendingDone() { WaitPendingDone(pending_count_); } + + void WaitPendingDone(std::atomic &pending_count) { + auto condition = [&pending_count]() { return pending_count == 0; }; + EXPECT_TRUE(WaitForCondition(condition, wait_pending_timeout_.count())); + } + + protected: + size_t io_service_num_{2}; + std::shared_ptr io_service_pool_; + + std::shared_ptr gcs_table_storage_; + + std::atomic pending_count_{0}; + std::chrono::milliseconds wait_pending_timeout_{5000}; +}; + +} // namespace gcs + +} // namespace ray diff --git a/src/ray/gcs/gcs_server/test/in_memory_gcs_table_storage_test.cc b/src/ray/gcs/gcs_server/test/in_memory_gcs_table_storage_test.cc new file mode 100644 index 000000000..0a6327753 --- /dev/null +++ b/src/ray/gcs/gcs_server/test/in_memory_gcs_table_storage_test.cc @@ -0,0 +1,42 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "ray/common/test_util.h" +#include "ray/gcs/gcs_server/gcs_table_storage.h" +#include "ray/gcs/gcs_server/test/gcs_table_storage_test_base.h" +#include "ray/gcs/store_client/in_memory_store_client.h" + +namespace ray { + +class InMemoryGcsTableStorageTest : public gcs::GcsTableStorageTestBase { + public: + void SetUp() override { + gcs_table_storage_ = + std::make_shared(*(io_service_pool_->Get())); + } +}; + +TEST_F(InMemoryGcsTableStorageTest, TestGcsTableApi) { TestGcsTableApi(); } + +TEST_F(InMemoryGcsTableStorageTest, TestGcsTableWithJobIdApi) { + TestGcsTableWithJobIdApi(); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc b/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc index 588550bf3..543a7666e 100644 --- a/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc +++ b/src/ray/gcs/gcs_server/test/redis_gcs_table_storage_test.cc @@ -15,23 +15,18 @@ #include "gtest/gtest.h" #include "ray/common/test_util.h" #include "ray/gcs/gcs_server/gcs_table_storage.h" +#include "ray/gcs/gcs_server/test/gcs_table_storage_test_base.h" #include "ray/gcs/store_client/redis_store_client.h" -#include "ray/gcs/store_client/test/store_client_test_base.h" -#include "ray/gcs/test/gcs_test_util.h" namespace ray { -class GcsTableStorageTest : public gcs::StoreClientTestBase { +class RedisGcsTableStorageTest : public gcs::GcsTableStorageTestBase { public: - GcsTableStorageTest() {} - - virtual ~GcsTableStorageTest() {} - static void SetUpTestCase() { RedisServiceManagerForTest::SetUpTestCase(); } static void TearDownTestCase() { RedisServiceManagerForTest::TearDownTestCase(); } - void InitStoreClient() override { + void SetUp() override { gcs::RedisClientOptions options("127.0.0.1", REDIS_SERVER_PORT, "", true); redis_client_ = std::make_shared(options); RAY_CHECK_OK(redis_client_->Connect(io_service_pool_->GetAll())); @@ -39,87 +34,15 @@ class GcsTableStorageTest : public gcs::StoreClientTestBase { gcs_table_storage_ = std::make_shared(redis_client_); } - void DisconnectStoreClient() override { redis_client_->Disconnect(); } + void TearDown() override { redis_client_->Disconnect(); } protected: - template - void Put(TABLE &table, const KEY &key, const VALUE &value) { - auto on_done = [this](Status status) { --pending_count_; }; - ++pending_count_; - RAY_CHECK_OK(table.Put(key, value, on_done)); - WaitPendingDone(); - } - - template - int Get(TABLE &table, const KEY &key, std::vector &values) { - auto on_done = [this, &values](Status status, const boost::optional &result) { - RAY_CHECK_OK(status); - --pending_count_; - values.clear(); - if (result) { - values.push_back(*result); - } - }; - ++pending_count_; - RAY_CHECK_OK(table.Get(key, on_done)); - WaitPendingDone(); - return values.size(); - } - - template - void Delete(TABLE &table, const KEY &key) { - auto on_done = [this](Status status) { - RAY_CHECK_OK(status); - --pending_count_; - }; - ++pending_count_; - RAY_CHECK_OK(table.Delete(key, on_done)); - WaitPendingDone(); - } - std::shared_ptr redis_client_; - std::shared_ptr gcs_table_storage_; }; -TEST_F(GcsTableStorageTest, TestGcsTableApi) { - auto table = gcs_table_storage_->JobTable(); - JobID job1_id = JobID::FromInt(1); - JobID job2_id = JobID::FromInt(2); - auto job1_table_data = Mocker::GenJobTableData(job1_id); - auto job2_table_data = Mocker::GenJobTableData(job2_id); +TEST_F(RedisGcsTableStorageTest, TestGcsTableApi) { TestGcsTableApi(); } - // Put. - Put(table, job1_id, *job1_table_data); - Put(table, job2_id, *job2_table_data); - - // Get. - std::vector values; - ASSERT_EQ(Get(table, job2_id, values), 1); - ASSERT_EQ(Get(table, job2_id, values), 1); - - // Delete. - Delete(table, job1_id); - ASSERT_EQ(Get(table, job1_id, values), 0); - ASSERT_EQ(Get(table, job2_id, values), 1); -} - -TEST_F(GcsTableStorageTest, TestGcsTableWithJobIdApi) { - auto table = gcs_table_storage_->ActorTable(); - JobID job_id = JobID::FromInt(3); - auto actor_table_data = Mocker::GenActorTableData(job_id); - ActorID actor_id = ActorID::FromBinary(actor_table_data->actor_id()); - - // Put. - Put(table, actor_id, *actor_table_data); - - // Get. - std::vector values; - ASSERT_EQ(Get(table, actor_id, values), 1); - - // Delete. - Delete(table, actor_id); - ASSERT_EQ(Get(table, actor_id, values), 0); -} +TEST_F(RedisGcsTableStorageTest, TestGcsTableWithJobIdApi) { TestGcsTableWithJobIdApi(); } } // namespace ray diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index c7a310835..94f383700 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -21,11 +21,10 @@ #include "src/ray/common/task/task.h" #include "src/ray/common/task/task_util.h" #include "src/ray/common/test_util.h" -#include "src/ray/gcs/gcs_server/gcs_actor_manager.h" -#include "src/ray/gcs/gcs_server/gcs_actor_scheduler.h" -#include "src/ray/gcs/gcs_server/gcs_node_manager.h" #include "src/ray/util/asio_util.h" +#include "src/ray/protobuf/gcs_service.grpc.pb.h" + namespace ray { struct Mocker { @@ -119,333 +118,6 @@ struct Mocker { worker_failure_data->set_timestamp(std::time(nullptr)); return worker_failure_data; } - - class MockWorkerClient : public rpc::CoreWorkerClientInterface { - public: - ray::Status PushNormalTask( - std::unique_ptr request, - const rpc::ClientCallback &callback) override { - callbacks.push_back(callback); - if (enable_auto_reply) { - ReplyPushTask(); - } - return Status::OK(); - } - - bool ReplyPushTask(Status status = Status::OK(), bool exit = false) { - if (callbacks.size() == 0) { - return false; - } - auto callback = callbacks.front(); - auto reply = rpc::PushTaskReply(); - if (exit) { - reply.set_worker_exiting(true); - } - callback(status, reply); - callbacks.pop_front(); - return true; - } - - bool enable_auto_reply = false; - std::list> callbacks; - }; - - class MockRayletClient : public WorkerLeaseInterface { - public: - ray::Status ReturnWorker(int worker_port, const WorkerID &worker_id, - bool disconnect_worker) override { - if (disconnect_worker) { - num_workers_disconnected++; - } else { - num_workers_returned++; - } - return Status::OK(); - } - - ray::Status RequestWorkerLease( - const ray::TaskSpecification &resource_spec, - const rpc::ClientCallback &callback) override { - num_workers_requested += 1; - callbacks.push_back(callback); - if (!auto_grant_node_id.IsNil()) { - GrantWorkerLease("", 0, WorkerID::FromRandom(), auto_grant_node_id, - ClientID::Nil()); - } - return Status::OK(); - } - - ray::Status CancelWorkerLease( - const TaskID &task_id, - const rpc::ClientCallback &callback) override { - num_leases_canceled += 1; - cancel_callbacks.push_back(callback); - return Status::OK(); - } - - // Trigger reply to RequestWorkerLease. - bool GrantWorkerLease(const std::string &address, int port, const WorkerID &worker_id, - const ClientID &raylet_id, const ClientID &retry_at_raylet_id, - Status status = Status::OK()) { - rpc::RequestWorkerLeaseReply reply; - if (!retry_at_raylet_id.IsNil()) { - reply.mutable_retry_at_raylet_address()->set_ip_address(address); - reply.mutable_retry_at_raylet_address()->set_port(port); - reply.mutable_retry_at_raylet_address()->set_raylet_id( - retry_at_raylet_id.Binary()); - } else { - reply.mutable_worker_address()->set_ip_address(address); - reply.mutable_worker_address()->set_port(port); - reply.mutable_worker_address()->set_raylet_id(raylet_id.Binary()); - reply.mutable_worker_address()->set_worker_id(worker_id.Binary()); - } - if (callbacks.size() == 0) { - return false; - } else { - auto callback = callbacks.front(); - callback(status, reply); - callbacks.pop_front(); - return true; - } - } - - bool ReplyCancelWorkerLease(bool success = true) { - rpc::CancelWorkerLeaseReply reply; - reply.set_success(success); - if (cancel_callbacks.size() == 0) { - return false; - } else { - auto callback = cancel_callbacks.front(); - callback(Status::OK(), reply); - cancel_callbacks.pop_front(); - return true; - } - } - - ~MockRayletClient() {} - - int num_workers_requested = 0; - int num_workers_returned = 0; - int num_workers_disconnected = 0; - int num_leases_canceled = 0; - ClientID auto_grant_node_id; - std::list> callbacks = {}; - std::list> cancel_callbacks = {}; - }; - - class MockedGcsActorScheduler : public gcs::GcsActorScheduler { - public: - using gcs::GcsActorScheduler::GcsActorScheduler; - - void ResetLeaseClientFactory(gcs::LeaseClientFactoryFn lease_client_factory) { - lease_client_factory_ = std::move(lease_client_factory); - } - - void ResetClientFactory(rpc::ClientFactoryFn client_factory) { - client_factory_ = std::move(client_factory); - } - - protected: - void RetryLeasingWorkerFromNode(std::shared_ptr actor, - std::shared_ptr node) override { - ++num_retry_leasing_count_; - DoRetryLeasingWorkerFromNode(actor, node); - } - - void RetryCreatingActorOnWorker(std::shared_ptr actor, - std::shared_ptr worker) override { - ++num_retry_creating_count_; - DoRetryCreatingActorOnWorker(actor, worker); - } - - public: - int num_retry_leasing_count_ = 0; - int num_retry_creating_count_ = 0; - }; - - class MockedActorInfoAccessor : public gcs::ActorInfoAccessor { - public: - Status GetAll(std::vector *actor_table_data_list) override { - return Status::NotImplemented(""); - } - - Status AsyncGet( - const ActorID &actor_id, - const gcs::OptionalItemCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncCreateActor(const TaskSpecification &task_spec, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncRegister(const std::shared_ptr &data_ptr, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncUpdate(const ActorID &actor_id, - const std::shared_ptr &data_ptr, - const gcs::StatusCallback &callback) override { - if (callback) { - callback(Status::OK()); - } - return Status::OK(); - } - - Status AsyncSubscribeAll( - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) override { - return Status::NotImplemented(""); - } - - Status AsyncSubscribe( - const ActorID &actor_id, - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) override { - return Status::NotImplemented(""); - } - - Status AsyncUnsubscribe(const ActorID &actor_id, - const gcs::StatusCallback &done) override { - return Status::NotImplemented(""); - } - - Status AsyncAddCheckpoint(const std::shared_ptr &data_ptr, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncGetCheckpoint( - const ActorCheckpointID &checkpoint_id, const ActorID &actor_id, - const gcs::OptionalItemCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncGetCheckpointID( - const ActorID &actor_id, - const gcs::OptionalItemCallback &callback) override { - return Status::NotImplemented(""); - } - }; - - class MockedNodeInfoAccessor : public gcs::NodeInfoAccessor { - public: - Status RegisterSelf(const rpc::GcsNodeInfo &local_node_info) override { - return Status::NotImplemented(""); - } - - Status UnregisterSelf() override { return Status::NotImplemented(""); } - - const ClientID &GetSelfId() const override { - static ClientID node_id; - return node_id; - } - - const rpc::GcsNodeInfo &GetSelfInfo() const override { - static rpc::GcsNodeInfo node_info; - return node_info; - } - - Status AsyncRegister(const rpc::GcsNodeInfo &node_info, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncUnregister(const ClientID &node_id, - const gcs::StatusCallback &callback) override { - if (callback) { - callback(Status::OK()); - } - return Status::OK(); - } - - Status AsyncGetAll( - const gcs::MultiItemCallback &callback) override { - if (callback) { - callback(Status::OK(), {}); - } - return Status::OK(); - } - - Status AsyncSubscribeToNodeChange( - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) override { - return Status::NotImplemented(""); - } - - boost::optional Get(const ClientID &node_id) const override { - return boost::none; - } - - const std::unordered_map &GetAll() const override { - static std::unordered_map node_info_list; - return node_info_list; - } - - bool IsRemoved(const ClientID &node_id) const override { return false; } - - Status AsyncGetResources( - const ClientID &node_id, - const gcs::OptionalItemCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncUpdateResources(const ClientID &node_id, const ResourceMap &resources, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncDeleteResources(const ClientID &node_id, - const std::vector &resource_names, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncSubscribeToResources( - const gcs::SubscribeCallback - &subscribe, - const gcs::StatusCallback &done) override { - return Status::NotImplemented(""); - } - - Status AsyncReportHeartbeat(const std::shared_ptr &data_ptr, - const gcs::StatusCallback &callback) override { - return Status::NotImplemented(""); - } - - Status AsyncSubscribeHeartbeat( - const gcs::SubscribeCallback &subscribe, - const gcs::StatusCallback &done) override { - return Status::NotImplemented(""); - } - - Status AsyncReportBatchHeartbeat( - const std::shared_ptr &data_ptr, - const gcs::StatusCallback &callback) override { - if (callback) { - callback(Status::OK()); - } - return Status::OK(); - } - - Status AsyncSubscribeBatchHeartbeat( - const gcs::ItemCallback &subscribe, - const gcs::StatusCallback &done) override { - return Status::NotImplemented(""); - } - }; - - class MockedErrorInfoAccessor : public gcs::ErrorInfoAccessor { - public: - Status AsyncReportJobError(const std::shared_ptr &data_ptr, - const gcs::StatusCallback &callback) override { - if (callback) { - callback(Status::OK()); - } - return Status::OK(); - } - }; }; } // namespace ray