Simplify gRPC service definition for the worker (#6095)

This commit is contained in:
Edward Oakes
2019-11-06 13:00:39 -08:00
committed by Eric Liang
parent 3f83b2daa9
commit 9820c10a09
23 changed files with 390 additions and 532 deletions
+4 -34
View File
@@ -77,17 +77,6 @@ cc_proto_library(
deps = [":object_manager_proto"],
)
proto_library(
name = "worker_proto",
srcs = ["src/ray/protobuf/worker.proto"],
deps = [":common_proto"],
)
cc_proto_library(
name = "worker_cc_proto",
deps = ["worker_proto"],
)
proto_library(
name = "core_worker_proto",
srcs = ["src/ray/protobuf/core_worker.proto"],
@@ -95,21 +84,10 @@ proto_library(
)
cc_proto_library(
name = "core_worker_cc_proto",
name = "worker_cc_proto",
deps = ["core_worker_proto"],
)
proto_library(
name = "direct_actor_proto",
srcs = ["src/ray/protobuf/direct_actor.proto"],
deps = [":common_proto"],
)
cc_proto_library(
name = "direct_actor_cc_proto",
deps = ["direct_actor_proto"],
)
proto_library(
name = "serialization_proto",
srcs = ["src/ray/protobuf/serialization.proto"],
@@ -193,19 +171,11 @@ cc_library(
# Worker gRPC lib.
cc_grpc_library(
name = "worker_cc_grpc",
srcs = [":worker_proto"],
srcs = [":core_worker_proto"],
grpc_only = True,
deps = [":worker_cc_proto"],
)
# direct actor gRPC lib.
cc_grpc_library(
name = "direct_actor_cc_grpc",
srcs = [":direct_actor_proto"],
grpc_only = True,
deps = [":direct_actor_cc_proto"],
)
# worker server and client.
cc_library(
name = "worker_rpc",
@@ -214,7 +184,6 @@ cc_library(
]),
copts = COPTS,
deps = [
"direct_actor_cc_grpc",
":grpc_common_lib",
":ray_common",
":worker_cc_grpc",
@@ -363,6 +332,7 @@ cc_library(
"src/ray/core_worker/store_provider/*.cc",
"src/ray/core_worker/store_provider/memory_store/*.cc",
"src/ray/core_worker/transport/*.cc",
"src/ray/rpc/worker/*.cc",
],
exclude = [
"src/ray/core_worker/*_test.cc",
@@ -379,7 +349,7 @@ cc_library(
deps = [
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
":core_worker_cc_proto",
":worker_cc_proto",
":ray_common",
":ray_util",
# TODO(hchen): After `raylet_client` is migrated to gRPC, `core_worker_lib`
-1
View File
@@ -89,7 +89,6 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
if (task_spec.IsNormalTask()) {
RAY_CHECK(current_job_id_.IsNil());
SetCurrentJobId(task_spec.JobId());
current_actor_use_direct_call_ = false;
} else if (task_spec.IsActorCreationTask()) {
RAY_CHECK(current_job_id_.IsNil());
SetCurrentJobId(task_spec.JobId());
+1 -1
View File
@@ -47,7 +47,7 @@ class WorkerContext {
const WorkerID worker_id_;
JobID current_job_id_;
ActorID current_actor_id_;
bool current_actor_use_direct_call_;
bool current_actor_use_direct_call_ = false;
int current_actor_max_concurrency_;
private:
+38 -5
View File
@@ -86,7 +86,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
gcs_client_(gcs_options),
memory_store_(std::make_shared<CoreWorkerMemoryStore>()),
task_execution_service_work_(task_execution_service_),
task_execution_callback_(task_execution_callback) {
task_execution_callback_(task_execution_callback),
grpc_service_(io_service_, *this) {
// Initialize logging if log_dir is passed. Otherwise, it must be initialized
// and cleaned up by the caller.
if (log_dir_ != "") {
@@ -111,14 +112,13 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
// Initialize task receivers.
auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3);
raylet_task_receiver_ =
std::unique_ptr<CoreWorkerRayletTaskReceiver>(new CoreWorkerRayletTaskReceiver(
worker_context_, raylet_client_, task_execution_service_, worker_server_,
execute_task, exit_handler));
raylet_task_receiver_ = std::unique_ptr<CoreWorkerRayletTaskReceiver>(
new CoreWorkerRayletTaskReceiver(raylet_client_, execute_task, exit_handler));
direct_actor_task_receiver_ = std::unique_ptr<CoreWorkerDirectActorTaskReceiver>(
new CoreWorkerDirectActorTaskReceiver(worker_context_, task_execution_service_,
worker_server_, execute_task,
exit_handler));
worker_server_.RegisterService(grpc_service_);
}
// Start RPC server after all the task receivers are properly initialized.
@@ -750,4 +750,37 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
return status;
}
void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request,
rpc::AssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
if (worker_context_.CurrentActorUseDirectCall()) {
send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr,
nullptr);
return;
} else {
task_execution_service_.post([=] {
raylet_task_receiver_->HandleAssignTask(request, reply, send_reply_callback);
});
}
}
void CoreWorker::HandleDirectActorAssignTask(
const rpc::DirectActorAssignTaskRequest &request,
rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) {
task_execution_service_.post([=] {
direct_actor_task_receiver_->HandleDirectActorAssignTask(request, reply,
send_reply_callback);
});
}
void CoreWorker::HandleDirectActorCallArgWaitComplete(
const rpc::DirectActorCallArgWaitCompleteRequest &request,
rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback) {
task_execution_service_.post([=] {
direct_actor_task_receiver_->HandleDirectActorCallArgWaitComplete(
request, reply, send_reply_callback);
});
}
} // namespace ray
+63 -9
View File
@@ -20,6 +20,16 @@
#include "ray/rpc/worker/worker_client.h"
#include "ray/rpc/worker/worker_server.h"
/// The set of gRPC handlers and their associated level of concurrency. If you want to
/// add a new call to the worker gRPC server, do the following:
/// 1) Add the rpc to the WorkerService in core_worker.proto, e.g., "ExampleCall"
/// 2) Add a new handler to the macro below: "RAY_CORE_WORKER_RPC_HANDLER(ExampleCall, 1)"
/// 3) Add a method to the CoreWorker class below: "CoreWorker::HandleExampleCall"
#define RAY_CORE_WORKER_RPC_HANDLERS \
RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \
RAY_CORE_WORKER_RPC_HANDLER(DirectActorAssignTask, 9999) \
RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100)
namespace ray {
/// The root class that contains all the core and language-independent functionalities
@@ -95,7 +105,9 @@ class CoreWorker {
// in the heartbeat messsage.
void RemoveActiveObjectID(const ObjectID &object_id) LOCKS_EXCLUDED(object_ref_mu_);
/* Public methods related to storing and retrieving objects. */
///
/// Public methods related to storing and retrieving objects.
///
/// Set options for this client's interactions with the object store.
///
@@ -198,7 +210,9 @@ class CoreWorker {
/// \return std::string The string describing memory usage.
std::string MemoryUsageString();
/* Public methods related to task submission. */
///
/// Public methods related to task submission.
///
/// Get the caller ID used to submit tasks from this worker to an actor.
///
@@ -268,7 +282,9 @@ class CoreWorker {
/// \return Status::Invalid if we don't have the specified handle.
Status SerializeActorHandle(const ActorID &actor_id, std::string *output) const;
/* Public methods related to task execution. Should not be used by driver processes. */
///
/// Public methods related to task execution. Should not be used by driver processes.
///
const ActorID &GetActorId() const { return actor_id_; }
@@ -295,6 +311,29 @@ class CoreWorker {
const std::vector<std::shared_ptr<Buffer>> &metadatas,
std::vector<std::shared_ptr<RayObject>> *return_objects);
/* Handlers for the worker's gRPC server. These are executed on the io_service_ and post
* work to the appropriate event loop.
*/
/// Handle an "AssignTask" event corresponding to scheduling a normal or an actor task
/// on this worker from the raylet.
void HandleAssignTask(const rpc::AssignTaskRequest &request,
rpc::AssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Handle a "DirectActorAssignTask" event corresponding to scheduling an actor task
/// on this worker from another worker.
void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request,
rpc::DirectActorAssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Handle a "DirectActorAssignTask" event corresponding to the raylet notifiying this
/// worker that an argument is ready.
void HandleDirectActorCallArgWaitComplete(
const rpc::DirectActorCallArgWaitCompleteRequest &request,
rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback);
private:
/// Run the io_service_ event loop. This should be called in a background thread.
void RunIOService();
@@ -306,7 +345,9 @@ class CoreWorker {
/// Send the list of active object IDs to the raylet.
void ReportActiveObjectIDs() LOCKS_EXCLUDED(object_ref_mu_);
/* Private methods related to task submission. */
///
/// Private methods related to task submission.
///
/// Give this worker a handle to an actor.
///
@@ -328,7 +369,9 @@ class CoreWorker {
/// \return Status::Invalid if we don't have this actor handle.
Status GetActorHandle(const ActorID &actor_id, ActorHandle **actor_handle) const;
/* Private methods related to task execution. Should not be used by driver processes. */
///
/// Private methods related to task execution. Should not be used by driver processes.
///
/// Execute a task.
///
@@ -408,7 +451,9 @@ class CoreWorker {
// Thread that runs a boost::asio service to process IO events.
std::thread io_thread_;
/* Fields related to ref counting objects. */
///
/// Fields related to ref counting objects.
///
/// Protects access to the set of active object ids. Since this set is updated
/// very frequently, it is faster to lock around accesses rather than serialize
@@ -422,7 +467,9 @@ class CoreWorker {
/// last time it was sent to the raylet.
bool active_object_ids_updated_ GUARDED_BY(object_ref_mu_) = false;
/* Fields related to storing and retrieving objects. */
///
/// Fields related to storing and retrieving objects.
///
/// In-memory store for return objects. This is used for `MEMORY` store provider.
std::shared_ptr<CoreWorkerMemoryStore> memory_store_;
@@ -433,7 +480,9 @@ class CoreWorker {
/// In-memory store interface.
std::unique_ptr<CoreWorkerMemoryStoreProvider> memory_store_provider_;
/* Fields related to task submission. */
///
/// Fields related to task submission.
///
// Interface to submit tasks directly to other actors.
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;
@@ -441,7 +490,9 @@ class CoreWorker {
/// Map from actor ID to a handle to that actor.
absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_;
/* Fields related to task execution. */
///
/// Fields related to task execution.
///
/// Our actor ID. If this is nil, then we execute only stateless tasks.
ActorID actor_id_;
@@ -466,6 +517,9 @@ class CoreWorker {
// Interface that receives tasks from the raylet.
std::unique_ptr<CoreWorkerRayletTaskReceiver> raylet_task_receiver_;
/// Common rpc service for all worker modules.
rpc::WorkerGrpcService grpc_service_;
// Interface that receives tasks from direct actor calls.
std::unique_ptr<CoreWorkerDirectActorTaskReceiver> direct_actor_task_receiver_;
@@ -11,8 +11,6 @@
namespace ray {
class CoreWorker;
/// The class provides implementations for accessing local process memory store.
/// An example usage for this is to retrieve the returned objects from direct
/// actor call (see direct_actor_transport.cc).
@@ -12,8 +12,6 @@
namespace ray {
class CoreWorker;
/// The class provides implementations for accessing plasma store, which includes both
/// local and remote stores. Local access goes is done via a
/// CoreWorkerLocalPlasmaStoreProvider and remote access goes through the raylet.
+7 -7
View File
@@ -13,8 +13,7 @@
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/raylet/raylet_client.h"
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
#include "src/ray/protobuf/direct_actor.pb.h"
#include "src/ray/protobuf/core_worker.pb.h"
#include "src/ray/protobuf/gcs.pb.h"
#include "src/ray/util/test_util.h"
@@ -481,7 +480,7 @@ TEST_F(ZeroNodeTest, TestTaskArg) {
ASSERT_EQ(*data, *buffer);
}
// Performance batchmark for `PushTaskRequest` creation.
// Performance batchmark for `DirectActorAssignTaskRequest` creation.
TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
// Create a dummy actor handle, and then create a number of `TaskSpec`
// to benchmark performance.
@@ -500,11 +499,11 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
function.GetFunctionDescriptor());
// Manually create `num_tasks` task specs, and for each of them create a
// `PushTaskRequest`, this is to batch performance of TaskSpec
// `DirectActorAssignTaskRequest`, this is to batch performance of TaskSpec
// creation/copy/destruction.
int64_t start_ms = current_time_ms();
const auto num_tasks = 10000 * 10;
RAY_LOG(INFO) << "start creating " << num_tasks << " PushTaskRequests";
RAY_LOG(INFO) << "start creating " << num_tasks << " DirectActorAssignTaskRequests";
for (int i = 0; i < num_tasks; i++) {
TaskOptions options{1, resources};
std::vector<ObjectID> return_ids;
@@ -529,10 +528,11 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
const auto &task_spec = builder.Build();
ASSERT_TRUE(task_spec.IsActorTask());
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
auto request = std::unique_ptr<rpc::DirectActorAssignTaskRequest>(
new rpc::DirectActorAssignTaskRequest);
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
}
RAY_LOG(INFO) << "Finish creating " << num_tasks << " PushTaskRequests"
RAY_LOG(INFO) << "Finish creating " << num_tasks << " DirectActorAssignTaskRequests"
<< ", which takes " << current_time_ms() - start_ms << " ms";
}
@@ -5,15 +5,6 @@ using ray::rpc::ActorTableData;
namespace ray {
bool HasByReferenceArgs(const TaskSpecification &spec) {
for (size_t i = 0; i < spec.NumArgs(); ++i) {
if (spec.ArgIdCount(i) > 0) {
return true;
}
}
return false;
}
CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter(
boost::asio::io_service &io_service,
std::unique_ptr<CoreWorkerMemoryStoreProvider> store_provider)
@@ -31,7 +22,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
const auto task_id = task_spec.TaskId();
const auto num_returns = task_spec.NumReturns();
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
auto request = std::unique_ptr<rpc::DirectActorAssignTaskRequest>(
new rpc::DirectActorAssignTaskRequest);
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
std::unique_lock<std::mutex> guard(mutex_);
@@ -57,7 +49,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
// Submit request.
auto &client = rpc_clients_[actor_id];
PushTask(*client, std::move(request), actor_id, task_id, num_returns);
DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns);
} else {
// Actor is dead, treat the task as failure.
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
@@ -114,8 +106,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
const ActorID &actor_id, std::string ip_address, int port) {
std::shared_ptr<rpc::DirectActorClient> grpc_client =
rpc::DirectActorClient::make(ip_address, port, client_call_manager_);
std::shared_ptr<rpc::WorkerTaskClient> grpc_client =
std::make_shared<rpc::WorkerTaskClient>(ip_address, port, client_call_manager_);
RAY_CHECK(rpc_clients_.emplace(actor_id, std::move(grpc_client)).second);
// Submit all pending requests.
@@ -125,20 +117,22 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
auto request = std::move(requests.front());
auto num_returns = request->task_spec().num_returns();
auto task_id = TaskID::FromBinary(request->task_spec().task_id());
PushTask(*client, std::move(request), actor_id, task_id, num_returns);
DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns);
requests.pop_front();
}
}
void CoreWorkerDirectActorTaskSubmitter::PushTask(
rpc::DirectActorClient &client, std::unique_ptr<rpc::PushTaskRequest> request,
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask(
rpc::WorkerTaskClient &client,
std::unique_ptr<rpc::DirectActorAssignTaskRequest> request, const ActorID &actor_id,
const TaskID &task_id, int num_returns) {
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
auto status = client.PushTask(
std::move(request), [this, actor_id, task_id, num_returns](
Status status, const rpc::PushTaskReply &reply) {
auto status = client.DirectActorAssignTask(
std::move(request),
[this, actor_id, task_id, num_returns](
Status status, const rpc::DirectActorAssignTaskReply &reply) {
{
std::unique_lock<std::mutex> guard(mutex_);
waiting_reply_tasks_[actor_id].erase(task_id);
@@ -203,12 +197,9 @@ CoreWorkerDirectActorTaskReceiver::CoreWorkerDirectActorTaskReceiver(
rpc::GrpcServer &server, const TaskHandler &task_handler,
const std::function<void()> &exit_handler)
: worker_context_(worker_context),
task_service_(main_io_service, *this),
task_handler_(task_handler),
exit_handler_(exit_handler),
task_main_io_service_(main_io_service) {
server.RegisterService(task_service_);
}
task_main_io_service_(main_io_service) {}
void CoreWorkerDirectActorTaskReceiver::Init(RayletClient &raylet_client) {
waiter_.reset(new DependencyWaiterImpl(raylet_client));
@@ -223,9 +214,9 @@ void CoreWorkerDirectActorTaskReceiver::SetMaxActorConcurrency(int max_concurren
}
}
void CoreWorkerDirectActorTaskReceiver::HandlePushTask(
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
void CoreWorkerDirectActorTaskReceiver::HandleDirectActorAssignTask(
const rpc::DirectActorAssignTaskRequest &request,
rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) {
RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
const TaskSpecification task_spec(request.task_spec());
RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId();
@@ -269,7 +260,7 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask(
std::vector<std::shared_ptr<RayObject>> return_by_value;
auto status = task_handler_(task_spec, resource_ids, &return_by_value);
if (status.IsSystemExit()) {
// In Python, SystemExit cannot be raised except on the main thread. To work
// In Python, SystemExit can only be raised on the main thread. To work
// around this when we are executing tasks on worker threads, we re-post the
// exit event explicitly on the main thread.
task_main_io_service_.post([this]() { exit_handler_(); });
@@ -14,8 +14,8 @@
#include "ray/core_worker/context.h"
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/rpc/worker/direct_actor_client.h"
#include "ray/rpc/worker/direct_actor_server.h"
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/worker/worker_client.h"
namespace ray {
@@ -67,9 +67,10 @@ class CoreWorkerDirectActorTaskSubmitter {
/// \param[in] task_id The ID of a task.
/// \param[in] num_returns Number of return objects.
/// \return Void.
void PushTask(rpc::DirectActorClient &client,
std::unique_ptr<rpc::PushTaskRequest> request, const ActorID &actor_id,
const TaskID &task_id, int num_returns);
void DirectActorAssignTask(rpc::WorkerTaskClient &client,
std::unique_ptr<rpc::DirectActorAssignTaskRequest> request,
const ActorID &actor_id, const TaskID &task_id,
int num_returns);
/// Treat a task as failed.
///
@@ -114,10 +115,11 @@ class CoreWorkerDirectActorTaskSubmitter {
///
/// TODO(zhijunfu): this will be moved into `actor_states_` later when we can
/// subscribe updates for a specific actor.
std::unordered_map<ActorID, std::shared_ptr<rpc::DirectActorClient>> rpc_clients_;
std::unordered_map<ActorID, std::shared_ptr<rpc::WorkerTaskClient>> rpc_clients_;
/// Map from actor id to the actor's pending requests.
std::unordered_map<ActorID, std::list<std::unique_ptr<rpc::PushTaskRequest>>>
std::unordered_map<ActorID,
std::list<std::unique_ptr<rpc::DirectActorAssignTaskRequest>>>
pending_requests_;
/// Map from actor id to the tasks that are waiting for reply.
@@ -327,7 +329,7 @@ class SchedulingQueue {
friend class SchedulingQueueTest;
};
class CoreWorkerDirectActorTaskReceiver : public rpc::DirectActorHandler {
class CoreWorkerDirectActorTaskReceiver {
public:
using TaskHandler = std::function<Status(
const TaskSpecification &task_spec, const ResourceMappingType &resource_ids,
@@ -342,13 +344,14 @@ class CoreWorkerDirectActorTaskReceiver : public rpc::DirectActorHandler {
/// Initialize this receiver. This must be called prior to use.
void Init(RayletClient &client);
/// Handle a `PushTask` request.
/// Handle a `DirectActorAssignTask` request.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] send_reply_callback The callback to be called when the request is done.
void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request,
rpc::DirectActorAssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Handle a `DirectActorCallArgWaitComplete` request.
///
@@ -358,7 +361,7 @@ class CoreWorkerDirectActorTaskReceiver : public rpc::DirectActorHandler {
void HandleDirectActorCallArgWaitComplete(
const rpc::DirectActorCallArgWaitCompleteRequest &request,
rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
rpc::SendReplyCallback send_reply_callback);
/// Set the max concurrency at runtime. It cannot be changed once set.
void SetMaxActorConcurrency(int max_concurrency);
@@ -366,8 +369,6 @@ class CoreWorkerDirectActorTaskReceiver : public rpc::DirectActorHandler {
private:
// Worker context.
WorkerContext &worker_context_;
/// The rpc service for `DirectActorService`.
rpc::DirectActorGrpcService task_service_;
/// The callback function to process a task.
TaskHandler task_handler_;
/// The callback function to exit the worker.
@@ -6,16 +6,11 @@
namespace ray {
CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(
WorkerContext &worker_context, std::unique_ptr<RayletClient> &raylet_client,
boost::asio::io_service &io_service, rpc::GrpcServer &server,
const TaskHandler &task_handler, const std::function<void()> &exit_handler)
: worker_context_(worker_context),
raylet_client_(raylet_client),
task_service_(io_service, *this),
std::unique_ptr<RayletClient> &raylet_client, const TaskHandler &task_handler,
const std::function<void()> &exit_handler)
: raylet_client_(raylet_client),
task_handler_(task_handler),
exit_handler_(exit_handler) {
server.RegisterService(task_service_);
}
exit_handler_(exit_handler) {}
void CoreWorkerRayletTaskReceiver::HandleAssignTask(
const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply,
@@ -23,11 +18,6 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask(
const Task task(request.task());
const auto &task_spec = task.GetTaskSpecification();
RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId();
if (task_spec.IsActorTask() && worker_context_.CurrentActorUseDirectCall()) {
send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr,
nullptr);
return;
}
// Set the resource IDs for this task.
// TODO: convert the resource map to protobuf and change this.
@@ -4,22 +4,19 @@
#include <list>
#include "ray/common/ray_object.h"
#include "ray/core_worker/context.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/worker/worker_server.h"
namespace ray {
class CoreWorkerRayletTaskReceiver : public rpc::WorkerTaskHandler {
class CoreWorkerRayletTaskReceiver {
public:
using TaskHandler = std::function<Status(
const TaskSpecification &task_spec, const ResourceMappingType &resource_ids,
std::vector<std::shared_ptr<RayObject>> *return_by_value)>;
CoreWorkerRayletTaskReceiver(WorkerContext &worker_context,
std::unique_ptr<RayletClient> &raylet_client,
boost::asio::io_service &io_service,
rpc::GrpcServer &server, const TaskHandler &task_handler,
CoreWorkerRayletTaskReceiver(std::unique_ptr<RayletClient> &raylet_client,
const TaskHandler &task_handler,
const std::function<void()> &exit_handler);
/// Handle a `AssignTask` request.
@@ -31,15 +28,11 @@ class CoreWorkerRayletTaskReceiver : public rpc::WorkerTaskHandler {
/// \param[in] send_reply_callback The callback to be called when the request is done.
void HandleAssignTask(const rpc::AssignTaskRequest &request,
rpc::AssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
rpc::SendReplyCallback send_reply_callback);
private:
// Worker context.
WorkerContext &worker_context_;
/// Raylet client.
std::unique_ptr<RayletClient> &raylet_client_;
/// The rpc service for `WorkerTaskService`.
rpc::WorkerTaskGrpcService task_service_;
/// The callback function to process a task.
TaskHandler task_handler_;
/// The callback function to exit the worker.
+65
View File
@@ -4,6 +4,10 @@ package ray.rpc;
import "src/ray/protobuf/common.proto";
message ActiveObjectIDs {
repeated bytes object_ids = 1;
}
// Persistent state of an ActorHandle.
message ActorHandle {
// ID of the actor.
@@ -27,3 +31,64 @@ message ActorHandle {
// Whether direct actor call is used.
bool is_direct_call = 7;
}
message AssignTaskRequest {
// The task to be pushed.
Task task = 1;
// A list of the resources reserved for this worker.
// TODO(zhijunfu): `resource_ids` is represented as
// flatbutters-serialized bytes, will be moved to protobuf later.
bytes resource_ids = 2;
}
message AssignTaskReply {
}
message ReturnObject {
// Object ID.
bytes object_id = 1;
// Data of the object.
bytes data = 2;
// Metadata of the object.
bytes metadata = 3;
}
message DirectActorAssignTaskRequest {
// The task to be pushed.
TaskSpec task_spec = 1;
// The sequence number of the task for this client. This must increase
// sequentially starting from zero for each actor handle. The server
// will guarantee tasks execute in this sequence, waiting for any
// out-of-order request messages to arrive as necessary.
int64 sequence_number = 2;
// The max sequence number the client has processed responses for. This
// is a performance optimization that allows the client to tell the server
// to cancel any DirectActorAssignTaskRequests with seqno <= this value, rather than
// waiting for the server to time out waiting for missing messages.
int64 client_processed_up_to = 3;
}
message DirectActorAssignTaskReply {
// The returned objects.
repeated ReturnObject return_objects = 1;
}
message DirectActorCallArgWaitCompleteRequest {
// Id used to uniquely identify this request. This is sent back to the core
// worker to notify the wait has completed.
int64 tag = 1;
}
message DirectActorCallArgWaitCompleteReply {
}
service WorkerService {
// Push a task to a worker.
rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply);
// Push a task to a direct-call actor.
rpc DirectActorAssignTask(DirectActorAssignTaskRequest)
returns (DirectActorAssignTaskReply);
// Notify wait for direct actor call args has completed.
rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest)
returns (DirectActorCallArgWaitCompleteReply);
}
-52
View File
@@ -1,52 +0,0 @@
syntax = "proto3";
package ray.rpc;
import "src/ray/protobuf/common.proto";
message ReturnObject {
// Object ID.
bytes object_id = 1;
// Data of the object.
bytes data = 2;
// Metaata of the object.
bytes metadata = 3;
}
message PushTaskRequest {
// The task to be pushed.
TaskSpec task_spec = 1;
// The sequence number of the task for this client. This must increase
// sequentially starting from zero for each actor handle. The server
// will guarantee tasks execute in this sequence, waiting for any
// out-of-order request messages to arrive as necessary.
int64 sequence_number = 2;
// The max sequence number the client has processed responses for. This
// is a performance optimization that allows the client to tell the server
// to cancel any PushTaskRequests with seqno <= this value, rather than
// waiting for the server to time out waiting for missing messages.
int64 client_processed_up_to = 3;
}
message PushTaskReply {
// The returned objects.
repeated ReturnObject return_objects = 1;
}
message DirectActorCallArgWaitCompleteRequest {
// Id used to uniquely identify this request. This is sent back to the core
// worker to notify the wait has completed.
int64 tag = 1;
}
message DirectActorCallArgWaitCompleteReply {
}
// Service for direct actor.
service DirectActorService {
// Push a task to a worker.
rpc PushTask(PushTaskRequest) returns (PushTaskReply);
// Notify wait for direct actor call args has completed
rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest)
returns (DirectActorCallArgWaitCompleteReply);
}
-27
View File
@@ -1,27 +0,0 @@
syntax = "proto3";
package ray.rpc;
import "src/ray/protobuf/common.proto";
message ActiveObjectIDs {
repeated bytes object_ids = 1;
}
message AssignTaskRequest {
// The task to be pushed.
Task task = 1;
// A list of the resources reserved for this worker.
// TODO(zhijunfu): `resource_ids` is represented as
// flatbutters-serialized bytes, will be moved to protobuf later.
bytes resource_ids = 2;
}
message AssignTaskReply {
}
// Service for worker.
service WorkerTaskService {
// Push a task to a worker.
rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply);
}
+3 -5
View File
@@ -4,8 +4,8 @@
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/raylet.h"
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
#include "src/ray/protobuf/direct_actor.pb.h"
#include "src/ray/protobuf/core_worker.grpc.pb.h"
#include "src/ray/protobuf/core_worker.pb.h"
namespace ray {
@@ -27,8 +27,6 @@ Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, i
if (port_ > 0) {
rpc_client_ = std::unique_ptr<rpc::WorkerTaskClient>(
new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_));
direct_rpc_client_ = std::unique_ptr<rpc::DirectActorClient>(
new rpc::DirectActorClient("127.0.0.1", port_, client_call_manager_));
}
}
@@ -163,7 +161,7 @@ void Worker::DirectActorCallArgWaitComplete(int64_t tag) {
RAY_CHECK(port_ > 0);
rpc::DirectActorCallArgWaitCompleteRequest request;
request.set_tag(tag);
auto status = direct_rpc_client_->DirectActorCallArgWaitComplete(
auto status = rpc_client_->DirectActorCallArgWaitComplete(
request, [](Status status, const rpc::DirectActorCallArgWaitCompleteReply &reply) {
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to send wait complete: " << status.ToString();
-3
View File
@@ -8,7 +8,6 @@
#include "ray/common/task/scheduling_resources.h"
#include "ray/common/task/task.h"
#include "ray/common/task/task_common.h"
#include "ray/rpc/worker/direct_actor_client.h"
#include "ray/rpc/worker/worker_client.h"
#include <unistd.h> // pid_t
@@ -109,8 +108,6 @@ class Worker {
/// Whether the worker is detached. This is applies when the worker is actor.
/// Detached actor means the actor's creator can exit without killing this actor.
bool is_detached_actor_;
/// The rpc client to send tasks to the direct actor service.
std::unique_ptr<rpc::DirectActorClient> direct_rpc_client_;
};
} // namespace raylet
-157
View File
@@ -1,157 +0,0 @@
#ifndef RAY_RPC_DIRECT_ACTOR_CLIENT_H
#define RAY_RPC_DIRECT_ACTOR_CLIENT_H
#include <deque>
#include <memory>
#include <mutex>
#include <thread>
#include <grpcpp/grpcpp.h>
#include "absl/base/thread_annotations.h"
#include "ray/common/status.h"
#include "ray/rpc/client_call.h"
#include "ray/rpc/worker/direct_actor_common.h"
#include "ray/util/logging.h"
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
#include "src/ray/protobuf/direct_actor.pb.h"
namespace ray {
namespace rpc {
/// Client used for communicating with a direct actor server.
class DirectActorClient : public std::enable_shared_from_this<DirectActorClient> {
public:
/// Constructor.
///
/// \param[in] address Address of the direct actor server.
/// \param[in] port Port of the direct actor server.
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
static std::shared_ptr<DirectActorClient> make(const std::string &address,
const int port,
ClientCallManager &client_call_manager) {
auto instance = new DirectActorClient(address, port, client_call_manager);
return std::shared_ptr<DirectActorClient>(instance);
}
/// Constructor.
///
/// \param[in] address Address of the direct actor server.
/// \param[in] port Port of the direct actor server.
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
DirectActorClient(const std::string &address, const int port,
ClientCallManager &client_call_manager)
: client_call_manager_(client_call_manager) {
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
stub_ = DirectActorService::NewStub(channel);
};
/// Push a task.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
ray::Status PushTask(std::unique_ptr<PushTaskRequest> request,
const ClientCallback<PushTaskReply> &callback) {
request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter());
{
std::lock_guard<std::mutex> lock(mutex_);
if (request->task_spec().caller_id() != cur_caller_id_) {
// We are running a new task, reset the seq no counter.
max_finished_seq_no_ = -1;
cur_caller_id_ = request->task_spec().caller_id();
}
send_queue_.push_back(std::make_pair(std::move(request), callback));
}
SendRequests();
return ray::Status::OK();
}
/// Notify a wait has completed for direct actor call arguments.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
ray::Status DirectActorCallArgWaitComplete(
const DirectActorCallArgWaitCompleteRequest &request,
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) {
auto call = client_call_manager_.CreateCall<DirectActorService,
DirectActorCallArgWaitCompleteRequest,
DirectActorCallArgWaitCompleteReply>(
*stub_, &DirectActorService::Stub::PrepareAsyncDirectActorCallArgWaitComplete,
request, callback);
return call->GetStatus();
}
/// Send as many pending tasks as possible. This method is thread-safe.
///
/// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being
/// sent at once. This prevents the server scheduling queue from being overwhelmed.
/// See direct_actor.proto for a description of the ordering protocol.
void SendRequests() {
std::lock_guard<std::mutex> lock(mutex_);
auto this_ptr = this->shared_from_this();
while (!send_queue_.empty() && rpc_bytes_in_flight_ < kMaxBytesInFlight) {
auto pair = std::move(*send_queue_.begin());
send_queue_.pop_front();
auto request = std::move(pair.first);
auto callback = pair.second;
int64_t task_size = RequestSizeInBytes(*request);
int64_t seq_no = request->sequence_number();
request->set_client_processed_up_to(max_finished_seq_no_);
rpc_bytes_in_flight_ += task_size;
client_call_manager_.CreateCall<DirectActorService, PushTaskRequest, PushTaskReply>(
*stub_, &DirectActorService::Stub::PrepareAsyncPushTask, *request,
[this, this_ptr, seq_no, task_size, callback](Status status,
const rpc::PushTaskReply &reply) {
{
std::lock_guard<std::mutex> lock(mutex_);
if (seq_no > max_finished_seq_no_) {
max_finished_seq_no_ = seq_no;
}
rpc_bytes_in_flight_ -= task_size;
RAY_CHECK(rpc_bytes_in_flight_ >= 0);
}
SendRequests();
callback(status, reply);
});
}
if (!send_queue_.empty()) {
RAY_LOG(DEBUG) << "client send queue size " << send_queue_.size();
}
}
private:
/// Protects against unsafe concurrent access from the callback thread.
std::mutex mutex_;
/// The gRPC-generated stub.
std::unique_ptr<DirectActorService::Stub> stub_;
/// The `ClientCallManager` used for managing requests.
ClientCallManager &client_call_manager_;
/// Queue of requests to send.
std::deque<std::pair<std::unique_ptr<PushTaskRequest>, ClientCallback<PushTaskReply>>>
send_queue_ GUARDED_BY(mutex_);
/// The number of bytes currently in flight.
int64_t rpc_bytes_in_flight_ GUARDED_BY(mutex_) = 0;
/// The max sequence number we have processed responses for.
int64_t max_finished_seq_no_ GUARDED_BY(mutex_) = -1;
/// The task id we are currently sending requests for. When this changes,
/// the max finished seq no counter is reset.
std::string cur_caller_id_;
};
} // namespace rpc
} // namespace ray
#endif // RAY_RPC_DIRECT_ACTOR_CLIENT_H
-28
View File
@@ -1,28 +0,0 @@
#ifndef RAY_RPC_DIRECT_ACTOR_COMMON_H
#define RAY_RPC_DIRECT_ACTOR_COMMON_H
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
#include "src/ray/protobuf/direct_actor.pb.h"
namespace ray {
namespace rpc {
/// The maximum number of requests in flight per client.
const int64_t kMaxBytesInFlight = 16 * 1024 * 1024;
/// The base size in bytes per request.
const int64_t kBaseRequestSize = 1024;
/// Get the estimated size in bytes of the given task.
const static int64_t RequestSizeInBytes(const PushTaskRequest &request) {
int64_t size = kBaseRequestSize;
for (auto &arg : request.task_spec().args()) {
size += arg.data().size();
}
return size;
}
} // namespace rpc
} // namespace ray
#endif // RAY_RPC_DIRECT_ACTOR_COMMON_H
-88
View File
@@ -1,88 +0,0 @@
#ifndef RAY_RPC_DIRECT_ACTOR_SERVER_H
#define RAY_RPC_DIRECT_ACTOR_SERVER_H
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/server_call.h"
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
#include "src/ray/protobuf/direct_actor.pb.h"
namespace ray {
namespace rpc {
/// Interface of the `DirectActorService`, see `src/ray/protobuf/direct_actor.proto`.
class DirectActorHandler {
public:
/// Handle a `PushTask` request.
/// The implementation can handle this request asynchronously. When hanling is done, the
/// `done_callback` should be called.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] done_callback The callback to be called when the request is done.
virtual void HandlePushTask(const PushTaskRequest &request, PushTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a wait reply for direct actor call arg dependencies.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] send_replay_callback The callback to be called when the request is done.
virtual void HandleDirectActorCallArgWaitComplete(
const rpc::DirectActorCallArgWaitCompleteRequest &request,
rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback) = 0;
};
/// The `GrpcServer` for `WorkerService`.
class DirectActorGrpcService : public GrpcService {
public:
/// Constructor.
///
/// \param[in] main_service See super class.
/// \param[in] handler The service handler that actually handle the requests.
DirectActorGrpcService(boost::asio::io_service &main_service,
DirectActorHandler &service_handler)
: GrpcService(main_service), service_handler_(service_handler){};
protected:
grpc::Service &GetGrpcService() override { return service_; }
void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) override {
// Initialize the Factory for `PushTask` requests.
std::unique_ptr<ServerCallFactory> push_task_call_Factory(
new ServerCallFactoryImpl<DirectActorService, DirectActorHandler, PushTaskRequest,
PushTaskReply>(
service_, &DirectActorService::AsyncService::RequestPushTask,
service_handler_, &DirectActorHandler::HandlePushTask, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(push_task_call_Factory), 100);
// Initialize the Factory for `DirectActorCallArgWaitComplete` requests.
std::unique_ptr<ServerCallFactory> wait_complete_call_Factory(
new ServerCallFactoryImpl<DirectActorService, DirectActorHandler,
DirectActorCallArgWaitCompleteRequest,
DirectActorCallArgWaitCompleteReply>(
service_,
&DirectActorService::AsyncService::RequestDirectActorCallArgWaitComplete,
service_handler_, &DirectActorHandler::HandleDirectActorCallArgWaitComplete,
cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(wait_complete_call_Factory), 100);
}
private:
/// The grpc async service object.
DirectActorService::AsyncService service_;
/// The service handler that actually handle the requests.
DirectActorHandler &service_handler_;
};
} // namespace rpc
} // namespace ray
#endif
+128 -9
View File
@@ -1,21 +1,40 @@
#ifndef RAY_RPC_WORKER_CLIENT_H
#define RAY_RPC_WORKER_CLIENT_H
#include <deque>
#include <memory>
#include <mutex>
#include <thread>
#include <grpcpp/grpcpp.h>
#include "absl/base/thread_annotations.h"
#include "ray/common/status.h"
#include "ray/rpc/client_call.h"
#include "ray/util/logging.h"
#include "src/ray/protobuf/worker.grpc.pb.h"
#include "src/ray/protobuf/worker.pb.h"
#include "src/ray/protobuf/core_worker.grpc.pb.h"
#include "src/ray/protobuf/core_worker.pb.h"
namespace ray {
namespace rpc {
/// The maximum number of requests in flight per client.
const int64_t kMaxBytesInFlight = 16 * 1024 * 1024;
/// The base size in bytes per request.
const int64_t kBaseRequestSize = 1024;
/// Get the estimated size in bytes of the given task.
const static int64_t RequestSizeInBytes(const DirectActorAssignTaskRequest &request) {
int64_t size = kBaseRequestSize;
for (auto &arg : request.task_spec().args()) {
size += arg.data().size();
}
return size;
}
/// Client used for communicating with a remote worker server.
class WorkerTaskClient {
class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
public:
/// Constructor.
///
@@ -27,7 +46,7 @@ class WorkerTaskClient {
: client_call_manager_(client_call_manager) {
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
stub_ = WorkerTaskService::NewStub(channel);
stub_ = WorkerService::NewStub(channel);
};
/// Assign a task to the work.
@@ -37,19 +56,119 @@ class WorkerTaskClient {
/// \return if the rpc call succeeds
ray::Status AssignTask(const AssignTaskRequest &request,
const ClientCallback<AssignTaskReply> &callback) {
auto call = client_call_manager_
.CreateCall<WorkerTaskService, AssignTaskRequest, AssignTaskReply>(
*stub_, &WorkerTaskService::Stub::PrepareAsyncAssignTask, request,
callback);
auto call =
client_call_manager_
.CreateCall<WorkerService, AssignTaskRequest, AssignTaskReply>(
*stub_, &WorkerService::Stub::PrepareAsyncAssignTask, request, callback);
return call->GetStatus();
}
/// Push a task.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
ray::Status DirectActorAssignTask(
std::unique_ptr<DirectActorAssignTaskRequest> request,
const ClientCallback<DirectActorAssignTaskReply> &callback) {
request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter());
{
std::lock_guard<std::mutex> lock(mutex_);
if (request->task_spec().caller_id() != cur_caller_id_) {
// We are running a new task, reset the seq no counter.
max_finished_seq_no_ = -1;
cur_caller_id_ = request->task_spec().caller_id();
}
send_queue_.push_back(std::make_pair(std::move(request), callback));
}
SendRequests();
return ray::Status::OK();
}
/// Notify a wait has completed for direct actor call arguments.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
ray::Status DirectActorCallArgWaitComplete(
const DirectActorCallArgWaitCompleteRequest &request,
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) {
auto call =
client_call_manager_
.CreateCall<WorkerService, DirectActorCallArgWaitCompleteRequest,
DirectActorCallArgWaitCompleteReply>(
*stub_, &WorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete,
request, callback);
return call->GetStatus();
}
/// Send as many pending tasks as possible. This method is thread-safe.
///
/// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being
/// sent at once. This prevents the server scheduling queue from being overwhelmed.
/// See direct_actor.proto for a description of the ordering protocol.
void SendRequests() {
std::lock_guard<std::mutex> lock(mutex_);
auto this_ptr = this->shared_from_this();
while (!send_queue_.empty() && rpc_bytes_in_flight_ < kMaxBytesInFlight) {
auto pair = std::move(*send_queue_.begin());
send_queue_.pop_front();
auto request = std::move(pair.first);
auto callback = pair.second;
int64_t task_size = RequestSizeInBytes(*request);
int64_t seq_no = request->sequence_number();
request->set_client_processed_up_to(max_finished_seq_no_);
rpc_bytes_in_flight_ += task_size;
client_call_manager_.CreateCall<WorkerService, DirectActorAssignTaskRequest,
DirectActorAssignTaskReply>(
*stub_, &WorkerService::Stub::PrepareAsyncDirectActorAssignTask, *request,
[this, this_ptr, seq_no, task_size, callback](
Status status, const rpc::DirectActorAssignTaskReply &reply) {
{
std::lock_guard<std::mutex> lock(mutex_);
if (seq_no > max_finished_seq_no_) {
max_finished_seq_no_ = seq_no;
}
rpc_bytes_in_flight_ -= task_size;
RAY_CHECK(rpc_bytes_in_flight_ >= 0);
}
SendRequests();
callback(status, reply);
});
}
if (!send_queue_.empty()) {
RAY_LOG(DEBUG) << "client send queue size " << send_queue_.size();
}
}
private:
/// Protects against unsafe concurrent access from the callback thread.
std::mutex mutex_;
/// The gRPC-generated stub.
std::unique_ptr<WorkerTaskService::Stub> stub_;
std::unique_ptr<WorkerService::Stub> stub_;
/// The `ClientCallManager` used for managing requests.
ClientCallManager &client_call_manager_;
/// Queue of requests to send.
std::deque<std::pair<std::unique_ptr<DirectActorAssignTaskRequest>,
ClientCallback<DirectActorAssignTaskReply>>>
send_queue_ GUARDED_BY(mutex_);
/// The number of bytes currently in flight.
int64_t rpc_bytes_in_flight_ GUARDED_BY(mutex_) = 0;
/// The max sequence number we have processed responses for.
int64_t max_finished_seq_no_ GUARDED_BY(mutex_) = -1;
/// The task id we are currently sending requests for. When this changes,
/// the max finished seq no counter is reset.
std::string cur_caller_id_;
};
} // namespace rpc
+28
View File
@@ -0,0 +1,28 @@
#include "ray/rpc/worker/worker_server.h"
#include "ray/core_worker/core_worker.h"
namespace ray {
namespace rpc {
#define RAY_CORE_WORKER_RPC_HANDLER(HANDLER, CONCURRENCY) \
std::unique_ptr<ServerCallFactory> HANDLER##_call_factory( \
new ServerCallFactoryImpl<WorkerService, CoreWorker, HANDLER##Request, \
HANDLER##Reply>( \
service_, &WorkerService::AsyncService::Request##HANDLER, core_worker_, \
&CoreWorker::Handle##HANDLER, cq, main_service_)); \
server_call_factories_and_concurrencies->emplace_back( \
std::move(HANDLER##_call_factory), CONCURRENCY);
WorkerGrpcService::WorkerGrpcService(boost::asio::io_service &main_service,
CoreWorker &core_worker)
: GrpcService(main_service), core_worker_(core_worker){};
void WorkerGrpcService::InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) {
RAY_CORE_WORKER_RPC_HANDLERS
}
} // namespace rpc
} // namespace ray
+11 -35
View File
@@ -4,36 +4,23 @@
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/server_call.h"
#include "src/ray/protobuf/worker.grpc.pb.h"
#include "src/ray/protobuf/worker.pb.h"
#include "src/ray/protobuf/core_worker.grpc.pb.h"
#include "src/ray/protobuf/core_worker.pb.h"
namespace ray {
class CoreWorker;
namespace rpc {
/// Interface of the `WorkerService`, see `src/ray/protobuf/worker.proto`.
class WorkerTaskHandler {
public:
/// Handle a `AssignTask` request.
/// The implementation can handle this request asynchronously. When handling is done,
/// the `send_reply_callback` should be called.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] send_reply_callback The callback to be called when the request is done.
virtual void HandleAssignTask(const AssignTaskRequest &request, AssignTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;
};
/// The `GrpcServer` for `WorkerService`.
class WorkerTaskGrpcService : public GrpcService {
class WorkerGrpcService : public GrpcService {
public:
/// Constructor.
///
/// \param[in] main_service See super class.
/// \param[in] handler The service handler that actually handle the requests.
WorkerTaskGrpcService(boost::asio::io_service &main_service,
WorkerTaskHandler &service_handler)
: GrpcService(main_service), service_handler_(service_handler){};
WorkerGrpcService(boost::asio::io_service &main_service, CoreWorker &core_worker);
protected:
grpc::Service &GetGrpcService() override { return service_; }
@@ -41,25 +28,14 @@ class WorkerTaskGrpcService : public GrpcService {
void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) override {
// Initialize the Factory for `AssignTask` requests.
std::unique_ptr<ServerCallFactory> push_task_call_Factory(
new ServerCallFactoryImpl<WorkerTaskService, WorkerTaskHandler, AssignTaskRequest,
AssignTaskReply>(
service_, &WorkerTaskService::AsyncService::RequestAssignTask,
service_handler_, &WorkerTaskHandler::HandleAssignTask, cq, main_service_));
// Set `AssignTask`'s accept concurrency to 5.
server_call_factories_and_concurrencies->emplace_back(
std::move(push_task_call_Factory), 5);
}
*server_call_factories_and_concurrencies) override;
private:
/// The grpc async service object.
WorkerTaskService::AsyncService service_;
WorkerService::AsyncService service_;
/// The service handler that actually handle the requests.
WorkerTaskHandler &service_handler_;
/// The core worker that actually handles the requests.
CoreWorker &core_worker_;
};
} // namespace rpc