From bf194db4bcd8bbcec2217727b22ec6b8ac2b245b Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Fri, 6 Apr 2018 00:17:14 -0700 Subject: [PATCH] [xray] Basic actor support (#1835) --- python/ray/actor.py | 2 + src/common/redis_module/ray_redis_module.cc | 11 +- .../local_scheduler_algorithm.cc | 6 + src/ray/CMakeLists.txt | 2 +- src/ray/gcs/client.cc | 3 + src/ray/gcs/client.h | 3 +- src/ray/gcs/format/gcs.fbs | 13 +- src/ray/gcs/tables.cc | 1 + src/ray/gcs/tables.h | 12 +- src/ray/raylet/actor.cc | 15 - src/ray/raylet/actor.h | 31 -- src/ray/raylet/actor_registration.cc | 41 +++ src/ray/raylet/actor_registration.h | 96 ++++++ src/ray/raylet/node_manager.cc | 317 +++++++++++++++--- src/ray/raylet/node_manager.h | 24 +- src/ray/raylet/raylet.cc | 7 +- src/ray/raylet/scheduling_queue.cc | 17 +- src/ray/raylet/scheduling_queue.h | 23 +- src/ray/raylet/task_dependency_manager.cc | 2 +- src/ray/raylet/task_spec.cc | 68 +++- src/ray/raylet/task_spec.h | 26 +- src/ray/raylet/worker.cc | 14 +- src/ray/raylet/worker.h | 5 + src/ray/raylet/worker_pool.cc | 31 +- src/ray/raylet/worker_pool.h | 19 +- src/ray/raylet/worker_pool_test.cc | 29 +- test/xray_test.py | 15 + 27 files changed, 652 insertions(+), 181 deletions(-) delete mode 100644 src/ray/raylet/actor.cc delete mode 100644 src/ray/raylet/actor.h create mode 100644 src/ray/raylet/actor_registration.cc create mode 100644 src/ray/raylet/actor_registration.h diff --git a/python/ray/actor.py b/python/ray/actor.py index 04d78e926..dd843cffa 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -802,6 +802,8 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources, actor_creation_resources, actor_method_cpus, ray.worker.global_worker) + # Increment the actor counter to account for the creation task. + actor_counter += 1 # Instantiate the actor handle. actor_object = cls.__new__(cls) diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index 984ea30bd..14d241c24 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -49,14 +49,11 @@ return RedisModule_ReplyWithError(ctx, (MESSAGE)); \ } +// NOTE(swang): The order of prefixes here must match the TablePrefix enum +// defined in src/ray/gcs/format/gcs.fbs. static const char *table_prefixes[] = { - NULL, - "TASK:", - "TASK:", - "CLIENT:", - "OBJECT:", - "FUNCTION:", - "TASK_RECONSTRUCTION:", + NULL, "TASK:", "TASK:", "CLIENT:", + "OBJECT:", "ACTOR:", "FUNCTION:", "TASK_RECONSTRUCTION:", "HEARTBEAT:", }; diff --git a/src/local_scheduler/local_scheduler_algorithm.cc b/src/local_scheduler/local_scheduler_algorithm.cc index f809390e7..d13d26a16 100644 --- a/src/local_scheduler/local_scheduler_algorithm.cc +++ b/src/local_scheduler/local_scheduler_algorithm.cc @@ -358,6 +358,12 @@ void handle_convert_worker_to_actor( * filled out, so fill out the correct worker field now. */ algorithm_state->local_actor_infos[actor_id].worker = worker; } + /* Increment the task counter for the creator's handle to account for the + * actor creation task. */ + auto &task_counters = + algorithm_state->local_actor_infos[actor_id].task_counters; + RAY_CHECK(task_counters[ActorHandleID::nil()] == 0); + task_counters[ActorHandleID::nil()]++; } /** diff --git a/src/ray/CMakeLists.txt b/src/ray/CMakeLists.txt index a0a34f61e..e19e9d9bb 100644 --- a/src/ray/CMakeLists.txt +++ b/src/ray/CMakeLists.txt @@ -52,7 +52,7 @@ set(RAY_SRCS raylet/worker.cc raylet/worker_pool.cc raylet/scheduling_resources.cc - raylet/actor.cc + raylet/actor_registration.cc raylet/scheduling_queue.cc raylet/scheduling_policy.cc raylet/task_dependency_manager.cc diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index d100a2ed7..f29c1f6ff 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -10,6 +10,7 @@ AsyncGcsClient::AsyncGcsClient(const ClientID &client_id) { context_.reset(new RedisContext()); client_table_.reset(new ClientTable(context_, this, client_id)); object_table_.reset(new ObjectTable(context_, this)); + actor_table_.reset(new ActorTable(context_, this)); task_table_.reset(new TaskTable(context_, this)); raylet_task_table_.reset(new raylet::TaskTable(context_, this)); task_reconstruction_log_.reset(new TaskReconstructionLog(context_, this)); @@ -48,6 +49,8 @@ TaskTable &AsyncGcsClient::task_table() { return *task_table_; } raylet::TaskTable &AsyncGcsClient::raylet_task_table() { return *raylet_task_table_; } +ActorTable &AsyncGcsClient::actor_table() { return *actor_table_; } + TaskReconstructionLog &AsyncGcsClient::task_reconstruction_log() { return *task_reconstruction_log_; } diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 99d8e6a65..bfe75ebba 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -45,12 +45,12 @@ class RAY_EXPORT AsyncGcsClient { inline FunctionTable &function_table(); // TODO: Some API for getting the error on the driver inline ClassTable &class_table(); - inline ActorTable &actor_table(); inline CustomSerializerTable &custom_serializer_table(); inline ConfigTable &config_table(); ObjectTable &object_table(); TaskTable &task_table(); raylet::TaskTable &raylet_task_table(); + ActorTable &actor_table(); TaskReconstructionLog &task_reconstruction_log(); ClientTable &client_table(); HeartbeatTable &heartbeat_table(); @@ -72,6 +72,7 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr object_table_; std::unique_ptr task_table_; std::unique_ptr raylet_task_table_; + std::unique_ptr actor_table_; std::unique_ptr task_reconstruction_log_; std::unique_ptr heartbeat_table_; std::unique_ptr client_table_; diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 085eb0d55..63de1a607 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -10,9 +10,10 @@ enum TablePrefix:int { RAYLET_TASK, CLIENT, OBJECT, + ACTOR, FUNCTION, TASK_RECONSTRUCTION, - HEARTBEAT + HEARTBEAT, } // The channel that Add operations to the Table should be published on, if any. @@ -89,6 +90,16 @@ table ClassTableData { } table ActorTableData { + // The ID of the actor that was created. + actor_id: string; + // The dummy object ID returned by the actor creation task. If the actor + // dies, then this is the object that should be reconstructed for the actor + // to be recreated. + actor_creation_dummy_object_id: string; + // The ID of the driver that created the actor. + driver_id: string; + // The ID of the node manager that created the actor. + node_manager_id: string; } table ErrorTableData { diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 293608ac6..c8ef17d46 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -344,6 +344,7 @@ template class Log; template class Log; template class Table; template class Table; +template class Log; template class Log; template class Table; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index e2053b2c9..0afa70853 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -277,14 +277,22 @@ class FunctionTable : public Table { using ClassTable = Table; // TODO(swang): Set the pubsub channel for the actor table. -using ActorTable = Table; +class ActorTable : public Log { + public: + ActorTable(const std::shared_ptr &context, AsyncGcsClient *client) + : Log(context, client) { + pubsub_channel_ = TablePubsub_ACTOR; + prefix_ = TablePrefix_TASK_RECONSTRUCTION; + } +}; class TaskReconstructionLog : public Log { public: TaskReconstructionLog(const std::shared_ptr &context, AsyncGcsClient *client) : Log(context, client) { - prefix_ = TablePrefix_TASK_RECONSTRUCTION; + pubsub_channel_ = TablePubsub_ACTOR; + prefix_ = TablePrefix_ACTOR; } }; diff --git a/src/ray/raylet/actor.cc b/src/ray/raylet/actor.cc deleted file mode 100644 index 0dd2487c7..000000000 --- a/src/ray/raylet/actor.cc +++ /dev/null @@ -1,15 +0,0 @@ -#include "actor.h" - -namespace ray { - -namespace raylet { - -ActorInformation::ActorInformation() : id_(UniqueID::nil()) {} - -ActorInformation::~ActorInformation() {} - -const ActorID &ActorInformation::GetActorId() const { return this->id_; } - -} // namespace raylet - -} // namespace ray diff --git a/src/ray/raylet/actor.h b/src/ray/raylet/actor.h deleted file mode 100644 index 25f9e2dde..000000000 --- a/src/ray/raylet/actor.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef RAY_RAYLET_ACTOR_H -#define RAY_RAYLET_ACTOR_H - -#include "ray/id.h" - -namespace ray { - -namespace raylet { - -class ActorInformation { - public: - /// \brief ActorInformation constructor. - ActorInformation(); - - /// \brief ActorInformation destructor. - ~ActorInformation(); - - /// \brief Return the id of this actor. - /// \return actor id. - const ActorID &GetActorId() const; - - private: - /// Unique identifier for this actor. - ActorID id_; -}; // class ActorInformation - -} // namespace raylet - -} // namespace ray - -#endif // RAY_RAYLET_ACTOR_H diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc new file mode 100644 index 000000000..c1e6303fb --- /dev/null +++ b/src/ray/raylet/actor_registration.cc @@ -0,0 +1,41 @@ +#include "ray/raylet/actor_registration.h" + +#include "ray/util/logging.h" + +namespace ray { + +namespace raylet { + +ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) + : actor_table_data_(actor_table_data), + execution_dependency_(ObjectID::nil()), + frontier_() {} + +const ClientID ActorRegistration::GetNodeManagerId() const { + return ClientID::from_binary(actor_table_data_.node_manager_id); +} + +const ObjectID ActorRegistration::GetActorCreationDependency() const { + return ObjectID::from_binary(actor_table_data_.actor_creation_dummy_object_id); +} + +const ObjectID ActorRegistration::GetExecutionDependency() const { + return execution_dependency_; +} + +const std::unordered_map + &ActorRegistration::GetFrontier() const { + return frontier_; +} + +void ActorRegistration::ExtendFrontier(const ActorHandleID &handle_id, + const ObjectID &execution_dependency) { + auto &frontier_entry = frontier_[handle_id]; + frontier_entry.task_counter++; + frontier_entry.execution_dependency = execution_dependency; + execution_dependency_ = execution_dependency; +} + +} // namespace raylet + +} // namespace ray diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h new file mode 100644 index 000000000..486be2719 --- /dev/null +++ b/src/ray/raylet/actor_registration.h @@ -0,0 +1,96 @@ +#ifndef RAY_RAYLET_ACTOR_REGISTRATION_H +#define RAY_RAYLET_ACTOR_REGISTRATION_H + +#include + +#include "ray/gcs/format/gcs_generated.h" +#include "ray/id.h" + +namespace ray { + +namespace raylet { + +/// \class ActorRegistration +/// +/// Information about an actor registered in the system. This includes the +/// actor's current node manager location, and if local, information about its +/// current execution state, used for reconstruction purposes. +class ActorRegistration { + public: + /// Create an actor registration. + /// + /// \param actor_table_data Information from the global actor table about + /// this actor. This includes the actor's node manager location. + ActorRegistration(const ActorTableDataT &actor_table_data); + + /// Each actor may have multiple callers, or "handles". A frontier leaf + /// represents the execution state of the actor with respect to a single + /// handle. + struct FrontierLeaf { + /// The number of tasks submitted by this handle that have executed on the + /// actor so far. + int64_t task_counter; + /// The execution dependency returned by the task submitted by this handle + /// that most recently executed on the actor. + ObjectID execution_dependency; + }; + + /// Get the actor's node manager location. + /// + /// \return The actor's node manager location. All tasks for the actor should + /// be forwarded to this node. + const ClientID GetNodeManagerId() const; + + /// Get the object that represents the actor's initial state. This is the + /// execution dependency returned by this actor's creation task. If + /// reconstructed, this will recreate the actor. + /// + /// \return The execution dependency returned by the actor's creation task. + const ObjectID GetActorCreationDependency() const; + + /// Get the object that represents the actor's current state. This is the + /// execution dependency returned by the task most recently executed on the + /// actor. The next task to execute on the actor should be marked as + /// execution-dependent on this object. + /// + /// \return The execution dependency returned by the most recently executed + /// task. + const ObjectID GetExecutionDependency() const; + + /// Get the execution frontier of the actor, indexed by handle. This captures + /// the execution state of the actor, a summary of which tasks have executed + /// so far. + /// + /// \return The actor frontier, a map from handle ID to execution state for + /// that handle. + const std::unordered_map &GetFrontier() + const; + + /// Extend the frontier of the actor by a single task. This should be called + /// whenever the actor executes a task. + /// + /// \param handle_id The ID of the handle that submitted the task. + /// \param execution_dependency The object representing the actor's new + /// state. This is the execution dependency returned by the task. + void ExtendFrontier(const ActorHandleID &handle_id, + const ObjectID &execution_dependency); + + private: + /// Information from the global actor table about this actor, including the + /// node manager location. + ActorTableDataT actor_table_data_; + /// The object representing the state following the actor's most recently + /// executed task. The next task to execute on the actor should be marked as + /// execution-dependent on this object. + ObjectID execution_dependency_; + /// The execution frontier of the actor, which represents which tasks have + /// executed so far and which tasks may execute next, based on execution + /// dependencies. This is indexed by handle. + std::unordered_map frontier_; +}; + +} // namespace raylet + +} // namespace ray + +#endif // RAY_RAYLET_ACTOR_REGISTRATION_H diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 5ed6bc150..2d388b5ba 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -3,6 +3,37 @@ #include "common_protocol.h" #include "ray/raylet/format/node_manager_generated.h" +namespace { + +/// A helper function to determine whether a given actor task has already been executed +/// according to the given actor registry. Returns true if the task is a duplicate. +bool CheckDuplicateActorTask( + const std::unordered_map + &actor_registry, + const ray::raylet::TaskSpecification &spec) { + auto actor_entry = actor_registry.find(spec.ActorId()); + RAY_CHECK(actor_entry != actor_registry.end()); + const auto &frontier = actor_entry->second.GetFrontier(); + int64_t expected_task_counter = 0; + auto frontier_entry = frontier.find(spec.ActorHandleId()); + if (frontier_entry != frontier.end()) { + expected_task_counter = frontier_entry->second.task_counter; + } + if (spec.ActorCounter() < expected_task_counter) { + // The assigned task counter is less than expected. The actor has already + // executed past this task, so do not assign the task again. + RAY_LOG(WARNING) << "A task was resubmitted, so we are ignoring it. This " + << "should only happen during reconstruction."; + return true; + } + RAY_CHECK(spec.ActorCounter() == expected_task_counter) + << "Expected actor counter: " << expected_task_counter + << ", got: " << spec.ActorCounter(); + return false; +}; + +} // namespace + namespace ray { namespace raylet { @@ -26,7 +57,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, gcs_client_(gcs_client), remote_clients_(), remote_server_connections_(), - object_manager_(object_manager) { + object_manager_(object_manager), + actor_registry_() { RAY_CHECK(heartbeat_period_ms_ > 0); // Initialize the resource map with own cluster resource configuration. ClientID local_client_id = gcs_client_->client_table().GetLocalClientId(); @@ -34,6 +66,39 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, SchedulingResources(config.resource_config)); } +ray::Status NodeManager::RegisterGcs() { + // Register a callback for actor creation notifications. + auto actor_creation_callback = [this]( + gcs::AsyncGcsClient *client, const ActorID &actor_id, + const std::vector &data) { HandleActorCreation(actor_id, data); }; + + RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe( + UniqueID::nil(), UniqueID::nil(), actor_creation_callback, nullptr)); + + // Register a callback on the client table for new clients. + auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { + ClientAdded(data); + }; + gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); + + // Subscribe to node manager heartbeats. + const auto heartbeat_added = [this](gcs::AsyncGcsClient *client, const ClientID &id, + const HeartbeatTableDataT &heartbeat_data) { + HeartbeatAdded(client, id, heartbeat_data); + }; + RAY_RETURN_NOT_OK(gcs_client_->heartbeat_table().Subscribe( + UniqueID::nil(), UniqueID::nil(), heartbeat_added, + [this](gcs::AsyncGcsClient *client) { + RAY_LOG(DEBUG) << "heartbeat table subscription done callback called."; + })); + + // Start sending heartbeats to the GCS. + Heartbeat(); + + return ray::Status::OK(); +} + void NodeManager::Heartbeat() { RAY_LOG(DEBUG) << "[Heartbeat] sending heartbeat."; auto &heartbeat_table = gcs_client_->heartbeat_table(); @@ -75,27 +140,13 @@ void NodeManager::Heartbeat() { }); } -void NodeManager::ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &client_data) { +void NodeManager::ClientAdded(const ClientTableDataT &client_data) { ClientID client_id = ClientID::from_binary(client_data.client_id); RAY_LOG(DEBUG) << "[ClientAdded] received callback from client id " << client_id.hex(); if (client_id == gcs_client_->client_table().GetLocalClientId()) { // We got a notification for ourselves, so we are connected to the GCS now. // Save this NodeManager's resource information in the cluster resource map. cluster_resource_map_[client_id] = local_resources_; - // Start sending heartbeats to the GCS. - Heartbeat(); - // Subscribe to heartbeats. - const auto heartbeat_added = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatTableDataT &heartbeat_data) { - this->HeartbeatAdded(client, id, heartbeat_data); - }; - ray::Status status = client->heartbeat_table().Subscribe( - UniqueID::nil(), UniqueID::nil(), heartbeat_added, - [](gcs::AsyncGcsClient *client) { - RAY_LOG(DEBUG) << "heartbeat table subscription done callback called."; - }); - RAY_CHECK_OK(status); return; } @@ -154,6 +205,46 @@ void NodeManager::HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &cl heartbeat_resource_available); } +void NodeManager::HandleActorCreation(const ActorID &actor_id, + const std::vector &data) { + RAY_LOG(DEBUG) << "Actor creation notification received: " << actor_id; + + // TODO(swang): In presence of failures, data may have size > 1, since the + // actor will have been created multiple times. In that case, we should + // only consider the last entry as valid. All previous entries should have + // a dead node_manager_id. + RAY_CHECK(data.size() == 1); + + // Register the new actor. + ActorRegistration actor_registration(data.back()); + // Extend the frontier to include the actor creation task. NOTE(swang): The + // creator of the actor is always assigned nil as the actor handle ID. + actor_registration.ExtendFrontier(ActorHandleID::nil(), + actor_registration.GetActorCreationDependency()); + auto inserted = actor_registry_.emplace(actor_id, std::move(actor_registration)); + RAY_CHECK(inserted.second); + + // Dequeue any methods that were submitted before the actor's location was + // known. + const auto &methods = local_queues_.GetUncreatedActorMethods(); + std::unordered_set created_actor_method_ids; + for (const auto &method : methods) { + if (method.GetTaskSpecification().ActorId() == actor_id) { + created_actor_method_ids.insert(method.GetTaskSpecification().TaskId()); + } + } + // Resubmit the methods that were submitted before the actor's location was + // known. + auto created_actor_methods = local_queues_.RemoveTasks(created_actor_method_ids); + for (const auto &method : created_actor_methods) { + lineage_cache_.RemoveWaitingTask(method.GetTaskSpecification().TaskId()); + // The task's uncommitted lineage was already added to the local lineage + // cache upon the initial submission, so it's okay to resubmit it with an + // empty lineage this time. + SubmitTask(method, Lineage()); + } +} + void NodeManager::ProcessNewClient(std::shared_ptr client) { // The new client is a worker, so begin listening for messages. client->ProcessMessages(); @@ -175,31 +266,39 @@ void NodeManager::ProcessClientMessage(std::shared_ptr cl } } break; case protocol::MessageType_GetTask: { - const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); RAY_CHECK(worker); // If the worker was assigned a task, mark it as finished. if (!worker->GetAssignedTaskId().is_nil()) { - FinishTask(worker->GetAssignedTaskId()); + FinishAssignedTask(worker); } // Return the worker to the idle pool. worker_pool_.PushWorker(worker); + // Check if there is a scheduled task that can now be assigned to the newly + // idle worker. auto scheduled_tasks = local_queues_.GetScheduledTasks(); if (!scheduled_tasks.empty()) { - const TaskID &scheduled_task_id = - scheduled_tasks.front().GetTaskSpecification().TaskId(); - auto scheduled_tasks = local_queues_.RemoveTasks({scheduled_task_id}); - AssignTask(scheduled_tasks.front()); + // Find a scheduled task that whose actor ID matches that of the newly + // idle worker. + auto worker_actor_id = worker->GetActorId(); + for (const auto &task : scheduled_tasks) { + if (task.GetTaskSpecification().ActorId() == worker_actor_id) { + auto scheduled_tasks = + local_queues_.RemoveTasks({task.GetTaskSpecification().TaskId()}); + AssignTask(scheduled_tasks.front()); + } + } } } break; case protocol::MessageType_DisconnectClient: { // Remove the dead worker from the pool and stop listening for messages. const std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker) { - if (!worker->GetAssignedTaskId().is_nil()) { - // TODO(swang): Clean up any tasks that were assigned to the worker. - // Release any resources that may be held by this worker. - FinishTask(worker->GetAssignedTaskId()); - } + // TODO(swang): Handle the case where the worker is killed while + // executing a task. Clean up the assigned task's resources, return an + // error to the driver. + // RAY_CHECK(worker->GetAssignedTaskId().is_nil()) + // << "Worker died while executing task: " << worker->GetAssignedTaskId(); worker_pool_.DisconnectWorker(worker); } return; @@ -300,9 +399,57 @@ void NodeManager::ScheduleTasks() { } void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineage) { + const TaskSpecification &spec = task.GetTaskSpecification(); + // Add the task and its uncommitted lineage to the lineage cache. lineage_cache_.AddWaitingTask(task, uncommitted_lineage); - // Queue the task according to the availability of its arguments. + + if (spec.IsActorTask()) { + // Check whether we know the location of the actor. + const auto actor_entry = actor_registry_.find(spec.ActorId()); + if (actor_entry != actor_registry_.end()) { + // We have a known location for the actor. + auto node_manager_id = actor_entry->second.GetNodeManagerId(); + if (node_manager_id == gcs_client_->client_table().GetLocalClientId()) { + // The actor is local. Queue the task for local execution. + QueueTask(task); + } else { + // The actor is remote. Forward the task to the node manager that owns + // the actor. + // TODO(swang): Handle forward task failure. + RAY_CHECK_OK(ForwardTask(task, node_manager_id)); + } + } else { + // We do not have a registered location for the object, so either the + // actor has not yet been created or we missed the notification for the + // actor creation because this node joined the cluster after the actor + // was already created. Look up the actor's registered location in case + // we missed the creation notification. + // NOTE(swang): This codepath needs to be tested in a cluster setting. + auto lookup_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, + const std::vector &data) { + if (!data.empty()) { + // The actor has been created. + HandleActorCreation(actor_id, data); + } else { + // The actor has not yet been created. + // TODO(swang): Set a timer for reconstructing the actor creation + // task. + } + }; + RAY_CHECK_OK(gcs_client_->actor_table().Lookup(JobID::nil(), spec.ActorId(), + lookup_callback)); + // Keep the task queued until we discover the actor's location. + local_queues_.QueueUncreatedActorMethods({task}); + } + } else { + // This is a non-actor task. Queue the task for local execution. + QueueTask(task); + } +} + +void NodeManager::QueueTask(const Task &task) { + // Queue the task depending on the availability of its arguments. if (task_dependency_manager_.TaskReady(task)) { local_queues_.QueueReadyTasks(std::vector({task})); ScheduleTasks(); @@ -312,27 +459,38 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag } } -void NodeManager::AssignTask(const Task &task) { +void NodeManager::AssignTask(Task &task) { + const TaskSpecification &spec = task.GetTaskSpecification(); + + // If this is an actor task, check that the new task has the correct counter. + if (spec.IsActorTask()) { + if (CheckDuplicateActorTask(actor_registry_, spec)) { + // Drop tasks that have already been executed. + return; + } + } + // Resource accounting: acquire resources for the scheduled task. const ClientID &my_client_id = gcs_client_->client_table().GetLocalClientId(); - RAY_CHECK(this->cluster_resource_map_[my_client_id].Acquire( - task.GetTaskSpecification().GetRequiredResources())); + RAY_CHECK( + this->cluster_resource_map_[my_client_id].Acquire(spec.GetRequiredResources())); - if (worker_pool_.PoolSize() == 0) { - worker_pool_.StartWorker(); + // Try to get an idle worker that can execute this task. + std::shared_ptr worker = worker_pool_.PopWorker(spec.ActorId()); + if (worker == nullptr) { + // There are no workers that can execute this task. + if (!spec.IsActorTask()) { + // There are no more non-actor workers available to execute this task. + // Start a new worker. + worker_pool_.StartWorker(); + } // Queue this task for future assignment. The task will be assigned to a // worker once one becomes available. local_queues_.QueueScheduledTasks(std::vector({task})); return; } - const TaskSpecification &spec = task.GetTaskSpecification(); - std::shared_ptr worker = worker_pool_.PopWorker(); RAY_LOG(DEBUG) << "Assigning task to worker with pid " << worker->Pid(); - - worker->AssignTaskId(spec.TaskId()); - local_queues_.QueueRunningTasks(std::vector({task})); - flatbuffers::FlatBufferBuilder fbb; auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb), fbb.CreateVector(std::vector())); @@ -340,33 +498,94 @@ void NodeManager::AssignTask(const Task &task) { auto status = worker->Connection()->WriteMessage(protocol::MessageType_ExecuteTask, fbb.GetSize(), fbb.GetBufferPointer()); if (status.ok()) { + // We successfully assigned the task to the worker. + worker->AssignTaskId(spec.TaskId()); + // If the task was an actor task, then record this execution to guarantee + // consistency in the case of reconstruction. + if (spec.IsActorTask()) { + // Extend the frontier to include the executing task. + auto actor_entry = actor_registry_.find(spec.ActorId()); + RAY_CHECK(actor_entry != actor_registry_.end()); + actor_entry->second.ExtendFrontier(spec.ActorHandleId(), spec.ActorDummyObject()); + // Update the task's execution dependencies to reflect the actual + // execution order, to support deterministic reconstruction. + // NOTE(swang): The update of an actor task's execution dependencies is + // performed asynchronously. This means that if this node manager dies, + // we may lose updates that are in flight to the task table. We only + // guarantee deterministic reconstruction ordering for tasks whose + // updates are reflected in the task table. + TaskExecutionSpecification &mutable_spec = task.GetTaskExecutionSpec(); + mutable_spec.SetExecutionDependencies( + {actor_entry->second.GetExecutionDependency()}); + } // We started running the task, so the task is ready to write to GCS. lineage_cache_.AddReadyTask(task); + // Mark the task as running. + local_queues_.QueueRunningTasks(std::vector({task})); } else { - // We failed to send the task to the worker, so disconnect the worker. The - // task will get queued again during cleanup. + RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; + // We failed to send the task to the worker, so disconnect the worker. ProcessClientMessage(worker->Connection(), protocol::MessageType_DisconnectClient, NULL); + // Queue this task for future assignment. The task will be assigned to a + // worker once one becomes available. + local_queues_.QueueScheduledTasks(std::vector({task})); } } -void NodeManager::FinishTask(const TaskID &task_id) { - RAY_LOG(DEBUG) << "Finished task " << task_id.hex(); +void NodeManager::FinishAssignedTask(std::shared_ptr worker) { + TaskID task_id = worker->GetAssignedTaskId(); + RAY_LOG(DEBUG) << "Finished task " << task_id; auto tasks = local_queues_.RemoveTasks({task_id}); - RAY_CHECK(tasks.size() == 1); auto task = *tasks.begin(); - // Resource accounting: release task's resources. - RAY_CHECK( - this->cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( - task.GetTaskSpecification().GetRequiredResources())); + if (task.GetTaskSpecification().IsActorCreationTask()) { + // If this was an actor creation task, then convert the worker to an actor. + auto actor_id = task.GetTaskSpecification().ActorCreationId(); + worker->AssignActorId(actor_id); + + // Publish the actor creation event to all other nodes so that methods for + // the actor will be forwarded directly to this node. + auto actor_notification = std::make_shared(); + actor_notification->actor_id = actor_id.binary(); + actor_notification->actor_creation_dummy_object_id = + task.GetTaskSpecification().ActorCreationDummyObjectId().binary(); + // TODO(swang): The driver ID. + actor_notification->driver_id = JobID::nil().binary(); + actor_notification->node_manager_id = + gcs_client_->client_table().GetLocalClientId().binary(); + RAY_LOG(DEBUG) << "Publishing actor creation: " << actor_id; + RAY_CHECK_OK(gcs_client_->actor_table().Append(JobID::nil(), actor_id, + actor_notification, nullptr)); + + // Resources required by an actor creation task are acquired for the + // lifetime of the actor, so we do not release any resources here. + } else { + // Release task's resources. + RAY_CHECK(this->cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] + .Release(task.GetTaskSpecification().GetRequiredResources())); + } + + // If the finished task was an actor task, mark the returned dummy object as + // locally available. This is not added to the object table, so the update + // will be invisible to both the local object manager and the other nodes. + // NOTE(swang): These objects are never cleaned up. We should consider + // removing the objects, e.g., when an actor is terminated. + if (task.GetTaskSpecification().IsActorCreationTask() || + task.GetTaskSpecification().IsActorTask()) { + auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); + task_dependency_manager_.MarkDependencyReady(dummy_object); + } + + // Unset the worker's assigned task. + worker->AssignTaskId(TaskID::nil()); } void NodeManager::ResubmitTask(const TaskID &task_id) { throw std::runtime_error("Method not implemented"); } -ray::Status NodeManager::ForwardTask(Task &task, const ClientID &node_id) { +ray::Status NodeManager::ForwardTask(const Task &task, const ClientID &node_id) { auto task_id = task.GetTaskSpecification().TaskId(); // Get and serialize the task's uncommitted lineage. diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index f69f43244..bfee64c6b 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -5,6 +5,7 @@ #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" +#include "ray/raylet/actor_registration.h" #include "ray/raylet/lineage_cache.h" #include "ray/raylet/scheduling_policy.h" #include "ray/raylet/scheduling_queue.h" @@ -53,26 +54,34 @@ class NodeManager { void ProcessNodeManagerMessage(std::shared_ptr node_manager_client, int64_t message_type, const uint8_t *message); - void ClientAdded(gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data); + ray::Status RegisterGcs(); void HeartbeatAdded(gcs::AsyncGcsClient *client, const ClientID &id, const HeartbeatTableDataT &data); private: + // Handler for the addition of a new GCS client. + void ClientAdded(const ClientTableDataT &data); + // Handler for the creation of an actor, possibly on a remote node. + void HandleActorCreation(const ActorID &actor_id, + const std::vector &data); + // Queue a task for local execution. + void QueueTask(const Task &task); /// Submit a task to this node. void SubmitTask(const Task &task, const Lineage &uncommitted_lineage); - /// Assign a task. - void AssignTask(const Task &task); - /// Finish a task. - void FinishTask(const TaskID &task_id); + /// Assign a task. The task is assumed to not be queued in local_queues_. + void AssignTask(Task &task); + /// Handle a worker finishing its assigned task. + void FinishAssignedTask(std::shared_ptr worker); /// Schedule tasks. void ScheduleTasks(); /// Handle a task whose local dependencies were missing and are now available. void HandleWaitingTaskReady(const TaskID &task_id); /// Resubmit a task whose return value needs to be reconstructed. void ResubmitTask(const TaskID &task_id); - ray::Status ForwardTask(Task &task, const ClientID &node_id); + /// Forward a task to another node to execute. The task is assumed to not be + /// queued in local_queues_. + ray::Status ForwardTask(const Task &task, const ClientID &node_id); /// Send heartbeats to the GCS. void Heartbeat(); @@ -101,6 +110,7 @@ class NodeManager { std::unordered_map remote_server_connections_; ObjectManager &object_manager_; + std::unordered_map actor_registry_; }; } // namespace raylet diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index af6c15082..75295f1ac 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -73,11 +73,8 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, << " port " << client_info.node_manager_port; RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info)); - auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { - node_manager_.ClientAdded(client, id, data); - }; - gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); + RAY_RETURN_NOT_OK(node_manager_.RegisterGcs()); + return Status::OK(); } diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 63d8869bd..30919ded6 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -6,6 +6,10 @@ namespace ray { namespace raylet { +const std::list &SchedulingQueue::GetUncreatedActorMethods() const { + return this->uncreated_actor_methods_; +} + const std::list &SchedulingQueue::GetWaitingTasks() const { return this->waiting_tasks_; } @@ -56,6 +60,7 @@ std::vector SchedulingQueue::RemoveTasks( std::vector removed_tasks; // Try to find the tasks to remove from the waiting tasks. + removeTasksFromQueue(uncreated_actor_methods_, task_ids, removed_tasks); removeTasksFromQueue(waiting_tasks_, task_ids, removed_tasks); removeTasksFromQueue(ready_tasks_, task_ids, removed_tasks); removeTasksFromQueue(scheduled_tasks_, task_ids, removed_tasks); @@ -66,6 +71,10 @@ std::vector SchedulingQueue::RemoveTasks( return removed_tasks; } +void SchedulingQueue::QueueUncreatedActorMethods(const std::vector &tasks) { + queueTasks(uncreated_actor_methods_, tasks); +} + void SchedulingQueue::QueueWaitingTasks(const std::vector &tasks) { queueTasks(waiting_tasks_, tasks); } @@ -82,14 +91,6 @@ void SchedulingQueue::QueueRunningTasks(const std::vector &tasks) { queueTasks(running_tasks_, tasks); } -// RegisterActor is responsible for recording provided actor_information -// in the actor registry. -bool SchedulingQueue::RegisterActor(ActorID actor_id, - const ActorInformation &actor_information) { - actor_registry_[actor_id] = actor_information; - return true; -} - } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index 304fd78d8..ad47da4f7 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -6,7 +6,6 @@ #include #include -#include "ray/raylet/actor.h" #include "ray/raylet/task.h" namespace ray { @@ -29,6 +28,13 @@ class SchedulingQueue { /// SchedulingQueue destructor. virtual ~SchedulingQueue() {} + /// Get the queue of tasks that are destined for actors that have not yet + /// been created. + /// + /// \return A const reference to the queue of tasks that are destined for + /// actors that have not yet been created. + const std::list &GetUncreatedActorMethods() const; + /// Get the queue of tasks in the waiting state. /// /// \return A const reference to the queue of tasks that are waiting for @@ -66,6 +72,11 @@ class SchedulingQueue { /// \return A vector of the tasks that were removed. std::vector RemoveTasks(std::unordered_set tasks); + /// Queue tasks that are destined for actors that have not yet been created. + /// + /// \param tasks The tasks to queue. + void QueueUncreatedActorMethods(const std::vector &tasks); + /// Queue tasks in the waiting state. /// /// \param tasks The tasks to queue. @@ -86,13 +97,9 @@ class SchedulingQueue { /// \param tasks The tasks to queue. void QueueRunningTasks(const std::vector &tasks); - /// Register an actor. - /// - /// \param actor_id The ID of the actor to register. - /// \param actor_information Information about the actor. - bool RegisterActor(ActorID actor_id, const ActorInformation &actor_information); - private: + /// Tasks that are destined for actors that have not yet been created. + std::list uncreated_actor_methods_; /// Tasks that are waiting for an object dependency to appear locally. std::list waiting_tasks_; /// Tasks whose object dependencies are locally available, but that are @@ -102,8 +109,6 @@ class SchedulingQueue { std::list scheduled_tasks_; /// Tasks that are running on a worker. std::list running_tasks_; - /// The registry of known actors. - std::unordered_map actor_registry_; }; } // namespace raylet diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 4ade555f3..1a8278153 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -105,7 +105,7 @@ void TaskDependencyManager::UnsubscribeTaskReady(const TaskID &task_id) { } void TaskDependencyManager::MarkDependencyReady(const ObjectID &object) { - throw std::runtime_error("Method not implemented"); + handleObjectReady(object); } } // namespace raylet diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index b11ed64c4..8488da3c4 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -38,11 +38,19 @@ TaskSpecification::TaskSpecification(const flatbuffers::String &string) { } TaskSpecification::TaskSpecification( - UniqueID driver_id, TaskID parent_task_id, int64_t parent_counter, - // UniqueID actor_id, - // UniqueID actor_handle_id, - // int64_t actor_counter, - FunctionID function_id, + const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const FunctionID &function_id, + const std::vector> &task_arguments, int64_t num_returns, + const std::unordered_map &required_resources) + : TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::nil(), + ObjectID::nil(), ActorID::nil(), ActorHandleID::nil(), -1, + function_id, task_arguments, num_returns, required_resources) {} + +TaskSpecification::TaskSpecification( + const UniqueID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, + const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, + const FunctionID &function_id, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources) : spec_() { @@ -54,10 +62,6 @@ TaskSpecification::TaskSpecification( sha256_update(&ctx, (BYTE *)&driver_id, sizeof(driver_id)); sha256_update(&ctx, (BYTE *)&parent_task_id, sizeof(parent_task_id)); sha256_update(&ctx, (BYTE *)&parent_counter, sizeof(parent_counter)); - // sha256_update(&ctx, (BYTE *) &actor_id, sizeof(actor_id)); - // sha256_update(&ctx, (BYTE *) &actor_counter, sizeof(actor_counter)); - // sha256_update(&ctx, (BYTE *) &is_actor_checkpoint_method, - // sizeof(is_actor_checkpoint_method)); // Compute the final task ID from the hash. BYTE buff[DIGEST_SIZE]; @@ -82,11 +86,11 @@ TaskSpecification::TaskSpecification( // Serialize the TaskSpecification. auto spec = CreateTaskInfo( fbb, to_flatbuf(fbb, driver_id), to_flatbuf(fbb, task_id), - to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, ActorID::nil()), - to_flatbuf(fbb, ActorID::nil()), to_flatbuf(fbb, WorkerID::nil()), - to_flatbuf(fbb, ActorHandleID::nil()), 0, false, to_flatbuf(fbb, function_id), - fbb.CreateVector(arguments), fbb.CreateVector(returns), - map_to_flatbuf(fbb, required_resources)); + to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), + to_flatbuf(fbb, actor_creation_dummy_object_id), to_flatbuf(fbb, actor_id), + to_flatbuf(fbb, actor_handle_id), actor_counter, false, + to_flatbuf(fbb, function_id), fbb.CreateVector(arguments), + fbb.CreateVector(returns), map_to_flatbuf(fbb, required_resources)); fbb.Finish(spec); AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } @@ -165,6 +169,42 @@ const ResourceSet TaskSpecification::GetRequiredResources() const { return ResourceSet(required_resources); } +bool TaskSpecification::IsActorCreationTask() const { + return !ActorCreationId().is_nil(); +} + +bool TaskSpecification::IsActorTask() const { return !ActorId().is_nil(); } + +ActorID TaskSpecification::ActorCreationId() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return from_flatbuf(*message->actor_creation_id()); +} + +ObjectID TaskSpecification::ActorCreationDummyObjectId() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return from_flatbuf(*message->actor_creation_dummy_object_id()); +} + +ActorID TaskSpecification::ActorId() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return from_flatbuf(*message->actor_id()); +} + +ActorHandleID TaskSpecification::ActorHandleId() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return from_flatbuf(*message->actor_handle_id()); +} + +int64_t TaskSpecification::ActorCounter() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return message->actor_counter(); +} + +ObjectID TaskSpecification::ActorDummyObject() const { + RAY_CHECK(IsActorTask() || IsActorCreationTask()); + return ReturnId(NumReturns() - 1); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 9bc91595c..d9e51dc96 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -94,15 +94,21 @@ class TaskSpecification { /// \param arguments The list of task arguments. /// \param num_returns The number of values returned by the task. /// \param required_resources The task's resource demands. - TaskSpecification(UniqueID driver_id, TaskID parent_task_id, int64_t parent_counter, - // UniqueID actor_id, - // UniqueID actor_handle_id, - // int64_t actor_counter, - FunctionID function_id, + TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id, + int64_t parent_counter, const FunctionID &function_id, const std::vector> &arguments, int64_t num_returns, const std::unordered_map &required_resources); + TaskSpecification(const UniqueID &driver_id, const TaskID &parent_task_id, + int64_t parent_counter, const ActorID &actor_creation_id, + const ObjectID &actor_creation_dummy_object_id, + const ActorID &actor_id, const ActorHandleID &actor_handle_id, + int64_t actor_counter, const FunctionID &function_id, + const std::vector> &task_arguments, + int64_t num_returns, + const std::unordered_map &required_resources); + ~TaskSpecification() {} /// Serialize the TaskSpecification to a flatbuffer. @@ -129,6 +135,16 @@ class TaskSpecification { double GetRequiredResource(const std::string &resource_name) const; const ResourceSet GetRequiredResources() const; + // Methods specific to actor tasks. + bool IsActorCreationTask() const; + bool IsActorTask() const; + ActorID ActorCreationId() const; + ObjectID ActorCreationDummyObjectId() const; + ActorID ActorId() const; + ActorHandleID ActorHandleId() const; + int64_t ActorCounter() const; + ObjectID ActorDummyObject() const; + private: /// Assign the specification data from a pointer. void AssignSpecification(const uint8_t *spec, size_t spec_size); diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index ef7eb75c1..cec388346 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -12,7 +12,10 @@ namespace raylet { /// A constructor responsible for initializing the state of a worker. Worker::Worker(pid_t pid, std::shared_ptr connection) - : pid_(pid), connection_(connection), assigned_task_id_(TaskID::nil()) {} + : pid_(pid), + connection_(connection), + assigned_task_id_(TaskID::nil()), + actor_id_(ActorID::nil()) {} pid_t Worker::Pid() const { return pid_; } @@ -20,6 +23,15 @@ void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; const TaskID &Worker::GetAssignedTaskId() const { return assigned_task_id_; } +void Worker::AssignActorId(const ActorID &actor_id) { + RAY_CHECK(actor_id_.is_nil()) + << "A worker that is already an actor cannot be assigned an actor ID again."; + RAY_CHECK(!actor_id.is_nil()); + actor_id_ = actor_id; +} + +const ActorID &Worker::GetActorId() const { return actor_id_; } + const std::shared_ptr Worker::Connection() const { return connection_; } diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 8fc8827b2..3017521ff 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -23,6 +23,8 @@ class Worker { pid_t Pid() const; void AssignTaskId(const TaskID &task_id); const TaskID &GetAssignedTaskId() const; + void AssignActorId(const ActorID &actor_id); + const ActorID &GetActorId() const; /// Return the worker's connection. const std::shared_ptr Connection() const; @@ -31,7 +33,10 @@ class Worker { pid_t pid_; /// Connection state of a worker. std::shared_ptr connection_; + /// The worker's currently assigned task. TaskID assigned_task_id_; + /// The worker's actor ID. If this is nil, then the worker is not an actor. + ActorID actor_id_; }; } // namespace raylet diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 666c95e2c..d866f9888 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -51,14 +51,12 @@ void WorkerPool::StartWorker() { RAY_LOG(FATAL) << "Failed to start worker with return value " << rv; } -uint32_t WorkerPool::PoolSize() const { return pool_.size(); } - void WorkerPool::RegisterWorker(std::shared_ptr worker) { RAY_LOG(DEBUG) << "Registering worker with pid " << worker->Pid(); registered_workers_.push_back(worker); } -const std::shared_ptr WorkerPool::GetRegisteredWorker( +std::shared_ptr WorkerPool::GetRegisteredWorker( std::shared_ptr connection) const { for (auto it = registered_workers_.begin(); it != registered_workers_.end(); it++) { if ((*it)->Connection() == connection) { @@ -70,17 +68,30 @@ const std::shared_ptr WorkerPool::GetRegisteredWorker( void WorkerPool::PushWorker(std::shared_ptr worker) { // Since the worker is now idle, unset its assigned task ID. - worker->AssignTaskId(TaskID::nil()); + RAY_CHECK(worker->GetAssignedTaskId().is_nil()) + << "Idle workers cannot have an assigned task ID"; // Add the worker to the idle pool. - pool_.push_back(std::move(worker)); + if (worker->GetActorId().is_nil()) { + pool_.push_back(std::move(worker)); + } else { + actor_pool_[worker->GetActorId()] = std::move(worker); + } } -std::shared_ptr WorkerPool::PopWorker() { - if (pool_.empty()) { - return nullptr; +std::shared_ptr WorkerPool::PopWorker(const ActorID &actor_id) { + std::shared_ptr worker = nullptr; + if (actor_id.is_nil()) { + if (!pool_.empty()) { + worker = std::move(pool_.back()); + pool_.pop_back(); + } + } else { + auto actor_entry = actor_pool_.find(actor_id); + if (actor_entry != actor_pool_.end()) { + worker = std::move(actor_entry->second); + actor_pool_.erase(actor_entry); + } } - std::shared_ptr worker = std::move(pool_.back()); - pool_.pop_back(); return worker; } diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index a9ce9824a..2486c57c6 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -3,6 +3,7 @@ #include #include +#include #include "ray/common/client_connection.h" #include "ray/raylet/worker.h" @@ -30,11 +31,6 @@ class WorkerPool { /// Destructor responsible for freeing a set of workers owned by this class. ~WorkerPool(); - /// Get the number of idle workers in the pool. - /// - /// \return The number of idle workers. - uint32_t PoolSize() const; - /// Asynchronously start a new worker process. Once the worker process has /// registered with an external server, the process should create and /// register a new Worker, then add itself to the pool. Failure to start @@ -52,7 +48,7 @@ class WorkerPool { /// \param The client connection owned by a registered worker. /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a worker yet. - const std::shared_ptr GetRegisteredWorker( + std::shared_ptr GetRegisteredWorker( std::shared_ptr connection) const; /// Disconnect a registered worker. @@ -61,8 +57,7 @@ class WorkerPool { /// \return Whether the given worker was in the pool of idle workers. bool DisconnectWorker(std::shared_ptr worker); - /// Add an idle worker to the pool. The worker's task assignment will be - /// reset. + /// Add an idle worker to the pool. /// /// \param The idle worker to add. void PushWorker(std::shared_ptr worker); @@ -70,13 +65,17 @@ class WorkerPool { /// Pop an idle worker from the pool. The caller is responsible for pushing /// the worker back onto the pool once the worker has completed its work. /// - /// \return An idle worker. Returns nullptr if the pool is empty. - std::shared_ptr PopWorker(); + /// \param actor_id The returned worker must have this actor ID. + /// \return An idle worker with the requested actor ID. Returns nullptr if no + /// such worker exists. + std::shared_ptr PopWorker(const ActorID &actor_id); private: std::vector worker_command_; /// The pool of idle workers. std::list> pool_; + /// The pool of idle actor workers. + std::unordered_map, UniqueIDHasher> actor_pool_; /// All workers that have registered and are still connected, including both /// idle and executing. // TODO(swang): Make this a map to make GetRegisteredWorker faster. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 2ba6954dd..b7aa199e4 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -50,7 +50,7 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { TEST_F(WorkerPoolTest, HandleWorkerPushPop) { // Try to pop a worker from the empty pool and make sure we don't get one. std::shared_ptr popped_worker; - popped_worker = worker_pool_.PopWorker(); + popped_worker = worker_pool_.PopWorker(ActorID::nil()); ASSERT_EQ(popped_worker, nullptr); // Create some workers. @@ -61,15 +61,36 @@ TEST_F(WorkerPoolTest, HandleWorkerPushPop) { for (auto &worker : workers) { worker_pool_.PushWorker(worker); } - ASSERT_EQ(worker_pool_.PoolSize(), workers.size()); // Pop two workers and make sure they're one of the workers we created. - popped_worker = worker_pool_.PopWorker(); + popped_worker = worker_pool_.PopWorker(ActorID::nil()); ASSERT_NE(popped_worker, nullptr); ASSERT_TRUE(workers.count(popped_worker) > 0); - popped_worker = worker_pool_.PopWorker(); + popped_worker = worker_pool_.PopWorker(ActorID::nil()); ASSERT_NE(popped_worker, nullptr); ASSERT_TRUE(workers.count(popped_worker) > 0); + popped_worker = worker_pool_.PopWorker(ActorID::nil()); + ASSERT_EQ(popped_worker, nullptr); +} + +TEST_F(WorkerPoolTest, PopActorWorker) { + // Create a worker. + auto worker = CreateWorker(1234); + // Add the worker to the pool. + worker_pool_.PushWorker(worker); + + // Assign an actor ID to the worker. + auto actor = worker_pool_.PopWorker(ActorID::nil()); + auto actor_id = ActorID::from_random(); + actor->AssignActorId(actor_id); + worker_pool_.PushWorker(actor); + + // Check that there are no more non-actor workers. + ASSERT_EQ(worker_pool_.PopWorker(ActorID::nil()), nullptr); + // Check that we can pop the actor worker. + actor = worker_pool_.PopWorker(actor_id); + ASSERT_EQ(actor, worker); + ASSERT_EQ(actor->GetActorId(), actor_id); } } // namespace raylet diff --git a/test/xray_test.py b/test/xray_test.py index 7e5fc9699..034099db6 100644 --- a/test/xray_test.py +++ b/test/xray_test.py @@ -47,3 +47,18 @@ def test_basic_task_api(ray_start): # Test arguments passed by ID. # Test keyword arguments. + + +def test_actor_api(ray_start): + + @ray.remote + class Foo(object): + def __init__(self, val): + self.x = val + + def get(self): + return self.x + + x = 1 + f = Foo.remote(x) + assert (ray.get(f.get.remote()) == x)