Send active object IDs to the raylet (#5803)

* Send active object IDs to the raylet

* comment

* comments

* dedup

* signed int in config

* comments

* Remove object ID from monitor

* Fix test

* re-add check

* fix cast

* check if core worker

* Add comment

* Reservoir sampling

* Fix lint

* Pointer return

* tmp

* Fix merge

* Initialize object ids properly

* Fix lint
This commit is contained in:
Edward Oakes
2019-10-20 22:05:28 -07:00
committed by GitHub
parent f286356e06
commit fc56872012
28 changed files with 393 additions and 34 deletions
+11 -10
View File
@@ -240,16 +240,6 @@ cc_library(
],
)
cc_test(
name = "common_test",
srcs = glob(["src/ray/common/**/*_test.cc"]),
copts = COPTS,
deps = [
":ray_common",
"@com_google_googletest//:gtest_main",
],
)
cc_binary(
name = "raylet",
srcs = ["src/ray/raylet/main.cc"],
@@ -491,6 +481,16 @@ cc_test(
],
)
cc_test(
name = "sample_test",
srcs = ["src/ray/util/sample_test.cc"],
copts = COPTS,
deps = [
":ray_common",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "task_dependency_manager_test",
srcs = ["src/ray/raylet/task_dependency_manager_test.cc"],
@@ -595,6 +595,7 @@ cc_library(
deps = [
":sha256",
"@com_github_google_glog//:glog",
"@com_google_absl//absl/random",
"@plasma//:plasma_client",
],
)
+1 -1
View File
@@ -88,7 +88,7 @@ def ray_deps_setup():
# This is how diamond dependencies are prevented.
git_repository(
name = "com_google_absl",
commit = "5b65c4af5107176555b23a638e5947686410ac1f",
commit = "aa844899c937bde5d2b24f276b59997e5b668bde",
remote = "https://github.com/abseil/abseil-cpp.git",
)
+12
View File
@@ -725,3 +725,15 @@ cdef class CoreWorker:
check_status(self.core_worker.get().SerializeActorHandle(
c_actor_id, &output))
return output
def add_active_object_id(self, ObjectID object_id):
cdef:
CObjectID c_object_id = object_id.native()
with nogil:
self.core_worker.get().AddActiveObjectID(c_object_id)
def remove_active_object_id(self, ObjectID object_id):
cdef:
CObjectID c_object_id = object_id.native()
with nogil:
self.core_worker.get().RemoveActiveObjectID(c_object_id)
+2
View File
@@ -90,3 +90,5 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
CActorID DeserializeAndRegisterActorHandle(const c_string &bytes)
CRayStatus SerializeActorHandle(const CActorID &actor_id, c_string
*bytes)
void AddActiveObjectID(const CObjectID &object_id)
void RemoveActiveObjectID(const CObjectID &object_id)
+1 -1
View File
@@ -12,7 +12,7 @@ cdef extern from "ray/common/ray_config.h" nogil:
int64_t handler_warning_timeout_ms() const
int64_t heartbeat_timeout_milliseconds() const
int64_t raylet_heartbeat_timeout_milliseconds() const
int64_t debug_dump_period_milliseconds() const
+2 -2
View File
@@ -10,8 +10,8 @@ cdef class Config:
return RayConfig.instance().handler_warning_timeout_ms()
@staticmethod
def heartbeat_timeout_milliseconds():
return RayConfig.instance().heartbeat_timeout_milliseconds()
def raylet_heartbeat_timeout_milliseconds():
return RayConfig.instance().raylet_heartbeat_timeout_milliseconds()
@staticmethod
def debug_dump_period_milliseconds():
+22 -2
View File
@@ -22,6 +22,7 @@ from ray.includes.unique_ids cimport (
CWorkerID,
)
import ray
from ray.utils import decode
@@ -128,13 +129,32 @@ cdef class UniqueID(BaseID):
cdef class ObjectID(BaseID):
cdef CObjectID data
cdef object buffer_ref
cdef:
CObjectID data
object buffer_ref
# Flag indicating whether or not this object ID was added to the set
# of active IDs in the core worker so we know whether we should clean
# it up.
c_bool in_core_worker
def __init__(self, id):
check_id(id)
self.data = CObjectID.FromBinary(<c_string>id)
worker = ray.worker.global_worker
# TODO(edoakes): there are dummy object IDs being created in
# includes/task.pxi before the core worker is initialized.
if hasattr(worker, "core_worker"):
worker.core_worker.add_active_object_id(self)
self.in_core_worker = True
else:
self.in_core_worker = False
def __dealloc__(self):
worker = ray.worker.global_worker
if self.in_core_worker and hasattr(worker, "core_worker"):
worker.core_worker.remove_active_object_id(self)
cdef CObjectID native(self):
return <CObjectID>self.data
+2 -1
View File
@@ -354,7 +354,8 @@ class Monitor(object):
# Wait for a heartbeat interval before processing the next round of
# messages.
time.sleep(ray._config.heartbeat_timeout_milliseconds() * 1e-3)
time.sleep(
ray._config.raylet_heartbeat_timeout_milliseconds() * 1e-3)
def run(self):
try:
+1 -1
View File
@@ -30,7 +30,7 @@ def _parse_client_table(redis_client):
Returns:
A list of information about the nodes in the cluster.
"""
NIL_CLIENT_ID = ray.ObjectID.nil().binary()
NIL_CLIENT_ID = ray.ClientID.nil().binary()
message = redis_client.execute_command(
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "",
NIL_CLIENT_ID)
+38 -1
View File
@@ -2,7 +2,7 @@
#define COMMON_PROTOCOL_H
#include <flatbuffers/flatbuffers.h>
#include <unordered_map>
#include <unordered_set>
#include "ray/common/id.h"
#include "ray/util/logging.h"
@@ -31,6 +31,14 @@ template <typename ID>
const std::vector<ID> from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &vector);
/// Convert a flatbuffer vector of strings to an unordered_set of unique IDs.
///
/// @param vector The flatbuffer vector.
/// @return The unordered set of IDs.
template <typename ID>
const std::unordered_set<ID> unordered_set_from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &vector);
/// Convert a flatbuffer of string that concatenated
/// unique IDs to a vector of unique IDs.
///
@@ -68,6 +76,15 @@ template <typename ID>
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector<ID> &ids);
/// Convert an unordered_set of unique IDs to a flatbuffer vector of strings.
///
/// @param fbb Reference to the flatbuffer builder.
/// @param ids Unordered set of IDs.
/// @return Flatbuffer vector of strings.
template <typename ID>
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::unordered_set<ID> &ids);
/// Convert a flatbuffer string to a std::string.
///
/// @param fbb Reference to the flatbuffer builder.
@@ -103,6 +120,16 @@ const std::vector<ID> from_flatbuf(
return ids;
}
template <typename ID>
const std::unordered_set<ID> unordered_set_from_flatbuf(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> &vector) {
std::unordered_set<ID> ids;
for (int64_t i = 0; i < vector.Length(); i++) {
ids.insert(from_flatbuf<ID>(*vector.Get(i)));
}
return ids;
}
template <typename ID>
const std::vector<ID> ids_from_flatbuf(const flatbuffers::String &string) {
const auto &ids = string_from_flatbuf(string);
@@ -151,4 +178,14 @@ to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::vector<ID> &ids) {
return fbb.CreateVector(results);
}
template <typename ID>
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, const std::unordered_set<ID> &ids) {
std::vector<flatbuffers::Offset<flatbuffers::String>> results;
for (auto id : ids) {
results.push_back(to_flatbuf(fbb, id));
}
return fbb.CreateVector(results);
}
#endif
+9 -2
View File
@@ -20,8 +20,8 @@ RAY_CONFIG(int64_t, ray_cookie, 0x5241590000000000)
/// warning is logged that the handler is taking too long.
RAY_CONFIG(int64_t, handler_warning_timeout_ms, 100)
/// The duration between heartbeats. These are sent by the raylet.
RAY_CONFIG(int64_t, heartbeat_timeout_milliseconds, 100)
/// The duration between heartbeats sent by the raylets.
RAY_CONFIG(int64_t, raylet_heartbeat_timeout_milliseconds, 100)
/// If a component has not sent a heartbeat in the last num_heartbeats_timeout
/// heartbeat intervals, the raylet monitor process will report
/// it as dead to the db_client table.
@@ -46,6 +46,10 @@ RAY_CONFIG(bool, fair_queueing_enabled, true)
/// this many milliseconds.
RAY_CONFIG(int64_t, initial_reconstruction_timeout_milliseconds, 10000)
/// The duration between heartbeats sent from the workers to the raylet.
/// If set to a negative value, the heartbeats will not be sent.
RAY_CONFIG(int64_t, worker_heartbeat_timeout_milliseconds, 500)
/// These are used by the worker to set timeouts and to batch requests when
/// getting objects.
RAY_CONFIG(int64_t, get_timeout_milliseconds, 1000)
@@ -85,6 +89,9 @@ RAY_CONFIG(int64_t, max_num_to_reconstruct, 10000)
/// regular raylet fetch timeout handler.
RAY_CONFIG(int64_t, raylet_fetch_request_size, 10000)
/// The maximum number of active object IDs to report in a heartbeat.
RAY_CONFIG(size_t, raylet_max_active_object_ids, 1000)
/// The duration that we wait after sending a worker SIGTERM before sending
/// the worker SIGKILL.
RAY_CONFIG(int64_t, kill_worker_timeout_milliseconds, 100)
+51 -3
View File
@@ -1,5 +1,6 @@
#include <boost/asio/signal_set.hpp>
#include "ray/common/ray_config.h"
#include "ray/common/task/task_util.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/core_worker.h"
@@ -53,7 +54,8 @@ CoreWorker::CoreWorker(
raylet_socket_(raylet_socket),
log_dir_(log_dir),
worker_context_(worker_type, job_id),
io_work_(io_service_) {
io_work_(io_service_),
heartbeat_timer_(io_service_) {
// Initialize logging if log_dir is passed. Otherwise, it must be initialized
// and cleaned up by the caller.
if (log_dir_ != "") {
@@ -108,6 +110,14 @@ CoreWorker::CoreWorker(
(worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
language_, rpc_server_port));
// Set timer to periodically send heartbeats containing active object IDs to the raylet.
// If the heartbeat timeout is < 0, the heartbeats are disabled.
if (RayConfig::instance().worker_heartbeat_timeout_milliseconds() >= 0) {
heartbeat_timer_.expires_from_now(boost::asio::chrono::milliseconds(
RayConfig::instance().worker_heartbeat_timeout_milliseconds()));
heartbeat_timer_.async_wait(boost::bind(&CoreWorker::ReportActiveObjectIDs, this));
}
io_thread_ = std::thread(&CoreWorker::StartIOService, this);
// Create an entry for the driver task in the task table. This task is
@@ -139,6 +149,46 @@ CoreWorker::CoreWorker(
object_interface_->CreateStoreProvider(StoreProviderType::MEMORY)));
}
void CoreWorker::AddActiveObjectID(const ObjectID &object_id) {
io_service_.post([this, object_id]() -> void {
active_object_ids_.insert(object_id);
active_object_ids_updated_ = true;
});
}
void CoreWorker::RemoveActiveObjectID(const ObjectID &object_id) {
io_service_.post([this, object_id]() -> void {
if (active_object_ids_.erase(object_id)) {
active_object_ids_updated_ = true;
} else {
RAY_LOG(WARNING) << "Tried to erase non-existent object ID" << object_id;
}
});
}
void CoreWorker::ReportActiveObjectIDs() {
// Only send a heartbeat when the set of active object IDs has changed because the
// raylet only modifies the set of IDs when it receives a heartbeat.
if (active_object_ids_updated_) {
RAY_LOG(DEBUG) << "Sending " << active_object_ids_.size() << " object IDs to raylet.";
if (active_object_ids_.size() >
RayConfig::instance().raylet_max_active_object_ids()) {
RAY_LOG(WARNING) << active_object_ids_.size()
<< "object IDs are currently in scope. "
<< "This may lead to required objects being garbage collected.";
}
RAY_CHECK_OK(raylet_client_->ReportActiveObjectIDs(active_object_ids_));
}
// Reset the timer from the previous expiration time to avoid drift.
heartbeat_timer_.expires_at(
heartbeat_timer_.expiry() +
boost::asio::chrono::milliseconds(
RayConfig::instance().worker_heartbeat_timeout_milliseconds()));
heartbeat_timer_.async_wait(boost::bind(&CoreWorker::ReportActiveObjectIDs, this));
active_object_ids_updated_ = false;
}
CoreWorker::~CoreWorker() {
io_service_.stop();
io_thread_.join();
@@ -159,8 +209,6 @@ void CoreWorker::Disconnect() {
}
}
void CoreWorker::StartIOService() { io_service_.run(); }
std::unique_ptr<worker::ProfileEvent> CoreWorker::CreateProfileEvent(
const std::string &event_type) {
return std::unique_ptr<worker::ProfileEvent>(
+22 -2
View File
@@ -158,6 +158,14 @@ class CoreWorker {
/// \return Status::Invalid if we don't have the specified handle.
Status SerializeActorHandle(const ActorID &actor_id, std::string *output) const;
// Add this object ID to the set of active object IDs that is sent to the raylet
// in the heartbeat messsage.
void AddActiveObjectID(const ObjectID &object_id);
// Remove this object ID from the set of active object IDs that is sent to the raylet
// in the heartbeat messsage.
void RemoveActiveObjectID(const ObjectID &object_id);
private:
/// Give this worker a handle to an actor.
///
@@ -179,7 +187,9 @@ class CoreWorker {
/// \return Status::Invalid if we don't have this actor handle.
Status GetActorHandle(const ActorID &actor_id, ActorHandle **actor_handle) const;
void StartIOService();
void StartIOService() { io_service_.run(); }
void ReportActiveObjectIDs();
const WorkerType worker_type_;
const Language language_;
@@ -197,6 +207,9 @@ class CoreWorker {
boost::asio::io_service io_service_;
/// Keeps the io_service_ alive.
boost::asio::io_service::work io_work_;
/// Timer used to periodically send heartbeat containing active object IDs to the
/// raylet.
boost::asio::steady_timer heartbeat_timer_;
std::thread io_thread_;
std::shared_ptr<worker::Profiler> profiler_;
@@ -206,7 +219,14 @@ class CoreWorker {
std::unique_ptr<CoreWorkerObjectInterface> object_interface_;
/// Map from actor ID to a handle to that actor.
std::unordered_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_;
std::unordered_map<ActorID, std::unique_ptr<ActorHandle> > actor_handles_;
/// Set of object IDs that are in scope in the language worker.
std::unordered_set<ObjectID> active_object_ids_;
/// Indicates whether or not the active_object_ids map has changed since the
/// last time it was sent to the raylet.
bool active_object_ids_updated_ = false;
/// Only available if it's not a driver.
std::unique_ptr<CoreWorkerTaskExecutionInterface> task_execution_interface_;
+2
View File
@@ -198,6 +198,8 @@ message HeartbeatTableData {
// Aggregate outstanding resource load on this node manager.
repeated string resource_load_label = 6;
repeated double resource_load_capacity = 7;
// Object IDs that are in use by workers on this node manager's node.
repeated bytes active_object_id = 8;
}
message HeartbeatBatchTableData {
+4
View File
@@ -4,6 +4,10 @@ package ray.rpc;
import "src/ray/protobuf/common.proto";
message ActiveObjectIDs {
repeated bytes object_ids = 1;
}
message AssignTaskRequest {
// The task to be pushed.
Task task = 1;
+12 -5
View File
@@ -74,8 +74,10 @@ enum MessageType:int {
NotifyActorResumedFromCheckpoint,
// A node manager requests to connect to another node manager.
ConnectClient,
// Set dynamic custom resource
// Set dynamic custom resource.
SetResourceRequest,
// Update the active set of object IDs in use on this worker.
ReportActiveObjectIDs,
}
table TaskExecutionSpecification {
@@ -238,11 +240,16 @@ table ConnectClient {
client_id: string;
}
table SetResourceRequest{
// Name of the resource to be set
table SetResourceRequest {
// Name of the resource to be set.
resource_name: string;
// Capacity of the resource to be set
// Capacity of the resource to be set.
capacity: double;
// Client ID where this resource will be set
// Client ID where this resource will be set.
client_id: string;
}
table ReportActiveObjectIDs {
// Object IDs that are active in the worker.
object_ids: [string];
}
+1 -1
View File
@@ -120,7 +120,7 @@ int main(int argc, char *argv[]) {
}
node_manager_config.heartbeat_period_ms =
RayConfig::instance().heartbeat_timeout_milliseconds();
RayConfig::instance().raylet_heartbeat_timeout_milliseconds();
node_manager_config.debug_dump_period_ms =
RayConfig::instance().debug_dump_period_milliseconds();
node_manager_config.fair_queueing_enabled =
+1 -1
View File
@@ -94,7 +94,7 @@ void Monitor::Tick() {
}
auto heartbeat_period = boost::posix_time::milliseconds(
RayConfig::instance().heartbeat_timeout_milliseconds());
RayConfig::instance().raylet_heartbeat_timeout_milliseconds());
heartbeat_timer_.expires_from_now(heartbeat_period);
heartbeat_timer_.async_wait([this](const boost::system::error_code &error) {
RAY_CHECK(!error);
+44 -1
View File
@@ -9,6 +9,7 @@
#include "ray/common/id.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/stats/stats.h"
#include "ray/util/sample.h"
namespace {
@@ -300,7 +301,7 @@ void NodeManager::Heartbeat() {
uint64_t now_ms = current_time_ms();
uint64_t interval = now_ms - last_heartbeat_at_ms_;
if (interval > RayConfig::instance().num_heartbeats_warning() *
RayConfig::instance().heartbeat_timeout_milliseconds()) {
RayConfig::instance().raylet_heartbeat_timeout_milliseconds()) {
RAY_LOG(WARNING) << "Last heartbeat was sent " << interval << " ms ago ";
}
last_heartbeat_at_ms_ = now_ms;
@@ -328,6 +329,25 @@ void NodeManager::Heartbeat() {
heartbeat_data->add_resource_load_capacity(resource_pair.second);
}
size_t max_size = RayConfig::instance().raylet_max_active_object_ids();
std::unordered_set<ObjectID> active_object_ids = worker_pool_.GetActiveObjectIDs();
if (active_object_ids.size() <= max_size) {
for (const auto &object_id : active_object_ids) {
heartbeat_data->add_active_object_id(object_id.Binary());
}
} else {
// If there are more than the configured maximum number of object IDs to send per
// heartbeat, sample from them randomly.
// TODO(edoakes): we might want to improve the sampling technique here, for example
// preferring object IDs with the earliest last-refreshed timestamp.
std::vector<ObjectID> downsampled;
random_sample(active_object_ids.begin(), active_object_ids.end(), max_size,
&downsampled);
for (const auto &object_id : downsampled) {
heartbeat_data->add_active_object_id(object_id.Binary());
}
}
ray::Status status = heartbeat_table.Add(
JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data,
/*success_callback=*/nullptr);
@@ -655,7 +675,13 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id,
void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch) {
const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId();
// Update load information provided by each heartbeat.
// TODO(edoakes): this isn't currently used, but will be used to refresh the LRU
// cache in the object store.
std::unordered_set<ObjectID> active_object_ids;
for (const auto &heartbeat_data : heartbeat_batch.batch()) {
for (int i = 0; i < heartbeat_data.active_object_id_size(); i++) {
active_object_ids.insert(ObjectID::FromBinary(heartbeat_data.active_object_id(i)));
}
const ClientID &client_id = ClientID::FromBinary(heartbeat_data.client_id());
if (client_id == local_client_id) {
// Skip heartbeats from self.
@@ -663,6 +689,7 @@ void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_b
}
HeartbeatAdded(client_id, heartbeat_data);
}
RAY_LOG(DEBUG) << "Total active object IDs received: " << active_object_ids.size();
}
void NodeManager::HandleActorStateTransition(const ActorID &actor_id,
@@ -906,6 +933,9 @@ void NodeManager::ProcessClientMessage(
case protocol::MessageType::NotifyActorResumedFromCheckpoint: {
ProcessNotifyActorResumedFromCheckpoint(message_data);
} break;
case protocol::MessageType::ReportActiveObjectIDs: {
ProcessReportActiveObjectIDs(client, message_data);
} break;
default:
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
@@ -1309,6 +1339,19 @@ void NodeManager::ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message
checkpoint_id_to_restore_.emplace(actor_id, checkpoint_id);
}
void NodeManager::ProcessReportActiveObjectIDs(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
if (!worker) {
worker = worker_pool_.GetRegisteredDriver(client);
RAY_CHECK(worker);
}
auto message = flatbuffers::GetRoot<protocol::ReportActiveObjectIDs>(message_data);
worker->SetActiveObjectIds(
unordered_set_from_flatbuf<ObjectID>(*message->object_ids()));
}
void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request,
rpc::ForwardTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
+7
View File
@@ -446,6 +446,13 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \param message_data A pointer to the message data.
void ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message_data);
/// Process client message of ReportActiveObjectIDs.
///
/// \param client The client that sent the message.
/// \param message_data A pointer to the message data.
void ProcessReportActiveObjectIDs(const std::shared_ptr<LocalClientConnection> &client,
const uint8_t *message_data);
/// Update actor frontier when a task finishes.
/// If the task is an actor creation task and the actor was resumed from a checkpoint,
/// restore the frontier from the checkpoint. Otherwise, just extend actor frontier.
+10
View File
@@ -390,3 +390,13 @@ ray::Status RayletClient::SetResource(const std::string &resource_name,
fbb.Finish(message);
return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb);
}
ray::Status RayletClient::ReportActiveObjectIDs(
const std::unordered_set<ObjectID> &object_ids) {
flatbuffers::FlatBufferBuilder fbb;
auto message =
ray::protocol::CreateReportActiveObjectIDs(fbb, to_flatbuf(fbb, object_ids));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::ReportActiveObjectIDs, &fbb);
}
+5
View File
@@ -176,6 +176,11 @@ class RayletClient {
ray::Status SetResource(const std::string &resource_name, const double capacity,
const ray::ClientID &client_Id);
/// Notifies the raylet of the object IDs currently in use on this worker.
/// \param object_ids The set of object IDs currently in use.
/// \return ray::Status
ray::Status ReportActiveObjectIDs(const std::unordered_set<ObjectID> &object_ids);
Language GetLanguage() const { return language_; }
WorkerID GetWorkerID() const { return worker_id_; }
+8
View File
@@ -113,6 +113,14 @@ void Worker::AcquireTaskCpuResources(const ResourceIdSet &cpu_resources) {
task_resource_ids_.Release(cpu_resources);
}
const std::unordered_set<ObjectID> &Worker::GetActiveObjectIds() const {
return active_object_ids_;
}
void Worker::SetActiveObjectIds(const std::unordered_set<ObjectID> &&object_ids) {
active_object_ids_ = object_ids;
}
bool Worker::UsePush() const { return rpc_client_ != nullptr; }
void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set,
+5
View File
@@ -57,6 +57,9 @@ class Worker {
ResourceIdSet ReleaseTaskCpuResources();
void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources);
const std::unordered_set<ObjectID> &GetActiveObjectIds() const;
void SetActiveObjectIds(const std::unordered_set<ObjectID> &&object_ids);
bool UsePush() const;
void AssignTask(const Task &task, const ResourceIdSet &resource_id_set,
const std::function<void(Status)> finish_assign_callback);
@@ -91,6 +94,8 @@ class Worker {
// of a task.
ResourceIdSet task_resource_ids_;
std::unordered_set<TaskID> blocked_task_ids_;
/// The set of object IDs that are currently in use on the worker.
std::unordered_set<ObjectID> active_object_ids_;
/// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all
/// workers.
rpc::ClientCallManager &client_call_manager_;
+15
View File
@@ -425,6 +425,21 @@ bool WorkerPool::HasPendingWorkerForTask(const Language &language,
return it != state.tasks_to_dedicated_workers.end();
}
std::unordered_set<ObjectID> WorkerPool::GetActiveObjectIDs() const {
std::unordered_set<ObjectID> active_object_ids;
for (const auto &entry : states_by_lang_) {
for (const auto &worker : entry.second.registered_workers) {
active_object_ids.insert(worker->GetActiveObjectIds().begin(),
worker->GetActiveObjectIds().end());
}
for (const auto &driver : entry.second.registered_drivers) {
active_object_ids.insert(driver->GetActiveObjectIds().begin(),
driver->GetActiveObjectIds().end());
}
}
return active_object_ids;
}
std::string WorkerPool::DebugString() const {
std::stringstream result;
result << "WorkerPool:";
+4
View File
@@ -121,6 +121,10 @@ class WorkerPool {
/// \param task_id The task that we want to query.
bool HasPendingWorkerForTask(const Language &language, const TaskID &task_id);
/// Get the set of active object IDs from all workers in the worker pool.
/// \return A set containing the active object IDs.
std::unordered_set<ObjectID> GetActiveObjectIDs() const;
/// Returns debug string for class.
///
/// \return string.
+33
View File
@@ -0,0 +1,33 @@
#ifndef RAY_UTIL_SAMPLE_H
#define RAY_UTIL_SAMPLE_H
#include "absl/random/random.h"
#include "absl/random/uniform_int_distribution.h"
// Randomly samples num_elements from the elements between first and last using reservoir
// sampling.
template <class Iterator, class T = typename std::iterator_traits<Iterator>::value_type>
void random_sample(Iterator begin, Iterator end, size_t num_elements,
std::vector<T> *out) {
out->resize(0);
absl::BitGen gen;
if (num_elements == 0) {
return;
}
size_t current_index = 0;
for (auto it = begin; it != end; it++) {
if (current_index < num_elements) {
out->push_back(*it);
} else {
size_t random_index = absl::uniform_int_distribution<size_t>(0, current_index)(gen);
if (random_index < num_elements) {
out->at(random_index) = *it;
}
}
current_index++;
}
return;
}
#endif // RAY_UTIL_SAMPLE_H
+68
View File
@@ -0,0 +1,68 @@
#include <vector>
#include "gtest/gtest.h"
#include "ray/util/sample.h"
namespace ray {
class RandomSampleTest : public ::testing::Test {
protected:
std::vector<int> *sample;
std::vector<int> *test_vector;
virtual void SetUp() {
sample = new std::vector<int>();
test_vector = new std::vector<int>();
for (int i = 0; i < 10; i++) {
test_vector->push_back(i);
}
}
virtual void TearDown() {
delete sample;
delete test_vector;
}
};
TEST_F(RandomSampleTest, TestEmpty) {
random_sample(test_vector->begin(), test_vector->end(), 0, sample);
ASSERT_EQ(sample->size(), 0);
}
TEST_F(RandomSampleTest, TestSmallerThanSampleSize) {
random_sample(test_vector->begin(), test_vector->end(), test_vector->size() + 1,
sample);
ASSERT_EQ(sample->size(), test_vector->size());
}
TEST_F(RandomSampleTest, TestEqualToSampleSize) {
random_sample(test_vector->begin(), test_vector->end(), test_vector->size(), sample);
ASSERT_EQ(sample->size(), test_vector->size());
}
TEST_F(RandomSampleTest, TestLargerThanSampleSize) {
random_sample(test_vector->begin(), test_vector->end(), test_vector->size() - 1,
sample);
ASSERT_EQ(sample->size(), test_vector->size() - 1);
}
TEST_F(RandomSampleTest, TestEqualOccurrenceChance) {
int trials = 100000;
std::vector<int> occurrences(test_vector->size(), 0);
for (int i = 0; i < trials; i++) {
random_sample(test_vector->begin(), test_vector->end(), test_vector->size() / 2,
sample);
for (int idx : *sample) {
occurrences[idx]++;
}
}
for (int count : occurrences) {
ASSERT_NEAR(trials / 2, count, 0.05 * trials / 2);
}
}
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}