From 765d470c409dec3f599858b6a8ec7199eb8f2fa8 Mon Sep 17 00:00:00 2001 From: fangfengbin <869218239a@zju.edu.cn> Date: Mon, 25 May 2020 17:21:35 +0800 Subject: [PATCH] Add gcs object manager (#8298) --- BUILD.bazel | 9 +- python/ray/gcs_utils.py | 2 + python/ray/includes/global_state_accessor.pxd | 6 + python/ray/includes/global_state_accessor.pxi | 15 + python/ray/state.py | 80 +++--- python/ray/tests/test_advanced_3.py | 20 -- src/ray/gcs/accessor.h | 7 + .../gcs/gcs_client/global_state_accessor.cc | 40 +++ .../gcs/gcs_client/global_state_accessor.h | 19 +- .../gcs/gcs_client/service_based_accessor.cc | 18 ++ .../gcs/gcs_client/service_based_accessor.h | 2 + .../test/global_state_accessor_test.cc | 22 ++ src/ray/gcs/gcs_server/gcs_object_manager.cc | 271 ++++++++++++++++++ src/ray/gcs/gcs_server/gcs_object_manager.h | 138 +++++++++ src/ray/gcs/gcs_server/gcs_server.cc | 6 +- .../gcs_server/object_info_handler_impl.cc | 115 -------- .../gcs/gcs_server/object_info_handler_impl.h | 52 ---- src/ray/gcs/gcs_server/object_locator.cc | 129 --------- src/ray/gcs/gcs_server/object_locator.h | 93 ------ .../test/gcs_object_manager_test.cc | 154 ++++++++++ .../gcs_server/test/object_locator_test.cc | 99 ------- src/ray/gcs/redis_accessor.h | 5 + src/ray/protobuf/gcs.proto | 5 + src/ray/protobuf/gcs_service.proto | 16 +- src/ray/rpc/gcs_server/gcs_rpc_client.h | 4 + src/ray/rpc/gcs_server/gcs_rpc_server.h | 5 + 26 files changed, 768 insertions(+), 564 deletions(-) create mode 100644 src/ray/gcs/gcs_server/gcs_object_manager.cc create mode 100644 src/ray/gcs/gcs_server/gcs_object_manager.h delete mode 100644 src/ray/gcs/gcs_server/object_info_handler_impl.cc delete mode 100644 src/ray/gcs/gcs_server/object_info_handler_impl.h delete mode 100644 src/ray/gcs/gcs_server/object_locator.cc delete mode 100644 src/ray/gcs/gcs_server/object_locator.h create mode 100644 src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc delete mode 100644 src/ray/gcs/gcs_server/test/object_locator_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index 02f664bb9..8b23694ea 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -973,11 +973,16 @@ cc_test( ) cc_test( - name = "object_locator_test", - srcs = ["src/ray/gcs/gcs_server/test/object_locator_test.cc"], + name = "gcs_object_manager_test", + srcs = [ + "src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc", + "src/ray/gcs/gcs_server/test/gcs_server_test_util.h", + ], copts = COPTS, deps = [ ":gcs_server_lib", + ":gcs_test_util_lib", + "@com_google_googletest//:gtest_main", ], ) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 42f7581ee..c0a7fdd22 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -14,6 +14,7 @@ from ray.core.generated.gcs_pb2 import ( TablePubsub, TaskTableData, ResourceTableData, + ObjectLocationInfo, ) __all__ = [ @@ -33,6 +34,7 @@ __all__ = [ "TaskTableData", "ResourceTableData", "construct_error_message", + "ObjectLocationInfo", ] FUNCTION_PREFIX = "RemoteFunction:" diff --git a/python/ray/includes/global_state_accessor.pxd b/python/ray/includes/global_state_accessor.pxd index a5aac7b87..6ef24596b 100644 --- a/python/ray/includes/global_state_accessor.pxd +++ b/python/ray/includes/global_state_accessor.pxd @@ -1,6 +1,10 @@ from libcpp.string cimport string as c_string from libcpp cimport bool as c_bool from libcpp.vector cimport vector as c_vector +from libcpp.memory cimport unique_ptr +from ray.includes.unique_ids cimport ( + CObjectID +) cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: cdef cppclass CGlobalStateAccessor "ray::gcs::GlobalStateAccessor": @@ -11,3 +15,5 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil: void Disconnect() c_vector[c_string] GetAllJobInfo() c_vector[c_string] GetAllProfileInfo() + c_vector[c_string] GetAllObjectInfo() + unique_ptr[c_string] GetObjectInfo(const CObjectID &object_id) diff --git a/python/ray/includes/global_state_accessor.pxi b/python/ray/includes/global_state_accessor.pxi index 7400829eb..99e20aa7b 100644 --- a/python/ray/includes/global_state_accessor.pxi +++ b/python/ray/includes/global_state_accessor.pxi @@ -1,7 +1,13 @@ +from ray.includes.unique_ids cimport ( + CObjectID +) + from ray.includes.global_state_accessor cimport ( CGlobalStateAccessor, ) +from libcpp.string cimport string as c_string + cdef class GlobalStateAccessor: """Cython wrapper class of C++ `ray::gcs::GlobalStateAccessor`.""" cdef: @@ -25,3 +31,12 @@ cdef class GlobalStateAccessor: def get_profile_table(self): return self.inner.get().GetAllProfileInfo() + + def get_object_table(self): + return self.inner.get().GetAllObjectInfo() + + def get_object_info(self, object_id): + object_info = self.inner.get().GetObjectInfo(CObjectID.FromBinary(object_id.binary())) + if object_info: + return c_string(object_info.get().data(), object_info.get().size()) + return None diff --git a/python/ray/state.py b/python/ray/state.py index 8687f7b02..70538f811 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -10,8 +10,7 @@ from ray import ( gcs_utils, services, ) -from ray.utils import (decode, binary_to_object_id, binary_to_hex, - hex_to_binary) +from ray.utils import (decode, binary_to_hex, hex_to_binary) from ray._raylet import GlobalStateAccessor @@ -256,38 +255,6 @@ class GlobalState: result.extend(list(client.scan_iter(match=pattern))) return result - def _object_table(self, object_id): - """Fetch and parse the object table information for a single object ID. - - Args: - object_id: An object ID to get information about. - - Returns: - A dictionary with information about the object ID in question. - """ - # Allow the argument to be either an ObjectID or a hex string. - if not isinstance(object_id, ray.ObjectID): - object_id = ray.ObjectID(hex_to_binary(object_id)) - - # Return information about a single object ID. - message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value("OBJECT"), - "", object_id.binary()) - if message is None: - return {} - gcs_entry = gcs_utils.GcsEntry.FromString(message) - - assert len(gcs_entry.entries) > 0 - - entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0]) - - object_info = { - "DataSize": entry.object_size, - "Manager": entry.manager, - } - - return object_info - def object_table(self, object_id=None): """Fetch and parse the object table info for one or more object IDs. @@ -299,23 +266,42 @@ class GlobalState: Information from the object table. """ self._check_connected() - if object_id is not None: - # Return information about a single object ID. - return self._object_table(object_id) - else: - # Return the entire object table. - object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*") - object_ids_binary = { - key[len(gcs_utils.TablePrefix_OBJECT_string):] - for key in object_keys - } + if object_id is not None: + object_id = ray.ObjectID(hex_to_binary(object_id)) + object_info = self.global_state_accessor.get_object_info(object_id) + if object_info is None: + return {} + else: + object_location_info = gcs_utils.ObjectLocationInfo.FromString( + object_info) + return self._gen_object_info(object_location_info) + else: + object_table = self.global_state_accessor.get_object_table() results = {} - for object_id_binary in object_ids_binary: - results[binary_to_object_id(object_id_binary)] = ( - self._object_table(binary_to_object_id(object_id_binary))) + for i in range(len(object_table)): + object_location_info = gcs_utils.ObjectLocationInfo.FromString( + object_table[i]) + results[binary_to_hex(object_location_info.object_id)] = \ + self._gen_object_info(object_location_info) return results + def _gen_object_info(self, object_location_info): + """Parse object location info. + Returns: + Information from object. + """ + locations = [] + for location in object_location_info.locations: + locations.append(ray.utils.binary_to_hex(location.manager)) + + object_info = { + "ObjectID": ray.utils.binary_to_hex( + object_location_info.object_id), + "Locations": locations, + } + return object_info + def _actor_table(self, actor_id): """Fetch and parse the actor table information for a single actor ID. diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index 340b427f8..9658920fd 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -511,26 +511,6 @@ def test_put_pins_object(ray_start_object_store_memory): ray.get(y_id) -@pytest.mark.parametrize( - "ray_start_object_store_memory", [150 * 1024 * 1024], indirect=True) -def test_redis_lru_with_set(ray_start_object_store_memory): - x = np.zeros(8 * 10**7, dtype=np.uint8) - x_id = ray.put(x, weakref=True) - - # Remove the object from the object table to simulate Redis LRU eviction. - removed = False - start_time = time.time() - while time.time() < start_time + 10: - if ray.state.state.redis_clients[0].delete(b"OBJECT" + - x_id.binary()) == 1: - removed = True - break - assert removed - - # Now evict the object from the object store. - ray.put(x) # This should not crash. - - def test_decorated_function(ray_start_regular): def function_invocation_decorator(f): def new_f(args, kwargs): diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index 4961a99bc..192da9b00 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -329,6 +329,13 @@ class ObjectInfoAccessor { const ObjectID &object_id, const MultiItemCallback &callback) = 0; + /// Get all object's locations from GCS asynchronously. + /// + /// \param callback Callback that will be called after lookup finished. + /// \return Status + virtual Status AsyncGetAll( + const MultiItemCallback &callback) = 0; + /// Add location of object to GCS asynchronously. /// /// \param object_id The ID of object which location will be added to GCS. diff --git a/src/ray/gcs/gcs_client/global_state_accessor.cc b/src/ray/gcs/gcs_client/global_state_accessor.cc index 977e3081f..1ee4cee2a 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.cc +++ b/src/ray/gcs/gcs_client/global_state_accessor.cc @@ -97,5 +97,45 @@ std::vector GlobalStateAccessor::GetAllProfileInfo() { return profile_table_data; } +std::vector GlobalStateAccessor::GetAllObjectInfo() { + std::vector all_object_info; + std::promise promise; + auto on_done = [&all_object_info, &promise]( + const Status &status, + const std::vector &result) { + RAY_CHECK_OK(status); + for (auto &data : result) { + all_object_info.push_back(data.SerializeAsString()); + } + promise.set_value(true); + }; + RAY_CHECK_OK(gcs_client_->Objects().AsyncGetAll(on_done)); + promise.get_future().get(); + return all_object_info; +} + +std::unique_ptr GlobalStateAccessor::GetObjectInfo( + const ObjectID &object_id) { + std::unique_ptr object_info; + std::promise promise; + auto on_done = [object_id, &object_info, &promise]( + const Status &status, + const std::vector &result) { + RAY_CHECK_OK(status); + if (!result.empty()) { + rpc::ObjectLocationInfo object_location_info; + object_location_info.set_object_id(object_id.Binary()); + for (auto &data : result) { + object_location_info.add_locations()->CopyFrom(data); + } + object_info.reset(new std::string(object_location_info.SerializeAsString())); + } + promise.set_value(true); + }; + RAY_CHECK_OK(gcs_client_->Objects().AsyncGetLocations(object_id, on_done)); + promise.get_future().get(); + return object_info; +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client/global_state_accessor.h b/src/ray/gcs/gcs_client/global_state_accessor.h index 0191a8261..3d4e9b9df 100644 --- a/src/ray/gcs/gcs_client/global_state_accessor.h +++ b/src/ray/gcs/gcs_client/global_state_accessor.h @@ -46,8 +46,8 @@ class GlobalStateAccessor { /// Get information of all jobs from GCS Service. /// - /// \return All job info. To support multi-language, we serialized each JobTableData and - /// returned the serialized string. Where used, it needs to be deserialized with + /// \return All job info. To support multi-language, we serialize each JobTableData and + /// return the serialized string. Where used, it needs to be deserialized with /// protobuf function. std::vector GetAllJobInfo(); @@ -58,6 +58,21 @@ class GlobalStateAccessor { /// deserialized with protobuf function. std::vector GetAllProfileInfo(); + /// Get information of all objects from GCS Service. + /// + /// \return All object info. To support multi-language, we serialize each + /// ObjectTableData and return the serialized string. Where used, it needs to be + /// deserialized with protobuf function. + std::vector GetAllObjectInfo(); + + /// Get information of an object from GCS Service. + /// + /// \param object_id The ID of object to look up in the GCS Service. + /// \return Object info. To support multi-language, we serialize each ObjectTableData + /// and return the serialized string. Where used, it needs to be deserialized with + /// protobuf function. + std::unique_ptr GetObjectInfo(const ObjectID &object_id); + private: /// Whether this client is connected to gcs server. bool is_connected_{false}; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index c822332b2..b6676f452 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -968,6 +968,24 @@ Status ServiceBasedObjectInfoAccessor::AsyncGetLocations( return Status::OK(); } +Status ServiceBasedObjectInfoAccessor::AsyncGetAll( + const MultiItemCallback &callback) { + RAY_LOG(DEBUG) << "Getting all object locations."; + rpc::GetAllObjectLocationsRequest request; + client_impl_->GetGcsRpcClient().GetAllObjectLocations( + request, + [callback](const Status &status, const rpc::GetAllObjectLocationsReply &reply) { + std::vector result; + result.reserve((reply.object_location_info_list_size())); + for (int index = 0; index < reply.object_location_info_list_size(); ++index) { + result.emplace_back(reply.object_location_info_list(index)); + } + callback(status, result); + RAY_LOG(DEBUG) << "Finished getting all object locations, status = " << status; + }); + return Status::OK(); +} + Status ServiceBasedObjectInfoAccessor::AsyncAddLocation(const ObjectID &object_id, const ClientID &node_id, const StatusCallback &callback) { diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index 51f0026b0..ddd2b8d7d 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -266,6 +266,8 @@ class ServiceBasedObjectInfoAccessor : public ObjectInfoAccessor { const ObjectID &object_id, const MultiItemCallback &callback) override; + Status AsyncGetAll(const MultiItemCallback &callback) override; + Status AsyncAddLocation(const ObjectID &object_id, const ClientID &node_id, const StatusCallback &callback) override; diff --git a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc index 34f479ab7..5dacd564f 100644 --- a/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc +++ b/src/ray/gcs/gcs_client/test/global_state_accessor_test.cc @@ -120,6 +120,28 @@ TEST_F(GlobalStateAccessorTest, TestProfileTable) { ASSERT_EQ(global_state_->GetAllProfileInfo().size(), profile_count); } +TEST_F(GlobalStateAccessorTest, TestObjectTable) { + int object_count = 1; + ASSERT_EQ(global_state_->GetAllObjectInfo().size(), 0); + std::vector object_ids; + object_ids.reserve(object_count); + for (int index = 0; index < object_count; ++index) { + ObjectID object_id = ObjectID::FromRandom(); + object_ids.emplace_back(object_id); + ClientID node_id = ClientID::FromRandom(); + std::promise promise; + RAY_CHECK_OK(gcs_client_->Objects().AsyncAddLocation( + object_id, node_id, + [&promise](Status status) { promise.set_value(status.ok()); })); + WaitReady(promise.get_future(), timeout_ms_); + } + ASSERT_EQ(global_state_->GetAllObjectInfo().size(), object_count); + + for (auto &object_id : object_ids) { + ASSERT_TRUE(global_state_->GetObjectInfo(object_id)); + } +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/gcs/gcs_server/gcs_object_manager.cc b/src/ray/gcs/gcs_server/gcs_object_manager.cc new file mode 100644 index 000000000..770ce6cfd --- /dev/null +++ b/src/ray/gcs/gcs_server/gcs_object_manager.cc @@ -0,0 +1,271 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gcs_object_manager.h" +#include "ray/gcs/pb_util.h" + +namespace ray { + +namespace gcs { + +void GcsObjectManager::HandleGetObjectLocations( + const rpc::GetObjectLocationsRequest &request, rpc::GetObjectLocationsReply *reply, + rpc::SendReplyCallback send_reply_callback) { + ObjectID object_id = ObjectID::FromBinary(request.object_id()); + RAY_LOG(DEBUG) << "Getting object locations, job id = " << object_id.TaskId().JobId() + << ", object id = " << object_id; + auto object_locations = GetObjectLocations(object_id); + for (auto &node_id : object_locations) { + rpc::ObjectTableData object_table_data; + object_table_data.set_manager(node_id.Binary()); + reply->add_object_table_data_list()->CopyFrom(object_table_data); + } + RAY_LOG(DEBUG) << "Finished getting object locations, job id = " + << object_id.TaskId().JobId() << ", object id = " << object_id; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); +} + +void GcsObjectManager::HandleGetAllObjectLocations( + const rpc::GetAllObjectLocationsRequest &request, + rpc::GetAllObjectLocationsReply *reply, rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Getting all object locations."; + absl::MutexLock lock(&mutex_); + for (auto &item : object_to_locations_) { + rpc::ObjectLocationInfo object_location_info; + object_location_info.set_object_id(item.first.Binary()); + for (auto &node_id : item.second) { + rpc::ObjectTableData object_table_data; + object_table_data.set_manager(node_id.Binary()); + object_location_info.add_locations()->CopyFrom(object_table_data); + } + reply->add_object_location_info_list()->CopyFrom(object_location_info); + } + RAY_LOG(DEBUG) << "Finished getting all object locations."; + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); +} + +void GcsObjectManager::HandleAddObjectLocation( + const rpc::AddObjectLocationRequest &request, rpc::AddObjectLocationReply *reply, + rpc::SendReplyCallback send_reply_callback) { + ObjectID object_id = ObjectID::FromBinary(request.object_id()); + ClientID node_id = ClientID::FromBinary(request.node_id()); + RAY_LOG(DEBUG) << "Adding object location, job id = " << object_id.TaskId().JobId() + << ", object id = " << object_id << ", node id = " << node_id; + AddObjectLocationInCache(object_id, node_id); + + auto on_done = [this, object_id, node_id, reply, + send_reply_callback](const Status &status) { + if (status.ok()) { + RAY_CHECK_OK(gcs_pub_sub_->Publish( + OBJECT_CHANNEL, object_id.Hex(), + gcs::CreateObjectLocationChange(node_id, true)->SerializeAsString(), nullptr)); + RAY_LOG(DEBUG) << "Finished adding object location, job id = " + << object_id.TaskId().JobId() << ", object id = " << object_id + << ", node id = " << node_id << ", task id = " << object_id.TaskId(); + } else { + RAY_LOG(ERROR) << "Failed to add object location: " << status.ToString() + << ", job id = " << object_id.TaskId().JobId() + << ", object id = " << object_id << ", node id = " << node_id; + } + // We should only reply after the update is written to storage. + // So, if GCS server crashes before writing storage, GCS client will retry this + // request. + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }; + + absl::MutexLock lock(&mutex_); + auto object_location_set = + GetObjectLocationSet(object_id, /* create_if_not_exist */ false); + auto object_table_data_list = GenObjectTableDataList(*object_location_set); + Status status = + gcs_table_storage_->ObjectTable().Put(object_id, *object_table_data_list, on_done); + if (!status.ok()) { + on_done(status); + } +} + +void GcsObjectManager::HandleRemoveObjectLocation( + const rpc::RemoveObjectLocationRequest &request, + rpc::RemoveObjectLocationReply *reply, rpc::SendReplyCallback send_reply_callback) { + ObjectID object_id = ObjectID::FromBinary(request.object_id()); + ClientID node_id = ClientID::FromBinary(request.node_id()); + RAY_LOG(DEBUG) << "Removing object location, job id = " << object_id.TaskId().JobId() + << ", object id = " << object_id << ", node id = " << node_id; + RemoveObjectLocationInCache(object_id, node_id); + + auto on_done = [this, object_id, node_id, reply, + send_reply_callback](const Status &status) { + if (status.ok()) { + RAY_CHECK_OK(gcs_pub_sub_->Publish( + OBJECT_CHANNEL, object_id.Hex(), + gcs::CreateObjectLocationChange(node_id, false)->SerializeAsString(), nullptr)); + RAY_LOG(DEBUG) << "Finished removing object location, job id = " + << object_id.TaskId().JobId() << ", object id = " << object_id + << ", node id = " << node_id; + } else { + RAY_LOG(ERROR) << "Failed to remove object location: " << status.ToString() + << ", job id = " << object_id.TaskId().JobId() + << ", object id = " << object_id << ", node id = " << node_id; + } + // We should only reply after the update is written to storage. + // So, if GCS server crashes before writing storage, GCS client will retry this + // request. + GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); + }; + + absl::MutexLock lock(&mutex_); + auto object_location_set = + GetObjectLocationSet(object_id, /* create_if_not_exist */ false); + Status status; + if (object_location_set != nullptr) { + auto object_table_data_list = GenObjectTableDataList(*object_location_set); + status = gcs_table_storage_->ObjectTable().Put(object_id, *object_table_data_list, + on_done); + } else { + status = gcs_table_storage_->ObjectTable().Delete(object_id, on_done); + } + + if (!status.ok()) { + on_done(status); + } +} + +void GcsObjectManager::AddObjectsLocation( + const ClientID &node_id, const absl::flat_hash_set &object_ids) { + // TODO(micafan) Optimize the lock when necessary. + // Maybe use read/write lock. Or reduce the granularity of the lock. + absl::MutexLock lock(&mutex_); + + auto *objects_on_node = GetObjectSetByNode(node_id, /* create_if_not_exist */ true); + objects_on_node->insert(object_ids.begin(), object_ids.end()); + + for (const auto &object_id : object_ids) { + auto *object_locations = + GetObjectLocationSet(object_id, /* create_if_not_exist */ true); + object_locations->emplace(node_id); + } +} + +void GcsObjectManager::AddObjectLocationInCache(const ObjectID &object_id, + const ClientID &node_id) { + absl::MutexLock lock(&mutex_); + + auto *objects_on_node = GetObjectSetByNode(node_id, /* create_if_not_exist */ true); + objects_on_node->emplace(object_id); + + auto *object_locations = + GetObjectLocationSet(object_id, /* create_if_not_exist */ true); + object_locations->emplace(node_id); +} + +absl::flat_hash_set GcsObjectManager::GetObjectLocations( + const ObjectID &object_id) { + absl::MutexLock lock(&mutex_); + + auto *object_locations = GetObjectLocationSet(object_id); + if (object_locations) { + return *object_locations; + } + return absl::flat_hash_set{}; +} + +void GcsObjectManager::OnNodeRemoved(const ClientID &node_id) { + absl::MutexLock lock(&mutex_); + + ObjectSet objects_on_node; + auto it = node_to_objects_.find(node_id); + if (it != node_to_objects_.end()) { + objects_on_node.swap(it->second); + node_to_objects_.erase(it); + } + + if (objects_on_node.empty()) { + return; + } + + for (const auto &object_id : objects_on_node) { + auto *object_locations = GetObjectLocationSet(object_id); + if (object_locations) { + object_locations->erase(node_id); + if (object_locations->empty()) { + object_to_locations_.erase(object_id); + } + } + } +} + +void GcsObjectManager::RemoveObjectLocationInCache(const ObjectID &object_id, + const ClientID &node_id) { + absl::MutexLock lock(&mutex_); + + auto *object_locations = GetObjectLocationSet(object_id); + if (object_locations) { + object_locations->erase(node_id); + if (object_locations->empty()) { + object_to_locations_.erase(object_id); + } + } + + auto *objects_on_node = GetObjectSetByNode(node_id); + if (objects_on_node) { + objects_on_node->erase(object_id); + if (objects_on_node->empty()) { + node_to_objects_.erase(node_id); + } + } +} + +GcsObjectManager::LocationSet *GcsObjectManager::GetObjectLocationSet( + const ObjectID &object_id, bool create_if_not_exist) { + LocationSet *object_locations = nullptr; + + auto it = object_to_locations_.find(object_id); + if (it != object_to_locations_.end()) { + object_locations = &it->second; + } else if (create_if_not_exist) { + auto ret = object_to_locations_.emplace(std::make_pair(object_id, LocationSet{})); + RAY_CHECK(ret.second); + object_locations = &(ret.first->second); + } + + return object_locations; +} + +GcsObjectManager::ObjectSet *GcsObjectManager::GetObjectSetByNode( + const ClientID &node_id, bool create_if_not_exist) { + ObjectSet *objects_on_node = nullptr; + + auto it = node_to_objects_.find(node_id); + if (it != node_to_objects_.end()) { + objects_on_node = &it->second; + } else if (create_if_not_exist) { + auto ret = node_to_objects_.emplace(std::make_pair(node_id, ObjectSet{})); + RAY_CHECK(ret.second); + objects_on_node = &(ret.first->second); + } + return objects_on_node; +} + +std::shared_ptr GcsObjectManager::GenObjectTableDataList( + const GcsObjectManager::LocationSet &location_set) const { + auto object_table_data_list = std::make_shared(); + for (auto &node_id : location_set) { + object_table_data_list->add_items()->set_manager(node_id.Binary()); + } + return object_table_data_list; +} + +} // namespace gcs + +} // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_object_manager.h b/src/ray/gcs/gcs_server/gcs_object_manager.h new file mode 100644 index 000000000..9fce7a0d0 --- /dev/null +++ b/src/ray/gcs/gcs_server/gcs_object_manager.h @@ -0,0 +1,138 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RAY_GCS_OBJECT_MANAGER_H +#define RAY_GCS_OBJECT_MANAGER_H + +#include "gcs_node_manager.h" +#include "gcs_table_storage.h" +#include "ray/gcs/pubsub/gcs_pub_sub.h" +#include "ray/gcs/redis_gcs_client.h" + +namespace ray { + +namespace gcs { + +class GcsObjectManager : public rpc::ObjectInfoHandler { + public: + explicit GcsObjectManager(std::shared_ptr gcs_table_storage, + std::shared_ptr &gcs_pub_sub, + gcs::GcsNodeManager &gcs_node_manager) + : gcs_table_storage_(std::move(gcs_table_storage)), gcs_pub_sub_(gcs_pub_sub) { + gcs_node_manager.AddNodeRemovedListener( + [this](const std::shared_ptr &node) { + // All of the related actors should be reconstructed when a node is removed from + // the GCS. + OnNodeRemoved(ClientID::FromBinary(node->node_id())); + }); + } + + void HandleGetObjectLocations(const rpc::GetObjectLocationsRequest &request, + rpc::GetObjectLocationsReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleGetAllObjectLocations(const rpc::GetAllObjectLocationsRequest &request, + rpc::GetAllObjectLocationsReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleAddObjectLocation(const rpc::AddObjectLocationRequest &request, + rpc::AddObjectLocationReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleRemoveObjectLocation(const rpc::RemoveObjectLocationRequest &request, + rpc::RemoveObjectLocationReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + protected: + typedef absl::flat_hash_set LocationSet; + + /// Add a location of objects. + /// If the GCS server restarts, this function is used to reload data from storage. + /// + /// \param node_id The object location that will be added. + /// \param object_ids The ids of objects which location will be added. + void AddObjectsLocation(const ClientID &node_id, + const absl::flat_hash_set &object_ids) + LOCKS_EXCLUDED(mutex_); + + /// Add a new location for the given object in local cache. + /// + /// \param object_id The id of object. + /// \param node_id The node id of the new location. + void AddObjectLocationInCache(const ObjectID &object_id, const ClientID &node_id) + LOCKS_EXCLUDED(mutex_); + + /// Get all locations of the given object. + /// + /// \param object_id The id of object to lookup. + /// \return Object locations. + LocationSet GetObjectLocations(const ObjectID &object_id) LOCKS_EXCLUDED(mutex_); + + /// Handler if a node is removed. + /// + /// \param node_id The node that will be removed. + void OnNodeRemoved(const ClientID &node_id) LOCKS_EXCLUDED(mutex_); + + /// Remove object's location. + /// + /// \param object_id The id of the object which location will be removed. + /// \param node_id The location that will be removed. + void RemoveObjectLocationInCache(const ObjectID &object_id, const ClientID &node_id) + LOCKS_EXCLUDED(mutex_); + + private: + typedef absl::flat_hash_set ObjectSet; + + std::shared_ptr GenObjectTableDataList( + const GcsObjectManager::LocationSet &location_set) const; + + /// Get object locations by object id from map. + /// Will create it if not exist and the flag create_if_not_exist is set to true. + /// + /// \param object_id The id of object to lookup. + /// \param create_if_not_exist Whether to create a new one if not exist. + /// \return LocationSet * + GcsObjectManager::LocationSet *GetObjectLocationSet(const ObjectID &object_id, + bool create_if_not_exist = false) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + /// Get objects by node id from map. + /// Will create it if not exist and the flag create_if_not_exist is set to true. + /// + /// \param node_id The id of node to lookup. + /// \param create_if_not_exist Whether to create a new one if not exist. + /// \return ObjectSet * + GcsObjectManager::ObjectSet *GetObjectSetByNode(const ClientID &node_id, + bool create_if_not_exist = false) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + mutable absl::Mutex mutex_; + + /// Mapping from object id to object locations. + /// This is the local cache of objects' locations in the storage. + absl::flat_hash_map object_to_locations_ GUARDED_BY(mutex_); + + /// Mapping from node id to objects that held by the node. + /// This is the local cache of nodes' objects in the storage. + absl::flat_hash_map node_to_objects_ GUARDED_BY(mutex_); + + std::shared_ptr gcs_table_storage_; + std::shared_ptr gcs_pub_sub_; +}; + +} // namespace gcs + +} // namespace ray + +#endif // RAY_GCS_OBJECT_MANAGER_H diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 9f5bad70f..2a6829447 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -18,8 +18,8 @@ #include "error_info_handler_impl.h" #include "gcs_actor_manager.h" #include "gcs_node_manager.h" +#include "gcs_object_manager.h" #include "job_info_handler_impl.h" -#include "object_info_handler_impl.h" #include "ray/common/network_util.h" #include "ray/common/ray_config.h" #include "stats_handler_impl.h" @@ -205,8 +205,8 @@ std::unique_ptr GcsServer::InitActorInfoHandler() { } std::unique_ptr GcsServer::InitObjectInfoHandler() { - return std::unique_ptr( - new rpc::DefaultObjectInfoHandler(*redis_gcs_client_, gcs_pub_sub_)); + return std::unique_ptr( + new GcsObjectManager(gcs_table_storage_, gcs_pub_sub_, *gcs_node_manager_)); } void GcsServer::StoreGcsServerAddressInRedis() { diff --git a/src/ray/gcs/gcs_server/object_info_handler_impl.cc b/src/ray/gcs/gcs_server/object_info_handler_impl.cc deleted file mode 100644 index f7a71f07f..000000000 --- a/src/ray/gcs/gcs_server/object_info_handler_impl.cc +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "object_info_handler_impl.h" -#include "ray/gcs/pb_util.h" -#include "ray/util/logging.h" - -namespace ray { -namespace rpc { - -void DefaultObjectInfoHandler::HandleGetObjectLocations( - const GetObjectLocationsRequest &request, GetObjectLocationsReply *reply, - SendReplyCallback send_reply_callback) { - ObjectID object_id = ObjectID::FromBinary(request.object_id()); - RAY_LOG(DEBUG) << "Getting object locations, job id = " << object_id.TaskId().JobId() - << ", object id = " << object_id; - - auto on_done = [reply, object_id, send_reply_callback]( - const Status &status, - const std::vector &result) { - if (status.ok()) { - for (const rpc::ObjectTableData &object_table_data : result) { - reply->add_object_table_data_list()->CopyFrom(object_table_data); - } - RAY_LOG(DEBUG) << "Finished getting object locations, job id = " - << object_id.TaskId().JobId() << ", object id = " << object_id; - } else { - RAY_LOG(ERROR) << "Failed to get object locations: " << status.ToString() - << ", job id = " << object_id.TaskId().JobId() - << ", object id = " << object_id; - } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); - }; - - Status status = gcs_client_.Objects().AsyncGetLocations(object_id, on_done); - if (!status.ok()) { - on_done(status, std::vector()); - } -} - -void DefaultObjectInfoHandler::HandleAddObjectLocation( - const AddObjectLocationRequest &request, AddObjectLocationReply *reply, - SendReplyCallback send_reply_callback) { - ObjectID object_id = ObjectID::FromBinary(request.object_id()); - ClientID node_id = ClientID::FromBinary(request.node_id()); - RAY_LOG(DEBUG) << "Adding object location, job id = " << object_id.TaskId().JobId() - << ", object id = " << object_id << ", node id = " << node_id; - - auto on_done = [this, object_id, node_id, reply, - send_reply_callback](const Status &status) { - if (status.ok()) { - RAY_CHECK_OK(gcs_pub_sub_->Publish( - OBJECT_CHANNEL, object_id.Hex(), - gcs::CreateObjectLocationChange(node_id, true)->SerializeAsString(), nullptr)); - RAY_LOG(DEBUG) << "Finished adding object location, job id = " - << object_id.TaskId().JobId() << ", object id = " << object_id - << ", node id = " << node_id << ", task id = " << object_id.TaskId(); - } else { - RAY_LOG(ERROR) << "Failed to add object location: " << status.ToString() - << ", job id = " << object_id.TaskId().JobId() - << ", object id = " << object_id << ", node id = " << node_id; - } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); - }; - - Status status = gcs_client_.Objects().AsyncAddLocation(object_id, node_id, on_done); - if (!status.ok()) { - on_done(status); - } -} - -void DefaultObjectInfoHandler::HandleRemoveObjectLocation( - const RemoveObjectLocationRequest &request, RemoveObjectLocationReply *reply, - SendReplyCallback send_reply_callback) { - ObjectID object_id = ObjectID::FromBinary(request.object_id()); - ClientID node_id = ClientID::FromBinary(request.node_id()); - RAY_LOG(DEBUG) << "Removing object location, job id = " << object_id.TaskId().JobId() - << ", object id = " << object_id << ", node id = " << node_id; - - auto on_done = [this, object_id, node_id, reply, - send_reply_callback](const Status &status) { - if (status.ok()) { - RAY_CHECK_OK(gcs_pub_sub_->Publish( - OBJECT_CHANNEL, object_id.Hex(), - gcs::CreateObjectLocationChange(node_id, false)->SerializeAsString(), nullptr)); - RAY_LOG(DEBUG) << "Finished removing object location, job id = " - << object_id.TaskId().JobId() << ", object id = " << object_id - << ", node id = " << node_id; - } else { - RAY_LOG(ERROR) << "Failed to remove object location: " << status.ToString() - << ", job id = " << object_id.TaskId().JobId() - << ", object id = " << object_id << ", node id = " << node_id; - } - GCS_RPC_SEND_REPLY(send_reply_callback, reply, status); - }; - - Status status = gcs_client_.Objects().AsyncRemoveLocation(object_id, node_id, on_done); - if (!status.ok()) { - on_done(status); - } -} - -} // namespace rpc -} // namespace ray diff --git a/src/ray/gcs/gcs_server/object_info_handler_impl.h b/src/ray/gcs/gcs_server/object_info_handler_impl.h deleted file mode 100644 index f714cef8c..000000000 --- a/src/ray/gcs/gcs_server/object_info_handler_impl.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2017 The Ray Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef RAY_GCS_OBJECT_INFO_HANDLER_IMPL_H -#define RAY_GCS_OBJECT_INFO_HANDLER_IMPL_H - -#include "ray/gcs/pubsub/gcs_pub_sub.h" -#include "ray/gcs/redis_gcs_client.h" -#include "ray/rpc/gcs_server/gcs_rpc_server.h" - -namespace ray { -namespace rpc { - -/// This implementation class of `ObjectInfoHandler`. -class DefaultObjectInfoHandler : public rpc::ObjectInfoHandler { - public: - explicit DefaultObjectInfoHandler(gcs::RedisGcsClient &gcs_client, - std::shared_ptr &gcs_pub_sub) - : gcs_client_(gcs_client), gcs_pub_sub_(gcs_pub_sub) {} - - void HandleGetObjectLocations(const GetObjectLocationsRequest &request, - GetObjectLocationsReply *reply, - SendReplyCallback send_reply_callback) override; - - void HandleAddObjectLocation(const AddObjectLocationRequest &request, - AddObjectLocationReply *reply, - SendReplyCallback send_reply_callback) override; - - void HandleRemoveObjectLocation(const RemoveObjectLocationRequest &request, - RemoveObjectLocationReply *reply, - SendReplyCallback send_reply_callback) override; - - private: - gcs::RedisGcsClient &gcs_client_; - std::shared_ptr gcs_pub_sub_; -}; - -} // namespace rpc -} // namespace ray - -#endif // RAY_GCS_OBJECT_INFO_HANDLER_IMPL_H diff --git a/src/ray/gcs/gcs_server/object_locator.cc b/src/ray/gcs/gcs_server/object_locator.cc deleted file mode 100644 index 3c2e4f794..000000000 --- a/src/ray/gcs/gcs_server/object_locator.cc +++ /dev/null @@ -1,129 +0,0 @@ -#include "ray/gcs/gcs_server/object_locator.h" - -namespace ray { - -namespace gcs { - -ObjectLocator::ObjectLocator() {} - -ObjectLocator::~ObjectLocator() {} - -void ObjectLocator::AddObjectsLocation(const ClientID &node_id, - const std::unordered_set &object_ids) { - // TODO(micafan) Optimize the lock when necessary. - // Maybe use read/write lock. Or reduce the granularity of the lock. - absl::MutexLock lock(&mutex_); - - auto *node_hold_objects = GetNodeHoldObjectSet(node_id, /* create_if_not_exist */ true); - node_hold_objects->insert(object_ids.begin(), object_ids.end()); - - for (const auto &object_id : object_ids) { - auto *object_locations = - GetObjectLocationSet(object_id, /* create_if_not_exist */ true); - object_locations->emplace(node_id); - } -} - -void ObjectLocator::AddObjectLocation(const ObjectID &object_id, - const ClientID &node_id) { - absl::MutexLock lock(&mutex_); - - auto *node_hold_objects = GetNodeHoldObjectSet(node_id, /* create_if_not_exist */ true); - node_hold_objects->emplace(object_id); - - auto *object_locations = - GetObjectLocationSet(object_id, /* create_if_not_exist */ true); - object_locations->emplace(node_id); -} - -std::unordered_set ObjectLocator::GetObjectLocations( - const ObjectID &object_id) { - absl::MutexLock lock(&mutex_); - - auto *object_locations = GetObjectLocationSet(object_id); - if (object_locations) { - return *object_locations; - } - return std::unordered_set{}; -} - -void ObjectLocator::RemoveNode(const ClientID &node_id) { - absl::MutexLock lock(&mutex_); - - ObjectSet node_hold_objects; - auto it = node_to_objects_.find(node_id); - if (it != node_to_objects_.end()) { - node_hold_objects.swap(it->second); - node_to_objects_.erase(it); - } - - if (node_hold_objects.empty()) { - return; - } - - for (const auto &object_id : node_hold_objects) { - auto *object_locations = GetObjectLocationSet(object_id); - if (object_locations) { - object_locations->erase(node_id); - if (object_locations->empty()) { - object_to_locations_.erase(object_id); - } - } - } -} - -void ObjectLocator::RemoveObjectLocation(const ObjectID &object_id, - const ClientID &node_id) { - absl::MutexLock lock(&mutex_); - - auto *object_locations = GetObjectLocationSet(object_id); - if (object_locations) { - object_locations->erase(node_id); - if (object_locations->empty()) { - object_to_locations_.erase(object_id); - } - } - - auto *node_hold_objects = GetNodeHoldObjectSet(node_id); - if (node_hold_objects) { - node_hold_objects->erase(object_id); - if (node_hold_objects->empty()) { - node_to_objects_.erase(node_id); - } - } -} - -ObjectLocator::LocationSet *ObjectLocator::GetObjectLocationSet( - const ObjectID &object_id, bool create_if_not_exist) { - LocationSet *object_locations = nullptr; - - auto it = object_to_locations_.find(object_id); - if (it != object_to_locations_.end()) { - object_locations = &it->second; - } else if (create_if_not_exist) { - auto ret = object_to_locations_.emplace(std::make_pair(object_id, LocationSet{})); - RAY_CHECK(ret.second); - object_locations = &(ret.first->second); - } - - return object_locations; -} - -ObjectLocator::ObjectSet *ObjectLocator::GetNodeHoldObjectSet(const ClientID &node_id, - bool create_if_not_exist) { - ObjectSet *node_hold_objects = nullptr; - - auto it = node_to_objects_.find(node_id); - if (it != node_to_objects_.end()) { - node_hold_objects = &it->second; - } else if (create_if_not_exist) { - auto ret = node_to_objects_.emplace(std::make_pair(node_id, ObjectSet{})); - RAY_CHECK(ret.second); - node_hold_objects = &(ret.first->second); - } - return node_hold_objects; -} - -} // namespace gcs - -} // namespace ray diff --git a/src/ray/gcs/gcs_server/object_locator.h b/src/ray/gcs/gcs_server/object_locator.h deleted file mode 100644 index b074908e8..000000000 --- a/src/ray/gcs/gcs_server/object_locator.h +++ /dev/null @@ -1,93 +0,0 @@ -#ifndef GCS_GCS_SERVER_OBJECT_LOCATOR_H -#define GCS_GCS_SERVER_OBJECT_LOCATOR_H - -#include -#include -#include -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "ray/common/id.h" -#include "ray/util/logging.h" - -namespace ray { - -namespace gcs { - -class ObjectLocator { - public: - ObjectLocator(); - - ~ObjectLocator(); - - /// Add a location of objects. - /// - /// \param node_id The object location that will be added. - /// \param object_ids The ids of objects which location will be added. - void AddObjectsLocation(const ClientID &node_id, - const std::unordered_set &object_ids) - LOCKS_EXCLUDED(mutex_); - - /// Add a location of an object. - /// - /// \param object_id The id of object which location will be added. - /// \param node_id The object location that will be added. - void AddObjectLocation(const ObjectID &object_id, const ClientID &node_id) - LOCKS_EXCLUDED(mutex_); - - /// Get object's locations. - /// - /// \param object_id The id of object to lookup. - /// \return Object locations. - std::unordered_set GetObjectLocations(const ObjectID &object_id) - LOCKS_EXCLUDED(mutex_); - - /// Remove a node. - /// - /// \param node_id The node that will be removed. - void RemoveNode(const ClientID &node_id) LOCKS_EXCLUDED(mutex_); - - /// Remove object's location. - /// - /// \param object_id The id of the object which location will be removed. - /// \param node_id The location that will be removed. - void RemoveObjectLocation(const ObjectID &object_id, const ClientID &node_id) - LOCKS_EXCLUDED(mutex_); - - private: - typedef std::unordered_set LocationSet; - typedef std::unordered_set ObjectSet; - - /// Get object locations by object id from map. - /// Will create it if not exist and the flag create_if_not_exist is set to true. - /// - /// \param object_id The id of object to lookup. - /// \param create_if_not_exist Whether to create a new one if not exist. - /// \return LocationSet * - ObjectLocator::LocationSet *GetObjectLocationSet(const ObjectID &object_id, - bool create_if_not_exist = false) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - /// Get objects by node id from map. - /// Will create it if not exist and the flag create_if_not_exist is set to true. - /// - /// \param node_id The id of node to lookup. - /// \param create_if_not_exist Whether to create a new one if not exist. - /// \return ObjectSet * - ObjectLocator::ObjectSet *GetNodeHoldObjectSet(const ClientID &node_id, - bool create_if_not_exist = false) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - mutable absl::Mutex mutex_; - - /// Mapping from object id to object locations. - std::unordered_map object_to_locations_ GUARDED_BY(mutex_); - - /// Mapping from node id to objects that held by the node. - std::unordered_map node_to_objects_ GUARDED_BY(mutex_); -}; - -} // namespace gcs - -} // namespace ray - -#endif // GCS_GCS_SERVER_OBJECT_LOCATOR_H diff --git a/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc new file mode 100644 index 000000000..25afaf348 --- /dev/null +++ b/src/ray/gcs/gcs_server/test/gcs_object_manager_test.cc @@ -0,0 +1,154 @@ +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/gcs/gcs_server/gcs_object_manager.h" +#include +#include +#include "gtest/gtest.h" + +namespace ray { + +class MockedGcsObjectManager : public gcs::GcsObjectManager { + public: + explicit MockedGcsObjectManager(std::shared_ptr gcs_table_storage, + std::shared_ptr &gcs_pub_sub, + gcs::GcsNodeManager &gcs_node_manager) + : gcs::GcsObjectManager(gcs_table_storage, gcs_pub_sub, gcs_node_manager) {} + + public: + void AddObjectsLocation(const ClientID &node_id, + const absl::flat_hash_set &object_ids) { + gcs::GcsObjectManager::AddObjectsLocation(node_id, object_ids); + } + + void AddObjectLocationInCache(const ObjectID &object_id, const ClientID &node_id) { + gcs::GcsObjectManager::AddObjectLocationInCache(object_id, node_id); + } + + absl::flat_hash_set GetObjectLocations(const ObjectID &object_id) { + return gcs::GcsObjectManager::GetObjectLocations(object_id); + } + + void OnNodeRemoved(const ClientID &node_id) { + gcs::GcsObjectManager::OnNodeRemoved(node_id); + } + + void RemoveObjectLocationInCache(const ObjectID &object_id, const ClientID &node_id) { + gcs::GcsObjectManager::RemoveObjectLocationInCache(object_id, node_id); + } +}; + +class GcsObjectManagerTest : public ::testing::Test { + public: + void SetUp() override { + gcs_table_storage_ = std::make_shared(io_service_); + gcs_node_manager_ = std::make_shared( + io_service_, node_info_accessor_, error_info_accessor_, gcs_pub_sub_); + gcs_object_manager_ = std::make_shared( + gcs_table_storage_, gcs_pub_sub_, *gcs_node_manager_); + GenTestData(); + } + + void GenTestData() { + for (size_t i = 0; i < object_count_; ++i) { + ObjectID object_id = ObjectID::FromRandom(); + object_ids_.emplace(object_id); + } + for (size_t i = 0; i < node_count_; ++i) { + ClientID node_id = ClientID::FromRandom(); + node_ids_.emplace(node_id); + } + } + + void CheckLocations(const absl::flat_hash_set &locations) { + ASSERT_EQ(locations.size(), node_ids_.size()); + for (const auto &location : locations) { + auto it = node_ids_.find(location); + ASSERT_TRUE(it != node_ids_.end()); + ASSERT_TRUE(location == *it); + } + } + + protected: + boost::asio::io_service io_service_; + GcsServerMocker::MockedNodeInfoAccessor node_info_accessor_; + GcsServerMocker::MockedErrorInfoAccessor error_info_accessor_; + std::shared_ptr gcs_node_manager_; + std::shared_ptr gcs_client_; + std::shared_ptr gcs_pub_sub_; + std::shared_ptr gcs_object_manager_; + std::shared_ptr gcs_table_storage_; + + size_t object_count_{5}; + size_t node_count_{10}; + absl::flat_hash_set object_ids_; + absl::flat_hash_set node_ids_; +}; + +TEST_F(GcsObjectManagerTest, AddObjectsLocationAndGetLocationTest) { + for (const auto &node_id : node_ids_) { + gcs_object_manager_->AddObjectsLocation(node_id, object_ids_); + } + for (const auto &object_id : object_ids_) { + auto locations = gcs_object_manager_->GetObjectLocations(object_id); + CheckLocations(locations); + } +} + +TEST_F(GcsObjectManagerTest, AddObjectLocationInCacheTest) { + for (const auto &object_id : object_ids_) { + for (const auto &node_id : node_ids_) { + gcs_object_manager_->AddObjectLocationInCache(object_id, node_id); + } + } + + for (const auto &object_id : object_ids_) { + auto locations = gcs_object_manager_->GetObjectLocations(object_id); + CheckLocations(locations); + } +} + +TEST_F(GcsObjectManagerTest, RemoveNodeTest) { + for (const auto &node_id : node_ids_) { + gcs_object_manager_->AddObjectsLocation(node_id, object_ids_); + } + + gcs_object_manager_->OnNodeRemoved(*node_ids_.begin()); + auto locations = gcs_object_manager_->GetObjectLocations(*object_ids_.begin()); + ASSERT_EQ(locations.size() + 1, node_ids_.size()); + + locations.emplace(*node_ids_.begin()); + ASSERT_EQ(locations.size(), node_ids_.size()); +} + +TEST_F(GcsObjectManagerTest, RemoveObjectLocationTest) { + for (const auto &node_id : node_ids_) { + gcs_object_manager_->AddObjectsLocation(node_id, object_ids_); + } + + gcs_object_manager_->RemoveObjectLocationInCache(*object_ids_.begin(), + *node_ids_.begin()); + auto locations = gcs_object_manager_->GetObjectLocations(*object_ids_.begin()); + ASSERT_EQ(locations.size() + 1, node_ids_.size()); + + locations.emplace(*node_ids_.begin()); + ASSERT_EQ(locations.size(), node_ids_.size()); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/gcs/gcs_server/test/object_locator_test.cc b/src/ray/gcs/gcs_server/test/object_locator_test.cc deleted file mode 100644 index 50e4b5585..000000000 --- a/src/ray/gcs/gcs_server/test/object_locator_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -#include "ray/gcs/gcs_server/object_locator.h" -#include "gtest/gtest.h" - -namespace ray { - -namespace gcs { - -class ObjectLocatorTest : public ::testing::Test { - public: - ObjectLocatorTest() {} - - void SetUp() override { GenTestData(); } - - void GenTestData() { - for (size_t i = 0; i < object_count_; ++i) { - ObjectID object_id = ObjectID::FromRandom(); - object_ids_.emplace(object_id); - } - for (size_t i = 0; i < node_count_; ++i) { - ClientID node_id = ClientID::FromRandom(); - node_ids_.emplace(node_id); - } - } - - void CheckLocations(const std::unordered_set &locations) { - ASSERT_EQ(locations.size(), node_ids_.size()); - for (const auto &location : locations) { - auto it = node_ids_.find(location); - ASSERT_TRUE(it != node_ids_.end()); - ASSERT_TRUE(location == *it); - } - } - - protected: - ObjectLocator object_locator_; - - size_t object_count_{5}; - size_t node_count_{10}; - std::unordered_set object_ids_; - std::unordered_set node_ids_; -}; - -TEST_F(ObjectLocatorTest, AddObjectsLocationAndGetLocationTest) { - for (const auto &node_id : node_ids_) { - object_locator_.AddObjectsLocation(node_id, object_ids_); - } - for (const auto &object_id : object_ids_) { - auto locations = object_locator_.GetObjectLocations(object_id); - CheckLocations(locations); - } -} - -TEST_F(ObjectLocatorTest, AddObjectLocationTest) { - for (const auto &object_id : object_ids_) { - for (const auto &node_id : node_ids_) { - object_locator_.AddObjectLocation(object_id, node_id); - } - } - - for (const auto &object_id : object_ids_) { - auto locations = object_locator_.GetObjectLocations(object_id); - CheckLocations(locations); - } -} - -TEST_F(ObjectLocatorTest, RemoveNodeTest) { - for (const auto &node_id : node_ids_) { - object_locator_.AddObjectsLocation(node_id, object_ids_); - } - - object_locator_.RemoveNode(*node_ids_.begin()); - auto locations = object_locator_.GetObjectLocations(*object_ids_.begin()); - ASSERT_EQ(locations.size() + 1, node_ids_.size()); - - locations.emplace(*node_ids_.begin()); - ASSERT_EQ(locations.size(), node_ids_.size()); -} - -TEST_F(ObjectLocatorTest, RemoveObjectLocationTest) { - for (const auto &node_id : node_ids_) { - object_locator_.AddObjectsLocation(node_id, object_ids_); - } - - object_locator_.RemoveObjectLocation(*object_ids_.begin(), *node_ids_.begin()); - auto locations = object_locator_.GetObjectLocations(*object_ids_.begin()); - ASSERT_EQ(locations.size() + 1, node_ids_.size()); - - locations.emplace(*node_ids_.begin()); - ASSERT_EQ(locations.size(), node_ids_.size()); -} - -} // namespace gcs - -} // namespace ray - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index 15e5d458f..f1d9dbd78 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -274,6 +274,11 @@ class RedisObjectInfoAccessor : public ObjectInfoAccessor { Status AsyncGetLocations(const ObjectID &object_id, const MultiItemCallback &callback) override; + Status AsyncGetAll( + const MultiItemCallback &callback) override { + return Status::NotImplemented("AsyncGetAll not implemented"); + } + Status AsyncAddLocation(const ObjectID &object_id, const ClientID &node_id, const StatusCallback &callback) override; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 2f76b12cd..ff35a40c7 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -314,6 +314,11 @@ message ObjectTableDataList { repeated ObjectTableData items = 1; } +message ObjectLocationInfo { + bytes object_id = 1; + repeated ObjectTableData locations = 2; +} + // A notification message about one object's locations being changed. message ObjectLocationChange { bool is_add = 1; diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 70669f19a..10e1f4cf3 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -141,7 +141,7 @@ message GetActorCheckpointIDReply { // Service for actor info access. service ActorInfoGcsService { - // Create actor via gcs service + // Create actor via gcs service. rpc CreateActor(CreateActorRequest) returns (CreateActorReply); // Get actor data from GCS Service by actor id. rpc GetActorInfo(GetActorInfoRequest) returns (GetActorInfoReply); @@ -256,10 +256,19 @@ message GetObjectLocationsRequest { message GetObjectLocationsReply { GcsStatus status = 1; - // Data of object + // Data of object. repeated ObjectTableData object_table_data_list = 2; } +message GetAllObjectLocationsRequest { +} + +message GetAllObjectLocationsReply { + GcsStatus status = 1; + // Data of object location info. + repeated ObjectLocationInfo object_location_info_list = 2; +} + message AddObjectLocationRequest { // The ID of object which location will be added to GCS Service. bytes object_id = 1; @@ -286,6 +295,9 @@ message RemoveObjectLocationReply { service ObjectInfoGcsService { // Get object's locations from GCS Service. rpc GetObjectLocations(GetObjectLocationsRequest) returns (GetObjectLocationsReply); + // Get all object's locations from GCS Service. + rpc GetAllObjectLocations(GetAllObjectLocationsRequest) + returns (GetAllObjectLocationsReply); // Add location of object to GCS Service. rpc AddObjectLocation(AddObjectLocationRequest) returns (AddObjectLocationReply); // Remove location of object from GCS Service. diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 9da56705c..2584e6803 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -165,6 +165,10 @@ class GcsRpcClient { VOID_GCS_RPC_CLIENT_METHOD(ObjectInfoGcsService, GetObjectLocations, object_info_grpc_client_, ) + /// Get all object's locations from GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(ObjectInfoGcsService, GetAllObjectLocations, + object_info_grpc_client_, ) + /// Add location of object to GCS Service. VOID_GCS_RPC_CLIENT_METHOD(ObjectInfoGcsService, AddObjectLocation, object_info_grpc_client_, ) diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index dfd4cbc18..23d84a3d6 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -243,6 +243,10 @@ class ObjectInfoGcsServiceHandler { GetObjectLocationsReply *reply, SendReplyCallback send_reply_callback) = 0; + virtual void HandleGetAllObjectLocations(const GetAllObjectLocationsRequest &request, + GetAllObjectLocationsReply *reply, + SendReplyCallback send_reply_callback) = 0; + virtual void HandleAddObjectLocation(const AddObjectLocationRequest &request, AddObjectLocationReply *reply, SendReplyCallback send_reply_callback) = 0; @@ -269,6 +273,7 @@ class ObjectInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories) override { OBJECT_INFO_SERVICE_RPC_HANDLER(GetObjectLocations); + OBJECT_INFO_SERVICE_RPC_HANDLER(GetAllObjectLocations); OBJECT_INFO_SERVICE_RPC_HANDLER(AddObjectLocation); OBJECT_INFO_SERVICE_RPC_HANDLER(RemoveObjectLocation); }