[xray] Basic actor support (#1835)

This commit is contained in:
Stephanie Wang
2018-04-06 00:17:14 -07:00
committed by Robert Nishihara
parent 313b864e66
commit bf194db4bc
27 changed files with 652 additions and 181 deletions
+2
View File
@@ -802,6 +802,8 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
actor_creation_resources,
actor_method_cpus,
ray.worker.global_worker)
# Increment the actor counter to account for the creation task.
actor_counter += 1
# Instantiate the actor handle.
actor_object = cls.__new__(cls)
+4 -7
View File
@@ -49,14 +49,11 @@
return RedisModule_ReplyWithError(ctx, (MESSAGE)); \
}
// NOTE(swang): The order of prefixes here must match the TablePrefix enum
// defined in src/ray/gcs/format/gcs.fbs.
static const char *table_prefixes[] = {
NULL,
"TASK:",
"TASK:",
"CLIENT:",
"OBJECT:",
"FUNCTION:",
"TASK_RECONSTRUCTION:",
NULL, "TASK:", "TASK:", "CLIENT:",
"OBJECT:", "ACTOR:", "FUNCTION:", "TASK_RECONSTRUCTION:",
"HEARTBEAT:",
};
@@ -358,6 +358,12 @@ void handle_convert_worker_to_actor(
* filled out, so fill out the correct worker field now. */
algorithm_state->local_actor_infos[actor_id].worker = worker;
}
/* Increment the task counter for the creator's handle to account for the
* actor creation task. */
auto &task_counters =
algorithm_state->local_actor_infos[actor_id].task_counters;
RAY_CHECK(task_counters[ActorHandleID::nil()] == 0);
task_counters[ActorHandleID::nil()]++;
}
/**
+1 -1
View File
@@ -52,7 +52,7 @@ set(RAY_SRCS
raylet/worker.cc
raylet/worker_pool.cc
raylet/scheduling_resources.cc
raylet/actor.cc
raylet/actor_registration.cc
raylet/scheduling_queue.cc
raylet/scheduling_policy.cc
raylet/task_dependency_manager.cc
+3
View File
@@ -10,6 +10,7 @@ AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) {
context_.reset(new RedisContext());
client_table_.reset(new ClientTable(context_, this, client_id));
object_table_.reset(new ObjectTable(context_, this));
actor_table_.reset(new ActorTable(context_, this));
task_table_.reset(new TaskTable(context_, this));
raylet_task_table_.reset(new raylet::TaskTable(context_, this));
task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this));
@@ -48,6 +49,8 @@ TaskTable &AsyncGcsClient::task_table() { return *task_table_; }
raylet::TaskTable &AsyncGcsClient::raylet_task_table() { return *raylet_task_table_; }
ActorTable &AsyncGcsClient::actor_table() { return *actor_table_; }
TaskReconstructionLog &AsyncGcsClient::task_reconstruction_log() {
return *task_reconstruction_log_;
}
+2 -1
View File
@@ -45,12 +45,12 @@ class RAY_EXPORT AsyncGcsClient {
inline FunctionTable &function_table();
// TODO: Some API for getting the error on the driver
inline ClassTable &class_table();
inline ActorTable &actor_table();
inline CustomSerializerTable &custom_serializer_table();
inline ConfigTable &config_table();
ObjectTable &object_table();
TaskTable &task_table();
raylet::TaskTable &raylet_task_table();
ActorTable &actor_table();
TaskReconstructionLog &task_reconstruction_log();
ClientTable &client_table();
HeartbeatTable &heartbeat_table();
@@ -72,6 +72,7 @@ class RAY_EXPORT AsyncGcsClient {
std::unique_ptr<ObjectTable> object_table_;
std::unique_ptr<TaskTable> task_table_;
std::unique_ptr<raylet::TaskTable> raylet_task_table_;
std::unique_ptr<ActorTable> actor_table_;
std::unique_ptr<TaskReconstructionLog> task_reconstruction_log_;
std::unique_ptr<HeartbeatTable> heartbeat_table_;
std::unique_ptr<ClientTable> client_table_;
+12 -1
View File
@@ -10,9 +10,10 @@ enum TablePrefix:int {
RAYLET_TASK,
CLIENT,
OBJECT,
ACTOR,
FUNCTION,
TASK_RECONSTRUCTION,
HEARTBEAT
HEARTBEAT,
}
// The channel that Add operations to the Table should be published on, if any.
@@ -89,6 +90,16 @@ table ClassTableData {
}
table ActorTableData {
// The ID of the actor that was created.
actor_id: string;
// The dummy object ID returned by the actor creation task. If the actor
// dies, then this is the object that should be reconstructed for the actor
// to be recreated.
actor_creation_dummy_object_id: string;
// The ID of the driver that created the actor.
driver_id: string;
// The ID of the node manager that created the actor.
node_manager_id: string;
}
table ErrorTableData {
+1
View File
@@ -344,6 +344,7 @@ template class Log<ObjectID, ObjectTableData>;
template class Log<TaskID, ray::protocol::Task>;
template class Table<TaskID, ray::protocol::Task>;
template class Table<TaskID, TaskTableData>;
template class Log<ActorID, ActorTableData>;
template class Log<TaskID, TaskReconstructionData>;
template class Table<ClientID, HeartbeatTableData>;
+10 -2
View File
@@ -277,14 +277,22 @@ class FunctionTable : public Table<ObjectID, FunctionTableData> {
using ClassTable = Table<ClassID, ClassTableData>;
// TODO(swang): Set the pubsub channel for the actor table.
using ActorTable = Table<ActorID, ActorTableData>;
class ActorTable : public Log<ActorID, ActorTableData> {
public:
ActorTable(const std::shared_ptr<RedisContext> &context, AsyncGcsClient *client)
: Log(context, client) {
pubsub_channel_ = TablePubsub_ACTOR;
prefix_ = TablePrefix_TASK_RECONSTRUCTION;
}
};
class TaskReconstructionLog : public Log<TaskID, TaskReconstructionData> {
public:
TaskReconstructionLog(const std::shared_ptr<RedisContext> &context,
AsyncGcsClient *client)
: Log(context, client) {
prefix_ = TablePrefix_TASK_RECONSTRUCTION;
pubsub_channel_ = TablePubsub_ACTOR;
prefix_ = TablePrefix_ACTOR;
}
};
-15
View File
@@ -1,15 +0,0 @@
#include "actor.h"
namespace ray {
namespace raylet {
ActorInformation::ActorInformation() : id_(UniqueID::nil()) {}
ActorInformation::~ActorInformation() {}
const ActorID &ActorInformation::GetActorId() const { return this->id_; }
} // namespace raylet
} // namespace ray
-31
View File
@@ -1,31 +0,0 @@
#ifndef RAY_RAYLET_ACTOR_H
#define RAY_RAYLET_ACTOR_H
#include "ray/id.h"
namespace ray {
namespace raylet {
class ActorInformation {
public:
/// \brief ActorInformation constructor.
ActorInformation();
/// \brief ActorInformation destructor.
~ActorInformation();
/// \brief Return the id of this actor.
/// \return actor id.
const ActorID &GetActorId() const;
private:
/// Unique identifier for this actor.
ActorID id_;
}; // class ActorInformation
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_ACTOR_H
+41
View File
@@ -0,0 +1,41 @@
#include "ray/raylet/actor_registration.h"
#include "ray/util/logging.h"
namespace ray {
namespace raylet {
ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data)
: actor_table_data_(actor_table_data),
execution_dependency_(ObjectID::nil()),
frontier_() {}
const ClientID ActorRegistration::GetNodeManagerId() const {
return ClientID::from_binary(actor_table_data_.node_manager_id);
}
const ObjectID ActorRegistration::GetActorCreationDependency() const {
return ObjectID::from_binary(actor_table_data_.actor_creation_dummy_object_id);
}
const ObjectID ActorRegistration::GetExecutionDependency() const {
return execution_dependency_;
}
const std::unordered_map<ActorHandleID, ActorRegistration::FrontierLeaf, UniqueIDHasher>
&ActorRegistration::GetFrontier() const {
return frontier_;
}
void ActorRegistration::ExtendFrontier(const ActorHandleID &handle_id,
const ObjectID &execution_dependency) {
auto &frontier_entry = frontier_[handle_id];
frontier_entry.task_counter++;
frontier_entry.execution_dependency = execution_dependency;
execution_dependency_ = execution_dependency;
}
} // namespace raylet
} // namespace ray
+96
View File
@@ -0,0 +1,96 @@
#ifndef RAY_RAYLET_ACTOR_REGISTRATION_H
#define RAY_RAYLET_ACTOR_REGISTRATION_H
#include <unordered_map>
#include "ray/gcs/format/gcs_generated.h"
#include "ray/id.h"
namespace ray {
namespace raylet {
/// \class ActorRegistration
///
/// Information about an actor registered in the system. This includes the
/// actor's current node manager location, and if local, information about its
/// current execution state, used for reconstruction purposes.
class ActorRegistration {
public:
/// Create an actor registration.
///
/// \param actor_table_data Information from the global actor table about
/// this actor. This includes the actor's node manager location.
ActorRegistration(const ActorTableDataT &actor_table_data);
/// Each actor may have multiple callers, or "handles". A frontier leaf
/// represents the execution state of the actor with respect to a single
/// handle.
struct FrontierLeaf {
/// The number of tasks submitted by this handle that have executed on the
/// actor so far.
int64_t task_counter;
/// The execution dependency returned by the task submitted by this handle
/// that most recently executed on the actor.
ObjectID execution_dependency;
};
/// Get the actor's node manager location.
///
/// \return The actor's node manager location. All tasks for the actor should
/// be forwarded to this node.
const ClientID GetNodeManagerId() const;
/// Get the object that represents the actor's initial state. This is the
/// execution dependency returned by this actor's creation task. If
/// reconstructed, this will recreate the actor.
///
/// \return The execution dependency returned by the actor's creation task.
const ObjectID GetActorCreationDependency() const;
/// Get the object that represents the actor's current state. This is the
/// execution dependency returned by the task most recently executed on the
/// actor. The next task to execute on the actor should be marked as
/// execution-dependent on this object.
///
/// \return The execution dependency returned by the most recently executed
/// task.
const ObjectID GetExecutionDependency() const;
/// Get the execution frontier of the actor, indexed by handle. This captures
/// the execution state of the actor, a summary of which tasks have executed
/// so far.
///
/// \return The actor frontier, a map from handle ID to execution state for
/// that handle.
const std::unordered_map<ActorHandleID, FrontierLeaf, UniqueIDHasher> &GetFrontier()
const;
/// Extend the frontier of the actor by a single task. This should be called
/// whenever the actor executes a task.
///
/// \param handle_id The ID of the handle that submitted the task.
/// \param execution_dependency The object representing the actor's new
/// state. This is the execution dependency returned by the task.
void ExtendFrontier(const ActorHandleID &handle_id,
const ObjectID &execution_dependency);
private:
/// Information from the global actor table about this actor, including the
/// node manager location.
ActorTableDataT actor_table_data_;
/// The object representing the state following the actor's most recently
/// executed task. The next task to execute on the actor should be marked as
/// execution-dependent on this object.
ObjectID execution_dependency_;
/// The execution frontier of the actor, which represents which tasks have
/// executed so far and which tasks may execute next, based on execution
/// dependencies. This is indexed by handle.
std::unordered_map<ActorHandleID, FrontierLeaf, UniqueIDHasher> frontier_;
};
} // namespace raylet
} // namespace ray
#endif // RAY_RAYLET_ACTOR_REGISTRATION_H
+268 -49
View File
@@ -3,6 +3,37 @@
#include "common_protocol.h"
#include "ray/raylet/format/node_manager_generated.h"
namespace {
/// A helper function to determine whether a given actor task has already been executed
/// according to the given actor registry. Returns true if the task is a duplicate.
bool CheckDuplicateActorTask(
const std::unordered_map<ActorID, ray::raylet::ActorRegistration, UniqueIDHasher>
&actor_registry,
const ray::raylet::TaskSpecification &spec) {
auto actor_entry = actor_registry.find(spec.ActorId());
RAY_CHECK(actor_entry != actor_registry.end());
const auto &frontier = actor_entry->second.GetFrontier();
int64_t expected_task_counter = 0;
auto frontier_entry = frontier.find(spec.ActorHandleId());
if (frontier_entry != frontier.end()) {
expected_task_counter = frontier_entry->second.task_counter;
}
if (spec.ActorCounter() < expected_task_counter) {
// The assigned task counter is less than expected. The actor has already
// executed past this task, so do not assign the task again.
RAY_LOG(WARNING) << "A task was resubmitted, so we are ignoring it. This "
<< "should only happen during reconstruction.";
return true;
}
RAY_CHECK(spec.ActorCounter() == expected_task_counter)
<< "Expected actor counter: " << expected_task_counter
<< ", got: " << spec.ActorCounter();
return false;
};
} // namespace
namespace ray {
namespace raylet {
@@ -26,7 +57,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
gcs_client_(gcs_client),
remote_clients_(),
remote_server_connections_(),
object_manager_(object_manager) {
object_manager_(object_manager),
actor_registry_() {
RAY_CHECK(heartbeat_period_ms_ > 0);
// Initialize the resource map with own cluster resource configuration.
ClientID local_client_id = gcs_client_->client_table().GetLocalClientId();
@@ -34,6 +66,39 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
SchedulingResources(config.resource_config));
}
ray::Status NodeManager::RegisterGcs() {
// Register a callback for actor creation notifications.
auto actor_creation_callback = [this](
gcs::AsyncGcsClient *client, const ActorID &actor_id,
const std::vector<ActorTableDataT> &data) { HandleActorCreation(actor_id, data); };
RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe(
UniqueID::nil(), UniqueID::nil(), actor_creation_callback, nullptr));
// Register a callback on the client table for new clients.
auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id,
const ClientTableDataT &data) {
ClientAdded(data);
};
gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added);
// Subscribe to node manager heartbeats.
const auto heartbeat_added = [this](gcs::AsyncGcsClient *client, const ClientID &id,
const HeartbeatTableDataT &heartbeat_data) {
HeartbeatAdded(client, id, heartbeat_data);
};
RAY_RETURN_NOT_OK(gcs_client_->heartbeat_table().Subscribe(
UniqueID::nil(), UniqueID::nil(), heartbeat_added,
[this](gcs::AsyncGcsClient *client) {
RAY_LOG(DEBUG) << "heartbeat table subscription done callback called.";
}));
// Start sending heartbeats to the GCS.
Heartbeat();
return ray::Status::OK();
}
void NodeManager::Heartbeat() {
RAY_LOG(DEBUG) << "[Heartbeat] sending heartbeat.";
auto &heartbeat_table = gcs_client_->heartbeat_table();
@@ -75,27 +140,13 @@ void NodeManager::Heartbeat() {
});
}
void NodeManager::ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id,
const ClientTableDataT &client_data) {
void NodeManager::ClientAdded(const ClientTableDataT &client_data) {
ClientID client_id = ClientID::from_binary(client_data.client_id);
RAY_LOG(DEBUG) << "[ClientAdded] received callback from client id " << client_id.hex();
if (client_id == gcs_client_->client_table().GetLocalClientId()) {
// We got a notification for ourselves, so we are connected to the GCS now.
// Save this NodeManager's resource information in the cluster resource map.
cluster_resource_map_[client_id] = local_resources_;
// Start sending heartbeats to the GCS.
Heartbeat();
// Subscribe to heartbeats.
const auto heartbeat_added = [this](gcs::AsyncGcsClient *client, const ClientID &id,
const HeartbeatTableDataT &heartbeat_data) {
this->HeartbeatAdded(client, id, heartbeat_data);
};
ray::Status status = client->heartbeat_table().Subscribe(
UniqueID::nil(), UniqueID::nil(), heartbeat_added,
[](gcs::AsyncGcsClient *client) {
RAY_LOG(DEBUG) << "heartbeat table subscription done callback called.";
});
RAY_CHECK_OK(status);
return;
}
@@ -154,6 +205,46 @@ void NodeManager::HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &cl
heartbeat_resource_available);
}
void NodeManager::HandleActorCreation(const ActorID &actor_id,
const std::vector<ActorTableDataT> &data) {
RAY_LOG(DEBUG) << "Actor creation notification received: " << actor_id;
// TODO(swang): In presence of failures, data may have size > 1, since the
// actor will have been created multiple times. In that case, we should
// only consider the last entry as valid. All previous entries should have
// a dead node_manager_id.
RAY_CHECK(data.size() == 1);
// Register the new actor.
ActorRegistration actor_registration(data.back());
// Extend the frontier to include the actor creation task. NOTE(swang): The
// creator of the actor is always assigned nil as the actor handle ID.
actor_registration.ExtendFrontier(ActorHandleID::nil(),
actor_registration.GetActorCreationDependency());
auto inserted = actor_registry_.emplace(actor_id, std::move(actor_registration));
RAY_CHECK(inserted.second);
// Dequeue any methods that were submitted before the actor's location was
// known.
const auto &methods = local_queues_.GetUncreatedActorMethods();
std::unordered_set<TaskID, UniqueIDHasher> created_actor_method_ids;
for (const auto &method : methods) {
if (method.GetTaskSpecification().ActorId() == actor_id) {
created_actor_method_ids.insert(method.GetTaskSpecification().TaskId());
}
}
// Resubmit the methods that were submitted before the actor's location was
// known.
auto created_actor_methods = local_queues_.RemoveTasks(created_actor_method_ids);
for (const auto &method : created_actor_methods) {
lineage_cache_.RemoveWaitingTask(method.GetTaskSpecification().TaskId());
// The task's uncommitted lineage was already added to the local lineage
// cache upon the initial submission, so it's okay to resubmit it with an
// empty lineage this time.
SubmitTask(method, Lineage());
}
}
void NodeManager::ProcessNewClient(std::shared_ptr<LocalClientConnection> client) {
// The new client is a worker, so begin listening for messages.
client->ProcessMessages();
@@ -175,31 +266,39 @@ void NodeManager::ProcessClientMessage(std::shared_ptr<LocalClientConnection> cl
}
} break;
case protocol::MessageType_GetTask: {
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
RAY_CHECK(worker);
// If the worker was assigned a task, mark it as finished.
if (!worker->GetAssignedTaskId().is_nil()) {
FinishTask(worker->GetAssignedTaskId());
FinishAssignedTask(worker);
}
// Return the worker to the idle pool.
worker_pool_.PushWorker(worker);
// Check if there is a scheduled task that can now be assigned to the newly
// idle worker.
auto scheduled_tasks = local_queues_.GetScheduledTasks();
if (!scheduled_tasks.empty()) {
const TaskID &scheduled_task_id =
scheduled_tasks.front().GetTaskSpecification().TaskId();
auto scheduled_tasks = local_queues_.RemoveTasks({scheduled_task_id});
AssignTask(scheduled_tasks.front());
// Find a scheduled task that whose actor ID matches that of the newly
// idle worker.
auto worker_actor_id = worker->GetActorId();
for (const auto &task : scheduled_tasks) {
if (task.GetTaskSpecification().ActorId() == worker_actor_id) {
auto scheduled_tasks =
local_queues_.RemoveTasks({task.GetTaskSpecification().TaskId()});
AssignTask(scheduled_tasks.front());
}
}
}
} break;
case protocol::MessageType_DisconnectClient: {
// Remove the dead worker from the pool and stop listening for messages.
const std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
if (worker) {
if (!worker->GetAssignedTaskId().is_nil()) {
// TODO(swang): Clean up any tasks that were assigned to the worker.
// Release any resources that may be held by this worker.
FinishTask(worker->GetAssignedTaskId());
}
// TODO(swang): Handle the case where the worker is killed while
// executing a task. Clean up the assigned task's resources, return an
// error to the driver.
// RAY_CHECK(worker->GetAssignedTaskId().is_nil())
// << "Worker died while executing task: " << worker->GetAssignedTaskId();
worker_pool_.DisconnectWorker(worker);
}
return;
@@ -300,9 +399,57 @@ void NodeManager::ScheduleTasks() {
}
void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineage) {
const TaskSpecification &spec = task.GetTaskSpecification();
// Add the task and its uncommitted lineage to the lineage cache.
lineage_cache_.AddWaitingTask(task, uncommitted_lineage);
// Queue the task according to the availability of its arguments.
if (spec.IsActorTask()) {
// Check whether we know the location of the actor.
const auto actor_entry = actor_registry_.find(spec.ActorId());
if (actor_entry != actor_registry_.end()) {
// We have a known location for the actor.
auto node_manager_id = actor_entry->second.GetNodeManagerId();
if (node_manager_id == gcs_client_->client_table().GetLocalClientId()) {
// The actor is local. Queue the task for local execution.
QueueTask(task);
} else {
// The actor is remote. Forward the task to the node manager that owns
// the actor.
// TODO(swang): Handle forward task failure.
RAY_CHECK_OK(ForwardTask(task, node_manager_id));
}
} else {
// We do not have a registered location for the object, so either the
// actor has not yet been created or we missed the notification for the
// actor creation because this node joined the cluster after the actor
// was already created. Look up the actor's registered location in case
// we missed the creation notification.
// NOTE(swang): This codepath needs to be tested in a cluster setting.
auto lookup_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id,
const std::vector<ActorTableDataT> &data) {
if (!data.empty()) {
// The actor has been created.
HandleActorCreation(actor_id, data);
} else {
// The actor has not yet been created.
// TODO(swang): Set a timer for reconstructing the actor creation
// task.
}
};
RAY_CHECK_OK(gcs_client_->actor_table().Lookup(JobID::nil(), spec.ActorId(),
lookup_callback));
// Keep the task queued until we discover the actor's location.
local_queues_.QueueUncreatedActorMethods({task});
}
} else {
// This is a non-actor task. Queue the task for local execution.
QueueTask(task);
}
}
void NodeManager::QueueTask(const Task &task) {
// Queue the task depending on the availability of its arguments.
if (task_dependency_manager_.TaskReady(task)) {
local_queues_.QueueReadyTasks(std::vector<Task>({task}));
ScheduleTasks();
@@ -312,27 +459,38 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag
}
}
void NodeManager::AssignTask(const Task &task) {
void NodeManager::AssignTask(Task &task) {
const TaskSpecification &spec = task.GetTaskSpecification();
// If this is an actor task, check that the new task has the correct counter.
if (spec.IsActorTask()) {
if (CheckDuplicateActorTask(actor_registry_, spec)) {
// Drop tasks that have already been executed.
return;
}
}
// Resource accounting: acquire resources for the scheduled task.
const ClientID &my_client_id = gcs_client_->client_table().GetLocalClientId();
RAY_CHECK(this->cluster_resource_map_[my_client_id].Acquire(
task.GetTaskSpecification().GetRequiredResources()));
RAY_CHECK(
this->cluster_resource_map_[my_client_id].Acquire(spec.GetRequiredResources()));
if (worker_pool_.PoolSize() == 0) {
worker_pool_.StartWorker();
// Try to get an idle worker that can execute this task.
std::shared_ptr<Worker> worker = worker_pool_.PopWorker(spec.ActorId());
if (worker == nullptr) {
// There are no workers that can execute this task.
if (!spec.IsActorTask()) {
// There are no more non-actor workers available to execute this task.
// Start a new worker.
worker_pool_.StartWorker();
}
// Queue this task for future assignment. The task will be assigned to a
// worker once one becomes available.
local_queues_.QueueScheduledTasks(std::vector<Task>({task}));
return;
}
const TaskSpecification &spec = task.GetTaskSpecification();
std::shared_ptr<Worker> worker = worker_pool_.PopWorker();
RAY_LOG(DEBUG) << "Assigning task to worker with pid " << worker->Pid();
worker->AssignTaskId(spec.TaskId());
local_queues_.QueueRunningTasks(std::vector<Task>({task}));
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb),
fbb.CreateVector(std::vector<int>()));
@@ -340,33 +498,94 @@ void NodeManager::AssignTask(const Task &task) {
auto status = worker->Connection()->WriteMessage(protocol::MessageType_ExecuteTask,
fbb.GetSize(), fbb.GetBufferPointer());
if (status.ok()) {
// We successfully assigned the task to the worker.
worker->AssignTaskId(spec.TaskId());
// If the task was an actor task, then record this execution to guarantee
// consistency in the case of reconstruction.
if (spec.IsActorTask()) {
// Extend the frontier to include the executing task.
auto actor_entry = actor_registry_.find(spec.ActorId());
RAY_CHECK(actor_entry != actor_registry_.end());
actor_entry->second.ExtendFrontier(spec.ActorHandleId(), spec.ActorDummyObject());
// Update the task's execution dependencies to reflect the actual
// execution order, to support deterministic reconstruction.
// NOTE(swang): The update of an actor task's execution dependencies is
// performed asynchronously. This means that if this node manager dies,
// we may lose updates that are in flight to the task table. We only
// guarantee deterministic reconstruction ordering for tasks whose
// updates are reflected in the task table.
TaskExecutionSpecification &mutable_spec = task.GetTaskExecutionSpec();
mutable_spec.SetExecutionDependencies(
{actor_entry->second.GetExecutionDependency()});
}
// We started running the task, so the task is ready to write to GCS.
lineage_cache_.AddReadyTask(task);
// Mark the task as running.
local_queues_.QueueRunningTasks(std::vector<Task>({task}));
} else {
// We failed to send the task to the worker, so disconnect the worker. The
// task will get queued again during cleanup.
RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client";
// We failed to send the task to the worker, so disconnect the worker.
ProcessClientMessage(worker->Connection(), protocol::MessageType_DisconnectClient,
NULL);
// Queue this task for future assignment. The task will be assigned to a
// worker once one becomes available.
local_queues_.QueueScheduledTasks(std::vector<Task>({task}));
}
}
void NodeManager::FinishTask(const TaskID &task_id) {
RAY_LOG(DEBUG) << "Finished task " << task_id.hex();
void NodeManager::FinishAssignedTask(std::shared_ptr<Worker> worker) {
TaskID task_id = worker->GetAssignedTaskId();
RAY_LOG(DEBUG) << "Finished task " << task_id;
auto tasks = local_queues_.RemoveTasks({task_id});
RAY_CHECK(tasks.size() == 1);
auto task = *tasks.begin();
// Resource accounting: release task's resources.
RAY_CHECK(
this->cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release(
task.GetTaskSpecification().GetRequiredResources()));
if (task.GetTaskSpecification().IsActorCreationTask()) {
// If this was an actor creation task, then convert the worker to an actor.
auto actor_id = task.GetTaskSpecification().ActorCreationId();
worker->AssignActorId(actor_id);
// Publish the actor creation event to all other nodes so that methods for
// the actor will be forwarded directly to this node.
auto actor_notification = std::make_shared<ActorTableDataT>();
actor_notification->actor_id = actor_id.binary();
actor_notification->actor_creation_dummy_object_id =
task.GetTaskSpecification().ActorCreationDummyObjectId().binary();
// TODO(swang): The driver ID.
actor_notification->driver_id = JobID::nil().binary();
actor_notification->node_manager_id =
gcs_client_->client_table().GetLocalClientId().binary();
RAY_LOG(DEBUG) << "Publishing actor creation: " << actor_id;
RAY_CHECK_OK(gcs_client_->actor_table().Append(JobID::nil(), actor_id,
actor_notification, nullptr));
// Resources required by an actor creation task are acquired for the
// lifetime of the actor, so we do not release any resources here.
} else {
// Release task's resources.
RAY_CHECK(this->cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()]
.Release(task.GetTaskSpecification().GetRequiredResources()));
}
// If the finished task was an actor task, mark the returned dummy object as
// locally available. This is not added to the object table, so the update
// will be invisible to both the local object manager and the other nodes.
// NOTE(swang): These objects are never cleaned up. We should consider
// removing the objects, e.g., when an actor is terminated.
if (task.GetTaskSpecification().IsActorCreationTask() ||
task.GetTaskSpecification().IsActorTask()) {
auto dummy_object = task.GetTaskSpecification().ActorDummyObject();
task_dependency_manager_.MarkDependencyReady(dummy_object);
}
// Unset the worker's assigned task.
worker->AssignTaskId(TaskID::nil());
}
void NodeManager::ResubmitTask(const TaskID &task_id) {
throw std::runtime_error("Method not implemented");
}
ray::Status NodeManager::ForwardTask(Task &task, const ClientID &node_id) {
ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) {
auto task_id = task.GetTaskSpecification().TaskId();
// Get and serialize the task's uncommitted lineage.
+17 -7
View File
@@ -5,6 +5,7 @@
#include "ray/raylet/task.h"
#include "ray/object_manager/object_manager.h"
#include "ray/common/client_connection.h"
#include "ray/raylet/actor_registration.h"
#include "ray/raylet/lineage_cache.h"
#include "ray/raylet/scheduling_policy.h"
#include "ray/raylet/scheduling_queue.h"
@@ -53,26 +54,34 @@ class NodeManager {
void ProcessNodeManagerMessage(std::shared_ptr<TcpClientConnection> node_manager_client,
int64_t message_type, const uint8_t *message);
void ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id,
const ClientTableDataT &data);
ray::Status RegisterGcs();
void HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &id,
const HeartbeatTableDataT &data);
private:
// Handler for the addition of a new GCS client.
void ClientAdded(const ClientTableDataT &data);
// Handler for the creation of an actor, possibly on a remote node.
void HandleActorCreation(const ActorID &actor_id,
const std::vector<ActorTableDataT> &data);
// Queue a task for local execution.
void QueueTask(const Task &task);
/// Submit a task to this node.
void SubmitTask(const Task &task, const Lineage &uncommitted_lineage);
/// Assign a task.
void AssignTask(const Task &task);
/// Finish a task.
void FinishTask(const TaskID &task_id);
/// Assign a task. The task is assumed to not be queued in local_queues_.
void AssignTask(Task &task);
/// Handle a worker finishing its assigned task.
void FinishAssignedTask(std::shared_ptr<Worker> worker);
/// Schedule tasks.
void ScheduleTasks();
/// Handle a task whose local dependencies were missing and are now available.
void HandleWaitingTaskReady(const TaskID &task_id);
/// Resubmit a task whose return value needs to be reconstructed.
void ResubmitTask(const TaskID &task_id);
ray::Status ForwardTask(Task &task, const ClientID &node_id);
/// Forward a task to another node to execute. The task is assumed to not be
/// queued in local_queues_.
ray::Status ForwardTask(const Task &task, const ClientID &node_id);
/// Send heartbeats to the GCS.
void Heartbeat();
@@ -101,6 +110,7 @@ class NodeManager {
std::unordered_map<ClientID, TcpServerConnection, UniqueIDHasher>
remote_server_connections_;
ObjectManager &object_manager_;
std::unordered_map<ActorID, ActorRegistration, UniqueIDHasher> actor_registry_;
};
} // namespace raylet
+2 -5
View File
@@ -73,11 +73,8 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
<< " port " << client_info.node_manager_port;
RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info));
auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id,
const ClientTableDataT &data) {
node_manager_.ClientAdded(client, id, data);
};
gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added);
RAY_RETURN_NOT_OK(node_manager_.RegisterGcs());
return Status::OK();
}
+9 -8
View File
@@ -6,6 +6,10 @@ namespace ray {
namespace raylet {
const std::list<Task> &SchedulingQueue::GetUncreatedActorMethods() const {
return this->uncreated_actor_methods_;
}
const std::list<Task> &SchedulingQueue::GetWaitingTasks() const {
return this->waiting_tasks_;
}
@@ -56,6 +60,7 @@ std::vector<Task> SchedulingQueue::RemoveTasks(
std::vector<Task> removed_tasks;
// Try to find the tasks to remove from the waiting tasks.
removeTasksFromQueue(uncreated_actor_methods_, task_ids, removed_tasks);
removeTasksFromQueue(waiting_tasks_, task_ids, removed_tasks);
removeTasksFromQueue(ready_tasks_, task_ids, removed_tasks);
removeTasksFromQueue(scheduled_tasks_, task_ids, removed_tasks);
@@ -66,6 +71,10 @@ std::vector<Task> SchedulingQueue::RemoveTasks(
return removed_tasks;
}
void SchedulingQueue::QueueUncreatedActorMethods(const std::vector<Task> &tasks) {
queueTasks(uncreated_actor_methods_, tasks);
}
void SchedulingQueue::QueueWaitingTasks(const std::vector<Task> &tasks) {
queueTasks(waiting_tasks_, tasks);
}
@@ -82,14 +91,6 @@ void SchedulingQueue::QueueRunningTasks(const std::vector<Task> &tasks) {
queueTasks(running_tasks_, tasks);
}
// RegisterActor is responsible for recording provided actor_information
// in the actor registry.
bool SchedulingQueue::RegisterActor(ActorID actor_id,
const ActorInformation &actor_information) {
actor_registry_[actor_id] = actor_information;
return true;
}
} // namespace raylet
} // namespace ray
+14 -9
View File
@@ -6,7 +6,6 @@
#include <unordered_set>
#include <vector>
#include "ray/raylet/actor.h"
#include "ray/raylet/task.h"
namespace ray {
@@ -29,6 +28,13 @@ class SchedulingQueue {
/// SchedulingQueue destructor.
virtual ~SchedulingQueue() {}
/// Get the queue of tasks that are destined for actors that have not yet
/// been created.
///
/// \return A const reference to the queue of tasks that are destined for
/// actors that have not yet been created.
const std::list<Task> &GetUncreatedActorMethods() const;
/// Get the queue of tasks in the waiting state.
///
/// \return A const reference to the queue of tasks that are waiting for
@@ -66,6 +72,11 @@ class SchedulingQueue {
/// \return A vector of the tasks that were removed.
std::vector<Task> RemoveTasks(std::unordered_set<TaskID, UniqueIDHasher> tasks);
/// Queue tasks that are destined for actors that have not yet been created.
///
/// \param tasks The tasks to queue.
void QueueUncreatedActorMethods(const std::vector<Task> &tasks);
/// Queue tasks in the waiting state.
///
/// \param tasks The tasks to queue.
@@ -86,13 +97,9 @@ class SchedulingQueue {
/// \param tasks The tasks to queue.
void QueueRunningTasks(const std::vector<Task> &tasks);
/// Register an actor.
///
/// \param actor_id The ID of the actor to register.
/// \param actor_information Information about the actor.
bool RegisterActor(ActorID actor_id, const ActorInformation &actor_information);
private:
/// Tasks that are destined for actors that have not yet been created.
std::list<Task> uncreated_actor_methods_;
/// Tasks that are waiting for an object dependency to appear locally.
std::list<Task> waiting_tasks_;
/// Tasks whose object dependencies are locally available, but that are
@@ -102,8 +109,6 @@ class SchedulingQueue {
std::list<Task> scheduled_tasks_;
/// Tasks that are running on a worker.
std::list<Task> running_tasks_;
/// The registry of known actors.
std::unordered_map<ActorID, ActorInformation, UniqueIDHasher> actor_registry_;
};
} // namespace raylet
+1 -1
View File
@@ -105,7 +105,7 @@ void TaskDependencyManager::UnsubscribeTaskReady(const TaskID &task_id) {
}
void TaskDependencyManager::MarkDependencyReady(const ObjectID &object) {
throw std::runtime_error("Method not implemented");
handleObjectReady(object);
}
} // namespace raylet
+54 -14
View File
@@ -38,11 +38,19 @@ TaskSpecification::TaskSpecification(const flatbuffers::String &string) {
}
TaskSpecification::TaskSpecification(
UniqueID driver_id, TaskID parent_task_id, int64_t parent_counter,
// UniqueID actor_id,
// UniqueID actor_handle_id,
// int64_t actor_counter,
FunctionID function_id,
const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const FunctionID &function_id,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources)
: TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::nil(),
ObjectID::nil(), ActorID::nil(), ActorHandleID::nil(), -1,
function_id, task_arguments, num_returns, required_resources) {}
TaskSpecification::TaskSpecification(
const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id,
const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter,
const FunctionID &function_id,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources)
: spec_() {
@@ -54,10 +62,6 @@ TaskSpecification::TaskSpecification(
sha256_update(&ctx, (BYTE *)&driver_id, sizeof(driver_id));
sha256_update(&ctx, (BYTE *)&parent_task_id, sizeof(parent_task_id));
sha256_update(&ctx, (BYTE *)&parent_counter, sizeof(parent_counter));
// sha256_update(&ctx, (BYTE *) &actor_id, sizeof(actor_id));
// sha256_update(&ctx, (BYTE *) &actor_counter, sizeof(actor_counter));
// sha256_update(&ctx, (BYTE *) &is_actor_checkpoint_method,
// sizeof(is_actor_checkpoint_method));
// Compute the final task ID from the hash.
BYTE buff[DIGEST_SIZE];
@@ -82,11 +86,11 @@ TaskSpecification::TaskSpecification(
// Serialize the TaskSpecification.
auto spec = CreateTaskInfo(
fbb, to_flatbuf(fbb, driver_id), to_flatbuf(fbb, task_id),
to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, ActorID::nil()),
to_flatbuf(fbb, ActorID::nil()), to_flatbuf(fbb, WorkerID::nil()),
to_flatbuf(fbb, ActorHandleID::nil()), 0, false, to_flatbuf(fbb, function_id),
fbb.CreateVector(arguments), fbb.CreateVector(returns),
map_to_flatbuf(fbb, required_resources));
to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id),
to_flatbuf(fbb, actor_creation_dummy_object_id), to_flatbuf(fbb, actor_id),
to_flatbuf(fbb, actor_handle_id), actor_counter, false,
to_flatbuf(fbb, function_id), fbb.CreateVector(arguments),
fbb.CreateVector(returns), map_to_flatbuf(fbb, required_resources));
fbb.Finish(spec);
AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize());
}
@@ -165,6 +169,42 @@ const ResourceSet TaskSpecification::GetRequiredResources() const {
return ResourceSet(required_resources);
}
bool TaskSpecification::IsActorCreationTask() const {
return !ActorCreationId().is_nil();
}
bool TaskSpecification::IsActorTask() const { return !ActorId().is_nil(); }
ActorID TaskSpecification::ActorCreationId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->actor_creation_id());
}
ObjectID TaskSpecification::ActorCreationDummyObjectId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->actor_creation_dummy_object_id());
}
ActorID TaskSpecification::ActorId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->actor_id());
}
ActorHandleID TaskSpecification::ActorHandleId() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return from_flatbuf(*message->actor_handle_id());
}
int64_t TaskSpecification::ActorCounter() const {
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
return message->actor_counter();
}
ObjectID TaskSpecification::ActorDummyObject() const {
RAY_CHECK(IsActorTask() || IsActorCreationTask());
return ReturnId(NumReturns() - 1);
}
} // namespace raylet
} // namespace ray
+21 -5
View File
@@ -94,15 +94,21 @@ class TaskSpecification {
/// \param arguments The list of task arguments.
/// \param num_returns The number of values returned by the task.
/// \param required_resources The task's resource demands.
TaskSpecification(UniqueID driver_id, TaskID parent_task_id, int64_t parent_counter,
// UniqueID actor_id,
// UniqueID actor_handle_id,
// int64_t actor_counter,
FunctionID function_id,
TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id,
int64_t parent_counter, const FunctionID &function_id,
const std::vector<std::shared_ptr<TaskArgument>> &arguments,
int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources);
TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id,
int64_t parent_counter, const ActorID &actor_creation_id,
const ObjectID &actor_creation_dummy_object_id,
const ActorID &actor_id, const ActorHandleID &actor_handle_id,
int64_t actor_counter, const FunctionID &function_id,
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments,
int64_t num_returns,
const std::unordered_map<std::string, double> &required_resources);
~TaskSpecification() {}
/// Serialize the TaskSpecification to a flatbuffer.
@@ -129,6 +135,16 @@ class TaskSpecification {
double GetRequiredResource(const std::string &resource_name) const;
const ResourceSet GetRequiredResources() const;
// Methods specific to actor tasks.
bool IsActorCreationTask() const;
bool IsActorTask() const;
ActorID ActorCreationId() const;
ObjectID ActorCreationDummyObjectId() const;
ActorID ActorId() const;
ActorHandleID ActorHandleId() const;
int64_t ActorCounter() const;
ObjectID ActorDummyObject() const;
private:
/// Assign the specification data from a pointer.
void AssignSpecification(const uint8_t *spec, size_t spec_size);
+13 -1
View File
@@ -12,7 +12,10 @@ namespace raylet {
/// A constructor responsible for initializing the state of a worker.
Worker::Worker(pid_t pid, std::shared_ptr<LocalClientConnection> connection)
: pid_(pid), connection_(connection), assigned_task_id_(TaskID::nil()) {}
: pid_(pid),
connection_(connection),
assigned_task_id_(TaskID::nil()),
actor_id_(ActorID::nil()) {}
pid_t Worker::Pid() const { return pid_; }
@@ -20,6 +23,15 @@ void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id;
const TaskID &Worker::GetAssignedTaskId() const { return assigned_task_id_; }
void Worker::AssignActorId(const ActorID &actor_id) {
RAY_CHECK(actor_id_.is_nil())
<< "A worker that is already an actor cannot be assigned an actor ID again.";
RAY_CHECK(!actor_id.is_nil());
actor_id_ = actor_id;
}
const ActorID &Worker::GetActorId() const { return actor_id_; }
const std::shared_ptr<LocalClientConnection> Worker::Connection() const {
return connection_;
}
+5
View File
@@ -23,6 +23,8 @@ class Worker {
pid_t Pid() const;
void AssignTaskId(const TaskID &task_id);
const TaskID &GetAssignedTaskId() const;
void AssignActorId(const ActorID &actor_id);
const ActorID &GetActorId() const;
/// Return the worker's connection.
const std::shared_ptr<LocalClientConnection> Connection() const;
@@ -31,7 +33,10 @@ class Worker {
pid_t pid_;
/// Connection state of a worker.
std::shared_ptr<LocalClientConnection> connection_;
/// The worker's currently assigned task.
TaskID assigned_task_id_;
/// The worker's actor ID. If this is nil, then the worker is not an actor.
ActorID actor_id_;
};
} // namespace raylet
+21 -10
View File
@@ -51,14 +51,12 @@ void WorkerPool::StartWorker() {
RAY_LOG(FATAL) << "Failed to start worker with return value " << rv;
}
uint32_t WorkerPool::PoolSize() const { return pool_.size(); }
void WorkerPool::RegisterWorker(std::shared_ptr<Worker> worker) {
RAY_LOG(DEBUG) << "Registering worker with pid " << worker->Pid();
registered_workers_.push_back(worker);
}
const std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
std::shared_ptr<LocalClientConnection> connection) const {
for (auto it = registered_workers_.begin(); it != registered_workers_.end(); it++) {
if ((*it)->Connection() == connection) {
@@ -70,17 +68,30 @@ const std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
void WorkerPool::PushWorker(std::shared_ptr<Worker> worker) {
// Since the worker is now idle, unset its assigned task ID.
worker->AssignTaskId(TaskID::nil());
RAY_CHECK(worker->GetAssignedTaskId().is_nil())
<< "Idle workers cannot have an assigned task ID";
// Add the worker to the idle pool.
pool_.push_back(std::move(worker));
if (worker->GetActorId().is_nil()) {
pool_.push_back(std::move(worker));
} else {
actor_pool_[worker->GetActorId()] = std::move(worker);
}
}
std::shared_ptr<Worker> WorkerPool::PopWorker() {
if (pool_.empty()) {
return nullptr;
std::shared_ptr<Worker> WorkerPool::PopWorker(const ActorID &actor_id) {
std::shared_ptr<Worker> worker = nullptr;
if (actor_id.is_nil()) {
if (!pool_.empty()) {
worker = std::move(pool_.back());
pool_.pop_back();
}
} else {
auto actor_entry = actor_pool_.find(actor_id);
if (actor_entry != actor_pool_.end()) {
worker = std::move(actor_entry->second);
actor_pool_.erase(actor_entry);
}
}
std::shared_ptr<Worker> worker = std::move(pool_.back());
pool_.pop_back();
return worker;
}
+9 -10
View File
@@ -3,6 +3,7 @@
#include <inttypes.h>
#include <list>
#include <unordered_map>
#include "ray/common/client_connection.h"
#include "ray/raylet/worker.h"
@@ -30,11 +31,6 @@ class WorkerPool {
/// Destructor responsible for freeing a set of workers owned by this class.
~WorkerPool();
/// Get the number of idle workers in the pool.
///
/// \return The number of idle workers.
uint32_t PoolSize() const;
/// Asynchronously start a new worker process. Once the worker process has
/// registered with an external server, the process should create and
/// register a new Worker, then add itself to the pool. Failure to start
@@ -52,7 +48,7 @@ class WorkerPool {
/// \param The client connection owned by a registered worker.
/// \return The Worker that owns the given client connection. Returns nullptr
/// if the client has not registered a worker yet.
const std::shared_ptr<Worker> GetRegisteredWorker(
std::shared_ptr<Worker> GetRegisteredWorker(
std::shared_ptr<LocalClientConnection> connection) const;
/// Disconnect a registered worker.
@@ -61,8 +57,7 @@ class WorkerPool {
/// \return Whether the given worker was in the pool of idle workers.
bool DisconnectWorker(std::shared_ptr<Worker> worker);
/// Add an idle worker to the pool. The worker's task assignment will be
/// reset.
/// Add an idle worker to the pool.
///
/// \param The idle worker to add.
void PushWorker(std::shared_ptr<Worker> worker);
@@ -70,13 +65,17 @@ class WorkerPool {
/// Pop an idle worker from the pool. The caller is responsible for pushing
/// the worker back onto the pool once the worker has completed its work.
///
/// \return An idle worker. Returns nullptr if the pool is empty.
std::shared_ptr<Worker> PopWorker();
/// \param actor_id The returned worker must have this actor ID.
/// \return An idle worker with the requested actor ID. Returns nullptr if no
/// such worker exists.
std::shared_ptr<Worker> PopWorker(const ActorID &actor_id);
private:
std::vector<std::string> worker_command_;
/// The pool of idle workers.
std::list<std::shared_ptr<Worker>> pool_;
/// The pool of idle actor workers.
std::unordered_map<ActorID, std::shared_ptr<Worker>, UniqueIDHasher> actor_pool_;
/// All workers that have registered and are still connected, including both
/// idle and executing.
// TODO(swang): Make this a map to make GetRegisteredWorker faster.
+25 -4
View File
@@ -50,7 +50,7 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
TEST_F(WorkerPoolTest, HandleWorkerPushPop) {
// Try to pop a worker from the empty pool and make sure we don't get one.
std::shared_ptr<Worker> popped_worker;
popped_worker = worker_pool_.PopWorker();
popped_worker = worker_pool_.PopWorker(ActorID::nil());
ASSERT_EQ(popped_worker, nullptr);
// Create some workers.
@@ -61,15 +61,36 @@ TEST_F(WorkerPoolTest, HandleWorkerPushPop) {
for (auto &worker : workers) {
worker_pool_.PushWorker(worker);
}
ASSERT_EQ(worker_pool_.PoolSize(), workers.size());
// Pop two workers and make sure they're one of the workers we created.
popped_worker = worker_pool_.PopWorker();
popped_worker = worker_pool_.PopWorker(ActorID::nil());
ASSERT_NE(popped_worker, nullptr);
ASSERT_TRUE(workers.count(popped_worker) > 0);
popped_worker = worker_pool_.PopWorker();
popped_worker = worker_pool_.PopWorker(ActorID::nil());
ASSERT_NE(popped_worker, nullptr);
ASSERT_TRUE(workers.count(popped_worker) > 0);
popped_worker = worker_pool_.PopWorker(ActorID::nil());
ASSERT_EQ(popped_worker, nullptr);
}
TEST_F(WorkerPoolTest, PopActorWorker) {
// Create a worker.
auto worker = CreateWorker(1234);
// Add the worker to the pool.
worker_pool_.PushWorker(worker);
// Assign an actor ID to the worker.
auto actor = worker_pool_.PopWorker(ActorID::nil());
auto actor_id = ActorID::from_random();
actor->AssignActorId(actor_id);
worker_pool_.PushWorker(actor);
// Check that there are no more non-actor workers.
ASSERT_EQ(worker_pool_.PopWorker(ActorID::nil()), nullptr);
// Check that we can pop the actor worker.
actor = worker_pool_.PopWorker(actor_id);
ASSERT_EQ(actor, worker);
ASSERT_EQ(actor->GetActorId(), actor_id);
}
} // namespace raylet
+15
View File
@@ -47,3 +47,18 @@ def test_basic_task_api(ray_start):
# Test arguments passed by ID.
# Test keyword arguments.
def test_actor_api(ray_start):
@ray.remote
class Foo(object):
def __init__(self, val):
self.x = val
def get(self):
return self.x
x = 1
f = Foo.remote(x)
assert (ray.get(f.get.remote()) == x)