diff --git a/BUILD.bazel b/BUILD.bazel index ed0af09ae..8867ff384 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -240,16 +240,6 @@ cc_library( ], ) -cc_test( - name = "common_test", - srcs = glob(["src/ray/common/**/*_test.cc"]), - copts = COPTS, - deps = [ - ":ray_common", - "@com_google_googletest//:gtest_main", - ], -) - cc_binary( name = "raylet", srcs = ["src/ray/raylet/main.cc"], @@ -491,6 +481,16 @@ cc_test( ], ) +cc_test( + name = "sample_test", + srcs = ["src/ray/util/sample_test.cc"], + copts = COPTS, + deps = [ + ":ray_common", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "task_dependency_manager_test", srcs = ["src/ray/raylet/task_dependency_manager_test.cc"], @@ -595,6 +595,7 @@ cc_library( deps = [ ":sha256", "@com_github_google_glog//:glog", + "@com_google_absl//absl/random", "@plasma//:plasma_client", ], ) diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index f1a0818fb..944c4994b 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -88,7 +88,7 @@ def ray_deps_setup(): # This is how diamond dependencies are prevented. git_repository( name = "com_google_absl", - commit = "5b65c4af5107176555b23a638e5947686410ac1f", + commit = "aa844899c937bde5d2b24f276b59997e5b668bde", remote = "https://github.com/abseil/abseil-cpp.git", ) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 93f25ab32..28afc8588 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -725,3 +725,15 @@ cdef class CoreWorker: check_status(self.core_worker.get().SerializeActorHandle( c_actor_id, &output)) return output + + def add_active_object_id(self, ObjectID object_id): + cdef: + CObjectID c_object_id = object_id.native() + with nogil: + self.core_worker.get().AddActiveObjectID(c_object_id) + + def remove_active_object_id(self, ObjectID object_id): + cdef: + CObjectID c_object_id = object_id.native() + with nogil: + self.core_worker.get().RemoveActiveObjectID(c_object_id) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 5c992b810..a1e0aff84 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -90,3 +90,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: CActorID DeserializeAndRegisterActorHandle(const c_string &bytes) CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string *bytes) + void AddActiveObjectID(const CObjectID &object_id) + void RemoveActiveObjectID(const CObjectID &object_id) diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index 4d3a76ec7..fa412f527 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -12,7 +12,7 @@ cdef extern from "ray/common/ray_config.h" nogil: int64_t handler_warning_timeout_ms() const - int64_t heartbeat_timeout_milliseconds() const + int64_t raylet_heartbeat_timeout_milliseconds() const int64_t debug_dump_period_milliseconds() const diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index 7171884ee..2b33995ee 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -10,8 +10,8 @@ cdef class Config: return RayConfig.instance().handler_warning_timeout_ms() @staticmethod - def heartbeat_timeout_milliseconds(): - return RayConfig.instance().heartbeat_timeout_milliseconds() + def raylet_heartbeat_timeout_milliseconds(): + return RayConfig.instance().raylet_heartbeat_timeout_milliseconds() @staticmethod def debug_dump_period_milliseconds(): diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index bf6183978..13ebca964 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -22,6 +22,7 @@ from ray.includes.unique_ids cimport ( CWorkerID, ) +import ray from ray.utils import decode @@ -128,13 +129,32 @@ cdef class UniqueID(BaseID): cdef class ObjectID(BaseID): - cdef CObjectID data - cdef object buffer_ref + cdef: + CObjectID data + object buffer_ref + # Flag indicating whether or not this object ID was added to the set + # of active IDs in the core worker so we know whether we should clean + # it up. + c_bool in_core_worker def __init__(self, id): check_id(id) self.data = CObjectID.FromBinary(id) + worker = ray.worker.global_worker + # TODO(edoakes): there are dummy object IDs being created in + # includes/task.pxi before the core worker is initialized. + if hasattr(worker, "core_worker"): + worker.core_worker.add_active_object_id(self) + self.in_core_worker = True + else: + self.in_core_worker = False + + def __dealloc__(self): + worker = ray.worker.global_worker + if self.in_core_worker and hasattr(worker, "core_worker"): + worker.core_worker.remove_active_object_id(self) + cdef CObjectID native(self): return self.data diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 12c30614f..caa02a82c 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -354,7 +354,8 @@ class Monitor(object): # Wait for a heartbeat interval before processing the next round of # messages. - time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3) + time.sleep( + ray._config.raylet_heartbeat_timeout_milliseconds() * 1e-3) def run(self): try: diff --git a/python/ray/state.py b/python/ray/state.py index 35113f8ce..e75141da4 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -30,7 +30,7 @@ def _parse_client_table(redis_client): Returns: A list of information about the nodes in the cluster. """ - NIL_CLIENT_ID = ray.ObjectID.nil().binary() + NIL_CLIENT_ID = ray.ClientID.nil().binary() message = redis_client.execute_command( "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "", NIL_CLIENT_ID) diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index 95dc0c1b1..e8e482128 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -2,7 +2,7 @@ #define COMMON_PROTOCOL_H #include -#include +#include #include "ray/common/id.h" #include "ray/util/logging.h" @@ -31,6 +31,14 @@ template const std::vector from_flatbuf( const flatbuffers::Vector> &vector); +/// Convert a flatbuffer vector of strings to an unordered_set of unique IDs. +/// +/// @param vector The flatbuffer vector. +/// @return The unordered set of IDs. +template +const std::unordered_set unordered_set_from_flatbuf( + const flatbuffers::Vector> &vector); + /// Convert a flatbuffer of string that concatenated /// unique IDs to a vector of unique IDs. /// @@ -68,6 +76,15 @@ template flatbuffers::Offset>> to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids); +/// Convert an unordered_set of unique IDs to a flatbuffer vector of strings. +/// +/// @param fbb Reference to the flatbuffer builder. +/// @param ids Unordered set of IDs. +/// @return Flatbuffer vector of strings. +template +flatbuffers::Offset>> +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::unordered_set &ids); + /// Convert a flatbuffer string to a std::string. /// /// @param fbb Reference to the flatbuffer builder. @@ -103,6 +120,16 @@ const std::vector from_flatbuf( return ids; } +template +const std::unordered_set unordered_set_from_flatbuf( + const flatbuffers::Vector> &vector) { + std::unordered_set ids; + for (int64_t i = 0; i < vector.Length(); i++) { + ids.insert(from_flatbuf(*vector.Get(i))); + } + return ids; +} + template const std::vector ids_from_flatbuf(const flatbuffers::String &string) { const auto &ids = string_from_flatbuf(string); @@ -151,4 +178,14 @@ to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids) { return fbb.CreateVector(results); } +template +flatbuffers::Offset>> +to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::unordered_set &ids) { + std::vector> results; + for (auto id : ids) { + results.push_back(to_flatbuf(fbb, id)); + } + return fbb.CreateVector(results); +} + #endif diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index db6547d3d..ef62b960c 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -20,8 +20,8 @@ RAY_CONFIG(int64_t, ray_cookie, 0x5241590000000000) /// warning is logged that the handler is taking too long. RAY_CONFIG(int64_t, handler_warning_timeout_ms, 100) -/// The duration between heartbeats. These are sent by the raylet. -RAY_CONFIG(int64_t, heartbeat_timeout_milliseconds, 100) +/// The duration between heartbeats sent by the raylets. +RAY_CONFIG(int64_t, raylet_heartbeat_timeout_milliseconds, 100) /// If a component has not sent a heartbeat in the last num_heartbeats_timeout /// heartbeat intervals, the raylet monitor process will report /// it as dead to the db_client table. @@ -46,6 +46,10 @@ RAY_CONFIG(bool, fair_queueing_enabled, true) /// this many milliseconds. RAY_CONFIG(int64_t, initial_reconstruction_timeout_milliseconds, 10000) +/// The duration between heartbeats sent from the workers to the raylet. +/// If set to a negative value, the heartbeats will not be sent. +RAY_CONFIG(int64_t, worker_heartbeat_timeout_milliseconds, 500) + /// These are used by the worker to set timeouts and to batch requests when /// getting objects. RAY_CONFIG(int64_t, get_timeout_milliseconds, 1000) @@ -85,6 +89,9 @@ RAY_CONFIG(int64_t, max_num_to_reconstruct, 10000) /// regular raylet fetch timeout handler. RAY_CONFIG(int64_t, raylet_fetch_request_size, 10000) +/// The maximum number of active object IDs to report in a heartbeat. +RAY_CONFIG(size_t, raylet_max_active_object_ids, 1000) + /// The duration that we wait after sending a worker SIGTERM before sending /// the worker SIGKILL. RAY_CONFIG(int64_t, kill_worker_timeout_milliseconds, 100) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 7b3927efe..b463ff760 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1,5 +1,6 @@ #include +#include "ray/common/ray_config.h" #include "ray/common/task/task_util.h" #include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" @@ -53,7 +54,8 @@ CoreWorker::CoreWorker( raylet_socket_(raylet_socket), log_dir_(log_dir), worker_context_(worker_type, job_id), - io_work_(io_service_) { + io_work_(io_service_), + heartbeat_timer_(io_service_) { // Initialize logging if log_dir is passed. Otherwise, it must be initialized // and cleaned up by the caller. if (log_dir_ != "") { @@ -108,6 +110,14 @@ CoreWorker::CoreWorker( (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), language_, rpc_server_port)); + // Set timer to periodically send heartbeats containing active object IDs to the raylet. + // If the heartbeat timeout is < 0, the heartbeats are disabled. + if (RayConfig::instance().worker_heartbeat_timeout_milliseconds() >= 0) { + heartbeat_timer_.expires_from_now(boost::asio::chrono::milliseconds( + RayConfig::instance().worker_heartbeat_timeout_milliseconds())); + heartbeat_timer_.async_wait(boost::bind(&CoreWorker::ReportActiveObjectIDs, this)); + } + io_thread_ = std::thread(&CoreWorker::StartIOService, this); // Create an entry for the driver task in the task table. This task is @@ -139,6 +149,46 @@ CoreWorker::CoreWorker( object_interface_->CreateStoreProvider(StoreProviderType::MEMORY))); } +void CoreWorker::AddActiveObjectID(const ObjectID &object_id) { + io_service_.post([this, object_id]() -> void { + active_object_ids_.insert(object_id); + active_object_ids_updated_ = true; + }); +} + +void CoreWorker::RemoveActiveObjectID(const ObjectID &object_id) { + io_service_.post([this, object_id]() -> void { + if (active_object_ids_.erase(object_id)) { + active_object_ids_updated_ = true; + } else { + RAY_LOG(WARNING) << "Tried to erase non-existent object ID" << object_id; + } + }); +} + +void CoreWorker::ReportActiveObjectIDs() { + // Only send a heartbeat when the set of active object IDs has changed because the + // raylet only modifies the set of IDs when it receives a heartbeat. + if (active_object_ids_updated_) { + RAY_LOG(DEBUG) << "Sending " << active_object_ids_.size() << " object IDs to raylet."; + if (active_object_ids_.size() > + RayConfig::instance().raylet_max_active_object_ids()) { + RAY_LOG(WARNING) << active_object_ids_.size() + << "object IDs are currently in scope. " + << "This may lead to required objects being garbage collected."; + } + RAY_CHECK_OK(raylet_client_->ReportActiveObjectIDs(active_object_ids_)); + } + + // Reset the timer from the previous expiration time to avoid drift. + heartbeat_timer_.expires_at( + heartbeat_timer_.expiry() + + boost::asio::chrono::milliseconds( + RayConfig::instance().worker_heartbeat_timeout_milliseconds())); + heartbeat_timer_.async_wait(boost::bind(&CoreWorker::ReportActiveObjectIDs, this)); + active_object_ids_updated_ = false; +} + CoreWorker::~CoreWorker() { io_service_.stop(); io_thread_.join(); @@ -159,8 +209,6 @@ void CoreWorker::Disconnect() { } } -void CoreWorker::StartIOService() { io_service_.run(); } - std::unique_ptr CoreWorker::CreateProfileEvent( const std::string &event_type) { return std::unique_ptr( diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 82452390d..d0cd16c71 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -158,6 +158,14 @@ class CoreWorker { /// \return Status::Invalid if we don't have the specified handle. Status SerializeActorHandle(const ActorID &actor_id, std::string *output) const; + // Add this object ID to the set of active object IDs that is sent to the raylet + // in the heartbeat messsage. + void AddActiveObjectID(const ObjectID &object_id); + + // Remove this object ID from the set of active object IDs that is sent to the raylet + // in the heartbeat messsage. + void RemoveActiveObjectID(const ObjectID &object_id); + private: /// Give this worker a handle to an actor. /// @@ -179,7 +187,9 @@ class CoreWorker { /// \return Status::Invalid if we don't have this actor handle. Status GetActorHandle(const ActorID &actor_id, ActorHandle **actor_handle) const; - void StartIOService(); + void StartIOService() { io_service_.run(); } + + void ReportActiveObjectIDs(); const WorkerType worker_type_; const Language language_; @@ -197,6 +207,9 @@ class CoreWorker { boost::asio::io_service io_service_; /// Keeps the io_service_ alive. boost::asio::io_service::work io_work_; + /// Timer used to periodically send heartbeat containing active object IDs to the + /// raylet. + boost::asio::steady_timer heartbeat_timer_; std::thread io_thread_; std::shared_ptr profiler_; @@ -206,7 +219,14 @@ class CoreWorker { std::unique_ptr object_interface_; /// Map from actor ID to a handle to that actor. - std::unordered_map> actor_handles_; + std::unordered_map > actor_handles_; + + /// Set of object IDs that are in scope in the language worker. + std::unordered_set active_object_ids_; + + /// Indicates whether or not the active_object_ids map has changed since the + /// last time it was sent to the raylet. + bool active_object_ids_updated_ = false; /// Only available if it's not a driver. std::unique_ptr task_execution_interface_; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 8041cc40f..4ea542b4b 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -198,6 +198,8 @@ message HeartbeatTableData { // Aggregate outstanding resource load on this node manager. repeated string resource_load_label = 6; repeated double resource_load_capacity = 7; + // Object IDs that are in use by workers on this node manager's node. + repeated bytes active_object_id = 8; } message HeartbeatBatchTableData { diff --git a/src/ray/protobuf/worker.proto b/src/ray/protobuf/worker.proto index 3c2d30ab6..e8cdb8e38 100644 --- a/src/ray/protobuf/worker.proto +++ b/src/ray/protobuf/worker.proto @@ -4,6 +4,10 @@ package ray.rpc; import "src/ray/protobuf/common.proto"; +message ActiveObjectIDs { + repeated bytes object_ids = 1; +} + message AssignTaskRequest { // The task to be pushed. Task task = 1; diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 705a9fdba..39ef1fc49 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -74,8 +74,10 @@ enum MessageType:int { NotifyActorResumedFromCheckpoint, // A node manager requests to connect to another node manager. ConnectClient, - // Set dynamic custom resource + // Set dynamic custom resource. SetResourceRequest, + // Update the active set of object IDs in use on this worker. + ReportActiveObjectIDs, } table TaskExecutionSpecification { @@ -238,11 +240,16 @@ table ConnectClient { client_id: string; } -table SetResourceRequest{ - // Name of the resource to be set +table SetResourceRequest { + // Name of the resource to be set. resource_name: string; - // Capacity of the resource to be set + // Capacity of the resource to be set. capacity: double; - // Client ID where this resource will be set + // Client ID where this resource will be set. client_id: string; } + +table ReportActiveObjectIDs { + // Object IDs that are active in the worker. + object_ids: [string]; +} diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index ead799574..05f53b272 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -120,7 +120,7 @@ int main(int argc, char *argv[]) { } node_manager_config.heartbeat_period_ms = - RayConfig::instance().heartbeat_timeout_milliseconds(); + RayConfig::instance().raylet_heartbeat_timeout_milliseconds(); node_manager_config.debug_dump_period_ms = RayConfig::instance().debug_dump_period_milliseconds(); node_manager_config.fair_queueing_enabled = diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 2113c1bb1..202d3e40e 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -94,7 +94,7 @@ void Monitor::Tick() { } auto heartbeat_period = boost::posix_time::milliseconds( - RayConfig::instance().heartbeat_timeout_milliseconds()); + RayConfig::instance().raylet_heartbeat_timeout_milliseconds()); heartbeat_timer_.expires_from_now(heartbeat_period); heartbeat_timer_.async_wait([this](const boost::system::error_code &error) { RAY_CHECK(!error); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index c191ca0b1..2b0aae0da 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -9,6 +9,7 @@ #include "ray/common/id.h" #include "ray/raylet/format/node_manager_generated.h" #include "ray/stats/stats.h" +#include "ray/util/sample.h" namespace { @@ -300,7 +301,7 @@ void NodeManager::Heartbeat() { uint64_t now_ms = current_time_ms(); uint64_t interval = now_ms - last_heartbeat_at_ms_; if (interval > RayConfig::instance().num_heartbeats_warning() * - RayConfig::instance().heartbeat_timeout_milliseconds()) { + RayConfig::instance().raylet_heartbeat_timeout_milliseconds()) { RAY_LOG(WARNING) << "Last heartbeat was sent " << interval << " ms ago "; } last_heartbeat_at_ms_ = now_ms; @@ -328,6 +329,25 @@ void NodeManager::Heartbeat() { heartbeat_data->add_resource_load_capacity(resource_pair.second); } + size_t max_size = RayConfig::instance().raylet_max_active_object_ids(); + std::unordered_set active_object_ids = worker_pool_.GetActiveObjectIDs(); + if (active_object_ids.size() <= max_size) { + for (const auto &object_id : active_object_ids) { + heartbeat_data->add_active_object_id(object_id.Binary()); + } + } else { + // If there are more than the configured maximum number of object IDs to send per + // heartbeat, sample from them randomly. + // TODO(edoakes): we might want to improve the sampling technique here, for example + // preferring object IDs with the earliest last-refreshed timestamp. + std::vector downsampled; + random_sample(active_object_ids.begin(), active_object_ids.end(), max_size, + &downsampled); + for (const auto &object_id : downsampled) { + heartbeat_data->add_active_object_id(object_id.Binary()); + } + } + ray::Status status = heartbeat_table.Add( JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, /*success_callback=*/nullptr); @@ -655,7 +675,13 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); // Update load information provided by each heartbeat. + // TODO(edoakes): this isn't currently used, but will be used to refresh the LRU + // cache in the object store. + std::unordered_set active_object_ids; for (const auto &heartbeat_data : heartbeat_batch.batch()) { + for (int i = 0; i < heartbeat_data.active_object_id_size(); i++) { + active_object_ids.insert(ObjectID::FromBinary(heartbeat_data.active_object_id(i))); + } const ClientID &client_id = ClientID::FromBinary(heartbeat_data.client_id()); if (client_id == local_client_id) { // Skip heartbeats from self. @@ -663,6 +689,7 @@ void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_b } HeartbeatAdded(client_id, heartbeat_data); } + RAY_LOG(DEBUG) << "Total active object IDs received: " << active_object_ids.size(); } void NodeManager::HandleActorStateTransition(const ActorID &actor_id, @@ -906,6 +933,9 @@ void NodeManager::ProcessClientMessage( case protocol::MessageType::NotifyActorResumedFromCheckpoint: { ProcessNotifyActorResumedFromCheckpoint(message_data); } break; + case protocol::MessageType::ReportActiveObjectIDs: { + ProcessReportActiveObjectIDs(client, message_data); + } break; default: RAY_LOG(FATAL) << "Received unexpected message type " << message_type; @@ -1309,6 +1339,19 @@ void NodeManager::ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message checkpoint_id_to_restore_.emplace(actor_id, checkpoint_id); } +void NodeManager::ProcessReportActiveObjectIDs( + const std::shared_ptr &client, const uint8_t *message_data) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + if (!worker) { + worker = worker_pool_.GetRegisteredDriver(client); + RAY_CHECK(worker); + } + + auto message = flatbuffers::GetRoot(message_data); + worker->SetActiveObjectIds( + unordered_set_from_flatbuf(*message->object_ids())); +} + void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, rpc::ForwardTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 173f1e0b7..2d9779d22 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -446,6 +446,13 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param message_data A pointer to the message data. void ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message_data); + /// Process client message of ReportActiveObjectIDs. + /// + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + void ProcessReportActiveObjectIDs(const std::shared_ptr &client, + const uint8_t *message_data); + /// Update actor frontier when a task finishes. /// If the task is an actor creation task and the actor was resumed from a checkpoint, /// restore the frontier from the checkpoint. Otherwise, just extend actor frontier. diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 1c8871bf0..36dd2caa7 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -390,3 +390,13 @@ ray::Status RayletClient::SetResource(const std::string &resource_name, fbb.Finish(message); return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb); } + +ray::Status RayletClient::ReportActiveObjectIDs( + const std::unordered_set &object_ids) { + flatbuffers::FlatBufferBuilder fbb; + auto message = + ray::protocol::CreateReportActiveObjectIDs(fbb, to_flatbuf(fbb, object_ids)); + fbb.Finish(message); + + return conn_->WriteMessage(MessageType::ReportActiveObjectIDs, &fbb); +} diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 235ba9cfb..71d0075b0 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -176,6 +176,11 @@ class RayletClient { ray::Status SetResource(const std::string &resource_name, const double capacity, const ray::ClientID &client_Id); + /// Notifies the raylet of the object IDs currently in use on this worker. + /// \param object_ids The set of object IDs currently in use. + /// \return ray::Status + ray::Status ReportActiveObjectIDs(const std::unordered_set &object_ids); + Language GetLanguage() const { return language_; } WorkerID GetWorkerID() const { return worker_id_; } diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 947775d6d..e8256bf78 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -113,6 +113,14 @@ void Worker::AcquireTaskCpuResources(const ResourceIdSet &cpu_resources) { task_resource_ids_.Release(cpu_resources); } +const std::unordered_set &Worker::GetActiveObjectIds() const { + return active_object_ids_; +} + +void Worker::SetActiveObjectIds(const std::unordered_set &&object_ids) { + active_object_ids_ = object_ids; +} + bool Worker::UsePush() const { return rpc_client_ != nullptr; } void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set, diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 93531ac8a..79642527e 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -57,6 +57,9 @@ class Worker { ResourceIdSet ReleaseTaskCpuResources(); void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources); + const std::unordered_set &GetActiveObjectIds() const; + void SetActiveObjectIds(const std::unordered_set &&object_ids); + bool UsePush() const; void AssignTask(const Task &task, const ResourceIdSet &resource_id_set, const std::function finish_assign_callback); @@ -91,6 +94,8 @@ class Worker { // of a task. ResourceIdSet task_resource_ids_; std::unordered_set blocked_task_ids_; + /// The set of object IDs that are currently in use on the worker. + std::unordered_set active_object_ids_; /// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all /// workers. rpc::ClientCallManager &client_call_manager_; diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index ade829870..b303fb602 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -425,6 +425,21 @@ bool WorkerPool::HasPendingWorkerForTask(const Language &language, return it != state.tasks_to_dedicated_workers.end(); } +std::unordered_set WorkerPool::GetActiveObjectIDs() const { + std::unordered_set active_object_ids; + for (const auto &entry : states_by_lang_) { + for (const auto &worker : entry.second.registered_workers) { + active_object_ids.insert(worker->GetActiveObjectIds().begin(), + worker->GetActiveObjectIds().end()); + } + for (const auto &driver : entry.second.registered_drivers) { + active_object_ids.insert(driver->GetActiveObjectIds().begin(), + driver->GetActiveObjectIds().end()); + } + } + return active_object_ids; +} + std::string WorkerPool::DebugString() const { std::stringstream result; result << "WorkerPool:"; diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index fe89140f1..ff1703be7 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -121,6 +121,10 @@ class WorkerPool { /// \param task_id The task that we want to query. bool HasPendingWorkerForTask(const Language &language, const TaskID &task_id); + /// Get the set of active object IDs from all workers in the worker pool. + /// \return A set containing the active object IDs. + std::unordered_set GetActiveObjectIDs() const; + /// Returns debug string for class. /// /// \return string. diff --git a/src/ray/util/sample.h b/src/ray/util/sample.h new file mode 100644 index 000000000..40c1bc0bf --- /dev/null +++ b/src/ray/util/sample.h @@ -0,0 +1,33 @@ +#ifndef RAY_UTIL_SAMPLE_H +#define RAY_UTIL_SAMPLE_H + +#include "absl/random/random.h" +#include "absl/random/uniform_int_distribution.h" + +// Randomly samples num_elements from the elements between first and last using reservoir +// sampling. +template ::value_type> +void random_sample(Iterator begin, Iterator end, size_t num_elements, + std::vector *out) { + out->resize(0); + absl::BitGen gen; + if (num_elements == 0) { + return; + } + + size_t current_index = 0; + for (auto it = begin; it != end; it++) { + if (current_index < num_elements) { + out->push_back(*it); + } else { + size_t random_index = absl::uniform_int_distribution(0, current_index)(gen); + if (random_index < num_elements) { + out->at(random_index) = *it; + } + } + current_index++; + } + return; +} + +#endif // RAY_UTIL_SAMPLE_H diff --git a/src/ray/util/sample_test.cc b/src/ray/util/sample_test.cc new file mode 100644 index 000000000..8621ad305 --- /dev/null +++ b/src/ray/util/sample_test.cc @@ -0,0 +1,68 @@ +#include + +#include "gtest/gtest.h" +#include "ray/util/sample.h" + +namespace ray { + +class RandomSampleTest : public ::testing::Test { + protected: + std::vector *sample; + std::vector *test_vector; + virtual void SetUp() { + sample = new std::vector(); + test_vector = new std::vector(); + for (int i = 0; i < 10; i++) { + test_vector->push_back(i); + } + } + + virtual void TearDown() { + delete sample; + delete test_vector; + } +}; + +TEST_F(RandomSampleTest, TestEmpty) { + random_sample(test_vector->begin(), test_vector->end(), 0, sample); + ASSERT_EQ(sample->size(), 0); +} + +TEST_F(RandomSampleTest, TestSmallerThanSampleSize) { + random_sample(test_vector->begin(), test_vector->end(), test_vector->size() + 1, + sample); + ASSERT_EQ(sample->size(), test_vector->size()); +} + +TEST_F(RandomSampleTest, TestEqualToSampleSize) { + random_sample(test_vector->begin(), test_vector->end(), test_vector->size(), sample); + ASSERT_EQ(sample->size(), test_vector->size()); +} + +TEST_F(RandomSampleTest, TestLargerThanSampleSize) { + random_sample(test_vector->begin(), test_vector->end(), test_vector->size() - 1, + sample); + ASSERT_EQ(sample->size(), test_vector->size() - 1); +} + +TEST_F(RandomSampleTest, TestEqualOccurrenceChance) { + int trials = 100000; + std::vector occurrences(test_vector->size(), 0); + for (int i = 0; i < trials; i++) { + random_sample(test_vector->begin(), test_vector->end(), test_vector->size() / 2, + sample); + for (int idx : *sample) { + occurrences[idx]++; + } + } + for (int count : occurrences) { + ASSERT_NEAR(trials / 2, count, 0.05 * trials / 2); + } +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}