mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 19:58:40 +08:00
Simplify gRPC service definition for the worker (#6095)
This commit is contained in:
+4
-34
@@ -77,17 +77,6 @@ cc_proto_library(
|
||||
deps = [":object_manager_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "worker_proto",
|
||||
srcs = ["src/ray/protobuf/worker.proto"],
|
||||
deps = [":common_proto"],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "worker_cc_proto",
|
||||
deps = ["worker_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "core_worker_proto",
|
||||
srcs = ["src/ray/protobuf/core_worker.proto"],
|
||||
@@ -95,21 +84,10 @@ proto_library(
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "core_worker_cc_proto",
|
||||
name = "worker_cc_proto",
|
||||
deps = ["core_worker_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "direct_actor_proto",
|
||||
srcs = ["src/ray/protobuf/direct_actor.proto"],
|
||||
deps = [":common_proto"],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "direct_actor_cc_proto",
|
||||
deps = ["direct_actor_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "serialization_proto",
|
||||
srcs = ["src/ray/protobuf/serialization.proto"],
|
||||
@@ -193,19 +171,11 @@ cc_library(
|
||||
# Worker gRPC lib.
|
||||
cc_grpc_library(
|
||||
name = "worker_cc_grpc",
|
||||
srcs = [":worker_proto"],
|
||||
srcs = [":core_worker_proto"],
|
||||
grpc_only = True,
|
||||
deps = [":worker_cc_proto"],
|
||||
)
|
||||
|
||||
# direct actor gRPC lib.
|
||||
cc_grpc_library(
|
||||
name = "direct_actor_cc_grpc",
|
||||
srcs = [":direct_actor_proto"],
|
||||
grpc_only = True,
|
||||
deps = [":direct_actor_cc_proto"],
|
||||
)
|
||||
|
||||
# worker server and client.
|
||||
cc_library(
|
||||
name = "worker_rpc",
|
||||
@@ -214,7 +184,6 @@ cc_library(
|
||||
]),
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
"direct_actor_cc_grpc",
|
||||
":grpc_common_lib",
|
||||
":ray_common",
|
||||
":worker_cc_grpc",
|
||||
@@ -363,6 +332,7 @@ cc_library(
|
||||
"src/ray/core_worker/store_provider/*.cc",
|
||||
"src/ray/core_worker/store_provider/memory_store/*.cc",
|
||||
"src/ray/core_worker/transport/*.cc",
|
||||
"src/ray/rpc/worker/*.cc",
|
||||
],
|
||||
exclude = [
|
||||
"src/ray/core_worker/*_test.cc",
|
||||
@@ -379,7 +349,7 @@ cc_library(
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
":core_worker_cc_proto",
|
||||
":worker_cc_proto",
|
||||
":ray_common",
|
||||
":ray_util",
|
||||
# TODO(hchen): After `raylet_client` is migrated to gRPC, `core_worker_lib`
|
||||
|
||||
@@ -89,7 +89,6 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
|
||||
if (task_spec.IsNormalTask()) {
|
||||
RAY_CHECK(current_job_id_.IsNil());
|
||||
SetCurrentJobId(task_spec.JobId());
|
||||
current_actor_use_direct_call_ = false;
|
||||
} else if (task_spec.IsActorCreationTask()) {
|
||||
RAY_CHECK(current_job_id_.IsNil());
|
||||
SetCurrentJobId(task_spec.JobId());
|
||||
|
||||
@@ -47,7 +47,7 @@ class WorkerContext {
|
||||
const WorkerID worker_id_;
|
||||
JobID current_job_id_;
|
||||
ActorID current_actor_id_;
|
||||
bool current_actor_use_direct_call_;
|
||||
bool current_actor_use_direct_call_ = false;
|
||||
int current_actor_max_concurrency_;
|
||||
|
||||
private:
|
||||
|
||||
@@ -86,7 +86,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
gcs_client_(gcs_options),
|
||||
memory_store_(std::make_shared<CoreWorkerMemoryStore>()),
|
||||
task_execution_service_work_(task_execution_service_),
|
||||
task_execution_callback_(task_execution_callback) {
|
||||
task_execution_callback_(task_execution_callback),
|
||||
grpc_service_(io_service_, *this) {
|
||||
// Initialize logging if log_dir is passed. Otherwise, it must be initialized
|
||||
// and cleaned up by the caller.
|
||||
if (log_dir_ != "") {
|
||||
@@ -111,14 +112,13 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
// Initialize task receivers.
|
||||
auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3);
|
||||
raylet_task_receiver_ =
|
||||
std::unique_ptr<CoreWorkerRayletTaskReceiver>(new CoreWorkerRayletTaskReceiver(
|
||||
worker_context_, raylet_client_, task_execution_service_, worker_server_,
|
||||
execute_task, exit_handler));
|
||||
raylet_task_receiver_ = std::unique_ptr<CoreWorkerRayletTaskReceiver>(
|
||||
new CoreWorkerRayletTaskReceiver(raylet_client_, execute_task, exit_handler));
|
||||
direct_actor_task_receiver_ = std::unique_ptr<CoreWorkerDirectActorTaskReceiver>(
|
||||
new CoreWorkerDirectActorTaskReceiver(worker_context_, task_execution_service_,
|
||||
worker_server_, execute_task,
|
||||
exit_handler));
|
||||
worker_server_.RegisterService(grpc_service_);
|
||||
}
|
||||
|
||||
// Start RPC server after all the task receivers are properly initialized.
|
||||
@@ -750,4 +750,37 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
|
||||
return status;
|
||||
}
|
||||
|
||||
void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request,
|
||||
rpc::AssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
if (worker_context_.CurrentActorUseDirectCall()) {
|
||||
send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr,
|
||||
nullptr);
|
||||
return;
|
||||
} else {
|
||||
task_execution_service_.post([=] {
|
||||
raylet_task_receiver_->HandleAssignTask(request, reply, send_reply_callback);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorker::HandleDirectActorAssignTask(
|
||||
const rpc::DirectActorAssignTaskRequest &request,
|
||||
rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) {
|
||||
task_execution_service_.post([=] {
|
||||
direct_actor_task_receiver_->HandleDirectActorAssignTask(request, reply,
|
||||
send_reply_callback);
|
||||
});
|
||||
}
|
||||
|
||||
void CoreWorker::HandleDirectActorCallArgWaitComplete(
|
||||
const rpc::DirectActorCallArgWaitCompleteRequest &request,
|
||||
rpc::DirectActorCallArgWaitCompleteReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
task_execution_service_.post([=] {
|
||||
direct_actor_task_receiver_->HandleDirectActorCallArgWaitComplete(
|
||||
request, reply, send_reply_callback);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -20,6 +20,16 @@
|
||||
#include "ray/rpc/worker/worker_client.h"
|
||||
#include "ray/rpc/worker/worker_server.h"
|
||||
|
||||
/// The set of gRPC handlers and their associated level of concurrency. If you want to
|
||||
/// add a new call to the worker gRPC server, do the following:
|
||||
/// 1) Add the rpc to the WorkerService in core_worker.proto, e.g., "ExampleCall"
|
||||
/// 2) Add a new handler to the macro below: "RAY_CORE_WORKER_RPC_HANDLER(ExampleCall, 1)"
|
||||
/// 3) Add a method to the CoreWorker class below: "CoreWorker::HandleExampleCall"
|
||||
#define RAY_CORE_WORKER_RPC_HANDLERS \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(DirectActorAssignTask, 9999) \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100)
|
||||
|
||||
namespace ray {
|
||||
|
||||
/// The root class that contains all the core and language-independent functionalities
|
||||
@@ -95,7 +105,9 @@ class CoreWorker {
|
||||
// in the heartbeat messsage.
|
||||
void RemoveActiveObjectID(const ObjectID &object_id) LOCKS_EXCLUDED(object_ref_mu_);
|
||||
|
||||
/* Public methods related to storing and retrieving objects. */
|
||||
///
|
||||
/// Public methods related to storing and retrieving objects.
|
||||
///
|
||||
|
||||
/// Set options for this client's interactions with the object store.
|
||||
///
|
||||
@@ -198,7 +210,9 @@ class CoreWorker {
|
||||
/// \return std::string The string describing memory usage.
|
||||
std::string MemoryUsageString();
|
||||
|
||||
/* Public methods related to task submission. */
|
||||
///
|
||||
/// Public methods related to task submission.
|
||||
///
|
||||
|
||||
/// Get the caller ID used to submit tasks from this worker to an actor.
|
||||
///
|
||||
@@ -268,7 +282,9 @@ class CoreWorker {
|
||||
/// \return Status::Invalid if we don't have the specified handle.
|
||||
Status SerializeActorHandle(const ActorID &actor_id, std::string *output) const;
|
||||
|
||||
/* Public methods related to task execution. Should not be used by driver processes. */
|
||||
///
|
||||
/// Public methods related to task execution. Should not be used by driver processes.
|
||||
///
|
||||
|
||||
const ActorID &GetActorId() const { return actor_id_; }
|
||||
|
||||
@@ -295,6 +311,29 @@ class CoreWorker {
|
||||
const std::vector<std::shared_ptr<Buffer>> &metadatas,
|
||||
std::vector<std::shared_ptr<RayObject>> *return_objects);
|
||||
|
||||
/* Handlers for the worker's gRPC server. These are executed on the io_service_ and post
|
||||
* work to the appropriate event loop.
|
||||
*/
|
||||
|
||||
/// Handle an "AssignTask" event corresponding to scheduling a normal or an actor task
|
||||
/// on this worker from the raylet.
|
||||
void HandleAssignTask(const rpc::AssignTaskRequest &request,
|
||||
rpc::AssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Handle a "DirectActorAssignTask" event corresponding to scheduling an actor task
|
||||
/// on this worker from another worker.
|
||||
void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request,
|
||||
rpc::DirectActorAssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Handle a "DirectActorAssignTask" event corresponding to the raylet notifiying this
|
||||
/// worker that an argument is ready.
|
||||
void HandleDirectActorCallArgWaitComplete(
|
||||
const rpc::DirectActorCallArgWaitCompleteRequest &request,
|
||||
rpc::DirectActorCallArgWaitCompleteReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
private:
|
||||
/// Run the io_service_ event loop. This should be called in a background thread.
|
||||
void RunIOService();
|
||||
@@ -306,7 +345,9 @@ class CoreWorker {
|
||||
/// Send the list of active object IDs to the raylet.
|
||||
void ReportActiveObjectIDs() LOCKS_EXCLUDED(object_ref_mu_);
|
||||
|
||||
/* Private methods related to task submission. */
|
||||
///
|
||||
/// Private methods related to task submission.
|
||||
///
|
||||
|
||||
/// Give this worker a handle to an actor.
|
||||
///
|
||||
@@ -328,7 +369,9 @@ class CoreWorker {
|
||||
/// \return Status::Invalid if we don't have this actor handle.
|
||||
Status GetActorHandle(const ActorID &actor_id, ActorHandle **actor_handle) const;
|
||||
|
||||
/* Private methods related to task execution. Should not be used by driver processes. */
|
||||
///
|
||||
/// Private methods related to task execution. Should not be used by driver processes.
|
||||
///
|
||||
|
||||
/// Execute a task.
|
||||
///
|
||||
@@ -408,7 +451,9 @@ class CoreWorker {
|
||||
// Thread that runs a boost::asio service to process IO events.
|
||||
std::thread io_thread_;
|
||||
|
||||
/* Fields related to ref counting objects. */
|
||||
///
|
||||
/// Fields related to ref counting objects.
|
||||
///
|
||||
|
||||
/// Protects access to the set of active object ids. Since this set is updated
|
||||
/// very frequently, it is faster to lock around accesses rather than serialize
|
||||
@@ -422,7 +467,9 @@ class CoreWorker {
|
||||
/// last time it was sent to the raylet.
|
||||
bool active_object_ids_updated_ GUARDED_BY(object_ref_mu_) = false;
|
||||
|
||||
/* Fields related to storing and retrieving objects. */
|
||||
///
|
||||
/// Fields related to storing and retrieving objects.
|
||||
///
|
||||
|
||||
/// In-memory store for return objects. This is used for `MEMORY` store provider.
|
||||
std::shared_ptr<CoreWorkerMemoryStore> memory_store_;
|
||||
@@ -433,7 +480,9 @@ class CoreWorker {
|
||||
/// In-memory store interface.
|
||||
std::unique_ptr<CoreWorkerMemoryStoreProvider> memory_store_provider_;
|
||||
|
||||
/* Fields related to task submission. */
|
||||
///
|
||||
/// Fields related to task submission.
|
||||
///
|
||||
|
||||
// Interface to submit tasks directly to other actors.
|
||||
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;
|
||||
@@ -441,7 +490,9 @@ class CoreWorker {
|
||||
/// Map from actor ID to a handle to that actor.
|
||||
absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_;
|
||||
|
||||
/* Fields related to task execution. */
|
||||
///
|
||||
/// Fields related to task execution.
|
||||
///
|
||||
|
||||
/// Our actor ID. If this is nil, then we execute only stateless tasks.
|
||||
ActorID actor_id_;
|
||||
@@ -466,6 +517,9 @@ class CoreWorker {
|
||||
// Interface that receives tasks from the raylet.
|
||||
std::unique_ptr<CoreWorkerRayletTaskReceiver> raylet_task_receiver_;
|
||||
|
||||
/// Common rpc service for all worker modules.
|
||||
rpc::WorkerGrpcService grpc_service_;
|
||||
|
||||
// Interface that receives tasks from direct actor calls.
|
||||
std::unique_ptr<CoreWorkerDirectActorTaskReceiver> direct_actor_task_receiver_;
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
class CoreWorker;
|
||||
|
||||
/// The class provides implementations for accessing local process memory store.
|
||||
/// An example usage for this is to retrieve the returned objects from direct
|
||||
/// actor call (see direct_actor_transport.cc).
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
class CoreWorker;
|
||||
|
||||
/// The class provides implementations for accessing plasma store, which includes both
|
||||
/// local and remote stores. Local access goes is done via a
|
||||
/// CoreWorkerLocalPlasmaStoreProvider and remote access goes through the raylet.
|
||||
|
||||
@@ -13,8 +13,7 @@
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
|
||||
#include "src/ray/protobuf/direct_actor.pb.h"
|
||||
#include "src/ray/protobuf/core_worker.pb.h"
|
||||
#include "src/ray/protobuf/gcs.pb.h"
|
||||
#include "src/ray/util/test_util.h"
|
||||
|
||||
@@ -481,7 +480,7 @@ TEST_F(ZeroNodeTest, TestTaskArg) {
|
||||
ASSERT_EQ(*data, *buffer);
|
||||
}
|
||||
|
||||
// Performance batchmark for `PushTaskRequest` creation.
|
||||
// Performance batchmark for `DirectActorAssignTaskRequest` creation.
|
||||
TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
|
||||
// Create a dummy actor handle, and then create a number of `TaskSpec`
|
||||
// to benchmark performance.
|
||||
@@ -500,11 +499,11 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
|
||||
function.GetFunctionDescriptor());
|
||||
|
||||
// Manually create `num_tasks` task specs, and for each of them create a
|
||||
// `PushTaskRequest`, this is to batch performance of TaskSpec
|
||||
// `DirectActorAssignTaskRequest`, this is to batch performance of TaskSpec
|
||||
// creation/copy/destruction.
|
||||
int64_t start_ms = current_time_ms();
|
||||
const auto num_tasks = 10000 * 10;
|
||||
RAY_LOG(INFO) << "start creating " << num_tasks << " PushTaskRequests";
|
||||
RAY_LOG(INFO) << "start creating " << num_tasks << " DirectActorAssignTaskRequests";
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
TaskOptions options{1, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
@@ -529,10 +528,11 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
|
||||
const auto &task_spec = builder.Build();
|
||||
|
||||
ASSERT_TRUE(task_spec.IsActorTask());
|
||||
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
||||
auto request = std::unique_ptr<rpc::DirectActorAssignTaskRequest>(
|
||||
new rpc::DirectActorAssignTaskRequest);
|
||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||
}
|
||||
RAY_LOG(INFO) << "Finish creating " << num_tasks << " PushTaskRequests"
|
||||
RAY_LOG(INFO) << "Finish creating " << num_tasks << " DirectActorAssignTaskRequests"
|
||||
<< ", which takes " << current_time_ms() - start_ms << " ms";
|
||||
}
|
||||
|
||||
|
||||
@@ -5,15 +5,6 @@ using ray::rpc::ActorTableData;
|
||||
|
||||
namespace ray {
|
||||
|
||||
bool HasByReferenceArgs(const TaskSpecification &spec) {
|
||||
for (size_t i = 0; i < spec.NumArgs(); ++i) {
|
||||
if (spec.ArgIdCount(i) > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter(
|
||||
boost::asio::io_service &io_service,
|
||||
std::unique_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
@@ -31,7 +22,8 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
|
||||
const auto task_id = task_spec.TaskId();
|
||||
const auto num_returns = task_spec.NumReturns();
|
||||
|
||||
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
||||
auto request = std::unique_ptr<rpc::DirectActorAssignTaskRequest>(
|
||||
new rpc::DirectActorAssignTaskRequest);
|
||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
@@ -57,7 +49,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
|
||||
|
||||
// Submit request.
|
||||
auto &client = rpc_clients_[actor_id];
|
||||
PushTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
} else {
|
||||
// Actor is dead, treat the task as failure.
|
||||
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
|
||||
@@ -114,8 +106,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
|
||||
const ActorID &actor_id, std::string ip_address, int port) {
|
||||
std::shared_ptr<rpc::DirectActorClient> grpc_client =
|
||||
rpc::DirectActorClient::make(ip_address, port, client_call_manager_);
|
||||
std::shared_ptr<rpc::WorkerTaskClient> grpc_client =
|
||||
std::make_shared<rpc::WorkerTaskClient>(ip_address, port, client_call_manager_);
|
||||
RAY_CHECK(rpc_clients_.emplace(actor_id, std::move(grpc_client)).second);
|
||||
|
||||
// Submit all pending requests.
|
||||
@@ -125,20 +117,22 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
|
||||
auto request = std::move(requests.front());
|
||||
auto num_returns = request->task_spec().num_returns();
|
||||
auto task_id = TaskID::FromBinary(request->task_spec().task_id());
|
||||
PushTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
requests.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::PushTask(
|
||||
rpc::DirectActorClient &client, std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
|
||||
void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask(
|
||||
rpc::WorkerTaskClient &client,
|
||||
std::unique_ptr<rpc::DirectActorAssignTaskRequest> request, const ActorID &actor_id,
|
||||
const TaskID &task_id, int num_returns) {
|
||||
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
|
||||
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
|
||||
|
||||
auto status = client.PushTask(
|
||||
std::move(request), [this, actor_id, task_id, num_returns](
|
||||
Status status, const rpc::PushTaskReply &reply) {
|
||||
auto status = client.DirectActorAssignTask(
|
||||
std::move(request),
|
||||
[this, actor_id, task_id, num_returns](
|
||||
Status status, const rpc::DirectActorAssignTaskReply &reply) {
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
waiting_reply_tasks_[actor_id].erase(task_id);
|
||||
@@ -203,12 +197,9 @@ CoreWorkerDirectActorTaskReceiver::CoreWorkerDirectActorTaskReceiver(
|
||||
rpc::GrpcServer &server, const TaskHandler &task_handler,
|
||||
const std::function<void()> &exit_handler)
|
||||
: worker_context_(worker_context),
|
||||
task_service_(main_io_service, *this),
|
||||
task_handler_(task_handler),
|
||||
exit_handler_(exit_handler),
|
||||
task_main_io_service_(main_io_service) {
|
||||
server.RegisterService(task_service_);
|
||||
}
|
||||
task_main_io_service_(main_io_service) {}
|
||||
|
||||
void CoreWorkerDirectActorTaskReceiver::Init(RayletClient &raylet_client) {
|
||||
waiter_.reset(new DependencyWaiterImpl(raylet_client));
|
||||
@@ -223,9 +214,9 @@ void CoreWorkerDirectActorTaskReceiver::SetMaxActorConcurrency(int max_concurren
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskReceiver::HandlePushTask(
|
||||
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
void CoreWorkerDirectActorTaskReceiver::HandleDirectActorAssignTask(
|
||||
const rpc::DirectActorAssignTaskRequest &request,
|
||||
rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) {
|
||||
RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
|
||||
const TaskSpecification task_spec(request.task_spec());
|
||||
RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId();
|
||||
@@ -269,7 +260,7 @@ void CoreWorkerDirectActorTaskReceiver::HandlePushTask(
|
||||
std::vector<std::shared_ptr<RayObject>> return_by_value;
|
||||
auto status = task_handler_(task_spec, resource_ids, &return_by_value);
|
||||
if (status.IsSystemExit()) {
|
||||
// In Python, SystemExit cannot be raised except on the main thread. To work
|
||||
// In Python, SystemExit can only be raised on the main thread. To work
|
||||
// around this when we are executing tasks on worker threads, we re-post the
|
||||
// exit event explicitly on the main thread.
|
||||
task_main_io_service_.post([this]() { exit_handler_(); });
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
#include "ray/gcs/redis_gcs_client.h"
|
||||
#include "ray/rpc/worker/direct_actor_client.h"
|
||||
#include "ray/rpc/worker/direct_actor_server.h"
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
#include "ray/rpc/worker/worker_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
@@ -67,9 +67,10 @@ class CoreWorkerDirectActorTaskSubmitter {
|
||||
/// \param[in] task_id The ID of a task.
|
||||
/// \param[in] num_returns Number of return objects.
|
||||
/// \return Void.
|
||||
void PushTask(rpc::DirectActorClient &client,
|
||||
std::unique_ptr<rpc::PushTaskRequest> request, const ActorID &actor_id,
|
||||
const TaskID &task_id, int num_returns);
|
||||
void DirectActorAssignTask(rpc::WorkerTaskClient &client,
|
||||
std::unique_ptr<rpc::DirectActorAssignTaskRequest> request,
|
||||
const ActorID &actor_id, const TaskID &task_id,
|
||||
int num_returns);
|
||||
|
||||
/// Treat a task as failed.
|
||||
///
|
||||
@@ -114,10 +115,11 @@ class CoreWorkerDirectActorTaskSubmitter {
|
||||
///
|
||||
/// TODO(zhijunfu): this will be moved into `actor_states_` later when we can
|
||||
/// subscribe updates for a specific actor.
|
||||
std::unordered_map<ActorID, std::shared_ptr<rpc::DirectActorClient>> rpc_clients_;
|
||||
std::unordered_map<ActorID, std::shared_ptr<rpc::WorkerTaskClient>> rpc_clients_;
|
||||
|
||||
/// Map from actor id to the actor's pending requests.
|
||||
std::unordered_map<ActorID, std::list<std::unique_ptr<rpc::PushTaskRequest>>>
|
||||
std::unordered_map<ActorID,
|
||||
std::list<std::unique_ptr<rpc::DirectActorAssignTaskRequest>>>
|
||||
pending_requests_;
|
||||
|
||||
/// Map from actor id to the tasks that are waiting for reply.
|
||||
@@ -327,7 +329,7 @@ class SchedulingQueue {
|
||||
friend class SchedulingQueueTest;
|
||||
};
|
||||
|
||||
class CoreWorkerDirectActorTaskReceiver : public rpc::DirectActorHandler {
|
||||
class CoreWorkerDirectActorTaskReceiver {
|
||||
public:
|
||||
using TaskHandler = std::function<Status(
|
||||
const TaskSpecification &task_spec, const ResourceMappingType &resource_ids,
|
||||
@@ -342,13 +344,14 @@ class CoreWorkerDirectActorTaskReceiver : public rpc::DirectActorHandler {
|
||||
/// Initialize this receiver. This must be called prior to use.
|
||||
void Init(RayletClient &client);
|
||||
|
||||
/// Handle a `PushTask` request.
|
||||
/// Handle a `DirectActorAssignTask` request.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[out] reply The reply message.
|
||||
/// \param[in] send_reply_callback The callback to be called when the request is done.
|
||||
void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request,
|
||||
rpc::DirectActorAssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Handle a `DirectActorCallArgWaitComplete` request.
|
||||
///
|
||||
@@ -358,7 +361,7 @@ class CoreWorkerDirectActorTaskReceiver : public rpc::DirectActorHandler {
|
||||
void HandleDirectActorCallArgWaitComplete(
|
||||
const rpc::DirectActorCallArgWaitCompleteRequest &request,
|
||||
rpc::DirectActorCallArgWaitCompleteReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Set the max concurrency at runtime. It cannot be changed once set.
|
||||
void SetMaxActorConcurrency(int max_concurrency);
|
||||
@@ -366,8 +369,6 @@ class CoreWorkerDirectActorTaskReceiver : public rpc::DirectActorHandler {
|
||||
private:
|
||||
// Worker context.
|
||||
WorkerContext &worker_context_;
|
||||
/// The rpc service for `DirectActorService`.
|
||||
rpc::DirectActorGrpcService task_service_;
|
||||
/// The callback function to process a task.
|
||||
TaskHandler task_handler_;
|
||||
/// The callback function to exit the worker.
|
||||
|
||||
@@ -6,16 +6,11 @@
|
||||
namespace ray {
|
||||
|
||||
CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(
|
||||
WorkerContext &worker_context, std::unique_ptr<RayletClient> &raylet_client,
|
||||
boost::asio::io_service &io_service, rpc::GrpcServer &server,
|
||||
const TaskHandler &task_handler, const std::function<void()> &exit_handler)
|
||||
: worker_context_(worker_context),
|
||||
raylet_client_(raylet_client),
|
||||
task_service_(io_service, *this),
|
||||
std::unique_ptr<RayletClient> &raylet_client, const TaskHandler &task_handler,
|
||||
const std::function<void()> &exit_handler)
|
||||
: raylet_client_(raylet_client),
|
||||
task_handler_(task_handler),
|
||||
exit_handler_(exit_handler) {
|
||||
server.RegisterService(task_service_);
|
||||
}
|
||||
exit_handler_(exit_handler) {}
|
||||
|
||||
void CoreWorkerRayletTaskReceiver::HandleAssignTask(
|
||||
const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply,
|
||||
@@ -23,11 +18,6 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask(
|
||||
const Task task(request.task());
|
||||
const auto &task_spec = task.GetTaskSpecification();
|
||||
RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId();
|
||||
if (task_spec.IsActorTask() && worker_context_.CurrentActorUseDirectCall()) {
|
||||
send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr,
|
||||
nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
// Set the resource IDs for this task.
|
||||
// TODO: convert the resource map to protobuf and change this.
|
||||
|
||||
@@ -4,22 +4,19 @@
|
||||
#include <list>
|
||||
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/rpc/worker/worker_server.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
class CoreWorkerRayletTaskReceiver : public rpc::WorkerTaskHandler {
|
||||
class CoreWorkerRayletTaskReceiver {
|
||||
public:
|
||||
using TaskHandler = std::function<Status(
|
||||
const TaskSpecification &task_spec, const ResourceMappingType &resource_ids,
|
||||
std::vector<std::shared_ptr<RayObject>> *return_by_value)>;
|
||||
|
||||
CoreWorkerRayletTaskReceiver(WorkerContext &worker_context,
|
||||
std::unique_ptr<RayletClient> &raylet_client,
|
||||
boost::asio::io_service &io_service,
|
||||
rpc::GrpcServer &server, const TaskHandler &task_handler,
|
||||
CoreWorkerRayletTaskReceiver(std::unique_ptr<RayletClient> &raylet_client,
|
||||
const TaskHandler &task_handler,
|
||||
const std::function<void()> &exit_handler);
|
||||
|
||||
/// Handle a `AssignTask` request.
|
||||
@@ -31,15 +28,11 @@ class CoreWorkerRayletTaskReceiver : public rpc::WorkerTaskHandler {
|
||||
/// \param[in] send_reply_callback The callback to be called when the request is done.
|
||||
void HandleAssignTask(const rpc::AssignTaskRequest &request,
|
||||
rpc::AssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
private:
|
||||
// Worker context.
|
||||
WorkerContext &worker_context_;
|
||||
/// Raylet client.
|
||||
std::unique_ptr<RayletClient> &raylet_client_;
|
||||
/// The rpc service for `WorkerTaskService`.
|
||||
rpc::WorkerTaskGrpcService task_service_;
|
||||
/// The callback function to process a task.
|
||||
TaskHandler task_handler_;
|
||||
/// The callback function to exit the worker.
|
||||
|
||||
@@ -4,6 +4,10 @@ package ray.rpc;
|
||||
|
||||
import "src/ray/protobuf/common.proto";
|
||||
|
||||
message ActiveObjectIDs {
|
||||
repeated bytes object_ids = 1;
|
||||
}
|
||||
|
||||
// Persistent state of an ActorHandle.
|
||||
message ActorHandle {
|
||||
// ID of the actor.
|
||||
@@ -27,3 +31,64 @@ message ActorHandle {
|
||||
// Whether direct actor call is used.
|
||||
bool is_direct_call = 7;
|
||||
}
|
||||
|
||||
message AssignTaskRequest {
|
||||
// The task to be pushed.
|
||||
Task task = 1;
|
||||
// A list of the resources reserved for this worker.
|
||||
// TODO(zhijunfu): `resource_ids` is represented as
|
||||
// flatbutters-serialized bytes, will be moved to protobuf later.
|
||||
bytes resource_ids = 2;
|
||||
}
|
||||
|
||||
message AssignTaskReply {
|
||||
}
|
||||
|
||||
message ReturnObject {
|
||||
// Object ID.
|
||||
bytes object_id = 1;
|
||||
// Data of the object.
|
||||
bytes data = 2;
|
||||
// Metadata of the object.
|
||||
bytes metadata = 3;
|
||||
}
|
||||
|
||||
message DirectActorAssignTaskRequest {
|
||||
// The task to be pushed.
|
||||
TaskSpec task_spec = 1;
|
||||
// The sequence number of the task for this client. This must increase
|
||||
// sequentially starting from zero for each actor handle. The server
|
||||
// will guarantee tasks execute in this sequence, waiting for any
|
||||
// out-of-order request messages to arrive as necessary.
|
||||
int64 sequence_number = 2;
|
||||
// The max sequence number the client has processed responses for. This
|
||||
// is a performance optimization that allows the client to tell the server
|
||||
// to cancel any DirectActorAssignTaskRequests with seqno <= this value, rather than
|
||||
// waiting for the server to time out waiting for missing messages.
|
||||
int64 client_processed_up_to = 3;
|
||||
}
|
||||
|
||||
message DirectActorAssignTaskReply {
|
||||
// The returned objects.
|
||||
repeated ReturnObject return_objects = 1;
|
||||
}
|
||||
|
||||
message DirectActorCallArgWaitCompleteRequest {
|
||||
// Id used to uniquely identify this request. This is sent back to the core
|
||||
// worker to notify the wait has completed.
|
||||
int64 tag = 1;
|
||||
}
|
||||
|
||||
message DirectActorCallArgWaitCompleteReply {
|
||||
}
|
||||
|
||||
service WorkerService {
|
||||
// Push a task to a worker.
|
||||
rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply);
|
||||
// Push a task to a direct-call actor.
|
||||
rpc DirectActorAssignTask(DirectActorAssignTaskRequest)
|
||||
returns (DirectActorAssignTaskReply);
|
||||
// Notify wait for direct actor call args has completed.
|
||||
rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest)
|
||||
returns (DirectActorCallArgWaitCompleteReply);
|
||||
}
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package ray.rpc;
|
||||
|
||||
import "src/ray/protobuf/common.proto";
|
||||
|
||||
message ReturnObject {
|
||||
// Object ID.
|
||||
bytes object_id = 1;
|
||||
// Data of the object.
|
||||
bytes data = 2;
|
||||
// Metaata of the object.
|
||||
bytes metadata = 3;
|
||||
}
|
||||
|
||||
message PushTaskRequest {
|
||||
// The task to be pushed.
|
||||
TaskSpec task_spec = 1;
|
||||
// The sequence number of the task for this client. This must increase
|
||||
// sequentially starting from zero for each actor handle. The server
|
||||
// will guarantee tasks execute in this sequence, waiting for any
|
||||
// out-of-order request messages to arrive as necessary.
|
||||
int64 sequence_number = 2;
|
||||
// The max sequence number the client has processed responses for. This
|
||||
// is a performance optimization that allows the client to tell the server
|
||||
// to cancel any PushTaskRequests with seqno <= this value, rather than
|
||||
// waiting for the server to time out waiting for missing messages.
|
||||
int64 client_processed_up_to = 3;
|
||||
}
|
||||
|
||||
message PushTaskReply {
|
||||
// The returned objects.
|
||||
repeated ReturnObject return_objects = 1;
|
||||
}
|
||||
|
||||
message DirectActorCallArgWaitCompleteRequest {
|
||||
// Id used to uniquely identify this request. This is sent back to the core
|
||||
// worker to notify the wait has completed.
|
||||
int64 tag = 1;
|
||||
}
|
||||
|
||||
message DirectActorCallArgWaitCompleteReply {
|
||||
}
|
||||
|
||||
// Service for direct actor.
|
||||
service DirectActorService {
|
||||
// Push a task to a worker.
|
||||
rpc PushTask(PushTaskRequest) returns (PushTaskReply);
|
||||
// Notify wait for direct actor call args has completed
|
||||
rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest)
|
||||
returns (DirectActorCallArgWaitCompleteReply);
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package ray.rpc;
|
||||
|
||||
import "src/ray/protobuf/common.proto";
|
||||
|
||||
message ActiveObjectIDs {
|
||||
repeated bytes object_ids = 1;
|
||||
}
|
||||
|
||||
message AssignTaskRequest {
|
||||
// The task to be pushed.
|
||||
Task task = 1;
|
||||
// A list of the resources reserved for this worker.
|
||||
// TODO(zhijunfu): `resource_ids` is represented as
|
||||
// flatbutters-serialized bytes, will be moved to protobuf later.
|
||||
bytes resource_ids = 2;
|
||||
}
|
||||
|
||||
message AssignTaskReply {
|
||||
}
|
||||
|
||||
// Service for worker.
|
||||
service WorkerTaskService {
|
||||
// Push a task to a worker.
|
||||
rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply);
|
||||
}
|
||||
@@ -4,8 +4,8 @@
|
||||
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
#include "ray/raylet/raylet.h"
|
||||
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
|
||||
#include "src/ray/protobuf/direct_actor.pb.h"
|
||||
#include "src/ray/protobuf/core_worker.grpc.pb.h"
|
||||
#include "src/ray/protobuf/core_worker.pb.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
@@ -27,8 +27,6 @@ Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, i
|
||||
if (port_ > 0) {
|
||||
rpc_client_ = std::unique_ptr<rpc::WorkerTaskClient>(
|
||||
new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_));
|
||||
direct_rpc_client_ = std::unique_ptr<rpc::DirectActorClient>(
|
||||
new rpc::DirectActorClient("127.0.0.1", port_, client_call_manager_));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,7 +161,7 @@ void Worker::DirectActorCallArgWaitComplete(int64_t tag) {
|
||||
RAY_CHECK(port_ > 0);
|
||||
rpc::DirectActorCallArgWaitCompleteRequest request;
|
||||
request.set_tag(tag);
|
||||
auto status = direct_rpc_client_->DirectActorCallArgWaitComplete(
|
||||
auto status = rpc_client_->DirectActorCallArgWaitComplete(
|
||||
request, [](Status status, const rpc::DirectActorCallArgWaitCompleteReply &reply) {
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << "Failed to send wait complete: " << status.ToString();
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ray/common/task/scheduling_resources.h"
|
||||
#include "ray/common/task/task.h"
|
||||
#include "ray/common/task/task_common.h"
|
||||
#include "ray/rpc/worker/direct_actor_client.h"
|
||||
#include "ray/rpc/worker/worker_client.h"
|
||||
|
||||
#include <unistd.h> // pid_t
|
||||
@@ -109,8 +108,6 @@ class Worker {
|
||||
/// Whether the worker is detached. This is applies when the worker is actor.
|
||||
/// Detached actor means the actor's creator can exit without killing this actor.
|
||||
bool is_detached_actor_;
|
||||
/// The rpc client to send tasks to the direct actor service.
|
||||
std::unique_ptr<rpc::DirectActorClient> direct_rpc_client_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
#ifndef RAY_RPC_DIRECT_ACTOR_CLIENT_H
|
||||
#define RAY_RPC_DIRECT_ACTOR_CLIENT_H
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include "absl/base/thread_annotations.h"
|
||||
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/rpc/client_call.h"
|
||||
#include "ray/rpc/worker/direct_actor_common.h"
|
||||
#include "ray/util/logging.h"
|
||||
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
|
||||
#include "src/ray/protobuf/direct_actor.pb.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// Client used for communicating with a direct actor server.
|
||||
class DirectActorClient : public std::enable_shared_from_this<DirectActorClient> {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] address Address of the direct actor server.
|
||||
/// \param[in] port Port of the direct actor server.
|
||||
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
|
||||
static std::shared_ptr<DirectActorClient> make(const std::string &address,
|
||||
const int port,
|
||||
ClientCallManager &client_call_manager) {
|
||||
auto instance = new DirectActorClient(address, port, client_call_manager);
|
||||
return std::shared_ptr<DirectActorClient>(instance);
|
||||
}
|
||||
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] address Address of the direct actor server.
|
||||
/// \param[in] port Port of the direct actor server.
|
||||
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
|
||||
DirectActorClient(const std::string &address, const int port,
|
||||
ClientCallManager &client_call_manager)
|
||||
: client_call_manager_(client_call_manager) {
|
||||
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
|
||||
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
|
||||
stub_ = DirectActorService::NewStub(channel);
|
||||
};
|
||||
|
||||
/// Push a task.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
ray::Status PushTask(std::unique_ptr<PushTaskRequest> request,
|
||||
const ClientCallback<PushTaskReply> &callback) {
|
||||
request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter());
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (request->task_spec().caller_id() != cur_caller_id_) {
|
||||
// We are running a new task, reset the seq no counter.
|
||||
max_finished_seq_no_ = -1;
|
||||
cur_caller_id_ = request->task_spec().caller_id();
|
||||
}
|
||||
send_queue_.push_back(std::make_pair(std::move(request), callback));
|
||||
}
|
||||
SendRequests();
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
/// Notify a wait has completed for direct actor call arguments.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
ray::Status DirectActorCallArgWaitComplete(
|
||||
const DirectActorCallArgWaitCompleteRequest &request,
|
||||
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) {
|
||||
auto call = client_call_manager_.CreateCall<DirectActorService,
|
||||
DirectActorCallArgWaitCompleteRequest,
|
||||
DirectActorCallArgWaitCompleteReply>(
|
||||
*stub_, &DirectActorService::Stub::PrepareAsyncDirectActorCallArgWaitComplete,
|
||||
request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
/// Send as many pending tasks as possible. This method is thread-safe.
|
||||
///
|
||||
/// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being
|
||||
/// sent at once. This prevents the server scheduling queue from being overwhelmed.
|
||||
/// See direct_actor.proto for a description of the ordering protocol.
|
||||
void SendRequests() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto this_ptr = this->shared_from_this();
|
||||
|
||||
while (!send_queue_.empty() && rpc_bytes_in_flight_ < kMaxBytesInFlight) {
|
||||
auto pair = std::move(*send_queue_.begin());
|
||||
send_queue_.pop_front();
|
||||
|
||||
auto request = std::move(pair.first);
|
||||
auto callback = pair.second;
|
||||
int64_t task_size = RequestSizeInBytes(*request);
|
||||
int64_t seq_no = request->sequence_number();
|
||||
request->set_client_processed_up_to(max_finished_seq_no_);
|
||||
rpc_bytes_in_flight_ += task_size;
|
||||
|
||||
client_call_manager_.CreateCall<DirectActorService, PushTaskRequest, PushTaskReply>(
|
||||
*stub_, &DirectActorService::Stub::PrepareAsyncPushTask, *request,
|
||||
[this, this_ptr, seq_no, task_size, callback](Status status,
|
||||
const rpc::PushTaskReply &reply) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (seq_no > max_finished_seq_no_) {
|
||||
max_finished_seq_no_ = seq_no;
|
||||
}
|
||||
rpc_bytes_in_flight_ -= task_size;
|
||||
RAY_CHECK(rpc_bytes_in_flight_ >= 0);
|
||||
}
|
||||
SendRequests();
|
||||
callback(status, reply);
|
||||
});
|
||||
}
|
||||
|
||||
if (!send_queue_.empty()) {
|
||||
RAY_LOG(DEBUG) << "client send queue size " << send_queue_.size();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/// Protects against unsafe concurrent access from the callback thread.
|
||||
std::mutex mutex_;
|
||||
|
||||
/// The gRPC-generated stub.
|
||||
std::unique_ptr<DirectActorService::Stub> stub_;
|
||||
|
||||
/// The `ClientCallManager` used for managing requests.
|
||||
ClientCallManager &client_call_manager_;
|
||||
|
||||
/// Queue of requests to send.
|
||||
std::deque<std::pair<std::unique_ptr<PushTaskRequest>, ClientCallback<PushTaskReply>>>
|
||||
send_queue_ GUARDED_BY(mutex_);
|
||||
|
||||
/// The number of bytes currently in flight.
|
||||
int64_t rpc_bytes_in_flight_ GUARDED_BY(mutex_) = 0;
|
||||
|
||||
/// The max sequence number we have processed responses for.
|
||||
int64_t max_finished_seq_no_ GUARDED_BY(mutex_) = -1;
|
||||
|
||||
/// The task id we are currently sending requests for. When this changes,
|
||||
/// the max finished seq no counter is reset.
|
||||
std::string cur_caller_id_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RPC_DIRECT_ACTOR_CLIENT_H
|
||||
@@ -1,28 +0,0 @@
|
||||
#ifndef RAY_RPC_DIRECT_ACTOR_COMMON_H
|
||||
#define RAY_RPC_DIRECT_ACTOR_COMMON_H
|
||||
|
||||
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
|
||||
#include "src/ray/protobuf/direct_actor.pb.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// The maximum number of requests in flight per client.
|
||||
const int64_t kMaxBytesInFlight = 16 * 1024 * 1024;
|
||||
|
||||
/// The base size in bytes per request.
|
||||
const int64_t kBaseRequestSize = 1024;
|
||||
|
||||
/// Get the estimated size in bytes of the given task.
|
||||
const static int64_t RequestSizeInBytes(const PushTaskRequest &request) {
|
||||
int64_t size = kBaseRequestSize;
|
||||
for (auto &arg : request.task_spec().args()) {
|
||||
size += arg.data().size();
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RPC_DIRECT_ACTOR_COMMON_H
|
||||
@@ -1,88 +0,0 @@
|
||||
#ifndef RAY_RPC_DIRECT_ACTOR_SERVER_H
|
||||
#define RAY_RPC_DIRECT_ACTOR_SERVER_H
|
||||
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
#include "ray/rpc/server_call.h"
|
||||
|
||||
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
|
||||
#include "src/ray/protobuf/direct_actor.pb.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// Interface of the `DirectActorService`, see `src/ray/protobuf/direct_actor.proto`.
|
||||
class DirectActorHandler {
|
||||
public:
|
||||
/// Handle a `PushTask` request.
|
||||
/// The implementation can handle this request asynchronously. When hanling is done, the
|
||||
/// `done_callback` should be called.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[out] reply The reply message.
|
||||
/// \param[in] done_callback The callback to be called when the request is done.
|
||||
virtual void HandlePushTask(const PushTaskRequest &request, PushTaskReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
|
||||
/// Handle a wait reply for direct actor call arg dependencies.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[out] reply The reply message.
|
||||
/// \param[in] send_replay_callback The callback to be called when the request is done.
|
||||
virtual void HandleDirectActorCallArgWaitComplete(
|
||||
const rpc::DirectActorCallArgWaitCompleteRequest &request,
|
||||
rpc::DirectActorCallArgWaitCompleteReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) = 0;
|
||||
};
|
||||
|
||||
/// The `GrpcServer` for `WorkerService`.
|
||||
class DirectActorGrpcService : public GrpcService {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] main_service See super class.
|
||||
/// \param[in] handler The service handler that actually handle the requests.
|
||||
DirectActorGrpcService(boost::asio::io_service &main_service,
|
||||
DirectActorHandler &service_handler)
|
||||
: GrpcService(main_service), service_handler_(service_handler){};
|
||||
|
||||
protected:
|
||||
grpc::Service &GetGrpcService() override { return service_; }
|
||||
|
||||
void InitServerCallFactories(
|
||||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
|
||||
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
|
||||
*server_call_factories_and_concurrencies) override {
|
||||
// Initialize the Factory for `PushTask` requests.
|
||||
std::unique_ptr<ServerCallFactory> push_task_call_Factory(
|
||||
new ServerCallFactoryImpl<DirectActorService, DirectActorHandler, PushTaskRequest,
|
||||
PushTaskReply>(
|
||||
service_, &DirectActorService::AsyncService::RequestPushTask,
|
||||
service_handler_, &DirectActorHandler::HandlePushTask, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(push_task_call_Factory), 100);
|
||||
|
||||
// Initialize the Factory for `DirectActorCallArgWaitComplete` requests.
|
||||
std::unique_ptr<ServerCallFactory> wait_complete_call_Factory(
|
||||
new ServerCallFactoryImpl<DirectActorService, DirectActorHandler,
|
||||
DirectActorCallArgWaitCompleteRequest,
|
||||
DirectActorCallArgWaitCompleteReply>(
|
||||
service_,
|
||||
&DirectActorService::AsyncService::RequestDirectActorCallArgWaitComplete,
|
||||
service_handler_, &DirectActorHandler::HandleDirectActorCallArgWaitComplete,
|
||||
cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(wait_complete_call_Factory), 100);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The grpc async service object.
|
||||
DirectActorService::AsyncService service_;
|
||||
|
||||
/// The service handler that actually handle the requests.
|
||||
DirectActorHandler &service_handler_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif
|
||||
@@ -1,21 +1,40 @@
|
||||
#ifndef RAY_RPC_WORKER_CLIENT_H
|
||||
#define RAY_RPC_WORKER_CLIENT_H
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include "absl/base/thread_annotations.h"
|
||||
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/rpc/client_call.h"
|
||||
#include "ray/util/logging.h"
|
||||
#include "src/ray/protobuf/worker.grpc.pb.h"
|
||||
#include "src/ray/protobuf/worker.pb.h"
|
||||
#include "src/ray/protobuf/core_worker.grpc.pb.h"
|
||||
#include "src/ray/protobuf/core_worker.pb.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// The maximum number of requests in flight per client.
|
||||
const int64_t kMaxBytesInFlight = 16 * 1024 * 1024;
|
||||
|
||||
/// The base size in bytes per request.
|
||||
const int64_t kBaseRequestSize = 1024;
|
||||
|
||||
/// Get the estimated size in bytes of the given task.
|
||||
const static int64_t RequestSizeInBytes(const DirectActorAssignTaskRequest &request) {
|
||||
int64_t size = kBaseRequestSize;
|
||||
for (auto &arg : request.task_spec().args()) {
|
||||
size += arg.data().size();
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Client used for communicating with a remote worker server.
|
||||
class WorkerTaskClient {
|
||||
class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
@@ -27,7 +46,7 @@ class WorkerTaskClient {
|
||||
: client_call_manager_(client_call_manager) {
|
||||
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
|
||||
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
|
||||
stub_ = WorkerTaskService::NewStub(channel);
|
||||
stub_ = WorkerService::NewStub(channel);
|
||||
};
|
||||
|
||||
/// Assign a task to the work.
|
||||
@@ -37,19 +56,119 @@ class WorkerTaskClient {
|
||||
/// \return if the rpc call succeeds
|
||||
ray::Status AssignTask(const AssignTaskRequest &request,
|
||||
const ClientCallback<AssignTaskReply> &callback) {
|
||||
auto call = client_call_manager_
|
||||
.CreateCall<WorkerTaskService, AssignTaskRequest, AssignTaskReply>(
|
||||
*stub_, &WorkerTaskService::Stub::PrepareAsyncAssignTask, request,
|
||||
callback);
|
||||
auto call =
|
||||
client_call_manager_
|
||||
.CreateCall<WorkerService, AssignTaskRequest, AssignTaskReply>(
|
||||
*stub_, &WorkerService::Stub::PrepareAsyncAssignTask, request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
/// Push a task.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
ray::Status DirectActorAssignTask(
|
||||
std::unique_ptr<DirectActorAssignTaskRequest> request,
|
||||
const ClientCallback<DirectActorAssignTaskReply> &callback) {
|
||||
request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter());
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (request->task_spec().caller_id() != cur_caller_id_) {
|
||||
// We are running a new task, reset the seq no counter.
|
||||
max_finished_seq_no_ = -1;
|
||||
cur_caller_id_ = request->task_spec().caller_id();
|
||||
}
|
||||
send_queue_.push_back(std::make_pair(std::move(request), callback));
|
||||
}
|
||||
SendRequests();
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
/// Notify a wait has completed for direct actor call arguments.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
ray::Status DirectActorCallArgWaitComplete(
|
||||
const DirectActorCallArgWaitCompleteRequest &request,
|
||||
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) {
|
||||
auto call =
|
||||
client_call_manager_
|
||||
.CreateCall<WorkerService, DirectActorCallArgWaitCompleteRequest,
|
||||
DirectActorCallArgWaitCompleteReply>(
|
||||
*stub_, &WorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete,
|
||||
request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
/// Send as many pending tasks as possible. This method is thread-safe.
|
||||
///
|
||||
/// The client will guarantee no more than kMaxBytesInFlight bytes of RPCs are being
|
||||
/// sent at once. This prevents the server scheduling queue from being overwhelmed.
|
||||
/// See direct_actor.proto for a description of the ordering protocol.
|
||||
void SendRequests() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto this_ptr = this->shared_from_this();
|
||||
|
||||
while (!send_queue_.empty() && rpc_bytes_in_flight_ < kMaxBytesInFlight) {
|
||||
auto pair = std::move(*send_queue_.begin());
|
||||
send_queue_.pop_front();
|
||||
|
||||
auto request = std::move(pair.first);
|
||||
auto callback = pair.second;
|
||||
int64_t task_size = RequestSizeInBytes(*request);
|
||||
int64_t seq_no = request->sequence_number();
|
||||
request->set_client_processed_up_to(max_finished_seq_no_);
|
||||
rpc_bytes_in_flight_ += task_size;
|
||||
|
||||
client_call_manager_.CreateCall<WorkerService, DirectActorAssignTaskRequest,
|
||||
DirectActorAssignTaskReply>(
|
||||
*stub_, &WorkerService::Stub::PrepareAsyncDirectActorAssignTask, *request,
|
||||
[this, this_ptr, seq_no, task_size, callback](
|
||||
Status status, const rpc::DirectActorAssignTaskReply &reply) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (seq_no > max_finished_seq_no_) {
|
||||
max_finished_seq_no_ = seq_no;
|
||||
}
|
||||
rpc_bytes_in_flight_ -= task_size;
|
||||
RAY_CHECK(rpc_bytes_in_flight_ >= 0);
|
||||
}
|
||||
SendRequests();
|
||||
callback(status, reply);
|
||||
});
|
||||
}
|
||||
|
||||
if (!send_queue_.empty()) {
|
||||
RAY_LOG(DEBUG) << "client send queue size " << send_queue_.size();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/// Protects against unsafe concurrent access from the callback thread.
|
||||
std::mutex mutex_;
|
||||
|
||||
/// The gRPC-generated stub.
|
||||
std::unique_ptr<WorkerTaskService::Stub> stub_;
|
||||
std::unique_ptr<WorkerService::Stub> stub_;
|
||||
|
||||
/// The `ClientCallManager` used for managing requests.
|
||||
ClientCallManager &client_call_manager_;
|
||||
|
||||
/// Queue of requests to send.
|
||||
std::deque<std::pair<std::unique_ptr<DirectActorAssignTaskRequest>,
|
||||
ClientCallback<DirectActorAssignTaskReply>>>
|
||||
send_queue_ GUARDED_BY(mutex_);
|
||||
|
||||
/// The number of bytes currently in flight.
|
||||
int64_t rpc_bytes_in_flight_ GUARDED_BY(mutex_) = 0;
|
||||
|
||||
/// The max sequence number we have processed responses for.
|
||||
int64_t max_finished_seq_no_ GUARDED_BY(mutex_) = -1;
|
||||
|
||||
/// The task id we are currently sending requests for. When this changes,
|
||||
/// the max finished seq no counter is reset.
|
||||
std::string cur_caller_id_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
#include "ray/rpc/worker/worker_server.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
#define RAY_CORE_WORKER_RPC_HANDLER(HANDLER, CONCURRENCY) \
|
||||
std::unique_ptr<ServerCallFactory> HANDLER##_call_factory( \
|
||||
new ServerCallFactoryImpl<WorkerService, CoreWorker, HANDLER##Request, \
|
||||
HANDLER##Reply>( \
|
||||
service_, &WorkerService::AsyncService::Request##HANDLER, core_worker_, \
|
||||
&CoreWorker::Handle##HANDLER, cq, main_service_)); \
|
||||
server_call_factories_and_concurrencies->emplace_back( \
|
||||
std::move(HANDLER##_call_factory), CONCURRENCY);
|
||||
|
||||
WorkerGrpcService::WorkerGrpcService(boost::asio::io_service &main_service,
|
||||
CoreWorker &core_worker)
|
||||
: GrpcService(main_service), core_worker_(core_worker){};
|
||||
|
||||
void WorkerGrpcService::InitServerCallFactories(
|
||||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
|
||||
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
|
||||
*server_call_factories_and_concurrencies) {
|
||||
RAY_CORE_WORKER_RPC_HANDLERS
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
@@ -4,36 +4,23 @@
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
#include "ray/rpc/server_call.h"
|
||||
|
||||
#include "src/ray/protobuf/worker.grpc.pb.h"
|
||||
#include "src/ray/protobuf/worker.pb.h"
|
||||
#include "src/ray/protobuf/core_worker.grpc.pb.h"
|
||||
#include "src/ray/protobuf/core_worker.pb.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
class CoreWorker;
|
||||
|
||||
namespace rpc {
|
||||
|
||||
/// Interface of the `WorkerService`, see `src/ray/protobuf/worker.proto`.
|
||||
class WorkerTaskHandler {
|
||||
public:
|
||||
/// Handle a `AssignTask` request.
|
||||
/// The implementation can handle this request asynchronously. When handling is done,
|
||||
/// the `send_reply_callback` should be called.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[out] reply The reply message.
|
||||
/// \param[in] send_reply_callback The callback to be called when the request is done.
|
||||
virtual void HandleAssignTask(const AssignTaskRequest &request, AssignTaskReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
};
|
||||
|
||||
/// The `GrpcServer` for `WorkerService`.
|
||||
class WorkerTaskGrpcService : public GrpcService {
|
||||
class WorkerGrpcService : public GrpcService {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] main_service See super class.
|
||||
/// \param[in] handler The service handler that actually handle the requests.
|
||||
WorkerTaskGrpcService(boost::asio::io_service &main_service,
|
||||
WorkerTaskHandler &service_handler)
|
||||
: GrpcService(main_service), service_handler_(service_handler){};
|
||||
WorkerGrpcService(boost::asio::io_service &main_service, CoreWorker &core_worker);
|
||||
|
||||
protected:
|
||||
grpc::Service &GetGrpcService() override { return service_; }
|
||||
@@ -41,25 +28,14 @@ class WorkerTaskGrpcService : public GrpcService {
|
||||
void InitServerCallFactories(
|
||||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
|
||||
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
|
||||
*server_call_factories_and_concurrencies) override {
|
||||
// Initialize the Factory for `AssignTask` requests.
|
||||
std::unique_ptr<ServerCallFactory> push_task_call_Factory(
|
||||
new ServerCallFactoryImpl<WorkerTaskService, WorkerTaskHandler, AssignTaskRequest,
|
||||
AssignTaskReply>(
|
||||
service_, &WorkerTaskService::AsyncService::RequestAssignTask,
|
||||
service_handler_, &WorkerTaskHandler::HandleAssignTask, cq, main_service_));
|
||||
|
||||
// Set `AssignTask`'s accept concurrency to 5.
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(push_task_call_Factory), 5);
|
||||
}
|
||||
*server_call_factories_and_concurrencies) override;
|
||||
|
||||
private:
|
||||
/// The grpc async service object.
|
||||
WorkerTaskService::AsyncService service_;
|
||||
WorkerService::AsyncService service_;
|
||||
|
||||
/// The service handler that actually handle the requests.
|
||||
WorkerTaskHandler &service_handler_;
|
||||
/// The core worker that actually handles the requests.
|
||||
CoreWorker &core_worker_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
|
||||
Reference in New Issue
Block a user