diff --git a/BUILD.bazel b/BUILD.bazel index d15dd7fd8..0ea928129 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -77,17 +77,6 @@ cc_proto_library( deps = [":object_manager_proto"], ) -proto_library( - name = "worker_proto", - srcs = ["src/ray/protobuf/worker.proto"], - deps = [":common_proto"], -) - -cc_proto_library( - name = "worker_cc_proto", - deps = ["worker_proto"], -) - proto_library( name = "core_worker_proto", srcs = ["src/ray/protobuf/core_worker.proto"], @@ -95,21 +84,10 @@ proto_library( ) cc_proto_library( - name = "core_worker_cc_proto", + name = "worker_cc_proto", deps = ["core_worker_proto"], ) -proto_library( - name = "direct_actor_proto", - srcs = ["src/ray/protobuf/direct_actor.proto"], - deps = [":common_proto"], -) - -cc_proto_library( - name = "direct_actor_cc_proto", - deps = ["direct_actor_proto"], -) - proto_library( name = "serialization_proto", srcs = ["src/ray/protobuf/serialization.proto"], @@ -193,19 +171,11 @@ cc_library( # Worker gRPC lib. cc_grpc_library( name = "worker_cc_grpc", - srcs = [":worker_proto"], + srcs = [":core_worker_proto"], grpc_only = True, deps = [":worker_cc_proto"], ) -# direct actor gRPC lib. -cc_grpc_library( - name = "direct_actor_cc_grpc", - srcs = [":direct_actor_proto"], - grpc_only = True, - deps = [":direct_actor_cc_proto"], -) - # worker server and client. cc_library( name = "worker_rpc", @@ -214,7 +184,6 @@ cc_library( ]), copts = COPTS, deps = [ - "direct_actor_cc_grpc", ":grpc_common_lib", ":ray_common", ":worker_cc_grpc", @@ -363,6 +332,7 @@ cc_library( "src/ray/core_worker/store_provider/*.cc", "src/ray/core_worker/store_provider/memory_store/*.cc", "src/ray/core_worker/transport/*.cc", + "src/ray/rpc/worker/*.cc", ], exclude = [ "src/ray/core_worker/*_test.cc", @@ -379,7 +349,7 @@ cc_library( deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - ":core_worker_cc_proto", + ":worker_cc_proto", ":ray_common", ":ray_util", # TODO(hchen): After `raylet_client` is migrated to gRPC, `core_worker_lib` diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 8ffbffdbd..e2dc5fd98 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -89,7 +89,6 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { if (task_spec.IsNormalTask()) { RAY_CHECK(current_job_id_.IsNil()); SetCurrentJobId(task_spec.JobId()); - current_actor_use_direct_call_ = false; } else if (task_spec.IsActorCreationTask()) { RAY_CHECK(current_job_id_.IsNil()); SetCurrentJobId(task_spec.JobId()); diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index d8fb63ae5..531706cca 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -47,7 +47,7 @@ class WorkerContext { const WorkerID worker_id_; JobID current_job_id_; ActorID current_actor_id_; - bool current_actor_use_direct_call_; + bool current_actor_use_direct_call_ = false; int current_actor_max_concurrency_; private: diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index a31aa335d..d93ef6242 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -86,7 +86,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, gcs_client_(gcs_options), memory_store_(std::make_shared()), task_execution_service_work_(task_execution_service_), - task_execution_callback_(task_execution_callback) { + task_execution_callback_(task_execution_callback), + grpc_service_(io_service_, *this) { // Initialize logging if log_dir is passed. Otherwise, it must be initialized // and cleaned up by the caller. if (log_dir_ != "") { @@ -111,14 +112,13 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, // Initialize task receivers. auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); - raylet_task_receiver_ = - std::unique_ptr(new CoreWorkerRayletTaskReceiver( - worker_context_, raylet_client_, task_execution_service_, worker_server_, - execute_task, exit_handler)); + raylet_task_receiver_ = std::unique_ptr( + new CoreWorkerRayletTaskReceiver(raylet_client_, execute_task, exit_handler)); direct_actor_task_receiver_ = std::unique_ptr( new CoreWorkerDirectActorTaskReceiver(worker_context_, task_execution_service_, worker_server_, execute_task, exit_handler)); + worker_server_.RegisterService(grpc_service_); } // Start RPC server after all the task receivers are properly initialized. @@ -750,4 +750,37 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, return status; } +void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request, + rpc::AssignTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { + if (worker_context_.CurrentActorUseDirectCall()) { + send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr, + nullptr); + return; + } else { + task_execution_service_.post([=] { + raylet_task_receiver_->HandleAssignTask(request, reply, send_reply_callback); + }); + } +} + +void CoreWorker::HandleDirectActorAssignTask( + const rpc::DirectActorAssignTaskRequest &request, + rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { + task_execution_service_.post([=] { + direct_actor_task_receiver_->HandleDirectActorAssignTask(request, reply, + send_reply_callback); + }); +} + +void CoreWorker::HandleDirectActorCallArgWaitComplete( + const rpc::DirectActorCallArgWaitCompleteRequest &request, + rpc::DirectActorCallArgWaitCompleteReply *reply, + rpc::SendReplyCallback send_reply_callback) { + task_execution_service_.post([=] { + direct_actor_task_receiver_->HandleDirectActorCallArgWaitComplete( + request, reply, send_reply_callback); + }); +} + } // namespace ray diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 40b099c93..4cb78a030 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -20,6 +20,16 @@ #include "ray/rpc/worker/worker_client.h" #include "ray/rpc/worker/worker_server.h" +/// The set of gRPC handlers and their associated level of concurrency. If you want to +/// add a new call to the worker gRPC server, do the following: +/// 1) Add the rpc to the WorkerService in core_worker.proto, e.g., "ExampleCall" +/// 2) Add a new handler to the macro below: "RAY_CORE_WORKER_RPC_HANDLER(ExampleCall, 1)" +/// 3) Add a method to the CoreWorker class below: "CoreWorker::HandleExampleCall" +#define RAY_CORE_WORKER_RPC_HANDLERS \ + RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \ + RAY_CORE_WORKER_RPC_HANDLER(DirectActorAssignTask, 9999) \ + RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100) + namespace ray { /// The root class that contains all the core and language-independent functionalities @@ -95,7 +105,9 @@ class CoreWorker { // in the heartbeat messsage. void RemoveActiveObjectID(const ObjectID &object_id) LOCKS_EXCLUDED(object_ref_mu_); - /* Public methods related to storing and retrieving objects. */ + /// + /// Public methods related to storing and retrieving objects. + /// /// Set options for this client's interactions with the object store. /// @@ -198,7 +210,9 @@ class CoreWorker { /// \return std::string The string describing memory usage. std::string MemoryUsageString(); - /* Public methods related to task submission. */ + /// + /// Public methods related to task submission. + /// /// Get the caller ID used to submit tasks from this worker to an actor. /// @@ -268,7 +282,9 @@ class CoreWorker { /// \return Status::Invalid if we don't have the specified handle. Status SerializeActorHandle(const ActorID &actor_id, std::string *output) const; - /* Public methods related to task execution. Should not be used by driver processes. */ + /// + /// Public methods related to task execution. Should not be used by driver processes. + /// const ActorID &GetActorId() const { return actor_id_; } @@ -295,6 +311,29 @@ class CoreWorker { const std::vector> &metadatas, std::vector> *return_objects); + /* Handlers for the worker's gRPC server. These are executed on the io_service_ and post + * work to the appropriate event loop. + */ + + /// Handle an "AssignTask" event corresponding to scheduling a normal or an actor task + /// on this worker from the raylet. + void HandleAssignTask(const rpc::AssignTaskRequest &request, + rpc::AssignTaskReply *reply, + rpc::SendReplyCallback send_reply_callback); + + /// Handle a "DirectActorAssignTask" event corresponding to scheduling an actor task + /// on this worker from another worker. + void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request, + rpc::DirectActorAssignTaskReply *reply, + rpc::SendReplyCallback send_reply_callback); + + /// Handle a "DirectActorAssignTask" event corresponding to the raylet notifiying this + /// worker that an argument is ready. + void HandleDirectActorCallArgWaitComplete( + const rpc::DirectActorCallArgWaitCompleteRequest &request, + rpc::DirectActorCallArgWaitCompleteReply *reply, + rpc::SendReplyCallback send_reply_callback); + private: /// Run the io_service_ event loop. This should be called in a background thread. void RunIOService(); @@ -306,7 +345,9 @@ class CoreWorker { /// Send the list of active object IDs to the raylet. void ReportActiveObjectIDs() LOCKS_EXCLUDED(object_ref_mu_); - /* Private methods related to task submission. */ + /// + /// Private methods related to task submission. + /// /// Give this worker a handle to an actor. /// @@ -328,7 +369,9 @@ class CoreWorker { /// \return Status::Invalid if we don't have this actor handle. Status GetActorHandle(const ActorID &actor_id, ActorHandle **actor_handle) const; - /* Private methods related to task execution. Should not be used by driver processes. */ + /// + /// Private methods related to task execution. Should not be used by driver processes. + /// /// Execute a task. /// @@ -408,7 +451,9 @@ class CoreWorker { // Thread that runs a boost::asio service to process IO events. std::thread io_thread_; - /* Fields related to ref counting objects. */ + /// + /// Fields related to ref counting objects. + /// /// Protects access to the set of active object ids. Since this set is updated /// very frequently, it is faster to lock around accesses rather than serialize @@ -422,7 +467,9 @@ class CoreWorker { /// last time it was sent to the raylet. bool active_object_ids_updated_ GUARDED_BY(object_ref_mu_) = false; - /* Fields related to storing and retrieving objects. */ + /// + /// Fields related to storing and retrieving objects. + /// /// In-memory store for return objects. This is used for `MEMORY` store provider. std::shared_ptr memory_store_; @@ -433,7 +480,9 @@ class CoreWorker { /// In-memory store interface. std::unique_ptr memory_store_provider_; - /* Fields related to task submission. */ + /// + /// Fields related to task submission. + /// // Interface to submit tasks directly to other actors. std::unique_ptr direct_actor_submitter_; @@ -441,7 +490,9 @@ class CoreWorker { /// Map from actor ID to a handle to that actor. absl::flat_hash_map> actor_handles_; - /* Fields related to task execution. */ + /// + /// Fields related to task execution. + /// /// Our actor ID. If this is nil, then we execute only stateless tasks. ActorID actor_id_; @@ -466,6 +517,9 @@ class CoreWorker { // Interface that receives tasks from the raylet. std::unique_ptr raylet_task_receiver_; + /// Common rpc service for all worker modules. + rpc::WorkerGrpcService grpc_service_; + // Interface that receives tasks from direct actor calls. std::unique_ptr direct_actor_task_receiver_; diff --git a/src/ray/core_worker/store_provider/memory_store_provider.h b/src/ray/core_worker/store_provider/memory_store_provider.h index 32ee88509..d93c1abd5 100644 --- a/src/ray/core_worker/store_provider/memory_store_provider.h +++ b/src/ray/core_worker/store_provider/memory_store_provider.h @@ -11,8 +11,6 @@ namespace ray { -class CoreWorker; - /// The class provides implementations for accessing local process memory store. /// An example usage for this is to retrieve the returned objects from direct /// actor call (see direct_actor_transport.cc). diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index a65f803a1..5b94cd536 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -12,8 +12,6 @@ namespace ray { -class CoreWorker; - /// The class provides implementations for accessing plasma store, which includes both /// local and remote stores. Local access goes is done via a /// CoreWorkerLocalPlasmaStoreProvider and remote access goes through the raylet. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 1d1924923..9c8367bb9 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -13,8 +13,7 @@ #include "ray/core_worker/store_provider/memory_store_provider.h" #include "ray/raylet/raylet_client.h" -#include "src/ray/protobuf/direct_actor.grpc.pb.h" -#include "src/ray/protobuf/direct_actor.pb.h" +#include "src/ray/protobuf/core_worker.pb.h" #include "src/ray/protobuf/gcs.pb.h" #include "src/ray/util/test_util.h" @@ -481,7 +480,7 @@ TEST_F(ZeroNodeTest, TestTaskArg) { ASSERT_EQ(*data, *buffer); } -// Performance batchmark for `PushTaskRequest` creation. +// Performance batchmark for `DirectActorAssignTaskRequest` creation. TEST_F(ZeroNodeTest, TestTaskSpecPerf) { // Create a dummy actor handle, and then create a number of `TaskSpec` // to benchmark performance. @@ -500,11 +499,11 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { function.GetFunctionDescriptor()); // Manually create `num_tasks` task specs, and for each of them create a - // `PushTaskRequest`, this is to batch performance of TaskSpec + // `DirectActorAssignTaskRequest`, this is to batch performance of TaskSpec // creation/copy/destruction. int64_t start_ms = current_time_ms(); const auto num_tasks = 10000 * 10; - RAY_LOG(INFO) << "start creating " << num_tasks << " PushTaskRequests"; + RAY_LOG(INFO) << "start creating " << num_tasks << " DirectActorAssignTaskRequests"; for (int i = 0; i < num_tasks; i++) { TaskOptions options{1, resources}; std::vector return_ids; @@ -529,10 +528,11 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { const auto &task_spec = builder.Build(); ASSERT_TRUE(task_spec.IsActorTask()); - auto request = std::unique_ptr(new rpc::PushTaskRequest); + auto request = std::unique_ptr( + new rpc::DirectActorAssignTaskRequest); request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); } - RAY_LOG(INFO) << "Finish creating " << num_tasks << " PushTaskRequests" + RAY_LOG(INFO) << "Finish creating " << num_tasks << " DirectActorAssignTaskRequests" << ", which takes " << current_time_ms() - start_ms << " ms"; } diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index b0e75a2c8..6c69213fa 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -5,15 +5,6 @@ using ray::rpc::ActorTableData; namespace ray { -bool HasByReferenceArgs(const TaskSpecification &spec) { - for (size_t i = 0; i < spec.NumArgs(); ++i) { - if (spec.ArgIdCount(i) > 0) { - return true; - } - } - return false; -} - CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter( boost::asio::io_service &io_service, std::unique_ptr store_provider) @@ -31,7 +22,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( const auto task_id = task_spec.TaskId(); const auto num_returns = task_spec.NumReturns(); - auto request = std::unique_ptr(new rpc::PushTaskRequest); + auto request = std::unique_ptr( + new rpc::DirectActorAssignTaskRequest); request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); std::unique_lock guard(mutex_); @@ -57,7 +49,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( // Submit request. auto &client = rpc_clients_[actor_id]; - PushTask(*client, std::move(request), actor_id, task_id, num_returns); + DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns); } else { // Actor is dead, treat the task as failure. RAY_CHECK(iter->second.state_ == ActorTableData::DEAD); @@ -114,8 +106,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate( void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( const ActorID &actor_id, std::string ip_address, int port) { - std::shared_ptr grpc_client = - rpc::DirectActorClient::make(ip_address, port, client_call_manager_); + std::shared_ptr grpc_client = + std::make_shared(ip_address, port, client_call_manager_); RAY_CHECK(rpc_clients_.emplace(actor_id, std::move(grpc_client)).second); // Submit all pending requests. @@ -125,20 +117,22 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( auto request = std::move(requests.front()); auto num_returns = request->task_spec().num_returns(); auto task_id = TaskID::FromBinary(request->task_spec().task_id()); - PushTask(*client, std::move(request), actor_id, task_id, num_returns); + DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns); requests.pop_front(); } } -void CoreWorkerDirectActorTaskSubmitter::PushTask( - rpc::DirectActorClient &client, std::unique_ptr request, - const ActorID &actor_id, const TaskID &task_id, int num_returns) { +void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask( + rpc::WorkerTaskClient &client, + std::unique_ptr request, const ActorID &actor_id, + const TaskID &task_id, int num_returns) { RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id; waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns)); - auto status = client.PushTask( - std::move(request), [this, actor_id, task_id, num_returns]( - Status status, const rpc::PushTaskReply &reply) { + auto status = client.DirectActorAssignTask( + std::move(request), + [this, actor_id, task_id, num_returns]( + Status status, const rpc::DirectActorAssignTaskReply &reply) { { std::unique_lock guard(mutex_); waiting_reply_tasks_[actor_id].erase(task_id); @@ -203,12 +197,9 @@ CoreWorkerDirectActorTaskReceiver::CoreWorkerDirectActorTaskReceiver( rpc::GrpcServer &server, const TaskHandler &task_handler, const std::function &exit_handler) : worker_context_(worker_context), - task_service_(main_io_service, *this), task_handler_(task_handler), exit_handler_(exit_handler), - task_main_io_service_(main_io_service) { - server.RegisterService(task_service_); -} + task_main_io_service_(main_io_service) {} void CoreWorkerDirectActorTaskReceiver::Init(RayletClient &raylet_client) { waiter_.reset(new DependencyWaiterImpl(raylet_client)); @@ -223,9 +214,9 @@ void CoreWorkerDirectActorTaskReceiver::SetMaxActorConcurrency(int max_concurren } } -void CoreWorkerDirectActorTaskReceiver::HandlePushTask( - const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) { +void CoreWorkerDirectActorTaskReceiver::HandleDirectActorAssignTask( + const rpc::DirectActorAssignTaskRequest &request, + rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use"; const TaskSpecification task_spec(request.task_spec()); RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId(); @@ -269,7 +260,7 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask( std::vector> return_by_value; auto status = task_handler_(task_spec, resource_ids, &return_by_value); if (status.IsSystemExit()) { - // In Python, SystemExit cannot be raised except on the main thread. To work + // In Python, SystemExit can only be raised on the main thread. To work // around this when we are executing tasks on worker threads, we re-post the // exit event explicitly on the main thread. task_main_io_service_.post([this]() { exit_handler_(); }); diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 9fb604ebb..a178bc7f6 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -14,8 +14,8 @@ #include "ray/core_worker/context.h" #include "ray/core_worker/store_provider/memory_store_provider.h" #include "ray/gcs/redis_gcs_client.h" -#include "ray/rpc/worker/direct_actor_client.h" -#include "ray/rpc/worker/direct_actor_server.h" +#include "ray/rpc/grpc_server.h" +#include "ray/rpc/worker/worker_client.h" namespace ray { @@ -67,9 +67,10 @@ class CoreWorkerDirectActorTaskSubmitter { /// \param[in] task_id The ID of a task. /// \param[in] num_returns Number of return objects. /// \return Void. - void PushTask(rpc::DirectActorClient &client, - std::unique_ptr request, const ActorID &actor_id, - const TaskID &task_id, int num_returns); + void DirectActorAssignTask(rpc::WorkerTaskClient &client, + std::unique_ptr request, + const ActorID &actor_id, const TaskID &task_id, + int num_returns); /// Treat a task as failed. /// @@ -114,10 +115,11 @@ class CoreWorkerDirectActorTaskSubmitter { /// /// TODO(zhijunfu): this will be moved into `actor_states_` later when we can /// subscribe updates for a specific actor. - std::unordered_map> rpc_clients_; + std::unordered_map> rpc_clients_; /// Map from actor id to the actor's pending requests. - std::unordered_map>> + std::unordered_map>> pending_requests_; /// Map from actor id to the tasks that are waiting for reply. @@ -327,7 +329,7 @@ class SchedulingQueue { friend class SchedulingQueueTest; }; -class CoreWorkerDirectActorTaskReceiver : public rpc::DirectActorHandler { +class CoreWorkerDirectActorTaskReceiver { public: using TaskHandler = std::function &raylet_client, - boost::asio::io_service &io_service, rpc::GrpcServer &server, - const TaskHandler &task_handler, const std::function &exit_handler) - : worker_context_(worker_context), - raylet_client_(raylet_client), - task_service_(io_service, *this), + std::unique_ptr &raylet_client, const TaskHandler &task_handler, + const std::function &exit_handler) + : raylet_client_(raylet_client), task_handler_(task_handler), - exit_handler_(exit_handler) { - server.RegisterService(task_service_); -} + exit_handler_(exit_handler) {} void CoreWorkerRayletTaskReceiver::HandleAssignTask( const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply, @@ -23,11 +18,6 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( const Task task(request.task()); const auto &task_spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId(); - if (task_spec.IsActorTask() && worker_context_.CurrentActorUseDirectCall()) { - send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr, - nullptr); - return; - } // Set the resource IDs for this task. // TODO: convert the resource map to protobuf and change this. diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index 958dea4dc..98d53ceb7 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -4,22 +4,19 @@ #include #include "ray/common/ray_object.h" -#include "ray/core_worker/context.h" #include "ray/raylet/raylet_client.h" #include "ray/rpc/worker/worker_server.h" namespace ray { -class CoreWorkerRayletTaskReceiver : public rpc::WorkerTaskHandler { +class CoreWorkerRayletTaskReceiver { public: using TaskHandler = std::function> *return_by_value)>; - CoreWorkerRayletTaskReceiver(WorkerContext &worker_context, - std::unique_ptr &raylet_client, - boost::asio::io_service &io_service, - rpc::GrpcServer &server, const TaskHandler &task_handler, + CoreWorkerRayletTaskReceiver(std::unique_ptr &raylet_client, + const TaskHandler &task_handler, const std::function &exit_handler); /// Handle a `AssignTask` request. @@ -31,15 +28,11 @@ class CoreWorkerRayletTaskReceiver : public rpc::WorkerTaskHandler { /// \param[in] send_reply_callback The callback to be called when the request is done. void HandleAssignTask(const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) override; + rpc::SendReplyCallback send_reply_callback); private: - // Worker context. - WorkerContext &worker_context_; /// Raylet client. std::unique_ptr &raylet_client_; - /// The rpc service for `WorkerTaskService`. - rpc::WorkerTaskGrpcService task_service_; /// The callback function to process a task. TaskHandler task_handler_; /// The callback function to exit the worker. diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 764529bf1..1e976ff48 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -4,6 +4,10 @@ package ray.rpc; import "src/ray/protobuf/common.proto"; +message ActiveObjectIDs { + repeated bytes object_ids = 1; +} + // Persistent state of an ActorHandle. message ActorHandle { // ID of the actor. @@ -27,3 +31,64 @@ message ActorHandle { // Whether direct actor call is used. bool is_direct_call = 7; } + +message AssignTaskRequest { + // The task to be pushed. + Task task = 1; + // A list of the resources reserved for this worker. + // TODO(zhijunfu): `resource_ids` is represented as + // flatbutters-serialized bytes, will be moved to protobuf later. + bytes resource_ids = 2; +} + +message AssignTaskReply { +} + +message ReturnObject { + // Object ID. + bytes object_id = 1; + // Data of the object. + bytes data = 2; + // Metadata of the object. + bytes metadata = 3; +} + +message DirectActorAssignTaskRequest { + // The task to be pushed. + TaskSpec task_spec = 1; + // The sequence number of the task for this client. This must increase + // sequentially starting from zero for each actor handle. The server + // will guarantee tasks execute in this sequence, waiting for any + // out-of-order request messages to arrive as necessary. + int64 sequence_number = 2; + // The max sequence number the client has processed responses for. This + // is a performance optimization that allows the client to tell the server + // to cancel any DirectActorAssignTaskRequests with seqno <= this value, rather than + // waiting for the server to time out waiting for missing messages. + int64 client_processed_up_to = 3; +} + +message DirectActorAssignTaskReply { + // The returned objects. + repeated ReturnObject return_objects = 1; +} + +message DirectActorCallArgWaitCompleteRequest { + // Id used to uniquely identify this request. This is sent back to the core + // worker to notify the wait has completed. + int64 tag = 1; +} + +message DirectActorCallArgWaitCompleteReply { +} + +service WorkerService { + // Push a task to a worker. + rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply); + // Push a task to a direct-call actor. + rpc DirectActorAssignTask(DirectActorAssignTaskRequest) + returns (DirectActorAssignTaskReply); + // Notify wait for direct actor call args has completed. + rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest) + returns (DirectActorCallArgWaitCompleteReply); +} diff --git a/src/ray/protobuf/direct_actor.proto b/src/ray/protobuf/direct_actor.proto deleted file mode 100644 index b767f1db1..000000000 --- a/src/ray/protobuf/direct_actor.proto +++ /dev/null @@ -1,52 +0,0 @@ -syntax = "proto3"; - -package ray.rpc; - -import "src/ray/protobuf/common.proto"; - -message ReturnObject { - // Object ID. - bytes object_id = 1; - // Data of the object. - bytes data = 2; - // Metaata of the object. - bytes metadata = 3; -} - -message PushTaskRequest { - // The task to be pushed. - TaskSpec task_spec = 1; - // The sequence number of the task for this client. This must increase - // sequentially starting from zero for each actor handle. The server - // will guarantee tasks execute in this sequence, waiting for any - // out-of-order request messages to arrive as necessary. - int64 sequence_number = 2; - // The max sequence number the client has processed responses for. This - // is a performance optimization that allows the client to tell the server - // to cancel any PushTaskRequests with seqno <= this value, rather than - // waiting for the server to time out waiting for missing messages. - int64 client_processed_up_to = 3; -} - -message PushTaskReply { - // The returned objects. - repeated ReturnObject return_objects = 1; -} - -message DirectActorCallArgWaitCompleteRequest { - // Id used to uniquely identify this request. This is sent back to the core - // worker to notify the wait has completed. - int64 tag = 1; -} - -message DirectActorCallArgWaitCompleteReply { -} - -// Service for direct actor. -service DirectActorService { - // Push a task to a worker. - rpc PushTask(PushTaskRequest) returns (PushTaskReply); - // Notify wait for direct actor call args has completed - rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest) - returns (DirectActorCallArgWaitCompleteReply); -} diff --git a/src/ray/protobuf/worker.proto b/src/ray/protobuf/worker.proto deleted file mode 100644 index e8cdb8e38..000000000 --- a/src/ray/protobuf/worker.proto +++ /dev/null @@ -1,27 +0,0 @@ -syntax = "proto3"; - -package ray.rpc; - -import "src/ray/protobuf/common.proto"; - -message ActiveObjectIDs { - repeated bytes object_ids = 1; -} - -message AssignTaskRequest { - // The task to be pushed. - Task task = 1; - // A list of the resources reserved for this worker. - // TODO(zhijunfu): `resource_ids` is represented as - // flatbutters-serialized bytes, will be moved to protobuf later. - bytes resource_ids = 2; -} - -message AssignTaskReply { -} - -// Service for worker. -service WorkerTaskService { - // Push a task to a worker. - rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply); -} diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 0218196fd..df6535098 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -4,8 +4,8 @@ #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/raylet.h" -#include "src/ray/protobuf/direct_actor.grpc.pb.h" -#include "src/ray/protobuf/direct_actor.pb.h" +#include "src/ray/protobuf/core_worker.grpc.pb.h" +#include "src/ray/protobuf/core_worker.pb.h" namespace ray { @@ -27,8 +27,6 @@ Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, i if (port_ > 0) { rpc_client_ = std::unique_ptr( new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_)); - direct_rpc_client_ = std::unique_ptr( - new rpc::DirectActorClient("127.0.0.1", port_, client_call_manager_)); } } @@ -163,7 +161,7 @@ void Worker::DirectActorCallArgWaitComplete(int64_t tag) { RAY_CHECK(port_ > 0); rpc::DirectActorCallArgWaitCompleteRequest request; request.set_tag(tag); - auto status = direct_rpc_client_->DirectActorCallArgWaitComplete( + auto status = rpc_client_->DirectActorCallArgWaitComplete( request, [](Status status, const rpc::DirectActorCallArgWaitCompleteReply &reply) { if (!status.ok()) { RAY_LOG(ERROR) << "Failed to send wait complete: " << status.ToString(); diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index fdf4d5dfc..a67b3a76e 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -8,7 +8,6 @@ #include "ray/common/task/scheduling_resources.h" #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" -#include "ray/rpc/worker/direct_actor_client.h" #include "ray/rpc/worker/worker_client.h" #include // pid_t @@ -109,8 +108,6 @@ class Worker { /// Whether the worker is detached. This is applies when the worker is actor. /// Detached actor means the actor's creator can exit without killing this actor. bool is_detached_actor_; - /// The rpc client to send tasks to the direct actor service. - std::unique_ptr direct_rpc_client_; }; } // namespace raylet diff --git a/src/ray/rpc/worker/direct_actor_client.h b/src/ray/rpc/worker/direct_actor_client.h deleted file mode 100644 index 5ca280739..000000000 --- a/src/ray/rpc/worker/direct_actor_client.h +++ /dev/null @@ -1,157 +0,0 @@ -#ifndef RAY_RPC_DIRECT_ACTOR_CLIENT_H -#define RAY_RPC_DIRECT_ACTOR_CLIENT_H - -#include -#include -#include -#include - -#include -#include "absl/base/thread_annotations.h" - -#include "ray/common/status.h" -#include "ray/rpc/client_call.h" -#include "ray/rpc/worker/direct_actor_common.h" -#include "ray/util/logging.h" -#include "src/ray/protobuf/direct_actor.grpc.pb.h" -#include "src/ray/protobuf/direct_actor.pb.h" - -namespace ray { -namespace rpc { - -/// Client used for communicating with a direct actor server. -class DirectActorClient : public std::enable_shared_from_this { - public: - /// Constructor. - /// - /// \param[in] address Address of the direct actor server. - /// \param[in] port Port of the direct actor server. - /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - static std::shared_ptr make(const std::string &address, - const int port, - ClientCallManager &client_call_manager) { - auto instance = new DirectActorClient(address, port, client_call_manager); - return std::shared_ptr(instance); - } - - /// Constructor. - /// - /// \param[in] address Address of the direct actor server. - /// \param[in] port Port of the direct actor server. - /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - DirectActorClient(const std::string &address, const int port, - ClientCallManager &client_call_manager) - : client_call_manager_(client_call_manager) { - std::shared_ptr channel = grpc::CreateChannel( - address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); - stub_ = DirectActorService::NewStub(channel); - }; - - /// Push a task. - /// - /// \param[in] request The request message. - /// \param[in] callback The callback function that handles reply. - /// \return if the rpc call succeeds - ray::Status PushTask(std::unique_ptr request, - const ClientCallback &callback) { - request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter()); - { - std::lock_guard lock(mutex_); - if (request->task_spec().caller_id() != cur_caller_id_) { - // We are running a new task, reset the seq no counter. - max_finished_seq_no_ = -1; - cur_caller_id_ = request->task_spec().caller_id(); - } - send_queue_.push_back(std::make_pair(std::move(request), callback)); - } - SendRequests(); - return ray::Status::OK(); - } - - /// Notify a wait has completed for direct actor call arguments. - /// - /// \param[in] request The request message. - /// \param[in] callback The callback function that handles reply. - /// \return if the rpc call succeeds - ray::Status DirectActorCallArgWaitComplete( - const DirectActorCallArgWaitCompleteRequest &request, - const ClientCallback &callback) { - auto call = client_call_manager_.CreateCall( - *stub_, &DirectActorService::Stub::PrepareAsyncDirectActorCallArgWaitComplete, - request, callback); - return call->GetStatus(); - } - - /// Send as many pending tasks as possible. This method is thread-safe. - /// - /// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being - /// sent at once. This prevents the server scheduling queue from being overwhelmed. - /// See direct_actor.proto for a description of the ordering protocol. - void SendRequests() { - std::lock_guard lock(mutex_); - auto this_ptr = this->shared_from_this(); - - while (!send_queue_.empty() && rpc_bytes_in_flight_ < kMaxBytesInFlight) { - auto pair = std::move(*send_queue_.begin()); - send_queue_.pop_front(); - - auto request = std::move(pair.first); - auto callback = pair.second; - int64_t task_size = RequestSizeInBytes(*request); - int64_t seq_no = request->sequence_number(); - request->set_client_processed_up_to(max_finished_seq_no_); - rpc_bytes_in_flight_ += task_size; - - client_call_manager_.CreateCall( - *stub_, &DirectActorService::Stub::PrepareAsyncPushTask, *request, - [this, this_ptr, seq_no, task_size, callback](Status status, - const rpc::PushTaskReply &reply) { - { - std::lock_guard lock(mutex_); - if (seq_no > max_finished_seq_no_) { - max_finished_seq_no_ = seq_no; - } - rpc_bytes_in_flight_ -= task_size; - RAY_CHECK(rpc_bytes_in_flight_ >= 0); - } - SendRequests(); - callback(status, reply); - }); - } - - if (!send_queue_.empty()) { - RAY_LOG(DEBUG) << "client send queue size " << send_queue_.size(); - } - } - - private: - /// Protects against unsafe concurrent access from the callback thread. - std::mutex mutex_; - - /// The gRPC-generated stub. - std::unique_ptr stub_; - - /// The `ClientCallManager` used for managing requests. - ClientCallManager &client_call_manager_; - - /// Queue of requests to send. - std::deque, ClientCallback>> - send_queue_ GUARDED_BY(mutex_); - - /// The number of bytes currently in flight. - int64_t rpc_bytes_in_flight_ GUARDED_BY(mutex_) = 0; - - /// The max sequence number we have processed responses for. - int64_t max_finished_seq_no_ GUARDED_BY(mutex_) = -1; - - /// The task id we are currently sending requests for. When this changes, - /// the max finished seq no counter is reset. - std::string cur_caller_id_; -}; - -} // namespace rpc -} // namespace ray - -#endif // RAY_RPC_DIRECT_ACTOR_CLIENT_H diff --git a/src/ray/rpc/worker/direct_actor_common.h b/src/ray/rpc/worker/direct_actor_common.h deleted file mode 100644 index f63cd27fb..000000000 --- a/src/ray/rpc/worker/direct_actor_common.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef RAY_RPC_DIRECT_ACTOR_COMMON_H -#define RAY_RPC_DIRECT_ACTOR_COMMON_H - -#include "src/ray/protobuf/direct_actor.grpc.pb.h" -#include "src/ray/protobuf/direct_actor.pb.h" - -namespace ray { -namespace rpc { - -/// The maximum number of requests in flight per client. -const int64_t kMaxBytesInFlight = 16 * 1024 * 1024; - -/// The base size in bytes per request. -const int64_t kBaseRequestSize = 1024; - -/// Get the estimated size in bytes of the given task. -const static int64_t RequestSizeInBytes(const PushTaskRequest &request) { - int64_t size = kBaseRequestSize; - for (auto &arg : request.task_spec().args()) { - size += arg.data().size(); - } - return size; -} - -} // namespace rpc -} // namespace ray - -#endif // RAY_RPC_DIRECT_ACTOR_COMMON_H diff --git a/src/ray/rpc/worker/direct_actor_server.h b/src/ray/rpc/worker/direct_actor_server.h deleted file mode 100644 index dc6c6b4fe..000000000 --- a/src/ray/rpc/worker/direct_actor_server.h +++ /dev/null @@ -1,88 +0,0 @@ -#ifndef RAY_RPC_DIRECT_ACTOR_SERVER_H -#define RAY_RPC_DIRECT_ACTOR_SERVER_H - -#include "ray/rpc/grpc_server.h" -#include "ray/rpc/server_call.h" - -#include "src/ray/protobuf/direct_actor.grpc.pb.h" -#include "src/ray/protobuf/direct_actor.pb.h" - -namespace ray { -namespace rpc { - -/// Interface of the `DirectActorService`, see `src/ray/protobuf/direct_actor.proto`. -class DirectActorHandler { - public: - /// Handle a `PushTask` request. - /// The implementation can handle this request asynchronously. When hanling is done, the - /// `done_callback` should be called. - /// - /// \param[in] request The request message. - /// \param[out] reply The reply message. - /// \param[in] done_callback The callback to be called when the request is done. - virtual void HandlePushTask(const PushTaskRequest &request, PushTaskReply *reply, - SendReplyCallback send_reply_callback) = 0; - - /// Handle a wait reply for direct actor call arg dependencies. - /// - /// \param[in] request The request message. - /// \param[out] reply The reply message. - /// \param[in] send_replay_callback The callback to be called when the request is done. - virtual void HandleDirectActorCallArgWaitComplete( - const rpc::DirectActorCallArgWaitCompleteRequest &request, - rpc::DirectActorCallArgWaitCompleteReply *reply, - rpc::SendReplyCallback send_reply_callback) = 0; -}; - -/// The `GrpcServer` for `WorkerService`. -class DirectActorGrpcService : public GrpcService { - public: - /// Constructor. - /// - /// \param[in] main_service See super class. - /// \param[in] handler The service handler that actually handle the requests. - DirectActorGrpcService(boost::asio::io_service &main_service, - DirectActorHandler &service_handler) - : GrpcService(main_service), service_handler_(service_handler){}; - - protected: - grpc::Service &GetGrpcService() override { return service_; } - - void InitServerCallFactories( - const std::unique_ptr &cq, - std::vector, int>> - *server_call_factories_and_concurrencies) override { - // Initialize the Factory for `PushTask` requests. - std::unique_ptr push_task_call_Factory( - new ServerCallFactoryImpl( - service_, &DirectActorService::AsyncService::RequestPushTask, - service_handler_, &DirectActorHandler::HandlePushTask, cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(push_task_call_Factory), 100); - - // Initialize the Factory for `DirectActorCallArgWaitComplete` requests. - std::unique_ptr wait_complete_call_Factory( - new ServerCallFactoryImpl( - service_, - &DirectActorService::AsyncService::RequestDirectActorCallArgWaitComplete, - service_handler_, &DirectActorHandler::HandleDirectActorCallArgWaitComplete, - cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(wait_complete_call_Factory), 100); - } - - private: - /// The grpc async service object. - DirectActorService::AsyncService service_; - - /// The service handler that actually handle the requests. - DirectActorHandler &service_handler_; -}; - -} // namespace rpc -} // namespace ray - -#endif diff --git a/src/ray/rpc/worker/worker_client.h b/src/ray/rpc/worker/worker_client.h index 91ded57ac..29d0bf3c2 100644 --- a/src/ray/rpc/worker/worker_client.h +++ b/src/ray/rpc/worker/worker_client.h @@ -1,21 +1,40 @@ #ifndef RAY_RPC_WORKER_CLIENT_H #define RAY_RPC_WORKER_CLIENT_H +#include +#include +#include #include #include +#include "absl/base/thread_annotations.h" #include "ray/common/status.h" #include "ray/rpc/client_call.h" #include "ray/util/logging.h" -#include "src/ray/protobuf/worker.grpc.pb.h" -#include "src/ray/protobuf/worker.pb.h" +#include "src/ray/protobuf/core_worker.grpc.pb.h" +#include "src/ray/protobuf/core_worker.pb.h" namespace ray { namespace rpc { +/// The maximum number of requests in flight per client. +const int64_t kMaxBytesInFlight = 16 * 1024 * 1024; + +/// The base size in bytes per request. +const int64_t kBaseRequestSize = 1024; + +/// Get the estimated size in bytes of the given task. +const static int64_t RequestSizeInBytes(const DirectActorAssignTaskRequest &request) { + int64_t size = kBaseRequestSize; + for (auto &arg : request.task_spec().args()) { + size += arg.data().size(); + } + return size; +} + /// Client used for communicating with a remote worker server. -class WorkerTaskClient { +class WorkerTaskClient : public std::enable_shared_from_this { public: /// Constructor. /// @@ -27,7 +46,7 @@ class WorkerTaskClient { : client_call_manager_(client_call_manager) { std::shared_ptr channel = grpc::CreateChannel( address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); - stub_ = WorkerTaskService::NewStub(channel); + stub_ = WorkerService::NewStub(channel); }; /// Assign a task to the work. @@ -37,19 +56,119 @@ class WorkerTaskClient { /// \return if the rpc call succeeds ray::Status AssignTask(const AssignTaskRequest &request, const ClientCallback &callback) { - auto call = client_call_manager_ - .CreateCall( - *stub_, &WorkerTaskService::Stub::PrepareAsyncAssignTask, request, - callback); + auto call = + client_call_manager_ + .CreateCall( + *stub_, &WorkerService::Stub::PrepareAsyncAssignTask, request, callback); return call->GetStatus(); } + /// Push a task. + /// + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// \return if the rpc call succeeds + ray::Status DirectActorAssignTask( + std::unique_ptr request, + const ClientCallback &callback) { + request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter()); + { + std::lock_guard lock(mutex_); + if (request->task_spec().caller_id() != cur_caller_id_) { + // We are running a new task, reset the seq no counter. + max_finished_seq_no_ = -1; + cur_caller_id_ = request->task_spec().caller_id(); + } + send_queue_.push_back(std::make_pair(std::move(request), callback)); + } + SendRequests(); + return ray::Status::OK(); + } + + /// Notify a wait has completed for direct actor call arguments. + /// + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// \return if the rpc call succeeds + ray::Status DirectActorCallArgWaitComplete( + const DirectActorCallArgWaitCompleteRequest &request, + const ClientCallback &callback) { + auto call = + client_call_manager_ + .CreateCall( + *stub_, &WorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete, + request, callback); + return call->GetStatus(); + } + + /// Send as many pending tasks as possible. This method is thread-safe. + /// + /// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being + /// sent at once. This prevents the server scheduling queue from being overwhelmed. + /// See direct_actor.proto for a description of the ordering protocol. + void SendRequests() { + std::lock_guard lock(mutex_); + auto this_ptr = this->shared_from_this(); + + while (!send_queue_.empty() && rpc_bytes_in_flight_ < kMaxBytesInFlight) { + auto pair = std::move(*send_queue_.begin()); + send_queue_.pop_front(); + + auto request = std::move(pair.first); + auto callback = pair.second; + int64_t task_size = RequestSizeInBytes(*request); + int64_t seq_no = request->sequence_number(); + request->set_client_processed_up_to(max_finished_seq_no_); + rpc_bytes_in_flight_ += task_size; + + client_call_manager_.CreateCall( + *stub_, &WorkerService::Stub::PrepareAsyncDirectActorAssignTask, *request, + [this, this_ptr, seq_no, task_size, callback]( + Status status, const rpc::DirectActorAssignTaskReply &reply) { + { + std::lock_guard lock(mutex_); + if (seq_no > max_finished_seq_no_) { + max_finished_seq_no_ = seq_no; + } + rpc_bytes_in_flight_ -= task_size; + RAY_CHECK(rpc_bytes_in_flight_ >= 0); + } + SendRequests(); + callback(status, reply); + }); + } + + if (!send_queue_.empty()) { + RAY_LOG(DEBUG) << "client send queue size " << send_queue_.size(); + } + } + private: + /// Protects against unsafe concurrent access from the callback thread. + std::mutex mutex_; + /// The gRPC-generated stub. - std::unique_ptr stub_; + std::unique_ptr stub_; /// The `ClientCallManager` used for managing requests. ClientCallManager &client_call_manager_; + + /// Queue of requests to send. + std::deque, + ClientCallback>> + send_queue_ GUARDED_BY(mutex_); + + /// The number of bytes currently in flight. + int64_t rpc_bytes_in_flight_ GUARDED_BY(mutex_) = 0; + + /// The max sequence number we have processed responses for. + int64_t max_finished_seq_no_ GUARDED_BY(mutex_) = -1; + + /// The task id we are currently sending requests for. When this changes, + /// the max finished seq no counter is reset. + std::string cur_caller_id_; }; } // namespace rpc diff --git a/src/ray/rpc/worker/worker_server.cc b/src/ray/rpc/worker/worker_server.cc new file mode 100644 index 000000000..44b3c7f7b --- /dev/null +++ b/src/ray/rpc/worker/worker_server.cc @@ -0,0 +1,28 @@ +#include "ray/rpc/worker/worker_server.h" +#include "ray/core_worker/core_worker.h" + +namespace ray { +namespace rpc { + +#define RAY_CORE_WORKER_RPC_HANDLER(HANDLER, CONCURRENCY) \ + std::unique_ptr HANDLER##_call_factory( \ + new ServerCallFactoryImpl( \ + service_, &WorkerService::AsyncService::Request##HANDLER, core_worker_, \ + &CoreWorker::Handle##HANDLER, cq, main_service_)); \ + server_call_factories_and_concurrencies->emplace_back( \ + std::move(HANDLER##_call_factory), CONCURRENCY); + +WorkerGrpcService::WorkerGrpcService(boost::asio::io_service &main_service, + CoreWorker &core_worker) + : GrpcService(main_service), core_worker_(core_worker){}; + +void WorkerGrpcService::InitServerCallFactories( + const std::unique_ptr &cq, + std::vector, int>> + *server_call_factories_and_concurrencies) { + RAY_CORE_WORKER_RPC_HANDLERS +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/worker/worker_server.h b/src/ray/rpc/worker/worker_server.h index fe25f9c94..d9f87126a 100644 --- a/src/ray/rpc/worker/worker_server.h +++ b/src/ray/rpc/worker/worker_server.h @@ -4,36 +4,23 @@ #include "ray/rpc/grpc_server.h" #include "ray/rpc/server_call.h" -#include "src/ray/protobuf/worker.grpc.pb.h" -#include "src/ray/protobuf/worker.pb.h" +#include "src/ray/protobuf/core_worker.grpc.pb.h" +#include "src/ray/protobuf/core_worker.pb.h" namespace ray { + +class CoreWorker; + namespace rpc { -/// Interface of the `WorkerService`, see `src/ray/protobuf/worker.proto`. -class WorkerTaskHandler { - public: - /// Handle a `AssignTask` request. - /// The implementation can handle this request asynchronously. When handling is done, - /// the `send_reply_callback` should be called. - /// - /// \param[in] request The request message. - /// \param[out] reply The reply message. - /// \param[in] send_reply_callback The callback to be called when the request is done. - virtual void HandleAssignTask(const AssignTaskRequest &request, AssignTaskReply *reply, - SendReplyCallback send_reply_callback) = 0; -}; - /// The `GrpcServer` for `WorkerService`. -class WorkerTaskGrpcService : public GrpcService { +class WorkerGrpcService : public GrpcService { public: /// Constructor. /// /// \param[in] main_service See super class. /// \param[in] handler The service handler that actually handle the requests. - WorkerTaskGrpcService(boost::asio::io_service &main_service, - WorkerTaskHandler &service_handler) - : GrpcService(main_service), service_handler_(service_handler){}; + WorkerGrpcService(boost::asio::io_service &main_service, CoreWorker &core_worker); protected: grpc::Service &GetGrpcService() override { return service_; } @@ -41,25 +28,14 @@ class WorkerTaskGrpcService : public GrpcService { void InitServerCallFactories( const std::unique_ptr &cq, std::vector, int>> - *server_call_factories_and_concurrencies) override { - // Initialize the Factory for `AssignTask` requests. - std::unique_ptr push_task_call_Factory( - new ServerCallFactoryImpl( - service_, &WorkerTaskService::AsyncService::RequestAssignTask, - service_handler_, &WorkerTaskHandler::HandleAssignTask, cq, main_service_)); - - // Set `AssignTask`'s accept concurrency to 5. - server_call_factories_and_concurrencies->emplace_back( - std::move(push_task_call_Factory), 5); - } + *server_call_factories_and_concurrencies) override; private: /// The grpc async service object. - WorkerTaskService::AsyncService service_; + WorkerService::AsyncService service_; - /// The service handler that actually handle the requests. - WorkerTaskHandler &service_handler_; + /// The core worker that actually handles the requests. + CoreWorker &core_worker_; }; } // namespace rpc